Skip to content

Commit 390e47c

Browse files
Merge pull request #14416 from xprilion/wandb-inference
Add W&B Inference to LiteLLM
2 parents fa20abf + 0f1de92 commit 390e47c

File tree

12 files changed

+395
-0
lines changed

12 files changed

+395
-0
lines changed

litellm/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
empower_models,
6161
together_ai_models,
6262
baseten_models,
63+
WANDB_MODELS,
6364
REPEATED_STREAMING_CHUNK_LIMIT,
6465
request_timeout,
6566
open_ai_embedding_models,
@@ -242,6 +243,7 @@
242243
snowflake_key: Optional[str] = None
243244
gradient_ai_api_key: Optional[str] = None
244245
nebius_key: Optional[str] = None
246+
wandb_key: Optional[str] = None
245247
heroku_key: Optional[str] = None
246248
cometapi_key: Optional[str] = None
247249
ovhcloud_key: Optional[str] = None
@@ -524,6 +526,7 @@ def identify(event_details):
524526
oci_models: Set = set()
525527
vercel_ai_gateway_models: Set = set()
526528
volcengine_models: Set = set()
529+
wandb_models: Set = set(WANDB_MODELS)
527530
ovhcloud_models: Set = set()
528531
ovhcloud_embedding_models: Set = set()
529532

@@ -740,6 +743,8 @@ def add_known_models():
740743
oci_models.add(key)
741744
elif value.get("litellm_provider") == "volcengine":
742745
volcengine_models.add(key)
746+
elif value.get("litellm_provider") == "wandb":
747+
wandb_models.add(key)
743748
elif value.get("litellm_provider") == "ovhcloud":
744749
ovhcloud_models.add(key)
745750
elif value.get("litellm_provider") == "ovhcloud-embedding-models":
@@ -838,6 +843,7 @@ def add_known_models():
838843
| heroku_models
839844
| vercel_ai_gateway_models
840845
| volcengine_models
846+
| wandb_models
841847
| ovhcloud_models
842848
)
843849

@@ -920,6 +926,7 @@ def add_known_models():
920926
"cometapi": cometapi_models,
921927
"oci": oci_models,
922928
"volcengine": volcengine_models,
929+
"wandb": wandb_models,
923930
"ovhcloud": ovhcloud_models | ovhcloud_embedding_models,
924931
}
925932

@@ -1259,6 +1266,7 @@ def add_known_models():
12591266
from .llms.watsonx.embed.transformation import IBMWatsonXEmbeddingConfig
12601267
from .llms.github_copilot.chat.transformation import GithubCopilotConfig
12611268
from .llms.nebius.chat.transformation import NebiusConfig
1269+
from .llms.wandb.chat.transformation import WandbConfig
12621270
from .llms.dashscope.chat.transformation import DashScopeChatConfig
12631271
from .llms.moonshot.chat.transformation import MoonshotChatConfig
12641272
from .llms.v0.chat.transformation import V0ChatConfig

litellm/constants.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@
313313
"morph",
314314
"lambda_ai",
315315
"vercel_ai_gateway",
316+
"wandb",
316317
"ovhcloud",
317318
]
318319

@@ -448,6 +449,7 @@
448449
"https://api.lambda.ai/v1",
449450
"https://api.hyperbolic.xyz/v1",
450451
"https://ai-gateway.vercel.sh/v1",
452+
"https://api.inference.wandb.ai/v1",
451453
]
452454

453455

@@ -492,6 +494,7 @@
492494
"hyperbolic",
493495
"vercel_ai_gateway",
494496
"aiml",
497+
"wandb",
495498
]
496499
openai_text_completion_compatible_providers: List = (
497500
[ # providers that support `/v1/completions`
@@ -507,6 +510,7 @@
507510
"v0",
508511
"lambda_ai",
509512
"hyperbolic",
513+
"wandb",
510514
]
511515
)
512516
_openai_like_providers: List = [
@@ -757,6 +761,38 @@
757761
]
758762
)
759763

764+
WANDB_MODELS: set = set(
765+
[
766+
# openai models
767+
"openai/gpt-oss-120b",
768+
"openai/gpt-oss-20b",
769+
770+
# zai-org models
771+
"zai-org/GLM-4.5",
772+
773+
# Qwen models
774+
"Qwen/Qwen3-235B-A22B-Instruct-2507",
775+
"Qwen/Qwen3-Coder-480B-A35B-Instruct",
776+
"Qwen/Qwen3-235B-A22B-Thinking-2507",
777+
778+
# moonshotai
779+
"moonshotai/Kimi-K2-Instruct",
780+
781+
# meta models
782+
"meta-llama/Llama-3.1-8B-Instruct",
783+
"meta-llama/Llama-3.3-70B-Instruct",
784+
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
785+
786+
# deepseek-ai
787+
"deepseek-ai/DeepSeek-V3.1",
788+
"deepseek-ai/DeepSeek-R1-0528",
789+
"deepseek-ai/DeepSeek-V3-0324",
790+
791+
# microsoft
792+
"microsoft/Phi-4-mini-instruct",
793+
]
794+
)
795+
760796
BEDROCK_INVOKE_PROVIDERS_LITERAL = Literal[
761797
"cohere",
762798
"anthropic",

litellm/litellm_core_utils/get_llm_provider_logic.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,9 @@ def get_llm_provider( # noqa: PLR0915
252252
elif endpoint == "https://ai-gateway.vercel.sh/v1":
253253
custom_llm_provider = "vercel_ai_gateway"
254254
dynamic_api_key = get_secret_str("VERCEL_AI_GATEWAY_API_KEY")
255+
elif endpoint == "https://api.inference.wandb.ai/v1":
256+
custom_llm_provider = "wandb"
257+
dynamic_api_key = get_secret_str("WANDB_API_KEY")
255258

256259
if api_base is not None and not isinstance(api_base, str):
257260
raise Exception(
@@ -773,6 +776,13 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915
773776
) = litellm.AIMLChatConfig()._get_openai_compatible_provider_info(
774777
api_base, api_key
775778
)
779+
elif custom_llm_provider == "wandb":
780+
api_base = (
781+
api_base
782+
or get_secret("WANDB_API_BASE")
783+
or "https://api.inference.wandb.ai/v1"
784+
) # type: ignore
785+
dynamic_api_key = api_key or get_secret_str("WANDB_API_KEY")
776786

777787
if api_base is not None and not isinstance(api_base, str):
778788
raise Exception("api base needs to be a string. api_base={}".format(api_base))

litellm/litellm_core_utils/get_supported_openai_params.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@ def get_supported_openai_params( # noqa: PLR0915
149149
elif custom_llm_provider == "nebius":
150150
if request_type == "chat_completion":
151151
return litellm.NebiusConfig().get_supported_openai_params(model=model)
152+
elif custom_llm_provider == "wandb":
153+
if request_type == "chat_completion":
154+
return litellm.WandbConfig().get_supported_openai_params(model=model)
152155
elif custom_llm_provider == "replicate":
153156
return litellm.ReplicateConfig().get_supported_openai_params(model=model)
154157
elif custom_llm_provider == "huggingface":

litellm/llms/wandb/__init__.py

Whitespace-only changes.

litellm/llms/wandb/chat/__init__.py

Whitespace-only changes.
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
"""
2+
Wandb Chat Completions API - Transformation
3+
4+
This is OpenAI compatible - no translation needed / occurs
5+
"""
6+
7+
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
8+
9+
10+
class WandbConfig(OpenAIGPTConfig):
11+
def map_openai_params(
12+
self,
13+
non_default_params: dict,
14+
optional_params: dict,
15+
model: str,
16+
drop_params: bool,
17+
) -> dict:
18+
"""
19+
map max_completion_tokens param to max_tokens
20+
"""
21+
supported_openai_params = self.get_supported_openai_params(model=model)
22+
for param, value in non_default_params.items():
23+
if param == "max_completion_tokens":
24+
optional_params["max_tokens"] = value
25+
elif param in supported_openai_params:
26+
optional_params[param] = value
27+
return optional_params

litellm/main.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1981,6 +1981,7 @@ def completion( # type: ignore # noqa: PLR0915
19811981
or custom_llm_provider == "openai"
19821982
or custom_llm_provider == "together_ai"
19831983
or custom_llm_provider == "nebius"
1984+
or custom_llm_provider == "wandb"
19841985
or custom_llm_provider in litellm.openai_compatible_providers
19851986
or "ft:gpt-3.5-turbo" in model # finetune gpt-3.5-turbo
19861987
): # allow user to make an openai call with a custom base
@@ -4400,6 +4401,27 @@ def embedding( # noqa: PLR0915
44004401
or "api.studio.nebius.ai/v1"
44014402
)
44024403

4404+
response = openai_chat_completions.embedding(
4405+
model=model,
4406+
input=input,
4407+
api_base=api_base,
4408+
api_key=api_key,
4409+
logging_obj=logging,
4410+
timeout=timeout,
4411+
model_response=EmbeddingResponse(),
4412+
optional_params=optional_params,
4413+
client=client,
4414+
aembedding=aembedding,
4415+
)
4416+
elif custom_llm_provider == "wandb":
4417+
api_key = api_key or litellm.api_key or get_secret_str("WANDB_API_KEY")
4418+
api_base = (
4419+
api_base
4420+
or litellm.api_base
4421+
or get_secret_str("WANDB_API_BASE")
4422+
or "https://api.inference.wandb.ai/v1"
4423+
)
4424+
44034425
response = openai_chat_completions.embedding(
44044426
model=model,
44054427
input=input,

litellm/types/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2397,6 +2397,7 @@ class LlmProviders(str, Enum):
23972397
AUTO_ROUTER = "auto_router"
23982398
VERCEL_AI_GATEWAY = "vercel_ai_gateway"
23992399
DOTPROMPT = "dotprompt"
2400+
WANDB = "wandb"
24002401
OVHCLOUD = "ovhcloud"
24012402

24022403

litellm/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3275,6 +3275,7 @@ def pre_process_optional_params(
32753275
and custom_llm_provider != "openrouter"
32763276
and custom_llm_provider != "vercel_ai_gateway"
32773277
and custom_llm_provider != "nebius"
3278+
and custom_llm_provider != "wandb"
32783279
and custom_llm_provider not in litellm.openai_compatible_providers
32793280
):
32803281
if custom_llm_provider == "ollama":
@@ -4446,6 +4447,9 @@ def get_api_key(llm_provider: str, dynamic_api_key: Optional[str]):
44464447
# nebius
44474448
elif llm_provider == "nebius":
44484449
api_key = api_key or litellm.nebius_key or get_secret("NEBIUS_API_KEY")
4450+
# wandb
4451+
elif llm_provider == "wandb":
4452+
api_key = api_key or litellm.wandb_key or get_secret("WANDB_API_KEY")
44494453
return api_key
44504454

44514455

@@ -5530,6 +5534,11 @@ def validate_environment( # noqa: PLR0915
55305534
keys_in_environment = True
55315535
else:
55325536
missing_keys.append("NEBIUS_API_KEY")
5537+
elif custom_llm_provider == "wandb":
5538+
if "WANDB_API_KEY" in os.environ:
5539+
keys_in_environment = True
5540+
else:
5541+
missing_keys.append("WANDB_API_KEY")
55335542
elif custom_llm_provider == "dashscope":
55345543
if "DASHSCOPE_API_KEY" in os.environ:
55355544
keys_in_environment = True
@@ -5644,6 +5653,11 @@ def validate_environment( # noqa: PLR0915
56445653
keys_in_environment = True
56455654
else:
56465655
missing_keys.append("NEBIUS_API_KEY")
5656+
elif model in litellm.wandb_models:
5657+
if "WANDB_API_KEY" in os.environ:
5658+
keys_in_environment = True
5659+
else:
5660+
missing_keys.append("WANDB_API_KEY")
56475661

56485662
def filter_missing_keys(keys: List[str], exclude_pattern: str) -> List[str]:
56495663
"""Filter out keys that contain the exclude_pattern (case insensitive)."""
@@ -7046,6 +7060,8 @@ def get_provider_chat_config( # noqa: PLR0915
70467060
return litellm.NovitaConfig()
70477061
elif litellm.LlmProviders.NEBIUS == provider:
70487062
return litellm.NebiusConfig()
7063+
elif litellm.LlmProviders.WANDB == provider:
7064+
return litellm.WandbConfig()
70497065
elif litellm.LlmProviders.DASHSCOPE == provider:
70507066
return litellm.DashScopeChatConfig()
70517067
elif litellm.LlmProviders.MOONSHOT == provider:

0 commit comments

Comments
 (0)