Skip to content

Commit 187131c

Browse files
author
Erick Friis
authored
Revert "integrations[patch]: remove non-required chat param defaults" (#29048)
Reverts #26730 discuss best way to release default changes (esp openai temperature)
1 parent 3d7ae8b commit 187131c

File tree

15 files changed

+43
-51
lines changed

15 files changed

+43
-51
lines changed

libs/partners/anthropic/langchain_anthropic/chat_models.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ class ChatAnthropic(BaseChatModel):
307307
Key init args — client params:
308308
timeout: Optional[float]
309309
Timeout for requests.
310-
max_retries: Optional[int]
310+
max_retries: int
311311
Max number of retries if a request fails.
312312
api_key: Optional[str]
313313
Anthropic API key. If not passed in will be read from env var ANTHROPIC_API_KEY.
@@ -558,7 +558,8 @@ class Joke(BaseModel):
558558
default_request_timeout: Optional[float] = Field(None, alias="timeout")
559559
"""Timeout for requests to Anthropic Completion API."""
560560

561-
max_retries: Optional[int] = None
561+
# sdk default = 2: https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#retries
562+
max_retries: int = 2
562563
"""Number of retries allowed for requests sent to the Anthropic Completion API."""
563564

564565
stop_sequences: Optional[List[str]] = Field(None, alias="stop")
@@ -661,10 +662,9 @@ def _client_params(self) -> Dict[str, Any]:
661662
client_params: Dict[str, Any] = {
662663
"api_key": self.anthropic_api_key.get_secret_value(),
663664
"base_url": self.anthropic_api_url,
665+
"max_retries": self.max_retries,
664666
"default_headers": (self.default_headers or None),
665667
}
666-
if self.max_retries is not None:
667-
client_params["max_retries"] = self.max_retries
668668
# value <= 0 indicates the param should be ignored. None is a meaningful value
669669
# for Anthropic client and treated differently than not specifying the param at
670670
# all.

libs/partners/fireworks/langchain_fireworks/chat_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ def is_lc_serializable(cls) -> bool:
316316
default="accounts/fireworks/models/mixtral-8x7b-instruct", alias="model"
317317
)
318318
"""Model name to use."""
319-
temperature: Optional[float] = None
319+
temperature: float = 0.0
320320
"""What sampling temperature to use."""
321321
stop: Optional[Union[str, List[str]]] = Field(default=None, alias="stop_sequences")
322322
"""Default stop sequences."""

libs/partners/fireworks/tests/unit_tests/__snapshots__/test_standard.ambr

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
'request_timeout': 60.0,
2323
'stop': list([
2424
]),
25-
'temperature': 0.0,
2625
}),
2726
'lc': 1,
2827
'name': 'ChatFireworks',

libs/partners/groq/langchain_groq/chat_models.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ class ChatGroq(BaseChatModel):
119119
Key init args — client params:
120120
timeout: Union[float, Tuple[float, float], Any, None]
121121
Timeout for requests.
122-
max_retries: Optional[int]
122+
max_retries: int
123123
Max number of retries.
124124
api_key: Optional[str]
125125
Groq API key. If not passed in will be read from env var GROQ_API_KEY.
@@ -303,7 +303,7 @@ class Joke(BaseModel):
303303
async_client: Any = Field(default=None, exclude=True) #: :meta private:
304304
model_name: str = Field(default="mixtral-8x7b-32768", alias="model")
305305
"""Model name to use."""
306-
temperature: Optional[float] = None
306+
temperature: float = 0.7
307307
"""What sampling temperature to use."""
308308
stop: Optional[Union[List[str], str]] = Field(default=None, alias="stop_sequences")
309309
"""Default stop sequences."""
@@ -327,11 +327,11 @@ class Joke(BaseModel):
327327
)
328328
"""Timeout for requests to Groq completion API. Can be float, httpx.Timeout or
329329
None."""
330-
max_retries: Optional[int] = None
330+
max_retries: int = 2
331331
"""Maximum number of retries to make when generating."""
332332
streaming: bool = False
333333
"""Whether to stream the results or not."""
334-
n: Optional[int] = None
334+
n: int = 1
335335
"""Number of chat completions to generate for each prompt."""
336336
max_tokens: Optional[int] = None
337337
"""Maximum number of tokens to generate."""
@@ -379,11 +379,10 @@ def build_extra(cls, values: Dict[str, Any]) -> Any:
379379
@model_validator(mode="after")
380380
def validate_environment(self) -> Self:
381381
"""Validate that api key and python package exists in environment."""
382-
if self.n is not None and self.n < 1:
382+
if self.n < 1:
383383
raise ValueError("n must be at least 1.")
384-
elif self.n is not None and self.n > 1 and self.streaming:
384+
if self.n > 1 and self.streaming:
385385
raise ValueError("n must be 1 when streaming.")
386-
387386
if self.temperature == 0:
388387
self.temperature = 1e-8
389388

@@ -393,11 +392,10 @@ def validate_environment(self) -> Self:
393392
),
394393
"base_url": self.groq_api_base,
395394
"timeout": self.request_timeout,
395+
"max_retries": self.max_retries,
396396
"default_headers": self.default_headers,
397397
"default_query": self.default_query,
398398
}
399-
if self.max_retries is not None:
400-
client_params["max_retries"] = self.max_retries
401399

402400
try:
403401
import groq

libs/partners/groq/tests/unit_tests/__snapshots__/test_standard.ambr

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
'max_retries': 2,
1818
'max_tokens': 100,
1919
'model_name': 'mixtral-8x7b-32768',
20+
'n': 1,
2021
'request_timeout': 60.0,
2122
'stop': list([
2223
]),

libs/partners/mistralai/langchain_mistralai/chat_models.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,8 @@ def _create_retry_decorator(
9595
"""Returns a tenacity retry decorator, preconfigured to handle exceptions"""
9696

9797
errors = [httpx.RequestError, httpx.StreamError]
98-
kwargs: dict = dict(
99-
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
100-
)
10198
return create_base_retry_decorator(
102-
**{k: v for k, v in kwargs.items() if v is not None}
99+
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
103100
)
104101

105102

@@ -383,13 +380,13 @@ class ChatMistralAI(BaseChatModel):
383380
default_factory=secret_from_env("MISTRAL_API_KEY", default=None),
384381
)
385382
endpoint: Optional[str] = Field(default=None, alias="base_url")
386-
max_retries: Optional[int] = None
387-
timeout: Optional[int] = None
388-
max_concurrent_requests: Optional[int] = None
383+
max_retries: int = 5
384+
timeout: int = 120
385+
max_concurrent_requests: int = 64
389386
model: str = Field(default="mistral-small", alias="model_name")
390-
temperature: Optional[float] = None
387+
temperature: float = 0.7
391388
max_tokens: Optional[int] = None
392-
top_p: Optional[float] = None
389+
top_p: float = 1
393390
"""Decode using nucleus sampling: consider the smallest set of tokens whose
394391
probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
395392
random_seed: Optional[int] = None

libs/partners/mistralai/tests/unit_tests/__snapshots__/test_standard.ambr

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
]),
1010
'kwargs': dict({
1111
'endpoint': 'boo',
12+
'max_concurrent_requests': 64,
1213
'max_retries': 2,
1314
'max_tokens': 100,
1415
'mistral_api_key': dict({
@@ -21,6 +22,7 @@
2122
'model': 'mistral-small',
2223
'temperature': 0.0,
2324
'timeout': 60,
25+
'top_p': 1,
2426
}),
2527
'lc': 1,
2628
'name': 'ChatMistralAI',

libs/partners/openai/langchain_openai/chat_models/azure.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
7979
https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning
8080
timeout: Union[float, Tuple[float, float], Any, None]
8181
Timeout for requests.
82-
max_retries: Optional[int]
82+
max_retries: int
8383
Max number of retries.
8484
organization: Optional[str]
8585
OpenAI organization ID. If not passed in will be read from env
@@ -586,9 +586,9 @@ def is_lc_serializable(cls) -> bool:
586586
@model_validator(mode="after")
587587
def validate_environment(self) -> Self:
588588
"""Validate that api key and python package exists in environment."""
589-
if self.n is not None and self.n < 1:
589+
if self.n < 1:
590590
raise ValueError("n must be at least 1.")
591-
elif self.n is not None and self.n > 1 and self.streaming:
591+
if self.n > 1 and self.streaming:
592592
raise ValueError("n must be 1 when streaming.")
593593

594594
if self.disabled_params is None:
@@ -641,11 +641,10 @@ def validate_environment(self) -> Self:
641641
"organization": self.openai_organization,
642642
"base_url": self.openai_api_base,
643643
"timeout": self.request_timeout,
644+
"max_retries": self.max_retries,
644645
"default_headers": self.default_headers,
645646
"default_query": self.default_query,
646647
}
647-
if self.max_retries is not None:
648-
client_params["max_retries"] = self.max_retries
649648
if not self.client:
650649
sync_specific = {"http_client": self.http_client}
651650
self.root_client = openai.AzureOpenAI(**client_params, **sync_specific) # type: ignore[arg-type]

libs/partners/openai/langchain_openai/chat_models/base.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ class BaseChatOpenAI(BaseChatModel):
409409
root_async_client: Any = Field(default=None, exclude=True) #: :meta private:
410410
model_name: str = Field(default="gpt-3.5-turbo", alias="model")
411411
"""Model name to use."""
412-
temperature: Optional[float] = None
412+
temperature: float = 0.7
413413
"""What sampling temperature to use."""
414414
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
415415
"""Holds any model parameters valid for `create` call not explicitly specified."""
@@ -430,7 +430,7 @@ class BaseChatOpenAI(BaseChatModel):
430430
)
431431
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
432432
None."""
433-
max_retries: Optional[int] = None
433+
max_retries: int = 2
434434
"""Maximum number of retries to make when generating."""
435435
presence_penalty: Optional[float] = None
436436
"""Penalizes repeated tokens."""
@@ -448,7 +448,7 @@ class BaseChatOpenAI(BaseChatModel):
448448
"""Modify the likelihood of specified tokens appearing in the completion."""
449449
streaming: bool = False
450450
"""Whether to stream the results or not."""
451-
n: Optional[int] = None
451+
n: int = 1
452452
"""Number of chat completions to generate for each prompt."""
453453
top_p: Optional[float] = None
454454
"""Total probability mass of tokens to consider at each step."""
@@ -532,9 +532,9 @@ def validate_temperature(cls, values: Dict[str, Any]) -> Any:
532532
@model_validator(mode="after")
533533
def validate_environment(self) -> Self:
534534
"""Validate that api key and python package exists in environment."""
535-
if self.n is not None and self.n < 1:
535+
if self.n < 1:
536536
raise ValueError("n must be at least 1.")
537-
elif self.n is not None and self.n > 1 and self.streaming:
537+
if self.n > 1 and self.streaming:
538538
raise ValueError("n must be 1 when streaming.")
539539

540540
# Check OPENAI_ORGANIZATION for backwards compatibility.
@@ -551,12 +551,10 @@ def validate_environment(self) -> Self:
551551
"organization": self.openai_organization,
552552
"base_url": self.openai_api_base,
553553
"timeout": self.request_timeout,
554+
"max_retries": self.max_retries,
554555
"default_headers": self.default_headers,
555556
"default_query": self.default_query,
556557
}
557-
if self.max_retries is not None:
558-
client_params["max_retries"] = self.max_retries
559-
560558
if self.openai_proxy and (self.http_client or self.http_async_client):
561559
openai_proxy = self.openai_proxy
562560
http_client = self.http_client
@@ -611,14 +609,14 @@ def _default_params(self) -> Dict[str, Any]:
611609
"stop": self.stop or None, # also exclude empty list for this
612610
"max_tokens": self.max_tokens,
613611
"extra_body": self.extra_body,
614-
"n": self.n,
615-
"temperature": self.temperature,
616612
"reasoning_effort": self.reasoning_effort,
617613
}
618614

619615
params = {
620616
"model": self.model_name,
621617
"stream": self.streaming,
618+
"n": self.n,
619+
"temperature": self.temperature,
622620
**{k: v for k, v in exclude_if_none.items() if v is not None},
623621
**self.model_kwargs,
624622
}
@@ -1567,7 +1565,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
15671565
15681566
timeout: Union[float, Tuple[float, float], Any, None]
15691567
Timeout for requests.
1570-
max_retries: Optional[int]
1568+
max_retries: int
15711569
Max number of retries.
15721570
api_key: Optional[str]
15731571
OpenAI API key. If not passed in will be read from env var OPENAI_API_KEY.

libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_azure_standard.ambr

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
}),
1616
'max_retries': 2,
1717
'max_tokens': 100,
18+
'n': 1,
1819
'openai_api_key': dict({
1920
'id': list([
2021
'AZURE_OPENAI_API_KEY',

0 commit comments

Comments
 (0)