Skip to content

Commit f725b3d

Browse files
authored
Refactor common tosa functions into tosaUtils (#2035)
rocMLIR has copies of many tosa related functions across a few source files. Mainly MIGraphXToTosa, TosaToRock and rocmlir-gen.cpp. This PR removes those copies and put them into header file to be shared across all source files.
1 parent f4abcea commit f725b3d

File tree

8 files changed

+858
-491
lines changed

8 files changed

+858
-491
lines changed
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
//===- tosa Utility Functions -===//
2+
//
3+
// Copyright 2025 Advanced Micro Devices.
4+
//
5+
// Licensed under the Apache License, Version 2.0 (the "License");
6+
// you may not use this file except in compliance with the License.
7+
// You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing, software
12+
// distributed under the License is distributed on an "AS IS" BASIS,
13+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
// See the License for the specific language governing permissions and
15+
// limitations under the License.
16+
// ============================================================
17+
#ifndef MLIR_DIALECT_ROCK_TOSA_UTILITY_H
18+
#define MLIR_DIALECT_ROCK_TOSA_UTILITY_H
19+
20+
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
21+
#include "mlir/IR/Builders.h"
22+
#include "mlir/IR/Value.h"
23+
#include "mlir/Interfaces/InferTypeOpInterface.h"
24+
25+
namespace mlir {
26+
namespace rock {
27+
bool isSpecificValueAttribute(Attribute value, double target);
28+
bool isConstantValue(Value v, double target);
29+
bool isConstantZero(Value v);
30+
bool isConstantOne(Value v);
31+
bool isConstNegInf(Value v);
32+
bool isConstRange(Value v);
33+
34+
namespace tosa {
35+
template <typename TosaOp, typename... Args>
36+
TosaOp createOpAndInfer(OpBuilder &rewriter, Location loc, Type elemType,
37+
Args &&...args) {
38+
auto op =
39+
TosaOp::create(rewriter, loc, UnrankedTensorType::get(elemType), args...);
40+
InferShapedTypeOpInterface shapeInterface =
41+
cast<InferShapedTypeOpInterface>(op.getOperation());
42+
SmallVector<ShapedTypeComponents> returnShape;
43+
LogicalResult shapeInferenceStatus = shapeInterface.inferReturnTypeComponents(
44+
op.getContext(), op.getLoc(), op->getOperands(), op->getAttrDictionary(),
45+
op->getPropertiesStorage(), op->getRegions(), returnShape);
46+
assert(shapeInferenceStatus.succeeded());
47+
Type newOutTy = RankedTensorType::get({returnShape[0].getDims()}, elemType);
48+
auto result = op->getResult(0);
49+
result.setType(newOutTy);
50+
return op;
51+
}
52+
53+
Value getOneTensor(OpBuilder &builder, Location loc, RankedTensorType type);
54+
55+
Type getAccType(OpBuilder &builder, Type inputType);
56+
57+
Value getZeroTensor(OpBuilder &builder, Location loc, RankedTensorType type);
58+
59+
mlir::tosa::TransposeOp getTransposeOp(OpBuilder &builder, Location loc,
60+
Value input,
61+
ArrayRef<int32_t> permutation);
62+
63+
mlir::tosa::MulOp getMulOp(OpBuilder &builder, Location loc, Value input1,
64+
Value input2, Type elemType);
65+
} // namespace tosa
66+
} // namespace rock
67+
} // namespace mlir
68+
69+
#endif // MLIR_DIALECT_ROCK_TOSA_UTILITY_H

0 commit comments

Comments
 (0)