Skip to content

Commit ec73f1b

Browse files
authored
[CI] Cleanup Dist Optim tests with shared helper funcs (#6125)
* Refractor and cleanup using common helper funcs. Tests passed * Update comments * Fix relative import * Fix param fetching bug
1 parent 5c09d72 commit ec73f1b

File tree

8 files changed

+142
-298
lines changed

8 files changed

+142
-298
lines changed

colossalai/shardformer/layer/linear.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ class Linear1D_Row(ParallelModule):
384384
out_features (int): size of each output sample.
385385
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
386386
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
387-
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
387+
parallel_input (bool): If set to ``True``, it's assumed that the input is already split/copied across each rank, defaults to False.
388388
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
389389
seq_parallel_mode (`str`): The type of sp mode, it will use sequence parallel when `seq_parallel_mode` is not None. Defaults to None.
390390
seq_parallel_dim (`int`): Which dim will sequence parallelism split and gather the sequence.
@@ -544,14 +544,14 @@ def forward(self, input_: Tensor) -> Tensor:
544544
if self.parallel_input:
545545
assert (
546546
input_.shape[-1] == self.weight.shape[-1]
547-
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
547+
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected feature dim of input {}.".format(
548548
input_.shape, self.weight.shape, self.weight.shape[-1]
549549
)
550550
input_ = input_
551551
else:
552552
assert (
553553
divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1]
554-
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
554+
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected feature dim of input {}.".format(
555555
input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions
556556
)
557557
input_ = split_forward_gather_backward(

tests/kit/model_zoo/custom/simple_mlp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414

1515
class Net(nn.Module):
16-
def __init__(self, in_dim=_IN_DIM, hid_dim=_HID_DIM, identity=False, dtype=torch.float32):
16+
def __init__(self, in_dim=_IN_DIM, hid_dim=_HID_DIM, identity=True, dtype=torch.float32):
1717
super().__init__()
1818
if identity:
1919
self.fc0 = nn.Identity()
@@ -30,7 +30,7 @@ def forward(self, x):
3030
class TPNet(nn.Module):
3131
def __init__(
3232
self,
33-
fc0=nn.Linear(_IN_DIM, _IN_DIM),
33+
fc0=nn.Identity(),
3434
fc1=nn.Linear(_IN_DIM, _HID_DIM),
3535
fc2=nn.Linear(_HID_DIM, _IN_DIM),
3636
tp_group=None,

tests/test_optimizer/_utils.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import torch
22
import torch.distributed as dist
3+
import torch.nn as nn
34
from torch.testing import assert_close
45

56
import colossalai
67
from colossalai.shardformer.layer.utils import Randomizer
8+
from colossalai.tensor.d_tensor import get_layout, get_sharding_spec, is_distributed_tensor
79
from colossalai.tensor.d_tensor.api import clear_layout_converter
10+
from colossalai.tensor.d_tensor.sharding_spec import DimSpec
811
from colossalai.testing import parameterize, spawn
912
from tests.kit.model_zoo import model_zoo
1013
from tests.test_shardformer.test_model._utils import (
@@ -15,6 +18,88 @@
1518
)
1619

1720

21+
def force_assign_grad(p, g_dtype, grad=None):
22+
"""Bypass inconsistent grad and param dtype error when assigning grad"""
23+
orig_p = p.data
24+
p.data = torch.randn_like(p, device=orig_p.device, dtype=g_dtype) if grad == None else grad.clone().to(g_dtype)
25+
p.grad = p.data
26+
p.data = orig_p
27+
28+
29+
def setup_param_groups(model: nn.Module) -> list:
30+
no_decay = ["bias", "LayerNorm.weight"]
31+
optimizer_grouped_parameters = [
32+
{
33+
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
34+
"weight_decay": 0.1,
35+
},
36+
{
37+
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
38+
"weight_decay": 0.0,
39+
},
40+
]
41+
return optimizer_grouped_parameters
42+
43+
44+
# setup flatten param groups, sharding spec and shape; (For dist Adafactor and CAME)
45+
def setup_flatten_param_groups_sharding_spec_shape(model: nn.Module) -> dict:
46+
flatten_optimizer_grouped_parameters = []
47+
sharding_spec = {} # {id(flatten param): get_layout(p).global_shape}
48+
param_shape = {} # {id(flatten param): get_sharding_spec(p)}
49+
for n, p in model.named_parameters():
50+
# flatten_p = copy.deepcopy(p).flatten()
51+
flatten_p = nn.Parameter(p.clone().flatten().requires_grad_(True))
52+
flatten_optimizer_grouped_parameters.append(flatten_p)
53+
if is_distributed_tensor(p):
54+
sharding_spec[id(flatten_p)] = get_sharding_spec(p)
55+
param_shape[id(flatten_p)] = get_layout(p).global_shape
56+
else:
57+
sharding_spec[id(flatten_p)] = None
58+
param_shape[id(flatten_p)] = p.shape
59+
return flatten_optimizer_grouped_parameters, sharding_spec, param_shape
60+
61+
62+
def set_master_param_to_shard_param(master_param_list) -> dict:
63+
master_param_to_shard_param = {id(p): p for p in master_param_list}
64+
return master_param_to_shard_param
65+
66+
67+
def set_dist_grad(
68+
dist_module: nn.Module,
69+
torch_model: nn.Module,
70+
g_dtype: torch.dtype,
71+
group: dist.ProcessGroup,
72+
tp_spec: DimSpec,
73+
) -> None:
74+
"""
75+
Set split grads for Tensor Parallel or ZeRO DP.
76+
We do not need a separate treatment for ZeRO,
77+
as the wrapper takes care of reduce-scattering grads.
78+
"""
79+
rank = dist.get_rank(group)
80+
world_size = dist.get_world_size(group)
81+
82+
for p, torch_p in zip(dist_module.parameters(), torch_model.parameters()):
83+
if torch_p.grad is None:
84+
torch_p.grad = torch.zeros_like(torch_p)
85+
86+
is_distributed = hasattr(p, "dist_layout")
87+
if is_distributed:
88+
sharding = p.dist_layout.sharding_spec.sharding_sequence
89+
split_dim = sharding.index(tp_spec)
90+
shape = torch_p.split(world_size, dim=split_dim)[rank].shape
91+
92+
indices = torch.arange(shape[split_dim] * rank, shape[split_dim] * (rank + 1))
93+
# Generate grads only for the correctly split chunk
94+
torch_p.grad.index_add_(split_dim, indices, torch.randn(shape, device=torch_p.device, dtype=g_dtype))
95+
96+
else:
97+
shape = torch_p.shape
98+
torch_p.grad += torch.randn(shape, device=torch_p.device, dtype=g_dtype)
99+
100+
force_assign_grad(p, g_dtype, grad=torch_p.grad)
101+
102+
18103
def check_optim_states(org_optim, sharded_optim):
19104
for group in org_optim.param_groups:
20105
for p in group["params"]:

tests/test_optimizer/test_adam_optim.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
1010
from tests.kit.model_zoo import model_zoo
11+
from tests.test_optimizer._utils import force_assign_grad, setup_param_groups
1112

1213
_ALLOWED_OPTIM_DEVICES = [
1314
(FusedAdam, torch.device("cuda:0")),
@@ -26,29 +27,11 @@
2627
N_STEPS = 3
2728

2829

29-
def setup_param_groups(bert_model: nn.Module) -> list:
30-
no_decay = ["bias", "LayerNorm.weight"]
31-
optimizer_grouped_parameters = [
32-
{
33-
"params": [p for n, p in bert_model.named_parameters() if not any(nd in n for nd in no_decay)],
34-
"weight_decay": 0.1,
35-
},
36-
{
37-
"params": [p for n, p in bert_model.named_parameters() if any(nd in n for nd in no_decay)],
38-
"weight_decay": 0.0,
39-
},
40-
]
41-
return optimizer_grouped_parameters
42-
43-
4430
def set_grad(model: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype) -> None:
4531
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
4632
torch_p.grad = torch.rand_like(torch_p)
4733
# avoid inconsistent grad and param dtype error
48-
orig_p = p.data
49-
p.data = torch_p.grad.clone().to(g_dtype)
50-
p.grad = p.data
51-
p.data = orig_p
34+
force_assign_grad(p, g_dtype, torch_p.grad)
5235

5336

5437
@pytest.mark.parametrize("optim_cls, device", _ALLOWED_OPTIM_DEVICES)

tests/test_optimizer/test_dist_adafactor.py

Lines changed: 24 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import copy
2-
31
import pytest
42
import torch
53
import torch.distributed as dist
@@ -16,7 +14,6 @@
1614
from colossalai.tensor.d_tensor import (
1715
distribute_tensor,
1816
get_device_mesh,
19-
get_layout,
2017
get_sharding_spec,
2118
is_distributed_tensor,
2219
shard_colwise,
@@ -28,7 +25,13 @@
2825
from colossalai.utils import set_seed
2926
from colossalai.zero import LowLevelZeroOptimizer
3027
from tests.kit.model_zoo import model_zoo
31-
from tests.test_optimizer._utils import check_dist_optim_state, check_dist_param, check_optim_states
28+
from tests.test_optimizer._utils import (
29+
check_dist_optim_state,
30+
check_dist_param,
31+
check_optim_states,
32+
set_master_param_to_shard_param,
33+
setup_param_groups,
34+
)
3235
from tests.test_shardformer.test_model._utils import (
3336
build_model_from_hybrid_plugin,
3437
build_model_from_low_level_zero_plugin,
@@ -38,10 +41,13 @@
3841
unwrap_model,
3942
)
4043

41-
HEIGHT = 4
42-
WIDTH = 4
44+
IN_DIM = 4
45+
HID_DIM = 4
4346
_TP_SPEC = DimSpec([0])
4447

48+
Net, data_gen, *_ = next(iter(model_zoo.get_sub_registry("simple_mlp").values()))
49+
TPNet, *_ = next(iter(model_zoo.get_sub_registry("simple_tp_mlp").values()))
50+
4551

4652
def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torch.dtype = torch.float32):
4753
rtol = None
@@ -59,92 +65,11 @@ def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torc
5965
assert_close(tensor1, tensor2, rtol=rtol, atol=atol)
6066

6167

62-
# setup param groups; (For zero test optim)
63-
def setup_param_groups_zero(model: nn.Module) -> list:
64-
no_decay = ["bias", "LayerNorm.weight"]
65-
optimizer_grouped_parameters = [
66-
{
67-
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
68-
"weight_decay": 0.1,
69-
},
70-
{
71-
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
72-
"weight_decay": 0.0,
73-
},
74-
]
75-
return optimizer_grouped_parameters
76-
77-
78-
# setup param groups; (For base optim)
79-
def setup_param_groups(model: nn.Module) -> list:
80-
optimizer_grouped_parameters = [p for n, p in model.named_parameters()]
81-
return optimizer_grouped_parameters
82-
83-
84-
# setup flatten param groups, sharding spec and shape; (For dist optim)
85-
def setup_flatten_param_groups_sharding_spec_shape(model: nn.Module) -> dict:
86-
flatten_optimizer_grouped_parameters = []
87-
sharding_spec = {} # {id(flatten param): get_layout(p).global_shape}
88-
param_shape = {} # {id(flatten param): get_sharding_spec(p)}
89-
for n, p in model.named_parameters():
90-
# flatten_p = copy.deepcopy(p).flatten()
91-
flatten_p = nn.Parameter(p.clone().flatten().requires_grad_(True))
92-
flatten_optimizer_grouped_parameters.append(flatten_p)
93-
if is_distributed_tensor(p):
94-
sharding_spec[id(flatten_p)] = get_sharding_spec(p)
95-
param_shape[id(flatten_p)] = get_layout(p).global_shape
96-
else:
97-
sharding_spec[id(flatten_p)] = None
98-
param_shape[id(flatten_p)] = p.shape
99-
return flatten_optimizer_grouped_parameters, sharding_spec, param_shape
100-
101-
102-
def set_dist_grad(
103-
dist_module: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype, group: dist.ProcessGroup
104-
) -> None:
105-
"""
106-
Set split grads for Tensor Parallel or ZeRO DP.
107-
We do not need a separate treatment for ZeRO,
108-
as the wrapper takes care of reduce-scattering grads.
109-
"""
110-
rank = dist.get_rank(group)
111-
world_size = dist.get_world_size(group)
112-
113-
for p, torch_p in zip(dist_module.parameters(), torch_model.parameters()):
114-
if torch_p.grad is None:
115-
torch_p.grad = torch.zeros_like(torch_p)
116-
117-
is_distributed = hasattr(p, "dist_layout")
118-
if is_distributed:
119-
sharding = p.dist_layout.sharding_spec.sharding_sequence
120-
split_dim = sharding.index(_TP_SPEC)
121-
shape = torch_p.split(world_size, dim=split_dim)[rank].shape
122-
123-
indices = torch.arange(shape[split_dim] * rank, shape[split_dim] * (rank + 1))
124-
# Generate grads only for the correctly split chunk
125-
torch_p.grad.index_add_(split_dim, indices, torch.randn(shape, device=torch_p.device, dtype=g_dtype))
126-
127-
else:
128-
shape = torch_p.shape
129-
torch_p.grad += torch.randn(shape, device=torch_p.device, dtype=g_dtype)
130-
131-
# avoid inconsistent grad and param dtype error
132-
orig_p = p.data
133-
p.data = torch_p.grad.clone().to(g_dtype)
134-
p.grad = p.data
135-
p.data = orig_p
136-
137-
138-
def set_master_param_to_shard_param(master_param_list) -> dict:
139-
master_param_to_shard_param = {id(p): p for p in master_param_list}
140-
return master_param_to_shard_param
141-
142-
14368
class MlpModel(nn.Module):
14469
def __init__(self):
14570
super(MlpModel, self).__init__()
146-
self.linear1 = nn.Linear(HEIGHT, WIDTH)
147-
self.linear2 = nn.Linear(WIDTH, HEIGHT)
71+
self.linear1 = nn.Linear(IN_DIM, HID_DIM)
72+
self.linear2 = nn.Linear(HID_DIM, IN_DIM)
14873

14974
def forward(self, x):
15075
x = self.linear1(x)
@@ -182,7 +107,7 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
182107
# ==============================
183108
# Base Case
184109
# ==============================
185-
H, W = HEIGHT, WIDTH
110+
H, W = IN_DIM, HID_DIM
186111
model_col = nn.Linear(H, W).to(local_rank) # Col parallel weight
187112
weight, bias = model_col.weight, model_col.bias
188113

@@ -284,8 +209,11 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
284209
# ==============================
285210
# Model Init
286211
# ==============================
287-
base_model = MlpModel().to(local_rank)
288-
tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank)
212+
# base_model = MlpModel().to(local_rank)
213+
# tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank)
214+
base_model = Net(in_dim=IN_DIM, hid_dim=HID_DIM, dtype=dtype).to(local_rank)
215+
# Must specify dtype; TPNet init seem to run out of set_default_dtype scope
216+
tp_model = TPNet(fc1=base_model.fc1, fc2=base_model.fc2, tp_group=tp_group, dtype=dtype)
289217

290218
base_param_group = setup_param_groups(base_model)
291219
tp_param_group = setup_param_groups(tp_model)
@@ -335,7 +263,7 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
335263
# ==============================
336264
# Correctness Verify
337265
# ==============================
338-
x = torch.randn(HEIGHT, WIDTH, device=local_rank)
266+
x = torch.randn(IN_DIM, HID_DIM, device=local_rank)
339267

340268
out = base_model(x)
341269
out_tp = tp_model(x)
@@ -353,7 +281,9 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
353281
base_optim.zero_grad()
354282
dist_optim.zero_grad()
355283

356-
for p, tp_p in zip(base_param_group, tp_param_group):
284+
base_params = base_model.parameters()
285+
tp_params = tp_model.parameters()
286+
for p, tp_p in zip(base_params, tp_params):
357287
param_is_distributed = is_distributed_tensor(tp_p)
358288
if param_is_distributed:
359289
shard_spec = get_sharding_spec(tp_p)

0 commit comments

Comments
 (0)