@@ -657,7 +657,7 @@ def __init__(self, config: TransformerArgs) -> None:
657657 self .layers [str (layer_id )] = TransformerBlock (config )
658658
659659 if config .stage_idx == config .n_stages - 1 :
660- self .norm = RMSNorm (config .dim , eps = config .norm_eps )
660+ self .norm = nn . RMSNorm (config .dim , eps = config .norm_eps )
661661 self .output = nn .Linear (config .dim , config .vocab_size , bias = False )
662662 if config .tie_word_embeddings :
663663 self .output .weight = self .tok_embeddings .weight
@@ -751,8 +751,8 @@ def __init__(self, config: TransformerArgs) -> None:
751751 super ().__init__ ()
752752 self .attention = Attention (config )
753753 self .feed_forward = FeedForward (config )
754- self .ffn_norm = RMSNorm (config .dim , config .norm_eps )
755- self .attention_norm = RMSNorm (config .dim , config .norm_eps )
754+ self .ffn_norm = nn . RMSNorm (config .dim , config .norm_eps )
755+ self .attention_norm = nn . RMSNorm (config .dim , config .norm_eps )
756756 # None for llama architecture, set for granite architectures
757757 self .residual_multiplier = (
758758 config .residual_multiplier
@@ -928,20 +928,6 @@ def forward(self, x: Tensor) -> Tensor:
928928 return self .w2 (F .silu (self .w1 (x )) * self .w3 (x ))
929929
930930
931- class RMSNorm (nn .Module ):
932- def __init__ (self , dim : int , eps : float = 1e-5 ):
933- super ().__init__ ()
934- self .eps = eps
935- self .weight = nn .Parameter (torch .ones (dim ))
936-
937- def _norm (self , x ):
938- return x * torch .rsqrt (torch .mean (x * x , dim = - 1 , keepdim = True ) + self .eps )
939-
940- def forward (self , x : Tensor ) -> Tensor :
941- output = self ._norm (x .float ()).type_as (x )
942- return output * self .weight
943-
944-
945931def apply_scaling (freqs : torch .Tensor , rope_scaling : Dict [str , Any ]):
946932 # Check for the presence of the required keys
947933 required_keys = {
0 commit comments