Skip to content

Commit 06cfb3a

Browse files
Raiyan KhanRaiyan Khan
authored andcommitted
Mixed-Curvature Transformers implementation
1 parent ebc6410 commit 06cfb3a

File tree

1 file changed

+231
-13
lines changed

1 file changed

+231
-13
lines changed

manify/predictors/nn/layers.py

Lines changed: 231 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
if TYPE_CHECKING:
1212
from beartype.typing import Callable
13-
from jaxtyping import Float
13+
from jaxtyping import Float, Union, List
1414

1515
from ...manifolds import Manifold, ProductManifold
1616

@@ -340,41 +340,259 @@ def forward(
340340

341341

342342
class StereographicLayerNorm(nn.Module):
343-
"""Stereographic Layer Normalization."""
343+
"""Stereographic Layer Normalization.
344+
345+
Args:
346+
manifold: Manifold or ProductManifold object defining the geometry.
347+
embedding_dim: Embedding dimension of the input points.
348+
curvatures: Tensor of shape [num_heads, 1, 1] representing the curvature
349+
value used per head in geometric computations.
350+
351+
Attributes:
352+
manifold: The manifold object for geometric operations.
353+
stereographic_norm: Stereographic layernorm used in the tangent space.
354+
curvatures: Tensor of shape [num_heads, 1, 1] representing the curvature
355+
value used per head in geometric computations.
356+
"""
357+
358+
def __init__(self, manifold: Manifold | ProductManifold, embedding_dim: int, curvatures: torch.Tensor["num_heads 1 1"]):
359+
super().__init__()
344360

345-
def __init__(self, manifold: Manifold | ProductManifold, num_heads: int):
346-
raise NotImplementedError
361+
self.manifold = manifold
362+
self.stereographic_norm = self.manifold.apply(nn.LayerNorm(embedding_dim))
363+
self.curvatures = curvatures
347364

348365
def forward(self, X: Float[torch.Tensor, "n_nodes dim"]) -> Float[torch.Tensor, "n_nodes dim"]:
349366
"""Apply layer normalization on the stereographic manifold."""
350-
raise NotImplementedError
367+
368+
norm_X = self.stereographic_norm(X)
369+
output = geoopt.manifolds.stereographic.math.project(norm_X, self.curvatures)
370+
return output
371+
372+
373+
class GeometricLinearizedAttention(nn.Module):
374+
"""Geometric Linearized Attention.
375+
376+
Args:
377+
curvatures: Tensor of shape [num_heads, 1, 1] representing the curvature
378+
value used per head in geometric computations.
379+
num_heads: Number of attention heads.
380+
head_dim: Dimension of each attention head.
381+
382+
Attributes:
383+
num_heads: Number of attention heads.
384+
head_dim: Dimension of each attention head.
385+
epsilon: Small epsilon for masking inverse denominator (constant).
386+
clamp_epsilon: Minimum clamp value for numerical stability in gamma denominator (constant).
387+
"""
388+
def __init__(self, curvatures: Union[float, List[float]], num_heads: int, head_dim: int):
389+
390+
super().__init__()
391+
392+
self.num_heads = num_heads
393+
self.curvatures = curvatures
394+
395+
self.head_dim = head_dim
396+
self._epsilon = 1e-5
397+
self._clamp_epsilon = 1e-10
398+
399+
def forward(
400+
self,
401+
Q: Float[torch.Tensor, "batch_size num_heads n_nodes head_dim"],
402+
K: Float[torch.Tensor, "batch_size num_heads n_nodes head_dim"],
403+
V: Float[torch.Tensor, "batch_size num_heads n_nodes head_dim"],
404+
mask: Float[torch.Tensor, "1 1 n_nodes n_nodes"]
405+
) -> Float[torch.Tensor, "batch_size n_nodes dim"]:
406+
407+
v1 = geoopt.manifolds.stereographic.math.parallel_transport0back(V, Q, k=self.curvatures)
408+
v2 = geoopt.manifolds.stereographic.math.parallel_transport0back(V, K, k=self.curvatures)
409+
410+
gamma = geoopt.manifolds.stereographic.math.lambda_x(x=V, k=self.curvatures, keepdim=True, dim=-1)
411+
denominator = geoopt.utils.clamp_abs((gamma - 1), self._clamp_epsilon)
412+
413+
x = ((gamma / denominator) * V) * mask[None, :, None]
414+
415+
v1 = (nn.functional.elu(v1) + 1)
416+
v2 = (denominator * (nn.functional.elu(v2) + 1)) * mask[None, :, None]
417+
418+
# Linearized approximation
419+
v2_cumsum = v2.sum(dim=-2) # [B, H, D]
420+
D = torch.einsum('...nd,...d->...n', v1, v2_cumsum.type_as(v1)) # normalization terms
421+
D_inv = 1./D.masked_fill_(D == 0, self._epsilon)
422+
context = torch.einsum('...nd,...ne->...de', v2, x)
423+
X = torch.einsum('...de,...nd,...n->...ne', context, v1, D_inv)
424+
425+
426+
X = geoopt.manifolds.stereographic.math.project(X, k=self.curvatures)
427+
X = geoopt.manifolds.stereographic.math.mobius_scalar_mul(torch.tensor(0.5, dtype=X.dtype, device=X.device),
428+
X,
429+
k=self.curvatures,
430+
dim=-1)
431+
X = geoopt.manifolds.stereographic.math.project(X, k=self.curvatures)
432+
433+
return X
351434

352435

353436
class StereographicAttention(nn.Module):
354-
"""Stereographic Attention Layer."""
437+
"""Stereographic Attention Layer.
438+
439+
Args:
440+
manifold: Manifold or ProductManifold object defining the geometry.
441+
num_heads: Number of attention heads.
442+
dim: Embedding dimension of the input points.
443+
head_dim: Dimension of each attention head.
444+
445+
Attributes:
446+
manifold: The manifold object for geometric operations.
447+
curvatures: Tensor of shape [num_heads, 1, 1] representing the curvature
448+
value used per head in geometric computations.
449+
num_heads: Number of attention heads.
450+
head_dim: Dimensionality of each attention head.
451+
W_q: Linear layer projecting inputs to query vectors.
452+
W_k: Linear layer projecting inputs to key vectors.
453+
W_v: Manifold-aware linear layer projecting to value vectors.
454+
attn: Stereographic multi-head attention module.
455+
ff: Manifold-aware linear layer for the feedforward output.
456+
"""
355457

356-
def __init__(self, manifold: Manifold | ProductManifold, num_heads: int):
357-
raise NotImplementedError
458+
def __init__(self, manifold: Manifold | ProductManifold, num_heads: int, dim: int, head_dim: int):
459+
super().__init__()
460+
461+
self.manifold = manifold
462+
self.num_heads = num_heads
463+
self.head_dim = head_dim
464+
self.curvatures = _reshape_curvatures(_get_curvatures(self.manifold), self.num_heads)
465+
466+
self.W_q = nn.Linear(in_features=dim, out_features=self.num_heads*self.head_dim)
467+
self.W_k = nn.Linear(in_features=dim, out_features=self.num_heads*self.head_dim)
468+
self.W_v = KappaGCNLayer(in_features=dim,
469+
out_features=self.num_heads*self.head_dim,
470+
manifold=self.manifold)
471+
472+
self.attn = GeometricLinearizedAttention(curvatures=self.curvatures,
473+
num_heads=self.num_heads,
474+
head_dim=self.head_dim)
475+
self.ff = KappaGCNLayer(in_features=self.num_heads*self.head_dim,
476+
out_features=dim,
477+
manifold=self.manifold)
358478

359479
def forward(
360480
self,
361481
X: Float[torch.Tensor, "n_nodes dim"],
362482
mask: Float[torch.Tensor, "n_nodes n_nodes"] | None = None,
363483
) -> Float[torch.Tensor, "n_nodes dim"]:
364484
"""Forward pass for the stereographic attention layer."""
365-
raise NotImplementedError
485+
Q = self._split_heads(self.W_q(X)) # [B, H, N, D]
486+
K = self._split_heads(self.W_k(X))
487+
V = self._split_heads(self.W_v(X=X))
488+
489+
attn_out = self.attn(Q, K, V, mask.unsqueeze(0).unsqueeze(0))
490+
attn_out = self._combine_heads(attn_out)
491+
492+
out = self.ff(X=attn_out)
493+
494+
return out
495+
496+
def _combine_heads(
497+
self,
498+
X: Float[torch.Tensor, "n_nodes num_heads head_dim"]
499+
) -> Float[torch.Tensor, "n_nodes num_heads * head_dim"]:
500+
"""Combines multi-head tensor by merging head and feature dimensions."""
501+
502+
X = X.transpose(0, 1)
503+
X = X.reshape(X.size(0), self.num_heads * self.head_dim)
504+
return X
505+
506+
def _split_heads(
507+
self,
508+
X: Float[torch.Tensor, "n_nodes num_heads * head_dim"]
509+
) -> Float[torch.Tensor, "num_heads n_nodes head_dim"]:
510+
"""
511+
Splits the last dimension of the input into (num_heads, head_dim) and transposes
512+
to prepare for multi-head attention computation.
513+
"""
514+
X = X.reshape(X.size(0), self.num_heads, self.head_dim)
515+
X = X.transpose(0, 1)
516+
return X
366517

367518

368519
class StereographicTransformer(nn.Module):
369-
"""Stereographic Transformer Block."""
520+
"""Stereographic Transformer Block.
521+
522+
Args:
523+
manifold: Manifold or ProductManifold object defining the geometry.
524+
num_heads: Number of attention heads.
525+
dim: Dimensionality of the input features.
526+
head_dim: Dimensionality of each attention head.
527+
use_layer_norm: Whether to apply layer normalization in tangent space.
370528
371-
def __init__(self, manifold: Manifold | ProductManifold, num_blocks: int, num_heads: int):
372-
raise NotImplementedError
529+
Attributes:
530+
manifold: The manifold object for geometric operations.
531+
curvatures: Manifold curvatures reshaped to [num_heads, 1, 1] for broadcasting.
532+
mha: Multi-head stereographic attention module.
533+
norm1: First normalization layer (can be Identity or StereographicLayerNorm).
534+
norm2: Second normalization layer.
535+
mlpblock: Feedforward network in stereographic space.
536+
stereographic_activation: Activation wrapped to operate in tangent space.
537+
"""
538+
539+
def __init__(self, manifold: Manifold | ProductManifold, num_heads: int, dim: int, head_dim: int, use_layer_norm: bool = True):
540+
super(StereographicTransformer, self).__init__()
541+
542+
# Check that manifold is stereographic
543+
if not manifold.is_stereographic:
544+
raise ValueError(
545+
"Manifold must be stereographic for StereographicLayerNorm to work. Please use manifold.stereographic() to convert."
546+
)
547+
548+
self.manifold = manifold
549+
self.curvatures = _reshape_curvatures(_get_curvatures(self.manifold), num_heads)
550+
self.stereographic_activation = self.manifold.apply(nn.ReLU())
551+
self.mha = StereographicAttention(manifold=self.manifold,
552+
num_heads=num_heads,
553+
dim=dim,
554+
head_dim=head_dim)
555+
556+
if use_layer_norm:
557+
self.norm1 = StereographicLayerNorm(manifold=self.manifold, embedding_dim=dim, curvatures=self.curvatures)
558+
self.norm2 = StereographicLayerNorm(manifold=self.manifold, embedding_dim=dim, curvatures=self.curvatures)
559+
else:
560+
self.norm1 = nn.Identity()
561+
self.norm2 = nn.Identity()
562+
563+
self.mlpblock = nn.Sequential(KappaGCNLayer(in_features=dim, out_features=dim, manifold=self.manifold),
564+
self.stereographic_activation,
565+
KappaGCNLayer(in_features=dim, out_features=dim, manifold=self.manifold)
566+
)
373567

374568
def forward(
375569
self,
376570
X: Float[torch.Tensor, "n_nodes dim"],
377571
mask: Float[torch.Tensor, "n_nodes n_nodes"] | None = None,
378572
) -> Float[torch.Tensor, "n_nodes dim"]:
379573
"""Forward pass through the stereographic transformer block."""
380-
raise NotImplementedError
574+
575+
X = geoopt.manifolds.stereographic.math.mobius_add(self.mha(self.norm1(X), mask), X, self.curvatures)
576+
X = geoopt.manifolds.stereographic.math.project(X, self.curvatures)
577+
X = geoopt.manifolds.stereographic.math.mobius_add(self.mlpblock(self.norm2(X)), X, self.curvatures)
578+
X = geoopt.manifolds.stereographic.math.project(X, self.curvatures)
579+
580+
return X
581+
582+
583+
def _reshape_curvatures(curvatures: Union[float, List[float]], num_heads: int) -> Float[torch.Tensor, "num_heads 1 1"]:
584+
"""Helper function to reshape curvature(s) for use in multi-head stereographic attention. """
585+
if isinstance(curvatures, float):
586+
output_curvatures = torch.tensor([curvatures] * num_heads, dtype=torch.float)
587+
else:
588+
output_curvatures = torch.tensor(curvatures, dtype=torch.float)
589+
return output_curvatures[:, None, None]
590+
591+
def _get_curvatures(manifold: Union[Manifold, ProductManifold]) -> Union[float, list]:
592+
"""Helper function to retrieve curvature(s) from a Manifold or ProductManifold."""
593+
if isinstance(manifold, ProductManifold):
594+
return manifold.curvatures
595+
elif isinstance(manifold, Manifold):
596+
return manifold.curvature
597+
else:
598+
raise TypeError("Expected a Manifold or ProductManifold class.")

0 commit comments

Comments
 (0)