1010
1111if TYPE_CHECKING :
1212 from beartype .typing import Callable
13- from jaxtyping import Float
13+ from jaxtyping import Float , Union , List
1414
1515from ...manifolds import Manifold , ProductManifold
1616
@@ -340,41 +340,259 @@ def forward(
340340
341341
342342class StereographicLayerNorm (nn .Module ):
343- """Stereographic Layer Normalization."""
343+ """Stereographic Layer Normalization.
344+
345+ Args:
346+ manifold: Manifold or ProductManifold object defining the geometry.
347+ embedding_dim: Embedding dimension of the input points.
348+ curvatures: Tensor of shape [num_heads, 1, 1] representing the curvature
349+ value used per head in geometric computations.
350+
351+ Attributes:
352+ manifold: The manifold object for geometric operations.
353+ stereographic_norm: Stereographic layernorm used in the tangent space.
354+ curvatures: Tensor of shape [num_heads, 1, 1] representing the curvature
355+ value used per head in geometric computations.
356+ """
357+
358+ def __init__ (self , manifold : Manifold | ProductManifold , embedding_dim : int , curvatures : torch .Tensor ["num_heads 1 1" ]):
359+ super ().__init__ ()
344360
345- def __init__ (self , manifold : Manifold | ProductManifold , num_heads : int ):
346- raise NotImplementedError
361+ self .manifold = manifold
362+ self .stereographic_norm = self .manifold .apply (nn .LayerNorm (embedding_dim ))
363+ self .curvatures = curvatures
347364
348365 def forward (self , X : Float [torch .Tensor , "n_nodes dim" ]) -> Float [torch .Tensor , "n_nodes dim" ]:
349366 """Apply layer normalization on the stereographic manifold."""
350- raise NotImplementedError
367+
368+ norm_X = self .stereographic_norm (X )
369+ output = geoopt .manifolds .stereographic .math .project (norm_X , self .curvatures )
370+ return output
371+
372+
373+ class GeometricLinearizedAttention (nn .Module ):
374+ """Geometric Linearized Attention.
375+
376+ Args:
377+ curvatures: Tensor of shape [num_heads, 1, 1] representing the curvature
378+ value used per head in geometric computations.
379+ num_heads: Number of attention heads.
380+ head_dim: Dimension of each attention head.
381+
382+ Attributes:
383+ num_heads: Number of attention heads.
384+ head_dim: Dimension of each attention head.
385+ epsilon: Small epsilon for masking inverse denominator (constant).
386+ clamp_epsilon: Minimum clamp value for numerical stability in gamma denominator (constant).
387+ """
388+ def __init__ (self , curvatures : Union [float , List [float ]], num_heads : int , head_dim : int ):
389+
390+ super ().__init__ ()
391+
392+ self .num_heads = num_heads
393+ self .curvatures = curvatures
394+
395+ self .head_dim = head_dim
396+ self ._epsilon = 1e-5
397+ self ._clamp_epsilon = 1e-10
398+
399+ def forward (
400+ self ,
401+ Q : Float [torch .Tensor , "batch_size num_heads n_nodes head_dim" ],
402+ K : Float [torch .Tensor , "batch_size num_heads n_nodes head_dim" ],
403+ V : Float [torch .Tensor , "batch_size num_heads n_nodes head_dim" ],
404+ mask : Float [torch .Tensor , "1 1 n_nodes n_nodes" ]
405+ ) -> Float [torch .Tensor , "batch_size n_nodes dim" ]:
406+
407+ v1 = geoopt .manifolds .stereographic .math .parallel_transport0back (V , Q , k = self .curvatures )
408+ v2 = geoopt .manifolds .stereographic .math .parallel_transport0back (V , K , k = self .curvatures )
409+
410+ gamma = geoopt .manifolds .stereographic .math .lambda_x (x = V , k = self .curvatures , keepdim = True , dim = - 1 )
411+ denominator = geoopt .utils .clamp_abs ((gamma - 1 ), self ._clamp_epsilon )
412+
413+ x = ((gamma / denominator ) * V ) * mask [None , :, None ]
414+
415+ v1 = (nn .functional .elu (v1 ) + 1 )
416+ v2 = (denominator * (nn .functional .elu (v2 ) + 1 )) * mask [None , :, None ]
417+
418+ # Linearized approximation
419+ v2_cumsum = v2 .sum (dim = - 2 ) # [B, H, D]
420+ D = torch .einsum ('...nd,...d->...n' , v1 , v2_cumsum .type_as (v1 )) # normalization terms
421+ D_inv = 1. / D .masked_fill_ (D == 0 , self ._epsilon )
422+ context = torch .einsum ('...nd,...ne->...de' , v2 , x )
423+ X = torch .einsum ('...de,...nd,...n->...ne' , context , v1 , D_inv )
424+
425+
426+ X = geoopt .manifolds .stereographic .math .project (X , k = self .curvatures )
427+ X = geoopt .manifolds .stereographic .math .mobius_scalar_mul (torch .tensor (0.5 , dtype = X .dtype , device = X .device ),
428+ X ,
429+ k = self .curvatures ,
430+ dim = - 1 )
431+ X = geoopt .manifolds .stereographic .math .project (X , k = self .curvatures )
432+
433+ return X
351434
352435
353436class StereographicAttention (nn .Module ):
354- """Stereographic Attention Layer."""
437+ """Stereographic Attention Layer.
438+
439+ Args:
440+ manifold: Manifold or ProductManifold object defining the geometry.
441+ num_heads: Number of attention heads.
442+ dim: Embedding dimension of the input points.
443+ head_dim: Dimension of each attention head.
444+
445+ Attributes:
446+ manifold: The manifold object for geometric operations.
447+ curvatures: Tensor of shape [num_heads, 1, 1] representing the curvature
448+ value used per head in geometric computations.
449+ num_heads: Number of attention heads.
450+ head_dim: Dimensionality of each attention head.
451+ W_q: Linear layer projecting inputs to query vectors.
452+ W_k: Linear layer projecting inputs to key vectors.
453+ W_v: Manifold-aware linear layer projecting to value vectors.
454+ attn: Stereographic multi-head attention module.
455+ ff: Manifold-aware linear layer for the feedforward output.
456+ """
355457
356- def __init__ (self , manifold : Manifold | ProductManifold , num_heads : int ):
357- raise NotImplementedError
458+ def __init__ (self , manifold : Manifold | ProductManifold , num_heads : int , dim : int , head_dim : int ):
459+ super ().__init__ ()
460+
461+ self .manifold = manifold
462+ self .num_heads = num_heads
463+ self .head_dim = head_dim
464+ self .curvatures = _reshape_curvatures (_get_curvatures (self .manifold ), self .num_heads )
465+
466+ self .W_q = nn .Linear (in_features = dim , out_features = self .num_heads * self .head_dim )
467+ self .W_k = nn .Linear (in_features = dim , out_features = self .num_heads * self .head_dim )
468+ self .W_v = KappaGCNLayer (in_features = dim ,
469+ out_features = self .num_heads * self .head_dim ,
470+ manifold = self .manifold )
471+
472+ self .attn = GeometricLinearizedAttention (curvatures = self .curvatures ,
473+ num_heads = self .num_heads ,
474+ head_dim = self .head_dim )
475+ self .ff = KappaGCNLayer (in_features = self .num_heads * self .head_dim ,
476+ out_features = dim ,
477+ manifold = self .manifold )
358478
359479 def forward (
360480 self ,
361481 X : Float [torch .Tensor , "n_nodes dim" ],
362482 mask : Float [torch .Tensor , "n_nodes n_nodes" ] | None = None ,
363483 ) -> Float [torch .Tensor , "n_nodes dim" ]:
364484 """Forward pass for the stereographic attention layer."""
365- raise NotImplementedError
485+ Q = self ._split_heads (self .W_q (X )) # [B, H, N, D]
486+ K = self ._split_heads (self .W_k (X ))
487+ V = self ._split_heads (self .W_v (X = X ))
488+
489+ attn_out = self .attn (Q , K , V , mask .unsqueeze (0 ).unsqueeze (0 ))
490+ attn_out = self ._combine_heads (attn_out )
491+
492+ out = self .ff (X = attn_out )
493+
494+ return out
495+
496+ def _combine_heads (
497+ self ,
498+ X : Float [torch .Tensor , "n_nodes num_heads head_dim" ]
499+ ) -> Float [torch .Tensor , "n_nodes num_heads * head_dim" ]:
500+ """Combines multi-head tensor by merging head and feature dimensions."""
501+
502+ X = X .transpose (0 , 1 )
503+ X = X .reshape (X .size (0 ), self .num_heads * self .head_dim )
504+ return X
505+
506+ def _split_heads (
507+ self ,
508+ X : Float [torch .Tensor , "n_nodes num_heads * head_dim" ]
509+ ) -> Float [torch .Tensor , "num_heads n_nodes head_dim" ]:
510+ """
511+ Splits the last dimension of the input into (num_heads, head_dim) and transposes
512+ to prepare for multi-head attention computation.
513+ """
514+ X = X .reshape (X .size (0 ), self .num_heads , self .head_dim )
515+ X = X .transpose (0 , 1 )
516+ return X
366517
367518
368519class StereographicTransformer (nn .Module ):
369- """Stereographic Transformer Block."""
520+ """Stereographic Transformer Block.
521+
522+ Args:
523+ manifold: Manifold or ProductManifold object defining the geometry.
524+ num_heads: Number of attention heads.
525+ dim: Dimensionality of the input features.
526+ head_dim: Dimensionality of each attention head.
527+ use_layer_norm: Whether to apply layer normalization in tangent space.
370528
371- def __init__ (self , manifold : Manifold | ProductManifold , num_blocks : int , num_heads : int ):
372- raise NotImplementedError
529+ Attributes:
530+ manifold: The manifold object for geometric operations.
531+ curvatures: Manifold curvatures reshaped to [num_heads, 1, 1] for broadcasting.
532+ mha: Multi-head stereographic attention module.
533+ norm1: First normalization layer (can be Identity or StereographicLayerNorm).
534+ norm2: Second normalization layer.
535+ mlpblock: Feedforward network in stereographic space.
536+ stereographic_activation: Activation wrapped to operate in tangent space.
537+ """
538+
539+ def __init__ (self , manifold : Manifold | ProductManifold , num_heads : int , dim : int , head_dim : int , use_layer_norm : bool = True ):
540+ super (StereographicTransformer , self ).__init__ ()
541+
542+ # Check that manifold is stereographic
543+ if not manifold .is_stereographic :
544+ raise ValueError (
545+ "Manifold must be stereographic for StereographicLayerNorm to work. Please use manifold.stereographic() to convert."
546+ )
547+
548+ self .manifold = manifold
549+ self .curvatures = _reshape_curvatures (_get_curvatures (self .manifold ), num_heads )
550+ self .stereographic_activation = self .manifold .apply (nn .ReLU ())
551+ self .mha = StereographicAttention (manifold = self .manifold ,
552+ num_heads = num_heads ,
553+ dim = dim ,
554+ head_dim = head_dim )
555+
556+ if use_layer_norm :
557+ self .norm1 = StereographicLayerNorm (manifold = self .manifold , embedding_dim = dim , curvatures = self .curvatures )
558+ self .norm2 = StereographicLayerNorm (manifold = self .manifold , embedding_dim = dim , curvatures = self .curvatures )
559+ else :
560+ self .norm1 = nn .Identity ()
561+ self .norm2 = nn .Identity ()
562+
563+ self .mlpblock = nn .Sequential (KappaGCNLayer (in_features = dim , out_features = dim , manifold = self .manifold ),
564+ self .stereographic_activation ,
565+ KappaGCNLayer (in_features = dim , out_features = dim , manifold = self .manifold )
566+ )
373567
374568 def forward (
375569 self ,
376570 X : Float [torch .Tensor , "n_nodes dim" ],
377571 mask : Float [torch .Tensor , "n_nodes n_nodes" ] | None = None ,
378572 ) -> Float [torch .Tensor , "n_nodes dim" ]:
379573 """Forward pass through the stereographic transformer block."""
380- raise NotImplementedError
574+
575+ X = geoopt .manifolds .stereographic .math .mobius_add (self .mha (self .norm1 (X ), mask ), X , self .curvatures )
576+ X = geoopt .manifolds .stereographic .math .project (X , self .curvatures )
577+ X = geoopt .manifolds .stereographic .math .mobius_add (self .mlpblock (self .norm2 (X )), X , self .curvatures )
578+ X = geoopt .manifolds .stereographic .math .project (X , self .curvatures )
579+
580+ return X
581+
582+
583+ def _reshape_curvatures (curvatures : Union [float , List [float ]], num_heads : int ) -> Float [torch .Tensor , "num_heads 1 1" ]:
584+ """Helper function to reshape curvature(s) for use in multi-head stereographic attention. """
585+ if isinstance (curvatures , float ):
586+ output_curvatures = torch .tensor ([curvatures ] * num_heads , dtype = torch .float )
587+ else :
588+ output_curvatures = torch .tensor (curvatures , dtype = torch .float )
589+ return output_curvatures [:, None , None ]
590+
591+ def _get_curvatures (manifold : Union [Manifold , ProductManifold ]) -> Union [float , list ]:
592+ """Helper function to retrieve curvature(s) from a Manifold or ProductManifold."""
593+ if isinstance (manifold , ProductManifold ):
594+ return manifold .curvatures
595+ elif isinstance (manifold , Manifold ):
596+ return manifold .curvature
597+ else :
598+ raise TypeError ("Expected a Manifold or ProductManifold class." )
0 commit comments