1
1
import ctypes
2
2
import random
3
3
import warnings
4
+ from collections import defaultdict
4
5
from contextlib import contextmanager
6
+ from copy import deepcopy
5
7
from functools import partial
6
8
from types import MethodType
7
9
from typing import Any , Callable , Dict , Iterator , List , Optional , OrderedDict , Tuple , Union
24
26
from colossalai .checkpoint_io import CheckpointIO , HybridParallelCheckpointIO
25
27
from colossalai .cluster import ProcessGroupMesh
26
28
from colossalai .interface import AMPModelMixin , ModelWrapper , OptimizerWrapper
29
+ from colossalai .interface .optimizer import DistributedOptim
30
+ from colossalai .nn .optimizer import DistGaloreAwamW
27
31
from colossalai .pipeline .schedule import InterleavedSchedule , OneForwardOneBackwardSchedule
28
32
from colossalai .pipeline .stage_manager import PipelineStageManager
29
33
from colossalai .shardformer import GradientCheckpointConfig , ShardConfig , ShardFormer
@@ -735,7 +739,7 @@ def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]:
735
739
# Get all working gradients and gradients to be synchronized.
736
740
all_working_grads = _get_all_working_grads ()
737
741
grads_to_sync = _get_grads_to_sync (all_working_grads )
738
- if self .require_grad_sync and grads_to_sync is not None :
742
+ if self ._grad_store . require_grad_sync and grads_to_sync is not None :
739
743
# Synchronize sequence parallelism gradients if required.
740
744
SeqParallelUtils .allreduce_partial_data_grad (process_group = self .tp_pg , grads = grads_to_sync )
741
745
else :
@@ -759,7 +763,7 @@ def backward(self, loss, retain_graph=False):
759
763
# Call the superclass backward method to compute gradients.
760
764
super ().backward (loss , retain_graph )
761
765
762
- if self .require_grad_sync and self .model .shard_config .enable_sequence_parallelism :
766
+ if self ._grad_store . require_grad_sync and self .model .shard_config .enable_sequence_parallelism :
763
767
# If gradient synchronization is required, sync sequence parallelism gradients.
764
768
self ._sync_sp_grads ()
765
769
else :
@@ -784,7 +788,7 @@ def backward_by_grad(self, tensor, grad):
784
788
# Call the superclass backward_by_grad method to compute gradients.
785
789
super ().backward_by_grad (tensor , grad )
786
790
787
- if self .require_grad_sync and self .model .shard_config .enable_sequence_parallelism :
791
+ if self ._grad_store . require_grad_sync and self .model .shard_config .enable_sequence_parallelism :
788
792
# If gradient synchronization is required, sync sequence parallelism gradients.
789
793
self ._sync_sp_grads ()
790
794
else :
@@ -1171,6 +1175,15 @@ def configure(
1171
1175
lr_scheduler : Optional [LRScheduler ] = None ,
1172
1176
) -> Tuple [Module , OptimizerWrapper , Callable , DataLoader , LRScheduler ]:
1173
1177
param_info = get_param_info (optimizer )
1178
+
1179
+ # TODO: Support Galore + ZeRO
1180
+ zero_stage = self .zero_stage
1181
+ zero_config = deepcopy (self .zero_config )
1182
+ if isinstance (optimizer , DistGaloreAwamW ) and zero_stage > 0 and self .dp_size > 0 :
1183
+ warnings .warn ("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO." )
1184
+ zero_config ["partition_grad" ] = False
1185
+ zero_stage = 0
1186
+
1174
1187
if not isinstance (model , ModelWrapper ):
1175
1188
use_ddp = (self .dp_size > 1 and self .pp_size == 1 and self .zero_stage == 0 ) or (
1176
1189
self .dp_size == 1
@@ -1194,7 +1207,8 @@ def configure(
1194
1207
custom_policy = self .custom_policy ,
1195
1208
)
1196
1209
if optimizer is not None and not isinstance (optimizer , OptimizerWrapper ):
1197
- if self .zero_stage == 0 :
1210
+ if zero_stage == 0 :
1211
+ is_zero = False
1198
1212
if self .precision in ["fp16" , "bf16" ]:
1199
1213
optimizer = HybridParallelAMPOptimizer (
1200
1214
optimizer ,
@@ -1218,11 +1232,11 @@ def configure(
1218
1232
tp_process_group = self .tp_group ,
1219
1233
)
1220
1234
else :
1221
- zero_dp_size = dist . get_world_size ( dp_group )
1222
- if zero_dp_size == 1 :
1235
+ is_zero = self . dp_size > 1
1236
+ if self . dp_size == 1 :
1223
1237
warnings .warn (
1224
1238
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
1225
- "If you are not intended to use cpu_offload, please consider set zero_stage=0."
1239
+ "If you do not intend to use cpu_offload, please consider set zero_stage=0."
1226
1240
)
1227
1241
1228
1242
assert self .precision != "fp32" , "Please set precision to 'fp16' or 'bf16' when using ZeRO."
@@ -1236,11 +1250,19 @@ def configure(
1236
1250
pp_process_group = self .pp_group ,
1237
1251
verbose = True ,
1238
1252
clip_grad_norm = self .max_norm ,
1239
- ** self . zero_config ,
1253
+ ** zero_config ,
1240
1254
** self .amp_config ,
1241
1255
)
1242
1256
# inject update_master_params
1243
1257
model .update_master_params = MethodType (optimizer .update_master_params , model )
1258
+
1259
+ # Setup optimizers that require global states
1260
+ optim = optimizer .optim
1261
+ if isinstance (optim , DistributedOptim ):
1262
+ shard_to_param = optimizer .get_master_to_working_map () if is_zero else {}
1263
+ padding_map = optimizer .get_param_padding_map () if is_zero else defaultdict (int )
1264
+ optim .setup_distributed (self .tp_group , self .dp_group , shard_to_param , padding_map , is_zero )
1265
+
1244
1266
return model , optimizer , criterion , dataloader , lr_scheduler
1245
1267
1246
1268
def execute_pipeline (
@@ -1272,7 +1294,7 @@ def execute_pipeline(
1272
1294
1273
1295
# run with gradients accumulation
1274
1296
if model .require_grad_sync == False or (
1275
- isinstance (optimizer , HybridParallelZeroOptimizer ) and optimizer .require_grad_sync == False
1297
+ isinstance (optimizer , HybridParallelZeroOptimizer ) and optimizer ._grad_store . require_grad_sync == False
1276
1298
):
1277
1299
return outputs
1278
1300
0 commit comments