Skip to content

Commit 8b1f54c

Browse files
authored
1 parent 443341a commit 8b1f54c

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

libs/partners/xai/langchain_xai/chat_models.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,8 @@ class Joke(BaseModel):
275275
"""
276276
xai_api_base: str = Field(default="https://api.x.ai/v1/")
277277
"""Base URL path for API requests."""
278+
search_parameters: Optional[dict[str, Any]] = None
279+
"""Parameters for search requests. Example: ``{"mode": "auto"}``."""
278280

279281
openai_api_key: Optional[SecretStr] = None
280282
openai_api_base: Optional[str] = None
@@ -371,6 +373,18 @@ def validate_environment(self) -> Self:
371373
)
372374
return self
373375

376+
@property
377+
def _default_params(self) -> dict[str, Any]:
378+
"""Get default parameters."""
379+
params = super()._default_params
380+
if self.search_parameters:
381+
if "extra_body" in params:
382+
params["extra_body"]["search_parameters"] = self.search_parameters
383+
else:
384+
params["extra_body"] = {"search_parameters": self.search_parameters}
385+
386+
return params
387+
374388
def _create_chat_result(
375389
self,
376390
response: Union[dict, openai.BaseModel],
@@ -386,6 +400,11 @@ def _create_chat_result(
386400
response.choices[0].message.reasoning_content # type: ignore
387401
)
388402

403+
if hasattr(response, "citations"):
404+
rtn.generations[0].message.additional_kwargs["citations"] = (
405+
response.citations
406+
)
407+
389408
return rtn
390409

391410
def _convert_chunk_to_generation_chunk(
@@ -407,6 +426,10 @@ def _convert_chunk_to_generation_chunk(
407426
reasoning_content
408427
)
409428

429+
if (citations := chunk.get("citations")) and generation_chunk:
430+
if isinstance(generation_chunk.message, AIMessageChunk):
431+
generation_chunk.message.additional_kwargs["citations"] = citations
432+
410433
return generation_chunk
411434

412435
def with_structured_output(

libs/partners/xai/tests/integration_tests/test_chat_models_standard.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,24 @@ def test_reasoning_content() -> None:
4848
full = chunk if full is None else full + chunk
4949
assert isinstance(full, AIMessageChunk)
5050
assert full.additional_kwargs["reasoning_content"]
51+
52+
53+
def test_web_search() -> None:
54+
llm = ChatXAI(
55+
model="grok-3-latest",
56+
search_parameters={"mode": "auto", "max_search_results": 3},
57+
)
58+
59+
# Test invoke
60+
response = llm.invoke("Provide me a digest of world news in the last 24 hours.")
61+
assert response.content
62+
assert response.additional_kwargs["citations"]
63+
assert len(response.additional_kwargs["citations"]) <= 3
64+
65+
# Test streaming
66+
full = None
67+
for chunk in llm.stream("Provide me a digest of world news in the last 24 hours."):
68+
full = chunk if full is None else full + chunk
69+
assert isinstance(full, AIMessageChunk)
70+
assert full.additional_kwargs["citations"]
71+
assert len(full.additional_kwargs["citations"]) <= 3

0 commit comments

Comments
 (0)