Skip to content

Commit da9682f

Browse files
committed
Add Accelerate framework blas__ldflags tests
1 parent 6132203 commit da9682f

File tree

3 files changed

+80
-16
lines changed

3 files changed

+80
-16
lines changed

pytensor/link/c/cmodule.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2873,9 +2873,21 @@ def check_libs(
28732873
)
28742874
except Exception as e:
28752875
_logger.debug(e)
2876+
try:
2877+
# 3. Mac Accelerate framework
2878+
_logger.debug("Checking Accelerate framework")
2879+
flags = ["-framework", "Accelerate"]
2880+
if rpath:
2881+
flags = [*flags, "-rpath", rpath]
2882+
validated_flags = try_blas_flag(flags)
2883+
if validated_flags == "":
2884+
raise Exception("Accelerate framework flag failed ")
2885+
return validated_flags
2886+
except Exception as e:
2887+
_logger.debug(e)
28762888
try:
28772889
_logger.debug("Checking Lapack + blas")
2878-
# 3. Try to use LAPACK + BLAS
2890+
# 4. Try to use LAPACK + BLAS
28792891
return check_libs(
28802892
all_libs,
28812893
required_libs=["lapack", "blas", "cblas", "m"],
@@ -2885,7 +2897,7 @@ def check_libs(
28852897
except Exception as e:
28862898
_logger.debug(e)
28872899
try:
2888-
# 4. Try to use BLAS alone
2900+
# 5. Try to use BLAS alone
28892901
_logger.debug("Checking blas alone")
28902902
return check_libs(
28912903
all_libs,
@@ -2896,7 +2908,7 @@ def check_libs(
28962908
except Exception as e:
28972909
_logger.debug(e)
28982910
try:
2899-
# 5. Try to use openblas
2911+
# 6. Try to use openblas
29002912
_logger.debug("Checking openblas")
29012913
return check_libs(
29022914
all_libs,

pytensor/tensor/blas.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
import logging
8080
import os
8181
import time
82+
from pathlib import Path
8283

8384
import numpy as np
8485

@@ -425,7 +426,7 @@ def _ldflags(
425426

426427
try:
427428
t0, t1 = t[0], t[1]
428-
assert t0 == "-"
429+
assert t0 == "-" or t == "Accelerate" or Path(t).exists()
429430
except Exception:
430431
raise ValueError(f'invalid token "{t}" in ldflags_str: "{ldflags_str}"')
431432
if libs_dir and t1 == "L":

tests/link/c/test_cmodule.py

Lines changed: 63 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -165,13 +165,22 @@ def test_flag_detection():
165165

166166
@pytest.fixture(
167167
scope="module",
168-
params=["mkl_intel", "mkl_gnu", "openblas", "lapack", "blas", "no_blas"],
168+
params=[
169+
"mkl_intel",
170+
"mkl_gnu",
171+
"accelerate",
172+
"openblas",
173+
"lapack",
174+
"blas",
175+
"no_blas",
176+
],
169177
)
170178
def blas_libs(request):
171179
key = request.param
172180
libs = {
173181
"mkl_intel": ["mkl_core", "mkl_rt", "mkl_intel_thread", "iomp5", "pthread"],
174182
"mkl_gnu": ["mkl_core", "mkl_rt", "mkl_gnu_thread", "gomp", "pthread"],
183+
"accelerate": ["vecLib_placeholder"],
175184
"openblas": ["openblas", "gfortran", "gomp", "m"],
176185
"lapack": ["lapack", "blas", "cblas", "m"],
177186
"blas": ["blas", "cblas"],
@@ -190,25 +199,37 @@ def mock_system(request):
190199
def cxx_search_dirs(blas_libs, mock_system):
191200
libext = {"Linux": "so", "Windows": "dll", "Darwin": "dylib"}
192201
libraries = []
202+
enabled_accelerate_framework = False
193203
with tempfile.TemporaryDirectory() as d:
194204
flags = None
195205
for lib in blas_libs:
196-
lib_path = Path(d) / f"{lib}.{libext[mock_system]}"
197-
lib_path.write_bytes(b"1")
198-
libraries.append(lib_path)
199-
if flags is None:
200-
flags = f"-l{lib}"
206+
if lib == "vecLib_placeholder":
207+
if mock_system != "Darwin":
208+
flags = ""
209+
else:
210+
flags = "-framework Accelerate"
211+
enabled_accelerate_framework = True
201212
else:
202-
flags += f" -l{lib}"
213+
lib_path = Path(d) / f"{lib}.{libext[mock_system]}"
214+
lib_path.write_bytes(b"1")
215+
libraries.append(lib_path)
216+
if flags is None:
217+
flags = f"-l{lib}"
218+
else:
219+
flags += f" -l{lib}"
203220
if "gomp" in blas_libs and "mkl_gnu_thread" not in blas_libs:
204221
flags += " -fopenmp"
205222
if len(blas_libs) == 0:
206223
flags = ""
207-
yield f"libraries: ={d}".encode(sys.stdout.encoding), flags
224+
yield (
225+
f"libraries: ={d}".encode(sys.stdout.encoding),
226+
flags,
227+
enabled_accelerate_framework,
228+
)
208229

209230

210231
@pytest.fixture(
211-
scope="function", params=[False, True], ids=["Working_CXX", "Broken_CXX"]
232+
scope="function", params=[True, False], ids=["Working_CXX", "Broken_CXX"]
212233
)
213234
def cxx_search_dirs_status(request):
214235
return request.param
@@ -219,22 +240,39 @@ def cxx_search_dirs_status(request):
219240
def test_default_blas_ldflags(
220241
mock_std_lib_dirs, mock_check_mkl_openmp, cxx_search_dirs, cxx_search_dirs_status
221242
):
222-
cxx_search_dirs, expected_blas_ldflags = cxx_search_dirs
243+
cxx_search_dirs, expected_blas_ldflags, enabled_accelerate_framework = (
244+
cxx_search_dirs
245+
)
223246
mock_process = MagicMock()
224247
if cxx_search_dirs_status:
225248
error_message = ""
226249
mock_process.communicate = lambda *args, **kwargs: (cxx_search_dirs, b"")
227250
mock_process.returncode = 0
228251
else:
252+
enabled_accelerate_framework = False
229253
error_message = "Unsupported argument -print-search-dirs"
230254
error_message_bytes = error_message.encode(sys.stderr.encoding)
231255
mock_process.communicate = lambda *args, **kwargs: (b"", error_message_bytes)
232256
mock_process.returncode = 1
257+
258+
def patched_compile_tmp(*args, **kwargs):
259+
def wrapped(test_code, tmp_prefix, flags, try_run, output):
260+
if len(flags) >= 2 and flags[:2] == ["-framework", "Accelerate"]:
261+
print(enabled_accelerate_framework)
262+
if enabled_accelerate_framework:
263+
return (True, True)
264+
else:
265+
return (False, False, "", "Invalid flags -framework Accelerate")
266+
else:
267+
return (True, True)
268+
269+
return wrapped
270+
233271
with patch("pytensor.link.c.cmodule.subprocess_Popen", return_value=mock_process):
234272
with patch.object(
235273
pytensor.link.c.cmodule.GCC_compiler,
236274
"try_compile_tmp",
237-
return_value=(True, True),
275+
new_callable=patched_compile_tmp,
238276
):
239277
if cxx_search_dirs_status:
240278
assert set(default_blas_ldflags().split(" ")) == set(
@@ -267,6 +305,9 @@ def windows_conda_libs(blas_libs):
267305
subdir.mkdir(exist_ok=True, parents=True)
268306
flags = f'-L"{subdir}"'
269307
for lib in blas_libs:
308+
if lib == "vecLib_placeholder":
309+
flags = ""
310+
break
270311
lib_path = subdir / f"{lib}.dll"
271312
lib_path.write_bytes(b"1")
272313
libraries.append(lib_path)
@@ -287,6 +328,16 @@ def test_default_blas_ldflags_conda_windows(
287328
mock_process = MagicMock()
288329
mock_process.communicate = lambda *args, **kwargs: (b"", b"")
289330
mock_process.returncode = 0
331+
332+
def patched_compile_tmp(*args, **kwargs):
333+
def wrapped(test_code, tmp_prefix, flags, try_run, output):
334+
if len(flags) >= 2 and flags[:2] == ["-framework", "Accelerate"]:
335+
return (False, False, "", "Invalid flags -framework Accelerate")
336+
else:
337+
return (True, True)
338+
339+
return wrapped
340+
290341
with patch("sys.platform", "win32"):
291342
with patch("sys.prefix", mock_sys_prefix):
292343
with patch(
@@ -295,7 +346,7 @@ def test_default_blas_ldflags_conda_windows(
295346
with patch.object(
296347
pytensor.link.c.cmodule.GCC_compiler,
297348
"try_compile_tmp",
298-
return_value=(True, True),
349+
new_callable=patched_compile_tmp,
299350
):
300351
assert set(default_blas_ldflags().split(" ")) == set(
301352
expected_blas_ldflags.split(" ")

0 commit comments

Comments
 (0)