Skip to content

Commit fa50145

Browse files
authored
add backward (#616)
1 parent 2233f1a commit fa50145

File tree

2 files changed

+202
-0
lines changed

2 files changed

+202
-0
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
4+
os.path.dirname(graph_net.__file__))")
5+
GRAPHNET_ROOT="$GRAPH_NET_ROOT/../"
6+
OUTPUT_DIR="/tmp/backward_graph_samples"
7+
mkdir -p "$OUTPUT_DIR"
8+
9+
python3 -m graph_net.apply_sample_pass \
10+
--model-path-list "graph_net/config/small100_torch_samples_list.txt" \
11+
--sample-pass-file-path "graph_net/torch/sample_pass/backward_graph_extractor.py" \
12+
--sample-pass-class-name "BackwardGraphExtractorPass" \
13+
--sample-pass-config $(base64 -w 0 <<EOF
14+
{
15+
"model_path_prefix": "$GRAPHNET_ROOT",
16+
"output_dir": "$OUTPUT_DIR",
17+
"device": "cuda"
18+
}
19+
EOF
20+
)
21+
22+
echo "Backward graph extraction completed!"
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
import os
2+
import inspect
3+
from pathlib import Path
4+
5+
import torch
6+
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_func
7+
8+
from graph_net.sample_pass.sample_pass import SamplePass
9+
from graph_net.sample_pass.resumable_sample_pass_mixin import ResumableSamplePassMixin
10+
from graph_net.torch.extractor import GraphExtractor as BuiltinGraphExtractor
11+
from graph_net.torch.fx_graph_module_util import (
12+
get_torch_module_and_inputs,
13+
_get_tensor_metas,
14+
)
15+
16+
17+
class BackwardGraphExtractor:
18+
def __init__(self, model_name, model_path, output_dir, device):
19+
self.model_path = model_path
20+
self.output_dir = output_dir
21+
self.device = device
22+
self.builtin_extractor = BuiltinGraphExtractor(
23+
name=model_name,
24+
dynamic=False,
25+
mut_graph_codes=[],
26+
placeholder_auto_rename=False,
27+
workspace_path=output_dir,
28+
)
29+
30+
def __call__(self):
31+
module, example_inputs = get_torch_module_and_inputs(
32+
self.model_path, use_dummy_inputs=False, device=self.device
33+
)
34+
module.train()
35+
36+
example_inputs = self.set_requires_grad_for_forward_inputs(
37+
self.model_path, module, example_inputs
38+
)
39+
bw_gm, backward_inputs = self.capture_backward_graph(module, example_inputs)
40+
self.builtin_extractor(bw_gm, backward_inputs)
41+
42+
def capture_backward_graph(self, module, example_inputs):
43+
backward_gm_holder = {}
44+
backward_inputs = []
45+
46+
def forward_compiler(fx_gm, example_inputs):
47+
return fx_gm
48+
49+
def backward_compiler(fx_gm, example_inputs):
50+
# Save the backward fx.Graph
51+
backward_gm_holder["gm"] = fx_gm
52+
53+
placeholders = [n for n in fx_gm.graph.nodes if n.op == "placeholder"]
54+
origin_forward = fx_gm.forward
55+
56+
def wrapped_forward(*args):
57+
for node, arg in zip(placeholders, args):
58+
backward_inputs.append(arg.detach().clone())
59+
return origin_forward(*args)
60+
61+
fx_gm.forward = wrapped_forward
62+
return make_boxed_func(fx_gm)
63+
64+
compiled = aot_module_simplified(
65+
module,
66+
example_inputs,
67+
fw_compiler=forward_compiler,
68+
bw_compiler=backward_compiler,
69+
)
70+
outs = compiled(*example_inputs)
71+
outs = [outs] if isinstance(outs, torch.Tensor) else outs
72+
valid_pairs = [
73+
(out, torch.ones_like(out))
74+
for out in outs
75+
if isinstance(out, torch.Tensor) and out.requires_grad
76+
]
77+
78+
if valid_pairs:
79+
tensors, grads = zip(*valid_pairs)
80+
torch.autograd.backward(tensors, grads)
81+
82+
bw_gm = self._remove_none_from_output(backward_gm_holder["gm"])
83+
return bw_gm, backward_inputs
84+
85+
def _remove_none_from_output(self, gm):
86+
output_node = next(
87+
(n for n in gm.graph.nodes if n.op == "output"),
88+
None,
89+
)
90+
outs = (
91+
output_node.args[0]
92+
if output_node and isinstance(output_node.args, (tuple, list))
93+
else output_node.args
94+
)
95+
if isinstance(outs, (tuple, list)):
96+
new_outs = tuple(out for out in outs if out is not None)
97+
if new_outs != outs:
98+
output_node.args = (new_outs,)
99+
100+
gm.graph.eliminate_dead_code()
101+
gm.graph.lint()
102+
gm.recompile()
103+
return gm
104+
105+
def _requires_grad(self, name, tensor):
106+
if not tensor.is_floating_point():
107+
return False
108+
109+
nograd_parameter_keywords = [
110+
"running_mean",
111+
"running_var",
112+
"num_batches_tracked",
113+
"mask",
114+
"indices",
115+
"position_ids",
116+
"anchor",
117+
]
118+
for keyword in nograd_parameter_keywords:
119+
if keyword in name:
120+
return False
121+
122+
return True
123+
124+
def set_requires_grad_for_forward_inputs(
125+
self, model_path, graph_module, example_inputs
126+
):
127+
tensor_metas = _get_tensor_metas(model_path)
128+
name2tensor_meta = {
129+
tensor_meta.name: tensor_meta for tensor_meta in tensor_metas
130+
}
131+
for input_idx, name in enumerate(
132+
inspect.signature(graph_module.forward).parameters
133+
):
134+
tensor = example_inputs[input_idx]
135+
tensor_meta = name2tensor_meta[name]
136+
original_name = (
137+
tensor_meta.original_name
138+
if hasattr(tensor_meta, "original_name") and tensor_meta.original_name
139+
else name
140+
)
141+
tensor.requires_grad = self._requires_grad(original_name, tensor)
142+
# print(f"{name}, {original_name}, requires_grad:{tensor.requires_grad}")
143+
return example_inputs
144+
145+
146+
class BackwardGraphExtractorPass(SamplePass, ResumableSamplePassMixin):
147+
"""SamplePass wrapper to generate Torch unittests via model_path_handler."""
148+
149+
def __init__(self, config=None):
150+
super().__init__(config)
151+
152+
def declare_config(
153+
self,
154+
model_path_prefix: str,
155+
output_dir: str,
156+
device: str = "auto",
157+
resume: bool = False,
158+
limits_handled_models: int = None,
159+
):
160+
pass
161+
162+
def __call__(self, rel_model_path: str):
163+
self.resumable_handle_sample(rel_model_path)
164+
165+
def sample_handled(self, rel_model_path: str) -> bool:
166+
return self.naive_sample_handled(rel_model_path, search_file_name="model.py")
167+
168+
def resume(self, rel_model_path: str):
169+
model_path_prefix = Path(self.config["model_path_prefix"])
170+
model_name = f"{os.path.basename(rel_model_path)}_backward"
171+
model_path = model_path_prefix / rel_model_path
172+
output_dir = Path(self.config["output_dir"]) / os.path.dirname(rel_model_path)
173+
device = self._choose_device(self.config["device"])
174+
extractor = BackwardGraphExtractor(model_name, model_path, output_dir, device)
175+
extractor()
176+
177+
def _choose_device(self, device) -> str:
178+
if device in ["cpu", "cuda"]:
179+
return device
180+
return "cuda" if torch.cuda.is_available() else "cpu"

0 commit comments

Comments
 (0)