Skip to content

Commit 603c9be

Browse files
committed
Fixed a small bug in metric wrapper
1 parent a56e224 commit 603c9be

File tree

2 files changed

+1
-1
lines changed

2 files changed

+1
-1
lines changed

CollaborativeCoding/load_metric.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def _get_metric(self, key):
7979
raise ValueError(f"Metric {key} not supported")
8080

8181
def __call__(self, y_true, y_pred):
82+
y_true, y_pred = y_true.detach().cpu(), y_pred.detach().cpu()
8283
for key in self.metrics:
8384
self.metrics[key](y_true, y_pred)
8485

CollaborativeCoding/metrics/EntropyPred.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ def __call__(self, y_true: th.Tensor, y_logits: th.Tensor):
3636
assert y_logits.size(-1) == self.num_classes, (
3737
f"y_logit class length: {y_logits.size(-1)}, expected: {self.num_classes}"
3838
)
39-
4039
y_pred = nn.Softmax(dim=1)(y_logits)
4140
print(f"y_pred: {y_pred}")
4241
entropy_values = entropy(y_pred, axis=1)

0 commit comments

Comments
 (0)