Skip to content

Commit d6e7b21

Browse files
committed
Despecialize Ops.dot_general
1 parent 1113a92 commit d6e7b21

File tree

1 file changed

+4
-60
lines changed

1 file changed

+4
-60
lines changed

src/Ops.jl

Lines changed: 4 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# If you want to add some check or test, the StableHLO spec should be taken as the source of truth, not the Julia or Reactant semantics.
33
# Julia and Reactant semantics should be considered on the higher abstractions that use these ops.
44
module Ops
5+
using Base: @nospecializeinfer
56
using ..MLIR: MLIR
67
using ..MLIR.Dialects: stablehlo, chlo, enzyme
78
using ..Reactant:
@@ -794,9 +795,9 @@ end
794795
# return TracedRArray{T,N}((), res, size(lhs))
795796
# end
796797

797-
@noinline function dot_general(
798-
lhs::TracedRArray{T1},
799-
rhs::TracedRArray{T2};
798+
Base.@nospecializeinfer @noinline function dot_general(
799+
@nospecialize(lhs::TracedRArray{T1}),
800+
@nospecialize(rhs::TracedRArray{T2});
800801
contracting_dimensions,
801802
batching_dimensions=(Int[], Int[]),
802803
precision_config=Reactant.DOT_GENERAL_PRECISION[],
@@ -920,63 +921,6 @@ end
920921
return TracedRArray{resT,length(ressize)}((), res, ressize)
921922
end
922923

923-
@noinline function einsum(
924-
lhs::TracedRArray{T},
925-
rhs::TracedRArray{T};
926-
equation::String,
927-
location=mlir_stacktrace("einsum", @__FILE__, @__LINE__),
928-
) where {T}
929-
Base.depwarn(
930-
"`stablehlo.einsum` is on deprecation process; use `dot_general` instead", :einsum
931-
)
932-
ins, ic = split(equation, "->")
933-
ia, ib = split(ins, ",")
934-
935-
sizea = Dict(c => d for (c, d) in zip(ia, size(lhs)))
936-
sizeb = Dict(c => d for (c, d) in zip(ib, size(rhs)))
937-
sizes = mergewith(sizea, sizeb) do da, db
938-
da == db ? da : error("Invalid dimensions in einsum equation")
939-
end
940-
941-
rsize = Tuple(sizes[i] for i in ic)
942-
result_0 = mlir_type(TracedRArray{T,length(ic)}, rsize)
943-
944-
res = MLIR.IR.result(
945-
stablehlo.einsum(
946-
lhs.mlir_data,
947-
rhs.mlir_data;
948-
result_0,
949-
einsum_config=MLIR.IR.Attribute(equation),
950-
location,
951-
),
952-
)
953-
return TracedRArray{T,length(rsize)}((), res, rsize)
954-
end
955-
956-
# function unary_einsum(
957-
# x::TracedRArray{T};
958-
# equation::String,
959-
# location=mlir_stacktrace(
960-
# "unary_einsum", @__FILE__, @__LINE__
961-
# ),
962-
# ) where {T}
963-
# ia, ic = split(equation, "->")
964-
# sizes = Dict(c => d for (c, d) in zip(ia, size(x)))
965-
# rsize = Tuple(sizes[i] for i in ic)
966-
# result_0 = mlir_type(TracedRArray{T,length(ic)}, rsize)
967-
968-
# res = MLIR.IR.result(
969-
# stablehlo.unary_einsum(
970-
# x.mlir_data; result_0, einsum_config=MLIR.IR.Attribute(equation), location
971-
# ),
972-
# )
973-
# if length(rsize) == 0
974-
# return TracedRNumber{T}((), res)
975-
# else
976-
# return TracedRArray{T,length(rsize)}((), res, rsize)
977-
# end
978-
# end
979-
980924
# parallel ops
981925
@noinline function partition_id(;
982926
location=mlir_stacktrace("partition_id", @__FILE__, @__LINE__)

0 commit comments

Comments
 (0)