Skip to content

Commit 65716cf

Browse files
feat(perplexity): Created Dedicated Output Parser to Support Reasoning Model Output for perplexity (#33670)
1 parent 1b77a19 commit 65716cf

File tree

3 files changed

+451
-3
lines changed

3 files changed

+451
-3
lines changed

libs/partners/perplexity/langchain_perplexity/chat_models.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
UsageMetadata,
3434
subtract_usage,
3535
)
36-
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
3736
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
3837
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
3938
from langchain_core.utils import get_pydantic_field_names, secret_from_env
@@ -42,6 +41,11 @@
4241
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
4342
from typing_extensions import Self
4443

44+
from langchain_perplexity.output_parsers import (
45+
ReasoningJsonOutputParser,
46+
ReasoningStructuredOutputParser,
47+
)
48+
4549
_DictOrPydanticClass: TypeAlias = dict[str, Any] | type[BaseModel]
4650
_DictOrPydantic: TypeAlias = dict | BaseModel
4751

@@ -521,9 +525,9 @@ def with_structured_output(
521525
},
522526
)
523527
output_parser = (
524-
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
528+
ReasoningStructuredOutputParser(pydantic_object=schema) # type: ignore[arg-type]
525529
if is_pydantic_schema
526-
else JsonOutputParser()
530+
else ReasoningJsonOutputParser()
527531
)
528532
else:
529533
raise ValueError(
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import re
2+
from typing import Any, Generic
3+
4+
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
5+
from langchain_core.outputs import Generation
6+
from langchain_core.utils.pydantic import TBaseModel
7+
8+
9+
def strip_think_tags(text: str) -> str:
10+
"""Removes all <think>...</think> tags and their content from text.
11+
12+
This function removes all occurrences of think tags, preserving text
13+
before, between, and after the tags. It also handles markdown code fences.
14+
15+
Args:
16+
text: The input text that may contain think tags.
17+
18+
Returns:
19+
The text with all `<think>...</think>` blocks removed.
20+
"""
21+
# Remove all <think>...</think> blocks using regex
22+
# The pattern matches <think> followed by any content (non-greedy) until </think>
23+
result = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL)
24+
25+
# Remove markdown code fence markers if present
26+
result = result.strip()
27+
if result.startswith("```json"):
28+
result = result[len("```json") :].strip()
29+
elif result.startswith("```"):
30+
result = result[3:].strip()
31+
32+
if result.endswith("```"):
33+
result = result[:-3].strip()
34+
35+
return result
36+
37+
38+
class ReasoningJsonOutputParser(JsonOutputParser):
39+
"""A JSON output parser that strips reasoning tags before parsing.
40+
41+
This parser removes any content enclosed in <think> tags from the input text
42+
before delegating to the parent JsonOutputParser for JSON parsing.
43+
44+
"""
45+
46+
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
47+
"""Parse the result of an LLM call to a JSON object.
48+
49+
Args:
50+
result: The result of the LLM call.
51+
partial: Whether to parse partial JSON objects.
52+
If `True`, the output will be a JSON object containing
53+
all the keys that have been returned so far.
54+
If `False`, the output will be the full JSON object.
55+
56+
Returns:
57+
The parsed JSON object.
58+
59+
Raises:
60+
OutputParserException: If the output is not valid JSON.
61+
"""
62+
text = result[0].text
63+
text = strip_think_tags(text)
64+
return super().parse_result([Generation(text=text)], partial=partial)
65+
66+
67+
class ReasoningStructuredOutputParser(
68+
PydanticOutputParser[TBaseModel], Generic[TBaseModel]
69+
):
70+
"""A structured output parser that strips reasoning tags before parsing.
71+
72+
This parser removes any content enclosed in <think> tags from the input text
73+
before delegating to the parent PydanticOutputParser for structured parsing.
74+
"""
75+
76+
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
77+
"""Parse the result of an LLM call to a Pydantic object.
78+
79+
Args:
80+
result: The result of the LLM call.
81+
partial: Whether to parse partial JSON objects.
82+
If `True`, the output will be a JSON object containing
83+
all the keys that have been returned so far.
84+
If `False`, the output will be the full JSON object.
85+
"""
86+
text = result[0].text
87+
text = strip_think_tags(text)
88+
return super().parse_result([Generation(text=text)], partial=partial)

0 commit comments

Comments
 (0)