1010
1111if TYPE_CHECKING :
1212 from beartype .typing import Callable
13- from jaxtyping import Float , Union , List
13+ from jaxtyping import Float
1414
1515from ...manifolds import Manifold , ProductManifold
1616
@@ -341,21 +341,23 @@ def forward(
341341
342342class StereographicLayerNorm (nn .Module ):
343343 """Stereographic Layer Normalization.
344-
344+
345345 Args:
346346 manifold: Manifold or ProductManifold object defining the geometry.
347347 embedding_dim: Embedding dimension of the input points.
348- curvatures: Tensor of shape [num_heads, 1, 1] representing the curvature
348+ curvatures: Tensor of shape [num_heads, 1, 1] representing the curvature
349349 value used per head in geometric computations.
350350
351- Attributes:
351+ Attributes:
352352 manifold: The manifold object for geometric operations.
353353 stereographic_norm: Stereographic layernorm used in the tangent space.
354- curvatures: Tensor of shape [num_heads, 1, 1] representing the curvature
355- value used per head in geometric computations.
354+ curvatures: Tensor of shape [num_heads, 1, 1] representing the curvature
355+ value used per head in geometric computations.
356356 """
357357
358- def __init__ (self , manifold : Manifold | ProductManifold , embedding_dim : int , curvatures : torch .Tensor ["num_heads 1 1" ]):
358+ def __init__ (
359+ self , manifold : Manifold | ProductManifold , embedding_dim : int , curvatures : torch .Tensor ["num_heads 1 1" ]
360+ ):
359361 super ().__init__ ()
360362
361363 self .manifold = manifold
@@ -364,152 +366,163 @@ def __init__(self, manifold: Manifold | ProductManifold, embedding_dim: int, cur
364366
365367 def forward (self , X : Float [torch .Tensor , "n_nodes dim" ]) -> Float [torch .Tensor , "n_nodes dim" ]:
366368 """Apply layer normalization on the stereographic manifold."""
367-
368369 norm_X = self .stereographic_norm (X )
369370 output = geoopt .manifolds .stereographic .math .project (norm_X , self .curvatures )
370371 return output
371372
372373
373374class GeometricLinearizedAttention (nn .Module ):
374375 """Geometric Linearized Attention.
375-
376+
376377 Args:
377- curvatures: Tensor of shape [num_heads, 1, 1] representing the curvature
378+ curvatures: Tensor of shape [num_heads, 1, 1] representing the curvature
378379 value used per head in geometric computations.
379380 num_heads: Number of attention heads.
380381 head_dim: Dimension of each attention head.
381382
382- Attributes:
383+ Attributes:
383384 num_heads: Number of attention heads.
384385 head_dim: Dimension of each attention head.
385386 epsilon: Small epsilon for masking inverse denominator (constant).
386387 clamp_epsilon: Minimum clamp value for numerical stability in gamma denominator (constant).
387388 """
388- def __init__ ( self , curvatures : Union [ float , List [ float ]], num_heads : int , head_dim : int ):
389-
389+
390+ def __init__ ( self , curvatures : float | list [ float ], num_heads : int , head_dim : int ):
390391 super ().__init__ ()
391392
392393 self .num_heads = num_heads
393- self .curvatures = curvatures
394-
395- self .head_dim = head_dim
394+ self .curvatures = curvatures
395+
396+ self .head_dim = head_dim
396397 self ._epsilon = 1e-5
397398 self ._clamp_epsilon = 1e-10
398-
399+
399400 def forward (
400- self ,
401- Q : Float [torch .Tensor , "batch_size num_heads n_nodes head_dim" ],
401+ self ,
402+ Q : Float [torch .Tensor , "batch_size num_heads n_nodes head_dim" ],
402403 K : Float [torch .Tensor , "batch_size num_heads n_nodes head_dim" ],
403404 V : Float [torch .Tensor , "batch_size num_heads n_nodes head_dim" ],
404- mask : Float [torch .Tensor , "1 1 n_nodes n_nodes" ]
405+ mask : Float [torch .Tensor , "1 1 n_nodes n_nodes" ],
405406 ) -> Float [torch .Tensor , "batch_size n_nodes dim" ]:
407+ """Forward pass for the geometric linearized attention layer.
406408
409+ Args:
410+ Q: Query tensor.
411+ K: Key tensor.
412+ V: Value tensor.
413+ mask: Mask tensor for attention.
414+
415+ Returns:
416+ Output tensor after applying attention.
417+ """
407418 v1 = geoopt .manifolds .stereographic .math .parallel_transport0back (V , Q , k = self .curvatures )
408419 v2 = geoopt .manifolds .stereographic .math .parallel_transport0back (V , K , k = self .curvatures )
409-
420+
410421 gamma = geoopt .manifolds .stereographic .math .lambda_x (x = V , k = self .curvatures , keepdim = True , dim = - 1 )
411422 denominator = geoopt .utils .clamp_abs ((gamma - 1 ), self ._clamp_epsilon )
412-
423+
413424 x = ((gamma / denominator ) * V ) * mask [None , :, None ]
414-
415- v1 = ( nn .functional .elu (v1 ) + 1 )
425+
426+ v1 = nn .functional .elu (v1 ) + 1
416427 v2 = (denominator * (nn .functional .elu (v2 ) + 1 )) * mask [None , :, None ]
417428
418429 # Linearized approximation
419- v2_cumsum = v2 .sum (dim = - 2 ) # [B, H, D]
420- D = torch .einsum ('...nd,...d->...n' , v1 , v2_cumsum .type_as (v1 )) # normalization terms
421- D_inv = 1. / D .masked_fill_ (D == 0 , self ._epsilon )
422- context = torch .einsum ('...nd,...ne->...de' , v2 , x )
423- X = torch .einsum ('...de,...nd,...n->...ne' , context , v1 , D_inv )
424-
425-
430+ v2_cumsum = v2 .sum (dim = - 2 ) # [B, H, D]
431+ D = torch .einsum ("...nd,...d->...n" , v1 , v2_cumsum .type_as (v1 )) # normalization terms
432+ D_inv = 1.0 / D .masked_fill_ (D == 0 , self ._epsilon )
433+ context = torch .einsum ("...nd,...ne->...de" , v2 , x )
434+ X = torch .einsum ("...de,...nd,...n->...ne" , context , v1 , D_inv )
435+
426436 X = geoopt .manifolds .stereographic .math .project (X , k = self .curvatures )
427- X = geoopt .manifolds .stereographic .math .mobius_scalar_mul (torch .tensor (0.5 , dtype = X .dtype , device = X .device ),
428- X ,
429- k = self .curvatures ,
430- dim = - 1 )
437+ X = geoopt .manifolds .stereographic .math .mobius_scalar_mul (
438+ torch .tensor (0.5 , dtype = X .dtype , device = X .device ), X , k = self .curvatures , dim = - 1
439+ )
431440 X = geoopt .manifolds .stereographic .math .project (X , k = self .curvatures )
432-
441+
433442 return X
434443
435444
436445class StereographicAttention (nn .Module ):
437446 """Stereographic Attention Layer.
438-
447+
439448 Args:
440449 manifold: Manifold or ProductManifold object defining the geometry.
441450 num_heads: Number of attention heads.
442451 dim: Embedding dimension of the input points.
443- head_dim: Dimension of each attention head.
452+ head_dim: Dimension of each attention head.
444453
445- Attributes:
454+ Attributes:
446455 manifold: The manifold object for geometric operations.
447- curvatures: Tensor of shape [num_heads, 1, 1] representing the curvature
456+ curvatures: Tensor of shape [num_heads, 1, 1] representing the curvature
448457 value used per head in geometric computations.
449458 num_heads: Number of attention heads.
450459 head_dim: Dimensionality of each attention head.
451460 W_q: Linear layer projecting inputs to query vectors.
452461 W_k: Linear layer projecting inputs to key vectors.
453462 W_v: Manifold-aware linear layer projecting to value vectors.
454463 attn: Stereographic multi-head attention module.
455- ff: Manifold-aware linear layer for the feedforward output.
464+ ff: Manifold-aware linear layer for the feedforward output.
456465 """
457466
458467 def __init__ (self , manifold : Manifold | ProductManifold , num_heads : int , dim : int , head_dim : int ):
459468 super ().__init__ ()
460-
469+
461470 self .manifold = manifold
462471 self .num_heads = num_heads
463472 self .head_dim = head_dim
464473 self .curvatures = _reshape_curvatures (_get_curvatures (self .manifold ), self .num_heads )
465-
466- self .W_q = nn .Linear (in_features = dim , out_features = self .num_heads * self .head_dim )
467- self .W_k = nn .Linear (in_features = dim , out_features = self .num_heads * self .head_dim )
468- self .W_v = KappaGCNLayer (in_features = dim ,
469- out_features = self .num_heads * self .head_dim ,
470- manifold = self .manifold )
471-
472- self .attn = GeometricLinearizedAttention (curvatures = self .curvatures ,
473- num_heads = self .num_heads ,
474- head_dim = self .head_dim )
475- self .ff = KappaGCNLayer (in_features = self .num_heads * self .head_dim ,
476- out_features = dim ,
477- manifold = self .manifold )
474+
475+ self .W_q = nn .Linear (in_features = dim , out_features = self .num_heads * self .head_dim )
476+ self .W_k = nn .Linear (in_features = dim , out_features = self .num_heads * self .head_dim )
477+ self .W_v = KappaGCNLayer (in_features = dim , out_features = self .num_heads * self .head_dim , manifold = self .manifold )
478+
479+ self .attn = GeometricLinearizedAttention (
480+ curvatures = self .curvatures , num_heads = self .num_heads , head_dim = self .head_dim
481+ )
482+ self .ff = KappaGCNLayer (in_features = self .num_heads * self .head_dim , out_features = dim , manifold = self .manifold )
478483
479484 def forward (
480485 self ,
481486 X : Float [torch .Tensor , "n_nodes dim" ],
482487 mask : Float [torch .Tensor , "n_nodes n_nodes" ] | None = None ,
483488 ) -> Float [torch .Tensor , "n_nodes dim" ]:
484489 """Forward pass for the stereographic attention layer."""
485- Q = self ._split_heads (self .W_q (X )) # [B, H, N, D]
490+ Q = self ._split_heads (self .W_q (X )) # [B, H, N, D]
486491 K = self ._split_heads (self .W_k (X ))
487492 V = self ._split_heads (self .W_v (X = X ))
488-
489- attn_out = self .attn (Q , K , V , mask .unsqueeze (0 ).unsqueeze (0 ))
493+
494+ attn_out = self .attn (Q , K , V , mask .unsqueeze (0 ).unsqueeze (0 )) # type: ignore
490495 attn_out = self ._combine_heads (attn_out )
491-
496+
492497 out = self .ff (X = attn_out )
493-
498+
494499 return out
495-
500+
496501 def _combine_heads (
497- self ,
498- X : Float [torch .Tensor , "n_nodes num_heads head_dim" ]
502+ self , X : Float [torch .Tensor , "n_nodes num_heads head_dim" ]
499503 ) -> Float [torch .Tensor , "n_nodes num_heads * head_dim" ]:
500- """Combines multi-head tensor by merging head and feature dimensions."""
504+ """Combines multi-head tensor by merging head and feature dimensions.
505+
506+ Args:
507+ X: Input tensor with shape.
501508
509+ Returns:
510+ X: Reshaped tensor with shape (n_nodes, num_heads * head_dim).
511+ """
502512 X = X .transpose (0 , 1 )
503513 X = X .reshape (X .size (0 ), self .num_heads * self .head_dim )
504514 return X
505-
515+
506516 def _split_heads (
507- self ,
508- X : Float [torch .Tensor , "n_nodes num_heads * head_dim" ]
517+ self , X : Float [torch .Tensor , "n_nodes num_heads * head_dim" ]
509518 ) -> Float [torch .Tensor , "num_heads n_nodes head_dim" ]:
510- """
511- Splits the last dimension of the input into (num_heads, head_dim) and transposes
512- to prepare for multi-head attention computation.
519+ """Splits the last dimension of the input into (num_heads, head_dim) and transposes to prepare for attention.
520+
521+ Args:
522+ X: Input tensor with shape (n_nodes, num_heads * head_dim).
523+
524+ Returns:
525+ X: Reshaped tensor with shape (num_heads, n_nodes, head_dim).
513526 """
514527 X = X .reshape (X .size (0 ), self .num_heads , self .head_dim )
515528 X = X .transpose (0 , 1 )
@@ -518,7 +531,7 @@ def _split_heads(
518531
519532class StereographicTransformer (nn .Module ):
520533 """Stereographic Transformer Block.
521-
534+
522535 Args:
523536 manifold: Manifold or ProductManifold object defining the geometry.
524537 num_heads: Number of attention heads.
@@ -533,36 +546,36 @@ class StereographicTransformer(nn.Module):
533546 norm1: First normalization layer (can be Identity or StereographicLayerNorm).
534547 norm2: Second normalization layer.
535548 mlpblock: Feedforward network in stereographic space.
536- stereographic_activation: Activation wrapped to operate in tangent space.
549+ stereographic_activation: Activation wrapped to operate in tangent space.
537550 """
538551
539- def __init__ (self , manifold : Manifold | ProductManifold , num_heads : int , dim : int , head_dim : int , use_layer_norm : bool = True ):
540- super (StereographicTransformer , self ).__init__ ()
552+ def __init__ (
553+ self , manifold : Manifold | ProductManifold , num_heads : int , dim : int , head_dim : int , use_layer_norm : bool = True
554+ ):
555+ super ().__init__ ()
541556
542- # Check that manifold is stereographic
557+ # Check that manifold is stereographic
543558 if not manifold .is_stereographic :
544559 raise ValueError (
545560 "Manifold must be stereographic for StereographicLayerNorm to work. Please use manifold.stereographic() to convert."
546- )
547-
561+ )
562+
548563 self .manifold = manifold
549564 self .curvatures = _reshape_curvatures (_get_curvatures (self .manifold ), num_heads )
550565 self .stereographic_activation = self .manifold .apply (nn .ReLU ())
551- self .mha = StereographicAttention (manifold = self .manifold ,
552- num_heads = num_heads ,
553- dim = dim ,
554- head_dim = head_dim )
566+ self .mha = StereographicAttention (manifold = self .manifold , num_heads = num_heads , dim = dim , head_dim = head_dim )
555567
556568 if use_layer_norm :
557569 self .norm1 = StereographicLayerNorm (manifold = self .manifold , embedding_dim = dim , curvatures = self .curvatures )
558570 self .norm2 = StereographicLayerNorm (manifold = self .manifold , embedding_dim = dim , curvatures = self .curvatures )
559571 else :
560572 self .norm1 = nn .Identity ()
561573 self .norm2 = nn .Identity ()
562-
563- self .mlpblock = nn .Sequential (KappaGCNLayer (in_features = dim , out_features = dim , manifold = self .manifold ),
564- self .stereographic_activation ,
565- KappaGCNLayer (in_features = dim , out_features = dim , manifold = self .manifold )
574+
575+ self .mlpblock = nn .Sequential (
576+ KappaGCNLayer (in_features = dim , out_features = dim , manifold = self .manifold ),
577+ self .stereographic_activation ,
578+ KappaGCNLayer (in_features = dim , out_features = dim , manifold = self .manifold ),
566579 )
567580
568581 def forward (
@@ -571,28 +584,28 @@ def forward(
571584 mask : Float [torch .Tensor , "n_nodes n_nodes" ] | None = None ,
572585 ) -> Float [torch .Tensor , "n_nodes dim" ]:
573586 """Forward pass through the stereographic transformer block."""
574-
575- X = geoopt .manifolds .stereographic .math .mobius_add (self .mha (self .norm1 (X ), mask ), X , self .curvatures )
587+ X = geoopt .manifolds .stereographic .math .mobius_add (self .mha (self .norm1 (X ), mask ), X , self .curvatures )
576588 X = geoopt .manifolds .stereographic .math .project (X , self .curvatures )
577589 X = geoopt .manifolds .stereographic .math .mobius_add (self .mlpblock (self .norm2 (X )), X , self .curvatures )
578- X = geoopt .manifolds .stereographic .math .project (X , self .curvatures )
579-
590+ X = geoopt .manifolds .stereographic .math .project (X , self .curvatures )
591+
580592 return X
581593
582594
583- def _reshape_curvatures (curvatures : Union [ float , List [float ] ], num_heads : int ) -> Float [torch .Tensor , "num_heads 1 1" ]:
584- """Helper function to reshape curvature(s) for use in multi-head stereographic attention. """
595+ def _reshape_curvatures (curvatures : float | list [float ], num_heads : int ) -> Float [torch .Tensor , "num_heads 1 1" ]:
596+ """Helper function to reshape curvature(s) for use in multi-head stereographic attention."""
585597 if isinstance (curvatures , float ):
586598 output_curvatures = torch .tensor ([curvatures ] * num_heads , dtype = torch .float )
587599 else :
588600 output_curvatures = torch .tensor (curvatures , dtype = torch .float )
589601 return output_curvatures [:, None , None ]
590602
591- def _get_curvatures (manifold : Union [Manifold , ProductManifold ]) -> Union [float , list ]:
603+
604+ def _get_curvatures (manifold : Manifold | ProductManifold ) -> float | list [float ]:
592605 """Helper function to retrieve curvature(s) from a Manifold or ProductManifold."""
593606 if isinstance (manifold , ProductManifold ):
594607 return manifold .curvatures
595608 elif isinstance (manifold , Manifold ):
596609 return manifold .curvature
597610 else :
598- raise TypeError ("Expected a Manifold or ProductManifold class." )
611+ raise TypeError ("Expected a Manifold or ProductManifold class." )
0 commit comments