diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index 5695d5d515d7f..19a52317956d2 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -712,10 +712,14 @@ def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> { return getAttrs().contains(name); } - ArrayAttr getStrides() { + ArrayAttr getStrideAttr() { return getAttrs().getAs("stride"); } + ArrayAttr getBlockAttr() { + return getAttrs().getAs("block"); + } + }]; } diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 73f9061f5debe..426377fcf598f 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -1298,14 +1298,14 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure, } def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, - AllElementTypesMatch<["mem_desc", "res"]>, - AllRanksMatch<["mem_desc", "res"]>]> { + AllElementTypesMatch<["mem_desc", "res"]>]> { let arguments = (ins XeGPU_MemDesc:$mem_desc, Variadic: $offsets, DenseI64ArrayAttr: $const_offsets, + OptionalAttr:$subgroup_block_io, OptionalAttr:$layout ); - let results = (outs XeGPU_ValueType:$res); + let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$res); let assemblyFormat = [{ $mem_desc `` custom($offsets, $const_offsets) prop-dict attr-dict `` `:` type(operands) `->` type(results) @@ -1319,6 +1319,9 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, Arguments: - `mem_desc`: the memory descriptor identifying the SLM region. - `offsets`: the coordinates within the matrix to read from. + - `subgroup_block_io`: [optional] An attribute indicating that the operation can be + lowered to a subgroup block load. When this attribute is present, + the offsets are subgroup-uniform across all lanes. - `layout`: [optional] An attribute for guiding distributions among subgroups and/or work-items. It currently can accept either LayoutAttr or SliceAttr. @@ -1336,7 +1339,10 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, } ArrayRef getDataShape() { - return getRes().getType().getShape(); + auto resTy = getRes().getType(); + if (auto vecTy = llvm::dyn_cast(resTy)) + return vecTy.getShape(); + return {}; } }]; @@ -1344,13 +1350,13 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, } def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, - AllElementTypesMatch<["mem_desc", "data"]>, - AllRanksMatch<["mem_desc", "data"]>]> { + AllElementTypesMatch<["mem_desc", "data"]>]> { let arguments = (ins - XeGPU_ValueType:$data, + AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$data, XeGPU_MemDesc:$mem_desc, Variadic: $offsets, DenseI64ArrayAttr: $const_offsets, + OptionalAttr:$subgroup_block_io, OptionalAttr:$layout ); let assemblyFormat = [{ $data `,` $mem_desc `` custom($offsets, $const_offsets) @@ -1364,6 +1370,9 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, - `mem_desc`: the memory descriptor specifying the SLM region. - `offsets`: the coordinates within the matrix where the data will be written. - `data`: the values to be stored in the matrix. + - `subgroup_block_io`: [optional] An attribute indicating that the operation can be + lowered to a subgroup block store. When this attribute is present, + the offsets are subgroup-uniform across all lanes. - `layout`: [optional] An attribute for guiding distributions among subgroups and/or work-items. It currently can accept either LayoutAttr or SliceAttr. @@ -1378,7 +1387,10 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, } ArrayRef getDataShape() { - return getData().getType().getShape(); + auto DataTy = getData().getType(); + if (auto vecTy = llvm::dyn_cast(DataTy)) + return vecTy.getShape(); + return {}; } }]; @@ -1386,41 +1398,4 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, let hasVerifier = 1; } -def XeGPU_MemDescSubviewOp: XeGPU_Op<"mem_desc_subview", - [Pure, ViewLikeOpInterface, AllElementTypesMatch<["src", "res"]>]> { - let description = [{ - Creates a subview of a memory descriptor. The resulting memory descriptor can have - a lower rank than the source; in this case, the result dimensions correspond to the - higher-order dimensions of the source memory descriptor. - - Arguments: - - `src` : a memory descriptor. - - `offsets` : the coordinates within the matrix the subview will be created from. - - Results: - - `res` : a memory descriptor with smaller size. - - }]; - let arguments = (ins XeGPU_MemDesc:$src, - Variadic:$offsets, - DenseI64ArrayAttr:$const_offsets); - let results = (outs XeGPU_MemDesc:$res); - let assemblyFormat = [{$src `` custom($offsets, $const_offsets) prop-dict - attr-dict `` `:` qualified(type($src)) `->` qualified(type($res))}]; - let builders = [ - OpBuilder<(ins "Type": $res, "Value":$src, "llvm::ArrayRef": $offsets)> - ]; - - let extraClassDeclaration = [{ - mlir::Value getViewSource() { return getSrc(); } - - SmallVector getMixedOffsets() { - return getMixedValues(getConstOffsets(), getOffsets(), getContext()); - } - }]; - - let hasVerifier = 1; -} - - #endif // MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td index 84902b2039643..b1196fbe9c66a 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td @@ -237,12 +237,11 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m return MemDescType::get(getContext(), shape.value_or(getShape()), elementType, getMemLayout()); } - ArrayAttr getStrides() { + ArrayAttr getStrideAttr() { auto layout = getMemLayout(); if (layout && layout.hasAttr("stride")) { - return layout.getStrides(); + return layout.getStrideAttr(); } - // derive and return default strides SmallVector defaultStrides; llvm::append_range(defaultStrides, getShape().drop_front()); @@ -250,6 +249,63 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m Builder builder(getContext()); return builder.getI64ArrayAttr(defaultStrides); } + + ArrayAttr getBlockAttr() { + auto layout = getMemLayout(); + if (layout && layout.hasAttr("block")) { + return layout.getBlockAttr(); + } + Builder builder(getContext()); + return builder.getI64ArrayAttr({}); + } + + /// Heuristic to determine if the MemDesc uses column-major layout, + /// based on the rank and the value of the first stride dimension. + bool isColMajor() { + auto dim0 = dyn_cast(getStrideAttr()[0]); + return getRank() == 2 && dim0.getInt() == 1; + } + + // Get the Blocking shape for a MemDescType, Which is represented + // as an attribute in MemDescType. By default it is the shape + // of the mdescTy + SmallVector getBlockShape() { + SmallVector size(getShape()); + ArrayAttr blockAttr = getBlockAttr(); + if (!blockAttr.empty()) { + size.clear(); + for (auto attr : blockAttr.getValue()) { + size.push_back(cast(attr).getInt()); + } + } + return size; + } + + // Get strides as vector of integer. + // If it contains block attribute, the strides are blocked strides. + // + // The blocking is applied to the base matrix shape derived from the + // memory descriptor's stride information. If the matrix described by + // the memory descriptor is not contiguous, it is assumed that the base + // matrix is contiguous and follows the same memory layout. + // + // It first computes the original matrix shape using the stride info, + // then computes the number of blocks in each dimension of original shape, + // then compute the outer block shape and stride, + // then combines the inner and outer block shape and stride + // e.g. for `mem_desc<32x256xf16, @block=[16, 8], @strides=[1, 32]>` + // its memory layout tuple is ([2,32,16,8],[128,256,1,16]) + // for `mem_desc<256x32xf16, @block=[8, 16]>` with default @stride[32, 1] + // its memory layout tuple is ([32,2,8,16],[256,128,16,1]) + SmallVector getStrideShape(); + + /// Generates instructions to compute the linearize offset + // if the memory descriptor is blocked, it returns linearize offset based on the blocked layout + // the strides of memory descriptor is always considered regardless of blocked or not + Value getLinearOffsets(OpBuilder &builder, + Location loc, ArrayRef offsets); + + }]; let hasCustomAssemblyFormat = true; diff --git a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt index 84b25809f1ed0..dd9edc43a1657 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt +++ b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt @@ -21,6 +21,7 @@ add_mlir_conversion_library(MLIRXeGPUToXeVM MLIRIndexDialect MLIRSCFDialect MLIRXeGPUDialect + MLIRXeGPUUtils MLIRPass MLIRTransforms MLIRSCFTransforms diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index ddcbc44f2652a..fcbf66dbe9e45 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -22,6 +22,7 @@ #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/STLExtras.h" @@ -63,6 +64,7 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) { case xegpu::MemorySpace::SLM: return static_cast(xevm::AddrSpace::SHARED); } + llvm_unreachable("Unknown XeGPU memory space"); } // Get same bitwidth flat vector type of new element type. @@ -186,6 +188,7 @@ class CreateNdDescToXeVMPattern int64_t rank = mixedSizes.size(); if (rank != 2) return rewriter.notifyMatchFailure(op, "Expected 2D shape."); + auto sourceTy = source.getType(); auto sourceMemrefTy = dyn_cast(sourceTy); // If source is a memref, we need to extract the aligned pointer as index. @@ -364,10 +367,11 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern { // Add a builder that creates // offset * elemByteSize + baseAddr -static Value addOffset(ConversionPatternRewriter &rewriter, Location loc, - Value baseAddr, Value offset, int64_t elemByteSize) { +static Value addOffsetToBaseAddr(ConversionPatternRewriter &rewriter, + Location loc, Value baseAddr, Value offset, + int64_t elemByteSize) { Value byteSize = arith::ConstantIntOp::create( - rewriter, loc, rewriter.getI64Type(), elemByteSize); + rewriter, loc, baseAddr.getType(), elemByteSize); Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize); Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset); return newAddr; @@ -443,7 +447,8 @@ class LoadStoreToXeVMPattern : public OpConversionPattern { // If offset is provided, we add them to the base pointer. // Offset is in number of elements, we need to multiply by // element byte size. - basePtrI64 = addOffset(rewriter, loc, basePtrI64, offset, elemByteSize); + basePtrI64 = + addOffsetToBaseAddr(rewriter, loc, basePtrI64, offset, elemByteSize); } // Convert base pointer (i64) to LLVM pointer type. Value basePtrLLVM = @@ -506,6 +511,147 @@ class LoadStoreToXeVMPattern : public OpConversionPattern { } }; +// Lower xegpu::CreateMemDescOp to memref::ViewOp. Since SLM access instructions +// on Xe2 and Xe3 operate on 32-bit or 64-bit units, all data types smaller than +// 32 bits will be converted to 32 bits. +class CreateMemDescOpPattern final + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto resTy = op.getMemDesc(); + + // Create the result MemRefType with the same shape, element type, and + // memory space + auto newResTy = getTypeConverter()->convertType(resTy); + + Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0); + auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy, + op.getSource(), zero, ValueRange()); + rewriter.replaceOp(op, viewOp); + return success(); + } +}; + +template ::value>> +class LoadStoreMatrixToXeVMPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + SmallVector offsets = op.getMixedOffsets(); + if (offsets.empty()) + return rewriter.notifyMatchFailure(op, "Expected offset to be provided."); + + auto loc = op.getLoc(); + auto ctxt = rewriter.getContext(); + Value basePtrStruct = adaptor.getMemDesc(); + Value mdescVal = op.getMemDesc(); + // Load result or Store value Type can be vector or scalar. + Value data; + if constexpr (std::is_same_v) + data = op.getResult(); + else + data = adaptor.getData(); + VectorType valOrResVecTy = dyn_cast(data.getType()); + if (!valOrResVecTy) + valOrResVecTy = VectorType::get(1, data.getType()); + + int64_t elemBitWidth = + valOrResVecTy.getElementType().getIntOrFloatBitWidth(); + // Element type must be multiple of 8 bits. + if (elemBitWidth % 8 != 0) + return rewriter.notifyMatchFailure( + op, "Expected element type bit width to be multiple of 8."); + int64_t elemByteSize = elemBitWidth / 8; + + // Default memory space is SLM. + LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get( + ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM)); + + auto mdescTy = cast(mdescVal.getType()); + + Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create( + rewriter, loc, basePtrStruct); + + // Convert base pointer (ptr) to i32 + Value basePtrI32 = arith::IndexCastUIOp::create( + rewriter, loc, rewriter.getI32Type(), basePtrLLVM); + + Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets); + linearOffset = arith::IndexCastUIOp::create( + rewriter, loc, rewriter.getI32Type(), linearOffset); + basePtrI32 = addOffsetToBaseAddr(rewriter, loc, basePtrI32, linearOffset, + elemByteSize); + + // convert base pointer (i32) to LLVM pointer type + basePtrLLVM = + LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32); + + if (op.getSubgroupBlockIoAttr()) { + // if the attribute 'subgroup_block_io' is set to true, it lowers to + // xevm.blockload + + Type intElemTy = rewriter.getIntegerType(elemBitWidth); + VectorType intVecTy = + VectorType::get(valOrResVecTy.getShape(), intElemTy); + + if constexpr (std::is_same_v) { + Value loadOp = + xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM); + if (intVecTy != valOrResVecTy) { + loadOp = + vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp); + } + rewriter.replaceOp(op, loadOp); + } else { + Value dataToStore = adaptor.getData(); + if (valOrResVecTy != intVecTy) { + dataToStore = + vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore); + } + xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore, + nullptr); + rewriter.eraseOp(op); + } + return success(); + } + + if (valOrResVecTy.getNumElements() >= 1) { + auto chipOpt = xegpu::getChipStr(op); + if (!chipOpt || (*chipOpt != "pvc" && *chipOpt != "bmg")) { + // the lowering for chunk load only works for pvc and bmg + return rewriter.notifyMatchFailure( + op, "The lowering is specific to pvc or bmg."); + } + } + + if constexpr (std::is_same_v) { + // if the size of valOrResVecTy is 1, it lowers to a scalar load/store + // operation. LLVM load/store does not support vector of size 1, so we + // need to handle this case separately. + auto scalarTy = valOrResVecTy.getElementType(); + LLVM::LoadOp loadOp; + if (valOrResVecTy.getNumElements() == 1) + loadOp = LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM); + else + loadOp = + LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM); + rewriter.replaceOp(op, loadOp); + } else { + LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM); + rewriter.eraseOp(op); + } + return success(); + } +}; + class PrefetchToXeVMPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -548,8 +694,8 @@ class PrefetchToXeVMPattern : public OpConversionPattern { op, "Expected element type bit width to be multiple of 8."); elemByteSize = elemBitWidth / 8; } - basePtrI64 = - addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize); + basePtrI64 = addOffsetToBaseAddr(rewriter, loc, basePtrI64, offsets, + elemByteSize); } } // Default memory space is global. @@ -786,6 +932,13 @@ struct ConvertXeGPUToXeVMPass auto i32Type = IntegerType::get(&getContext(), 32); return VectorType::get(8, i32Type); }); + // Convert MemDescType into flattened MemRefType for SLM + typeConverter.addConversion([&](xegpu::MemDescType type) -> Type { + Type elemTy = type.getElementType(); + int numElems = type.getNumElements(); + return MemRefType::get(numElems, elemTy, AffineMap(), 3); + }); + typeConverter.addConversion([&](MemRefType type) -> Type { // Convert MemRefType to i64 type. return IntegerType::get(&getContext(), 64); @@ -940,6 +1093,9 @@ void mlir::populateXeGPUToXeVMConversionPatterns( LoadStoreToXeVMPattern, LoadStoreToXeVMPattern>( typeConverter, patterns.getContext()); + patterns.add, + LoadStoreMatrixToXeVMPattern, + CreateMemDescOpPattern>(typeConverter, patterns.getContext()); patterns.add(typeConverter, patterns.getContext()); } diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 9beb22d517473..1cfae28f31188 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -727,6 +727,152 @@ void MemLayoutAttr::print(AsmPrinter &printer) const { } printer << ">"; } +// a helper utility to perform binary operation on OpFoldResult. +// If both a and b are attributes, it will simply return the result. +// Otherwise, the corresponding arith op will be generated, and an +// contant op will be created if one of them is an attribute. +template +OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc, + OpBuilder &builder) { + auto aVal = getValueOrCreateConstantIndexOp(builder, loc, a); + auto bVal = getValueOrCreateConstantIndexOp(builder, loc, b); + return builder.create(loc, aVal, bVal).getResult(); +} + +// a helper utility to perform division operation on OpFoldResult and int64_t. +#define div(a, b) \ + genBinOp(a, builder.getIndexAttr(b), loc, builder) + +// a helper utility to perform reminder operation on OpFoldResult and int64_t. +#define rem(a, b) \ + genBinOp(a, builder.getIndexAttr(b), loc, builder) + +// a helper utility to perform multiply operation on OpFoldResult and int64_t. +#define mul(a, b) \ + genBinOp(a, builder.getIndexAttr(b), loc, builder) + +// a helper utility to perform addition operation on two OpFoldResult. +#define add(a, b) genBinOp(a, b, loc, builder) + +// block the given offsets according to the block shape +// say the original offset is [y, x], and the block shape is [By, Bx], +// then the blocked offset is [y/By, x/Bx, y%By, x%Bx] +SmallVector getBlockedOffsets(OpBuilder &builder, Location loc, + ArrayRef offsets, + ArrayRef blockShape) { + + assert(offsets.size() == blockShape.size() && + "offsets and blockShape must have the same size"); + SmallVector blockedOffsets; + SmallVector divs, rems; + + for (auto [offset, block] : llvm::zip(offsets, blockShape)) { + divs.push_back(div(offset, block)); + rems.push_back(rem(offset, block)); + } + blockedOffsets.append(divs.begin(), divs.end()); + blockedOffsets.append(rems.begin(), rems.end()); + + return blockedOffsets; +} + +// Get strides as vector of integer for MemDesc. +SmallVector MemDescType::getStrideShape() { + + SmallVector matrixShape(getShape().begin(), getShape().end()); + + ArrayAttr strideAttr = getStrideAttr(); + SmallVector strides; + for (Attribute attr : strideAttr.getValue()) { + strides.push_back(cast(attr).getInt()); + } + + SmallVector innerBlkShape = getBlockShape(); + + // get perm from FCD to LCD + // perm[i] = the dim with i-th smallest stride + SmallVector perm = + llvm::to_vector<4>(llvm::seq(0, strides.size())); + llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; }); + + assert(strides[perm[0]] == 1 && "inner most dim must have stride 1"); + + SmallVector innerBlkStride(innerBlkShape.size()); + innerBlkStride[perm[0]] = 1; + for (size_t i = 1; i < perm.size(); ++i) + innerBlkStride[perm[i]] = + innerBlkStride[perm[i - 1]] * innerBlkShape[perm[i - 1]]; + + // compute the original matrix shape using the stride info + // and compute the number of blocks in each dimension + // The shape of highest dim can't be derived from stride info, + // but doesn't impact the stride computation for blocked layout. + SmallVector matrixShapeOrig(matrixShape.size()); + SmallVector BlkShapeOrig(matrixShape.size()); + for (size_t i = 0; i < perm.size() - 1; ++i) { + matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]]; + BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]]; + } + + int64_t innerBlkSize = 1; + for (auto s : innerBlkShape) + innerBlkSize *= s; + + SmallVector outerBlkStride(matrixShape.size()); + outerBlkStride[perm[0]] = innerBlkSize; + for (size_t i = 0; i < perm.size() - 1; ++i) { + outerBlkStride[perm[i + 1]] = + outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]]; + } + + // combine the inner and outer strides + SmallVector blockedStrides; + blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end()); + blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end()); + + return blockedStrides; +} + +// Calculate the linear offset using the blocked offsets and stride +Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc, + ArrayRef offsets) { + + SmallVector matrixShape(getShape().begin(), getShape().end()); + SmallVector blockShape = getBlockShape(); + SmallVector strides = getStrideShape(); + + // blockshape equal to matrixshape means no blocking + if (llvm::equal(blockShape, matrixShape)) { + // remove the outer dims from strides + strides.erase(strides.begin(), strides.begin() + matrixShape.size()); + } else { + assert(offsets.size() == blockShape.size() && + "offsets and blockShape must have the same size"); + // say the original offset is [y, x], and the block shape is [By, Bx], + // then the blocked offset is [y/By, x/Bx, y%By, x%Bx] + SmallVector blockedOffsets; + SmallVector divs, rems; + + for (auto [offset, block] : llvm::zip(offsets, blockShape)) { + divs.push_back(div(offset, block)); + rems.push_back(rem(offset, block)); + } + blockedOffsets.append(divs.begin(), divs.end()); + blockedOffsets.append(rems.begin(), rems.end()); + + offsets = blockedOffsets; + } + + // Start with initial value as matrix descriptor's base offset. + Value linearOffset = arith::ConstantIndexOp::create(builder, loc, 0); + for (size_t i = 0; i < offsets.size(); ++i) { + OpFoldResult mulResult = mul(offsets[i], strides[i]); + Value mulVal = getValueOrCreateConstantIndexOp(builder, loc, mulResult); + linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset); + } + + return linearOffset; +} } // namespace xegpu } // namespace mlir diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 81b5788d0b9b4..464a9e2d2a806 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -173,6 +173,49 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy, return success(); } +LogicalResult +IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy, + UnitAttr subgroup_block_io, + function_ref emitError) { + + if (!dataTy) { + if (subgroup_block_io) + return emitError() << "subgroup_block_io " + "are only allowed when result is a 1D VectorType."; + else + return success(); + } + + if (mdescTy.getRank() != 2) + return emitError() << "mem_desc must be 2D."; + + ArrayRef dataShape = dataTy.getShape(); + ArrayRef mdescShape = mdescTy.getShape(); + + if (dataShape.size() == 2) { + if (subgroup_block_io) + return emitError() << "subgroup_block_io " + "are only allowed when result is a 1D VectorType."; + if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape), + [](auto p) { return std::get<0>(p) > std::get<1>(p); })) + return emitError() << "data shape must not exceed mem_desc shape."; + } else { + SmallVector blockShape = mdescTy.getBlockShape(); + // if the subgroup_block_io attribute is set, mdescTy must have block + // attribute + if (subgroup_block_io && !blockShape.size()) + return emitError() << "mem_desc must have block attribute when " + "subgroup_block_io is set."; + // if the subgroup_block_io attribute is set, the memdesc should be row + // major + if (subgroup_block_io && mdescTy.isColMajor()) + return emitError() << "mem_desc should be row major when " + "subgroup_block_io is set."; + } + + return success(); +} + //===----------------------------------------------------------------------===// // XeGPU_CreateNdDescOp //===----------------------------------------------------------------------===// @@ -1049,23 +1092,20 @@ void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res, llvm::SmallVector staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); + // Call the generated builder with all parameters (including optional ones as + // nullptr/empty) build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr, - layout); + /*subgroup_block_io=*/nullptr, layout); } LogicalResult LoadMatrixOp::verify() { - VectorType resTy = getRes().getType(); - MemDescType mdescTy = getMemDesc().getType(); - if (mdescTy.getRank() != 2) - return emitOpError("mem_desc must be 2D."); + auto resTy = dyn_cast(getRes().getType()); + UnitAttr subgroup_block_io = getSubgroupBlockIoAttr(); + MemDescType mdescTy = getMemDesc().getType(); - ArrayRef valueShape = resTy.getShape(); - ArrayRef mdescShape = mdescTy.getShape(); - if (llvm::any_of(llvm::zip_equal(valueShape, mdescShape), - [](auto p) { return std::get<0>(p) > std::get<1>(p); })) - return emitOpError("result shape must not exceed mem_desc shape."); - return success(); + return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io, + [&]() { return emitError(); }); } //===----------------------------------------------------------------------===// @@ -1080,57 +1120,16 @@ void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data, dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr, - layout); + /*subgroup_block_io=*/nullptr, layout); } LogicalResult StoreMatrixOp::verify() { - VectorType dataTy = getData().getType(); - MemDescType mdescTy = getMemDesc().getType(); - - if (mdescTy.getRank() != 2) - return emitOpError("mem_desc must be 2D."); - - ArrayRef dataShape = dataTy.getShape(); - ArrayRef mdescShape = mdescTy.getShape(); - if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape), - [](auto p) { return std::get<0>(p) > std::get<1>(p); })) - return emitOpError("data shape must not exceed mem_desc shape."); - - return success(); -} - -//===----------------------------------------------------------------------===// -// XeGPU_MemDescSubviewOp -//===----------------------------------------------------------------------===// - -void MemDescSubviewOp::build(OpBuilder &builder, OperationState &state, - Type resTy, Value src, - llvm::ArrayRef offsets) { - llvm::SmallVector dynamicOffsets; - llvm::SmallVector staticOffsets; - dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); - auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); - build(builder, state, resTy, src, dynamicOffsets, staticOffsetsAttr); -} - -LogicalResult MemDescSubviewOp::verify() { - MemDescType srcTy = getSrc().getType(); - MemDescType resTy = getRes().getType(); - ArrayRef srcShape = srcTy.getShape(); - ArrayRef resShape = resTy.getShape(); - if (srcTy.getRank() < resTy.getRank()) - return emitOpError("result rank must not exceed source rank."); - - if (llvm::any_of( - llvm::zip_equal(resShape, srcShape.take_back(resShape.size())), - [](auto p) { return std::get<0>(p) > std::get<1>(p); })) - return emitOpError("result shape must not exceed source shape."); - - if (srcTy.getStrides() != resTy.getStrides()) - return emitOpError("result must inherit the source strides."); - - return success(); + auto dataTy = dyn_cast(getData().getType()); + UnitAttr subgroup_block_io = getSubgroupBlockIoAttr(); + MemDescType mdescTy = getMemDesc().getType(); + return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io, + [&]() { return emitError(); }); } } // namespace xegpu diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index a178d0fe4b0b0..aafa1b7deb84b 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -941,7 +941,9 @@ struct UnrollLoadMatrixOp : public UnrollPattern { LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - VectorType valueTy = op.getType(); + VectorType valueTy = llvm::dyn_cast(op.getType()); + assert(valueTy && "the value type must be vector type!"); + std::optional> targetShape = getTargetShape(op); if (!targetShape || targetShape->size() != (size_t)valueTy.getRank()) return failure(); @@ -984,7 +986,8 @@ struct UnrollStoreMatrixOp : public UnrollPattern { return failure(); Location loc = op.getLoc(); - VectorType valueTy = op.getData().getType(); + VectorType valueTy = llvm::dyn_cast(op.getData().getType()); + assert(valueTy && "the value type must be vector type!"); ArrayRef shape = valueTy.getShape(); auto layout = dyn_cast(op.getLayoutAttr()); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index c28d2fc6c2b63..31a967dcd04c7 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -991,7 +991,8 @@ struct WgToSgLoadMatrixOp : public OpConversionPattern { return failure(); ArrayRef wgShape = op.getDataShape(); - VectorType valueTy = op.getRes().getType(); + VectorType valueTy = llvm::dyn_cast(op.getRes().getType()); + assert(valueTy && "the value type must be vector type!"); Type elemTy = valueTy.getElementType(); xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); diff --git a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir index e6f22f0a9acbb..a9ab0be00722c 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir @@ -1,17 +1,13 @@ // RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s -#sg_map_a_f16 = #xegpu.layout -#sg_map_b_f16 = #xegpu.layout -#sg_map_c_f32 = #xegpu.layout - -gpu.module @load_store_check { +gpu.module @test_kernel { // CHECK-LABEL: func.func @dpas( // CHECK-SAME: %[[ARG0:.*]]: vector<8xf16>, %[[ARG1:.*]]: vector<16xf16>, %[[ARG2:.*]]: vector<8xf32> func.func @dpas(%a_loaded: vector<8xf16>, %b_loaded: vector<16xf16>, %c_loaded: vector<8xf32>) -> vector<8xf32> { // Loads are checked in a separate test. // CHECK: %[[D:.*]] = xevm.mma %[[ARG0]], %[[ARG1]], %[[ARG2]] {shape = , types = } // CHECK-SAME: : (vector<8xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32> - %d = xegpu.dpas %a_loaded, %b_loaded, %c_loaded {a_layout = #sg_map_a_f16, b_layout = #sg_map_b_f16, c_layout = #sg_map_c_f32} + %d = xegpu.dpas %a_loaded, %b_loaded, %c_loaded : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32> return %d : vector<8xf32> } diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir new file mode 100644 index 0000000000000..d4cb493271d0d --- /dev/null +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir @@ -0,0 +1,201 @@ +// RUN: mlir-opt -split-input-file -convert-xegpu-to-xevm -cse %s | FileCheck %s + +gpu.module @test_kernel [#xevm.target] { + + // e.g. for mem_desc<32x32xf16, @strides=[1, 16]> + // its memory layout tuple is (blocked shape = [1,1,32,32],strides=[1024,1024,32,1]) + //CHECK-LABEL: load_store_matrix_1 + gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> f32 { + %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32> + + //CHECK: %[[TID:.*]] = gpu.thread_id x + //CHECK: %[[C1:.*]] = arith.constant 1 : index + //CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index + //CHECK: %[[C4:.*]] = arith.constant 4 : i32 + //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i32 + //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32 + + %tid_x = gpu.thread_id x + %c0 = arith.constant 0 : index + %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32 + + //CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3> + + xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index + + gpu.return %1: f32 + } + +// e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 32]> + // its memory layout tuple is ([2,4,16,16],[256,512,1,16]) + //CHECK-LABEL: load_store_matrix_2 + gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> f16 { + %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> + //CHECK: %[[c0:.*]] = arith.constant 0 : index + //CHECK: %[[tid_x:.*]] = gpu.thread_id x + //CHECK: %[[c13:.*]] = arith.constant 13 : index + //CHECK: %[[c16:.*]] = arith.constant 16 : index + //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c13]], %[[c16]] : index + //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index + //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index + //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index + + //CHECK: %[[c256:.*]] = arith.constant 256 : index + //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index + //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index + //CHECK: %[[c512:.*]] = arith.constant 512 : index + //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index + //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index + //CHECK: %[[c1:.*]] = arith.constant 1 : index + //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index + //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index + //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index + //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index + + //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> f16 + + + %tid_x = gpu.thread_id x + %c13 = arith.constant 13 : index + %1 = xegpu.load_matrix %0[%c13, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index -> f16 + + //CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3> + + xegpu.store_matrix %1, %0[%c13, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index + gpu.return %1: f16 + } + + + // e.g. for mem_desc<32x64xf16, @block=[16, 16]> + // its memory layout tuple is ([2,4,16,16],[1024,256,16,1]) + //CHECK-LABEL: load_store_matrix_3 + gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> f16 { + //CHECK: %[[c0:.*]] = arith.constant 0 : index + //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3> + %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> + + //CHECK: %[[tid_x:.*]] = gpu.thread_id x + //CHECK: %[[c19:.*]] = arith.constant 19 : index + %tid_x = gpu.thread_id x + %c19 = arith.constant 19: index + + //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index + //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32 + //CHECK: %[[c16:.*]] = arith.constant 16 : index + //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index + //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index + //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index + //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index + //CHECK: %[[c1024:.*]] = arith.constant 1024 : index + //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index + //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index + //CHECK: %[[c256:.*]] = arith.constant 256 : index + //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c256]] : index + //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index + //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c16]] : index + //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index + //CHECK: %[[c1:.*]] = arith.constant 1 : index + //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index + //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index + + //CHECK: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16 + %1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index -> f16 + + //CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3> + xegpu.store_matrix %1, %0[%c19, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index + + //CHECK: gpu.return %[[loaded]] : f16 + gpu.return %1: f16 + } + + // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]> + // its memory layout tuple is ([2,4,16,16],[256,512,1,16]) + //CHECK-LABEL: load_store_matrix_4 + gpu.func @load_store_matrix_4(%arg0: memref<4096xi8, 3>) -> vector<8xf16> { + %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> + + //CHECK: %[[c0:.*]] = arith.constant 0 : index + //CHECK: %[[tid_x:.*]] = gpu.thread_id x + + //CHECK: %[[c16:.*]] = arith.constant 16 : index + //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index + //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index + //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index + //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index + + //CHECK: %[[c256:.*]] = arith.constant 256 : index + //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index + //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index + //CHECK: %[[c512:.*]] = arith.constant 512 : index + //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index + //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index + //CHECK: %[[c1:.*]] = arith.constant 1 : index + //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index + //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index + //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index + //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index + + //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> vector<8xf16> + + %tid_x = gpu.thread_id x + %c16 = arith.constant 16 : index + %1 = xegpu.load_matrix %0[%c16, %tid_x] : !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index -> vector<8xf16> + + //CHECK: llvm.store %[[loaded]], {{.*}} : vector<8xf16>, !llvm.ptr<3> + xegpu.store_matrix %1, %0[%c16, %tid_x] : vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index + + gpu.return %1: vector<8xf16> + } + + + // e.g. for mem_desc<32x64xf16, @block=[16, 16]> + // its memory layout tuple is ([2,4,16,16],[1024,256,16,1]) + //CHECK-LABEL: load_store_matrix_5 + gpu.func @load_store_matrix_5(%arg0: memref<4096xi8, 3>) -> vector<8xf16> { + //CHECK: %[[c0:.*]] = arith.constant 0 : index + //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3> + + %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> + + //CHECK: %[[c16:.*]] = arith.constant 16 : index + //CHECK: %[[c48:.*]] = arith.constant 48 : index + + %c16 = arith.constant 16 : index + %c48 = arith.constant 48 : index + + //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index + //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32 + //CHECK: %[[offset0:.*]] = arith.divsi %[[c16]], %[[c16]] : index + //CHECK: %[[offset1:.*]] = arith.remsi %[[c16]], %[[c16]] : index + //CHECK: %[[offset2:.*]] = arith.divsi %[[c48]], %[[c16]] : index + //CHECK: %[[offset3:.*]] = arith.remsi %[[c48]], %[[c16]] : index + //CHECK: %[[c1024:.*]] = arith.constant 1024 : index + //CHECK: %[[mul0:.*]] = arith.muli %[[offset0]], %[[c1024]] : index + //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index + //CHECK: %[[c256:.*]] = arith.constant 256 : index + //CHECK: %[[mul1:.*]] = arith.muli %[[offset2]], %[[c256]] : index + //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index + //CHECK: %[[mul2:.*]] = arith.muli %[[offset1]], %[[c16]] : index + //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index + //CHECK: %[[c1:.*]] = arith.constant 1 : index + //CHECK: %[[mul3:.*]] = arith.muli %[[offset3]], %[[c1]] : index + //CHECK: %[[linearOffset:.*]] = arith.addi %[[mul3]], %[[add2]] : index + //CHECK: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i32 + //CHECK: %[[c2:.*]] = arith.constant 2 : i32 + //CHECK: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i32 + //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI64]], %[[byteOffset]] : i32 + //CHECK: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i32 to !llvm.ptr<3> + //CHECK: %[[loadedI16:.*]] = xevm.blockload %[[ptr]] : (!llvm.ptr<3>) -> vector<8xi16> + //CHECK: %[[loaded:.*]] = vector.bitcast %[[loadedI16]] : vector<8xi16> to vector<8xf16> + + %1 = xegpu.load_matrix %0[%c16, %c48] {subgroup_block_io}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index -> vector<8xf16> + + //CHECK: %[[storeDataI16:.*]] = vector.bitcast %[[loaded]] : vector<8xf16> to vector<8xi16> + //CHECK: xevm.blockstore %[[ptr]], %[[storeDataI16]] : (!llvm.ptr<3>, vector<8xi16>) + + xegpu.store_matrix %1, %0[%c16, %c48] {subgroup_block_io}: vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index + + gpu.return %1: vector<8xf16> + } + +} diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index 228ef69d9a478..ebbe3ce0ec0d0 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -858,7 +858,7 @@ func.func @load_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16> // ----- func.func @load_mem_desc_invalid_result_size(%arg0: !xegpu.mem_desc<16x64xf16>) { - // expected-error@+1 {{result shape must not exceed mem_desc shape}} + // expected-error@+1 {{data shape must not exceed mem_desc shape}} %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> vector<32x16xf16> return } @@ -870,6 +870,14 @@ func.func @load_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>) { return } +// ----- +func.func @load_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>) { + // expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}} + %data2 = xegpu.load_matrix %arg0[8, 8] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16> -> vector<16x16xf16> + return +} + + // ----- func.func @store_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf32>) { // expected-error@+1 {{failed to verify that all of {mem_desc, data} have same element type}} @@ -892,30 +900,16 @@ func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: ve } // ----- -func.func @mem_desc_subview_size_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) { - // expected-error@+1 {{result shape must not exceed source shape}} - %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<32x16xf16> - return -} - -// ----- -func.func @mem_desc_subview_layout_mismatch(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) { - // expected-error@+1 {{result must inherit the source strides}} - %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> !xegpu.mem_desc<8x16xf16> - return -} - -// ----- -func.func @mem_desc_subview_element_type_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) { - // expected-error@+1 {{failed to verify that all of {src, res} have same element type}} - %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf32, #xegpu.mem_layout> +func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) { + // expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}} + xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16> return } // ----- -func.func @mem_desc_subview_rank_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) { - // expected-error@+1 {{result rank must not exceed source rank}} - %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<4x8x16xf16> +func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) { + // expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}} + xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16> return } diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir index bb379024a34d7..0a10f6814ae96 100644 --- a/mlir/test/Dialect/XeGPU/ops.mlir +++ b/mlir/test/Dialect/XeGPU/ops.mlir @@ -825,53 +825,73 @@ gpu.func @create_mem_desc_with_stride() { gpu.return } -// CHECK: gpu.func @load_mem_desc([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>) -gpu.func @load_mem_desc(%arg0: !xegpu.mem_desc<16x64xf16>) { +// CHECK: gpu.func @load_matrix([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>) +gpu.func @load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) { // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16> %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16> gpu.return } -// CHECK: gpu.func @load_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) -gpu.func @load_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) { +// CHECK: gpu.func @load_matrix_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) +gpu.func @load_matrix_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) { // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8x16xf16> %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8x16xf16> gpu.return } +// CHECK: gpu.func @simt_load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) +gpu.func @simt_load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) { + // CHECK: xegpu.load_matrix [[ARG0]][8, 16] : !xegpu.mem_desc<16x64xf16> -> vector<1xf16> + %data = xegpu.load_matrix %arg0[8, 16]: !xegpu.mem_desc<16x64xf16> -> vector<1xf16> + gpu.return +} + +// CHECK: gpu.func @simt_load_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) +gpu.func @simt_load_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) { + // CHECK: xegpu.load_matrix [[ARG0]][8, 16] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8xf16> + %data = xegpu.load_matrix %arg0[8, 16] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8xf16> + gpu.return +} + +// CHECK: gpu.func @simt_load_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) +gpu.func @simt_load_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) { + // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8xf16> + %data = xegpu.load_matrix %arg0[8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8xf16> + gpu.return +} -// CHECK: gpu.func @store_mem_desc([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>, [[ARG1:%.+]]: vector<16x16xf16>) -gpu.func @store_mem_desc(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf16>) { +// CHECK: gpu.func @store_matrix([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>, [[ARG1:%.+]]: vector<16x16xf16>) +gpu.func @store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf16>) { // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16> xegpu.store_matrix %arg1, %arg0[8, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16> gpu.return } -// CHECK: gpu.func @store_mem_desc_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, [[ARG1:%.+]]: vector<16x16xf16>) -gpu.func @store_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, %arg1: vector<16x16xf16>) { +// CHECK: gpu.func @store_matrix_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, [[ARG1:%.+]]: vector<16x16xf16>) +gpu.func @store_matrix_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, %arg1: vector<16x16xf16>) { // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][0, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> xegpu.store_matrix %arg1, %arg0[0, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> gpu.return } -// CHECK: gpu.func @mem_desc_subview([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>) -gpu.func @mem_desc_subview(%arg0: !xegpu.mem_desc<16x64xf16>) { - //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout> - %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout> +// CHECK: gpu.func @simt_store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<1xf16>) { +gpu.func @simt_store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<1xf16>) { + // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 16] : vector<1xf16>, !xegpu.mem_desc<16x64xf16> + xegpu.store_matrix %arg1, %arg0[8, 16]: vector<1xf16>, !xegpu.mem_desc<16x64xf16> gpu.return } -// CHECK: gpu.func @mem_desc_subview_lower_rank([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>) -gpu.func @mem_desc_subview_lower_rank(%arg0: !xegpu.mem_desc<16x64xf16>) { - //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<16xf16, #xegpu.mem_layout> - %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<16xf16, #xegpu.mem_layout> +// CHECK: gpu.func @simt_store_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, %arg1: vector<8xf16>) +gpu.func @simt_store_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, %arg1: vector<8xf16>) { + // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 16] <{subgroup_block_io}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> + xegpu.store_matrix %arg1, %arg0[8, 16] <{subgroup_block_io}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> gpu.return } -// CHECK: gpu.func @mem_desc_subview_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) -gpu.func @mem_desc_subview_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) { - //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout> - %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout> +// CHECK: gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, %arg1: vector<8xf16>) { +gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, %arg1: vector<8xf16>) { + // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] : vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> + xegpu.store_matrix %arg1, %arg0[8, 8] : vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> gpu.return }