File tree Expand file tree Collapse file tree 2 files changed +5
-1
lines changed
Expand file tree Collapse file tree 2 files changed +5
-1
lines changed Original file line number Diff line number Diff line change @@ -123,7 +123,7 @@ class CrossEncoderConfig(HFModelConfig):
123123
124124class RNNConfig (BaseModel ):
125125 model_config = ConfigDict (extra = "forbid" )
126- device : str = Field (None , description = "Torch notation for CPU or CUDA." )
126+ device : str | None = Field (None , description = "Torch notation for CPU or CUDA." )
127127 max_seq_length : int = Field (128 , description = "Maximum sequence length." )
128128 padding_idx : int = Field (0 , description = "Index used for padding." )
129129 batch_size : PositiveInt = Field (32 , description = "Batch size for model inference." )
Original file line number Diff line number Diff line change @@ -198,6 +198,10 @@ def clear_cache(self) -> None:
198198 if hasattr (self , "_model" ):
199199 del self ._model
200200
201+ @property
202+ def device (self ) -> str :
203+ """Get device used for model computations."""
204+ return self ._device
201205
202206class SupervisedRNNClassifier (nn .Module ):
203207 def __init__ (
You can’t perform that action at this time.
0 commit comments