Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 243 additions & 12 deletions manify/predictors/nn/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,41 +340,272 @@ def forward(


class StereographicLayerNorm(nn.Module):
"""Stereographic Layer Normalization."""
"""Stereographic Layer Normalization.

def __init__(self, manifold: Manifold | ProductManifold, num_heads: int):
raise NotImplementedError
Args:
manifold: Manifold or ProductManifold object defining the geometry.
embedding_dim: Embedding dimension of the input points.
curvatures: Tensor of shape [num_heads, 1, 1] representing the curvature
value used per head in geometric computations.

Attributes:
manifold: The manifold object for geometric operations.
stereographic_norm: Stereographic layernorm used in the tangent space.
curvatures: Tensor of shape [num_heads, 1, 1] representing the curvature
value used per head in geometric computations.
"""

def __init__(
self, manifold: Manifold | ProductManifold, embedding_dim: int, curvatures: torch.Tensor["num_heads 1 1"]
):
super().__init__()

self.manifold = manifold
self.stereographic_norm = self.manifold.apply(nn.LayerNorm(embedding_dim))
self.curvatures = curvatures

def forward(self, X: Float[torch.Tensor, "n_nodes dim"]) -> Float[torch.Tensor, "n_nodes dim"]:
"""Apply layer normalization on the stereographic manifold."""
raise NotImplementedError
norm_X = self.stereographic_norm(X)
output = geoopt.manifolds.stereographic.math.project(norm_X, self.curvatures)
return output


class GeometricLinearizedAttention(nn.Module):
"""Geometric Linearized Attention.

Args:
curvatures: Tensor of shape [num_heads, 1, 1] representing the curvature
value used per head in geometric computations.
num_heads: Number of attention heads.
head_dim: Dimension of each attention head.

Attributes:
num_heads: Number of attention heads.
head_dim: Dimension of each attention head.
epsilon: Small epsilon for masking inverse denominator (constant).
clamp_epsilon: Minimum clamp value for numerical stability in gamma denominator (constant).
"""

def __init__(self, curvatures: float | list[float], num_heads: int, head_dim: int):
super().__init__()

self.num_heads = num_heads
self.curvatures = curvatures

self.head_dim = head_dim
self._epsilon = 1e-5
self._clamp_epsilon = 1e-10

def forward(
self,
Q: Float[torch.Tensor, "batch_size num_heads n_nodes head_dim"],
K: Float[torch.Tensor, "batch_size num_heads n_nodes head_dim"],
V: Float[torch.Tensor, "batch_size num_heads n_nodes head_dim"],
mask: Float[torch.Tensor, "1 1 n_nodes n_nodes"],
) -> Float[torch.Tensor, "batch_size n_nodes dim"]:
"""Forward pass for the geometric linearized attention layer.

Args:
Q: Query tensor.
K: Key tensor.
V: Value tensor.
mask: Mask tensor for attention.

Returns:
Output tensor after applying attention.
"""
v1 = geoopt.manifolds.stereographic.math.parallel_transport0back(V, Q, k=self.curvatures)
v2 = geoopt.manifolds.stereographic.math.parallel_transport0back(V, K, k=self.curvatures)

gamma = geoopt.manifolds.stereographic.math.lambda_x(x=V, k=self.curvatures, keepdim=True, dim=-1)
denominator = geoopt.utils.clamp_abs((gamma - 1), self._clamp_epsilon)

x = ((gamma / denominator) * V) * mask[None, :, None]

v1 = nn.functional.elu(v1) + 1
v2 = (denominator * (nn.functional.elu(v2) + 1)) * mask[None, :, None]

# Linearized approximation
v2_cumsum = v2.sum(dim=-2) # [B, H, D]
D = torch.einsum("...nd,...d->...n", v1, v2_cumsum.type_as(v1)) # normalization terms
D_inv = 1.0 / D.masked_fill_(D == 0, self._epsilon)
context = torch.einsum("...nd,...ne->...de", v2, x)
X = torch.einsum("...de,...nd,...n->...ne", context, v1, D_inv)

X = geoopt.manifolds.stereographic.math.project(X, k=self.curvatures)
X = geoopt.manifolds.stereographic.math.mobius_scalar_mul(
torch.tensor(0.5, dtype=X.dtype, device=X.device), X, k=self.curvatures, dim=-1
)
X = geoopt.manifolds.stereographic.math.project(X, k=self.curvatures)

return X


class StereographicAttention(nn.Module):
"""Stereographic Attention Layer."""
"""Stereographic Attention Layer.

Args:
manifold: Manifold or ProductManifold object defining the geometry.
num_heads: Number of attention heads.
dim: Embedding dimension of the input points.
head_dim: Dimension of each attention head.

def __init__(self, manifold: Manifold | ProductManifold, num_heads: int):
raise NotImplementedError
Attributes:
manifold: The manifold object for geometric operations.
curvatures: Tensor of shape [num_heads, 1, 1] representing the curvature
value used per head in geometric computations.
num_heads: Number of attention heads.
head_dim: Dimensionality of each attention head.
W_q: Linear layer projecting inputs to query vectors.
W_k: Linear layer projecting inputs to key vectors.
W_v: Manifold-aware linear layer projecting to value vectors.
attn: Stereographic multi-head attention module.
ff: Manifold-aware linear layer for the feedforward output.
"""

def __init__(self, manifold: Manifold | ProductManifold, num_heads: int, dim: int, head_dim: int):
super().__init__()

self.manifold = manifold
self.num_heads = num_heads
self.head_dim = head_dim
self.curvatures = _reshape_curvatures(_get_curvatures(self.manifold), self.num_heads)

self.W_q = nn.Linear(in_features=dim, out_features=self.num_heads * self.head_dim)
self.W_k = nn.Linear(in_features=dim, out_features=self.num_heads * self.head_dim)
self.W_v = KappaGCNLayer(in_features=dim, out_features=self.num_heads * self.head_dim, manifold=self.manifold)

self.attn = GeometricLinearizedAttention(
curvatures=self.curvatures, num_heads=self.num_heads, head_dim=self.head_dim
)
self.ff = KappaGCNLayer(in_features=self.num_heads * self.head_dim, out_features=dim, manifold=self.manifold)

def forward(
self,
X: Float[torch.Tensor, "n_nodes dim"],
mask: Float[torch.Tensor, "n_nodes n_nodes"] | None = None,
) -> Float[torch.Tensor, "n_nodes dim"]:
"""Forward pass for the stereographic attention layer."""
raise NotImplementedError
Q = self._split_heads(self.W_q(X)) # [B, H, N, D]
K = self._split_heads(self.W_k(X))
V = self._split_heads(self.W_v(X=X))

attn_out = self.attn(Q, K, V, mask.unsqueeze(0).unsqueeze(0)) # type: ignore
attn_out = self._combine_heads(attn_out)

out = self.ff(X=attn_out)

return out

def _combine_heads(
self, X: Float[torch.Tensor, "n_nodes num_heads head_dim"]
) -> Float[torch.Tensor, "n_nodes num_heads * head_dim"]:
"""Combines multi-head tensor by merging head and feature dimensions.

Args:
X: Input tensor with shape.

Returns:
X: Reshaped tensor with shape (n_nodes, num_heads * head_dim).
"""
X = X.transpose(0, 1)
X = X.reshape(X.size(0), self.num_heads * self.head_dim)
return X

def _split_heads(
self, X: Float[torch.Tensor, "n_nodes num_heads * head_dim"]
) -> Float[torch.Tensor, "num_heads n_nodes head_dim"]:
"""Splits the last dimension of the input into (num_heads, head_dim) and transposes to prepare for attention.

Args:
X: Input tensor with shape (n_nodes, num_heads * head_dim).

Returns:
X: Reshaped tensor with shape (num_heads, n_nodes, head_dim).
"""
X = X.reshape(X.size(0), self.num_heads, self.head_dim)
X = X.transpose(0, 1)
return X


class StereographicTransformer(nn.Module):
"""Stereographic Transformer Block."""
"""Stereographic Transformer Block.

def __init__(self, manifold: Manifold | ProductManifold, num_blocks: int, num_heads: int):
raise NotImplementedError
Args:
manifold: Manifold or ProductManifold object defining the geometry.
num_heads: Number of attention heads.
dim: Dimensionality of the input features.
head_dim: Dimensionality of each attention head.
use_layer_norm: Whether to apply layer normalization in tangent space.

Attributes:
manifold: The manifold object for geometric operations.
curvatures: Manifold curvatures reshaped to [num_heads, 1, 1] for broadcasting.
mha: Multi-head stereographic attention module.
norm1: First normalization layer (can be Identity or StereographicLayerNorm).
norm2: Second normalization layer.
mlpblock: Feedforward network in stereographic space.
stereographic_activation: Activation wrapped to operate in tangent space.
"""

def __init__(
self, manifold: Manifold | ProductManifold, num_heads: int, dim: int, head_dim: int, use_layer_norm: bool = True
):
super().__init__()

# Check that manifold is stereographic
if not manifold.is_stereographic:
raise ValueError(
"Manifold must be stereographic for StereographicLayerNorm to work. Please use manifold.stereographic() to convert."
)

self.manifold = manifold
self.curvatures = _reshape_curvatures(_get_curvatures(self.manifold), num_heads)
self.stereographic_activation = self.manifold.apply(nn.ReLU())
self.mha = StereographicAttention(manifold=self.manifold, num_heads=num_heads, dim=dim, head_dim=head_dim)

if use_layer_norm:
self.norm1 = StereographicLayerNorm(manifold=self.manifold, embedding_dim=dim, curvatures=self.curvatures)
self.norm2 = StereographicLayerNorm(manifold=self.manifold, embedding_dim=dim, curvatures=self.curvatures)
else:
self.norm1 = nn.Identity()
self.norm2 = nn.Identity()

self.mlpblock = nn.Sequential(
KappaGCNLayer(in_features=dim, out_features=dim, manifold=self.manifold),
self.stereographic_activation,
KappaGCNLayer(in_features=dim, out_features=dim, manifold=self.manifold),
)

def forward(
self,
X: Float[torch.Tensor, "n_nodes dim"],
mask: Float[torch.Tensor, "n_nodes n_nodes"] | None = None,
) -> Float[torch.Tensor, "n_nodes dim"]:
"""Forward pass through the stereographic transformer block."""
raise NotImplementedError
X = geoopt.manifolds.stereographic.math.mobius_add(self.mha(self.norm1(X), mask), X, self.curvatures)
X = geoopt.manifolds.stereographic.math.project(X, self.curvatures)
X = geoopt.manifolds.stereographic.math.mobius_add(self.mlpblock(self.norm2(X)), X, self.curvatures)
X = geoopt.manifolds.stereographic.math.project(X, self.curvatures)

return X


def _reshape_curvatures(curvatures: float | list[float], num_heads: int) -> Float[torch.Tensor, "num_heads 1 1"]:
"""Helper function to reshape curvature(s) for use in multi-head stereographic attention."""
if isinstance(curvatures, float):
output_curvatures = torch.tensor([curvatures] * num_heads, dtype=torch.float)
else:
output_curvatures = torch.tensor(curvatures, dtype=torch.float)
return output_curvatures[:, None, None]


def _get_curvatures(manifold: Manifold | ProductManifold) -> float | list[float]:
"""Helper function to retrieve curvature(s) from a Manifold or ProductManifold."""
if isinstance(manifold, ProductManifold):
return manifold.curvatures
elif isinstance(manifold, Manifold):
return manifold.curvature
else:
raise TypeError("Expected a Manifold or ProductManifold class.")