diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index 032ce5bc18334..84c1dc1373ee5 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -295,11 +295,17 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> { } LayoutAttr dropSgLayoutAndData() { + // avoid every field of the attribute is nullptr, which may lead to segment fault + if (!getInstData() && !getLaneLayout()) + return nullptr; return LayoutAttr::get(getContext(), nullptr, nullptr, getInstData(), getLaneLayout(), getLaneData(), getOrder()); } LayoutAttr dropInstData() { + // avoid every field of the attribute is nullptr, which may lead to segment fault + if (!getSgLayout() && !getLaneLayout()) + return nullptr; return LayoutAttr::get(getContext(), getSgLayout(), getSgData(), nullptr, getLaneLayout(), getLaneData(), getOrder()); } diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 238bb1567d301..e6c7efc47593f 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -205,7 +205,7 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface auto memrefType = llvm::dyn_cast(getSourceType()); assert(memrefType && "Incorrect use of getStaticStrides"); - auto [strides, offset] = memrefType.getStridesAndOffset(); + auto [strides, _] = memrefType.getStridesAndOffset(); // reuse the storage of ConstStridesAttr since strides from // memref is not persistant setConstStrides(strides); diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td index 6f585f9ceb29b..8bdf19ac0e47d 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td @@ -45,4 +45,17 @@ def XeGPUWgToSgDistribute : Pass<"xegpu-wg-to-sg-distribute"> { "gpu::GPUDialect", "index::IndexDialect"]; } +def XeGPUBlocking: Pass<"xegpu-blocking"> { + let summary = "Block XeGPU ops into smaller size."; + let description = [{ + This pass partitions operations that process large shapes into multiple + operations on smaller shapes, as specified by the inst_data in the layout + attribute. This enables each resulting operation to be efficiently mapped + to a hardware instruction. + }]; + let dependentDialects = [ + "memref::MemRefDialect", "xegpu::XeGPUDialect", "vector::VectorDialect" + ]; +} + #endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h index 3616fa614e7f9..f9327d63869c0 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h +++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h @@ -13,6 +13,12 @@ namespace mlir { class VectorType; +class OpOperand; +class OpResult; +class OpBuilder; +class ValueRange; +class TypeConverter; + namespace xegpu { class LayoutAttr; class TensorDescType; @@ -50,6 +56,59 @@ FailureOr getDistributedVectorType(xegpu::TensorDescType tdescTy); FailureOr getDistributedVectorType(VectorType originalType, LayoutAttr layout); +/// Return the attribute name for the OpOperand to attach LayoutAttr +std::string getLayoutName(const OpOperand &operand); + +/// Return the attribute name for the OpResult to attach LayoutAttr +std::string getLayoutName(const OpResult result); + +/// Retrieves the LayoutAttr associated with a given Value. For TensorDescType +/// values, the LayoutAttr is extracted from the TensorDescType itself. For +/// other values, it is obtained from the attributes of the defining operation. +/// Returns nullptr if no LayoutAttr is found. +LayoutAttr getLayoutAttr(const Value value); + +/// Retrieves the LayoutAttr associated with a given OpOperand. It will +/// first check the operand_layout_{id} of the owner operation. If not found, +/// it will check the operand itself and its defining op. +LayoutAttr getLayoutAttr(const OpOperand &opr); + +/// Sets the LayoutAttr for a given OpOperand or OpResult by attaching +/// it to the owner's dictionary attributes +template || + std::is_same_v>> +void setLayoutAttr(const T &operandOrResult, const LayoutAttr layout); + +/// Set the LayoutAttr for each OpOperand and OpResult of the given operation. +/// If the operation contains regions, it is also applied recursively to the +/// contained operations +void setLayoutAttrs(Operation *op, + function_ref getLayoutImpl); + +/// Extract a set of small vectors from a value with a given shape using +/// vector.extract_stride_slice +SmallVector extractVectorsWithShapeFromValue(OpBuilder &builder, + Location loc, Value value, + ArrayRef shape); + +/// Create a vector of shape from a set of values using +/// vector.insert_stride_slice. +Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, + ValueRange values, + ArrayRef shape); + +/// Do type conversion for SCF structural ops, e.g., scf.for using SCF structure +/// type convertion patterns. Since VectorType cannot carry the layout +/// attribute, which is needed to guide the type conversion for XeGPU, they are +/// first converted into RankedTensorType, where the layout attribute can be +/// attached. And then upstream SCF structural type conversion patterns are +/// applied with the provided converter. +/// TODO: This is a temporary solution. We should refactor it when context-aware +/// type conversion is available. +void doSCFStructuralTypeConversionWithTensorType(Operation *op, + TypeConverter converter); + } // namespace xegpu } // namespace mlir diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt index 7d9b5584b0b2b..af0d7f6bd9070 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRXeGPUTransforms + XeGPUBlocking.cpp XeGPUFoldAliasOps.cpp XeGPUSubgroupDistribute.cpp XeGPUUnroll.cpp diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp new file mode 100644 index 0000000000000..7cd998eed2e08 --- /dev/null +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp @@ -0,0 +1,337 @@ +//===---- XeGPUBlocking.cpp ---- XeGPU Blocking Pass ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/XeGPU/Transforms/Passes.h" + +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" +#include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/Dialect/XeGPU/Transforms/Transforms.h" +#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/STLExtras.h" + +namespace mlir { +namespace xegpu { +#define GEN_PASS_DEF_XEGPUBLOCKING +#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc" +} // namespace xegpu +} // namespace mlir + +#define DEBUG_TYPE "xegpu-blocking" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; + +namespace { + +// reslove the unrealized conversion cast ops generated when doing SCF +// Structural Type Conversion. It will have two formats, N:1 vector +// cast and 1:N vector cast. vector::insert_strided_slice ops will be +// used for the first case, and vector::extract_strided_slice ops will be +// used for the second case. +static void +resolveUnrealizedConversionCastOp(UnrealizedConversionCastOp castOp) { + ValueRange inputs = castOp.getInputs(); + ValueRange outputs = castOp.getOutputs(); + + auto hasIdenticalVectorTypes = [](ValueRange values) { + auto types = values.getTypes(); + return llvm::all_of(types, [&](Type type) { + return isa(type) && type == types.front(); + }); + }; + + // We only interest in the case where all inputs and outputs have the + // identical VectorTypes + if (!hasIdenticalVectorTypes(inputs) || !hasIdenticalVectorTypes(outputs)) { + LDBG("skip unrealized conversion cast op not emulating pack/unpack."); + return; + } + + VectorType outputTy = dyn_cast(outputs[0].getType()); + OpBuilder builder(castOp); + if (inputs.size() > 1 && outputs.size() == 1) { + // the castOp is emulating an unpack op + ArrayRef shape = outputTy.getShape(); + Value result = xegpu::createVectorWithShapeFromValues( + builder, castOp.getLoc(), inputs, shape); + castOp->replaceAllUsesWith(ValueRange(result)); + castOp->erase(); + } else if (castOp.getNumResults() > 1 && castOp.getNumOperands() == 1) { + // the castOp is emulating a pack op + ArrayRef tileShape = outputTy.getShape(); + SmallVector results = xegpu::extractVectorsWithShapeFromValue( + builder, castOp.getLoc(), inputs[0], tileShape); + castOp->replaceAllUsesWith(results); + castOp->erase(); + } +} + +//===------------------------------------------------------------------------===// +// The XeGPUBlockingPass leverages the unroll patterns for XeGPU and Vector ops +// to partition operations that process large shapes into multiple operations on +// smaller shapes, as specified by the inst_data in the layout attribute. This +// enables each resulting operation to be efficiently mapped to a hardware +// instruction. +//===------------------------------------------------------------------------===// + +class XeGPUBlockingPass final + : public xegpu::impl::XeGPUBlockingBase { +public: + void runOnOperation() override; + +private: + // Get the tile shape for a given OpOperand or OpResult by examining the + // corresponding layout attribute. If layout is not present or is not a + // subgroup level layout, it returns std::nullopt. + template || + std::is_same_v>> + std::optional> + getTileShape(const T &operandOrResult) const; + + // Get the tile shape for a given operation. + std::optional> getTileShape(Operation *op) const; + + // Determine if the operation requires unrolling. Return false if all operands + // and results have tile shapes identical to their original types. Otherwise, + // return true. + bool needsUnroll(Operation *op) const; +}; +} // namespace + +template +std::optional> +XeGPUBlockingPass::getTileShape(const T &operandOrResult) const { + Value value; + if constexpr (std::is_same_v) + value = operandOrResult.get(); + else + value = (Value)operandOrResult; + + xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operandOrResult); + if (layout && layout.isSgLayout()) { + if (auto inst_data = layout.getInstData()) + return llvm::to_vector_of(inst_data.asArrayRef()); + + if (auto type = dyn_cast(value.getType())) + return llvm::to_vector(type.getShape()); + } + LDBG("failed to getTileShape for: " << value); + return std::nullopt; +} + +std::optional> +XeGPUBlockingPass::getTileShape(Operation *op) const { + if (isa(op)) + return getTileShape(op->getOpResult(0)); + if (isa(op)) + return getTileShape(op->getOpOperand(0)); + if (isa(op)) + return getTileShape(op->getOpOperand(1)); + + if (isa(op)) { + std::optional> aTile = + getTileShape(op->getOpOperand(0)); + std::optional> bTile = + getTileShape(op->getOpOperand(1)); + + if (!aTile || aTile->size() != 2 || !bTile || bTile->size() != 2) + return std::nullopt; + + // semantic check for A and B + if ((*aTile)[1] != (*bTile)[0]) + return std::nullopt; + + // semantic check for C + if (op->getNumOperands() == 3) { + std::optional> cTile = + getTileShape(op->getOpOperand(2)); + int64_t expectedCTile[2] = {(*aTile)[0], (*bTile)[1]}; + if (!cTile || !llvm::equal(*cTile, expectedCTile)) + return std::nullopt; + } + + return SmallVector({(*aTile)[0], (*aTile)[1], (*bTile)[1]}); + } + + if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) + return getTileShape(op->getOpResult(0)); + + return std::nullopt; +} + +bool XeGPUBlockingPass::needsUnroll(Operation *op) const { + // skip the op if any of its operands or results has workgroup level layouts + bool hasWgLayoutOperands = + llvm::any_of(op->getOpOperands(), [](OpOperand &opr) { + xegpu::LayoutAttr layout = xegpu::getLayoutAttr(opr); + return layout && layout.isWgLayout(); + }); + bool hasWgLayoutResults = + llvm::any_of(op->getOpResults(), [](OpResult result) { + xegpu::LayoutAttr layout = xegpu::getLayoutAttr(result); + return layout && layout.isWgLayout(); + }); + if (hasWgLayoutOperands || hasWgLayoutResults) { + LDBG("skip unrolling for op with workgroup level layout: " << *op); + return false; + } + + auto isUnrollable = [](Value value, ArrayRef tileShape) { + Type valTy = value.getType(); + if (auto tdescTy = dyn_cast(valTy)) { + xegpu::LayoutAttr layout = tdescTy.getLayoutAttr(); + return layout && layout.getInstData(); + } + auto shapedType = dyn_cast(valTy); + return shapedType && !llvm::equal(tileShape, shapedType.getShape()); + }; + + bool hasUnrollableOperands = + llvm::any_of(op->getOpOperands(), [&](OpOperand &opr) { + std::optional> tileShape = getTileShape(opr); + return tileShape.has_value() && isUnrollable(opr.get(), *tileShape); + }); + bool hasUnrollableResults = + llvm::any_of(op->getOpResults(), [&](OpResult result) { + std::optional> tileShape = getTileShape(result); + return tileShape.has_value() && isUnrollable(result, *tileShape); + }); + return hasUnrollableOperands || hasUnrollableResults; +} + +void XeGPUBlockingPass::runOnOperation() { + MLIRContext *ctx = &getContext(); + Operation *op = getOperation(); + + // Preserve the LayoutAttr for each operand to the owner's DictionaryAttr. + // This ensures that the LayoutAttr remains accessible even if the defining + // operation is replaced. + xegpu::setLayoutAttrs(op, [](Value v) { return xegpu::getLayoutAttr(v); }); + + auto getTileShapeAndCount = [](llvm::ArrayRef shape, + xegpu::LayoutAttr layout) { + int count = 1; + SmallVector tileShape(shape); + if (layout && layout.getInstData()) { + DenseI32ArrayAttr instData = layout.getInstData(); + tileShape = llvm::to_vector_of(instData.asArrayRef()); + count = computeProduct(shape) / computeProduct(tileShape); + } + return std::make_pair(tileShape, count); + }; + + // Perform type conversion for SCF control folow ops + TypeConverter converter; + converter.addConversion([](Type type) -> Type { return type; }); + converter.addConversion( + [&](RankedTensorType type, + SmallVectorImpl &result) -> std::optional { + Type elemTy = type.getElementType(); + ArrayRef shape = type.getShape(); + + auto layout = + llvm::dyn_cast_if_present(type.getEncoding()); + if (layout && layout.isWgLayout()) + return failure(); + + int count; + SmallVector subShape; + std::tie(subShape, count) = getTileShapeAndCount(shape, layout); + auto newTy = VectorType::get(subShape, elemTy); + result.append(count, newTy); + return success(); + }); + converter.addConversion( + [&](xegpu::TensorDescType type, + SmallVectorImpl &result) -> std::optional { + Type elemTy = type.getElementType(); + ArrayRef shape = type.getShape(); + + xegpu::LayoutAttr layout = type.getLayoutAttr(); + if (layout && layout.isWgLayout()) + return failure(); + + int count; + SmallVector subShape; + std::tie(subShape, count) = getTileShapeAndCount(shape, layout); + + if (layout) + layout = layout.dropInstData(); + + auto newTy = xegpu::TensorDescType::get( + type.getContext(), subShape, elemTy, type.getEncoding(), layout); + result.append(count, newTy); + return success(); + }); + + xegpu::doSCFStructuralTypeConversionWithTensorType(op, converter); + + xegpu::UnrollOptions options; + options.setFilterConstraint( + [&](Operation *op) -> LogicalResult { return success(needsUnroll(op)); }); + + options.setNativeShapeFn([&](Operation *op) { return getTileShape(op); }); + + options.setUnrolledTypesFn([&](ShapedType type, ArrayRef tileShape) { + Type elemTy = type.getElementType(); + Type newTy; + + if (auto tdescTy = dyn_cast(type)) + newTy = xegpu::TensorDescType::get( + ctx, tileShape, elemTy, tdescTy.getEncoding(), + tdescTy.getLayoutAttr().dropInstData()); + else + newTy = type.clone(tileShape, elemTy); + + std::optional> ratio = + computeShapeRatio(type.getShape(), tileShape); + assert(ratio && "The shape of the type must be a multiple of tileShape."); + return SmallVector(computeProduct(*ratio), newTy); + }); + + RewritePatternSet patterns(ctx); + + vector::UnrollVectorOptions vectorOptions; + vectorOptions.setNativeShapeFn(options.nativeShape); + + populateXeGPUUnrollPatterns(patterns, options); + vector::populateVectorUnrollPatterns(patterns, vectorOptions); + + (void)applyPatternsGreedily(op, std::move(patterns)); + + op->walk([](Operation *op) { + // Remove the layout attributes cached per operands. + for (OpOperand &opr : op->getOpOperands()) { + std::string name = xegpu::getLayoutName(opr); + if (op->hasAttrOfType(name)) + op->removeAttr(name); + } + + // Update the layout attributes per result. + for (OpResult result : op->getOpResults()) { + std::string name = xegpu::getLayoutName(result); + if (auto layout = op->getAttrOfType(name)) { + op->removeAttr(name); + if (!isa(op)) + xegpu::setLayoutAttr(result, layout.dropInstData()); + } + } + + // Resolve unrealized conversion cast ops emulating pack/unpack + if (auto castOp = dyn_cast(op)) + resolveUnrealizedConversionCastOp(castOp); + }); +} diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index 992700524146a..c84906cc45568 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -62,8 +62,6 @@ constexpr unsigned packedSizeInBitsForDefault = 16; // Minimum packing size per register for DPAS A. constexpr unsigned packedSizeInBitsForDpasB = 32; // Minimum packing size per register for DPAS B. -static const char *const operandLayoutNamePrefix = "layout_operand_"; -static const char *const resultLayoutNamePrefix = "layout_result_"; namespace { @@ -729,10 +727,7 @@ class LayoutAttrAssignment { void LayoutAttrAssignment::assignToUsers(Value v, xegpu::LayoutAttr layout) { for (OpOperand &user : v.getUses()) { Operation *owner = user.getOwner(); - unsigned operandNumber = user.getOperandNumber(); - // Use a generic name for ease of querying the layout attribute later. - std::string attrName = - operandLayoutNamePrefix + std::to_string(operandNumber); + std::string attrName = xegpu::getLayoutName(user); owner->setAttr(attrName, layout); } } @@ -806,10 +801,10 @@ LogicalResult LayoutAttrAssignment::assign(Operation *op) { return success(); } // Otherwise simply attach the layout to the op itself. - for (auto [i, r] : llvm::enumerate(op->getResults())) { + for (auto r : op->getOpResults()) { xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(r); if (layoutInfo) { - std::string attrName = resultLayoutNamePrefix + std::to_string(i); + std::string attrName = xegpu::getLayoutName(r); op->setAttr(attrName, layoutInfo); // Attach the layout attribute to the users of the result. assignToUsers(r, layoutInfo); @@ -929,11 +924,8 @@ static SmallVector removeTemporaryLayoutAttributes(ArrayRef attrs) { SmallVector newAttrs; for (NamedAttribute attr : attrs) { - if (attr.getName().strref().contains(operandLayoutNamePrefix) || - attr.getName().strref().contains(resultLayoutNamePrefix)) { - continue; - } - newAttrs.push_back(attr); + if (!isa(attr.getValue())) + newAttrs.push_back(attr); } return newAttrs; } @@ -1336,11 +1328,10 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern { auto dpasOp = operand->get().getDefiningOp(); unsigned operandIdx = operand->getOperandNumber(); - std::string layoutAName = - llvm::formatv("{0}{1}", operandLayoutNamePrefix, 0).str(); - std::string layoutBName = - llvm::formatv("{0}{1}", operandLayoutNamePrefix, 1).str(); - auto layoutCName = llvm::formatv("{0}{1}", resultLayoutNamePrefix, 0).str(); + std::string layoutAName = xegpu::getLayoutName(dpasOp->getOpOperand(0)); + std::string layoutBName = xegpu::getLayoutName(dpasOp->getOpOperand(1)); + std::string layoutCName = xegpu::getLayoutName(dpasOp->getOpResult(0)); + xegpu::LayoutAttr layoutA = dpasOp->getAttrOfType(layoutAName); xegpu::LayoutAttr layoutB = diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index 44d45dd2eaec0..885477fe4cbd5 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Dialect/XeGPU/Transforms/Transforms.h" +#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" @@ -74,17 +75,7 @@ struct UnrollPattern : public OpRewritePattern { assert(vecTy.getRank() == static_cast(blockSize.size()) && "Expecting blockSize size to match the rank of destTy."); auto shape = vecTy.getShape(); - auto zeroAttr = rewriter.getZeroAttr(vecTy.getElementType()); - - Value result = rewriter.create( - loc, vecTy, DenseElementsAttr::get(vecTy, zeroAttr)); - for (auto [src, offsets] : - llvm::zip_equal(srcs, StaticTileOffsetRange(shape, blockSize))) { - SmallVector staticStrides(offsets.size(), 1); - result = rewriter.create( - loc, src, result, offsets, staticStrides); - } - return result; + return xegpu::createVectorWithShapeFromValues(rewriter, loc, srcs, shape); } if (isa(destTy)) { @@ -109,16 +100,8 @@ struct UnrollPattern : public OpRewritePattern { if (auto vecTy = dyn_cast(src.getType())) { assert(vecTy.getRank() == static_cast(blockSize.size()) && "Expecting blockSize size to match the rank of src."); - auto shape = vecTy.getShape(); - SmallVector results; - for (SmallVector offsets : - StaticTileOffsetRange(shape, blockSize)) { - SmallVector staticStrides(offsets.size(), 1); - auto slice = rewriter.create( - loc, src, offsets, blockSize, staticStrides); - results.push_back(slice); - } - return results; + return xegpu::extractVectorsWithShapeFromValue(rewriter, loc, src, + blockSize); } if (isa(src.getType())) { @@ -153,7 +136,7 @@ struct UnrollCreateNdOp : public UnrollPattern { ArrayRef shape = tdescTy.getShape(); std::optional> targetShape = getTargetShape(op); - if (!targetShape || llvm::equal(*targetShape, shape)) + if (!targetShape) return failure(); auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0]; @@ -204,10 +187,9 @@ struct UnrollUpdateNdOffsetOp : public UnrollPattern { PatternRewriter &rewriter) const override { Location loc = op.getLoc(); xegpu::TensorDescType tdescTy = op.getTensorDescType(); - ArrayRef shape = tdescTy.getShape(); std::optional> targetShape = getTargetShape(op); - if (!targetShape || llvm::equal(*targetShape, shape)) + if (!targetShape) return failure(); SmallVector convertedTdescTypes = @@ -233,10 +215,9 @@ struct UnrollPrefetchNdOp : public UnrollPattern { PatternRewriter &rewriter) const override { Location loc = op.getLoc(); xegpu::TensorDescType tdescTy = op.getTensorDescType(); - ArrayRef shape = tdescTy.getShape(); std::optional> targetShape = getTargetShape(op); - if (!targetShape || llvm::equal(*targetShape, shape)) + if (!targetShape) return failure(); SmallVector convertedTdescTypes = @@ -260,10 +241,9 @@ struct UnrollLoadNdOp : public UnrollPattern { Location loc = op.getLoc(); VectorType valueTy = op.getType(); xegpu::TensorDescType tdescTy = op.getTensorDescType(); - ArrayRef shape = tdescTy.getShape(); std::optional> targetShape = getTargetShape(op); - if (!targetShape || llvm::equal(*targetShape, shape)) + if (!targetShape) return failure(); Type elemTy = tdescTy.getElementType(); @@ -295,10 +275,9 @@ struct UnrollStoreNdOp : public UnrollPattern { Location loc = op.getLoc(); VectorType valueTy = op.getValueType(); xegpu::TensorDescType tdescTy = op.getTensorDescType(); - ArrayRef shape = tdescTy.getShape(); std::optional> targetShape = getTargetShape(op); - if (!targetShape || llvm::equal(*targetShape, shape)) + if (!targetShape) return failure(); SmallVector convertedValTypes = diff --git a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt index afd8e2d5c4df3..98e84a4420722 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt @@ -6,5 +6,6 @@ add_mlir_dialect_library(MLIRXeGPUUtils LINK_LIBS PUBLIC MLIRIR + MLIRSCFTransforms MLIRXeGPUDialect ) diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp index 6b45ed0ae4ced..dcaf4e85a82c5 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp +++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp @@ -11,12 +11,29 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" #include #include using namespace mlir; +/// convert ArrayRef into SmallVector +static SmallVector flattenValues(ArrayRef values) { + SmallVector result; + for (const auto &vals : values) + llvm::append_range(result, vals); + return result; +} + FailureOr mlir::xegpu::getDistributedVectorType(xegpu::TensorDescType tdescTy) { auto layout = llvm::dyn_cast_if_present(tdescTy.getLayout()); @@ -83,3 +100,278 @@ mlir::xegpu::getDistributedVectorType(VectorType originalType, /*memory_space=*/xegpu::MemorySpace::Global, layout); return xegpu::getDistributedVectorType(helperTdescTy); } + +std::string xegpu::getLayoutName(const OpOperand &operand) { + const StringRef prefix("layout_operand_"); + unsigned idx = const_cast(operand).getOperandNumber(); + return llvm::formatv("{0}{1}", prefix, idx).str(); +} + +std::string xegpu::getLayoutName(const OpResult result) { + const StringRef prefix = "layout_result_"; + return llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str(); +} + +xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) { + if (!value) + return nullptr; + + if (auto tdescTy = + dyn_cast_if_present(value.getType())) + return tdescTy.getLayoutAttr(); + + if (auto result = dyn_cast(value)) { + Operation *defOp = result.getDefiningOp(); + assert(defOp && "result must have a defining op"); + + // for LoadNdOp, the layout is stored in the tensor descriptor + if (auto loadNd = dyn_cast(defOp)) + return getLayoutAttr(loadNd.getTensorDesc()); + + std::string layoutName = getLayoutName(result); + if (defOp->hasAttr(layoutName)) + return defOp->getAttrOfType(layoutName); + } + + if (auto arg = dyn_cast(value)) { + auto parentOp = arg.getOwner()->getParentOp(); + if (auto loop = dyn_cast(parentOp)) { + OpOperand *tiedInit = loop.getTiedLoopInit(arg); + return getLayoutAttr(tiedInit->get()); + } + } + + return nullptr; +} + +xegpu::LayoutAttr xegpu::getLayoutAttr(const OpOperand &opr) { + Operation *op = opr.getOwner(); + std::string layoutName = xegpu::getLayoutName(opr); + if (op->hasAttr(layoutName)) + return op->getAttrOfType(layoutName); + return getLayoutAttr(opr.get()); +} + +template +void xegpu::setLayoutAttr(const T &operandOrResult, const LayoutAttr layout) { + Operation *owner = operandOrResult.getOwner(); + std::string name = xegpu::getLayoutName(operandOrResult); + if (layout && !owner->hasAttrOfType(name)) + owner->setAttr(name, layout); +} + +// Explicit instantiation for OpResult +template void +xegpu::setLayoutAttr(const mlir::OpResult &result, + const mlir::xegpu::LayoutAttr layout); + +// Explicit instantiation for OpOperand +template void +xegpu::setLayoutAttr(const mlir::OpOperand &operand, + const mlir::xegpu::LayoutAttr layout); + +void xegpu::setLayoutAttrs(Operation *op, + function_ref getLayoutImpl) { + op->walk([&](Operation *nestOp) { + for (OpOperand &opr : nestOp->getOpOperands()) { + auto layout = getLayoutImpl(opr.get()); + setLayoutAttr(opr, layout); + } + for (OpResult result : nestOp->getOpResults()) { + auto layout = getLayoutImpl(result); + setLayoutAttr(result, layout); + } + }); +} + +SmallVector +xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, + Value value, ArrayRef shape) { + auto vecTy = dyn_cast(value.getType()); + if (!vecTy) + return {value}; + + ArrayRef srcShape = vecTy.getShape(); + if (!computeShapeRatio(srcShape, shape)) + return {value}; + + SmallVector result; + for (SmallVector offsets : StaticTileOffsetRange(srcShape, shape)) { + SmallVector staticStrides(offsets.size(), 1); + result.push_back(builder.create( + loc, value, offsets, shape, staticStrides)); + } + + return result; +} + +Value xegpu::createVectorWithShapeFromValues(OpBuilder &builder, Location loc, + ValueRange values, + ArrayRef shape) { + VectorType inputTy = dyn_cast(values[0].getType()); + assert(llvm::all_of(values.getTypes(), + [&](Type type) { return type == inputTy; }) && + "values must be of the same VectorType"); + + Type elemTy = inputTy.getElementType(); + ArrayRef tileShape = inputTy.getShape(); + + VectorType resultTy = VectorType::get(shape, elemTy); + auto zeroAttr = builder.getZeroAttr(elemTy); + Value result = builder.create( + loc, resultTy, DenseElementsAttr::get(resultTy, zeroAttr)); + + for (auto [src, offsets] : + llvm::zip_equal(values, StaticTileOffsetRange(shape, tileShape))) { + SmallVector staticStrides(offsets.size(), 1); + result = builder.create( + loc, src, result, offsets, staticStrides); + } + return result; +} + +void xegpu::doSCFStructuralTypeConversionWithTensorType( + Operation *op, TypeConverter converter) { + MLIRContext *context = op->getContext(); + + auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs, + Location loc) -> Value { + return builder.create(loc, type, inputs) + .getResult(0); + }; + + { // convert VectorType to RankedTensorType for SCF Structural ops + TypeConverter converter; + converter.addConversion([](Type type) -> Type { return type; }); + converter.addConversion([](VectorType type) -> Type { + return RankedTensorType::get(type.getShape(), type.getElementType()); + }); + converter.addSourceMaterialization(materializeCast); + converter.addTargetMaterialization(materializeCast); + + mlir::ConversionTarget target(*context); + target.addLegalOp(); + + mlir::RewritePatternSet patterns(context); + scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, + target); + (void)mlir::applyPartialConversion(op, target, std::move(patterns)); + } + + { // propagate the layout attribute to RankedTensorType by checking + // BuiltInUnrealizedCastOps + // for VectorType to RankedTensorType cast. + op->walk([](UnrealizedConversionCastOp castOp) { + if (castOp.getNumOperands() != 1 || castOp.getNumResults() != 1) + return WalkResult::skip(); + + Value input = castOp.getInputs()[0]; + Value result = castOp.getResults()[0]; + auto inputTy = dyn_cast(input.getType()); + auto resultTy = dyn_cast(result.getType()); + + // Only look at ops casting from VectorType to RankedTensorType + if (!isa(inputTy) || !isa(resultTy)) + return WalkResult::skip(); + + xegpu::LayoutAttr layout = xegpu::getLayoutAttr(input); + if (!layout) + return WalkResult::skip(); + + RankedTensorType newTy = resultTy.cloneWithEncoding(layout); + result.setType(newTy); + + // update the arguments if user is a LoopLike op. + for (OpOperand &use : result.getUses()) { + if (auto loop = dyn_cast(use.getOwner())) { + BlockArgument arg = loop.getTiedLoopRegionIterArg(&use); + arg.setType(newTy); + } + // whileOp has two regions, the BlockArgument of the after region + // is not exposed by LoopLikeOpInterface + if (auto whileOp = dyn_cast(use.getOwner())) { + unsigned idx = use.getOperandNumber(); + BlockArgument arg = whileOp.getAfterArguments()[idx]; + arg.setType(newTy); + } + } + return WalkResult::advance(); + }); + + // using yieldOp as anchor to update the result type of its ParentOp + op->walk([](scf::YieldOp yieldOp) { + Operation *parentOp = yieldOp->getParentOp(); + for (OpResult r : parentOp->getOpResults()) { + unsigned idx = r.getResultNumber(); + Type resultTy = r.getType(); + Type yieldTy = yieldOp.getResults()[idx].getType(); + if (isa(resultTy) && yieldTy != resultTy) + r.setType(yieldTy); + } + }); + } + + { // perform the conversion from RankedTensorType to VectorType based on the + // LayoutAttr + + // Handle the UnrealizedConversionCastOp introduced by the first step. + // For vector->RankedTensorType, it will simply forward the inputs. + // For RankedTensorType->vector, it will update the inputs with the + // one from the adaptor. + class UnrealizedConversionCastOpPattern + : public OpConversionPattern { + using OpConversionPattern< + mlir::UnrealizedConversionCastOp>::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(mlir::UnrealizedConversionCastOp op, + OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto inputs = op.getOperands(); + auto outputs = op.getOutputs(); + + if (inputs.size() != 1 || outputs.size() != 1) + return failure(); + + auto inputTy = inputs[0].getType(); + auto outputTy = outputs[0].getType(); + + if (isa(inputTy) && isa(outputTy)) { + rewriter.replaceOpWithMultiple(op, adaptor.getInputs()); + return success(); + } + + if (isa(inputTy) && isa(outputTy)) { + SmallVector values = flattenValues(adaptor.getInputs()); + auto newOp = rewriter.create( + op.getLoc(), outputTy, values); + rewriter.replaceOp(op, newOp); + return success(); + } + return failure(); + } + }; + + converter.addSourceMaterialization(materializeCast); + converter.addTargetMaterialization([&](OpBuilder &builder, TypeRange type, + ValueRange inputs, Location loc) { + return builder.create(loc, type, inputs) + .getResults(); + }); + + mlir::ConversionTarget target(*context); + target.addDynamicallyLegalOp( + [](UnrealizedConversionCastOp op) { + auto isTensorTy = [](Type type) { + return isa(type); + }; + return llvm::none_of(op->getOperandTypes(), isTensorTy) && + llvm::none_of(op->getResultTypes(), isTensorTy); + }); + mlir::RewritePatternSet patterns(context); + patterns.insert(context); + scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, + target); + (void)mlir::applyPartialConversion(op, target, std::move(patterns)); + } +} diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir new file mode 100644 index 0000000000000..f9114988686c8 --- /dev/null +++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir @@ -0,0 +1,248 @@ +// RUN: mlir-opt --xegpu-blocking -split-input-file %s | FileCheck %s + +#a = #xegpu.layout +#b = #xegpu.layout +#c = #xegpu.layout +gpu.module @test_kernel { + gpu.func @test_gemm_with_one_to_n_lowering(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %m = arith.muli %block_id_x, %c16 : index + %n = arith.muli %block_id_y, %c32 : index + + %c_tdesc = xegpu.create_nd_tdesc %C[%m, %n] : memref<1024x1024xf32> -> !xegpu.tensor_desc<16x32xf32, #c> + %c_init = xegpu.load_nd %c_tdesc : !xegpu.tensor_desc<16x32xf32, #c> -> vector<16x32xf32> + + %a_tdesc = xegpu.create_nd_tdesc %A[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x32xf16, #a> + %b_tdesc = xegpu.create_nd_tdesc %B[%c0, %n] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x32xf16, #b> + %out:3 = scf.for %k = %c0 to %c1024 step %c32 + iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_init) + -> (!xegpu.tensor_desc<16x32xf16, #a>, !xegpu.tensor_desc<32x32xf16, #b>, vector<16x32xf32>) { + //CHECK-COUNT-4: xegpu.load_nd {{.*}} -> vector<8x16xf16> + %a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x32xf16, #a> -> vector<16x32xf16> + //CHECK-COUNT-4: xegpu.load_nd {{.*}} -> vector<16x16xf16> + %b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<32x32xf16, #b> -> vector<32x32xf16> + //CHECK-COUNT-8: xegpu.dpas {{.*}} {layout_result_0 = #xegpu.layout} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + %c = xegpu.dpas %a, %b, %arg2 {layout_result_0 = #c}: vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> + //CHECK-COUNT-4: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<8x16xf16, #xegpu.layout> + %a_next_tdesc = xegpu.update_nd_offset %arg0, [%c0, %c32] : !xegpu.tensor_desc<16x32xf16, #a> + //CHECK-COUNT-4: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout> + %b_next_tdesc = xegpu.update_nd_offset %arg1, [%c32, %c0] : !xegpu.tensor_desc<32x32xf16, #b> + scf.yield %a_next_tdesc, %b_next_tdesc, %c + : !xegpu.tensor_desc<16x32xf16, #a>, !xegpu.tensor_desc<32x32xf16, #b>, vector<16x32xf32> + } + //CHECK-COUNT-4: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout> + xegpu.store_nd %out#2, %c_tdesc: vector<16x32xf32>, !xegpu.tensor_desc<16x32xf32, #c> + gpu.return + } +} + +// ----- +#l1 = #xegpu.layout +#l2 = #xegpu.layout +gpu.module @test_kernel { + gpu.func @test_gemm_with_inst_data_only_attribute(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %m = arith.muli %block_id_x, %c16 : index + %n = arith.muli %block_id_y, %c32 : index + + %c_tdesc = xegpu.create_nd_tdesc %C[%m, %n] : memref<1024x1024xf32> -> !xegpu.tensor_desc<16x32xf32, #l1> + %c_init = xegpu.load_nd %c_tdesc : !xegpu.tensor_desc<16x32xf32, #l1> -> vector<16x32xf32> + + %a_tdesc = xegpu.create_nd_tdesc %A[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x32xf16, #l1> + %b_tdesc = xegpu.create_nd_tdesc %B[%c0, %n] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x32xf16, #l2> + %out:3 = scf.for %k = %c0 to %c1024 step %c32 + iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_init) + -> (!xegpu.tensor_desc<16x32xf16, #l1>, !xegpu.tensor_desc<32x32xf16, #l2>, vector<16x32xf32>) { + //CHECK-COUNT-4: xegpu.load_nd {{.*}} -> vector<8x16xf16> + %a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x32xf16, #l1> -> vector<16x32xf16> + //CHECK-COUNT-4: xegpu.load_nd {{.*}} -> vector<16x16xf16> + %b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<32x32xf16, #l2> -> vector<32x32xf16> + //CHECK-COUNT-8: xegpu.dpas {{.*}} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + %c = xegpu.dpas %a, %b, %arg2 {layout_result_0 = #l1}: vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> + //CHECK-COUNT-4: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<8x16xf16> + %a_next_tdesc = xegpu.update_nd_offset %arg0, [%c0, %c32] : !xegpu.tensor_desc<16x32xf16, #l1> + //CHECK-COUNT-4: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<16x16xf16> + %b_next_tdesc = xegpu.update_nd_offset %arg1, [%c32, %c0] : !xegpu.tensor_desc<32x32xf16, #l2> + scf.yield %a_next_tdesc, %b_next_tdesc, %c + : !xegpu.tensor_desc<16x32xf16, #l1>, !xegpu.tensor_desc<32x32xf16, #l2>, vector<16x32xf32> + } + //CHECK-COUNT-4: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %out#2, %c_tdesc: vector<16x32xf32>, !xegpu.tensor_desc<16x32xf32, #l1> + gpu.return + } +} + +// ----- +#l1 = #xegpu.layout +#l2 = #xegpu.layout +gpu.module @test_kernel { + gpu.func @test_gemm_with_one_to_one_lowering(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) { + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %m = arith.muli %block_id_x, %c8 : index + %n = arith.muli %block_id_y, %c32 : index + + %c_tdesc = xegpu.create_nd_tdesc %C[%m, %n] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x32xf32, #l1> + + //CHECK-COUNT-2: xegpu.load_nd {{.*}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + %c_init = xegpu.load_nd %c_tdesc : !xegpu.tensor_desc<8x32xf32, #l1> -> vector<8x32xf32> + + %a_tdesc = xegpu.create_nd_tdesc %A[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16, #l1> + %b_tdesc = xegpu.create_nd_tdesc %B[%c0, %n] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x32xf16, #l2> + %out:3 = scf.for %k = %c0 to %c1024 step %c16 + iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_init) + -> (!xegpu.tensor_desc<8x16xf16, #l1>, !xegpu.tensor_desc<16x32xf16, #l2>, vector<8x32xf32>) { + //CHECK: xegpu.load_nd {{.*}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16, #l1> -> vector<8x16xf16> + //CHECK-COUNT-2: xegpu.load_nd {{.*}} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> + %b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x32xf16, #l2> -> vector<16x32xf16> + %c = xegpu.dpas %a, %b, %arg2 {layout_result_0 = #l1}: vector<8x16xf16>, vector<16x32xf16>, vector<8x32xf32> -> vector<8x32xf32> + //CHECK: xegpu.update_nd_offset {{.*}} [%c0, %c32] : !xegpu.tensor_desc<8x16xf16> + %a_next_tdesc = xegpu.update_nd_offset %arg0, [%c0, %c32] : !xegpu.tensor_desc<8x16xf16, #l1> + //CHECK-COUNT-2: xegpu.update_nd_offset {{.*}} [%c32, %c0] : !xegpu.tensor_desc<16x16xf16> + %b_next_tdesc = xegpu.update_nd_offset %arg1, [%c32, %c0] : !xegpu.tensor_desc<16x32xf16, #l2> + scf.yield %a_next_tdesc, %b_next_tdesc, %c + : !xegpu.tensor_desc<8x16xf16, #l1>, !xegpu.tensor_desc<16x32xf16, #l2>, vector<8x32xf32> + } + //CHECK-COUNT-2: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %out#2, %c_tdesc: vector<8x32xf32>, !xegpu.tensor_desc<8x32xf32, #l1> + gpu.return + } +} + +// ----- +#a = #xegpu.layout +#b = #xegpu.layout +#c = #xegpu.layout +gpu.module @test_kernel { + gpu.func @test_gemm_with_elemwise_preop(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %m = arith.muli %block_id_x, %c16 : index + %n = arith.muli %block_id_y, %c32 : index + + %c_tdesc = xegpu.create_nd_tdesc %C[%m, %n] : memref<1024x1024xf32> -> !xegpu.tensor_desc<16x32xf32, #c> + %c_init = xegpu.load_nd %c_tdesc : !xegpu.tensor_desc<16x32xf32, #c> -> vector<16x32xf32> + + %a_tdesc = xegpu.create_nd_tdesc %A[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x32xf16, #a> + %b_tdesc = xegpu.create_nd_tdesc %B[%c0, %n] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x32xf16, #b> + %out:3 = scf.for %k = %c0 to %c1024 step %c32 + iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_init) + -> (!xegpu.tensor_desc<16x32xf16, #a>, !xegpu.tensor_desc<32x32xf16, #b>, vector<16x32xf32>) { + //CHECK-COUNT-4: xegpu.load_nd {{.*}} -> vector<8x16xf16> + %a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x32xf16, #a> -> vector<16x32xf16> + //CHECK-COUNT-4: xegpu.load_nd {{.*}} -> vector<16x16xf16> + %b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<32x32xf16, #b> -> vector<32x32xf16> + //CHECK-COUNT-4: math.exp {{.*}} : vector<8x16xf16> + %e = math.exp %a {layout_result_0 = #a} : vector<16x32xf16> + //CHECK-COUNT-8: xegpu.dpas {{.*}} {layout_result_0 = #xegpu.layout} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + %c = xegpu.dpas %e, %b, %arg2 {layout_result_0 = #c}: vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> + //CHECK-COUNT-4: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<8x16xf16, #xegpu.layout> + %a_next_tdesc = xegpu.update_nd_offset %arg0, [%c0, %c32] : !xegpu.tensor_desc<16x32xf16, #a> + //CHECK-COUNT-4: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout> + %b_next_tdesc = xegpu.update_nd_offset %arg1, [%c32, %c0] : !xegpu.tensor_desc<32x32xf16, #b> + scf.yield %a_next_tdesc, %b_next_tdesc, %c + : !xegpu.tensor_desc<16x32xf16, #a>, !xegpu.tensor_desc<32x32xf16, #b>, vector<16x32xf32> + } + //CHECK-COUNT-4: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout> + xegpu.store_nd %out#2, %c_tdesc: vector<16x32xf32>, !xegpu.tensor_desc<16x32xf32, #c> + gpu.return + } +} + +// ----- +#l = #xegpu.layout +gpu.module @test_kernel { + gpu.func @test_elementwise_with_inst_data_only(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf16>) { + %c0 = arith.constant 0 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %m = arith.muli %block_id_x, %c32 : index + + %a_tdesc = xegpu.create_nd_tdesc %A[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x32xf16, #l> + %b_tdesc = xegpu.create_nd_tdesc %B[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x32xf16, #l> + %c_tdesc = xegpu.create_nd_tdesc %C[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x32xf16, #l> + + %out:3 = scf.for %k = %c0 to %c1024 step %c32 + iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_tdesc) + -> (!xegpu.tensor_desc<16x32xf16, #l>, !xegpu.tensor_desc<16x32xf16, #l>, !xegpu.tensor_desc<16x32xf16, #l>) { + //CHECK-COUNT-8: xegpu.load_nd {{.*}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x32xf16, #l> -> vector<16x32xf16> + %b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x32xf16, #l> -> vector<16x32xf16> + + //CHECK-COUNT-4: arith.addf {{.*}} : vector<8x16xf16> + %c = arith.addf %a, %b {layout_result_0 = #l} : vector<16x32xf16> + + //CHECK-COUNT-4: xegpu.store_nd {{.*}} : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> + xegpu.store_nd %c, %arg2: vector<16x32xf16>, !xegpu.tensor_desc<16x32xf16, #l> + + //CHECK-COUNT-12: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<8x16xf16> + %a_next_tdesc = xegpu.update_nd_offset %arg0, [%c0, %c32] : !xegpu.tensor_desc<16x32xf16, #l> + %b_next_tdesc = xegpu.update_nd_offset %arg1, [%c0, %c32] : !xegpu.tensor_desc<16x32xf16, #l> + %c_next_tdesc = xegpu.update_nd_offset %arg2, [%c0, %c32] : !xegpu.tensor_desc<16x32xf16, #l> + scf.yield %a_next_tdesc, %b_next_tdesc, %c_next_tdesc + : !xegpu.tensor_desc<16x32xf16, #l>, !xegpu.tensor_desc<16x32xf16, #l>, !xegpu.tensor_desc<16x32xf16, #l> + } + gpu.return + } +} + +// ----- +#l = #xegpu.layout +gpu.module @test_kernel { + gpu.func @test_elementwise_1D(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf16>) { + %c0 = arith.constant 0 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %m = arith.muli %block_id_x, %c32 : index + + %a_tdesc = xegpu.create_nd_tdesc %A[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32xf16, #l> + %b_tdesc = xegpu.create_nd_tdesc %B[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32xf16, #l> + %c_tdesc = xegpu.create_nd_tdesc %C[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32xf16, #l> + + %out:3 = scf.for %k = %c0 to %c1024 step %c32 + iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_tdesc) + -> (!xegpu.tensor_desc<32xf16, #l>, !xegpu.tensor_desc<32xf16, #l>, !xegpu.tensor_desc<32xf16, #l>) { + //CHECK-COUNT-8: xegpu.load_nd {{.*}} : !xegpu.tensor_desc<8xf16> -> vector<8xf16> + %a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<32xf16, #l> -> vector<32xf16> + %b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<32xf16, #l> -> vector<32xf16> + + //CHECK-COUNT-4: arith.addf {{.*}} : vector<8xf16> + %c = arith.addf %a, %b {layout_result_0 = #l} : vector<32xf16> + + //CHECK-COUNT-4: xegpu.store_nd {{.*}} : vector<8xf16>, !xegpu.tensor_desc<8xf16> + xegpu.store_nd %c, %arg2: vector<32xf16>, !xegpu.tensor_desc<32xf16, #l> + + //CHECK-COUNT-12: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<8xf16> + %a_next_tdesc = xegpu.update_nd_offset %arg0, [%c32] : !xegpu.tensor_desc<32xf16, #l> + %b_next_tdesc = xegpu.update_nd_offset %arg1, [%c32] : !xegpu.tensor_desc<32xf16, #l> + %c_next_tdesc = xegpu.update_nd_offset %arg2, [%c32] : !xegpu.tensor_desc<32xf16, #l> + scf.yield %a_next_tdesc, %b_next_tdesc, %c_next_tdesc + : !xegpu.tensor_desc<32xf16, #l>, !xegpu.tensor_desc<32xf16, #l>, !xegpu.tensor_desc<32xf16, #l> + } + gpu.return + } +}