Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,20 +298,56 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
return std::make_tuple(packedOperand, indexingMap);
}

static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
int numDpsOuts = genericOp.getNumDpsInits();
for (int i = 0; i < numDpsOuts; ++i) {
Block *block = genericOp.getBody();
int numBlockArgs = block->getNumArguments();
int matchingInitArgIndex = numBlockArgs - numDpsOuts + i;
return block->getArgument(matchingInitArgIndex).use_empty();
}
return true;
}

/// Pack a genericOp and return it.
static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
Value dest, AffineMap packedOutIndexingMap,
const PackInfo &packInfo) {
Location loc = genericOp.getLoc();
SmallVector<Value> inputOperands;
SmallVector<Value> inputOperandsFromUnpackedSource;
SmallVector<AffineMap> indexingMaps;

// Note: canUnpackPackFold needs to also guarantee the generic body
// doesn't have gather semantics. Since such scenarios has been
// rejected by both BubbleUpPackOpThroughGenericOp and
// PushDownUnPackOpThroughGenericOp, we can safely assume
// canUnpackPackFold is as long as init is not used.
bool canUnpackPackFold = isGenericOutsNotUsed(genericOp);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be moved down to where it's used.

Copy link
Contributor

@hanhanW hanhanW May 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a local function, and we should not have comment about how it is used in the implementation comments. They could be outdated easily.

I think what you're looking for is using isElementwise(genericOp) && !hasGatherSemantics(genericOp) and you can just put the statement to the if-condition. Like @pashu123 mentioned, it is not used until l.345.

Copy link
Member Author

@jerryyin jerryyin May 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reminder. I'd adopt the && !hasGatherSemantics(genericOp) instead of relying on the comment to stay up to date. I was attempting to save an redundant check if possible :-p.

I don't think isElementwise(genericOp) is necessary here. Per a past discussion with @Max191 offline, I think we are good as long as outs are unused. For example, the below IR isn't elementwise but we'd be okay to push unpack down or bubble the pack up.

  ^bb0(%in_0: f32, %in_1, %out: f32):
    %21 = arith.addf %in_0, %in_1 : f32
    linalg.yield %21 : f32
  } 

Copy link
Contributor

@hanhanW hanhanW May 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reminder. I'd adopt the && !hasGatherSemantics(genericOp) instead of relying on the comment to stay up to date. I was attempting to save an redundant check if possible :-p.

In this case, maybe we can put hasGatherSemantics in an assertion and document the requirement to the method. It is not trivial to support the op that has gather semantics, but I think it is doable. We need either an assertion or the actual check, so future contributors won't miss updating the method, IMO.

I don't think isElementwise(genericOp) is necessary here. Per a past discussion with @Max191 offline, I think we are good as long as outs are unused. For example, the below IR isn't elementwise but we'd be okay to push unpack down or bubble the pack up.

After reviewing the isElementwise implementation and definition, I realized that I had wrong understanding about it. I thought that it requires outs is not used, but the implementation says no -- I can see the reason, but I'm still not fully convinced. Anyway, my initial idea is to only handle the cases you understand, and my assumption is that you only want to support elementwise operations when all the outs are not used. I'm being conservative here because people have different uses for linalg dialect. They could have a creative generic op that not uses outs but accidentally meets the requirement, and it would open up a can of worms. It prevents the divergence of the expectation of the pass between users and authors.

  ^bb0(%in_0: f32, %in_1, %out: f32):
    %21 = arith.addf %in_0, %in_1 : f32
    linalg.yield %21 : f32
  } 

I don't follow the example, the computation body looks like an elementwise operation to me. Did you miss indexing maps or something else? My understanding is that it sums up the in_0 and in_1 and yield the result? It is a generic op form of arith.addf in0, in1 : tensor<...xf32>, IIUC.

EDIT: I did not put the concrete action item here, sorry about that. I'd be more comfortable if you have both condition (ie., isElementwise() and isGenericOutsNotUsed) in the check.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need either an assertion or the actual check

Sounds good, will just go with an actual check as I currently do.

my assumption is that you only want to support elementwise operations when all the outs are not used

Yes, exactly. I think I misunderstood about what isElementwise() and probably misunderstood it as isUnaryOp(), therefore giving a non-relevant counter example in my last response. Then upon second review, I realized that this is dependent on element mappable traits which all arith op carries. I'll make sure to add this isElementwise() in the condition check. Thanks for raising the concern.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After reviewing the current requirements for this propagation, I think that we actually don't need any more checks than what is already there. I need to think a little more to fully convince myself, but I'll explain my current thinking.

Some current requirements for the pattern are (mostly from getPackingInfoFromOperand):

  • No gather semantics.
  • All packed dimensions of the iteration space must not exist as part of a composite AffineExpr in any indexing map result. This means that any packed dimension must exist as a simple AffineDimExpr in all indexing map results that contain it.
  • All packed dimensions of the iteration space must be parallel.

I think these conditions are enough to ensure the padding value does not matter in the generic op because this means that the set of padded dimensions are fully parallel and independent from the other dimensions of the op. Any padding elements of the generic op will only be used within the padded part of the iteration space, and the result tensor will then be unpacked, which removes the part of the tensor that resulted from the padded part of the iteration space. It does not matter what happens to the padding value in the body of the generic op, because the element that is ultimately written will be removed by the unpack.

for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
rewriter, loc, packInfo, genericOp, inputOperand);

if (auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>()) {
inputOperandsFromUnpackedSource.push_back(unpackOp.getSource());
} else {
inputOperandsFromUnpackedSource.push_back(packedOperand);
}

inputOperands.push_back(packedOperand);
indexingMaps.push_back(packedIndexingMap);
}

// If The pack and unpack op can be folded:
// 1) use unpack op source op for operand to fold unpack -> pack sequence
// 2) init tensor of the generic op can be replaced by the new tensor.empty
// as the generic out.
if (canUnpackPackFold) {
inputOperands = inputOperandsFromUnpackedSource;
if (auto destPack = dest.getDefiningOp<linalg::PackOp>())
dest = destPack.getDest();
}

int64_t numInnerLoops = packInfo.getNumTiledLoops();
SmallVector<utils::IteratorType> iterTypes =
genericOp.getIteratorTypesArray();
Expand Down
68 changes: 33 additions & 35 deletions mlir/test/Dialect/Linalg/data-layout-propagation.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -524,22 +524,11 @@ func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: t
// CHECK-LABEL: func.func @unpack_element_type_change
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
// CHECK: %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
// CHECK: %[[ARG0_PACK:.+]] = linalg.pack %[[UNPACKED_ARG0]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
// CHECK-SAME: ins(%[[ARG0_PACK]]
// CHECK-SAME: outs(%[[ARG1_PACK]]
// CHECK-SAME: ins(%[[ARG0]]
// CHECK-SAME: outs(%[[EMPTY]]
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG1]]
Expand All @@ -564,19 +553,11 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
// CHECK-LABEL: func.func @forward_tensor_empty
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK: %[[FINAL_RES:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
// CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
// CHECK-SAME: ins(%[[PACKED_ARG0]]
// CHECK-SAME: outs(%[[DEST]]
// CHECK-SAME: ins(%[[ARG0]]
// CHECK-SAME: outs(%[[EMPTY]]
// CHECK: %[[UNPACKED:.+]] = linalg.unpack %[[RES]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[FINAL_RES]]
Expand Down Expand Up @@ -810,12 +791,9 @@ func.func @unpack_empty_inner_dims(%arg0: tensor<12x64x56x56xf32>) -> tensor<12x
}

// CHECK-LABEL: func.func @unpack_empty_inner_dims
// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
// CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<12x64x56x56xf32>)
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: ins(%[[PACKED_ARG0]]
// CHECK-SAME: ins(%[[ARG0]]
// CHECK: %[[UNPACKED:.+]] = linalg.unpack %[[RES]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []

Expand Down Expand Up @@ -943,14 +921,10 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
// CHECK: %[[FINAL_RES:.+]] = tensor.empty() : tensor<16x540x960xi32>
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x540x960x16xi32>
// CHECK: %[[PACK_EMPTY:.+]] = tensor.empty() : tensor<1x1x1080x1920x16xi32>
// CHECK: %[[PACK_ARG0:.+]] = linalg.pack
// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [16]
// CHECK-SAME: into %[[PACK_EMPTY]]
// CHECK: %[[POOL:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
// CHECK-SAME: ins(%[[PACK_ARG0]], %[[ARG1]]
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]]
// CHECK-SAME: outs(%[[INIT]]
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[POOL]]
// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [16]
Expand Down Expand Up @@ -1421,3 +1395,27 @@ func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]]
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[UNPACK]] {{\[}}[0, 1], [2]] output_shape [256, 12, 256] : tensor<3072x256xf32> into tensor<256x12x256xf32>
// CHECK: return %[[EXPANDED]] : tensor<256x12x256xf32>

// -----

#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @fold_unpack_pack_after_bubble_up(%arg0: tensor<8x8x4x8xf32>) -> tensor<8x8x4x8xf32> {
%empty = tensor.empty() : tensor<32x64xf32>
%unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %empty : tensor<8x8x4x8xf32> -> tensor<32x64xf32>
%1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor<32x64xf32>) outs(%empty : tensor<32x64xf32>) {
^bb0(%in: f32, %out: f32):
%2 = arith.addf %in, %in : f32
linalg.yield %2 : f32
} -> tensor<32x64xf32>
%empty1 = tensor.empty() : tensor<8x8x4x8xf32>
%pack = linalg.pack %1 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %empty1 : tensor<32x64xf32> -> tensor<8x8x4x8xf32>
return %pack : tensor<8x8x4x8xf32>
}

// CHECK-LABEL: func.func @fold_unpack_pack_after_bubble_up
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x8x4x8xf32>
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
// CHECK-SAME: outs(%[[EMPTY]] : tensor<8x8x4x8xf32>)
// CHECK: return %[[GENERIC]] : tensor<8x8x4x8xf32>