Skip to content

Commit 2a2fc83

Browse files
committed
address torch cov in fid calculation #213
1 parent 81b4a5d commit 2a2fc83

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

denoising_diffusion_pytorch/denoising_diffusion_pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -953,7 +953,7 @@ def calculate_activation_statistics(self, samples):
953953
features = rearrange(features, '... 1 1 -> ...')
954954

955955
mu = torch.mean(features, dim = 0).cpu()
956-
sigma = torch.cov(features).cpu()
956+
sigma = torch.cov(rearrange(features, '... i j -> ... j i')).cpu()
957957
return mu, sigma
958958

959959
def fid_score(self, real_samples, fake_samples):
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.5.6'
1+
__version__ = '1.5.7'

0 commit comments

Comments
 (0)