-
Notifications
You must be signed in to change notification settings - Fork 73
Description
With automatic scheduling, I observed more communications being inserted in preseg than required.
The following observation is for b = 1. With b>1, we have another issue: #4523
#5667 uses some manual scheduling to close the performance gap.
MHA / MLP Up Projection:
out = linear (in, weight, bias)
In backprop:
out_grad = [DIDx(d), b, s, 3*e // d] # Sharded on hidden dimension (or TP)
in = [DIDx(d), b, s//d, e] # Sharded on sequence dimension (or SP)
weight = [DIDx(d), 3*e//d, e] # TP sharded
bias = [DIDx(d), 3*e//d]
- Computing weight gradient
weight_grad = matmul (out_grad.t(), in) # [DIDx(d), 3 * e // d, e, r{DIDx(d)}, r{s // d}]
This is decomposed into local matmul + sum:
local_matmul = [DIDx(d), 3 * e, e, r {s // d}]
sum = [DIDx(d), 3 * e // d, r, r{DIDx(d)}]
Observed communciation: out_grad is resharded to [DIDx(d), s //d, 3 * e] to match with local_matmul. This is lowered to SendRecv (Also incorrect: See #4188 (comment)). The sum following the local matmul is a ReduceScatter
Optimal: Allgather in instead of sharding s in weight_grad or out_grad.
MHA / MLP Down Projection
out = linear (in, weight, bias)
In backprop:
out_grad = [DIDx(d), b, s //d , e] # SP sharded
in = [DIDx(d), b, s, 4 * e//d] # TP sharded
weight = [DIDx(d), e, 4 * e // d] # TP sharded
bias = [e]
- Computing bias grad
bias_grad = sum (out_grad, dim = [0,1]) [e, r{DIDx(d)}, r{s//d}]
This is decomposed into local sum + allreduce
Note: Sum over batch dim is a squeeze since b = 1
Communication = AllReduce
- Computing weight_grad
out_grad_reshaped = [DIDx(d), s//d, e] (b is squeezed out by reshape)
weight_grad = matmul(out_grad_reshaped.t(), in) [DIDx(d), e, 4*e//d, r{s})
Note that, weight_grad is not sharded on s. This happens because we propagate from in first causing one non-reduction dimension to be sharded in weight_grad. When propagating from out_grad_reshaped.t(), DIDx is not chosen for propagation since it already present on the target on a non-reduction ID. We currently do not check if the ref sharding maps to a reduction dim in target.
If the above issue were resolved, we would probably see similar communication as in the up projection case above.
Communication = AllGather out_grad_reshaped.t()
- Computing in_grad
in_grad = matmul (out_grad, w)
Communication = AllGather out_grad.
Optimal communication: Allgather out_grad and use the same for all three grad computations.
Layernorm
- Computing bias_grad: propagation looks like the following:
Here, layernorm_grad, bcast_out, float_out are sharded on s in backpropagation. However, squeeze_out (coming from sum across b) is not sharded on s. This results in an Allgather, instead of AllReduce
- Computing weight_grad: Incurs 1 Allreduce when
sis reduced.