Skip to content

Commit e1def46

Browse files
committed
Ruff and mypy fixes
1 parent 06cfb3a commit e1def46

File tree

1 file changed

+106
-93
lines changed

1 file changed

+106
-93
lines changed

manify/predictors/nn/layers.py

Lines changed: 106 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
if TYPE_CHECKING:
1212
from beartype.typing import Callable
13-
from jaxtyping import Float, Union, List
13+
from jaxtyping import Float
1414

1515
from ...manifolds import Manifold, ProductManifold
1616

@@ -341,21 +341,23 @@ def forward(
341341

342342
class StereographicLayerNorm(nn.Module):
343343
"""Stereographic Layer Normalization.
344-
344+
345345
Args:
346346
manifold: Manifold or ProductManifold object defining the geometry.
347347
embedding_dim: Embedding dimension of the input points.
348-
curvatures: Tensor of shape [num_heads, 1, 1] representing the curvature
348+
curvatures: Tensor of shape [num_heads, 1, 1] representing the curvature
349349
value used per head in geometric computations.
350350
351-
Attributes:
351+
Attributes:
352352
manifold: The manifold object for geometric operations.
353353
stereographic_norm: Stereographic layernorm used in the tangent space.
354-
curvatures: Tensor of shape [num_heads, 1, 1] representing the curvature
355-
value used per head in geometric computations.
354+
curvatures: Tensor of shape [num_heads, 1, 1] representing the curvature
355+
value used per head in geometric computations.
356356
"""
357357

358-
def __init__(self, manifold: Manifold | ProductManifold, embedding_dim: int, curvatures: torch.Tensor["num_heads 1 1"]):
358+
def __init__(
359+
self, manifold: Manifold | ProductManifold, embedding_dim: int, curvatures: torch.Tensor["num_heads 1 1"]
360+
):
359361
super().__init__()
360362

361363
self.manifold = manifold
@@ -364,152 +366,163 @@ def __init__(self, manifold: Manifold | ProductManifold, embedding_dim: int, cur
364366

365367
def forward(self, X: Float[torch.Tensor, "n_nodes dim"]) -> Float[torch.Tensor, "n_nodes dim"]:
366368
"""Apply layer normalization on the stereographic manifold."""
367-
368369
norm_X = self.stereographic_norm(X)
369370
output = geoopt.manifolds.stereographic.math.project(norm_X, self.curvatures)
370371
return output
371372

372373

373374
class GeometricLinearizedAttention(nn.Module):
374375
"""Geometric Linearized Attention.
375-
376+
376377
Args:
377-
curvatures: Tensor of shape [num_heads, 1, 1] representing the curvature
378+
curvatures: Tensor of shape [num_heads, 1, 1] representing the curvature
378379
value used per head in geometric computations.
379380
num_heads: Number of attention heads.
380381
head_dim: Dimension of each attention head.
381382
382-
Attributes:
383+
Attributes:
383384
num_heads: Number of attention heads.
384385
head_dim: Dimension of each attention head.
385386
epsilon: Small epsilon for masking inverse denominator (constant).
386387
clamp_epsilon: Minimum clamp value for numerical stability in gamma denominator (constant).
387388
"""
388-
def __init__(self, curvatures: Union[float, List[float]], num_heads: int, head_dim: int):
389-
389+
390+
def __init__(self, curvatures: float | list[float], num_heads: int, head_dim: int):
390391
super().__init__()
391392

392393
self.num_heads = num_heads
393-
self.curvatures = curvatures
394-
395-
self.head_dim = head_dim
394+
self.curvatures = curvatures
395+
396+
self.head_dim = head_dim
396397
self._epsilon = 1e-5
397398
self._clamp_epsilon = 1e-10
398-
399+
399400
def forward(
400-
self,
401-
Q: Float[torch.Tensor, "batch_size num_heads n_nodes head_dim"],
401+
self,
402+
Q: Float[torch.Tensor, "batch_size num_heads n_nodes head_dim"],
402403
K: Float[torch.Tensor, "batch_size num_heads n_nodes head_dim"],
403404
V: Float[torch.Tensor, "batch_size num_heads n_nodes head_dim"],
404-
mask: Float[torch.Tensor, "1 1 n_nodes n_nodes"]
405+
mask: Float[torch.Tensor, "1 1 n_nodes n_nodes"],
405406
) -> Float[torch.Tensor, "batch_size n_nodes dim"]:
407+
"""Forward pass for the geometric linearized attention layer.
406408
409+
Args:
410+
Q: Query tensor.
411+
K: Key tensor.
412+
V: Value tensor.
413+
mask: Mask tensor for attention.
414+
415+
Returns:
416+
Output tensor after applying attention.
417+
"""
407418
v1 = geoopt.manifolds.stereographic.math.parallel_transport0back(V, Q, k=self.curvatures)
408419
v2 = geoopt.manifolds.stereographic.math.parallel_transport0back(V, K, k=self.curvatures)
409-
420+
410421
gamma = geoopt.manifolds.stereographic.math.lambda_x(x=V, k=self.curvatures, keepdim=True, dim=-1)
411422
denominator = geoopt.utils.clamp_abs((gamma - 1), self._clamp_epsilon)
412-
423+
413424
x = ((gamma / denominator) * V) * mask[None, :, None]
414-
415-
v1 = (nn.functional.elu(v1) + 1)
425+
426+
v1 = nn.functional.elu(v1) + 1
416427
v2 = (denominator * (nn.functional.elu(v2) + 1)) * mask[None, :, None]
417428

418429
# Linearized approximation
419-
v2_cumsum = v2.sum(dim=-2) # [B, H, D]
420-
D = torch.einsum('...nd,...d->...n', v1, v2_cumsum.type_as(v1)) # normalization terms
421-
D_inv = 1./D.masked_fill_(D == 0, self._epsilon)
422-
context = torch.einsum('...nd,...ne->...de', v2, x)
423-
X = torch.einsum('...de,...nd,...n->...ne', context, v1, D_inv)
424-
425-
430+
v2_cumsum = v2.sum(dim=-2) # [B, H, D]
431+
D = torch.einsum("...nd,...d->...n", v1, v2_cumsum.type_as(v1)) # normalization terms
432+
D_inv = 1.0 / D.masked_fill_(D == 0, self._epsilon)
433+
context = torch.einsum("...nd,...ne->...de", v2, x)
434+
X = torch.einsum("...de,...nd,...n->...ne", context, v1, D_inv)
435+
426436
X = geoopt.manifolds.stereographic.math.project(X, k=self.curvatures)
427-
X = geoopt.manifolds.stereographic.math.mobius_scalar_mul(torch.tensor(0.5, dtype=X.dtype, device=X.device),
428-
X,
429-
k=self.curvatures,
430-
dim=-1)
437+
X = geoopt.manifolds.stereographic.math.mobius_scalar_mul(
438+
torch.tensor(0.5, dtype=X.dtype, device=X.device), X, k=self.curvatures, dim=-1
439+
)
431440
X = geoopt.manifolds.stereographic.math.project(X, k=self.curvatures)
432-
441+
433442
return X
434443

435444

436445
class StereographicAttention(nn.Module):
437446
"""Stereographic Attention Layer.
438-
447+
439448
Args:
440449
manifold: Manifold or ProductManifold object defining the geometry.
441450
num_heads: Number of attention heads.
442451
dim: Embedding dimension of the input points.
443-
head_dim: Dimension of each attention head.
452+
head_dim: Dimension of each attention head.
444453
445-
Attributes:
454+
Attributes:
446455
manifold: The manifold object for geometric operations.
447-
curvatures: Tensor of shape [num_heads, 1, 1] representing the curvature
456+
curvatures: Tensor of shape [num_heads, 1, 1] representing the curvature
448457
value used per head in geometric computations.
449458
num_heads: Number of attention heads.
450459
head_dim: Dimensionality of each attention head.
451460
W_q: Linear layer projecting inputs to query vectors.
452461
W_k: Linear layer projecting inputs to key vectors.
453462
W_v: Manifold-aware linear layer projecting to value vectors.
454463
attn: Stereographic multi-head attention module.
455-
ff: Manifold-aware linear layer for the feedforward output.
464+
ff: Manifold-aware linear layer for the feedforward output.
456465
"""
457466

458467
def __init__(self, manifold: Manifold | ProductManifold, num_heads: int, dim: int, head_dim: int):
459468
super().__init__()
460-
469+
461470
self.manifold = manifold
462471
self.num_heads = num_heads
463472
self.head_dim = head_dim
464473
self.curvatures = _reshape_curvatures(_get_curvatures(self.manifold), self.num_heads)
465-
466-
self.W_q = nn.Linear(in_features=dim, out_features=self.num_heads*self.head_dim)
467-
self.W_k = nn.Linear(in_features=dim, out_features=self.num_heads*self.head_dim)
468-
self.W_v = KappaGCNLayer(in_features=dim,
469-
out_features=self.num_heads*self.head_dim,
470-
manifold=self.manifold)
471-
472-
self.attn = GeometricLinearizedAttention(curvatures=self.curvatures,
473-
num_heads=self.num_heads,
474-
head_dim=self.head_dim)
475-
self.ff = KappaGCNLayer(in_features=self.num_heads*self.head_dim,
476-
out_features=dim,
477-
manifold=self.manifold)
474+
475+
self.W_q = nn.Linear(in_features=dim, out_features=self.num_heads * self.head_dim)
476+
self.W_k = nn.Linear(in_features=dim, out_features=self.num_heads * self.head_dim)
477+
self.W_v = KappaGCNLayer(in_features=dim, out_features=self.num_heads * self.head_dim, manifold=self.manifold)
478+
479+
self.attn = GeometricLinearizedAttention(
480+
curvatures=self.curvatures, num_heads=self.num_heads, head_dim=self.head_dim
481+
)
482+
self.ff = KappaGCNLayer(in_features=self.num_heads * self.head_dim, out_features=dim, manifold=self.manifold)
478483

479484
def forward(
480485
self,
481486
X: Float[torch.Tensor, "n_nodes dim"],
482487
mask: Float[torch.Tensor, "n_nodes n_nodes"] | None = None,
483488
) -> Float[torch.Tensor, "n_nodes dim"]:
484489
"""Forward pass for the stereographic attention layer."""
485-
Q = self._split_heads(self.W_q(X)) # [B, H, N, D]
490+
Q = self._split_heads(self.W_q(X)) # [B, H, N, D]
486491
K = self._split_heads(self.W_k(X))
487492
V = self._split_heads(self.W_v(X=X))
488-
489-
attn_out = self.attn(Q, K, V, mask.unsqueeze(0).unsqueeze(0))
493+
494+
attn_out = self.attn(Q, K, V, mask.unsqueeze(0).unsqueeze(0)) # type: ignore
490495
attn_out = self._combine_heads(attn_out)
491-
496+
492497
out = self.ff(X=attn_out)
493-
498+
494499
return out
495-
500+
496501
def _combine_heads(
497-
self,
498-
X: Float[torch.Tensor, "n_nodes num_heads head_dim"]
502+
self, X: Float[torch.Tensor, "n_nodes num_heads head_dim"]
499503
) -> Float[torch.Tensor, "n_nodes num_heads * head_dim"]:
500-
"""Combines multi-head tensor by merging head and feature dimensions."""
504+
"""Combines multi-head tensor by merging head and feature dimensions.
505+
506+
Args:
507+
X: Input tensor with shape.
501508
509+
Returns:
510+
X: Reshaped tensor with shape (n_nodes, num_heads * head_dim).
511+
"""
502512
X = X.transpose(0, 1)
503513
X = X.reshape(X.size(0), self.num_heads * self.head_dim)
504514
return X
505-
515+
506516
def _split_heads(
507-
self,
508-
X: Float[torch.Tensor, "n_nodes num_heads * head_dim"]
517+
self, X: Float[torch.Tensor, "n_nodes num_heads * head_dim"]
509518
) -> Float[torch.Tensor, "num_heads n_nodes head_dim"]:
510-
"""
511-
Splits the last dimension of the input into (num_heads, head_dim) and transposes
512-
to prepare for multi-head attention computation.
519+
"""Splits the last dimension of the input into (num_heads, head_dim) and transposes to prepare for attention.
520+
521+
Args:
522+
X: Input tensor with shape (n_nodes, num_heads * head_dim).
523+
524+
Returns:
525+
X: Reshaped tensor with shape (num_heads, n_nodes, head_dim).
513526
"""
514527
X = X.reshape(X.size(0), self.num_heads, self.head_dim)
515528
X = X.transpose(0, 1)
@@ -518,7 +531,7 @@ def _split_heads(
518531

519532
class StereographicTransformer(nn.Module):
520533
"""Stereographic Transformer Block.
521-
534+
522535
Args:
523536
manifold: Manifold or ProductManifold object defining the geometry.
524537
num_heads: Number of attention heads.
@@ -533,36 +546,36 @@ class StereographicTransformer(nn.Module):
533546
norm1: First normalization layer (can be Identity or StereographicLayerNorm).
534547
norm2: Second normalization layer.
535548
mlpblock: Feedforward network in stereographic space.
536-
stereographic_activation: Activation wrapped to operate in tangent space.
549+
stereographic_activation: Activation wrapped to operate in tangent space.
537550
"""
538551

539-
def __init__(self, manifold: Manifold | ProductManifold, num_heads: int, dim: int, head_dim: int, use_layer_norm: bool = True):
540-
super(StereographicTransformer, self).__init__()
552+
def __init__(
553+
self, manifold: Manifold | ProductManifold, num_heads: int, dim: int, head_dim: int, use_layer_norm: bool = True
554+
):
555+
super().__init__()
541556

542-
# Check that manifold is stereographic
557+
# Check that manifold is stereographic
543558
if not manifold.is_stereographic:
544559
raise ValueError(
545560
"Manifold must be stereographic for StereographicLayerNorm to work. Please use manifold.stereographic() to convert."
546-
)
547-
561+
)
562+
548563
self.manifold = manifold
549564
self.curvatures = _reshape_curvatures(_get_curvatures(self.manifold), num_heads)
550565
self.stereographic_activation = self.manifold.apply(nn.ReLU())
551-
self.mha = StereographicAttention(manifold=self.manifold,
552-
num_heads=num_heads,
553-
dim=dim,
554-
head_dim=head_dim)
566+
self.mha = StereographicAttention(manifold=self.manifold, num_heads=num_heads, dim=dim, head_dim=head_dim)
555567

556568
if use_layer_norm:
557569
self.norm1 = StereographicLayerNorm(manifold=self.manifold, embedding_dim=dim, curvatures=self.curvatures)
558570
self.norm2 = StereographicLayerNorm(manifold=self.manifold, embedding_dim=dim, curvatures=self.curvatures)
559571
else:
560572
self.norm1 = nn.Identity()
561573
self.norm2 = nn.Identity()
562-
563-
self.mlpblock = nn.Sequential(KappaGCNLayer(in_features=dim, out_features=dim, manifold=self.manifold),
564-
self.stereographic_activation,
565-
KappaGCNLayer(in_features=dim, out_features=dim, manifold=self.manifold)
574+
575+
self.mlpblock = nn.Sequential(
576+
KappaGCNLayer(in_features=dim, out_features=dim, manifold=self.manifold),
577+
self.stereographic_activation,
578+
KappaGCNLayer(in_features=dim, out_features=dim, manifold=self.manifold),
566579
)
567580

568581
def forward(
@@ -571,28 +584,28 @@ def forward(
571584
mask: Float[torch.Tensor, "n_nodes n_nodes"] | None = None,
572585
) -> Float[torch.Tensor, "n_nodes dim"]:
573586
"""Forward pass through the stereographic transformer block."""
574-
575-
X = geoopt.manifolds.stereographic.math.mobius_add(self.mha(self.norm1(X), mask), X, self.curvatures)
587+
X = geoopt.manifolds.stereographic.math.mobius_add(self.mha(self.norm1(X), mask), X, self.curvatures)
576588
X = geoopt.manifolds.stereographic.math.project(X, self.curvatures)
577589
X = geoopt.manifolds.stereographic.math.mobius_add(self.mlpblock(self.norm2(X)), X, self.curvatures)
578-
X = geoopt.manifolds.stereographic.math.project(X, self.curvatures)
579-
590+
X = geoopt.manifolds.stereographic.math.project(X, self.curvatures)
591+
580592
return X
581593

582594

583-
def _reshape_curvatures(curvatures: Union[float, List[float]], num_heads: int) -> Float[torch.Tensor, "num_heads 1 1"]:
584-
"""Helper function to reshape curvature(s) for use in multi-head stereographic attention. """
595+
def _reshape_curvatures(curvatures: float | list[float], num_heads: int) -> Float[torch.Tensor, "num_heads 1 1"]:
596+
"""Helper function to reshape curvature(s) for use in multi-head stereographic attention."""
585597
if isinstance(curvatures, float):
586598
output_curvatures = torch.tensor([curvatures] * num_heads, dtype=torch.float)
587599
else:
588600
output_curvatures = torch.tensor(curvatures, dtype=torch.float)
589601
return output_curvatures[:, None, None]
590602

591-
def _get_curvatures(manifold: Union[Manifold, ProductManifold]) -> Union[float, list]:
603+
604+
def _get_curvatures(manifold: Manifold | ProductManifold) -> float | list[float]:
592605
"""Helper function to retrieve curvature(s) from a Manifold or ProductManifold."""
593606
if isinstance(manifold, ProductManifold):
594607
return manifold.curvatures
595608
elif isinstance(manifold, Manifold):
596609
return manifold.curvature
597610
else:
598-
raise TypeError("Expected a Manifold or ProductManifold class.")
611+
raise TypeError("Expected a Manifold or ProductManifold class.")

0 commit comments

Comments
 (0)