1
- import copy
2
-
3
1
import pytest
4
2
import torch
5
3
import torch .distributed as dist
16
14
from colossalai .tensor .d_tensor import (
17
15
distribute_tensor ,
18
16
get_device_mesh ,
19
- get_layout ,
20
17
get_sharding_spec ,
21
18
is_distributed_tensor ,
22
19
shard_colwise ,
28
25
from colossalai .utils import set_seed
29
26
from colossalai .zero import LowLevelZeroOptimizer
30
27
from tests .kit .model_zoo import model_zoo
31
- from tests .test_optimizer ._utils import check_dist_optim_state , check_dist_param , check_optim_states
28
+ from tests .test_optimizer ._utils import (
29
+ check_dist_optim_state ,
30
+ check_dist_param ,
31
+ check_optim_states ,
32
+ set_master_param_to_shard_param ,
33
+ setup_param_groups ,
34
+ )
32
35
from tests .test_shardformer .test_model ._utils import (
33
36
build_model_from_hybrid_plugin ,
34
37
build_model_from_low_level_zero_plugin ,
38
41
unwrap_model ,
39
42
)
40
43
41
- HEIGHT = 4
42
- WIDTH = 4
44
+ IN_DIM = 4
45
+ HID_DIM = 4
43
46
_TP_SPEC = DimSpec ([0 ])
44
47
48
+ Net , data_gen , * _ = next (iter (model_zoo .get_sub_registry ("simple_mlp" ).values ()))
49
+ TPNet , * _ = next (iter (model_zoo .get_sub_registry ("simple_tp_mlp" ).values ()))
50
+
45
51
46
52
def correctness_verify (tensor1 : torch .Tensor , tensor2 : torch .Tensor , dtype : torch .dtype = torch .float32 ):
47
53
rtol = None
@@ -59,92 +65,11 @@ def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torc
59
65
assert_close (tensor1 , tensor2 , rtol = rtol , atol = atol )
60
66
61
67
62
- # setup param groups; (For zero test optim)
63
- def setup_param_groups_zero (model : nn .Module ) -> list :
64
- no_decay = ["bias" , "LayerNorm.weight" ]
65
- optimizer_grouped_parameters = [
66
- {
67
- "params" : [p for n , p in model .named_parameters () if not any (nd in n for nd in no_decay )],
68
- "weight_decay" : 0.1 ,
69
- },
70
- {
71
- "params" : [p for n , p in model .named_parameters () if any (nd in n for nd in no_decay )],
72
- "weight_decay" : 0.0 ,
73
- },
74
- ]
75
- return optimizer_grouped_parameters
76
-
77
-
78
- # setup param groups; (For base optim)
79
- def setup_param_groups (model : nn .Module ) -> list :
80
- optimizer_grouped_parameters = [p for n , p in model .named_parameters ()]
81
- return optimizer_grouped_parameters
82
-
83
-
84
- # setup flatten param groups, sharding spec and shape; (For dist optim)
85
- def setup_flatten_param_groups_sharding_spec_shape (model : nn .Module ) -> dict :
86
- flatten_optimizer_grouped_parameters = []
87
- sharding_spec = {} # {id(flatten param): get_layout(p).global_shape}
88
- param_shape = {} # {id(flatten param): get_sharding_spec(p)}
89
- for n , p in model .named_parameters ():
90
- # flatten_p = copy.deepcopy(p).flatten()
91
- flatten_p = nn .Parameter (p .clone ().flatten ().requires_grad_ (True ))
92
- flatten_optimizer_grouped_parameters .append (flatten_p )
93
- if is_distributed_tensor (p ):
94
- sharding_spec [id (flatten_p )] = get_sharding_spec (p )
95
- param_shape [id (flatten_p )] = get_layout (p ).global_shape
96
- else :
97
- sharding_spec [id (flatten_p )] = None
98
- param_shape [id (flatten_p )] = p .shape
99
- return flatten_optimizer_grouped_parameters , sharding_spec , param_shape
100
-
101
-
102
- def set_dist_grad (
103
- dist_module : nn .Module , torch_model : nn .Module , g_dtype : torch .dtype , group : dist .ProcessGroup
104
- ) -> None :
105
- """
106
- Set split grads for Tensor Parallel or ZeRO DP.
107
- We do not need a separate treatment for ZeRO,
108
- as the wrapper takes care of reduce-scattering grads.
109
- """
110
- rank = dist .get_rank (group )
111
- world_size = dist .get_world_size (group )
112
-
113
- for p , torch_p in zip (dist_module .parameters (), torch_model .parameters ()):
114
- if torch_p .grad is None :
115
- torch_p .grad = torch .zeros_like (torch_p )
116
-
117
- is_distributed = hasattr (p , "dist_layout" )
118
- if is_distributed :
119
- sharding = p .dist_layout .sharding_spec .sharding_sequence
120
- split_dim = sharding .index (_TP_SPEC )
121
- shape = torch_p .split (world_size , dim = split_dim )[rank ].shape
122
-
123
- indices = torch .arange (shape [split_dim ] * rank , shape [split_dim ] * (rank + 1 ))
124
- # Generate grads only for the correctly split chunk
125
- torch_p .grad .index_add_ (split_dim , indices , torch .randn (shape , device = torch_p .device , dtype = g_dtype ))
126
-
127
- else :
128
- shape = torch_p .shape
129
- torch_p .grad += torch .randn (shape , device = torch_p .device , dtype = g_dtype )
130
-
131
- # avoid inconsistent grad and param dtype error
132
- orig_p = p .data
133
- p .data = torch_p .grad .clone ().to (g_dtype )
134
- p .grad = p .data
135
- p .data = orig_p
136
-
137
-
138
- def set_master_param_to_shard_param (master_param_list ) -> dict :
139
- master_param_to_shard_param = {id (p ): p for p in master_param_list }
140
- return master_param_to_shard_param
141
-
142
-
143
68
class MlpModel (nn .Module ):
144
69
def __init__ (self ):
145
70
super (MlpModel , self ).__init__ ()
146
- self .linear1 = nn .Linear (HEIGHT , WIDTH )
147
- self .linear2 = nn .Linear (WIDTH , HEIGHT )
71
+ self .linear1 = nn .Linear (IN_DIM , HID_DIM )
72
+ self .linear2 = nn .Linear (HID_DIM , IN_DIM )
148
73
149
74
def forward (self , x ):
150
75
x = self .linear1 (x )
@@ -182,7 +107,7 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
182
107
# ==============================
183
108
# Base Case
184
109
# ==============================
185
- H , W = HEIGHT , WIDTH
110
+ H , W = IN_DIM , HID_DIM
186
111
model_col = nn .Linear (H , W ).to (local_rank ) # Col parallel weight
187
112
weight , bias = model_col .weight , model_col .bias
188
113
@@ -284,8 +209,11 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
284
209
# ==============================
285
210
# Model Init
286
211
# ==============================
287
- base_model = MlpModel ().to (local_rank )
288
- tp_model = TPModel (copy .deepcopy (base_model .linear1 ), copy .deepcopy (base_model .linear2 ), tp_group ).to (local_rank )
212
+ # base_model = MlpModel().to(local_rank)
213
+ # tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank)
214
+ base_model = Net (in_dim = IN_DIM , hid_dim = HID_DIM , dtype = dtype ).to (local_rank )
215
+ # Must specify dtype; TPNet init seem to run out of set_default_dtype scope
216
+ tp_model = TPNet (fc1 = base_model .fc1 , fc2 = base_model .fc2 , tp_group = tp_group , dtype = dtype )
289
217
290
218
base_param_group = setup_param_groups (base_model )
291
219
tp_param_group = setup_param_groups (tp_model )
@@ -335,7 +263,7 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
335
263
# ==============================
336
264
# Correctness Verify
337
265
# ==============================
338
- x = torch .randn (HEIGHT , WIDTH , device = local_rank )
266
+ x = torch .randn (IN_DIM , HID_DIM , device = local_rank )
339
267
340
268
out = base_model (x )
341
269
out_tp = tp_model (x )
@@ -353,7 +281,9 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
353
281
base_optim .zero_grad ()
354
282
dist_optim .zero_grad ()
355
283
356
- for p , tp_p in zip (base_param_group , tp_param_group ):
284
+ base_params = base_model .parameters ()
285
+ tp_params = tp_model .parameters ()
286
+ for p , tp_p in zip (base_params , tp_params ):
357
287
param_is_distributed = is_distributed_tensor (tp_p )
358
288
if param_is_distributed :
359
289
shard_spec = get_sharding_spec (tp_p )
0 commit comments