-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsetup.py
More file actions
31 lines (27 loc) · 974 Bytes
/
setup.py
File metadata and controls
31 lines (27 loc) · 974 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import os
# Set environment variables to use ccache
os.environ['PYTORCH_NVCC'] = "ccache nvcc"
os.environ['TORCH_EXTENSION_SKIP_NVCC_GEN_DEPENDENCIES'] = '1'
from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name="spix_paper",
py_modules=["spix_paper"],
install_requires=[],
package_dir={"": "."},
packages=find_packages("."),
package_data={'': ['*.so']},
include_package_data=True,
ext_modules=[
CUDAExtension('spix_paper_cuda', [
# -- pairwise distance --
'spix_paper/csrc/pwd/pair_wise_distance_cuda_source.cu',
# -- apis --
'spix_paper/csrc/sna/attn_reweight.cu',
'spix_paper/csrc/sna/gather_sims.cu',
# -- pybind --
"spix_paper/csrc/pybind.cpp",
],extra_compile_args={'cxx': ['-g','-w'],'nvcc': ['-O3','-w']}),
],
cmdclass={'build_ext': BuildExtension},
)