Skip to content

Commit 9496f27

Browse files
[TOSA] Introduce arith.constant -> tosa.const normalization pass
Add a standalone pass that rewrites tensor-valued `arith.constant` ops into `tosa.const`, normalize the TOSA backend contract. Co-authored-by: Shubham <[email protected]> Signed-off-by: Vitalii Shutov <[email protected]> Change-Id: I4e71926107633007a71bd1fcc3311a5da6d38849
1 parent f1b5504 commit 9496f27

File tree

4 files changed

+221
-0
lines changed

4 files changed

+221
-0
lines changed

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,15 @@ def TosaReduceTransposes : Pass<"tosa-reduce-transposes", "func::FuncOp"> {
105105
}];
106106
}
107107

108+
def TosaArithConstantToTosaConstPass
109+
: Pass<"tosa-arith-const-to-tosa-const", "func::FuncOp"> {
110+
let summary = "Convert tensor arith.constant operations into tosa.const";
111+
let description = [{
112+
Normalizes tensor-valued arith.constant operations into tosa.const so that
113+
subsequent TOSA passes operate on a consistent representation of constants.
114+
}];
115+
}
116+
108117
def TosaConvertIntegerTypeToSignless : Pass<"tosa-convert-integer-type-to-signless", "func::FuncOp"> {
109118
let summary = "Convert integer types to signless";
110119
let description = [{

mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_mlir_dialect_library(MLIRTosaTransforms
22
TosaAttachTarget.cpp
3+
TosaArithConstantToConst.cpp
34
TosaConvertIntegerTypeToSignless.cpp
45
TosaDecomposeTransposeConv.cpp
56
TosaDecomposeDepthwise.cpp
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
//===- TosaArithConstantToConst.cpp ---------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a pass that converts tensor-valued arith.constant ops
10+
// into tosa.const so that TOSA pipelines operate on a uniform constant form.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
15+
16+
#include "mlir/Dialect/Arith/IR/Arith.h"
17+
#include "mlir/Dialect/Func/IR/FuncOps.h"
18+
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
19+
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
20+
#include "mlir/IR/BuiltinAttributes.h"
21+
#include "mlir/IR/BuiltinTypes.h"
22+
#include "mlir/IR/PatternMatch.h"
23+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24+
25+
namespace mlir {
26+
namespace tosa {
27+
#define GEN_PASS_DEF_TOSAARITHCONSTANTTOTOSACONSTPASS
28+
#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
29+
} // namespace tosa
30+
} // namespace mlir
31+
32+
using namespace mlir;
33+
using namespace mlir::tosa;
34+
35+
namespace {
36+
37+
// NOTE: TOSA pipelines already lower their constants through shared Arith
38+
// folding passes, so tensor literals often come back as `arith.constant` even
39+
// after the IR is otherwise TOSA-only. Keep this normalization with the rest of
40+
// the TOSA transforms so any client can re-establish a canonical `tosa.const`
41+
// representation without needing a full Arith->TOSA conversion library.
42+
43+
/// Returns true when `elementType` is natively representable by tosa.const.
44+
static bool isSupportedElementType(Type elementType) {
45+
if (isa<FloatType>(elementType))
46+
return true;
47+
48+
if (auto intType = dyn_cast<IntegerType>(elementType))
49+
return intType.isSignless() || intType.isUnsigned();
50+
51+
if (isa<quant::QuantizedType>(elementType))
52+
return true;
53+
54+
if (isa<tosa::mxint8Type>(elementType))
55+
return true;
56+
57+
return false;
58+
}
59+
60+
class ArithConstantToTosaConst : public OpRewritePattern<arith::ConstantOp> {
61+
public:
62+
using OpRewritePattern::OpRewritePattern;
63+
64+
LogicalResult matchAndRewrite(arith::ConstantOp constOp,
65+
PatternRewriter &rewriter) const override {
66+
// TOSA constant verification requires a ranked, statically shaped tensor.
67+
auto resultType = dyn_cast<RankedTensorType>(constOp.getResult().getType());
68+
if (!resultType || !resultType.hasStaticShape())
69+
return failure();
70+
71+
if (!isSupportedElementType(resultType.getElementType()))
72+
return failure();
73+
74+
Attribute attr = constOp.getValueAttr();
75+
auto elementsAttr = dyn_cast<ElementsAttr>(attr);
76+
if (!elementsAttr)
77+
return failure();
78+
79+
auto attrType = dyn_cast<RankedTensorType>(elementsAttr.getType());
80+
if (!attrType || !attrType.hasStaticShape())
81+
return failure();
82+
if (attrType != resultType)
83+
return failure();
84+
85+
auto newConst = tosa::ConstOp::create(rewriter, constOp.getLoc(),
86+
resultType, elementsAttr);
87+
rewriter.replaceOp(constOp, newConst.getResult());
88+
return success();
89+
}
90+
};
91+
92+
struct TosaArithConstantToTosaConstPass
93+
: public tosa::impl::TosaArithConstantToTosaConstPassBase<
94+
TosaArithConstantToTosaConstPass> {
95+
using Base::Base;
96+
97+
void getDependentDialects(DialectRegistry &registry) const override {
98+
registry.insert<arith::ArithDialect, tosa::TosaDialect>();
99+
}
100+
101+
void runOnOperation() override {
102+
auto *ctx = &getContext();
103+
RewritePatternSet patterns(ctx);
104+
patterns.add<ArithConstantToTosaConst>(ctx);
105+
106+
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
107+
signalPassFailure();
108+
}
109+
};
110+
111+
} // namespace
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
// RUN: mlir-opt %s --tosa-arith-const-to-tosa-const --split-input-file | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @rewrite_f32_tensor
4+
// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32>}> : () -> tensor<2xf32>
5+
// CHECK: return %[[CST]]
6+
func.func @rewrite_f32_tensor() -> tensor<2xf32> {
7+
%c = arith.constant dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32>
8+
return %c : tensor<2xf32>
9+
}
10+
11+
// -----
12+
13+
// CHECK-LABEL: func.func @rewrite_i32_tensor
14+
// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32>
15+
// CHECK: return %[[CST]]
16+
func.func @rewrite_i32_tensor() -> tensor<3xi32> {
17+
%c = arith.constant dense<[1, 0, -1]> : tensor<3xi32>
18+
return %c : tensor<3xi32>
19+
}
20+
21+
// -----
22+
23+
// CHECK-LABEL: func.func @rewrite_i1_tensor
24+
// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[true, false]> : tensor<2xi1>}> : () -> tensor<2xi1>
25+
func.func @rewrite_i1_tensor() -> tensor<2xi1> {
26+
%c = arith.constant dense<[true, false]> : tensor<2xi1>
27+
return %c : tensor<2xi1>
28+
}
29+
30+
// -----
31+
32+
// CHECK-LABEL: func.func @rewrite_rank0_tensor
33+
// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<1.234500e+00> : tensor<f32>}> : () -> tensor<f32>
34+
func.func @rewrite_rank0_tensor() -> tensor<f32> {
35+
%c = arith.constant dense<1.234500e+00> : tensor<f32>
36+
return %c : tensor<f32>
37+
}
38+
39+
// -----
40+
41+
// CHECK-LABEL: func.func @preserve_scalar_i32
42+
// CHECK: %[[CST:.*]] = arith.constant 42 : i32
43+
func.func @preserve_scalar_i32() -> i32 {
44+
%c = arith.constant 42 : i32
45+
return %c : i32
46+
}
47+
48+
// -----
49+
50+
// CHECK-LABEL: func.func @preserve_index_tensor
51+
// CHECK: %[[CST:.*]] = arith.constant dense<[0, 1]> : tensor<2xindex>
52+
func.func @preserve_index_tensor() -> tensor<2xindex> {
53+
%c = arith.constant dense<[0, 1]> : tensor<2xindex>
54+
return %c : tensor<2xindex>
55+
}
56+
57+
// -----
58+
59+
// CHECK-LABEL: func.func @rewrite_resource_tensor
60+
// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense_resource<blob1> : tensor<4xf32>}> : () -> tensor<4xf32>
61+
func.func @rewrite_resource_tensor() -> tensor<4xf32> {
62+
%c = arith.constant dense_resource<"blob1"> : tensor<4xf32>
63+
return %c : tensor<4xf32>
64+
}
65+
66+
// -----
67+
68+
// CHECK-LABEL: func.func @rewrite_quant_tensor
69+
// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[10, 20]> : tensor<2xui8>}> : () -> tensor<2xui8>
70+
func.func @rewrite_quant_tensor() -> tensor<2xui8> {
71+
%c = arith.constant dense<[10, 20]> : tensor<2xui8>
72+
return %c : tensor<2xui8>
73+
}
74+
75+
// -----
76+
77+
// CHECK-LABEL: func.func @rewrite_quant_uniform_tensor
78+
// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<["10", "20"]> : tensor<2x!quant.uniform<i8:f32, 5.000000e-01>>}> : () -> tensor<2x!quant.uniform<i8:f32, 5.000000e-01>>
79+
func.func @rewrite_quant_uniform_tensor() -> tensor<2x!quant.uniform<i8:f32, 0.5:0>> {
80+
%c = arith.constant dense<["10", "20"]> : tensor<2x!quant.uniform<i8:f32, 0.5:0>>
81+
return %c : tensor<2x!quant.uniform<i8:f32, 0.5:0>>
82+
}
83+
84+
// -----
85+
86+
// CHECK-LABEL: func.func @rewrite_fp8_tensor
87+
// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[1.000000e+00, -5.000000e-01]> : tensor<2xf8E4M3FN>}> : () -> tensor<2xf8E4M3FN>
88+
func.func @rewrite_fp8_tensor() -> tensor<2xf8E4M3FN> {
89+
%c = arith.constant dense<[1.0, -0.5]> : tensor<2xf8E4M3FN>
90+
return %c : tensor<2xf8E4M3FN>
91+
}
92+
93+
// -----
94+
95+
// CHECK-LABEL: func.func @rewrite_mxint8_tensor
96+
// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<["0x00", "0x7F"]> : tensor<2x!tosa.mxint8>}> : () -> tensor<2x!tosa.mxint8>
97+
func.func @rewrite_mxint8_tensor() -> tensor<2x!tosa.mxint8> {
98+
%c = arith.constant dense<["0x00", "0x7F"]> : tensor<2x!tosa.mxint8>
99+
return %c : tensor<2x!tosa.mxint8>
100+
}

0 commit comments

Comments
 (0)