Skip to content

Commit 6d3a379

Browse files
committed
Undoing a massive error
This reverts commit bc81082.
1 parent bc81082 commit 6d3a379

File tree

3 files changed

+16
-36
lines changed

3 files changed

+16
-36
lines changed

main.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ def main():
107107
optimizer.step()
108108
optimizer.zero_grad(set_to_none=True)
109109

110-
metrics(y, logits)
110+
preds = th.argmax(logits, dim=1)
111+
metrics(y, preds)
111112

112113
break
113114
print(metrics.accumulate())
@@ -133,7 +134,8 @@ def main():
133134
optimizer.zero_grad(set_to_none=True)
134135
trainingloss.append(loss.item())
135136

136-
metrics(y, logits)
137+
preds = th.argmax(logits, dim=1)
138+
metrics(y, preds)
137139

138140
wandb.log(metrics.accumulate(str_prefix="Train "))
139141
metrics.reset()
@@ -148,7 +150,8 @@ def main():
148150
loss = criterion(logits, y)
149151
evalloss.append(loss.item())
150152

151-
metrics(y, logits)
153+
preds = th.argmax(logits, dim=1)
154+
metrics(y, preds)
152155

153156
wandb.log(metrics.accumulate(str_prefix="Evaluation "))
154157
metrics.reset()

test.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
import numpy as np
2-
from scipy.stats import entropy
1+
from torch.utils.data import Dataset
2+
from torchvision.datasets import SVHN
33

4-
a = np.array([0,0,0,1,1,5,4,3]).reshape(8,1)
5-
b = np.random.rand(8,10)
6-
7-
print(entropy(a, qk=b))
4+
SVHN('data/', download=True)

utils/metrics/EntropyPred.py

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,12 @@
1-
import numpy as np
21
import torch.nn as nn
3-
from scipy.stats import entropy
42

53

64
class EntropyPrediction(nn.Module):
7-
def __init__(self, averages: str = 'average'):
8-
"""
9-
Initializes the EntropyPrediction module.
10-
Args:
11-
averages (str): Specifies the method of aggregation for entropy values.
12-
Must be either 'average' or 'sum'.
13-
Raises:
14-
AssertionError: If the averages parameter is not 'average' or 'sum'.
15-
"""
5+
def __init__(self):
166
super().__init__()
17-
18-
assert averages == 'average' or averages == 'sum'
19-
self.averages = averages
20-
self.stored_entropy_values = []
21-
22-
def __call__(self, y_true, y_pred_logits):
23-
"""
24-
Computes the entropy between true labels and predicted logits, storing the results.
25-
Args:
26-
y_true: The true labels.
27-
y_pred_logits: The predicted logits.
28-
Side Effects:
29-
Appends the computed entropy values to the stored_entropy_values list.
30-
"""
31-
entropy_values = entropy(y_true, qk=y_pred_logits)
32-
return np.mean(entropy_values)
7+
8+
def __call__(self, y_true, y_false_logits):
9+
return
10+
11+
def __reset__(self):
12+
pass

0 commit comments

Comments
 (0)