11# setup.py
22from setuptools import setup , find_packages
3- from torch .utils .cpp_extension import BuildExtension , CUDAExtension
4- import torch
53import os
64import platform
75
8- # Check if CUDA is available
9- if not torch .cuda .is_available ():
10- raise RuntimeError ("CUDA is not available. This package requires CUDA." )
6+ # Try to import torch and CUDA extension, but don't fail if not available
7+ try :
8+ import torch
9+ from torch .utils .cpp_extension import BuildExtension , CUDAExtension
10+ TORCH_AVAILABLE = True
11+ CUDA_AVAILABLE = torch .cuda .is_available ()
12+ except ImportError :
13+ TORCH_AVAILABLE = False
14+ CUDA_AVAILABLE = False
15+ # Dummy classes for when torch is not available
16+ class BuildExtension :
17+ pass
18+ class CUDAExtension :
19+ def __init__ (self , * args , ** kwargs ):
20+ pass
1121
1222# Platform-specific compiler arguments
1323def get_compiler_args ():
@@ -32,33 +42,26 @@ def get_compiler_args():
3242# Optionally set TORCH_CUDA_ARCH_LIST to restrict architectures, e.g. "8.6;8.0"
3343# os.environ['TORCH_CUDA_ARCH_LIST'] = "8.6" # adjust for your GPU if desired
3444
35- setup (
36- name = 'emd_ext' ,
37- version = '1.0.0' ,
38- description = 'Earth Mover Distance (EMD) CUDA extension for PyTorch' ,
39- long_description = open ('README.md' , 'r' , encoding = 'utf-8' ).read (),
40- long_description_content_type = 'text/markdown' ,
41- author = 'Haoqiang Fan, Kaichun Mo, Jiayuan Gu' ,
42- maintainer = 'hieulhaiwork' ,
43- url = 'https://github.com/hieulhaiwork/EMD-Pytorch' ,
44- packages = find_packages (),
45- install_requires = [
45+ # Determine if we should build CUDA extensions
46+ BUILD_CUDA = TORCH_AVAILABLE and CUDA_AVAILABLE and os .environ .get ('SKIP_CUDA_BUILD' , '0' ) != '1'
47+
48+ # Setup arguments
49+ setup_args = {
50+ 'name' : 'emd_ext' ,
51+ 'version' : '1.0.0' ,
52+ 'description' : 'Earth Mover Distance (EMD) CUDA extension for PyTorch' ,
53+ 'long_description' : open ('README.md' , 'r' , encoding = 'utf-8' ).read (),
54+ 'long_description_content_type' : 'text/markdown' ,
55+ 'author' : 'Haoqiang Fan, Kaichun Mo, Jiayuan Gu' ,
56+ 'maintainer' : 'hieulhaiwork' ,
57+ 'url' : 'https://github.com/hieulhaiwork/EMD-Pytorch' ,
58+ 'packages' : find_packages (),
59+ 'install_requires' : [
4660 'torch>=1.8.0' ,
4761 'numpy' ,
4862 ],
49- python_requires = '>=3.7' ,
50- ext_modules = [
51- CUDAExtension (
52- name = 'emd_cuda' ,
53- sources = [
54- 'emd/cuda/emd.cpp' ,
55- 'emd/cuda/emd_kernel.cu' ,
56- ],
57- extra_compile_args = get_compiler_args ()
58- ),
59- ],
60- cmdclass = {'build_ext' : BuildExtension },
61- classifiers = [
63+ 'python_requires' : '>=3.7' ,
64+ 'classifiers' : [
6265 'Development Status :: 4 - Beta' ,
6366 'Intended Audience :: Science/Research' ,
6467 'License :: OSI Approved :: MIT License' ,
@@ -74,4 +77,20 @@ def get_compiler_args():
7477 'Operating System :: POSIX :: Linux' ,
7578 'Operating System :: Microsoft :: Windows' ,
7679 ],
77- )
80+ }
81+
82+ # Only add CUDA extensions if available
83+ if BUILD_CUDA :
84+ setup_args ['ext_modules' ] = [
85+ CUDAExtension (
86+ name = 'emd_cuda' ,
87+ sources = [
88+ 'emd/cuda/emd.cpp' ,
89+ 'emd/cuda/emd_kernel.cu' ,
90+ ],
91+ extra_compile_args = get_compiler_args ()
92+ ),
93+ ]
94+ setup_args ['cmdclass' ] = {'build_ext' : BuildExtension }
95+
96+ setup (** setup_args )
0 commit comments