Skip to content

Commit 40bafbc

Browse files
committed
Implement BackwardGraphExtractor for torch.
1 parent 9819d29 commit 40bafbc

File tree

2 files changed

+177
-0
lines changed

2 files changed

+177
-0
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#!/usr/bin/env bash
2+
3+
GRAPH_NET_ROOT=$(python -c "import graph_net, os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")
4+
MODEL_PATH_PREFIX=$GRAPH_NET_ROOT
5+
OUTPUT_DIR=/tmp/backward_graph_workspace
6+
FRAMEWORK="torch"
7+
HANDLER_CONFIG=$(base64 -w 0 <<EOF
8+
{
9+
"handler_path": "$GRAPH_NET_ROOT/graph_net/torch/sample_passes/backward_graph_extractor.py",
10+
"handler_class_name": "BackwardGraphExtractorPass",
11+
"handler_config": {
12+
"model_path_prefix": "$MODEL_PATH_PREFIX",
13+
"output_dir": "$OUTPUT_DIR",
14+
"device": "cuda",
15+
"resume": false
16+
}
17+
}
18+
EOF
19+
)
20+
21+
run_case() {
22+
local rel_sample_path="$1"
23+
python -m graph_net.model_path_handler \
24+
--model-path "$rel_sample_path" \
25+
--handler-config "$HANDLER_CONFIG"
26+
}
27+
28+
run_case "samples/torchvision/resnet18"
29+
#run_case "samples/transformers-auto-model/albert-base-v2"
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import os
2+
import inspect
3+
from pathlib import Path
4+
5+
import torch
6+
from torch._functorch.aot_autograd import aot_module
7+
8+
from graph_net.sample_pass.sample_pass import SamplePass
9+
from graph_net.sample_pass.resumable_sample_pass_mixin import ResumableSamplePassMixin
10+
from graph_net.torch.extractor import GraphExtractor as BuiltinGraphExtractor
11+
from graph_net.torch.fx_graph_module_util import (
12+
get_torch_module_and_inputs,
13+
_get_tensor_metas,
14+
)
15+
16+
17+
class BackwardGraphExtractor:
18+
def __init__(self, model_name, model_path, output_dir, device):
19+
self.model_path = model_path
20+
self.output_dir = output_dir
21+
self.device = device
22+
self.builtin_extractor = BuiltinGraphExtractor(
23+
name=model_name,
24+
dynamic=False,
25+
mut_graph_codes=[],
26+
placeholder_auto_rename=False,
27+
workspace_path=output_dir,
28+
)
29+
30+
def __call__(self):
31+
module, example_inputs = get_torch_module_and_inputs(
32+
self.model_path, use_dummy_inputs=False, device=self.device
33+
)
34+
module.train()
35+
36+
example_inputs = self.set_requires_grad_for_forward_inputs(
37+
self.model_path, module, example_inputs
38+
)
39+
bw_gm, backward_inputs = self.capture_backward_graph(module, example_inputs)
40+
print(bw_gm.graph)
41+
self.builtin_extractor(bw_gm, backward_inputs)
42+
43+
def capture_backward_graph(self, module, example_inputs):
44+
backward_gm_holder = {}
45+
backward_inputs = []
46+
47+
def forward_compiler(fx_gm, example_inputs):
48+
return fx_gm
49+
50+
def backward_compiler(fx_gm, example_inputs):
51+
# Save the backward fx.Graph
52+
backward_gm_holder["gm"] = fx_gm
53+
54+
placeholders = [n for n in fx_gm.graph.nodes if n.op == "placeholder"]
55+
origin_forward = fx_gm.forward
56+
57+
def wrapped_forward(*args):
58+
for node, arg in zip(placeholders, args):
59+
if torch.is_tensor(arg):
60+
backward_inputs.append(arg.detach().clone())
61+
else:
62+
print(f"- {node.name} is not a torch.Tensor.")
63+
return origin_forward(*args)
64+
65+
fx_gm.forward = wrapped_forward
66+
return fx_gm
67+
68+
compiled = aot_module(
69+
module,
70+
fw_compiler=forward_compiler,
71+
bw_compiler=backward_compiler,
72+
)
73+
outs = compiled(*example_inputs)
74+
if isinstance(outs, torch.Tensor):
75+
outs = [outs]
76+
77+
outs_grad = [torch.ones_like(out) for out in outs]
78+
torch.autograd.backward(outs, outs_grad)
79+
return backward_gm_holder["gm"], backward_inputs
80+
81+
def _requires_grad(self, name, tensor):
82+
if not tensor.is_floating_point():
83+
return False
84+
85+
nograd_parameter_keywords = ["running_mean", "running_var"]
86+
for keyword in nograd_parameter_keywords:
87+
if keyword in name:
88+
return False
89+
90+
return True
91+
92+
def set_requires_grad_for_forward_inputs(
93+
self, model_path, graph_module, example_inputs
94+
):
95+
tensor_metas = _get_tensor_metas(model_path)
96+
name2tensor_meta = {
97+
tensor_meta.name: tensor_meta for tensor_meta in tensor_metas
98+
}
99+
for input_idx, name in enumerate(
100+
inspect.signature(graph_module.forward).parameters
101+
):
102+
tensor = example_inputs[input_idx]
103+
tensor_meta = name2tensor_meta[name]
104+
original_name = (
105+
tensor_meta.original_name
106+
if hasattr(tensor_meta, "original_name") and tensor_meta.original_name
107+
else name
108+
)
109+
tensor.requires_grad = self._requires_grad(original_name, tensor)
110+
# print(f"{name}, {original_name}, requires_grad:{tensor.requires_grad}")
111+
return example_inputs
112+
113+
114+
class BackwardGraphExtractorPass(SamplePass, ResumableSamplePassMixin):
115+
"""SamplePass wrapper to generate Torch unittests via model_path_handler."""
116+
117+
def __init__(self, config=None):
118+
super().__init__(config)
119+
120+
def declare_config(
121+
self,
122+
model_path_prefix: str,
123+
output_dir: str,
124+
device: str = "auto",
125+
resume: bool = False,
126+
limits_handled_models: int = None,
127+
):
128+
pass
129+
130+
def __call__(self, rel_model_path: str):
131+
self.resumable_handle_sample(rel_model_path)
132+
133+
def sample_handled(self, rel_model_path: str) -> bool:
134+
return self.naive_sample_handled(rel_model_path, search_file_name="model.py")
135+
136+
def resume(self, rel_model_path: str):
137+
model_path_prefix = Path(self.config["model_path_prefix"])
138+
model_name = f"{os.path.basename(rel_model_path)}_backward"
139+
model_path = model_path_prefix / rel_model_path
140+
output_dir = Path(self.config["output_dir"]) / os.path.dirname(rel_model_path)
141+
device = self._choose_device(self.config["device"])
142+
extractor = BackwardGraphExtractor(model_name, model_path, output_dir, device)
143+
extractor()
144+
145+
def _choose_device(self, device) -> str:
146+
if device in ["cpu", "cuda"]:
147+
return device
148+
return "cuda" if torch.cuda.is_available() else "cpu"

0 commit comments

Comments
 (0)