Skip to content

Commit c78ca04

Browse files
Add experimental support for building JAX CPU and GPU wheels with GCC.
The `build.py` script uses Clang compiler by default, and JAX doesn't support building with GCC officially. However, experimental GCC support is still present. Command examples: ``` python build/build.py build --wheels=jaxlib,jax-cuda-plugin --use_clang=false python build/build.py build --wheels=jaxlib,jax-cuda-plugin --use_clang=false --gcc_path=/use/bin/gcc ``` This change addresses the request in #25488. PiperOrigin-RevId: 707930913
1 parent 49c9246 commit c78ca04

File tree

2 files changed

+37
-12
lines changed

2 files changed

+37
-12
lines changed

build/build.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,15 @@ def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser):
277277
""",
278278
)
279279

280+
compile_group.add_argument(
281+
"--gcc_path",
282+
type=str,
283+
default="",
284+
help="""
285+
Path to the GCC binary to use.
286+
""",
287+
)
288+
280289
compile_group.add_argument(
281290
"--disable_mkl_dnn",
282291
action="store_true",
@@ -481,7 +490,13 @@ async def main():
481490
# versions of Clang.
482491
wheel_build_command.append("--config=clang")
483492
else:
484-
logging.debug("Use Clang: False")
493+
gcc_path = args.gcc_path or utils.get_gcc_path_or_exit()
494+
logging.debug(
495+
"Using GCC as the compiler, gcc path: %s",
496+
gcc_path,
497+
)
498+
wheel_build_command.append(f"--repo_env=CC=\"{gcc_path}\"")
499+
wheel_build_command.append(f"--repo_env=BAZEL_COMPILER=\"{gcc_path}\"")
485500

486501
if not args.disable_mkl_dnn:
487502
logging.debug("Enabling MKL DNN")
@@ -515,12 +530,16 @@ async def main():
515530

516531
if "cuda" in wheel:
517532
wheel_build_command.append("--config=cuda")
518-
wheel_build_command.append(
533+
if args.use_clang:
534+
wheel_build_command.append(
519535
f"--action_env=CLANG_CUDA_COMPILER_PATH=\"{clang_path}\""
520536
)
521-
if args.build_cuda_with_clang:
522-
logging.debug("Building CUDA with Clang")
523-
wheel_build_command.append("--config=build_cuda_with_clang")
537+
if args.build_cuda_with_clang:
538+
logging.debug("Building CUDA with Clang")
539+
wheel_build_command.append("--config=build_cuda_with_clang")
540+
else:
541+
logging.debug("Building CUDA with NVCC")
542+
wheel_build_command.append("--config=build_cuda_with_nvcc")
524543
else:
525544
logging.debug("Building CUDA with NVCC")
526545
wheel_build_command.append("--config=build_cuda_with_nvcc")

build/tools/utils.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -170,19 +170,25 @@ def get_bazel_version(bazel_path):
170170
return None
171171
return tuple(int(x) for x in match.group(1).split("."))
172172

173-
def get_clang_path_or_exit():
174-
which_clang_output = shutil.which("clang")
175-
if which_clang_output:
176-
# If we've found a clang on the path, need to get the fully resolved path
173+
def get_compiler_path_or_exit(compiler_path_flag, compiler_name):
174+
which_compiler_output = shutil.which(compiler_name)
175+
if which_compiler_output:
176+
# If we've found a compiler on the path, need to get the fully resolved path
177177
# to ensure that system headers are found.
178-
return str(pathlib.Path(which_clang_output).resolve())
178+
return str(pathlib.Path(which_compiler_output).resolve())
179179
else:
180180
print(
181-
"--clang_path is unset and clang cannot be found"
182-
" on the PATH. Please pass --clang_path directly."
181+
f"--{compiler_path_flag} is unset and {compiler_name} cannot be found"
182+
" on the PATH. Please pass --{compiler_path_flag} directly."
183183
)
184184
sys.exit(-1)
185185

186+
def get_gcc_path_or_exit():
187+
return get_compiler_path_or_exit("gcc_path", "gcc")
188+
189+
def get_clang_path_or_exit():
190+
return get_compiler_path_or_exit("clang_path", "clang")
191+
186192
def get_clang_major_version(clang_path):
187193
clang_version_proc = subprocess.run(
188194
[clang_path, "-E", "-P", "-"],

0 commit comments

Comments
 (0)