@@ -77,9 +77,14 @@ def configure_step(self):
7777
7878 # Collect options for the build script
7979 # Used only by the build script
80+ options = []
81+
82+ # update build command for jaxlib-0.6 to build.py build
83+ if LooseVersion (self .version ) >= LooseVersion ('0.6.0' ):
84+ options .append ('build' )
8085
8186 # C++ flags are set through copt below
82- options = [ '--target_cpu_features=default' ]
87+ options . append ( '--target_cpu_features=default' )
8388
8489 # Passed directly to bazel
8590 bazel_startup_options = [
@@ -125,13 +130,19 @@ def configure_step(self):
125130 options .append ('--noenable_nccl' )
126131
127132 config_env_vars ['GCC_HOST_COMPILER_PATH' ] = which (os .getenv ('CC' ))
128- else :
133+ elif LooseVersion ( self . version ) <= LooseVersion ( '0.6.0' ) :
129134 options .append ('--noenable_cuda' )
130135
131136 if self .cfg ['use_mkl_dnn' ]:
132- options .append ('--enable_mkl_dnn' )
133- else :
137+ # --enable_mkl_dnn option was removed in jax(lib) v0.4.36,
138+ # see https://github.com/jax-ml/jax/commit/676151265859f8b0dd8baf6f6ae50c3367ed0509
139+ if LooseVersion (self .version ) < LooseVersion ('0.4.36' ):
140+ options .append ('--enable_mkl_dnn' )
141+ # if use_mkl_dnn is not enabled, use correct flag to disable use of MKL DNN
142+ elif LooseVersion (self .version ) < LooseVersion ('0.4.36' ):
134143 options .append ('--noenable_mkl_dnn' )
144+ else :
145+ options .append ('--disable_mkl_dnn' )
135146
136147 # Prepend to buildopts so users can overwrite this
137148 self .cfg ['buildopts' ] = ' ' .join (
0 commit comments