Skip to content

Commit c62dd75

Browse files
authored
Merge pull request #18 from pchlenski/implement_mixed_curv_transformers
Implemented mixed curvature transformers
2 parents a411a5e + e1def46 commit c62dd75

File tree

1 file changed

+243
-12
lines changed

1 file changed

+243
-12
lines changed

manify/predictors/nn/layers.py

Lines changed: 243 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -340,41 +340,272 @@ def forward(
340340

341341

342342
class StereographicLayerNorm(nn.Module):
343-
"""Stereographic Layer Normalization."""
343+
"""Stereographic Layer Normalization.
344344
345-
def __init__(self, manifold: Manifold | ProductManifold, num_heads: int):
346-
raise NotImplementedError
345+
Args:
346+
manifold: Manifold or ProductManifold object defining the geometry.
347+
embedding_dim: Embedding dimension of the input points.
348+
curvatures: Tensor of shape [num_heads, 1, 1] representing the curvature
349+
value used per head in geometric computations.
350+
351+
Attributes:
352+
manifold: The manifold object for geometric operations.
353+
stereographic_norm: Stereographic layernorm used in the tangent space.
354+
curvatures: Tensor of shape [num_heads, 1, 1] representing the curvature
355+
value used per head in geometric computations.
356+
"""
357+
358+
def __init__(
359+
self, manifold: Manifold | ProductManifold, embedding_dim: int, curvatures: torch.Tensor["num_heads 1 1"]
360+
):
361+
super().__init__()
362+
363+
self.manifold = manifold
364+
self.stereographic_norm = self.manifold.apply(nn.LayerNorm(embedding_dim))
365+
self.curvatures = curvatures
347366

348367
def forward(self, X: Float[torch.Tensor, "n_nodes dim"]) -> Float[torch.Tensor, "n_nodes dim"]:
349368
"""Apply layer normalization on the stereographic manifold."""
350-
raise NotImplementedError
369+
norm_X = self.stereographic_norm(X)
370+
output = geoopt.manifolds.stereographic.math.project(norm_X, self.curvatures)
371+
return output
372+
373+
374+
class GeometricLinearizedAttention(nn.Module):
375+
"""Geometric Linearized Attention.
376+
377+
Args:
378+
curvatures: Tensor of shape [num_heads, 1, 1] representing the curvature
379+
value used per head in geometric computations.
380+
num_heads: Number of attention heads.
381+
head_dim: Dimension of each attention head.
382+
383+
Attributes:
384+
num_heads: Number of attention heads.
385+
head_dim: Dimension of each attention head.
386+
epsilon: Small epsilon for masking inverse denominator (constant).
387+
clamp_epsilon: Minimum clamp value for numerical stability in gamma denominator (constant).
388+
"""
389+
390+
def __init__(self, curvatures: float | list[float], num_heads: int, head_dim: int):
391+
super().__init__()
392+
393+
self.num_heads = num_heads
394+
self.curvatures = curvatures
395+
396+
self.head_dim = head_dim
397+
self._epsilon = 1e-5
398+
self._clamp_epsilon = 1e-10
399+
400+
def forward(
401+
self,
402+
Q: Float[torch.Tensor, "batch_size num_heads n_nodes head_dim"],
403+
K: Float[torch.Tensor, "batch_size num_heads n_nodes head_dim"],
404+
V: Float[torch.Tensor, "batch_size num_heads n_nodes head_dim"],
405+
mask: Float[torch.Tensor, "1 1 n_nodes n_nodes"],
406+
) -> Float[torch.Tensor, "batch_size n_nodes dim"]:
407+
"""Forward pass for the geometric linearized attention layer.
408+
409+
Args:
410+
Q: Query tensor.
411+
K: Key tensor.
412+
V: Value tensor.
413+
mask: Mask tensor for attention.
414+
415+
Returns:
416+
Output tensor after applying attention.
417+
"""
418+
v1 = geoopt.manifolds.stereographic.math.parallel_transport0back(V, Q, k=self.curvatures)
419+
v2 = geoopt.manifolds.stereographic.math.parallel_transport0back(V, K, k=self.curvatures)
420+
421+
gamma = geoopt.manifolds.stereographic.math.lambda_x(x=V, k=self.curvatures, keepdim=True, dim=-1)
422+
denominator = geoopt.utils.clamp_abs((gamma - 1), self._clamp_epsilon)
423+
424+
x = ((gamma / denominator) * V) * mask[None, :, None]
425+
426+
v1 = nn.functional.elu(v1) + 1
427+
v2 = (denominator * (nn.functional.elu(v2) + 1)) * mask[None, :, None]
428+
429+
# Linearized approximation
430+
v2_cumsum = v2.sum(dim=-2) # [B, H, D]
431+
D = torch.einsum("...nd,...d->...n", v1, v2_cumsum.type_as(v1)) # normalization terms
432+
D_inv = 1.0 / D.masked_fill_(D == 0, self._epsilon)
433+
context = torch.einsum("...nd,...ne->...de", v2, x)
434+
X = torch.einsum("...de,...nd,...n->...ne", context, v1, D_inv)
435+
436+
X = geoopt.manifolds.stereographic.math.project(X, k=self.curvatures)
437+
X = geoopt.manifolds.stereographic.math.mobius_scalar_mul(
438+
torch.tensor(0.5, dtype=X.dtype, device=X.device), X, k=self.curvatures, dim=-1
439+
)
440+
X = geoopt.manifolds.stereographic.math.project(X, k=self.curvatures)
441+
442+
return X
351443

352444

353445
class StereographicAttention(nn.Module):
354-
"""Stereographic Attention Layer."""
446+
"""Stereographic Attention Layer.
447+
448+
Args:
449+
manifold: Manifold or ProductManifold object defining the geometry.
450+
num_heads: Number of attention heads.
451+
dim: Embedding dimension of the input points.
452+
head_dim: Dimension of each attention head.
355453
356-
def __init__(self, manifold: Manifold | ProductManifold, num_heads: int):
357-
raise NotImplementedError
454+
Attributes:
455+
manifold: The manifold object for geometric operations.
456+
curvatures: Tensor of shape [num_heads, 1, 1] representing the curvature
457+
value used per head in geometric computations.
458+
num_heads: Number of attention heads.
459+
head_dim: Dimensionality of each attention head.
460+
W_q: Linear layer projecting inputs to query vectors.
461+
W_k: Linear layer projecting inputs to key vectors.
462+
W_v: Manifold-aware linear layer projecting to value vectors.
463+
attn: Stereographic multi-head attention module.
464+
ff: Manifold-aware linear layer for the feedforward output.
465+
"""
466+
467+
def __init__(self, manifold: Manifold | ProductManifold, num_heads: int, dim: int, head_dim: int):
468+
super().__init__()
469+
470+
self.manifold = manifold
471+
self.num_heads = num_heads
472+
self.head_dim = head_dim
473+
self.curvatures = _reshape_curvatures(_get_curvatures(self.manifold), self.num_heads)
474+
475+
self.W_q = nn.Linear(in_features=dim, out_features=self.num_heads * self.head_dim)
476+
self.W_k = nn.Linear(in_features=dim, out_features=self.num_heads * self.head_dim)
477+
self.W_v = KappaGCNLayer(in_features=dim, out_features=self.num_heads * self.head_dim, manifold=self.manifold)
478+
479+
self.attn = GeometricLinearizedAttention(
480+
curvatures=self.curvatures, num_heads=self.num_heads, head_dim=self.head_dim
481+
)
482+
self.ff = KappaGCNLayer(in_features=self.num_heads * self.head_dim, out_features=dim, manifold=self.manifold)
358483

359484
def forward(
360485
self,
361486
X: Float[torch.Tensor, "n_nodes dim"],
362487
mask: Float[torch.Tensor, "n_nodes n_nodes"] | None = None,
363488
) -> Float[torch.Tensor, "n_nodes dim"]:
364489
"""Forward pass for the stereographic attention layer."""
365-
raise NotImplementedError
490+
Q = self._split_heads(self.W_q(X)) # [B, H, N, D]
491+
K = self._split_heads(self.W_k(X))
492+
V = self._split_heads(self.W_v(X=X))
493+
494+
attn_out = self.attn(Q, K, V, mask.unsqueeze(0).unsqueeze(0)) # type: ignore
495+
attn_out = self._combine_heads(attn_out)
496+
497+
out = self.ff(X=attn_out)
498+
499+
return out
500+
501+
def _combine_heads(
502+
self, X: Float[torch.Tensor, "n_nodes num_heads head_dim"]
503+
) -> Float[torch.Tensor, "n_nodes num_heads * head_dim"]:
504+
"""Combines multi-head tensor by merging head and feature dimensions.
505+
506+
Args:
507+
X: Input tensor with shape.
508+
509+
Returns:
510+
X: Reshaped tensor with shape (n_nodes, num_heads * head_dim).
511+
"""
512+
X = X.transpose(0, 1)
513+
X = X.reshape(X.size(0), self.num_heads * self.head_dim)
514+
return X
515+
516+
def _split_heads(
517+
self, X: Float[torch.Tensor, "n_nodes num_heads * head_dim"]
518+
) -> Float[torch.Tensor, "num_heads n_nodes head_dim"]:
519+
"""Splits the last dimension of the input into (num_heads, head_dim) and transposes to prepare for attention.
520+
521+
Args:
522+
X: Input tensor with shape (n_nodes, num_heads * head_dim).
523+
524+
Returns:
525+
X: Reshaped tensor with shape (num_heads, n_nodes, head_dim).
526+
"""
527+
X = X.reshape(X.size(0), self.num_heads, self.head_dim)
528+
X = X.transpose(0, 1)
529+
return X
366530

367531

368532
class StereographicTransformer(nn.Module):
369-
"""Stereographic Transformer Block."""
533+
"""Stereographic Transformer Block.
370534
371-
def __init__(self, manifold: Manifold | ProductManifold, num_blocks: int, num_heads: int):
372-
raise NotImplementedError
535+
Args:
536+
manifold: Manifold or ProductManifold object defining the geometry.
537+
num_heads: Number of attention heads.
538+
dim: Dimensionality of the input features.
539+
head_dim: Dimensionality of each attention head.
540+
use_layer_norm: Whether to apply layer normalization in tangent space.
541+
542+
Attributes:
543+
manifold: The manifold object for geometric operations.
544+
curvatures: Manifold curvatures reshaped to [num_heads, 1, 1] for broadcasting.
545+
mha: Multi-head stereographic attention module.
546+
norm1: First normalization layer (can be Identity or StereographicLayerNorm).
547+
norm2: Second normalization layer.
548+
mlpblock: Feedforward network in stereographic space.
549+
stereographic_activation: Activation wrapped to operate in tangent space.
550+
"""
551+
552+
def __init__(
553+
self, manifold: Manifold | ProductManifold, num_heads: int, dim: int, head_dim: int, use_layer_norm: bool = True
554+
):
555+
super().__init__()
556+
557+
# Check that manifold is stereographic
558+
if not manifold.is_stereographic:
559+
raise ValueError(
560+
"Manifold must be stereographic for StereographicLayerNorm to work. Please use manifold.stereographic() to convert."
561+
)
562+
563+
self.manifold = manifold
564+
self.curvatures = _reshape_curvatures(_get_curvatures(self.manifold), num_heads)
565+
self.stereographic_activation = self.manifold.apply(nn.ReLU())
566+
self.mha = StereographicAttention(manifold=self.manifold, num_heads=num_heads, dim=dim, head_dim=head_dim)
567+
568+
if use_layer_norm:
569+
self.norm1 = StereographicLayerNorm(manifold=self.manifold, embedding_dim=dim, curvatures=self.curvatures)
570+
self.norm2 = StereographicLayerNorm(manifold=self.manifold, embedding_dim=dim, curvatures=self.curvatures)
571+
else:
572+
self.norm1 = nn.Identity()
573+
self.norm2 = nn.Identity()
574+
575+
self.mlpblock = nn.Sequential(
576+
KappaGCNLayer(in_features=dim, out_features=dim, manifold=self.manifold),
577+
self.stereographic_activation,
578+
KappaGCNLayer(in_features=dim, out_features=dim, manifold=self.manifold),
579+
)
373580

374581
def forward(
375582
self,
376583
X: Float[torch.Tensor, "n_nodes dim"],
377584
mask: Float[torch.Tensor, "n_nodes n_nodes"] | None = None,
378585
) -> Float[torch.Tensor, "n_nodes dim"]:
379586
"""Forward pass through the stereographic transformer block."""
380-
raise NotImplementedError
587+
X = geoopt.manifolds.stereographic.math.mobius_add(self.mha(self.norm1(X), mask), X, self.curvatures)
588+
X = geoopt.manifolds.stereographic.math.project(X, self.curvatures)
589+
X = geoopt.manifolds.stereographic.math.mobius_add(self.mlpblock(self.norm2(X)), X, self.curvatures)
590+
X = geoopt.manifolds.stereographic.math.project(X, self.curvatures)
591+
592+
return X
593+
594+
595+
def _reshape_curvatures(curvatures: float | list[float], num_heads: int) -> Float[torch.Tensor, "num_heads 1 1"]:
596+
"""Helper function to reshape curvature(s) for use in multi-head stereographic attention."""
597+
if isinstance(curvatures, float):
598+
output_curvatures = torch.tensor([curvatures] * num_heads, dtype=torch.float)
599+
else:
600+
output_curvatures = torch.tensor(curvatures, dtype=torch.float)
601+
return output_curvatures[:, None, None]
602+
603+
604+
def _get_curvatures(manifold: Manifold | ProductManifold) -> float | list[float]:
605+
"""Helper function to retrieve curvature(s) from a Manifold or ProductManifold."""
606+
if isinstance(manifold, ProductManifold):
607+
return manifold.curvatures
608+
elif isinstance(manifold, Manifold):
609+
return manifold.curvature
610+
else:
611+
raise TypeError("Expected a Manifold or ProductManifold class.")

0 commit comments

Comments
 (0)