Skip to content

Commit fd166f6

Browse files
committed
Fix types
1 parent 948ffa4 commit fd166f6

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

manify/clustering/fuzzy_kmeans.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ def _init_centers(self, X: Float[torch.Tensor, "n_points n_features"]) -> None:
120120
# IMPORTANT: Use self.manifold.manifold for ManifoldParameter,
121121
# as self.manifold is our wrapper and self.manifold.manifold is the geoopt object.
122122
self.mu_ = ManifoldParameter(
123-
centers.clone().detach(), manifold=self.manifold.manifold
123+
centers.clone().detach(), # type: ignore
124+
manifold=self.manifold.manifold,
124125
) # Ensure centers are detached
125126
self.mu_.requires_grad_(True)
126127

manify/manifolds.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -829,7 +829,7 @@ def gaussian_mixture(
829829

830830
# Generate cluster means
831831
cluster_means = self.sample(num_clusters, sigma_factorized=[torch.eye(M.dim) * cov_scale_means for M in self.P])
832-
assert cluster_means.shape == (num_clusters, self.ambient_dim), "Cluster means shape mismatch."
832+
assert cluster_means.shape == (num_clusters, self.ambient_dim), "Cluster means shape mismatch." # type: ignore
833833

834834
# Generate class assignments
835835
cluster_probs = torch.rand(num_clusters)

0 commit comments

Comments
 (0)