Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Expand All @@ -80,6 +81,7 @@ ReactantFloat8sExt = "Float8s"
ReactantKernelAbstractionsExt = "KernelAbstractions"
ReactantLogExpFunctionsExt = ["IrrationalConstants", "LogExpFunctions"]
ReactantMPIExt = "MPI"
ReactantMetalExt = "Metal"
ReactantNNlibExt = ["NNlib", "Statistics"]
ReactantNPZExt = "NPZ"
ReactantOffsetArraysExt = "OffsetArrays"
Expand Down Expand Up @@ -122,6 +124,7 @@ LinearAlgebra = "1.10"
LogExpFunctions = "0.3"
MCMCDiagnosticTools = "0.3.11"
MPI = "0.20"
Metal = "1.4"
NNlib = "0.9.26"
NPZ = "0.4"
OffsetArrays = "1"
Expand Down
21 changes: 21 additions & 0 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,27 @@ REACTANT_ABI PjRtClient *MakeClientUsingPluginAPI(const char *device_type,
return GetCApiClient(client_name);
}

// Register a Julia-allocated PJRT_Api struct directly (no dlopen needed).
// Julia fills the struct with @cfunction pointers, passes it here.
REACTANT_ABI PjRtClient *MakeClientFromApi(const PJRT_Api *api,
const char *device_type,
const char *client_name,
const char **error) {
absl::Status set_status = pjrt::SetPjrtApi(device_type, api);
if (!set_status.ok()) {
auto str = set_status.message();
char *err = (char *)malloc(str.size() + 1);
memcpy(err, str.data(), str.size() + 1);
*error = err;
return nullptr;
}
if (InitializePjrtPlugin(device_type, error) == 1)
return nullptr;

RegisterProfiler(api);
return GetCApiClient(client_name);
}

REACTANT_ABI PjRtClient *MakeTPUClient(const char *tpu_path,
const char **error) {
// Prefer $TPU_LIBRARY_PATH if set
Expand Down
1 change: 1 addition & 0 deletions deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1125,6 +1125,7 @@ cc_library(
"-Wl,-exported_symbol,_LoadPjrtPlugin",
"-Wl,-exported_symbol,_InitializePjrtPlugin",
"-Wl,-exported_symbol,_MakeClientUsingPluginAPI",
"-Wl,-exported_symbol,_MakeClientFromApi",
"-Wl,-exported_symbol,_GetCApiClient",
"-Wl,-exported_symbol,_ClientNumDevices",
"-Wl,-exported_symbol,_ClientNumAddressableDevices",
Expand Down
8 changes: 7 additions & 1 deletion deps/build_local.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,13 @@ push!(build_cmd_list, "--jobs=$(parsed_args["jobs"])")
push!(build_cmd_list, "--experimental_ui_max_stdouterr_bytes=-1")
push!(build_cmd_list, "--sandbox_debug")

push!(build_cmd_list, "--linkopt=-fuse-ld=lld")
# push!(build_cmd_list, "--linkopt=-fuse-ld=lld") # lld not available on macOS

# On macOS, enable new toolchain resolution so Bazel uses platform-aware
# selection instead of legacy CPU-string matching (which maps "darwin" to x86)
if Sys.isapple()
push!(build_cmd_list, "--incompatible_enable_cc_toolchain_resolution")
end

for opt in parsed_args["copt"]
push!(build_cmd_list, "--copt=$(opt)")
Expand Down
67 changes: 67 additions & 0 deletions ext/ReactantMetalExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
module ReactantMetalExt

using Metal
using Metal.MPS
using Metal: MtlArray

# ObjectiveC primitives needed by @objc call sites in XLACompiler.jl
using Metal.MTL: @objc, id, nil, NSString, NSArray, NSDictionary, reinterpret

# Descriptor types needed by @objc [T alloc] calls (macro requires bare identifiers)
using Metal.MPSGraphs: MPSGraphConvolution2DOpDescriptor,
MPSGraphConvolution3DOpDescriptor,
MPSGraphPooling2DOpDescriptor,
MPSGraphPooling4DOpDescriptor

# Reactant's in-tree MLIR modules — no parameter injection needed
using Reactant: Reactant, MLIR
using Reactant.MLIR: IR, API

# Phase-1 PJRT plugin: 30 @cfunction callbacks + PJRT_Api struct + make_client()
include("ReactantMetalExt/PJRTPlugin.jl")

# @objc bindings for MPSGraph ops not wrapped by Metal.jl,
# plus julia_to_mps_dtype and mps_reshape helpers
include("ReactantMetalExt/XLACompiler.jl")

# MLIR walker: compile_mlir_module, MetalExecutable, execute!
include("ReactantMetalExt/MLIRWalker.jl")

export compile_mlir_module, MetalExecutable, execute!

function __init__()
@static if Sys.isapple()
if Metal.functional()
# Initialize @cfunction handles and register the PJRT_Api pointer
# so PJRT.MakeMetalClient() (no-args) can be called from XLA.jl.
try
init_pjrt_handles!()
# Expose the PJRT_Api struct pointer to Reactant's Client.jl
Reactant.XLA.PJRT._metal_pjrt_api_ptr[] = Ptr{Cvoid}(_PJRT_API_MEM)

# Create client via the shared PJRT.MetalClient() path (checkcount=false
# because initialize_default_clients! may not have run yet and the counter
# won't have been touched).
state = Reactant.XLA.global_backend_state
if haskey(state.clients, "metal")
# Already registered (e.g., XLA.jl's init block ran first).
state.default_client = state.clients["metal"]
else
metal = Reactant.XLA.PJRT.MetalClient(checkcount=false)
Reactant.XLA.PJRT.metal_client_count[] += 1
state.clients["metal"] = metal
state.default_client = metal
end
catch e
if e isa ErrorException && contains(e.msg, "MakeClientFromApi")
@warn "Metal PJRT backend requires rebuilt libReactantExtra. Run: julia --project=deps deps/build_local.jl"
else
@warn "Metal backend initialization failed" exception = e
end
end
end
end
return nothing
end

end # module ReactantMetalExt
Loading
Loading