Skip to content

Commit 8032427

Browse files
authored
feat: overloads for accumulate-style function (#1322)
* feat: generalize the reduce_window op * feat: overload accumulate functions * test: accumulate
1 parent 592e74e commit 8032427

File tree

5 files changed

+307
-99
lines changed

5 files changed

+307
-99
lines changed

ext/ReactantNNlibExt/Implementations.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ function overloaded_maxpool!(
337337
y::AnyTracedRArray{T,N}, x::AnyTracedRArray{T2,N}, pdims::NNlib.PoolDims;
338338
) where {T,T2,N}
339339
res = reduce_window(
340-
Reactant.MLIR.Dialects.stablehlo.maximum,
340+
max,
341341
T.(x);
342342
init=typemin(T),
343343
dilation=NNlib.dilation(pdims),
@@ -353,7 +353,7 @@ function overloaded_meanpool!(
353353
y::AnyTracedRArray{T,N}, x::AnyTracedRArray{T2,N}, pdims::NNlib.PoolDims;
354354
) where {T,T2,N}
355355
res = reduce_window(
356-
Reactant.MLIR.Dialects.stablehlo.add,
356+
+,
357357
T.(x);
358358
init=zero(T),
359359
dilation=NNlib.dilation(pdims),

ext/ReactantNNlibExt/Ops.jl

Lines changed: 16 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,27 @@
11
function reduce_window(
22
f, x::AnyTracedRArray{T,N}; init, dilation, kernel_size, padding, stride
33
) where {T,N}
4-
x = materialize_traced_array(x)
5-
6-
num_spatial_dims = N - 2
7-
input_spatial_dims = 1:num_spatial_dims
8-
9-
window_dimensions = [kernel_size..., 1, 1]
10-
window_strides = [stride..., 1, 1]
11-
window_dilations = [dilation..., 1, 1]
12-
13-
output_spatial_shapes = map(input_spatial_dims) do i
4+
output_spatial_shapes = map(1:(N - 2)) do i
145
K = kernel_size[i]
156
pl, pr = padding[2i - 1], padding[2i]
167
d = dilation[i]
178
s = stride[i]
189

19-
(size(x, i) + pl + pr - d * (K - 1) - 1) ÷ s + 1
10+
return (size(x, i) + pl + pr - d * (K - 1) - 1) ÷ s + 1
2011
end
2112

22-
padding = Reactant.MLIR.IR.DenseElementsAttribute(
23-
reshape([padding..., 0, 0, 0, 0], (2, N))'
24-
)
25-
26-
output_shape = Int[output_spatial_shapes..., size(x, N - 1), size(x, N)]
27-
result_type = Reactant.MLIR.IR.TensorType(output_shape, Reactant.MLIR.IR.Type(T))
28-
29-
unranked = Reactant.MLIR.IR.TensorType(
30-
Int[], eltype(Reactant.MLIR.IR.type(get_mlir_data(x)))
31-
)
32-
body =
33-
let body = Reactant.MLIR.IR.Region(),
34-
loc = Reactant.MLIR.IR.Location(),
35-
block = Reactant.MLIR.IR.Block([unranked, unranked], [loc, loc])
36-
37-
Reactant.MLIR.IR.block!(block) do
38-
red = f(
39-
Reactant.MLIR.IR.argument(block, 1),
40-
Reactant.MLIR.IR.argument(block, 2);
41-
result=nothing,
42-
)
43-
Reactant.MLIR.Dialects.stablehlo.return_([Reactant.MLIR.IR.result(red)])
44-
end
45-
push!(body, block)
46-
47-
body
48-
end
49-
50-
attr = fill(Reactant.MLIR.IR.Attribute(init), unranked)
51-
init_value = Reactant.MLIR.IR.result(
52-
Reactant.MLIR.Dialects.stablehlo.constant(; value=attr)
53-
)
54-
reduction = Reactant.MLIR.Dialects.stablehlo.reduce_window(
55-
[get_mlir_data(x)],
56-
[init_value];
57-
result_0=[result_type],
58-
window_dimensions,
59-
window_strides,
60-
window_dilations,
61-
padding,
62-
body,
63-
)
64-
65-
return TracedRArray{T,N}((), Reactant.MLIR.IR.result(reduction), size(result_type))
13+
padding = collect(Int64, reshape([padding..., 0, 0, 0, 0], (2, N))')
14+
15+
return Ops.reduce_window(
16+
f,
17+
[materialize_traced_array(x)],
18+
[Ops.constant(T(init))];
19+
window_dimensions=[kernel_size..., 1, 1],
20+
window_strides=[stride..., 1, 1],
21+
window_dilations=[dilation..., 1, 1],
22+
padding_low=padding[:, 1],
23+
padding_high=padding[:, 2],
24+
output_shape=Int[output_spatial_shapes..., size(x, N - 1), size(x, N)],
25+
base_dilations=ones(Int, N),
26+
)[1]
6627
end

src/Ops.jl

Lines changed: 74 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2647,6 +2647,32 @@ Produces a [`Reactant.MLIR.Dialects.sdy.sharding_constraint`](@ref) operation wi
26472647
end
26482648
end
26492649

2650+
function _construct_reduce_function(f::F, ::Type{T}) where {F,T}
2651+
func =
2652+
Reactant.TracedUtils.make_mlir_fn(
2653+
f,
2654+
(
2655+
Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0),
2656+
Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0),
2657+
),
2658+
(),
2659+
"reduce_fn" * string(f),
2660+
false;
2661+
args_in_result=:none,
2662+
return_dialect=:stablehlo,
2663+
).f
2664+
@assert MLIR.IR.nregions(func) == 1
2665+
ftype_attr = MLIR.IR.attr(func, "function_type")
2666+
ftype = MLIR.IR.Type(ftype_attr)
2667+
@assert MLIR.IR.result(ftype) == MLIR.IR.TensorType(Int[], MLIR.IR.Type(T)) "$(fn) return type is not of tensor<$(T)>"
2668+
2669+
fn = MLIR.IR.Region()
2670+
MLIR.API.mlirRegionTakeBody(fn, MLIR.IR.region(func, 1))
2671+
MLIR.IR.rmfromparent!(func)
2672+
2673+
return fn
2674+
end
2675+
26502676
"""
26512677
reduce(
26522678
x::TracedRArray{T},
@@ -2698,45 +2724,13 @@ Applies a reduction function `fn` along the specified `dimensions` of input `x`,
26982724
) where {T}
26992725
reduced_shape = Tuple(deleteat!(collect(Int64, size(x)), dimensions))
27002726

2701-
result_type = mlir_type(TracedRArray{T,length(reduced_shape)}, reduced_shape)
2702-
2703-
sample_inputs = [
2704-
Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0),
2705-
Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0),
2706-
]
2707-
2708-
func =
2709-
Reactant.TracedUtils.make_mlir_fn(
2710-
fn,
2711-
(sample_inputs),
2712-
(),
2713-
"reduce_fn",
2714-
false;
2715-
args_in_result=:none,
2716-
return_dialect=:stablehlo,
2717-
).f
2718-
@assert MLIR.IR.nregions(func) == 1
2719-
fn_name = String(
2720-
MLIR.IR.attr(func, String(MLIR.API.mlirSymbolTableGetSymbolAttributeName()))
2721-
)
2722-
ftype_attr = MLIR.IR.attr(func, "function_type")
2723-
ftype = MLIR.IR.Type(ftype_attr)
2724-
@assert MLIR.IR.result(ftype) == MLIR.IR.TensorType(Int[], MLIR.IR.Type(T)) error (
2725-
"$fn return type is not tensor<i1>"
2726-
)
2727-
fn = MLIR.IR.Region()
2728-
MLIR.API.mlirRegionTakeBody(fn, MLIR.IR.region(func, 1))
2729-
MLIR.IR.rmfromparent!(func)
2730-
2731-
dimensions = MLIR.IR.Attribute(dimensions .- 1)
2732-
27332727
res = MLIR.IR.result(
27342728
stablehlo.reduce(
27352729
[x.mlir_data],
27362730
[init_values.mlir_data];
2737-
result_0=[result_type],
2738-
dimensions=dimensions,
2739-
body=fn,
2731+
result_0=[mlir_type(TracedRArray{T,length(reduced_shape)}, reduced_shape)],
2732+
dimensions=MLIR.IR.Attribute(dimensions .- 1),
2733+
body=_construct_reduce_function(fn, T),
27402734
location=location,
27412735
),
27422736
)
@@ -2986,4 +2980,49 @@ Compute the row maximum pivoted LU factorization of `x` and return the factors `
29862980
return (res, ipiv, perm, info)
29872981
end
29882982

2983+
@noinline function reduce_window(
2984+
f::F,
2985+
inputs::Vector{TracedRArray{T,N}},
2986+
init_values::Vector{TracedRNumber{T}};
2987+
window_dimensions::Vector{Int},
2988+
window_strides::Vector{Int},
2989+
base_dilations::Vector{Int},
2990+
window_dilations::Vector{Int},
2991+
padding_low::Vector{Int},
2992+
padding_high::Vector{Int},
2993+
output_shape::Vector{Int},
2994+
location=mlir_stacktrace("reduce_window", @__FILE__, @__LINE__),
2995+
) where {F,T,N}
2996+
@assert length(inputs) == length(init_values)
2997+
@assert length(window_dimensions) ==
2998+
length(window_strides) ==
2999+
length(base_dilations) ==
3000+
length(window_dilations) ==
3001+
length(padding_low) ==
3002+
length(padding_high) ==
3003+
N
3004+
3005+
reduction = stablehlo.reduce_window(
3006+
[inp.mlir_data for inp in inputs],
3007+
[init.mlir_data for init in init_values];
3008+
result_0=[
3009+
mlir_type(TracedRArray{T,length(output_shape)}, output_shape) for
3010+
_ in 1:length(inputs)
3011+
],
3012+
window_dimensions,
3013+
window_strides,
3014+
base_dilations,
3015+
window_dilations,
3016+
padding=MLIR.IR.DenseElementsAttribute(hcat(padding_low, padding_high)),
3017+
body=_construct_reduce_function(f, T),
3018+
location,
3019+
)
3020+
3021+
return [
3022+
TracedRArray{T,length(output_shape)}(
3023+
(), MLIR.IR.result(reduction, i), output_shape
3024+
) for i in 1:length(inputs)
3025+
]
3026+
end
3027+
29893028
end # module Ops

src/TracedRArray.jl

Lines changed: 112 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,12 @@ for (jlop, hloop, hlocomp, merge) in
531531
end
532532
end
533533

534+
__default_init(::Type{T}, ::typeof(Base.min)) where {T} = typemax(T)
535+
__default_init(::Type{T}, ::typeof(Base.max)) where {T} = typemin(T)
536+
function __default_init(::Type{T}, op::F) where {T,F}
537+
return Base.reduce_empty(Base.BottomRF(op), T)
538+
end
539+
534540
function overloaded_mapreduce(
535541
@nospecialize(f),
536542
@nospecialize(op),
@@ -547,13 +553,7 @@ function overloaded_mapreduce(
547553
op_in_T = Core.Compiler.return_type(f, Tuple{T})
548554

549555
if init === nothing
550-
if op === min
551-
init = typemax(op_in_T)
552-
elseif op === max
553-
init = typemin(op_in_T)
554-
else
555-
init = Base.reduce_empty(Base.BottomRF(op), op_in_T)
556-
end
556+
init = __default_init(op_in_T, op)
557557

558558
if typeof(init) != op_in_T
559559
op_in_T = typeof(init)
@@ -1241,4 +1241,109 @@ function Base.mapslices(f::F, A::TracedRArray; dims) where {F}
12411241
return Ops.batch(f, A, dims)
12421242
end
12431243

1244+
# accumulate interface
1245+
## Taken from https://github.com/JuliaGPU/CUDA.jl/blob/a4a7af45f54f0e57f5912bb52db48e2d27cf7b4f/src/accumulate.jl#L201
1246+
function Base.accumulate(
1247+
op, A::AnyTracedRArray; dims::Union{Integer,Nothing}=nothing, kwargs...
1248+
)
1249+
if dims === nothing && ndims(A) != 1
1250+
return reshape(accumulate(op, A[:]), size(A)...)
1251+
end
1252+
1253+
nt = values(kwargs)
1254+
# Base.promote_op was having issues
1255+
if isempty(kwargs)
1256+
zA = zero(unwrapped_eltype(A))
1257+
out = similar(A, TracedRNumber{unwrapped_eltype(op(zA, zA))})
1258+
elseif keys(nt) === (:init,)
1259+
zA = zero(unwrapped_eltype(A))
1260+
zI = zero(unwrapped_eltype(nt.init))
1261+
out = similar(A, TracedRNumber{unwrapped_eltype(op(zA, zI))})
1262+
else
1263+
throw(
1264+
ArgumentError(
1265+
"accumulate does not support the keyword arguments $(setdiff(keys(nt), (:init,)))",
1266+
),
1267+
)
1268+
end
1269+
1270+
return accumulate!(op, out, A; dims, kwargs...)
1271+
end
1272+
1273+
function Base.accumulate_pairwise!(op, A::AnyTracedRVector, B::AnyTracedRVector)
1274+
return accumulate!(op, A, B; dims=1)
1275+
end
1276+
1277+
function Base._accumulate!(
1278+
op, output::AnyTracedRArray, input::AnyTracedRVector, ::Nothing, ::Nothing
1279+
)
1280+
return scan_impl!(op, output, input; dims=1)
1281+
end
1282+
1283+
function Base._accumulate!(
1284+
op, output::AnyTracedRArray, input::AnyTracedRArray, dims::Integer, ::Nothing
1285+
)
1286+
return scan_impl!(op, output, input; dims=dims)
1287+
end
1288+
1289+
function Base._accumulate!(
1290+
op, output::AnyTracedRArray, input::AnyTracedRVector, ::Nothing, init::Some
1291+
)
1292+
return scan_impl!(op, output, input; dims=1, init=init)
1293+
end
1294+
1295+
function Base._accumulate!(
1296+
op, output::AnyTracedRArray, input::AnyTracedRArray, dims::Integer, init::Some
1297+
)
1298+
return scan_impl!(op, output, input; dims=dims, init=init)
1299+
end
1300+
1301+
function scan_impl!(
1302+
op,
1303+
output::AnyTracedRArray{T,N},
1304+
input::AnyTracedRArray{T,N};
1305+
dims::Integer,
1306+
init=nothing,
1307+
) where {T,N}
1308+
@assert dims > 0 "dims must be a positive integer"
1309+
@assert axes(output) == axes(input) "output and input must have the same shape"
1310+
1311+
dims > ndims(input) && return copyto!(output, input)
1312+
1313+
if init === nothing
1314+
op_in_T = Core.Compiler.return_type(op, Tuple{T,T})
1315+
op_in_T === Union{} && (op_in_T = T)
1316+
1317+
init = __default_init(T, op)
1318+
if typeof(init) != op_in_T
1319+
op_in_T = typeof(init)
1320+
input = typeof(init).(input)
1321+
end
1322+
end
1323+
init = something(init) # unwrap Some
1324+
init = TracedUtils.promote_to(TracedRNumber{unwrapped_eltype(init)}, init)
1325+
1326+
window_dimensions = ones(Int64, N)
1327+
window_dimensions[dims] = size(input, dims)
1328+
1329+
padding_low = zeros(Int64, N)
1330+
padding_low[dims] = size(input, dims) - 1
1331+
1332+
reduction_result = Ops.reduce_window(
1333+
op,
1334+
[materialize_traced_array(input)],
1335+
[init];
1336+
window_dimensions=window_dimensions,
1337+
window_strides=ones(Int64, N),
1338+
base_dilations=ones(Int64, N),
1339+
window_dilations=ones(Int64, N),
1340+
padding_low=padding_low,
1341+
padding_high=zeros(Int64, N),
1342+
output_shape=collect(Int64, size(output)),
1343+
)[1]
1344+
copyto!(output, reduction_result)
1345+
1346+
return output
1347+
end
1348+
12441349
end

0 commit comments

Comments
 (0)