@@ -76,18 +76,24 @@ def forward(self, input):
76
76
77
77
FusedRMSNormWithHook = NPUFusedRMSNormWithHook
78
78
else :
79
- from apex .normalization import FusedRMSNorm as ApexFusedRMSNorm
80
-
81
- class CUDAFusedRMSNormWithHook (ApexFusedRMSNorm ):
82
- def __init__ (self , normalized_shape , eps = 0.00001 , elementwise_affine = True ):
83
- super ().__init__ (normalized_shape , eps , elementwise_affine )
84
-
85
- def forward (self , input ):
86
- output = super ().forward (input )
87
- output = hook_parameter_in_backward (output , self .weight )
88
- return output
79
+ try :
80
+ from apex .normalization import FusedRMSNorm as ApexFusedRMSNorm
81
+
82
+ class CUDAFusedRMSNormWithHook (ApexFusedRMSNorm ):
83
+ def __init__ (self , normalized_shape , eps = 0.00001 , elementwise_affine = True ):
84
+ super ().__init__ (normalized_shape , eps , elementwise_affine )
85
+
86
+ def forward (self , input ):
87
+ output = super ().forward (input )
88
+ output = hook_parameter_in_backward (output , self .weight )
89
+ return output
90
+
91
+ FusedRMSNormWithHook = CUDAFusedRMSNormWithHook
92
+ except ImportError :
93
+ warnings .warn (
94
+ "Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMSNorm kernel"
95
+ )
89
96
90
- FusedRMSNormWithHook = CUDAFusedRMSNormWithHook
91
97
92
98
FAST_LAYERNORM_SUPPORTED_SIZE = [
93
99
1024 ,
0 commit comments