Skip to content

Commit 2c0689a

Browse files
committed
Fix symbol errors on Windows
1 parent 2483d9e commit 2c0689a

File tree

3 files changed

+81
-56
lines changed

3 files changed

+81
-56
lines changed

.github/workflows/ci.yml

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
run: |
3333
python ci_check.py
3434
35-
build-test:
35+
package-test:
3636
runs-on: ubuntu-latest
3737
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
3838

@@ -44,16 +44,22 @@ jobs:
4444
with:
4545
python-version: 3.9
4646

47-
- name: Install build dependencies
47+
- name: Install basic dependencies
4848
run: |
4949
python -m pip install --upgrade pip
50-
pip install build wheel setuptools
50+
pip install setuptools wheel
5151
52-
- name: Build package
52+
- name: Validate package metadata
5353
run: |
54-
python -m build
54+
python setup.py check --metadata --strict
5555
56-
- name: Check package
56+
- name: Test package structure
5757
run: |
58-
pip install twine
59-
twine check dist/*
58+
python -c "
59+
import os
60+
assert os.path.exists('emd/'), 'emd directory missing'
61+
assert os.path.exists('emd/__init__.py'), 'emd/__init__.py missing'
62+
assert os.path.exists('emd/emd.py'), 'emd/emd.py missing'
63+
assert os.path.exists('emd/cuda/emd.cpp'), 'CUDA sources missing'
64+
print('Package structure validated!')
65+
"

ci_check.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ def check_python_version():
1414
print(f" Current: {version.major}.{version.minor}.{version.micro}")
1515

1616
if version >= (3, 7):
17-
print(" Compatible")
17+
print(" [OK] Compatible")
1818
return True
1919
else:
20-
print(" Requires Python 3.7+")
20+
print(" [FAIL] Requires Python 3.7+")
2121
return False
2222

2323
def check_platform():
@@ -28,10 +28,10 @@ def check_platform():
2828
print(f" Architecture: {platform.machine()}")
2929

3030
if system in ['Windows', 'Linux', 'Darwin']:
31-
print(" Supported platform")
31+
print(" [OK] Supported platform")
3232
return True
3333
else:
34-
print(" ⚠️ Untested platform")
34+
print(" [WARN] Untested platform")
3535
return False
3636

3737
def check_basic_imports():
@@ -41,17 +41,17 @@ def check_basic_imports():
4141
try:
4242
import torch
4343
print(f" PyTorch version: {torch.__version__}")
44-
print(" PyTorch import successful")
44+
print(" [OK] PyTorch import successful")
4545
except ImportError:
46-
print(" PyTorch not available")
46+
print(" [FAIL] PyTorch not available")
4747
return False
4848

4949
try:
5050
import numpy
5151
print(f" NumPy version: {numpy.__version__}")
52-
print(" NumPy import successful")
52+
print(" [OK] NumPy import successful")
5353
except ImportError:
54-
print(" NumPy not available")
54+
print(" [FAIL] NumPy not available")
5555
return False
5656

5757
return True
@@ -62,27 +62,27 @@ def check_package_structure():
6262

6363
# Check if emd directory exists
6464
if not os.path.exists('emd'):
65-
print(" emd directory not found")
65+
print(" [FAIL] emd directory not found")
6666
return False
67-
print(" emd directory exists")
67+
print(" [OK] emd directory exists")
6868

6969
# Check if __init__.py exists
7070
if not os.path.exists('emd/__init__.py'):
71-
print(" emd/__init__.py not found")
71+
print(" [FAIL] emd/__init__.py not found")
7272
return False
73-
print(" emd/__init__.py exists")
73+
print(" [OK] emd/__init__.py exists")
7474

7575
# Check if emd.py exists
7676
if not os.path.exists('emd/emd.py'):
77-
print(" emd/emd.py not found")
77+
print(" [FAIL] emd/emd.py not found")
7878
return False
79-
print(" emd/emd.py exists")
79+
print(" [OK] emd/emd.py exists")
8080

8181
# Check if CUDA directory exists
8282
if not os.path.exists('emd/cuda'):
83-
print(" emd/cuda directory not found")
83+
print(" [FAIL] emd/cuda directory not found")
8484
return False
85-
print(" emd/cuda directory exists")
85+
print(" [OK] emd/cuda directory exists")
8686

8787
return True
8888

@@ -101,10 +101,10 @@ def main():
101101

102102
print("\n" + "=" * 50)
103103
if all(checks):
104-
print("🎉 All compatibility checks passed!")
104+
print("SUCCESS: All compatibility checks passed!")
105105
sys.exit(0)
106106
else:
107-
print(" Some checks failed!")
107+
print("FAILED: Some checks failed!")
108108
sys.exit(1)
109109

110110
if __name__ == "__main__":

setup.py

Lines changed: 49 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,23 @@
11
# setup.py
22
from setuptools import setup, find_packages
3-
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4-
import torch
53
import os
64
import platform
75

8-
# Check if CUDA is available
9-
if not torch.cuda.is_available():
10-
raise RuntimeError("CUDA is not available. This package requires CUDA.")
6+
# Try to import torch and CUDA extension, but don't fail if not available
7+
try:
8+
import torch
9+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
10+
TORCH_AVAILABLE = True
11+
CUDA_AVAILABLE = torch.cuda.is_available()
12+
except ImportError:
13+
TORCH_AVAILABLE = False
14+
CUDA_AVAILABLE = False
15+
# Dummy classes for when torch is not available
16+
class BuildExtension:
17+
pass
18+
class CUDAExtension:
19+
def __init__(self, *args, **kwargs):
20+
pass
1121

1222
# Platform-specific compiler arguments
1323
def get_compiler_args():
@@ -32,33 +42,26 @@ def get_compiler_args():
3242
# Optionally set TORCH_CUDA_ARCH_LIST to restrict architectures, e.g. "8.6;8.0"
3343
# os.environ['TORCH_CUDA_ARCH_LIST'] = "8.6" # adjust for your GPU if desired
3444

35-
setup(
36-
name='emd_ext',
37-
version='1.0.0',
38-
description='Earth Mover Distance (EMD) CUDA extension for PyTorch',
39-
long_description=open('README.md', 'r', encoding='utf-8').read(),
40-
long_description_content_type='text/markdown',
41-
author='Haoqiang Fan, Kaichun Mo, Jiayuan Gu',
42-
maintainer='hieulhaiwork',
43-
url='https://github.com/hieulhaiwork/EMD-Pytorch',
44-
packages=find_packages(),
45-
install_requires=[
45+
# Determine if we should build CUDA extensions
46+
BUILD_CUDA = TORCH_AVAILABLE and CUDA_AVAILABLE and os.environ.get('SKIP_CUDA_BUILD', '0') != '1'
47+
48+
# Setup arguments
49+
setup_args = {
50+
'name': 'emd_ext',
51+
'version': '1.0.0',
52+
'description': 'Earth Mover Distance (EMD) CUDA extension for PyTorch',
53+
'long_description': open('README.md', 'r', encoding='utf-8').read(),
54+
'long_description_content_type': 'text/markdown',
55+
'author': 'Haoqiang Fan, Kaichun Mo, Jiayuan Gu',
56+
'maintainer': 'hieulhaiwork',
57+
'url': 'https://github.com/hieulhaiwork/EMD-Pytorch',
58+
'packages': find_packages(),
59+
'install_requires': [
4660
'torch>=1.8.0',
4761
'numpy',
4862
],
49-
python_requires='>=3.7',
50-
ext_modules=[
51-
CUDAExtension(
52-
name='emd_cuda',
53-
sources=[
54-
'emd/cuda/emd.cpp',
55-
'emd/cuda/emd_kernel.cu',
56-
],
57-
extra_compile_args=get_compiler_args()
58-
),
59-
],
60-
cmdclass={'build_ext': BuildExtension},
61-
classifiers=[
63+
'python_requires': '>=3.7',
64+
'classifiers': [
6265
'Development Status :: 4 - Beta',
6366
'Intended Audience :: Science/Research',
6467
'License :: OSI Approved :: MIT License',
@@ -74,4 +77,20 @@ def get_compiler_args():
7477
'Operating System :: POSIX :: Linux',
7578
'Operating System :: Microsoft :: Windows',
7679
],
77-
)
80+
}
81+
82+
# Only add CUDA extensions if available
83+
if BUILD_CUDA:
84+
setup_args['ext_modules'] = [
85+
CUDAExtension(
86+
name='emd_cuda',
87+
sources=[
88+
'emd/cuda/emd.cpp',
89+
'emd/cuda/emd_kernel.cu',
90+
],
91+
extra_compile_args=get_compiler_args()
92+
),
93+
]
94+
setup_args['cmdclass'] = {'build_ext': BuildExtension}
95+
96+
setup(**setup_args)

0 commit comments

Comments
 (0)