Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions colossalai/shardformer/layer/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,24 @@ def forward(self, input):

FusedRMSNormWithHook = NPUFusedRMSNormWithHook
else:
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm

class CUDAFusedRMSNormWithHook(ApexFusedRMSNorm):
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
super().__init__(normalized_shape, eps, elementwise_affine)

def forward(self, input):
output = super().forward(input)
output = hook_parameter_in_backward(output, self.weight)
return output
try:
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm

class CUDAFusedRMSNormWithHook(ApexFusedRMSNorm):
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
super().__init__(normalized_shape, eps, elementwise_affine)

def forward(self, input):
output = super().forward(input)
output = hook_parameter_in_backward(output, self.weight)
return output

FusedRMSNormWithHook = CUDAFusedRMSNormWithHook
except ImportError:
warnings.warn(
"Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMSNorm kernel"
)

FusedRMSNormWithHook = CUDAFusedRMSNormWithHook

FAST_LAYERNORM_SUPPORTED_SIZE = [
1024,
Expand Down
Loading