Skip to content

Commit 106d8ba

Browse files
committed
mypy fix
1 parent b2c21d9 commit 106d8ba

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

autointent/configs/_transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ class CrossEncoderConfig(HFModelConfig):
123123

124124
class RNNConfig(BaseModel):
125125
model_config = ConfigDict(extra="forbid")
126-
device: str = Field(None, description="Torch notation for CPU or CUDA.")
126+
device: str | None = Field(None, description="Torch notation for CPU or CUDA.")
127127
max_seq_length: int = Field(128, description="Maximum sequence length.")
128128
padding_idx: int = Field(0, description="Index used for padding.")
129129
batch_size: PositiveInt = Field(32, description="Batch size for model inference.")

autointent/modules/scoring/_rnn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,10 @@ def clear_cache(self) -> None:
198198
if hasattr(self, "_model"):
199199
del self._model
200200

201+
@property
202+
def device(self) -> str:
203+
"""Get device used for model computations."""
204+
return self._device
201205

202206
class SupervisedRNNClassifier(nn.Module):
203207
def __init__(

0 commit comments

Comments
 (0)