diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index 84c1dc1373ee5..42b5b7a0d4e3f 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -42,9 +42,18 @@ def XeGPU_BlockTensorDescAttr: XeGPU_TensorDescAttr<"BlockTensorDesc", "block_td }]; let parameters = (ins - OptionalParameter<"MemorySpaceAttr">: $memory_space, - OptionalParameter<"IntegerAttr", "1">: $array_length, - OptionalParameter<"BoolAttr", "true">: $boundary_check + DefaultValuedParameter< + "MemorySpaceAttr", + "MemorySpaceAttr::get($_ctxt, xegpu::MemorySpace::Global)", + "Data memory location">: $memory_space, + DefaultValuedParameter< + "IntegerAttr", + "IntegerAttr::get(IntegerType::get($_ctxt, 64), 1)", + "Number of continuous blocks to load">: $array_length, + DefaultValuedParameter< + "BoolAttr", + "BoolAttr::get($_ctxt, true)", + "Checking the out of boundary access">: $boundary_check ); let builders = [ @@ -67,8 +76,8 @@ def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scat TensorDesc is located, `Global` device memory or `Shared` local memory. It is default to `Global`. - 2. `chunk_size`: indicates number of contiguous elements accessed for each - offset, default is 1. It is used with `scattered` attr only. + 2. `chunk_size`: Specifies the number of contiguous elements accessed per offset. + The default value is 1. }]; let parameters = (ins @@ -91,6 +100,12 @@ def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scat )> ]; + let extraClassDeclaration = [{ + int64_t getChunkSizeAsInt() { + return getChunkSize().getInt(); + } + }]; + let genVerifyDecl = 1; } diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index daab65ec893b8..bd5ea9fd83781 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -287,7 +287,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [ transpose is another Intel hardware feature, which will do transpose operation when loading the data if the bit width of the data type is fp32 or fp64. It implies that vnni and transpose cannot exit at the - same time. + same time. It is only available to 1D or 2D blocked tensor_desc. In SIMT mode, result vector represents the data to be loaded by each work-item. @@ -343,6 +343,7 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [ by the TensorDesc. It takes a set of optional cache hints for each level of cache, L1, L2 and L3. If hardware does not have a correspoding cache, Corresponding cache hint attribute will be masked. + It is only available to 1D or 2D blocked tensor_desc. In SIMT mode, the input vector represents the data to be stored by each work-item. @@ -757,6 +758,8 @@ def XeGPU_UpdateOffsetOp: XeGPU_Op<"update_offset", let assemblyFormat = [{ $TensorDesc `,` $offsets attr-dict `:` qualified(type($TensorDesc)) `,` type($offsets) }]; + + let hasVerifier = 1; } def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>]> { diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td index 84314875c2ae5..277158ac85409 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td @@ -17,12 +17,12 @@ def XeGPU_IntType: AnyTypeOf<[I1, I8, I16, I32, I64, SI1, SI8, SI16, SI32, SI64, def XeGPU_FloatType: AnyTypeOf<[F16, F32, F64, BF16, TF32]>; def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>; def XeGPU_BaseAddrType: AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64, UI32, I64, I32]>; -def XeGPU_DpasOprType: VectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>; -def XeGPU_DpasResType: VectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>; -def XeGPU_OffsetType: VectorOfRankAndType<[1], [Index]>; -def XeGPU_MaskType: AnyTypeOf<[VectorOfRankAndType<[1], [I1]>, I1]>; -def XeGPU_ValueType: AnyTypeOf<[VectorOfRankAndType<[1,2,3,4], [XeGPU_ScalarType]>, XeGPU_ScalarType]>; -def XeGPU_Vector2DType: VectorOfRankAndType<[2], [XeGPU_ScalarType]>; +def XeGPU_DpasOprType: FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>; +def XeGPU_DpasResType: FixedVectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>; +def XeGPU_OffsetType: FixedVectorOfNonZeroRankOf<[Index]>; +def XeGPU_MaskType: FixedVectorOfNonZeroRankOf<[I1]>; +def XeGPU_ValueType: FixedVectorOfNonZeroRankOf<[XeGPU_ScalarType]>; +def XeGPU_Vector2DType: FixedVectorOfRankAndType<[2], [XeGPU_ScalarType]>; // common base class for types in XeGPU dialect class XeGPUTypeDef traits = [], @@ -118,7 +118,6 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc", ]; let extraClassDeclaration = [{ - using TensorType::clone; using mlir::ShapedType::Trait::getElementTypeBitWidth; using mlir::ShapedType::Trait::getRank; using mlir::ShapedType::Trait::getNumElements; @@ -157,6 +156,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc", return MemorySpace::Global; } + // get the ArrayLength for blocked TensorDesc int getArrayLength() { auto attr = getEncoding(); auto block_attr = mlir::dyn_cast_if_present(attr); @@ -181,13 +181,12 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc", return bool(getEncodingAsScatterTensorDescAttr()); } - int getChunkSize() { + // get the ChunkSize for scattered TensorDesc + int getChunkSizeAsInt() { auto attr = getEncoding(); auto scatter_attr = mlir::dyn_cast_if_present(attr); - assert((!attr || scatter_attr) && "invalid on non ScatterTensorDescAttr."); - if (scatter_attr) - return scatter_attr.getChunkSize().getInt(); - return 1; + assert(scatter_attr && "invalid on non ScatterTensorDescAttr."); + return scatter_attr.getChunkSizeAsInt(); } /// Helper to drop all layout information from the TensorDesc type. diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 649e0d453015f..0f9cd95cf63ca 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -129,9 +129,7 @@ LogicalResult ScatterTensorDescAttr::verify( llvm::function_ref emitError, MemorySpaceAttr memory_space, IntegerAttr chunk_size) { int64_t chunkSize = chunk_size.getInt(); - SmallVector supportedChunkSizes = {1, 2, 3, 4, 8, - 16, 32, 64, 128, 256}; - if (!llvm::is_contained(supportedChunkSizes, chunkSize)) + if (chunkSize <= 0) return emitError() << "invalid chunk size"; return success(); @@ -310,15 +308,16 @@ LogicalResult TensorDescType::verify( llvm::ArrayRef shape, mlir::Type elementType, mlir::Attribute encoding, mlir::Attribute layout) { size_t rank = shape.size(); - if (rank != 1 && rank != 2) - return emitError() << "expected 1D or 2D tensor"; + + if (rank == 0) + return emitError() << "expected non-zero rank tensor"; auto blockAttr = mlir::dyn_cast_if_present(encoding); if (blockAttr) { MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace(); - if (rank == 2 && memorySpaceAttr && + if (rank > 1 && memorySpaceAttr && memorySpaceAttr.getValue() == MemorySpace::SLM) - return emitError() << "SLM is not supported for 2D block tensor"; + return emitError() << "SLM is only supported for 1D block tensor"; } // for gather and scatter ops, Low-precision types are packed in 32-bit units. @@ -329,22 +328,18 @@ LogicalResult TensorDescType::verify( : 1; auto scatterAttr = mlir::dyn_cast_if_present(encoding); if (scatterAttr) { - // Expected tensor ranks for scattered data: - // - 1D tensor for fully non-contiguous elements (chunk size == 1) - // - 2D tensor for scattered blocks (chunk size > 1) - unsigned chunkSize = scatterAttr.getChunkSize().getInt(); + int64_t chunkSize = scatterAttr.getChunkSizeAsInt(); if (rank == 1 && chunkSize != 1) return emitError() << "expected non-contiguous elements for 1D tensor"; - if (rank == 2 && chunkSize < 2) - return emitError() << "expected chunk blocks for 2D tensor"; + // If chunk size > 1, the second dimension of the tensor shape must be - // equal to chunk size and it must be a multiple of the packing factor. + // equal to chunk size and it must be a multiple of the + // chunkAlignmentFactor. if (chunkSize > 1) { if (shape.back() != chunkSize) - return emitError() << "expected tensor shape[1] to match chunk size"; + return emitError() << "expected last dim of tensor to match chunk size"; if (shape.back() % chunkAlignmentFactor != 0) - return emitError() << "expected tensor shape[1] to be a multiple of " - "chunk alignment factor " + return emitError() << "expected last dim of tensor to be a multiple of " << chunkAlignmentFactor; } } @@ -357,17 +352,13 @@ LogicalResult TensorDescType::verify( auto laneData = layoutAttr.getLaneData(); if (scatterAttr && laneData) { // Validate subgroup mapping rules for scattered tensors. - // A work-item's slice of the tensor with shape [sg_size] or - // [sg_size, chunk_size] will be [1] or [1, 32/element_ty_bit_width] - // respectively, the mapping should reflect that. This is because each - // work item access data in 32 bit granularity. - - if (rank > 1 && laneData[0] != 1) + // if chunkSize > 1, the last dimension of the tensor should + // be distributed in the units divisible by chunkAlignmentFactor. + int64_t chunkSize = scatterAttr.getChunkSizeAsInt(); + if (chunkSize > 1 && laneData[rank - 1] % chunkAlignmentFactor) return emitError() - << "cannot map over non-contiguous scattered row elements"; - if (laneData[rank - 1] != chunkAlignmentFactor) - return emitError() << "work item data mapping must match the number of " - "contiguous elements"; + << "expected last dim of lane_data to be a multiple of: " + << chunkAlignmentFactor; } if (!XeGPUDialect::isEvenlyDistributable(shape, layoutAttr)) { diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 2793c7a35bc97..15f06a99ced09 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -81,15 +81,18 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy, auto maskShape = getShapeOf(maskTy); auto valueShape = getShapeOf(valueTy); auto tdescShape = getShapeOf(tdescTy); - auto chunkSize = tdescTy.getChunkSize(); + auto chunkSize = tdescTy.getChunkSizeAsInt(); if (valueTy.getElementType() != tdescTy.getElementType()) return emitError() << "Value should have the same element type as TensorDesc."; - if (tdescShape[0] != maskShape[0]) + llvm::SmallVector expectedMaskShape(tdescShape); + if (chunkSize > 1) + expectedMaskShape.pop_back(); + if (expectedMaskShape != maskShape) return emitError() - << "dim-0 of the Mask and TensorDesc should be the same."; + << "Mask should match TensorDesc except the chunk size dim."; // a valid shape for SIMT case if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) { @@ -203,11 +206,9 @@ LogicalResult CreateNdDescOp::verify() { "is a memref) should match with each other."); // check result TensorDesc rank - invalidRank = (getType().getRank() > 2 || getType().getRank() > rank); - - if (invalidRank) + if (getType().getRank() > rank) return emitOpError( - "Expecting the TensorDesc rank is up to 2 and not greater than the " + "Expecting the TensorDesc rank is not greater than the " "ranks of shape, strides, offsets or the memref source."); if (invalidElemTy) @@ -247,12 +248,12 @@ LogicalResult LoadNdOp::verify() { auto tdescTy = getTensorDescType(); auto valueTy = getType(); - if (tdescTy.getRank() > 2) - return emitOpError("Expecting a 1D/2D TensorDesc.\n"); - if (tdescTy.isScattered()) return emitOpError("Expects a non-scattered TensorDesc.\n"); + if (tdescTy.getRank() > 2) + return emitOpError("Expects a 1D or 2D TensorDesc.\n"); + if (!valueTy) return emitOpError("Invalid result, it should be a VectorType.\n"); @@ -316,15 +317,13 @@ LogicalResult LoadNdOp::verify() { } auto array_len = tdescTy.getArrayLength(); - if (array_len > 1) { + if (array_len > 1) tdescShape.insert(tdescShape.begin(), array_len); - } - if (tdescShape != valueShape) { + if (tdescShape != valueShape) return emitOpError() << "Result shape " << makeString(valueShape) << " is not consistent with tensor descriptor " << tdescTy; - } return success(); } @@ -336,12 +335,12 @@ LogicalResult StoreNdOp::verify() { auto dstTy = getTensorDescType(); // Tile auto valTy = getValueType(); // Vector - if (dstTy.getRank() > 2) - return emitOpError("Expecting a 1D/2D TensorDesc.\n"); - if (dstTy.isScattered()) return emitOpError("Expects a non-scattered TensorDesc.\n"); + if (dstTy.getRank() > 2) + return emitOpError("Expects a 1D or 2D TensorDesc.\n"); + if (!valTy) return emitOpError("Expecting a VectorType result.\n"); @@ -370,22 +369,21 @@ LogicalResult StoreNdOp::verify() { return emitOpError() << "TensorDesc doesn't need LayoutAttr for SIMT code"; - if (tdescElems % valueElems) { + if (tdescElems % valueElems) return emitOpError() << "Value shape " << makeString(getShapeOf(valTy)) << " is not a valid distribution for tensor descriptor " << dstTy; - } + return success(); } // SIMD code should have the same shape as the tensor descriptor. auto tdescShape = getShapeOf(dstTy); auto valueShape = getShapeOf(valTy); - if (tdescShape != valueShape) { + if (tdescShape != valueShape) return emitOpError() << "Value shape " << makeString(valueShape) << " is not consistent with tensor descriptor " << dstTy; - } return success(); } @@ -449,25 +447,8 @@ LogicalResult CreateDescOp::verify() { << ", TensorDesc: " << tdescMemorySpace; // check total size - auto chunkSize = tdescTy.getChunkSize(); - auto elemBits = tdescTy.getElementType().getIntOrFloatBitWidth(); - auto bitsPerLane = elemBits * chunkSize; - if (chunkSize > 1 && bitsPerLane % 32) { - // For 8-bit and 16-bit data, the hardware only supports chunk size of 1. - // For 32-bit data, the hardware can support larger larger chunk size. So - // we can bitcast 8-bit/16-bit data to 32-bit data for better performance. - // But this requires the total size is 32 bit aligned to make the - // optimization work. - return emitOpError( - "access size (chunk_size * sizeof(elemTy)) should be 32-bit aligned."); - } - - auto lscConstraints = 512 * 8; // each access is upto 512 bytes. - if (elemBits * tdescTy.getNumElements() > lscConstraints) - return emitOpError("total access size (simd_lanes * chunk_size * " - "sizeof(elemTy)) is upto 512 bytes."); - - SmallVector shape({(int64_t)getNumOffsets()}); + auto chunkSize = tdescTy.getChunkSizeAsInt(); + SmallVector shape(getOffsetsType().getShape()); if (chunkSize != 1) shape.push_back(chunkSize); @@ -563,6 +544,23 @@ void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state, build(builder, state, tensorDesc, ofrs); } +LogicalResult UpdateOffsetOp::verify() { + auto tdescTy = getTensorDescType(); + if (!tdescTy.isScattered()) + return emitOpError("Expects a scattered TensorDesc.\n"); + + SmallVector expectedOffsetShape = getShapeOf(tdescTy); + SmallVector offsetShape = getShapeOf(getOffsetsType()); + if (tdescTy.getChunkSizeAsInt() > 1) + expectedOffsetShape.pop_back(); + + if (expectedOffsetShape != offsetShape) + return emitOpError( + "Offsets should match TensorDesc except the chunk size dim."); + + return success(); +} + //===----------------------------------------------------------------------===// // XeGPU_DpasOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp index 3950e8f70d1ca..ddc9e0eb908ac 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp @@ -303,9 +303,7 @@ void XeGPUBlockingPass::runOnOperation() { // If the encoding is a ScatterTensorDescAttr, we need to // potentially adjust the chunk size based on the inst_data. if (tdescTy.isScattered()) { - auto scatterAttr = - llvm::dyn_cast_if_present(encoding); - int64_t chunkSize = scatterAttr.getChunkSize().getInt(); + int64_t chunkSize = tdescTy.getChunkSizeAsInt(); if (chunkSize > 1) { int64_t blockedChunkSize = chunkSize; @@ -315,7 +313,7 @@ void XeGPUBlockingPass::runOnOperation() { // To create a new attribute with a different chunk_size: auto newEncoding = xegpu::ScatterTensorDescAttr::get( - ctx, scatterAttr.getMemorySpace().getValue(), blockedChunkSize); + ctx, tdescTy.getMemorySpace(), blockedChunkSize); encoding = newEncoding; } diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index 2c48a735bf956..13d49cb0e9d82 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -413,7 +413,7 @@ struct UnrollCreateDescOp : public UnrollPattern { return failure(); SmallVector targetIndiceShape(*targetShape); - int64_t originalChunkSize = tdescTy.getChunkSize(); + int64_t originalChunkSize = tdescTy.getChunkSizeAsInt(); // IndiceVec is 1 dim lower than tdescTy when chunkSize is larger than 1. if (originalChunkSize > 1) targetIndiceShape.pop_back(); @@ -480,7 +480,7 @@ struct UnrollLoadGatherOp : public UnrollPattern { return failure(); SmallVector targetMaskShape(*targetShape); - int64_t originalChunkSize = tdescTy.getChunkSize(); + int64_t originalChunkSize = tdescTy.getChunkSizeAsInt(); VectorType maskTy = llvm::dyn_cast(op.getMask().getType()); @@ -571,7 +571,7 @@ struct UnrollStoreScatterOp : public UnrollPattern { return failure(); SmallVector targetMaskShape(*targetShape); - int64_t originalChunkSize = tdescTy.getChunkSize(); + int64_t originalChunkSize = tdescTy.getChunkSizeAsInt(); VectorType maskTy = llvm::dyn_cast(op.getMask().getType()); @@ -625,9 +625,6 @@ struct UnrollUpdateOffsetOp : public UnrollPattern { Location loc = op.getLoc(); xegpu::TensorDescType tdescTy = op.getTensorDescType(); - if (tdescTy.getRank() > 2) - return failure(); - if (!tdescTy.isScattered()) return failure(); @@ -645,7 +642,7 @@ struct UnrollUpdateOffsetOp : public UnrollPattern { SmallVector convertedOffsetTypes; SmallVector convertedOffsetVec; SmallVector newOps; - int64_t originalChunkSize = tdescTy.getChunkSize(); + int64_t originalChunkSize = tdescTy.getChunkSizeAsInt(); if (originalChunkSize > 1) { auto targetOffsetShape = ArrayRef(*targetShape).drop_back(); convertedOffsetTypes = getUnrolledTypes(offsetVecTy, targetOffsetShape); diff --git a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir index 7cef17df79dd2..4af7061a4f8a3 100644 --- a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir @@ -31,7 +31,6 @@ func.func @load_2D_vector(%source: memref<8x16x32xf32>, // CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc // CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] // CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32, -// CHECK-SAME: boundary_check = true // CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32> // CHECK: return %[[VEC]] @@ -57,7 +56,6 @@ func.func @load_dynamic_source(%source: memref, // CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] // CHECK-SAME: [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], [%[[DIM_0_STRIDE]], %[[DIM_2]], 1] // CHECK-SAME: memref -> !xegpu.tensor_desc<8x16xf32, -// CHECK-SAME: boundary_check = true // CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32> // CHECK: return %[[VEC]] @@ -76,7 +74,6 @@ func.func @load_out_of_bounds(%source: memref<7x15xf32>, // CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc // CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]] // CHECK-SAME: memref<7x15xf32> -> !xegpu.tensor_desc<8x16xf32, -// CHECK-SAME: boundary_check = true // CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32> // CHECK: return %[[VEC]] diff --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir index 4f069ebc39db3..d68a02b54e967 100644 --- a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir @@ -33,7 +33,6 @@ func.func @store_2D_vector(%vec: vector<8x16xf32>, // CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc // CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] // CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32, -// CHECK-SAME: boundary_check = true // CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32> // ----- @@ -59,7 +58,6 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>, // CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] // CHECK-SAME: [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], [%[[DIM_0_STRIDE]], %[[DIM_2]], 1] // CHECK-SAME: memref -> !xegpu.tensor_desc<8x16xf32, -// CHECK-SAME: boundary_check = true // CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32> // ----- @@ -78,7 +76,6 @@ func.func @store_out_of_bounds(%vec: vector<8x16xf32>, // CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc // CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]] // CHECK-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32, -// CHECK-SAME: boundary_check = true // CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32> // ----- diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir index 497eb86cea835..c2f760b29afc4 100644 --- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir @@ -52,7 +52,6 @@ func.func @load_zero_pad_out_of_bounds(%source: memref<32x64xf32>, // CHECK-SAME: %[[OFFSET:.+]]: index // CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]]] // CHECK-SAME: memref<32x64xf32> -> !xegpu.tensor_desc<8x16xf32, -// CHECK-SAME: boundary_check = true // CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32> // CHECK: return %[[VEC]] diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir index 91e3fb3841f6e..8de6c2283b37c 100644 --- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir @@ -81,7 +81,6 @@ func.func @store_out_of_bounds(%vec: vector<8x16xf32>, // CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc // CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]] // CHECK-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32, -// CHECK-SAME: boundary_check = true // CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32> // ----- diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index a2778cd94d963..83a98ab0622b7 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -2,7 +2,7 @@ // ----- func.func @create_nd_tdesc_vc_1(%src: memref<24xf32>) { - // expected-error@+1 {{Expecting the TensorDesc rank is up to 2 and not greater than the ranks of shape, strides, offsets or the memref source}} + // expected-error@+1 {{Expecting the TensorDesc rank is not greater than the ranks of shape, strides, offsets or the memref source}} %1 = xegpu.create_nd_tdesc %src[0] : memref<24xf32> -> !xegpu.tensor_desc<8x16xf32> return } @@ -17,7 +17,7 @@ func.func @create_nd_tdesc_vc_2(%src: memref<24x32xf32>) { // ----- func.func @create_nd_tdesc_vc_3(%src: memref<2x24x32xf32, 3>) { - // expected-error@+1 {{SLM is not supported for 2D block tensor}} + // expected-error@+1 {{SLM is only supported for 1D block tensor}} %1 = xegpu.create_nd_tdesc %src[0, 0, 0] : memref<2x24x32xf32, 3> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> return } @@ -109,6 +109,14 @@ func.func @load_nd_vc_4(%src: memref<24x32xf32>) { return } +// ----- +func.func @subgroup_load_nd_9(%src: memref<4x8x16xf16>) { + %1 = xegpu.create_nd_tdesc %src[0, 0, 0] : memref<4x8x16xf16> -> !xegpu.tensor_desc<4x8x16xf16> + // expected-error@+1 {{Expects a 1D or 2D TensorDesc}} + %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<4x8x16xf16> -> vector<4x8x16xf16> + return +} + // ----- func.func @load_nd_layout(%src: memref<24x32xf32>) { %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16xf32> @@ -156,6 +164,15 @@ func.func @store_nd_vc_3(%dst: memref<24x32xf16>) { return } +// ----- +func.func @store_nd_vc_4(%dst: memref<8x24x32xf16>) { + %1 = arith.constant dense<1.0>: vector<8x24x32xf16> + %2 = xegpu.create_nd_tdesc %dst[0, 0, 0] : memref<8x24x32xf16> -> !xegpu.tensor_desc<8x24x32xf16> + // expected-error@+1 {{Expects a 1D or 2D TensorDesc}} + xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: vector<8x24x32xf16>, !xegpu.tensor_desc<8x24x32xf16> + return +} + // ----- func.func @store_nd_simt(%dst: memref<24x32xf32>, %data: vector<3xf32>) { %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16xf32> @@ -200,11 +217,11 @@ func.func @create_tdesc_vc_1(%src: ui64) { } // ----- -func.func @create_tdesc_vc_2(%src: ui64) { - %0 = arith.constant dense<[0, 2, 4, 6, 8, 10, 12, 14]> : vector<8xindex> - %1 = xegpu.create_tdesc %src, %0 : ui64, vector<8xindex> - // expected-error@+1 {{expected chunk blocks for 2D tensor}} - -> !xegpu.tensor_desc<8x4xf16, #xegpu.scatter_tdesc_attr<>> +func.func @create_tdesc_vc_2(%src: memref) { + %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %1 = xegpu.create_tdesc %src, %0 : memref, vector<4xindex> + // expected-error@+1 {{invalid chunk size}} + -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr> return } @@ -221,25 +238,16 @@ func.func @create_tdesc_vc_3(%src: memref) { func.func @create_tdesc_vc_4(%src: memref) { %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> %1 = xegpu.create_tdesc %src, %0 : memref, vector<4xindex> - // expected-error@+1 {{invalid chunk size}} - -> !xegpu.tensor_desc<4x5xf32, #xegpu.scatter_tdesc_attr> - return -} - -// ----- -func.func @create_tdesc_vc_5(%src: memref) { - %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> - %1 = xegpu.create_tdesc %src, %0 : memref, vector<4xindex> - // expected-error@+1 {{expected tensor shape[1] to match chunk size}} + // expected-error@+1 {{expected last dim of tensor to match chunk size}} -> !xegpu.tensor_desc<4x5xf32, #xegpu.scatter_tdesc_attr> return } // ----- -func.func @create_tdesc_vc_6(%src: memref) { +func.func @create_tdesc_vc_5(%src: memref) { %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> %1 = xegpu.create_tdesc %src, %0 : memref, vector<4xindex> - // expected-error@+1 {{tensor shape[1] to be a multiple of chunk alignment factor 2}} + // expected-error@+1 {{last dim of tensor to be a multiple of 2}} -> !xegpu.tensor_desc<4x3xf16, #xegpu.scatter_tdesc_attr> return } @@ -267,23 +275,15 @@ func.func @prefetch_vc_2(%src: ui64) { func.func @create_tdesc_layout_1(%src: ui64) { %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> // expected-error@+1 {{expected layout rank to match tensor rank}} - %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout> + %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout> return } // ----- func.func @create_tdesc_layout_2(%src: ui64) { %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> - // expected-error@+1 {{cannot map over non-contiguous scattered row elements}} - %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> - return -} - -// ----- -func.func @create_tdesc_layout_3(%src: ui64) { - %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> - // expected-error@+1 {{work item data mapping must match the number of contiguous elements}} - %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x3xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> + // expected-error@+1 {{expected last dim of lane_data to be a multiple of: 2}} + %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x4xf16, #xegpu.scatter_tdesc_attr, #xegpu.layout> return } @@ -331,6 +331,19 @@ func.func @load_gather_vc_2(%src: ui64) { return } +// ----- +func.func @load_gather_vc_3(%src: ui64) { + %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %0 = arith.constant dense<1>: vector<8xi1> + %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> + -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr> + // expected-error@+1 {{Mask should match TensorDesc except the chunk size dim}} + %2 = xegpu.load %1, %0 + : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr>, vector<8xi1> + -> vector<4x2xf32> + return +} + // ----- func.func @store_scatter_vc_1(%src: memref<24x32xf32>) { %0 = arith.constant dense<1>: vector<4xi1> @@ -355,6 +368,19 @@ func.func @store_scatter_vc_2(%src: ui64) { return } +// ----- +func.func @store_scatter_vc_3(%src: ui64) { + %cst = arith.constant dense<[0, 8, 16, 24]>: vector<4xindex> + %0 = arith.constant dense<1>: vector<8xi1> + %1 = arith.constant dense<2.9>: vector<4x2xf32> + %2 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> + -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr> + // expected-error@+1 {{Mask should match TensorDesc except the chunk size dim}} + xegpu.store %1, %2, %0 : vector<4x2xf32>, + !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr>, vector<8xi1> + return +} + // ----- func.func @dpas_vc_1(%a : vector<8x8xf16>, %b: vector<8x16x2xf16>) { // expected-error@+1 {{K-dimension mismatch}} @@ -406,18 +432,10 @@ func.func @atomic_rmw(%src: ui64, %value : vector<16x4xf32>, %mask : vector<16xi return } -// ----- -func.func @tensor_desc_invalid_rank(%src: memref<24x32xf32>) { - %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> - // expected-error@+1 {{expected 1D or 2D tensor}} - !xegpu.tensor_desc<16x2x2xf32> - return -} - // ----- func.func @tensor_desc_invalid_rank_1(%src: memref<24x32xf32>) { %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> - // expected-error@+1 {{expected 1D or 2D tensor}} + // expected-error@+1 {{expected non-zero rank tensor}} !xegpu.tensor_desc return } @@ -470,27 +488,6 @@ func.func @tensor_desc_invalid_map_data_1(%src: memref<24x32xf32>) { return } -// ----- -func.func @tensor_desc_scatter_invalid_map_data(%src: ui64) { - %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> - %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> - // expected-error@+1 {{cannot map over non-contiguous scattered row elements}} - !xegpu.tensor_desc<4x2xf32, - #xegpu.scatter_tdesc_attr, - #xegpu.layout> - return -} - -// ----- -func.func @tensor_desc_scatter_invalid_map_data_1(%src: ui64, %offsets: vector<16xindex>) { - %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> -> - // expected-error@+1 {{work item data mapping must match the number of contiguous elements}} - !xegpu.tensor_desc<16xf32, - #xegpu.scatter_tdesc_attr, - #xegpu.layout> - return -} - // ----- func.func @tensor_desc_scatter_invalid_chunk_size_1D(%src: ui64, %offsets: vector<16xindex>) { %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> -> @@ -504,9 +501,9 @@ func.func @tensor_desc_scatter_invalid_chunk_size_1D(%src: ui64, %offsets: vecto // ----- func.func @tensor_desc_scatter_invalid_chunk_size_2D(%src: ui64, %offsets: vector<16xindex>) { %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> -> - // expected-error@+1 {{expected chunk blocks for 2D tensor}} + // expected-error@+1 {{expected last dim of tensor to match chunk size}} !xegpu.tensor_desc<16x2xf32, - #xegpu.scatter_tdesc_attr, + #xegpu.scatter_tdesc_attr, #xegpu.layout> return } diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir index aff8f63adc05b..3bfe1fa81aa6e 100644 --- a/mlir/test/Dialect/XeGPU/ops.mlir +++ b/mlir/test/Dialect/XeGPU/ops.mlir @@ -54,6 +54,13 @@ gpu.func @create_nd_tdesc_6(%src: memref<24x32xf32>) { gpu.return } +// CHECK: gpu.func @create_nd_tdesc_7(%[[arg0:.*]]: memref<8x24x32x48x64xf32>) { +gpu.func @create_nd_tdesc_7(%src: memref<8x24x32x48x64xf32>) { + // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0, 0, 0, 0] : memref<8x24x32x48x64xf32> -> !xegpu.tensor_desc<8x8x8x24x32xf32> + %1 = xegpu.create_nd_tdesc %src[0, 0, 0, 0, 0] : memref<8x24x32x48x64xf32> -> !xegpu.tensor_desc<8x8x8x24x32xf32> + gpu.return +} + // CHECK: gpu.func @prefetch_nd(%[[arg0:.*]]: memref<24x32xf16>) { gpu.func @prefetch_nd(%src: memref<24x32xf16>) { @@ -64,6 +71,14 @@ gpu.func @prefetch_nd(%src: memref<24x32xf16>) { gpu.return } +// CHECK: gpu.func @prefetch_nd_2(%[[arg0:.*]]: memref<8x24x32x48x64xf16>) { +gpu.func @prefetch_nd_2(%src: memref<8x24x32x48x64xf16>) { + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16> + %1 = xegpu.create_nd_tdesc %src[0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16> + // CHECK: xegpu.prefetch_nd %[[R0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<1x2x4x8x16xf16> + xegpu.prefetch_nd %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: !xegpu.tensor_desc<1x2x4x8x16xf16> + gpu.return +} // CHECK: func @subgroup_load_nd(%[[arg0:.*]]: memref<8x16xf16>) { gpu.func @subgroup_load_nd(%src: memref<8x16xf16>) { @@ -266,6 +281,14 @@ gpu.func @update_nd_tdesc(%src: memref<24x32xf32>) { gpu.return } +// CHECK: gpu.func @update_nd_tdesc_2(%[[arg0:.*]]: memref<8x24x32xf32>) { +gpu.func @update_nd_tdesc_2(%src: memref<8x24x32xf32>) { + // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0, 0] : memref<8x24x32xf32> -> !xegpu.tensor_desc<2x8x16xf32> + %1 = xegpu.create_nd_tdesc %src[0, 0, 0] : memref<8x24x32xf32> -> !xegpu.tensor_desc<2x8x16xf32> + // CHECK: %[[R1:.*]] = xegpu.update_nd_offset %[[REG]], [0, 0, 16] : !xegpu.tensor_desc<2x8x16xf32> + %2 = xegpu.update_nd_offset %1, [0, 0, 16]: !xegpu.tensor_desc<2x8x16xf32> + gpu.return +} // CHECK: gpu.func @create_tdesc(%[[arg0:.*]]: ui64) { gpu.func @create_tdesc(%src: ui64) { @@ -291,8 +314,8 @@ gpu.func @create_tdesc_1(%src: memref) { gpu.func @create_tdesc_2(%src: memref) { //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> - //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : memref, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<> - %1 = xegpu.create_tdesc %src, %0 : memref, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr> + //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : memref, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>> + %1 = xegpu.create_tdesc %src, %0 : memref, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>> gpu.return } @@ -306,6 +329,15 @@ gpu.func @create_tdesc_3(%src: ui64) { gpu.return } +// CHECK: gpu.func @create_tdesc_4(%[[arg0:.*]]: ui64) { +gpu.func @create_tdesc_4(%src: ui64) { + //CHECK: %[[cst:.*]] = arith.constant dense<{{.*}}> : vector<2x4xindex> + %0 = arith.constant dense<[[0, 8, 16, 24], [32, 40, 48, 56]]> : vector<2x4xindex> + //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<2x4xindex> -> !xegpu.tensor_desc<2x4x2xf16, #xegpu.scatter_tdesc_attr> + %1 = xegpu.create_tdesc %src, %0 : ui64, vector<2x4xindex> -> !xegpu.tensor_desc<2x4x2xf16, #xegpu.scatter_tdesc_attr> + gpu.return +} + // CHECK: gpu.func @subgroup_load(%[[arg0:.*]]: ui64) { gpu.func @subgroup_load(%src: ui64) { @@ -385,6 +417,19 @@ gpu.func @simt_load_3(%src: ui64) { gpu.return } +// CHECK: gpu.func @subgroup_load_4(%[[arg0:.*]]: ui64) { +gpu.func @subgroup_load_4(%src: ui64) { + //CHECK: %[[cst:.*]] = arith.constant dense<{{.*}}> : vector<2x4xindex> + %0 = arith.constant dense<[[0, 8, 16, 24], [32, 40, 48, 56]]> : vector<2x4xindex> + //CHECK: %[[cst1:.*]] = arith.constant dense : vector<2x4xi1> + %1 = arith.constant dense<1>: vector<2x4xi1> + //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<2x4xindex> -> !xegpu.tensor_desc<2x4x8xf16, #xegpu.scatter_tdesc_attr> + %2 = xegpu.create_tdesc %src, %0 : ui64, vector<2x4xindex> -> !xegpu.tensor_desc<2x4x8xf16, #xegpu.scatter_tdesc_attr> + //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<2x4x8xf16, #xegpu.scatter_tdesc_attr>, vector<2x4xi1> -> vector<2x4x8xf16> + %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<2x4x8xf16, #xegpu.scatter_tdesc_attr>, vector<2x4xi1> -> vector<2x4x8xf16> + gpu.return +} + // CHECK: gpu.func @subgroup_store(%[[arg0:.*]]: ui64) { gpu.func @subgroup_store(%src: ui64) { //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> @@ -475,6 +520,21 @@ gpu.func @simt_store_3(%src: ui64) { gpu.return } +// CHECK: gpu.func @subgroup_store_4(%[[arg0:.*]]: ui64) { +gpu.func @subgroup_store_4(%src: ui64) { + //CHECK: %[[cst:.*]] = arith.constant dense<{{.*}}> : vector<2x4xindex> + %0 = arith.constant dense<[[0, 8, 16, 24], [32, 40, 48, 56]]> : vector<2x4xindex> + //CHECK: %[[cst1:.*]] = arith.constant dense : vector<2x4xi1> + %1 = arith.constant dense<1>: vector<2x4xi1> + //CHECK: %[[cst2:.*]] = arith.constant {{.*}} : vector<2x4xf32> + %2 = arith.constant dense<2.9>: vector<2x4xf32> + //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<2x4xindex> -> !xegpu.tensor_desc<2x4xf32, #xegpu.scatter_tdesc_attr<>> + %3 = xegpu.create_tdesc %src, %0 : ui64, vector<2x4xindex> -> !xegpu.tensor_desc<2x4xf32, #xegpu.scatter_tdesc_attr<>> + //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<2x4xf32>, !xegpu.tensor_desc<2x4xf32, #xegpu.scatter_tdesc_attr<>>, vector<2x4xi1> + xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<2x4xf32>, !xegpu.tensor_desc<2x4xf32, #xegpu.scatter_tdesc_attr<>>, vector<2x4xi1> + gpu.return +} + // CHECK: gpu.func @prefetch(%[[arg0:.*]]: ui64) { gpu.func @prefetch(%src: ui64) { //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir index ac5fe89a67f9a..e820e13f09f64 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir @@ -358,8 +358,8 @@ gpu.module @test_kernel { // CHECK-SAME: [[arg0:%.+]]: ui64 // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> - // CHECK-COUNT-2: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex> - // CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32> + // CHECK-COUNT-2: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex> + // CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32> // CHECK-COUNT-2: xegpu.store {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> gpu.func @test_prefetch_load_store_update(%src: ui64) { @@ -406,8 +406,8 @@ gpu.module @test_kernel { // CHECK-SAME: [[arg0:%.+]]: ui64 // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr> // CHECK-COUNT-4: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr> - // CHECK-COUNT-4: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr>, vector<16xindex> - // CHECK-COUNT-4: xegpu.load {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16x2xf32> + // CHECK-COUNT-4: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr>, vector<16xindex> + // CHECK-COUNT-4: xegpu.load {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16x2xf32> // CHECK-COUNT-4: xegpu.store {{.*}} : vector<16x2xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> gpu.func @test_prefetch_load_store_update_chunk(%src: ui64) { @@ -446,4 +446,57 @@ gpu.module @test_kernel { } } +// ----- +#l = #xegpu.layout + +// test the blocking pass on a 3D scattered tensor descriptor, +// Ops working 4x8x4xf32 scattered tensor_descs will be unrolled +// into 4 ops working 2x8x2xf32 scattered tensor_descs based on +// the given layout. +gpu.module @test_kernel { + // CHECK-LABEL: test_3d_scattered_tensor_desc + // CHECK-SAME: [[arg0:%.+]]: ui64 + // CHECK: [[cst_1:%.+]] = arith.constant dense<{{.*}}[130, 138, 146, 154, 162, 170, 178, 186], [194, 202, 210, 218, 226, 234, 242, 250]]> : vector<2x8xindex> + // CHECK: [[cst_2:%.+]] = arith.constant dense<{{.*}}[2, 10, 18, 26, 34, 42, 50, 58], [66, 74, 82, 90, 98, 106, 114, 122]]> : vector<2x8xindex> + // CHECK: [[cst_3:%.+]] = arith.constant dense<{{.*}}[0, 8, 16, 24, 32, 40, 48, 56], [64, 72, 80, 88, 96, 104, 112, 120]]> : vector<2x8xindex> + // CHECK: [[cst_4:%.+]] = arith.constant dense<{{.*}}[128, 136, 144, 152, 160, 168, 176, 184], [192, 200, 208, 216, 224, 232, 240, 248]]> : vector<2x8xindex> + // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<2x8xindex> -> !xegpu.tensor_desc<2x8x2xf32, #xegpu.scatter_tdesc_attr> + // CHECK-COUNT-4: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<2x8x2xf32, #xegpu.scatter_tdesc_attr> + // CHECK-COUNT-4: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<2x8x2xf32, #xegpu.scatter_tdesc_attr>, vector<2x8xindex> + // CHECK-COUNT-4: xegpu.load {{.*}} : !xegpu.tensor_desc<2x8x2xf32, #xegpu.scatter_tdesc_attr>, vector<2x8xi1> -> vector<2x8x2xf32> + // CHECK-COUNT-4: xegpu.store {{.*}} : vector<2x8x2xf32>, !xegpu.tensor_desc<2x8x2xf32, #xegpu.scatter_tdesc_attr>, vector<2x8xi1> + + + gpu.func @test_3d_scattered_tensor_desc(%src: ui64) { + %cst = arith.constant dense<[ + [0, 8, 16, 24, 32, 40, 48, 56], + [64, 72, 80, 88, 96, 104, 112, 120], + [128, 136, 144, 152, 160, 168, 176, 184], + [192, 200, 208, 216, 224, 232, 240, 248] + ]> : vector<4x8xindex> + + %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<4x8xindex> -> !xegpu.tensor_desc<4x8x4xf32, #xegpu.scatter_tdesc_attr, #l> + xegpu.prefetch %tdesc: !xegpu.tensor_desc<4x8x4xf32, #xegpu.scatter_tdesc_attr, #l> + %delta = arith.constant dense<[ + [32, 32, 32, 32, 32, 32, 32, 32], + [32, 32, 32, 32, 32, 32, 32, 64], + [128, 128, 128, 128, 128, 128, 128, 128], + [128, 128, 128, 128, 128, 128, 128, 256] + ]> : vector<4x8xindex> + %new_tdesc = xegpu.update_offset %tdesc, %delta + : !xegpu.tensor_desc<4x8x4xf32, #xegpu.scatter_tdesc_attr, #l>, vector<4x8xindex> + + %c4 = arith.constant 4: index + %mask = vector.create_mask %c4, %c4: vector<4x8xi1> + + %ld_vec = xegpu.load %new_tdesc, %mask <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: !xegpu.tensor_desc<4x8x4xf32, #xegpu.scatter_tdesc_attr, #l>, vector<4x8xi1> -> vector<4x8x4xf32> + + %st_vec = arith.addf %ld_vec, %ld_vec {layout_result_0 = #l} : vector<4x8x4xf32> + xegpu.store %st_vec, %tdesc, %mask <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: + vector<4x8x4xf32>, + !xegpu.tensor_desc<4x8x4xf32, #xegpu.scatter_tdesc_attr, #l>, + vector<4x8xi1> + gpu.return + } +} diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp index c84eb74198544..f71fcf7ca297b 100644 --- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp +++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp @@ -107,10 +107,7 @@ struct TestXeGPUUnrollingPatterns // If the encoding is a ScatterTensorDescAttr, we need to // potentially adjust the chunk size based on the inst_data. if (tdescTy.isScattered()) { - auto scatterAttr = - llvm::dyn_cast_if_present( - encoding); - int64_t chunkSize = scatterAttr.getChunkSize().getInt(); + int64_t chunkSize = tdescTy.getChunkSizeAsInt(); if (chunkSize > 1) { int64_t blockedChunkSize = chunkSize; @@ -120,8 +117,7 @@ struct TestXeGPUUnrollingPatterns // To create a new attribute with a different chunk_size: auto newEncoding = xegpu::ScatterTensorDescAttr::get( - ctx, scatterAttr.getMemorySpace().getValue(), - blockedChunkSize); + ctx, tdescTy.getMemorySpace(), blockedChunkSize); encoding = newEncoding; }