@@ -340,41 +340,272 @@ def forward(
340340
341341
342342class StereographicLayerNorm (nn .Module ):
343- """Stereographic Layer Normalization."""
343+ """Stereographic Layer Normalization.
344344
345- def __init__ (self , manifold : Manifold | ProductManifold , num_heads : int ):
346- raise NotImplementedError
345+ Args:
346+ manifold: Manifold or ProductManifold object defining the geometry.
347+ embedding_dim: Embedding dimension of the input points.
348+ curvatures: Tensor of shape [num_heads, 1, 1] representing the curvature
349+ value used per head in geometric computations.
350+
351+ Attributes:
352+ manifold: The manifold object for geometric operations.
353+ stereographic_norm: Stereographic layernorm used in the tangent space.
354+ curvatures: Tensor of shape [num_heads, 1, 1] representing the curvature
355+ value used per head in geometric computations.
356+ """
357+
358+ def __init__ (
359+ self , manifold : Manifold | ProductManifold , embedding_dim : int , curvatures : torch .Tensor ["num_heads 1 1" ]
360+ ):
361+ super ().__init__ ()
362+
363+ self .manifold = manifold
364+ self .stereographic_norm = self .manifold .apply (nn .LayerNorm (embedding_dim ))
365+ self .curvatures = curvatures
347366
348367 def forward (self , X : Float [torch .Tensor , "n_nodes dim" ]) -> Float [torch .Tensor , "n_nodes dim" ]:
349368 """Apply layer normalization on the stereographic manifold."""
350- raise NotImplementedError
369+ norm_X = self .stereographic_norm (X )
370+ output = geoopt .manifolds .stereographic .math .project (norm_X , self .curvatures )
371+ return output
372+
373+
374+ class GeometricLinearizedAttention (nn .Module ):
375+ """Geometric Linearized Attention.
376+
377+ Args:
378+ curvatures: Tensor of shape [num_heads, 1, 1] representing the curvature
379+ value used per head in geometric computations.
380+ num_heads: Number of attention heads.
381+ head_dim: Dimension of each attention head.
382+
383+ Attributes:
384+ num_heads: Number of attention heads.
385+ head_dim: Dimension of each attention head.
386+ epsilon: Small epsilon for masking inverse denominator (constant).
387+ clamp_epsilon: Minimum clamp value for numerical stability in gamma denominator (constant).
388+ """
389+
390+ def __init__ (self , curvatures : float | list [float ], num_heads : int , head_dim : int ):
391+ super ().__init__ ()
392+
393+ self .num_heads = num_heads
394+ self .curvatures = curvatures
395+
396+ self .head_dim = head_dim
397+ self ._epsilon = 1e-5
398+ self ._clamp_epsilon = 1e-10
399+
400+ def forward (
401+ self ,
402+ Q : Float [torch .Tensor , "batch_size num_heads n_nodes head_dim" ],
403+ K : Float [torch .Tensor , "batch_size num_heads n_nodes head_dim" ],
404+ V : Float [torch .Tensor , "batch_size num_heads n_nodes head_dim" ],
405+ mask : Float [torch .Tensor , "1 1 n_nodes n_nodes" ],
406+ ) -> Float [torch .Tensor , "batch_size n_nodes dim" ]:
407+ """Forward pass for the geometric linearized attention layer.
408+
409+ Args:
410+ Q: Query tensor.
411+ K: Key tensor.
412+ V: Value tensor.
413+ mask: Mask tensor for attention.
414+
415+ Returns:
416+ Output tensor after applying attention.
417+ """
418+ v1 = geoopt .manifolds .stereographic .math .parallel_transport0back (V , Q , k = self .curvatures )
419+ v2 = geoopt .manifolds .stereographic .math .parallel_transport0back (V , K , k = self .curvatures )
420+
421+ gamma = geoopt .manifolds .stereographic .math .lambda_x (x = V , k = self .curvatures , keepdim = True , dim = - 1 )
422+ denominator = geoopt .utils .clamp_abs ((gamma - 1 ), self ._clamp_epsilon )
423+
424+ x = ((gamma / denominator ) * V ) * mask [None , :, None ]
425+
426+ v1 = nn .functional .elu (v1 ) + 1
427+ v2 = (denominator * (nn .functional .elu (v2 ) + 1 )) * mask [None , :, None ]
428+
429+ # Linearized approximation
430+ v2_cumsum = v2 .sum (dim = - 2 ) # [B, H, D]
431+ D = torch .einsum ("...nd,...d->...n" , v1 , v2_cumsum .type_as (v1 )) # normalization terms
432+ D_inv = 1.0 / D .masked_fill_ (D == 0 , self ._epsilon )
433+ context = torch .einsum ("...nd,...ne->...de" , v2 , x )
434+ X = torch .einsum ("...de,...nd,...n->...ne" , context , v1 , D_inv )
435+
436+ X = geoopt .manifolds .stereographic .math .project (X , k = self .curvatures )
437+ X = geoopt .manifolds .stereographic .math .mobius_scalar_mul (
438+ torch .tensor (0.5 , dtype = X .dtype , device = X .device ), X , k = self .curvatures , dim = - 1
439+ )
440+ X = geoopt .manifolds .stereographic .math .project (X , k = self .curvatures )
441+
442+ return X
351443
352444
353445class StereographicAttention (nn .Module ):
354- """Stereographic Attention Layer."""
446+ """Stereographic Attention Layer.
447+
448+ Args:
449+ manifold: Manifold or ProductManifold object defining the geometry.
450+ num_heads: Number of attention heads.
451+ dim: Embedding dimension of the input points.
452+ head_dim: Dimension of each attention head.
355453
356- def __init__ (self , manifold : Manifold | ProductManifold , num_heads : int ):
357- raise NotImplementedError
454+ Attributes:
455+ manifold: The manifold object for geometric operations.
456+ curvatures: Tensor of shape [num_heads, 1, 1] representing the curvature
457+ value used per head in geometric computations.
458+ num_heads: Number of attention heads.
459+ head_dim: Dimensionality of each attention head.
460+ W_q: Linear layer projecting inputs to query vectors.
461+ W_k: Linear layer projecting inputs to key vectors.
462+ W_v: Manifold-aware linear layer projecting to value vectors.
463+ attn: Stereographic multi-head attention module.
464+ ff: Manifold-aware linear layer for the feedforward output.
465+ """
466+
467+ def __init__ (self , manifold : Manifold | ProductManifold , num_heads : int , dim : int , head_dim : int ):
468+ super ().__init__ ()
469+
470+ self .manifold = manifold
471+ self .num_heads = num_heads
472+ self .head_dim = head_dim
473+ self .curvatures = _reshape_curvatures (_get_curvatures (self .manifold ), self .num_heads )
474+
475+ self .W_q = nn .Linear (in_features = dim , out_features = self .num_heads * self .head_dim )
476+ self .W_k = nn .Linear (in_features = dim , out_features = self .num_heads * self .head_dim )
477+ self .W_v = KappaGCNLayer (in_features = dim , out_features = self .num_heads * self .head_dim , manifold = self .manifold )
478+
479+ self .attn = GeometricLinearizedAttention (
480+ curvatures = self .curvatures , num_heads = self .num_heads , head_dim = self .head_dim
481+ )
482+ self .ff = KappaGCNLayer (in_features = self .num_heads * self .head_dim , out_features = dim , manifold = self .manifold )
358483
359484 def forward (
360485 self ,
361486 X : Float [torch .Tensor , "n_nodes dim" ],
362487 mask : Float [torch .Tensor , "n_nodes n_nodes" ] | None = None ,
363488 ) -> Float [torch .Tensor , "n_nodes dim" ]:
364489 """Forward pass for the stereographic attention layer."""
365- raise NotImplementedError
490+ Q = self ._split_heads (self .W_q (X )) # [B, H, N, D]
491+ K = self ._split_heads (self .W_k (X ))
492+ V = self ._split_heads (self .W_v (X = X ))
493+
494+ attn_out = self .attn (Q , K , V , mask .unsqueeze (0 ).unsqueeze (0 )) # type: ignore
495+ attn_out = self ._combine_heads (attn_out )
496+
497+ out = self .ff (X = attn_out )
498+
499+ return out
500+
501+ def _combine_heads (
502+ self , X : Float [torch .Tensor , "n_nodes num_heads head_dim" ]
503+ ) -> Float [torch .Tensor , "n_nodes num_heads * head_dim" ]:
504+ """Combines multi-head tensor by merging head and feature dimensions.
505+
506+ Args:
507+ X: Input tensor with shape.
508+
509+ Returns:
510+ X: Reshaped tensor with shape (n_nodes, num_heads * head_dim).
511+ """
512+ X = X .transpose (0 , 1 )
513+ X = X .reshape (X .size (0 ), self .num_heads * self .head_dim )
514+ return X
515+
516+ def _split_heads (
517+ self , X : Float [torch .Tensor , "n_nodes num_heads * head_dim" ]
518+ ) -> Float [torch .Tensor , "num_heads n_nodes head_dim" ]:
519+ """Splits the last dimension of the input into (num_heads, head_dim) and transposes to prepare for attention.
520+
521+ Args:
522+ X: Input tensor with shape (n_nodes, num_heads * head_dim).
523+
524+ Returns:
525+ X: Reshaped tensor with shape (num_heads, n_nodes, head_dim).
526+ """
527+ X = X .reshape (X .size (0 ), self .num_heads , self .head_dim )
528+ X = X .transpose (0 , 1 )
529+ return X
366530
367531
368532class StereographicTransformer (nn .Module ):
369- """Stereographic Transformer Block."""
533+ """Stereographic Transformer Block.
370534
371- def __init__ (self , manifold : Manifold | ProductManifold , num_blocks : int , num_heads : int ):
372- raise NotImplementedError
535+ Args:
536+ manifold: Manifold or ProductManifold object defining the geometry.
537+ num_heads: Number of attention heads.
538+ dim: Dimensionality of the input features.
539+ head_dim: Dimensionality of each attention head.
540+ use_layer_norm: Whether to apply layer normalization in tangent space.
541+
542+ Attributes:
543+ manifold: The manifold object for geometric operations.
544+ curvatures: Manifold curvatures reshaped to [num_heads, 1, 1] for broadcasting.
545+ mha: Multi-head stereographic attention module.
546+ norm1: First normalization layer (can be Identity or StereographicLayerNorm).
547+ norm2: Second normalization layer.
548+ mlpblock: Feedforward network in stereographic space.
549+ stereographic_activation: Activation wrapped to operate in tangent space.
550+ """
551+
552+ def __init__ (
553+ self , manifold : Manifold | ProductManifold , num_heads : int , dim : int , head_dim : int , use_layer_norm : bool = True
554+ ):
555+ super ().__init__ ()
556+
557+ # Check that manifold is stereographic
558+ if not manifold .is_stereographic :
559+ raise ValueError (
560+ "Manifold must be stereographic for StereographicLayerNorm to work. Please use manifold.stereographic() to convert."
561+ )
562+
563+ self .manifold = manifold
564+ self .curvatures = _reshape_curvatures (_get_curvatures (self .manifold ), num_heads )
565+ self .stereographic_activation = self .manifold .apply (nn .ReLU ())
566+ self .mha = StereographicAttention (manifold = self .manifold , num_heads = num_heads , dim = dim , head_dim = head_dim )
567+
568+ if use_layer_norm :
569+ self .norm1 = StereographicLayerNorm (manifold = self .manifold , embedding_dim = dim , curvatures = self .curvatures )
570+ self .norm2 = StereographicLayerNorm (manifold = self .manifold , embedding_dim = dim , curvatures = self .curvatures )
571+ else :
572+ self .norm1 = nn .Identity ()
573+ self .norm2 = nn .Identity ()
574+
575+ self .mlpblock = nn .Sequential (
576+ KappaGCNLayer (in_features = dim , out_features = dim , manifold = self .manifold ),
577+ self .stereographic_activation ,
578+ KappaGCNLayer (in_features = dim , out_features = dim , manifold = self .manifold ),
579+ )
373580
374581 def forward (
375582 self ,
376583 X : Float [torch .Tensor , "n_nodes dim" ],
377584 mask : Float [torch .Tensor , "n_nodes n_nodes" ] | None = None ,
378585 ) -> Float [torch .Tensor , "n_nodes dim" ]:
379586 """Forward pass through the stereographic transformer block."""
380- raise NotImplementedError
587+ X = geoopt .manifolds .stereographic .math .mobius_add (self .mha (self .norm1 (X ), mask ), X , self .curvatures )
588+ X = geoopt .manifolds .stereographic .math .project (X , self .curvatures )
589+ X = geoopt .manifolds .stereographic .math .mobius_add (self .mlpblock (self .norm2 (X )), X , self .curvatures )
590+ X = geoopt .manifolds .stereographic .math .project (X , self .curvatures )
591+
592+ return X
593+
594+
595+ def _reshape_curvatures (curvatures : float | list [float ], num_heads : int ) -> Float [torch .Tensor , "num_heads 1 1" ]:
596+ """Helper function to reshape curvature(s) for use in multi-head stereographic attention."""
597+ if isinstance (curvatures , float ):
598+ output_curvatures = torch .tensor ([curvatures ] * num_heads , dtype = torch .float )
599+ else :
600+ output_curvatures = torch .tensor (curvatures , dtype = torch .float )
601+ return output_curvatures [:, None , None ]
602+
603+
604+ def _get_curvatures (manifold : Manifold | ProductManifold ) -> float | list [float ]:
605+ """Helper function to retrieve curvature(s) from a Manifold or ProductManifold."""
606+ if isinstance (manifold , ProductManifold ):
607+ return manifold .curvatures
608+ elif isinstance (manifold , Manifold ):
609+ return manifold .curvature
610+ else :
611+ raise TypeError ("Expected a Manifold or ProductManifold class." )
0 commit comments