Skip to content

Commit 3ec45c4

Browse files
tomsun28Wauplin
authored andcommitted
feat: support zai as inference provider (#3395)
* feat: support zai as inference provider * update * update * update * update
1 parent 0e46a48 commit 3ec45c4

File tree

6 files changed

+48
-2
lines changed

6 files changed

+48
-2
lines changed

src/huggingface_hub/inference/_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ class InferenceClient:
130130
Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
131131
arguments are mutually exclusive. If a URL is passed as `model` or `base_url` for chat completion, the `(/v1)/chat/completions` suffix path will be appended to the URL.
132132
provider (`str`, *optional*):
133-
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `publicai`, `"replicate"`, `"sambanova"`, `"scaleway"` or `"together"`.
133+
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `publicai`, `"replicate"`, `"sambanova"`, `"scaleway"`, `"together"` or `"zai-org"`.
134134
Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
135135
If model is a URL or `base_url` is passed, then `provider` is not used.
136136
token (`str`, *optional*):

src/huggingface_hub/inference/_generated/_async_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ class AsyncInferenceClient:
118118
Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
119119
arguments are mutually exclusive. If a URL is passed as `model` or `base_url` for chat completion, the `(/v1)/chat/completions` suffix path will be appended to the URL.
120120
provider (`str`, *optional*):
121-
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `publicai`, `"replicate"`, `"sambanova"`, `"scaleway"` or `"together"`.
121+
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `publicai`, `"replicate"`, `"sambanova"`, `"scaleway"`, `"together"` or `"zai-org"`.
122122
Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
123123
If model is a URL or `base_url` is passed, then `provider` is not used.
124124
token (`str`, *optional*):

src/huggingface_hub/inference/_providers/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from .sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask
4242
from .scaleway import ScalewayConversationalTask, ScalewayFeatureExtractionTask
4343
from .together import TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask
44+
from .zai_org import ZaiConversationalTask
4445

4546

4647
logger = logging.get_logger(__name__)
@@ -65,6 +66,7 @@
6566
"sambanova",
6667
"scaleway",
6768
"together",
69+
"zai-org",
6870
]
6971

7072
PROVIDER_OR_POLICY_T = Union[PROVIDER_T, Literal["auto"]]
@@ -170,6 +172,9 @@
170172
"conversational": TogetherConversationalTask(),
171173
"text-generation": TogetherTextGenerationTask(),
172174
},
175+
"zai-org": {
176+
"conversational": ZaiConversationalTask(),
177+
},
173178
}
174179

175180

src/huggingface_hub/inference/_providers/_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
"sambanova": {},
3636
"scaleway": {},
3737
"together": {},
38+
"zai-org": {},
3839
}
3940

4041

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from typing import Any, Dict
2+
3+
from huggingface_hub.inference._providers._common import BaseConversationalTask
4+
5+
6+
class ZaiConversationalTask(BaseConversationalTask):
7+
def __init__(self):
8+
super().__init__(provider="zai-org", base_url="https://api.z.ai")
9+
10+
def _prepare_headers(self, headers: Dict, api_key: str) -> Dict[str, Any]:
11+
headers = super()._prepare_headers(headers, api_key)
12+
headers["Accept-Language"] = "en-US,en"
13+
headers["x-source-channel"] = "hugging_face"
14+
return headers
15+
16+
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
17+
return "/api/paas/v4/chat/completions"

tests/test_inference_providers.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from huggingface_hub.inference._providers.sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask
5454
from huggingface_hub.inference._providers.scaleway import ScalewayConversationalTask, ScalewayFeatureExtractionTask
5555
from huggingface_hub.inference._providers.together import TogetherTextToImageTask
56+
from huggingface_hub.inference._providers.zai_org import ZaiConversationalTask
5657

5758
from .testing_utils import assert_in_logs
5859

@@ -1412,6 +1413,28 @@ def test_text_to_image_get_response(self):
14121413
assert response == b"image_bytes"
14131414

14141415

1416+
class TestZaiProvider:
1417+
def test_prepare_route(self):
1418+
helper = ZaiConversationalTask()
1419+
route = helper._prepare_route("test-model", "zai_token")
1420+
assert route == "/api/paas/v4/chat/completions"
1421+
1422+
def test_prepare_headers(self):
1423+
helper = ZaiConversationalTask()
1424+
headers = helper._prepare_headers({}, "test_key")
1425+
assert headers["Accept-Language"] == "en-US,en"
1426+
1427+
def test_prepare_url(self):
1428+
helper = ZaiConversationalTask()
1429+
assert helper.task == "conversational"
1430+
url = helper._prepare_url("zai_token", "test-model")
1431+
assert url == "https://api.z.ai/api/paas/v4/chat/completions"
1432+
1433+
# Test with HF token (should route through HF proxy)
1434+
url = helper._prepare_url("hf_token", "test-model")
1435+
assert url.startswith("https://router.huggingface.co/zai-org")
1436+
1437+
14151438
class TestBaseConversationalTask:
14161439
def test_prepare_route(self):
14171440
helper = BaseConversationalTask(provider="test-provider", base_url="https://api.test.com")

0 commit comments

Comments
 (0)