-
Notifications
You must be signed in to change notification settings - Fork 54
Metal PJRT backend via MPSGraph + pure-Julia plugin #2489
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Dale-Black
wants to merge
7
commits into
EnzymeAD:main
Choose a base branch
from
Dale-Black:metal-pjrt-backend
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+3,152
−77
Open
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
a1d2349
Add MakeClientFromApi to C++ PJRT bridge
Dale-Black 4571fcb
Add Metal accelerator hooks and PJRT client registration
Dale-Black 1b24d8c
Add ReactantMetalExt package extension
Dale-Black f416d50
Wire Metal as weak dependency with test support
Dale-Black 78c1f20
Fix macOS local build: disable lld, add toolchain resolution
Dale-Black 7d2a87a
Replace string parsing in MLIRWalker with MLIR C API calls
Dale-Black e78236b
Remove per-extension thread-safety overrides from ReactantMetalExt
Dale-Black File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| module ReactantMetalExt | ||
|
|
||
| using Metal | ||
| using Metal.MPS | ||
| using Metal: MtlArray | ||
|
|
||
| # ObjectiveC primitives needed by @objc call sites in XLACompiler.jl | ||
| using Metal.MTL: @objc, id, nil, NSString, NSArray, NSDictionary, reinterpret | ||
|
|
||
| # Descriptor types needed by @objc [T alloc] calls (macro requires bare identifiers) | ||
| using Metal.MPSGraphs: MPSGraphConvolution2DOpDescriptor, | ||
| MPSGraphConvolution3DOpDescriptor, | ||
| MPSGraphPooling2DOpDescriptor, | ||
| MPSGraphPooling4DOpDescriptor | ||
|
|
||
| # Reactant's in-tree MLIR modules — no parameter injection needed | ||
| using Reactant: Reactant, MLIR | ||
| using Reactant.MLIR: IR, API | ||
|
|
||
| # Phase-1 PJRT plugin: 30 @cfunction callbacks + PJRT_Api struct + make_client() | ||
| include("ReactantMetalExt/PJRTPlugin.jl") | ||
|
|
||
| # @objc bindings for MPSGraph ops not wrapped by Metal.jl, | ||
| # plus julia_to_mps_dtype and mps_reshape helpers | ||
| include("ReactantMetalExt/XLACompiler.jl") | ||
|
|
||
| # MLIR walker: compile_mlir_module, MetalExecutable, execute! | ||
| include("ReactantMetalExt/MLIRWalker.jl") | ||
|
|
||
| export compile_mlir_module, MetalExecutable, execute! | ||
|
|
||
| function __init__() | ||
| @static if Sys.isapple() | ||
| if Metal.functional() | ||
| # Initialize @cfunction handles and register the PJRT_Api pointer | ||
| # so PJRT.MakeMetalClient() (no-args) can be called from XLA.jl. | ||
| try | ||
| init_pjrt_handles!() | ||
| # Expose the PJRT_Api struct pointer to Reactant's Client.jl | ||
| Reactant.XLA.PJRT._metal_pjrt_api_ptr[] = Ptr{Cvoid}(_PJRT_API_MEM) | ||
|
|
||
| # Create client via the shared PJRT.MetalClient() path (checkcount=false | ||
| # because initialize_default_clients! may not have run yet and the counter | ||
| # won't have been touched). | ||
| state = Reactant.XLA.global_backend_state | ||
| if haskey(state.clients, "metal") | ||
| # Already registered (e.g., XLA.jl's init block ran first). | ||
| state.default_client = state.clients["metal"] | ||
| else | ||
| metal = Reactant.XLA.PJRT.MetalClient(checkcount=false) | ||
| Reactant.XLA.PJRT.metal_client_count[] += 1 | ||
| state.clients["metal"] = metal | ||
| state.default_client = metal | ||
| end | ||
| catch e | ||
| if e isa ErrorException && contains(e.msg, "MakeClientFromApi") | ||
| @warn "Metal PJRT backend requires rebuilt libReactantExtra. Run: julia --project=deps deps/build_local.jl" | ||
| else | ||
| @warn "Metal backend initialization failed" exception = e | ||
| end | ||
| end | ||
| end | ||
| end | ||
| return nothing | ||
| end | ||
|
|
||
| end # module ReactantMetalExt |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.