Skip to content

Commit 541b967

Browse files
authored
[apex FusedAdam] crash workaround (#249)
* [apex FusedAdam] crash workaround * Trigger CI * Trigger CI * fix * fix
1 parent 406022e commit 541b967

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

megatron/optimizer/__init__.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,20 @@ def _get_params_for_weight_decay_optimization(modules):
4444
[p for n, p in list(module_._parameters.items())
4545
if p is not None and n == 'bias'])
4646

47-
return weight_decay_params, no_weight_decay_params
47+
# XXX: temp hack to workaround the crash in apex FusedAdam's multi_tensor_applier
48+
#
49+
# it crashes when the param count is larger than a certain size which we hit at 200B over 80
50+
# A100 gpus - I think around 2.7B per gpu, so halving it works around the issue
51+
param_count = len(weight_decay_params['params'])
52+
first_half = weight_decay_params['params'][:param_count // 2]
53+
second_half = weight_decay_params['params'][param_count // 2:]
54+
55+
first_half = { 'params': first_half }
56+
second_half = { 'params': second_half }
57+
58+
return first_half, second_half, no_weight_decay_params
59+
60+
#return weight_decay_params, no_weight_decay_params
4861

4962

5063
def get_megatron_optimizer(model):

0 commit comments

Comments
 (0)