@@ -3419,6 +3419,182 @@ end
34193419 ]
34203420end
34213421
3422+ """
3423+ scan(
3424+ inputs::Vector{<:TracedRArray},
3425+ init_values::Vector{<:TracedRNumber},
3426+ fn::Function,
3427+ dimension::Int;
3428+ is_reverse::Bool=false,
3429+ is_associative::Bool=false,
3430+ location=mlir_stacktrace("scan", @__FILE__, @__LINE__)
3431+ )
3432+
3433+ Applies a scan (prefix reduction) function `fn` along the specified `dimension` of `inputs`,
3434+ starting from `init_values`. Returns both the outputs and the final carry values.
3435+
3436+ # Arguments
3437+
3438+ - `inputs`: A vector of input arrays to scan over.
3439+ - `init_values`: A vector of initial values (one per input). These define the carry shape.
3440+ - `fn`: A function that takes `(input_slices..., carries...)` where each input_slice has the
3441+ scan dimension removed, and returns `(output_slices..., new_carries...)`.
3442+ - `dimension`: The dimension to scan along (1-indexed).
3443+
3444+ # Keyword Arguments
3445+
3446+ - `is_reverse`: If `true`, the scan is performed in reverse order. Default is `false`.
3447+ - `is_associative`: Indicates whether the reduction function is associative. Default is `false`.
3448+
3449+ # Returns
3450+
3451+ A tuple `(outputs, carries)` where:
3452+ - `outputs`: A vector of arrays with the scan dimension present, containing the scan results.
3453+ - `carries`: A vector of the final carry values (same shape as init_values).
3454+
3455+ # Notes
3456+
3457+ The body function receives:
3458+ - For each input: a slice with the scan dimension removed
3459+ - For each init: the carry value (same shape as init)
3460+
3461+ The body function must return:
3462+ - Output values: these will have the scan dimension added at position `dimension`
3463+ - Carry values: same shape as init_values
3464+
3465+ See: https://www.tensorflow.org/xla/operation_semantics#scan
3466+ """
3467+ @noinline function scan (
3468+ inputs:: Vector{<:TracedRArray} ,
3469+ init_values:: Vector{<:Union{<:TracedRNumber,<:TracedRArray}} ,
3470+ fn:: F ,
3471+ dimension:: Int ;
3472+ is_reverse:: Bool = false ,
3473+ is_associative:: Bool = false ,
3474+ location= mlir_stacktrace (" scan" , @__FILE__ , @__LINE__ ),
3475+ ) where {F}
3476+ @assert length (inputs) == length (init_values) " Number of inputs must match number of \
3477+ init_values"
3478+ @assert ! isempty (inputs) " At least one input is required"
3479+ @assert allequal (size .(inputs)) " All input arrays must have the same size."
3480+ @assert 1 <= dimension <= ndims (inputs[1 ]) " Dimension $(dimension) is out of bounds \
3481+ for input with $(ndims (inputs[1 ])) \
3482+ dimensions"
3483+
3484+ input_shape = size (inputs[1 ])
3485+ # Shape of input slices (scan dimension removed)
3486+ slice_shape = Tuple (deleteat! (collect (Int64, input_shape), dimension))
3487+ scan_dim_size = input_shape[dimension]
3488+
3489+ n_inputs = length (inputs)
3490+ n_carries = length (init_values)
3491+
3492+ # Create sample inputs for constructing the body function:
3493+ # - For inputs: slices with scan dimension removed (rank N-1)
3494+ # - For inits: scalar TracedRNumbers (rank 0)
3495+ sample_input_slices = [
3496+ Reactant. promote_to (
3497+ TracedRArray{unwrapped_eltype (inputs[i]),length (slice_shape)},
3498+ zeros (unwrapped_eltype (inputs[i]), slice_shape... ),
3499+ ) for i in 1 : n_inputs
3500+ ]
3501+
3502+ # Construct the body region
3503+ func =
3504+ Reactant. TracedUtils. make_mlir_fn (
3505+ fn,
3506+ (sample_input_slices... , init_values... ),
3507+ (),
3508+ " scan_fn" * string (fn),
3509+ false ;
3510+ args_in_result= :result ,
3511+ return_dialect= :stablehlo ,
3512+ ). f
3513+
3514+ @assert MLIR. IR. nregions (func) == 1
3515+ ftype_attr = MLIR. IR. getattr (func, " function_type" )
3516+ ftype = MLIR. IR. Type (ftype_attr)
3517+
3518+ n_results = MLIR. IR. nresults (ftype)
3519+ @assert n_results >= n_carries " Body function must return at least $(n_carries) \
3520+ values (carries)"
3521+ n_outputs = n_results - n_carries
3522+
3523+ # Extract the body region
3524+ body_region = MLIR. IR. Region ()
3525+ MLIR. API. mlirRegionTakeBody (body_region, MLIR. IR. region (func, 1 ))
3526+ MLIR. IR. rmfromparent! (func)
3527+
3528+ # Compute output types: body output types with scan dimension inserted
3529+ output_types = MLIR. IR. Type[]
3530+ for i in 1 : n_outputs
3531+ body_result_type = MLIR. IR. result (ftype, i)
3532+ body_result_shape = collect (Int64, size (body_result_type))
3533+ # Insert scan dimension at the correct position
3534+ output_shape = insert! (body_result_shape, dimension, scan_dim_size)
3535+ push! (output_types, MLIR. IR. TensorType (output_shape, eltype (body_result_type)))
3536+ end
3537+
3538+ # Compute carry types: same as body carry return types
3539+ carry_types = MLIR. IR. Type[]
3540+ for i in 1 : n_carries
3541+ push! (carry_types, MLIR. IR. result (ftype, n_outputs + i))
3542+ end
3543+
3544+ op = chlo. scan (
3545+ [x. mlir_data for x in inputs],
3546+ [init_value. mlir_data for init_value in init_values];
3547+ outputs= output_types,
3548+ carries= carry_types,
3549+ dimension= dimension - 1 , # Convert to 0-indexed
3550+ is_reverse,
3551+ is_associative,
3552+ body= body_region,
3553+ location,
3554+ )
3555+
3556+ # Extract outputs (with scan dimension)
3557+ outputs = [
3558+ TracedRArray{
3559+ MLIR. IR. julia_type (eltype (output_types[i])),length (size (output_types[i]))
3560+ }(
3561+ (), MLIR. IR. result (op, i), Tuple (size (output_types[i]))
3562+ ) for i in 1 : n_outputs
3563+ ]
3564+
3565+ # Extract carries
3566+ carries = [
3567+ if length (size (carry_types[i])) == 0
3568+ TracedRNumber {MLIR.IR.julia_type(eltype(carry_types[i]))} (
3569+ (), MLIR. IR. result (op, n_outputs + i)
3570+ )
3571+ else
3572+ TracedRArray{
3573+ MLIR. IR. julia_type (eltype (carry_types[i])),length (size (carry_types[i]))
3574+ }(
3575+ (), MLIR. IR. result (op, n_outputs + i), Tuple (size (carry_types[i]))
3576+ )
3577+ end for i in 1 : n_carries
3578+ ]
3579+
3580+ return outputs, carries
3581+ end
3582+
3583+ @noinline function scan (
3584+ x:: TracedRArray{T} ,
3585+ init_value:: Union{TracedRNumber{T},TracedRArray{T}} ,
3586+ fn:: F ,
3587+ dimension:: Int ;
3588+ is_reverse:: Bool = false ,
3589+ is_associative:: Bool = false ,
3590+ location= mlir_stacktrace (" scan" , @__FILE__ , @__LINE__ ),
3591+ ) where {T,F}
3592+ outputs, carries = scan (
3593+ [x], [init_value], fn, dimension; is_reverse, is_associative, location
3594+ )
3595+ return only (outputs), only (carries)
3596+ end
3597+
34223598function standardize_start_index (
34233599 sz:: Int ,
34243600 update_sz:: Union{Int,Nothing} ,
0 commit comments