Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions include/gc/Transforms/Utils/ValueUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,28 @@ Value flattenMemref(PatternRewriter &rewriter, Location loc, Value srcMemref);
// Return true if the memref has shared memory space.
bool hasSharedMemSpace(mlir::Value memref);

// Go through all parent 'memref.subview' ops for the given `memref`
// and return the folded offsets of all subviews and the root memref.
void computeSubviewOffsets(PatternRewriter &rewriter, Location loc,
Value memref, SmallVector<Value> &resultOffsets,
Value &resultRootMemref);

// Return the strides of the memref
SmallVector<OpFoldResult> getMemrefStrides(PatternRewriter &rewriter,
Location loc, Value memref);

// Squeeze the leading dimensions of a given memref up to 'maxDims'.
FailureOr<Value> squeezeMemref(PatternRewriter &rewriter, Location loc,
Value memref, size_t maxDims = 2);

// Squeeze the leading dimensions of memref operands of a given 'linalgOp'.
LogicalResult maybeSqueezeDims(PatternRewriter &rewriter,
linalg::LinalgOp linalgOp, size_t maxDims = 2);

// Return if a memref with the given shape can be squeezed to the shape of
// 'maxDims'. Only leading dimensions are considered squeezable.
bool canSqueezeDims(llvm::ArrayRef<int64_t> shape, size_t maxDims = 2);

} // namespace utils
} // namespace mlir

Expand Down
92 changes: 39 additions & 53 deletions lib/gc/Transforms/GPU/LinalgToXeGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,28 +62,6 @@ static Value createFullMask(PatternRewriter &rewriter, Location loc,
return res.getResult();
}

// Extracts the offsets from a subview operation as values.
// The differense from mlir::getMixedOffsets is that this function
// returns the offsets as mlir::Value that can already be used as an argument
// for other mlir::Operations.
static SmallVector<Value> extractOffsetsAsValues(PatternRewriter &rewriter,
Location loc,
memref::SubViewOp subview) {
SmallVector<Value> offsetValues;
auto staticOffsets = subview.getStaticOffsets();
auto dynamicOffsets = subview.getOffsets();
size_t dynIdx = 0;
for (size_t i = 0; i < staticOffsets.size(); i++) {
if (staticOffsets[i] == ShapedType::kDynamic)
offsetValues.push_back(dynamicOffsets[dynIdx++]);
else
offsetValues.push_back(
rewriter.create<arith::ConstantIndexOp>(loc, staticOffsets[i]));
}

return offsetValues;
}

// Max number of elements to load/store from SLM
constexpr int64_t maxSLMTileSize = 32;

Expand Down Expand Up @@ -214,7 +192,8 @@ static LogicalResult isValidMemrefOperand(linalg::LinalgOp linalgOp,
linalgOp, "Expect memref operand for XeGPU lowering");
}

if (type.getShape().size() > maxDims) {
if (type.getShape().size() > maxDims &&
!utils::canSqueezeDims(type.getShape(), maxDims)) {
return rewriter.notifyMatchFailure(
linalgOp, "Too high dimensionality for XeGPU operations");
}
Expand Down Expand Up @@ -856,43 +835,33 @@ static SmallVector<Value> createSLMDescTiles(PatternRewriter &rewriter,
auto srcType = cast<MemRefType>(src.getType());
assert(srcType.getRank() == 2 && "Expected a 2D memref");

SmallVector<int64_t> memrefStrides;
Value blockOffset;

SmallVector<Value> offsets;
Value rootMemref;
// 'imex::ConvertGPUXToSPIRVPass' doesn't allow 'memref.subview' ops in the
// GPU kernel. We have to merge the subview offsets into the descriptor
// offset.
if (auto subView = dyn_cast<memref::SubViewOp>(src.getDefiningOp())) {
auto offsets = extractOffsetsAsValues(rewriter, loc, subView);
assert(offsets.size() == 2 && "Expected 2D subview offsets");

auto xIntOffs = offsets[0];
auto yIntOffs = offsets[1];

// compute 'blockOffset' (beginning of the subview block in the original
// flat memref)
auto rowStride =
cast<MemRefType>(subView.getOperand(0).getType()).getShape()[1];
auto rowStrideValue =
rewriter.create<arith::ConstantIndexOp>(loc, rowStride);

auto rowBlockOffset =
rewriter.create<arith::MulIOp>(loc, xIntOffs, rowStrideValue)
.getResult();
blockOffset = rewriter.create<arith::AddIOp>(loc, rowBlockOffset, yIntOffs)
.getResult();
utils::computeSubviewOffsets(rewriter, loc, src, offsets, rootMemref);
auto rootStridesFold = utils::getMemrefStrides(rewriter, loc, rootMemref);
auto rootStrides =
getValueOrCreateConstantIndexOp(rewriter, loc, rootStridesFold);

memrefStrides = {rowStride, 1};
src = subView.getOperand(0);
} else {
// If the source is not a subview, then the blockOffset is 0
blockOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
memrefStrides = {srcType.getShape()[1], 1};
assert(rootStrides.size() == offsets.size() &&
"Expected same number of strides and offsets");

// blockOffset = sum(rootStrides[i] * offsets[i])
Value blockOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
for (size_t i = 0; i < rootStrides.size(); i++) {
auto mul = rewriter.create<arith::MulIOp>(loc, rootStrides[i], offsets[i]);
blockOffset = rewriter.create<arith::AddIOp>(loc, blockOffset, mul);
}

// Scatter descriptors only work with 1D memrefs
src = utils::flattenMemref(rewriter, loc, src);
auto memrefStridesFold = utils::getMemrefStrides(rewriter, loc, src);
auto [memrefStrides, memrefStridesDynamic] =
decomposeMixedValues(memrefStridesFold);
assert(memrefStridesDynamic.size() == 0 &&
"Expected all values to be resolved");

src = utils::flattenMemref(rewriter, loc, rootMemref);
return createScatterDescriptorTiles(
rewriter, loc, /*flatMemref=*/src, /*loadShape2D=*/loadShape,
/*tileSize2D=*/descTile, /*memrefStrides=*/memrefStrides,
Expand Down Expand Up @@ -1839,6 +1808,11 @@ struct ConvertGemmLikeToXeGPU : public OpRewritePattern<LinalgOpTy> {
if (failed(isOutputValid))
return isOutputValid;

if (failed(mlir::utils::maybeSqueezeDims(rewriter, gemmLikeOp))) {
return rewriter.notifyMatchFailure(
gemmLikeOp, "Failed to squeeze dimensions of GEMM-like operation");
}

// Ensure that reduction dimension tiling also works for smaller
// workloads.
auto aType = cast<ShapedType>(gemmLikeOp.getDpsInputs()[0].getType());
Expand Down Expand Up @@ -1894,6 +1868,12 @@ struct ConvertNamedEltwiseToXeGPU : public OpRewritePattern<LinalgOpTy> {
if (failed(isOutputValid))
return isOutputValid;

if (failed(utils::maybeSqueezeDims(rewriter, eltwiseOp))) {
return rewriter.notifyMatchFailure(
eltwiseOp,
"Could not squeeze dimensions of the elementwise operation");
}

return createEltwiseKernel(eltwiseOp, rewriter);
}

Expand Down Expand Up @@ -1988,6 +1968,12 @@ struct ConvertMemoryFillToXeGPU : public OpRewritePattern<LinalgOpTy> {
if (failed(isOutputValid))
return isOutputValid;

if (failed(utils::maybeSqueezeDims(rewriter, linalgOp))) {
return rewriter.notifyMatchFailure(
linalgOp,
"Could not squeeze dimensions of the memory fill operation");
}

return createMemoryFillKernel(linalgOp, rewriter);
}

Expand Down
131 changes: 129 additions & 2 deletions lib/gc/Transforms/Utils/ValueUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@
//
//===----------------------------------------------------------------------===//

#include <numeric>

#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Matchers.h"
Expand Down Expand Up @@ -155,9 +159,10 @@ Value flattenMemref(PatternRewriter &rewriter, Location loc, Value srcMemref) {
auto srcType = cast<MemRefType>(srcMemref.getType());

assert(srcType && "Expected a memref type");
assert(srcType.getRank() == 2 && "Expected a 2D memref");

int64_t flatSize = srcType.getShape()[0] * srcType.getShape()[1];
auto shapeNd = srcType.getShape();
int64_t flatSize =
std::accumulate(shapeNd.begin(), shapeNd.end(), 1, std::multiplies<>());

Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value size = rewriter.create<arith::ConstantIndexOp>(loc, flatSize);
Expand Down Expand Up @@ -193,5 +198,127 @@ bool hasSharedMemSpace(mlir::Value memref) {
return false;
}

void computeSubviewOffsets(PatternRewriter &rewriter, Location loc,
Value memref, SmallVector<Value> &resultOffsets,
Value &resultRootMemref) {
auto fillVal = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto origShape = dyn_cast<MemRefType>(memref.getType()).getShape();

resultOffsets.clear();
resultOffsets.append(origShape.size(), fillVal);
resultRootMemref = memref;

while (auto subViewOp = resultRootMemref.getDefiningOp<memref::SubViewOp>()) {
auto currentOffsets = getAsOpFoldResult(resultOffsets);
resultOffsets.clear();

affine::resolveIndicesIntoOpWithOffsetsAndStrides(
rewriter, resultRootMemref.getLoc(), subViewOp.getMixedOffsets(),
subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), currentOffsets,
resultOffsets);
resultRootMemref = subViewOp.getOperand(0);
}
}

SmallVector<OpFoldResult> getMemrefStrides(PatternRewriter &rewriter,
Location loc, Value memref) {
auto type = dyn_cast<MemRefType>(memref.getType());

auto stridedLayout = dyn_cast<StridedLayoutAttr>(type.getLayout());
if (stridedLayout) {
auto strides = stridedLayout.getStrides();
return getMixedValues(strides, {}, rewriter);
}

auto sizes = getMixedValues(type.getShape(), {}, rewriter);
auto strides = memref::computeStridesIRBlock(loc, rewriter, sizes);
return strides;
}

FailureOr<Value> squeezeMemref(PatternRewriter &rewriter, Location loc,
Value memref, size_t maxDims = 2) {
auto type = dyn_cast<MemRefType>(memref.getType());
auto shape = type.getShape();

if (shape.size() <= maxDims)
return memref;

for (size_t i = 0; i < shape.size() - maxDims; i++)
if (shape[i] != 1)
return failure();

auto offsets =
getMixedValues(SmallVector<int64_t>(shape.size(), 0), {}, rewriter);
auto sizes = getMixedValues(shape, {}, rewriter);
auto staticStrides = utils::getStaticStrides(memref).value();
auto strides =
getMixedValues(SmallVector<int64_t>(shape.size(), 1), {}, rewriter);

SmallVector<int64_t> newShape(shape.begin() + shape.size() - maxDims,
shape.end());
SmallVector<int64_t> newStrides(
staticStrides.begin() + shape.size() - maxDims, staticStrides.end());

int64_t newOffset = 0;
if (auto memrefLayout = dyn_cast<StridedLayoutAttr>(type.getLayout()))
newOffset = memrefLayout.getOffset();

auto newLayout = StridedLayoutAttr::get(
rewriter.getContext(), /*offset=*/newOffset, /*strides=*/newStrides);
MemRefType newMemRefType = MemRefType::get(newShape, type.getElementType(),
newLayout, type.getMemorySpace());

auto squeezedSubview =
rewriter
.create<memref::SubViewOp>(loc, newMemRefType, memref, offsets, sizes,
strides)
.getResult();
return squeezedSubview;
}

LogicalResult maybeSqueezeDims(PatternRewriter &rewriter,
linalg::LinalgOp linalgOp, size_t maxDims) {
SmallVector<std::pair<size_t, Value>> newOperands;
auto operands = linalgOp->getOperands();
auto loc = linalgOp.getLoc();

for (size_t i = 0; i < operands.size(); i++) {
auto operand = operands[i];
auto type = dyn_cast<MemRefType>(operand.getType());
if (!type) {
// Skip non-memref operands
continue;
}

if (type.getShape().size() <= maxDims)
continue;

auto res = squeezeMemref(rewriter, loc, operand, maxDims);
if (failed(res)) {
return rewriter.notifyMatchFailure(
linalgOp, "Can't squeeze memref to the desired number of dimensions");
}

auto flatSubview = res.value();
newOperands.emplace_back(i, flatSubview);
}

for (auto [i, operand] : newOperands)
linalgOp->setOperand(i, operand);

return success();
}

bool canSqueezeDims(llvm::ArrayRef<int64_t> shape, size_t maxDims) {
if (shape.size() <= maxDims)
return true;

for (size_t i = 0; i < shape.size() - maxDims; i++)
if (shape[i] != 1)
return false;

return true;
}

} // namespace utils
} // namespace mlir
61 changes: 61 additions & 0 deletions test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-squeeze.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// RUN: gc-opt %s -linalg-to-xegpu="dpas-tile=8,16,16 k-tile=16" -canonicalize -split-input-file -cse | FileCheck %s

!input_type = memref<2x4x8x16xf16>
!chunk_type = memref<1x1x8x16xf16, strided<[512, 128, 16, 1], offset: ?>>
!slm_chunk = memref<1x1x8x16xf16, strided<[128, 128, 16, 1], offset: ?>, 3>

// The map that computes an offset for SLM
// CHECK: #map = affine_map<(d0, d1) -> (d0 * 4 + d1)>
#map = affine_map<(xi, yi) -> (xi * 4 + yi)>

func.func @entry(%arg0: !input_type, %arg1: !input_type, %arg2: !input_type) {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index

gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c1, %arg10 = %c1, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c2, %arg13 = %c4, %arg14 = %c1) {
// CHECK: %[[ARG0_SB:.+]] = memref.subview %arg0[%arg6, %arg7, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1]
%arg0_sb = memref.subview %arg0[%arg6, %arg7, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : !input_type to !chunk_type
// CHECK: %[[ARG1_SB:.+]] = memref.subview %arg1[%arg6, %arg7, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1]
%arg1_sb = memref.subview %arg1[%arg6, %arg7, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : !input_type to !chunk_type
// CHECK: %[[ARG2_SB:.+]] = memref.subview %arg2[%arg6, %arg7, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1]
%arg2_sb = memref.subview %arg2[%arg6, %arg7, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : !input_type to !chunk_type

// CHECK: %[[SLM_BUFF:.+]] = memref.alloc() : memref<8x1x8x16xf16, 3>
%slm_root = memref.alloc() : memref<8x1x8x16xf16, 3>

%slm_idx = affine.apply #map(%arg6, %arg7)
%slm = memref.subview %slm_root[%slm_idx, 0, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : memref<8x1x8x16xf16, 3> to !slm_chunk

// Squeezing the arguments of 'linalg.mul'
// CHECK: %[[ARG0_SQUEEZ:.+]] = memref.subview %[[ARG0_SB]][0, 0, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] :
// CHECK-SAME: memref<1x1x8x16xf16, strided<[512, 128, 16, 1], offset: ?>> to memref<8x16xf16, strided<[16, 1], offset: ?>>

// CHECK: %[[ARG1_SQUEEZ:.+]] = memref.subview %[[ARG1_SB]][0, 0, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] :
// CHECK-SAME: memref<1x1x8x16xf16, strided<[512, 128, 16, 1], offset: ?>> to memref<8x16xf16, strided<[16, 1], offset: ?>>

// Verify that tensor descriptors are created from the squeezed memrefs
// CHECK: xegpu.create_nd_tdesc %[[ARG0_SQUEEZ]]
// CHECK: xegpu.create_nd_tdesc %[[ARG1_SQUEEZ]]

// Verify that the SLM output of linalg.mul is squeezed correctly
// CHECK-NOT: .* = memref.subview %[[SLM_BUFF]] .*
// CHECK: %[[SLM_THREAD_OFF:.+]] = affine.apply #map(%arg6, %arg7)
// CHECK: %[[SLM_OFF:.+]] = arith.muli %[[SLM_THREAD_OFF]], %c128 : index
// CHECK: %[[FLAT_SLM:.+]] = memref.reinterpret_cast %[[SLM_BUFF]] to offset: [%c0], sizes: [%c1024], strides: [%c1] : memref<8x1x8x16xf16, 3> to memref<1024xf16, 3>
// CHECK: xegpu.create_tdesc %[[FLAT_SLM]]
linalg.mul ins(%arg0_sb, %arg1_sb : !chunk_type, !chunk_type) outs(%slm : !slm_chunk)

// Squeezing the result buffer of 'linalg.add'
// CHECK: %[[ARG2_SQUEEZ:.+]] = memref.subview %[[ARG2_SB]][0, 0, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] :
// CHECK-SAME: memref<1x1x8x16xf16, strided<[512, 128, 16, 1], offset: ?>> to memref<8x16xf16, strided<[16, 1], offset: ?>>

// Verify that tensor descriptors are created from the squeezed memrefs
// CHECK: xegpu.create_nd_tdesc %[[ARG2_SQUEEZ]]
linalg.add ins(%arg0_sb, %slm : !chunk_type, !slm_chunk) outs(%arg2_sb : !chunk_type)

gpu.terminator
} {SCFToGPU_visited}

return
}