diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h index 1481859e94a92..0c059967bb898 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h @@ -30,9 +30,11 @@ class SliceAttr; } // namespace xegpu } // namespace mlir +// clang-format off +#include #include #include -#include +// clang-format on #define GET_ATTRDEF_CLASSES #include diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index 19a52317956d2..6f03ec9dbed69 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -223,17 +223,17 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> { InterfaceMethod<"Derive a new layout by dropping InstData", "xegpu::DistributeLayoutAttr", "dropInstData">, - InterfaceMethod<[{Delinearizes a linear subgroup ID into its multidimensional - indices based on the effective subgroup layout.}], + InterfaceMethod<[{Delinearizes a linear ID into its multidimensional + indices based on the effective layout level.}], "FailureOr>", - "delinearizeSubgroupId", + "delinearizeId", (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId)>, - InterfaceMethod<[{Generates instructions to compute multidimensional offsets for blocks - assigned to a subgroup identified by linearId. The shape parameter - represents the workgroup-level problem size. Each subgroup may access + InterfaceMethod<[{Generates instructions to compute multidimensional coordinates for dist units + assigned to a level identified by linearId. The shape parameter + represents the higher-level problem size. Each level may access multiple blocks according to round-robin distribution rules.}], "FailureOr>>", - "getOffsets", + "computeDistributedCoords", (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef":$shape)>, InterfaceMethod { return {}; } - /// Delinearizes a linear subgroup ID into its multidimensional indices - /// based on the effective subgroup layout. + /// Delinearizes a linear ID into its multidimensional indices + /// based on the effective level of the layout. FailureOr> - delinearizeSubgroupId(OpBuilder &builder, Location loc, Value linearId); + delinearizeId(OpBuilder &builder, Location loc, Value linearId); - /// Generates instructions to compute multidimensional offsets for blocks - /// assigned to a subgroup identified by linearId. The shape parameter - /// represents the workgroup-level problem size. Each subgroup may access + /// Generates instructions to compute multidimensional coordinates for dist units + /// assigned to a level identified by linearId. The shape parameter + /// represents the higher-level problem size. Each `level` may access /// multiple blocks according to round-robin distribution rules. FailureOr>> - getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef shape); + computeDistributedCoords(OpBuilder &builder, Location loc, Value linearId, ArrayRef shape); /// Check if this is slice of some other layout. bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; } @@ -643,14 +643,15 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> { /// Delinearizes a linear subgroup ID into its multidimensional indices /// based on the effective subgroup layout. FailureOr> - delinearizeSubgroupId(OpBuilder &builder, Location loc, Value linearId); + delinearizeId(OpBuilder &builder, Location loc, Value linearId); - /// Generates instructions to compute multidimensional offsets for blocks + /// Generates instructions to compute multidimensional coordinates for blocks /// assigned to a subgroup identified by linearId. The shape parameter /// represents the workgroup-level problem size. Each subgroup may access /// multiple blocks according to round-robin distribution rules. + FailureOr>> - getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef shape); + computeDistributedCoords(OpBuilder &builder, Location loc, Value linearId, ArrayRef shape); /// Check if this is slice of some other layout. bool isSliceOf(const xegpu::DistributeLayoutAttr &other); diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td index 564d9c4d5422b..5f803233041ab 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td @@ -26,7 +26,7 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> { The pass distributes subgroup level (SIMD) XeGPU ops to work items. }]; let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect", - "vector::VectorDialect"]; + "vector::VectorDialect", "index::IndexDialect"]; } def XeGPUPropagateLayout : Pass<"xegpu-propagate-layout"> { diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index fcbf66dbe9e45..53b8c4f0bbd59 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -562,6 +562,8 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern { VectorType valOrResVecTy = dyn_cast(data.getType()); if (!valOrResVecTy) valOrResVecTy = VectorType::get(1, data.getType()); + if (valOrResVecTy.getShape().size() != 1) + return rewriter.notifyMatchFailure(op, "Expected 1D data vector."); int64_t elemBitWidth = valOrResVecTy.getElementType().getIntOrFloatBitWidth(); diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 24e909548fe0b..5a9b15e73002d 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -38,55 +38,61 @@ void XeGPUDialect::initialize() { >(); } -/// Generates instructions to compute offsets for a subgroup identified by -/// its multidimensional indices (sgId), using the specified subgroup layout -/// (sgLayout), subgroup data dimensions (sizePerSg), and the overall data -/// dimensions (sizePerWg). +// A `srcShape` consists of N distribution units, each being `subShapesLayout` x +// `subShape`. A `delinearizedId` is used to identify a particular `subShape` +// within each distribution unit. +// Example: +// WG data is 128x256. SG data is 16x32, in 4x2 layout, this gives a +// distribution unit of shape 64x64, we have 2x4 such distribution units. +// `delinearizedId` is used to identify a 16x32 of a subgroup in each +// distribution unit. static SmallVector> -genOffsetsComputingInsts(OpBuilder &builder, Location loc, - SmallVector sgId, ArrayRef sgLayout, - ArrayRef sizePerSg, - ArrayRef sizePerWg) { - - SmallVector> offsets; +genCoordinates(OpBuilder &builder, Location loc, + SmallVector delinearizedId, + ArrayRef subShapesLayout, ArrayRef subShape, + ArrayRef srcShape) { + SmallVector> coordinates; + + // A distribution unit must be less than or equal to `srcShape` + SmallVector distUnitShape = llvm::map_to_vector( + llvm::zip_equal(srcShape, + computeElementwiseMul(subShapesLayout, subShape)), + [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); }); - // nd local offset, localOffset[i] = sgId[i] * sizePerSg[i] - SmallVector localOffsets = llvm::map_to_vector( - llvm::zip(sgId, sizePerSg), [&](const auto &t) -> Value { + // Get the offset of `subShape` within a distribution unit. + SmallVector distUnitLocalOffset = llvm::map_to_vector( + llvm::zip(delinearizedId, subShape), [&](const auto &t) -> Value { return builder.createOrFold( loc, std::get<0>(t), builder.createOrFold(loc, std::get<1>(t))); }); - // distUnit[i] is the minimum value between sizePerWg[i] and - // sgLayout[i] * sizePerSg[i] - SmallVector distUnit = llvm::map_to_vector( - llvm::zip_equal(sizePerWg, computeElementwiseMul(sgLayout, sizePerSg)), - [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); }); - + // For each dist unit for (SmallVector unitOffs : - StaticTileOffsetRange(sizePerWg, distUnit)) { + StaticTileOffsetRange(srcShape, distUnitShape)) { + // Get dist unit offset within `srcShape`. SmallVector base = llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value { return arith::ConstantIndexOp::create(builder, loc, d); }); - - SmallVector adds = llvm::map_to_vector( - llvm::zip_equal(base, localOffsets), [&](const auto &t) -> Value { - return builder.createOrFold(loc, std::get<0>(t), - std::get<1>(t)); - }); - + // Calculate `subShape` offset within `srcShape`. + SmallVector adds = + llvm::map_to_vector(llvm::zip_equal(base, distUnitLocalOffset), + [&](const auto &t) -> Value { + return builder.createOrFold( + loc, std::get<0>(t), std::get<1>(t)); + }); + // Do not go beyond `srcShape` bounds. SmallVector mods = llvm::map_to_vector( - llvm::zip_equal(adds, sizePerWg), [&](const auto &t) -> Value { + llvm::zip_equal(adds, srcShape), [&](const auto &t) -> Value { return builder.createOrFold( loc, std::get<0>(t), arith::ConstantIndexOp::create(builder, loc, std::get<1>(t))); }); - offsets.push_back(mods); + coordinates.push_back(mods); } - return offsets; + return coordinates; } // Checks if the given shape can be evenly distributed based on the layout @@ -268,12 +274,7 @@ LayoutAttr::verify(llvm::function_ref emitError, } FailureOr> -LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc, - Value linearId) { - // delinearizeSubgroupId is only available for - // workgroup-level layout attribute - if (!isForWorkgroup()) - return failure(); +LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) { // TODO: handle order attribute auto hasDefaultOrder = [&]() { @@ -283,41 +284,52 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc, }; if (!hasDefaultOrder()) return mlir::emitError(loc, "order attribute is currently not supported."); - - auto dims = - llvm::map_to_vector(getEffectiveSgLayoutAsInt(), [&](int64_t d) -> Value { - return builder.createOrFold(loc, d); - }); + SmallVector layout; + if (isForWorkgroup()) { + layout = getEffectiveSgLayoutAsInt(); + } else if (isForSubgroup()) { + layout = getEffectiveLaneLayoutAsInt(); + } else { + return failure(); + } + auto dims = llvm::map_to_vector(layout, [&](int64_t d) -> Value { + return builder.createOrFold(loc, d); + }); return affine::delinearizeIndex(builder, loc, linearId, dims); } -/// Implements DistributeLayoutAttr::getOffsets to generate +/// Implements DistributeLayoutAttr::computeDistributedCoords to generate /// instructions for computing multi-dimensional offsets when distributed by /// LayoutAttr. FailureOr>> -LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, - ArrayRef shape) { - if (!isForWorkgroup()) +LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc, + Value linearId, ArrayRef shape) { + SmallVector layout; + SmallVector subShape; + if (isForWorkgroup()) { + layout = getEffectiveSgLayoutAsInt(); + subShape = getEffectiveSgDataAsInt(); + } else if (isForSubgroup()) { + layout = getEffectiveLaneLayoutAsInt(); + subShape = getEffectiveLaneDataAsInt(); + } else { return failure(); - - SmallVector sgLayout = getEffectiveSgLayoutAsInt(); - SmallVector sgShape = getEffectiveSgDataAsInt(); - if (sgShape.empty()) { - if (auto derivedShape = computeShapeRatio(shape, sgLayout)) - sgShape = derivedShape.value(); + } + if (subShape.empty()) { + if (auto derivedShape = computeShapeRatio(shape, layout)) + subShape = derivedShape.value(); else return failure(); } // delinearize Ids - auto maybeIds = delinearizeSubgroupId(builder, loc, linearId); + auto maybeIds = delinearizeId(builder, loc, linearId); if (failed(maybeIds)) return failure(); - SmallVector sgIds = *maybeIds; + SmallVector ids = *maybeIds; - return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape, - shape); + return genCoordinates(builder, loc, ids, layout, subShape, shape); } //===----------------------------------------------------------------------===// @@ -371,34 +383,43 @@ SliceAttr SliceAttr::flatten() const { } FailureOr> -SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc, - Value linearId) { +SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) { SliceAttr attr = flatten(); auto parent = dyn_cast(attr.getParent()); - return parent.delinearizeSubgroupId(builder, loc, linearId); + return parent.delinearizeId(builder, loc, linearId); } -/// Implements DistributeLayoutAttr::getOffsets to generate -/// instructions for computing multi-dimensional offsets when distributed by -/// SliceAttr. +// Implements DistributeLayoutAttr::computeDistributedCoords to generate +// instructions for computing multi-dimensional offsets when distributed by +// LayoutAttr. FailureOr>> -SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, - ArrayRef shape) { +SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc, + Value linearId, ArrayRef shape) { assert(getRank() == static_cast(shape.size()) && "invalid shape."); if (!isForWorkgroup()) return failure(); - SmallVector sgLayout = getEffectiveSgLayoutAsInt(); - SmallVector sgShape = getEffectiveSgDataAsInt(); - if (sgShape.empty()) { - if (auto derivedShape = computeShapeRatio(shape, sgLayout)) - sgShape = derivedShape.value(); + SmallVector layout; + SmallVector subShape; + if (isForWorkgroup()) { + layout = getEffectiveSgLayoutAsInt(); + subShape = getEffectiveSgDataAsInt(); + } else if (isForSubgroup()) { + layout = getEffectiveLaneLayoutAsInt(); + subShape = getEffectiveLaneDataAsInt(); + } else { + return failure(); + } + + if (subShape.empty()) { + if (auto derivedShape = computeShapeRatio(shape, layout)) + subShape = derivedShape.value(); else return failure(); } // delinearize Ids - auto maybeIds = delinearizeSubgroupId(builder, loc, linearId); + auto maybeIds = delinearizeId(builder, loc, linearId); if (failed(maybeIds)) return failure(); @@ -408,8 +429,7 @@ SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, SmallVector sgIds = XeGPUDialect::slice(ArrayRef(*maybeIds), dims); - return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape, - shape); + return genCoordinates(builder, loc, sgIds, layout, subShape, shape); } bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) { diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index abd12e2e69ac0..7b6c4b6c2c813 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -175,13 +175,13 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy, LogicalResult IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy, - UnitAttr subgroup_block_io, + UnitAttr subgroup_block_io, DistributeLayoutAttr layout, function_ref emitError) { if (!dataTy) { if (subgroup_block_io) return emitError() << "subgroup_block_io " - "are only allowed when result is a 1D VectorType."; + "are only allowed when result is a VectorType."; else return success(); } @@ -192,15 +192,37 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy, ArrayRef dataShape = dataTy.getShape(); ArrayRef mdescShape = mdescTy.getShape(); + SmallVector blockShape = mdescTy.getBlockShape(); + ArrayAttr strideAttr = mdescTy.getStrideAttr(); + SmallVector strides; + for (Attribute attr : strideAttr.getValue()) { + strides.push_back(cast(attr).getInt()); + } + if (subgroup_block_io && layout) { + auto laneData = layout.getEffectiveLaneDataAsInt(); + auto laneLayout = layout.getEffectiveLaneLayoutAsInt(); + if (!laneData.empty()) { + bool isLaneDataContiguous = + std::all_of(laneData.begin(), std::prev(laneData.end()), + [](int x) { return x == 1; }); + if (!isLaneDataContiguous) + return emitError() << "With subgroup_block_io, accessed data must be " + "contiguous and coalesced."; + for (size_t i = 0; i < laneData.size(); ++i) { + if (laneLayout[i] != blockShape[i]) + return emitError() << "With subgroup_block_io, the block shape must " + "match the lane layout."; + if (laneLayout[i] != 1 && strides[i] != 1) + return emitError() << "With subgroup_block_io, the distributed " + "dimensions must be contiguous."; + } + } + } if (dataShape.size() == 2) { - if (subgroup_block_io) - return emitError() << "subgroup_block_io " - "are only allowed when result is a 1D VectorType."; if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape), [](auto p) { return std::get<0>(p) > std::get<1>(p); })) return emitError() << "data shape must not exceed mem_desc shape."; } else { - SmallVector blockShape = mdescTy.getBlockShape(); // if the subgroup_block_io attribute is set, mdescTy must have block // attribute if (subgroup_block_io && !blockShape.size()) @@ -1105,7 +1127,7 @@ LogicalResult LoadMatrixOp::verify() { MemDescType mdescTy = getMemDesc().getType(); return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io, - [&]() { return emitError(); }); + getLayoutAttr(), [&]() { return emitError(); }); } //===----------------------------------------------------------------------===// @@ -1129,7 +1151,7 @@ LogicalResult StoreMatrixOp::verify() { UnitAttr subgroup_block_io = getSubgroupBlockIoAttr(); MemDescType mdescTy = getMemDesc().getType(); return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io, - [&]() { return emitError(); }); + getLayoutAttr(), [&]() { return emitError(); }); } namespace mlir { diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index d09dc196c0bf7..29ccc0a48786b 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/GPU/Utils/DistributionUtils.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" @@ -906,6 +907,186 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern { } }; +static SmallVector computeDistributedCoordinatesForMatrixOp( + PatternRewriter &rewriter, Location loc, xegpu::DistributeLayoutAttr layout, + Value laneId, ArrayRef payloadShape, ValueRange origOffsets) { + SmallVector newCoods; + auto maybeCoords = + layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape); + if (failed(maybeCoords)) + return {}; + assert(maybeCoords.value().size() == 1 && + "Expected one set of distributed offsets"); + SmallVector ofrVec = xegpu::addWithRightAligned( + rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]), + getAsOpFoldResult(origOffsets)); + newCoods = llvm::to_vector(llvm::map_range( + ofrVec, [&](OpFoldResult ofr) -> Value { return cast(ofr); })); + return newCoods; +} + +/// Pattern for distributing xegpu::LoadMatrixOp. +struct LoadMatrixDistribution final : public gpu::WarpDistributionPattern { + using gpu::WarpDistributionPattern::WarpDistributionPattern; + LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + gpu::YieldOp yield = warpOp.getTerminator(); + Operation *lastNode = yield->getPrevNode(); + auto matrixOp = dyn_cast_or_null(lastNode); + if (!matrixOp) + return failure(); + + OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) { + return isa(op) && matrixOp == op; + }); + if (!producedByLastLoad) + return rewriter.notifyMatchFailure( + warpOp, "The last op is not xegpu::LoadMatrixOp"); + const int operandIdx = producedByLastLoad->getOperandNumber(); + + VectorType sgPayloadTy = + dyn_cast(matrixOp.getResult().getType()); + VectorType warpResultTy = + cast(warpOp.getResult(operandIdx).getType()); + if (!sgPayloadTy) + return rewriter.notifyMatchFailure( + matrixOp, "the matrix op payload must be a vector type"); + + auto loc = matrixOp.getLoc(); + auto offsets = matrixOp.getMixedOffsets(); + if (offsets.empty()) + return rewriter.notifyMatchFailure(matrixOp, + "the load op must have offsets"); + SmallVector offsetsAsValues = + vector::getAsValues(rewriter, matrixOp.getLoc(), offsets); + + auto layout = matrixOp.getLayoutAttr(); + if (!layout) + return rewriter.notifyMatchFailure( + matrixOp, "the matrix operation lacks layout attribute"); + + FailureOr distPayloadByWarpOpOrFailure = + getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy); + if (failed(distPayloadByWarpOpOrFailure)) + return rewriter.notifyMatchFailure( + matrixOp, "Failed to distribute matrix op payload based on layout."); + + SmallVector operands = {matrixOp.getMemDesc()}; + const unsigned offsetsStartIdx = operands.size(); + operands.append(offsetsAsValues); + + SmallVector operandTypes = llvm::to_vector( + llvm::map_range(operands, [](Value v) { return v.getType(); })); + + SmallVector newRetIndices; + gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, operands, operandTypes, newRetIndices); + SmallVector newOperands = llvm::map_to_vector( + newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); }); + + SmallVector newConstOffsets{matrixOp.getConstOffsets()}; + std::fill(newConstOffsets.begin(), newConstOffsets.end(), + ShapedType::kDynamic); + DenseI64ArrayAttr newConstOffsetsAttr = + rewriter.getDenseI64ArrayAttr(newConstOffsets); + ValueRange currentOffsets = + ValueRange(newOperands).drop_front(offsetsStartIdx); + + SmallVector newCoords = currentOffsets; + rewriter.setInsertionPointAfter(newWarpOp); + + if (!matrixOp.getSubgroupBlockIoAttr()) { + newCoords = computeDistributedCoordinatesForMatrixOp( + rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(), + currentOffsets); + } + xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create( + rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure, + newOperands[0], ValueRange(newCoords), newConstOffsetsAttr, + matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{}); + // Resolve the output type and replace all uses. + rewriter.replaceAllUsesWith( + newWarpOp.getResult(operandIdx), + resolveDistributedTy(newOp.getResult(), warpResultTy, rewriter)); + return success(); + } +}; + +/// Pattern for distributing xegpu::StoreMatrixOp. +struct StoreMatrixDistribution final : public gpu::WarpDistributionPattern { + using gpu::WarpDistributionPattern::WarpDistributionPattern; + LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + gpu::YieldOp yield = warpOp.getTerminator(); + Operation *lastNode = yield->getPrevNode(); + auto matrixOp = dyn_cast_or_null(lastNode); + if (!matrixOp) + return failure(); + + VectorType sgPayloadTy = dyn_cast(matrixOp.getData().getType()); + if (!sgPayloadTy) + return rewriter.notifyMatchFailure( + matrixOp, "the matrix op payload must be a vector type"); + + auto loc = matrixOp.getLoc(); + auto offsets = matrixOp.getMixedOffsets(); + if (offsets.empty()) + return rewriter.notifyMatchFailure(matrixOp, + "the store op must have offsets"); + SmallVector offsetsAsValues = + vector::getAsValues(rewriter, matrixOp.getLoc(), offsets); + + auto layout = matrixOp.getLayoutAttr(); + if (!layout) + return rewriter.notifyMatchFailure( + matrixOp, "the matrix operation lacks layout attribute"); + + FailureOr distPayloadByWarpOpOrFailure = + getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy); + if (failed(distPayloadByWarpOpOrFailure)) + return rewriter.notifyMatchFailure( + matrixOp, "Failed to distribute matrix op payload based on layout."); + + SmallVector operands = {matrixOp.getData(), matrixOp.getMemDesc()}; + const unsigned offsetsStartIdx = operands.size(); + operands.append(offsetsAsValues); + + SmallVector operandTypes = llvm::to_vector( + llvm::map_range(operands, [](Value v) { return v.getType(); })); + operandTypes[0] = *distPayloadByWarpOpOrFailure; + + SmallVector newRetIndices; + gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, operands, operandTypes, newRetIndices); + SmallVector newOperands = llvm::map_to_vector( + newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); }); + + SmallVector newConstOffsets{matrixOp.getConstOffsets()}; + std::fill(newConstOffsets.begin(), newConstOffsets.end(), + ShapedType::kDynamic); + DenseI64ArrayAttr newConstOffsetsAttr = + rewriter.getDenseI64ArrayAttr(newConstOffsets); + ValueRange currentOffsets = + ValueRange(newOperands).drop_front(offsetsStartIdx); + + SmallVector newCoords = currentOffsets; + rewriter.setInsertionPointAfter(newWarpOp); + + if (!matrixOp.getSubgroupBlockIoAttr()) { + newCoords = computeDistributedCoordinatesForMatrixOp( + rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(), + currentOffsets); + } + + xegpu::StoreMatrixOp::create( + rewriter, loc, TypeRange{}, newOperands[0], newOperands[1], + ValueRange(newCoords), newConstOffsetsAttr, + matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{}); + rewriter.eraseOp(matrixOp); + return success(); + } +}; + /// Distribute a scattered load op. The logic and requirements are the same as /// for the scattered store distribution. The warpOp's payload vector is /// expected to be distributed by the load's result consumer. @@ -1437,7 +1618,8 @@ void xegpu::populateXeGPUSubgroupDistributePatterns( LoadNdDistribution, DpasDistribution, PrefetchNdDistribution, GpuBarrierDistribution, VectorMultiReductionDistribution, LoadDistribution, StoreDistribution, VectorTransposeDistribution, - VectorBitcastDistribution, + VectorBitcastDistribution, LoadMatrixDistribution, + StoreMatrixDistribution, MemrefExtractAlignedPointerAsIndexDistribution>( patterns.getContext(), /*pattern benefit=*/regularPatternBenefit); @@ -1462,6 +1644,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() { // Layouts are needed for vector type only. if (!isa(operand.get().getType())) continue; + if (isa(op)) + continue; auto layout = xegpu::getDistributeLayoutAttr(operand.get()); if (!layout) { diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 9fc5ad9af5c7b..79eea55c8b78a 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -114,7 +114,8 @@ genOffsetsList(ConversionPatternRewriter &rewriter, OpType op, // Compute the list of subgroup-relative offsets for sub-tensors or sub-memory // descriptors to be accessed, based on the layout information. ArrayRef wgShape = op.getDataShape(); - auto maybeDescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape); + auto maybeDescOffsets = + layout.computeDistributedCoords(rewriter, loc, sgId, wgShape); if (failed(maybeDescOffsets)) return failure(); @@ -830,8 +831,8 @@ struct WgToSgArithConstantOp : public OpConversionPattern { // Get subgroup id Value sgId = gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); - - auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape); + auto sgOffsets = + layout.computeDistributedCoords(rewriter, loc, sgId, wgShape); if (failed(sgOffsets)) return failure(); @@ -1052,7 +1053,8 @@ struct WgToSgVectorStepOp : public OpConversionPattern { Value sgId = gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); - auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape); + auto sgOffsets = + layout.computeDistributedCoords(rewriter, loc, sgId, wgShape); if (failed(sgOffsets)) return failure(); diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index ebbe3ce0ec0d0..92f353717ac59 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -451,7 +451,7 @@ func.func @store_scatter_offset_wi_1(%src: memref) { %offsets = arith.constant dense<[0]> : vector<1xindex> %mask = arith.constant dense<1>: vector<1xi1> // expected-error@+1 {{Mask should match value except the chunk size dim}} - xegpu.store %val, %src[%offsets], %mask + xegpu.store %val, %src[%offsets], %mask : vector<4xf16>, memref, vector<1xindex>, vector<1xi1> return } @@ -870,14 +870,6 @@ func.func @load_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>) { return } -// ----- -func.func @load_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>) { - // expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}} - %data2 = xegpu.load_matrix %arg0[8, 8] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16> -> vector<16x16xf16> - return -} - - // ----- func.func @store_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf32>) { // expected-error@+1 {{failed to verify that all of {mem_desc, data} have same element type}} @@ -900,16 +892,25 @@ func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: ve } // ----- -func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) { - // expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}} - xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16> +func.func @simt_store_matrix_vector_nonlinear(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout>, %arg1: vector<2x16xf32>) { + // expected-error@+1 {{With subgroup_block_io, accessed data must be contiguous and coalesced}} + xegpu.store_matrix %arg1, %arg0[0, 0] {subgroup_block_io, layout = #xegpu.layout} : + vector<2x16xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout> return } // ----- -func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) { - // expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}} - xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16> +func.func @simt_store_matrix_vector_noncoalesced(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout>, %arg1: vector<16x2xf32>) { + // expected-error@+1 {{With subgroup_block_io, the distributed dimensions must be contiguous}} + xegpu.store_matrix %arg1, %arg0[0, 0] {subgroup_block_io, layout = #xegpu.layout} : + vector<16x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout> return } +// ----- +func.func @simt_store_matrix_vector_noncoalesced(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout>, %arg1: vector<16x2xf32>) { + // expected-error@+1 {{With subgroup_block_io, the block shape must match the lane layout}} + xegpu.store_matrix %arg1, %arg0[0, 0] {subgroup_block_io, layout = #xegpu.layout} : + vector<16x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout> + return +} diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir index 27a3dc373c739..8946d14e80b72 100644 --- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir +++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir @@ -265,3 +265,66 @@ gpu.module @xevm_module{ gpu.return } } + +// ----- +// CHECK-LABEL: gpu.func @load_store_matrix_1({{.*}}) { +// CHECK: %[[LAYOUT_X:.*]] = arith.constant 8 : index +// CHECK: %[[LAYOUT_Y:.*]] = arith.constant 2 : index +// CHECK: %[[LANE_ID:.*]] = gpu.lane_id +// CHECK: %[[DELINEARIZED_LANE_Y:.*]] = affine.apply #{{.*}}()[%[[LANE_ID]]] +// CHECK: %[[DELINEARIZED_LANE_X:.*]] = affine.apply #{{.*}}()[%[[LANE_ID]]] +// CHECK: %[[LANE_Y_OFFSET:.*]] = index.remu %[[DELINEARIZED_LANE_Y]], %[[LAYOUT_Y]] +// CHECK: %[[LANE_X_OFFSET:.*]] = index.remu %[[DELINEARIZED_LANE_X]], %[[LAYOUT_X]] +// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<1x1xf32> +// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : vector<1x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index +gpu.module @xevm_module{ + gpu.func @load_store_matrix_1(%arg0: !xegpu.mem_desc<32x32xf32>) { + %c0 = arith.constant 0 : index + %1 = xegpu.load_matrix %arg0[%c0, %c0] <{layout = #xegpu.layout}> : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x8xf32> + xegpu.store_matrix %1, %arg0[%c0, %c0] <{layout = #xegpu.layout}> : vector<2x8xf32>, !xegpu.mem_desc<32x32xf32>, index, index + gpu.return + } +} + +// ----- +// CHECK-LABEL: gpu.func @load_store_matrix_2({{.*}}) { +// CHECK: %[[DIST_UNIT_HEIGHT_X:.*]] = arith.constant 4 : index +// CHECK: %[[DIST_UNIT_HEIGHT_Y:.*]] = arith.constant 8 : index +// CHECK: %[[LANE_DATA_Y:.*]] = arith.constant 2 : index +// CHECK: %[[USER_OFFSET_X:.*]] = arith.constant 1 : index +// CHECK: %[[LANE_ID:.*]] = gpu.lane_id +// CHECK: %[[DELINEARIZED_LANE_Y:.*]] = affine.apply #{{.*}}()[%[[LANE_ID]]] +// CHECK: %[[DELINEARIZED_LANE_X:.*]] = affine.apply #{{.*}}()[%[[LANE_ID]]] +// CHECK: %[[LANE_Y_OFFSET_1:.*]] = index.mul %[[DELINEARIZED_LANE_Y]], %[[LANE_DATA_Y]] +// CHECK: %[[LANE_Y_OFFSET:.*]] = index.remu %[[LANE_Y_OFFSET_1]], %[[DIST_UNIT_HEIGHT_Y]] +// CHECK: %[[LANE_X_OFFSET_1:.*]] = index.remu %[[DELINEARIZED_LANE_X]], %[[DIST_UNIT_HEIGHT_X]] +// CHECK: %[[LANE_X_OFFSET:.*]] = index.add %[[LANE_X_OFFSET_1]], %[[USER_OFFSET_X]] +// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x1xf32> +// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : vector<2x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index +gpu.module @xevm_module{ + gpu.func @load_store_matrix_2(%arg0: !xegpu.mem_desc<32x32xf32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %1 = xegpu.load_matrix %arg0[%c0, %c1] <{layout = #xegpu.layout}> : !xegpu.mem_desc<32x32xf32>, index, index -> vector<8x4xf32> + xegpu.store_matrix %1, %arg0[%c0, %c1] <{layout = #xegpu.layout}> : vector<8x4xf32>, !xegpu.mem_desc<32x32xf32>, index, index + gpu.return + } +} + +// ----- +// CHECK-LABEL: gpu.func @load_store_matrix_3({{.*}}) { +// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%{{.*}}, %{{.*}}] <{subgroup_block_io}>: +// CHECK-SAME: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout>, index, index -> vector<1x2xf32> +// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%{{.*}}, %{{.*}}] <{subgroup_block_io}>: +// CHECK-SAME: vector<1x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout>, index, index +gpu.module @xevm_module{ + gpu.func @load_store_matrix_3(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %1 = xegpu.load_matrix %arg0[%c0, %c1] {subgroup_block_io, layout = #xegpu.layout} : + !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout>, index, index -> vector<16x2xf32> + xegpu.store_matrix %1, %arg0[%c0, %c1] {subgroup_block_io, layout = #xegpu.layout} : + vector<16x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout>, index, index + gpu.return + } +} diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp index 76d461108b296..93d51441f5b81 100644 --- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp +++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp @@ -200,7 +200,8 @@ class TestStepOpPattern : public OpConversionPattern { Value sgId = gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); - auto maybeOffsets = sliceAttr.getOffsets(rewriter, loc, sgId, wgShape); + auto maybeOffsets = + sliceAttr.computeDistributedCoords(rewriter, loc, sgId, wgShape); if (failed(maybeOffsets)) return failure();