Skip to content

Commit 07558f2

Browse files
committed
Support another method with make_fx.
1 parent 8161df7 commit 07558f2

File tree

1 file changed

+79
-10
lines changed

1 file changed

+79
-10
lines changed

graph_net/torch/collect_stats.py

Lines changed: 79 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from collections import defaultdict
1111

1212
import torch
13-
from torch.fx.passes.shape_prop import ShapeProp
13+
from functorch import make_fx
1414
from graph_net.torch import utils
1515

1616

@@ -99,13 +99,13 @@ def resolve_get_attr(gm: torch.fx.GraphModule, node: torch.fx.Node):
9999
return out
100100

101101

102-
def collect_op_stats(model, input_dict):
102+
def collect_op_stats_manual(model, input_dict):
103103
try:
104104
# FX symbolic trace
105105
traced = torch.fx.symbolic_trace(model)
106106
# print(traced.graph)
107107
except Exception:
108-
print("Failed to FX symbolic trace")
108+
print("Failed to FX symbolic_trace")
109109
return False, None
110110

111111
# Use meta tensors as input to avoid actually running the model
@@ -136,7 +136,6 @@ def collect_op_stats(model, input_dict):
136136
)
137137

138138
try:
139-
# if True:
140139
if node.op == "call_module":
141140
# classname of module
142141
submod = traced.get_submodule(node.target)
@@ -200,23 +199,94 @@ def collect_op_stats(model, input_dict):
200199
return is_complete, op_stats
201200

202201

202+
def collect_op_stats_with_make_fx(model, input_dict, arg_types):
203+
# Use meta tensors as input to avoid actually running the model
204+
meta_input_list = []
205+
for arg_name in arg_types.keys():
206+
x = input_dict[arg_name]
207+
meta_x = (
208+
torch.empty_like(x, device="meta") if isinstance(x, torch.Tensor) else x
209+
)
210+
meta_input_list.append(meta_x)
211+
212+
try:
213+
# Generate FX Graph, and automatically fill in meta information
214+
fx_model = make_fx(model)(*meta_input_list)
215+
except Exception:
216+
print("Failed to execute make_fx")
217+
return False, None
218+
219+
is_complete = True
220+
op_stats = {}
221+
for node in fx_model.graph.nodes:
222+
op_name = None
223+
if node.op == "call_module":
224+
# classname of module
225+
submod = traced.get_submodule(node.target)
226+
op_name = submod.__class__.__name__
227+
elif node.op == "call_function":
228+
op_name = node.target.__name__
229+
elif node.op == "call_method":
230+
op_name = node.target
231+
elif node.op in ["placeholder", "output", "get_attr"]:
232+
op_name = node.op
233+
else:
234+
assert False, f"node.op: {node.op}"
235+
236+
dtype = None
237+
if node.op != "output":
238+
if "tensor_meta" in node.meta:
239+
tensor_meta = node.meta["tensor_meta"]
240+
dtype = tensor_meta.dtype
241+
# print(f"node.op={node.op}, node.target={node.target}, dtype={tensor_meta.dtype}")
242+
else:
243+
print(
244+
f"node.op={node.op}, node.target={node.target} has no tensor_meta!"
245+
)
246+
is_complete = False
247+
248+
op_name = (
249+
op_name.replace(".default", "")
250+
.replace(".Tensor", "")
251+
.replace(".Scalar", "")
252+
)
253+
dtype_str = str(dtype).replace("torch.", "")
254+
if op_stats.get(op_name, None) is None:
255+
op_stats[op_name] = OpStat(op_name, {dtype_str}, 1)
256+
else:
257+
op_stats[op_name].dtype.add(dtype_str)
258+
op_stats[op_name].count = op_stats[op_name].count + 1
259+
return is_complete, op_stats
260+
261+
262+
def collect_op_stats(model, input_dict, arg_types):
263+
is_complete_manual, op_stats_manual = collect_op_stats_manual(model, input_dict)
264+
if not is_complete_manual:
265+
is_complete_make_fx, op_stats_make_fx = collect_op_stats_with_make_fx(
266+
model, input_dict, arg_types
267+
)
268+
if is_complete_make_fx or op_stats_manual is None:
269+
return "make_fx", is_complete_make_fx, op_stats_make_fx
270+
return "manual", is_complete_manual, op_stats_manual
271+
272+
203273
def collect_model_stats(model_path, device, log_prompt):
204274
model_class = load_class_from_file(
205275
os.path.join(model_path, "model.py"), "GraphModule"
206276
)
207277
model = model_class()
278+
arg_types = get_argument_types(model_class, "forward")
208279
input_dict = get_input_dict(model_path, device)
209280

210281
num_ops = 0
211-
num_inputs = 0
212282
num_outputs = 0
213283
ops_count_info = []
214284
dtypes = set()
215-
is_complete, op_stats = collect_op_stats(model, input_dict)
285+
method, is_complete, op_stats = collect_op_stats(model, input_dict, arg_types)
216286
if op_stats is not None:
217287
for op_name, stat in sorted(op_stats.items()):
218288
if op_name == "placeholder":
219-
num_inputs += stat.count
289+
pass
220290
elif op_name == "output":
221291
num_outputs += stat.count
222292
else:
@@ -226,8 +296,7 @@ def collect_model_stats(model_path, device, log_prompt):
226296
if v is not None:
227297
dtypes.add(v)
228298

229-
arg_types = get_argument_types(model_class, "forward")
230-
num_inputs = len(arg_types) if op_stats is None else num_inputs
299+
num_inputs = len(arg_types)
231300
num_params = 0
232301
param_dtypes = set()
233302
for name, arg_type in arg_types.items():
@@ -242,7 +311,7 @@ def collect_model_stats(model_path, device, log_prompt):
242311
dtypes_str = "[" + ",".join(dtypes) + "]"
243312
param_dtypes_str = "[" + ",".join(param_dtypes) + "]"
244313
print(
245-
f"{log_prompt} [ModelStats] model_path:{model_path} num_inputs:{num_inputs} num_outputs:{num_outputs} num_ops:{num_ops} num_params:{num_params_in_billion}B param_dtypes:{param_dtypes_str} op_dtypes:{dtypes_str} is_complete:{is_complete} ops:{ops_str}",
314+
f"{log_prompt} [ModelStats] model_path:{model_path} num_inputs:{num_inputs} num_outputs:{num_outputs} num_ops:{num_ops} num_params:{num_params_in_billion}B param_dtypes:{param_dtypes_str} op_dtypes:{dtypes_str} method:{method} is_complete:{is_complete} ops:{ops_str}",
246315
flush=True,
247316
)
248317

0 commit comments

Comments
 (0)