Skip to content

Commit 280efc3

Browse files
Copilotjpfeuffer
andcommitted
Make ArrayWrappers automatically compiled with every autowrap module
Modified compile_and_import in Utils.py to automatically compile ArrayWrappers module when numpy is detected. This makes ArrayWrappers available to all generated modules. Fixed .pxd declarations to avoid conflicts during compilation. Still debugging buffer protocol implementation - shape/strides pointer lifetime issue needs resolution. Co-authored-by: jpfeuffer <8102638+jpfeuffer@users.noreply.github.com>
1 parent e63190d commit 280efc3

File tree

4 files changed

+397
-190
lines changed

4 files changed

+397
-190
lines changed

autowrap/Utils.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,102 @@
6262
"""
6363

6464

65+
def _compile_array_wrappers_if_needed(tempdir, include_dirs, debug=False):
66+
"""
67+
Compile ArrayWrappers module if it's needed (i.e., if numpy is being used).
68+
This makes ArrayWrappers available to the module being compiled.
69+
"""
70+
import os
71+
import os.path
72+
import shutil
73+
import subprocess
74+
import sys
75+
76+
# Check if ArrayWrappers source files exist
77+
autowrap_dir = os.path.dirname(os.path.abspath(__file__))
78+
array_wrappers_dir = os.path.join(autowrap_dir, "data_files", "autowrap")
79+
array_wrappers_pyx = os.path.join(array_wrappers_dir, "ArrayWrappers.pyx")
80+
array_wrappers_pxd = os.path.join(array_wrappers_dir, "ArrayWrappers.pxd")
81+
82+
if not os.path.exists(array_wrappers_pyx):
83+
# ArrayWrappers not available, skip
84+
return
85+
86+
if debug:
87+
print("Compiling ArrayWrappers module...")
88+
89+
# Copy only ArrayWrappers.pyx to tempdir
90+
# Don't copy .pxd - Cython will auto-generate it from .pyx during compilation
91+
# The .pxd is only needed by OTHER modules that import ArrayWrappers
92+
shutil.copy(array_wrappers_pyx, tempdir)
93+
94+
# Create a simple setup.py for ArrayWrappers
95+
compile_args = []
96+
link_args = []
97+
98+
if sys.platform == "darwin":
99+
compile_args += ["-stdlib=libc++", "-std=c++17"]
100+
link_args += ["-stdlib=libc++"]
101+
102+
if sys.platform == "linux" or sys.platform == "linux2":
103+
compile_args += ["-std=c++17"]
104+
105+
if sys.platform != "win32":
106+
compile_args += ["-Wno-unused-but-set-variable"]
107+
108+
# Get numpy include directory if available
109+
try:
110+
import numpy
111+
numpy_include = numpy.get_include()
112+
except ImportError:
113+
numpy_include = None
114+
115+
# Filter include_dirs to exclude the autowrap data_files directory
116+
# to prevent Cython from finding ArrayWrappers.pxd during its own compilation
117+
filtered_include_dirs = [d for d in include_dirs if array_wrappers_dir not in os.path.abspath(d)]
118+
if numpy_include and numpy_include not in filtered_include_dirs:
119+
filtered_include_dirs.append(numpy_include)
120+
121+
include_dirs_abs = [os.path.abspath(d) for d in filtered_include_dirs]
122+
123+
setup_code = """
124+
from distutils.core import setup, Extension
125+
from Cython.Distutils import build_ext
126+
127+
ext = Extension("ArrayWrappers",
128+
sources=["ArrayWrappers.pyx"],
129+
language="c++",
130+
include_dirs=%r,
131+
extra_compile_args=%r,
132+
extra_link_args=%r)
133+
134+
setup(cmdclass={'build_ext': build_ext},
135+
name="ArrayWrappers",
136+
ext_modules=[ext])
137+
""" % (include_dirs_abs, compile_args, link_args)
138+
139+
# Write and build ArrayWrappers
140+
setup_file = os.path.join(tempdir, "setup_arraywrappers.py")
141+
with open(setup_file, "w") as fp:
142+
fp.write(setup_code)
143+
144+
# Build ArrayWrappers in the tempdir
145+
result = subprocess.Popen(
146+
"%s %s build_ext --inplace" % (sys.executable, setup_file),
147+
shell=True,
148+
cwd=tempdir
149+
).wait()
150+
151+
if result != 0:
152+
print("Warning: Failed to compile ArrayWrappers module")
153+
elif debug:
154+
print("ArrayWrappers compiled successfully")
155+
156+
# After building, copy the .pxd file so other modules can cimport from it
157+
if result == 0 and os.path.exists(array_wrappers_pxd):
158+
shutil.copy(array_wrappers_pxd, tempdir)
159+
160+
65161
def compile_and_import(name, source_files, include_dirs=None, **kws):
66162
if include_dirs is None:
67163
include_dirs = []
@@ -78,6 +174,24 @@ def compile_and_import(name, source_files, include_dirs=None, **kws):
78174
print("\n")
79175
print("tempdir=", tempdir)
80176
print("\n")
177+
178+
# Check if any source file imports ArrayWrappers (indicates numpy usage)
179+
needs_array_wrappers = False
180+
for source_file in source_files:
181+
if source_file.endswith('.pyx'):
182+
try:
183+
with open(source_file, 'r') as f:
184+
content = f.read()
185+
if 'ArrayWrappers' in content or 'ArrayWrapper' in content or 'ArrayView' in content:
186+
needs_array_wrappers = True
187+
break
188+
except:
189+
pass
190+
191+
# Compile ArrayWrappers first if needed
192+
if needs_array_wrappers:
193+
_compile_array_wrappers_if_needed(tempdir, include_dirs, debug)
194+
81195
for source_file in source_files:
82196
if source_file[-4:] != ".pyx" and source_file[-4:] != ".cpp":
83197
raise NameError("Expected pyx and/or cpp files as source files for compilation.")

autowrap/data_files/autowrap/ArrayWrappers.pxd

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,94 +11,134 @@ from libc.stdint cimport int8_t, int16_t, int32_t, int64_t, uint8_t, uint16_t, u
1111
# Owning wrapper classes (hold libcpp_vector directly)
1212
cdef class ArrayWrapperFloat:
1313
cdef libcpp_vector[float] vec
14+
cdef Py_ssize_t _shape_val
15+
cdef Py_ssize_t _strides_val
1416

1517
cdef class ArrayWrapperDouble:
1618
cdef libcpp_vector[double] vec
19+
cdef Py_ssize_t _shape_val
20+
cdef Py_ssize_t _strides_val
1721

1822
cdef class ArrayWrapperInt8:
1923
cdef libcpp_vector[int8_t] vec
24+
cdef Py_ssize_t _shape_val
25+
cdef Py_ssize_t _strides_val
2026

2127
cdef class ArrayWrapperInt16:
2228
cdef libcpp_vector[int16_t] vec
29+
cdef Py_ssize_t _shape_val
30+
cdef Py_ssize_t _strides_val
2331

2432
cdef class ArrayWrapperInt32:
2533
cdef libcpp_vector[int32_t] vec
34+
cdef Py_ssize_t _shape_val
35+
cdef Py_ssize_t _strides_val
2636

2737
cdef class ArrayWrapperInt64:
2838
cdef libcpp_vector[int64_t] vec
39+
cdef Py_ssize_t _shape_val
40+
cdef Py_ssize_t _strides_val
2941

3042
cdef class ArrayWrapperUInt8:
3143
cdef libcpp_vector[uint8_t] vec
44+
cdef Py_ssize_t _shape_val
45+
cdef Py_ssize_t _strides_val
3246

3347
cdef class ArrayWrapperUInt16:
3448
cdef libcpp_vector[uint16_t] vec
49+
cdef Py_ssize_t _shape_val
50+
cdef Py_ssize_t _strides_val
3551

3652
cdef class ArrayWrapperUInt32:
3753
cdef libcpp_vector[uint32_t] vec
54+
cdef Py_ssize_t _shape_val
55+
cdef Py_ssize_t _strides_val
3856

3957
cdef class ArrayWrapperUInt64:
4058
cdef libcpp_vector[uint64_t] vec
59+
cdef Py_ssize_t _shape_val
60+
cdef Py_ssize_t _strides_val
4161

4262
# Non-owning view classes (hold raw pointer + size + owner)
4363
cdef class ArrayViewFloat:
4464
cdef float* ptr
4565
cdef size_t _size
4666
cdef object owner
4767
cdef cbool readonly
68+
cdef Py_ssize_t _shape_val
69+
cdef Py_ssize_t _strides_val
4870

4971
cdef class ArrayViewDouble:
5072
cdef double* ptr
5173
cdef size_t _size
5274
cdef object owner
5375
cdef cbool readonly
76+
cdef Py_ssize_t _shape_val
77+
cdef Py_ssize_t _strides_val
5478

5579
cdef class ArrayViewInt8:
5680
cdef int8_t* ptr
5781
cdef size_t _size
5882
cdef object owner
5983
cdef cbool readonly
84+
cdef Py_ssize_t _shape_val
85+
cdef Py_ssize_t _strides_val
6086

6187
cdef class ArrayViewInt16:
6288
cdef int16_t* ptr
6389
cdef size_t _size
6490
cdef object owner
6591
cdef cbool readonly
92+
cdef Py_ssize_t _shape_val
93+
cdef Py_ssize_t _strides_val
6694

6795
cdef class ArrayViewInt32:
6896
cdef int32_t* ptr
6997
cdef size_t _size
7098
cdef object owner
7199
cdef cbool readonly
100+
cdef Py_ssize_t _shape_val
101+
cdef Py_ssize_t _strides_val
72102

73103
cdef class ArrayViewInt64:
74104
cdef int64_t* ptr
75105
cdef size_t _size
76106
cdef object owner
77107
cdef cbool readonly
108+
cdef Py_ssize_t _shape_val
109+
cdef Py_ssize_t _strides_val
78110

79111
cdef class ArrayViewUInt8:
80112
cdef uint8_t* ptr
81113
cdef size_t _size
82114
cdef object owner
83115
cdef cbool readonly
116+
cdef Py_ssize_t _shape_val
117+
cdef Py_ssize_t _strides_val
84118

85119
cdef class ArrayViewUInt16:
86120
cdef uint16_t* ptr
87121
cdef size_t _size
88122
cdef object owner
89123
cdef cbool readonly
124+
cdef Py_ssize_t _shape_val
125+
cdef Py_ssize_t _strides_val
90126

91127
cdef class ArrayViewUInt32:
92128
cdef uint32_t* ptr
93129
cdef size_t _size
94130
cdef object owner
95131
cdef cbool readonly
132+
cdef Py_ssize_t _shape_val
133+
cdef Py_ssize_t _strides_val
96134

97135
cdef class ArrayViewUInt64:
98136
cdef uint64_t* ptr
99137
cdef size_t _size
100138
cdef object owner
101139
cdef cbool readonly
140+
cdef Py_ssize_t _shape_val
141+
cdef Py_ssize_t _strides_val
102142

103143
# Factory functions for creating views from C level
104144
cdef ArrayViewFloat _create_view_float(float* ptr, size_t size, object owner, cbool readonly)

0 commit comments

Comments
 (0)