diff --git a/Project.toml b/Project.toml index 05c9d134da..7214266826 100644 --- a/Project.toml +++ b/Project.toml @@ -48,6 +48,7 @@ KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605" +OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" @@ -73,6 +74,7 @@ ReactantKernelAbstractionsExt = "KernelAbstractions" ReactantMPIExt = "MPI" ReactantNNlibExt = ["NNlib", "Statistics"] ReactantNPZExt = "NPZ" +ReactantOMEinsumExt = "OMEinsum" ReactantOffsetArraysExt = "OffsetArrays" ReactantOneHotArraysExt = "OneHotArrays" ReactantPythonCallExt = "PythonCall" @@ -111,6 +113,7 @@ LinearAlgebra = "1.10" MPI = "0.20" NNlib = "0.9.26" NPZ = "0.4" +OMEinsum = "0.9" OffsetArrays = "1" OneHotArrays = "0.2.10" OrderedCollections = "1.1" diff --git a/ext/ReactantOMEinsumExt.jl b/ext/ReactantOMEinsumExt.jl new file mode 100644 index 0000000000..bf0ef8a5d7 --- /dev/null +++ b/ext/ReactantOMEinsumExt.jl @@ -0,0 +1,98 @@ +module ReactantOMEinsumExt + +using Reactant +using Reactant: @reactant_overlay, looped_any, use_overlayed_version, @opcall +using Reactant.TracedUtils: get_mlir_data, set_mlir_data! +using OMEinsum +using OMEinsum: _analyze_binary_input + +@reactant_overlay @noinline function OMEinsum.get_output_array(xs, size, fillzero) + # we ignore fillzero here, as it's easier for us to zero-initialize arrays + if looped_any(use_overlayed_version, xs) + T = promote_type(map(eltype, xs)...) + return @opcall fill(zero(T), size) + else + return Reactant.call_with_native(OMEinsum.get_output_array, xs, size, fillzero) + end +end + +@reactant_overlay @noinline function OMEinsum.unary_einsum!( + ::OMEinsum.Diag, ix, iy, x::AbstractArray, y::AbstractArray, sx, sy +) + if use_overlayed_version(x) + @assert use_overlayed_version(y) + # TODO we probably would prefer a more efficient implementation here... like a reduction or a specialized op + @debug "Diag" ix => iy size.(x) + return @allowscalar OMEinsum.compactify!(y, x, ix, iy, sx, sy) + else + return Reactant.call_with_native( + OMEinsum.unary_einsum!, OMEinsum.Diag(), ix, iy, x, y, sx, sy + ) + end +end + +@reactant_overlay @noinline function OMEinsum.tensorpermute!( + C::AbstractArray{T,N}, A::AbstractArray{T,N}, perm, sx, sy +) where {T,N} + if use_overlayed_version(A) + @assert use_overlayed_version(C) + permv = collect(perm) + sx´ = Reactant.promote_to(T, sx) + sy´ = Reactant.promote_to(T, sy) + res = sy´ * C + sx´ * @opcall transpose(A, permv) + set_mlir_data!(C, get_mlir_data(res)) + return C + else + return Reactant.call_with_native(OMEinsum.tensorpermute!, C, A, perm, sx, sy) + end +end + +@reactant_overlay @noinline function OMEinsum.einsum!( + ixs, iy, @nospecialize(xs::NTuple{2,Any}), @nospecialize(y), sx, sy, size_dict +) + if looped_any(use_overlayed_version, xs) + @assert use_overlayed_version(y) + + # shortcut for scalar multiplication + if looped_any(x -> x isa Number, xs) + c = sy * y + sx * xs[1] * xs[2] + set_mlir_data!(y, get_mlir_data(c)) + return y + end + + LT = keytype(size_dict) + a, b = xs + ia, ib = collect.(LT, ixs) + iyv = collect(LT, iy) + inner, a_outer, b_outer, batch = _analyze_binary_input(ia, ib, iyv) + + contracting_dimensions = ( + Int[findfirst(==(i), ia) for i in inner], + Int[findfirst(==(i), ib) for i in inner], + ) + batching_dimensions = ( + Int[findfirst(==(i), ia) for i in batch], + Int[findfirst(==(i), ib) for i in batch], + ) + + c = @opcall dot_general(a, b; contracting_dimensions, batching_dimensions) + + # permute dims to match iy + ic = vcat(batch, a_outer, b_outer) + perm = Int[findfirst(==(i), ic) for i in iyv] + c = @opcall transpose(c, perm) + @assert size(c) == size(y) + @assert eltype(c) == eltype(y) + + # just like GEMM, we do: y = sy * y + sx * c + c = sy * y + sx * c + set_mlir_data!(y, get_mlir_data(c)) + return y + else + return Reactant.call_with_native( + OMEinsum.einsum!, ixs, iy, xs, y, sx, sy, size_dict + ) + end +end + +end diff --git a/src/Indexing.jl b/src/Indexing.jl index 1af2a85420..3680d94374 100644 --- a/src/Indexing.jl +++ b/src/Indexing.jl @@ -640,6 +640,28 @@ function _setindex_scalar_cartesian!( return a end +function _setindex_scalar_cartesian!( + a::TracedRArray{T,N}, v::TracedRNumber, index::CartesianIndex{N} +) where {T,N} + assertscalar("setindex!(::TracedRArray, v, ::CartesianIndex{N})") + res = @opcall( + reshape( + @opcall( + dynamic_update_slice( + a, + Reactant.broadcast_to_size( + Reactant.promote_to(TracedRNumber{T}, v), ntuple(Returns(1), N) + ), + collect(Int64, index.I), + ) + ), + collect(size(a)), + ) + ) + TracedUtils.set_mlir_data!(a, TracedUtils.get_mlir_data(res)) + return a +end + function _setindex_linear!(a::TracedRArray{T,N}, v, indices::AbstractArray) where {T,N} if !(indices isa Reactant.TracedType) && TracedUtils.__contiguous_indices(vec(indices)) res = @opcall( diff --git a/test/Project.toml b/test/Project.toml index dc0d5c9b46..ccd5f296e0 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -23,6 +23,7 @@ MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" MethodAnalysis = "85b6ec6f-f7df-4429-9514-a64bcd9ee824" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605" +OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" diff --git a/test/integration/omeinsum.jl b/test/integration/omeinsum.jl new file mode 100644 index 0000000000..9898e2b8d5 --- /dev/null +++ b/test/integration/omeinsum.jl @@ -0,0 +1,216 @@ +using Test +using Reactant +using OMEinsum + +@testset "sum" begin + a = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 3) + a_re = Reactant.to_rarray(a) + + f = ein"ij->" + c = f(a) + c_re = @jit f(a_re) + @test c ≈ c_re + + f = ein"ij->i" + c = f(a) + c_re = @jit f(a_re) + @test c ≈ c_re + + f = ein"ij->j" + c = f(a) + c_re = @jit f(a_re) + @test c ≈ c_re +end + +@testset "diagonal" begin + a = Reactant.TestUtils.construct_test_array(ComplexF32, 3, 3) + a_re = Reactant.to_rarray(a) + + f = ein"ii->i" + c = f(a) + c_re = @jit f(a_re) + @test c ≈ c_re + + # NOTE currently broken + # hyper-diagonal + # a = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 2, 2) + # a_re = Reactant.to_rarray(a) + # f = ein"iii->i" + # c = f(a) + # c_re = @jit f(a_re) + # @test c ≈ c_re +end + +# NOTE currently broken +# @testset "trace" begin +# a = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 2) +# a_re = Reactant.to_rarray(a) + +# f = ein"ii->" +# c = f(a) +# c_re = @jit f(a_re) +# @test c ≈ c_re + +# # partial trace +# a = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 2, 2, 2) +# a_re = Reactant.to_rarray(a) +# f = ein"iijk->jk" +# c = f(a) +# c_re = @jit f(a_re) +# @test c ≈ c_re +# end + +@testset "transpose" begin + a = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 3) + a_re = Reactant.to_rarray(a) + + f = ein"ij->ji" + c = f(a) + c_re = @jit f(a_re) + @test c ≈ c_re + + a = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 3, 4, 5) + a_re = Reactant.to_rarray(a) + + f = ein"ijkl->jilk" + c = f(a) + c_re = @jit f(a_re) + @test c ≈ c_re +end + +@testset "matmul" begin + a = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 3) + b = Reactant.TestUtils.construct_test_array(ComplexF32, 3, 4) + a_re = Reactant.to_rarray(a) + b_re = Reactant.to_rarray(b) + + f = ein"ij,jk->ik" + c = f(a, b) + c_re = @jit f(a_re, b_re) + @test c ≈ c_re + + @testset "with different eltype" begin + b = Reactant.TestUtils.construct_test_array(Float32, 3, 4) + b_re = Reactant.to_rarray(b) + + f = ein"ij,jk->ik" + c = f(a, b) + c_re = @jit f(a_re, b_re) + @test c ≈ c_re + end +end + +@testset "hadamard product" begin + a = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 3) + b = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 3) + a_re = Reactant.to_rarray(a) + b_re = Reactant.to_rarray(b) + + f = ein"ij,ij->ij" + c = f(a, b) + c_re = @jit f(a_re, b_re) + @test c ≈ c_re + + b = Reactant.TestUtils.construct_test_array(ComplexF32, 3, 2) + b_re = Reactant.to_rarray(b) + f = ein"ij,ji->ij" + c = f(a, b) + c_re = @jit f(a_re, b_re) + @test c ≈ c_re +end + +@testset "inner product" begin + a = Reactant.TestUtils.construct_test_array(ComplexF32, 4, 3) + b = Reactant.TestUtils.construct_test_array(ComplexF32, 3, 4) + a_re = Reactant.to_rarray(a) + b_re = Reactant.to_rarray(b) + + f = ein"ij,ji->" + c = f(a, b) + c_re = @jit f(a_re, b_re) + @test c ≈ c_re +end + +@testset "outer product" begin + a = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 3) + b = Reactant.TestUtils.construct_test_array(ComplexF32, 4, 5) + a_re = Reactant.to_rarray(a) + b_re = Reactant.to_rarray(b) + + f = ein"ij,kl->ijkl" + c = f(a, b) + c_re = @jit f(a_re, b_re) + @test c ≈ c_re + + f = ein"ij,kl->klij" + c = f(a, b) + c_re = @jit f(a_re, b_re) + @test c ≈ c_re + + f = ein"ij,kl->ikjl" + c = f(a, b) + c_re = @jit f(a_re, b_re) + @test c ≈ c_re +end + +@testset "scale" begin + a = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 3) + b = fill(ComplexF32(2.0)) + a_re = Reactant.to_rarray(a) + b_re = Reactant.to_rarray(b) + + f = ein"ij,->ij" + c = f(a, b) + c_re = @jit f(a_re, b_re) + @test c ≈ c_re +end + +@testset "batch matmul" begin + a = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 3, 6) + b = Reactant.TestUtils.construct_test_array(ComplexF32, 3, 4, 6) + a_re = Reactant.to_rarray(a) + b_re = Reactant.to_rarray(b) + + f = ein"ijb,jkb->ikb" + c = f(a, b) + c_re = @jit f(a_re, b_re) + @test c ≈ c_re + + f = ein"ijb,jkb->bik" + c = f(a, b) + c_re = @jit f(a_re, b_re) + @test c ≈ c_re +end + +# NOTE currently broken +# @testset "star contraction" begin +# a = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 3) +# b = Reactant.TestUtils.construct_test_array(ComplexF32, 4, 3) +# c = Reactant.TestUtils.construct_test_array(ComplexF32, 5, 3) +# a_re = Reactant.to_rarray(a) +# b_re = Reactant.to_rarray(b) +# c_re = Reactant.to_rarray(c) + +# f = ein"ai,bi,ci->abc" +# d = f(a, b, c) +# d_re = @jit f(a_re, b_re, c_re) +# @test d ≈ d_re +# end + +@testset "tensor contraction" begin + a = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 4, 3) + b = Reactant.TestUtils.construct_test_array(ComplexF32, 3, 5, 4) + a_re = Reactant.to_rarray(a) + b_re = Reactant.to_rarray(b) + + f = ein"ijk,klj->il" + c = f(a, b) + c_re = @jit f(a_re, b_re) + @test c ≈ c_re + + # contraction of NOT all common indices + f = ein"ijk,klj->ikl" + c = f(a, b) + c_re = @jit f(a_re, b_re) + @test c ≈ c_re +end