@@ -86,6 +86,10 @@ def _is_env_enabled(env_var: str, default: bool = False) -> bool:
8686 def pybindings (cls ) -> bool :
8787 return cls ._is_env_enabled ("EXECUTORCH_BUILD_PYBIND" , default = False )
8888
89+ @classmethod
90+ def training (cls ) -> bool :
91+ return cls ._is_env_enabled ("EXECUTORCH_BUILD_TRAINING" , default = True )
92+
8993 @classmethod
9094 def llama_custom_ops (cls ) -> bool :
9195 return cls ._is_env_enabled ("EXECUTORCH_BUILD_KERNELS_CUSTOM_AOT" , default = True )
@@ -575,6 +579,10 @@ def run(self):
575579 "-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON" , # add quantized ops to pybindings.
576580 "-DEXECUTORCH_BUILD_KERNELS_QUANTIZED_AOT=ON" ,
577581 ]
582+ if ShouldBuild .training ():
583+ cmake_args += [
584+ "-DEXECUTORCH_BUILD_EXTENSION_TRAINING=ON" ,
585+ ]
578586 build_args += ["--target" , "portable_lib" ]
579587 # To link backends into the portable_lib target, callers should
580588 # add entries like `-DEXECUTORCH_BUILD_XNNPACK=ON` to the CMAKE_ARGS
@@ -677,6 +685,14 @@ def get_ext_modules() -> List[Extension]:
677685 "_portable_lib.*" , "executorch.extension.pybindings._portable_lib"
678686 )
679687 )
688+ if ShouldBuild .training ():
689+
690+ ext_modules .append (
691+ # Install the prebuilt pybindings extension wrapper for training
692+ BuiltExtension (
693+ "_training_lib.*" , "executorch.extension.training.pybindings._training_lib"
694+ )
695+ )
680696 if ShouldBuild .llama_custom_ops ():
681697 ext_modules .append (
682698 BuiltFile (
0 commit comments