Skip to content

More stable frechet distance computationΒ #117

@AIWanderer-X

Description

@AIWanderer-X

covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

I always got "Imaginary component" errors that drove me crazy. I'm not sure if wrong numpy versions cause this. But then, I found a more stable way to compute FiD.

        diff = mu1 - mu2
        
        # Ensure covariance matrices are positive definite
        sigma1 = sigma1 + np.eye(sigma1.shape[0]) * eps
        sigma2 = sigma2 + np.eye(sigma2.shape[0]) * eps
        
        # More robust implementation using eigendecomposition
        sigma1_sqrt = linalg.sqrtm(sigma1)
        sigma1_sqrt = sigma1_sqrt.astype(np.float64)  # Ensure double precision
        
        # Compute sigma1_sqrt * sigma2 * sigma1_sqrt more carefully
        product = sigma1_sqrt @ sigma2 @ sigma1_sqrt
        
        # Ensure product is Hermitian (symmetric)
        product = (product + product.T) / 2
        
        # Compute trace directly from eigenvalues for better numerical stability
        try:
            s = np.real(np.linalg.eigvalsh(product))
            trace_sqrt = np.sum(np.sqrt(np.maximum(s, 0)))
        except np.linalg.LinAlgError:
            # Fall back to original sqrtm approach with warning
            warnings.warn("Eigendecomposition failed, falling back to sqrtm")
            covmean = linalg.sqrtm(product)
            if np.iscomplexobj(covmean):
                covmean = covmean.real
            trace_sqrt = np.trace(covmean)
        
        return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * trace_sqrt

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions