- 
                Notifications
    You must be signed in to change notification settings 
- Fork 34
Description
Description
When using Reactant's @compile with Enzyme's autodiff, compilation fails if function arguments have different sizes.
Environment
- Julia Version 1.11.7
 Commit f2b3dbda30a (2025-09-08 12:10 UTC)
 Build Info:
 Official https://julialang.org/ release
 Platform Info:
 OS: Linux (x86_64-linux-gnu)
 CPU: 32 × 13th Gen Intel(R) Core(TM) i9-13900K
 WORD_SIZE: 64
 LLVM: libLLVM-16.0.6 (ORCJIT, alderlake)
 Threads: 32 default, 0 interactive, 16 GC (on 32 virtual cores)
- Reactant version: [3c362404] Reactant v0.2.169
- Enzyme version: [7da242da] Enzyme v0.13.85
Minimal Reproducible Example
using Reactant
using Enzyme
# Create arrays with different sizes
x = rand(1000)
t = rand(1001)  # Different size from x
# it works if uncommenting the following to make the size of x and t the same
# t = rand(1000)
function eval_F(x, t)
    F = x.^2  # Note: t is unused
    return F
end
function eval_jvp_x_F(v, x, t)
    jvp = Enzyme.autodiff(Forward, eval_F, Duplicated(x, v), Const(t))[1]
    return jvp
end
# This works fine
v = ones(size(x))
eval_jvp_x_F(v, x, t)
# This throws a compilation error
v_r = Reactant.to_rarray(v)
x_r = Reactant.to_rarray(x)
t_r = Reactant.to_rarray(t)
eval_jvp_x_F_compiled = @compile eval_jvp_x_F(v_r, x_r, t_r)Expected Behavior
The compilation should succeed.
Actual Behavior
Compilation fails with an error when size(t) != size(x).
Error Message
error: type of return operand 3 ('tensor<1000xf64>') doesn't match function result type ('tensor<1001xf64>') in function @eval_jvp_x_F
┌ Error: Compilation failed, MLIR module written to /tmp/reactant_uMtOZs/module_001_4sWt_post_all_pm.mlir
└ @ Reactant.MLIR.IR ~/.julia/packages/Reactant/bR5J7/src/mlir/IR/Pass.jl:119
ERROR: "failed to run pass manager on module"
Stacktrace:
[1] run!(pm::Reactant.MLIR.IR.PassManager, mod::Reactant.MLIR.IR.Module, key::String)
@ Reactant.MLIR.IR ~/.julia/packages/Reactant/bR5J7/src/mlir/IR/Pass.jl:163
[2] run_pass_pipeline!(mod::Reactant.MLIR.IR.Module, pass_pipeline::String, key::String; enable_verifier::Bool)
@ Reactant.Compiler ~/.julia/packages/Reactant/bR5J7/src/Compiler.jl:1306
[3] run_pass_pipeline!
@ ~/.julia/packages/Reactant/bR5J7/src/Compiler.jl:1301 [inlined]
[4] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, kwargs::@kwargs{})
@ Reactant.Compiler ~/.julia/packages/Reactant/bR5J7/src/Compiler.jl:1746
[5] compile_mlir! (repeats 2 times)
@ ~/.julia/packages/Reactant/bR5J7/src/Compiler.jl:1561 [inlined]
[6] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@kwargs{…})
@ Reactant.Compiler ~/.julia/packages/Reactant/bR5J7/src/Compiler.jl:3483
[7] compile_xla
@ ~/.julia/packages/Reactant/bR5J7/src/Compiler.jl:3456 [inlined]
[8] compile(f::Function, args::Tuple{…}; kwargs::@kwargs{…})
@ Reactant.Compiler ~/.julia/packages/Reactant/bR5J7/src/Compiler.jl:3555
[9] top-level scope
@ ~/.julia/packages/Reactant/bR5J7/src/Compiler.jl:2633