Skip to content

Commit 3d10192

Browse files
committed
codestyle
1 parent 7fa2543 commit 3d10192

File tree

3 files changed

+13
-36
lines changed

3 files changed

+13
-36
lines changed

autointent/modules/scoring/_bert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def _initialize_model(self) -> Any: # noqa: ANN401
7676
label2id = {i: i for i in range(self._n_classes)}
7777
id2label = {i: i for i in range(self._n_classes)}
7878

79-
return AutoModelForSequenceClassification.from_pretrained( # type: ignore[no-untyped-call]
79+
return AutoModelForSequenceClassification.from_pretrained( # type: ignore[no-untyped-call]
8080
self.classification_model_config.model_name,
8181
trust_remote_code=self.classification_model_config.trust_remote_code,
8282
num_labels=self._n_classes,

autointent/modules/scoring/_cnn/textcnn.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -52,21 +52,13 @@ def __init__(
5252

5353
if pretrained_embs is not None:
5454
_, embed_dim = pretrained_embs.shape
55-
self.embedding = nn.Embedding.from_pretrained(pretrained_embs, freeze=True) # type: ignore[no-untyped-call]
55+
self.embedding = nn.Embedding.from_pretrained(pretrained_embs, freeze=True) # type: ignore[no-untyped-call]
5656
else:
57-
self.embedding = nn.Embedding(
58-
num_embeddings=vocab_size,
59-
embedding_dim=embed_dim,
60-
padding_idx=padding_idx
61-
)
62-
63-
self.convs = nn.ModuleList([
64-
nn.Conv1d(
65-
in_channels=embed_dim,
66-
out_channels=num_filters,
67-
kernel_size=k
68-
) for k in kernel_sizes
69-
])
57+
self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_dim, padding_idx=padding_idx)
58+
59+
self.convs = nn.ModuleList(
60+
[nn.Conv1d(in_channels=embed_dim, out_channels=num_filters, kernel_size=k) for k in kernel_sizes]
61+
)
7062
self.dropout = nn.Dropout(dropout)
7163
self.fc = nn.Linear(num_filters * len(kernel_sizes), n_classes)
7264

@@ -77,7 +69,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
7769
conved: list[torch.Tensor] = [F.relu(conv(embedded)).max(dim=2)[0] for conv in self.convs]
7870
concatenated: torch.Tensor = torch.cat(conved, dim=1)
7971
dropped: torch.Tensor = self.dropout(concatenated)
80-
return self.fc(dropped) # type: ignore[no-any-return]
72+
return self.fc(dropped) # type: ignore[no-any-return]
8173

8274
def dump(self, path: Path) -> None:
8375
metadata = {
@@ -87,7 +79,7 @@ def dump(self, path: Path) -> None:
8779
"kernel_sizes": self.kernel_sizes,
8880
"num_filters": self.num_filters,
8981
"dropout": self.dropout_rate,
90-
"padding_idx": self.padding_idx
82+
"padding_idx": self.padding_idx,
9183
}
9284
with (path / self._metadata_dict_name).open("w") as file:
9385
json.dump(metadata, file, indent=4)

tests/modules/scoring/test_cnn.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def test_cnn_prediction(dataset):
2121
embed_dim=128,
2222
kernel_sizes=(3, 4, 5),
2323
num_filters=100,
24-
dropout=0.1
24+
dropout=0.1,
2525
)
2626
scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0))
2727

@@ -57,12 +57,7 @@ def test_cnn_cache_clearing(dataset):
5757
"""Test that the CNN model properly handles cache clearing."""
5858
data_handler = DataHandler(dataset)
5959

60-
scorer = CNNScorer(
61-
max_seq_length=50,
62-
num_train_epochs=1,
63-
batch_size=8,
64-
learning_rate=5e-5
65-
)
60+
scorer = CNNScorer(max_seq_length=50, num_train_epochs=1, batch_size=8, learning_rate=5e-5)
6661
scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0))
6762

6863
test_data = ["test text"]
@@ -86,12 +81,7 @@ def test_cnn_scorer_dump_load(dataset):
8681
data_handler = DataHandler(dataset)
8782

8883
# Create and train scorer
89-
scorer = CNNScorer(
90-
max_seq_length=50,
91-
num_train_epochs=1,
92-
batch_size=8,
93-
learning_rate=5e-5
94-
)
84+
scorer = CNNScorer(max_seq_length=50, num_train_epochs=1, batch_size=8, learning_rate=5e-5)
9585
scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0))
9686

9787
# Test data
@@ -110,12 +100,7 @@ def test_cnn_scorer_dump_load(dataset):
110100
scorer.dump(str(temp_dir_path))
111101

112102
# Create a new scorer and load saved model
113-
scorer_loaded = CNNScorer(
114-
max_seq_length=50,
115-
num_train_epochs=1,
116-
batch_size=8,
117-
learning_rate=5e-5
118-
)
103+
scorer_loaded = CNNScorer(max_seq_length=50, num_train_epochs=1, batch_size=8, learning_rate=5e-5)
119104
scorer_loaded = scorer_loaded.load(str(temp_dir_path))
120105

121106
# Verify model is loaded

0 commit comments

Comments
 (0)