Skip to content

Commit c7b6c98

Browse files
wanchaolpytorchmergebot
authored andcommitted
[tp] improve parallelize_module API to support more cases (pytorch#157182)
This PR improves the parallelize_module API to support more corner cases: 1. if the plan entry specified as "", it should apply the style to the current module 2. if the plan entry does not have a corresponding submodule to apply, raise a warning and ignore this plan entry As working on this PR, I also found that the while-loop inside is actually not necessary and could produce some nasty on the fly modifying while iterating behavior.. So I removed the while loop Pull Request resolved: pytorch#157182 Approved by: https://github.com/tianyu-l
1 parent d5e6f42 commit c7b6c98

File tree

3 files changed

+92
-34
lines changed

3 files changed

+92
-34
lines changed

test/distributed/tensor/debug/test_comm_mode_features.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,10 +144,10 @@ def test_MLPStacked_distributed_sharding_display(self):
144144
model2 = MLPStacked(self.device_type)
145145

146146
parallelize_plan = {
147-
"MLPStacked.layers.0.net1": ColwiseParallel(),
148-
"MLPStacked.layers.0.net2": RowwiseParallel(),
149-
"MLPStacked.layers.1.net1": ColwiseParallel(),
150-
"MLPStacked.layers.1.net2": RowwiseParallel(),
147+
"layers.0.net1": ColwiseParallel(),
148+
"layers.0.net2": RowwiseParallel(),
149+
"layers.1.net1": ColwiseParallel(),
150+
"layers.1.net2": RowwiseParallel(),
151151
}
152152

153153
model2 = parallelize_module(model2, device_mesh, parallelize_plan)

test/distributed/tensor/parallel/test_parallelize_api.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,49 @@ def test_parallelize_module_multi_wildcard(self):
332332
)
333333
self._compare_module(model, model_tp, inp_size, rank0_only=False)
334334

335+
@with_comms
336+
def test_parallelize_module_with_root_module(self):
337+
inp_size = [16, 10]
338+
model = MLPModule(self.device_type)
339+
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
340+
341+
model_tp = deepcopy(model)
342+
model_tp = parallelize_module(
343+
model_tp,
344+
device_mesh,
345+
{
346+
"": PrepareModuleInputOutput(
347+
input_layouts=Replicate(),
348+
desired_input_layouts=Shard(0),
349+
output_layouts=Shard(0),
350+
desired_output_layouts=Replicate(),
351+
),
352+
"net1": ColwiseParallel(input_layouts=Shard(0)),
353+
"net2": RowwiseParallel(output_layouts=Shard(0)),
354+
},
355+
)
356+
self._compare_module(model, model_tp, inp_size, rank0_only=False)
357+
358+
@with_comms
359+
def test_parallelize_module_with_no_match(self):
360+
inp_size = [16, 10]
361+
model = MLPModule(self.device_type)
362+
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
363+
364+
model_tp = deepcopy(model)
365+
with self.assertWarns(UserWarning):
366+
model_tp = parallelize_module(
367+
model_tp,
368+
device_mesh,
369+
{
370+
"net0.hello.world": ColwiseParallel(),
371+
"net1": ColwiseParallel(),
372+
"net2": RowwiseParallel(),
373+
"net3": ColwiseParallel(),
374+
},
375+
)
376+
self._compare_module(model, model_tp, inp_size, rank0_only=False)
377+
335378
@with_comms
336379
def test_under_devicemesh_context(self):
337380
# test ColwiseParallel
@@ -357,7 +400,8 @@ def test_empty_plan(self):
357400
# Call parallelize_module with empty plan.
358401
# Goal is not to crash.
359402
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
360-
parallelize_module(model, device_mesh)
403+
with self.assertWarns(UserWarning):
404+
parallelize_module(model, device_mesh)
361405

362406

363407
if __name__ == "__main__":

torch/distributed/tensor/parallel/api.py

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -88,39 +88,53 @@ def parallelize_module( # type: ignore[return]
8888
return parallelize_plan._apply(module, device_mesh)
8989
elif isinstance(parallelize_plan, dict):
9090
for module_path, parallelize_style in parallelize_plan.items():
91+
if module_path == "":
92+
# shortcut: empty string means to apply the plan to the current module
93+
parallelize_module(module, device_mesh, parallelize_style)
94+
continue
95+
9196
path_splits = module_path.split(".")
92-
if len(path_splits) == 0:
93-
raise ValueError(
94-
"Expect module path to be non-empty, but got empty string!"
95-
)
96-
while path_splits:
97-
atom = path_splits.pop(0)
98-
matched_children = filter(
97+
# Instead of blindly popping tokens, first check the match,
98+
# we only consume/pop the token if we found a match.
99+
token = path_splits[0]
100+
101+
matched_children = list(
102+
filter(
99103
# `t[0]` is child name
100-
lambda t: fnmatch(t[0], atom),
104+
lambda t: fnmatch(t[0], token),
101105
module.named_children(),
102106
)
103-
# apply the plan to all matched submodules
104-
for _, submodule in matched_children:
105-
if path_splits:
106-
# we haven't reached the leaf, apply in dict style
107-
leaf_path = ".".join(
108-
path_splits
109-
) # rest of the path after `atom`
110-
parallelize_module(
111-
submodule,
112-
device_mesh,
113-
{leaf_path: parallelize_style},
114-
src_data_rank=src_data_rank,
115-
)
116-
else:
117-
# otherwise, directly apply style to this submodule
118-
parallelize_module(
119-
submodule,
120-
device_mesh,
121-
parallelize_style,
122-
src_data_rank=src_data_rank,
123-
)
107+
)
108+
if not matched_children:
109+
# No match at this level. Log a warning and process next plan entry.
110+
warnings.warn(
111+
f"Parallelize plan key '{module_path}' could not be resolved: "
112+
f"no submodule matching token '{token}' in module {module}, "
113+
f"skipping this plan entry."
114+
)
115+
continue
116+
117+
# Now that we have a match, we can consume the token.
118+
path_splits.pop(0)
119+
# apply the plan to all matched submodules
120+
for _, submodule in matched_children:
121+
if path_splits:
122+
# we haven't reached the leaf, apply in dict style
123+
leaf_path = ".".join(path_splits) # rest of the path after `token`
124+
parallelize_module(
125+
submodule,
126+
device_mesh,
127+
{leaf_path: parallelize_style},
128+
src_data_rank=src_data_rank,
129+
)
130+
else:
131+
# otherwise, directly apply style to this submodule
132+
parallelize_module(
133+
submodule,
134+
device_mesh,
135+
parallelize_style,
136+
src_data_rank=src_data_rank,
137+
)
124138
return module
125139
else:
126140
raise TypeError( # pyre-ignore[7]

0 commit comments

Comments
 (0)