Skip to content

Commit 16e2462

Browse files
committed
Fix calculate_kl
1 parent 81e6a35 commit 16e2462

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def acc(outputs, targets):
2424
return np.mean(outputs.cpu().numpy().argmax(axis=1) == targets.data.cpu().numpy())
2525

2626

27-
def calculate_kl(mu_p, sig_p, mu_q, sig_q):
27+
def calculate_kl(mu_q, sig_q, mu_p, sig_p):
2828
kl = 0.5 * (2 * torch.log(sig_p / sig_q) - 1 + (sig_q / sig_p).pow(2) + ((mu_p - mu_q) / sig_p).pow(2)).sum()
2929
return kl
3030

0 commit comments

Comments
 (0)