From 4c58d3d6a627f23425528668ffff92bcca8f1461 Mon Sep 17 00:00:00 2001 From: Jianhui Li Date: Tue, 7 Oct 2025 22:08:30 +0000 Subject: [PATCH 01/12] pass basic lowering test --- .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 22 ++ .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 6 +- .../mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 50 +++- .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 124 ++++++++++ mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 216 ++++++++++++++++++ mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 20 +- mlir/test/Conversion/XeGPUToXeVM/dpas.mlir | 6 +- .../XeGPUToXeVM/loadstore_matrix.mlir | 40 ++++ 8 files changed, 466 insertions(+), 18 deletions(-) create mode 100644 mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index 5695d5d515d7f..601e966b49890 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -716,8 +716,30 @@ def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> { return getAttrs().getAs("stride"); } + ArrayAttr getBlockAttr() { + return getAttrs().getAs("block"); + } + }]; } +def RowOriented : I32EnumAttrCase<"ROW", 0, "row">; +def ColOriented : I32EnumAttrCase<"COL", 1, "col">; +def MatrixAccessDirection : + I32EnumAttr<"MatrixAccessDirection", + "Matrix elements/vectors can have row or column direction", [ + RowOriented, ColOriented +]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::xegpu"; +} +def MatrixAccessDirectionAttr : + EnumAttr{ + let summary = [{Describe the direction of memory access for load_matrix and store_matrix.}]; + let assemblyFormat = "`<` $value `>`"; +} + #endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 73f9061f5debe..32d21bae8cd34 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -1298,8 +1298,7 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure, } def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, - AllElementTypesMatch<["mem_desc", "res"]>, - AllRanksMatch<["mem_desc", "res"]>]> { + AllElementTypesMatch<["mem_desc", "res"]>]> { let arguments = (ins XeGPU_MemDesc:$mem_desc, Variadic: $offsets, DenseI64ArrayAttr: $const_offsets, @@ -1344,8 +1343,7 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, } def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, - AllElementTypesMatch<["mem_desc", "data"]>, - AllRanksMatch<["mem_desc", "data"]>]> { + AllElementTypesMatch<["mem_desc", "data"]>]> { let arguments = (ins XeGPU_ValueType:$data, XeGPU_MemDesc:$mem_desc, diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td index 84902b2039643..c261fbb576642 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td @@ -237,7 +237,7 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m return MemDescType::get(getContext(), shape.value_or(getShape()), elementType, getMemLayout()); } - ArrayAttr getStrides() { + ArrayAttr getStridesAttr() { auto layout = getMemLayout(); if (layout && layout.hasAttr("stride")) { return layout.getStrides(); @@ -250,6 +250,54 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m Builder builder(getContext()); return builder.getI64ArrayAttr(defaultStrides); } + + /// Heuristic to determine if the MemDesc uses column-major layout, + /// based on the rank and the value of the first stride dimension. + bool isColMajor() { + auto dim0 = dyn_cast(getStridesAttr()[0]); + return getRank() == 2 && dim0 && dim0.getInt() == 1; + } + + // get the Blocking shape for a MemDescType, Which is represented + // as an attribute in MemDescType. By default it is the shape + // of the mdescTy + SmallVector getBlockSize() { + SmallVector size(getShape()); + MemLayoutAttr layout = getMemLayout(); + if (layout && layout.hasAttr("block")) { + ArrayAttr attr = layout.getBlockAttr(); + size.clear(); + llvm::for_each(attr, [&](Attribute elem) { + if (auto intElem = dyn_cast(elem)) + size.push_back(intElem.getInt()); + }); + } + return size; + } + + // Get strides as vector of integer. + // If it contains block attribute, the strides are blocked strides. + // + // The blocking is applied against the original matrix shape + // so that the linear offset is not impacted by the subview. + // + // It first computes the original matrix shape using the stride info, + // then computes the number of blocks in each dimension of original shape, + // then compute the outer block shape and stride, + // then combines the inner and outer block shape and stride + // e.g. for mem_desc<32x256xf16, @block=[16, 8], @strides=[1, 32]> + // its memory layout tuple is ([2,32,16,8],[128,256,1,16]) + // for mem_desc<256x32xf16, @block=[8, 16]> with default @stride[32, 1] + // its memory layout tuple is ([32,2,8,16],[256,128,16,1]) + SmallVector getStrides(); + + /// Generates instructions to compute the linearize offset + // if the memory descriptor is blocked, it returns linearize offset based on the blocked layout + // the strides of memory descriptor is always considered regardless of blocked or not + Value getLinearOffsets(OpBuilder &builder, + Location loc, ArrayRef offsets); + + }]; let hasCustomAssemblyFormat = true; diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 9ead1d89069d6..666df293bb8be 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -32,6 +32,8 @@ #include +#define DEBUG_TYPE "xegpu-to-xevm" + namespace mlir { #define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS #include "mlir/Conversion/Passes.h.inc" @@ -60,6 +62,9 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) { return static_cast(xevm::AddrSpace::GLOBAL); case xegpu::MemorySpace::SLM: return static_cast(xevm::AddrSpace::SHARED); + default: + llvm_unreachable("Unknown XeGPU memory space"); + return static_cast(xevm::AddrSpace::GLOBAL); } } @@ -366,6 +371,7 @@ static Value addOffset(ConversionPatternRewriter &rewriter, Location loc, Value baseAddr, Value offset, int64_t elemByteSize) { Value byteSize = arith::ConstantIntOp::create( rewriter, loc, rewriter.getI64Type(), elemByteSize); + offset = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(), offset); Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize); Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset); return newAddr; @@ -503,6 +509,113 @@ class LoadStoreToXeVMPattern : public OpConversionPattern { } }; +// Lower xegpu::CreateMemDescOp to memref::ViewOp. Since SLM access instructions +// on Xe2 and Xe3 operate on 32-bit or 64-bit units, all data types smaller than +// 32 bits will be converted to 32 bits. +class CreateMemDescOpPattern final + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // DEBUG: Print operation and types + LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Lowering CreateMemDescOp: " << op << "\n"); + TypedValue src = op.getSource(); + auto resTy = cast(op.getResult().getType()); + + // Create the result MemRefType with the same shape, element type, and memory space + auto newResTy = getTypeConverter()->convertType(resTy); + + LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Source MemRefType: " << src.getType() << "\n"); + LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Result MemDescType: " << resTy << "\n"); + LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Converted MemRefType: " << newResTy << "\n"); + Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0); + auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy, Value(src), zero, + ValueRange()); + rewriter.replaceOp(op, viewOp); + LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Replaced CreateMemDescOp with memref::ViewOp\n"); + return success(); + } +}; + +class MemDescSubviewOpPattern final + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::MemDescSubviewOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return rewriter.notifyMatchFailure( + op, "MemDescSubviewOp are not supported on Xe2/Xe3 architecture."); + } +}; + + +template ::value>> +class LoadStoreMatrixToXeVMPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + SmallVector offsets = op.getMixedOffsets(); + if (offsets.empty()) + return rewriter.notifyMatchFailure(op, "Expected offset to be provided."); + + auto loc = op.getLoc(); + auto ctxt = rewriter.getContext(); + Value basePtrStruct = adaptor.getMemDesc(); + Value mdescVal = op.getMemDesc(); + // Load result or Store value Type can be vector or scalar. + Value data; + if constexpr (std::is_same_v) + data = op.getResult(); + else + data = adaptor.getData(); + VectorType valOrResVecTy = dyn_cast(data.getType()); + + int64_t elemBitWidth = valOrResVecTy.getElementType().getIntOrFloatBitWidth(); + // Element type must be multiple of 8 bits. + if (elemBitWidth % 8 != 0) + return rewriter.notifyMatchFailure( + op, "Expected element type bit width to be multiple of 8."); + int64_t elemByteSize = elemBitWidth / 8; + + // Default memory space is SLM. + LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get( + ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM)); + + auto mdescTy = cast(mdescVal.getType()); + + Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, basePtrStruct); + + // Convert base pointer (ptr) to i64 + Value basePtrI64 = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(), basePtrLLVM); + + Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets); + basePtrI64 = addOffset(rewriter, loc, basePtrI64, linearOffset, elemByteSize); + + // convert base pointer (i64) to LLVM pointer type + basePtrLLVM = + LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64); + + if constexpr (std::is_same_v) { + + Value loadOp = + LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM); + rewriter.replaceOp(op, loadOp); + } else { + auto storeOp = + LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM); + rewriter.eraseOp(op); + } + return success(); + } +}; + class PrefetchToXeVMPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -785,6 +898,13 @@ struct ConvertXeGPUToXeVMPass auto i32Type = IntegerType::get(&getContext(), 32); return VectorType::get(8, i32Type); }); + // Convert MemDescType into flattened MemRefType for SLM + typeConverter.addConversion([&](xegpu::MemDescType type) -> Type { + Type elemTy = type.getElementType(); + int numElems = type.getNumElements(); + return MemRefType::get(numElems, elemTy, AffineMap(), 3); + }); + typeConverter.addConversion([&](MemRefType type) -> Type { // Convert MemRefType to i64 type. return IntegerType::get(&getContext(), 64); @@ -919,6 +1039,10 @@ void mlir::populateXeGPUToXeVMConversionPatterns( LoadStoreToXeVMPattern, LoadStoreToXeVMPattern>( typeConverter, patterns.getContext()); + patterns.add, + LoadStoreMatrixToXeVMPattern, + CreateMemDescOpPattern, MemDescSubviewOpPattern>( + typeConverter, patterns.getContext()); patterns.add(typeConverter, patterns.getContext()); } diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 94c5509fd7c29..c64699c12cf4a 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -37,6 +37,8 @@ void XeGPUDialect::initialize() { >(); } +#define DEBUG_TYPE "xegpu" + /// Generates instructions to compute offsets for a subgroup identified by /// its multidimensional indices (sgId), using the specified subgroup layout /// (sgLayout), subgroup data dimensions (sizePerSg), and the overall data @@ -726,6 +728,220 @@ void MemLayoutAttr::print(AsmPrinter &printer) const { } printer << ">"; } +// a helper utility to perform binary operation on OpFoldResult. +// If both a and b are attributes, it will simply return the result. +// Otherwise, the corresponding arith op will be generated, and an +// contant op will be created if one of them is an attribute. +template +OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc, + OpBuilder &builder) { + auto aVal = getValueOrCreateConstantIndexOp(builder, loc, a); + auto bVal = getValueOrCreateConstantIndexOp(builder, loc, b); + return builder.create(loc, aVal, bVal).getResult(); +} + +// a helper utility to perform division operation on OpFoldResult and int64_t. +#define div(a, b) \ + genBinOp(a, builder.getIndexAttr(b), loc, builder) + +// a helper utility to perform reminder operation on OpFoldResult and int64_t. +#define rem(a, b) \ + genBinOp(a, builder.getIndexAttr(b), loc, builder) + +// a helper utility to perform multiply operation on OpFoldResult and int64_t. +#define mul(a, b) \ + genBinOp(a, builder.getIndexAttr(b), loc, builder) + +// a helper utility to perform addition operation on two OpFoldResult. +#define add(a, b) genBinOp(a, b, loc, builder) + +// block the given offsets according to the block shape +// say the original offset is [y, x], and the block shape is [By, Bx], +// then the blocked offset is [y/By, x/Bx, y%By, x%Bx] +SmallVector getBlockedOffsets(OpBuilder &builder, Location loc, + ArrayRef offsets, + ArrayRef blockShape) { + + assert(offsets.size() == blockShape.size() && + "offsets and blockShape must have the same size"); + SmallVector blockedOffsets; + SmallVector divs, rems; + + for (auto [offset, block] : llvm::zip(offsets, blockShape)) { + divs.push_back(div(offset, block)); + rems.push_back(rem(offset, block)); + } + blockedOffsets.append(divs.begin(), divs.end()); + blockedOffsets.append(rems.begin(), rems.end()); + + return blockedOffsets; +} + +// Get strides as vector of integer for MemDesc. +SmallVector MemDescType::getStrides() { + + SmallVector matrixShape(getShape().begin(), + getShape().end()); + + ArrayAttr strideAttr = getStridesAttr(); + SmallVector strides; + for (Attribute attr : strideAttr.getValue()) { + strides.push_back(cast(attr).getInt()); + } + + llvm::dbgs() << "DEBUG: matrixShape = ["; + for (size_t i = 0; i < matrixShape.size(); ++i) { + llvm::dbgs() << matrixShape[i]; + if (i < matrixShape.size() - 1) llvm::dbgs() << ", "; + } + llvm::dbgs() << "]\n"; + + llvm::dbgs() << "DEBUG: strides = ["; + for (size_t i = 0; i < strides.size(); ++i) { + llvm::dbgs() << strides[i]; + if (i < strides.size() - 1) llvm::dbgs() << ", "; + } + llvm::dbgs() << "]\n"; + + SmallVector innerBlkShape = getBlockSize(); + llvm::dbgs() << "DEBUG: innerBlkShape = ["; + for (size_t i = 0; i < innerBlkShape.size(); ++i) { + llvm::dbgs() << innerBlkShape[i]; + if (i < innerBlkShape.size() - 1) llvm::dbgs() << ", "; + } + llvm::dbgs() << "]\n"; + + if (innerBlkShape.empty()) + return strides; + + SmallVector perm = llvm::to_vector<4>( + llvm::seq(0, strides.size())); + llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; }); + + llvm::dbgs() << "DEBUG: perm = ["; + for (size_t i = 0; i < perm.size(); ++i) { + llvm::dbgs() << perm[i]; + if (i < perm.size() - 1) llvm::dbgs() << ", "; + } + llvm::dbgs() << "]\n"; + + assert(strides[perm[0]] == 1 && "inner most dim must have stride 1"); + + SmallVector innerBlkStride = computeStrides(innerBlkShape); + + llvm::dbgs() << "DEBUG: innerBlkStride = ["; + for (size_t i = 0; i < innerBlkStride.size(); ++i) { + llvm::dbgs() << innerBlkStride[i]; + if (i < innerBlkStride.size() - 1) llvm::dbgs() << ", "; + } + llvm::dbgs() << "]\n"; + + // compute the original matrix shape using the stride info + // and compute the number of blocks in each dimension + // The shape of highest dim can't be derived from stride info, + // but doesn't impact the stride computation for blocked layout. + SmallVector matrixShapeOrig(matrixShape.size()); + SmallVector BlkShapeOrig(matrixShape.size()); + for (size_t i = 0; i < perm.size() - 1; ++i) { + matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]]; + BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]]; + } + + llvm::dbgs() << "DEBUG: matrixShapeOrig = ["; + for (size_t i = 0; i < matrixShapeOrig.size(); ++i) { + llvm::dbgs() << matrixShapeOrig[i]; + if (i < matrixShapeOrig.size() - 1) llvm::dbgs() << ", "; + } + llvm::dbgs() << "]\n"; + + llvm::dbgs() << "DEBUG: BlkShapeOrig = ["; + for (size_t i = 0; i < BlkShapeOrig.size(); ++i) { + llvm::dbgs() << BlkShapeOrig[i]; + if (i < BlkShapeOrig.size() - 1) llvm::dbgs() << ", "; + } + llvm::dbgs() << "]\n"; + + int64_t innerBlkSize = 1; + for (auto s : innerBlkShape) + innerBlkSize *= s; + + llvm::dbgs() << "DEBUG: innerBlkSize = " << innerBlkSize << "\n"; + + SmallVector outerBlkStride(matrixShape.size()); + outerBlkStride[perm[0]] = innerBlkSize; + for (size_t i = 0; i < perm.size() - 1; ++i) { + outerBlkStride[perm[i + 1]] = + outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]]; + } + + llvm::dbgs() << "DEBUG: outerBlkStride = ["; + for (size_t i = 0; i < outerBlkStride.size(); ++i) { + llvm::dbgs() << outerBlkStride[i]; + if (i < outerBlkStride.size() - 1) llvm::dbgs() << ", "; + } + llvm::dbgs() << "]\n"; + + // combine the inner and outer strides + SmallVector blockedStrides; + blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end()); + blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end()); + + llvm::dbgs() << "DEBUG: blockedStrides = ["; + for (size_t i = 0; i < blockedStrides.size(); ++i) { + llvm::dbgs() << blockedStrides[i]; + if (i < blockedStrides.size() - 1) llvm::dbgs() << ", "; + } + llvm::dbgs() << "]\n"; + + return blockedStrides; + } + +// Calculate the linear offset using the blocked offsets and stride +Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc, + ArrayRef offsets) { + + SmallVector blockShape = getBlockSize(); + SmallVector strides = getStrides(); + + LLVM_DEBUG(llvm::dbgs() << "getLinearOffsets: blockShape=["; + llvm::interleaveComma(blockShape, llvm::dbgs()); + llvm::dbgs() << "], strides=["; + llvm::interleaveComma(strides, llvm::dbgs()); + llvm::dbgs() << "]\n"); + + if (!blockShape.empty()) { + assert(offsets.size() == blockShape.size() && + "offsets and blockShape must have the same size"); + // say the original offset is [y, x], and the block shape is [By, Bx], + // then the blocked offset is [y/By, x/Bx, y%By, x%Bx] + SmallVector blockedOffsets; + SmallVector divs, rems; + + for (auto [offset, block] : llvm::zip(offsets, blockShape)) { + divs.push_back(div(offset, block)); + rems.push_back(rem(offset, block)); + } + blockedOffsets.append(divs.begin(), divs.end()); + blockedOffsets.append(rems.begin(), rems.end()); + + offsets = blockedOffsets; + LLVM_DEBUG(llvm::dbgs() << "getLinearOffsets: blocked offsets size=" + << offsets.size() << "\n"); + } + + // Start with initial value as matrix descriptor's base offset. + Value linearOffset = arith::ConstantIndexOp::create(builder, loc, 0); + for (size_t i = 0; i < offsets.size(); ++i) { + OpFoldResult mulResult = mul(offsets[i], strides[i]); + Value mulVal = getValueOrCreateConstantIndexOp(builder, loc, mulResult); + linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset); + } + + LLVM_DEBUG(llvm::dbgs() << "getLinearOffsets: final linearOffset=" + << linearOffset << "\n"); + + return linearOffset; +} } // namespace xegpu } // namespace mlir diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 81b5788d0b9b4..23e487787652d 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -1062,9 +1062,12 @@ LogicalResult LoadMatrixOp::verify() { ArrayRef valueShape = resTy.getShape(); ArrayRef mdescShape = mdescTy.getShape(); - if (llvm::any_of(llvm::zip_equal(valueShape, mdescShape), - [](auto p) { return std::get<0>(p) > std::get<1>(p); })) - return emitOpError("result shape must not exceed mem_desc shape."); + + if (valueShape.size() != 1) { + if (llvm::any_of(llvm::zip_equal(valueShape, mdescShape), + [](auto p) { return std::get<0>(p) > std::get<1>(p); })) + return emitOpError("result shape must not exceed mem_desc shape."); + } return success(); } @@ -1092,10 +1095,11 @@ LogicalResult StoreMatrixOp::verify() { ArrayRef dataShape = dataTy.getShape(); ArrayRef mdescShape = mdescTy.getShape(); - if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape), - [](auto p) { return std::get<0>(p) > std::get<1>(p); })) - return emitOpError("data shape must not exceed mem_desc shape."); - + if (dataShape.size() != 1) { + if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape), + [](auto p) { return std::get<0>(p) > std::get<1>(p); })) + return emitOpError("data shape must not exceed mem_desc shape."); + } return success(); } @@ -1127,7 +1131,7 @@ LogicalResult MemDescSubviewOp::verify() { [](auto p) { return std::get<0>(p) > std::get<1>(p); })) return emitOpError("result shape must not exceed source shape."); - if (srcTy.getStrides() != resTy.getStrides()) + if (srcTy.getStridesAttr() != resTy.getStridesAttr()) return emitOpError("result must inherit the source strides."); return success(); diff --git a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir index e6f22f0a9acbb..bbf313bf4fb60 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir @@ -1,10 +1,6 @@ // RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s -#sg_map_a_f16 = #xegpu.layout -#sg_map_b_f16 = #xegpu.layout -#sg_map_c_f32 = #xegpu.layout - -gpu.module @load_store_check { +gpu.module @test_kernel { // CHECK-LABEL: func.func @dpas( // CHECK-SAME: %[[ARG0:.*]]: vector<8xf16>, %[[ARG1:.*]]: vector<16xf16>, %[[ARG2:.*]]: vector<8xf32> func.func @dpas(%a_loaded: vector<8xf16>, %b_loaded: vector<16xf16>, %c_loaded: vector<8xf32>) -> vector<8xf32> { diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir new file mode 100644 index 0000000000000..30d6274c9dccf --- /dev/null +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir @@ -0,0 +1,40 @@ +// RUN: mlir-opt -split-input-file -convert-xegpu-to-xevm --cse --canonicalize %s | FileCheck %s + +gpu.module @test_kernel { + //CHECK-LABEL: load_store_matrix_1 + gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> vector<8xf32> { + %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32> + //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf32> + %tid_x = gpu.thread_id x + %c0 = arith.constant 0 : index + %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> vector<8xf32> + gpu.return %1: vector<8xf32> + } + + // e.g. for mem_desc<32x32xf16, @block=[16, 16], @strides=[1, 16]> + // its memory layout tuple is ([2,2,16,16],[256,512,1,16]) + + //CHECK-LABEL: load_store_matrix_2 + gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> vector<8xf32> { + %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf16, #xegpu.mem_layout> + //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf32> + %tid_x = gpu.thread_id x + %c0 = arith.constant 0 : index + %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> vector<8xf32> + gpu.return %1: vector<8xf32> + } + + // e.g. for mem_desc<32x32xf16, @block=[16, 16]> + // its memory layout tuple is ([2,2,16,16],[512,256,16,1]) + //CHECK-LABEL: load_store_matrix_3 + gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> vector<8xf32> { + %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf16, #xegpu.mem_layout> + //CHECK-COUNT-8: xegpu.load_matrix {{.*}} : !xegpu.mem_desc<32x32xf32>, index, index -> vector<8x16xf32> + //CHECK-COUNT-8: vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<32x32xf32> + %tid_x = gpu.thread_id x + %c0 = arith.constant 0 : index + %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> vector<8xf32> + gpu.return %1: vector<8xf32> + } + +} From 554b95edf3079fee2ac91ccd22078886244724f0 Mon Sep 17 00:00:00 2001 From: Jianhui Li Date: Tue, 7 Oct 2025 23:17:20 +0000 Subject: [PATCH 02/12] add attributes --- .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 3 + .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 63 +++-- mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 262 +++++++++--------- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 7 +- .../XeGPUToXeVM/loadstore_matrix.mlir | 47 ++-- mlir/test/Dialect/XeGPU/ops.mlir | 57 +++- 6 files changed, 253 insertions(+), 186 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 32d21bae8cd34..a0a8669baf90d 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -1302,6 +1302,9 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, let arguments = (ins XeGPU_MemDesc:$mem_desc, Variadic: $offsets, DenseI64ArrayAttr: $const_offsets, + OptionalAttr:$vec_length, + OptionalAttr:$vec_direction, + OptionalAttr:$subgroupBlockIO, OptionalAttr:$layout ); let results = (outs XeGPU_ValueType:$res); diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 666df293bb8be..97deca167204a 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -371,7 +371,8 @@ static Value addOffset(ConversionPatternRewriter &rewriter, Location loc, Value baseAddr, Value offset, int64_t elemByteSize) { Value byteSize = arith::ConstantIntOp::create( rewriter, loc, rewriter.getI64Type(), elemByteSize); - offset = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(), offset); + offset = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(), + offset); Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize); Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset); return newAddr; @@ -513,29 +514,36 @@ class LoadStoreToXeVMPattern : public OpConversionPattern { // on Xe2 and Xe3 operate on 32-bit or 64-bit units, all data types smaller than // 32 bits will be converted to 32 bits. class CreateMemDescOpPattern final - : public OpConversionPattern { + : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // DEBUG: Print operation and types - LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Lowering CreateMemDescOp: " << op << "\n"); - TypedValue src = op.getSource(); - auto resTy = cast(op.getResult().getType()); - - // Create the result MemRefType with the same shape, element type, and memory space - auto newResTy = getTypeConverter()->convertType(resTy); - - LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Source MemRefType: " << src.getType() << "\n"); - LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Result MemDescType: " << resTy << "\n"); - LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Converted MemRefType: " << newResTy << "\n"); - Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0); - auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy, Value(src), zero, - ValueRange()); - rewriter.replaceOp(op, viewOp); - LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Replaced CreateMemDescOp with memref::ViewOp\n"); - return success(); + ConversionPatternRewriter &rewriter) const override { + // DEBUG: Print operation and types + LLVM_DEBUG(llvm::dbgs() + << "[XeGPUToXeVM] Lowering CreateMemDescOp: " << op << "\n"); + TypedValue src = op.getSource(); + auto resTy = cast(op.getResult().getType()); + + // Create the result MemRefType with the same shape, element type, and + // memory space + auto newResTy = getTypeConverter()->convertType(resTy); + + LLVM_DEBUG(llvm::dbgs() + << "[XeGPUToXeVM] Source MemRefType: " << src.getType() << "\n"); + LLVM_DEBUG(llvm::dbgs() + << "[XeGPUToXeVM] Result MemDescType: " << resTy << "\n"); + LLVM_DEBUG(llvm::dbgs() + << "[XeGPUToXeVM] Converted MemRefType: " << newResTy << "\n"); + Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0); + auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy, + Value(src), zero, ValueRange()); + rewriter.replaceOp(op, viewOp); + LLVM_DEBUG( + llvm::dbgs() + << "[XeGPUToXeVM] Replaced CreateMemDescOp with memref::ViewOp\n"); + return success(); } }; @@ -551,7 +559,6 @@ class MemDescSubviewOpPattern final } }; - template ::value>> @@ -577,7 +584,8 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern { data = adaptor.getData(); VectorType valOrResVecTy = dyn_cast(data.getType()); - int64_t elemBitWidth = valOrResVecTy.getElementType().getIntOrFloatBitWidth(); + int64_t elemBitWidth = + valOrResVecTy.getElementType().getIntOrFloatBitWidth(); // Element type must be multiple of 8 bits. if (elemBitWidth % 8 != 0) return rewriter.notifyMatchFailure( @@ -589,14 +597,17 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern { ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM)); auto mdescTy = cast(mdescVal.getType()); - - Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, basePtrStruct); + + Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create( + rewriter, loc, basePtrStruct); // Convert base pointer (ptr) to i64 - Value basePtrI64 = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(), basePtrLLVM); + Value basePtrI64 = arith::IndexCastUIOp::create( + rewriter, loc, rewriter.getI64Type(), basePtrLLVM); Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets); - basePtrI64 = addOffset(rewriter, loc, basePtrI64, linearOffset, elemByteSize); + basePtrI64 = + addOffset(rewriter, loc, basePtrI64, linearOffset, elemByteSize); // convert base pointer (i64) to LLVM pointer type basePtrLLVM = diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index c64699c12cf4a..3cbb39ee9b144 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -777,168 +777,176 @@ SmallVector getBlockedOffsets(OpBuilder &builder, Location loc, return blockedOffsets; } -// Get strides as vector of integer for MemDesc. +// Get strides as vector of integer for MemDesc. SmallVector MemDescType::getStrides() { - SmallVector matrixShape(getShape().begin(), - getShape().end()); + SmallVector matrixShape(getShape().begin(), getShape().end()); - ArrayAttr strideAttr = getStridesAttr(); - SmallVector strides; - for (Attribute attr : strideAttr.getValue()) { - strides.push_back(cast(attr).getInt()); - } + ArrayAttr strideAttr = getStridesAttr(); + SmallVector strides; + for (Attribute attr : strideAttr.getValue()) { + strides.push_back(cast(attr).getInt()); + } - llvm::dbgs() << "DEBUG: matrixShape = ["; - for (size_t i = 0; i < matrixShape.size(); ++i) { - llvm::dbgs() << matrixShape[i]; - if (i < matrixShape.size() - 1) llvm::dbgs() << ", "; - } - llvm::dbgs() << "]\n"; + llvm::dbgs() << "DEBUG: matrixShape = ["; + for (size_t i = 0; i < matrixShape.size(); ++i) { + llvm::dbgs() << matrixShape[i]; + if (i < matrixShape.size() - 1) + llvm::dbgs() << ", "; + } + llvm::dbgs() << "]\n"; - llvm::dbgs() << "DEBUG: strides = ["; - for (size_t i = 0; i < strides.size(); ++i) { - llvm::dbgs() << strides[i]; - if (i < strides.size() - 1) llvm::dbgs() << ", "; - } - llvm::dbgs() << "]\n"; + llvm::dbgs() << "DEBUG: strides = ["; + for (size_t i = 0; i < strides.size(); ++i) { + llvm::dbgs() << strides[i]; + if (i < strides.size() - 1) + llvm::dbgs() << ", "; + } + llvm::dbgs() << "]\n"; + + SmallVector innerBlkShape = getBlockSize(); + llvm::dbgs() << "DEBUG: innerBlkShape = ["; + for (size_t i = 0; i < innerBlkShape.size(); ++i) { + llvm::dbgs() << innerBlkShape[i]; + if (i < innerBlkShape.size() - 1) + llvm::dbgs() << ", "; + } + llvm::dbgs() << "]\n"; - SmallVector innerBlkShape = getBlockSize(); - llvm::dbgs() << "DEBUG: innerBlkShape = ["; - for (size_t i = 0; i < innerBlkShape.size(); ++i) { - llvm::dbgs() << innerBlkShape[i]; - if (i < innerBlkShape.size() - 1) llvm::dbgs() << ", "; - } - llvm::dbgs() << "]\n"; + if (innerBlkShape.empty()) + return strides; - if (innerBlkShape.empty()) - return strides; + SmallVector perm = + llvm::to_vector<4>(llvm::seq(0, strides.size())); + llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; }); - SmallVector perm = llvm::to_vector<4>( - llvm::seq(0, strides.size())); - llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; }); + llvm::dbgs() << "DEBUG: perm = ["; + for (size_t i = 0; i < perm.size(); ++i) { + llvm::dbgs() << perm[i]; + if (i < perm.size() - 1) + llvm::dbgs() << ", "; + } + llvm::dbgs() << "]\n"; - llvm::dbgs() << "DEBUG: perm = ["; - for (size_t i = 0; i < perm.size(); ++i) { - llvm::dbgs() << perm[i]; - if (i < perm.size() - 1) llvm::dbgs() << ", "; - } - llvm::dbgs() << "]\n"; + assert(strides[perm[0]] == 1 && "inner most dim must have stride 1"); - assert(strides[perm[0]] == 1 && "inner most dim must have stride 1"); + SmallVector innerBlkStride = computeStrides(innerBlkShape); - SmallVector innerBlkStride = computeStrides(innerBlkShape); - - llvm::dbgs() << "DEBUG: innerBlkStride = ["; - for (size_t i = 0; i < innerBlkStride.size(); ++i) { - llvm::dbgs() << innerBlkStride[i]; - if (i < innerBlkStride.size() - 1) llvm::dbgs() << ", "; - } - llvm::dbgs() << "]\n"; - - // compute the original matrix shape using the stride info - // and compute the number of blocks in each dimension - // The shape of highest dim can't be derived from stride info, - // but doesn't impact the stride computation for blocked layout. - SmallVector matrixShapeOrig(matrixShape.size()); - SmallVector BlkShapeOrig(matrixShape.size()); - for (size_t i = 0; i < perm.size() - 1; ++i) { - matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]]; - BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]]; - } + llvm::dbgs() << "DEBUG: innerBlkStride = ["; + for (size_t i = 0; i < innerBlkStride.size(); ++i) { + llvm::dbgs() << innerBlkStride[i]; + if (i < innerBlkStride.size() - 1) + llvm::dbgs() << ", "; + } + llvm::dbgs() << "]\n"; + + // compute the original matrix shape using the stride info + // and compute the number of blocks in each dimension + // The shape of highest dim can't be derived from stride info, + // but doesn't impact the stride computation for blocked layout. + SmallVector matrixShapeOrig(matrixShape.size()); + SmallVector BlkShapeOrig(matrixShape.size()); + for (size_t i = 0; i < perm.size() - 1; ++i) { + matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]]; + BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]]; + } - llvm::dbgs() << "DEBUG: matrixShapeOrig = ["; - for (size_t i = 0; i < matrixShapeOrig.size(); ++i) { - llvm::dbgs() << matrixShapeOrig[i]; - if (i < matrixShapeOrig.size() - 1) llvm::dbgs() << ", "; - } - llvm::dbgs() << "]\n"; + llvm::dbgs() << "DEBUG: matrixShapeOrig = ["; + for (size_t i = 0; i < matrixShapeOrig.size(); ++i) { + llvm::dbgs() << matrixShapeOrig[i]; + if (i < matrixShapeOrig.size() - 1) + llvm::dbgs() << ", "; + } + llvm::dbgs() << "]\n"; - llvm::dbgs() << "DEBUG: BlkShapeOrig = ["; - for (size_t i = 0; i < BlkShapeOrig.size(); ++i) { - llvm::dbgs() << BlkShapeOrig[i]; - if (i < BlkShapeOrig.size() - 1) llvm::dbgs() << ", "; - } - llvm::dbgs() << "]\n"; + llvm::dbgs() << "DEBUG: BlkShapeOrig = ["; + for (size_t i = 0; i < BlkShapeOrig.size(); ++i) { + llvm::dbgs() << BlkShapeOrig[i]; + if (i < BlkShapeOrig.size() - 1) + llvm::dbgs() << ", "; + } + llvm::dbgs() << "]\n"; - int64_t innerBlkSize = 1; - for (auto s : innerBlkShape) - innerBlkSize *= s; + int64_t innerBlkSize = 1; + for (auto s : innerBlkShape) + innerBlkSize *= s; - llvm::dbgs() << "DEBUG: innerBlkSize = " << innerBlkSize << "\n"; + llvm::dbgs() << "DEBUG: innerBlkSize = " << innerBlkSize << "\n"; - SmallVector outerBlkStride(matrixShape.size()); - outerBlkStride[perm[0]] = innerBlkSize; - for (size_t i = 0; i < perm.size() - 1; ++i) { - outerBlkStride[perm[i + 1]] = + SmallVector outerBlkStride(matrixShape.size()); + outerBlkStride[perm[0]] = innerBlkSize; + for (size_t i = 0; i < perm.size() - 1; ++i) { + outerBlkStride[perm[i + 1]] = outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]]; - } - - llvm::dbgs() << "DEBUG: outerBlkStride = ["; - for (size_t i = 0; i < outerBlkStride.size(); ++i) { - llvm::dbgs() << outerBlkStride[i]; - if (i < outerBlkStride.size() - 1) llvm::dbgs() << ", "; - } - llvm::dbgs() << "]\n"; - - // combine the inner and outer strides - SmallVector blockedStrides; - blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end()); - blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end()); + } - llvm::dbgs() << "DEBUG: blockedStrides = ["; - for (size_t i = 0; i < blockedStrides.size(); ++i) { - llvm::dbgs() << blockedStrides[i]; - if (i < blockedStrides.size() - 1) llvm::dbgs() << ", "; - } - llvm::dbgs() << "]\n"; + llvm::dbgs() << "DEBUG: outerBlkStride = ["; + for (size_t i = 0; i < outerBlkStride.size(); ++i) { + llvm::dbgs() << outerBlkStride[i]; + if (i < outerBlkStride.size() - 1) + llvm::dbgs() << ", "; + } + llvm::dbgs() << "]\n"; + + // combine the inner and outer strides + SmallVector blockedStrides; + blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end()); + blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end()); + + llvm::dbgs() << "DEBUG: blockedStrides = ["; + for (size_t i = 0; i < blockedStrides.size(); ++i) { + llvm::dbgs() << blockedStrides[i]; + if (i < blockedStrides.size() - 1) + llvm::dbgs() << ", "; + } + llvm::dbgs() << "]\n"; - return blockedStrides; - } + return blockedStrides; +} // Calculate the linear offset using the blocked offsets and stride Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc, - ArrayRef offsets) { + ArrayRef offsets) { SmallVector blockShape = getBlockSize(); SmallVector strides = getStrides(); - + LLVM_DEBUG(llvm::dbgs() << "getLinearOffsets: blockShape=["; - llvm::interleaveComma(blockShape, llvm::dbgs()); - llvm::dbgs() << "], strides=["; - llvm::interleaveComma(strides, llvm::dbgs()); - llvm::dbgs() << "]\n"); - - if (!blockShape.empty()) { - assert(offsets.size() == blockShape.size() && - "offsets and blockShape must have the same size"); - // say the original offset is [y, x], and the block shape is [By, Bx], - // then the blocked offset is [y/By, x/Bx, y%By, x%Bx] - SmallVector blockedOffsets; - SmallVector divs, rems; + llvm::interleaveComma(blockShape, llvm::dbgs()); + llvm::dbgs() << "], strides=["; + llvm::interleaveComma(strides, llvm::dbgs()); + llvm::dbgs() << "]\n"); - for (auto [offset, block] : llvm::zip(offsets, blockShape)) { - divs.push_back(div(offset, block)); - rems.push_back(rem(offset, block)); - } - blockedOffsets.append(divs.begin(), divs.end()); - blockedOffsets.append(rems.begin(), rems.end()); + if (!blockShape.empty()) { + assert(offsets.size() == blockShape.size() && + "offsets and blockShape must have the same size"); + // say the original offset is [y, x], and the block shape is [By, Bx], + // then the blocked offset is [y/By, x/Bx, y%By, x%Bx] + SmallVector blockedOffsets; + SmallVector divs, rems; + + for (auto [offset, block] : llvm::zip(offsets, blockShape)) { + divs.push_back(div(offset, block)); + rems.push_back(rem(offset, block)); + } + blockedOffsets.append(divs.begin(), divs.end()); + blockedOffsets.append(rems.begin(), rems.end()); - offsets = blockedOffsets; - LLVM_DEBUG(llvm::dbgs() << "getLinearOffsets: blocked offsets size=" - << offsets.size() << "\n"); + offsets = blockedOffsets; + LLVM_DEBUG(llvm::dbgs() << "getLinearOffsets: blocked offsets size=" + << offsets.size() << "\n"); } // Start with initial value as matrix descriptor's base offset. Value linearOffset = arith::ConstantIndexOp::create(builder, loc, 0); for (size_t i = 0; i < offsets.size(); ++i) { - OpFoldResult mulResult = mul(offsets[i], strides[i]); - Value mulVal = getValueOrCreateConstantIndexOp(builder, loc, mulResult); - linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset); + OpFoldResult mulResult = mul(offsets[i], strides[i]); + Value mulVal = getValueOrCreateConstantIndexOp(builder, loc, mulResult); + linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset); } - LLVM_DEBUG(llvm::dbgs() << "getLinearOffsets: final linearOffset=" - << linearOffset << "\n"); + LLVM_DEBUG(llvm::dbgs() << "getLinearOffsets: final linearOffset=" + << linearOffset << "\n"); return linearOffset; } diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 23e487787652d..c40d5a42fb6e5 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -1049,8 +1049,11 @@ void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res, llvm::SmallVector staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); + // Call the generated builder with all parameters (including optional ones as + // nullptr/empty) build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr, - layout); + /*vec_length=*/nullptr, /*vec_direction=*/nullptr, + /*subgroupBlockIO=*/nullptr, layout); } LogicalResult LoadMatrixOp::verify() { @@ -1097,7 +1100,7 @@ LogicalResult StoreMatrixOp::verify() { ArrayRef mdescShape = mdescTy.getShape(); if (dataShape.size() != 1) { if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape), - [](auto p) { return std::get<0>(p) > std::get<1>(p); })) + [](auto p) { return std::get<0>(p) > std::get<1>(p); })) return emitOpError("data shape must not exceed mem_desc shape."); } return success(); diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir index 30d6274c9dccf..7b87f32b876fe 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir @@ -1,40 +1,41 @@ // RUN: mlir-opt -split-input-file -convert-xegpu-to-xevm --cse --canonicalize %s | FileCheck %s gpu.module @test_kernel { + + // e.g. for mem_desc<32x32xf16, @strides=[1, 16]> + // its memory layout tuple is (blocked shape = [1,1,32,32],strides=[1024,1024,32,1]) //CHECK-LABEL: load_store_matrix_1 - gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> vector<8xf32> { + gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> vector<1xf32> { %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32> - //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf32> + //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<1xf32> %tid_x = gpu.thread_id x %c0 = arith.constant 0 : index - %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> vector<8xf32> - gpu.return %1: vector<8xf32> + %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> vector<1xf32> + gpu.return %1: vector<1xf32> } - // e.g. for mem_desc<32x32xf16, @block=[16, 16], @strides=[1, 16]> - // its memory layout tuple is ([2,2,16,16],[256,512,1,16]) - + // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]> + // its memory layout tuple is ([2,4,16,16],[256,512,1,16]) //CHECK-LABEL: load_store_matrix_2 - gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> vector<8xf32> { - %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf16, #xegpu.mem_layout> - //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf32> + gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> vector<1xf16> { + %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> + //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<1xf16> %tid_x = gpu.thread_id x - %c0 = arith.constant 0 : index - %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> vector<8xf32> - gpu.return %1: vector<8xf32> + %c13 = arith.constant 13 : index + %1 = xegpu.load_matrix %0[%c13, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index -> vector<1xf16> + gpu.return %1: vector<1xf16> } - // e.g. for mem_desc<32x32xf16, @block=[16, 16]> - // its memory layout tuple is ([2,2,16,16],[512,256,16,1]) + // e.g. for mem_desc<32x64xf16, @block=[16, 16]> + // its memory layout tuple is ([2,4,16,16],[1024,256,16,1]) //CHECK-LABEL: load_store_matrix_3 - gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> vector<8xf32> { - %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf16, #xegpu.mem_layout> - //CHECK-COUNT-8: xegpu.load_matrix {{.*}} : !xegpu.mem_desc<32x32xf32>, index, index -> vector<8x16xf32> - //CHECK-COUNT-8: vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<32x32xf32> + gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> vector<1xf16> { + %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> + //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<1xf16> %tid_x = gpu.thread_id x - %c0 = arith.constant 0 : index - %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> vector<8xf32> - gpu.return %1: vector<8xf32> + %c17 = arith.constant 17 : index + %1 = xegpu.load_matrix %0[%c17, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index -> vector<1xf16> + gpu.return %1: vector<1xf16> } -} +} \ No newline at end of file diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir index bb379024a34d7..47aa05763ee99 100644 --- a/mlir/test/Dialect/XeGPU/ops.mlir +++ b/mlir/test/Dialect/XeGPU/ops.mlir @@ -825,35 +825,76 @@ gpu.func @create_mem_desc_with_stride() { gpu.return } -// CHECK: gpu.func @load_mem_desc([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>) -gpu.func @load_mem_desc(%arg0: !xegpu.mem_desc<16x64xf16>) { +// CHECK: gpu.func @load_matrix([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>) +gpu.func @load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) { // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16> %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16> gpu.return } -// CHECK: gpu.func @load_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) -gpu.func @load_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) { +// CHECK: gpu.func @load_matrix_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) +gpu.func @load_matrix_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) { // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8x16xf16> %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8x16xf16> gpu.return } +// CHECK: gpu.func @simt_load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) +gpu.func @simt_load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) { + // CHECK: xegpu.load_matrix [[ARG0]][8, 16] : !xegpu.mem_desc<16x64xf16> -> vector<1xf16> + %data = xegpu.load_matrix %arg0[8, 16]: !xegpu.mem_desc<16x64xf16> -> vector<1xf16> + gpu.return +} + +// CHECK: gpu.func @simt_load_matrix_subgroupblockIO(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) +gpu.func @simt_load_matrix_subgroupblockIO(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) { + // CHECK: xegpu.load_matrix [[ARG0]][8, 16] <{subgroupBlockIO}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8xf16> + %data = xegpu.load_matrix %arg0[8, 16] {subgroupBlockIO}: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8xf16> + gpu.return +} -// CHECK: gpu.func @store_mem_desc([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>, [[ARG1:%.+]]: vector<16x16xf16>) -gpu.func @store_mem_desc(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf16>) { +// CHECK: gpu.func @simt_load_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) +gpu.func @simt_load_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) { + // CHECK: xegpu.load_matrix [[ARG0]][8, 8] <{vec_direction = #xegpu.matrix_access_direction, vec_length = 8 : i32}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8xf16> + %data = xegpu.load_matrix %arg0[8, 8]{vec_direction = #xegpu.matrix_access_direction, vec_length = 8 : i32}: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8xf16> + gpu.return +} + +// CHECK: gpu.func @store_matrix([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>, [[ARG1:%.+]]: vector<16x16xf16>) +gpu.func @store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf16>) { // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16> xegpu.store_matrix %arg1, %arg0[8, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16> gpu.return } -// CHECK: gpu.func @store_mem_desc_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, [[ARG1:%.+]]: vector<16x16xf16>) -gpu.func @store_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, %arg1: vector<16x16xf16>) { +// CHECK: gpu.func @store_matrix_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, [[ARG1:%.+]]: vector<16x16xf16>) +gpu.func @store_matrix_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, %arg1: vector<16x16xf16>) { // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][0, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> xegpu.store_matrix %arg1, %arg0[0, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> gpu.return } +// CHECK: gpu.func @simt_store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<1xf16>) { +gpu.func @simt_store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<1xf16>) { + // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 16] : vector<1xf16>, !xegpu.mem_desc<16x64xf16> + xegpu.store_matrix %arg1, %arg0[8, 16]: vector<1xf16>, !xegpu.mem_desc<16x64xf16> + gpu.return +} + +// CHECK: gpu.func @simt_store_matrix_subgroupblockIO(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, %arg1: vector<8xf16>) +gpu.func @simt_store_matrix_subgroupblockIO(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, %arg1: vector<8xf16>) { + // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 16] {subgroupBlockIO}: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> + xegpu.store_matrix %arg1, %arg0[8, 16] {subgroupBlockIO}: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> + gpu.return +} + +// CHECK: gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, %arg1: vector<8xf16>) { +gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, %arg1: vector<8xf16>) { + // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] {vec_direction = #xegpu.matrix_access_direction, vec_length = 8 : i32}: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> + xegpu.store_matrix %arg1, %arg0[8, 8] {vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction}: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> + gpu.return +} + // CHECK: gpu.func @mem_desc_subview([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>) gpu.func @mem_desc_subview(%arg0: !xegpu.mem_desc<16x64xf16>) { //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout> From 446b951f2ed0bffd8be64955b7c4e5a94d5e2eb7 Mon Sep 17 00:00:00 2001 From: Jianhui Li Date: Wed, 8 Oct 2025 22:49:43 +0000 Subject: [PATCH 03/12] add tests and refactoring --- .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 19 +++- .../lib/Conversion/XeGPUToXeVM/CMakeLists.txt | 1 + .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 80 ++++++++++++++-- mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 12 ++- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 92 +++++++++++++------ .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 4 +- .../Transforms/XeGPUWgToSgDistribute.cpp | 2 +- mlir/test/Conversion/XeGPUToXeVM/dpas.mlir | 2 +- .../XeGPUToXeVM/loadstore_matrix.mlir | 54 ++++++++--- mlir/test/Dialect/XeGPU/ops.mlir | 22 ++--- 10 files changed, 211 insertions(+), 77 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index a0a8669baf90d..044a8ef22d891 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -1304,10 +1304,10 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, DenseI64ArrayAttr: $const_offsets, OptionalAttr:$vec_length, OptionalAttr:$vec_direction, - OptionalAttr:$subgroupBlockIO, + OptionalAttr:$subgroup_block_io, OptionalAttr:$layout ); - let results = (outs XeGPU_ValueType:$res); + let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$res); let assemblyFormat = [{ $mem_desc `` custom($offsets, $const_offsets) prop-dict attr-dict `` `:` type(operands) `->` type(results) @@ -1338,7 +1338,10 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, } ArrayRef getDataShape() { - return getRes().getType().getShape(); + auto resTy = getRes().getType(); + if (auto vecTy = llvm::dyn_cast(resTy)) + return vecTy.getShape(); + return {}; } }]; @@ -1348,10 +1351,13 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, AllElementTypesMatch<["mem_desc", "data"]>]> { let arguments = (ins - XeGPU_ValueType:$data, + AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$data, XeGPU_MemDesc:$mem_desc, Variadic: $offsets, DenseI64ArrayAttr: $const_offsets, + OptionalAttr:$vec_length, + OptionalAttr:$vec_direction, + OptionalAttr:$subgroup_block_io, OptionalAttr:$layout ); let assemblyFormat = [{ $data `,` $mem_desc `` custom($offsets, $const_offsets) @@ -1379,7 +1385,10 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, } ArrayRef getDataShape() { - return getData().getType().getShape(); + auto DataTy = getData().getType(); + if (auto vecTy = llvm::dyn_cast(DataTy)) + return vecTy.getShape(); + return {}; } }]; diff --git a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt index 84b25809f1ed0..dd9edc43a1657 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt +++ b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt @@ -21,6 +21,7 @@ add_mlir_conversion_library(MLIRXeGPUToXeVM MLIRIndexDialect MLIRSCFDialect MLIRXeGPUDialect + MLIRXeGPUUtils MLIRPass MLIRTransforms MLIRSCFTransforms diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 97deca167204a..f4f0a46c54089 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -21,6 +21,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "llvm/Support/FormatVariadic.h" @@ -371,8 +372,6 @@ static Value addOffset(ConversionPatternRewriter &rewriter, Location loc, Value baseAddr, Value offset, int64_t elemByteSize) { Value byteSize = arith::ConstantIntOp::create( rewriter, loc, rewriter.getI64Type(), elemByteSize); - offset = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(), - offset); Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize); Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset); return newAddr; @@ -583,6 +582,8 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern { else data = adaptor.getData(); VectorType valOrResVecTy = dyn_cast(data.getType()); + if (!valOrResVecTy) + valOrResVecTy = VectorType::get(1, data.getType()); int64_t elemBitWidth = valOrResVecTy.getElementType().getIntOrFloatBitWidth(); @@ -606,6 +607,8 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern { rewriter, loc, rewriter.getI64Type(), basePtrLLVM); Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets); + linearOffset = arith::IndexCastUIOp::create( + rewriter, loc, rewriter.getI64Type(), linearOffset); basePtrI64 = addOffset(rewriter, loc, basePtrI64, linearOffset, elemByteSize); @@ -613,15 +616,72 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern { basePtrLLVM = LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64); - if constexpr (std::is_same_v) { - - Value loadOp = - LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM); - rewriter.replaceOp(op, loadOp); + // if the size of valOrResVecTy is 1, it lowers to a scalar load/store + // operation. LLVM load/store does not support vector of size 1, so we need + // to handle this case separately. + if (valOrResVecTy.getNumElements() == 1) { + Type scalarTy = valOrResVecTy.getElementType(); + if constexpr (std::is_same_v) { + Value loadOp = + LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM); + rewriter.replaceOp(op, loadOp); + } else { + auto storeOp = LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), + basePtrLLVM); + rewriter.eraseOp(op); + } + return success(); } else { - auto storeOp = - LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM); - rewriter.eraseOp(op); + // if the attribute 'subgroup_block_io' is set to true, it lowers to + // xevm.blockload + auto subgroupBlockIoAttr = op.getSubgroupBlockIoAttr(); + bool subgroup_block_io = + subgroupBlockIoAttr && cast(subgroupBlockIoAttr).getValue(); + if (subgroup_block_io) { + if constexpr (std::is_same_v) { + Value loadOp = xevm::BlockLoadOp::create(rewriter, loc, valOrResVecTy, + basePtrLLVM); + rewriter.replaceOp(op, loadOp); + } else { + xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, + adaptor.getData(), nullptr); + rewriter.eraseOp(op); + } + } else { + // if the result is 1D vector, if the vector direction is Column, then + // the + // memory descriptor should be treated as column major + auto chipOpt = xegpu::getChipStr(op); + if (!chipOpt || (*chipOpt != "pvc" && *chipOpt != "bmg")) { + // the lowering only works for pvc and bmg + return rewriter.notifyMatchFailure( + op, "The lowering is specific to pvc or bmg."); + } + xegpu::MatrixAccessDirectionAttr vecDirection = + op.getVecDirectionAttr(); + if (vecDirection && + vecDirection.getValue() == xegpu::MatrixAccessDirection::COL && + !mdescTy.isColMajor()) + return rewriter.notifyMatchFailure( + op, "mem_desc should be column major when " + "vec_direction is COLUMN for 1D result."); + if (vecDirection && + vecDirection.getValue() == xegpu::MatrixAccessDirection::ROW && + mdescTy.isColMajor()) + return rewriter.notifyMatchFailure( + op, "mem_desc should be row major when " + "vec_direction is ROW for 1D result."); + + if constexpr (std::is_same_v) { + Value loadOp = + LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM); + rewriter.replaceOp(op, loadOp); + } else { + auto storeOp = LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), + basePtrLLVM); + rewriter.eraseOp(op); + } + } } return success(); } diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 3cbb39ee9b144..26f2f691ab860 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -813,9 +813,8 @@ SmallVector MemDescType::getStrides() { } llvm::dbgs() << "]\n"; - if (innerBlkShape.empty()) - return strides; - + // get perm from FCD to LCD + // perm[i] = the dim with i-th smallest stride SmallVector perm = llvm::to_vector<4>(llvm::seq(0, strides.size())); llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; }); @@ -908,6 +907,7 @@ SmallVector MemDescType::getStrides() { Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc, ArrayRef offsets) { + SmallVector matrixShape(getShape().begin(), getShape().end()); SmallVector blockShape = getBlockSize(); SmallVector strides = getStrides(); @@ -917,7 +917,11 @@ Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc, llvm::interleaveComma(strides, llvm::dbgs()); llvm::dbgs() << "]\n"); - if (!blockShape.empty()) { + // blockshape equal to matrixshape means no blocking + if (llvm::equal(blockShape, matrixShape)) { + // remove the outer dims from strides + strides.erase(strides.begin(), strides.begin() + matrixShape.size()); + } else { assert(offsets.size() == blockShape.size() && "offsets and blockShape must have the same size"); // say the original offset is [y, x], and the block shape is [By, Bx], diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index c40d5a42fb6e5..0bc7b3f06ec53 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -173,6 +173,51 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy, return success(); } +LogicalResult IsValidStoreMatrixParams( + VectorType dataTy, MemDescType mdescTy, UnitAttr subgroup_block_io, + MatrixAccessDirectionAttr vecDirection, IntegerAttr vecLength, + function_ref emitError) { + + if (!dataTy) + if (subgroup_block_io || vecDirection || vecLength) + return emitError() << "vec_length, vec_direction and subgroup_block_io " + "are only allowed when result is a 1D VectorType."; + else + return success(); + + if (mdescTy.getRank() != 2) + return emitError() << "mem_desc must be 2D."; + + ArrayRef dataShape = dataTy.getShape(); + ArrayRef mdescShape = mdescTy.getShape(); + + if (dataShape.size() == 2) { + if (subgroup_block_io || vecDirection || vecLength) + return emitError() << "vec_length, vec_direction and subgroup_block_io " + "are only allowed when result is a 1D VectorType."; + if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape), + [](auto p) { return std::get<0>(p) > std::get<1>(p); })) + return emitError() << "data shape must not exceed mem_desc shape."; + } else if (dataShape.size() == 1) { + + SmallVector blockSize = mdescTy.getBlockSize(); + // if the subgroup_block_io attribute is set, mdescTy must have block + // attribute + if (subgroup_block_io && !blockSize.size()) + return emitError() << "mem_desc must have block attribute when " + "subgroup_block_io is set."; + // if the subgroup_block_io attribute is set, the memdesc should be row + // major + if (subgroup_block_io && mdescTy.isColMajor()) + return emitError() << "mem_desc should be row major when " + "subgroup_block_io is set."; + } else if (dataShape.size() == 0) { + return emitError() << "result shape must not be empty."; + } + + return success(); +} + //===----------------------------------------------------------------------===// // XeGPU_CreateNdDescOp //===----------------------------------------------------------------------===// @@ -1053,25 +1098,20 @@ void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res, // nullptr/empty) build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr, /*vec_length=*/nullptr, /*vec_direction=*/nullptr, - /*subgroupBlockIO=*/nullptr, layout); + /*subgroup_block_io=*/nullptr, layout); } LogicalResult LoadMatrixOp::verify() { - VectorType resTy = getRes().getType(); - MemDescType mdescTy = getMemDesc().getType(); - - if (mdescTy.getRank() != 2) - return emitOpError("mem_desc must be 2D."); - ArrayRef valueShape = resTy.getShape(); - ArrayRef mdescShape = mdescTy.getShape(); + auto resTy = dyn_cast(getRes().getType()); + UnitAttr subgroup_block_io = getSubgroupBlockIoAttr(); + MatrixAccessDirectionAttr vecDirection = getVecDirectionAttr(); + IntegerAttr vecLength = getVecLengthAttr(); + MemDescType mdescTy = getMemDesc().getType(); - if (valueShape.size() != 1) { - if (llvm::any_of(llvm::zip_equal(valueShape, mdescShape), - [](auto p) { return std::get<0>(p) > std::get<1>(p); })) - return emitOpError("result shape must not exceed mem_desc shape."); - } - return success(); + return IsValidStoreMatrixParams(resTy, mdescTy, subgroup_block_io, + vecDirection, vecLength, + [&]() { return emitError(); }); } //===----------------------------------------------------------------------===// @@ -1086,24 +1126,20 @@ void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data, dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr, - layout); + /*vec_length=*/nullptr, /*vec_direction=*/nullptr, + /*subgroup_block_io=*/nullptr, layout); } LogicalResult StoreMatrixOp::verify() { - VectorType dataTy = getData().getType(); - MemDescType mdescTy = getMemDesc().getType(); - if (mdescTy.getRank() != 2) - return emitOpError("mem_desc must be 2D."); - - ArrayRef dataShape = dataTy.getShape(); - ArrayRef mdescShape = mdescTy.getShape(); - if (dataShape.size() != 1) { - if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape), - [](auto p) { return std::get<0>(p) > std::get<1>(p); })) - return emitOpError("data shape must not exceed mem_desc shape."); - } - return success(); + auto dataTy = dyn_cast(getData().getType()); + UnitAttr subgroup_block_io = getSubgroupBlockIoAttr(); + MatrixAccessDirectionAttr vecDirection = getVecDirectionAttr(); + IntegerAttr vecLength = getVecLengthAttr(); + MemDescType mdescTy = getMemDesc().getType(); + return IsValidStoreMatrixParams(dataTy, mdescTy, subgroup_block_io, + vecDirection, vecLength, + [&]() { return emitError(); }); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index a178d0fe4b0b0..6d17b27849a43 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -941,7 +941,7 @@ struct UnrollLoadMatrixOp : public UnrollPattern { LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - VectorType valueTy = op.getType(); + VectorType valueTy = llvm::dyn_cast(op.getType()); std::optional> targetShape = getTargetShape(op); if (!targetShape || targetShape->size() != (size_t)valueTy.getRank()) return failure(); @@ -984,7 +984,7 @@ struct UnrollStoreMatrixOp : public UnrollPattern { return failure(); Location loc = op.getLoc(); - VectorType valueTy = op.getData().getType(); + VectorType valueTy = llvm::dyn_cast(op.getData().getType()); ArrayRef shape = valueTy.getShape(); auto layout = dyn_cast(op.getLayoutAttr()); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 9413a9296b184..d57289a6b21e9 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -867,7 +867,7 @@ struct WgToSgLoadMatrixOp : public OpConversionPattern { return failure(); ArrayRef wgShape = op.getDataShape(); - VectorType valueTy = op.getRes().getType(); + VectorType valueTy = llvm::dyn_cast(op.getRes().getType()); Type elemTy = valueTy.getElementType(); xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); diff --git a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir index bbf313bf4fb60..a9ab0be00722c 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir @@ -7,7 +7,7 @@ gpu.module @test_kernel { // Loads are checked in a separate test. // CHECK: %[[D:.*]] = xevm.mma %[[ARG0]], %[[ARG1]], %[[ARG2]] {shape = , types = } // CHECK-SAME: : (vector<8xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32> - %d = xegpu.dpas %a_loaded, %b_loaded, %c_loaded {a_layout = #sg_map_a_f16, b_layout = #sg_map_b_f16, c_layout = #sg_map_c_f32} + %d = xegpu.dpas %a_loaded, %b_loaded, %c_loaded : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32> return %d : vector<8xf32> } diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir index 7b87f32b876fe..372f477219817 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir @@ -1,41 +1,65 @@ // RUN: mlir-opt -split-input-file -convert-xegpu-to-xevm --cse --canonicalize %s | FileCheck %s -gpu.module @test_kernel { +gpu.module @test_kernel [#xevm.target] { // e.g. for mem_desc<32x32xf16, @strides=[1, 16]> // its memory layout tuple is (blocked shape = [1,1,32,32],strides=[1024,1024,32,1]) //CHECK-LABEL: load_store_matrix_1 - gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> vector<1xf32> { + gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> f32 { %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32> - //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<1xf32> + //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32 %tid_x = gpu.thread_id x %c0 = arith.constant 0 : index - %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> vector<1xf32> - gpu.return %1: vector<1xf32> + %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32 + gpu.return %1: f32 } - // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]> + // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 32]> // its memory layout tuple is ([2,4,16,16],[256,512,1,16]) //CHECK-LABEL: load_store_matrix_2 - gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> vector<1xf16> { + gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> f16 { %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> - //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<1xf16> + //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f16 %tid_x = gpu.thread_id x %c13 = arith.constant 13 : index - %1 = xegpu.load_matrix %0[%c13, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index -> vector<1xf16> - gpu.return %1: vector<1xf16> + %1 = xegpu.load_matrix %0[%c13, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index -> f16 + gpu.return %1: f16 } // e.g. for mem_desc<32x64xf16, @block=[16, 16]> // its memory layout tuple is ([2,4,16,16],[1024,256,16,1]) //CHECK-LABEL: load_store_matrix_3 - gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> vector<1xf16> { + gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> f16 { %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> - //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<1xf16> + //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f16 + %tid_x = gpu.thread_id x + %c19 = arith.constant 19: index + %1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index -> f16 + gpu.return %1: f16 + } + + // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]> + // its memory layout tuple is ([2,4,16,16],[256,512,1,16]) + //CHECK-LABEL: load_store_matrix_4 + gpu.func @load_store_matrix_4(%arg0: memref<4096xi8, 3>) -> vector<8xf16> { + %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> + //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf16> %tid_x = gpu.thread_id x - %c17 = arith.constant 17 : index - %1 = xegpu.load_matrix %0[%c17, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index -> vector<1xf16> - gpu.return %1: vector<1xf16> + %c16 = arith.constant 16 : index + %1 = xegpu.load_matrix %0[%c16, %tid_x] {vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index -> vector<8xf16> + gpu.return %1: vector<8xf16> + } + + // e.g. for mem_desc<32x64xf16, @block=[16, 16]> + // its memory layout tuple is ([2,4,16,16],[1024,256,16,1]) + //CHECK-LABEL: load_store_matrix_5 + gpu.func @load_store_matrix_5(%arg0: memref<4096xi8, 3>) -> vector<8xf16> { + %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> + //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf16> + %c16 = arith.constant 16 : index + %c48 = arith.constant 48 : index + %1 = xegpu.load_matrix %0[%c16, %c48] {subgroup_block_io}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index -> vector<8xf16> + gpu.return %1: vector<8xf16> } } \ No newline at end of file diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir index 47aa05763ee99..eb5d653be8b9c 100644 --- a/mlir/test/Dialect/XeGPU/ops.mlir +++ b/mlir/test/Dialect/XeGPU/ops.mlir @@ -846,17 +846,17 @@ gpu.func @simt_load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) { gpu.return } -// CHECK: gpu.func @simt_load_matrix_subgroupblockIO(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) -gpu.func @simt_load_matrix_subgroupblockIO(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) { - // CHECK: xegpu.load_matrix [[ARG0]][8, 16] <{subgroupBlockIO}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8xf16> - %data = xegpu.load_matrix %arg0[8, 16] {subgroupBlockIO}: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8xf16> +// CHECK: gpu.func @simt_load_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) +gpu.func @simt_load_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) { + // CHECK: xegpu.load_matrix [[ARG0]][8, 16] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8xf16> + %data = xegpu.load_matrix %arg0[8, 16] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8xf16> gpu.return } // CHECK: gpu.func @simt_load_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) gpu.func @simt_load_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) { // CHECK: xegpu.load_matrix [[ARG0]][8, 8] <{vec_direction = #xegpu.matrix_access_direction, vec_length = 8 : i32}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8xf16> - %data = xegpu.load_matrix %arg0[8, 8]{vec_direction = #xegpu.matrix_access_direction, vec_length = 8 : i32}: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8xf16> + %data = xegpu.load_matrix %arg0[8, 8] <{vec_direction = #xegpu.matrix_access_direction, vec_length = 8 : i32}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8xf16> gpu.return } @@ -881,17 +881,17 @@ gpu.func @simt_store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<1xf gpu.return } -// CHECK: gpu.func @simt_store_matrix_subgroupblockIO(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, %arg1: vector<8xf16>) -gpu.func @simt_store_matrix_subgroupblockIO(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, %arg1: vector<8xf16>) { - // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 16] {subgroupBlockIO}: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> - xegpu.store_matrix %arg1, %arg0[8, 16] {subgroupBlockIO}: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> +// CHECK: gpu.func @simt_store_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, %arg1: vector<8xf16>) +gpu.func @simt_store_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, %arg1: vector<8xf16>) { + // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 16] <{subgroup_block_io}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> + xegpu.store_matrix %arg1, %arg0[8, 16] <{subgroup_block_io}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> gpu.return } // CHECK: gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, %arg1: vector<8xf16>) { gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, %arg1: vector<8xf16>) { - // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] {vec_direction = #xegpu.matrix_access_direction, vec_length = 8 : i32}: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> - xegpu.store_matrix %arg1, %arg0[8, 8] {vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction}: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> + // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] <{vec_direction = #xegpu.matrix_access_direction, vec_length = 8 : i32}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> + xegpu.store_matrix %arg1, %arg0[8, 8] <{vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> gpu.return } From 9f9744cecbd30fea7b63c47768b323879222d105 Mon Sep 17 00:00:00 2001 From: Jianhui Li Date: Thu, 9 Oct 2025 02:00:49 +0000 Subject: [PATCH 04/12] bug fixes --- .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 43 +++++---- mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 92 +------------------ .../XeGPUToXeVM/loadstore_matrix.mlir | 2 +- mlir/test/Dialect/XeGPU/invalid.mlir | 2 +- 4 files changed, 30 insertions(+), 109 deletions(-) diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index f4f0a46c54089..67e8246e5536a 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -33,8 +33,6 @@ #include -#define DEBUG_TYPE "xegpu-to-xevm" - namespace mlir { #define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS #include "mlir/Conversion/Passes.h.inc" @@ -519,9 +517,6 @@ class CreateMemDescOpPattern final LogicalResult matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // DEBUG: Print operation and types - LLVM_DEBUG(llvm::dbgs() - << "[XeGPUToXeVM] Lowering CreateMemDescOp: " << op << "\n"); TypedValue src = op.getSource(); auto resTy = cast(op.getResult().getType()); @@ -529,19 +524,10 @@ class CreateMemDescOpPattern final // memory space auto newResTy = getTypeConverter()->convertType(resTy); - LLVM_DEBUG(llvm::dbgs() - << "[XeGPUToXeVM] Source MemRefType: " << src.getType() << "\n"); - LLVM_DEBUG(llvm::dbgs() - << "[XeGPUToXeVM] Result MemDescType: " << resTy << "\n"); - LLVM_DEBUG(llvm::dbgs() - << "[XeGPUToXeVM] Converted MemRefType: " << newResTy << "\n"); Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0); auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy, Value(src), zero, ValueRange()); rewriter.replaceOp(op, viewOp); - LLVM_DEBUG( - llvm::dbgs() - << "[XeGPUToXeVM] Replaced CreateMemDescOp with memref::ViewOp\n"); return success(); } }; @@ -635,16 +621,33 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern { // if the attribute 'subgroup_block_io' is set to true, it lowers to // xevm.blockload auto subgroupBlockIoAttr = op.getSubgroupBlockIoAttr(); - bool subgroup_block_io = - subgroupBlockIoAttr && cast(subgroupBlockIoAttr).getValue(); + bool subgroup_block_io = static_cast(subgroupBlockIoAttr); + + // BlockLoadOp only supports integer types, so we need to bitcast + // Get integer type with matching bit width + Type elemTy = valOrResVecTy.getElementType(); + int64_t bitWidth = elemTy.getIntOrFloatBitWidth(); + Type intElemTy = rewriter.getIntegerType(bitWidth); + VectorType intVecTy = + VectorType::get(valOrResVecTy.getShape(), intElemTy); + if (subgroup_block_io) { if constexpr (std::is_same_v) { - Value loadOp = xevm::BlockLoadOp::create(rewriter, loc, valOrResVecTy, - basePtrLLVM); + Value loadOp = + xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM); + if (intVecTy != valOrResVecTy) { + loadOp = + vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp); + } rewriter.replaceOp(op, loadOp); } else { - xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, - adaptor.getData(), nullptr); + Value dataToStore = adaptor.getData(); + if (valOrResVecTy != intVecTy) { + dataToStore = + vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore); + } + xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore, + nullptr); rewriter.eraseOp(op); } } else { diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 26f2f691ab860..cccc8fab4adbc 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -37,8 +37,6 @@ void XeGPUDialect::initialize() { >(); } -#define DEBUG_TYPE "xegpu" - /// Generates instructions to compute offsets for a subgroup identified by /// its multidimensional indices (sgId), using the specified subgroup layout /// (sgLayout), subgroup data dimensions (sizePerSg), and the overall data @@ -788,30 +786,7 @@ SmallVector MemDescType::getStrides() { strides.push_back(cast(attr).getInt()); } - llvm::dbgs() << "DEBUG: matrixShape = ["; - for (size_t i = 0; i < matrixShape.size(); ++i) { - llvm::dbgs() << matrixShape[i]; - if (i < matrixShape.size() - 1) - llvm::dbgs() << ", "; - } - llvm::dbgs() << "]\n"; - - llvm::dbgs() << "DEBUG: strides = ["; - for (size_t i = 0; i < strides.size(); ++i) { - llvm::dbgs() << strides[i]; - if (i < strides.size() - 1) - llvm::dbgs() << ", "; - } - llvm::dbgs() << "]\n"; - SmallVector innerBlkShape = getBlockSize(); - llvm::dbgs() << "DEBUG: innerBlkShape = ["; - for (size_t i = 0; i < innerBlkShape.size(); ++i) { - llvm::dbgs() << innerBlkShape[i]; - if (i < innerBlkShape.size() - 1) - llvm::dbgs() << ", "; - } - llvm::dbgs() << "]\n"; // get perm from FCD to LCD // perm[i] = the dim with i-th smallest stride @@ -819,25 +794,13 @@ SmallVector MemDescType::getStrides() { llvm::to_vector<4>(llvm::seq(0, strides.size())); llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; }); - llvm::dbgs() << "DEBUG: perm = ["; - for (size_t i = 0; i < perm.size(); ++i) { - llvm::dbgs() << perm[i]; - if (i < perm.size() - 1) - llvm::dbgs() << ", "; - } - llvm::dbgs() << "]\n"; - assert(strides[perm[0]] == 1 && "inner most dim must have stride 1"); - SmallVector innerBlkStride = computeStrides(innerBlkShape); - - llvm::dbgs() << "DEBUG: innerBlkStride = ["; - for (size_t i = 0; i < innerBlkStride.size(); ++i) { - llvm::dbgs() << innerBlkStride[i]; - if (i < innerBlkStride.size() - 1) - llvm::dbgs() << ", "; - } - llvm::dbgs() << "]\n"; + SmallVector innerBlkStride(innerBlkShape.size()); + innerBlkStride[perm[0]] = 1; + for (size_t i = 1; i < perm.size(); ++i) + innerBlkStride[perm[i]] = + innerBlkStride[perm[i - 1]] * innerBlkShape[perm[i - 1]]; // compute the original matrix shape using the stride info // and compute the number of blocks in each dimension @@ -850,28 +813,10 @@ SmallVector MemDescType::getStrides() { BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]]; } - llvm::dbgs() << "DEBUG: matrixShapeOrig = ["; - for (size_t i = 0; i < matrixShapeOrig.size(); ++i) { - llvm::dbgs() << matrixShapeOrig[i]; - if (i < matrixShapeOrig.size() - 1) - llvm::dbgs() << ", "; - } - llvm::dbgs() << "]\n"; - - llvm::dbgs() << "DEBUG: BlkShapeOrig = ["; - for (size_t i = 0; i < BlkShapeOrig.size(); ++i) { - llvm::dbgs() << BlkShapeOrig[i]; - if (i < BlkShapeOrig.size() - 1) - llvm::dbgs() << ", "; - } - llvm::dbgs() << "]\n"; - int64_t innerBlkSize = 1; for (auto s : innerBlkShape) innerBlkSize *= s; - llvm::dbgs() << "DEBUG: innerBlkSize = " << innerBlkSize << "\n"; - SmallVector outerBlkStride(matrixShape.size()); outerBlkStride[perm[0]] = innerBlkSize; for (size_t i = 0; i < perm.size() - 1; ++i) { @@ -879,27 +824,11 @@ SmallVector MemDescType::getStrides() { outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]]; } - llvm::dbgs() << "DEBUG: outerBlkStride = ["; - for (size_t i = 0; i < outerBlkStride.size(); ++i) { - llvm::dbgs() << outerBlkStride[i]; - if (i < outerBlkStride.size() - 1) - llvm::dbgs() << ", "; - } - llvm::dbgs() << "]\n"; - // combine the inner and outer strides SmallVector blockedStrides; blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end()); blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end()); - llvm::dbgs() << "DEBUG: blockedStrides = ["; - for (size_t i = 0; i < blockedStrides.size(); ++i) { - llvm::dbgs() << blockedStrides[i]; - if (i < blockedStrides.size() - 1) - llvm::dbgs() << ", "; - } - llvm::dbgs() << "]\n"; - return blockedStrides; } @@ -911,12 +840,6 @@ Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc, SmallVector blockShape = getBlockSize(); SmallVector strides = getStrides(); - LLVM_DEBUG(llvm::dbgs() << "getLinearOffsets: blockShape=["; - llvm::interleaveComma(blockShape, llvm::dbgs()); - llvm::dbgs() << "], strides=["; - llvm::interleaveComma(strides, llvm::dbgs()); - llvm::dbgs() << "]\n"); - // blockshape equal to matrixshape means no blocking if (llvm::equal(blockShape, matrixShape)) { // remove the outer dims from strides @@ -937,8 +860,6 @@ Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc, blockedOffsets.append(rems.begin(), rems.end()); offsets = blockedOffsets; - LLVM_DEBUG(llvm::dbgs() << "getLinearOffsets: blocked offsets size=" - << offsets.size() << "\n"); } // Start with initial value as matrix descriptor's base offset. @@ -949,9 +870,6 @@ Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc, linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset); } - LLVM_DEBUG(llvm::dbgs() << "getLinearOffsets: final linearOffset=" - << linearOffset << "\n"); - return linearOffset; } diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir index 372f477219817..3713635a1cc71 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir @@ -55,7 +55,7 @@ gpu.module @test_kernel [#xevm.target] { //CHECK-LABEL: load_store_matrix_5 gpu.func @load_store_matrix_5(%arg0: memref<4096xi8, 3>) -> vector<8xf16> { %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> - //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf16> + //CHECK: xevm.blockload {{.*}} : (!llvm.ptr<3>) -> vector<8xi16> %c16 = arith.constant 16 : index %c48 = arith.constant 48 : index %1 = xegpu.load_matrix %0[%c16, %c48] {subgroup_block_io}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index -> vector<8xf16> diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index 228ef69d9a478..bef45438c944e 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -858,7 +858,7 @@ func.func @load_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16> // ----- func.func @load_mem_desc_invalid_result_size(%arg0: !xegpu.mem_desc<16x64xf16>) { - // expected-error@+1 {{result shape must not exceed mem_desc shape}} + // expected-error@+1 {{data shape must not exceed mem_desc shape}} %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> vector<32x16xf16> return } From bbd43d089096c8c66507c59c7df0c42d2806bcc0 Mon Sep 17 00:00:00 2001 From: Jianhui Li Date: Fri, 10 Oct 2025 00:20:47 +0000 Subject: [PATCH 05/12] polish tests --- .../XeGPUToXeVM/loadstore_matrix.mlir | 154 +++++++++++++++++- mlir/test/Dialect/XeGPU/invalid.mlir | 29 ++++ 2 files changed, 174 insertions(+), 9 deletions(-) diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir index 3713635a1cc71..6302758195e51 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir @@ -1,64 +1,200 @@ -// RUN: mlir-opt -split-input-file -convert-xegpu-to-xevm --cse --canonicalize %s | FileCheck %s +// RUN: mlir-opt -split-input-file -convert-xegpu-to-xevm -cse %s | FileCheck %s gpu.module @test_kernel [#xevm.target] { - // e.g. for mem_desc<32x32xf16, @strides=[1, 16]> + // e.g. for mem_desc<32x32xf16, @strides=[1, 16]> // its memory layout tuple is (blocked shape = [1,1,32,32],strides=[1024,1024,32,1]) //CHECK-LABEL: load_store_matrix_1 gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> f32 { %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32> + + //CHECK: %[[TID:.*]] = gpu.thread_id x + //CHECK: %[[C1:.*]] = arith.constant 1 : index + //CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index + //CHECK: %[[C4:.*]] = arith.constant 4 : i64 + //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i64 //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32 + %tid_x = gpu.thread_id x %c0 = arith.constant 0 : index %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32 + + //CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3> + + xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index + gpu.return %1: f32 } - // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 32]> +// e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 32]> // its memory layout tuple is ([2,4,16,16],[256,512,1,16]) //CHECK-LABEL: load_store_matrix_2 gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> f16 { %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> - //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f16 + //CHECK: %[[c0:.*]] = arith.constant 0 : index + //CHECK: %[[tid_x:.*]] = gpu.thread_id x + //CHECK: %[[c13:.*]] = arith.constant 13 : index + //CHECK: %[[c16:.*]] = arith.constant 16 : index + //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c13]], %[[c16]] : index + //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index + //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index + //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index + + //CHECK: %[[c256:.*]] = arith.constant 256 : index + //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index + //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index + //CHECK: %[[c512:.*]] = arith.constant 512 : index + //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index + //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index + //CHECK: %[[c1:.*]] = arith.constant 1 : index + //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index + //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index + //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index + //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index + + //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> f16 + + %tid_x = gpu.thread_id x %c13 = arith.constant 13 : index %1 = xegpu.load_matrix %0[%c13, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index -> f16 + + //CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3> + + xegpu.store_matrix %1, %0[%c13, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index gpu.return %1: f16 } + // e.g. for mem_desc<32x64xf16, @block=[16, 16]> // its memory layout tuple is ([2,4,16,16],[1024,256,16,1]) //CHECK-LABEL: load_store_matrix_3 gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> f16 { + //CHECK: %[[c0:.*]] = arith.constant 0 : index + //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3> %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> - //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f16 + + //CHECK: %[[tid_x:.*]] = gpu.thread_id x + //CHECK: %[[c19:.*]] = arith.constant 19 : index %tid_x = gpu.thread_id x %c19 = arith.constant 19: index + + //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index + //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i64 + //CHECK: %[[c16:.*]] = arith.constant 16 : index + //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index + //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index + //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index + //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index + //CHECK: %[[c1024:.*]] = arith.constant 1024 : index + //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index + //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index + //CHECK: %[[c256:.*]] = arith.constant 256 : index + //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c256]] : index + //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index + //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c16]] : index + //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index + //CHECK: %[[c1:.*]] = arith.constant 1 : index + //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index + //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index + + //CHECK: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16 %1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index -> f16 + + //CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3> + xegpu.store_matrix %1, %0[%c19, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index + + //CHECK: gpu.return %[[loaded]] : f16 gpu.return %1: f16 } - - // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]> + + // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]> // its memory layout tuple is ([2,4,16,16],[256,512,1,16]) //CHECK-LABEL: load_store_matrix_4 gpu.func @load_store_matrix_4(%arg0: memref<4096xi8, 3>) -> vector<8xf16> { %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> - //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf16> + + //CHECK: %[[c0:.*]] = arith.constant 0 : index + //CHECK: %[[tid_x:.*]] = gpu.thread_id x + + //CHECK: %[[c16:.*]] = arith.constant 16 : index + //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index + //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index + //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index + //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index + + //CHECK: %[[c256:.*]] = arith.constant 256 : index + //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index + //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index + //CHECK: %[[c512:.*]] = arith.constant 512 : index + //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index + //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index + //CHECK: %[[c1:.*]] = arith.constant 1 : index + //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index + //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index + //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index + //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index + + //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> vector<8xf16> + %tid_x = gpu.thread_id x %c16 = arith.constant 16 : index %1 = xegpu.load_matrix %0[%c16, %tid_x] {vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index -> vector<8xf16> + + //CHECK: llvm.store %[[loaded]], {{.*}} : vector<8xf16>, !llvm.ptr<3> + xegpu.store_matrix %1, %0[%c16, %tid_x] {vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction}: vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index + gpu.return %1: vector<8xf16> } + // e.g. for mem_desc<32x64xf16, @block=[16, 16]> // its memory layout tuple is ([2,4,16,16],[1024,256,16,1]) //CHECK-LABEL: load_store_matrix_5 gpu.func @load_store_matrix_5(%arg0: memref<4096xi8, 3>) -> vector<8xf16> { + //CHECK: %[[c0:.*]] = arith.constant 0 : index + //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3> + %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> - //CHECK: xevm.blockload {{.*}} : (!llvm.ptr<3>) -> vector<8xi16> + + //CHECK: %[[c16:.*]] = arith.constant 16 : index + //CHECK: %[[c48:.*]] = arith.constant 48 : index + %c16 = arith.constant 16 : index %c48 = arith.constant 48 : index + + //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index + //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i64 + //CHECK: %[[offset0:.*]] = arith.divsi %[[c16]], %[[c16]] : index + //CHECK: %[[offset1:.*]] = arith.remsi %[[c16]], %[[c16]] : index + //CHECK: %[[offset2:.*]] = arith.divsi %[[c48]], %[[c16]] : index + //CHECK: %[[offset3:.*]] = arith.remsi %[[c48]], %[[c16]] : index + //CHECK: %[[c1024:.*]] = arith.constant 1024 : index + //CHECK: %[[mul0:.*]] = arith.muli %[[offset0]], %[[c1024]] : index + //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index + //CHECK: %[[c256:.*]] = arith.constant 256 : index + //CHECK: %[[mul1:.*]] = arith.muli %[[offset2]], %[[c256]] : index + //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index + //CHECK: %[[mul2:.*]] = arith.muli %[[offset1]], %[[c16]] : index + //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index + //CHECK: %[[c1:.*]] = arith.constant 1 : index + //CHECK: %[[mul3:.*]] = arith.muli %[[offset3]], %[[c1]] : index + //CHECK: %[[linearOffset:.*]] = arith.addi %[[mul3]], %[[add2]] : index + //CHECK: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i64 + //CHECK: %[[c2:.*]] = arith.constant 2 : i64 + //CHECK: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i64 + //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI64]], %[[byteOffset]] : i64 + //CHECK: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i64 to !llvm.ptr<3> + //CHECK: %[[loadedI16:.*]] = xevm.blockload %[[ptr]] : (!llvm.ptr<3>) -> vector<8xi16> + //CHECK: %[[loaded:.*]] = vector.bitcast %[[loadedI16]] : vector<8xi16> to vector<8xf16> + %1 = xegpu.load_matrix %0[%c16, %c48] {subgroup_block_io}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index -> vector<8xf16> + + //CHECK: %[[storeDataI16:.*]] = vector.bitcast %[[loaded]] : vector<8xf16> to vector<8xi16> + //CHECK: xevm.blockstore %[[ptr]], %[[storeDataI16]] : (!llvm.ptr<3>, vector<8xi16>) + + xegpu.store_matrix %1, %0[%c16, %c48] {subgroup_block_io}: vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index + gpu.return %1: vector<8xf16> } diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index bef45438c944e..fee3136195e1d 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -870,6 +870,21 @@ func.func @load_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>) { return } +// ----- +func.func @load_mem_desc_invalid_attr1(%arg0: !xegpu.mem_desc<16x64xf16>) { + // expected-error@+1 {{vec_length, vec_direction and subgroup_block_io are only allowed when result is a 1D VectorType.}} + %data1 = xegpu.load_matrix %arg0[8, 8]<{vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction}>: !xegpu.mem_desc<16x64xf16> -> vector<16x16xf16> + return +} + +// ----- +func.func @load_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>) { + // expected-error@+1 {{vec_length, vec_direction and subgroup_block_io are only allowed when result is a 1D VectorType.}} + %data2 = xegpu.load_matrix %arg0[8, 8] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16> -> vector<16x16xf16> + return +} + + // ----- func.func @store_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf32>) { // expected-error@+1 {{failed to verify that all of {mem_desc, data} have same element type}} @@ -891,6 +906,20 @@ func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: ve return } +// ----- +func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) { + // expected-error@+1 {{vec_length, vec_direction and subgroup_block_io are only allowed when result is a 1D VectorType.}} + xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16> + return +} + +// ----- +func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) { + // expected-error@+1 {{vec_length, vec_direction and subgroup_block_io are only allowed when result is a 1D VectorType.}} + xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16> + return +} + // ----- func.func @mem_desc_subview_size_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) { // expected-error@+1 {{result shape must not exceed source shape}} From 034476186425dde826929584ae36f95fa7263fd8 Mon Sep 17 00:00:00 2001 From: Jianhui Li Date: Fri, 10 Oct 2025 06:19:21 +0000 Subject: [PATCH 06/12] fix minor issues --- mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 67e8246e5536a..05f26354e5a2a 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -61,10 +61,8 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) { return static_cast(xevm::AddrSpace::GLOBAL); case xegpu::MemorySpace::SLM: return static_cast(xevm::AddrSpace::SHARED); - default: - llvm_unreachable("Unknown XeGPU memory space"); - return static_cast(xevm::AddrSpace::GLOBAL); } + llvm_unreachable("Unknown XeGPU memory space"); } // Get same bitwidth flat vector type of new element type. @@ -186,8 +184,9 @@ class CreateNdDescToXeVMPattern SmallVector mixedSizes = op.getMixedSizes(); // Descriptor shape is expected to be 2D. int64_t rank = mixedSizes.size(); - if (rank != 2) + if (rank != 2) { return rewriter.notifyMatchFailure(op, "Expected 2D shape."); + } auto sourceTy = source.getType(); auto sourceMemrefTy = dyn_cast(sourceTy); // If source is a memref, we need to extract the aligned pointer as index. @@ -612,8 +611,7 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern { LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM); rewriter.replaceOp(op, loadOp); } else { - auto storeOp = LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), - basePtrLLVM); + LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM); rewriter.eraseOp(op); } return success(); @@ -680,8 +678,7 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern { LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM); rewriter.replaceOp(op, loadOp); } else { - auto storeOp = LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), - basePtrLLVM); + LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM); rewriter.eraseOp(op); } } From 966525b19652cb75c20722dfa9c22bb74d43a87b Mon Sep 17 00:00:00 2001 From: Jianhui Li Date: Mon, 13 Oct 2025 18:04:56 +0000 Subject: [PATCH 07/12] remove vector direction and lenght attirbutes --- .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 18 ------------ .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 4 --- .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 18 ++---------- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 29 +++++++------------ .../XeGPUToXeVM/loadstore_matrix.mlir | 4 +-- mlir/test/Dialect/XeGPU/invalid.mlir | 13 ++------- mlir/test/Dialect/XeGPU/ops.mlir | 8 ++--- 7 files changed, 22 insertions(+), 72 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index 601e966b49890..2efd575a652db 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -724,22 +724,4 @@ def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> { } -def RowOriented : I32EnumAttrCase<"ROW", 0, "row">; -def ColOriented : I32EnumAttrCase<"COL", 1, "col">; -def MatrixAccessDirection : - I32EnumAttr<"MatrixAccessDirection", - "Matrix elements/vectors can have row or column direction", [ - RowOriented, ColOriented -]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::xegpu"; -} -def MatrixAccessDirectionAttr : - EnumAttr{ - let summary = [{Describe the direction of memory access for load_matrix and store_matrix.}]; - let assemblyFormat = "`<` $value `>`"; -} - #endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 044a8ef22d891..f41f9e887cff7 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -1302,8 +1302,6 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, let arguments = (ins XeGPU_MemDesc:$mem_desc, Variadic: $offsets, DenseI64ArrayAttr: $const_offsets, - OptionalAttr:$vec_length, - OptionalAttr:$vec_direction, OptionalAttr:$subgroup_block_io, OptionalAttr:$layout ); @@ -1355,8 +1353,6 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, XeGPU_MemDesc:$mem_desc, Variadic: $offsets, DenseI64ArrayAttr: $const_offsets, - OptionalAttr:$vec_length, - OptionalAttr:$vec_direction, OptionalAttr:$subgroup_block_io, OptionalAttr:$layout ); diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 05f26354e5a2a..2ff2c98d291d2 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -184,9 +184,9 @@ class CreateNdDescToXeVMPattern SmallVector mixedSizes = op.getMixedSizes(); // Descriptor shape is expected to be 2D. int64_t rank = mixedSizes.size(); - if (rank != 2) { + if (rank != 2) return rewriter.notifyMatchFailure(op, "Expected 2D shape."); - } + auto sourceTy = source.getType(); auto sourceMemrefTy = dyn_cast(sourceTy); // If source is a memref, we need to extract the aligned pointer as index. @@ -658,20 +658,6 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "The lowering is specific to pvc or bmg."); } - xegpu::MatrixAccessDirectionAttr vecDirection = - op.getVecDirectionAttr(); - if (vecDirection && - vecDirection.getValue() == xegpu::MatrixAccessDirection::COL && - !mdescTy.isColMajor()) - return rewriter.notifyMatchFailure( - op, "mem_desc should be column major when " - "vec_direction is COLUMN for 1D result."); - if (vecDirection && - vecDirection.getValue() == xegpu::MatrixAccessDirection::ROW && - mdescTy.isColMajor()) - return rewriter.notifyMatchFailure( - op, "mem_desc should be row major when " - "vec_direction is ROW for 1D result."); if constexpr (std::is_same_v) { Value loadOp = diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 0bc7b3f06ec53..8d86e78fcbf4f 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -173,17 +173,18 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy, return success(); } -LogicalResult IsValidStoreMatrixParams( - VectorType dataTy, MemDescType mdescTy, UnitAttr subgroup_block_io, - MatrixAccessDirectionAttr vecDirection, IntegerAttr vecLength, - function_ref emitError) { - - if (!dataTy) - if (subgroup_block_io || vecDirection || vecLength) - return emitError() << "vec_length, vec_direction and subgroup_block_io " +LogicalResult +IsValidStoreMatrixParams(VectorType dataTy, MemDescType mdescTy, + UnitAttr subgroup_block_io, + function_ref emitError) { + + if (!dataTy) { + if (subgroup_block_io) + return emitError() << "subgroup_block_io " "are only allowed when result is a 1D VectorType."; else return success(); + } if (mdescTy.getRank() != 2) return emitError() << "mem_desc must be 2D."; @@ -192,8 +193,8 @@ LogicalResult IsValidStoreMatrixParams( ArrayRef mdescShape = mdescTy.getShape(); if (dataShape.size() == 2) { - if (subgroup_block_io || vecDirection || vecLength) - return emitError() << "vec_length, vec_direction and subgroup_block_io " + if (subgroup_block_io) + return emitError() << "subgroup_block_io " "are only allowed when result is a 1D VectorType."; if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape), [](auto p) { return std::get<0>(p) > std::get<1>(p); })) @@ -1097,7 +1098,6 @@ void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res, // Call the generated builder with all parameters (including optional ones as // nullptr/empty) build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr, - /*vec_length=*/nullptr, /*vec_direction=*/nullptr, /*subgroup_block_io=*/nullptr, layout); } @@ -1105,12 +1105,9 @@ LogicalResult LoadMatrixOp::verify() { auto resTy = dyn_cast(getRes().getType()); UnitAttr subgroup_block_io = getSubgroupBlockIoAttr(); - MatrixAccessDirectionAttr vecDirection = getVecDirectionAttr(); - IntegerAttr vecLength = getVecLengthAttr(); MemDescType mdescTy = getMemDesc().getType(); return IsValidStoreMatrixParams(resTy, mdescTy, subgroup_block_io, - vecDirection, vecLength, [&]() { return emitError(); }); } @@ -1126,7 +1123,6 @@ void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data, dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr, - /*vec_length=*/nullptr, /*vec_direction=*/nullptr, /*subgroup_block_io=*/nullptr, layout); } @@ -1134,11 +1130,8 @@ LogicalResult StoreMatrixOp::verify() { auto dataTy = dyn_cast(getData().getType()); UnitAttr subgroup_block_io = getSubgroupBlockIoAttr(); - MatrixAccessDirectionAttr vecDirection = getVecDirectionAttr(); - IntegerAttr vecLength = getVecLengthAttr(); MemDescType mdescTy = getMemDesc().getType(); return IsValidStoreMatrixParams(dataTy, mdescTy, subgroup_block_io, - vecDirection, vecLength, [&]() { return emitError(); }); } diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir index 6302758195e51..ebb3c2b2b6a83 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir @@ -139,10 +139,10 @@ gpu.module @test_kernel [#xevm.target] { %tid_x = gpu.thread_id x %c16 = arith.constant 16 : index - %1 = xegpu.load_matrix %0[%c16, %tid_x] {vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index -> vector<8xf16> + %1 = xegpu.load_matrix %0[%c16, %tid_x] : !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index -> vector<8xf16> //CHECK: llvm.store %[[loaded]], {{.*}} : vector<8xf16>, !llvm.ptr<3> - xegpu.store_matrix %1, %0[%c16, %tid_x] {vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction}: vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index + xegpu.store_matrix %1, %0[%c16, %tid_x] : vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index gpu.return %1: vector<8xf16> } diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index fee3136195e1d..6062eba709b88 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -870,16 +870,9 @@ func.func @load_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>) { return } -// ----- -func.func @load_mem_desc_invalid_attr1(%arg0: !xegpu.mem_desc<16x64xf16>) { - // expected-error@+1 {{vec_length, vec_direction and subgroup_block_io are only allowed when result is a 1D VectorType.}} - %data1 = xegpu.load_matrix %arg0[8, 8]<{vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction}>: !xegpu.mem_desc<16x64xf16> -> vector<16x16xf16> - return -} - // ----- func.func @load_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>) { - // expected-error@+1 {{vec_length, vec_direction and subgroup_block_io are only allowed when result is a 1D VectorType.}} + // expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}} %data2 = xegpu.load_matrix %arg0[8, 8] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16> -> vector<16x16xf16> return } @@ -908,14 +901,14 @@ func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: ve // ----- func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) { - // expected-error@+1 {{vec_length, vec_direction and subgroup_block_io are only allowed when result is a 1D VectorType.}} + // expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}} xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16> return } // ----- func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) { - // expected-error@+1 {{vec_length, vec_direction and subgroup_block_io are only allowed when result is a 1D VectorType.}} + // expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}} xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16> return } diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir index eb5d653be8b9c..f1f5f86d33bc0 100644 --- a/mlir/test/Dialect/XeGPU/ops.mlir +++ b/mlir/test/Dialect/XeGPU/ops.mlir @@ -855,8 +855,8 @@ gpu.func @simt_load_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, # // CHECK: gpu.func @simt_load_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) gpu.func @simt_load_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) { - // CHECK: xegpu.load_matrix [[ARG0]][8, 8] <{vec_direction = #xegpu.matrix_access_direction, vec_length = 8 : i32}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8xf16> - %data = xegpu.load_matrix %arg0[8, 8] <{vec_direction = #xegpu.matrix_access_direction, vec_length = 8 : i32}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8xf16> + // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8xf16> + %data = xegpu.load_matrix %arg0[8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8xf16> gpu.return } @@ -890,8 +890,8 @@ gpu.func @simt_store_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, // CHECK: gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, %arg1: vector<8xf16>) { gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, %arg1: vector<8xf16>) { - // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] <{vec_direction = #xegpu.matrix_access_direction, vec_length = 8 : i32}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> - xegpu.store_matrix %arg1, %arg0[8, 8] <{vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> + // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] : vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> + xegpu.store_matrix %arg1, %arg0[8, 8] : vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> gpu.return } From 272f51213290a1784ac8a44124fbb38b67c9b1c3 Mon Sep 17 00:00:00 2001 From: Jianhui Li Date: Mon, 13 Oct 2025 23:15:41 +0000 Subject: [PATCH 08/12] address comments --- .../mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 26 ++++++++++++------- mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 8 +++--- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 4 +-- .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 3 +++ .../Transforms/XeGPUWgToSgDistribute.cpp | 1 + .../XeGPUToXeVM/loadstore_matrix.mlir | 2 +- 6 files changed, 27 insertions(+), 17 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td index c261fbb576642..99526159cd2e7 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td @@ -242,7 +242,6 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m if (layout && layout.hasAttr("stride")) { return layout.getStrides(); } - // derive and return default strides SmallVector defaultStrides; llvm::append_range(defaultStrides, getShape().drop_front()); @@ -251,6 +250,15 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m return builder.getI64ArrayAttr(defaultStrides); } + ArrayAttr getBlockAttr() { + auto layout = getMemLayout(); + if (layout && layout.hasAttr("block")) { + return layout.getBlockAttr(); + } + Builder builder(getContext()); + return builder.getI64ArrayAttr({}); + } + /// Heuristic to determine if the MemDesc uses column-major layout, /// based on the rank and the value of the first stride dimension. bool isColMajor() { @@ -261,16 +269,14 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m // get the Blocking shape for a MemDescType, Which is represented // as an attribute in MemDescType. By default it is the shape // of the mdescTy - SmallVector getBlockSize() { + SmallVector getBlockShape() { SmallVector size(getShape()); - MemLayoutAttr layout = getMemLayout(); - if (layout && layout.hasAttr("block")) { - ArrayAttr attr = layout.getBlockAttr(); + ArrayAttr blockAttr = getBlockAttr(); + if (!blockAttr.empty()) { size.clear(); - llvm::for_each(attr, [&](Attribute elem) { - if (auto intElem = dyn_cast(elem)) - size.push_back(intElem.getInt()); - }); + for (auto attr : blockAttr.getValue()) { + size.push_back(cast(attr).getInt()); + } } return size; } @@ -289,7 +295,7 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m // its memory layout tuple is ([2,32,16,8],[128,256,1,16]) // for mem_desc<256x32xf16, @block=[8, 16]> with default @stride[32, 1] // its memory layout tuple is ([32,2,8,16],[256,128,16,1]) - SmallVector getStrides(); + SmallVector getStrideShape(); /// Generates instructions to compute the linearize offset // if the memory descriptor is blocked, it returns linearize offset based on the blocked layout diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index cccc8fab4adbc..78eee0102ba85 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -776,7 +776,7 @@ SmallVector getBlockedOffsets(OpBuilder &builder, Location loc, } // Get strides as vector of integer for MemDesc. -SmallVector MemDescType::getStrides() { +SmallVector MemDescType::getStrideShape() { SmallVector matrixShape(getShape().begin(), getShape().end()); @@ -786,7 +786,7 @@ SmallVector MemDescType::getStrides() { strides.push_back(cast(attr).getInt()); } - SmallVector innerBlkShape = getBlockSize(); + SmallVector innerBlkShape = getBlockShape(); // get perm from FCD to LCD // perm[i] = the dim with i-th smallest stride @@ -837,8 +837,8 @@ Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc, ArrayRef offsets) { SmallVector matrixShape(getShape().begin(), getShape().end()); - SmallVector blockShape = getBlockSize(); - SmallVector strides = getStrides(); + SmallVector blockShape = getBlockShape(); + SmallVector strides = getStrideShape(); // blockshape equal to matrixshape means no blocking if (llvm::equal(blockShape, matrixShape)) { diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 8d86e78fcbf4f..8c7a686b8ce0d 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -201,10 +201,10 @@ IsValidStoreMatrixParams(VectorType dataTy, MemDescType mdescTy, return emitError() << "data shape must not exceed mem_desc shape."; } else if (dataShape.size() == 1) { - SmallVector blockSize = mdescTy.getBlockSize(); + SmallVector blockShape = mdescTy.getBlockShape(); // if the subgroup_block_io attribute is set, mdescTy must have block // attribute - if (subgroup_block_io && !blockSize.size()) + if (subgroup_block_io && !blockShape.size()) return emitError() << "mem_desc must have block attribute when " "subgroup_block_io is set."; // if the subgroup_block_io attribute is set, the memdesc should be row diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index 6d17b27849a43..aafa1b7deb84b 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -942,6 +942,8 @@ struct UnrollLoadMatrixOp : public UnrollPattern { PatternRewriter &rewriter) const override { Location loc = op.getLoc(); VectorType valueTy = llvm::dyn_cast(op.getType()); + assert(valueTy && "the value type must be vector type!"); + std::optional> targetShape = getTargetShape(op); if (!targetShape || targetShape->size() != (size_t)valueTy.getRank()) return failure(); @@ -985,6 +987,7 @@ struct UnrollStoreMatrixOp : public UnrollPattern { Location loc = op.getLoc(); VectorType valueTy = llvm::dyn_cast(op.getData().getType()); + assert(valueTy && "the value type must be vector type!"); ArrayRef shape = valueTy.getShape(); auto layout = dyn_cast(op.getLayoutAttr()); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index baee57c512ddf..31a967dcd04c7 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -992,6 +992,7 @@ struct WgToSgLoadMatrixOp : public OpConversionPattern { ArrayRef wgShape = op.getDataShape(); VectorType valueTy = llvm::dyn_cast(op.getRes().getType()); + assert(valueTy && "the value type must be vector type!"); Type elemTy = valueTy.getElementType(); xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir index ebb3c2b2b6a83..df1433e7b98ae 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir @@ -198,4 +198,4 @@ gpu.module @test_kernel [#xevm.target] { gpu.return %1: vector<8xf16> } -} \ No newline at end of file +} From b1857a275d7e30a55ac9b17b335f61f556b2e695 Mon Sep 17 00:00:00 2001 From: Jianhui Li Date: Tue, 14 Oct 2025 01:05:09 +0000 Subject: [PATCH 09/12] address more comments --- .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 134 ++++++++---------- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 14 +- .../XeGPUToXeVM/loadstore_matrix.mlir | 18 +-- 3 files changed, 77 insertions(+), 89 deletions(-) diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 2ff2c98d291d2..e5e797a1fa1c3 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -365,10 +365,11 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern { // Add a builder that creates // offset * elemByteSize + baseAddr -static Value addOffset(ConversionPatternRewriter &rewriter, Location loc, - Value baseAddr, Value offset, int64_t elemByteSize) { +static Value addOffsetToBaseAddr(ConversionPatternRewriter &rewriter, + Location loc, Value baseAddr, Value offset, + int64_t elemByteSize) { Value byteSize = arith::ConstantIntOp::create( - rewriter, loc, rewriter.getI64Type(), elemByteSize); + rewriter, loc, baseAddr.getType(), elemByteSize); Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize); Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset); return newAddr; @@ -443,7 +444,8 @@ class LoadStoreToXeVMPattern : public OpConversionPattern { // If offset is provided, we add them to the base pointer. // Offset is in number of elements, we need to multiply by // element byte size. - basePtrI64 = addOffset(rewriter, loc, basePtrI64, offset, elemByteSize); + basePtrI64 = + addOffsetToBaseAddr(rewriter, loc, basePtrI64, offset, elemByteSize); } // Convert base pointer (i64) to LLVM pointer type. Value basePtrLLVM = @@ -516,7 +518,7 @@ class CreateMemDescOpPattern final LogicalResult matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - TypedValue src = op.getSource(); + auto resTy = cast(op.getResult().getType()); // Create the result MemRefType with the same shape, element type, and @@ -525,7 +527,7 @@ class CreateMemDescOpPattern final Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0); auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy, - Value(src), zero, ValueRange()); + op.getSource(), zero, ValueRange()); rewriter.replaceOp(op, viewOp); return success(); } @@ -587,88 +589,74 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern { Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create( rewriter, loc, basePtrStruct); - // Convert base pointer (ptr) to i64 - Value basePtrI64 = arith::IndexCastUIOp::create( - rewriter, loc, rewriter.getI64Type(), basePtrLLVM); + // Convert base pointer (ptr) to i32 + Value basePtrI32 = arith::IndexCastUIOp::create( + rewriter, loc, rewriter.getI32Type(), basePtrLLVM); Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets); linearOffset = arith::IndexCastUIOp::create( - rewriter, loc, rewriter.getI64Type(), linearOffset); - basePtrI64 = - addOffset(rewriter, loc, basePtrI64, linearOffset, elemByteSize); + rewriter, loc, rewriter.getI32Type(), linearOffset); + basePtrI32 = addOffsetToBaseAddr(rewriter, loc, basePtrI32, linearOffset, + elemByteSize); - // convert base pointer (i64) to LLVM pointer type + // convert base pointer (i32) to LLVM pointer type basePtrLLVM = - LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64); + LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32); - // if the size of valOrResVecTy is 1, it lowers to a scalar load/store - // operation. LLVM load/store does not support vector of size 1, so we need - // to handle this case separately. - if (valOrResVecTy.getNumElements() == 1) { - Type scalarTy = valOrResVecTy.getElementType(); - if constexpr (std::is_same_v) { - Value loadOp = - LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM); - rewriter.replaceOp(op, loadOp); - } else { - LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM); - rewriter.eraseOp(op); - } - return success(); - } else { + if (op.getSubgroupBlockIoAttr()) { // if the attribute 'subgroup_block_io' is set to true, it lowers to // xevm.blockload - auto subgroupBlockIoAttr = op.getSubgroupBlockIoAttr(); - bool subgroup_block_io = static_cast(subgroupBlockIoAttr); - - // BlockLoadOp only supports integer types, so we need to bitcast - // Get integer type with matching bit width - Type elemTy = valOrResVecTy.getElementType(); - int64_t bitWidth = elemTy.getIntOrFloatBitWidth(); - Type intElemTy = rewriter.getIntegerType(bitWidth); + + Type intElemTy = rewriter.getIntegerType(elemBitWidth); VectorType intVecTy = VectorType::get(valOrResVecTy.getShape(), intElemTy); - if (subgroup_block_io) { - if constexpr (std::is_same_v) { - Value loadOp = - xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM); - if (intVecTy != valOrResVecTy) { - loadOp = - vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp); - } - rewriter.replaceOp(op, loadOp); - } else { - Value dataToStore = adaptor.getData(); - if (valOrResVecTy != intVecTy) { - dataToStore = - vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore); - } - xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore, - nullptr); - rewriter.eraseOp(op); + if constexpr (std::is_same_v) { + Value loadOp = + xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM); + if (intVecTy != valOrResVecTy) { + loadOp = + vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp); } + rewriter.replaceOp(op, loadOp); } else { - // if the result is 1D vector, if the vector direction is Column, then - // the - // memory descriptor should be treated as column major - auto chipOpt = xegpu::getChipStr(op); - if (!chipOpt || (*chipOpt != "pvc" && *chipOpt != "bmg")) { - // the lowering only works for pvc and bmg - return rewriter.notifyMatchFailure( - op, "The lowering is specific to pvc or bmg."); + Value dataToStore = adaptor.getData(); + if (valOrResVecTy != intVecTy) { + dataToStore = + vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore); } + xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore, + nullptr); + rewriter.eraseOp(op); + } + return success(); + } - if constexpr (std::is_same_v) { - Value loadOp = - LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM); - rewriter.replaceOp(op, loadOp); - } else { - LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM); - rewriter.eraseOp(op); - } + if (valOrResVecTy.getNumElements() >= 1) { + auto chipOpt = xegpu::getChipStr(op); + if (!chipOpt || (*chipOpt != "pvc" && *chipOpt != "bmg")) { + // the lowering for chunk load only works for pvc and bmg + return rewriter.notifyMatchFailure( + op, "The lowering is specific to pvc or bmg."); } } + + if constexpr (std::is_same_v) { + // if the size of valOrResVecTy is 1, it lowers to a scalar load/store + // operation. LLVM load/store does not support vector of size 1, so we + // need to handle this case separately. + auto scalarTy = valOrResVecTy.getElementType(); + LLVM::LoadOp loadOp; + if (valOrResVecTy.getNumElements() == 1) + loadOp = LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM); + else + loadOp = + LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM); + rewriter.replaceOp(op, loadOp); + } else { + LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM); + rewriter.eraseOp(op); + } return success(); } }; @@ -715,8 +703,8 @@ class PrefetchToXeVMPattern : public OpConversionPattern { op, "Expected element type bit width to be multiple of 8."); elemByteSize = elemBitWidth / 8; } - basePtrI64 = - addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize); + basePtrI64 = addOffsetToBaseAddr(rewriter, loc, basePtrI64, offsets, + elemByteSize); } } // Default memory space is global. diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 8c7a686b8ce0d..7108afffe99d5 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -174,9 +174,9 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy, } LogicalResult -IsValidStoreMatrixParams(VectorType dataTy, MemDescType mdescTy, - UnitAttr subgroup_block_io, - function_ref emitError) { +IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy, + UnitAttr subgroup_block_io, + function_ref emitError) { if (!dataTy) { if (subgroup_block_io) @@ -1107,8 +1107,8 @@ LogicalResult LoadMatrixOp::verify() { UnitAttr subgroup_block_io = getSubgroupBlockIoAttr(); MemDescType mdescTy = getMemDesc().getType(); - return IsValidStoreMatrixParams(resTy, mdescTy, subgroup_block_io, - [&]() { return emitError(); }); + return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io, + [&]() { return emitError(); }); } //===----------------------------------------------------------------------===// @@ -1131,8 +1131,8 @@ LogicalResult StoreMatrixOp::verify() { auto dataTy = dyn_cast(getData().getType()); UnitAttr subgroup_block_io = getSubgroupBlockIoAttr(); MemDescType mdescTy = getMemDesc().getType(); - return IsValidStoreMatrixParams(dataTy, mdescTy, subgroup_block_io, - [&]() { return emitError(); }); + return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io, + [&]() { return emitError(); }); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir index df1433e7b98ae..d4cb493271d0d 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir @@ -11,8 +11,8 @@ gpu.module @test_kernel [#xevm.target] { //CHECK: %[[TID:.*]] = gpu.thread_id x //CHECK: %[[C1:.*]] = arith.constant 1 : index //CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index - //CHECK: %[[C4:.*]] = arith.constant 4 : i64 - //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i64 + //CHECK: %[[C4:.*]] = arith.constant 4 : i32 + //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i32 //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32 %tid_x = gpu.thread_id x @@ -80,7 +80,7 @@ gpu.module @test_kernel [#xevm.target] { %c19 = arith.constant 19: index //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index - //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i64 + //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32 //CHECK: %[[c16:.*]] = arith.constant 16 : index //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index @@ -164,7 +164,7 @@ gpu.module @test_kernel [#xevm.target] { %c48 = arith.constant 48 : index //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index - //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i64 + //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32 //CHECK: %[[offset0:.*]] = arith.divsi %[[c16]], %[[c16]] : index //CHECK: %[[offset1:.*]] = arith.remsi %[[c16]], %[[c16]] : index //CHECK: %[[offset2:.*]] = arith.divsi %[[c48]], %[[c16]] : index @@ -180,11 +180,11 @@ gpu.module @test_kernel [#xevm.target] { //CHECK: %[[c1:.*]] = arith.constant 1 : index //CHECK: %[[mul3:.*]] = arith.muli %[[offset3]], %[[c1]] : index //CHECK: %[[linearOffset:.*]] = arith.addi %[[mul3]], %[[add2]] : index - //CHECK: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i64 - //CHECK: %[[c2:.*]] = arith.constant 2 : i64 - //CHECK: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i64 - //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI64]], %[[byteOffset]] : i64 - //CHECK: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i64 to !llvm.ptr<3> + //CHECK: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i32 + //CHECK: %[[c2:.*]] = arith.constant 2 : i32 + //CHECK: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i32 + //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI64]], %[[byteOffset]] : i32 + //CHECK: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i32 to !llvm.ptr<3> //CHECK: %[[loadedI16:.*]] = xevm.blockload %[[ptr]] : (!llvm.ptr<3>) -> vector<8xi16> //CHECK: %[[loaded:.*]] = vector.bitcast %[[loadedI16]] : vector<8xi16> to vector<8xf16> From 7a63d93d076d8b90ff27e3e4f88b008780078f75 Mon Sep 17 00:00:00 2001 From: Jianhui Li Date: Tue, 14 Oct 2025 21:24:07 +0000 Subject: [PATCH 10/12] address more feedback --- .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 2 +- .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 37 ------------------- .../mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 6 +-- .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 15 +------- mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 2 +- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 34 ----------------- mlir/test/Dialect/XeGPU/invalid.mlir | 28 -------------- mlir/test/Dialect/XeGPU/ops.mlir | 21 ----------- 8 files changed, 6 insertions(+), 139 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index 2efd575a652db..19a52317956d2 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -712,7 +712,7 @@ def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> { return getAttrs().contains(name); } - ArrayAttr getStrides() { + ArrayAttr getStrideAttr() { return getAttrs().getAs("stride"); } diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index f41f9e887cff7..73b70da9642e4 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -1392,41 +1392,4 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, let hasVerifier = 1; } -def XeGPU_MemDescSubviewOp: XeGPU_Op<"mem_desc_subview", - [Pure, ViewLikeOpInterface, AllElementTypesMatch<["src", "res"]>]> { - let description = [{ - Creates a subview of a memory descriptor. The resulting memory descriptor can have - a lower rank than the source; in this case, the result dimensions correspond to the - higher-order dimensions of the source memory descriptor. - - Arguments: - - `src` : a memory descriptor. - - `offsets` : the coordinates within the matrix the subview will be created from. - - Results: - - `res` : a memory descriptor with smaller size. - - }]; - let arguments = (ins XeGPU_MemDesc:$src, - Variadic:$offsets, - DenseI64ArrayAttr:$const_offsets); - let results = (outs XeGPU_MemDesc:$res); - let assemblyFormat = [{$src `` custom($offsets, $const_offsets) prop-dict - attr-dict `` `:` qualified(type($src)) `->` qualified(type($res))}]; - let builders = [ - OpBuilder<(ins "Type": $res, "Value":$src, "llvm::ArrayRef": $offsets)> - ]; - - let extraClassDeclaration = [{ - mlir::Value getViewSource() { return getSrc(); } - - SmallVector getMixedOffsets() { - return getMixedValues(getConstOffsets(), getOffsets(), getContext()); - } - }]; - - let hasVerifier = 1; -} - - #endif // MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td index 99526159cd2e7..024ca2023c811 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td @@ -237,10 +237,10 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m return MemDescType::get(getContext(), shape.value_or(getShape()), elementType, getMemLayout()); } - ArrayAttr getStridesAttr() { + ArrayAttr getStrideAttr() { auto layout = getMemLayout(); if (layout && layout.hasAttr("stride")) { - return layout.getStrides(); + return layout.getStrideAttr(); } // derive and return default strides SmallVector defaultStrides; @@ -262,7 +262,7 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m /// Heuristic to determine if the MemDesc uses column-major layout, /// based on the rank and the value of the first stride dimension. bool isColMajor() { - auto dim0 = dyn_cast(getStridesAttr()[0]); + auto dim0 = dyn_cast(getStrideAttr()[0]); return getRank() == 2 && dim0 && dim0.getInt() == 1; } diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index fdd29dd96cd55..9cf963e101816 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -534,18 +534,6 @@ class CreateMemDescOpPattern final } }; -class MemDescSubviewOpPattern final - : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(xegpu::MemDescSubviewOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - return rewriter.notifyMatchFailure( - op, "MemDescSubviewOp are not supported on Xe2/Xe3 architecture."); - } -}; - template ::value>> @@ -1085,8 +1073,7 @@ void mlir::populateXeGPUToXeVMConversionPatterns( typeConverter, patterns.getContext()); patterns.add, LoadStoreMatrixToXeVMPattern, - CreateMemDescOpPattern, MemDescSubviewOpPattern>( - typeConverter, patterns.getContext()); + CreateMemDescOpPattern>(typeConverter, patterns.getContext()); patterns.add(typeConverter, patterns.getContext()); } diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index dc880308e6b3a..1cfae28f31188 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -781,7 +781,7 @@ SmallVector MemDescType::getStrideShape() { SmallVector matrixShape(getShape().begin(), getShape().end()); - ArrayAttr strideAttr = getStridesAttr(); + ArrayAttr strideAttr = getStrideAttr(); SmallVector strides; for (Attribute attr : strideAttr.getValue()) { strides.push_back(cast(attr).getInt()); diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 7108afffe99d5..f2d1b85fa1737 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -1135,40 +1135,6 @@ LogicalResult StoreMatrixOp::verify() { [&]() { return emitError(); }); } -//===----------------------------------------------------------------------===// -// XeGPU_MemDescSubviewOp -//===----------------------------------------------------------------------===// - -void MemDescSubviewOp::build(OpBuilder &builder, OperationState &state, - Type resTy, Value src, - llvm::ArrayRef offsets) { - llvm::SmallVector dynamicOffsets; - llvm::SmallVector staticOffsets; - dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); - auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); - build(builder, state, resTy, src, dynamicOffsets, staticOffsetsAttr); -} - -LogicalResult MemDescSubviewOp::verify() { - MemDescType srcTy = getSrc().getType(); - MemDescType resTy = getRes().getType(); - ArrayRef srcShape = srcTy.getShape(); - ArrayRef resShape = resTy.getShape(); - - if (srcTy.getRank() < resTy.getRank()) - return emitOpError("result rank must not exceed source rank."); - - if (llvm::any_of( - llvm::zip_equal(resShape, srcShape.take_back(resShape.size())), - [](auto p) { return std::get<0>(p) > std::get<1>(p); })) - return emitOpError("result shape must not exceed source shape."); - - if (srcTy.getStridesAttr() != resTy.getStridesAttr()) - return emitOpError("result must inherit the source strides."); - - return success(); -} - } // namespace xegpu } // namespace mlir diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index 6062eba709b88..ebbe3ce0ec0d0 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -913,31 +913,3 @@ func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data return } -// ----- -func.func @mem_desc_subview_size_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) { - // expected-error@+1 {{result shape must not exceed source shape}} - %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<32x16xf16> - return -} - -// ----- -func.func @mem_desc_subview_layout_mismatch(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) { - // expected-error@+1 {{result must inherit the source strides}} - %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> !xegpu.mem_desc<8x16xf16> - return -} - -// ----- -func.func @mem_desc_subview_element_type_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) { - // expected-error@+1 {{failed to verify that all of {src, res} have same element type}} - %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf32, #xegpu.mem_layout> - return -} - -// ----- -func.func @mem_desc_subview_rank_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) { - // expected-error@+1 {{result rank must not exceed source rank}} - %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<4x8x16xf16> - return -} - diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir index f1f5f86d33bc0..0a10f6814ae96 100644 --- a/mlir/test/Dialect/XeGPU/ops.mlir +++ b/mlir/test/Dialect/XeGPU/ops.mlir @@ -895,25 +895,4 @@ gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_ gpu.return } -// CHECK: gpu.func @mem_desc_subview([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>) -gpu.func @mem_desc_subview(%arg0: !xegpu.mem_desc<16x64xf16>) { - //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout> - %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout> - gpu.return -} - -// CHECK: gpu.func @mem_desc_subview_lower_rank([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>) -gpu.func @mem_desc_subview_lower_rank(%arg0: !xegpu.mem_desc<16x64xf16>) { - //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<16xf16, #xegpu.mem_layout> - %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<16xf16, #xegpu.mem_layout> - gpu.return -} - -// CHECK: gpu.func @mem_desc_subview_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) -gpu.func @mem_desc_subview_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) { - //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout> - %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout> - gpu.return -} - } From de87d094a9a309b66138d2a357d1ac73b8270c2b Mon Sep 17 00:00:00 2001 From: Jianhui Li Date: Tue, 14 Oct 2025 22:05:14 +0000 Subject: [PATCH 11/12] address minor comments --- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index f2d1b85fa1737..464a9e2d2a806 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -199,8 +199,7 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy, if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape), [](auto p) { return std::get<0>(p) > std::get<1>(p); })) return emitError() << "data shape must not exceed mem_desc shape."; - } else if (dataShape.size() == 1) { - + } else { SmallVector blockShape = mdescTy.getBlockShape(); // if the subgroup_block_io attribute is set, mdescTy must have block // attribute @@ -212,8 +211,6 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy, if (subgroup_block_io && mdescTy.isColMajor()) return emitError() << "mem_desc should be row major when " "subgroup_block_io is set."; - } else if (dataShape.size() == 0) { - return emitError() << "result shape must not be empty."; } return success(); From faa0bfb3eb6004dcaf33269b9e161051a96baa79 Mon Sep 17 00:00:00 2001 From: Jianhui Li Date: Wed, 15 Oct 2025 23:35:57 +0000 Subject: [PATCH 12/12] address comments --- mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 6 ++++++ mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 14 ++++++++------ mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 2 +- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 73b70da9642e4..426377fcf598f 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -1319,6 +1319,9 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, Arguments: - `mem_desc`: the memory descriptor identifying the SLM region. - `offsets`: the coordinates within the matrix to read from. + - `subgroup_block_io`: [optional] An attribute indicating that the operation can be + lowered to a subgroup block load. When this attribute is present, + the offsets are subgroup-uniform across all lanes. - `layout`: [optional] An attribute for guiding distributions among subgroups and/or work-items. It currently can accept either LayoutAttr or SliceAttr. @@ -1367,6 +1370,9 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, - `mem_desc`: the memory descriptor specifying the SLM region. - `offsets`: the coordinates within the matrix where the data will be written. - `data`: the values to be stored in the matrix. + - `subgroup_block_io`: [optional] An attribute indicating that the operation can be + lowered to a subgroup block store. When this attribute is present, + the offsets are subgroup-uniform across all lanes. - `layout`: [optional] An attribute for guiding distributions among subgroups and/or work-items. It currently can accept either LayoutAttr or SliceAttr. diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td index 024ca2023c811..b1196fbe9c66a 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td @@ -263,10 +263,10 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m /// based on the rank and the value of the first stride dimension. bool isColMajor() { auto dim0 = dyn_cast(getStrideAttr()[0]); - return getRank() == 2 && dim0 && dim0.getInt() == 1; + return getRank() == 2 && dim0.getInt() == 1; } - // get the Blocking shape for a MemDescType, Which is represented + // Get the Blocking shape for a MemDescType, Which is represented // as an attribute in MemDescType. By default it is the shape // of the mdescTy SmallVector getBlockShape() { @@ -284,16 +284,18 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m // Get strides as vector of integer. // If it contains block attribute, the strides are blocked strides. // - // The blocking is applied against the original matrix shape - // so that the linear offset is not impacted by the subview. + // The blocking is applied to the base matrix shape derived from the + // memory descriptor's stride information. If the matrix described by + // the memory descriptor is not contiguous, it is assumed that the base + // matrix is contiguous and follows the same memory layout. // // It first computes the original matrix shape using the stride info, // then computes the number of blocks in each dimension of original shape, // then compute the outer block shape and stride, // then combines the inner and outer block shape and stride - // e.g. for mem_desc<32x256xf16, @block=[16, 8], @strides=[1, 32]> + // e.g. for `mem_desc<32x256xf16, @block=[16, 8], @strides=[1, 32]>` // its memory layout tuple is ([2,32,16,8],[128,256,1,16]) - // for mem_desc<256x32xf16, @block=[8, 16]> with default @stride[32, 1] + // for `mem_desc<256x32xf16, @block=[8, 16]>` with default @stride[32, 1] // its memory layout tuple is ([32,2,8,16],[256,128,16,1]) SmallVector getStrideShape(); diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 9cf963e101816..9ee384e46ef33 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -520,7 +520,7 @@ class CreateMemDescOpPattern final matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto resTy = cast(op.getResult().getType()); + auto resTy = op.getMemDesc(); // Create the result MemRefType with the same shape, element type, and // memory space