Skip to content

Commit 12f9d57

Browse files
committed
Add manifold tests
1 parent 3591857 commit 12f9d57

File tree

4 files changed

+453
-15
lines changed

4 files changed

+453
-15
lines changed

manify/manifolds.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -399,16 +399,35 @@ def inverse_stereographic(self, *points: Float[torch.Tensor, "n_points n_dim_ste
399399
return orig_manifold, *points # type: ignore
400400

401401
# Inverse projection for points
402-
norm_squared = [(Y**2).sum(dim=1, keepdim=True) for Y in points]
403-
sign = torch.sign(self.curvature) # type: ignore
402+
out = []
403+
for X in points:
404+
# Calculate squared norm
405+
# let σ = sign(K) and λ = sqrt(|K|)
406+
sign = torch.sign(torch.tensor(self.curvature, device=self.device))
407+
lam = abs(self.curvature) ** 0.5
404408

405-
X0 = (1 + sign * norm_squared) / (1 - sign * norm_squared)
406-
Xi = 2 * points / (1 - sign * norm_squared)
409+
# compute the ‖·‖² in the *scaled* ball
410+
norm2 = torch.sum((lam * X) ** 2, dim=1)
407411

408-
inv_points = [torch.cat([x0, xi], dim=1) for x0, xi in zip(X0, Xi)]
409-
assert all([orig_manifold.manifold.check_point(X) for X in inv_points])
412+
# inverse‐stereographic denom must be (1 + σ⋅‖y‖²), *not* (1 – σ⋅‖y‖²)
413+
denom = 1.0 + sign * norm2
414+
# clamp to avoid blow‐up at the boundary
415+
denom = torch.clamp_min(denom.abs(), 1e-6) * denom.sign()
410416

411-
return orig_manifold, *inv_points # type: ignore
417+
# then
418+
X0 = (1.0 - sign * norm2) / denom
419+
Xi = 2.0 * lam * X / denom.unsqueeze(1)
420+
421+
# Combine into full coordinates
422+
inv_points = torch.cat([X0.unsqueeze(1), Xi], dim=1)
423+
424+
# Let the manifold class validate the points
425+
if not orig_manifold.manifold.check_point(inv_points):
426+
raise ValueError("Generated points do not lie on the target manifold")
427+
428+
out.append(inv_points)
429+
430+
return orig_manifold, *out # type: ignore
412431

413432
def apply(self, f: Callable) -> Callable:
414433
"""

manify/utils/benchmarks.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,7 @@
99
from sklearn.base import BaseEstimator
1010
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
1111
from sklearn.linear_model import SGDClassifier, SGDRegressor
12-
from sklearn.metrics import (
13-
accuracy_score,
14-
f1_score,
15-
mean_squared_error,
16-
root_mean_squared_error,
17-
)
12+
from sklearn.metrics import accuracy_score, f1_score, mean_squared_error, root_mean_squared_error
1813
from sklearn.model_selection import train_test_split
1914
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
2015
from sklearn.svm import SVC, SVR

notebooks/60_pytest_scratch.ipynb

Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 10,
6+
"id": "cfd9764c",
7+
"metadata": {},
8+
"outputs": [
9+
{
10+
"name": "stdout",
11+
"output_type": "stream",
12+
"text": [
13+
"The autoreload extension is already loaded. To reload it, use:\n",
14+
" %reload_ext autoreload\n"
15+
]
16+
}
17+
],
18+
"source": [
19+
"%load_ext autoreload\n",
20+
"%autoreload 2\n",
21+
"\n",
22+
"import manify\n",
23+
"from manify.manifolds import Manifold\n",
24+
"import torch\n",
25+
"import geoopt"
26+
]
27+
},
28+
{
29+
"cell_type": "code",
30+
"execution_count": 75,
31+
"id": "5a128e26",
32+
"metadata": {},
33+
"outputs": [],
34+
"source": [
35+
"\n",
36+
"for curv, dim in [(-1, 2), (0, 2), (1, 2), (-1, 64), (0, 64), (1, 64)]:\n",
37+
" M = Manifold(curvature=curv, dim=dim)\n",
38+
"\n",
39+
" # Does device switching work?\n",
40+
" M.to(\"cpu\")\n",
41+
"\n",
42+
" # Do attributes work correctly?\n",
43+
" if curv < 0:\n",
44+
" assert M.type == \"H\" and isinstance(M.manifold.base, geoopt.Lorentz)\n",
45+
" elif curv == 0:\n",
46+
" assert M.type == \"E\" and isinstance(M.manifold.base, geoopt.Euclidean)\n",
47+
" else:\n",
48+
" assert M.type == \"S\" and isinstance(M.manifold.base, geoopt.Sphere)\n",
49+
"\n",
50+
" # get some vectors via gaussian mixture\n",
51+
" cov = torch.eye(M.dim) / M.dim / 100\n",
52+
" means = torch.vstack([M.mu0] * 10)\n",
53+
" covs = torch.stack([cov] * 10)\n",
54+
" X1, _ = M.sample(z_mean=means, sigma=covs)\n",
55+
" X2, _ = M.sample(z_mean=means[:5], sigma=covs[:5])\n",
56+
"\n",
57+
" # Verify points are on manifold\n",
58+
" assert M.manifold.check_point(X1), \"X1 is not on the manifold\"\n",
59+
" assert M.manifold.check_point(X2), \"X2 is not on the manifold\"\n",
60+
"\n",
61+
" # Inner products\n",
62+
" ip_11 = M.inner(X1, X1)\n",
63+
" assert ip_11.shape == (10, 10), \"Inner product shape mismatch for X1\"\n",
64+
" ip_12 = M.inner(X1, X2)\n",
65+
" assert ip_12.shape == (10, 5), \"Inner product shape mismatch for X1 and X2\"\n",
66+
" if curv == 0:\n",
67+
" assert torch.allclose(ip_11, X1 @ X1.T), \"Euclidean inner products do not match for X1\"\n",
68+
" assert torch.allclose(ip_12, X1 @ X2.T), \"Euclidean inner products do not match for X1 and X2\"\n",
69+
"\n",
70+
" # Dists\n",
71+
" dists_11 = M.dist(X1, X1)\n",
72+
" assert dists_11.shape == (10, 10), \"Distance shape mismatch for X1\"\n",
73+
" dists_12 = M.dist(X1, X2)\n",
74+
" assert dists_12.shape == (10, 5), \"Distance shape mismatch for X1 and X2\"\n",
75+
" if curv == 0:\n",
76+
" assert torch.allclose(\n",
77+
" dists_12, torch.linalg.norm(X1[:, None] - X2[None, :], dim=-1)\n",
78+
" ), \"Euclidean distances do not match for X1 and X2\"\n",
79+
" assert torch.allclose(\n",
80+
" dists_11, torch.linalg.norm(X1[:, None] - X1[None, :], dim=-1)\n",
81+
" ), \"Euclidean distances do not match for X1\"\n",
82+
" assert (dists_11.triu(1) >= 0).all(), \"Distances for X1 should be non-negative\"\n",
83+
" assert (dists_12.triu(1) >= 0).all(), \"Distances for X2 should be non-negative\"\n",
84+
" assert torch.allclose(dists_11.triu(1), M.pdist(X1).triu(1)), \"dist and pdist diverge for X1\"\n",
85+
"\n",
86+
" # Square dists\n",
87+
" sqdists_11 = M.dist2(X1, X1)\n",
88+
" assert sqdists_11.shape == (10, 10), \"Squared distance shape mismatch for X1\"\n",
89+
" sqdists_12 = M.dist2(X1, X2)\n",
90+
" assert sqdists_12.shape == (10, 5), \"Squared distance shape mismatch for X1 and X2\"\n",
91+
" if curv == 0:\n",
92+
" assert torch.allclose(\n",
93+
" sqdists_12, torch.linalg.norm(X1[:, None] - X2[None, :], dim=-1) ** 2\n",
94+
" ), \"Euclidean squared distances do not match for X1 and X2\"\n",
95+
" assert torch.allclose(\n",
96+
" sqdists_11, torch.linalg.norm(X1[:, None] - X1[None, :], dim=-1) ** 2\n",
97+
" ), \"Euclidean squared distances do not match for X1\"\n",
98+
" assert (sqdists_11.triu(1) >= 0).all(), \"Squared distances for X1 should be non-negative\"\n",
99+
" assert (sqdists_12.triu(1) >= 0).all(), \"Squared distances for X1 and X2 should be non-negative\"\n",
100+
" assert torch.allclose(sqdists_11.triu(1), M.pdist2(X1).triu(1)), \"sqdists_11 and pdist2 diverge for X1\"\n",
101+
"\n",
102+
" # Log-likelihood\n",
103+
" lls = M.log_likelihood(X1)\n",
104+
" if curv == 0:\n",
105+
" # Evaluate as ll of gaussian with mean 0, variance 1:\n",
106+
" assert torch.allclose(\n",
107+
" lls,\n",
108+
" -0.5 * (torch.sum(X1**2, dim=-1) + X1.size(-1) * math.log(2 * math.pi)),\n",
109+
" ), \"Log-likelihood mismatch for Gaussian\"\n",
110+
" assert (lls <= 0).all(), \"Log-likelihood should be non-positive\"\n",
111+
"\n",
112+
" # Logmap and expmap\n",
113+
" logmap_x1 = M.logmap(X1)\n",
114+
" assert M.manifold.check_vector(logmap_x1), \"Logmap point should be in the tangent plane\"\n",
115+
" expmap_x1 = M.expmap(logmap_x1)\n",
116+
" assert M.manifold.check_point(expmap_x1), \"Expmap point should be on the manifold\"\n",
117+
" assert torch.allclose(expmap_x1, X1, atol=1e-5), \"Expmap does not return the original points\"\n",
118+
"\n",
119+
" # Stereographic conversions\n",
120+
" M_stereo, X1_stereo, X2_stereo = M.stereographic(X1, X2)\n",
121+
" assert M_stereo.is_stereographic\n",
122+
" X_inv_stereo, X1_inv_stereo, X2_inv_stereo = M_stereo.inverse_stereographic(X1_stereo, X2_stereo)\n",
123+
" assert not X_inv_stereo.is_stereographic\n",
124+
" assert torch.allclose(X1_inv_stereo, X1), \"Inverse stereographic conversion mismatch for X1\"\n",
125+
" assert torch.allclose(X2_inv_stereo, X2), \"Inverse stereographic conversion mismatch for X2\"\n",
126+
"\n",
127+
" # Apply\n",
128+
" @M.apply\n",
129+
" def apply_function(x):\n",
130+
" return torch.nn.functional.relu(x)\n",
131+
"\n",
132+
" result = apply_function(X1)\n",
133+
" assert result.shape == X1.shape, \"Result shape mismatch for apply_function\"\n",
134+
" assert M.manifold.check_point(result)"
135+
]
136+
},
137+
{
138+
"cell_type": "code",
139+
"execution_count": 72,
140+
"id": "84491262",
141+
"metadata": {},
142+
"outputs": [
143+
{
144+
"data": {
145+
"text/plain": [
146+
"tensor([[ 1.0011, 0.0458, 0.0055],\n",
147+
" [ 1.0001, -0.0142, -0.0087],\n",
148+
" [ 1.0121, 0.1557, 0.0073],\n",
149+
" [ 1.0099, -0.0979, 0.1019],\n",
150+
" [ 1.0033, 0.0339, 0.0737],\n",
151+
" [ 1.0008, 0.0300, 0.0255],\n",
152+
" [ 1.0006, 0.0211, 0.0289],\n",
153+
" [ 1.0040, -0.0701, -0.0553],\n",
154+
" [ 1.0160, 0.1332, -0.1208],\n",
155+
" [ 1.0026, 0.0174, 0.0700]], grad_fn=<CatBackward0>)"
156+
]
157+
},
158+
"execution_count": 72,
159+
"metadata": {},
160+
"output_type": "execute_result"
161+
}
162+
],
163+
"source": [
164+
"expmap_x1"
165+
]
166+
},
167+
{
168+
"cell_type": "markdown",
169+
"id": "be661a96",
170+
"metadata": {},
171+
"source": []
172+
},
173+
{
174+
"cell_type": "code",
175+
"execution_count": 73,
176+
"id": "a3edd57e",
177+
"metadata": {},
178+
"outputs": [
179+
{
180+
"data": {
181+
"text/plain": [
182+
"tensor([[ 1.0011, 0.0458, 0.0055],\n",
183+
" [ 1.0001, -0.0142, -0.0087],\n",
184+
" [ 1.0121, 0.1557, 0.0073],\n",
185+
" [ 1.0099, -0.0979, 0.1019],\n",
186+
" [ 1.0033, 0.0339, 0.0737],\n",
187+
" [ 1.0008, 0.0300, 0.0255],\n",
188+
" [ 1.0006, 0.0211, 0.0289],\n",
189+
" [ 1.0040, -0.0701, -0.0553],\n",
190+
" [ 1.0160, 0.1332, -0.1208],\n",
191+
" [ 1.0026, 0.0174, 0.0700]], grad_fn=<CatBackward0>)"
192+
]
193+
},
194+
"execution_count": 73,
195+
"metadata": {},
196+
"output_type": "execute_result"
197+
}
198+
],
199+
"source": [
200+
"X1"
201+
]
202+
},
203+
{
204+
"cell_type": "code",
205+
"execution_count": 59,
206+
"id": "2ea982c6",
207+
"metadata": {},
208+
"outputs": [],
209+
"source": [
210+
"# make a stack of (10, 2, 2) from this\n",
211+
"my_stack = torch.stack([cov] * 10, dim=0) # create a stack of 10 copies of cov"
212+
]
213+
},
214+
{
215+
"cell_type": "code",
216+
"execution_count": 66,
217+
"id": "8cb6c755",
218+
"metadata": {},
219+
"outputs": [
220+
{
221+
"data": {
222+
"text/plain": [
223+
"torch.Size([1, 10, 3])"
224+
]
225+
},
226+
"execution_count": 66,
227+
"metadata": {},
228+
"output_type": "execute_result"
229+
}
230+
],
231+
"source": [
232+
"torch.stack([M.mu0] * 10, dim=1).shape"
233+
]
234+
},
235+
{
236+
"cell_type": "code",
237+
"execution_count": 67,
238+
"id": "585ed32e",
239+
"metadata": {},
240+
"outputs": [
241+
{
242+
"data": {
243+
"text/plain": [
244+
"tensor([[1., 0., 0.]])"
245+
]
246+
},
247+
"execution_count": 67,
248+
"metadata": {},
249+
"output_type": "execute_result"
250+
}
251+
],
252+
"source": [
253+
"M.mu0"
254+
]
255+
},
256+
{
257+
"cell_type": "code",
258+
"execution_count": 70,
259+
"id": "72cdf03f",
260+
"metadata": {},
261+
"outputs": [
262+
{
263+
"data": {
264+
"text/plain": [
265+
"tensor([[1., 0., 0.],\n",
266+
" [1., 0., 0.],\n",
267+
" [1., 0., 0.],\n",
268+
" [1., 0., 0.],\n",
269+
" [1., 0., 0.],\n",
270+
" [1., 0., 0.],\n",
271+
" [1., 0., 0.],\n",
272+
" [1., 0., 0.],\n",
273+
" [1., 0., 0.],\n",
274+
" [1., 0., 0.]])"
275+
]
276+
},
277+
"execution_count": 70,
278+
"metadata": {},
279+
"output_type": "execute_result"
280+
}
281+
],
282+
"source": [
283+
"torch.vstack([M.mu0] * 10)"
284+
]
285+
},
286+
{
287+
"cell_type": "code",
288+
"execution_count": null,
289+
"id": "46564821",
290+
"metadata": {},
291+
"outputs": [],
292+
"source": []
293+
}
294+
],
295+
"metadata": {
296+
"kernelspec": {
297+
"display_name": "manify",
298+
"language": "python",
299+
"name": "python3"
300+
},
301+
"language_info": {
302+
"codemirror_mode": {
303+
"name": "ipython",
304+
"version": 3
305+
},
306+
"file_extension": ".py",
307+
"mimetype": "text/x-python",
308+
"name": "python",
309+
"nbconvert_exporter": "python",
310+
"pygments_lexer": "ipython3",
311+
"version": "3.10.0"
312+
}
313+
},
314+
"nbformat": 4,
315+
"nbformat_minor": 5
316+
}

0 commit comments

Comments
 (0)