We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 81e6a35 commit 16e2462Copy full SHA for 16e2462
metrics.py
@@ -24,7 +24,7 @@ def acc(outputs, targets):
24
return np.mean(outputs.cpu().numpy().argmax(axis=1) == targets.data.cpu().numpy())
25
26
27
-def calculate_kl(mu_p, sig_p, mu_q, sig_q):
+def calculate_kl(mu_q, sig_q, mu_p, sig_p):
28
kl = 0.5 * (2 * torch.log(sig_p / sig_q) - 1 + (sig_q / sig_p).pow(2) + ((mu_p - mu_q) / sig_p).pow(2)).sum()
29
return kl
30
0 commit comments