File tree Expand file tree Collapse file tree 1 file changed +14
-1
lines changed Expand file tree Collapse file tree 1 file changed +14
-1
lines changed Original file line number Diff line number Diff line change @@ -44,7 +44,20 @@ def _get_params_for_weight_decay_optimization(modules):
4444 [p for n , p in list (module_ ._parameters .items ())
4545 if p is not None and n == 'bias' ])
4646
47- return weight_decay_params , no_weight_decay_params
47+ # XXX: temp hack to workaround the crash in apex FusedAdam's multi_tensor_applier
48+ #
49+ # it crashes when the param count is larger than a certain size which we hit at 200B over 80
50+ # A100 gpus - I think around 2.7B per gpu, so halving it works around the issue
51+ param_count = len (weight_decay_params ['params' ])
52+ first_half = weight_decay_params ['params' ][:param_count // 2 ]
53+ second_half = weight_decay_params ['params' ][param_count // 2 :]
54+
55+ first_half = { 'params' : first_half }
56+ second_half = { 'params' : second_half }
57+
58+ return first_half , second_half , no_weight_decay_params
59+
60+ #return weight_decay_params, no_weight_decay_params
4861
4962
5063def get_megatron_optimizer (model ):
You can’t perform that action at this time.
0 commit comments