Skip to content

Forward-mode AD through @trace loop fails with shape mismatch in stablehlo.while #2361

@KookiesNKareem

Description

@KookiesNKareem

Forward-mode Enzyme.gradient through a @trace loop fails in the MLIR DifferentiatePass. The tangent of a loop-carried accumulator (tensor<10xf32>) is expanded to a Jacobian (tensor<10x10xf32>), but the stablehlo.while condition block still expects the primal shape.

expect operands to be compatible with condition block arguments but got
'tensor<i64>', 'tensor<10xf32>', 'tensor<10x10xf32>'
vs
'tensor<i64>', 'tensor<10xf32>', 'tensor<10xf32>'
# Reactant CPU forward-mode AD through @trace loop fails in DifferentiatePass

using Reactant
using Reactant: @jit
using ReactantCore: @trace
using Enzyme

Reactant.set_default_backend("cpu")
Reactant.allowscalar(true)

@inline valof(::Val{N}) where N = N

# Works: no @trace
function loss_no_trace(x)
    return sum(x .* x)
end

# Crashes: @trace loop accumulating over timesteps
function loss_with_trace(x, coeffs, n_val)
    N = valof(n_val)
    acc = zero(x)
    @trace for i in 1:N
        c = coeffs[i]
        acc = acc .+ c .* x
    end
    return sum(acc .* acc)
end

N_elem = 10
N_steps = 4
x_ra = Reactant.to_rarray(ones(Float32, N_elem))
c_ra = Reactant.to_rarray(Float32.(1:N_steps))

# Test 1: Forward-mode without @trace
println("Test 1: Forward-mode, no @trace")
try
    grad = @jit Enzyme.gradient(Forward, loss_no_trace, x_ra)
    println("grad = $(Array(grad[1]))")
catch e
    print(e)
end

# Test 2: Forward-mode with @trace
println("\nTest 2: Forward-mode, with @trace")
try
    grad = @jit Enzyme.gradient(Forward, loss_with_trace, x_ra, Const(c_ra), Const(Val(N_steps)))
    println("grad = $(Array(grad[1]))")
catch e
    print(e)
end

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions