Skip to content

Compilation error when function arguments have different sizes #1729

@NoFixedPoint

Description

@NoFixedPoint

Description

When using Reactant's @compile with Enzyme's autodiff, compilation fails if function arguments have different sizes.

Environment

  • Julia Version 1.11.7
    Commit f2b3dbda30a (2025-09-08 12:10 UTC)
    Build Info:
    Official https://julialang.org/ release
    Platform Info:
    OS: Linux (x86_64-linux-gnu)
    CPU: 32 × 13th Gen Intel(R) Core(TM) i9-13900K
    WORD_SIZE: 64
    LLVM: libLLVM-16.0.6 (ORCJIT, alderlake)
    Threads: 32 default, 0 interactive, 16 GC (on 32 virtual cores)
  • Reactant version: [3c362404] Reactant v0.2.169
  • Enzyme version: [7da242da] Enzyme v0.13.85

Minimal Reproducible Example

using Reactant
using Enzyme

# Create arrays with different sizes
x = rand(1000)
t = rand(1001)  # Different size from x
# it works if uncommenting the following to make the size of x and t the same
# t = rand(1000)

function eval_F(x, t)
    F = x.^2  # Note: t is unused
    return F
end

function eval_jvp_x_F(v, x, t)
    jvp = Enzyme.autodiff(Forward, eval_F, Duplicated(x, v), Const(t))[1]
    return jvp
end

# This works fine
v = ones(size(x))
eval_jvp_x_F(v, x, t)

# This throws a compilation error
v_r = Reactant.to_rarray(v)
x_r = Reactant.to_rarray(x)
t_r = Reactant.to_rarray(t)
eval_jvp_x_F_compiled = @compile eval_jvp_x_F(v_r, x_r, t_r)

Expected Behavior

The compilation should succeed.

Actual Behavior

Compilation fails with an error when size(t) != size(x).

Error Message

error: type of return operand 3 ('tensor<1000xf64>') doesn't match function result type ('tensor<1001xf64>') in function @eval_jvp_x_F
┌ Error: Compilation failed, MLIR module written to /tmp/reactant_uMtOZs/module_001_4sWt_post_all_pm.mlir
└ @ Reactant.MLIR.IR ~/.julia/packages/Reactant/bR5J7/src/mlir/IR/Pass.jl:119
ERROR: "failed to run pass manager on module"
Stacktrace:
[1] run!(pm::Reactant.MLIR.IR.PassManager, mod::Reactant.MLIR.IR.Module, key::String)
@ Reactant.MLIR.IR ~/.julia/packages/Reactant/bR5J7/src/mlir/IR/Pass.jl:163
[2] run_pass_pipeline!(mod::Reactant.MLIR.IR.Module, pass_pipeline::String, key::String; enable_verifier::Bool)
@ Reactant.Compiler ~/.julia/packages/Reactant/bR5J7/src/Compiler.jl:1306
[3] run_pass_pipeline!
@ ~/.julia/packages/Reactant/bR5J7/src/Compiler.jl:1301 [inlined]
[4] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, kwargs::@kwargs{})
@ Reactant.Compiler ~/.julia/packages/Reactant/bR5J7/src/Compiler.jl:1746
[5] compile_mlir! (repeats 2 times)
@ ~/.julia/packages/Reactant/bR5J7/src/Compiler.jl:1561 [inlined]
[6] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@kwargs{…})
@ Reactant.Compiler ~/.julia/packages/Reactant/bR5J7/src/Compiler.jl:3483
[7] compile_xla
@ ~/.julia/packages/Reactant/bR5J7/src/Compiler.jl:3456 [inlined]
[8] compile(f::Function, args::Tuple{…}; kwargs::@kwargs{…})
@ Reactant.Compiler ~/.julia/packages/Reactant/bR5J7/src/Compiler.jl:3555
[9] top-level scope
@ ~/.julia/packages/Reactant/bR5J7/src/Compiler.jl:2633

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions