|
794 | 794 | # return TracedRArray{T,N}((), res, size(lhs))
|
795 | 795 | # end
|
796 | 796 |
|
797 |
| -@noinline function dot_general( |
798 |
| - lhs::TracedRArray{T1}, |
799 |
| - rhs::TracedRArray{T2}; |
| 797 | +Base.@nospecializeinfer @noinline function dot_general( |
| 798 | + @nospecialize(lhs::TracedRArray{T1}), |
| 799 | + @nospecialize(rhs::TracedRArray{T2}); |
800 | 800 | contracting_dimensions,
|
801 | 801 | batching_dimensions=(Int[], Int[]),
|
802 | 802 | precision_config=Reactant.DOT_GENERAL_PRECISION[],
|
|
920 | 920 | return TracedRArray{resT,length(ressize)}((), res, ressize)
|
921 | 921 | end
|
922 | 922 |
|
923 |
| -@noinline function einsum( |
924 |
| - lhs::TracedRArray{T}, |
925 |
| - rhs::TracedRArray{T}; |
926 |
| - equation::String, |
927 |
| - location=mlir_stacktrace("einsum", @__FILE__, @__LINE__), |
928 |
| -) where {T} |
929 |
| - Base.depwarn( |
930 |
| - "`stablehlo.einsum` is on deprecation process; use `dot_general` instead", :einsum |
931 |
| - ) |
932 |
| - ins, ic = split(equation, "->") |
933 |
| - ia, ib = split(ins, ",") |
934 |
| - |
935 |
| - sizea = Dict(c => d for (c, d) in zip(ia, size(lhs))) |
936 |
| - sizeb = Dict(c => d for (c, d) in zip(ib, size(rhs))) |
937 |
| - sizes = mergewith(sizea, sizeb) do da, db |
938 |
| - da == db ? da : error("Invalid dimensions in einsum equation") |
939 |
| - end |
940 |
| - |
941 |
| - rsize = Tuple(sizes[i] for i in ic) |
942 |
| - result_0 = mlir_type(TracedRArray{T,length(ic)}, rsize) |
943 |
| - |
944 |
| - res = MLIR.IR.result( |
945 |
| - stablehlo.einsum( |
946 |
| - lhs.mlir_data, |
947 |
| - rhs.mlir_data; |
948 |
| - result_0, |
949 |
| - einsum_config=MLIR.IR.Attribute(equation), |
950 |
| - location, |
951 |
| - ), |
952 |
| - ) |
953 |
| - return TracedRArray{T,length(rsize)}((), res, rsize) |
954 |
| -end |
955 |
| - |
956 |
| -# function unary_einsum( |
957 |
| -# x::TracedRArray{T}; |
958 |
| -# equation::String, |
959 |
| -# location=mlir_stacktrace( |
960 |
| -# "unary_einsum", @__FILE__, @__LINE__ |
961 |
| -# ), |
962 |
| -# ) where {T} |
963 |
| -# ia, ic = split(equation, "->") |
964 |
| -# sizes = Dict(c => d for (c, d) in zip(ia, size(x))) |
965 |
| -# rsize = Tuple(sizes[i] for i in ic) |
966 |
| -# result_0 = mlir_type(TracedRArray{T,length(ic)}, rsize) |
967 |
| - |
968 |
| -# res = MLIR.IR.result( |
969 |
| -# stablehlo.unary_einsum( |
970 |
| -# x.mlir_data; result_0, einsum_config=MLIR.IR.Attribute(equation), location |
971 |
| -# ), |
972 |
| -# ) |
973 |
| -# if length(rsize) == 0 |
974 |
| -# return TracedRNumber{T}((), res) |
975 |
| -# else |
976 |
| -# return TracedRArray{T,length(rsize)}((), res, rsize) |
977 |
| -# end |
978 |
| -# end |
979 |
| - |
980 | 923 | # parallel ops
|
981 | 924 | @noinline function partition_id(;
|
982 | 925 | location=mlir_stacktrace("partition_id", @__FILE__, @__LINE__)
|
|
0 commit comments