@@ -76,18 +76,24 @@ def forward(self, input):
7676
7777 FusedRMSNormWithHook = NPUFusedRMSNormWithHook
7878else :
79- from apex .normalization import FusedRMSNorm as ApexFusedRMSNorm
80-
81- class CUDAFusedRMSNormWithHook (ApexFusedRMSNorm ):
82- def __init__ (self , normalized_shape , eps = 0.00001 , elementwise_affine = True ):
83- super ().__init__ (normalized_shape , eps , elementwise_affine )
84-
85- def forward (self , input ):
86- output = super ().forward (input )
87- output = hook_parameter_in_backward (output , self .weight )
88- return output
79+ try :
80+ from apex .normalization import FusedRMSNorm as ApexFusedRMSNorm
81+
82+ class CUDAFusedRMSNormWithHook (ApexFusedRMSNorm ):
83+ def __init__ (self , normalized_shape , eps = 0.00001 , elementwise_affine = True ):
84+ super ().__init__ (normalized_shape , eps , elementwise_affine )
85+
86+ def forward (self , input ):
87+ output = super ().forward (input )
88+ output = hook_parameter_in_backward (output , self .weight )
89+ return output
90+
91+ FusedRMSNormWithHook = CUDAFusedRMSNormWithHook
92+ except ImportError :
93+ warnings .warn (
94+ "Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMSNorm kernel"
95+ )
8996
90- FusedRMSNormWithHook = CUDAFusedRMSNormWithHook
9197
9298FAST_LAYERNORM_SUPPORTED_SIZE = [
9399 1024 ,
0 commit comments