@@ -63,5 +63,83 @@ Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
6363 return const_op.getResult ();
6464}
6565
66+ // Templated function to create a constant op for given type and shape.
67+ // T: storage C type.
68+ // Default template creates a constant tensor in T.
69+ template <typename T>
70+ llvm::Optional<Value> getConstTensor (PatternRewriter &rewriter, Operation *op,
71+ ArrayRef<T> vec, ArrayRef<int64_t > shape) {
72+ uint64_t num_total_elements = 1 ;
73+ for (int64_t a : shape) {
74+ num_total_elements *= a;
75+ }
76+
77+ if (vec.size () != num_total_elements) {
78+ op->emitOpError (" getConstTensor(): number of elements mismatch." );
79+ return llvm::None;
80+ }
81+
82+ auto const_type =
83+ RankedTensorType::get (shape, rewriter.getIntegerType (sizeof (T) * 8 ));
84+ auto const_attr = DenseElementsAttr::get (const_type, vec);
85+
86+ auto const_op =
87+ rewriter.create <tosa::ConstOp>(op->getLoc (), const_type, const_attr);
88+ return const_op.getResult ();
89+ }
90+
91+ // Template specialization for APInt
92+ template <>
93+ llvm::Optional<Value> getConstTensor<APInt>(PatternRewriter &rewriter,
94+ Operation *op, ArrayRef<APInt> vec,
95+ ArrayRef<int64_t > shape) {
96+ uint64_t num_total_elements = 1 ;
97+ for (int64_t a : shape) {
98+ num_total_elements *= a;
99+ }
100+
101+ if (vec.size () != num_total_elements) {
102+ op->emitOpError (" getConstTensor(): number of elements mismatch." );
103+ return llvm::None;
104+ }
105+
106+ auto const_type = RankedTensorType::get (
107+ shape, rewriter.getIntegerType (vec[0 ].getBitWidth ()));
108+ auto const_attr = DenseElementsAttr::get (const_type, vec);
109+
110+ auto const_op =
111+ rewriter.create <tosa::ConstOp>(op->getLoc (), const_type, const_attr);
112+ return const_op.getResult ();
113+ }
114+
115+ // Template specialization for float
116+ template <>
117+ llvm::Optional<Value> getConstTensor<float >(PatternRewriter &rewriter,
118+ Operation *op, ArrayRef<float > vec,
119+ ArrayRef<int64_t > shape) {
120+ uint64_t num_total_elements = 1 ;
121+ for (int64_t a : shape) {
122+ num_total_elements *= a;
123+ }
124+
125+ if (vec.size () != num_total_elements) {
126+ op->emitOpError (" getConstTensor(): number of elements mismatch." );
127+ return llvm::None;
128+ }
129+
130+ auto const_type = RankedTensorType::get (shape, rewriter.getF32Type ());
131+ auto const_attr = DenseElementsAttr::get (const_type, vec);
132+
133+ auto const_op =
134+ rewriter.create <tosa::ConstOp>(op->getLoc (), const_type, const_attr);
135+ return const_op.getResult ();
136+ }
137+
138+ // Template instantiation
139+ template llvm::Optional<Value> getConstTensor<int32_t >(PatternRewriter &,
140+ Operation *,
141+ ArrayRef<int32_t > vec,
142+ ArrayRef<int64_t > shape);
143+
66144} // namespace tosa
67145} // namespace mlir
0 commit comments