Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Expand All @@ -80,6 +81,7 @@ ReactantFloat8sExt = "Float8s"
ReactantKernelAbstractionsExt = "KernelAbstractions"
ReactantLogExpFunctionsExt = ["IrrationalConstants", "LogExpFunctions"]
ReactantMPIExt = "MPI"
ReactantMetalExt = "Metal"
ReactantNNlibExt = ["NNlib", "Statistics"]
ReactantNPZExt = "NPZ"
ReactantOffsetArraysExt = "OffsetArrays"
Expand Down Expand Up @@ -122,6 +124,7 @@ LinearAlgebra = "1.10"
LogExpFunctions = "0.3"
MCMCDiagnosticTools = "0.3.11"
MPI = "0.20"
Metal = "1.4"
NNlib = "0.9.26"
NPZ = "0.4"
OffsetArrays = "1"
Expand Down
21 changes: 21 additions & 0 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,27 @@ REACTANT_ABI PjRtClient *MakeClientUsingPluginAPI(const char *device_type,
return GetCApiClient(client_name);
}

// Register a Julia-allocated PJRT_Api struct directly (no dlopen needed).
// Julia fills the struct with @cfunction pointers, passes it here.
REACTANT_ABI PjRtClient *MakeClientFromApi(const PJRT_Api *api,
const char *device_type,
const char *client_name,
const char **error) {
absl::Status set_status = pjrt::SetPjrtApi(device_type, api);
if (!set_status.ok()) {
auto str = set_status.message();
char *err = (char *)malloc(str.size() + 1);
memcpy(err, str.data(), str.size() + 1);
*error = err;
return nullptr;
}
if (InitializePjrtPlugin(device_type, error) == 1)
return nullptr;

RegisterProfiler(api);
return GetCApiClient(client_name);
}

REACTANT_ABI PjRtClient *MakeTPUClient(const char *tpu_path,
const char **error) {
// Prefer $TPU_LIBRARY_PATH if set
Expand Down
1 change: 1 addition & 0 deletions deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1125,6 +1125,7 @@ cc_library(
"-Wl,-exported_symbol,_LoadPjrtPlugin",
"-Wl,-exported_symbol,_InitializePjrtPlugin",
"-Wl,-exported_symbol,_MakeClientUsingPluginAPI",
"-Wl,-exported_symbol,_MakeClientFromApi",
"-Wl,-exported_symbol,_GetCApiClient",
"-Wl,-exported_symbol,_ClientNumDevices",
"-Wl,-exported_symbol,_ClientNumAddressableDevices",
Expand Down
8 changes: 7 additions & 1 deletion deps/build_local.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,13 @@ push!(build_cmd_list, "--jobs=$(parsed_args["jobs"])")
push!(build_cmd_list, "--experimental_ui_max_stdouterr_bytes=-1")
push!(build_cmd_list, "--sandbox_debug")

push!(build_cmd_list, "--linkopt=-fuse-ld=lld")
# push!(build_cmd_list, "--linkopt=-fuse-ld=lld") # lld not available on macOS

# On macOS, enable new toolchain resolution so Bazel uses platform-aware
# selection instead of legacy CPU-string matching (which maps "darwin" to x86)
if Sys.isapple()
push!(build_cmd_list, "--incompatible_enable_cc_toolchain_resolution")
end

for opt in parsed_args["copt"]
push!(build_cmd_list, "--copt=$(opt)")
Expand Down
147 changes: 147 additions & 0 deletions ext/ReactantMetalExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Metal backend for Reactant — not precompiled because it overrides
# Base.convert, Reactant.XLA.free_buffer, and Reactant.XLA.to_host to add
# thread-safety for host buffer transfers.
# (Julia disallows method overwriting during precompilation.)
__precompile__(false)

module ReactantMetalExt

using Metal
using Metal.MPS
using Metal: MtlArray

# ObjectiveC primitives needed by @objc call sites in XLACompiler.jl
using Metal.MTL: @objc, id, nil, NSString, NSArray, NSDictionary, reinterpret

# Descriptor types needed by @objc [T alloc] calls (macro requires bare identifiers)
using Metal.MPSGraphs: MPSGraphConvolution2DOpDescriptor,
MPSGraphConvolution3DOpDescriptor,
MPSGraphPooling2DOpDescriptor,
MPSGraphPooling4DOpDescriptor

# Reactant's in-tree MLIR modules — no parameter injection needed
using Reactant: Reactant, MLIR
using Reactant.MLIR: IR, API

# Phase-1 PJRT plugin: 30 @cfunction callbacks + PJRT_Api struct + make_client()
include("ReactantMetalExt/PJRTPlugin.jl")

# @objc bindings for MPSGraph ops not wrapped by Metal.jl,
# plus julia_to_mps_dtype and mps_reshape helpers
include("ReactantMetalExt/XLACompiler.jl")

# MLIR walker: compile_mlir_module, MetalExecutable, execute!
include("ReactantMetalExt/MLIRWalker.jl")

export compile_mlir_module, MetalExecutable, execute!

# ─── Thread-safe Metal PJRT buffer operations ────────────────────────────────
#
# Julia 1.9+ runs GC finalizers in a dedicated finalizer thread. When the Julia
# GC triggers, old PJRT buffer wrappers from previous @jit calls are finalized:
#
# finalizer thread: free_buffer → PjRtBufferFree → delete PjRtCApiBuffer
# → ~PjRtCApiBuffer() → accesses PjRtCApiClient shared state
#
# main thread: BufferToHost → PjRtCApiBuffer::ToLiteralSync()
# → accesses PjRtCApiClient shared state
#
# XLA's PjRtCApiClient is NOT thread-safe. Concurrent access from the finalizer
# thread and the main thread causes heap corruption (std::bad_alloc).
#
# Fix: METAL_XLA_LOCK serializes every call that enters XLA's C++ wrapper layer.
# Both free_buffer (finalizer thread) and to_host (main thread) must hold the
# lock before calling PjRtBufferFree / BufferToHost respectively.
#
# Note: GC.enable_finalizers(false/true) alone is insufficient because it only
# prevents NEW finalizers from being dequeued — already-running finalizers
# continue concurrently. A proper mutex is needed.

const METAL_XLA_LOCK = ReentrantLock()

# Override Base.convert for ConcretePJRTArray.
# Disabling finalizers here reduces lock contention: no new finalizers can start
# between the output-buffer allocation and the BufferToHost call, so METAL_XLA_LOCK
# inside to_host is almost never contended.
function Base.convert(::Type{<:Array}, X::Reactant.ConcretePJRTArray{T,N}) where {T,N}
GC.enable_finalizers(false)
try
if Reactant.has_padding(X)
padding = Reactant.get_padding(X)
data = Array{T,N}(undef, (size(X) .+ padding)...)
Reactant.write_to_host_buffer!(data, X)
return view(data, [1:size(X, i) for i in 1:ndims(X)]...)
else
data = Array{T,N}(undef, size(X)...)
Reactant.write_to_host_buffer!(data, X)
return data
end
finally
GC.enable_finalizers(true)
end
end

# Override free_buffer so that PjRtBufferFree (called from the Julia GC finalizer
# thread) cannot overlap with BufferToHost on the main thread.
function Reactant.XLA.free_buffer(buffer::Reactant.XLA.PJRT.Buffer)
sbuffer = buffer.buffer
if sbuffer != C_NULL && Reactant.XLA.is_live[]
@lock METAL_XLA_LOCK begin
@ccall Reactant.MLIR.API.mlir_c.PjRtBufferFree(sbuffer::Ptr{Cvoid})::Cvoid
end
end
end

# Override to_host so that BufferToHost cannot overlap with PjRtBufferFree.
function Reactant.XLA.to_host(
buffer::Reactant.XLA.PJRT.Buffer,
data,
sharding,
)
@assert buffer.buffer !== C_NULL
@lock METAL_XLA_LOCK begin
GC.@preserve buffer data begin
@ccall Reactant.MLIR.API.mlir_c.BufferToHost(
buffer.buffer::Ptr{Cvoid}, data::Ptr{Cvoid}
)::Cvoid
end
end
return data
end

function __init__()
@static if Sys.isapple()
if Metal.functional()
# Initialize @cfunction handles and register the PJRT_Api pointer
# so PJRT.MakeMetalClient() (no-args) can be called from XLA.jl.
try
init_pjrt_handles!()
# Expose the PJRT_Api struct pointer to Reactant's Client.jl
Reactant.XLA.PJRT._metal_pjrt_api_ptr[] = Ptr{Cvoid}(_PJRT_API_MEM)

# Create client via the shared PJRT.MetalClient() path (checkcount=false
# because initialize_default_clients! may not have run yet and the counter
# won't have been touched).
state = Reactant.XLA.global_backend_state
if haskey(state.clients, "metal")
# Already registered (e.g., XLA.jl's init block ran first).
state.default_client = state.clients["metal"]
else
metal = Reactant.XLA.PJRT.MetalClient(checkcount=false)
Reactant.XLA.PJRT.metal_client_count[] += 1
state.clients["metal"] = metal
state.default_client = metal
end
catch e
if e isa ErrorException && contains(e.msg, "MakeClientFromApi")
@warn "Metal PJRT backend requires rebuilt libReactantExtra. Run: julia --project=deps deps/build_local.jl"
else
@warn "Metal backend initialization failed" exception = e
end
end
end
end
return nothing
end

end # module ReactantMetalExt
Loading