|
6 | 6 | import warnings
|
7 | 7 | from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
|
8 | 8 | from operator import itemgetter
|
9 |
| -from typing import ( |
10 |
| - Any, |
11 |
| - Callable, |
12 |
| - Literal, |
13 |
| - Optional, |
14 |
| - TypedDict, |
15 |
| - Union, |
16 |
| - cast, |
17 |
| -) |
| 9 | +from typing import Any, Callable, Literal, Optional, TypedDict, Union, cast |
18 | 10 |
|
19 | 11 | from langchain_core._api import deprecated
|
20 | 12 | from langchain_core.callbacks import (
|
|
46 | 38 | ToolMessage,
|
47 | 39 | ToolMessageChunk,
|
48 | 40 | )
|
49 |
| -from langchain_core.output_parsers import ( |
50 |
| - JsonOutputParser, |
51 |
| - PydanticOutputParser, |
52 |
| -) |
| 41 | +from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser |
53 | 42 | from langchain_core.output_parsers.base import OutputParserLike
|
54 | 43 | from langchain_core.output_parsers.openai_tools import (
|
55 | 44 | JsonOutputKeyToolsParser,
|
|
60 | 49 | from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
61 | 50 | from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
62 | 51 | from langchain_core.tools import BaseTool
|
63 |
| -from langchain_core.utils import ( |
64 |
| - from_env, |
65 |
| - get_pydantic_field_names, |
66 |
| - secret_from_env, |
67 |
| -) |
| 52 | +from langchain_core.utils import from_env, get_pydantic_field_names, secret_from_env |
68 | 53 | from langchain_core.utils.function_calling import (
|
69 | 54 | convert_to_openai_function,
|
70 | 55 | convert_to_openai_tool,
|
71 | 56 | )
|
72 | 57 | from langchain_core.utils.pydantic import is_basemodel_subclass
|
73 |
| -from pydantic import ( |
74 |
| - BaseModel, |
75 |
| - ConfigDict, |
76 |
| - Field, |
77 |
| - SecretStr, |
78 |
| - model_validator, |
79 |
| -) |
| 58 | +from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator |
80 | 59 | from typing_extensions import Self
|
81 | 60 |
|
82 | 61 | from langchain_groq.version import __version__
|
@@ -122,7 +101,7 @@ class ChatGroq(BaseChatModel):
|
122 | 101 |
|
123 | 102 | See the `Groq documentation
|
124 | 103 | <https://console.groq.com/docs/reasoning#reasoning>`__ for more
|
125 |
| - details and a list of supported reasoning models. |
| 104 | + details and a list of supported models. |
126 | 105 | model_kwargs: Dict[str, Any]
|
127 | 106 | Holds any model parameters valid for create call not
|
128 | 107 | explicitly specified.
|
@@ -328,20 +307,15 @@ class Joke(BaseModel):
|
328 | 307 | overridden in ``reasoning_effort``.
|
329 | 308 |
|
330 | 309 | See the `Groq documentation <https://console.groq.com/docs/reasoning#reasoning>`__
|
331 |
| - for more details and a list of supported reasoning models. |
| 310 | + for more details and a list of supported models. |
332 | 311 | """
|
333 |
| - reasoning_effort: Optional[Literal["none", "default"]] = Field(default=None) |
| 312 | + reasoning_effort: Optional[str] = Field(default=None) |
334 | 313 | """The level of effort the model will put into reasoning. Groq will default to
|
335 |
| - enabling reasoning if left undefined. If set to ``none``, ``reasoning_format`` will |
336 |
| - not apply and ``reasoning_content`` will not be returned. |
337 |
| -
|
338 |
| - - ``'none'``: Disable reasoning. The model will not use any reasoning tokens when |
339 |
| - generating a response. |
340 |
| - - ``'default'``: Enable reasoning. |
| 314 | + enabling reasoning if left undefined. |
341 | 315 |
|
342 | 316 | See the `Groq documentation
|
343 | 317 | <https://console.groq.com/docs/reasoning#options-for-reasoning-effort>`__ for more
|
344 |
| - details and a list of models that support setting a reasoning effort. |
| 318 | + details and a list of options and models that support setting a reasoning effort. |
345 | 319 | """
|
346 | 320 | model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
347 | 321 | """Holds any model parameters valid for `create` call not explicitly specified."""
|
@@ -601,6 +575,11 @@ def _stream(
|
601 | 575 | generation_info["system_fingerprint"] = system_fingerprint
|
602 | 576 | service_tier = params.get("service_tier") or self.service_tier
|
603 | 577 | generation_info["service_tier"] = service_tier
|
| 578 | + reasoning_effort = ( |
| 579 | + params.get("reasoning_effort") or self.reasoning_effort |
| 580 | + ) |
| 581 | + if reasoning_effort: |
| 582 | + generation_info["reasoning_effort"] = reasoning_effort |
604 | 583 | logprobs = choice.get("logprobs")
|
605 | 584 | if logprobs:
|
606 | 585 | generation_info["logprobs"] = logprobs
|
@@ -644,6 +623,11 @@ async def _astream(
|
644 | 623 | generation_info["system_fingerprint"] = system_fingerprint
|
645 | 624 | service_tier = params.get("service_tier") or self.service_tier
|
646 | 625 | generation_info["service_tier"] = service_tier
|
| 626 | + reasoning_effort = ( |
| 627 | + params.get("reasoning_effort") or self.reasoning_effort |
| 628 | + ) |
| 629 | + if reasoning_effort: |
| 630 | + generation_info["reasoning_effort"] = reasoning_effort |
647 | 631 | logprobs = choice.get("logprobs")
|
648 | 632 | if logprobs:
|
649 | 633 | generation_info["logprobs"] = logprobs
|
@@ -714,6 +698,9 @@ def _create_chat_result(
|
714 | 698 | "system_fingerprint": response.get("system_fingerprint", ""),
|
715 | 699 | }
|
716 | 700 | llm_output["service_tier"] = params.get("service_tier") or self.service_tier
|
| 701 | + reasoning_effort = params.get("reasoning_effort") or self.reasoning_effort |
| 702 | + if reasoning_effort: |
| 703 | + llm_output["reasoning_effort"] = reasoning_effort |
717 | 704 | return ChatResult(generations=generations, llm_output=llm_output)
|
718 | 705 |
|
719 | 706 | def _create_message_dicts(
|
|
0 commit comments