Skip to content

Reactant v0.2.172 and above fails to compile Lux model, while v0.2.171 works just fine #1804

@Sleort

Description

@Sleort

This works:

(jl_DYkUR1) pkg> st
Status `/tmp/jl_DYkUR1/Project.toml`
  [b2108857] Lux v1.24.0
⌃ [3c362404] Reactant v0.2.171
  [9a3f8284] Random v1.11.0
Info Packages marked with ⌃ have new versions available and may be upgradable.

julia> using Lux, Reactant, Random

julia> const xdev = reactant_device()
(::ReactantDevice{Missing, Missing, Missing, Missing}) (generic function with 1 method)

julia> model = Conv((3,3), 1=>1)
Conv((3, 3), 1 => 1)         # 10 parameters

julia> ps, st = Lux.setup(Random.default_rng(), model) |> xdev;
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1761814970.893636   36469 service.cc:158] XLA service 0x22b16d0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1761814970.893707   36469 service.cc:166]   StreamExecutor device (0): NVIDIA GeForce RTX 4070 Laptop GPU, Compute Capability 8.9
I0000 00:00:1761814970.894194   36469 se_gpu_pjrt_client.cc:1339] Using BFC allocator.
I0000 00:00:1761814970.894238   36469 gpu_helpers.cc:136] XLA backend allocating 6146408448 bytes on device 0 for BFCAllocator.
I0000 00:00:1761814970.894277   36469 gpu_helpers.cc:177] XLA backend will use up to 2048802816 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1761814970.907333   36469 cuda_dnn.cc:463] Loaded cuDNN version 90800

julia> x = rand(Float32, 10, 10, 1, 1) |> xdev;

julia> @jit model(x, ps, st)

(ConcretePJRTArray{Float32, 4, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.8066233 -0.7050564 … -0.9303775 -0.5120818; -0.50369656 -0.5531037 … -0.22096159 -0.49018833; … ; -0.36883602 -0.36507434 … -0.6171071 -0.87650114; -0.094239205 -0.70827335 … -0.7866802 -1.2901264;;;;]), NamedTuple())

... while doing the same thing with v0.2.172 fails:

(jl_LkoBMu) pkg> st
Status `/tmp/jl_LkoBMu/Project.toml`
  [b2108857] Lux v1.24.0
⌃ [3c362404] Reactant v0.2.172
  [9a3f8284] Random v1.11.0
Info Packages marked with ⌃ have new versions available and may be upgradable.

julia> using Lux, Reactant, Random

julia> const xdev = reactant_device()
(::ReactantDevice{Missing, Missing, Missing, Missing}) (generic function with 1 method)

julia> model = Conv((3,3), 1=>1)
Conv((3, 3), 1 => 1)         # 10 parameters

julia> ps, st = Lux.setup(Random.default_rng(), model) |> xdev;
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1761815169.294402   37037 service.cc:158] XLA service 0x1622a4b0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1761815169.294458   37037 service.cc:166]   StreamExecutor device (0): NVIDIA GeForce RTX 4070 Laptop GPU, Compute Capability 8.9
I0000 00:00:1761815169.294830   37037 se_gpu_pjrt_client.cc:770] Using BFC allocator.
I0000 00:00:1761815169.294860   37037 gpu_helpers.cc:136] XLA backend allocating 6146408448 bytes on device 0 for BFCAllocator.
I0000 00:00:1761815169.294895   37037 gpu_helpers.cc:177] XLA backend will use up to 2048802816 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1761815169.317817   37037 cuda_dnn.cc:463] Loaded cuDNN version 91400

julia> x = rand(Float32, 10, 10, 1, 1) |> xdev;

julia> @jit model(x, ps, st)
┌ Error: Compilation failed, MLIR module written to /tmp/reactant_QfnHPl/module_000_reactant_Conv((3..._post_xla_compile.mlir
└ @ Reactant.MLIR.IR ~/.julia/packages/Reactant/Nc0uv/src/mlir/IR/Pass.jl:119
ERROR: module @"reactant_Conv((3..." attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  func.func @main(%arg0: tensor<1x1x10x10xf32>, %arg1: tensor<1x1x3x3xf32>, %arg2: tensor<1xf32>) -> tensor<1x1x8x8xf32> {
    %0 = "mhlo.reverse"(%arg1) <{dimensions = dense<[3, 2]> : tensor<2xi64>}> : (tensor<1x1x3x3xf32>) -> tensor<1x1x3x3xf32>
    %1 = mhlo.convolution(%arg0, %0) dim_numbers = [b, f, 1, 0]x[o, i, 1, 0]->[b, f, 1, 0], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<1x1x10x10xf32>, tensor<1x1x3x3xf32>) -> tensor<1x1x8x8xf32>
    %2 = "mhlo.broadcast_in_dim"(%arg2) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<1xf32>) -> tensor<1x1x8x8xf32>
    %3 = mhlo.add %1, %2 : tensor<1x1x8x8xf32>
    return %3 : tensor<1x1x8x8xf32>
  }
}
INTERNAL: Autotuner could not find any supported configs for HLO: %cudnn-conv-bw-input.1 = (f32[1,1,8,8]{3,2,1,0}, u8[0]{0}) custom-call(%bitcast.15, %bitcast.17), window={size=3x3 pad=2_2x2_2}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convBackwardInput", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"activation_mode":"kNone","conv_result_scale":1,"side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false,"reification_cost":[],"device_type":"DEVICE_TYPE_INVALID"}

Stacktrace:
 [1] reactant_err(msg::Cstring)
   @ Reactant.XLA ~/.julia/packages/Reactant/Nc0uv/src/xla/Utils.jl:12
 [2] (::Reactant.XLA.PJRT.var"#23#24"{Bool, Vector{Int64}, Int64, Int64, Bool, Reactant.XLA.PJRT.Client, Reactant.MLIR.IR.Module, Int64})()
   @ Reactant.XLA.PJRT ~/.julia/packages/Reactant/Nc0uv/src/xla/PJRT/LoadedExecutable.jl:83
 [3] try_compile_dump_mlir(f::Reactant.XLA.PJRT.var"#23#24"{Bool, Vector{Int64}, Int64, Int64, Bool, Reactant.XLA.PJRT.Client, Reactant.MLIR.IR.Module, Int64}, mod::Reactant.MLIR.IR.Module, pm::Nothing)
   @ Reactant.MLIR.IR ~/.julia/packages/Reactant/Nc0uv/src/mlir/IR/Pass.jl:134
 [4] try_compile_dump_mlir
   @ ~/.julia/packages/Reactant/Nc0uv/src/mlir/IR/Pass.jl:129 [inlined]
 [5] #compile#22
   @ ~/.julia/packages/Reactant/Nc0uv/src/xla/PJRT/LoadedExecutable.jl:82 [inlined]
 [6] compile_xla(f::Conv{typeof(identity), Int64, Int64, Tuple{Int64, Int64}, Tuple{Int64, Int64}, NTuple{4, Int64}, Tuple{Int64, Int64}, Int64, Nothing, Nothing, Static.True, Static.False}, args::Tuple{ConcretePJRTArray{Float32, 4, 1, Reactant.Sharding.ShardInfo{…}}, @NamedTuple{weight::ConcretePJRTArray{…}, bias::ConcretePJRTArray{…}}, @NamedTuple{}}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{compile_options::CompileOptions, fn_kwargs::@NamedTuple{}})
   @ Reactant.Compiler ~/.julia/packages/Reactant/Nc0uv/src/Compiler.jl:3547
 [7] compile_xla
   @ ~/.julia/packages/Reactant/Nc0uv/src/Compiler.jl:3475 [inlined]
 [8] compile(f::Conv{…}, args::Tuple{…}; kwargs::@Kwargs{…})
   @ Reactant.Compiler ~/.julia/packages/Reactant/Nc0uv/src/Compiler.jl:3579
 [9] top-level scope
   @ ~/.julia/packages/Reactant/Nc0uv/src/Compiler.jl:2652
Some type information was truncated. Use `show(err)` to see complete types.

FYI:

julia> versioninfo()
Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 22 × Intel(R) Core(TM) Ultra 7 155H
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, alderlake)
Threads: 1 default, 0 interactive, 1 GC (on 22 virtual cores)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions