diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7c2aa6ab6..8d933f0f0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,6 +15,10 @@ jobs: - Core version: - '1' + os: + - ubuntu-latest + - macos-latest + - windows-latest steps: - uses: actions/checkout@v5 - uses: julia-actions/setup-julia@v1 @@ -37,4 +41,4 @@ jobs: - uses: julia-actions/julia-processcoverage@v1 - uses: codecov/codecov-action@v5 with: - file: lcov.info \ No newline at end of file + file: lcov.info diff --git a/Project.toml b/Project.toml index c4a079e5b..af8d379c1 100644 --- a/Project.toml +++ b/Project.toml @@ -14,6 +14,7 @@ CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e" ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" +Metal = "dde4c033-4e86-420c-a63e-0dd931031962" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" @@ -27,6 +28,7 @@ ArrayInterfaceCUDSSExt = "CUDSS" ArrayInterfaceChainRulesCoreExt = "ChainRulesCore" ArrayInterfaceChainRulesExt = "ChainRules" ArrayInterfaceGPUArraysCoreExt = "GPUArraysCore" +ArrayInterfaceMetalExt = "Metal" ArrayInterfaceReverseDiffExt = "ReverseDiff" ArrayInterfaceSparseArraysExt = "SparseArrays" ArrayInterfaceStaticArraysCoreExt = "StaticArraysCore" @@ -43,6 +45,7 @@ ChainRulesCore = "1" ChainRulesTestUtils = "1" GPUArraysCore = "0.1, 0.2" LinearAlgebra = "1.10" +Metal = "1" ReverseDiff = "1" SparseArrays = "1.10" StaticArraysCore = "1" @@ -59,6 +62,8 @@ ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" +JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" +Metal = "dde4c033-4e86-420c-a63e-0dd931031962" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" diff --git a/ext/ArrayInterfaceMetalExt.jl b/ext/ArrayInterfaceMetalExt.jl new file mode 100644 index 000000000..c27f67f24 --- /dev/null +++ b/ext/ArrayInterfaceMetalExt.jl @@ -0,0 +1,15 @@ +module ArrayInterfaceMetalExt + +using ArrayInterface +using Metal +using LinearAlgebra + +function ArrayInterface.lu_instance(A::MtlMatrix{T}) where {T} + ipiv = MtlVector{Int32}(undef, 0) + info = zero(Int) + return LinearAlgebra.LU(similar(A, 0, 0), ipiv, info) +end + +ArrayInterface.device(::Type{<:Metal.MtlArray}) = ArrayInterface.GPU() + +end # module \ No newline at end of file diff --git a/test/gpu/Project.toml b/test/gpu/Project.toml index 0853fd7bb..353250f1b 100644 --- a/test/gpu/Project.toml +++ b/test/gpu/Project.toml @@ -1,2 +1,3 @@ [deps] -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" \ No newline at end of file +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Metal = "dde4c033-4e86-420c-a63e-0dd931031962" \ No newline at end of file diff --git a/test/gpu/metal.jl b/test/gpu/metal.jl new file mode 100644 index 000000000..c819a05ee --- /dev/null +++ b/test/gpu/metal.jl @@ -0,0 +1,8 @@ +using Metal +using ArrayInterface +using LinearAlgebra + +using Test + +# Test that lu_instance works with Metal.jl gpu arrays +@test isa(ArrayInterface.lu_instance(MtlArray([1.f0 1.f0; 1.f0 1.f0])), LU) \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 8a5d7b363..cd47c19fb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -21,5 +21,8 @@ end if GROUP == "GPU" activate_gpu_env() @time @safetestset "CUDA" begin include("gpu/cuda.jl") end + if Sys.isapple() + @time @safetestset "Metal" begin include("gpu/metal.jl") end + end end end