Skip to content

Commit 0d0d84c

Browse files
authored
Add bias a weight we need to sync as well (#307)
1 parent c3be5d3 commit 0d0d84c

File tree

4 files changed

+9
-4
lines changed

4 files changed

+9
-4
lines changed

megatron/arguments.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,8 +375,8 @@ def _add_network_size_args(parser):
375375
', needs to be divisible by TP size and `make-vocab-size-divisible-by`.')
376376
group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
377377
help='Layer norm epsilon.')
378-
group.add_argument('--layernorm-tp-auto-sync', action='store_true',
379-
help='Force syncing layernorm params across TP ranks in forward. '
378+
group.add_argument('--sync-tp-duplicated-parameters', action='store_true',
379+
help='Force syncing duplicated params across TP ranks in forward. '
380380
'This is a workaround for an unresolved bug leading to TP ranks '
381381
'getting out of sync with each other.')
382382
group.add_argument('--apply-residual-connection-post-layernorm',

megatron/model/fused_layer_norm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def __init__(self, normalized_shape, eps=1e-5):
8484
self.reset_parameters()
8585

8686
args = get_args()
87-
self.layernorm_tp_auto_sync = args.layernorm_tp_auto_sync
87+
self.layernorm_tp_auto_sync = args.sync_tp_duplicated_parameters
8888

8989
self.use_meg_ds_fused_layer_norm = (
9090
args.bf16 # Current Meg-DS cuda kernel has better throughput than torch.nn.LayerNorm

megatron/mpu/layers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,7 @@ def __init__(self, input_size, output_size, bias=True,
423423
else:
424424
self.register_parameter('bias', None)
425425

426+
self.bias_tp_auto_sync = args.sync_tp_duplicated_parameters
426427

427428

428429
def forward(self, input_):
@@ -435,6 +436,10 @@ def forward(self, input_):
435436
output_parallel = F.linear(input_parallel, self.weight)
436437
# All-reduce across all the partitions.
437438
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
439+
440+
if self.bias_tp_auto_sync:
441+
torch.distributed.all_reduce(self.bias, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group())
442+
438443
if not self.skip_bias_add:
439444
output = output_ + self.bias if self.bias is not None else output_
440445
output_bias = None

tests/test_training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def get_variation_config(self, variation, output_dir, n_samples=None):
167167
--clip-grad 1.0
168168
--weight-decay 1e-1
169169
--embed-layernorm
170-
--layernorm-tp-auto-sync
170+
--sync-tp-duplicated-parameters
171171
--fp16
172172
173173
--log-level debug

0 commit comments

Comments
 (0)