Skip to content

Commit edf4a0e

Browse files
authored
[tosa] Add more common utility functions (#525)
- Common code as TF repository, being moved to MLIR core. - Will support further legalizations to be published. Signed-off-by: Suraj Sudhir <[email protected]>
1 parent 5ded7d0 commit edf4a0e

File tree

2 files changed

+86
-0
lines changed

2 files changed

+86
-0
lines changed

include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@ Value buildRescaleToInt32(PatternRewriter &rewriter, Operation *op,
3737
Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
3838
float val);
3939

40+
// Templated function to create a constant op for given type and shape.
41+
// T: storage C type.
42+
// Default template creates a constant tensor in T.
43+
// To create INT48 TOSA constant, need to pass in llvm::APInt instead.
44+
template <typename T>
45+
llvm::Optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
46+
ArrayRef<T> vec, ArrayRef<int64_t> shape);
47+
4048
// Creates a TOSA operation and performs shape inference on the individual
4149
// op. This allows shape inference during the framework to TOSA lowering.
4250
template <typename TosaOp, typename... Args>

lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,5 +63,83 @@ Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
6363
return const_op.getResult();
6464
}
6565

66+
// Templated function to create a constant op for given type and shape.
67+
// T: storage C type.
68+
// Default template creates a constant tensor in T.
69+
template <typename T>
70+
llvm::Optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
71+
ArrayRef<T> vec, ArrayRef<int64_t> shape) {
72+
uint64_t num_total_elements = 1;
73+
for (int64_t a : shape) {
74+
num_total_elements *= a;
75+
}
76+
77+
if (vec.size() != num_total_elements) {
78+
op->emitOpError("getConstTensor(): number of elements mismatch.");
79+
return llvm::None;
80+
}
81+
82+
auto const_type =
83+
RankedTensorType::get(shape, rewriter.getIntegerType(sizeof(T) * 8));
84+
auto const_attr = DenseElementsAttr::get(const_type, vec);
85+
86+
auto const_op =
87+
rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
88+
return const_op.getResult();
89+
}
90+
91+
// Template specialization for APInt
92+
template <>
93+
llvm::Optional<Value> getConstTensor<APInt>(PatternRewriter &rewriter,
94+
Operation *op, ArrayRef<APInt> vec,
95+
ArrayRef<int64_t> shape) {
96+
uint64_t num_total_elements = 1;
97+
for (int64_t a : shape) {
98+
num_total_elements *= a;
99+
}
100+
101+
if (vec.size() != num_total_elements) {
102+
op->emitOpError("getConstTensor(): number of elements mismatch.");
103+
return llvm::None;
104+
}
105+
106+
auto const_type = RankedTensorType::get(
107+
shape, rewriter.getIntegerType(vec[0].getBitWidth()));
108+
auto const_attr = DenseElementsAttr::get(const_type, vec);
109+
110+
auto const_op =
111+
rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
112+
return const_op.getResult();
113+
}
114+
115+
// Template specialization for float
116+
template <>
117+
llvm::Optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
118+
Operation *op, ArrayRef<float> vec,
119+
ArrayRef<int64_t> shape) {
120+
uint64_t num_total_elements = 1;
121+
for (int64_t a : shape) {
122+
num_total_elements *= a;
123+
}
124+
125+
if (vec.size() != num_total_elements) {
126+
op->emitOpError("getConstTensor(): number of elements mismatch.");
127+
return llvm::None;
128+
}
129+
130+
auto const_type = RankedTensorType::get(shape, rewriter.getF32Type());
131+
auto const_attr = DenseElementsAttr::get(const_type, vec);
132+
133+
auto const_op =
134+
rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
135+
return const_op.getResult();
136+
}
137+
138+
// Template instantiation
139+
template llvm::Optional<Value> getConstTensor<int32_t>(PatternRewriter &,
140+
Operation *,
141+
ArrayRef<int32_t> vec,
142+
ArrayRef<int64_t> shape);
143+
66144
} // namespace tosa
67145
} // namespace mlir

0 commit comments

Comments
 (0)