Skip to content

Commit ad6b275

Browse files
committed
remove further type ignore
1 parent cda13d9 commit ad6b275

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Any
22

33
from outlines.models.base import Model
4+
from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer, LlavaForConditionalGeneration
45
from transformers.modeling_utils import PreTrainedModel
56
from transformers.processing_utils import ProcessorMixin
67
from transformers.tokenization_utils import PreTrainedTokenizer
@@ -9,8 +10,8 @@ class Transformers(Model): ...
910
class TransformersMultiModal(Model): ...
1011

1112
def from_transformers(
12-
model: PreTrainedModel,
13-
tokenizer_or_processor: PreTrainedTokenizer | ProcessorMixin,
13+
model: PreTrainedModel | AutoModelForCausalLM | LlavaForConditionalGeneration,
14+
tokenizer_or_processor: PreTrainedTokenizer | ProcessorMixin | AutoTokenizer | AutoProcessor,
1415
*,
1516
device_dtype: Any = None,
1617
) -> Transformers | TransformersMultiModal: ...

tests/models/test_outlines.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,9 @@ def transformers_model() -> OutlinesModel:
129129
hf_tokenizer = transformers.AutoTokenizer.from_pretrained('erwanf/gpt2-mini')
130130
chat_template = '{% for message in messages %}{{ message.role }}: {{ message.content }}{% endfor %}'
131131
hf_tokenizer.chat_template = chat_template
132-
outlines_model = outlines.models.transformers.from_transformers( # type: ignore[reportUnknownMemberType]
133-
hf_model, # type: ignore[reportUnknownArgumentType]
134-
hf_tokenizer, # type: ignore
132+
outlines_model = outlines.models.transformers.from_transformers(
133+
hf_model,
134+
hf_tokenizer,
135135
)
136136
return OutlinesModel(outlines_model, provider=OutlinesProvider())
137137

@@ -143,9 +143,9 @@ def transformers_multimodal_model() -> OutlinesModel:
143143
device_map='cpu',
144144
)
145145
hf_processor = transformers.AutoProcessor.from_pretrained('trl-internal-testing/tiny-LlavaForConditionalGeneration')
146-
outlines_model = outlines.models.transformers.from_transformers( # type: ignore[reportUnknownMemberType]
146+
outlines_model = outlines.models.transformers.from_transformers(
147147
hf_model,
148-
hf_processor, # type: ignore
148+
hf_processor,
149149
)
150150
return OutlinesModel(outlines_model, provider=OutlinesProvider())
151151

0 commit comments

Comments
 (0)