Skip to content

Failure after converting HLO to StableHLO: error: 'builtin.module' op -:16:12: error: INVALID_ARGUMENT: The parameter of condition and body, the result of the body, and init must all have the same shape #39828

@housrepository

Description

@housrepository

Bug Report:

This bug is triggered by the HLO-to-StableHLO conversion: the HLO module runs successfully via run_hlo_module --input_format=hlo, and the conversion via hlo-translate --hlo-to-mlir also succeeds, but running the translated StableHLO via run_hlo_module --input_format=stablehlo fails with:

error: 'builtin.module' op -:16:12: error: INVALID_ARGUMENT: The parameter of condition and body, the result of the body, and init must all have the same shape; got Condition: (arg_tuple.2: (s32[...], s32[...], s32[], pred[])) -> pred[]; body: (arg_tuple: (s32[...], s32[...], s32[], pred[])) -> ((s32[...]), s32[...], s32[], pred[]); init: (s32[...], s32[...], s32[], pred[])..

Environment

  • XLA commit: 5ce7908a2d32a9f91fd99380435cda1b645c8cc7
  • CPU: Intel(R) Core(TM) i9-14900HX
  • GPU: NVIDIA GeForce RTX 4060 Laptop GPU
  • CUDA Driver: 580.126.09

run_hlo_module (HLO) — Success

HLO:

HloModule module, entry_computation_layout={(s32[256]{0}, s32[], pred[])->s32[1024]{0}}

body {
  input_tuple.1 = (s32[1024]{0}, s32[256]{0}, s32[], pred[]) parameter(0)
  input.1 = s32[1024]{0} get-tuple-element(input_tuple.1), index=0
  input.2 = s32[256]{0} get-tuple-element(input_tuple.1), index=1
  input.3 = s32[] get-tuple-element(input_tuple.1), index=2
  async-start = ((s32[1024]{0}, s32[256]{0}, s32[]), s32[1024]{0}, u32[]) dynamic-update-slice-start(input.1, input.2, input.3)
  async-done = s32[1024]{0} dynamic-update-slice-done(async-start)
  input.4 = pred[] get-tuple-element(input_tuple.1), index=3
  ROOT tuple = (s32[1024]{0}, s32[256]{0}, s32[], pred[]) tuple(async-done, input.2, input.3, input.4)
}

condition {
  input_tuple = (s32[1024]{0}, s32[256]{0}, s32[], pred[]) parameter(0)
  ROOT cond = pred[] get-tuple-element(input_tuple), index=3
}

ENTRY main {
  input.5 = s32[] parameter(1)
  broadcast = s32[1024]{0} broadcast(input.5), dimensions={}
  input.0 = s32[256]{0} parameter(0)
  input.6 = pred[] parameter(2)
  while_tuple = (s32[1024]{0}, s32[256]{0}, s32[], pred[]) tuple(broadcast, input.0, input.5, input.6)
  while = (s32[1024]{0}, s32[256]{0}, s32[], pred[]) while(while_tuple), condition=condition, body=body
  ROOT gte = s32[1024]{0} get-tuple-element(while), index=0
}

Execution Command:

run_hlo_module \
  --platform=CUDA \
  --reference_platform= \
  --input_format=hlo \
  module_9d759a35_34.hlo

Output:


 ** Running module_9d759a35_34.hlo** 
Running HLO module with runner CUDA...
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1774418321.819708 2114360 cuda_dnn.cc:461] Loaded cuDNN version 91900
... compiled and ran in 0.0697718s.
Skipping reference runner

run_hlo_module (StableHLO) — FAIL

Translation Command:

hlo-translate \
  --hlo-to-mlir \
  module_9d759a35_34.hlo \
  -o \
  module_9d759a35_34.mlir

IR After Translation:

module @module attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
  func.func private @async_wrapped(%arg0: tensor<1024xi32>, %arg1: tensor<256xi32>, %arg2: tensor<i32>) -> tensor<1024xi32> attributes {execution_thread = "main"} {
    %0 = stablehlo.dynamic_update_slice %arg0, %arg1, %arg2 : (tensor<1024xi32>, tensor<256xi32>, tensor<i32>) -> tensor<1024xi32>
    return %0 : tensor<1024xi32>
  }
  func.func private @body(%arg0: tensor<1024xi32>, %arg1: tensor<256xi32>, %arg2: tensor<i32>, %arg3: tensor<i1>) -> (tensor<1024xi32>, tensor<256xi32>, tensor<i32>, tensor<i1>) {
    %0 = "mhlo.async_start"(%arg0, %arg1, %arg2) <{called_computation = @async_wrapped, execution_thread = "main"}> {xla_shape = "((s32[1024]{0}, s32[256]{0}, s32[]), s32[1024]{0}, u32[])"} : (tensor<1024xi32>, tensor<256xi32>, tensor<i32>) -> !mhlo.async_bundle<tuple<tensor<1024xi32>, tensor<256xi32>, tensor<i32>>, tensor<1024xi32>, tensor<ui32>>
    %1 = "mhlo.async_done"(%0) {called_computation = @async_wrapped, execution_thread = "main"} : (!mhlo.async_bundle<tuple<tensor<1024xi32>, tensor<256xi32>, tensor<i32>>, tensor<1024xi32>, tensor<ui32>>) -> tensor<1024xi32>
    return %1, %arg1, %arg2, %arg3 : tensor<1024xi32>, tensor<256xi32>, tensor<i32>, tensor<i1>
  }
  func.func private @condition(%arg0: tensor<1024xi32>, %arg1: tensor<256xi32>, %arg2: tensor<i32>, %arg3: tensor<i1>) -> tensor<i1> {
    return %arg3 : tensor<i1>
  }
  func.func @main(%arg0: tensor<256xi32>, %arg1: tensor<i32>, %arg2: tensor<i1>) -> tensor<1024xi32> {
    %0 = stablehlo.broadcast_in_dim %arg1, dims = [] : (tensor<i32>) -> tensor<1024xi32>
    %1:4 = stablehlo.while(%iterArg = %0, %iterArg_0 = %arg0, %iterArg_1 = %arg1, %iterArg_2 = %arg2) : tensor<1024xi32>, tensor<256xi32>, tensor<i32>, tensor<i1>
    cond {
      stablehlo.return %iterArg_2 : tensor<i1>
    } do {
      %2 = "mhlo.async_start"(%iterArg, %iterArg_0, %iterArg_1) <{called_computation = @async_wrapped, execution_thread = "main"}> {xla_shape = "((s32[1024]{0}, s32[256]{0}, s32[]), s32[1024]{0}, u32[])"} : (tensor<1024xi32>, tensor<256xi32>, tensor<i32>) -> !mhlo.async_bundle<tuple<tensor<1024xi32>, tensor<256xi32>, tensor<i32>>, tensor<1024xi32>, tensor<ui32>>
      %3 = "mhlo.async_done"(%2) {called_computation = @async_wrapped, execution_thread = "main"} : (!mhlo.async_bundle<tuple<tensor<1024xi32>, tensor<256xi32>, tensor<i32>>, tensor<1024xi32>, tensor<ui32>>) -> tensor<1024xi32>
      stablehlo.return %3, %iterArg_0, %iterArg_1, %iterArg_2 : tensor<1024xi32>, tensor<256xi32>, tensor<i32>, tensor<i1>
    }
    return %1#0 : tensor<1024xi32>
  }
}

Execution Command:

run_hlo_module \
  --platform=CUDA \
  --reference_platform= \
  --input_format=stablehlo \
  module_9d759a35_34.mlir

Output:

loc("-":1:1): error: 'builtin.module' op -:16:12: error: INVALID_ARGUMENT: The parameter of condition and body, the result of the body, and init must all have the same shape; got Condition: (arg_tuple.2: (s32[1024], s32[256], s32[], pred[])) -> pred[]; body: (arg_tuple: (s32[1024], s32[256], s32[], pred[])) -> ((s32[1024]), s32[256], s32[], pred[]); init: (s32[1024], s32[256], s32[], pred[])..
-:16:12: note: see current operation: 
%1:4 = "stablehlo.while"(%0, %arg0, %arg1, %arg2) ({
^bb0(%arg7: tensor<1024xi32>, %arg8: tensor<256xi32>, %arg9: tensor<i32>, %arg10: tensor<i1>):
  "stablehlo.return"(%arg10) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<1024xi32>, %arg4: tensor<256xi32>, %arg5: tensor<i32>, %arg6: tensor<i1>):
  %2 = "mhlo.async_start"(%arg3, %arg4, %arg5) <{called_computation = @async_wrapped, execution_thread = "main"}> {xla_shape = "((s32[1024]{0}, s32[256]{0}, s32[]), s32[1024]{0}, u32[])"} : (tensor<1024xi32>, tensor<256xi32>, tensor<i32>) -> !mhlo.async_bundle<tuple<tensor<1024xi32>, tensor<256xi32>, tensor<i32>>, tensor<1024xi32>, tensor<ui32>>
  %3 = "mhlo.async_done"(%2) {called_computation = @async_wrapped, execution_thread = "main"} : (!mhlo.async_bundle<tuple<tensor<1024xi32>, tensor<256xi32>, tensor<i32>>, tensor<1024xi32>, tensor<ui32>>) -> tensor<1024xi32>
  "stablehlo.return"(%3, %arg4, %arg5, %arg6) : (tensor<1024xi32>, tensor<256xi32>, tensor<i32>, tensor<i1>) -> ()
}) : (tensor<1024xi32>, tensor<256xi32>, tensor<i32>, tensor<i1>) -> (tensor<1024xi32>, tensor<256xi32>, tensor<i32>, tensor<i1>)
-:16:12: error: 'stablehlo.while' op can't be translated to XLA HLO
-:16:12: note: see current operation: 
%1:4 = "stablehlo.while"(%0, %arg0, %arg1, %arg2) ({
^bb0(%arg7: tensor<1024xi32>, %arg8: tensor<256xi32>, %arg9: tensor<i32>, %arg10: tensor<i1>):
  "stablehlo.return"(%arg10) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<1024xi32>, %arg4: tensor<256xi32>, %arg5: tensor<i32>, %arg6: tensor<i1>):
  %2 = "mhlo.async_start"(%arg3, %arg4, %arg5) <{called_computation = @async_wrapped, execution_thread = "main"}> {xla_shape = "((s32[1024]{0}, s32[256]{0}, s32[]), s32[1024]{0}, u32[])"} : (tensor<1024xi32>, tensor<256xi32>, tensor<i32>) -> !mhlo.async_bundle<tuple<tensor<1024xi32>, tensor<256xi32>, tensor<i32>>, tensor<1024xi32>, tensor<ui32>>
  %3 = "mhlo.async_done"(%2) {called_computation = @async_wrapped, execution_thread = "main"} : (!mhlo.async_bundle<tuple<tensor<1024xi32>, tensor<256xi32>, tensor<i32>>, tensor<1024xi32>, tensor<ui32>>) -> tensor<1024xi32>
  "stablehlo.return"(%3, %arg4, %arg5, %arg6) : (tensor<1024xi32>, tensor<256xi32>, tensor<i32>, tensor<i1>) -> ()
}) : (tensor<1024xi32>, tensor<256xi32>, tensor<i32>, tensor<i1>) -> (tensor<1024xi32>, tensor<256xi32>, tensor<i32>, tensor<i1>)

WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1774418322.175888 2114407 translate.cc:190] Conversion to HLO module failed: UNKNOWN: -:16:12: error: INVALID_ARGUMENT: The parameter of condition and body, the result of the body, and init must all have the same shape; got Condition: (arg_tuple.2: (s32[1024], s32[256], s32[], pred[])) -> pred[]; body: (arg_tuple: (s32[1024], s32[256], s32[], pred[])) -> ((s32[1024]), s32[256], s32[], pred[]); init: (s32[1024], s32[256], s32[], pred[])..
-:16:12: note: see current operation: 
%1:4 = "stablehlo.while"(%0, %arg0, %arg1, %arg2) ({
^bb0(%arg7: tensor<1024xi32>, %arg8: tensor<256xi32>, %arg9: tensor<i32>, %arg10: tensor<i1>):
  "stablehlo.return"(%arg10) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<1024xi32>, %arg4: tensor<256xi32>, %arg5: tensor<i32>, %arg6: tensor<i1>):
  %2 = "mhlo.async_start"(%arg3, %arg4, %arg5) <{called_computation = @async_wrapped, execution_thread = "main"}> {xla_shape = "((s32[1024]{0}, s32[256]{0}, s32[]), s32[1024]{0}, u32[])"} : (tensor<1024xi32>, tensor<256xi32>, tensor<i32>) -> !mhlo.async_bundle<tuple<tensor<1024xi32>, tensor<256xi32>, tensor<i32>>, tensor<1024xi32>, tensor<ui32>>
  %3 = "mhlo.async_done"(%2) {called_computation = @async_wrapped, execution_thread = "main"} : (!mhlo.async_bundle<tuple<tensor<1024xi32>, tensor<256xi32>, tensor<i32>>, tensor<1024xi32>, tensor<ui32>>) -> tensor<1024xi32>
  "stablehlo.return"(%3, %arg4, %arg5, %arg6) : (tensor<1024xi32>, tensor<256xi32>, tensor<i32>, tensor<i1>) -> ()
}) : (tensor<1024xi32>, tensor<256xi32>, tensor<i32>, tensor<i1>) -> (tensor<1024xi32>, tensor<256xi32>, tensor<i32>, tensor<i1>)
-:16:12: error: 'stablehlo.while' op can't be translated to XLA HLO
-:16:12: note: see current operation: 
%1:4 = "stablehlo.while"(%0, %arg0, %arg1, %arg2) ({
^bb0(%arg7: tensor<1024xi32>, %arg8: tensor<256xi32>, %arg9: tensor<i32>, %arg10: tensor<i1>):
  "stablehlo.return"(%arg10) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<1024xi32>, %arg4: tensor<256xi32>, %arg5: tensor<i32>, %arg6: tensor<i1>):
  %2 = "mhlo.async_start"(%arg3, %arg4, %arg5) <{called_computation = @async_wrapped, execution_thread = "main"}> {xla_shape = "((s32[1024]{0}, s32[256]{0}, s32[]), s32[1024]{0}, u32[])"} : (tensor<1024xi32>, tensor<256xi32>, tensor<i32>) -> !mhlo.async_bundle<tuple<tensor<1024xi32>, tensor<256xi32>, tensor<i32>>, tensor<1024xi32>, tensor<ui32>>
  %3 = "mhlo.async_done"(%2) {called_computation = @async_wrapped, execution_thread = "main"} : (!mhlo.async_bundle<tuple<tensor<1024xi32>, tensor<256xi32>, tensor<i32>>, tensor<1024xi32>, tensor<ui32>>) -> tensor<1024xi32>
  "stablehlo.return"(%3, %arg4, %arg5, %arg6) : (tensor<1024xi32>, tensor<256xi32>, tensor<i32>, tensor<i1>) -> ()
}) : (tensor<1024xi32>, tensor<256xi32>, tensor<i32>, tensor<i1>) -> (tensor<1024xi32>, tensor<256xi32>, tensor<i32>, tensor<i1>)

F0000 00:00:1774418322.176375 2114407 hlo_module_loader.cc:117] Failed to translate input stablehlo program to HLO text

Contact

  • Email: ch395@njit.edu, zhihao.yao@njit.edu, benquike@gmail.com

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions