Skip to content

Commit 3da638e

Browse files
wsmosesWilliam Moses
andauthored
CUDA build local (#94)
Co-authored-by: William Moses <[email protected]>
1 parent 950302f commit 3da638e

File tree

3 files changed

+26
-3
lines changed

3 files changed

+26
-3
lines changed

deps/ReactantExtra/.bazelrc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,17 @@ build --define=allow_oversize_protos=true
1414

1515
build -c opt
1616

17+
18+
build:cuda --repo_env TF_NEED_CUDA=1
19+
build:cuda --repo_env TF_NCCL_USE_STUB=1
20+
build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2"
21+
build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1"
22+
# "sm" means we emit only cubin, which is forward compatible within a GPU generation.
23+
# "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations.
24+
build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90"
25+
build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
26+
build:cuda --@local_config_cuda//:enable_cuda
27+
build:cuda --@xla//xla/python:jax_cuda_pip_rpaths=true
28+
# Default hermetic CUDA and CUDNN versions.
29+
build:cuda --@local_config_cuda//cuda:include_cuda_libs=true
30+
build:cuda --@local_config_cuda//:cuda_compiler=nvcc

deps/build_local.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,14 @@ run(Cmd(`$(Base.julia_cmd().exec[1]) --project=. -e "using Pkg; Pkg.instantiate(
3737
# --@local_config_cuda//:cuda_compiler=nvcc
3838
# --crosstool_top="@local_config_cuda//crosstool:toolchain"
3939

40-
run(Cmd(`bazel query --experimental_repo_remote_exec --repo_env HERMETIC_PYTHON_VERSION="3.10" "allpaths(:libReactantExtra.so, @llvm-project//mlir:CAPIIR)" --output graph`, dir=source_dir,
41-
))
42-
run(Cmd(`bazel build -c dbg --action_env=JULIA=$(Base.julia_cmd().exec[1])
40+
arg = try
41+
run(Cmd(`nvidia-smi`))
42+
"--config=cuda"
43+
catch
44+
""
45+
end
46+
47+
run(Cmd(`bazel build $(arg) -c dbg --action_env=JULIA=$(Base.julia_cmd().exec[1])
4348
--repo_env HERMETIC_PYTHON_VERSION="3.10"
4449
--check_visibility=false --verbose_failures :libReactantExtra.so :Builtin.inc.jl :Arith.inc.jl :Affine.inc.jl :Func.inc.jl :Enzyme.inc.jl :StableHLO.inc.jl :CHLO.inc.jl :VHLO.inc.jl`, dir=source_dir,
4550
))

src/XLA.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ function from_row_major(x::Matrix{T}) where {T}
3838
return transpose(x)
3939
end
4040

41+
SetLogLevel(x) = @ccall MLIR.API.mlir_c.SetLogLevel(x::Cint)::Cvoid
42+
4143
const cpuclientcount = Ref(0)
4244
# TODO synchronization when async is not working because `future` in `ConcreteRArray` is always `nothing`
4345
function CPUClient(asynchronous=false, node_id=0, num_nodes=1)
@@ -93,6 +95,8 @@ using Scratch, Downloads
9395
function __init__()
9496
initLogs = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "InitializeLogs")
9597
ccall(initLogs, Cvoid, ())
98+
# Add most log level
99+
SetLogLevel(0)
96100
cpu = CPUClient()
97101
backends["cpu"] = cpu
98102
default_backend[] = cpu

0 commit comments

Comments
 (0)