Skip to content

Commit 612e0c1

Browse files
committed
changes
1 parent a3537af commit 612e0c1

File tree

3 files changed

+29
-6
lines changed

3 files changed

+29
-6
lines changed

CMakeLists.txt

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -796,15 +796,19 @@ if(EXECUTORCH_BUILD_PYBIND)
796796

797797
set(_pybind_training_dep_libs
798798
${TORCH_PYTHON_LIBRARY}
799-
bundled_program
800799
etdump
801800
executorch
802-
extension_data_loader
803801
util
804802
torch
805803
extension_training
806804
)
807805

806+
if(EXECUTORCH_BUILD_XNNPACK)
807+
# need to explicitly specify XNNPACK and microkernels-prod
808+
# here otherwise uses XNNPACK and microkernel-prod symbols from libtorch_cpu
809+
list(APPEND _pybind_training_dep_libs xnnpack_backend XNNPACK microkernels-prod)
810+
endif()
811+
808812
# pybind training
809813
pybind11_add_module(_training_lib SHARED extension/training/pybindings/_training_lib.cpp)
810814

@@ -816,7 +820,7 @@ if(EXECUTORCH_BUILD_PYBIND)
816820
target_link_libraries(_training_lib PRIVATE ${_pybind_training_dep_libs})
817821

818822
install(TARGETS _training_lib
819-
LIBRARY DESTINATION extension/training/pybindings
823+
LIBRARY DESTINATION executorch/extension/training/pybindings
820824
)
821825
endif()
822826
endif()

install_executorch.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def clean():
3232
print("Done cleaning build artifacts.")
3333

3434

35-
VALID_PYBINDS = ["coreml", "mps", "xnnpack"]
35+
VALID_PYBINDS = ["coreml", "mps", "xnnpack", "training"]
3636

3737

3838
def main(args):
@@ -78,8 +78,11 @@ def main(args):
7878
raise Exception(
7979
f"Unrecognized pybind argument {pybind_arg}; valid options are: {', '.join(VALID_PYBINDS)}"
8080
)
81-
EXECUTORCH_BUILD_PYBIND = "ON"
82-
CMAKE_ARGS += f" -DEXECUTORCH_BUILD_{pybind_arg.upper()}=ON"
81+
if pybind_arg == "training":
82+
CMAKE_ARGS += " -DEXECUTORCH_BUILD_EXTENSION_TRAINING=ON"
83+
else:
84+
EXECUTORCH_BUILD_PYBIND = "ON"
85+
CMAKE_ARGS += f" -DEXECUTORCH_BUILD_{pybind_arg.upper()}=ON"
8386

8487
if args.clean:
8588
clean()

setup.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ def _is_env_enabled(env_var: str, default: bool = False) -> bool:
8686
def pybindings(cls) -> bool:
8787
return cls._is_env_enabled("EXECUTORCH_BUILD_PYBIND", default=False)
8888

89+
@classmethod
90+
def training(cls) -> bool:
91+
return cls._is_env_enabled("EXECUTORCH_BUILD_TRAINING", default=True)
92+
8993
@classmethod
9094
def llama_custom_ops(cls) -> bool:
9195
return cls._is_env_enabled("EXECUTORCH_BUILD_KERNELS_CUSTOM_AOT", default=True)
@@ -575,6 +579,10 @@ def run(self):
575579
"-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON", # add quantized ops to pybindings.
576580
"-DEXECUTORCH_BUILD_KERNELS_QUANTIZED_AOT=ON",
577581
]
582+
if ShouldBuild.training():
583+
cmake_args += [
584+
"-DEXECUTORCH_BUILD_EXTENSION_TRAINING=ON",
585+
]
578586
build_args += ["--target", "portable_lib"]
579587
# To link backends into the portable_lib target, callers should
580588
# add entries like `-DEXECUTORCH_BUILD_XNNPACK=ON` to the CMAKE_ARGS
@@ -677,6 +685,14 @@ def get_ext_modules() -> List[Extension]:
677685
"_portable_lib.*", "executorch.extension.pybindings._portable_lib"
678686
)
679687
)
688+
if ShouldBuild.training():
689+
690+
ext_modules.append(
691+
# Install the prebuilt pybindings extension wrapper for training
692+
BuiltExtension(
693+
"_training_lib.*", "executorch.extension.training.pybindings._training_lib"
694+
)
695+
)
680696
if ShouldBuild.llama_custom_ops():
681697
ext_modules.append(
682698
BuiltFile(

0 commit comments

Comments
 (0)