Skip to content

Commit 7ee569b

Browse files
[hotfix] Fixed fused layernorm bug without apex (#5609)
* fixed fused layernorm bug without apex * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * same for flash attn * remove flash attn check --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 0d0a582 commit 7ee569b

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

colossalai/shardformer/layer/normalization.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,13 @@ def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *
225225
# fall back to the normal fused layernorm is not built
226226
ApexFusedLayerNorm = FusedLayerNormWithHook
227227
else:
228-
ApexFusedLayerNorm = FusedLayerNormWithHook
228+
try:
229+
ApexFusedLayerNorm = FusedLayerNormWithHook
230+
except NameError:
231+
warnings.warn(
232+
"Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using vanilla layernorm instead."
233+
)
234+
return module
229235

230236
layernorm = (
231237
ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device)

colossalai/shardformer/shard/shard_config.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,15 @@ def _turn_on_all_optimization(self):
120120
Turn on all optimization.
121121
"""
122122
# you can add all the optimization flag here
123-
self.enable_fused_normalization = True
123+
try:
124+
from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm # noqa
125+
126+
apex_avail = True
127+
except ImportError:
128+
apex_avail = False
129+
warnings.warn("You set enable_all_optimization=True, but apex is not installed.")
130+
131+
self.enable_fused_normalization = apex_avail
124132
self.enable_flash_attention = True
125133
self.enable_jit_fused = True
126134
# This can cause non-in-place param sharding when used without ZeRO.

0 commit comments

Comments
 (0)