Metal PJRT backend via MPSGraph + pure-Julia plugin#2489
Metal PJRT backend via MPSGraph + pure-Julia plugin#2489Dale-Black wants to merge 7 commits intoEnzymeAD:mainfrom
Conversation
Allows Julia-allocated PJRT_Api structs (filled with @cfunction pointers) to be registered directly with XLA without requiring dlopen of a shared library. This is the entry point for the Metal PJRT backend. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Replace download-based Metal.jl with has_metal()/setup_metal!() API - Add MakeMetalClientFromApi for Julia-side PJRT_Api registration - Enable Metal client initialization in XLA.jl when Metal.jl is loaded - Downgrade unimplemented platform properties log to @debug Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Pure-Julia Metal GPU backend for Reactant via MLIR walking + MPSGraph. - PJRTPlugin.jl: 30 @cfunction PJRT callbacks implementing the PJRT_Api struct for PjRtCApiClient initialization (no shared library needed) - XLACompiler.jl: @objc MPSGraph bindings for ops not wrapped by Metal.jl - MLIRWalker.jl: MLIR → MPSGraph translation supporting element-wise ops, dot_general, broadcast, reshape, transpose, reduce, conv2d/3d, pooling, concatenate, slice, scatter, and reverse - Thread-safe buffer operations (METAL_XLA_LOCK) to prevent heap corruption from concurrent GC finalizer and main thread access Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Add Metal to [weakdeps] and [compat] in Project.toml - Register ReactantMetalExt in [extensions] - Add Metal to test/Project.toml [deps] - Fix test/plugins/metal.jl: add Metal.functional() guard, broadcasting, correct gradient indexing - Enable Metal tests on macOS in runtests.jl Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
lld is not available on macOS. Enable --incompatible_enable_cc_toolchain_resolution so Bazel uses platform-aware toolchain selection instead of legacy CPU-string matching (which incorrectly maps "darwin" to x86). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #2489 +/- ##
===========================================
- Coverage 68.16% 36.64% -31.52%
===========================================
Files 109 200 +91
Lines 11779 31298 +19519
===========================================
+ Hits 8029 11469 +3440
- Misses 3750 19829 +16079 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
ext/ReactantMetalExt/MLIRWalker.jl
Outdated
| # ============================================================================ | ||
|
|
||
| """Extract contracting_dims from dot_general op text.""" | ||
| function parse_contracting_dims(op_text::AbstractString) |
There was a problem hiding this comment.
We dont need to parse the string here, we should be able to query the operation to extract these info
|
Split out the C++/Bazel changes into #2490 per @avik-pal's request. That PR adds only Once the JLL is rebuilt with that symbol, the Julia changes in this PR will work against the new JLL (no more I'll rebase this PR to remove the |
Address review feedback: the walker was using regex on `string(op)` and `string(IR.type(...))` to extract attributes and type information from MLIR operations. This replaces all 12 `parse_*` functions with proper MLIR C API calls through Reactant's IR module. The main patterns used: - IR.getattr(op, name) + DenseArray indexing for simple attributes (broadcast_dimensions, permutation, dimensions, window_strides, etc.) - StableHLO C API for structured attributes (stablehloConvDimensionNumbers*, stablehloDotDimensionNumbers*) - API.mlirDenseElementsAttrGetInt64Value for DenseElements attributes where the Julia wrapper has a known bug (padding attrs) - IR.type/IR.ndims/IR.size/IR.eltype for type inspection, replacing regex parsing of "tensor<4x8xf32>" strings This is a pure refactor — no behavioral changes. All existing tests pass identically before and after. Net result is -163 lines, since the API calls are more concise than the regex parsers they replace. There is one remaining `string(op)` call used for error messages on unrecognized ops, which is a legitimate diagnostic use rather than attribute extraction. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
MLIR API Refactor — Addressing String Parsing Feedback@avik-pal — thanks for the review. I interpreted your comment about not needing to parse strings as referring to the What changed (latest commit)Replaced all string-based attribute/type extraction with API calls:
Net result is -163 lines since the API calls are more concise than the regex parsers. What we tested
I'll share some screenshots from the local Pluto benchmark notebook in a follow-up comment. Please let me know if this is what you had in mind or if there are other areas that need attention — still learning my way around the MLIR infrastructure here. |
Per review feedback: the PjRtCApiClient finalizer-vs-main-thread race condition should be fixed generically in core Reactant, not per-backend. Removed METAL_XLA_LOCK, Base.convert override, free_buffer override, to_host override, and __precompile__(false). See EnzymeAD#2493 for the generic fix. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>



Summary
Pure-Julia Metal GPU backend for Reactant on Apple Silicon. Instead of depending on an external PJRT plugin shared library (the old jax-metal
.dylibapproach, which is no longer compatible with the current OpenXLA), this implements the full PJRT callback interface directly in Julia using@cfunctionpointers, then walks the optimized StableHLO IR to build an equivalent MPSGraph that executes on the Metal GPU.Target UX:
using Reactant, Metal; @jit f(x)— transparent dispatch, no special API.How it works
The optimization pipeline has two layers: XLA/MLIR does high-level fusion and CSE on the IR, then MPSGraph does Metal-specific kernel fusion and scheduling on the GPU side.
What's included
MakeClientFromApi): Registers a Julia-allocatedPJRT_Apistruct directly with XLA — nodlopenneededPJRTPlugin.jl): FullPJRT_Apiimplementation covering client lifecycle, device/memory discovery, buffer management, compilation, and executionMLIRWalker.jl): Translates StableHLO ops to MPSGraph nodes — supports element-wise ops,dot_general,broadcast_in_dim,reshape,transpose,reduce(sum/max),conv2d/conv3d,reduce_window(pooling 2D/3D),concatenate,slice,scatter,reverse, andconstantXLACompiler.jl): MPSGraph operations not wrapped by Metal.jlMETAL_XLA_LOCKserializes buffer operations to prevent heap corruption from concurrent GC finalizer and main thread access toPjRtCApiClient@jitcalls to avoid per-call allocationlldlinker (unavailable on macOS) and enables platform-aware Bazel toolchain resolutionWhat works today
sin,cos,exp,tanh,relu, etc.)ChainmodelsArchitecture decisions
ReactantMetalExt): Loaded automatically whenusing Metalbrings Metal.jl into scope. No changes needed to user code.__precompile__(false): Required because the extension overridesBase.convert,XLA.free_buffer, andXLA.to_hostfor thread-safety. Julia disallows method overwrites during precompilation.@cfunctionpointers stored in aLibc.malloc'd struct. This eliminates the need for any external binary beyond the existinglibReactantExtra.placeholderTensorauto-reverses Julia shapes. The walker uses IR shapes directly for all operations, with layout permutations only at conv/pool boundaries.Development process
This backend was developed over ~48 commits using an autonomous agent loop ("ralph loop") powered by Claude Code. The agent iteratively implemented and verified each component — from the initial PJRT callback prototype through conv layout bugs and thread-safety fixes. This PR is a clean 5-commit squash of that work onto
origin/main, containing only the necessary production code. All development scaffolding (research files, debug tests, benchmark notebooks) has been removed.Known limitations
stablehlo.convertis identity-only (no actual dtype casting yet)reducefor min/prodFiles changed (15 files, +3,395 / -77)
deps/ReactantExtra/API.cppMakeClientFromApi()deps/ReactantExtra/BUILDdeps/build_local.jlsrc/accelerators/Metal.jlhas_metal()/setup_metal!()src/xla/Device.jl@warn→@debugsrc/xla/PJRT/Client.jlMakeMetalClientFromApi,_metal_pjrt_api_ptrsrc/xla/XLA.jlext/ReactantMetalExt.jlext/ReactantMetalExt/MLIRWalker.jlext/ReactantMetalExt/PJRTPlugin.jlext/ReactantMetalExt/XLACompiler.jlProject.tomltest/Project.tomltest/plugins/metal.jltest/runtests.jlTest plan
julia test/plugins/metal.jlon macOS with Apple Silicon — sincos, autodiff, CNN all passjulia -e 'using Reactant; println(Reactant.XLA.default_backend())'— basic Reactant still works on non-MacMetalis NOT in[deps](only[weakdeps]) — no new mandatory dependency🤖 Generated with Claude Code