Skip to content

Commit 2fc85ab

Browse files
[gemini] async grad chunk reduce (all-reduce&reduce-scatter) (#5713)
* [gemini] async grad chunk reduce (all-reduce&reduce-scatter) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [gemini] add test * [gemini] rename func * [gemini] update llama benchmark * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [gemini] use tensor counter * [gemini] change default config in GeminiPlugin and GeminiDDP * [chore] typo * [gemini] fix sync issue & add test cases * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 85946d4 commit 2fc85ab

File tree

11 files changed

+130
-45
lines changed

11 files changed

+130
-45
lines changed

colossalai/booster/plugin/gemini_plugin.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,7 @@ def __init__(
361361
enable_sequence_parallelism: bool = False,
362362
enable_jit_fused: bool = False,
363363
enable_sequence_overlap: bool = False,
364+
enable_async_reduce: bool = True,
364365
verbose: bool = False,
365366
) -> None:
366367
super().__init__()
@@ -386,6 +387,7 @@ def __init__(
386387
memstats=memstats,
387388
mixed_precision=PRECISION_STR_TO_DTYPE[precision],
388389
master_weights=master_weights,
390+
enable_async_reduce=enable_async_reduce,
389391
)
390392
self.zero_optim_config = dict(
391393
gpu_margin_mem_ratio=gpu_margin_mem_ratio,

colossalai/zero/gemini/chunk/chunk.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@ def __init__(
164164
self.l2_norm = None
165165

166166
self.grad_chunk = None
167+
# the async all-reduce/reduce-scatter work of this grad chunk (None means sync)
168+
self.grad_reduce_work = None
167169

168170
@property
169171
def memory_usage(self) -> Dict[str, int]:
@@ -244,7 +246,7 @@ def has_inf_or_nan(self) -> bool:
244246
assert self.cuda_shard is not None # only check on CUDA
245247
valid_tensor = self.cuda_shard[: self.valid_end]
246248

247-
return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item()
249+
return torch.isinf(valid_tensor).any() | torch.isnan(valid_tensor).any()
248250

249251
def set_l2_norm(self) -> None:
250252
"""Record l2 norm of this chunks on CUDA."""
@@ -374,37 +376,49 @@ def release_chunk(self):
374376
if self.is_gathered:
375377
self.__scatter()
376378

377-
def reduce(self):
379+
def reduce(self, async_op: bool = False):
378380
"""Reduce scatter all the gradients. It's an operation done in CUDA."""
379381
# sanity check
380382
assert self.is_gathered
381-
383+
assert self.grad_reduce_work is None
382384
if self.pg_size == 1:
383385
# tricky code here
384386
# just move cuda_global_chunk to cuda_shard
385387
# the communication is not necessary
386388
self.__scatter()
387389
if self.extra_dp_group is not None:
388-
dist.all_reduce(self.cuda_shard, group=self.extra_dp_group)
390+
self.grad_reduce_work = dist.all_reduce(self.cuda_shard, group=self.extra_dp_group, async_op=async_op)
389391
elif self.keep_gathered:
390392
# we use all-reduce here
391-
dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg)
392-
if self.extra_dp_group is not None:
393-
dist.all_reduce(self.cuda_global_chunk, group=self.extra_dp_group)
393+
self.grad_reduce_work = dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg, async_op=async_op)
394+
if self.extra_dp_group is not None: # cannot guranatee the order of multiple all-reduce
395+
self.wait_async_reduce()
396+
self.grad_reduce_work = dist.all_reduce(
397+
self.cuda_global_chunk, group=self.extra_dp_group, async_op=async_op
398+
)
394399
else:
395400
self.cuda_shard = torch.empty(
396401
self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device()
397402
)
398403

399404
input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
400-
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
405+
self.grad_reduce_work = dist.reduce_scatter(
406+
self.cuda_shard, input_list, group=self.torch_pg, async_op=async_op
407+
)
408+
401409
if self.extra_dp_group is not None:
402-
dist.all_reduce(self.cuda_shard, group=self.extra_dp_group)
410+
self.wait_async_reduce()
411+
self.grad_reduce_work = dist.all_reduce(self.cuda_shard, group=self.extra_dp_group, async_op=async_op)
403412

404413
free_storage(self.cuda_global_chunk)
405414
self.is_gathered = False
406415
self.__update_tensors_state(TensorState.HOLD)
407416

417+
def wait_async_reduce(self) -> None:
418+
if self.grad_reduce_work is not None:
419+
self.grad_reduce_work.wait()
420+
self.grad_reduce_work = None
421+
408422
def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None:
409423
"""
410424
Make a transition of the tensor into the next state.

colossalai/zero/gemini/chunk/manager.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(
4141
self.reuse_fp16_chunk = reuse_fp16_chunk
4242
# Whether model is accumulating gradients,
4343
self.accumulating_grads = False
44-
self.overflow_counter = 0
44+
self.overflow_counter = torch.tensor([0], dtype=torch.int, device=get_accelerator().get_current_device())
4545

4646
def register_tensor(
4747
self,
@@ -143,12 +143,12 @@ def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
143143
chunk = self.tensor_chunk_map[tensor]
144144
chunk.tensor_trans_state(tensor, state)
145145

146-
def reduce_chunk(self, chunk: Chunk) -> bool:
146+
def reduce_chunk(self, chunk: Chunk, async_op: bool = False) -> bool:
147147
"""Reduce or all reduce the chunk."""
148148
if not chunk.can_reduce:
149149
return False
150150
self.__sub_memory_usage(chunk.memory_usage)
151-
chunk.reduce()
151+
chunk.reduce(async_op=async_op)
152152
self.__sub_accessed_chunk(chunk)
153153
self.__add_memory_usage(chunk.memory_usage)
154154
return True
@@ -272,7 +272,7 @@ def init_grad_chunk(self, chunk: Chunk) -> Chunk:
272272
return grad_chunk
273273

274274
def rearrange_accumulated_grad_chunk(self, chunk: Chunk) -> Chunk:
275-
"""Rearrange gradients accumulated in chunk.grad_chunk, and getP prepared for gradient reduction."""
275+
"""Rearrange gradients accumulated in chunk.grad_chunk, and get prepared for gradient reduction."""
276276

277277
assert chunk.grad_chunk is not None
278278

colossalai/zero/gemini/gemini_ddp.py

Lines changed: 57 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def __init__(
9696
master_weights: bool = True,
9797
extra_dp_group: Optional[ProcessGroup] = None,
9898
verbose: bool = False,
99+
enable_async_reduce: bool = True,
99100
) -> None:
100101
assert mixed_precision in (torch.float16, torch.bfloat16)
101102
reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False
@@ -178,6 +179,7 @@ def __init__(
178179
if is_ddp_ignored(p):
179180
continue
180181
if p.requires_grad:
182+
assert not hasattr(p, "_grad_handle")
181183
p._grad_handle = p.register_hook(
182184
partial(
183185
GeminiDDP.grad_handle,
@@ -187,6 +189,7 @@ def __init__(
187189
master_weights=self.master_weights,
188190
enable_gradient_accumulation=self.enable_gradient_accumulation,
189191
p=p,
192+
async_reduce=enable_async_reduce,
190193
)
191194
)
192195

@@ -334,6 +337,11 @@ def _pre_backward(self):
334337
setattr(param, "_gemini_reduced", False)
335338

336339
def _post_backward(self):
340+
for param in self.param2name:
341+
if hasattr(param, "_release_grad_chunk_cb"):
342+
param._release_grad_chunk_cb()
343+
delattr(param, "_release_grad_chunk_cb")
344+
337345
if self.chunk_manager.accessed_mem != 0:
338346
error_params = ["Reduction failed at followed parameters:"]
339347
for param in self.param2name:
@@ -371,6 +379,7 @@ def grad_handle(
371379
master_weights: bool,
372380
enable_gradient_accumulation: bool,
373381
p: nn.Parameter,
382+
async_reduce: bool,
374383
):
375384
setattr(p, "_gemini_reduced", True)
376385
empty_grad = torch.empty_like(grad)
@@ -406,31 +415,57 @@ def grad_handle(
406415
grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=chunk_manager.reuse_fp16_chunk)
407416
else:
408417
grad_chunk.add_tensor_to_chunk_slice(p, grad)
409-
reduced = chunk_manager.reduce_chunk(grad_chunk)
410-
if reduced:
411-
if not chunk_manager.reuse_fp16_chunk:
412-
if chunk.keep_gathered:
413-
chunk_manager.fake_release_chunk(chunk)
414-
else:
415-
chunk_manager.release_chunk(chunk)
416-
if grad_chunk.is_gathered:
417-
grad_chunk.cuda_global_chunk.div_(chunk.pg_size)
418-
if chunk.extra_dp_group is not None:
419-
grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size)
418+
reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=async_reduce)
419+
if reduced: # if not async, can release immediately, else release in when work finished
420+
if async_reduce:
421+
# dirty fix by installing callback
422+
assert not hasattr(p, "_release_grad_chunk_cb")
423+
424+
def _release_grad_chunk_cb():
425+
grad_chunk.wait_async_reduce()
426+
GeminiDDP.release_grad_chunk_handle(
427+
chunk_manager,
428+
grads_device,
429+
master_weights,
430+
enable_gradient_accumulation,
431+
p,
432+
chunk,
433+
grad_chunk,
434+
)
435+
436+
p._release_grad_chunk_cb = _release_grad_chunk_cb
420437
else:
421-
grad_chunk.cuda_shard.div_(chunk.pg_size)
422-
if chunk.extra_dp_group is not None:
423-
grad_chunk.cuda_shard.div_(chunk.extra_dp_size)
424-
# check overflow elements
425-
chunk_manager.overflow_counter += grad_chunk.has_inf_or_nan
426-
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
427-
if chunk.l2_norm_flag:
428-
grad_chunk.set_l2_norm()
429-
chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True)
430-
if not (master_weights) or (enable_gradient_accumulation):
431-
chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True)
438+
GeminiDDP.release_grad_chunk_handle(
439+
chunk_manager, grads_device, master_weights, enable_gradient_accumulation, p, chunk, grad_chunk
440+
)
432441
return empty_grad
433442

443+
@staticmethod
444+
def release_grad_chunk_handle(
445+
chunk_manager, grads_device, master_weights, enable_gradient_accumulation, p, chunk, grad_chunk
446+
):
447+
if not chunk_manager.reuse_fp16_chunk:
448+
if chunk.keep_gathered:
449+
chunk_manager.fake_release_chunk(chunk)
450+
else:
451+
chunk_manager.release_chunk(chunk)
452+
if grad_chunk.is_gathered:
453+
grad_chunk.cuda_global_chunk.div_(chunk.pg_size)
454+
if chunk.extra_dp_group is not None:
455+
grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size)
456+
else:
457+
grad_chunk.cuda_shard.div_(chunk.pg_size)
458+
if chunk.extra_dp_group is not None:
459+
grad_chunk.cuda_shard.div_(chunk.extra_dp_size)
460+
# check overflow elements
461+
chunk_manager.overflow_counter += grad_chunk.has_inf_or_nan
462+
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
463+
if chunk.l2_norm_flag:
464+
grad_chunk.set_l2_norm()
465+
chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True)
466+
if not (master_weights) or (enable_gradient_accumulation):
467+
chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True)
468+
434469
def zero_grad(self, set_to_none: bool = False) -> None:
435470
self.module.zero_grad(set_to_none=True)
436471

colossalai/zero/gemini/gemini_optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,10 @@ def __init__(
6262
self.module = module
6363

6464
def check_local_overflow(self) -> bool:
65-
return self.module.chunk_manager.overflow_counter > 0
65+
return self.module.chunk_manager.overflow_counter.item() > 0
6666

6767
def pre_zero_grad(self) -> None:
68-
self.module.chunk_manager.overflow_counter = 0
68+
self.module.chunk_manager.overflow_counter.zero_()
6969

7070

7171
class GeminiOptimizer(OptimizerWrapper):

examples/language/llama/benchmark.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ def main():
7676
parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel")
7777
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
7878
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
79+
parser.add_argument("--disable-async-reduce", action="store_true", help="Customize checkpoint", default=False)
80+
7981
args = parser.parse_args()
8082

8183
colossalai.launch_from_torch()
@@ -110,6 +112,7 @@ def empty_init():
110112
extra_dp_size=args.extra_dp,
111113
enable_fused_normalization=torch.cuda.is_available(),
112114
enable_flash_attention=args.xformers,
115+
enable_async_reduce=not args.disable_async_reduce,
113116
)
114117
elif args.plugin == "gemini_auto":
115118
plugin = GeminiPlugin(

tests/test_zero/test_gemini/test_chunkv2.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ def check_equal(param, param_cp):
3434
@parameterize("init_device", [None, torch.device("cpu")])
3535
@parameterize("keep_gathered", [True, False])
3636
@parameterize("pin_memory", [True, False])
37-
def exam_chunk_basic(init_device, keep_gathered, pin_memory):
37+
@parameterize("async_op", [True, False])
38+
def exam_chunk_basic(init_device, keep_gathered, pin_memory, async_op):
3839
world_size = torch.distributed.get_world_size()
3940
pg = _get_default_group()
4041
my_chunk = Chunk(
@@ -94,9 +95,12 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory):
9495

9596
assert my_chunk.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == 4
9697
assert my_chunk.can_reduce
97-
my_chunk.reduce()
98+
my_chunk.reduce(async_op)
9899
assert my_chunk.tensor_state_cnter[TensorState.HOLD] == 4
99100

101+
if async_op:
102+
my_chunk.wait_async_reduce()
103+
100104
if keep_gathered is False:
101105
assert my_chunk.cuda_shard.size(0) == 1024 // world_size
102106
assert my_chunk.device_type == "cuda"

tests/test_zero/test_gemini/test_fwd_bwd.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,14 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
4040
@parameterize("model_name", ["transformers_gpt_lm"])
4141
@parameterize("use_grad_checkpoint", [False, True])
4242
@parameterize("master_weights", [False, True])
43+
@parameterize("enable_async_reduce", [False, True])
4344
def exam_gpt_fwd_bwd(
4445
placement_config,
4546
keep_gather,
4647
model_name: str,
4748
use_grad_checkpoint: bool = False,
4849
master_weights: bool = True,
50+
enable_async_reduce=True,
4951
):
5052
init_device = get_accelerator().get_current_device()
5153
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
@@ -69,7 +71,13 @@ def exam_gpt_fwd_bwd(
6971
config_dict[world_size]["chunk_size"] = 5000
7072
config_dict[world_size]["keep_gathered"] = keep_gather
7173
model = GeminiDDP(
72-
model, config_dict, init_device, pin_memory=True, **placement_config, master_weights=master_weights
74+
model,
75+
config_dict,
76+
init_device,
77+
pin_memory=True,
78+
**placement_config,
79+
master_weights=master_weights,
80+
enable_async_reduce=enable_async_reduce,
7381
)
7482
optimizer = HybridAdam(model.parameters(), lr=1e-3)
7583
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=1)

tests/test_zero/test_gemini/test_grad_accum.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,14 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
5050
@parameterize("model_name", ["transformers_gpt_lm"])
5151
@parameterize("master_weights", [False, True])
5252
@parameterize("use_grad_checkpoint", [False, True])
53+
@parameterize("enable_async_reduce", [False, True])
5354
def exam_gemini_grad_acc(
54-
placement_config, keep_gathered: bool, model_name: str, master_weights: bool, use_grad_checkpoint: bool
55+
placement_config,
56+
keep_gathered: bool,
57+
model_name: str,
58+
master_weights: bool,
59+
use_grad_checkpoint: bool,
60+
enable_async_reduce: bool,
5561
):
5662
init_device = get_accelerator().get_current_device()
5763
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
@@ -81,10 +87,13 @@ def exam_gemini_grad_acc(
8187
pin_memory=True,
8288
enable_gradient_accumulation=True,
8389
master_weights=master_weights,
90+
enable_async_reduce=enable_async_reduce,
8491
**placement_config,
8592
)
8693
optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3)
87-
gemini_optim = GeminiOptimizer(optimizer, gemini_model, initial_scale=1, max_norm=1.0)
94+
gemini_optim = GeminiOptimizer(
95+
optimizer, gemini_model, initial_scale=1, max_norm=1.0, enable_async_reduce=enable_async_reduce
96+
)
8897

8998
rank = dist.get_rank()
9099

tests/test_zero/test_gemini/test_grad_clip.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module):
5252
@parameterize("placement_config", PLACEMENT_CONFIGS)
5353
@parameterize("model_name", ["transformers_gpt_lm"])
5454
@parameterize("master_weights", [True, False])
55-
def exam_grad_clipping(placement_config, model_name: str, master_weights: bool):
55+
@parameterize("enable_async_reduce", [False, True])
56+
def exam_grad_clipping(placement_config, model_name: str, master_weights: bool, enable_async_reduce: bool):
5657
set_seed(1912)
5758
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
5859
iter(model_zoo.get_sub_registry(model_name).values())
@@ -84,6 +85,7 @@ def exam_grad_clipping(placement_config, model_name: str, master_weights: bool):
8485
chunk_init_device=init_device,
8586
pin_memory=True,
8687
master_weights=master_weights,
88+
enable_async_reduce=enable_async_reduce,
8789
**placement_config,
8890
)
8991

0 commit comments

Comments
 (0)