Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit ed6787a

Browse files
authored
add option for meta tensor tracing (#349)
* add option for meta tensor tracing * Added meta tensor flag
1 parent 3fa6ea9 commit ed6787a

File tree

4 files changed

+153
-62
lines changed

4 files changed

+153
-62
lines changed

functorch/_src/aot_autograd.py

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -385,54 +385,6 @@ def clear_compile_cache():
385385
compile_cache.clear()
386386
compile_cache = None
387387

388-
def tvm_compile(fx_module, example_inputs, name = None):
389-
import tvm
390-
from tvm import relay, auto_scheduler
391-
from tvm.contrib import graph_executor
392-
import os
393-
394-
jit_mod = torch.jit.script(fx_module)
395-
# jit_mod = torch.jit.trace(fx_module, example_inputs)
396-
397-
shape_list = [(f"inp_{idx}", i.shape) for idx, i in enumerate(example_inputs)]
398-
mod, params = relay.frontend.from_pytorch(jit_mod, shape_list)
399-
target = tvm.target.Target("llvm -mcpu=core-avx2")
400-
tasks, task_weights = auto_scheduler.extract_tasks(mod['main'], params, target)
401-
for task in tasks:
402-
print(task.compute_dag)
403-
if name is None:
404-
log_file = f'{time.time()}.json'
405-
else:
406-
log_file = f'{name}.json'
407-
if len(tasks) != 0:
408-
tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
409-
if not os.path.exists(log_file):
410-
tune_option = auto_scheduler.TuningOptions(
411-
num_measure_trials=10000, # change this to 20000 to achieve the best performance
412-
measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
413-
# early_stopping=1000,
414-
# verbose=2,
415-
)
416-
tuner.tune(tune_option)
417-
418-
dev = tvm.cpu(0)
419-
with auto_scheduler.ApplyHistoryBest(log_file):
420-
with tvm.transform.PassContext(opt_level=3, config={"relay.backend.use_auto_scheduler": True}):
421-
lib = relay.build(mod, target=target, params=params)
422-
dtype = "float32"
423-
m = graph_executor.GraphModule(lib["default"](dev))
424-
def exec_tvm(*args):
425-
for idx, arg in enumerate(args, 0):
426-
if arg.dim() != 0:
427-
428-
m.set_input(f"inp_{idx}", tvm.nd.from_dlpack(torch.utils.dlpack.to_dlpack(arg)))
429-
m.run()
430-
outs = [torch.utils.dlpack.from_dlpack(m.get_output(i).to_dlpack()) for i in range(m.get_num_outputs())]
431-
return outs
432-
return exec_tvm
433-
434-
def tvm_function(fn, name):
435-
return compiled_function(fn, partial(tvm_compile, name=f'fw_{name}'), partial(tvm_compile, name=f'bw_{name}'))
436388

437389
def compiled_module(mod, *args, **kwargs):
438390
func_mod, params, buffers = make_functional_with_buffers(mod)

functorch/_src/compilers.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import torch
2+
from functools import partial
3+
from .aot_autograd import draw_graph
4+
import time
5+
6+
def ts_compile(fx_g, _):
7+
for node in fx_g.graph.nodes:
8+
if node.target == torch.ops.aten.new_zeros:
9+
if node.args[1] == []:
10+
args = list(node.args)
11+
args[1] = [1]
12+
node.args = tuple(args)
13+
fx_g.graph.lint()
14+
# Works around this NVFuser issue: https://github.com/csarofeen/pytorch/issues/1311
15+
for i in range(1000):
16+
attr = f'_tensor_constant{i}'
17+
if hasattr(fx_g, attr):
18+
setattr(fx_g, attr, getattr(fx_g, attr).cuda())
19+
else:
20+
break
21+
22+
fx_g.recompile()
23+
f = torch.jit.script(fx_g)
24+
25+
# Works around alias analysis issues in TS
26+
graph = f.graph
27+
outputs = list(graph.outputs())
28+
output = outputs[0]
29+
graph.eraseOutput(0)
30+
outputs = list(output.node().inputs())
31+
for inp in output.node().inputs():
32+
graph.registerOutput(inp)
33+
output.node().destroy()
34+
torch._C._jit_pass_remove_mutation(graph)
35+
for i in range(len(list(graph.outputs()))):
36+
graph.eraseOutput(0)
37+
node = graph.create("prim::ListConstruct", outputs)
38+
graph.appendNode(node)
39+
node.output().setType(torch._C.ListType.ofTensors())
40+
graph.registerOutput(node.output())
41+
torch._C._jit_pass_remove_mutation(f.graph)
42+
43+
f = torch.jit.freeze(f.eval())
44+
f = torch.jit.optimize_for_inference(f)
45+
return f
46+
47+
def _draw_graph_compile(fx_g, _, name):
48+
draw_graph(fx_g, name)
49+
return fx_g
50+
51+
def draw_graph_compile(name):
52+
return partial(draw_graph_compile, name=name)
53+
54+
def _tvm_compile(fx_module, example_inputs, name = None):
55+
import tvm
56+
from tvm import relay, auto_scheduler
57+
from tvm.contrib import graph_executor
58+
import os
59+
60+
jit_mod = torch.jit.script(fx_module)
61+
# jit_mod = torch.jit.trace(fx_module, example_inputs)
62+
63+
shape_list = [(f"inp_{idx}", i.shape) for idx, i in enumerate(example_inputs)]
64+
mod, params = relay.frontend.from_pytorch(jit_mod, shape_list)
65+
target = tvm.target.Target("llvm -mcpu=core-avx2")
66+
tasks, task_weights = auto_scheduler.extract_tasks(mod['main'], params, target)
67+
for task in tasks:
68+
print(task.compute_dag)
69+
if name is None:
70+
log_file = f'{time.time()}.json'
71+
else:
72+
log_file = f'{name}.json'
73+
if len(tasks) != 0:
74+
tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
75+
if not os.path.exists(log_file):
76+
tune_option = auto_scheduler.TuningOptions(
77+
num_measure_trials=10000, # change this to 20000 to achieve the best performance
78+
measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
79+
# early_stopping=1000,
80+
# verbose=2,
81+
)
82+
tuner.tune(tune_option)
83+
84+
dev = tvm.cpu(0)
85+
with auto_scheduler.ApplyHistoryBest(log_file):
86+
with tvm.transform.PassContext(opt_level=3, config={"relay.backend.use_auto_scheduler": True}):
87+
lib = relay.build(mod, target=target, params=params)
88+
dtype = "float32"
89+
m = graph_executor.GraphModule(lib["default"](dev))
90+
def exec_tvm(*args):
91+
for idx, arg in enumerate(args, 0):
92+
if arg.dim() != 0:
93+
94+
m.set_input(f"inp_{idx}", tvm.nd.from_dlpack(torch.utils.dlpack.to_dlpack(arg)))
95+
m.run()
96+
outs = [torch.utils.dlpack.from_dlpack(m.get_output(i).to_dlpack()) for i in range(m.get_num_outputs())]
97+
return outs
98+
return exec_tvm
99+
100+
def tvm_compile(name):
101+
return partial(tvm_compile, name=name)
102+
103+
def nop(f, _):
104+
print(f.code)
105+
return f

functorch/_src/python_key.py

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323

2424
USE_DECOMPOSE = False
25+
USE_META = False
2526

2627
@contextmanager
2728
def pythonkey_decompose():
@@ -32,22 +33,38 @@ def pythonkey_decompose():
3233
finally:
3334
USE_DECOMPOSE = False
3435

36+
37+
@contextmanager
38+
def pythonkey_meta():
39+
global USE_META
40+
USE_META = True
41+
try:
42+
yield USE_META
43+
finally:
44+
USE_META = False
45+
3546
class PythonTensor(torch.Tensor):
3647
elem: torch.Tensor
3748

3849
__slots__ = ['elem', 'proxy']
3950

4051
@staticmethod
41-
def __new__(cls, elem, proxy):
52+
def __new__(cls, elem, proxy, device=None):
4253
# The wrapping tensor (PythonTensor) is just a meta tensor, so it
4354
# doesn't hold any memory (meta tensor is generally the preferred type
4455
# of tensor you want to make a subclass from)...
45-
meta = elem.new_empty((0,))
46-
meta.set_(meta.storage(), 0, elem.size(), elem.stride())
47-
r = torch.Tensor._make_subclass(cls, meta, elem.requires_grad)
56+
57+
r = torch.Tensor._make_wrapper_subclass(
58+
cls, elem.size(),
59+
strides=elem.stride(), storage_offset=elem.storage_offset(),
60+
dtype=elem.dtype, layout=elem.layout, requires_grad=elem.requires_grad,
61+
device=(elem.device if device is None else device),
62+
)
4863

4964
# ...the real tensor is held as an element on the tensor.
5065
r.elem = elem
66+
if USE_META:
67+
r.elem = r.elem.to('meta')
5168
r.proxy = proxy
5269
return r
5370

@@ -59,28 +76,45 @@ def __repr__(self):
5976
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
6077
if func in decomposition_table and USE_DECOMPOSE:
6178
return decomposition_table[func](*args, **kwargs)
79+
6280
def unwrap_proxy(e):
6381
return e.proxy if isinstance(e, PythonTensor) else e
6482

6583
def unwrap_tensor(e):
6684
return e.elem if isinstance(e, PythonTensor) else e
85+
86+
# Used to infer the output device
87+
input_devices = list(set([i.device for i in pytree.tree_flatten(args)[0] + pytree.tree_flatten(kwargs)[0] if isinstance(i, PythonTensor)]))
88+
assert len(input_devices) == 1
89+
output_device = input_devices[0]
6790
proxy_args = pytree.tree_map(unwrap_proxy, args)
6891
proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs)
6992
proxy_out = func(*proxy_args, **proxy_kwargs)
70-
real_out = func(*pytree.tree_map(unwrap_tensor, args), **pytree.tree_map(unwrap_tensor, kwargs))
71-
72-
def wrap_with_proxy(e, idx):
93+
args = pytree.tree_map(unwrap_tensor, args)
94+
kwargs = pytree.tree_map(unwrap_tensor, kwargs)
95+
try:
96+
real_out = func(*args, **kwargs)
97+
except NotImplementedError as e:
98+
args = pytree.tree_map(lambda x: torch.ones_like(x, device=output_device) if isinstance(x, torch.Tensor) else x, args)
99+
kwargs = pytree.tree_map(lambda x: torch.ones_like(x, device=output_device) if isinstance(x, torch.Tensor) else x, kwargs)
100+
real_out = func(*args, **kwargs)
101+
102+
def wrap_with_proxy(e, proxy):
73103
# Some ops (like native_batch_norm_backward) return undefined tensors that get converted into None in python.
74104
# As the function signature expects tensors, if we directly return these None tensors back to C++, we'll error.
75105
if e is None:
76-
return PythonTensor(torch.empty(()), proxy_out[idx])
77-
return PythonTensor(e, proxy_out[idx]) if type(e) == torch.Tensor else e
106+
e = torch.empty(())
107+
# Currently assuming that all inputs to an op are the same device - not totally sure that's true
108+
if type(e) == torch.Tensor:
109+
return PythonTensor(e, proxy, output_device)
110+
else:
111+
return e
78112
if isinstance(real_out, tuple):
79-
return tuple([wrap_with_proxy(e, idx) for idx, e in enumerate(real_out)])
113+
return tuple([wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)])
80114
elif isinstance(real_out, list):
81-
return list([wrap_with_proxy(e, idx) for idx, e in enumerate(real_out)])
115+
return list([wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)])
82116
elif isinstance(real_out, torch.Tensor):
83-
return PythonTensor(real_out, proxy_out)
117+
return PythonTensor(real_out, proxy_out, output_device)
84118
else:
85119
return real_out
86120

functorch/compile/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .._src.operator_authoring import pointwise_operator
22
from .._src.memory_efficient_op_authoring import memory_efficient_pointwise_fusion, torchscript_nvfuser_compile
3-
from .._src.python_key import nnc_jit, make_nnc, pythonkey_decompose
3+
from .._src.python_key import nnc_jit, make_nnc, pythonkey_decompose, pythonkey_meta
44
from .._src.decompositions import register_decomposition, decomposition_table
55
from .._src.nnc_compile import nnc_compile, get_ops
66
from .._src.fx_minifier import minimizer
@@ -9,11 +9,11 @@
99
aot_module,
1010
compiled_function,
1111
compiled_module,
12-
tvm_compile,
1312
draw_joint_graph,
1413
default_partition,
1514
partition_with_recompute_fwd_in_bwd,
1615
num_of_recompilations,
1716
clear_compile_cache,
1817
draw_graph,
1918
)
19+
from .._src.compilers import ts_compile, tvm_compile, draw_graph_compile, nop

0 commit comments

Comments
 (0)