Skip to content

Failure after converting HLO to StableHLO: error: 'builtin.module' op -:20:5: error: Expected array argument for rhs of binary operation add, but got (f32[...]).: #39831

@housrepository

Description

@housrepository

Bug Report:

This bug is triggered by the HLO-to-StableHLO conversion: the HLO module runs successfully via run_hlo_module --input_format=hlo, and the conversion via hlo-translate --hlo-to-mlir also succeeds, but running the translated StableHLO via run_hlo_module --input_format=stablehlo fails with:

error: 'builtin.module' op -:20:5: error: Expected array argument for rhs of binary operation add, but got (f32[...]).:

Environment

  • XLA commit: 5ce7908a2d32a9f91fd99380435cda1b645c8cc7
  • CPU: Intel(R) Core(TM) i9-14900HX
  • GPU: NVIDIA GeForce RTX 4060 Laptop GPU
  • CUDA Driver: 580.126.09

run_hlo_module (HLO) — Success

HLO:

HloModule AsyncCall, entry_computation_layout={(f32[4096]{0}, f32[4096]{0})->f32[4096]{0}}

called_computation {
  param_0 = f32[4096]{0} parameter(0)
  param_1 = f32[4096]{0} parameter(1)
  ROOT result.1 = f32[4096]{0} add(param_0, param_1)
}

ENTRY main {
  a = f32[4096]{0} parameter(0)
  b = f32[4096]{0} parameter(1)
  call-start.0 = ((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) call-start(a, b), to_apply=called_computation
  call-done.0 = f32[4096]{0} call-done(call-start.0)
  call-start.1 = ((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) call-start(call-done.0, b), to_apply=called_computation
  call-done.1 = f32[4096]{0} call-done(call-start.1)
  ROOT add_1 = f32[4096]{0} add(a, call-done.1)
}

Execution Command:

run_hlo_module \
  --platform=CUDA \
  --reference_platform= \
  --input_format=hlo \
  AsyncCall_3f46825b_4.hlo

Output:


 ** Running AsyncCall_3f46825b_4.hlo** 
Running HLO module with runner CUDA...
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1774418325.595760 2114943 cuda_dnn.cc:461] Loaded cuDNN version 91900
... compiled and ran in 0.0769556s.
Skipping reference runner

run_hlo_module (StableHLO) — FAIL

Translation Command:

hlo-translate \
  --hlo-to-mlir \
  AsyncCall_3f46825b_4.hlo \
  -o \
  AsyncCall_3f46825b_4.mlir

IR After Translation:

module @AsyncCall attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
  func.func private @called_computation(%arg0: tensor<4096xf32>, %arg1: tensor<4096xf32>) -> tensor<4096xf32> {
    %0 = stablehlo.add %arg0, %arg1 : tensor<4096xf32>
    return %0 : tensor<4096xf32>
  }
  func.func private @async_wrapped(%arg0: tensor<4096xf32>, %arg1: tensor<4096xf32>) -> tensor<4096xf32> attributes {execution_thread = "main"} {
    %0 = call @called_computation(%arg0, %arg1) : (tensor<4096xf32>, tensor<4096xf32>) -> tensor<4096xf32>
    return %0 : tensor<4096xf32>
  }
  func.func private @async_wrapped.1(%arg0: tensor<4096xf32>, %arg1: tensor<4096xf32>) -> tensor<4096xf32> attributes {execution_thread = "main"} {
    %0 = call @called_computation(%arg0, %arg1) : (tensor<4096xf32>, tensor<4096xf32>) -> tensor<4096xf32>
    return %0 : tensor<4096xf32>
  }
  func.func @main(%arg0: tensor<4096xf32>, %arg1: tensor<4096xf32>) -> tensor<4096xf32> {
    %0 = "mhlo.async_start"(%arg0, %arg1) <{called_computation = @async_wrapped, execution_thread = "main"}> {xla_shape = "((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[])"} : (tensor<4096xf32>, tensor<4096xf32>) -> !mhlo.async_bundle<tuple<tensor<4096xf32>, tensor<4096xf32>>, tensor<4096xf32>, tensor<ui32>>
    %1 = "mhlo.async_done"(%0) {called_computation = @async_wrapped, execution_thread = "main"} : (!mhlo.async_bundle<tuple<tensor<4096xf32>, tensor<4096xf32>>, tensor<4096xf32>, tensor<ui32>>) -> tensor<4096xf32>
    %2 = "mhlo.async_start"(%1, %arg1) <{called_computation = @async_wrapped.1, execution_thread = "main"}> {xla_shape = "((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[])"} : (tensor<4096xf32>, tensor<4096xf32>) -> !mhlo.async_bundle<tuple<tensor<4096xf32>, tensor<4096xf32>>, tensor<4096xf32>, tensor<ui32>>
    %3 = "mhlo.async_done"(%2) {called_computation = @async_wrapped.1, execution_thread = "main"} : (!mhlo.async_bundle<tuple<tensor<4096xf32>, tensor<4096xf32>>, tensor<4096xf32>, tensor<ui32>>) -> tensor<4096xf32>
    %4 = stablehlo.add %arg0, %3 : tensor<4096xf32>
    return %4 : tensor<4096xf32>
  }
}

Execution Command:

run_hlo_module \
  --platform=CUDA \
  --reference_platform= \
  --input_format=stablehlo \
  AsyncCall_3f46825b_4.mlir

Output:

loc("-":1:1): error: 'builtin.module' op -:20:5: error: Expected array argument for rhs of binary operation add, but got (f32[4096]).: 
-:20:5: note: see current operation: "func.return"(%4) : (tensor<4096xf32>) -> ()

WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1774418325.963264 2114990 translate.cc:190] Conversion to HLO module failed: UNKNOWN: -:20:5: error: Expected array argument for rhs of binary operation add, but got (f32[4096]).: 
-:20:5: note: see current operation: "func.return"(%4) : (tensor<4096xf32>) -> ()

F0000 00:00:1774418325.963628 2114990 hlo_module_loader.cc:117] Failed to translate input stablehlo program to HLO text

Contact

  • Email: ch395@njit.edu, zhihao.yao@njit.edu, benquike@gmail.com

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions