Skip to content

Comments

Metal PJRT backend via MPSGraph + pure-Julia plugin#2489

Open
Dale-Black wants to merge 7 commits intoEnzymeAD:mainfrom
Dale-Black:metal-pjrt-backend
Open

Metal PJRT backend via MPSGraph + pure-Julia plugin#2489
Dale-Black wants to merge 7 commits intoEnzymeAD:mainfrom
Dale-Black:metal-pjrt-backend

Conversation

@Dale-Black
Copy link
Contributor

Summary

Pure-Julia Metal GPU backend for Reactant on Apple Silicon. Instead of depending on an external PJRT plugin shared library (the old jax-metal .dylib approach, which is no longer compatible with the current OpenXLA), this implements the full PJRT callback interface directly in Julia using @cfunction pointers, then walks the optimized StableHLO IR to build an equivalent MPSGraph that executes on the Metal GPU.

Target UX: using Reactant, Metal; @jit f(x) — transparent dispatch, no special API.

How it works

Julia code → Reactant tracing → XLA/MLIR optimization (fusion, CSE, layout opt)
→ Optimized StableHLO IR → PJRT compile callback → MLIR walker → MPSGraph → Metal GPU

The optimization pipeline has two layers: XLA/MLIR does high-level fusion and CSE on the IR, then MPSGraph does Metal-specific kernel fusion and scheduling on the GPU side.

What's included

  • C++ bridge (MakeClientFromApi): Registers a Julia-allocated PJRT_Api struct directly with XLA — no dlopen needed
  • 30 PJRT callbacks (PJRTPlugin.jl): Full PJRT_Api implementation covering client lifecycle, device/memory discovery, buffer management, compilation, and execution
  • MLIR walker (MLIRWalker.jl): Translates StableHLO ops to MPSGraph nodes — supports element-wise ops, dot_general, broadcast_in_dim, reshape, transpose, reduce (sum/max), conv2d/conv3d, reduce_window (pooling 2D/3D), concatenate, slice, scatter, reverse, and constant
  • @objc bindings (XLACompiler.jl): MPSGraph operations not wrapped by Metal.jl
  • Thread-safety: METAL_XLA_LOCK serializes buffer operations to prevent heap corruption from concurrent GC finalizer and main thread access to PjRtCApiClient
  • MtlArray pool: Recycles GPU buffers across @jit calls to avoid per-call allocation
  • macOS build fix: Disables lld linker (unavailable on macOS) and enables platform-aware Bazel toolchain resolution

What works today

  • Element-wise math (sin, cos, exp, tanh, relu, etc.)
  • Dense layers and Chain models
  • Conv2D and Conv3D with arbitrary layouts
  • Max/avg pooling (2D and 3D)
  • Enzyme autodiff (forward and reverse mode)
  • Full Lux CNN (Conv → Pool → Dense pipeline)

Architecture decisions

  1. Package extension (ReactantMetalExt): Loaded automatically when using Metal brings Metal.jl into scope. No changes needed to user code.
  2. __precompile__(false): Required because the extension overrides Base.convert, XLA.free_buffer, and XLA.to_host for thread-safety. Julia disallows method overwrites during precompilation.
  3. Direct PJRT_Api registration (Option A): Rather than building a C shared library, all 30 PJRT callbacks are Julia @cfunction pointers stored in a Libc.malloc'd struct. This eliminates the need for any external binary beyond the existing libReactantExtra.
  4. IR convention for tensors: MPSGraph tensors use IR (row-major) convention internally because placeholderTensor auto-reverses Julia shapes. The walker uses IR shapes directly for all operations, with layout permutations only at conv/pool boundaries.

Development process

This backend was developed over ~48 commits using an autonomous agent loop ("ralph loop") powered by Claude Code. The agent iteratively implemented and verified each component — from the initial PJRT callback prototype through conv layout bugs and thread-safety fixes. This PR is a clean 5-commit squash of that work onto origin/main, containing only the necessary production code. All development scaffolding (research files, debug tests, benchmark notebooks) has been removed.

Known limitations

  • Conv-after-concat with non-square spatial dims has a known shape mismatch (the "L7 problem" in UNet patterns) — under investigation
  • stablehlo.convert is identity-only (no actual dtype casting yet)
  • No reduce for min/prod
  • Float64 and Int64 are silently downcast to Float32/Int32 (MPSGraph limitation)

Files changed (15 files, +3,395 / -77)

File Change
deps/ReactantExtra/API.cpp +21: MakeClientFromApi()
deps/ReactantExtra/BUILD +1: export symbol
deps/build_local.jl ~8: macOS build fix
src/accelerators/Metal.jl rewrite: has_metal()/setup_metal!()
src/xla/Device.jl +1: @warn@debug
src/xla/PJRT/Client.jl +26: MakeMetalClientFromApi, _metal_pjrt_api_ptr
src/xla/XLA.jl ~22: enable Metal client init
ext/ReactantMetalExt.jl +147: extension entry point
ext/ReactantMetalExt/MLIRWalker.jl +1,576: MLIR → MPSGraph
ext/ReactantMetalExt/PJRTPlugin.jl +1,197: 30 PJRT callbacks
ext/ReactantMetalExt/XLACompiler.jl +369: @objc MPSGraph bindings
Project.toml +3: Metal in weakdeps + compat + extension
test/Project.toml +1: Metal in test deps
test/plugins/metal.jl ~16: fixes for Metal backend
test/runtests.jl ~4: enable Metal tests on macOS

Test plan

  • julia test/plugins/metal.jl on macOS with Apple Silicon — sincos, autodiff, CNN all pass
  • julia -e 'using Reactant; println(Reactant.XLA.default_backend())' — basic Reactant still works on non-Mac
  • Verify CI passes on Linux/CUDA (no functional changes to non-Metal paths)
  • Verify Metal is NOT in [deps] (only [weakdeps]) — no new mandatory dependency

🤖 Generated with Claude Code

Dale-Black and others added 5 commits February 20, 2026 09:00
Allows Julia-allocated PJRT_Api structs (filled with @cfunction pointers)
to be registered directly with XLA without requiring dlopen of a shared
library. This is the entry point for the Metal PJRT backend.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Replace download-based Metal.jl with has_metal()/setup_metal!() API
- Add MakeMetalClientFromApi for Julia-side PJRT_Api registration
- Enable Metal client initialization in XLA.jl when Metal.jl is loaded
- Downgrade unimplemented platform properties log to @debug

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Pure-Julia Metal GPU backend for Reactant via MLIR walking + MPSGraph.

- PJRTPlugin.jl: 30 @cfunction PJRT callbacks implementing the PJRT_Api
  struct for PjRtCApiClient initialization (no shared library needed)
- XLACompiler.jl: @objc MPSGraph bindings for ops not wrapped by Metal.jl
- MLIRWalker.jl: MLIR → MPSGraph translation supporting element-wise ops,
  dot_general, broadcast, reshape, transpose, reduce, conv2d/3d, pooling,
  concatenate, slice, scatter, and reverse
- Thread-safe buffer operations (METAL_XLA_LOCK) to prevent heap corruption
  from concurrent GC finalizer and main thread access

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Add Metal to [weakdeps] and [compat] in Project.toml
- Register ReactantMetalExt in [extensions]
- Add Metal to test/Project.toml [deps]
- Fix test/plugins/metal.jl: add Metal.functional() guard, broadcasting,
  correct gradient indexing
- Enable Metal tests on macOS in runtests.jl

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
lld is not available on macOS. Enable --incompatible_enable_cc_toolchain_resolution
so Bazel uses platform-aware toolchain selection instead of legacy CPU-string
matching (which incorrectly maps "darwin" to x86).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@codecov
Copy link

codecov bot commented Feb 20, 2026

Codecov Report

❌ Patch coverage is 0.27739% with 1438 lines in your changes missing coverage. Please review.
✅ Project coverage is 36.64%. Comparing base (b39a1fc) to head (e78236b).
⚠️ Report is 700 commits behind head on main.

Files with missing lines Patch % Lines
ext/ReactantMetalExt/MLIRWalker.jl 0.00% 748 Missing ⚠️
ext/ReactantMetalExt/PJRTPlugin.jl 0.00% 513 Missing ⚠️
ext/ReactantMetalExt/XLACompiler.jl 0.00% 142 Missing ⚠️
ext/ReactantMetalExt.jl 5.55% 17 Missing ⚠️
src/xla/PJRT/Client.jl 0.00% 10 Missing ⚠️
src/xla/XLA.jl 16.66% 5 Missing ⚠️
src/accelerators/Metal.jl 50.00% 2 Missing ⚠️
src/xla/Device.jl 0.00% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #2489       +/-   ##
===========================================
- Coverage   68.16%   36.64%   -31.52%     
===========================================
  Files         109      200       +91     
  Lines       11779    31298    +19519     
===========================================
+ Hits         8029    11469     +3440     
- Misses       3750    19829    +16079     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

# ============================================================================

"""Extract contracting_dims from dot_general op text."""
function parse_contracting_dims(op_text::AbstractString)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We dont need to parse the string here, we should be able to query the operation to extract these info

@Dale-Black
Copy link
Contributor Author

Split out the C++/Bazel changes into #2490 per @avik-pal's request. That PR adds only MakeClientFromApi to API.cpp and the symbol export in BUILD.

Once the JLL is rebuilt with that symbol, the Julia changes in this PR will work against the new JLL (no more LocalPreferences.toml / local build requirement).

I'll rebase this PR to remove the deps/ReactantExtra/ commits once #2490 is merged.

Address review feedback: the walker was using regex on `string(op)` and
`string(IR.type(...))` to extract attributes and type information from
MLIR operations. This replaces all 12 `parse_*` functions with proper
MLIR C API calls through Reactant's IR module.

The main patterns used:
- IR.getattr(op, name) + DenseArray indexing for simple attributes
  (broadcast_dimensions, permutation, dimensions, window_strides, etc.)
- StableHLO C API for structured attributes
  (stablehloConvDimensionNumbers*, stablehloDotDimensionNumbers*)
- API.mlirDenseElementsAttrGetInt64Value for DenseElements attributes
  where the Julia wrapper has a known bug (padding attrs)
- IR.type/IR.ndims/IR.size/IR.eltype for type inspection, replacing
  regex parsing of "tensor<4x8xf32>" strings

This is a pure refactor — no behavioral changes. All existing tests
pass identically before and after. Net result is -163 lines, since the
API calls are more concise than the regex parsers they replace.

There is one remaining `string(op)` call used for error messages on
unrecognized ops, which is a legitimate diagnostic use rather than
attribute extraction.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@Dale-Black
Copy link
Contributor Author

MLIR API Refactor — Addressing String Parsing Feedback

@avik-pal — thanks for the review. I interpreted your comment about not needing to parse strings as referring to the parse_* functions in MLIRWalker.jl that were using regex on string(op) and string(IR.type(...)) to extract attributes and type information. I've replaced all 12 of those functions with proper MLIR C API calls.

What changed (latest commit)

Replaced all string-based attribute/type extraction with API calls:

  • Simple attributes (broadcast_dimensions, permutation, dimensions, etc.): IR.getattr(op, name) + DenseArray indexing
  • Structured attributes (conv dimension_numbers, dot dimension_numbers): StableHLO C API functions (stablehloConvDimensionNumbersGet*, stablehloDotDimensionNumbersGet*)
  • Dense elements (padding): API.mlirDenseElementsAttrGetInt64Value (the Julia-level wrapper has a known bug for Int64/Float32 element types, so we call the C API directly)
  • Type inspection: IR.type / IR.ndims / IR.size / IR.eltype instead of regex-parsing "tensor<4x8xf32>" strings

Net result is -163 lines since the API calls are more concise than the regex parsers.

What we tested

  • test/plugins/metal.jl all 3 testsets pass (sincos, autodiff with Enzyme, CNN) — verified after each individual function replacement
  • Quick conv test with non-square input (24x16) to catch layout regressions
  • One remaining string(op) call is used for error messages on unrecognized ops (diagnostic, not attribute extraction)

I'll share some screenshots from the local Pluto benchmark notebook in a follow-up comment. Please let me know if this is what you had in mind or if there are other areas that need attention — still learning my way around the MLIR infrastructure here.

@Dale-Black
Copy link
Contributor Author

image image image

Since most of this is agent coded, I have been verifying with this local notebook I have been running that I THINK makes it hard to purely hallucinate these results (which look promising as far as I can tell!)

Local Pluto Notebook
### A Pluto.jl notebook ###
# v0.20.13

using Markdown
using InteractiveUtils

# ╔═╡ a1b2c3d4-0003-0001-0001-000000000001
# ╠═╡ show_logs = false
begin
	import Pkg
	env = mktempdir()
	Pkg.activate(env)
	# Write LocalPreferences BEFORE any Pkg ops (precompilation caches the JLL path)
	open(joinpath(env, "LocalPreferences.toml"), "w") do io
		println(io, "[Reactant_jll]")
		println(io, "libReactantExtra_path = \"", expanduser("~/Documents/dev/julia/Reactant.jl/deps/ReactantExtra/bazel-bin/libReactantExtra.so"), "\"")
	end
	# Use the local Reactant.jl checkout (metal-pjrt-backend branch)
	Pkg.develop(path=expanduser("~/Documents/dev/julia/Reactant.jl"))
	# Reactant_jll must be a DIRECT dep for Preferences.jl to find LocalPreferences.toml
	Pkg.add(["Lux", "Metal", "Reactant_jll", "Statistics"])
end

# ╔═╡ a1b2c3d4-0004-0001-0001-000000000001
begin
	using Random
	using Lux
	using Metal
	using Metal: MtlArray
	using Reactant
	using Statistics
end

# ╔═╡ a1b2c3d4-0001-0001-0001-000000000001
md"""
# Metal GPU Backend Benchmark

## CPU vs Metal.jl vs Reactant+Metal

This notebook compares three ways to run a Lux neural network on Apple Silicon:

1. **CPU only** — plain Julia arrays
2. **Metal.jl only** — GPU via `MtlArray`, no compiler optimization
3. **Reactant + Metal** — GPU via `@jit`, with XLA/MLIR optimization (op fusion, CSE, etc.)

The key insight: Reactant+Metal preserves **two** optimization layers — XLA/MLIR graph optimization AND Metal GPU execution — which neither CPU nor Metal.jl alone can match.
"""

# ╔═╡ a1b2c3d4-0002-0001-0001-000000000001
md"""
## Setup
"""

# ╔═╡ a1b2c3d4-0005-0001-0001-000000000001
md"""
## Model Definition

A larger `Dense` chain: 2048 → 1024 → 512 → 256 → 10. GPU benefits only appear when matrix operations are big enough that compute time dominates dispatch overhead. We test at batch sizes 256, 1024, and 4096.
"""

# ╔═╡ a1b2c3d4-0006-0001-0001-000000000001
begin
	const INPUT_DIM = 2048
	const BATCH_SIZES = [256, 1024, 4096]
	const N_WARMUP = 5
	const N_TRIALS = 20

	rng = Random.MersenneTwister(42)

	model = Chain(
		Dense(INPUT_DIM => 1024, relu),
		Dense(1024 => 512, relu),
		Dense(512 => 256, relu),
		Dense(256 => 10),
	)
	ps_cpu, st_cpu = Lux.setup(rng, model)
	nothing
end

# ╔═╡ a1b2c3d4-0007-0001-0001-000000000001
md"""
## Benchmark 1: CPU Only

Standard Julia arrays, no GPU. This is the baseline.
"""

# ╔═╡ a1b2c3d4-0008-0001-0001-000000000001
function bench_cpu(model, ps, st, input_dim, batch_size; n_warmup=N_WARMUP, n_trials=N_TRIALS)
	x = randn(Float32, input_dim, batch_size)
	f(m, x, ps, st) = first(m(x, ps, st))

	# Warmup
	for _ in 1:n_warmup
		f(model, x, ps, st)
	end

	# Timed runs
	times = Float64[]
	allocs = Int[]
	for _ in 1:n_trials
		stats = @timed f(model, x, ps, st)
		push!(times, stats.time)
		push!(allocs, Int(stats.bytes))
	end

	(; median_ms = median(times) * 1000,
	   min_ms = minimum(times) * 1000,
	   median_alloc_kb = median(allocs) / 1024)
end

# ╔═╡ a1b2c3d4-0009-0001-0001-000000000001
cpu_results = Dict(
	bs => bench_cpu(model, ps_cpu, st_cpu, INPUT_DIM, bs)
	for bs in BATCH_SIZES
)

# ╔═╡ a1b2c3d4-0010-0001-0001-000000000001
md"""
## Benchmark 2: Metal.jl Only

Move arrays to GPU via `MtlArray`. Lux supports this via `gpu_device()`. The model runs on Metal GPU but without any graph-level optimization — each op is dispatched individually.
"""

# ╔═╡ a1b2c3d4-0011-0001-0001-000000000001
begin
	gdev = gpu_device()
	cdev = cpu_device()
	ps_mtl, st_mtl = (ps_cpu, st_cpu) |> gdev
	md"Metal device: $(gdev)"
end

# ╔═╡ a1b2c3d4-0012-0001-0001-000000000001
function bench_metal(model, ps_gpu, st_gpu, input_dim, batch_size, gdev;
		n_warmup=N_WARMUP, n_trials=N_TRIALS)
	x_gpu = MtlArray(randn(Float32, input_dim, batch_size))
	f(m, x, ps, st) = first(m(x, ps, st))

	# Warmup (+ sync)
	for _ in 1:n_warmup
		y = f(model, x_gpu, ps_gpu, st_gpu)
		Metal.synchronize()
	end

	# Timed runs — sync after each to get true GPU time
	times = Float64[]
	allocs = Int[]
	for _ in 1:n_trials
		Metal.synchronize()
		stats = @timed begin
			y = f(model, x_gpu, ps_gpu, st_gpu)
			Metal.synchronize()
		end
		push!(times, stats.time)
		push!(allocs, Int(stats.bytes))
	end

	(; median_ms = median(times) * 1000,
	   min_ms = minimum(times) * 1000,
	   median_alloc_kb = median(allocs) / 1024)
end

# ╔═╡ a1b2c3d4-0013-0001-0001-000000000001
metal_results = Dict(
	bs => bench_metal(model, ps_mtl, st_mtl, INPUT_DIM, bs, gdev)
	for bs in BATCH_SIZES
)

# ╔═╡ a1b2c3d4-0014-0001-0001-000000000001
md"""
## Benchmark 3: Reactant + Metal

`@compile` traces the function, optimizes the MLIR graph (fusion, CSE, constant folding), and returns a compiled executable for Metal GPU via our PJRT plugin. Compilation happens once; the compiled function is called repeatedly for timing. (Note: `@jit` recompiles every call — always use `@compile` for benchmarks.)
"""

# ╔═╡ a1b2c3d4-0015-0001-0001-000000000001
function bench_reactant_metal(model, ps, st, input_dim, batch_size;
		n_warmup=N_WARMUP, n_trials=N_TRIALS)
	x = randn(Float32, input_dim, batch_size)

	x_ra = Reactant.to_rarray(x)
	ps_ra = Reactant.to_rarray(ps)
	st_ra = Reactant.to_rarray(st)

	f(m, x, ps, st) = first(m(x, ps, st))

	# Compile ONCE — @compile returns a cached compiled function
	compile_stats = @timed begin
		compiled_f = @compile f(model, x_ra, ps_ra, st_ra)
	end
	compile_ms = compile_stats.time * 1000

	# Warmup the compiled function (no recompilation)
	for _ in 1:n_warmup
		compiled_f(model, x_ra, ps_ra, st_ra)
	end

	# Timed runs — execute only, no host transfer (apples-to-apples with Metal.jl)
	times = Float64[]
	allocs = Int[]
	for _ in 1:n_trials
		stats = @timed begin
			compiled_f(model, x_ra, ps_ra, st_ra)
		end
		push!(times, stats.time)
		push!(allocs, Int(stats.bytes))
	end

	(; median_ms = median(times) * 1000,
	   min_ms = minimum(times) * 1000,
	   median_alloc_kb = median(allocs) / 1024,
	   compile_ms)
end

# ╔═╡ a1b2c3d4-0016-0001-0001-000000000001
reactant_results = Dict(
	bs => bench_reactant_metal(model, ps_cpu, st_cpu, INPUT_DIM, bs)
	for bs in BATCH_SIZES
)

# ╔═╡ a1b2c3d4-0017-0001-0001-000000000001
md"""
## Correctness Check

Before trusting the benchmarks, verify all three backends produce the same result.
"""

# ╔═╡ a1b2c3d4-0018-0001-0001-000000000001
begin
	x_test = randn(Float32, INPUT_DIM, 8)
	f_test(m, x, ps, st) = first(m(x, ps, st))

	# CPU
	y_cpu = f_test(model, x_test, ps_cpu, st_cpu)

	# Metal.jl
	x_mtl_test = MtlArray(x_test)
	y_mtl = Array(f_test(model, x_mtl_test, ps_mtl, st_mtl))

	# Reactant + Metal (compile once, then execute)
	x_ra_test = Reactant.to_rarray(x_test)
	ps_ra_test = Reactant.to_rarray(ps_cpu)
	st_ra_test = Reactant.to_rarray(st_cpu)
	compiled_test = @compile f_test(model, x_ra_test, ps_ra_test, st_ra_test)
	y_reactant = Array(compiled_test(model, x_ra_test, ps_ra_test, st_ra_test))

	err_metal = maximum(abs.(y_cpu .- y_mtl))
	err_reactant = maximum(abs.(y_cpu .- y_reactant))

	md"""
	| Backend | Max Error vs CPU |
	|---------|-----------------|
	| Metal.jl | $(round(err_metal; sigdigits=3)) |
	| Reactant+Metal | $(round(err_reactant; sigdigits=3)) |

	Both should be < 1e-5 (float32 precision).
	"""
end

# ╔═╡ a1b2c3d4-0019-0001-0001-000000000001
md"""
## Results Comparison
"""

# ╔═╡ a1b2c3d4-0020-0001-0001-000000000001
begin
	header = "| Batch Size | CPU (ms) | Metal.jl (ms) | Reactant+Metal (ms) | Speedup vs CPU | Speedup vs Metal |"
	sep    = "|-----------|---------|--------------|--------------------|--------------:|----------------:|"
	rows = String[]
	for bs in BATCH_SIZES
		c = cpu_results[bs]
		m = metal_results[bs]
		r = reactant_results[bs]
		speedup_cpu = round(c.median_ms / r.median_ms; digits=1)
		speedup_metal = round(m.median_ms / r.median_ms; digits=1)
		push!(rows, "| $bs | $(round(c.median_ms; digits=3)) | $(round(m.median_ms; digits=3)) | $(round(r.median_ms; digits=3)) | $(speedup_cpu)x | $(speedup_metal)x |")
	end

	Markdown.parse(join([header, sep, rows...], "\n"))
end

# ╔═╡ a1b2c3d4-0021-0001-0001-000000000001
begin
	alloc_header = "| Batch Size | CPU (KB) | Metal.jl (KB) | Reactant+Metal (KB) |"
	alloc_sep    = "|-----------|---------|--------------|--------------------:|"
	alloc_rows = String[]
	for bs in BATCH_SIZES
		c = cpu_results[bs]
		m = metal_results[bs]
		r = reactant_results[bs]
		push!(alloc_rows, "| $bs | $(round(c.median_alloc_kb; digits=1)) | $(round(m.median_alloc_kb; digits=1)) | $(round(r.median_alloc_kb; digits=1)) |")
	end

	Markdown.parse(join(["### Allocations (median per call)", "", alloc_header, alloc_sep, alloc_rows...], "\n"))
end

# ╔═╡ a1b2c3d4-0022-0001-0001-000000000001
begin
	compile_header = "| Batch Size | First-call compile (ms) |"
	compile_sep    = "|-----------|----------------------:|"
	compile_rows = [
		"| $bs | $(round(reactant_results[bs].compile_ms; digits=1)) |"
		for bs in BATCH_SIZES
	]

	Markdown.parse(join(["### Reactant Compile Time (one-time cost)", "", compile_header, compile_sep, compile_rows...], "\n"))
end

# ╔═╡ b2c3d4e5-0001-0001-0001-000000000001
md"""
## Benchmark 4: Fusion Stress Test

This tests where Reactant+Metal should **dominate**: a long chain of element-wise ops on large tensors. Metal.jl dispatches each broadcast as a separate GPU kernel. Reactant fuses them ALL into a single kernel — one launch instead of many.
"""

# ╔═╡ b2c3d4e5-0002-0001-0001-000000000001
begin
	"""10 element-wise ops that XLA fuses into 1-2 kernels (vs 10 kernel launches in Metal.jl)."""
	function elementwise_chain(x)
		x = x .* 2.0f0
		x = x .+ 1.0f0
		x = tanh.(x)
		x = x .* x            # square
		x = x .- 0.5f0
		x = exp.(x)
		x = x ./ (x .+ 1.0f0) # sigmoid-like
		x = abs.(x)
		x = x .* 3.0f0
		x = x .- x .* 0.1f0
		return x
	end

	const FUSION_SIZES = [2048*256, 2048*1024, 2048*4096]
	md"Defined `elementwise_chain` — 10 chained broadcasts."
end

# ╔═╡ b2c3d4e5-0003-0001-0001-000000000001
function bench_cpu_fusion(f, n; n_warmup=N_WARMUP, n_trials=N_TRIALS)
	x = randn(Float32, n)
	for _ in 1:n_warmup; f(x); end
	times = Float64[]
	allocs = Int[]
	for _ in 1:n_trials
		stats = @timed f(x)
		push!(times, stats.time)
		push!(allocs, Int(stats.bytes))
	end
	(; median_ms = median(times) * 1000,
	   min_ms = minimum(times) * 1000,
	   median_alloc_kb = median(allocs) / 1024)
end

# ╔═╡ b2c3d4e5-0004-0001-0001-000000000001
function bench_metal_fusion(f, n; n_warmup=N_WARMUP, n_trials=N_TRIALS)
	x_gpu = MtlArray(randn(Float32, n))
	for _ in 1:n_warmup
		f(x_gpu)
		Metal.synchronize()
	end
	times = Float64[]
	allocs = Int[]
	for _ in 1:n_trials
		Metal.synchronize()
		stats = @timed begin
			f(x_gpu)
			Metal.synchronize()
		end
		push!(times, stats.time)
		push!(allocs, Int(stats.bytes))
	end
	(; median_ms = median(times) * 1000,
	   min_ms = minimum(times) * 1000,
	   median_alloc_kb = median(allocs) / 1024)
end

# ╔═╡ b2c3d4e5-0005-0001-0001-000000000001
function bench_reactant_fusion(f, n; n_warmup=N_WARMUP, n_trials=N_TRIALS)
	x = randn(Float32, n)
	x_ra = Reactant.to_rarray(x)

	compile_stats = @timed begin
		compiled_f = @compile f(x_ra)
	end
	compile_ms = compile_stats.time * 1000

	for _ in 1:n_warmup
		compiled_f(x_ra)
	end

	times = Float64[]
	allocs = Int[]
	for _ in 1:n_trials
		stats = @timed compiled_f(x_ra)
		push!(times, stats.time)
		push!(allocs, Int(stats.bytes))
	end
	(; median_ms = median(times) * 1000,
	   min_ms = minimum(times) * 1000,
	   median_alloc_kb = median(allocs) / 1024,
	   compile_ms)
end

# ╔═╡ b2c3d4e5-0006-0001-0001-000000000001
begin
	fusion_cpu = Dict(n => bench_cpu_fusion(elementwise_chain, n) for n in FUSION_SIZES)
	fusion_metal = Dict(n => bench_metal_fusion(elementwise_chain, n) for n in FUSION_SIZES)
	fusion_reactant = Dict(n => bench_reactant_fusion(elementwise_chain, n) for n in FUSION_SIZES)
	md"Fusion benchmarks complete."
end

# ╔═╡ b2c3d4e5-0007-0001-0001-000000000001
begin
	fh = "| Elements | CPU (ms) | Metal.jl (ms) | Reactant+Metal (ms) | Speedup vs CPU | Speedup vs Metal |"
	fs = "|---------|---------|--------------|--------------------|--------------:|----------------:|"
	frows = String[]
	for n in FUSION_SIZES
		c = fusion_cpu[n]
		m = fusion_metal[n]
		r = fusion_reactant[n]
		sp_cpu = round(c.median_ms / r.median_ms; digits=1)
		sp_metal = round(m.median_ms / r.median_ms; digits=1)
		push!(frows, "| $(n) | $(round(c.median_ms; digits=3)) | $(round(m.median_ms; digits=3)) | $(round(r.median_ms; digits=3)) | $(sp_cpu)x | $(sp_metal)x |")
	end

	Markdown.parse(join(["### Fusion Stress Test Results", "", "10 chained element-wise ops on large 1D tensors. Metal.jl = 10 kernel launches. Reactant = 1-2 fused kernels.", "", fh, fs, frows...], "\n"))
end

# ╔═╡ b2c3d4e5-0008-0001-0001-000000000001
begin
	fa_header = "| Elements | CPU (KB) | Metal.jl (KB) | Reactant+Metal (KB) |"
	fa_sep    = "|---------|---------|--------------|--------------------:|"
	fa_rows = String[]
	for n in FUSION_SIZES
		c = fusion_cpu[n]
		m = fusion_metal[n]
		r = fusion_reactant[n]
		push!(fa_rows, "| $(n) | $(round(c.median_alloc_kb; digits=1)) | $(round(m.median_alloc_kb; digits=1)) | $(round(r.median_alloc_kb; digits=1)) |")
	end

	Markdown.parse(join(["### Fusion Allocations (median per call)", "", fa_header, fa_sep, fa_rows...], "\n"))
end

# ╔═╡ c3d4e5f6-0001-0001-0001-000000000001
md"""
## Benchmark 5: 2D UNet-like Model (Conv + Pool + Skip Connections)

A model with **convolution**, **max pooling**, **residual add**, and **concatenation skip connections** — the core building blocks of a UNet.

This exercises three new op handlers:
- `stablehlo.convolution` → `MPSGraph convolution2D`
- `stablehlo.reduce_window` → `MPSGraph maxPooling2D`
- `stablehlo.concatenate` → `MPSGraph concatTensors`

**No Metal.jl column:** Metal.jl has no native GPU kernels for conv/pool (NNlib falls back to CPU scalar indexing). This is exactly why Reactant+Metal exists — it compiles these ops directly to MPSGraph.
"""

# ╔═╡ c3d4e5f6-0002-0001-0001-000000000001
begin
	const IMG_SIZE = (64, 64)
	const IMG_CH = 1
	const UNET_BATCHES = [1, 4, 16]

	unet_model = Chain(
		# Initial projection
		Conv((3, 3), 1 => 16, relu; pad=1),                  # 64×64×16

		# Residual block (addition skip connection)
		SkipConnection(
			Chain(
				Conv((3, 3), 16 => 16, relu; pad=1),
				Conv((3, 3), 16 => 16, relu; pad=1),
			),
			+
		),                                                      # 64×64×16

		# Downsample
		MaxPool((2, 2)),                                        # 32×32×16

		# Concatenation skip block (channel-cat skip connection)
		SkipConnection(
			Chain(
				Conv((3, 3), 16 => 16, relu; pad=1),
				Conv((3, 3), 16 => 16, relu; pad=1),
			),
			(mx, x) -> cat(mx, x; dims=3)                      # concat along C dim
		),                                                      # 32×32×32

		# More convolution
		Conv((3, 3), 32 => 32, relu; pad=1),                   # 32×32×32

		# Downsample again
		MaxPool((2, 2)),                                        # 16×16×32

		# Bottleneck
		Conv((3, 3), 32 => 64, relu; pad=1),                   # 16×16×64

		# Output head
		Conv((1, 1), 64 => 1),                                  # 16×16×1
	)

	unet_ps_cpu, unet_st_cpu = Lux.setup(rng, unet_model)
	md"2D UNet model: 8 Conv layers, 2 MaxPool, 1 residual add, 1 concat skip."
end

# ╔═╡ c3d4e5f6-0003-0001-0001-000000000001
function bench_unet_cpu(model, ps, st, batch_size; n_warmup=N_WARMUP, n_trials=N_TRIALS)
	x = randn(Float32, IMG_SIZE..., IMG_CH, batch_size)
	f(m, x, ps, st) = first(m(x, ps, st))
	for _ in 1:n_warmup; f(model, x, ps, st); end
	times = Float64[]
	allocs = Int[]
	for _ in 1:n_trials
		stats = @timed f(model, x, ps, st)
		push!(times, stats.time)
		push!(allocs, Int(stats.bytes))
	end
	(; median_ms = median(times) * 1000,
	   min_ms = minimum(times) * 1000,
	   median_alloc_kb = median(allocs) / 1024)
end

# ╔═╡ c3d4e5f6-0004-0001-0001-000000000001
unet_cpu_results = Dict(
	bs => bench_unet_cpu(unet_model, unet_ps_cpu, unet_st_cpu, bs)
	for bs in UNET_BATCHES
)

# ╔═╡ c3d4e5f6-0005-0001-0001-000000000001
function bench_unet_reactant(model, ps, st, batch_size;
		n_warmup=N_WARMUP, n_trials=N_TRIALS)
	x = randn(Float32, IMG_SIZE..., IMG_CH, batch_size)
	x_ra = Reactant.to_rarray(x)
	ps_ra = Reactant.to_rarray(ps)
	st_ra = Reactant.to_rarray(st)
	f(m, x, ps, st) = first(m(x, ps, st))

	compile_stats = @timed begin
		compiled_f = @compile f(model, x_ra, ps_ra, st_ra)
	end
	compile_ms = compile_stats.time * 1000

	for _ in 1:n_warmup
		compiled_f(model, x_ra, ps_ra, st_ra)
	end

	times = Float64[]
	allocs = Int[]
	for _ in 1:n_trials
		stats = @timed compiled_f(model, x_ra, ps_ra, st_ra)
		push!(times, stats.time)
		push!(allocs, Int(stats.bytes))
	end
	(; median_ms = median(times) * 1000,
	   min_ms = minimum(times) * 1000,
	   median_alloc_kb = median(allocs) / 1024,
	   compile_ms)
end

# ╔═╡ c3d4e5f6-0006-0001-0001-000000000001
unet_reactant_results = Dict(
	bs => bench_unet_reactant(unet_model, unet_ps_cpu, unet_st_cpu, bs)
	for bs in UNET_BATCHES
)

# ╔═╡ c3d4e5f6-0007-0001-0001-000000000001
begin
	x_unet_test = randn(Float32, IMG_SIZE..., IMG_CH, 2)
	f_unet(m, x, ps, st) = first(m(x, ps, st))

	# CPU reference
	y_unet_cpu = f_unet(unet_model, x_unet_test, unet_ps_cpu, unet_st_cpu)

	# Reactant+Metal
	x_unet_ra = Reactant.to_rarray(x_unet_test)
	ps_unet_ra = Reactant.to_rarray(unet_ps_cpu)
	st_unet_ra = Reactant.to_rarray(unet_st_cpu)
	compiled_unet = @compile f_unet(unet_model, x_unet_ra, ps_unet_ra, st_unet_ra)
	y_unet_reactant = Array(compiled_unet(unet_model, x_unet_ra, ps_unet_ra, st_unet_ra))

	unet_err_reactant = maximum(abs.(y_unet_cpu .- y_unet_reactant))

	md"""
	### 2D UNet Correctness Check

	| Backend | Max Error vs CPU |
	|---------|-----------------|
	| Reactant+Metal | $(round(unet_err_reactant; sigdigits=3)) |

	Should be < 1e-4 (conv accumulation allows slightly more error than dense).
	"""
end

# ╔═╡ c3d4e5f6-0008-0001-0001-000000000001
begin
	uh = "| Batch | CPU (ms) | Reactant+Metal (ms) | Speedup | Compile (ms) |"
	us = "|------|---------|--------------------|---------:|------------:|"
	urows = String[]
	for bs in UNET_BATCHES
		c = unet_cpu_results[bs]
		r = unet_reactant_results[bs]
		sp = round(c.median_ms / r.median_ms; digits=1)
		push!(urows, "| $bs | $(round(c.median_ms; digits=3)) | $(round(r.median_ms; digits=3)) | $(sp)x | $(round(r.compile_ms; digits=1)) |")
	end

	Markdown.parse(join(["### 2D UNet Results (CPU vs Reactant+Metal)", "",
		"8 Conv layers, 2 MaxPool, residual add + concat skip. Metal.jl cannot run this on GPU.", "",
		uh, us, urows...], "\n"))
end

# ╔═╡ c3d4e5f6-0009-0001-0001-000000000001
begin
	ua_header = "| Batch | CPU alloc (KB) | Reactant+Metal alloc (KB) |"
	ua_sep    = "|------|---------------:|-------------------------:|"
	ua_rows = String[]
	for bs in UNET_BATCHES
		c = unet_cpu_results[bs]
		r = unet_reactant_results[bs]
		push!(ua_rows, "| $bs | $(round(c.median_alloc_kb; digits=1)) | $(round(r.median_alloc_kb; digits=1)) |")
	end

	Markdown.parse(join(["### 2D UNet Allocations", "", ua_header, ua_sep, ua_rows...], "\n"))
end

# ╔═╡ d4e5f6a7-0001-0001-0001-000000000001
md"""
## Benchmark 6: 3D UNet-like Model (3D Conv + Skip Connections)

A volumetric model with **3D convolution**, **residual add**, and **concatenation skip connections** — the building blocks of a 3D UNet for medical imaging (CT/MRI segmentation).

This exercises the 3D convolution handler:
- `stablehlo.convolution` (5D tensors) → `MPSGraph convolution3D`
- `stablehlo.concatenate` → `MPSGraph concatTensors`

**No pooling:** MPSGraph has no 5D pooling. Downsampling uses stride-2 convolutions instead (common in modern architectures like V-Net).

**No Metal.jl column:** Same as 2D — Metal.jl has no native GPU kernels for 3D conv.
"""

# ╔═╡ d4e5f6a7-0002-0001-0001-000000000001
begin
	const VOL_SIZE = (32, 32, 32)
	const VOL_CH = 1
	const UNET3D_BATCHES = [1, 2]

	unet3d_model = Chain(
		# Initial projection: 32³×1 → 32³×8
		Conv((3, 3, 3), 1 => 8, relu; pad=1),

		# Residual block (addition skip)
		SkipConnection(
			Chain(
				Conv((3, 3, 3), 8 => 8, relu; pad=1),
				Conv((3, 3, 3), 8 => 8, relu; pad=1),
			),
			+
		),                                              # 32³×8

		# Downsample via stride-2 conv: 32³×8 → 16³×16
		Conv((2, 2, 2), 8 => 16, relu; stride=2),

		# Concatenation skip block
		SkipConnection(
			Chain(
				Conv((3, 3, 3), 16 => 16, relu; pad=1),
				Conv((3, 3, 3), 16 => 16, relu; pad=1),
			),
			(mx, x) -> cat(mx, x; dims=4)              # concat along C dim (dim 4 for 5D)
		),                                              # 16³×32

		# More convolution
		Conv((3, 3, 3), 32 => 32, relu; pad=1),        # 16³×32

		# Downsample again: 16³×32 → 8³×64
		Conv((2, 2, 2), 32 => 64, relu; stride=2),

		# Bottleneck
		Conv((3, 3, 3), 64 => 64, relu; pad=1),        # 8³×64

		# Output head
		Conv((1, 1, 1), 64 => 1),                       # 8³×1
	)

	unet3d_ps_cpu, unet3d_st_cpu = Lux.setup(rng, unet3d_model)
	md"3D UNet model: 8 Conv3D layers, 2 stride-2 downsample, 1 residual add, 1 concat skip."
end

# ╔═╡ d4e5f6a7-0003-0001-0001-000000000001
function bench_unet3d_cpu(model, ps, st, batch_size; n_warmup=3, n_trials=10)
	x = randn(Float32, VOL_SIZE..., VOL_CH, batch_size)
	f(m, x, ps, st) = first(m(x, ps, st))
	for _ in 1:n_warmup; f(model, x, ps, st); end
	times = Float64[]
	allocs = Int[]
	for _ in 1:n_trials
		stats = @timed f(model, x, ps, st)
		push!(times, stats.time)
		push!(allocs, Int(stats.bytes))
	end
	(; median_ms = median(times) * 1000,
	   min_ms = minimum(times) * 1000,
	   median_alloc_kb = median(allocs) / 1024)
end

# ╔═╡ d4e5f6a7-0004-0001-0001-000000000001
unet3d_cpu_results = Dict(
	bs => bench_unet3d_cpu(unet3d_model, unet3d_ps_cpu, unet3d_st_cpu, bs)
	for bs in UNET3D_BATCHES
)

# ╔═╡ d4e5f6a7-0005-0001-0001-000000000001
function bench_unet3d_reactant(model, ps, st, batch_size;
		n_warmup=3, n_trials=10)
	x = randn(Float32, VOL_SIZE..., VOL_CH, batch_size)
	x_ra = Reactant.to_rarray(x)
	ps_ra = Reactant.to_rarray(ps)
	st_ra = Reactant.to_rarray(st)
	f(m, x, ps, st) = first(m(x, ps, st))

	compile_stats = @timed begin
		compiled_f = @compile f(model, x_ra, ps_ra, st_ra)
	end
	compile_ms = compile_stats.time * 1000

	for _ in 1:n_warmup
		compiled_f(model, x_ra, ps_ra, st_ra)
	end

	times = Float64[]
	allocs = Int[]
	for _ in 1:n_trials
		stats = @timed compiled_f(model, x_ra, ps_ra, st_ra)
		push!(times, stats.time)
		push!(allocs, Int(stats.bytes))
	end
	(; median_ms = median(times) * 1000,
	   min_ms = minimum(times) * 1000,
	   median_alloc_kb = median(allocs) / 1024,
	   compile_ms)
end

# ╔═╡ d4e5f6a7-0006-0001-0001-000000000001
unet3d_reactant_results = Dict(
	bs => bench_unet3d_reactant(unet3d_model, unet3d_ps_cpu, unet3d_st_cpu, bs)
	for bs in UNET3D_BATCHES
)

# ╔═╡ d4e5f6a7-0007-0001-0001-000000000001
begin
	x_unet3d_test = randn(Float32, VOL_SIZE..., VOL_CH, 1)
	f_unet3d(m, x, ps, st) = first(m(x, ps, st))

	# CPU reference
	y_unet3d_cpu = f_unet3d(unet3d_model, x_unet3d_test, unet3d_ps_cpu, unet3d_st_cpu)

	# Reactant+Metal
	x_unet3d_ra = Reactant.to_rarray(x_unet3d_test)
	ps_unet3d_ra = Reactant.to_rarray(unet3d_ps_cpu)
	st_unet3d_ra = Reactant.to_rarray(unet3d_st_cpu)
	compiled_unet3d = @compile f_unet3d(unet3d_model, x_unet3d_ra, ps_unet3d_ra, st_unet3d_ra)
	y_unet3d_reactant = Array(compiled_unet3d(unet3d_model, x_unet3d_ra, ps_unet3d_ra, st_unet3d_ra))

	unet3d_err_reactant = maximum(abs.(y_unet3d_cpu .- y_unet3d_reactant))

	md"""
	### 3D UNet Correctness Check

	| Backend | Max Error vs CPU |
	|---------|-----------------|
	| Reactant+Metal | $(round(unet3d_err_reactant; sigdigits=3)) |

	Should be < 1e-4 (3D conv accumulation allows slightly more error).
	"""
end

# ╔═╡ d4e5f6a7-0008-0001-0001-000000000001
begin
	u3h = "| Batch | CPU (ms) | Reactant+Metal (ms) | Speedup | Compile (ms) |"
	u3s = "|------|---------|--------------------|---------:|------------:|"
	u3rows = String[]
	for bs in UNET3D_BATCHES
		c = unet3d_cpu_results[bs]
		r = unet3d_reactant_results[bs]
		sp = round(c.median_ms / r.median_ms; digits=1)
		push!(u3rows, "| $bs | $(round(c.median_ms; digits=3)) | $(round(r.median_ms; digits=3)) | $(sp)x | $(round(r.compile_ms; digits=1)) |")
	end

	Markdown.parse(join(["### 3D UNet Results (CPU vs Reactant+Metal)", "",
		"8 Conv3D layers, stride-2 downsample, residual add + concat skip. 32³ voxel input.", "",
		u3h, u3s, u3rows...], "\n"))
end

# ╔═╡ d4e5f6a7-0009-0001-0001-000000000001
begin
	u3a_header = "| Batch | CPU alloc (KB) | Reactant+Metal alloc (KB) |"
	u3a_sep    = "|------|---------------:|-------------------------:|"
	u3a_rows = String[]
	for bs in UNET3D_BATCHES
		c = unet3d_cpu_results[bs]
		r = unet3d_reactant_results[bs]
		push!(u3a_rows, "| $bs | $(round(c.median_alloc_kb; digits=1)) | $(round(r.median_alloc_kb; digits=1)) |")
	end

	Markdown.parse(join(["### 3D UNet Allocations", "", u3a_header, u3a_sep, u3a_rows...], "\n"))
end

# ╔═╡ a1b2c3d4-0023-0001-0001-000000000001
md"""
## What's Happening Under the Hood

Julia code


Reactant tracing → StableHLO MLIR


XLA/MLIR optimization (op fusion, CSE, constant folding)


Optimized MLIR → PJRT plugin → MLIR walker


MPSGraph builder → Metal GPU execution


**Why Reactant+Metal can beat Metal.jl alone:**
- XLA fuses multiple operations into single GPU kernels (fewer launches)
- Constant folding eliminates redundant computation at compile time
- CSE (common subexpression elimination) removes duplicate work
- The MPSGraph layer adds Apple's own Metal-specific optimizations on top

**Why GPU beats CPU (at sufficient scale):**
- GPU parallelism for large matrix operations
- GPU memory bandwidth advantage for data-heavy workloads

**Important caveats:**
- Small models (< ~100K params) are faster on CPU — GPU dispatch overhead dominates
- Reactant+Metal has per-call PJRT overhead (~ms) that only pays off with larger compute
- First `@jit` call includes compilation (seconds); subsequent calls reuse cached executable
- This is a prototype PJRT plugin — production performance would be better
"""

# ╔═╡ Cell order:
# ╟─a1b2c3d4-0001-0001-0001-000000000001
# ╟─a1b2c3d4-0002-0001-0001-000000000001
# ╠═a1b2c3d4-0003-0001-0001-000000000001
# ╠═a1b2c3d4-0004-0001-0001-000000000001
# ╟─a1b2c3d4-0005-0001-0001-000000000001
# ╠═a1b2c3d4-0006-0001-0001-000000000001
# ╟─a1b2c3d4-0007-0001-0001-000000000001
# ╠═a1b2c3d4-0008-0001-0001-000000000001
# ╠═a1b2c3d4-0009-0001-0001-000000000001
# ╟─a1b2c3d4-0010-0001-0001-000000000001
# ╠═a1b2c3d4-0011-0001-0001-000000000001
# ╠═a1b2c3d4-0012-0001-0001-000000000001
# ╠═a1b2c3d4-0013-0001-0001-000000000001
# ╟─a1b2c3d4-0014-0001-0001-000000000001
# ╠═a1b2c3d4-0015-0001-0001-000000000001
# ╠═a1b2c3d4-0016-0001-0001-000000000001
# ╟─a1b2c3d4-0017-0001-0001-000000000001
# ╠═a1b2c3d4-0018-0001-0001-000000000001
# ╟─a1b2c3d4-0019-0001-0001-000000000001
# ╠═a1b2c3d4-0020-0001-0001-000000000001
# ╠═a1b2c3d4-0021-0001-0001-000000000001
# ╠═a1b2c3d4-0022-0001-0001-000000000001
# ╟─b2c3d4e5-0001-0001-0001-000000000001
# ╠═b2c3d4e5-0002-0001-0001-000000000001
# ╠═b2c3d4e5-0003-0001-0001-000000000001
# ╠═b2c3d4e5-0004-0001-0001-000000000001
# ╠═b2c3d4e5-0005-0001-0001-000000000001
# ╠═b2c3d4e5-0006-0001-0001-000000000001
# ╟─b2c3d4e5-0007-0001-0001-000000000001
# ╠═b2c3d4e5-0008-0001-0001-000000000001
# ╟─c3d4e5f6-0001-0001-0001-000000000001
# ╠═c3d4e5f6-0002-0001-0001-000000000001
# ╠═c3d4e5f6-0003-0001-0001-000000000001
# ╠═c3d4e5f6-0004-0001-0001-000000000001
# ╠═c3d4e5f6-0005-0001-0001-000000000001
# ╠═c3d4e5f6-0006-0001-0001-000000000001
# ╠═c3d4e5f6-0007-0001-0001-000000000001
# ╟─c3d4e5f6-0008-0001-0001-000000000001
# ╠═c3d4e5f6-0009-0001-0001-000000000001
# ╟─d4e5f6a7-0001-0001-0001-000000000001
# ╠═d4e5f6a7-0002-0001-0001-000000000001
# ╠═d4e5f6a7-0003-0001-0001-000000000001
# ╠═d4e5f6a7-0004-0001-0001-000000000001
# ╠═d4e5f6a7-0005-0001-0001-000000000001
# ╠═d4e5f6a7-0006-0001-0001-000000000001
# ╠═d4e5f6a7-0007-0001-0001-000000000001
# ╟─d4e5f6a7-0008-0001-0001-000000000001
# ╠═d4e5f6a7-0009-0001-0001-000000000001
# ╟─a1b2c3d4-0023-0001-0001-000000000001

Per review feedback: the PjRtCApiClient finalizer-vs-main-thread race
condition should be fixed generically in core Reactant, not per-backend.

Removed METAL_XLA_LOCK, Base.convert override, free_buffer override,
to_host override, and __precompile__(false).

See EnzymeAD#2493 for the generic fix.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants