-
Notifications
You must be signed in to change notification settings - Fork 524
Open
Description
pytorch-fid/src/pytorch_fid/fid_score.py
Line 225 in b9c1811
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) |
I always got "Imaginary component" errors that drove me crazy. I'm not sure if wrong numpy versions cause this. But then, I found a more stable way to compute FiD.
diff = mu1 - mu2
# Ensure covariance matrices are positive definite
sigma1 = sigma1 + np.eye(sigma1.shape[0]) * eps
sigma2 = sigma2 + np.eye(sigma2.shape[0]) * eps
# More robust implementation using eigendecomposition
sigma1_sqrt = linalg.sqrtm(sigma1)
sigma1_sqrt = sigma1_sqrt.astype(np.float64) # Ensure double precision
# Compute sigma1_sqrt * sigma2 * sigma1_sqrt more carefully
product = sigma1_sqrt @ sigma2 @ sigma1_sqrt
# Ensure product is Hermitian (symmetric)
product = (product + product.T) / 2
# Compute trace directly from eigenvalues for better numerical stability
try:
s = np.real(np.linalg.eigvalsh(product))
trace_sqrt = np.sum(np.sqrt(np.maximum(s, 0)))
except np.linalg.LinAlgError:
# Fall back to original sqrtm approach with warning
warnings.warn("Eigendecomposition failed, falling back to sqrtm")
covmean = linalg.sqrtm(product)
if np.iscomplexobj(covmean):
covmean = covmean.real
trace_sqrt = np.trace(covmean)
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * trace_sqrt
xiaoliua1, JingyunLiang and TO-Hitori
Metadata
Metadata
Assignees
Labels
No labels