Skip to content

Commit 4ce7bbf

Browse files
authored
Add LayerSpec Support to ORTPipelineModule (#20410)
### Description In Deepspeed's Pipeline Parallel Implementation, there is a class used to instantiate the object after it's moved to the device and assigned in a stage. This approach helps reduce peak memory usage. In this PR, we're adding support to ORT for wrapping this LayerSpec.
1 parent 5055dc0 commit 4ce7bbf

File tree

3 files changed

+112
-48
lines changed

3 files changed

+112
-48
lines changed

orttraining/orttraining/python/training/ortmodule/experimental/pipe/_ort_pipeline_module.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,18 @@ def _build(self):
142142
elif isinstance(layer, LayerSpec):
143143
module = layer.build()
144144
name = str(layer_idx)
145+
146+
if "debug_options" in self.ort_kwargs:
147+
new_onnx_prefix = name + "_" + self.ort_kwargs["debug_options"].onnx_prefix
148+
parallel_debug_options = DebugOptions(
149+
self.ort_kwargs["debug_options"].log_level,
150+
self.ort_kwargs["debug_options"].save_onnx,
151+
new_onnx_prefix,
152+
)
153+
module = ORTModule(module, parallel_debug_options)
154+
else:
155+
module = ORTModule(module)
156+
145157
self.forward_funcs.append(module)
146158
self.fwd_map.update({name: len(self.forward_funcs) - 1})
147159
self.add_module(name, module)

orttraining/orttraining/test/python/orttraining_ortmodule_distributed_tests.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def run_ortmodule_deepspeed_pipeline_parallel_tests(cwd, log):
5656
run_subprocess(command, cwd=cwd, log=log).check_returncode()
5757

5858

59-
def run_ort_pipeline_module_tests(cwd, log):
59+
def run_ort_pipeline_module_tests(cwd, log, layer_spec_flag=False):
6060
log.debug("Running: ORTPipelineModule tests")
6161

6262
command = [
@@ -66,6 +66,9 @@ def run_ort_pipeline_module_tests(cwd, log):
6666
"orttraining_test_ortmodule_deepspeed_pipeline_parallel_config.json",
6767
]
6868

69+
if layer_spec_flag:
70+
command.append("--layer_spec=True")
71+
6972
run_subprocess(command, cwd=cwd, log=log).check_returncode()
7073

7174

@@ -107,7 +110,11 @@ def main():
107110
run_ortmodule_deepspeed_zero_stage_1_tests(cwd, log, args.mnist)
108111

109112
run_ortmodule_deepspeed_pipeline_parallel_tests(cwd, log)
113+
114+
# Deepspeed ORTPipelineModule Tests
110115
run_ort_pipeline_module_tests(cwd, log)
116+
run_ort_pipeline_module_tests(cwd, log, layer_spec_flag=True)
117+
111118
run_ortmodule_fairscale_sharded_optimizer_tests(cwd, log, args.mnist)
112119

113120
run_distributed_cache_test(cwd, log)
Lines changed: 92 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,90 +1,135 @@
11
import argparse
2+
from typing import Dict, Tuple
23

34
import deepspeed
45
import torch
5-
from torch import nn
6+
from deepspeed.pipe import LayerSpec
7+
from torch import nn, utils
68

79
from onnxruntime.training.ortmodule.experimental.pipe import ORTPipelineModule
810

11+
# This script demonstrates how to set up a pipeline parallel training session
12+
# using DeepSpeed's ORTPipelineModule for a simple neural network model.
13+
14+
915
# USAGE:
1016
# pip install deepspeed
1117
# deepspeed orttraining_test_ort_pipeline_module.py --deepspeed_config=orttraining_test_ortmodule_deepspeed_pipeline_parallel_config.json --pipeline-parallel-size 2 --steps=100
12-
# expected output : steps: 100 loss: 0.0585 iter time (s): 0.186 samples/sec: 53.694
18+
def get_args() -> argparse.Namespace:
19+
"""Parse and return command line arguments."""
20+
parser = argparse.ArgumentParser()
21+
parser.add_argument("--local_rank", type=int, default=-1, help="Local rank passed from distributed launcher")
22+
parser.add_argument("--steps", type=int, default=100, help="Number of training steps to run")
23+
parser.add_argument("--pipeline-parallel-size", type=int, default=2, help="Number of pipeline stages")
24+
parser.add_argument("--backend", type=str, default="nccl", help="Distributed backend")
25+
parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducibility")
26+
parser.add_argument("--layer_spec", type=bool, default=False, help="Use LayerSpec for layer specification")
27+
28+
parser = deepspeed.add_config_arguments(parser)
29+
return parser.parse_args()
30+
1331

32+
class SampleData(utils.data.Dataset):
33+
"""Custom dataset to facilitate loading and batching of data."""
1434

15-
class SampleData(torch.utils.data.Dataset):
16-
def __init__(self, x, y):
35+
def __init__(self, x: torch.Tensor, y: torch.Tensor):
1736
self.x = x
1837
self.y = y
1938

20-
def __len__(self):
21-
return x.size()[0]
39+
def __len__(self) -> int:
40+
return self.x.size(0)
2241

23-
def __getitem__(self, idx):
42+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
2443
return self.x[idx], self.y[idx]
2544

2645

27-
def get_args():
28-
parser = argparse.ArgumentParser()
29-
parser.add_argument("--local_rank", type=int, default=-1, help="local rank passed from distributed launcher")
30-
parser.add_argument("-s", "--steps", type=int, default=100, help="quit after this many steps")
31-
parser.add_argument("-p", "--pipeline-parallel-size", type=int, default=2, help="pipeline parallelism")
32-
parser.add_argument("--backend", type=str, default="nccl", help="distributed backend")
33-
parser.add_argument("--seed", type=int, default=0, help="PRNG seed")
34-
parser.add_argument("--fp16", type=bool, default=False, help="fp16 run")
46+
class SimpleNetPipeInput(nn.Module):
47+
"""First stage of the pipeline, responsible for initial processing."""
3548

36-
parser = deepspeed.add_config_arguments(parser)
37-
args = parser.parse_args()
38-
return args
49+
def __init__(self, config: Dict[str, int]):
50+
super().__init__()
51+
self.linear = nn.Linear(config["input_size"], config["hidden_size"])
52+
self.activation = nn.ReLU()
53+
54+
def forward(self, x: torch.Tensor) -> torch.Tensor:
55+
x = self.linear(x)
56+
x = self.activation(x)
57+
return x
58+
59+
60+
class SimpleNetPipeBlock(nn.Module):
61+
"""Intermediate stage of the pipeline, can be duplicated to deepen the network."""
62+
63+
def __init__(self, config: Dict[str, int]):
64+
super().__init__()
65+
self.linear = nn.Linear(config["hidden_size"], config["hidden_size"])
66+
self.activation = nn.ReLU()
67+
68+
def forward(self, x: torch.Tensor) -> torch.Tensor:
69+
x = self.linear(x)
70+
x = self.activation(x)
71+
return x
72+
73+
74+
class SimpleNetPipeOutput(nn.Module):
75+
"""Final stage of the pipeline, producing the output."""
76+
77+
def __init__(self, config: Dict[str, int]):
78+
super().__init__()
79+
self.linear = nn.Linear(config["hidden_size"], config["output_size"])
80+
81+
def forward(self, x: torch.Tensor) -> torch.Tensor:
82+
x = self.linear(x)
83+
return x
84+
85+
86+
def build_model(config: Dict[str, int], n: int, layer_spec: bool) -> nn.Module:
87+
"""Constructs and returns the model either using LayerSpec or nn.Sequential."""
88+
if layer_spec:
89+
print("Wrapping layers with LayerSpec")
90+
model = (
91+
[LayerSpec(SimpleNetPipeInput, config)]
92+
+ [LayerSpec(SimpleNetPipeBlock, config) for _ in range(n)]
93+
+ [LayerSpec(SimpleNetPipeOutput, config)]
94+
)
95+
else:
96+
print("Wrapping layers with nn.Sequential")
97+
model = nn.Sequential(
98+
SimpleNetPipeInput(config),
99+
SimpleNetPipeBlock(config),
100+
SimpleNetPipeBlock(config),
101+
SimpleNetPipeOutput(config),
102+
)
103+
return model
39104

40105

41-
n = 10
42-
d_in = 4
43-
d_hidden = 8
44-
d_out = 3
45106
args = get_args()
46107
torch.cuda.set_device(args.local_rank)
47108
device = torch.device("cuda", args.local_rank)
48-
49-
# dist.init_process_group(backend=args.backend)
50109
deepspeed.init_distributed(dist_backend=args.backend)
51110
torch.manual_seed(args.seed)
52-
# Model.
53-
54-
model = nn.Sequential(
55-
nn.Linear(d_in, d_hidden), # Stage 1
56-
nn.ReLU(), # Stage 1
57-
nn.Linear(d_hidden, d_hidden), # Stage 1
58-
nn.ReLU(), # Stage 1
59-
nn.Linear(d_hidden, d_hidden), # Stage 2
60-
nn.ReLU(), # Stage 2
61-
nn.Linear(d_hidden, d_out), # Stage 2
62-
)
111+
112+
model = build_model({"input_size": 4, "hidden_size": 8, "output_size": 3}, n=10, layer_spec=args.layer_spec)
63113

64114
model = ORTPipelineModule(
65115
layers=model,
66116
loss_fn=torch.nn.CrossEntropyLoss(),
67117
num_stages=args.pipeline_parallel_size,
68-
partition_method="uniform", #'parameters',
118+
partition_method="uniform",
69119
activation_checkpoint_interval=0,
70120
)
71121

72-
params = [p for p in model.parameters() if p.requires_grad]
73-
74-
# Input.
75-
x = torch.rand((n, d_in))
76-
if args.fp16:
77-
x = x.half()
78-
# Output.
79-
y = torch.randint(0, d_out, (n,))
80-
ds = SampleData(x, y)
122+
# Setup input data
123+
x = torch.rand((10, 4))
124+
y = torch.randint(0, 3, (10,))
125+
dataset = SampleData(x, y)
81126

82127
print("Initialize deepspeed")
83128
model_engine, optimizer, _, _ = deepspeed.initialize(
84-
args=args, model=model, model_parameters=params, training_data=ds # (x,y)#
129+
args=args, model=model, model_parameters=model.parameters(), training_data=dataset
85130
)
86131

87132
for step in range(args.steps):
88133
loss = model_engine.train_batch()
89134
if step % 10 == 0:
90-
print("step = ", step, ", loss = ", loss)
135+
print(f"step = {step}, loss = {loss}")

0 commit comments

Comments
 (0)