Skip to content

Commit d10dcc3

Browse files
committed
Fix several problems.
1 parent a3fb5ae commit d10dcc3

File tree

1 file changed

+69
-22
lines changed

1 file changed

+69
-22
lines changed

graph_net/torch/collect_stats.py

Lines changed: 69 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def resolve_native_multi_head_attention(*args, **kwargs):
7979

8080

8181
def resolve_tensor_to(tensor, *args, **kwargs):
82-
if isinstance(args[0], torch.dtype):
82+
if len(args) > 0 and isinstance(args[0], torch.dtype):
8383
dtype = args[0]
8484
else:
8585
dtype = tensor.dtype
@@ -99,7 +99,40 @@ def resolve_get_attr(gm: torch.fx.GraphModule, node: torch.fx.Node):
9999
return out
100100

101101

102-
def collect_op_stats_manual(model, input_dict):
102+
def convert_real_to_meta(x):
103+
if isinstance(x, torch.Tensor) and not x.is_meta:
104+
return torch.empty_like(x, device="meta")
105+
elif isinstance(x, (list, tuple)):
106+
return type(x)(convert_real_to_meta(v) for v in x)
107+
elif isinstance(x, dict):
108+
return {k: convert_real_to_meta(v) for k, v in x.items()}
109+
else:
110+
return x
111+
112+
113+
def convert_meta_to_real(x, device):
114+
if isinstance(x, torch.Tensor) and x.is_meta:
115+
return torch.empty_like(x, device=device)
116+
elif isinstance(x, (list, tuple)):
117+
return type(x)(convert_meta_to_real(v, device) for v in x)
118+
elif isinstance(x, dict):
119+
return {k: convert_meta_to_real(v, device) for k, v in x.items()}
120+
else:
121+
return x
122+
123+
124+
def resolve_with_real_tensor(op_func, device, meta_args, meta_kwargs):
125+
try:
126+
real_args = convert_meta_to_real(meta_args, device)
127+
real_kwargs = convert_meta_to_real(meta_kwargs, device)
128+
129+
real_out = op_func(*real_args, **real_kwargs)
130+
return convert_real_to_meta(real_out)
131+
except Exception:
132+
return None
133+
134+
135+
def collect_op_stats_manual(model, input_dict, device):
103136
try:
104137
# FX symbolic trace
105138
traced = torch.fx.symbolic_trace(model)
@@ -109,11 +142,19 @@ def collect_op_stats_manual(model, input_dict):
109142
return False, None
110143

111144
# Use meta tensors as input to avoid actually running the model
112-
meta_input_dict = {}
113-
for name, x in input_dict.items():
114-
meta_input_dict[name] = (
115-
torch.empty_like(x, device="meta") if isinstance(x, torch.Tensor) else x
116-
)
145+
meta_input_dict = convert_real_to_meta(input_dict)
146+
147+
def get_output_dtype(out):
148+
if isinstance(out, torch.Tensor):
149+
return out.dtype
150+
if (
151+
isinstance(out, (list, tuple))
152+
and len(out) > 0
153+
and isinstance(out[0], torch.Tensor)
154+
):
155+
return out[0].dtype
156+
else:
157+
return None
117158

118159
is_complete = True
119160
op_stats = {}
@@ -157,6 +198,7 @@ def collect_op_stats_manual(model, input_dict):
157198
if op_name == "_native_multi_head_attention":
158199
out = resolve_native_multi_head_attention(*node_args, **node_kwargs)
159200
elif op_name == "to":
201+
# print(f"node.op={node.op}, op_name={op_name}, node.args={node.args}")
160202
out = resolve_tensor_to(
161203
node_outputs[node.args[0].name], *node_args, **node_kwargs
162204
)
@@ -165,26 +207,30 @@ def collect_op_stats_manual(model, input_dict):
165207
else:
166208
out = op_func(*node_args, **node_kwargs)
167209
node_outputs[node.name] = out
168-
dtype = out.dtype if isinstance(out, torch.Tensor) else None
210+
dtype = get_output_dtype(out)
169211
except Exception:
170-
print(f"dtype inference failed: node.op={node.op}, op_name={op_name}")
171-
node_outputs[node.name] = None
172-
is_complete = False
212+
out = resolve_with_real_tensor(op_func, device, node_args, node_kwargs)
213+
node_outputs[node.name] = out
214+
if out is not None:
215+
dtype = get_output_dtype(out)
216+
else:
217+
print(
218+
f"dtype inference failed: node.op={node.op}, op_name={op_name}"
219+
)
220+
is_complete = False
173221
elif node.op == "get_attr":
174222
op_name = node.op
175223
out = resolve_get_attr(traced, node)
176224
node_outputs[node.name] = out
177-
dtype = out.dtype if isinstance(out, torch.Tensor) else None
225+
dtype = get_output_dtype(out)
178226
elif node.op == "output":
179227
op_name = node.op
180228
node_args = torch.fx.map_arg(
181229
node.args,
182230
lambda n: node_outputs[n.name] if isinstance(n, torch.fx.Node) else n,
183231
)
184232
node_outputs[node.name] = node_args[0] if len(node_args) == 1 else node_args
185-
dtype = (
186-
node_args[0].dtype if isinstance(node_args[0], torch.Tensor) else None
187-
)
233+
dtype = get_output_dtype(node_args[0])
188234
else:
189235
assert False, f"node.op: {node.op}"
190236

@@ -205,10 +251,7 @@ def collect_op_stats_with_make_fx(model, input_dict, arg_types):
205251
meta_input_list = []
206252
for arg_name in arg_types.keys():
207253
x = input_dict[arg_name]
208-
meta_x = (
209-
torch.empty_like(x, device="meta") if isinstance(x, torch.Tensor) else x
210-
)
211-
meta_input_list.append(meta_x)
254+
meta_input_list.append(convert_real_to_meta(x))
212255

213256
try:
214257
# Generate FX Graph, and automatically fill in meta information
@@ -262,8 +305,10 @@ def collect_op_stats_with_make_fx(model, input_dict, arg_types):
262305
return is_complete, op_stats
263306

264307

265-
def collect_op_stats(model, input_dict, arg_types):
266-
is_complete_manual, op_stats_manual = collect_op_stats_manual(model, input_dict)
308+
def collect_op_stats(model, input_dict, arg_types, device):
309+
is_complete_manual, op_stats_manual = collect_op_stats_manual(
310+
model, input_dict, device
311+
)
267312
if not is_complete_manual:
268313
is_complete_make_fx, op_stats_make_fx = collect_op_stats_with_make_fx(
269314
model, input_dict, arg_types
@@ -285,7 +330,9 @@ def collect_model_stats(model_path, device, log_prompt):
285330
num_outputs = 0
286331
ops_count_dict = {}
287332
op_dtypes = {}
288-
method, is_complete, op_stats = collect_op_stats(model, input_dict, arg_types)
333+
method, is_complete, op_stats = collect_op_stats(
334+
model, input_dict, arg_types, device
335+
)
289336
if op_stats is not None:
290337
for op_name, stat in sorted(op_stats.items()):
291338
if op_name == "placeholder":

0 commit comments

Comments
 (0)