Skip to content

Commit 0eef0fe

Browse files
majiddadashicopybara-github
authored andcommitted
Add legalization for mhlo.case to tfl.if.
Converts mhlo.case ops with two branches to tfl.if. The tfl.if predicate is true if the index is not 0, with the then_region corresponding to branch 1 and the else_region to branch 0. LiteRT-Converter-PiperOrigin-RevId: 828672857
1 parent ea98859 commit 0eef0fe

File tree

6 files changed

+180
-3
lines changed

6 files changed

+180
-3
lines changed

tflite/converter/stablehlo/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,7 @@ cc_library(
544544
":passes_inc_gen",
545545
":unfold_splat_constant_pass",
546546
"//tflite/converter:tensorflow_lite",
547+
"//tflite/converter/stablehlo/transforms/legalize_hlo_conversions:case",
547548
"//tflite/converter/stablehlo/transforms/legalize_hlo_conversions:conv",
548549
"//tflite/converter/stablehlo/transforms/legalize_hlo_conversions:custom_call",
549550
"//tflite/converter/stablehlo/transforms/legalize_hlo_conversions:dot_general",

tflite/converter/stablehlo/tests/tfl_legalize_hlo.mlir

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3721,14 +3721,43 @@ func.func @dynamic_broadcast_in_dim_general_case_expand_back_dims(%arg0: tensor<
37213721
// CHECK: %2 = "tfl.broadcast_to"(%1, %arg1) : (tensor<?x3000x1x1xf32>, tensor<4xi32>) -> tensor<?x3000x2x4xf32>
37223722

37233723

3724+
// -----
3725+
3726+
//===----------------------------------------------------------------------===//
3727+
// mhlo.case
3728+
//===----------------------------------------------------------------------===//
3729+
3730+
// CHECK-LABEL: case_func
3731+
func.func @case_func(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> (tensor<i32>) {
3732+
%0 = "mhlo.case"(%arg0) ({
3733+
%2 = mhlo.add %arg1, %arg2 : tensor<i32>
3734+
"mhlo.return"(%2) : (tensor<i32>) -> ()
3735+
}, {
3736+
%2 = mhlo.multiply %arg1, %arg1 : tensor<i32>
3737+
"mhlo.return"(%2) : (tensor<i32>) -> ()
3738+
}) : (tensor<i32>) -> tensor<i32>
3739+
func.return %0: tensor<i32>
3740+
}
3741+
3742+
// CHECK: %[[CST:.*]] = arith.constant dense<0> : tensor<i32>
3743+
// CHECK: %[[PRED:.*]] = tfl.not_equal(%arg0, %[[CST]]) : (tensor<i32>, tensor<i32>) -> tensor<i1>
3744+
// CHECK: %[[IF:.*]] = "tfl.if"(%[[PRED]]) ({
3745+
// CHECK: %[[MUL:.*]] = tfl.mul %arg1, %arg1 {fused_activation_function = "NONE"} : tensor<i32>
3746+
// CHECK: "tfl.yield"(%[[MUL]]) : (tensor<i32>) -> ()
3747+
// CHECK: }, {
3748+
// CHECK: %[[ADD:.*]] = tfl.add %arg1, %arg2 {fused_activation_function = "NONE"} : tensor<i32>
3749+
// CHECK: "tfl.yield"(%[[ADD]]) : (tensor<i32>) -> ()
3750+
// CHECK: }) : (tensor<i1>) -> tensor<i32>
3751+
// CHECK: return %[[IF]] : tensor<i32>
3752+
37243753
// -----
37253754

37263755
//===----------------------------------------------------------------------===//
37273756
// mhlo.if
37283757
//===----------------------------------------------------------------------===//
37293758

3730-
// CHECK-LABEL: if
3731-
func.func @if(%arg0: tensor<i1>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> (tensor<i32>) {
3759+
// CHECK-LABEL: if_label
3760+
func.func @if_label(%arg0: tensor<i1>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> (tensor<i32>) {
37323761
%0 = mhlo.add %arg1, %arg2 : tensor<i32>
37333762
%1 = "mhlo.if"(%arg0) ({
37343763
"mhlo.return"(%0) : (tensor<i32>) -> ()

tflite/converter/stablehlo/transforms/legalize_hlo_conversions/BUILD

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,21 @@ cc_library(
320320
],
321321
)
322322

323+
cc_library(
324+
name = "case",
325+
srcs = ["case.cc"],
326+
hdrs = ["case.h"],
327+
deps = [
328+
":util",
329+
"//tflite/converter:tensorflow_lite",
330+
"@llvm-project//mlir:ArithDialect",
331+
"@llvm-project//mlir:IR",
332+
"@llvm-project//mlir:Support",
333+
"@llvm-project//mlir:TransformUtils",
334+
"@local_xla//xla/mlir_hlo",
335+
],
336+
)
337+
323338
cc_library(
324339
name = "if",
325340
srcs = ["if.cc"],
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tflite/converter/stablehlo/transforms/legalize_hlo_conversions/case.h"
17+
18+
#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
19+
#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project
20+
#include "mlir/IR/PatternMatch.h" // from @llvm-project
21+
#include "mlir/Support/LLVM.h" // from @llvm-project
22+
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
23+
#include "tflite/converter/ir/tfl_ops.h"
24+
#include "tflite/converter/stablehlo/transforms/legalize_hlo_conversions/util.h"
25+
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
26+
27+
namespace mlir::odml {
28+
namespace {
29+
30+
// Legalizes mhlo.case op to tfl.if op.
31+
// This pattern only supports mhlo.case ops with exactly two branches.
32+
class LegalizeCaseOp : public OpConversionPattern<mhlo::CaseOp> {
33+
public:
34+
using OpConversionPattern<mhlo::CaseOp>::OpConversionPattern;
35+
36+
LogicalResult matchAndRewrite(
37+
mhlo::CaseOp case_op, OpAdaptor adaptor,
38+
ConversionPatternRewriter& rewriter) const final {
39+
// mhlo.case can have N branches, but tfl.if only supports two.
40+
if (case_op.getBranches().size() != 2) {
41+
return rewriter.notifyMatchFailure(
42+
case_op, "can only convert mhlo.case with 2 branches");
43+
}
44+
45+
// `mhlo.case` takes an index, `tfl.if` takes a boolean predicate.
46+
// For a 2-branch `mhlo.case` (branch 0 and branch 1), we need to map
47+
// the index to a boolean.
48+
// According to the mhlo.case spec, an out-of-bounds index defaults to the
49+
// index of the last branch, which is 1 in this case.
50+
// So, index 0 maps to branch 0, and any other index (1, or out of bounds)
51+
// maps to branch 1.
52+
// This can be expressed as a predicate `index != 0` for branch 1.
53+
54+
auto loc = case_op->getLoc();
55+
auto index = case_op.getIndex();
56+
auto index_type = mlir::cast<ShapedType>(index.getType());
57+
58+
// Create a constant tensor of the same shape as the index, filled with
59+
// zeros.
60+
auto const_zero = arith::ConstantOp::create(
61+
rewriter, loc, rewriter.getZeroAttr(index_type));
62+
63+
// Create the predicate `index != 0`.
64+
auto pred_type = index_type.clone(rewriter.getI1Type());
65+
auto pred = mhlo::CompareOp::create(
66+
rewriter, loc, pred_type, index, const_zero,
67+
mhlo::ComparisonDirectionAttr::get(rewriter.getContext(),
68+
mhlo::ComparisonDirection::NE),
69+
mhlo::ComparisonTypeAttr{}); // Default comparison type is fine for
70+
// integers.
71+
72+
// Create the tfl.if op.
73+
auto tfl_if =
74+
TFL::IfOp::create(rewriter, loc, case_op.getResultTypes(), pred);
75+
76+
// Branch 1 of mhlo.case becomes the `then_region` of tfl.if.
77+
tfl_if.getThenRegion().takeBody(case_op.getBranches()[1]);
78+
ReplaceTerminatorWithYield(tfl_if.getThenRegion(), rewriter);
79+
80+
// Branch 0 of mhlo.case becomes the `else_region` of tfl.if.
81+
tfl_if.getElseRegion().takeBody(case_op.getBranches()[0]);
82+
ReplaceTerminatorWithYield(tfl_if.getElseRegion(), rewriter);
83+
84+
rewriter.replaceOp(case_op, tfl_if.getResults());
85+
return success();
86+
}
87+
};
88+
89+
} // namespace
90+
91+
void PopulateCasePatterns(MLIRContext* context, RewritePatternSet& patterns,
92+
ConversionTarget& target) {
93+
patterns.add<LegalizeCaseOp>(context);
94+
// Mark mhlo.case as dynamically legal: it's legal if it does NOT have
95+
// exactly 2 branches, as those are the ones we want to convert.
96+
target.addDynamicallyLegalOp<mhlo::CaseOp>(
97+
[](mhlo::CaseOp op) { return op.getBranches().size() != 2; });
98+
}
99+
100+
} // namespace mlir::odml
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_CASE_H_
17+
#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_CASE_H_
18+
19+
#include "mlir/IR/PatternMatch.h" // from @llvm-project
20+
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
21+
22+
namespace mlir {
23+
namespace odml {
24+
25+
void PopulateCasePatterns(MLIRContext* context, RewritePatternSet& patterns,
26+
ConversionTarget& target);
27+
28+
} // namespace odml
29+
} // namespace mlir
30+
31+
#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_CASE_H_

tflite/converter/stablehlo/transforms/tflite_legalize_hlo.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ limitations under the License.
3838
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
3939
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
4040
#include "tflite/converter/ir/tfl_ops.h" // IWYU pragma: keep
41+
#include "tflite/converter/stablehlo/transforms/legalize_hlo_conversions/case.h"
4142
#include "tflite/converter/stablehlo/transforms/legalize_hlo_conversions/conv.h" // IWYU pragma: keep
4243
#include "tflite/converter/stablehlo/transforms/legalize_hlo_conversions/custom_call.h"
4344
#include "tflite/converter/stablehlo/transforms/legalize_hlo_conversions/dot_general.h" // IWYU pragma: keep
@@ -479,6 +480,7 @@ void LegalizeHloToTfLitePass::runOnOperation() {
479480
PopulateWhilePatterns(context, patterns, target);
480481
PopulateGetDimensionSizePatterns(context, patterns, target);
481482
PopulateIfPatterns(context, patterns, target);
483+
PopulateCasePatterns(context, patterns, target);
482484
PopulateLegalizeFftPatterns(context, patterns, target);
483485
PopulateCustomCallPatterns(context, patterns, target);
484486

@@ -493,7 +495,6 @@ void LegalizeHloToTfLitePass::runOnOperation() {
493495

494496
} // namespace
495497

496-
497498
// Creates an instance of the pass.
498499
std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeHloToTfLitePass() {
499500
return std::make_unique<LegalizeHloToTfLitePass>();

0 commit comments

Comments
 (0)