Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 80 additions & 28 deletions megatron/core/tensor_parallel/layers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.

import functools
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch

Expand All @@ -18,12 +18,15 @@
get_tensor_model_parallel_group,
get_global_memory_buffer,
)
from megatron.core.weight_grad_store import WeightGradStore
from .mappings import (
copy_to_tensor_model_parallel_region,
gather_from_tensor_model_parallel_region,
reduce_from_tensor_model_parallel_region,
scatter_to_tensor_model_parallel_region,
reduce_scatter_to_sequence_parallel_region,
_reduce_scatter_along_first_dim,
_gather_along_first_dim,
)

from .random import get_cuda_rng_tracker
Expand Down Expand Up @@ -215,12 +218,13 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):

@staticmethod
def forward(ctx, input, weight, bias, gradient_accumulation_fusion,
async_grad_allreduce, sequence_parallel):
async_grad_allreduce, sequence_parallel, reshard_for_sequence_parallel):
ctx.save_for_backward(input, weight)
ctx.use_bias = bias is not None
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
ctx.async_grad_allreduce = async_grad_allreduce
ctx.sequence_parallel = sequence_parallel
ctx.reshard_for_sequence_parallel = reshard_for_sequence_parallel

if sequence_parallel:
world_size = get_tensor_model_parallel_world_size()
Expand All @@ -240,40 +244,68 @@ def forward(ctx, input, weight, bias, gradient_accumulation_fusion,
output = torch.matmul(total_input, weight.t())
if bias is not None:
output = output + bias

if reshard_for_sequence_parallel:
assert not sequence_parallel
output = _reduce_scatter_along_first_dim(output)
return output

@staticmethod
def backward(ctx, grad_output):
def backward(ctx, grad_output_):
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias

if ctx.sequence_parallel:
world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
if ctx.reshard_for_sequence_parallel:
grad_output = _gather_along_first_dim(grad_output_)
else:
grad_output = grad_output_

all_gather_buffer = \
get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
handle = torch.distributed._all_gather_base(
all_gather_buffer,
input,
group=get_tensor_model_parallel_group(), async_op=True)
def pre_process(_input_):
if ctx.sequence_parallel:
world_size = get_tensor_model_parallel_world_size()
dim_size = list(_input_.size())
dim_size[0] = dim_size[0] * world_size

all_gather_buffer = get_global_memory_buffer().get_tensor(dim_size, _input_.dtype, "mpu")
_handle_ = torch.distributed._all_gather_base(
all_gather_buffer, _input_, group=get_tensor_model_parallel_group(), async_op=True
)

# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# gather is scheduled before the input gradient computation
_total_input_ = all_gather_buffer
return _total_input_, _handle_
else:
_total_input_ = _input_
return _total_input_, None

not_split_bw = not WeightGradStore.split_bw or not WeightGradStore.is_supported()
if not_split_bw:
pre_processed_results = pre_process(input)

# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# gather is scheduled before the input gradient computation
total_input = all_gather_buffer
else:
total_input = input
grad_input = grad_output.matmul(weight)

if ctx.sequence_parallel:
handle.wait()
def post_process(_pre_processed_results_):
_total_input_, _handle_ = _pre_processed_results_
if ctx.sequence_parallel:
_handle_.wait()
_total_input_ = _total_input_.view(
_total_input_.shape[0] * _total_input_.shape[1], _total_input_.shape[2]
)
return _total_input_

def execute_w_pass(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None):
if ctx.reshard_for_sequence_parallel:
_grad_output_ = _gather_along_first_dim(_grad_output_)
_total_input_ = post_process(pre_process(_input_))
wgrad_gemm_accum_func(_total_input_, _grad_output_, _weight_main_grad_)

if not_split_bw:
total_input = post_process(pre_processed_results)

# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1],
grad_output.shape[2])
total_input = total_input.view(total_input.shape[0] * total_input.shape[1],
total_input.shape[2])

if ctx.async_grad_allreduce:
# Asynchronous all-reduce
Expand All @@ -297,9 +329,25 @@ def backward(ctx, grad_output):

if ctx.gradient_accumulation_fusion:
if weight.main_grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, weight.main_grad)
elif weight.main_grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, weight.main_grad)
# capture the weight gradient computation of linear layers
if not_split_bw:
WeightGradStore.put(
total_input, grad_output, weight, fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
)
else:
WeightGradStore.put(
input, grad_output_, weight, functools.partial(execute_w_pass, wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32)
)
elif weight.main_grad.dtype in (torch.float16, torch.bfloat16):
# capture the weight gradient computation of linear layers
if not_split_bw:
WeightGradStore.put(
total_input, grad_output, weight, fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
)
else:
WeightGradStore.put(
input, grad_output_, weight, functools.partial(execute_w_pass, wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16)
)
else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
grad_weight = None
Expand All @@ -309,12 +357,12 @@ def backward(ctx, grad_output):

if ctx.sequence_parallel:
handle.wait()
return sub_grad_input, grad_weight, grad_bias, None, None, None
return sub_grad_input, grad_weight, grad_bias, None, None, None, None

if ctx.async_grad_allreduce:
handle.wait()

return grad_input, grad_weight, grad_bias, None, None, None
return grad_input, grad_weight, grad_bias, None, None, None, None


def linear_with_grad_accumulation_and_async_allreduce(
Expand All @@ -324,6 +372,7 @@ def linear_with_grad_accumulation_and_async_allreduce(
gradient_accumulation_fusion: bool,
async_grad_allreduce: bool,
sequence_parallel_enabled: bool,
reshard_for_sequence_parallel: bool = False,
) -> torch.Tensor:
"""Linear layer execution with asynchronous communication and
gradient accumulation fusion in backprop.
Expand Down Expand Up @@ -384,6 +433,7 @@ def linear_with_grad_accumulation_and_async_allreduce(
gradient_accumulation_fusion,
async_grad_allreduce,
sequence_parallel_enabled,
reshard_for_sequence_parallel,
]

if not linear_with_grad_accumulation_and_async_allreduce.warned:
Expand Down Expand Up @@ -685,11 +735,13 @@ def forward(self, input_):
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
async_grad_allreduce=False,
sequence_parallel_enabled=False,
reshard_for_sequence_parallel=self.sequence_parallel_enabled,
)

# All-reduce across all the partitions.
if self.sequence_parallel_enabled:
output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
# output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
output_ = output_parallel
else:
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
if not self.skip_bias_add:
Expand Down
55 changes: 55 additions & 0 deletions megatron/core/weight_grad_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import queue
from megatron import get_args


class WeightGradStore:

cache = []
weight_grad_queue = queue.Queue()
split_bw = True

@classmethod
def is_supported(cls):
"""If not supported, fallback to original schedule."""
args = get_args()
if args.pipeline_model_parallel_size <= 1:
return False
if args.virtual_pipeline_model_parallel_size is not None:
return False
if args.transformer_impl == 'transformer_engine':
# hard to capture weight gradient computation for transformer_engine
return False
return True

@classmethod
def put(cls, total_input, grad_output, weight, func):
if not cls.split_bw or not cls.is_supported():
func(total_input, grad_output, weight.main_grad)
return
# Store the weight gradient computation of linear layers.
cls.cache.append((total_input, grad_output, weight, func))

@classmethod
def flush(cls):
if not cls.is_supported():
return
# Collect all stored computations during backward as a W.
cls.weight_grad_queue.put(cls.cache)
cls.cache = []

@classmethod
def pop(cls):
if not cls.is_supported():
return
# Execute a single W.
assert cls.weight_grad_queue.qsize() > 0
stored_grads = cls.weight_grad_queue.get()
for total_input, grad_output, weight, func in stored_grads:
func(total_input, grad_output, weight.main_grad)

@classmethod
def pop_all(cls):
# Execute all remaining W.
remaining_qsize = cls.weight_grad_queue.qsize()
for _ in range(remaining_qsize):
cls.pop()
18 changes: 18 additions & 0 deletions megatron/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from megatron import get_num_microbatches
from megatron import p2p_communication
from megatron.core import mpu
from megatron.core.weight_grad_store import WeightGradStore
from megatron.utils import unwrap_model
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
Expand Down Expand Up @@ -693,13 +694,24 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)

# For BWF pattern or in rank 0, we don't split W and B for reasons below.
# 1. to leverage batched p2p op (send_backward_recv_forward)
# 2. to overlap grad all-reduce for tensor parallel
# 3. to avoid redoing grad all-gather for sequence parallel
# Note that the order of grad accumulation is changed by this behavior,
# thus causing a minor precision error compared to 1F1B even it's mathematically correct.
WeightGradStore.split_bw = (i < rank or last_iteration) and rank > 0
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad, timers)
if WeightGradStore.split_bw:
WeightGradStore.flush()

if last_iteration:
input_tensor = None
send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)
if i >= rank > 0: # delay W by rank
WeightGradStore.pop() # W
else:
input_tensor = \
send_backward_recv_forward(
Expand All @@ -713,10 +725,16 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,

output_tensor_grad = recv_backward(send_tensor_shapes, timers=timers)

WeightGradStore.split_bw = rank > 0
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad, timers)

send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)
if WeightGradStore.split_bw:
WeightGradStore.flush()
if num_microbatches_remaining + i >= rank:
WeightGradStore.pop() # W
WeightGradStore.pop_all() # W

return forward_data_store