Skip to content

Failure after converting HLO to StableHLO: UNIMPLEMENTED: Computation compare.impl.1 is called in both a parallel (eg, kMap) and sequential (eg, kCall) context #39832

@housrepository

Description

@housrepository

Bug Report:

This bug is triggered by the HLO-to-StableHLO conversion: the HLO module runs successfully via run_hlo_module --input_format=hlo, and the conversion via hlo-translate --hlo-to-mlir also succeeds, but running the translated StableHLO via run_hlo_module --input_format=stablehlo fails with:

UNIMPLEMENTED: Computation compare.impl.1 is called in both a parallel (eg, kMap) and sequential (eg, kCall) context

Environment

  • XLA commit: 5ce7908a2d32a9f91fd99380435cda1b645c8cc7
  • CPU: Intel(R) Core(TM) i9-14900HX
  • GPU: NVIDIA GeForce RTX 4060 Laptop GPU
  • CUDA Driver: 580.126.09

run_hlo_module (HLO) — Success

HLO:

HloModule CallInSortInCall, entry_computation_layout={(f32[4096]{0}, f32[4096]{0})->f32[4096]{0}}

compare.impl {
  p.0.lhs = f32[] parameter(0)
  p.0.rhs = f32[] parameter(1)
  ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
}

compare {
  p.0.lhs.1 = f32[] parameter(0)
  p.0.rhs.1 = f32[] parameter(1)
  ROOT lt.1 = pred[] call(p.0.lhs.1, p.0.rhs.1), to_apply=compare.impl
}

called_computation {
  x = f32[4096]{0} parameter(0)
  ROOT sort = f32[4096]{0} sort(x), dimensions={0}, to_apply=compare
}

ENTRY main {
  a = f32[4096]{0} parameter(0)
  call0 = f32[4096]{0} call(a), to_apply=called_computation
  b = f32[4096]{0} parameter(1)
  call1 = f32[4096]{0} call(b), to_apply=called_computation
  ROOT multiply = f32[4096]{0} multiply(call0, call1)
}

Execution Command:

run_hlo_module \
  --platform=CPU \
  --reference_platform= \
  --input_format=hlo \
  CallInSortInCall_a7179098_9.hlo

Output:


 ** Running CallInSortInCall_a7179098_9.hlo** 
Running HLO module with runner Host...
... compiled and ran in 0.0325732s.
Skipping reference runner

run_hlo_module (StableHLO) — FAIL

Translation Command:

hlo-translate \
  --hlo-to-mlir \
  CallInSortInCall_a7179098_9.hlo \
  -o \
  CallInSortInCall_a7179098_9.mlir

IR After Translation:

module @CallInSortInCall attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
  func.func private @compare.impl(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<i1> {
    %0 = stablehlo.compare  LT, %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<i1>
    return %0 : tensor<i1>
  }
  func.func private @compare(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<i1> {
    %0 = call @compare.impl(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<i1>
    return %0 : tensor<i1>
  }
  func.func private @called_computation(%arg0: tensor<4096xf32>) -> tensor<4096xf32> {
    %0 = "stablehlo.sort"(%arg0) <{dimension = 0 : i64, is_stable = false}> ({
    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
      %1 = func.call @compare.impl(%arg1, %arg2) : (tensor<f32>, tensor<f32>) -> tensor<i1>
      stablehlo.return %1 : tensor<i1>
    }) : (tensor<4096xf32>) -> tensor<4096xf32>
    return %0 : tensor<4096xf32>
  }
  func.func @main(%arg0: tensor<4096xf32>, %arg1: tensor<4096xf32>) -> tensor<4096xf32> {
    %0 = call @called_computation(%arg0) : (tensor<4096xf32>) -> tensor<4096xf32>
    %1 = call @called_computation(%arg1) : (tensor<4096xf32>) -> tensor<4096xf32>
    %2 = stablehlo.multiply %0, %1 : tensor<4096xf32>
    return %2 : tensor<4096xf32>
  }
}

Execution Command:

run_hlo_module \
  --platform=CPU \
  --reference_platform= \
  --input_format=stablehlo \
  CallInSortInCall_a7179098_9.mlir

Output:


 ** Running CallInSortInCall_a7179098_9.mlir** 
UNIMPLEMENTED: Computation compare.impl.1 is called in both a parallel (eg, kMap) and sequential (eg, kCall) context

Contact

  • Email: ch395@njit.edu, zhihao.yao@njit.edu, benquike@gmail.com

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions