Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ext/MuscleDaggerExt/binary_einsum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using Muscle
using Dagger: Dagger, ArrayOp, Context, ArrayDomain, EagerThunk, DArray
using LinearAlgebra

function Muscle.binary_einsum(::Muscle.BackendDagger, inds_c, a, b)
function Muscle.binary_einsum(::Muscle.BackendDagger, inds_c, a, b; kwargs...)
op = BinaryEinsum(inds_c, a, b)
darray = Dagger._to_darray(op)
return Tensor(darray, inds_c)
Expand Down
2 changes: 1 addition & 1 deletion ext/MuscleOMEinsumExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ function Muscle.unary_einsum!(::BackendOMEinsum, y, x)
return y
end

function Muscle.binary_einsum(::Muscle.BackendOMEinsum, inds_c, a, b)
function Muscle.binary_einsum(::Muscle.BackendOMEinsum, inds_c, a, b; kwargs...)
size_dict = Dict{Index,Int}()
for (ind, ind_size) in Iterators.flatten([inds(a) .=> size(a), inds(b) .=> size(b)])
size_dict[ind] = ind_size
Expand Down
2 changes: 1 addition & 1 deletion ext/MuscleReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ function Muscle.unary_einsum(
end

Base.@nospecializeinfer @noinline function Muscle.binary_einsum(
::BackendReactant, inds_c, @nospecialize(a::Tensor{TracedRNumber{Ta}}), @nospecialize(b::Tensor{TracedRNumber{Tb}})
::BackendReactant, inds_c, @nospecialize(a::Tensor{TracedRNumber{Ta}}), @nospecialize(b::Tensor{TracedRNumber{Tb}}); kwargs...
) where {Ta,Tb}
out = inds_c
dims = setdiff(inds(a) ∩ inds(b), out)
Expand Down
4 changes: 2 additions & 2 deletions ext/MuscleStridedExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ function Muscle.choose_backend_rule(
BackendStrided()
end

function Muscle.binary_einsum(::BackendStrided, inds_c, a::Tensor, b::Tensor)
function Muscle.binary_einsum(::BackendStrided, inds_c, a::Tensor, b::Tensor; kwargs...)
binary_einsum(
BackendStrided(),
inds_c,
Expand All @@ -34,7 +34,7 @@ function Muscle.binary_einsum(::BackendStrided, inds_c, a::Tensor, b::Tensor)
end

function Muscle.binary_einsum(
::BackendStrided, inds_c, a::Tensor{Ta,Na,<:StridedView}, b::Tensor{Tb,Nb,<:StridedView}
::BackendStrided, inds_c, a::Tensor{Ta,Na,<:StridedView}, b::Tensor{Tb,Nb,<:StridedView}; kwargs...
) where {Ta,Tb,Na,Nb}
inds_contract = inds(a) ∩ inds(b)
inds_left = setdiff(inds(a), inds_contract)
Expand Down
18 changes: 13 additions & 5 deletions src/Operations/binary_einsum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,17 @@ choose_backend_rule(::typeof(binary_einsum!), ::DomainCUDA, ::DomainCUDA, ::Doma
choose_backend_rule(::typeof(binary_einsum!), ::DomainReactant, ::DomainReactant, ::DomainReactant) = BackendReactant()

function binary_einsum(a::Tensor, b::Tensor; dims=(∩(inds(a), inds(b))), out=nothing)

reorder_inds = false
inds_sum = ∩(dims, inds(a), inds(b))

inds_c = if isnothing(out)
setdiff(inds(a) ∪ inds(b), inds_sum isa Base.AbstractVecOrTuple ? inds_sum : [inds_sum])
else
reorder_inds = true
out
end

backend = choose_backend(binary_einsum, parent(a), parent(b))
# if ismissing(backend)
# @warn "No backend found for binary_einsum(::$(typeof(a)), ::$(typeof(b))), so unwrapping data"
Expand All @@ -47,10 +50,12 @@ function binary_einsum(a::Tensor, b::Tensor; dims=(∩(inds(a), inds(b))), out=n
# backend = choose_backend(binary_einsum, data_a, data_b)
# end

return binary_einsum(backend, inds_c, a, b)
return binary_einsum(backend, inds_c, a, b; reorder_inds)


Comment on lines +54 to +55
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

end

function binary_einsum(::Backend, inds_c, a, b)
function binary_einsum(::Backend, inds_c, a, b; kwargs...)
throw(ArgumentError("`binary_einsum` not implemented or not loaded for backend $(typeof(a))"))
end

Expand All @@ -73,7 +78,7 @@ function binary_einsum!(::Backend, c, a, b)
throw(ArgumentError("`binary_einsum!` not implemented or not loaded for backend $(typeof(a))"))
end

function binary_einsum(::BackendBase, inds_c, a::Tensor, b::Tensor)
function binary_einsum(::BackendBase, inds_c, a::Tensor, b::Tensor; reorder_inds=true)
inds_contract = inds(a) ∩ inds(b)
inds_left = setdiff(inds(a), inds_contract)
inds_right = setdiff(inds(b), inds_contract)
Expand All @@ -92,7 +97,10 @@ function binary_einsum(::BackendBase, inds_c, a::Tensor, b::Tensor)
c_mat = a_mat * b_mat

c = Tensor(reshape(c_mat, sizes_left..., sizes_right...), [inds_left; inds_right])
return permutedims(c, inds_c)
if reorder_inds
c = permutedims(c, inds_c)
end
return c
end

function binary_einsum!(::BackendBase, c::Tensor, a::Tensor, b::Tensor)
Expand Down
Loading