-
Notifications
You must be signed in to change notification settings - Fork 774
Failure after converting HLO to StableHLO: INTERNAL: during context [Unknown]: Mismatched tuple structure in original_value for instruction %while.1 = s32[...]{0} while(%constant.9), condition=%region_1.4, body=%region_0.3, origin={({"w"})}... #39836
Copy link
Copy link
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Bug Report:
This bug is triggered by the HLO-to-StableHLO conversion: the HLO module runs successfully via run_hlo_module --input_format=hlo, and the conversion via hlo-translate --hlo-to-mlir also succeeds, but running the translated StableHLO via run_hlo_module --input_format=stablehlo fails with:
INTERNAL: during context [Unknown]: Mismatched tuple structure in original_value for instruction %while.1 = s32[...]{0} while(%constant.9), condition=%region_1.4, body=%region_0.3, origin={({"w"})}, metadata={source_file="-" source_line=12 source_end_line=12 source_column=10 source_end_column=10}. Leaf indices in shape and original_value do not match.
Environment
- XLA commit:
5ce7908a2d32a9f91fd99380435cda1b645c8cc7 - CPU:
Intel(R) Core(TM) i9-14900HX - GPU:
NVIDIA GeForce RTX 4060 Laptop GPU - CUDA Driver:
580.126.09
run_hlo_module (HLO) — Success
HLO:
HloModule Test, entry_computation_layout={()->(s32[1]{0})}
Body {
param = (s32[1]{0}) parameter(0)
a = s32[1]{0} constant({0})
ROOT tuple = (s32[1]{0}) tuple(a)
}
Cond {
param.1 = (s32[1]{0}) parameter(0)
ROOT cond = pred[] constant(true)
}
ENTRY Loop {
a.1 = s32[1]{0} constant({0})
init = (s32[1]{0}) tuple(a.1), origin={({"a"})}
ROOT while = (s32[1]{0}) while(init), condition=Cond, body=Body, origin={({"w"})}
}
Execution Command:
run_hlo_module \
--platform=CPU \
--reference_platform= \
--input_format=hlo \
Test_59043631_30.hloOutput:
** Running Test_59043631_30.hlo**
Running HLO module with runner Host...
... compiled and ran in 0.00753302s.
Skipping reference runner
run_hlo_module (StableHLO) — FAIL
Translation Command:
hlo-translate \
--hlo-to-mlir \
Test_59043631_30.hlo \
-o \
Test_59043631_30.mlirIR After Translation:
module @Test attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
func.func private @Body(%arg0: tensor<1xi32>) -> tensor<1xi32> {
%c = stablehlo.constant dense<0> : tensor<1xi32>
return %c : tensor<1xi32>
}
func.func private @Cond(%arg0: tensor<1xi32>) -> tensor<i1> {
%c = stablehlo.constant dense<true> : tensor<i1>
return %c : tensor<i1>
}
func.func @main() -> tensor<1xi32> {
%c = stablehlo.constant dense<0> : tensor<1xi32>
%0 = stablehlo.while(%iterArg = %c) : tensor<1xi32> attributes {mhlo.original_value = "{({\22w\22})}"}
cond {
%c_0 = stablehlo.constant dense<true> : tensor<i1>
stablehlo.return %c_0 : tensor<i1>
} do {
%c_0 = stablehlo.constant dense<0> : tensor<1xi32>
stablehlo.return %c_0 : tensor<1xi32>
}
return %0 : tensor<1xi32>
}
}Execution Command:
run_hlo_module \
--platform=CPU \
--reference_platform= \
--input_format=stablehlo \
Test_59043631_30.mlirOutput:
** Running Test_59043631_30.mlir**
INTERNAL: during context [Unknown]: Mismatched tuple structure in original_value for instruction %while.1 = s32[1]{0} while(%constant.9), condition=%region_1.4, body=%region_0.3, origin={({"w"})}, metadata={source_file="-" source_line=12 source_end_line=12 source_column=10 source_end_column=10}. Leaf indices in shape and original_value do not match.
In shape only: {{}}
In original_value only: {{0}}
Contact
- Email:
ch395@njit.edu, zhihao.yao@njit.edu, benquike@gmail.com
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working