Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
Expand All @@ -73,6 +74,7 @@ ReactantKernelAbstractionsExt = "KernelAbstractions"
ReactantMPIExt = "MPI"
ReactantNNlibExt = ["NNlib", "Statistics"]
ReactantNPZExt = "NPZ"
ReactantOMEinsumExt = "OMEinsum"
ReactantOffsetArraysExt = "OffsetArrays"
ReactantOneHotArraysExt = "OneHotArrays"
ReactantPythonCallExt = "PythonCall"
Expand Down Expand Up @@ -111,6 +113,7 @@ LinearAlgebra = "1.10"
MPI = "0.20"
NNlib = "0.9.26"
NPZ = "0.4"
OMEinsum = "0.9"
OffsetArrays = "1"
OneHotArrays = "0.2.10"
OrderedCollections = "1.1"
Expand Down
98 changes: 98 additions & 0 deletions ext/ReactantOMEinsumExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
module ReactantOMEinsumExt

using Reactant
using Reactant: @reactant_overlay, looped_any, use_overlayed_version, @opcall
using Reactant.TracedUtils: get_mlir_data, set_mlir_data!
using OMEinsum
using OMEinsum: _analyze_binary_input

@reactant_overlay @noinline function OMEinsum.get_output_array(xs, size, fillzero)
# we ignore fillzero here, as it's easier for us to zero-initialize arrays
if looped_any(use_overlayed_version, xs)
T = promote_type(map(eltype, xs)...)
return @opcall fill(zero(T), size)
else
return Reactant.call_with_native(OMEinsum.get_output_array, xs, size, fillzero)
end
end

@reactant_overlay @noinline function OMEinsum.unary_einsum!(
::OMEinsum.Diag, ix, iy, x::AbstractArray, y::AbstractArray, sx, sy
)
if use_overlayed_version(x)
@assert use_overlayed_version(y)
# TODO we probably would prefer a more efficient implementation here... like a reduction or a specialized op
@debug "Diag" ix => iy size.(x)
return @allowscalar OMEinsum.compactify!(y, x, ix, iy, sx, sy)
else
return Reactant.call_with_native(
OMEinsum.unary_einsum!, OMEinsum.Diag(), ix, iy, x, y, sx, sy
)
end
end

@reactant_overlay @noinline function OMEinsum.tensorpermute!(
C::AbstractArray{T,N}, A::AbstractArray{T,N}, perm, sx, sy
) where {T,N}
if use_overlayed_version(A)
@assert use_overlayed_version(C)
permv = collect(perm)
sx´ = Reactant.promote_to(T, sx)
sy´ = Reactant.promote_to(T, sy)
res = sy´ * C + sx´ * @opcall transpose(A, permv)
set_mlir_data!(C, get_mlir_data(res))
return C
else
return Reactant.call_with_native(OMEinsum.tensorpermute!, C, A, perm, sx, sy)
end
end

@reactant_overlay @noinline function OMEinsum.einsum!(
ixs, iy, @nospecialize(xs::NTuple{2,Any}), @nospecialize(y), sx, sy, size_dict
)
if looped_any(use_overlayed_version, xs)
@assert use_overlayed_version(y)

# shortcut for scalar multiplication
if looped_any(x -> x isa Number, xs)
c = sy * y + sx * xs[1] * xs[2]
set_mlir_data!(y, get_mlir_data(c))
return y
end

LT = keytype(size_dict)
a, b = xs
ia, ib = collect.(LT, ixs)
iyv = collect(LT, iy)
inner, a_outer, b_outer, batch = _analyze_binary_input(ia, ib, iyv)

contracting_dimensions = (
Int[findfirst(==(i), ia) for i in inner],
Int[findfirst(==(i), ib) for i in inner],
)
batching_dimensions = (
Int[findfirst(==(i), ia) for i in batch],
Int[findfirst(==(i), ib) for i in batch],
)

c = @opcall dot_general(a, b; contracting_dimensions, batching_dimensions)

# permute dims to match iy
ic = vcat(batch, a_outer, b_outer)
perm = Int[findfirst(==(i), ic) for i in iyv]
c = @opcall transpose(c, perm)
@assert size(c) == size(y)
@assert eltype(c) == eltype(y)

# just like GEMM, we do: y = sy * y + sx * c
c = sy * y + sx * c
set_mlir_data!(y, get_mlir_data(c))
return y
else
return Reactant.call_with_native(
OMEinsum.einsum!, ixs, iy, xs, y, sx, sy, size_dict
)
end
end

end
22 changes: 22 additions & 0 deletions src/Indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,28 @@ function _setindex_scalar_cartesian!(
return a
end

function _setindex_scalar_cartesian!(
a::TracedRArray{T,N}, v::TracedRNumber, index::CartesianIndex{N}
) where {T,N}
assertscalar("setindex!(::TracedRArray, v, ::CartesianIndex{N})")
res = @opcall(
reshape(
@opcall(
dynamic_update_slice(
a,
Reactant.broadcast_to_size(
Reactant.promote_to(TracedRNumber{T}, v), ntuple(Returns(1), N)
),
collect(Int64, index.I),
)
),
collect(size(a)),
)
)
TracedUtils.set_mlir_data!(a, TracedUtils.get_mlir_data(res))
return a
end

function _setindex_linear!(a::TracedRArray{T,N}, v, indices::AbstractArray) where {T,N}
if !(indices isa Reactant.TracedType) && TracedUtils.__contiguous_indices(vec(indices))
res = @opcall(
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
MethodAnalysis = "85b6ec6f-f7df-4429-9514-a64bcd9ee824"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Expand Down
216 changes: 216 additions & 0 deletions test/integration/omeinsum.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
using Test
using Reactant
using OMEinsum

@testset "sum" begin
a = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 3)
a_re = Reactant.to_rarray(a)

f = ein"ij->"
c = f(a)
c_re = @jit f(a_re)
@test c ≈ c_re

f = ein"ij->i"
c = f(a)
c_re = @jit f(a_re)
@test c ≈ c_re

f = ein"ij->j"
c = f(a)
c_re = @jit f(a_re)
@test c ≈ c_re
end

@testset "diagonal" begin
a = Reactant.TestUtils.construct_test_array(ComplexF32, 3, 3)
a_re = Reactant.to_rarray(a)

f = ein"ii->i"
c = f(a)
c_re = @jit f(a_re)
@test c ≈ c_re

# NOTE currently broken
# hyper-diagonal
# a = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 2, 2)
# a_re = Reactant.to_rarray(a)
# f = ein"iii->i"
# c = f(a)
# c_re = @jit f(a_re)
# @test c ≈ c_re
end

# NOTE currently broken
# @testset "trace" begin
# a = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 2)
# a_re = Reactant.to_rarray(a)

# f = ein"ii->"
# c = f(a)
# c_re = @jit f(a_re)
# @test c ≈ c_re

# # partial trace
# a = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 2, 2, 2)
# a_re = Reactant.to_rarray(a)
# f = ein"iijk->jk"
# c = f(a)
# c_re = @jit f(a_re)
# @test c ≈ c_re
# end

@testset "transpose" begin
a = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 3)
a_re = Reactant.to_rarray(a)

f = ein"ij->ji"
c = f(a)
c_re = @jit f(a_re)
@test c ≈ c_re

a = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 3, 4, 5)
a_re = Reactant.to_rarray(a)

f = ein"ijkl->jilk"
c = f(a)
c_re = @jit f(a_re)
@test c ≈ c_re
end

@testset "matmul" begin
a = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 3)
b = Reactant.TestUtils.construct_test_array(ComplexF32, 3, 4)
a_re = Reactant.to_rarray(a)
b_re = Reactant.to_rarray(b)

f = ein"ij,jk->ik"
c = f(a, b)
c_re = @jit f(a_re, b_re)
@test c ≈ c_re

@testset "with different eltype" begin
b = Reactant.TestUtils.construct_test_array(Float32, 3, 4)
b_re = Reactant.to_rarray(b)

f = ein"ij,jk->ik"
c = f(a, b)
c_re = @jit f(a_re, b_re)
@test c ≈ c_re
end
end

@testset "hadamard product" begin
a = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 3)
b = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 3)
a_re = Reactant.to_rarray(a)
b_re = Reactant.to_rarray(b)

f = ein"ij,ij->ij"
c = f(a, b)
c_re = @jit f(a_re, b_re)
@test c ≈ c_re

b = Reactant.TestUtils.construct_test_array(ComplexF32, 3, 2)
b_re = Reactant.to_rarray(b)
f = ein"ij,ji->ij"
c = f(a, b)
c_re = @jit f(a_re, b_re)
@test c ≈ c_re
end

@testset "inner product" begin
a = Reactant.TestUtils.construct_test_array(ComplexF32, 4, 3)
b = Reactant.TestUtils.construct_test_array(ComplexF32, 3, 4)
a_re = Reactant.to_rarray(a)
b_re = Reactant.to_rarray(b)

f = ein"ij,ji->"
c = f(a, b)
c_re = @jit f(a_re, b_re)
@test c ≈ c_re
end

@testset "outer product" begin
a = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 3)
b = Reactant.TestUtils.construct_test_array(ComplexF32, 4, 5)
a_re = Reactant.to_rarray(a)
b_re = Reactant.to_rarray(b)

f = ein"ij,kl->ijkl"
c = f(a, b)
c_re = @jit f(a_re, b_re)
@test c ≈ c_re

f = ein"ij,kl->klij"
c = f(a, b)
c_re = @jit f(a_re, b_re)
@test c ≈ c_re

f = ein"ij,kl->ikjl"
c = f(a, b)
c_re = @jit f(a_re, b_re)
@test c ≈ c_re
end

@testset "scale" begin
a = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 3)
b = fill(ComplexF32(2.0))
a_re = Reactant.to_rarray(a)
b_re = Reactant.to_rarray(b)

f = ein"ij,->ij"
c = f(a, b)
c_re = @jit f(a_re, b_re)
@test c ≈ c_re
end

@testset "batch matmul" begin
a = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 3, 6)
b = Reactant.TestUtils.construct_test_array(ComplexF32, 3, 4, 6)
a_re = Reactant.to_rarray(a)
b_re = Reactant.to_rarray(b)

f = ein"ijb,jkb->ikb"
c = f(a, b)
c_re = @jit f(a_re, b_re)
@test c ≈ c_re

f = ein"ijb,jkb->bik"
c = f(a, b)
c_re = @jit f(a_re, b_re)
@test c ≈ c_re
end

# NOTE currently broken
# @testset "star contraction" begin
# a = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 3)
# b = Reactant.TestUtils.construct_test_array(ComplexF32, 4, 3)
# c = Reactant.TestUtils.construct_test_array(ComplexF32, 5, 3)
# a_re = Reactant.to_rarray(a)
# b_re = Reactant.to_rarray(b)
# c_re = Reactant.to_rarray(c)

# f = ein"ai,bi,ci->abc"
# d = f(a, b, c)
# d_re = @jit f(a_re, b_re, c_re)
# @test d ≈ d_re
# end

@testset "tensor contraction" begin
a = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 4, 3)
b = Reactant.TestUtils.construct_test_array(ComplexF32, 3, 5, 4)
a_re = Reactant.to_rarray(a)
b_re = Reactant.to_rarray(b)

f = ein"ijk,klj->il"
c = f(a, b)
c_re = @jit f(a_re, b_re)
@test c ≈ c_re

# contraction of NOT all common indices
f = ein"ijk,klj->ikl"
c = f(a, b)
c_re = @jit f(a_re, b_re)
@test c ≈ c_re
end
Loading