Skip to content

Commit 39e8b4f

Browse files
committed
Linting changes; readthedocs setup
1 parent 4d51ad6 commit 39e8b4f

File tree

3 files changed

+100
-107
lines changed

3 files changed

+100
-107
lines changed
File renamed without changes.

manify/manifolds.py

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,7 @@ def _to_tangent_plane_mu0(
185185
x = torch.Tensor(x).reshape(-1, self.dim)
186186
if self.type == "E":
187187
return x
188-
else:
189-
return torch.cat([torch.zeros((x.shape[0], 1), device=self.device), x], dim=1)
188+
return torch.cat([torch.zeros((x.shape[0], 1), device=self.device), x], dim=1)
190189

191190
def sample(
192191
self,
@@ -277,31 +276,30 @@ def log_likelihood(
277276
if self.type == "E":
278277
return torch.distributions.MultivariateNormal(mu, sigma).log_prob(z)
279278

280-
else:
281-
u = self.manifold.logmap(x=mu, y=z) # Map z to tangent space at mu
282-
v = self.manifold.transp(x=mu, y=self.mu0, v=u) # Parallel transport to origin
283-
# assert torch.allclose(v[:, 0], torch.Tensor([0.])) # For tangent vectors at origin this should be true
284-
# OK, so this assertion doesn't actually pass, but it's spiritually true
285-
if torch.isnan(v).any():
286-
print("NANs in parallel transport")
287-
v = torch.nan_to_num(v, nan=0.0)
288-
N = torch.distributions.MultivariateNormal(torch.zeros(self.dim, device=self.device), sigma)
289-
ll = N.log_prob(v[:, 1:])
290-
291-
# For convenience
292-
R = self.scale
293-
n = self.dim
294-
295-
# Final formula (epsilon to avoid log(0))
296-
if self.type == "S":
297-
sin_M = torch.sin
298-
u_norm = self.manifold.norm(x=mu, u=u)
279+
u = self.manifold.logmap(x=mu, y=z) # Map z to tangent space at mu
280+
v = self.manifold.transp(x=mu, y=self.mu0, v=u) # Parallel transport to origin
281+
# assert torch.allclose(v[:, 0], torch.Tensor([0.])) # For tangent vectors at origin this should be true
282+
# OK, so this assertion doesn't actually pass, but it's spiritually true
283+
if torch.isnan(v).any():
284+
print("NANs in parallel transport")
285+
v = torch.nan_to_num(v, nan=0.0)
286+
N = torch.distributions.MultivariateNormal(torch.zeros(self.dim, device=self.device), sigma)
287+
ll = N.log_prob(v[:, 1:])
288+
289+
# For convenience
290+
R = self.scale
291+
n = self.dim
292+
293+
# Final formula (epsilon to avoid log(0))
294+
if self.type == "S":
295+
sin_M = torch.sin
296+
u_norm = self.manifold.norm(x=mu, u=u)
299297

300-
else:
301-
sin_M = torch.sinh
302-
u_norm = self.manifold.base.norm(u=u) # Horrible workaround needed for geoopt bug # type: ignore
298+
else:
299+
sin_M = torch.sinh
300+
u_norm = self.manifold.base.norm(u=u) # Horrible workaround needed for geoopt bug # type: ignore
303301

304-
return ll - (n - 1) * torch.log(R * torch.abs(sin_M(u_norm / R) / u_norm) + 1e-8)
302+
return ll - (n - 1) * torch.log(R * torch.abs(sin_M(u_norm / R) / u_norm) + 1e-8)
305303

306304
def logmap(
307305
self, x: Float[torch.Tensor, "n_points n_dim"], base: Optional[Float[torch.Tensor, "n_points n_dim"]] = None
@@ -366,7 +364,7 @@ def stereographic(self, *points: Float[torch.Tensor, "n_points n_dim"]) -> Tuple
366364
for X in denom:
367365
X[X.abs() < 1e-6] = 1e-6 # Avoid division by zero
368366
stereo_points = [n / d for n, d in zip(num, denom)]
369-
assert all([stereo_manifold.manifold.check_point(X) for X in stereo_points])
367+
assert all(stereo_manifold.manifold.check_point(X) for X in stereo_points)
370368

371369
return stereo_manifold, *stereo_points
372370

@@ -579,7 +577,7 @@ def sample(
579577
for M, sigma in zip(self.P, sigma_factorized)
580578
]
581579

582-
assert sum([sigma.shape == (n, M.dim, M.dim) for M, sigma in zip(self.P, sigma_factorized)]) == len(self.P)
580+
assert all(sigma.shape == (n, M.dim, M.dim) for M, sigma in zip(self.P, sigma_factorized))
583581
assert z_mean.shape[-1] == self.ambient_dim
584582

585583
# Sample initial vector from N(0, sigma)
@@ -637,7 +635,7 @@ def stereographic(self, *points: Float[torch.Tensor, "n_points n_dim"]) -> Tuple
637635
stereo_points = [
638636
torch.hstack([M.stereographic(x)[1] for x, M in zip(self.factorize(X), self.P)]) for X in points
639637
]
640-
assert all([stereo_manifold.manifold.check_point(X) for X in stereo_points])
638+
assert all(stereo_manifold.manifold.check_point(X) for X in stereo_points)
641639

642640
return stereo_manifold, *stereo_points
643641

@@ -652,7 +650,7 @@ def inverse_stereographic(self, *points: Float[torch.Tensor, "n_points n_dim_ste
652650
orig_points = [
653651
torch.hstack([M.inverse_stereographic(x)[1] for x, M in zip(self.factorize(X), self.P)]) for X in points
654652
]
655-
assert all([orig_manifold.manifold.check_point(X) for X in orig_points])
653+
assert all(orig_manifold.manifold.check_point(X) for X in orig_points)
656654

657655
return orig_manifold, *orig_points
658656

0 commit comments

Comments
 (0)