Skip to content

Commit 7443323

Browse files
committed
Fix tests
1 parent 6052f76 commit 7443323

File tree

3 files changed

+16
-4
lines changed

3 files changed

+16
-4
lines changed

manify/manifolds.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ def sample(
234234
z_mean = self.mu0 if z_mean is None else z_mean
235235
z_mean = torch.Tensor(z_mean).reshape(-1, self.ambient_dim).to(self.device)
236236
n = z_mean.shape[0]
237+
237238
sigma = torch.stack([torch.eye(self.dim)] * n).to(self.device) if sigma is None else sigma
238239
sigma = torch.Tensor(sigma).reshape(-1, self.dim, self.dim).to(self.device)
239240
assert sigma.shape == (
@@ -250,7 +251,7 @@ def sample(
250251

251252
# Sample initial vector from N(0, sigma)
252253
N = torch.distributions.MultivariateNormal(
253-
loc=torch.zeros((n, self.dim), device=self.device), covariance_matrix=sigma
254+
loc=torch.zeros((n * n_samples, self.dim), device=self.device), covariance_matrix=sigma
254255
)
255256
v = N.sample()
256257

tests/test_clustering.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@ def test_riemannian_fuzzy_k_means():
99
X, _ = pm.gaussian_mixture(num_points=100)
1010

1111
for optimizer in ["adam", "adan"]:
12-
kmeans = RiemannianFuzzyKMeans(manifold=pm, n_clusters=5, random_state=42)
12+
kmeans = RiemannianFuzzyKMeans(pm=pm, n_clusters=5, random_state=42)
1313
kmeans.fit(X)
1414
preds = kmeans.predict(X)
1515
assert preds.shape == (100,), f"Predictions should have shape (100,) (optimizer: {optimizer})"
1616

1717
# Also test with X as a numpy array
1818
X_np = X.numpy()
19-
kmeans = RiemannianFuzzyKMeans(manifold=pm, n_clusters=5, random_state=42)
19+
kmeans = RiemannianFuzzyKMeans(pm=pm, n_clusters=5, random_state=42)
2020
kmeans.fit(X_np)
2121
preds_np = kmeans.predict(X_np)
2222
assert torch.tensor(preds_np).shape == (100,), f"Predictions should have shape (100,) (optimizer: {optimizer})"
@@ -25,7 +25,7 @@ def test_riemannian_fuzzy_k_means():
2525
)
2626

2727
# Also do a single manifold
28-
kmeans = RiemannianFuzzyKMeans(manifold=pm.P[0], n_clusters=5, optimizer=optimizer, random_state=42)
28+
kmeans = RiemannianFuzzyKMeans(pm=pm.P[0], n_clusters=5, optimizer=optimizer, random_state=42)
2929
X0 = pm.factorize(X)[0]
3030
kmeans.fit(X0)
3131
preds = kmeans.predict(X0)

tests/test_manifolds.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,17 @@ def _shared_tests(M, X1, X2, is_euclidean):
2323
assert torch.allclose(ip_11, X1 @ X1.T, atol=1e-5), "Euclidean inner products do not match for X1"
2424
assert torch.allclose(ip_12, X1 @ X2.T, atol=1e-5), "Euclidean inner products do not match for X1 and X2"
2525

26+
# Sampling shapes should support a variety of inputs
27+
stacked_means = torch.stack([M.mu0] * 5)
28+
s1 = M.sample(100)
29+
assert s1.shape == (100, M.ambient_dim), "Sampled points should have the correct shape"
30+
s2 = M.sample(100, z_mean=M.mu0)
31+
assert s2.shape == (100, M.ambient_dim), "Sampled points should have the correct shape"
32+
s3 = M.sample(z_mean=stacked_means)
33+
assert s3.shape == (5, M.ambient_dim), "Sampled points should have the correct shape"
34+
s3 = M.sample(100, z_mean=stacked_means)
35+
assert s3.shape == (500, M.ambient_dim), "Sampled points should have the correct shape"
36+
2637
# Dists
2738
dists_11 = M.dist(X1, X1)
2839
assert dists_11.shape == (10, 10), "Distance shape mismatch for X1"

0 commit comments

Comments
 (0)