@@ -190,38 +190,6 @@ static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm,
190190 return outerDimsPerm;
191191}
192192
193- // / Returns a tuple for packed operand and indexing_map with the assumptions:
194- // / 1) The generic op is the producer of the pack op.
195- // / 2) The generic op has only one result.
196- // / If the operand is a scalar or packing dimensions are all irrelevant to the
197- // / operand, the operand and the updated indexing map will be returned.
198- // / Otherwise, it returns the packed operand and the updated indexing map. E.g.,
199- // /
200- // / #map0 = affine_map<(d0, d1) -> (d0, d1)>
201- // / #map1 = affine_map<(d0, d1) -> (d0)>
202- // / #map2 = affine_map<(d0, d1) -> (d1)>
203- // / %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0],
204- // / iterator_types = ["parallel", "parallel"]}
205- // / ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
206- // / outs(%init : tensor<?x?xf32>) {
207- // / ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
208- // / %4 = arith.addf %arg3, %arg4 : f32
209- // / linalg.yield %4 : f32
210- // / } -> tensor<?x?xf32>
211- // / %1 = linalg.pack %0
212- // / inner_dims_pos = [0, 1]
213- // / inner_tiles = [8, 2]
214- // / into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
215- // /
216- // / Taking the first input operand as an example, the inner tile size of d1 is
217- // / 8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> ->
218- // / affine_map<(d1, d3)>` will be returned.
219- // /
220- // / %pack = linalg.pack %arg0
221- // / inner_dims_pos = [0]
222- // / inner_tiles = [8]
223- // / into %init : tensor<?xf32> -> tensor<?x8xf32>
224-
225193struct PackedOperandDetails {
226194 SmallVector<OpFoldResult> innerTileSizes;
227195 SmallVector<int64_t > innerDimsPos;
@@ -231,7 +199,7 @@ struct PackedOperandDetails {
231199
232200// / Helper function for getOrCreatePackedViewOfOperand that populates
233201// / the details of the packedOperand that needs to be formed and also
234- // returns if the packing would require padding.
202+ // / returns if the packing would require padding.
235203static bool getPackedOperandDetails (
236204 OpBuilder &b, PackInfo packInfo, GenericOp genericOp, OpOperand *opOperand,
237205 DenseMap<OpOperand *, PackedOperandDetails> &packedOperandMap) {
@@ -323,23 +291,53 @@ static bool getPackedOperandDetails(
323291 currOperandDetails.outerDimsPerm = outerDimsPerm;
324292 packedOperandMap[opOperand] = currOperandDetails;
325293
326- if (requirePadding)
327- return true ;
328- return false ;
294+ return requirePadding;
329295}
330296
297+ // / Returns a tuple for packed operand and indexing_map with the assumptions:
298+ // / 1) The generic op is the producer of the pack op.
299+ // / 2) The generic op has only one result.
300+ // / If the operand is a scalar or packing dimensions are all irrelevant to the
301+ // / operand, the operand and the updated indexing map will be returned.
302+ // / Otherwise, it returns the packed operand and the updated indexing map. E.g.,
303+ // /
304+ // / #map0 = affine_map<(d0, d1) -> (d0, d1)>
305+ // / #map1 = affine_map<(d0, d1) -> (d0)>
306+ // / #map2 = affine_map<(d0, d1) -> (d1)>
307+ // / %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0],
308+ // / iterator_types = ["parallel", "parallel"]}
309+ // / ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
310+ // / outs(%init : tensor<?x?xf32>) {
311+ // / ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
312+ // / %4 = arith.addf %arg3, %arg4 : f32
313+ // / linalg.yield %4 : f32
314+ // / } -> tensor<?x?xf32>
315+ // / %1 = linalg.pack %0
316+ // / inner_dims_pos = [0, 1]
317+ // / inner_tiles = [8, 2]
318+ // / into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
319+ // /
320+ // / Taking the first input operand as an example, the inner tile size of d1 is
321+ // / 8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> ->
322+ // / affine_map<(d1, d3)>` will be returned.
323+ // /
324+ // / %pack = linalg.pack %arg0
325+ // / inner_dims_pos = [0]
326+ // / inner_tiles = [8]
327+ // / into %init : tensor<?xf32> -> tensor<?x8xf32>
328+
331329static std::tuple<Value, AffineMap> getOrCreatePackedViewOfOperand (
332330 OpBuilder &b, Location loc, OpOperand *opOperand,
333- DenseMap<OpOperand *, PackedOperandDetails> packedOperandMap) {
331+ const DenseMap<OpOperand *, PackedOperandDetails> & packedOperandMap) {
334332 assert (packedOperandMap.contains (opOperand) &&
335333 " packed operand details expected to be populated" );
336- auto currOperandDetails = packedOperandMap[ opOperand] ;
334+ auto currOperandDetails = packedOperandMap. at ( opOperand) ;
337335 auto innerDimsPos = currOperandDetails.innerDimsPos ;
338336 auto outerDimsPerm = currOperandDetails.outerDimsPerm ;
339337 auto innerTileSizes = currOperandDetails.innerTileSizes ;
340- if (innerDimsPos.empty () && outerDimsPerm.empty ()) {
338+ if (innerDimsPos.empty () && outerDimsPerm.empty ())
341339 return std::make_tuple (opOperand->get (), currOperandDetails.indexingMap );
342- }
340+
343341 auto empty = linalg::PackOp::createDestinationTensor (
344342 b, loc, opOperand->get (), innerTileSizes, innerDimsPos, outerDimsPerm);
345343 auto poison = ub::PoisonOp::create (
@@ -375,9 +373,9 @@ packGenericOp(RewriterBase &rewriter, GenericOp genericOp, Value dest,
375373 requiresPadding |= getPackedOperandDetails (rewriter, packInfo, genericOp,
376374 inputOperand, packedOperandMap);
377375 }
378- if (requiresPadding && !poisonPaddingOk) {
376+ if (requiresPadding && !poisonPaddingOk)
379377 return failure ();
380- }
378+
381379 for (OpOperand *inputOperand : genericOp.getDpsInputOperands ()) {
382380 auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand (
383381 rewriter, loc, inputOperand, packedOperandMap);
@@ -538,9 +536,9 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
538536 DenseMap<OpOperand *, PackedOperandDetails> packedOperandMap;
539537 bool requiresPadding = getPackedOperandDetails (rewriter, *packInfo, genericOp,
540538 opOperand, packedOperandMap);
541- if (requiresPadding && !poisonPaddingOk) {
539+ if (requiresPadding && !poisonPaddingOk)
542540 return failure ();
543- }
541+
544542 auto [packedOutOperand, packedOutIndexingMap] =
545543 getOrCreatePackedViewOfOperand (rewriter, genericOp.getLoc (), opOperand,
546544 packedOperandMap);
@@ -1186,9 +1184,9 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
11861184 bool requiresPadding =
11871185 getPackedOperandDetails (rewriter, *packInfo, genericOp,
11881186 genericOp.getDpsInitOperand (0 ), packedOperandMap);
1189- if (requiresPadding && !poisonPaddingOk) {
1187+ if (requiresPadding && !poisonPaddingOk)
11901188 return failure ();
1191- }
1189+
11921190 auto [packedOutOperand, packedOutIndexingMap] =
11931191 getOrCreatePackedViewOfOperand (rewriter, genericOp.getLoc (),
11941192 genericOp.getDpsInitOperand (0 ),
0 commit comments