Skip to content

Commit abf16fb

Browse files
committed
codestyle
1 parent 848147d commit abf16fb

File tree

3 files changed

+4
-6
lines changed

3 files changed

+4
-6
lines changed

autointent/_wrappers/base_torch_module.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,9 @@ def load(cls, path: Path, device: str | None = None) -> Self:
3333
device: torch notation for CPU, CUDA, MPS, etc. By default, it is inferred automatically.
3434
"""
3535

36-
3736
@property
3837
def device(self) -> torch.device:
3938
"""Torch device object where this module resides."""
4039
if not hasattr(self, "_device"):
41-
self._device = next(self.parameters()).device
40+
self._device = next(self.parameters()).device
4241
return self._device

autointent/modules/scoring/_cnn/textcnn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class TextCNNDumpMetadata(TypedDict):
2323
padding_idx: int
2424
vocab: dict[str, int]
2525
max_seq_length: int
26-
vocab: dict[str, int]
26+
max_vocab_size: int | None
2727

2828

2929
class TextCNN(BaseTorchModule):
@@ -115,7 +115,7 @@ def text_to_indices(self, utterances: list[str]) -> list[list[int]]:
115115

116116
def forward(self, x: torch.Tensor) -> torch.Tensor:
117117
"""Forward pass of the model."""
118-
if self._vocab is None:
118+
if not hasattr(self, "_vocab"):
119119
msg = "Model not initialized. Call build_vocab() first."
120120
raise ValueError(msg)
121121

tests/modules/scoring/test_cnn.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,7 @@ def test_cnn_scorer_dump_load(dataset):
100100
scorer.dump(str(temp_dir_path))
101101

102102
# Create a new scorer and load saved model
103-
scorer_loaded = CNNScorer(max_seq_length=50, num_train_epochs=1, batch_size=8, learning_rate=5e-5)
104-
scorer_loaded = scorer_loaded.load(str(temp_dir_path))
103+
scorer_loaded = CNNScorer.load(str(temp_dir_path))
105104

106105
# Verify model is loaded
107106
assert hasattr(scorer_loaded, "_model")

0 commit comments

Comments
 (0)