Skip to content

Commit 546bead

Browse files
authored
Merge pull request #103 from SFI-Visual-Intelligence/mag-branch
Fixed a small bug in metric wrapper
2 parents a56e224 + bc1dd7a commit 546bead

File tree

2 files changed

+1
-3
lines changed

2 files changed

+1
-3
lines changed

CollaborativeCoding/load_metric.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def _get_metric(self, key):
7979
raise ValueError(f"Metric {key} not supported")
8080

8181
def __call__(self, y_true, y_pred):
82+
y_true, y_pred = y_true.detach().cpu(), y_pred.detach().cpu()
8283
for key in self.metrics:
8384
self.metrics[key](y_true, y_pred)
8485

CollaborativeCoding/metrics/EntropyPred.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,13 @@ def __call__(self, y_true: th.Tensor, y_logits: th.Tensor):
3636
assert y_logits.size(-1) == self.num_classes, (
3737
f"y_logit class length: {y_logits.size(-1)}, expected: {self.num_classes}"
3838
)
39-
4039
y_pred = nn.Softmax(dim=1)(y_logits)
41-
print(f"y_pred: {y_pred}")
4240
entropy_values = entropy(y_pred, axis=1)
4341
entropy_values = th.from_numpy(entropy_values)
4442

4543
# Fix numerical errors for perfect guesses
4644
entropy_values[entropy_values == th.inf] = 0
4745
entropy_values = th.nan_to_num(entropy_values)
48-
print(f"Entropy Values: {entropy_values}")
4946
for sample in entropy_values:
5047
self.stored_entropy_values.append(sample.item())
5148

0 commit comments

Comments
 (0)