Skip to content

Commit e66fffa

Browse files
authored
Despecialize Ops.dot_general (#1318)
* Despecialize `Ops.dot_general` * Remove import * Remove einsum tests * Format code
1 parent 8a289c9 commit e66fffa

File tree

2 files changed

+3
-91
lines changed

2 files changed

+3
-91
lines changed

src/Ops.jl

Lines changed: 3 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -794,9 +794,9 @@ end
794794
# return TracedRArray{T,N}((), res, size(lhs))
795795
# end
796796

797-
@noinline function dot_general(
798-
lhs::TracedRArray{T1},
799-
rhs::TracedRArray{T2};
797+
Base.@nospecializeinfer @noinline function dot_general(
798+
@nospecialize(lhs::TracedRArray{T1}),
799+
@nospecialize(rhs::TracedRArray{T2});
800800
contracting_dimensions,
801801
batching_dimensions=(Int[], Int[]),
802802
precision_config=Reactant.DOT_GENERAL_PRECISION[],
@@ -920,63 +920,6 @@ end
920920
return TracedRArray{resT,length(ressize)}((), res, ressize)
921921
end
922922

923-
@noinline function einsum(
924-
lhs::TracedRArray{T},
925-
rhs::TracedRArray{T};
926-
equation::String,
927-
location=mlir_stacktrace("einsum", @__FILE__, @__LINE__),
928-
) where {T}
929-
Base.depwarn(
930-
"`stablehlo.einsum` is on deprecation process; use `dot_general` instead", :einsum
931-
)
932-
ins, ic = split(equation, "->")
933-
ia, ib = split(ins, ",")
934-
935-
sizea = Dict(c => d for (c, d) in zip(ia, size(lhs)))
936-
sizeb = Dict(c => d for (c, d) in zip(ib, size(rhs)))
937-
sizes = mergewith(sizea, sizeb) do da, db
938-
da == db ? da : error("Invalid dimensions in einsum equation")
939-
end
940-
941-
rsize = Tuple(sizes[i] for i in ic)
942-
result_0 = mlir_type(TracedRArray{T,length(ic)}, rsize)
943-
944-
res = MLIR.IR.result(
945-
stablehlo.einsum(
946-
lhs.mlir_data,
947-
rhs.mlir_data;
948-
result_0,
949-
einsum_config=MLIR.IR.Attribute(equation),
950-
location,
951-
),
952-
)
953-
return TracedRArray{T,length(rsize)}((), res, rsize)
954-
end
955-
956-
# function unary_einsum(
957-
# x::TracedRArray{T};
958-
# equation::String,
959-
# location=mlir_stacktrace(
960-
# "unary_einsum", @__FILE__, @__LINE__
961-
# ),
962-
# ) where {T}
963-
# ia, ic = split(equation, "->")
964-
# sizes = Dict(c => d for (c, d) in zip(ia, size(x)))
965-
# rsize = Tuple(sizes[i] for i in ic)
966-
# result_0 = mlir_type(TracedRArray{T,length(ic)}, rsize)
967-
968-
# res = MLIR.IR.result(
969-
# stablehlo.unary_einsum(
970-
# x.mlir_data; result_0, einsum_config=MLIR.IR.Attribute(equation), location
971-
# ),
972-
# )
973-
# if length(rsize) == 0
974-
# return TracedRNumber{T}((), res)
975-
# else
976-
# return TracedRArray{T,length(rsize)}((), res, rsize)
977-
# end
978-
# end
979-
980923
# parallel ops
981924
@noinline function partition_id(;
982925
location=mlir_stacktrace("partition_id", @__FILE__, @__LINE__)

test/ops.jl

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -237,37 +237,6 @@ end
237237
@test Array(a)' * Array(b) == @jit f1(a, b)
238238
end
239239

240-
@testset "einsum" begin
241-
f1(a, b) = Ops.einsum(a, b; equation="i,i->i")
242-
f2(a, b) = Ops.einsum(a, b; equation="i,j->ij")
243-
f3(a, b) = Ops.einsum(a, b; equation="ij,ij->ij")
244-
f4(a, b) = Ops.einsum(a, b; equation="ik,kj->ij")
245-
246-
for (a, b) in [
247-
(Reactant.to_rarray([1, 2, 3, 4]), Reactant.to_rarray([5, 6, -7, -8])),
248-
(
249-
Reactant.to_rarray([1.0, 2.0, 3.0, 4.0]),
250-
Reactant.to_rarray([5.0, 6.0, -7.0, -8.0]),
251-
),
252-
(
253-
Reactant.to_rarray([1.0 + 1im, 2.0 + 2im, 3.0 - 3im, 4.0 - 4im]),
254-
Reactant.to_rarray([5.0 + 5im, 6.0 + 6im, -7.0 - 7im, -8.0 - 8im]),
255-
),
256-
]
257-
@test a .* b
258-
@test_warn r"`stablehlo.einsum` is on deprecation process" @jit f1(a, b)
259-
@test reshape(kron(Array(b), Array(a)), 4, 4)
260-
@test_warn r"`stablehlo.einsum` is on deprecation process" @jit f2(a, b)
261-
262-
x = ConcreteRArray(reshape(a, (2, 2)))
263-
y = ConcreteRArray(reshape(b, (2, 2)))
264-
@test x .* y
265-
@test_warn r"`stablehlo.einsum` is on deprecation process" @jit f3(x, y)
266-
@test Array(x) * Array(y)
267-
@test_warn r"`stablehlo.einsum` is on deprecation process" @jit f4(x, y)
268-
end
269-
end
270-
271240
@testset "exponential" begin
272241
x = Reactant.to_rarray([1.0, 2.0, 3.0, 4.0])
273242
@test exp.(Array(x)) @jit Ops.exponential(x)

0 commit comments

Comments
 (0)