Skip to content

Commit bcb6034

Browse files
authored
refactor: use Ops instead of direct stablehlo calls (EnzymeAD#347)
* refactor: use Ops instead of direct stablehlo calls * revert: restore Base.conj * fix: minor fixes to Ops * feat: add convert dispatches * revert: keep original `transpose_val` impl * revert: keep control flow needs Operation
1 parent c4a9ae3 commit bcb6034

File tree

5 files changed

+124
-285
lines changed

5 files changed

+124
-285
lines changed

ext/ReactantAbstractFFTsExt.jl

Lines changed: 15 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module ReactantAbstractFFTsExt
22

33
using AbstractFFTs: AbstractFFTs
4-
using Reactant: Reactant, MLIR, TracedRArray
4+
using Reactant: Reactant, MLIR, Ops, TracedRArray
55

66
function check_contiguous_innermost_dims(dims, N)
77
@assert sort([dims...]) == [dims...] "un-sorted dims are not supported"
@@ -32,6 +32,7 @@ function compute_correct_pdims(x::AbstractArray, dims)
3232
end
3333

3434
for op in (:rfft, :fft, :ifft)
35+
mode = uppercase(string(op))
3536
@eval function AbstractFFTs.$(op)(x::TracedRArray, dims)
3637
@assert maximum(dims) ndims(x) "dims out of range"
3738
if dims isa Integer
@@ -41,19 +42,20 @@ for op in (:rfft, :fft, :ifft)
4142
AbstractFFTs.$(op)(permutedims(x, pdims), 1), invperm(pdims)
4243
)
4344
end
44-
return generalized_fft(x, $(Meta.quot(op)), nothing, 1)
45+
return generalized_fft(x, $(mode), nothing, length(dims))
4546
end
4647
if !check_contiguous_innermost_dims(dims, ndims(x))
4748
pdims = compute_correct_pdims(x, dims)
4849
return permutedims(
4950
AbstractFFTs.$(op)(permutedims(x, pdims), 1:length(dims)), invperm(pdims)
5051
)
5152
end
52-
return generalized_fft(x, $(Meta.quot(op)), nothing, length(dims))
53+
return generalized_fft(x, $(mode), nothing, length(dims))
5354
end
5455
end
5556

5657
for op in (:irfft,)
58+
mode = uppercase(string(op))
5759
@eval function AbstractFFTs.$(op)(x::TracedRArray, d::Int, dims)
5860
@assert maximum(dims) ndims(x) "dims out of range"
5961
if dims isa Integer
@@ -63,49 +65,30 @@ for op in (:irfft,)
6365
AbstractFFTs.$(op)(permutedims(x, pdims), d, 1), invperm(pdims)
6466
)
6567
end
66-
return generalized_fft(x, $(Meta.quot(op)), d, 1)
68+
return generalized_fft(x, $(mode), d, length(dims))
6769
end
6870
if !check_contiguous_innermost_dims(dims, ndims(x))
6971
pdims = compute_correct_pdims(x, dims)
7072
return permutedims(
7173
AbstractFFTs.$(op)(permutedims(x, pdims), d, 1:length(dims)), invperm(pdims)
7274
)
7375
end
74-
return generalized_fft(x, $(Meta.quot(op)), d, length(dims))
76+
return generalized_fft(x, $(mode), d, length(dims))
7577
end
7678
end
7779

78-
function generalized_fft(x::TracedRArray{T,N}, mode::Symbol, d, first_n::Int) where {T,N}
79-
@assert mode (:rfft, :irfft, :fft, :ifft)
80-
81-
x = permutedims(x, reverse(1:N))
82-
fft_type_str = uppercase(string(mode))
83-
fft_type = MLIR.API.stablehloFftTypeAttrGet(MLIR.IR.context(), fft_type_str)
84-
80+
function generalized_fft(x::TracedRArray{T,N}, mode::String, d, first_n::Int) where {T,N}
8581
if d === nothing
86-
@assert mode (:rfft, :fft, :ifft)
87-
if mode == :rfft
88-
@assert T <: Real
89-
rT = Complex{T}
90-
res_size = [size(x)[1:(end - 1)]..., size(x, N) ÷ 2 + 1]
91-
else
92-
@assert T <: Complex
93-
rT = T
94-
res_size = [size(x)...]
95-
end
96-
fft_length = [size(x, i) for i in (ndims(x) - first_n + 1):ndims(x)]
82+
@assert mode ("RFFT", "FFT", "IFFT")
83+
fft_length = [size(x, i) for i in 1:first_n]
9784
else
98-
@assert mode == :irfft
99-
@assert T <: Complex
100-
rT = real(T)
101-
res_size = [size(x)[1:(end - 1)]..., d]
102-
fft_length = [res_size[i] for i in (ndims(x) - first_n + 1):ndims(x)]
85+
@assert mode == "IRFFT"
86+
fft_length = [i == 1 ? d : size(x, i) for i in 1:first_n]
10387
end
10488

105-
@assert 1 length(fft_length) 3 "stablehlo.fft only supports up to rank 3"
106-
mlir_type = MLIR.IR.TensorType(res_size, Reactant.MLIR.IR.Type(rT))
107-
op = MLIR.Dialects.stablehlo.fft(x.mlir_data; fft_type, fft_length, result_0=mlir_type)
108-
x = TracedRArray{rT,N}((), MLIR.IR.result(op, 1), Tuple(res_size))
89+
x = permutedims(x, reverse(1:N))
90+
reverse!(fft_length)
91+
x = Ops.fft(x; type=mode, length=fft_length)
10992
return permutedims(x, reverse(1:N))
11093
end
11194

src/Interpreter.jl

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -233,17 +233,7 @@ function push_acts!(ad_inputs, x::BatchDuplicated, path, reverse)
233233
predims = size(x.val)
234234
cval = MLIR.IR.result(
235235
MLIR.Dialects.stablehlo.concatenate(
236-
[
237-
MLIR.IR.result(
238-
MLIR.Dialects.stablehlo.reshape(
239-
v.mlir_data;
240-
result_0=MLIR.IR.TensorType(
241-
Int64[1, predims...], eltype(MLIR.IR.type(v.mlir_data))
242-
),
243-
),
244-
) for v in x.dval
245-
];
246-
dimension=Int64(0),
236+
[Ops.reshape(v, Int64[1, predims...]) for v in x.dval]; dimension=Int64(0)
247237
),
248238
)
249239
tval = TracedRArray{ET,length(predims) + 1}((), cval, (length(x.dval), predims...))
@@ -258,17 +248,7 @@ function push_acts!(ad_inputs, x::BatchDuplicatedNoNeed, path, reverse)
258248
predims = size(x.val)
259249
cval = MLIR.IR.result(
260250
MLIR.Dialects.stablehlo.concatenate(
261-
[
262-
MLIR.IR.result(
263-
MLIR.Dialects.stablehlo.reshape(
264-
v.mlir_data;
265-
result_0=MLIR.IR.TensorType(
266-
Int64[1, predims...], eltype(MLIR.IR.type(v.mlir_data))
267-
),
268-
),
269-
) for v in x.dval
270-
];
271-
dimension=Int64(0),
251+
[Ops.reshape(v, Int64[1, predims...]) for v in x.dval]; dimension=Int64(0)
272252
),
273253
)
274254
tval = TracedRArray{ET,length(predims) + 1}((), cval, (length(x.dval), predims...))
@@ -502,22 +482,12 @@ function overload_autodiff(
502482
for i in 1:width
503483
sz = size(a)
504484
starts = Int64[i]
505-
strides = Int64[1]
506485
limits = Int64[i]
507486
for v in sz
508487
push!(starts, 0)
509488
push!(limits, v)
510-
push!(strides, 1)
511489
end
512-
sval = MLIR.IR.result(
513-
MLIR.Dialects.stablehlo.slice(
514-
sval;
515-
start_indices=starts,
516-
limit_indices=limits,
517-
stride_indices=strides,
518-
),
519-
1,
520-
)
490+
sval = Ops.slice(sval, starts, limits)
521491
set!(dresult[i], path[2:end], sval)
522492
end
523493
end

src/Ops.jl

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ function constant(
2828
end
2929

3030
function constant(x::ConcreteRArray; kwargs...)
31-
return stablehlo.constant(convert(Array, x); kwargs...)
31+
return stablehlo.constant(Base.convert(Array, x); kwargs...)
3232
end
3333

3434
function constant(
@@ -42,7 +42,9 @@ function constant(
4242
x::ConcreteRNumber{T}; location=mlir_stacktrace("constant", @__FILE__, @__LINE__)
4343
) where {T}
4444
output = mlir_type(TracedRArray{T,0}, ())
45-
value = MLIR.IR.DenseElementsAttribute(fill(MLIR.IR.Attribute(convert(T, x)), output))
45+
value = MLIR.IR.DenseElementsAttribute(
46+
fill(MLIR.IR.Attribute(Base.convert(T, x)), output)
47+
)
4648
res = MLIR.IR.result(stablehlo.constant(; output, value, location))
4749
return TracedRNumber{T,N}((), res)
4850
end
@@ -458,10 +460,11 @@ function fft(
458460
Tout = Complex{T}
459461
rsize = let rsize = collect(size(x))
460462
rsize[end] = rsize[end] == 0 ? 0 : rsize[end] ÷ 2 + 1
463+
Tuple(rsize)
461464
end
462465
elseif type == "IRFFT"
463466
@assert T <: Complex
464-
Tout = real(T)
467+
Tout = Base.real(T)
465468
rsize = let rsize = collect(size(x))
466469
rsize[(end - Base.length(length) + 1):end] = length
467470
Tuple(rsize)
@@ -514,7 +517,25 @@ function clamp(
514517
return TracedRArray{T,N}((), res, size(x))
515518
end
516519

517-
function clamp(min::T, x::TracedRArray{T,N}, max::T) where {T,N}
520+
function clamp(
521+
min::TracedRNumber{T},
522+
x::TracedRNumber{T},
523+
max::TracedRNumber{T};
524+
location=mlir_stacktrace("clamp", @__FILE__, @__LINE__),
525+
) where {T}
526+
res = MLIR.IR.result(
527+
stablehlo.clamp(
528+
min.mlir_data,
529+
x.mlir_data,
530+
max.mlir_data;
531+
result=mlir_type(TracedRArray{T,0}, ()),
532+
location,
533+
),
534+
)
535+
return TracedRNumber{T}((), res)
536+
end
537+
538+
function clamp(min::T, x::Union{TracedRArray{T,N},TracedRNumber{T}}, max::T) where {T,N}
518539
return clamp(constant(min), x, constant(max))
519540
end
520541

@@ -1033,7 +1054,7 @@ function compare(
10331054
end
10341055

10351056
res = MLIR.IR.result(
1036-
MLIR.Dialects.stablehlo.compare(
1057+
stablehlo.compare(
10371058
lhs.mlir_data,
10381059
rhs.mlir_data;
10391060
comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet(
@@ -1048,6 +1069,37 @@ function compare(
10481069
return TracedRArray{Bool,ndims(lhs)}((), res, size(lhs))
10491070
end
10501071

1072+
# eltype conversion
1073+
function convert(
1074+
::Type{TracedRArray{T,N}},
1075+
x::TracedRArray;
1076+
location=mlir_stacktrace("convert", @__FILE__, @__LINE__),
1077+
) where {T,N}
1078+
@assert N == ndims(x)
1079+
return TracedRArray{T,N}(
1080+
(),
1081+
MLIR.IR.result(
1082+
stablehlo.convert(
1083+
x.mlir_data; result=mlir_type(TracedRArray{T,N}, size(x)), location
1084+
),
1085+
),
1086+
size(x),
1087+
)
1088+
end
1089+
1090+
function convert(
1091+
::Type{TracedRNumber{T}},
1092+
x::TracedRNumber;
1093+
location=mlir_stacktrace("convert", @__FILE__, @__LINE__),
1094+
) where {T}
1095+
return TracedRNumber{T}(
1096+
(),
1097+
MLIR.IR.result(
1098+
stablehlo.convert(x.mlir_data; result=mlir_type(TracedRNumber{T}), location)
1099+
),
1100+
)
1101+
end
1102+
10511103
# Generate a unique name given a module hash and a function name.
10521104
function _hlo_call_name(orig_name, module_suffix)
10531105
return orig_name * "_hlo_call_" * module_suffix

0 commit comments

Comments
 (0)