Skip to content

Sub-optimal communication in transformer TP + SP backward block #5673

@Priya2698

Description

@Priya2698

With automatic scheduling, I observed more communications being inserted in preseg than required.
The following observation is for b = 1. With b>1, we have another issue: #4523
#5667 uses some manual scheduling to close the performance gap.

MHA / MLP Up Projection:

out = linear (in, weight, bias)

In backprop:
out_grad = [DIDx(d), b, s, 3*e // d] # Sharded on hidden dimension (or TP)
in = [DIDx(d), b, s//d, e] # Sharded on sequence dimension (or SP)
weight = [DIDx(d), 3*e//d, e] # TP sharded
bias = [DIDx(d), 3*e//d]
  • Computing weight gradient
weight_grad = matmul (out_grad.t(), in) # [DIDx(d), 3 * e // d, e, r{DIDx(d)}, r{s // d}]

This is decomposed into local matmul + sum:
local_matmul = [DIDx(d), 3 * e,  e, r {s // d}]
sum = [DIDx(d), 3 * e // d, r, r{DIDx(d)}]

Observed communciation: out_grad is resharded to [DIDx(d), s //d, 3 * e] to match with local_matmul. This is lowered to SendRecv (Also incorrect: See #4188 (comment)). The sum following the local matmul is a ReduceScatter
Optimal: Allgather in instead of sharding s in weight_grad or out_grad.

MHA / MLP Down Projection

out = linear (in, weight, bias)

In backprop:
out_grad = [DIDx(d), b, s //d , e] # SP sharded
in = [DIDx(d), b, s, 4 * e//d] # TP sharded
weight = [DIDx(d), e, 4 * e // d] # TP sharded
bias = [e] 
  • Computing bias grad
bias_grad = sum (out_grad, dim = [0,1]) [e, r{DIDx(d)}, r{s//d}]
This is decomposed into local sum + allreduce
Note: Sum over batch dim is a squeeze since b = 1

Communication = AllReduce

  • Computing weight_grad
out_grad_reshaped = [DIDx(d), s//d, e] (b is squeezed out by reshape)
weight_grad = matmul(out_grad_reshaped.t(), in) [DIDx(d), e, 4*e//d, r{s})

Note that, weight_grad is not sharded on s. This happens because we propagate from in first causing one non-reduction dimension to be sharded in weight_grad. When propagating from out_grad_reshaped.t(), DIDx is not chosen for propagation since it already present on the target on a non-reduction ID. We currently do not check if the ref sharding maps to a reduction dim in target.

If the above issue were resolved, we would probably see similar communication as in the up projection case above.

Communication = AllGather out_grad_reshaped.t()

  • Computing in_grad
in_grad = matmul (out_grad, w)

Communication = AllGather out_grad.

Optimal communication: Allgather out_grad and use the same for all three grad computations.

Layernorm

  • Computing bias_grad: propagation looks like the following:

Here, layernorm_grad, bcast_out, float_out are sharded on s in backpropagation. However, squeeze_out (coming from sum across b) is not sharded on s. This results in an Allgather, instead of AllReduce

  • Computing weight_grad: Incurs 1 Allreduce when s is reduced.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions