Skip to content

Commit 08f1a65

Browse files
authored
Namespace cleanup (#229)
Two main things happened: - I removed {wrap_key, PythonTensor, pythonkey_trace} from being public APIs - I moved all compilation related things to the functorch.compile namespace. This includes nnc_jit which is now in functorch.compile.nnc_jit Concerns: - nnc_jit was in the functorch namespace for a long time. Should we leave it there? Are there stakeholders to notify?
1 parent eeb80e5 commit 08f1a65

File tree

4 files changed

+26
-8
lines changed

4 files changed

+26
-8
lines changed

benchmarks/operator_authoring.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pandas as pd
44
import timeit
55
import torch
6-
from functorch import pointwise_operator
6+
from functorch.compile import pointwise_operator
77

88
WRITE_CSV = False
99
CUDA = False

functorch/__init__.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,24 @@
99
import textwrap
1010
from . import _C
1111

12+
# Top-level APIs. Please think carefully before adding something to the
13+
# top-level namespace:
14+
# - private helper functions should go into functorch._src
15+
# - very experimental things should go into functorch.experimental
16+
# - compilation related things should go into functorch.compile
17+
18+
# functorch transforms
1219
from ._src.vmap import vmap
1320
from ._src.eager_transforms import grad, grad_and_value, vjp, jacrev
21+
from ._src.python_key import make_fx
22+
23+
# utilities. Maybe these should go in their own namespace in the future?
1424
from ._src.make_functional import (
1525
make_functional_with_buffers,
1626
make_functional,
1727
combine_state_for_ensemble,
1828
FunctionalModule,
1929
)
20-
from ._src.python_key import wrap_key, PythonTensor, pythonkey_trace, make_fx, nnc_jit, make_nnc
21-
from ._src.nnc_compile import nnc_compile, get_ops
22-
from ._src.eager_compilation import compiled_function, compiled_module, tvm_compile, draw_joint_graph, default_partition, partition_with_recompute_fwd_in_bwd
23-
from ._src.operator_authoring import pointwise_operator
24-
2530

2631
# Monkeypatching lol
2732
_old_cross_entropy = torch.nn.functional.cross_entropy

functorch/compile/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,11 @@
1-
from .._src.operator_authoring import pointwise_operator
1+
from .._src.operator_authoring import pointwise_operator
2+
from .._src.python_key import nnc_jit, make_nnc
3+
from .._src.nnc_compile import nnc_compile, get_ops
4+
from .._src.eager_compilation import (
5+
compiled_function,
6+
compiled_module,
7+
tvm_compile,
8+
draw_joint_graph,
9+
default_partition,
10+
partition_with_recompute_fwd_in_bwd,
11+
)

test/test_pythonkey.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
import functorch
2424
from functorch import (
2525
grad, vjp, vmap, jacrev, grad_and_value,
26-
make_fx, nnc_jit, compiled_function, compiled_module,
26+
make_fx,
27+
)
28+
from functorch.compile import (
29+
nnc_jit, compiled_function, compiled_module,
2730
partition_with_recompute_fwd_in_bwd
2831
)
2932

0 commit comments

Comments
 (0)