Skip to content

Commit bc81082

Browse files
committed
Thingy
1 parent be7d12a commit bc81082

File tree

3 files changed

+36
-16
lines changed

3 files changed

+36
-16
lines changed

main.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,7 @@ def main():
107107
optimizer.step()
108108
optimizer.zero_grad(set_to_none=True)
109109

110-
preds = th.argmax(logits, dim=1)
111-
metrics(y, preds)
110+
metrics(y, logits)
112111

113112
break
114113
print(metrics.accumulate())
@@ -134,8 +133,7 @@ def main():
134133
optimizer.zero_grad(set_to_none=True)
135134
trainingloss.append(loss.item())
136135

137-
preds = th.argmax(logits, dim=1)
138-
metrics(y, preds)
136+
metrics(y, logits)
139137

140138
wandb.log(metrics.accumulate(str_prefix="Train "))
141139
metrics.reset()
@@ -150,8 +148,7 @@ def main():
150148
loss = criterion(logits, y)
151149
evalloss.append(loss.item())
152150

153-
preds = th.argmax(logits, dim=1)
154-
metrics(y, preds)
151+
metrics(y, logits)
155152

156153
wandb.log(metrics.accumulate(str_prefix="Evaluation "))
157154
metrics.reset()

test.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
from torch.utils.data import Dataset
2-
from torchvision.datasets import SVHN
1+
import numpy as np
2+
from scipy.stats import entropy
33

4-
SVHN('data/', download=True)
4+
a = np.array([0,0,0,1,1,5,4,3]).reshape(8,1)
5+
b = np.random.rand(8,10)
6+
7+
print(entropy(a, qk=b))

utils/metrics/EntropyPred.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,32 @@
1+
import numpy as np
12
import torch.nn as nn
3+
from scipy.stats import entropy
24

35

46
class EntropyPrediction(nn.Module):
5-
def __init__(self):
7+
def __init__(self, averages: str = 'average'):
8+
"""
9+
Initializes the EntropyPrediction module.
10+
Args:
11+
averages (str): Specifies the method of aggregation for entropy values.
12+
Must be either 'average' or 'sum'.
13+
Raises:
14+
AssertionError: If the averages parameter is not 'average' or 'sum'.
15+
"""
616
super().__init__()
7-
8-
def __call__(self, y_true, y_false_logits):
9-
return
10-
11-
def __reset__(self):
12-
pass
17+
18+
assert averages == 'average' or averages == 'sum'
19+
self.averages = averages
20+
self.stored_entropy_values = []
21+
22+
def __call__(self, y_true, y_pred_logits):
23+
"""
24+
Computes the entropy between true labels and predicted logits, storing the results.
25+
Args:
26+
y_true: The true labels.
27+
y_pred_logits: The predicted logits.
28+
Side Effects:
29+
Appends the computed entropy values to the stored_entropy_values list.
30+
"""
31+
entropy_values = entropy(y_true, qk=y_pred_logits)
32+
return np.mean(entropy_values)

0 commit comments

Comments
 (0)