Skip to content

Commit 45cec8d

Browse files
committed
Revert "Despecialize Ops.dot_general"
This reverts commit d6e7b21.
1 parent d6e7b21 commit 45cec8d

File tree

1 file changed

+60
-4
lines changed

1 file changed

+60
-4
lines changed

src/Ops.jl

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# If you want to add some check or test, the StableHLO spec should be taken as the source of truth, not the Julia or Reactant semantics.
33
# Julia and Reactant semantics should be considered on the higher abstractions that use these ops.
44
module Ops
5-
using Base: @nospecializeinfer
65
using ..MLIR: MLIR
76
using ..MLIR.Dialects: stablehlo, chlo, enzyme
87
using ..Reactant:
@@ -795,9 +794,9 @@ end
795794
# return TracedRArray{T,N}((), res, size(lhs))
796795
# end
797796

798-
Base.@nospecializeinfer @noinline function dot_general(
799-
@nospecialize(lhs::TracedRArray{T1}),
800-
@nospecialize(rhs::TracedRArray{T2});
797+
@noinline function dot_general(
798+
lhs::TracedRArray{T1},
799+
rhs::TracedRArray{T2};
801800
contracting_dimensions,
802801
batching_dimensions=(Int[], Int[]),
803802
precision_config=Reactant.DOT_GENERAL_PRECISION[],
@@ -921,6 +920,63 @@ Base.@nospecializeinfer @noinline function dot_general(
921920
return TracedRArray{resT,length(ressize)}((), res, ressize)
922921
end
923922

923+
@noinline function einsum(
924+
lhs::TracedRArray{T},
925+
rhs::TracedRArray{T};
926+
equation::String,
927+
location=mlir_stacktrace("einsum", @__FILE__, @__LINE__),
928+
) where {T}
929+
Base.depwarn(
930+
"`stablehlo.einsum` is on deprecation process; use `dot_general` instead", :einsum
931+
)
932+
ins, ic = split(equation, "->")
933+
ia, ib = split(ins, ",")
934+
935+
sizea = Dict(c => d for (c, d) in zip(ia, size(lhs)))
936+
sizeb = Dict(c => d for (c, d) in zip(ib, size(rhs)))
937+
sizes = mergewith(sizea, sizeb) do da, db
938+
da == db ? da : error("Invalid dimensions in einsum equation")
939+
end
940+
941+
rsize = Tuple(sizes[i] for i in ic)
942+
result_0 = mlir_type(TracedRArray{T,length(ic)}, rsize)
943+
944+
res = MLIR.IR.result(
945+
stablehlo.einsum(
946+
lhs.mlir_data,
947+
rhs.mlir_data;
948+
result_0,
949+
einsum_config=MLIR.IR.Attribute(equation),
950+
location,
951+
),
952+
)
953+
return TracedRArray{T,length(rsize)}((), res, rsize)
954+
end
955+
956+
# function unary_einsum(
957+
# x::TracedRArray{T};
958+
# equation::String,
959+
# location=mlir_stacktrace(
960+
# "unary_einsum", @__FILE__, @__LINE__
961+
# ),
962+
# ) where {T}
963+
# ia, ic = split(equation, "->")
964+
# sizes = Dict(c => d for (c, d) in zip(ia, size(x)))
965+
# rsize = Tuple(sizes[i] for i in ic)
966+
# result_0 = mlir_type(TracedRArray{T,length(ic)}, rsize)
967+
968+
# res = MLIR.IR.result(
969+
# stablehlo.unary_einsum(
970+
# x.mlir_data; result_0, einsum_config=MLIR.IR.Attribute(equation), location
971+
# ),
972+
# )
973+
# if length(rsize) == 0
974+
# return TracedRNumber{T}((), res)
975+
# else
976+
# return TracedRArray{T,length(rsize)}((), res, rsize)
977+
# end
978+
# end
979+
924980
# parallel ops
925981
@noinline function partition_id(;
926982
location=mlir_stacktrace("partition_id", @__FILE__, @__LINE__)

0 commit comments

Comments
 (0)