@@ -36,10 +36,10 @@ TensorType inferReshapeInputType(TypedValue<TensorType> input,
3636 return input.getType ();
3737
3838 // The input type must be cast into a tensor with the same rank and all static
39- // dimensions set to 1. This prevents the generation of a tensor.collapse_shape
40- // op that converts a dynamically shaped tensor into a 0D tensor. While such
41- // construct is not incorrect on its own, bufferization cannot properly handle
42- // it at the moment, so we avoid it.
39+ // dimensions set to 1. This prevents the generation of a
40+ // tensor.collapse_shape op that converts a dynamically shaped tensor into a
41+ // 0D tensor. While such construct is not incorrect on its own, bufferization
42+ // cannot properly handle it at the moment, so we avoid it.
4343 SmallVector<int64_t > shape (input.getType ().getRank (), 1 );
4444 return input.getType ().clone (shape);
4545}
@@ -58,29 +58,31 @@ TensorType inferReshapeExpandedType(TensorType inputType,
5858 int64_t totalSize = inputIsStatic ? inputType.getNumElements () : -1 ;
5959
6060 // Compute result shape
61- auto resultShape = llvm::map_to_vector (newShape, [&](int64_t size) -> int64_t {
62- // If this is not a placeholder, do not change it
63- if (size >= 0 )
64- return size;
65-
66- // If we do not know the total size of the tensor, keep this dimension
67- // dynamic in the result shape.
68- if (!inputIsStatic)
69- return ShapedType::kDynamic ;
70-
71- // Calculate the product of all elements in 'newShape' except for the -1
72- // placeholder, which we discard by negating the result.
73- int64_t totalSizeNoPlaceholder = -std::accumulate (
74- newShape.begin (), newShape.end (), 1 , std::multiplies<int64_t >());
75-
76- // If there is a 0 component in 'newShape', resolve the placeholder as 0.
77- if (totalSizeNoPlaceholder == 0 )
78- return 0 ;
79-
80- // Resolve the placeholder as the quotient between the total tensor size and
81- // the product of all other sizes.
82- return totalSize / totalSizeNoPlaceholder;
83- });
61+ auto resultShape =
62+ llvm::map_to_vector (newShape, [&](int64_t size) -> int64_t {
63+ // If this is not a placeholder, do not change it
64+ if (size >= 0 )
65+ return size;
66+
67+ // If we do not know the total size of the tensor, keep this dimension
68+ // dynamic in the result shape.
69+ if (!inputIsStatic)
70+ return ShapedType::kDynamic ;
71+
72+ // Calculate the product of all elements in 'newShape' except for the -1
73+ // placeholder, which we discard by negating the result.
74+ int64_t totalSizeNoPlaceholder = -std::accumulate (
75+ newShape.begin (), newShape.end (), 1 , std::multiplies<int64_t >());
76+
77+ // If there is a 0 component in 'newShape', resolve the placeholder as
78+ // 0.
79+ if (totalSizeNoPlaceholder == 0 )
80+ return 0 ;
81+
82+ // Resolve the placeholder as the quotient between the total tensor size
83+ // and the product of all other sizes.
84+ return totalSize / totalSizeNoPlaceholder;
85+ });
8486
8587 bool resultIsStatic = !ShapedType::isDynamicShape (resultShape);
8688
@@ -108,7 +110,8 @@ TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) {
108110 if (lhsShape.empty () || rhsShape.empty ())
109111 return lhsType.clone (ArrayRef<int64_t >{});
110112
111- if (ShapedType::isDynamicShape (lhsShape) || ShapedType::isDynamicShape (rhsShape))
113+ if (ShapedType::isDynamicShape (lhsShape) ||
114+ ShapedType::isDynamicShape (rhsShape))
112115 return lhsType.clone ({ShapedType::kDynamic });
113116
114117 SmallVector<int64_t > intermediateShape;
@@ -150,14 +153,16 @@ TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) {
150153}
151154
152155SmallVector<ReassociationExprs>
153- createReassociationMapForCollapse (OpBuilder &builder, Type srcType, Type dstType) {
156+ createReassociationMapForCollapse (OpBuilder &builder, Type srcType,
157+ Type dstType) {
154158 auto srcShape = cast<TensorType>(srcType).getShape ();
155159 auto dstShape = cast<TensorType>(dstType).getShape ();
156160
157161 if (srcShape.empty () || dstShape.empty ())
158162 return {};
159163
160- if (ShapedType::isDynamicShape (srcShape) || ShapedType::isDynamicShape (dstShape)) {
164+ if (ShapedType::isDynamicShape (srcShape) ||
165+ ShapedType::isDynamicShape (dstShape)) {
161166 assert (dstShape.size () == 1 );
162167 SmallVector<AffineExpr, 2 > exprs;
163168 for (auto i : llvm::seq<int64_t >(srcShape.size ()))
@@ -249,14 +254,16 @@ class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
249254 auto collapsedType = inferReshapeCollapsedType (inputType, expandedType);
250255
251256 // Cast input if needed
252- auto castInput = rewriter.createOrFold <tensor::CastOp>(loc, inputType, input);
257+ auto castInput =
258+ rewriter.createOrFold <tensor::CastOp>(loc, inputType, input);
253259
254260 // Emit collaspe-expand pair
255261 auto collapsed = createCollapse (rewriter, loc, collapsedType, castInput);
256262 auto expanded = createExpand (rewriter, loc, expandedType, collapsed);
257263
258264 // Cast to final result type if needed
259- auto result = rewriter.createOrFold <tensor::CastOp>(loc, resultType, expanded);
265+ auto result =
266+ rewriter.createOrFold <tensor::CastOp>(loc, resultType, expanded);
260267 rewriter.replaceOp (reshape, result);
261268 return success ();
262269 }
@@ -350,29 +357,12 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
350357 }
351358
352359 ShapedType inputTy = cast<ShapedType>(input.getType ());
353- Type elementTy = inputTy.getElementType ();
354360 int64_t rank = inputTy.getRank ();
355361
356362 // Setup the default constantAttr.
357363
358- Value padConstant;
359-
360- if (padOp.getPadConst ()) {
361- padConstant = rewriter.createOrFold <tensor::ExtractOp>(
362- loc, padOp.getPadConst (), ValueRange ({}));
363- } else {
364- TypedAttr constantAttr;
365- if (isa<FloatType>(elementTy)) {
366- constantAttr = rewriter.getFloatAttr (elementTy, 0.0 );
367- } else if (isa<IntegerType>(elementTy) && !padOp.getInputZpAttr ()) {
368- constantAttr = rewriter.getIntegerAttr (elementTy, 0 );
369- } else if (isa<IntegerType>(elementTy) && padOp.getInputZpAttr ()) {
370- int64_t value = padOp.getInputZpAttr ().getInt ();
371- constantAttr = rewriter.getIntegerAttr (elementTy, value);
372- }
373- if (constantAttr)
374- padConstant = rewriter.create <arith::ConstantOp>(loc, constantAttr);
375- }
364+ Value padConstant = rewriter.createOrFold <tensor::ExtractOp>(
365+ loc, padOp.getPadConst (), ValueRange ({}));
376366
377367 if (!padConstant) {
378368 return rewriter.notifyMatchFailure (
0 commit comments