2020import textwrap
2121import time
2222import warnings
23- from collections .abc import Callable
23+ from collections .abc import Callable , Collection , Sequence
2424from contextlib import AbstractContextManager , nullcontext
2525from io import BytesIO , StringIO
2626from pathlib import Path
@@ -2736,6 +2736,96 @@ def check_mkl_openmp():
27362736 )
27372737
27382738
2739+ def _check_required_file (
2740+ paths : Collection [Path ],
2741+ required_regexs : Collection [str | re .Pattern [str ]],
2742+ ) -> list [tuple [str , str ]]:
2743+ """Select path parents for each required pattern."""
2744+ libs : list [tuple [str , str ]] = []
2745+ for req in required_regexs :
2746+ found = False
2747+ for path in paths :
2748+ m = re .search (req , path .name )
2749+ if m :
2750+ libs .append ((str (path .parent ), m .string [slice (* m .span ())]))
2751+ found = True
2752+ break
2753+ if not found :
2754+ _logger .debug ("Required file '%s' not found" , req )
2755+ raise RuntimeError (f"Required file { req } not found" )
2756+ return libs
2757+
2758+
2759+ def _get_cxx_library_dirs () -> list [str ]:
2760+ """Query C++ search dirs and return those the existing ones."""
2761+ cmd = [config .cxx , "-print-search-dirs" ]
2762+ p = subprocess_Popen (
2763+ cmd ,
2764+ stdout = subprocess .PIPE ,
2765+ stderr = subprocess .PIPE ,
2766+ stdin = subprocess .PIPE ,
2767+ )
2768+ (stdout , stderr ) = p .communicate (input = b"" )
2769+ if p .returncode != 0 :
2770+ warnings .warn (
2771+ "Pytensor cxx failed to communicate its search dirs. As a consequence, "
2772+ "it might not be possible to automatically determine the blas link flags to use.\n "
2773+ f"Command that was run: { config .cxx } -print-search-dirs\n "
2774+ f"Output printed to stderr: { stderr .decode (sys .stderr .encoding )} "
2775+ )
2776+ return []
2777+
2778+ maybe_lib_dirs = [
2779+ [Path (p ).resolve () for p in line [len ("libraries: =" ) :].split (":" )]
2780+ for line in stdout .decode (sys .getdefaultencoding ()).splitlines ()
2781+ if line .startswith ("libraries: =" )
2782+ ]
2783+ if not maybe_lib_dirs :
2784+ return []
2785+ return [str (d ) for d in maybe_lib_dirs [0 ] if d .exists () and d .is_dir ()]
2786+
2787+
2788+ def _check_libs (
2789+ all_libs : Collection [Path ],
2790+ required_libs : Collection [str | re .Pattern ],
2791+ extra_compile_flags : Sequence [str ] = (),
2792+ cxx_library_dirs : Sequence [str ] = (),
2793+ ) -> str :
2794+ """Assembly library paths and try BLAS flags, returning the flags on success."""
2795+ found_libs = _check_required_file (
2796+ all_libs ,
2797+ required_libs ,
2798+ )
2799+ path_quote = '"' if sys .platform == "win32" else ""
2800+ libdir_ldflags = list (
2801+ dict .fromkeys (
2802+ [
2803+ f"-L{ path_quote } { lib_path } { path_quote } "
2804+ for lib_path , _ in found_libs
2805+ if lib_path not in cxx_library_dirs
2806+ ]
2807+ )
2808+ )
2809+
2810+ flags = (
2811+ libdir_ldflags
2812+ + [f"-l{ lib_name } " for _ , lib_name in found_libs ]
2813+ + list (extra_compile_flags )
2814+ )
2815+ res = try_blas_flag (flags )
2816+ if not res :
2817+ _logger .debug ("Supplied flags '%s' failed to compile" , res )
2818+ raise RuntimeError (f"Supplied flags { flags } failed to compile" )
2819+
2820+ if any ("mkl" in flag for flag in flags ):
2821+ try :
2822+ check_mkl_openmp ()
2823+ except Exception as e :
2824+ _logger .debug (e )
2825+ _logger .debug ("The following blas flags will be used: '%s'" , res )
2826+ return res
2827+
2828+
27392829def default_blas_ldflags () -> str :
27402830 """Look for an available BLAS implementation in the system.
27412831
@@ -2763,88 +2853,6 @@ def default_blas_ldflags() -> str:
27632853
27642854 """
27652855
2766- def check_required_file (paths , required_regexs ):
2767- libs = []
2768- for req in required_regexs :
2769- found = False
2770- for path in paths :
2771- m = re .search (req , path .name )
2772- if m :
2773- libs .append ((str (path .parent ), m .string [slice (* m .span ())]))
2774- found = True
2775- break
2776- if not found :
2777- _logger .debug ("Required file '%s' not found" , req )
2778- raise RuntimeError (f"Required file { req } not found" )
2779- return libs
2780-
2781- def get_cxx_library_dirs ():
2782- cmd = [config .cxx , "-print-search-dirs" ]
2783- p = subprocess_Popen (
2784- cmd ,
2785- stdout = subprocess .PIPE ,
2786- stderr = subprocess .PIPE ,
2787- stdin = subprocess .PIPE ,
2788- )
2789- (stdout , stderr ) = p .communicate (input = b"" )
2790- if p .returncode != 0 :
2791- warnings .warn (
2792- "Pytensor cxx failed to communicate its search dirs. As a consequence, "
2793- "it might not be possible to automatically determine the blas link flags to use.\n "
2794- f"Command that was run: { config .cxx } -print-search-dirs\n "
2795- f"Output printed to stderr: { stderr .decode (sys .stderr .encoding )} "
2796- )
2797- return []
2798-
2799- maybe_lib_dirs = [
2800- [Path (p ).resolve () for p in line [len ("libraries: =" ) :].split (":" )]
2801- for line in stdout .decode (sys .getdefaultencoding ()).splitlines ()
2802- if line .startswith ("libraries: =" )
2803- ]
2804- if len (maybe_lib_dirs ) > 0 :
2805- maybe_lib_dirs = maybe_lib_dirs [0 ]
2806- return [str (d ) for d in maybe_lib_dirs if d .exists () and d .is_dir ()]
2807-
2808- def check_libs (
2809- all_libs , required_libs , extra_compile_flags = None , cxx_library_dirs = None
2810- ) -> str :
2811- if cxx_library_dirs is None :
2812- cxx_library_dirs = []
2813- if extra_compile_flags is None :
2814- extra_compile_flags = []
2815- found_libs = check_required_file (
2816- all_libs ,
2817- required_libs ,
2818- )
2819- path_quote = '"' if sys .platform == "win32" else ""
2820- libdir_ldflags = list (
2821- dict .fromkeys (
2822- [
2823- f"-L{ path_quote } { lib_path } { path_quote } "
2824- for lib_path , _ in found_libs
2825- if lib_path not in cxx_library_dirs
2826- ]
2827- )
2828- )
2829-
2830- flags = (
2831- libdir_ldflags
2832- + [f"-l{ lib_name } " for _ , lib_name in found_libs ]
2833- + extra_compile_flags
2834- )
2835- res = try_blas_flag (flags )
2836- if res :
2837- if any ("mkl" in flag for flag in flags ):
2838- try :
2839- check_mkl_openmp ()
2840- except Exception as e :
2841- _logger .debug (e )
2842- _logger .debug ("The following blas flags will be used: '%s'" , res )
2843- return res
2844- else :
2845- _logger .debug ("Supplied flags '%s' failed to compile" , res )
2846- raise RuntimeError (f"Supplied flags { flags } failed to compile" )
2847-
28482856 # If no compiler is available we default to empty ldflags
28492857 if not config .cxx :
28502858 return ""
@@ -2854,7 +2862,7 @@ def check_libs(
28542862 else :
28552863 rpath = None
28562864
2857- cxx_library_dirs = get_cxx_library_dirs ()
2865+ cxx_library_dirs = _get_cxx_library_dirs ()
28582866 searched_library_dirs = cxx_library_dirs + _std_lib_dirs
28592867 if sys .platform == "win32" :
28602868 # Conda on Windows saves MKL libraries under CONDA_PREFIX\Library\bin
@@ -2884,7 +2892,7 @@ def check_libs(
28842892 try :
28852893 # 1. Try to use MKL with INTEL OpenMP threading
28862894 _logger .debug ("Checking MKL flags with intel threading" )
2887- return check_libs (
2895+ return _check_libs (
28882896 all_libs ,
28892897 required_libs = [
28902898 "mkl_core" ,
@@ -2901,7 +2909,7 @@ def check_libs(
29012909 try :
29022910 # 2. Try to use MKL with GNU OpenMP threading
29032911 _logger .debug ("Checking MKL flags with GNU OpenMP threading" )
2904- return check_libs (
2912+ return _check_libs (
29052913 all_libs ,
29062914 required_libs = ["mkl_core" , "mkl_rt" , "mkl_gnu_thread" , "gomp" , "pthread" ],
29072915 extra_compile_flags = [f"-Wl,-rpath,{ rpath } " ] if rpath is not None else [],
@@ -2924,7 +2932,7 @@ def check_libs(
29242932 try :
29252933 _logger .debug ("Checking Lapack + blas" )
29262934 # 4. Try to use LAPACK + BLAS
2927- return check_libs (
2935+ return _check_libs (
29282936 all_libs ,
29292937 required_libs = ["lapack" , "blas" , "cblas" , "m" ],
29302938 extra_compile_flags = [f"-Wl,-rpath,{ rpath } " ] if rpath is not None else [],
@@ -2935,7 +2943,7 @@ def check_libs(
29352943 try :
29362944 # 5. Try to use BLAS alone
29372945 _logger .debug ("Checking blas alone" )
2938- return check_libs (
2946+ return _check_libs (
29392947 all_libs ,
29402948 required_libs = ["blas" , "cblas" ],
29412949 extra_compile_flags = [f"-Wl,-rpath,{ rpath } " ] if rpath is not None else [],
@@ -2946,7 +2954,7 @@ def check_libs(
29462954 try :
29472955 # 6. Try to use openblas
29482956 _logger .debug ("Checking openblas" )
2949- return check_libs (
2957+ return _check_libs (
29502958 all_libs ,
29512959 required_libs = ["openblas" , "gfortran" , "gomp" , "m" ],
29522960 extra_compile_flags = ["-fopenmp" , f"-Wl,-rpath,{ rpath } " ]
0 commit comments