Skip to content

Commit 0eb0f0c

Browse files
authored
Merge pull request #1 from boegel/jax_update
fix logic in jaxlib w.r.t. disabling use of MKL DNN
2 parents 76c8fb0 + 15dd83b commit 0eb0f0c

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

easybuild/easyblocks/j/jaxlib.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,12 +133,16 @@ def configure_step(self):
133133
elif LooseVersion(self.version) <= LooseVersion('0.6.0'):
134134
options.append('--noenable_cuda')
135135

136-
if self.cfg["use_mkl_dnn"] and LooseVersion(self.version) <= LooseVersion("0.6.0"):
137-
options.append("--enable_mkl_dnn")
138-
elif LooseVersion(self.version) <= LooseVersion("0.6.0"):
139-
options.append("--noenable_mkl_dnn")
136+
if self.cfg['use_mkl_dnn']:
137+
# --enable_mkl_dnn option was removed in jax(lib) v0.4.36,
138+
# see https://github.com/jax-ml/jax/commit/676151265859f8b0dd8baf6f6ae50c3367ed0509
139+
if LooseVersion(self.version) < LooseVersion('0.4.36'):
140+
options.append('--enable_mkl_dnn')
141+
# if use_mkl_dnn is not enabled, use correct flag to disable use of MKL DNN
142+
elif LooseVersion(self.version) < LooseVersion('0.4.36'):
143+
options.append('--noenable_mkl_dnn')
140144
else:
141-
options.append("--disable_mkl_dnn")
145+
options.append('--disable_mkl_dnn')
142146

143147
# Prepend to buildopts so users can overwrite this
144148
self.cfg['buildopts'] = ' '.join(

0 commit comments

Comments
 (0)