Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 8 additions & 16 deletions mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,14 @@ namespace linalg {
//===----------------------------------------------------------------------===//
// Utilities for inferring various semantics properties of Linalg ops.
//===----------------------------------------------------------------------===//
/// Shell function to compute the Destination Permutation of PackOp
/// This function uses the helper function `computePackUnPackPerm` to get
/// the permutation vector. Only major difference between UnPack and Pack is
/// that packOp uses destination rank whereas unpack Uses source rank.
SmallVector<int64_t> getPackInverseDestPerm(linalg::PackOp packOp);

/// Shell function to compute the Source Permutation of unPackOp.
/// This function, like the getPackInverseDestPerm uses the helper function
/// computePackUnPackPerm` to get the permutation vector.
/// Only major difference between UnPack and Pack is that packOp uses
/// destination rank whereas unpack Uses source rank.
SmallVector<int64_t> getUnPackInverseSrcPerm(linalg::UnPackOp unpackOp);

/// Shell function to compute the Source rank permutation for unpackOp
/// Unpack requires some packing metadata data information, so created
/// another function where this value is passed by reference.

/// Compute inverse permutation for the destination tensor (i.e. in the packed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

picky nit: I'd add a blank line before this comment. It looks easier to me; it is a new chunk of the declaration.

/// domain).
SmallVector<int64_t> getPackInverseDestPerm(linalg::PackOp packOp,
PackingMetadata &metadata);

/// Compute inverse permutation for the source tensor (i.e. in the packed
/// domain).
SmallVector<int64_t> getUnPackInverseSrcPerm(linalg::UnPackOp,
PackingMetadata &metadata);

Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,9 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,

// 2. Compute the permutation vector to shuffle packed shape into the shape
// before any outer or inner permutations have been applied.
PackingMetadata packingMetadata = computePackingMetadata(
packedTensorType.getRank(), packOp.getInnerDimsPos());
PackingMetadata packingMetadata;
SmallVector<int64_t> packedToStripMinedShapePerm =
getPackInverseDestPerm(packOp);
getPackInverseDestPerm(packOp, packingMetadata);

// 3. Compute the stripMinedShape: this is the packed shape before any outer
// or inner permutations have been applied.
Expand Down
218 changes: 114 additions & 104 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1564,13 +1564,6 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
return success();
}

/// Given a linalg::PackOp, return the `dest` shape before any packing
/// permutations.
static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
ArrayRef<int64_t> destShape) {
return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp));
}

/// Determines whether a mask for xfer_write is trivially "all true"
///
/// Given all the inputs required to generate a mask (mask sizes and shapes),
Expand Down Expand Up @@ -1761,99 +1754,6 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
return mlir::vector::maskOperation(builder, write, maskForWrite);
}

/// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
/// padding value and (3) input vector sizes into:
///
/// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
///
/// As in the following example:
/// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
///
/// This pack would be vectorized to:
///
/// %load = vector.mask %mask {
/// vector.transfer_read %arg0[%c0, %c0, %c0], %cst
/// {in_bounds = [true, true, true]} :
/// tensor<32x7x16xf32>, vector<32x8x16xf32>
/// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
/// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
/// to vector<32x4x2x1x16xf32>
/// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
/// %write = vector.transfer_write %transpose,
/// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
/// {in_bounds = [true, true, true, true, true]}
/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
///
/// If the (3) input vector sizes are not provided, the vector sizes are
/// determined by the result tensor shape and the `in_bounds`
/// attribute is used instead of masking to mark out-of-bounds accesses.
///
/// NOTE: The input vector sizes specify the dimensions corresponding to the
/// outer dimensions of the output tensor. The remaining dimensions are
/// computed based on, e.g., the static inner tiles.
/// Supporting dynamic inner tiles will require the user to specify the
/// missing vector sizes. This is left as a TODO.
static LogicalResult
vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
ArrayRef<int64_t> inputVectorSizes,
SmallVectorImpl<Value> &newResults) {
// TODO: Introduce a parent class that will handle the insertion point update.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(packOp);

Location loc = packOp.getLoc();
std::optional<Value> padValue = packOp.getPaddingValue()
? std::optional(packOp.getPaddingValue())
: std::nullopt;

// If the input vector sizes are not provided, then the vector sizes are
// determined by the result tensor shape. In case the vector sizes aren't
// provided, we update the inBounds attribute instead of masking.
bool useInBoundsInsteadOfMasking = false;
if (inputVectorSizes.empty()) {
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
useInBoundsInsteadOfMasking = true;
}

// Create masked TransferReadOp.
SmallVector<int64_t> inputShape(inputVectorSizes);
auto innerTiles = packOp.getStaticInnerTiles();
auto innerDimsPos = packOp.getInnerDimsPos();
auto outerDimsPerm = packOp.getOuterDimsPerm();
if (!outerDimsPerm.empty())
applyPermutationToVector(inputShape,
invertPermutationVector(outerDimsPerm));
for (auto [idx, size] : enumerate(innerTiles))
inputShape[innerDimsPos[idx]] *= size;
auto maskedRead = vector::createReadOrMaskedRead(
rewriter, loc, packOp.getSource(), inputShape, padValue,
useInBoundsInsteadOfMasking,
/*inputScalableVecSizes=*/{});

// Create ShapeCastOp.
SmallVector<int64_t> destShape(inputVectorSizes);
destShape.append(innerTiles.begin(), innerTiles.end());
auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
packOp.getDestType().getElementType());
auto shapeCastOp =
vector::ShapeCastOp::create(rewriter, loc, tiledPackType, maskedRead);

// Create TransposeOp.
auto destPermutation =
invertPermutationVector(getPackInverseDestPerm(packOp));
auto transposeOp = vector::TransposeOp::create(
rewriter, loc, shapeCastOp.getResult(), destPermutation);

// Create TransferWriteOp.
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, transposeOp.getResult(), packOp.getDest());
newResults.push_back(write->getResult(0));
return success();
}

/// Given the re-associations, "collapses" the input Vector type
///
/// This is similar to CollapseShapeOp::inferCollapsedType with two notable
Expand Down Expand Up @@ -1901,12 +1801,120 @@ static VectorType getCollapsedVecType(VectorType type,
return VectorType::get(newShape, type.getElementType(), newScalableFlags);
}

/// Vectorize `linalg.pack` as:
/// * xfer_read -> shape_cast -> transpose -> xfer_write
///
/// The input-vector-sizes specify the _write_ vector sizes (i.e. the vector
/// sizes for the xfer_write operation). This is sufficient to infer the other
/// vector sizes required here.
///
/// If the vector sizes are not provided:
/// * the vector sizes are determined from the destination tensor static shape.
/// * the inBounds attribute is used instead of masking.
///
/// EXAMPLE (no vector sizes):
/// ```
/// %pack = tensor.pack %src
/// inner_dims_pos = [2, 1]
/// inner_tiles = [16, 2]
/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
/// ``
/// is vectorizes as:
/// ```
/// %read = vector.transfer_read %src
/// : tensor<32x7x16xf32>, vector<32x8x16xf32>
/// %sc = vector.shape_cast %read
/// : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
/// %tr = vector.transpose %sc, [0, 1, 3, 4, 2]
/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
/// %write = vector.transfer_write %tr into %dest
/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
/// ```
static LogicalResult
vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
ArrayRef<int64_t> inputVectorSizes,
SmallVectorImpl<Value> &newResults) {
if (!inputVectorSizes.empty()) {
assert(inputVectorSizes.size() == packOp.getDestRank() &&
"Invalid number of input vector sizes!");
}

// TODO: Introduce a parent class that will handle the insertion point update.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(packOp);

Location loc = packOp.getLoc();
std::optional<Value> padValue = packOp.getPaddingValue()
? std::optional(packOp.getPaddingValue())
: std::nullopt;

SmallVector<int64_t> destShape =
SmallVector<int64_t>(packOp.getDestType().getShape());

// This is just a convenience alias to clearly communicate that the input
// vector sizes determine the _write_ sizes.
ArrayRef<int64_t> &writeVectorSizes = inputVectorSizes;

// In the absence of input-vector-sizes, use the _static_ input tensor shape.
// In addition, use the inBounds attribute instead of masking.
bool useInBoundsInsteadOfMasking = false;
if (writeVectorSizes.empty()) {
if (ShapedType::isDynamicShape(destShape))
return rewriter.notifyMatchFailure(packOp,
"unable to infer vector sizes");

writeVectorSizes = destShape;
useInBoundsInsteadOfMasking = true;
}

// Compute vector type for the _read_ opeartion. The required dims are
// determined based on the _write_ vector sizes. This is done in two
// steps:
// 1) Invert the permutation/transposition that's part of the Pack
// operation.
// 2) Collapse the tiled sizes/dims to "return" to the unpacked domain.
PackingMetadata packMetadata;
auto destInvPermutation = getPackInverseDestPerm(packOp, packMetadata);

SmallVector<int64_t> writeVecSizesUnpermuted(writeVectorSizes);
applyPermutationToVector(writeVecSizesUnpermuted, destInvPermutation);

VectorType readVecType = getCollapsedVecType(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is a little hard to follow at first glance IMO

VectorType::get(writeVecSizesUnpermuted,
packOp.getType().getElementType()),
getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
rewriter.getContext(), packMetadata.reassociations)));

// Create masked TransferReadOp.
auto maskedRead = vector::createReadOrMaskedRead(
rewriter, loc, packOp.getSource(), readVecType.getShape(), padValue,
useInBoundsInsteadOfMasking,
/*inputScalableVecSizes=*/{});

// Create ShapeCastOp.
auto expandedVecType = VectorType::get(writeVecSizesUnpermuted,
packOp.getType().getElementType());
auto shapeCastOp =
vector::ShapeCastOp::create(rewriter, loc, expandedVecType, maskedRead);

// Create TransposeOp.
auto destPermutation = invertPermutationVector(destInvPermutation);
auto transposeOp = vector::TransposeOp::create(
rewriter, loc, shapeCastOp.getResult(), destPermutation);

// Create TransferWriteOp.
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, transposeOp.getResult(), packOp.getDest());
newResults.push_back(write->getResult(0));
return success();
}

/// Vectorize `linalg.unpack` as:
/// * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
///
/// The input-vector-sizes specify the read vector sizes (i.e. the vector sizes
/// for the xfer_read operation). This is sufficient to infer the other vector
/// sizes required here.
/// The input-vector-sizes specify the _read_ vector sizes (i.e. the vector
/// sizes for the xfer_read operation). This is sufficient to infer the other
/// vector sizes required here.
///
/// If the vector sizes are not provided:
/// * the vector sizes are determined from the input tensor static shape.
Expand Down Expand Up @@ -1960,7 +1968,8 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
// In the absence of input-vector-sizes, use the _static_ input tensor shape.
if (inputVectorSizes.empty()) {
if (ShapedType::isDynamicShape(sourceShape))
return failure();
return rewriter.notifyMatchFailure(unpackOp,
"Unable to infer vector sizes!");

readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
useInBoundsInsteadOfMasking = true;
Expand Down Expand Up @@ -2443,6 +2452,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp,
ArrayRef<int64_t> inputVectorSizes) {
auto padValue = packOp.getPaddingValue();
Attribute cstAttr;
// TODO: Relax this condiiton
if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
LDBG() << "pad value is not constant: " << packOp;
return failure();
Expand Down
15 changes: 5 additions & 10 deletions mlir/lib/Dialect/Linalg/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,29 +171,24 @@ computePackUnPackPerm(int64_t rank, ArrayRef<int64_t> &innerDimsPos,
namespace mlir {
namespace linalg {

SmallVector<int64_t> getPackInverseDestPerm(PackOp packOp) {
SmallVector<int64_t> getPackInverseDestPerm(PackOp packOp,
PackingMetadata &metadata) {

PackingMetadata pMetadata;
int64_t packedRank = packOp.getDestType().getRank();
ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos();
ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
SmallVector<int64_t> packInvDestPerm =
computePackUnPackPerm(packedRank, innerDimPos, outerPerm, pMetadata);
computePackUnPackPerm(packedRank, innerDimPos, outerPerm, metadata);
return packInvDestPerm;
}

SmallVector<int64_t> getUnPackInverseSrcPerm(UnPackOp unpackOp) {
PackingMetadata metadata;
return getUnPackInverseSrcPerm(unpackOp, metadata);
}

SmallVector<int64_t> getUnPackInverseSrcPerm(UnPackOp unpackOp,
PackingMetadata &metadata) {
int64_t unpackRank = unpackOp.getSourceType().getRank();
int64_t packedRank = unpackOp.getSourceType().getRank();
ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
ArrayRef<int64_t> outerPerm = unpackOp.getOuterDimsPerm();
SmallVector<int64_t> unpackInvSrcPerm =
computePackUnPackPerm(unpackRank, innerDimPos, outerPerm, metadata);
computePackUnPackPerm(packedRank, innerDimPos, outerPerm, metadata);
return unpackInvSrcPerm;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,8 @@ module attributes {transform.with_named_sequence} {

///----------------------------------------------------------------------------------------
/// Tests for linalg.pack
///
/// TODO: Add similar tests for linalg.unpack
///----------------------------------------------------------------------------------------

// Note, see a similar test in:
Expand Down
Loading