Skip to content

Commit eb32ac4

Browse files
committed
stubs for mlx, vllm, outlines and transformers
1 parent eea1136 commit eb32ac4

File tree

16 files changed

+172
-17
lines changed

16 files changed

+172
-17
lines changed

pydantic_ai_slim/pydantic_ai/models/outlines.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
)
6161
from outlines.models.vllm_offline import (
6262
VLLMOffline,
63-
from_vllm_offline, # pyright: ignore[reportUnknownVariableType]
63+
from_vllm_offline,
6464
)
6565
from outlines.types.dsl import JsonSchema
6666
from PIL import Image as PILImage
@@ -393,7 +393,7 @@ def _format_vllm_offline_inference_kwargs( # pragma: no cover
393393
self, model_settings: dict[str, Any]
394394
) -> dict[str, Any]:
395395
"""Select the model settings supported by the vLLMOffline model."""
396-
from vllm.sampling_params import SamplingParams # pyright: ignore
396+
from vllm.sampling_params import SamplingParams
397397

398398
supported_args = [
399399
'max_tokens',

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ quote-style = "single"
188188
"docs/**/*.py" = ["D"]
189189

190190
[tool.pyright]
191+
stubPath = "stubs"
191192
pythonVersion = "3.12"
192193
typeCheckingMode = "strict"
193194
reportMissingTypeStubs = false
@@ -217,6 +218,7 @@ exclude = [
217218
[tool.mypy]
218219
files = "tests/typed_agent.py,tests/typed_graph.py"
219220
strict = true
221+
mypy_path = "stubs"
220222

221223
[tool.pytest.ini_options]
222224
testpaths = ["tests", "docs/.hooks"]

stubs/mlx/__init__.pyi

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from typing import Any
2+
3+
# mlx is imported as a package, primarily for mlx.nn
4+
__all__: list[str]

stubs/mlx/nn.pyi

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from typing import Any
2+
3+
4+
class Module:
5+
...

stubs/mlx_lm.pyi

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from typing import Any
2+
3+
from mlx.nn import Module
4+
from transformers.tokenization_utils import PreTrainedTokenizer
5+
6+
def load(model_path: str) -> tuple[Module, PreTrainedTokenizer]: ...
7+
def generate_step(*args: Any, **kwargs: Any) -> Any: ...

stubs/outlines/models/base.pyi

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from typing import Any, AsyncIterable, Iterable
2+
3+
4+
class Model:
5+
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
6+
def stream(self, *args: Any, **kwargs: Any) -> Iterable[Any]: ...
7+
8+
9+
class AsyncModel(Model):
10+
async def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
11+
def stream(self, *args: Any, **kwargs: Any) -> AsyncIterable[Any]: ...

stubs/outlines/models/mlxlm.pyi

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from typing import Any
2+
from outlines.models.base import Model
3+
4+
from mlx.nn import Module
5+
from transformers.tokenization_utils import PreTrainedTokenizer
6+
7+
8+
class MLXLM(Model):
9+
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
10+
11+
12+
def from_mlxlm(model: Module, tokenizer: PreTrainedTokenizer) -> MLXLM: ...
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from typing import Any
2+
from outlines.models.base import Model
3+
4+
from transformers.modeling_utils import PreTrainedModel
5+
from transformers.processing_utils import ProcessorMixin
6+
from transformers.tokenization_utils import PreTrainedTokenizer
7+
8+
9+
class Transformers(Model):
10+
...
11+
12+
13+
class TransformersMultiModal(Model):
14+
...
15+
16+
17+
def from_transformers(
18+
model: PreTrainedModel,
19+
tokenizer_or_processor: PreTrainedTokenizer | ProcessorMixin,
20+
*,
21+
device_dtype: Any = None,
22+
) -> Transformers | TransformersMultiModal: ...
23+
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from typing import Any
2+
from outlines.models.base import Model
3+
4+
5+
class VLLMOffline(Model):
6+
...
7+
8+
9+
def from_vllm_offline(model: Any) -> VLLMOffline: ...

stubs/transformers/__init__.pyi

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from typing import Any
2+
3+
from . import modeling_utils, processing_utils, tokenization_utils
4+
from .modeling_utils import PreTrainedModel
5+
from .processing_utils import ProcessorMixin
6+
from .tokenization_utils import PreTrainedTokenizer
7+
8+
9+
class AutoModelForCausalLM(PreTrainedModel):
10+
@classmethod
11+
def from_pretrained(cls, *args: Any, **kwargs: Any) -> PreTrainedModel: ...
12+
13+
14+
class AutoTokenizer(PreTrainedTokenizer):
15+
@classmethod
16+
def from_pretrained(cls, *args: Any, **kwargs: Any) -> PreTrainedTokenizer: ...
17+
18+
19+
class AutoProcessor(ProcessorMixin):
20+
@classmethod
21+
def from_pretrained(cls, *args: Any, **kwargs: Any) -> ProcessorMixin: ...
22+
23+
24+
class LlavaForConditionalGeneration(PreTrainedModel):
25+
@classmethod
26+
def from_pretrained(cls, *args: Any, **kwargs: Any) -> PreTrainedModel: ...
27+
28+
29+
def from_pretrained(*args: Any, **kwargs: Any) -> Any: ...

0 commit comments

Comments
 (0)