Skip to content

Commit c06208e

Browse files
Merge pull request #5737 from yuanheng-zhao/inference/sync/main
[sync] Sync feature/colossal-infer with main
2 parents d8b1ea4 + 8633c15 commit c06208e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+6978
-278
lines changed

.github/ISSUE_TEMPLATE/bug-report.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@ body:
88
attributes:
99
value: >
1010
#### Not suitable for your needs? [Open a blank issue](https://github.com/hpcaitech/ColossalAI/issues/new).
11+
- type: checkboxes
12+
attributes:
13+
label: Is there an existing issue for this bug?
14+
description: Please search [here](https://github.com/hpcaitech/ColossalAI/issues) to see if an open or closed issue already exists for the bug you have encountered.
15+
options:
16+
- label: I have searched the existing issues
17+
required: true
1118
- type: textarea
1219
attributes:
1320
label: 🐛 Describe the bug

.github/workflows/build_on_pr.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ jobs:
140140
141141
- name: Install Colossal-AI
142142
run: |
143-
BUILD_EXT=1 pip install -v -e .
143+
pip install -v -e .
144144
pip install -r requirements/requirements-test.txt
145145
146146
- name: Store Colossal-AI Cache

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ Please visit our [documentation](https://www.colossalai.org/) and [examples](htt
418418
## Installation
419419

420420
Requirements:
421-
- PyTorch >= 1.11 and PyTorch <= 2.1
421+
- PyTorch >= 2.1
422422
- Python >= 3.7
423423
- CUDA >= 11.0
424424
- [NVIDIA GPU Compute Capability](https://developer.nvidia.com/cuda-gpus) >= 7.0 (V100/RTX20 and higher)

applications/Colossal-LLaMA/prepare_sft_dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import os
1111
from multiprocessing import cpu_count
1212

13-
from colossal_llama.dataset.conversation import default_conversation
13+
from colossal_llama.dataset.conversation import LLaMA2_Conv
1414
from colossal_llama.dataset.spliced_and_tokenized_dataset import supervised_tokenize_sft
1515
from datasets import dataset_dict, load_dataset
1616
from transformers import AddedToken, AutoTokenizer
@@ -78,6 +78,7 @@ def main():
7878
# Fix </s> split issue: https://github.com/huggingface/transformers/issues/23833
7979
if args.llama_version == 2:
8080
tokenizer.add_tokens(AddedToken("</s>", normalized=False, special=True), special_tokens=True)
81+
default_conversation = LLaMA2_Conv
8182

8283
tokenizer.add_bos_token = False
8384
tokenizer.add_eos_token = False

colossalai/booster/plugin/hybrid_parallel_plugin.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import ctypes
22
import random
33
import warnings
4+
from collections import defaultdict
45
from contextlib import contextmanager
6+
from copy import deepcopy
57
from functools import partial
68
from types import MethodType
79
from typing import Any, Callable, Dict, Iterator, List, Optional, OrderedDict, Tuple, Union
@@ -24,6 +26,8 @@
2426
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
2527
from colossalai.cluster import ProcessGroupMesh
2628
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
29+
from colossalai.interface.optimizer import DistributedOptim
30+
from colossalai.nn.optimizer import DistGaloreAwamW
2731
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
2832
from colossalai.pipeline.stage_manager import PipelineStageManager
2933
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
@@ -735,7 +739,7 @@ def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]:
735739
# Get all working gradients and gradients to be synchronized.
736740
all_working_grads = _get_all_working_grads()
737741
grads_to_sync = _get_grads_to_sync(all_working_grads)
738-
if self.require_grad_sync and grads_to_sync is not None:
742+
if self._grad_store.require_grad_sync and grads_to_sync is not None:
739743
# Synchronize sequence parallelism gradients if required.
740744
SeqParallelUtils.allreduce_partial_data_grad(process_group=self.tp_pg, grads=grads_to_sync)
741745
else:
@@ -759,7 +763,7 @@ def backward(self, loss, retain_graph=False):
759763
# Call the superclass backward method to compute gradients.
760764
super().backward(loss, retain_graph)
761765

762-
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
766+
if self._grad_store.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
763767
# If gradient synchronization is required, sync sequence parallelism gradients.
764768
self._sync_sp_grads()
765769
else:
@@ -784,7 +788,7 @@ def backward_by_grad(self, tensor, grad):
784788
# Call the superclass backward_by_grad method to compute gradients.
785789
super().backward_by_grad(tensor, grad)
786790

787-
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
791+
if self._grad_store.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
788792
# If gradient synchronization is required, sync sequence parallelism gradients.
789793
self._sync_sp_grads()
790794
else:
@@ -1171,6 +1175,15 @@ def configure(
11711175
lr_scheduler: Optional[LRScheduler] = None,
11721176
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
11731177
param_info = get_param_info(optimizer)
1178+
1179+
# TODO: Support Galore + ZeRO
1180+
zero_stage = self.zero_stage
1181+
zero_config = deepcopy(self.zero_config)
1182+
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
1183+
warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.")
1184+
zero_config["partition_grad"] = False
1185+
zero_stage = 0
1186+
11741187
if not isinstance(model, ModelWrapper):
11751188
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
11761189
self.dp_size == 1
@@ -1194,7 +1207,8 @@ def configure(
11941207
custom_policy=self.custom_policy,
11951208
)
11961209
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
1197-
if self.zero_stage == 0:
1210+
if zero_stage == 0:
1211+
is_zero = False
11981212
if self.precision in ["fp16", "bf16"]:
11991213
optimizer = HybridParallelAMPOptimizer(
12001214
optimizer,
@@ -1218,11 +1232,11 @@ def configure(
12181232
tp_process_group=self.tp_group,
12191233
)
12201234
else:
1221-
zero_dp_size = dist.get_world_size(dp_group)
1222-
if zero_dp_size == 1:
1235+
is_zero = self.dp_size > 1
1236+
if self.dp_size == 1:
12231237
warnings.warn(
12241238
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
1225-
"If you are not intended to use cpu_offload, please consider set zero_stage=0."
1239+
"If you do not intend to use cpu_offload, please consider set zero_stage=0."
12261240
)
12271241

12281242
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
@@ -1236,11 +1250,19 @@ def configure(
12361250
pp_process_group=self.pp_group,
12371251
verbose=True,
12381252
clip_grad_norm=self.max_norm,
1239-
**self.zero_config,
1253+
**zero_config,
12401254
**self.amp_config,
12411255
)
12421256
# inject update_master_params
12431257
model.update_master_params = MethodType(optimizer.update_master_params, model)
1258+
1259+
# Setup optimizers that require global states
1260+
optim = optimizer.optim
1261+
if isinstance(optim, DistributedOptim):
1262+
shard_to_param = optimizer.get_master_to_working_map() if is_zero else {}
1263+
padding_map = optimizer.get_param_padding_map() if is_zero else defaultdict(int)
1264+
optim.setup_distributed(self.tp_group, self.dp_group, shard_to_param, padding_map, is_zero)
1265+
12441266
return model, optimizer, criterion, dataloader, lr_scheduler
12451267

12461268
def execute_pipeline(
@@ -1272,7 +1294,7 @@ def execute_pipeline(
12721294

12731295
# run with gradients accumulation
12741296
if model.require_grad_sync == False or (
1275-
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
1297+
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer._grad_store.require_grad_sync == False
12761298
):
12771299
return outputs
12781300

colossalai/booster/plugin/low_level_zero_plugin.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
from typing import Callable, Dict, Iterator, List, Optional, Tuple
99

1010
import torch
11+
import torch.distributed
12+
import torch.distributed as dist
1113
import torch.nn as nn
14+
from torch.distributed.distributed_c10d import _get_default_group
1215
from torch.nn import Parameter
1316
from torch.optim import Optimizer
1417
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
@@ -28,6 +31,8 @@
2831
sharded_optimizer_loading_epilogue,
2932
)
3033
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
34+
from colossalai.interface.optimizer import DistributedOptim
35+
from colossalai.nn.optimizer import DistGaloreAwamW
3136
from colossalai.quantization import BnbQuantizationConfig, quantize_model
3237
from colossalai.zero import LowLevelZeroOptimizer
3338

@@ -428,13 +433,31 @@ def configure(
428433
if not isinstance(model, ModelWrapper):
429434
model = LowLevelZeroModel(model, self.precision)
430435

436+
# TODO: Support Galore + ZeRO
437+
zero_stage = self.stage
438+
zero_optim_kwargs = {**self.zero_optim_kwargs}
439+
dp_size = dist.get_world_size()
440+
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and dp_size > 0:
441+
warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.")
442+
zero_optim_kwargs["partition_grad"] = False
443+
zero_stage = 0
444+
431445
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
432446
optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(
433-
optimizer, **self.zero_optim_kwargs, verbose=self.verbose
447+
optimizer, **zero_optim_kwargs, verbose=self.verbose
434448
)
435449
# inject update_master_params
436450
model.update_master_params = MethodType(optimizer.update_master_params, model)
437451

452+
# Setup optimizers that require global states
453+
optim = optimizer.optim
454+
is_zero = dp_size > 1 and zero_stage > 0
455+
dp_group = _get_default_group() # Use the whole world
456+
if isinstance(optim, DistributedOptim):
457+
shard_to_param = optimizer.get_master_to_working_map()
458+
padding_map = optimizer.get_param_padding_map()
459+
optim.setup_distributed(None, dp_group, shard_to_param, padding_map, is_zero)
460+
438461
return model, optimizer, criterion, dataloader, lr_scheduler
439462

440463
def control_checkpoint_io(self) -> bool:

colossalai/cluster/process_group_mesh.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,12 @@ class ProcessGroupMesh:
3838

3939
def __init__(self, *size: int) -> None:
4040
assert dist.is_initialized(), "Please initialize torch.distributed first."
41-
assert prod(size) == dist.get_world_size(), "The product of the size must be equal to the world size."
41+
world_size = dist.get_world_size()
42+
prod_size = prod(size)
43+
assert (
44+
prod_size == world_size
45+
), f"The product of the size({prod_size}) must be equal to the world size({world_size})."
46+
4247
self._shape = size
4348
self._rank = dist.get_rank()
4449
self._coord = ProcessGroupMesh.unravel(self._rank, self._shape)

colossalai/device/device_mesh.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -306,9 +306,8 @@ def _init_global_to_logical_rank_mapping(
306306
# index means the local rank in the current axis
307307
# inner_tensor refers to the processes with the same local rank
308308

309-
if inner_tensor.numel() == 1:
310-
# if the inner_tensor only has one element, it means that
311-
# it already reaches the last axis
309+
if inner_tensor.dim() == 0:
310+
# if the inner_tensor already reaches the last axis,
312311
# we append its local_rank in the last axis to the index_list
313312
# and assign to the mapping
314313
# the value of the mapping is the the local rank at the indexed axis of the device mesh
@@ -459,6 +458,7 @@ def _collate_global_ranks_in_same_process_group(self, global_rank):
459458

460459
# replace the local rank in the given dimension with the
461460
# local rank of the current process iterated
461+
462462
process_coordinates[dim] = _local_rank
463463
processes_in_the_same_process_group[dim].append(process_coordinates)
464464

colossalai/interface/optimizer.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from typing import Union
1+
from typing import Dict, Optional, Union
22

33
import torch
4+
import torch.distributed as dist
45
import torch.nn as nn
56
from torch import Tensor
67
from torch.optim import Optimizer
@@ -133,3 +134,25 @@ def unwrap(self):
133134
Unwrap the optimizer for checkpoint saving/loading.
134135
"""
135136
return self.optim
137+
138+
139+
class DistributedOptim(Optimizer):
140+
def setup_distributed(
141+
self,
142+
tp_group: Optional[dist.ProcessGroup] = None,
143+
dp_group: Optional[dist.ProcessGroup] = None,
144+
shard_to_working_param: Optional[Dict] = {},
145+
padding_map: Optional[Dict] = None,
146+
is_zero: Optional[bool] = False,
147+
):
148+
"""Assign process groups for TP and ZeRO 2.
149+
Arguments:
150+
tp_group (dist.ProcessGroup): Tensor Parallel process group
151+
dp_group (dist.ProcessGroup): ZeRO stage 2 process group
152+
shard_to_working_param (Dict): ZeRO stage 2 feeds the optimizer a sharded param view to match grad shape.
153+
This maps from id(view) to model params used in forward & backward.
154+
padding_map (Dict): Per-param padding from ZeRO stage 2
155+
is_zero (bool): Whether to use ZeRO stage 2.
156+
"""
157+
158+
raise NotImplementedError("setup_distributed for TP/DP isn't supported by this optimizer yet!")

colossalai/lazy/pretrained.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import os
23
from typing import Callable, Optional, Union
34

@@ -74,6 +75,24 @@ def new_from_pretrained(
7475
subfolder = kwargs.pop("subfolder", "")
7576
commit_hash = kwargs.pop("_commit_hash", None)
7677
variant = kwargs.pop("variant", None)
78+
79+
kwargs.pop("state_dict", None)
80+
kwargs.pop("from_tf", False)
81+
kwargs.pop("from_flax", False)
82+
kwargs.pop("output_loading_info", False)
83+
kwargs.pop("trust_remote_code", None)
84+
kwargs.pop("low_cpu_mem_usage", None)
85+
kwargs.pop("device_map", None)
86+
kwargs.pop("max_memory", None)
87+
kwargs.pop("offload_folder", None)
88+
kwargs.pop("offload_state_dict", False)
89+
kwargs.pop("load_in_8bit", False)
90+
kwargs.pop("load_in_4bit", False)
91+
kwargs.pop("quantization_config", None)
92+
kwargs.pop("adapter_kwargs", {})
93+
kwargs.pop("adapter_name", "default")
94+
kwargs.pop("use_flash_attention_2", False)
95+
7796
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
7897

7998
if len(kwargs) > 0:
@@ -108,6 +127,10 @@ def new_from_pretrained(
108127
**kwargs,
109128
)
110129
else:
130+
config = copy.deepcopy(config)
131+
kwarg_attn_imp = kwargs.pop("attn_implementation", None)
132+
if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp:
133+
config._attn_implementation = kwarg_attn_imp
111134
model_kwargs = kwargs
112135

113136
if commit_hash is None:

0 commit comments

Comments
 (0)