Skip to content

Commit 736dccd

Browse files
committed
Fixing transformers typing
1 parent ea03d45 commit 736dccd

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

rigging/generator/transformers_.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,21 +80,23 @@ def llm(self) -> AutoModelForCausalLM:
8080
"load_in_4bit",
8181
},
8282
)
83-
self._llm = AutoModelForCausalLM.from_pretrained(self.model, **llm_kwargs) # type: ignore [no-untyped-call, unused-ignore] # nosec
83+
self._llm = AutoModelForCausalLM.from_pretrained(self.model, **llm_kwargs) # type: ignore [no-untyped-call, assignment, unused-ignore] # nosec
84+
if self._llm is None:
85+
raise ValueError(f"Failed to load model '{self.model}'")
8486
return self._llm
8587

8688
@property
8789
def tokenizer(self) -> AutoTokenizer:
8890
"""The underlying `AutoTokenizer` instance."""
8991
if self._tokenizer is None:
90-
self._tokenizer = AutoTokenizer.from_pretrained(self.model) # nosec
92+
self._tokenizer = AutoTokenizer.from_pretrained(self.model) # type: ignore [no-untyped-call, unused-ignore] # nosec
9193
return self._tokenizer
9294

9395
@property
9496
def pipeline(self) -> TextGenerationPipeline:
9597
"""The underlying `TextGenerationPipeline` instance."""
9698
if self._pipeline is None:
97-
self._pipeline = transformers.pipeline( # type: ignore [attr-defined, assignment, unused-ignore]
99+
self._pipeline = transformers.pipeline( # type: ignore [attr-defined, call-overload, assignment, unused-ignore]
98100
"text-generation",
99101
return_full_text=False,
100102
model=self.llm, # type: ignore [arg-type, unused-ignore]
@@ -160,7 +162,7 @@ def _generate(
160162
if any(k in kwargs for k in ["temperature", "top_k", "top_p"]):
161163
kwargs["do_sample"] = True
162164

163-
outputs = self.pipeline(inputs, **kwargs)
165+
outputs = self.pipeline(inputs, **kwargs) # type: ignore [call-overload]
164166

165167
# TODO: We do strip() here as it's often needed, but I think
166168
# we should return and standardize this behavior.

rigging/tokenizer/transformers_.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class TransformersTokenizer(Tokenizer):
3636
def tokenizer(self) -> "PreTrainedTokenizer":
3737
"""The underlying `PreTrainedTokenizer` instance."""
3838
if self._tokenizer is None:
39-
self._tokenizer = AutoTokenizer.from_pretrained(self.model) # nosec
39+
self._tokenizer = AutoTokenizer.from_pretrained(self.model) # type: ignore[no-untyped-call] # nosec
4040
return self._tokenizer
4141

4242
@classmethod

0 commit comments

Comments
 (0)