Skip to content

Commit 4f69429

Browse files
authored
update compile example imports (#834)
1 parent 21b2394 commit 4f69429

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

examples/compilation/fuse_module.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import timeit
2-
from functorch import compiled_module, tvm_compile
2+
from functorch.compile import compiled_module, tvm_compile
33
import torch.nn as nn
44
import torch
55
from functools import partial
@@ -9,8 +9,8 @@ def nop(f, _):
99
return f
1010

1111

12-
fw_compiler = partial(tvm_compile, name='fw_keops')
13-
bw_compiler = partial(tvm_compile, name='bw_keops')
12+
fw_compiler = tvm_compile(target='llvm', tuning_logfile='fw_keops')
13+
bw_compiler = tvm_compile(target='llvm', tuning_logfile='bw_keops')
1414
fw_compiler = nop
1515
bw_compiler = nop
1616

examples/compilation/linear_train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from functorch import nnc_jit, make_functional
7+
from functorch import make_functional
8+
from functorch.compile import nnc_jit
89
import torch
910
import torch.nn as nn
1011
import time

examples/compilation/simple_function.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from functorch import grad, nnc_jit, make_fx, make_nnc
7+
from functorch import grad, make_fx
8+
from functorch.compile import nnc_jit
89
import torch
910
import time
1011

@@ -16,9 +17,7 @@ def f(x):
1617
inp = torch.randn(100)
1718
grad_pt = grad(f)
1819
grad_fx = make_fx(grad_pt)(inp)
19-
grad_nnc = nnc_jit(grad_pt, skip_specialization=True)
20-
loopnest = make_nnc(grad_pt)(inp)
21-
print(loopnest)
20+
grad_nnc = nnc_jit(grad_pt)
2221

2322

2423
def bench(name, f, iters=10000, warmup=3):

0 commit comments

Comments
 (0)