Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
Expand All @@ -73,6 +74,7 @@ ReactantKernelAbstractionsExt = "KernelAbstractions"
ReactantMPIExt = "MPI"
ReactantNNlibExt = ["NNlib", "Statistics"]
ReactantNPZExt = "NPZ"
ReactantOMEinsumExt = "OMEinsum"
ReactantOffsetArraysExt = "OffsetArrays"
ReactantOneHotArraysExt = "OneHotArrays"
ReactantPythonCallExt = "PythonCall"
Expand Down Expand Up @@ -111,6 +113,7 @@ LinearAlgebra = "1.10"
MPI = "0.20"
NNlib = "0.9.26"
NPZ = "0.4"
OMEinsum = "0.9"
OffsetArrays = "1"
OneHotArrays = "0.2.10"
OrderedCollections = "1.1"
Expand Down
72 changes: 72 additions & 0 deletions ext/ReactantOMEinsumExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
module ReactantOMEinsumExt

using Reactant
using Reactant: @reactant_overlay, looped_any, use_overlayed_version, @opcall
using OMEinsum
using OMEinsum: _analyze_binary_input

@reactant_overlay @noinline function OMEinsum.get_output_array(xs, size, fillzero)
# we ignore fillzero here, as it's easier for us to zero-initialize arrays
if looped_any(use_overlayed_version, xs)
T = promote_type(map(eltype, xs)...)
return @opcall fill(zero(T), size)
else
return Reactant.call_with_native(OMEinsum.get_output_array, xs, size, fillzero)
end
end

@reactant_overlay @noinline function OMEinsum.einsum!(ixs, iy, @nospecialize(xs::NTuple{1,Any}), @nospecialize(y), sx, sy, size_dict)
if looped_any(use_overlayed_version, xs)
@assert use_overlayed_version(y)
# TODO
error("unary einsum support not implemented yet")
else
return Reactant.call_with_native(OMEinsum.einsum!, ixs, iy, xs, y, sx, sy, size_dict)
end
end

@reactant_overlay @noinline function OMEinsum.einsum!(ixs, iy, @nospecialize(xs::NTuple{2,Any}), @nospecialize(y), sx, sy, size_dict)
if looped_any(use_overlayed_version, xs)
@assert use_overlayed_version(y)

# shortcut for scalar multiplication
if looped_any(x -> x isa Number, xs)
c = sy * y + sx * xs[1] * xs[2]
y.mlir_data = c.mlir_data
return y
end

LT = keytype(size_dict)
a, b = xs
ia, ib = collect.(LT, ixs)
iyv = collect(LT, iy)
inner, a_outer, b_outer, batch = _analyze_binary_input(ia, ib, iyv)

contracting_dimensions = (
Int[findfirst(==(i), ia) for i in inner],
Int[findfirst(==(i), ib) for i in inner],
)
batching_dimensions = (
Int[findfirst(==(i), ia) for i in batch],
Int[findfirst(==(i), ib) for i in batch],
)

c = @opcall dot_general(a, b; contracting_dimensions, batching_dimensions)

# permute dims to match iy
ic = vcat(batch, a_outer, b_outer)
perm = Int[findfirst(==(i), ic) for i in iyv]
c = @opcall transpose(c, perm)
@assert size(c) == size(y)
@assert eltype(c) == eltype(y)

# just like GEMM, we do: y = sy * y + sx * c
c = sy * y + sx * c
y.mlir_data = c.mlir_data
return y
else
return Reactant.call_with_native(OMEinsum.einsum!, ixs, iy, xs, y, sx, sy, size_dict)
end
end

end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
MethodAnalysis = "85b6ec6f-f7df-4429-9514-a64bcd9ee824"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Expand Down
173 changes: 173 additions & 0 deletions test/integration/omeinsum.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
using Test
using Reactant
using OMEinsum

@testset "sum" begin
a = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 3)
a_re = Reactant.to_rarray(a)

f = ein"ij->"
c = f(a)
c_re = @jit f(a_re)
@test c ≈ c_re

f = ein"ij->i"
c = f(a)
c_re = @jit f(a_re)
@test c ≈ c_re

f = ein"ij->j"
c = f(a)
c_re = @jit f(a_re)
@test c ≈ c_re
end

@testset "trace" begin
a = Reactant.TestUtils.construct_test_array(ComplexF32, 3, 3)
a_re = Reactant.to_rarray(a)

f = ein"ii->"
c = f(a)
c_re = @jit f(a_re)
@test c ≈ c_re
end

@testset "transpose" begin
a = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 3)
a_re = Reactant.to_rarray(a)

f = ein"ij->ji"
c = f(a)
c_re = @jit f(a_re)
@test c ≈ c_re

a = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 3, 4, 5)
a_re = Reactant.to_rarray(a)

f = ein"ijkl->jilk"
c = f(a)
c_re = @jit f(a_re)
@test c ≈ c_re
end

@testset "matmul" begin
a = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 3)
b = Reactant.TestUtils.construct_test_array(ComplexF32, 3, 4)
a_re = Reactant.to_rarray(a)
b_re = Reactant.to_rarray(b)

f = ein"ij,jk->ik"
c = f(a, b)
c_re = @jit f(a_re, b_re)
@test c ≈ c_re

@testset "with different eltype" begin
b = Reactant.TestUtils.construct_test_array(Float32, 3, 4)
b_re = Reactant.to_rarray(b)

f = ein"ij,jk->ik"
c = f(a, b)
c_re = @jit f(a_re, b_re)
@test c ≈ c_re
end
end

@testset "hadamard product" begin
a = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 3)
b = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 3)
a_re = Reactant.to_rarray(a)
b_re = Reactant.to_rarray(b)

f = ein"ij,ij->ij"
c = f(a, b)
c_re = @jit f(a_re, b_re)
@test c ≈ c_re

b = Reactant.TestUtils.construct_test_array(ComplexF32, 3, 2)
b_re = Reactant.to_rarray(b)
f = ein"ij,ji->ij"
c = f(a, b)
c_re = @jit f(a_re, b_re)
@test c ≈ c_re
end

@testset "inner product" begin
a = Reactant.TestUtils.construct_test_array(ComplexF32, 4, 3)
b = Reactant.TestUtils.construct_test_array(ComplexF32, 3, 4)
a_re = Reactant.to_rarray(a)
b_re = Reactant.to_rarray(b)

f = ein"ij,ji->"
c = f(a, b)
c_re = @jit f(a_re, b_re)
@test c ≈ c_re
end

@testset "outer product" begin
a = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 3)
b = Reactant.TestUtils.construct_test_array(ComplexF32, 4, 5)
a_re = Reactant.to_rarray(a)
b_re = Reactant.to_rarray(b)

f = ein"ij,kl->ijkl"
c = f(a, b)
c_re = @jit f(a_re, b_re)
@test c ≈ c_re

f = ein"ij,kl->klij"
c = f(a, b)
c_re = @jit f(a_re, b_re)
@test c ≈ c_re

f = ein"ij,kl->ikjl"
c = f(a, b)
c_re = @jit f(a_re, b_re)
@test c ≈ c_re
end

@testset "scale" begin
a = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 3)
b = fill(ComplexF32(2.0))
a_re = Reactant.to_rarray(a)
b_re = Reactant.to_rarray(b)

f = ein"ij,->ij"
c = f(a, b)
c_re = @jit f(a_re, b_re)
@test c ≈ c_re
end

@testset "batch matmul" begin
a = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 3, 6)
b = Reactant.TestUtils.construct_test_array(ComplexF32, 3, 4, 6)
a_re = Reactant.to_rarray(a)
b_re = Reactant.to_rarray(b)

f = ein"ijb,jkb->ikb"
c = f(a, b)
c_re = @jit f(a_re, b_re)
@test c ≈ c_re

f = ein"ijb,jkb->bik"
c = f(a, b)
c_re = @jit f(a_re, b_re)
@test c ≈ c_re
end

@testset "tensor contraction" begin
a = Reactant.TestUtils.construct_test_array(ComplexF32, 2, 4, 3)
b = Reactant.TestUtils.construct_test_array(ComplexF32, 3, 5, 4)
a_re = Reactant.to_rarray(a)
b_re = Reactant.to_rarray(b)

f = ein"ijk,klj->il"
c = f(a, b)
c_re = @jit f(a_re, b_re)
@test c ≈ c_re

# contraction of NOT all common indices
f = ein"ijk,klj->ikl"
c = f(a, b)
c_re = @jit f(a_re, b_re)
@test c ≈ c_re
end
Loading