Skip to content

Commit c050590

Browse files
lars20070DouweM
andauthored
type stubs for some third-party libraries (#3443)
Co-authored-by: Douwe Maan <[email protected]>
1 parent 05fbaaa commit c050590

28 files changed

+274
-50
lines changed

Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ typecheck-pyright:
4343
.PHONY: typecheck-mypy
4444
typecheck-mypy:
4545
uv run mypy
46+
uv run mypy typings/ --strict
4647

4748
.PHONY: typecheck
4849
typecheck: typecheck-pyright ## Run static type checking

pydantic_ai_slim/pydantic_ai/models/outlines.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
)
6161
from outlines.models.vllm_offline import (
6262
VLLMOffline,
63-
from_vllm_offline, # pyright: ignore[reportUnknownVariableType]
63+
from_vllm_offline,
6464
)
6565
from outlines.types.dsl import JsonSchema
6666
from PIL import Image as PILImage
@@ -393,7 +393,7 @@ def _format_vllm_offline_inference_kwargs( # pragma: no cover
393393
self, model_settings: dict[str, Any]
394394
) -> dict[str, Any]:
395395
"""Select the model settings supported by the vLLMOffline model."""
396-
from vllm.sampling_params import SamplingParams # pyright: ignore
396+
from vllm.sampling_params import SamplingParams
397397

398398
supported_args = [
399399
'max_tokens',

pydantic_ai_slim/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ bedrock = ["boto3>=1.40.14"]
7979
huggingface = ["huggingface-hub[inference]>=0.33.5"]
8080
outlines-transformers = ["outlines[transformers]>=1.0.0, <1.3.0; (sys_platform != 'darwin' or platform_machine != 'x86_64')", "transformers>=4.0.0", "pillow", "torch; (sys_platform != 'darwin' or platform_machine != 'x86_64')"]
8181
outlines-llamacpp = ["outlines[llamacpp]>=1.0.0, <1.3.0"]
82-
outlines-mlxlm = ["outlines[mlxlm]>=1.0.0, <1.3.0; (sys_platform != 'darwin' or platform_machine != 'x86_64')"]
82+
outlines-mlxlm = ["outlines[mlxlm]>=1.0.0, <1.3.0; platform_system == 'Darwin' and platform_machine == 'arm64'"]
8383
outlines-sglang = ["outlines[sglang]>=1.0.0, <1.3.0", "pillow"]
8484
outlines-vllm-offline = ["vllm; python_version < '3.12' and (sys_platform != 'darwin' or platform_machine != 'x86_64')", "torch; (sys_platform != 'darwin' or platform_machine != 'x86_64')", "outlines>=1.0.0, <1.3.0"]
8585
# Tools

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ dbos = ["pydantic-ai-slim[dbos]=={{ version }}"]
5656
prefect = ["pydantic-ai-slim[prefect]=={{ version }}"]
5757
outlines-transformers = ["pydantic-ai-slim[outlines-transformers]=={{ version }}"]
5858
outlines-llamacpp = ["pydantic-ai-slim[outlines-llamacpp]=={{ version }}"]
59-
outlines-mlxlm = ["pydantic-ai-slim[outlines-mlxlm]=={{ version }}"]
59+
outlines-mlxlm = ["pydantic-ai-slim[outlines-mlxlm]=={{ version }}; platform_system == 'Darwin' and platform_machine == 'arm64'"]
6060
outlines-sglang = ["pydantic-ai-slim[outlines-sglang]=={{ version }}"]
6161
outlines-vllm-offline = ["pydantic-ai-slim[outlines-vllm-offline]=={{ version }}"]
6262

@@ -142,6 +142,7 @@ include = [
142142
"clai/**/*.py",
143143
"tests/**/*.py",
144144
"docs/**/*.py",
145+
"typings/**/*.pyi",
145146
]
146147

147148
[tool.ruff.lint]
@@ -186,8 +187,10 @@ quote-style = "single"
186187
"examples/**/*.py" = ["D101", "D103"]
187188
"tests/**/*.py" = ["D"]
188189
"docs/**/*.py" = ["D"]
190+
"typings/**/*.pyi" = ["F401", "PYI044", "PYI035", "ANN401"]
189191

190192
[tool.pyright]
193+
stubPath = "typings"
191194
pythonVersion = "3.12"
192195
typeCheckingMode = "strict"
193196
reportMissingTypeStubs = false
@@ -217,6 +220,7 @@ exclude = [
217220
[tool.mypy]
218221
files = "tests/typed_agent.py,tests/typed_graph.py"
219222
strict = true
223+
mypy_path = "typings"
220224

221225
[tool.pytest.ini_options]
222226
testpaths = ["tests", "docs/.hooks"]

tests/models/test_outlines.py

Lines changed: 33 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
with try_import() as imports_successful:
4545
import outlines
4646

47-
from pydantic_ai.models.outlines import OutlinesModel
47+
from pydantic_ai.models.outlines import OutlinesAsyncBaseModel, OutlinesModel
4848
from pydantic_ai.providers.outlines import OutlinesProvider
4949

5050
with try_import() as transformer_imports_successful:
@@ -54,11 +54,11 @@
5454
import llama_cpp
5555

5656
with try_import() as vllm_imports_successful:
57-
import vllm # type: ignore[reportMissingImports]
57+
import vllm
5858

5959
# We try to load the vllm model to ensure it is available
6060
try: # pragma: no lax cover
61-
vllm.LLM('microsoft/Phi-3-mini-4k-instruct') # type: ignore
61+
vllm.LLM('microsoft/Phi-3-mini-4k-instruct')
6262
except RuntimeError as e: # pragma: lax no cover
6363
if 'Found no NVIDIA driver' in str(e) or 'Device string must not be empty' in str(e):
6464
# Treat as import failure
@@ -97,68 +97,67 @@
9797

9898
@pytest.fixture
9999
def mock_async_model() -> OutlinesModel:
100-
class MockOutlinesAsyncModel(outlines.models.base.AsyncModel):
100+
class MockOutlinesAsyncModel(OutlinesAsyncBaseModel):
101101
"""Mock an OutlinesAsyncModel because no Outlines local models have an async version.
102102
103103
The `__call__` and `stream` methods will be called by the Pydantic AI model while the other methods are
104104
only implemented because they are abstract methods in the OutlinesAsyncModel class.
105105
"""
106106

107-
async def __call__(self, model_input, output_type, backend, **inference_kwargs): # type: ignore[reportMissingParameterType]
107+
async def __call__(self, model_input: Any, output_type: Any, backend: Any, **inference_kwargs: Any) -> str:
108108
return 'test'
109109

110-
async def stream(self, model_input, output_type, backend, **inference_kwargs): # type: ignore[reportMissingParameterType]
110+
async def stream(self, model_input: Any, output_type: Any, backend: Any, **inference_kwargs: Any):
111111
for _ in range(2):
112112
yield 'test'
113113

114-
async def generate(self, model_input, output_type, **inference_kwargs): # type: ignore[reportMissingParameterType]
115-
... # pragma: no cover
114+
async def generate(self, model_input: Any, output_type: Any, **inference_kwargs: Any): ... # pragma: no cover
116115

117-
async def generate_batch(self, model_input, output_type, **inference_kwargs): # type: ignore[reportMissingParameterType]
118-
... # pragma: no cover
116+
async def generate_batch(
117+
self, model_input: Any, output_type: Any, **inference_kwargs: Any
118+
): ... # pragma: no cover
119119

120-
async def generate_stream(self, model_input, output_type, **inference_kwargs): # type: ignore[reportMissingParameterType]
121-
... # pragma: no cover
120+
async def generate_stream(
121+
self, model_input: Any, output_type: Any, **inference_kwargs: Any
122+
): ... # pragma: no cover
122123

123124
return OutlinesModel(MockOutlinesAsyncModel(), provider=OutlinesProvider())
124125

125126

126127
@pytest.fixture
127128
def transformers_model() -> OutlinesModel:
128-
hf_model = transformers.AutoModelForCausalLM.from_pretrained( # type: ignore
129+
hf_model = transformers.AutoModelForCausalLM.from_pretrained(
129130
'erwanf/gpt2-mini',
130131
device_map='cpu',
131132
)
132-
hf_tokenizer = transformers.AutoTokenizer.from_pretrained('erwanf/gpt2-mini') # type: ignore
133+
hf_tokenizer = transformers.AutoTokenizer.from_pretrained('erwanf/gpt2-mini')
133134
chat_template = '{% for message in messages %}{{ message.role }}: {{ message.content }}{% endfor %}'
134135
hf_tokenizer.chat_template = chat_template
135136
outlines_model = outlines.models.transformers.from_transformers(
136-
hf_model, # type: ignore[reportUnknownArgumentType]
137-
hf_tokenizer, # type: ignore
137+
hf_model,
138+
hf_tokenizer,
138139
)
139140
return OutlinesModel(outlines_model, provider=OutlinesProvider())
140141

141142

142143
@pytest.fixture
143144
def transformers_multimodal_model() -> OutlinesModel:
144-
hf_model = transformers.LlavaForConditionalGeneration.from_pretrained( # type: ignore
145+
hf_model = transformers.LlavaForConditionalGeneration.from_pretrained(
145146
'trl-internal-testing/tiny-LlavaForConditionalGeneration',
146147
device_map='cpu',
147148
)
148-
hf_processor = transformers.AutoProcessor.from_pretrained( # type: ignore
149-
'trl-internal-testing/tiny-LlavaForConditionalGeneration'
150-
)
149+
hf_processor = transformers.AutoProcessor.from_pretrained('trl-internal-testing/tiny-LlavaForConditionalGeneration')
151150
outlines_model = outlines.models.transformers.from_transformers(
152151
hf_model,
153-
hf_processor, # type: ignore
152+
hf_processor,
154153
)
155154
return OutlinesModel(outlines_model, provider=OutlinesProvider())
156155

157156

158157
@pytest.fixture
159158
def llamacpp_model() -> OutlinesModel:
160159
outlines_model_llamacpp = outlines.models.llamacpp.from_llamacpp(
161-
llama_cpp.Llama.from_pretrained( # type: ignore
160+
llama_cpp.Llama.from_pretrained(
162161
repo_id='M4-ai/TinyMistral-248M-v2-Instruct-GGUF',
163162
filename='TinyMistral-248M-v2-Instruct.Q4_K_M.gguf',
164163
)
@@ -168,9 +167,7 @@ def llamacpp_model() -> OutlinesModel:
168167

169168
@pytest.fixture
170169
def mlxlm_model() -> OutlinesModel: # pragma: no cover
171-
outlines_model = outlines.models.mlxlm.from_mlxlm(
172-
*mlx_lm.load('mlx-community/SmolLM-135M-Instruct-4bit') # type: ignore
173-
)
170+
outlines_model = outlines.models.mlxlm.from_mlxlm(*mlx_lm.load('mlx-community/SmolLM-135M-Instruct-4bit'))
174171
return OutlinesModel(outlines_model, provider=OutlinesProvider())
175172

176173

@@ -184,9 +181,7 @@ def sglang_model() -> OutlinesModel:
184181

185182
@pytest.fixture
186183
def vllm_model_offline() -> OutlinesModel: # pragma: no cover
187-
outlines_model = outlines.models.vllm_offline.from_vllm_offline( # type: ignore
188-
vllm.LLM('microsoft/Phi-3-mini-4k-instruct') # type: ignore
189-
)
184+
outlines_model = outlines.models.vllm_offline.from_vllm_offline(vllm.LLM('microsoft/Phi-3-mini-4k-instruct'))
190185
return OutlinesModel(outlines_model, provider=OutlinesProvider())
191186

192187

@@ -201,18 +196,18 @@ def binary_image() -> BinaryImage:
201196
pytest.param(
202197
'from_transformers',
203198
lambda: (
204-
transformers.AutoModelForCausalLM.from_pretrained( # type: ignore
199+
transformers.AutoModelForCausalLM.from_pretrained(
205200
'erwanf/gpt2-mini',
206201
device_map='cpu',
207202
),
208-
transformers.AutoTokenizer.from_pretrained('erwanf/gpt2-mini'), # type: ignore
203+
transformers.AutoTokenizer.from_pretrained('erwanf/gpt2-mini'),
209204
),
210205
marks=skip_if_transformers_imports_unsuccessful,
211206
),
212207
pytest.param(
213208
'from_llamacpp',
214209
lambda: (
215-
llama_cpp.Llama.from_pretrained( # type: ignore
210+
llama_cpp.Llama.from_pretrained(
216211
repo_id='M4-ai/TinyMistral-248M-v2-Instruct-GGUF',
217212
filename='TinyMistral-248M-v2-Instruct.Q4_K_M.gguf',
218213
),
@@ -221,7 +216,7 @@ def binary_image() -> BinaryImage:
221216
),
222217
pytest.param(
223218
'from_mlxlm',
224-
lambda: mlx_lm.load('mlx-community/SmolLM-135M-Instruct-4bit'), # type: ignore
219+
lambda: mlx_lm.load('mlx-community/SmolLM-135M-Instruct-4bit'),
225220
marks=skip_if_mlxlm_imports_unsuccessful,
226221
),
227222
pytest.param(
@@ -231,7 +226,7 @@ def binary_image() -> BinaryImage:
231226
),
232227
pytest.param(
233228
'from_vllm_offline',
234-
lambda: (vllm.LLM('microsoft/Phi-3-mini-4k-instruct'),), # type: ignore
229+
lambda: (vllm.LLM('microsoft/Phi-3-mini-4k-instruct'),),
235230
marks=skip_if_vllm_imports_unsuccessful,
236231
),
237232
]
@@ -260,18 +255,18 @@ def test_init(model_loading_function_name: str, args: Callable[[], tuple[Any]])
260255
pytest.param(
261256
'from_transformers',
262257
lambda: (
263-
transformers.AutoModelForCausalLM.from_pretrained( # type: ignore
258+
transformers.AutoModelForCausalLM.from_pretrained(
264259
'erwanf/gpt2-mini',
265260
device_map='cpu',
266261
),
267-
transformers.AutoTokenizer.from_pretrained('erwanf/gpt2-mini'), # type: ignore
262+
transformers.AutoTokenizer.from_pretrained('erwanf/gpt2-mini'),
268263
),
269264
marks=skip_if_transformers_imports_unsuccessful,
270265
),
271266
pytest.param(
272267
'from_llamacpp',
273268
lambda: (
274-
llama_cpp.Llama.from_pretrained( # type: ignore
269+
llama_cpp.Llama.from_pretrained(
275270
repo_id='M4-ai/TinyMistral-248M-v2-Instruct-GGUF',
276271
filename='TinyMistral-248M-v2-Instruct.Q4_K_M.gguf',
277272
),
@@ -280,7 +275,7 @@ def test_init(model_loading_function_name: str, args: Callable[[], tuple[Any]])
280275
),
281276
pytest.param(
282277
'from_mlxlm',
283-
lambda: mlx_lm.load('mlx-community/SmolLM-135M-Instruct-4bit'), # type: ignore
278+
lambda: mlx_lm.load('mlx-community/SmolLM-135M-Instruct-4bit'),
284279
marks=skip_if_mlxlm_imports_unsuccessful,
285280
),
286281
pytest.param(
@@ -290,7 +285,7 @@ def test_init(model_loading_function_name: str, args: Callable[[], tuple[Any]])
290285
),
291286
pytest.param(
292287
'from_vllm_offline',
293-
lambda: (vllm.LLM('microsoft/Phi-3-mini-4k-instruct'),), # type: ignore
288+
lambda: (vllm.LLM('microsoft/Phi-3-mini-4k-instruct'),),
294289
marks=skip_if_vllm_imports_unsuccessful,
295290
),
296291
]

typings/README.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
Stub files (`*.pyi`) contain type hints used only by type checkers, not at
2+
runtime. They were introduced in
3+
[PEP 484](https://peps.python.org/pep-0484/#stub-files). For example, the
4+
[`typeshed`](https://github.com/python/typeshed) repository maintains a
5+
collection of such stubs for the Python standard library and some third-party
6+
libraries.
7+
8+
The `./typings` folder contains type information only for the parts of
9+
third-party dependencies used in the `pydantic-ai` codebase. These stubs must be
10+
manually maintained. When a dependency's API changes, both the codebase and the
11+
stubs need to be updated. There are two ways to update the stubs:
12+
13+
1. **Manual update:** Check the dependency's source code and copy the type
14+
information to `./typings`. For example, take the `from_pretrained()` method
15+
of the `Llama` class in `llama-cpp-python`. The
16+
[source code](https://github.com/abetlen/llama-cpp-python/blob/main/llama_cpp/llama.py#L2240)
17+
contains the type information that is copied to `./typings/llama_cpp.pyi`.
18+
This eliminates the need for `# type: ignore` comments in the codebase.
19+
20+
2. **Update with AI coding assistants:** Most dependencies maintain `llms.txt`
21+
and `llms-full.txt` files with their documentation. This information is
22+
compiled by [Context7](https://context7.com). For example, the
23+
`llama-cpp-python` library is documented
24+
[here](https://github.com/abetlen/llama-cpp-python). MCP servers such as
25+
[this one by Upstash](https://github.com/upstash/context7) provide AI coding
26+
assistants access to Context7. AI coding assistants such as VS Code Copilot
27+
or Cursor can reliably generate and update the stubs.

typings/llama_cpp.pyi

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from collections.abc import Sequence
2+
from os import PathLike
3+
from typing import Any, Literal
4+
5+
from typing_extensions import Self
6+
7+
class Llama:
8+
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
9+
@classmethod
10+
def from_pretrained(
11+
cls,
12+
repo_id: str,
13+
filename: str | None = None,
14+
additional_files: Sequence[str] | None = None,
15+
local_dir: str | PathLike[str] | None = None,
16+
local_dir_use_symlinks: bool | Literal['auto'] = 'auto',
17+
cache_dir: str | PathLike[str] | None = None,
18+
**kwargs: Any,
19+
) -> Self: ...

typings/mlx/__init__.pyi

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from typing import Any
2+
3+
from . import nn
4+
5+
# mlx is imported as a package, primarily for mlx.nn
6+
__all__: list[str] = []

typings/mlx/nn.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from typing import Any
2+
3+
class Module: ...

typings/mlx_lm.pyi

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from typing import Any
2+
3+
from mlx.nn import Module
4+
from transformers.tokenization_utils import PreTrainedTokenizer
5+
6+
def load(model_path: str | None = None, *args: Any, **kwargs: Any) -> tuple[Module, PreTrainedTokenizer]: ...
7+
def generate_step(*args: Any, **kwargs: Any) -> Any: ...

0 commit comments

Comments
 (0)