Skip to content

Commit 1b35352

Browse files
committed
feat: lower scan_impl! to chlo.scan
1 parent 09d2294 commit 1b35352

File tree

2 files changed

+186
-33
lines changed

2 files changed

+186
-33
lines changed

src/Ops.jl

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3419,6 +3419,182 @@ end
34193419
]
34203420
end
34213421

3422+
"""
3423+
scan(
3424+
inputs::Vector{<:TracedRArray},
3425+
init_values::Vector{<:TracedRNumber},
3426+
fn::Function,
3427+
dimension::Int;
3428+
is_reverse::Bool=false,
3429+
is_associative::Bool=false,
3430+
location=mlir_stacktrace("scan", @__FILE__, @__LINE__)
3431+
)
3432+
3433+
Applies a scan (prefix reduction) function `fn` along the specified `dimension` of `inputs`,
3434+
starting from `init_values`. Returns both the outputs and the final carry values.
3435+
3436+
# Arguments
3437+
3438+
- `inputs`: A vector of input arrays to scan over.
3439+
- `init_values`: A vector of initial values (one per input). These define the carry shape.
3440+
- `fn`: A function that takes `(input_slices..., carries...)` where each input_slice has the
3441+
scan dimension removed, and returns `(output_slices..., new_carries...)`.
3442+
- `dimension`: The dimension to scan along (1-indexed).
3443+
3444+
# Keyword Arguments
3445+
3446+
- `is_reverse`: If `true`, the scan is performed in reverse order. Default is `false`.
3447+
- `is_associative`: Indicates whether the reduction function is associative. Default is `false`.
3448+
3449+
# Returns
3450+
3451+
A tuple `(outputs, carries)` where:
3452+
- `outputs`: A vector of arrays with the scan dimension present, containing the scan results.
3453+
- `carries`: A vector of the final carry values (same shape as init_values).
3454+
3455+
# Notes
3456+
3457+
The body function receives:
3458+
- For each input: a slice with the scan dimension removed
3459+
- For each init: the carry value (same shape as init)
3460+
3461+
The body function must return:
3462+
- Output values: these will have the scan dimension added at position `dimension`
3463+
- Carry values: same shape as init_values
3464+
3465+
See: https://www.tensorflow.org/xla/operation_semantics#scan
3466+
"""
3467+
@noinline function scan(
3468+
inputs::Vector{<:TracedRArray},
3469+
init_values::Vector{<:Union{<:TracedRNumber,<:TracedRArray}},
3470+
fn::F,
3471+
dimension::Int;
3472+
is_reverse::Bool=false,
3473+
is_associative::Bool=false,
3474+
location=mlir_stacktrace("scan", @__FILE__, @__LINE__),
3475+
) where {F}
3476+
@assert length(inputs) == length(init_values) "Number of inputs must match number of \
3477+
init_values"
3478+
@assert !isempty(inputs) "At least one input is required"
3479+
@assert allequal(size.(inputs)) "All input arrays must have the same size."
3480+
@assert 1 <= dimension <= ndims(inputs[1]) "Dimension $(dimension) is out of bounds \
3481+
for input with $(ndims(inputs[1])) \
3482+
dimensions"
3483+
3484+
input_shape = size(inputs[1])
3485+
# Shape of input slices (scan dimension removed)
3486+
slice_shape = Tuple(deleteat!(collect(Int64, input_shape), dimension))
3487+
scan_dim_size = input_shape[dimension]
3488+
3489+
n_inputs = length(inputs)
3490+
n_carries = length(init_values)
3491+
3492+
# Create sample inputs for constructing the body function:
3493+
# - For inputs: slices with scan dimension removed (rank N-1)
3494+
# - For inits: scalar TracedRNumbers (rank 0)
3495+
sample_input_slices = [
3496+
Reactant.promote_to(
3497+
TracedRArray{unwrapped_eltype(inputs[i]),length(slice_shape)},
3498+
zeros(unwrapped_eltype(inputs[i]), slice_shape...),
3499+
) for i in 1:n_inputs
3500+
]
3501+
3502+
# Construct the body region
3503+
func =
3504+
Reactant.TracedUtils.make_mlir_fn(
3505+
fn,
3506+
(sample_input_slices..., init_values...),
3507+
(),
3508+
"scan_fn" * string(fn),
3509+
false;
3510+
args_in_result=:result,
3511+
return_dialect=:stablehlo,
3512+
).f
3513+
3514+
@assert MLIR.IR.nregions(func) == 1
3515+
ftype_attr = MLIR.IR.getattr(func, "function_type")
3516+
ftype = MLIR.IR.Type(ftype_attr)
3517+
3518+
n_results = MLIR.IR.nresults(ftype)
3519+
@assert n_results >= n_carries "Body function must return at least $(n_carries) \
3520+
values (carries)"
3521+
n_outputs = n_results - n_carries
3522+
3523+
# Extract the body region
3524+
body_region = MLIR.IR.Region()
3525+
MLIR.API.mlirRegionTakeBody(body_region, MLIR.IR.region(func, 1))
3526+
MLIR.IR.rmfromparent!(func)
3527+
3528+
# Compute output types: body output types with scan dimension inserted
3529+
output_types = MLIR.IR.Type[]
3530+
for i in 1:n_outputs
3531+
body_result_type = MLIR.IR.result(ftype, i)
3532+
body_result_shape = collect(Int64, size(body_result_type))
3533+
# Insert scan dimension at the correct position
3534+
output_shape = insert!(body_result_shape, dimension, scan_dim_size)
3535+
push!(output_types, MLIR.IR.TensorType(output_shape, eltype(body_result_type)))
3536+
end
3537+
3538+
# Compute carry types: same as body carry return types
3539+
carry_types = MLIR.IR.Type[]
3540+
for i in 1:n_carries
3541+
push!(carry_types, MLIR.IR.result(ftype, n_outputs + i))
3542+
end
3543+
3544+
op = chlo.scan(
3545+
[x.mlir_data for x in inputs],
3546+
[init_value.mlir_data for init_value in init_values];
3547+
outputs=output_types,
3548+
carries=carry_types,
3549+
dimension=dimension - 1, # Convert to 0-indexed
3550+
is_reverse,
3551+
is_associative,
3552+
body=body_region,
3553+
location,
3554+
)
3555+
3556+
# Extract outputs (with scan dimension)
3557+
outputs = [
3558+
TracedRArray{
3559+
MLIR.IR.julia_type(eltype(output_types[i])),length(size(output_types[i]))
3560+
}(
3561+
(), MLIR.IR.result(op, i), Tuple(size(output_types[i]))
3562+
) for i in 1:n_outputs
3563+
]
3564+
3565+
# Extract carries
3566+
carries = [
3567+
if length(size(carry_types[i])) == 0
3568+
TracedRNumber{MLIR.IR.julia_type(eltype(carry_types[i]))}(
3569+
(), MLIR.IR.result(op, n_outputs + i)
3570+
)
3571+
else
3572+
TracedRArray{
3573+
MLIR.IR.julia_type(eltype(carry_types[i])),length(size(carry_types[i]))
3574+
}(
3575+
(), MLIR.IR.result(op, n_outputs + i), Tuple(size(carry_types[i]))
3576+
)
3577+
end for i in 1:n_carries
3578+
]
3579+
3580+
return outputs, carries
3581+
end
3582+
3583+
@noinline function scan(
3584+
x::TracedRArray{T},
3585+
init_value::Union{TracedRNumber{T},TracedRArray{T}},
3586+
fn::F,
3587+
dimension::Int;
3588+
is_reverse::Bool=false,
3589+
is_associative::Bool=false,
3590+
location=mlir_stacktrace("scan", @__FILE__, @__LINE__),
3591+
) where {T,F}
3592+
outputs, carries = scan(
3593+
[x], [init_value], fn, dimension; is_reverse, is_associative, location
3594+
)
3595+
return only(outputs), only(carries)
3596+
end
3597+
34223598
function standardize_start_index(
34233599
sz::Int,
34243600
update_sz::Union{Int,Nothing},

src/TracedRArray.jl

Lines changed: 10 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,44 +1015,21 @@ function scan_impl!(
10151015
op_in_T = riT
10161016
input = riT.(input)
10171017
end
1018-
else
1019-
# Note: fix this for TPUs
1020-
if contains(string(first(Reactant.devices())), "TPU")
1021-
initT = __default_init(T, op)
1022-
if initT != init && initT != something(init)
1023-
throw(
1024-
AssertionError(
1025-
"Currently, `init` is not supported on TPUs, provided value $init does not match identity $initT.",
1026-
),
1027-
)
1028-
end
1029-
end
10301018
end
10311019

1020+
input = materialize_traced_array(input)
1021+
10321022
init = something(init) # unwrap Some
10331023
init = Reactant.promote_to(TracedRNumber{unwrapped_eltype(init)}, init)
1024+
init = Reactant.broadcast_to_size(init, deleteat!(collect(Int64, size(input)), dims))
10341025

1035-
window_dimensions = ones(Int64, N)
1036-
window_dimensions[dims] = size(input, dims)
1037-
1038-
padding_low = zeros(Int64, N)
1039-
padding_low[dims] = size(input, dims) - 1
1040-
1041-
reduction_result = @opcall(
1042-
reduce_window(
1043-
op,
1044-
[materialize_traced_array(input)],
1045-
[init];
1046-
window_dimensions=window_dimensions,
1047-
window_strides=ones(Int64, N),
1048-
base_dilations=ones(Int64, N),
1049-
window_dilations=ones(Int64, N),
1050-
padding_low=padding_low,
1051-
padding_high=zeros(Int64, N),
1052-
output_shape=collect(Int64, size(output)),
1053-
)
1054-
)[1]
1055-
copyto!(output, reduction_result)
1026+
function scan_fn(a, c)
1027+
res = a .+ c
1028+
return res, copy(res) # force 2 returns
1029+
end
1030+
1031+
scan_result, _ = @opcall scan(input, init, scan_fn, dims)
1032+
copyto!(output, scan_result)
10561033

10571034
return output
10581035
end

0 commit comments

Comments
 (0)