Skip to content

Commit 3c43c2d

Browse files
authored
Add grok-4 and groq kimi-k2 models (#2235)
1 parent babc23b commit 3c43c2d

File tree

10 files changed

+62
-2
lines changed

10 files changed

+62
-2
lines changed

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,13 +185,22 @@
185185
'gpt-4o-mini-search-preview-2025-03-11',
186186
'gpt-4o-search-preview',
187187
'gpt-4o-search-preview-2025-03-11',
188+
'grok:grok-4',
189+
'grok:grok-4-0709',
190+
'grok:grok-3',
191+
'grok:grok-3-mini',
192+
'grok:grok-3-fast',
193+
'grok:grok-3-mini-fast',
194+
'grok:grok-2-vision-1212',
195+
'grok:grok-2-image-1212',
188196
'groq:distil-whisper-large-v3-en',
189197
'groq:gemma2-9b-it',
190198
'groq:llama-3.3-70b-versatile',
191199
'groq:llama-3.1-8b-instant',
192200
'groq:llama-guard-3-8b',
193201
'groq:llama3-70b-8192',
194202
'groq:llama3-8b-8192',
203+
'groq:moonshotai/kimi-k2-instruct',
195204
'groq:whisper-large-v3',
196205
'groq:whisper-large-v3-turbo',
197206
'groq:playai-tts',

pydantic_ai_slim/pydantic_ai/models/groq.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
'llama-3.2-3b-preview',
8080
'llama-3.2-11b-vision-preview',
8181
'llama-3.2-90b-vision-preview',
82+
'moonshotai/kimi-k2-instruct',
8283
]
8384
"""Preview Groq models from <https://console.groq.com/docs/models#preview-models>."""
8485

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from __future__ import annotations as _annotations
2+
3+
from . import ModelProfile
4+
5+
6+
def moonshotai_model_profile(model_name: str) -> ModelProfile | None:
7+
"""Get the model profile for a MoonshotAI model."""
8+
return None

pydantic_ai_slim/pydantic_ai/providers/grok.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations as _annotations
22

33
import os
4-
from typing import overload
4+
from typing import Literal, overload
55

66
from httpx import AsyncClient as AsyncHTTPClient
77
from openai import AsyncOpenAI
@@ -21,6 +21,18 @@
2121
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
2222
) from _import_error
2323

24+
# https://docs.x.ai/docs/models
25+
GrokModelName = Literal[
26+
'grok-4',
27+
'grok-4-0709',
28+
'grok-3',
29+
'grok-3-mini',
30+
'grok-3-fast',
31+
'grok-3-mini-fast',
32+
'grok-2-vision-1212',
33+
'grok-2-image-1212',
34+
]
35+
2436

2537
class GrokProvider(Provider[AsyncOpenAI]):
2638
"""Provider for Grok API."""

pydantic_ai_slim/pydantic_ai/providers/groq.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pydantic_ai.profiles.google import google_model_profile
1313
from pydantic_ai.profiles.meta import meta_model_profile
1414
from pydantic_ai.profiles.mistral import mistral_model_profile
15+
from pydantic_ai.profiles.moonshotai import moonshotai_model_profile
1516
from pydantic_ai.profiles.qwen import qwen_model_profile
1617
from pydantic_ai.providers import Provider
1718

@@ -47,6 +48,7 @@ def model_profile(self, model_name: str) -> ModelProfile | None:
4748
'qwen': qwen_model_profile,
4849
'deepseek': deepseek_model_profile,
4950
'mistral': mistral_model_profile,
51+
'moonshotai/': moonshotai_model_profile,
5052
}
5153

5254
for prefix, profile_func in prefix_to_profile.items():

tests/models/cassettes/test_model_names/test_known_model_names.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ interactions:
1515
response:
1616
headers:
1717
content-length:
18-
- '545'
18+
- '550'
1919
content-security-policy:
2020
- default-src 'none'; frame-ancestors 'none'
2121
content-type:
@@ -46,6 +46,7 @@ interactions:
4646
- text-to-text
4747
- model_id: claude-4-sonnet
4848
regions:
49+
- eu
4950
- us
5051
type:
5152
- text-to-text

tests/models/test_model.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,22 @@
7272
'github',
7373
'OpenAIModel',
7474
),
75+
(
76+
'GROK_API_KEY',
77+
'grok:grok-3',
78+
'grok-3',
79+
'grok',
80+
'grok',
81+
'OpenAIModel',
82+
),
83+
(
84+
'GROK_API_KEY',
85+
'grok-4', # Note that the provider and model name are both "grok", so the plain string grok with no prefix works because its also the provider name
86+
'grok-4',
87+
'grok',
88+
'grok',
89+
'OpenAIModel',
90+
),
7591
]
7692

7793

tests/models/test_model_names.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pydantic_ai.models.huggingface import HuggingFaceModelName
2020
from pydantic_ai.models.mistral import MistralModelName
2121
from pydantic_ai.models.openai import OpenAIModelName
22+
from pydantic_ai.providers.grok import GrokModelName
2223

2324
pytestmark = [
2425
pytest.mark.skipif(not imports_successful(), reason='some model package was not installed'),
@@ -48,6 +49,7 @@ def get_model_names(model_name_type: Any) -> Iterator[str]:
4849
google_names = [f'google-gla:{n}' for n in get_model_names(GeminiModelName)] + [
4950
f'google-vertex:{n}' for n in get_model_names(GeminiModelName)
5051
]
52+
grok_names = [f'grok:{n}' for n in get_model_names(GrokModelName)]
5153
groq_names = [f'groq:{n}' for n in get_model_names(GroqModelName)]
5254
mistral_names = [f'mistral:{n}' for n in get_model_names(MistralModelName)]
5355
openai_names = [f'openai:{n}' for n in get_model_names(OpenAIModelName)] + [
@@ -63,6 +65,7 @@ def get_model_names(model_name_type: Any) -> Iterator[str]:
6365
anthropic_names
6466
+ cohere_names
6567
+ google_names
68+
+ grok_names
6669
+ groq_names
6770
+ mistral_names
6871
+ openai_names

tests/providers/test_groq.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pydantic_ai.profiles.google import GoogleJsonSchemaTransformer, google_model_profile
1313
from pydantic_ai.profiles.meta import meta_model_profile
1414
from pydantic_ai.profiles.mistral import mistral_model_profile
15+
from pydantic_ai.profiles.moonshotai import moonshotai_model_profile
1516
from pydantic_ai.profiles.qwen import qwen_model_profile
1617

1718
from ..conftest import TestEnv, try_import
@@ -74,6 +75,7 @@ def test_groq_provider_model_profile(mocker: MockerFixture):
7475
google_model_profile_mock = mocker.patch(f'{ns}.google_model_profile', wraps=google_model_profile)
7576
mistral_model_profile_mock = mocker.patch(f'{ns}.mistral_model_profile', wraps=mistral_model_profile)
7677
qwen_model_profile_mock = mocker.patch(f'{ns}.qwen_model_profile', wraps=qwen_model_profile)
78+
moonshotai_model_profile_mock = mocker.patch(f'{ns}.moonshotai_model_profile', wraps=moonshotai_model_profile)
7779

7880
meta_profile = provider.model_profile('meta-llama/Llama-Guard-4-12B')
7981
meta_model_profile_mock.assert_called_with('llama-guard-4-12b')
@@ -103,5 +105,10 @@ def test_groq_provider_model_profile(mocker: MockerFixture):
103105
assert qwen_profile is not None
104106
assert qwen_profile.json_schema_transformer == InlineDefsJsonSchemaTransformer
105107

108+
# MoonshotAI model should remove the "moonshotai/" prefix before passing to profile
109+
moonshotai_profile = provider.model_profile('moonshotai/kimi-k2-instruct')
110+
moonshotai_model_profile_mock.assert_called_with('kimi-k2-instruct')
111+
assert moonshotai_profile is None
112+
106113
unknown_profile = provider.model_profile('unknown-model')
107114
assert unknown_profile is None

tests/test_cli.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def test_list_models(capfd: CaptureFixture[str]):
144144
'cohere',
145145
'deepseek',
146146
'heroku',
147+
'grok',
147148
'huggingface',
148149
)
149150
models = {line.strip().split(' ')[0] for line in output[3:]}

0 commit comments

Comments
 (0)