From 5058a72cdeefc3ba042db39143046ceba7eca1c8 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Mon, 1 Sep 2025 14:48:40 -0700 Subject: [PATCH 1/3] Add Metal.jl support for lu_instance MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add ArrayInterfaceMetalExt extension with proper lu_instance implementation - Fix issue where MtlArray lu_instance failed due to 0x0 matrix creation - Add Metal.jl to weakdeps and extensions in Project.toml - Add Metal.jl test to verify lu_instance works properly - Follow same pattern as CUDA extension for GPU array support Fixes #467 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- Project.toml | 5 +++++ ext/ArrayInterfaceMetalExt.jl | 15 +++++++++++++++ test/gpu/Project.toml | 3 ++- test/gpu/metal.jl | 7 +++++++ test/runtests.jl | 1 + 5 files changed, 30 insertions(+), 1 deletion(-) create mode 100644 ext/ArrayInterfaceMetalExt.jl create mode 100644 test/gpu/metal.jl diff --git a/Project.toml b/Project.toml index c4a079e5b..af8d379c1 100644 --- a/Project.toml +++ b/Project.toml @@ -14,6 +14,7 @@ CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e" ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" +Metal = "dde4c033-4e86-420c-a63e-0dd931031962" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" @@ -27,6 +28,7 @@ ArrayInterfaceCUDSSExt = "CUDSS" ArrayInterfaceChainRulesCoreExt = "ChainRulesCore" ArrayInterfaceChainRulesExt = "ChainRules" ArrayInterfaceGPUArraysCoreExt = "GPUArraysCore" +ArrayInterfaceMetalExt = "Metal" ArrayInterfaceReverseDiffExt = "ReverseDiff" ArrayInterfaceSparseArraysExt = "SparseArrays" ArrayInterfaceStaticArraysCoreExt = "StaticArraysCore" @@ -43,6 +45,7 @@ ChainRulesCore = "1" ChainRulesTestUtils = "1" GPUArraysCore = "0.1, 0.2" LinearAlgebra = "1.10" +Metal = "1" ReverseDiff = "1" SparseArrays = "1.10" StaticArraysCore = "1" @@ -59,6 +62,8 @@ ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" +JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" +Metal = "dde4c033-4e86-420c-a63e-0dd931031962" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" diff --git a/ext/ArrayInterfaceMetalExt.jl b/ext/ArrayInterfaceMetalExt.jl new file mode 100644 index 000000000..c27f67f24 --- /dev/null +++ b/ext/ArrayInterfaceMetalExt.jl @@ -0,0 +1,15 @@ +module ArrayInterfaceMetalExt + +using ArrayInterface +using Metal +using LinearAlgebra + +function ArrayInterface.lu_instance(A::MtlMatrix{T}) where {T} + ipiv = MtlVector{Int32}(undef, 0) + info = zero(Int) + return LinearAlgebra.LU(similar(A, 0, 0), ipiv, info) +end + +ArrayInterface.device(::Type{<:Metal.MtlArray}) = ArrayInterface.GPU() + +end # module \ No newline at end of file diff --git a/test/gpu/Project.toml b/test/gpu/Project.toml index 0853fd7bb..353250f1b 100644 --- a/test/gpu/Project.toml +++ b/test/gpu/Project.toml @@ -1,2 +1,3 @@ [deps] -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" \ No newline at end of file +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Metal = "dde4c033-4e86-420c-a63e-0dd931031962" \ No newline at end of file diff --git a/test/gpu/metal.jl b/test/gpu/metal.jl new file mode 100644 index 000000000..b1bd33694 --- /dev/null +++ b/test/gpu/metal.jl @@ -0,0 +1,7 @@ +using Metal +using ArrayInterface + +using Test + +# Test whether lu_instance throws an error when invoked with a Metal.jl gpu array +@test !isa(try ArrayInterface.lu_instance(MtlArray([1.f0 1.f0; 1.f0 1.f0])) catch ex ex end, Exception) \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 8a5d7b363..ecd329fc1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -21,5 +21,6 @@ end if GROUP == "GPU" activate_gpu_env() @time @safetestset "CUDA" begin include("gpu/cuda.jl") end + @time @safetestset "Metal" begin include("gpu/metal.jl") end end end From 72bb8e1d42965987a208e0b3bbc0ea1c4f42887e Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Mon, 1 Sep 2025 18:14:29 -0700 Subject: [PATCH 2/3] Improve Metal.jl tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Simplify test to directly check that lu_instance returns an LU type - Make Metal tests only run on macOS using Sys.isapple() check - Add LinearAlgebra import for LU type check 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- test/gpu/metal.jl | 5 +++-- test/runtests.jl | 4 +++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/test/gpu/metal.jl b/test/gpu/metal.jl index b1bd33694..c819a05ee 100644 --- a/test/gpu/metal.jl +++ b/test/gpu/metal.jl @@ -1,7 +1,8 @@ using Metal using ArrayInterface +using LinearAlgebra using Test -# Test whether lu_instance throws an error when invoked with a Metal.jl gpu array -@test !isa(try ArrayInterface.lu_instance(MtlArray([1.f0 1.f0; 1.f0 1.f0])) catch ex ex end, Exception) \ No newline at end of file +# Test that lu_instance works with Metal.jl gpu arrays +@test isa(ArrayInterface.lu_instance(MtlArray([1.f0 1.f0; 1.f0 1.f0])), LU) \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index ecd329fc1..cd47c19fb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -21,6 +21,8 @@ end if GROUP == "GPU" activate_gpu_env() @time @safetestset "CUDA" begin include("gpu/cuda.jl") end - @time @safetestset "Metal" begin include("gpu/metal.jl") end + if Sys.isapple() + @time @safetestset "Metal" begin include("gpu/metal.jl") end + end end end From d7d7b7b93112a70950c7b4525362e5aec846f4bd Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Mon, 1 Sep 2025 21:17:27 -0400 Subject: [PATCH 3/3] Update ci.yml --- .github/workflows/ci.yml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7c2aa6ab6..8d933f0f0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,6 +15,10 @@ jobs: - Core version: - '1' + os: + - ubuntu-latest + - macos-latest + - windows-latest steps: - uses: actions/checkout@v5 - uses: julia-actions/setup-julia@v1 @@ -37,4 +41,4 @@ jobs: - uses: julia-actions/julia-processcoverage@v1 - uses: codecov/codecov-action@v5 with: - file: lcov.info \ No newline at end of file + file: lcov.info