Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 44 additions & 46 deletions mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,38 +190,6 @@ static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm,
return outerDimsPerm;
}

/// Returns a tuple for packed operand and indexing_map with the assumptions:
/// 1) The generic op is the producer of the pack op.
/// 2) The generic op has only one result.
/// If the operand is a scalar or packing dimensions are all irrelevant to the
/// operand, the operand and the updated indexing map will be returned.
/// Otherwise, it returns the packed operand and the updated indexing map. E.g.,
///
/// #map0 = affine_map<(d0, d1) -> (d0, d1)>
/// #map1 = affine_map<(d0, d1) -> (d0)>
/// #map2 = affine_map<(d0, d1) -> (d1)>
/// %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0],
/// iterator_types = ["parallel", "parallel"]}
/// ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
/// outs(%init : tensor<?x?xf32>) {
/// ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
/// %4 = arith.addf %arg3, %arg4 : f32
/// linalg.yield %4 : f32
/// } -> tensor<?x?xf32>
/// %1 = linalg.pack %0
/// inner_dims_pos = [0, 1]
/// inner_tiles = [8, 2]
/// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
///
/// Taking the first input operand as an example, the inner tile size of d1 is
/// 8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> ->
/// affine_map<(d1, d3)>` will be returned.
///
/// %pack = linalg.pack %arg0
/// inner_dims_pos = [0]
/// inner_tiles = [8]
/// into %init : tensor<?xf32> -> tensor<?x8xf32>

struct PackedOperandDetails {
SmallVector<OpFoldResult> innerTileSizes;
SmallVector<int64_t> innerDimsPos;
Expand All @@ -231,7 +199,7 @@ struct PackedOperandDetails {

/// Helper function for getOrCreatePackedViewOfOperand that populates
/// the details of the packedOperand that needs to be formed and also
// returns if the packing would require padding.
/// returns if the packing would require padding.
static bool getPackedOperandDetails(
OpBuilder &b, PackInfo packInfo, GenericOp genericOp, OpOperand *opOperand,
DenseMap<OpOperand *, PackedOperandDetails> &packedOperandMap) {
Expand Down Expand Up @@ -323,23 +291,53 @@ static bool getPackedOperandDetails(
currOperandDetails.outerDimsPerm = outerDimsPerm;
packedOperandMap[opOperand] = currOperandDetails;

if (requirePadding)
return true;
return false;
return requirePadding;
}

/// Returns a tuple for packed operand and indexing_map with the assumptions:
/// 1) The generic op is the producer of the pack op.
/// 2) The generic op has only one result.
/// If the operand is a scalar or packing dimensions are all irrelevant to the
/// operand, the operand and the updated indexing map will be returned.
/// Otherwise, it returns the packed operand and the updated indexing map. E.g.,
///
/// #map0 = affine_map<(d0, d1) -> (d0, d1)>
/// #map1 = affine_map<(d0, d1) -> (d0)>
/// #map2 = affine_map<(d0, d1) -> (d1)>
/// %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0],
/// iterator_types = ["parallel", "parallel"]}
/// ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
/// outs(%init : tensor<?x?xf32>) {
/// ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
/// %4 = arith.addf %arg3, %arg4 : f32
/// linalg.yield %4 : f32
/// } -> tensor<?x?xf32>
/// %1 = linalg.pack %0
/// inner_dims_pos = [0, 1]
/// inner_tiles = [8, 2]
/// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
///
/// Taking the first input operand as an example, the inner tile size of d1 is
/// 8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> ->
/// affine_map<(d1, d3)>` will be returned.
///
/// %pack = linalg.pack %arg0
/// inner_dims_pos = [0]
/// inner_tiles = [8]
/// into %init : tensor<?xf32> -> tensor<?x8xf32>

static std::tuple<Value, AffineMap> getOrCreatePackedViewOfOperand(
OpBuilder &b, Location loc, OpOperand *opOperand,
DenseMap<OpOperand *, PackedOperandDetails> packedOperandMap) {
const DenseMap<OpOperand *, PackedOperandDetails> &packedOperandMap) {
assert(packedOperandMap.contains(opOperand) &&
"packed operand details expected to be populated");
auto currOperandDetails = packedOperandMap[opOperand];
auto currOperandDetails = packedOperandMap.at(opOperand);
auto innerDimsPos = currOperandDetails.innerDimsPos;
auto outerDimsPerm = currOperandDetails.outerDimsPerm;
auto innerTileSizes = currOperandDetails.innerTileSizes;
if (innerDimsPos.empty() && outerDimsPerm.empty()) {
if (innerDimsPos.empty() && outerDimsPerm.empty())
return std::make_tuple(opOperand->get(), currOperandDetails.indexingMap);
}

auto empty = linalg::PackOp::createDestinationTensor(
b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm);
auto poison = ub::PoisonOp::create(
Expand Down Expand Up @@ -375,9 +373,9 @@ packGenericOp(RewriterBase &rewriter, GenericOp genericOp, Value dest,
requiresPadding |= getPackedOperandDetails(rewriter, packInfo, genericOp,
inputOperand, packedOperandMap);
}
if (requiresPadding && !poisonPaddingOk) {
if (requiresPadding && !poisonPaddingOk)
return failure();
}

for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
rewriter, loc, inputOperand, packedOperandMap);
Expand Down Expand Up @@ -538,9 +536,9 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
DenseMap<OpOperand *, PackedOperandDetails> packedOperandMap;
bool requiresPadding = getPackedOperandDetails(rewriter, *packInfo, genericOp,
opOperand, packedOperandMap);
if (requiresPadding && !poisonPaddingOk) {
if (requiresPadding && !poisonPaddingOk)
return failure();
}

auto [packedOutOperand, packedOutIndexingMap] =
getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), opOperand,
packedOperandMap);
Expand Down Expand Up @@ -1186,9 +1184,9 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
bool requiresPadding =
getPackedOperandDetails(rewriter, *packInfo, genericOp,
genericOp.getDpsInitOperand(0), packedOperandMap);
if (requiresPadding && !poisonPaddingOk) {
if (requiresPadding && !poisonPaddingOk)
return failure();
}

auto [packedOutOperand, packedOutIndexingMap] =
getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(),
genericOp.getDpsInitOperand(0),
Expand Down
Loading