diff --git a/pyproject.toml b/pyproject.toml index ea7fa4ae1..777d80ec1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,6 +78,7 @@ dependencies = [ "datasets", "colorlog>=6.9.0", "langsmith", + "langchain-xai>=0.2.1", ] license = { text = "Apache-2.0" } diff --git a/src/codegen/extensions/langchain/llm.py b/src/codegen/extensions/langchain/llm.py index 1aafce31c..d2cec02a3 100644 --- a/src/codegen/extensions/langchain/llm.py +++ b/src/codegen/extensions/langchain/llm.py @@ -13,6 +13,7 @@ from langchain_core.runnables import Runnable from langchain_core.tools import BaseTool from langchain_openai import ChatOpenAI +from langchain_xai import ChatXAI from pydantic import Field @@ -76,6 +77,9 @@ def _get_model_kwargs(self) -> dict[str, Any]: if self.model_provider == "anthropic": return {**base_kwargs, "model": self.model_name} + elif self.model_provider == "xai": + xai_api_base = os.getenv("XAI_API_BASE", "https://api.x.ai/v1/") + return {**base_kwargs, "model": self.model_name, "xai_api_base": xai_api_base} else: # openai return {**base_kwargs, "model": self.model_name} @@ -93,7 +97,13 @@ def _get_model(self) -> BaseChatModel: raise ValueError(msg) return ChatOpenAI(**self._get_model_kwargs()) - msg = f"Unknown model provider: {self.model_provider}. Must be one of: anthropic, openai" + elif self.model_provider == "xai": + if not os.getenv("XAI_API_KEY"): + msg = "XAI_API_KEY not found in environment. Please set it in your .env file or environment variables." + raise ValueError(msg) + return ChatXAI(**self._get_model_kwargs()) + + msg = f"Unknown model provider: {self.model_provider}. Must be one of: anthropic, openai, xai" raise ValueError(msg) def _generate( diff --git a/uv.lock b/uv.lock index 86bc5b166..9796f29bb 100644 --- a/uv.lock +++ b/uv.lock @@ -561,6 +561,7 @@ dependencies = [ { name = "langchain-anthropic" }, { name = "langchain-core" }, { name = "langchain-openai" }, + { name = "langchain-xai" }, { name = "langgraph" }, { name = "langgraph-prebuilt" }, { name = "langsmith" }, @@ -689,6 +690,7 @@ requires-dist = [ { name = "langchain-anthropic", specifier = ">=0.3.7" }, { name = "langchain-core" }, { name = "langchain-openai" }, + { name = "langchain-xai" }, { name = "langgraph" }, { name = "langgraph-prebuilt" }, { name = "langsmith" }, @@ -2109,6 +2111,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4c/f8/6b82af988e65af9697f6a2f25373fb173fd32d48b62772a8773c5184c870/langchain_text_splitters-0.3.6-py3-none-any.whl", hash = "sha256:e5d7b850f6c14259ea930be4a964a65fa95d9df7e1dbdd8bad8416db72292f4e", size = 31197 }, ] +[[package]] +name = "langchain-xai" +version = "0.2.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohttp" }, + { name = "langchain-core" }, + { name = "langchain-openai" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a3/94/a633bf1b4bbf66e4516f4188adc1174480c465ae12fb98f06c3e23c98519/langchain_xai-0.2.1.tar.gz", hash = "sha256:143a6f52be7617b5e5c68ab10c9b7df90914f54a6b3098566ce22b5d8fd89da5", size = 7788 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7b/88/d8050e610fadabf97c1745d24f0987b3e53b72fca63c8038ab1e0c103da9/langchain_xai-0.2.1-py3-none-any.whl", hash = "sha256:87228125cb15131663979d627210fca47dcd6b9a28462e8b5fee47f73bbed9f4", size = 6263 }, +] + [[package]] name = "langgraph" version = "0.3.2"