Skip to content

Commit fa9d031

Browse files
authored
[Hotfix] hotfix normalization (#6163)
* [fix] hotfix normalization * [hotfix] force doc ci test * [hotfix] fallback doc
1 parent 130229f commit fa9d031

File tree

1 file changed

+17
-11
lines changed

1 file changed

+17
-11
lines changed

colossalai/shardformer/layer/normalization.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,18 +76,24 @@ def forward(self, input):
7676

7777
FusedRMSNormWithHook = NPUFusedRMSNormWithHook
7878
else:
79-
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
80-
81-
class CUDAFusedRMSNormWithHook(ApexFusedRMSNorm):
82-
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
83-
super().__init__(normalized_shape, eps, elementwise_affine)
84-
85-
def forward(self, input):
86-
output = super().forward(input)
87-
output = hook_parameter_in_backward(output, self.weight)
88-
return output
79+
try:
80+
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
81+
82+
class CUDAFusedRMSNormWithHook(ApexFusedRMSNorm):
83+
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
84+
super().__init__(normalized_shape, eps, elementwise_affine)
85+
86+
def forward(self, input):
87+
output = super().forward(input)
88+
output = hook_parameter_in_backward(output, self.weight)
89+
return output
90+
91+
FusedRMSNormWithHook = CUDAFusedRMSNormWithHook
92+
except ImportError:
93+
warnings.warn(
94+
"Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMSNorm kernel"
95+
)
8996

90-
FusedRMSNormWithHook = CUDAFusedRMSNormWithHook
9197

9298
FAST_LAYERNORM_SUPPORTED_SIZE = [
9399
1024,

0 commit comments

Comments
 (0)