Skip to content

Commit 8d8e103

Browse files
committed
multiclass fix
1 parent b0fe5c0 commit 8d8e103

File tree

2 files changed

+49
-22
lines changed

2 files changed

+49
-22
lines changed

autointent/modules/scoring/_gcn/gcn_scorer.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
9292
self._label_embedder = Embedder(self.label_embedder_config)
9393

9494
x_tensor = torch.tensor(self._embedder.embed(utterances, TaskTypeEnum.classification))
95-
y_tensor = torch.tensor(labels, dtype=torch.float)
95+
y_tensor_dtype = torch.float if self._multilabel else torch.long
96+
y_tensor = torch.tensor(labels, dtype=y_tensor_dtype)
9697

9798
intent_texts = [f"intent {i}" for i in range(self._n_classes)]
9899
self._label_embeddings = torch.tensor(
@@ -107,14 +108,16 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
107108
p_reweight=self.p_reweight,
108109
tau_threshold=self.tau_threshold,
109110
)
110-
self._model.set_correlation_matrix(y_tensor)
111-
self._train_model(x_tensor, y_tensor)
112111

113-
def _train_model(self, train_x: torch.Tensor, train_y: torch.Tensor) -> None:
112+
y_corr_tensor = y_tensor if self._multilabel else torch.nn.functional.one_hot(y_tensor, self._n_classes)
113+
self._model.set_correlation_matrix(y_corr_tensor.float())
114+
115+
criterion = nn.BCEWithLogitsLoss() if self._multilabel else nn.CrossEntropyLoss()
116+
self._train_model(x_tensor, y_tensor, criterion)
117+
118+
def _train_model(self, train_x: torch.Tensor, train_y: torch.Tensor, criterion: nn.Module) -> None:
114119
train_dataset = TensorDataset(train_x, train_y)
115120
train_dataloader = DataLoader(train_dataset, batch_size=self.torch_config.batch_size, shuffle=True)
116-
117-
criterion = nn.BCEWithLogitsLoss()
118121
optimizer = torch.optim.Adam(self._model.parameters(), lr=self.torch_config.learning_rate)
119122

120123
self._model.to(self.torch_config.device)
@@ -142,7 +145,10 @@ def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
142145
for i in range(0, len(x_tensor), self.torch_config.batch_size):
143146
batch_x = x_tensor[i : i + self.torch_config.batch_size].to(self.torch_config.device)
144147
outputs = self._model(batch_x, self._label_embeddings)
145-
probs = torch.sigmoid(outputs).cpu().numpy()
148+
if self._multilabel:
149+
probs = torch.sigmoid(outputs).cpu().numpy()
150+
else:
151+
probs = torch.softmax(outputs, dim=1).cpu().numpy()
146152
all_probs.append(probs)
147153

148154
return np.concatenate(all_probs, axis=0)

tests/modules/scoring/test_gcn_scorer.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,36 +23,57 @@ def multilabel_dataset():
2323
return Dataset.from_dict(data)
2424

2525

26-
def test_gcn_scorer_fit_predict(multilabel_dataset):
27-
scorer = GCNScorer(
28-
embedder_config="prajjwal1/bert-tiny",
29-
num_train_epochs=1,
30-
batch_size=2,
31-
)
26+
@pytest.fixture
27+
def multiclass_dataset():
28+
data = {
29+
"train": [
30+
{"utterance": "utterance 1", "label": 0},
31+
{"utterance": "utterance 2", "label": 1},
32+
{"utterance": "utterance 3", "label": 2},
33+
{"utterance": "utterance 4", "label": 0},
34+
],
35+
"intents": [
36+
{"id": 0, "name": "intent_0"},
37+
{"id": 1, "name": "intent_1"},
38+
{"id": 2, "name": "intent_2"},
39+
],
40+
}
41+
return Dataset.from_dict(data)
42+
43+
44+
def test_gcn_scorer_multilabel(multilabel_dataset):
45+
scorer = GCNScorer(embedder_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=2)
3246
train_utterances = multilabel_dataset["train"]["utterance"]
3347
train_labels = multilabel_dataset["train"]["label"]
34-
3548
scorer.fit(train_utterances, train_labels)
49+
test_utterances = ["test 1", "test 2"]
50+
predictions = scorer.predict(test_utterances)
3651

37-
test_utterances = ["test utterance 1", "test utterance 2"]
52+
assert isinstance(predictions, np.ndarray)
53+
assert predictions.shape == (2, 3)
54+
assert np.all((predictions >= 0) & (predictions <= 1))
55+
56+
57+
def test_gcn_scorer_multiclass(multiclass_dataset):
58+
scorer = GCNScorer(embedder_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=2)
59+
train_utterances = multiclass_dataset["train"]["utterance"]
60+
train_labels = multiclass_dataset["train"]["label"]
61+
scorer.fit(train_utterances, train_labels)
62+
test_utterances = ["test 1", "test 2"]
3863
predictions = scorer.predict(test_utterances)
3964

4065
assert isinstance(predictions, np.ndarray)
4166
assert predictions.shape == (2, 3)
4267
assert np.all((predictions >= 0) & (predictions <= 1))
68+
np.testing.assert_allclose(predictions.sum(axis=1), 1.0, atol=1e-6)
4369

4470

4571
def test_gcn_scorer_dump_load(tmp_path, multilabel_dataset):
46-
scorer = GCNScorer(
47-
embedder_config="prajjwal1/bert-tiny",
48-
num_train_epochs=1,
49-
batch_size=2,
50-
)
72+
scorer = GCNScorer(embedder_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=2)
5173
train_utterances = multilabel_dataset["train"]["utterance"]
5274
train_labels = multilabel_dataset["train"]["label"]
5375
scorer.fit(train_utterances, train_labels)
54-
55-
test_utterances = ["test utterance 1", "test utterance 2"]
76+
test_utterances = ["test utterance 1"]
5677
original_predictions = scorer.predict(test_utterances)
5778

5879
scorer.dump(str(tmp_path))

0 commit comments

Comments
 (0)