1313from build_tools import setup_helpers
1414from setuptools import setup , find_packages
1515
16+ import glob
17+ from torch .utils .cpp_extension import (
18+ CppExtension ,
19+ BuildExtension ,
20+ )
21+
22+
1623
1724def _get_pytorch_version ():
1825 if "PYTORCH_VERSION" in os .environ :
@@ -60,6 +67,50 @@ def _run_cmd(cmd):
6067 return None
6168
6269
70+ def get_extensions ():
71+ extension = CppExtension
72+
73+ extra_link_args = []
74+ extra_compile_args = {"cxx" : [
75+ "-O3" ,
76+ "-std=c++14" ,
77+ "-fdiagnostics-color=always" ,
78+ ]}
79+ debug_mode = os .getenv ('DEBUG' , '0' ) == '1'
80+ if debug_mode :
81+ print ("Compiling in debug mode" )
82+ extra_compile_args = {
83+ "cxx" : [
84+ "-O0" ,
85+ "-fno-inline" ,
86+ "-g" ,
87+ "-std=c++14" ,
88+ "-fdiagnostics-color=always" ,
89+ ]}
90+ extra_link_args = ["-O0" , "-g" ]
91+
92+ this_dir = os .path .dirname (os .path .abspath (__file__ ))
93+ extensions_dir = os .path .join (this_dir , "torchrl" , "csrc" )
94+
95+ extension_sources = set (
96+ os .path .join (extensions_dir , p )
97+ for p in glob .glob (os .path .join (extensions_dir , "*.cpp" ))
98+ )
99+ sources = list (extension_sources )
100+
101+ ext_modules = [
102+ extension (
103+ "torchrl._torchrl" ,
104+ sources ,
105+ include_dirs = [this_dir ],
106+ extra_compile_args = extra_compile_args ,
107+ extra_link_args = extra_link_args ,
108+ )
109+ ]
110+
111+ return ext_modules
112+
113+
63114def _main ():
64115 pytorch_package_dep = _get_pytorch_version ()
65116 print ("-- PyTorch dependency:" , pytorch_package_dep )
@@ -71,10 +122,10 @@ def _main():
71122 version = "0.1" ,
72123 author = "torchrl contributors" ,
7312474- packages = _get_packages (),
75- ext_modules = setup_helpers . get_ext_modules (),
125+ packages = find_packages (),
126+ ext_modules = get_extensions (),
76127 cmdclass = {
77- "build_ext" : setup_helpers . CMakeBuild ,
128+ "build_ext" : BuildExtension . with_options ( no_python_abi_suffix = True ) ,
78129 "clean" : clean ,
79130 },
80131 install_requires = [pytorch_package_dep , "numpy" , "tensorboard" , "packaging" ],
0 commit comments