|
2 | 2 | # If you want to add some check or test, the StableHLO spec should be taken as the source of truth, not the Julia or Reactant semantics.
|
3 | 3 | # Julia and Reactant semantics should be considered on the higher abstractions that use these ops.
|
4 | 4 | module Ops
|
5 |
| -using Base: @nospecializeinfer |
6 | 5 | using ..MLIR: MLIR
|
7 | 6 | using ..MLIR.Dialects: stablehlo, chlo, enzyme
|
8 | 7 | using ..Reactant:
|
|
795 | 794 | # return TracedRArray{T,N}((), res, size(lhs))
|
796 | 795 | # end
|
797 | 796 |
|
798 |
| -Base.@nospecializeinfer @noinline function dot_general( |
799 |
| - @nospecialize(lhs::TracedRArray{T1}), |
800 |
| - @nospecialize(rhs::TracedRArray{T2}); |
| 797 | +@noinline function dot_general( |
| 798 | + lhs::TracedRArray{T1}, |
| 799 | + rhs::TracedRArray{T2}; |
801 | 800 | contracting_dimensions,
|
802 | 801 | batching_dimensions=(Int[], Int[]),
|
803 | 802 | precision_config=Reactant.DOT_GENERAL_PRECISION[],
|
@@ -921,6 +920,63 @@ Base.@nospecializeinfer @noinline function dot_general(
|
921 | 920 | return TracedRArray{resT,length(ressize)}((), res, ressize)
|
922 | 921 | end
|
923 | 922 |
|
| 923 | +@noinline function einsum( |
| 924 | + lhs::TracedRArray{T}, |
| 925 | + rhs::TracedRArray{T}; |
| 926 | + equation::String, |
| 927 | + location=mlir_stacktrace("einsum", @__FILE__, @__LINE__), |
| 928 | +) where {T} |
| 929 | + Base.depwarn( |
| 930 | + "`stablehlo.einsum` is on deprecation process; use `dot_general` instead", :einsum |
| 931 | + ) |
| 932 | + ins, ic = split(equation, "->") |
| 933 | + ia, ib = split(ins, ",") |
| 934 | + |
| 935 | + sizea = Dict(c => d for (c, d) in zip(ia, size(lhs))) |
| 936 | + sizeb = Dict(c => d for (c, d) in zip(ib, size(rhs))) |
| 937 | + sizes = mergewith(sizea, sizeb) do da, db |
| 938 | + da == db ? da : error("Invalid dimensions in einsum equation") |
| 939 | + end |
| 940 | + |
| 941 | + rsize = Tuple(sizes[i] for i in ic) |
| 942 | + result_0 = mlir_type(TracedRArray{T,length(ic)}, rsize) |
| 943 | + |
| 944 | + res = MLIR.IR.result( |
| 945 | + stablehlo.einsum( |
| 946 | + lhs.mlir_data, |
| 947 | + rhs.mlir_data; |
| 948 | + result_0, |
| 949 | + einsum_config=MLIR.IR.Attribute(equation), |
| 950 | + location, |
| 951 | + ), |
| 952 | + ) |
| 953 | + return TracedRArray{T,length(rsize)}((), res, rsize) |
| 954 | +end |
| 955 | + |
| 956 | +# function unary_einsum( |
| 957 | +# x::TracedRArray{T}; |
| 958 | +# equation::String, |
| 959 | +# location=mlir_stacktrace( |
| 960 | +# "unary_einsum", @__FILE__, @__LINE__ |
| 961 | +# ), |
| 962 | +# ) where {T} |
| 963 | +# ia, ic = split(equation, "->") |
| 964 | +# sizes = Dict(c => d for (c, d) in zip(ia, size(x))) |
| 965 | +# rsize = Tuple(sizes[i] for i in ic) |
| 966 | +# result_0 = mlir_type(TracedRArray{T,length(ic)}, rsize) |
| 967 | + |
| 968 | +# res = MLIR.IR.result( |
| 969 | +# stablehlo.unary_einsum( |
| 970 | +# x.mlir_data; result_0, einsum_config=MLIR.IR.Attribute(equation), location |
| 971 | +# ), |
| 972 | +# ) |
| 973 | +# if length(rsize) == 0 |
| 974 | +# return TracedRNumber{T}((), res) |
| 975 | +# else |
| 976 | +# return TracedRArray{T,length(rsize)}((), res, rsize) |
| 977 | +# end |
| 978 | +# end |
| 979 | + |
924 | 980 | # parallel ops
|
925 | 981 | @noinline function partition_id(;
|
926 | 982 | location=mlir_stacktrace("partition_id", @__FILE__, @__LINE__)
|
|
0 commit comments