Skip to content

Commit a37b08c

Browse files
committed
Fixed thing
1 parent 6a04084 commit a37b08c

File tree

2 files changed

+29
-12
lines changed

2 files changed

+29
-12
lines changed

main.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,7 @@ def main():
107107
optimizer.step()
108108
optimizer.zero_grad(set_to_none=True)
109109

110-
preds = th.argmax(logits, dim=1)
111-
metrics(y, preds)
110+
metrics(y, logits)
112111

113112
break
114113
print(metrics.accumulate())
@@ -134,8 +133,7 @@ def main():
134133
optimizer.zero_grad(set_to_none=True)
135134
trainingloss.append(loss.item())
136135

137-
preds = th.argmax(logits, dim=1)
138-
metrics(y, preds)
136+
metrics(y, logits)
139137

140138
wandb.log(metrics.accumulate(str_prefix="Train "))
141139
metrics.reset()
@@ -150,8 +148,7 @@ def main():
150148
loss = criterion(logits, y)
151149
evalloss.append(loss.item())
152150

153-
preds = th.argmax(logits, dim=1)
154-
metrics(y, preds)
151+
metrics(y, logits)
155152

156153
wandb.log(metrics.accumulate(str_prefix="Evaluation "))
157154
metrics.reset()

utils/metrics/EntropyPred.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,32 @@
1+
import numpy as np
12
import torch.nn as nn
3+
from scipy.stats import entropy
24

35

46
class EntropyPrediction(nn.Module):
5-
def __init__(self):
7+
def __init__(self, averages: str = 'average'):
8+
"""
9+
Initializes the EntropyPrediction module.
10+
Args:
11+
averages (str): Specifies the method of aggregation for entropy values.
12+
Must be either 'average' or 'sum'.
13+
Raises:
14+
AssertionError: If the averages parameter is not 'average' or 'sum'.
15+
"""
616
super().__init__()
7-
17+
18+
assert averages == 'average' or averages == 'sum'
19+
self.averages = averages
20+
self.stored_entropy_values = []
21+
822
def __call__(self, y_true, y_false_logits):
9-
return
10-
11-
def __reset__(self):
12-
pass
23+
"""
24+
Computes the entropy between true labels and predicted logits, storing the results.
25+
Args:
26+
y_true: The true labels.
27+
y_false_logits: The predicted logits.
28+
Side Effects:
29+
Appends the computed entropy values to the stored_entropy_values list.
30+
"""
31+
entropy_values = entropy(y_true, qk=y_false_logits)
32+
return entropy_values

0 commit comments

Comments
 (0)