Skip to content

Commit 1c65d0e

Browse files
[TOSA] Introduce arith.constant -> tosa.const normalization pass
Add a standalone pass that rewrites tensor-valued `arith.constant` ops into `tosa.const`, normalize the TOSA backend contract. Co-authored-by: Shubham <[email protected]> Signed-off-by: Vitalii Shutov <[email protected]> Change-Id: I4e71926107633007a71bd1fcc3311a5da6d38849
1 parent f1b5504 commit 1c65d0e

File tree

4 files changed

+211
-0
lines changed

4 files changed

+211
-0
lines changed

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,15 @@ def TosaReduceTransposes : Pass<"tosa-reduce-transposes", "func::FuncOp"> {
105105
}];
106106
}
107107

108+
def TosaArithConstantToTosaConstPass
109+
: Pass<"tosa-arith-const-to-tosa-const", "func::FuncOp"> {
110+
let summary = "Convert tensor arith.constant operations into tosa.const";
111+
let description = [{
112+
Normalizes tensor-valued arith.constant operations into tosa.const so that
113+
subsequent TOSA passes operate on a consistent representation of constants.
114+
}];
115+
}
116+
108117
def TosaConvertIntegerTypeToSignless : Pass<"tosa-convert-integer-type-to-signless", "func::FuncOp"> {
109118
let summary = "Convert integer types to signless";
110119
let description = [{

mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_mlir_dialect_library(MLIRTosaTransforms
22
TosaAttachTarget.cpp
3+
TosaArithConstantToConst.cpp
34
TosaConvertIntegerTypeToSignless.cpp
45
TosaDecomposeTransposeConv.cpp
56
TosaDecomposeDepthwise.cpp
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
//===- TosaArithConstantToConst.cpp ---------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a pass that converts tensor-valued arith.constant ops
10+
// into tosa.const so that TOSA pipelines operate on a uniform constant form.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
15+
16+
#include "mlir/Dialect/Arith/IR/Arith.h"
17+
#include "mlir/Dialect/Func/IR/FuncOps.h"
18+
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
19+
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
20+
#include "mlir/IR/BuiltinAttributes.h"
21+
#include "mlir/IR/BuiltinTypes.h"
22+
#include "mlir/IR/PatternMatch.h"
23+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24+
25+
namespace mlir {
26+
namespace tosa {
27+
#define GEN_PASS_DEF_TOSAARITHCONSTANTTOTOSACONSTPASS
28+
#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
29+
} // namespace tosa
30+
} // namespace mlir
31+
32+
using namespace mlir;
33+
using namespace mlir::tosa;
34+
35+
namespace {
36+
37+
/// Returns true when `elementType` is natively representable by tosa.const.
38+
static bool isSupportedElementType(Type elementType) {
39+
if (isa<FloatType>(elementType))
40+
return true;
41+
42+
if (auto intType = dyn_cast<IntegerType>(elementType))
43+
return intType.isSignless() || intType.isUnsigned();
44+
45+
if (isa<quant::QuantizedType>(elementType))
46+
return true;
47+
48+
return false;
49+
}
50+
51+
class ArithConstantToTosaConst : public OpRewritePattern<arith::ConstantOp> {
52+
public:
53+
using OpRewritePattern::OpRewritePattern;
54+
55+
LogicalResult matchAndRewrite(arith::ConstantOp constOp,
56+
PatternRewriter &rewriter) const override {
57+
// TOSA constant verification requires a ranked, statically shaped tensor.
58+
auto resultType = dyn_cast<RankedTensorType>(constOp.getResult().getType());
59+
if (!resultType || !resultType.hasStaticShape())
60+
return failure();
61+
62+
if (!isSupportedElementType(resultType.getElementType()))
63+
return failure();
64+
65+
Attribute attr = constOp.getValueAttr();
66+
auto elementsAttr = dyn_cast<ElementsAttr>(attr);
67+
if (!elementsAttr)
68+
return failure();
69+
70+
auto attrType = dyn_cast<RankedTensorType>(elementsAttr.getType());
71+
if (!attrType || !attrType.hasStaticShape())
72+
return failure();
73+
74+
if (attrType != resultType) {
75+
// Allow reshape when the payload can be reinterpreted without altering
76+
// the number of elements or element type. Dense resource attributes
77+
// cannot be reshaped losslessly, so bail out in that case.
78+
if (!isa<DenseElementsAttr>(elementsAttr))
79+
return failure();
80+
81+
if (attrType.getElementType() != resultType.getElementType())
82+
return failure();
83+
84+
auto denseAttr = cast<DenseElementsAttr>(elementsAttr);
85+
if (denseAttr.getNumElements() != resultType.getNumElements())
86+
return failure();
87+
88+
elementsAttr = denseAttr.reshape(resultType);
89+
}
90+
91+
auto newConst = tosa::ConstOp::create(rewriter, constOp.getLoc(),
92+
resultType, elementsAttr);
93+
rewriter.replaceOp(constOp, newConst.getResult());
94+
return success();
95+
}
96+
};
97+
98+
struct TosaArithConstantToTosaConstPass
99+
: public tosa::impl::TosaArithConstantToTosaConstPassBase<
100+
TosaArithConstantToTosaConstPass> {
101+
using Base::Base;
102+
103+
void getDependentDialects(DialectRegistry &registry) const override {
104+
registry.insert<tosa::TosaDialect>();
105+
}
106+
107+
void runOnOperation() override {
108+
auto *ctx = &getContext();
109+
RewritePatternSet patterns(ctx);
110+
patterns.add<ArithConstantToTosaConst>(ctx);
111+
112+
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
113+
signalPassFailure();
114+
}
115+
};
116+
117+
} // namespace
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// RUN: mlir-opt %s --tosa-arith-const-to-tosa-const --split-input-file | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @rewrite_f32_tensor
4+
// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32>}> : () -> tensor<2xf32>
5+
// CHECK: return %[[CST]]
6+
func.func @rewrite_f32_tensor() -> tensor<2xf32> {
7+
%c = arith.constant dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32>
8+
return %c : tensor<2xf32>
9+
}
10+
11+
// -----
12+
13+
// CHECK-LABEL: func.func @rewrite_i32_tensor
14+
// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32>
15+
// CHECK: return %[[CST]]
16+
func.func @rewrite_i32_tensor() -> tensor<3xi32> {
17+
%c = arith.constant dense<[1, 0, -1]> : tensor<3xi32>
18+
return %c : tensor<3xi32>
19+
}
20+
21+
// -----
22+
23+
// CHECK-LABEL: func.func @rewrite_i1_tensor
24+
// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[true, false]> : tensor<2xi1>}> : () -> tensor<2xi1>
25+
func.func @rewrite_i1_tensor() -> tensor<2xi1> {
26+
%c = arith.constant dense<[true, false]> : tensor<2xi1>
27+
return %c : tensor<2xi1>
28+
}
29+
30+
// -----
31+
32+
// CHECK-LABEL: func.func @rewrite_rank0_tensor
33+
// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<1.234500e+00> : tensor<f32>}> : () -> tensor<f32>
34+
func.func @rewrite_rank0_tensor() -> tensor<f32> {
35+
%c = arith.constant dense<1.234500e+00> : tensor<f32>
36+
return %c : tensor<f32>
37+
}
38+
39+
// -----
40+
41+
// CHECK-LABEL: func.func @preserve_scalar_i32
42+
// CHECK: %[[CST:.*]] = arith.constant 42 : i32
43+
func.func @preserve_scalar_i32() -> i32 {
44+
%c = arith.constant 42 : i32
45+
return %c : i32
46+
}
47+
48+
// -----
49+
50+
// CHECK-LABEL: func.func @preserve_index_tensor
51+
// CHECK: %[[CST:.*]] = arith.constant dense<[0, 1]> : tensor<2xindex>
52+
func.func @preserve_index_tensor() -> tensor<2xindex> {
53+
%c = arith.constant dense<[0, 1]> : tensor<2xindex>
54+
return %c : tensor<2xindex>
55+
}
56+
57+
// -----
58+
59+
// CHECK-LABEL: func.func @rewrite_resource_tensor
60+
// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense_resource<blob1> : tensor<4xf32>}> : () -> tensor<4xf32>
61+
func.func @rewrite_resource_tensor() -> tensor<4xf32> {
62+
%c = arith.constant dense_resource<"blob1"> : tensor<4xf32>
63+
return %c : tensor<4xf32>
64+
}
65+
66+
// -----
67+
68+
// CHECK-LABEL: func.func @rewrite_quant_tensor
69+
// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[10, 20]> : tensor<2xui8>}> : () -> tensor<2xui8>
70+
func.func @rewrite_quant_tensor() -> tensor<2xui8> {
71+
%c = arith.constant dense<[10, 20]> : tensor<2xui8>
72+
return %c : tensor<2xui8>
73+
}
74+
75+
// -----
76+
77+
// CHECK-LABEL: func.func @rewrite_quant_uniform_tensor
78+
// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<["10", "20"]> : tensor<2x!quant.uniform<i8:f32, 5.000000e-01>>}> : () -> tensor<2x!quant.uniform<i8:f32, 5.000000e-01>>
79+
func.func @rewrite_quant_uniform_tensor() -> tensor<2x!quant.uniform<i8:f32, 0.5:0>> {
80+
%c = arith.constant dense<["10", "20"]> : tensor<2x!quant.uniform<i8:f32, 0.5:0>>
81+
return %c : tensor<2x!quant.uniform<i8:f32, 0.5:0>>
82+
}
83+
84+
// -----

0 commit comments

Comments
 (0)