Skip to content

Failure after converting HLO to StableHLO: Check failed: while_instr->shape().IsTuple() #39830

@housrepository

Description

@housrepository

Bug Report:

This bug is triggered by the HLO-to-StableHLO conversion: the HLO module runs successfully via run_hlo_module --input_format=hlo, and the conversion via hlo-translate --hlo-to-mlir also succeeds, but running the translated StableHLO via run_hlo_module --input_format=stablehlo fails with:

Check failed: while_instr->shape().IsTuple()

Environment

  • XLA commit: 5ce7908a2d32a9f91fd99380435cda1b645c8cc7
  • CPU: Intel(R) Core(TM) i9-14900HX
  • GPU: NVIDIA GeForce RTX 4060 Laptop GPU
  • CUDA Driver: 580.126.09

run_hlo_module (HLO) — Success

HLO:

HloModule module, entry_computation_layout={()->s32[2]{0}}

body {
  stack = (s32[<=4,2]{1,0}) parameter(0)
  stack_buffer = s32[<=4,2]{1,0} get-tuple-element(stack), index=0
  stack_size = s32[] get-dimension-size(stack_buffer), dimensions={0}
  one = s32[] constant(1)
  new_stack_size = s32[] add(stack_size, one)
  new_stack_buffer = s32[<=4,2]{1,0} set-dimension-size(stack_buffer, new_stack_size), dimensions={0}
  new_data = s32[1,2]{1,0} broadcast(stack_size), dimensions={}
  zero = s32[] constant(0)
  new_stack = s32[<=4,2]{1,0} dynamic-update-slice(new_stack_buffer, new_data, stack_size, zero)
  ROOT new_stack_tuple = (s32[<=4,2]{1,0}) tuple(new_stack)
}

condition {
  stack.1 = (s32[<=4,2]{1,0}) parameter(0)
  stack_buffer.1 = s32[<=4,2]{1,0} get-tuple-element(stack.1), index=0
  stack_size.1 = s32[] get-dimension-size(stack_buffer.1), dimensions={0}
  three = s32[] constant(3)
  ROOT less-than = pred[] compare(stack_size.1, three), direction=LT
}

update_s32 {
  lhs = s32[] parameter(0)
  rhs = s32[] parameter(1)
  ROOT add = s32[] add(lhs, rhs)
}

ENTRY entry {
  pad = s32[] constant(-1)
  stack_buffer_input = s32[4,2]{1,0} broadcast(pad), dimensions={}
  zero.1 = s32[] constant(0)
  stack_buffer_input_dynamic = s32[<=4,2]{1,0} set-dimension-size(stack_buffer_input, zero.1), dimensions={0}
  input_tuple = (s32[<=4,2]{1,0}) tuple(stack_buffer_input_dynamic)
  while = (s32[<=4,2]{1,0}) while(input_tuple), condition=condition, body=body
  stack_buffer.2 = s32[<=4,2]{1,0} get-tuple-element(while), index=0
  ROOT reduce = s32[2]{0} reduce(stack_buffer.2, zero.1), dimensions={0}, to_apply=update_s32
}

Execution Command:

run_hlo_module \
  --platform=CPU \
  --reference_platform= \
  --input_format=hlo \
  module_9d759a35_8.hlo

Output:


 ** Running module_9d759a35_8.hlo** 
Running HLO module with runner Host...
... compiled and ran in 0.0246828s.
Skipping reference runner

run_hlo_module (StableHLO) — FAIL

Translation Command:

hlo-translate \
  --hlo-to-mlir \
  module_9d759a35_8.hlo \
  -o \
  module_9d759a35_8.mlir

IR After Translation:

module @module attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
  func.func private @body(%arg0: tensor<?x2xi32, #stablehlo.bounds<4, ?>>) -> tensor<?x2xi32, #stablehlo.bounds<4, ?>> {
    %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?x2xi32, #stablehlo.bounds<4, ?>>) -> tensor<i32>
    %c = stablehlo.constant dense<1> : tensor<i32>
    %1 = stablehlo.add %0, %c : tensor<i32>
    %2 = stablehlo.set_dimension_size %arg0, %1, dim = 0 : (tensor<?x2xi32, #stablehlo.bounds<4, ?>>, tensor<i32>) -> tensor<?x2xi32, #stablehlo.bounds<4, ?>>
    %3 = stablehlo.broadcast_in_dim %0, dims = [] : (tensor<i32>) -> tensor<1x2xi32>
    %c_0 = stablehlo.constant dense<0> : tensor<i32>
    %4 = stablehlo.dynamic_update_slice %2, %3, %0, %c_0 : (tensor<?x2xi32, #stablehlo.bounds<4, ?>>, tensor<1x2xi32>, tensor<i32>, tensor<i32>) -> tensor<?x2xi32, #stablehlo.bounds<4, ?>>
    return %4 : tensor<?x2xi32, #stablehlo.bounds<4, ?>>
  }
  func.func private @condition(%arg0: tensor<?x2xi32, #stablehlo.bounds<4, ?>>) -> tensor<i1> {
    %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?x2xi32, #stablehlo.bounds<4, ?>>) -> tensor<i32>
    %c = stablehlo.constant dense<3> : tensor<i32>
    %1 = stablehlo.compare  LT, %0, %c : (tensor<i32>, tensor<i32>) -> tensor<i1>
    return %1 : tensor<i1>
  }
  func.func private @update_s32(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
    %0 = stablehlo.add %arg0, %arg1 : tensor<i32>
    return %0 : tensor<i32>
  }
  func.func @main() -> tensor<2xi32> {
    %c = stablehlo.constant dense<-1> : tensor<i32>
    %0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<4x2xi32>
    %c_0 = stablehlo.constant dense<0> : tensor<i32>
    %1 = stablehlo.set_dimension_size %0, %c_0, dim = 0 : (tensor<4x2xi32>, tensor<i32>) -> tensor<?x2xi32, #stablehlo.bounds<4, ?>>
    %2 = stablehlo.while(%iterArg = %1) : tensor<?x2xi32, #stablehlo.bounds<4, ?>>
    cond {
      %4 = stablehlo.get_dimension_size %iterArg, dim = 0 : (tensor<?x2xi32, #stablehlo.bounds<4, ?>>) -> tensor<i32>
      %c_1 = stablehlo.constant dense<3> : tensor<i32>
      %5 = stablehlo.compare  LT, %4, %c_1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
      stablehlo.return %5 : tensor<i1>
    } do {
      %4 = stablehlo.get_dimension_size %iterArg, dim = 0 : (tensor<?x2xi32, #stablehlo.bounds<4, ?>>) -> tensor<i32>
      %c_1 = stablehlo.constant dense<1> : tensor<i32>
      %5 = stablehlo.add %4, %c_1 : tensor<i32>
      %6 = stablehlo.set_dimension_size %iterArg, %5, dim = 0 : (tensor<?x2xi32, #stablehlo.bounds<4, ?>>, tensor<i32>) -> tensor<?x2xi32, #stablehlo.bounds<4, ?>>
      %7 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor<i32>) -> tensor<1x2xi32>
      %c_2 = stablehlo.constant dense<0> : tensor<i32>
      %8 = stablehlo.dynamic_update_slice %6, %7, %4, %c_2 : (tensor<?x2xi32, #stablehlo.bounds<4, ?>>, tensor<1x2xi32>, tensor<i32>, tensor<i32>) -> tensor<?x2xi32, #stablehlo.bounds<4, ?>>
      stablehlo.return %8 : tensor<?x2xi32, #stablehlo.bounds<4, ?>>
    }
    %3 = stablehlo.reduce(%2 init: %c_0) applies stablehlo.add across dimensions = [0] : (tensor<?x2xi32, #stablehlo.bounds<4, ?>>, tensor<i32>) -> tensor<2xi32>
    return %3 : tensor<2xi32>
  }
}

Execution Command:

run_hlo_module \
  --platform=CPU \
  --reference_platform= \
  --input_format=stablehlo \
  module_9d759a35_8.mlir

Output:


 ** Running module_9d759a35_8.mlir** 
Running HLO module with runner Host...
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
F0000 00:00:1774418324.132360 2114867 while_util.cc:229] Check failed: while_instr->shape().IsTuple() 
*** Check failure stack trace: ***
    @     0x644dea1b04f4  absl::lts_20250814::log_internal::LogMessage::SendToLog()
    @     0x644dea1b0476  absl::lts_20250814::log_internal::LogMessage::Flush()
    @     0x644de97fae41  xla::WhileUtil::MakeInstructionsLiveIn()
    @     0x644de97e63b3  xla::DynamicDimensionInferenceVisitor::HandleWhile()
    @     0x644de9c67975  xla::HloInstruction::Visit<>()
    @     0x644de9c68d9b  xla::PostOrderDFS<>()
    @     0x644de9c66593  xla::HloInstruction::Accept<>()
    @     0x644de0337659  xla::HloComputation::Accept<>()
    @     0x644de97e5de2  xla::DynamicDimensionInferenceVisitor::Run()
    @     0x644de97e8404  xla::DynamicDimensionInference::AnalyzeDynamicDimensions()
    @     0x644de97e7ec9  xla::DynamicDimensionInference::Run()
    @     0x644de19652c7  xla::DynamicPadder::RunImpl()
    @     0x644de98c044f  xla::HloPassInterface::Run()
    @     0x644de8ad8ddc  xla::HloPassPipeline::RunHelper<>()
    @     0x644de8ad5da6  xla::HloPassPipeline::RunPassesInternal<>()
    @     0x644de8ad56be  xla::HloPassPipeline::RunImpl()
    @     0x644de98c044f  xla::HloPassInterface::Run()
    @     0x644de11d65ff  xla::cpu::CpuCompiler::RunHloPassesThroughLayoutAssn()
    @     0x644de11d8ddc  xla::cpu::CpuCompiler::RunHloPasses()
    @     0x644de11d91f0  xla::cpu::CpuCompiler::RunHloPasses()
    @     0x644de1995a6d  xla::LLVMCompiler::Compile()
    @     0x644de11d3724  xla::cpu::CpuCompiler::Compile()
    @     0x644de036b29e  xla::Compiler::Compile()
    @     0x644de036acd0  xla::HloRunner::CreateExecutableWithBufferAssignment()
    @     0x644de03652a9  xla::HloRunner::ExecuteWithMovedDeviceBuffersAndBufferAssignment()
    @     0x644de0364dfc  xla::HloRunner::Execute()
    @     0x644de037167e  xla::HloRunnerInterface::Execute()
    @     0x644de0340697  xla::(anonymous namespace)::ExecuteWithRunner()
    @     0x644de033c78c  xla::(anonymous namespace)::RunAndCompareInternal()
    @     0x644de033a806  xla::RunAndCompare()
    @     0x644de033dad6  xla::RunAndCompare()
    @     0x644de033910b  main
    @     0x74801062a1ca  (unknown)
    @     0x74801062a28b  __libc_start_main
    @     0x644dea1f31fa  _start

Contact

  • Email: ch395@njit.edu, zhihao.yao@njit.edu, benquike@gmail.com

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions