diff --git a/src/accelerators/Accelerators.jl b/src/accelerators/Accelerators.jl index 88476922f6..278d0acade 100644 --- a/src/accelerators/Accelerators.jl +++ b/src/accelerators/Accelerators.jl @@ -2,5 +2,6 @@ module Accelerators include("TPU.jl") include("Metal.jl") +include("ROCm.jl") end diff --git a/src/accelerators/ROCm.jl b/src/accelerators/ROCm.jl new file mode 100644 index 0000000000..83236cdb92 --- /dev/null +++ b/src/accelerators/ROCm.jl @@ -0,0 +1,62 @@ +module ROCm + +using Reactant: Reactant +using Scratch: @get_scratch! +using Downloads + +const rocm_pjrt_plugin_dir = Ref{Union{Nothing,String}}(nothing) + +function __init__() + @static if Sys.islinux() + Reactant.precompiling() || setup_rocm_pjrt_plugin!() + end +end + +has_rocm() = true + +function setup_rocm_pjrt_plugin!() + path_from_env = get(ENV, "ROCM_LIBRARY_PATH", nothing) + if path_from_env !== nothing && ispath(path_from_env) + rocm_pjrt_plugin_dir[] = path_from_env + else + rocm_pjrt_plugin_dir[] = @get_scratch!("pjrt_rocm_plugin") + end + # download_rocm_pjrt_plugin_if_needed(rocm_pjrt_plugin_dir[]) + return nothing +end + +get_rocm_pjrt_plugin_dir() = rocm_pjrt_plugin_dir[] + +function get_rocm_pjrt_plugin_path() + return joinpath(get_rocm_pjrt_plugin_dir(), "xla_rocm_plugin.so") +end + +# function download_rocm_pjrt_plugin_if_needed(path=nothing) +# path === nothing && (path = get_rocm_pjrt_plugin_dir()) +# @assert path !== nothing "rocm_pjrt_plugin_dir is not set!" + +# rocm_pjrt_plugin_path = joinpath(path, "pjrt_plugin_rocm_14.dylib") +# if !isfile(rocm_pjrt_plugin_path) +# zip_file_path = joinpath(path, "pjrt-plugin-rocm.zip") +# tmp_dir = joinpath(path, "tmp") +# Downloads.download( +# if Sys.ARCH === :aarch64 +# "https://files.pythonhosted.org/packages/09/dc/6d8fbfc29d902251cf333414cf7dcfaf4b252a9920c881354584ed36270d/jax_rocm-0.1.1-py3-none-macosx_13_0_arm64.whl" +# elseif Sys.ARCH === :x86_64 +# "https://files.pythonhosted.org/packages/87/ec/9bb7f7f0ffd06c3fb89813126b2f698636ac7a4263ed7bdd1ff7d7c94f8f/jax_rocm-0.1.1-py3-none-macosx_10_14_x86_64.whl" +# else +# error("Unsupported architecture: $(Sys.ARCH)") +# end, +# zip_file_path, +# ) +# run(`unzip -qq $(zip_file_path) -d $(tmp_dir)`) +# mv( +# joinpath(tmp_dir, "jax_plugins", "rocm_plugin", "pjrt_plugin_rocm_14.dylib"), +# rocm_pjrt_plugin_path, +# ) +# rm(tmp_dir; recursive=true) +# rm(zip_file_path; recursive=true) +# end +# end + +end # module ROCm diff --git a/src/xla/IFRT/Client.jl b/src/xla/IFRT/Client.jl index e13a308893..0cc6391f42 100644 --- a/src/xla/IFRT/Client.jl +++ b/src/xla/IFRT/Client.jl @@ -115,12 +115,14 @@ const cpu_client_count = Ref(0) const cuda_client_count = Ref(0) const tpu_client_count = Ref(0) const metal_client_count = Ref(0) +const rocm_client_count = Ref(0) for (backend, counter) in ( (:CPUClient, :cpu_client_count), (:CUDAClient, :cuda_client_count), (:TPUClient, :tpu_client_count), (:MetalClient, :metal_client_count), + (:ROCmClient, :rocm_client_count), ) main_fn = Symbol(:MakeIFRTPJRT, backend) @eval function $(backend)(args...; checkcount::Bool=true, kwargs...) @@ -219,6 +221,22 @@ function MakeIFRTPJRTMetalClient(; ) end +function MakeIFRTPJRTROCmClient(; + rocm_pjrt_plugin_path::String, + node_id::Integer=0, + num_nodes::Integer=1, + distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing, +) + return MakeIFRTPJRTClientViaPluginAPI( + rocm_pjrt_plugin_path, + "rocm", + "ROCM"; + node_id, + num_nodes, + distributed_runtime_client, + ) +end + function MakeIFRTPJRTClientViaPluginAPI( library_path::String, device_type::String, diff --git a/src/xla/PJRT/Client.jl b/src/xla/PJRT/Client.jl index c45aeac1a1..fb03dd7060 100644 --- a/src/xla/PJRT/Client.jl +++ b/src/xla/PJRT/Client.jl @@ -110,12 +110,14 @@ const cpu_client_count = Ref(0) const cuda_client_count = Ref(0) const tpu_client_count = Ref(0) const metal_client_count = Ref(0) +const rocm_client_count = Ref(0) for (backend, counter) in ( (:CPUClient, :cpu_client_count), (:CUDAClient, :cuda_client_count), (:TPUClient, :tpu_client_count), (:MetalClient, :metal_client_count), + (:ROCmClient, :rocm_client_count), ) main_fn = Symbol(:Make, backend) @eval function $(backend)(args...; checkcount::Bool=true, kwargs...) @@ -207,6 +209,20 @@ function MakeMetalClient(; return MakeClientUsingPluginAPI(metal_pjrt_plugin_path, "metal", "METAL") end +function MakeROCmClient(; + rocm_pjrt_plugin_path::String, + node_id::Integer=0, + num_nodes::Integer=1, + distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing, +) + @assert node_id == 0 "`PJRT.MakeROCmClient` does not support node_id" + @assert num_nodes == 1 "`PJRT.MakeROCmClient` does not support num_nodes > 1" + @assert distributed_runtime_client === nothing "`PJRT.MakeROCmClient` does not support \ + distributed_runtime_client" + + return MakeClientUsingPluginAPI(rocm_pjrt_plugin_path, "rocm", "ROCM") +end + function MakeClientUsingPluginAPI( library_path::String, device_type::String, client_name::String=uppercase(device_type) ) diff --git a/src/xla/XLA.jl b/src/xla/XLA.jl index 058e060ee3..c5ce6cd49a 100644 --- a/src/xla/XLA.jl +++ b/src/xla/XLA.jl @@ -221,6 +221,22 @@ for runtime in (:PJRT, :IFRT) catch e println(stdout, e) end + elseif Accelerators.ROCm.has_rocm() + try + if was_initialized && haskey(state.clients, "rocm") + XLA.free_client(state.clients["rocm"]) + XLA.$(runtime).rocm_client_count[] -= 1 + end + gpu = $(runtime).ROCmClient( + ; + rocm_pjrt_plugin_path=Accelerators.ROCm.get_rocm_pjrt_plugin_path(), + common_kwargs... + ) + state.clients["rocm"] = gpu + state.default_client = gpu + catch e + println(stdout, e) + end else try if was_initialized && haskey(state.clients, "cuda")