Skip to content

Commit 1df5d5f

Browse files
committed
Implement method using torch.compile with customized backend.
1 parent f159a3d commit 1df5d5f

File tree

1 file changed

+171
-99
lines changed

1 file changed

+171
-99
lines changed

graph_net/torch/collect_stats.py

Lines changed: 171 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,12 @@ class OpStat:
5555
op_dtypes: dict[str, int] = field(default_factory=dict)
5656
count: int = 0
5757

58+
def update(self, other):
59+
if isinstance(other, OpStat) and self.op_name == other.op_name:
60+
self.count += other.count
61+
for name, count in other.op_dtypes.items():
62+
self.op_dtypes[name] = self.op_dtypes.get(name, 0) + count
63+
5864

5965
def resolve_native_multi_head_attention(*args, **kwargs):
6066
query, key, value = args[0], args[1], args[2]
@@ -132,19 +138,23 @@ def resolve_with_real_tensor(op_func, device, meta_args, meta_kwargs):
132138
return None
133139

134140

135-
def collect_op_stats_manual(model, input_dict, device):
136-
try:
137-
# FX symbolic trace
138-
traced = torch.fx.symbolic_trace(model)
139-
# print(traced.graph)
140-
except Exception:
141-
print("Failed to FX symbolic_trace")
142-
return False, None
141+
torch._dynamo.config.capture_scalar_outputs = True
142+
torch._dynamo.config.capture_dynamic_output_shape_ops = True
143+
torch._dynamo.config.capture_sparse_compute = True
144+
torch._dynamo.config.raise_on_ctx_manager_usage = False
145+
torch._dynamo.config.allow_rnn = True
143146

144-
# Use meta tensors as input to avoid actually running the model
145-
meta_input_dict = convert_real_to_meta(input_dict)
146147

147-
def get_output_dtype(out):
148+
class GraphMetaExecutor:
149+
def __init__(self, device):
150+
self.device = device
151+
self.op_stats = {}
152+
self.is_complete = True
153+
self.num_ops = 0
154+
self.num_ops_misses_dtypes = 0
155+
self.subgraph_counter = 0
156+
157+
def get_output_dtype(self, out):
148158
if isinstance(out, torch.Tensor):
149159
return out.dtype
150160
if (
@@ -156,102 +166,160 @@ def get_output_dtype(out):
156166
else:
157167
return None
158168

159-
is_complete = True
160-
op_stats = {}
161-
node_outputs = {}
162-
for node in traced.graph.nodes:
169+
def get_op_name_and_func(self, node, node_outputs):
163170
op_name = None
164-
dtype = None
165-
if node.op == "placeholder":
166-
node_outputs[node.name] = meta_input_dict[node.target]
167-
op_name = node.op
168-
elif node.op in ["call_function", "call_module", "call_method"]:
169-
node_args = torch.fx.map_arg(
170-
node.args,
171-
lambda n: node_outputs[n.name] if isinstance(n, torch.fx.Node) else n,
172-
)
173-
node_kwargs = torch.fx.map_arg(
174-
node.kwargs,
175-
lambda n: node_outputs[n.name] if isinstance(n, torch.fx.Node) else n,
176-
)
177-
178-
try:
179-
if node.op == "call_module":
180-
# classname of module
181-
submod = traced.get_submodule(node.target)
182-
op_name = submod.__class__.__name__
183-
op_func = submod
184-
elif node.op == "call_function":
185-
op_name = node.target.__name__
186-
op_func = node.target
187-
elif node.op == "call_method":
188-
op_name = node.target
189-
self_obj = (
190-
node_outputs[node.args[0].name]
191-
if isinstance(node.args[0], torch.fx.Node)
192-
else node.args[0]
193-
)
194-
op_func = getattr(self_obj, node.target)
195-
node_args = node_args[1:]
196-
197-
# print(f"node.op={node.op}, op_name={op_name}, node.args={node.args}")
198-
if op_name == "_native_multi_head_attention":
199-
out = resolve_native_multi_head_attention(*node_args, **node_kwargs)
200-
elif op_name == "to":
201-
# print(f"node.op={node.op}, op_name={op_name}, node.args={node.args}")
202-
out = resolve_tensor_to(
203-
node_outputs[node.args[0].name], *node_args, **node_kwargs
204-
)
205-
elif op_name == "item":
206-
out = resolve_tensor_item(node_outputs[node.args[0].name])
207-
else:
208-
out = op_func(*node_args, **node_kwargs)
209-
node_outputs[node.name] = out
210-
dtype = get_output_dtype(out)
211-
except Exception:
212-
out = resolve_with_real_tensor(op_func, device, node_args, node_kwargs)
213-
node_outputs[node.name] = out
214-
if out is not None:
215-
dtype = get_output_dtype(out)
216-
else:
217-
print(
218-
f"dtype inference failed: node.op={node.op}, op_name={op_name}"
219-
)
220-
is_complete = False
221-
elif node.op == "get_attr":
222-
op_name = node.op
223-
out = resolve_get_attr(traced, node)
224-
node_outputs[node.name] = out
225-
dtype = get_output_dtype(out)
226-
elif node.op == "output":
227-
op_name = node.op
228-
node_args = torch.fx.map_arg(
229-
node.args,
230-
lambda n: node_outputs[n.name] if isinstance(n, torch.fx.Node) else n,
231-
)
232-
node_outputs[node.name] = node_args[0] if len(node_args) == 1 else node_args
233-
dtype = get_output_dtype(node_args[0])
234-
else:
235-
assert False, f"node.op: {node.op}"
236-
171+
op_func = None
172+
try:
173+
if node.op == "call_module":
174+
# classname of module
175+
submod = traced.get_submodule(node.target)
176+
op_name = submod.__class__.__name__
177+
op_func = submod
178+
elif node.op == "call_function":
179+
op_name = node.target.__name__
180+
op_func = node.target
181+
elif node.op == "call_method":
182+
op_name = node.target
183+
self_obj = (
184+
node_outputs[node.args[0].name]
185+
if isinstance(node.args[0], torch.fx.Node)
186+
else node.args[0]
187+
)
188+
op_func = getattr(self_obj, node.target)
189+
elif node.op in ["get_attr", "placeholder", "output"]:
190+
op_name = node.op
191+
except Exception:
192+
pass
193+
return op_name, op_func
194+
195+
def update_op_stats(self, op_stats, op_name, op_dtype):
237196
if op_name is not None:
238-
dtype_str = str(dtype).replace("torch.", "")
197+
dtype_str = str(op_dtype).replace("torch.", "")
239198
if op_stats.get(op_name, None) is None:
240199
op_stats[op_name] = OpStat(op_name, {dtype_str: 1}, 1)
241200
else:
242201
op_stats[op_name].op_dtypes[dtype_str] = (
243202
op_stats[op_name].op_dtypes.get(dtype_str, 0) + 1
244203
)
245-
op_stats[op_name].count = op_stats[op_name].count + 1
246-
return is_complete, op_stats
204+
op_stats[op_name].count += 1
205+
206+
def __call__(self, gm: torch.fx.GraphModule, sample_inputs):
207+
# Use meta tensors as input to avoid actually running the model
208+
meta_sample_inputs = convert_real_to_meta(sample_inputs)
209+
210+
op_stats = {}
211+
num_ops_misses_dtypes = 0
212+
213+
input_idx = 0
214+
node_outputs = {}
215+
for node in gm.graph.nodes:
216+
out = None
217+
op_dtype = None
218+
op_name, op_func = self.get_op_name_and_func(node, node_outputs)
219+
if node.op == "placeholder":
220+
out = meta_sample_inputs[input_idx]
221+
input_idx += 1
222+
elif node.op in ["call_function", "call_module", "call_method"]:
223+
try:
224+
node_args = torch.fx.map_arg(
225+
node.args,
226+
lambda n: node_outputs[n.name]
227+
if isinstance(n, torch.fx.Node)
228+
else n,
229+
)
230+
node_kwargs = torch.fx.map_arg(
231+
node.kwargs,
232+
lambda n: node_outputs[n.name]
233+
if isinstance(n, torch.fx.Node)
234+
else n,
235+
)
236+
if node.op == "call_method":
237+
node_args = node_args[1:]
238+
239+
if op_name == "_native_multi_head_attention":
240+
out = resolve_native_multi_head_attention(
241+
*node_args, **node_kwargs
242+
)
243+
elif op_name == "to":
244+
out = resolve_tensor_to(
245+
node_outputs[node.args[0].name], *node_args, **node_kwargs
246+
)
247+
elif op_name == "item":
248+
out = resolve_tensor_item(node_outputs[node.args[0].name])
249+
else:
250+
assert op_func is not None, f"op_func of {node} is None."
251+
out = op_func(*node_args, **node_kwargs)
252+
except Exception:
253+
out = resolve_with_real_tensor(
254+
op_func, self.device, node_args, node_kwargs
255+
)
256+
if out is None:
257+
if num_ops_misses_dtypes == 0:
258+
print(
259+
f"dtype inference failed: node.op={node.op}, op_name={op_name}"
260+
)
261+
num_ops_misses_dtypes += 1
262+
elif node.op == "get_attr":
263+
out = resolve_get_attr(traced, node)
264+
elif node.op == "output":
265+
pass
266+
else:
267+
assert False, f"node.op: {node.op}"
268+
269+
if out is not None:
270+
node_outputs[node.name] = out
271+
op_dtype = self.get_output_dtype(out)
272+
273+
if node.op not in ["placeholder", "output"]:
274+
self.update_op_stats(op_stats, op_name, op_dtype)
275+
276+
if num_ops_misses_dtypes > 0:
277+
self.is_complete = False
278+
self.num_ops_misses_dtypes += num_ops_misses_dtypes
279+
num_ops = 0
280+
for name, stat in op_stats.items():
281+
num_ops += stat.count
282+
if name in self.op_stats.keys():
283+
self.op_stats[name].update(stat)
284+
else:
285+
self.op_stats[name] = stat
286+
self.num_ops += num_ops
287+
self.subgraph_counter += 1
288+
return gm.forward
289+
290+
def summary(self):
291+
print(
292+
f"Totally {self.subgraph_counter} subgraphs, {self.num_ops} operators, and {self.num_ops_misses_dtypes} operators failed to inference dtypes."
293+
)
247294

248295

249-
def collect_op_stats_with_make_fx(model, input_dict, arg_types):
296+
def collect_op_stats_with_compile(model, sample_inputs, device):
297+
assert isinstance(model, torch.nn.Module), f"{type(model)=}"
298+
meta_executor = GraphMetaExecutor(device)
299+
compiled_model = torch.compile(model, backend=meta_executor)
300+
compiled_model(*sample_inputs)
301+
meta_executor.summary()
302+
return "compile", meta_executor.is_complete, meta_executor.op_stats
303+
304+
305+
def collect_op_stats_manual(model, sample_inputs, device):
306+
try:
307+
# FX symbolic trace
308+
traced = torch.fx.symbolic_trace(model)
309+
# print(traced.graph)
310+
except Exception:
311+
print("Failed to FX symbolic_trace")
312+
return False, None
313+
314+
meta_executor = GraphMetaExecutor(device)
315+
meta_executor(traced, sample_inputs)
316+
meta_executor.summary()
317+
return meta_executor.is_complete, meta_executor.op_stats
318+
319+
320+
def collect_op_stats_with_make_fx(model, sample_inputs):
250321
# Use meta tensors as input to avoid actually running the model
251-
meta_input_list = []
252-
for arg_name in arg_types.keys():
253-
x = input_dict[arg_name]
254-
meta_input_list.append(convert_real_to_meta(x))
322+
meta_input_list = convert_real_to_meta(sample_inputs)
255323

256324
try:
257325
# Generate FX Graph, and automatically fill in meta information
@@ -325,14 +393,19 @@ def collect_model_stats(model_path, device, log_prompt):
325393
model = model_class()
326394
arg_types = get_argument_types(model_class, "forward")
327395
input_dict = get_input_dict(model_path, device)
396+
ordered_input_list = [input_dict[arg_name] for arg_name in arg_types.keys()]
328397

329398
num_ops = 0
330399
num_outputs = 0
331400
ops_count_dict = {}
332401
op_dtypes = {}
333-
method, is_complete, op_stats = collect_op_stats(
334-
model, input_dict, arg_types, device
402+
method, is_complete, op_stats = collect_op_stats_with_compile(
403+
model, ordered_input_list, device
335404
)
405+
406+
# method, is_complete, op_stats = collect_op_stats(
407+
# model, input_dict, arg_types, device
408+
# )
336409
if op_stats is not None:
337410
for op_name, stat in sorted(op_stats.items()):
338411
if op_name == "placeholder":
@@ -474,5 +547,4 @@ def main(args):
474547
help="Log prompt for stats log filtering.",
475548
)
476549
args = parser.parse_args()
477-
print(f"[CollectStats Arguments] {args}")
478550
main(args=args)

0 commit comments

Comments
 (0)