Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ UNBOUND_API_KEY=
SiliconFLOW_ENDPOINT=https://api.siliconflow.cn/v1/
SiliconFLOW_API_KEY=

XINFERENCE_OPENAI_ENDPOINT=https://api.xinference.com/v1
XINFERENCE_MODEL=
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing documentation for expected values for XINFERENCE_MODEL environment variable


# Set to false to disable anonymized telemetry
ANONYMIZED_TELEMETRY=false

Expand Down
16 changes: 16 additions & 0 deletions src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,18 @@ def get_llm_model(provider: str, **kwargs):
model_name=kwargs.get("model_name", "Qwen/QwQ-32B"),
temperature=kwargs.get("temperature", 0.0),
)
elif provider == "xinference":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing entry for "xinference" in PROVIDER_DISPLAY_NAMES dictionary

if not kwargs.get("base_url", ""):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

API key retrieval uses a different environment variable pattern than what's defined in .env.example

base_url = os.getenv("XINFERENCE_OPENAI_ENDPOINT", "https://api.xinference.com/v1")
else:
base_url = kwargs.get("base_url")

return ChatOpenAI(
model=kwargs.get("model_name", "gpt-4o"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default model for Xinference provider is set to "gpt-4o" which is not in the xinference model list

temperature=kwargs.get("temperature", 0.0),
base_url=base_url,
api_key=api_key,
)
else:
raise ValueError(f"Unsupported provider: {provider}")

Expand Down Expand Up @@ -234,6 +246,10 @@ def get_llm_model(provider: str, **kwargs):
"Pro/THUDM/chatglm3-6b",
"Pro/THUDM/glm-4-9b-chat",
],
"xinference": ["qwen2.5-instruct", "qwen2.5", "qwen2.5-coder", "qwen2.5-coder-instruct", "qwen2.5-instruct-1m",
"qwen2.5-vl-instruct", "deepseek", "deepseek-chat", "deepseek-coder", "deepseek-coder-instruct",
"deepseek-r1", "deepseek-v2", "deepseek-v2-chat", "deepseek-v2-chat-0628", "deepseek-v2.5",
"deepseek-v3", "deepseek-vl-chat", "deepseek-vl2"]
}


Expand Down