From 2a1d373a61ca10bca9064a2afa7ac1fb88a87fc8 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Thu, 10 Apr 2025 18:45:30 +0000 Subject: [PATCH 1/7] Switch to 1D representation for SIMT --- .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 17 +- .../mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 3 +- mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 26 +- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 227 +++++++++++------- mlir/test/Dialect/XeGPU/invalid.mlir | 100 ++------ mlir/test/Dialect/XeGPU/ops.mlir | 162 ++++++------- 6 files changed, 250 insertions(+), 285 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 16a7f63d60c82..9af6eaf69aec3 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -833,16 +833,14 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>] data type, the matrices are `A: vector<8x16xf16>`, `B: vector<16x16xf16>`, and `C/D: vector<8x16xf32>`. Besides the matrix size requirements, DPAS also requires A and B to be loaded with the required data layout. Specially, - VNNI layout is required for B operand. It is achieved via adding `packed` attribute to the `load_nd` operator. Due to the VNNI transformation, B operands can be represented as a 3D vector, with the last dimension representing the VNNI factor, which is computed as `32/bit_width_of_elem_type`. Thus, `B: vector<16x16xf16>` can be represented as `B: vector<8x16x2xf16>`. - In SIMT mode, DpasOp expects layout attributes `a`, `b`, and `c` (only if acc is used) - which describe the data fragment owned by each work-item w.r.t. the tensor descriptor - these data are loaded from. + In SIMT code, each work-item from a subgroup holds a data fragment for A, B, C and the result, + which are represented as 1D vectors. Note: on PVC, the hardware can perform load with VNNI transformation when data element type is 16-bit or lower precision, taking 2 or 4 elements from @@ -850,13 +848,10 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>] }]; let arguments = (ins - XeGPU_DpasOpType : $lhs, - XeGPU_DpasOpType : $rhs, - Optional: $acc, - OptionalAttr:$a_layout, - OptionalAttr:$b_layout, - OptionalAttr:$c_layout); - let results = (outs XeGPU_Vector2DType: $result); + XeGPU_DpasOprType : $lhs, + XeGPU_DpasOprType : $rhs, + Optional: $acc); + let results = (outs XeGPU_DpasResType: $result); let extraClassDeclaration = [{ VectorType getLhsType() { diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td index 173f1462fdd73..3cb71788a15ef 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td @@ -17,7 +17,8 @@ def XeGPU_IntType: AnyTypeOf<[I1, I8, I16, I32, I64, SI1, SI8, SI16, SI32, SI64, def XeGPU_FloatType: AnyTypeOf<[F16, F32, F64, BF16, TF32]>; def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>; def XeGPU_BaseAddrType: AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64, UI32, I64, I32]>; -def XeGPU_DpasOpType: VectorOfRankAndType<[2, 3], [XeGPU_ScalarType]>; +def XeGPU_DpasOprType: VectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>; +def XeGPU_DpasResType: VectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>; def XeGPU_OffsetType: VectorOfRankAndType<[1], [Index]>; def XeGPU_MaskType: AnyTypeOf<[VectorOfRankAndType<[1], [I1]>, I1]>; def XeGPU_ValueType: AnyTypeOf<[VectorOfRankAndType<[1,2,3,4], [XeGPU_ScalarType]>, XeGPU_ScalarType]>; diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 171a15ce27b59..269e445c3790c 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -10,6 +10,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "llvm/ADT/TypeSwitch.h" +#include namespace mlir { namespace xegpu { @@ -336,19 +337,20 @@ LogicalResult TensorDescType::verify( // [n_distribution_units, lane_data_size] FailureOr TensorDescType::getDistributedVectorType() { auto layout = llvm::dyn_cast_if_present(getLayout()); - // If no layout is provided, tensor desc is not used in SIMT mode. - if (!layout) + // It only works for subgroup level layout, which only has lane_layout + // and lane_data, and is to distribute a SIMD code into SIMT code. + if (!layout || !layout.isSgLayout()) return failure(); SmallVector laneData(layout.getLaneData().asArrayRef()); SmallVector laneLayout(layout.getLaneLayout().asArrayRef()); auto tdescShape = getShape(); - auto laneDataSize = 1, sgSize = 1; - for (auto [laneDim, laneDataDim] : llvm::zip_equal(laneLayout, laneData)) { - laneDataSize *= laneDataDim; - sgSize *= laneDim; - } + // compute sgSize by multiply elements of laneLayout + // e.g. for 2D layout, sgSize = laneLayout[0] * laneLayout[1] + // e.g. for 1D layout, sgSize = laneLayout[0] + auto sgSize = std::accumulate(laneLayout.begin(), laneLayout.end(), 1, + std::multiplies()); // Case 1: regular loads/stores auto scatterAttr = getEncodingAsScatterTensorDescAttr(); @@ -356,12 +358,9 @@ FailureOr TensorDescType::getDistributedVectorType() { auto chunkSize = scatterAttr.getChunkSize().getInt(); // Verify if the first dimension of the tensor descriptor shape is // distributable. - assert(tdescShape[0] % (laneLayout[0]) == 0 && + assert(tdescShape[0] == laneLayout[0] && "tensor descriptor shape is not distributable"); - if (chunkSize > 1) - return VectorType::get({chunkSize / laneDataSize, laneDataSize}, - getElementType()); - return VectorType::get({laneDataSize}, getElementType()); + return VectorType::get({chunkSize}, getElementType()); } // Case 2: block loads/stores @@ -376,8 +375,7 @@ FailureOr TensorDescType::getDistributedVectorType() { // tensorSize must be adjusted for array_length. tensorSize *= getArrayLength(); - return VectorType::get({tensorSize / (sgSize * laneDataSize), laneDataSize}, - getElementType()); + return VectorType::get({tensorSize / sgSize}, getElementType()); } } // namespace xegpu diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 0d67e3d70f945..fef39508c3bfe 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -73,38 +73,6 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) { kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH; } -// Helper to validate value shape of LoadNd and StoreNd ops. -static LogicalResult -isArgShapesValid(TensorDescType tdescTy, VectorType valueTy, - ArrayRef adjustedTdescShape, - function_ref emitError) { - auto layout = tdescTy.getLayoutAttr(); - auto valueShape = valueTy.getShape(); - // layout not present means IR is in SIMD mode. In this case value shape must - // match adjusted tensor descriptor shape. - if (!layout) - return valueShape == adjustedTdescShape - ? success() - : emitError() - << "Value shape " << makeString(valueShape) - << " is not consistent with tensor descriptor " << tdescTy; - - // layout present means IR is in SIMT mode. In this case layout determines the - // value shape. - auto expectedValueShapeOrFailure = tdescTy.getDistributedVectorType(); - assert(succeeded(expectedValueShapeOrFailure) && - "Failed to compute distributed vector shape for " - "tensor descriptor "); - - return valueTy == expectedValueShapeOrFailure.value() - ? success() - : emitError() - << "Result shape " << makeString(valueShape) - << " is not consistent with distributed vector shape " - << makeString(expectedValueShapeOrFailure.value().getShape()) - << " for tensor descriptor " << tdescTy; -} - // Checks if the given shape is evenly distributed based on the layout // and data factors provided by the LayoutAttr. The function ensures that // each dimension of the shape can be evenly divided by the corresponding @@ -302,9 +270,35 @@ LogicalResult LoadNdOp::verify() { if (!isReadHintOrNone(getL3HintAttr())) return emitOpError("invalid l3_hint: ") << getL3HintAttr(); + // Handling a 1D vector as the result can be complex. It may represent the + // outcome of a 1D block load in SIMD mode or a fragment of a block load + // result in SIMT mode. In the latter case, the tensor descriptor must be + // evenly distributed, with each lane holding an equally sized fragment of + // the result. Only subgroup size 8 or 16 is supported. + if (valueTy.getRank() == 1 && + valueTy.getNumElements() < tdescTy.getNumElements()) { + // SIMT mode doesn't need LayoutAttr. + if (tdescTy.getLayoutAttr()) + return emitOpError() + << "TensorDesc doesn't need LayoutAttr for SIMT code"; + + int tdescElems = tdescTy.getNumElements() * tdescTy.getArrayLength(); + int valueElems = valueTy.getNumElements(); + + int lanes = tdescElems % valueElems == 0 ? tdescElems / valueElems : -1; + if (lanes != 16 && lanes != 8) { + return emitOpError() + << "Result shape " << makeString(getShapeOf(valueTy)) + << " is not a valid distribution for tensor descriptor " + << tdescTy; + } + return success(); + } + + // Check SIMD mode. auto array_len = tdescTy.getArrayLength(); // adjusted tensor descriptor shape tracks the expected shape of the result. - auto adjustedTdescShape = getShapeOf(tdescTy); + auto tdescShape = getShapeOf(tdescTy); auto valueShape = getShapeOf(valueTy); if (getTranspose()) { @@ -316,7 +310,7 @@ LogicalResult LoadNdOp::verify() { }); if (valid) - transpose(trans, adjustedTdescShape); + transpose(trans, tdescShape); else mlir::emitWarning(getLoc()) << "Invalid transpose attr. It is ignored."; } @@ -325,8 +319,8 @@ LogicalResult LoadNdOp::verify() { if (tdescTy.getRank() == 2) { const int axis = 0; auto vnni_factor = valueShape.back(); - adjustedTdescShape[axis] /= vnni_factor; - adjustedTdescShape.push_back(vnni_factor); + tdescShape[axis] /= vnni_factor; + tdescShape.push_back(vnni_factor); } else { mlir::emitWarning(getLoc()) << "Invalid Packed Attr. It is ignored (available for 2D " @@ -335,12 +329,16 @@ LogicalResult LoadNdOp::verify() { } if (array_len > 1) { - auto it = adjustedTdescShape.begin(); - adjustedTdescShape.insert(it, array_len); + tdescShape.insert(tdescShape.begin(), array_len); + } + + if (tdescShape != valueShape) { + return emitOpError() << "Result shape " << makeString(valueShape) + << " is not consistent with tensor descriptor " + << tdescTy; } - return isArgShapesValid(tdescTy, valueTy, adjustedTdescShape, - [&]() { return emitOpError(); }); + return success(); } //===----------------------------------------------------------------------===// @@ -371,8 +369,37 @@ LogicalResult StoreNdOp::verify() { auto tdescShape = getShapeOf(dstTy); auto valueShape = getShapeOf(valTy); - return isArgShapesValid(dstTy, valTy, tdescShape, - [&]() { return emitOpError(); }); + // Similar to LoadNdOp, handling a 1D vector as the value can be complex. It + // may represent the input of a 1D block store in SIMD mode or a fragment of + // a block store input in SIMT mode. In the latter case, the tensor descriptor + // must be evenly distributed, with each lane holding an equally sized + // fragment of the input. Only subgroup size 8 or 16 is supported. + if (valTy.getRank() == 1 && valTy.getNumElements() < dstTy.getNumElements()) { + // SIMT mode doesn't need LayoutAttr. + if (dstTy.getLayoutAttr()) + return emitOpError() + << "TensorDesc doesn't need LayoutAttr for SIMT code"; + + int tdescElems = dstTy.getNumElements() * dstTy.getArrayLength(); + int valueElems = valueShape[0]; + + int lanes = tdescElems % valueElems == 0 ? tdescElems / valueElems : -1; + if (lanes != 16 && lanes != 8) { + return emitOpError() + << "Value shape " << makeString(getShapeOf(valTy)) + << " is not a valid distribution for tensor descriptor " << dstTy; + } + return success(); + } + + // SIMD code should have the same shape as the tensor descriptor. + if (tdescShape != valueShape) { + return emitOpError() << "Value shape " << makeString(valueShape) + << " is not consistent with tensor descriptor " + << dstTy; + } + + return success(); } //===----------------------------------------------------------------------===// @@ -520,14 +547,41 @@ LogicalResult LoadGatherOp::verify() { if (tdescShape[0] != maskShape[0]) return emitOpError("dim-0 of the Mask and TensorDesc should be the same."); + auto chunkSize = tdescTy.getChunkSize(); + // for SIMT code, the value should be 1D vector with size of chunkSize. + if (valueTy.getRank() == 1 && valueTy.getNumElements() != tdescShape[0]) { + if (valueTy.getNumElements() != chunkSize) { + return emitOpError() + << "Result shape " << makeString(valueShape) + << " is not a valid distribution for tensor descriptor " + << tdescTy; + } else { // valid SIMT code doesn't need LayoutAttr and TransposeAttr. + if (tdescTy.getLayoutAttr()) + return emitOpError() + << "TensorDesc doesn't need LayoutAttr for SIMT code"; + if (getTransposeAttr()) + return emitOpError() << "doesn't need TransposeAttr for SIMT code"; + } + return success(); + } else if (valueTy.getRank() == 1 && tdescShape[0] == chunkSize) { + // for 1D vector and valueTy.getNumElements() == tdescShape[0] case, + // it is a valid SIMT code if chunkSize happens to be the same as + // subgroup size, e.g., tensor_desc<16x16xf16, chunkSize = 16> + return success(); + } + + // For SIMD code verification. if (tdescTy.getRank() == 2) { if (!getTransposeAttr()) return emitOpError("load of rank-2 tensor has to be transposed."); transpose({1, 0}, tdescShape); } - return isArgShapesValid(tdescTy, valueTy, tdescShape, - [&]() { return emitOpError(); }); + if (tdescShape != valueShape) + return emitOpError() << "Result shape " << makeString(valueShape) + << " is not consistent with tensor descriptor " + << tdescTy; + return success(); } //===----------------------------------------------------------------------===// @@ -559,14 +613,42 @@ LogicalResult StoreScatterOp::verify() { if (tdescShape[0] != maskShape[0]) return emitOpError("dim-0 of the Mask and TensorDesc should be the same."); + auto chunkSize = tdescTy.getChunkSize(); + // for SIMT code, the value should be 1D vector with size of chunkSize. + if (valueTy.getRank() == 1 && valueTy.getNumElements() != tdescShape[0]) { + if (valueTy.getNumElements() != chunkSize) { + return emitOpError() + << "Value shape " << makeString(valueShape) + << " is not a valid distribution for tensor descriptor " + << tdescTy; + } else { // valid SIMT code doesn't need LayoutAttr and TransposeAttr. + if (tdescTy.getLayoutAttr()) + return emitOpError() + << "TensorDesc doesn't need LayoutAttr for SIMT code"; + if (getTransposeAttr()) + return emitOpError() << "doesn't need TransposeAttr for SIMT code"; + } + return success(); + } else if (valueTy.getRank() == 1 && tdescShape[0] == chunkSize) { + // for 1D vector and valueTy.getNumElements() == tdescShape[0] case, + // it is a valid SIMT code if chunkSize happens to be the same as + // subgroup size, e.g., tensor_desc<16x16xf16, chunkSize = 16> + return success(); + } + + // for SIMD code verification. if (tdescTy.getRank() == 2) { if (!getTransposeAttr()) return emitOpError("Store of a rank-2 tensor has to be transposed."); transpose({1, 0}, tdescShape); } - return isArgShapesValid(tdescTy, valueTy, tdescShape, - [&]() { return emitOpError(); }); + if (tdescShape != valueShape) + return emitOpError() << "Value shape " << makeString(valueShape) + << " is not consistent with tensor descriptor " + << tdescTy; + + return success(); } //===----------------------------------------------------------------------===// @@ -602,51 +684,16 @@ LogicalResult DpasOp::verify() { auto rhsShape = getRhsType().getShape(); auto resShape = getResultType().getShape(); - auto aLayout = getALayoutAttr(); - auto bLayout = getBLayoutAttr(); - auto cLayout = getCLayoutAttr(); - - // make sure the layout attribute is either set for every available - // operand or simply not set at all. C is special, since ACC is optional. - auto hasValidLayoutAttrs = [&]() { - bool result = (aLayout != nullptr) ^ (bLayout != nullptr); - if (hasAcc()) { - result |= (aLayout != nullptr) ^ (cLayout != nullptr); - } - return !result; - }; + if (getAcc()) { + if (getAcc().getType() != getResultType()) + return emitOpError("Expecting the acc type to be the same as result."); + } - if (!hasValidLayoutAttrs()) - return emitOpError( - "layout attributes should be either set for all operands (for SIMT " - "code) or not set at all (for SIMD code)."); - - // query the scope from aLayout (a valid setting). - if (aLayout) { - // In SIMT mode, All data fragments must be 2D - if (lhsRank != 2 || rhsRank != 2 || resRank != 2) - return emitOpError("expecting lhs, rhs, and result to be a 2D vector."); - - auto laneLayoutA = aLayout.getLaneLayout(); - auto laneLayoutB = bLayout.getLaneLayout(); - auto laneLayoutC = cLayout.getLaneLayout(); - // Obtain the expanded shapes of the operands and result using lane_layout. - // NOTE: For B, get rid of the packed dimension for the expanded shape. - SmallVector expandedShapeA = {lhsShape[0] * laneLayoutA[0], - lhsShape[1] * laneLayoutA[1]}; - SmallVector expandedShapeB = { - rhsShape[0] * rhsShape[1] * laneLayoutB[0], 1 * laneLayoutB[1]}; - SmallVector expandedShapeC = {resShape[0] * laneLayoutC[0], - resShape[1] * laneLayoutC[1]}; - auto bK = expandedShapeB[0]; - if (bK != expandedShapeA[1]) - return emitOpError("K-dimension mismatch."); - if (expandedShapeA[0] != expandedShapeC[0]) - return emitOpError("M-dimension mismatch."); - if (expandedShapeB[1] != expandedShapeC[1]) - return emitOpError("N-dimension mismatch."); - } else { // For other scopes, operands' shape should match the mxkxn - // semantics. + // SIMT code: skip the check since lack of semantic info at this level. + // Users need to ensure the correctness. + if (lhsRank == 1 && rhsRank == 1 && resRank == 1) { + return success(); + } else { // SIMD code if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resRank != 2) return emitOpError( "expecting lhs and result to be a 2D vector, and rhs to be either " diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index 48df33a591908..c0739d735dfec 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -79,25 +79,10 @@ func.func @test_load_nd_vc_3(%src: memref<8x16xf16>) { // ----- func.func @test_load_nd_layout(%src: memref<24x32xf32>) { - %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> - !xegpu.tensor_desc<8x16xf32, #xegpu.layout> - // expected-error@+1 {{Result shape [8, 2] is not consistent with distributed vector shape [8, 1]}} - %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, - l2_hint = #xegpu.cache_hint}> - : !xegpu.tensor_desc<8x16xf32, #xegpu.layout> - -> vector<8x2xf32> - return -} - -// ----- -func.func @test_load_nd_layout(%src: memref<24x32xf32>) { - %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> - !xegpu.tensor_desc<16xf32, #xegpu.layout> - // expected-error@+1 {{Result shape [8] is not consistent with distributed vector shape [1, 1]}} + %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16xf32> + // expected-error@+1 {{Result shape [8] is not a valid distribution for tensor descriptor}} %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, - l2_hint = #xegpu.cache_hint}> - : !xegpu.tensor_desc<16xf32, #xegpu.layout> - -> vector<8xf32> + l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf32> -> vector<8xf32> return } @@ -105,7 +90,7 @@ func.func @test_load_nd_layout(%src: memref<24x32xf32>) { func.func @test_load_nd_vc_6(%src: memref<24x32xf32>) { %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> - // expected-error@+1 {{Value shape [8, 1] is not consistent with tensor descriptor}} + // expected-error@+1 {{Result shape [8, 1] is not consistent with tensor descriptor}} %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf32> -> vector<8x1xf32> @@ -134,22 +119,10 @@ func.func @test_store_nd_vc_2(%dst: memref<16xf16>) { } // ----- -func.func @test_store_nd_layout(%dst: memref<24x32xf32>, %data: vector<8x2xf32>) { - %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> -> - !xegpu.tensor_desc<8x16xf32, #xegpu.layout> - // expected-error@+1 {{Result shape [8, 2] is not consistent with distributed vector shape [8, 1] for tensor descriptor}} - xegpu.store_nd %data, %1 - : vector<8x2xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout> - return -} - -// ----- -func.func @test_store_nd_layout(%dst: memref<24x32xf32>, %data: vector<2xf32>) { - %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> -> - !xegpu.tensor_desc<16xf32, #xegpu.layout> - // expected-error@+1 {{Result shape [2] is not consistent with distributed vector shape [1, 1] for tensor descriptor}} - xegpu.store_nd %data, %1 - : vector<2xf32>, !xegpu.tensor_desc<16xf32, #xegpu.layout> +func.func @test_store_nd_simt(%dst: memref<24x32xf32>, %data: vector<4xf32>) { + %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16xf32> + // expected-error@+1 {{Value shape [4] is not a valid distribution for tensor descriptor}} + xegpu.store_nd %data, %1 : vector<4xf32>, !xegpu.tensor_desc<16xf32> return } @@ -269,45 +242,23 @@ func.func @test_create_tdesc_layout_3(%src: ui64) { } // ----- -func.func @test_load_gather_layout_1(%src: ui64) { +func.func @test_load_gather_simt_1(%src: ui64) { %0 = arith.constant dense<1>: vector<4xi1> %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> - %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> - // expected-error@+1 {{Result shape [1, 2] is not consistent with distributed vector shape [2, 1] for tensor descriptor}} - %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout>, vector<4xi1> -> vector<1x2xf32> + %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr> + // expected-error@+1 {{Result shape [6] is not a valid distribution for tensor descriptor}} + %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr>, vector<4xi1> -> vector<6xf32> return } // ----- -func.func @test_load_gather_layout_2(%src: ui64) { +func.func @test_store_scatter_simt_1(%src: ui64) { %0 = arith.constant dense<1>: vector<4xi1> %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> - %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> - // expected-error@+1 {{esult shape [2] is not consistent with distributed vector shape [2, 1] for tensor descriptor}} - %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout>, vector<4xi1> -> vector<2xf32> - return -} - - -// ----- -func.func @test_store_scatter_layout_1(%src: ui64) { - %0 = arith.constant dense<1>: vector<4xi1> - %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> - %val = arith.constant dense<2.9>: vector<1x2xf32> - %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> - // expected-error@+1 {{Result shape [1, 2] is not consistent with distributed vector shape [2, 1] for tensor descriptor}} - xegpu.store %val, %1, %0 <{l1_hint = #xegpu.cache_hint, transpose}> : vector<1x2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout>, vector<4xi1> - return -} - -// ----- -func.func @test_store_scatter_layout_2(%src: ui64) { - %0 = arith.constant dense<1>: vector<4xi1> - %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> - %val = arith.constant dense<2.9>: vector<2xf32> - %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> - // expected-error@+1 {{esult shape [2] is not consistent with distributed vector shape [2, 1] for tensor descriptor}} - xegpu.store %val, %1, %0 <{l1_hint = #xegpu.cache_hint, transpose}> : vector<2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout>, vector<4xi1> + %val = arith.constant dense<2.9>: vector<6xf32> + %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr> + // expected-error@+1 {{Value shape [6] is not a valid distribution for tensor descriptor}} + xegpu.store %val, %1, %0 <{l1_hint = #xegpu.cache_hint}> : vector<6xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr>, vector<4xi1> return } @@ -393,23 +344,6 @@ func.func @test_dpas_4(%a : vector<8x16xf16>, %b: vector<8x8x2xf16>) { return } -// ----- -func.func @test_dpas_layout_1(%a : vector<8x1xf16>, %b: vector<8x2xf16>) { - // expected-error@+1 {{layout attributes should be either set for all operands (for SIMT code) or not set at all (for SIMD code)}} - %1 = xegpu.dpas %a, %b {a_layout = #xegpu.layout} : vector<8x1xf16>, vector<8x2xf16> -> vector<8x1xf32> - return -} - -// ----- -func.func @test_dpas_layout_2(%a : vector<8x1xf16>, %b: vector<4x2xf16>) { - // expected-error@+1 {{K-dimension mismatch}} - %1 = xegpu.dpas %a, %b {a_layout = #xegpu.layout, - b_layout = #xegpu.layout, - c_layout = #xegpu.layout} - : vector<8x1xf16>, vector<4x2xf16> -> vector<8x1xf32> - return -} - // ----- func.func @test_atomic_rmw(%src: ui64, %value : vector<16x4xf32>, %mask : vector<16xi1>) { %0 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xindex> diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir index e9895e0d0a71d..71e7e9bdda07d 100644 --- a/mlir/test/Dialect/XeGPU/ops.mlir +++ b/mlir/test/Dialect/XeGPU/ops.mlir @@ -125,11 +125,11 @@ gpu.func @test_load_nd_vc(%src: memref<8x16xf16>) { // CHECK: func @test_load_nd_simt(%[[arg0:.*]]: memref<8x16xf16>) { gpu.func @test_load_nd_simt(%src: memref<8x16xf16>) { - // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout> - %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout> - // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<8x16xf16, #xegpu.layout> -> vector<4x2xf16> + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16> %2 = xegpu.load_nd %1 <{packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> - : !xegpu.tensor_desc<8x16xf16, #xegpu.layout> -> vector<4x2xf16> + : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16> gpu.return } @@ -144,10 +144,10 @@ gpu.func @test_load_nd_vc_2(%src: memref<8x16xf16>) { // CHECK: func @test_load_nd_simt_2(%[[arg0:.*]]: memref<8x16xf16>) { gpu.func @test_load_nd_simt_2(%src: memref<8x16xf16>) { - // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<16xf16, #xegpu.layout> - %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<16xf16, #xegpu.layout> - // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf16, #xegpu.layout> -> vector<1x1xf16> - %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf16, #xegpu.layout> -> vector<1x1xf16> + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<16xf16> + %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<16xf16> + // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf16> -> vector<1xf16> + %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf16> -> vector<1xf16> gpu.return } @@ -162,11 +162,10 @@ gpu.func @test_load_nd_vc_3(%src: memref<24x32xf32>) { // CHECK: func @test_load_nd_simt_3(%[[arg0:.*]]: memref<24x32xf32>) { gpu.func @test_load_nd_simt_3(%src: memref<24x32xf32>) { - // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout> - %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> - !xegpu.tensor_desc<8x16xf32, #xegpu.layout> - // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf32, #xegpu.layout> -> vector<8x1xf32> - %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf32, #xegpu.layout> -> vector<8x1xf32> + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> + %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32> + %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32> gpu.return } @@ -181,11 +180,10 @@ gpu.func @test_load_nd_vc_4(%src: memref<24x32xf16>) { // CHECK: func @test_load_nd_simt_4(%[[arg0:.*]]: memref<24x32xf16>) { gpu.func @test_load_nd_simt_4(%src: memref<24x32xf16>) { - // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout> - %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> - !xegpu.tensor_desc<16x16xf16, #xegpu.layout> - // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf16, #xegpu.layout> -> vector<8x2xf16> - %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf16, #xegpu.layout> -> vector<8x2xf16> + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16> + %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16> + // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16> + %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16> gpu.return } @@ -200,11 +198,10 @@ gpu.func @test_load_nd_vc_5(%src: memref<24x32xf32>) { // CHECK: func @test_load_nd_simt_5(%[[arg0:.*]]: memref<24x32xf32>) { gpu.func @test_load_nd_simt_5(%src: memref<24x32xf32>) { - // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.layout> - %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> - !xegpu.tensor_desc<32xf32, #xegpu.layout> - // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32xf32, #xegpu.layout> -> vector<2x1xf32> - %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32xf32, #xegpu.layout> -> vector<2x1xf32> + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<32xf32> + %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<32xf32> + // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32xf32> -> vector<2xf32> + %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32xf32> -> vector<2xf32> gpu.return } @@ -219,11 +216,11 @@ gpu.func @test_load_nd_vc_6(%src: memref<24x32xf16>) { // CHECK: func @test_load_nd_simt_6(%[[arg0:.*]]: memref<24x32xf16>) { gpu.func @test_load_nd_simt_6(%src: memref<24x32xf16>) { - // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr, #xegpu.layout> - %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr, #xegpu.layout> - // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr, #xegpu.layout> -> vector<32x1xf16> + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> + %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> + // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> -> vector<32xf16> %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : - !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr, #xegpu.layout> -> vector<32x1xf16> + !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> -> vector<32xf16> gpu.return } @@ -238,11 +235,11 @@ gpu.func @test_load_nd_vc_7(%src: memref<24x32xf16>) { // CHECK: func @test_load_nd_simt_7(%[[arg0:.*]]: memref<24x32xf16>) { gpu.func @test_load_nd_simt_7(%src: memref<24x32xf16>) { - // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr, #xegpu.layout> - %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr, #xegpu.layout> - // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr, #xegpu.layout> -> vector<16x2xf16> + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> + %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> + // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> -> vector<32xf16> %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : - !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr, #xegpu.layout> -> vector<16x2xf16> + !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> -> vector<32xf16> gpu.return } @@ -257,10 +254,10 @@ gpu.func @test_load_nd_vc_8(%src: memref<24x32xf32>) { // CHECK: func @test_load_nd_simt_8(%[[arg0:.*]]: memref<24x32xf32>) { gpu.func @test_load_nd_simt_8(%src: memref<24x32xf32>) { - // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32, #xegpu.layout> - %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32, #xegpu.layout> - // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose = array}> : !xegpu.tensor_desc<16x8xf32, #xegpu.layout> -> vector<8x1xf32> - %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose = array}> : !xegpu.tensor_desc<16x8xf32, #xegpu.layout> -> vector<8x1xf32> + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32> + %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32> + // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose = array}> : !xegpu.tensor_desc<16x8xf32> -> vector<8xf32> + %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose = array}> : !xegpu.tensor_desc<16x8xf32> -> vector<8xf32> gpu.return } @@ -277,13 +274,12 @@ gpu.func @test_store_nd_vc(%dst: memref<24x32xf16>) { // CHECK: func @test_store_nd_simt(%[[arg0:.*]]: memref<24x32xf16>) { gpu.func @test_store_nd_simt(%src: memref<24x32xf16>) { - // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<48x1xf16> - %1 = arith.constant dense<1.0>: vector<48x1xf16> - // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16, #xegpu.layout> - %2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> - !xegpu.tensor_desc<24x32xf16, #xegpu.layout> - // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<48x1xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.layout> - xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: vector<48x1xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.layout> + // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<48xf16> + %1 = arith.constant dense<1.0>: vector<48xf16> + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16> + %2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16> + // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<48xf16>, !xegpu.tensor_desc<24x32xf16> + xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: vector<48xf16>, !xegpu.tensor_desc<24x32xf16> gpu.return } @@ -303,13 +299,12 @@ gpu.func @test_store_nd_vc_2(%dst: memref<24x32xf16>) { // CHECK: func @test_store_nd_simt_2(%[[arg0:.*]]: memref<24x32xf16>) { gpu.func @test_store_nd_simt_2(%src: memref<24x32xf16>) { - // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2x1xf16> - %1 = arith.constant dense<1.0>: vector<2x1xf16> - // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16, #xegpu.layout> - %2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> - !xegpu.tensor_desc<32xf16, #xegpu.layout> - // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<2x1xf16>, !xegpu.tensor_desc<32xf16, #xegpu.layout> - xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: vector<2x1xf16>, !xegpu.tensor_desc<32xf16, #xegpu.layout> + // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2xf16> + %1 = arith.constant dense<1.0>: vector<2xf16> + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16> + %2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16> + // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<2xf16>, !xegpu.tensor_desc<32xf16> + xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: vector<2xf16>, !xegpu.tensor_desc<32xf16> gpu.return } @@ -425,10 +420,10 @@ gpu.func @test_load_simt(%src: ui64) { %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> //CHECK: %[[cst1:.*]] = arith.constant dense : vector<4xi1> %1 = arith.constant dense<1>: vector<4xi1> - //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> - %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> - //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout>, vector<4xi1> -> vector<2x1xf32> - %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout>, vector<4xi1> -> vector<2x1xf32> + //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr> + %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr> + //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr>, vector<4xi1> -> vector<2xf32> + %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr>, vector<4xi1> -> vector<2xf32> gpu.return } @@ -451,10 +446,10 @@ gpu.func @test_load_simt_2(%src: ui64) { %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> //CHECK: %[[cst1:.*]] = arith.constant dense : vector<4xi1> %1 = arith.constant dense<1>: vector<4xi1> - //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout> - %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout> - //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout>, vector<4xi1> -> vector<1xf32> - %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout>, vector<4xi1> -> vector<1xf32> + //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>> + %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>> + //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>>, vector<4xi1> -> vector<1xf32> + %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>>, vector<4xi1> -> vector<1xf32> gpu.return } @@ -477,10 +472,10 @@ gpu.func @test_load_simt_3(%src: ui64) { %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> //CHECK: %[[cst1:.*]] = arith.constant dense : vector<4xi1> %1 = arith.constant dense<1>: vector<4xi1> - //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr, #xegpu.layout> - %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr, #xegpu.layout> - //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose}> : !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr, #xegpu.layout>, vector<4xi1> -> vector<4x2xf16> - %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose}> : !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr, #xegpu.layout>, vector<4xi1> -> vector<4x2xf16> + //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr> + %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr> + //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr>, vector<4xi1> -> vector<8xf16> + %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr>, vector<4xi1> -> vector<8xf16> gpu.return } @@ -507,12 +502,12 @@ gpu.func @test_store_simt(%src: ui64) { %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> //CHECK: %[[cst1:.*]] = arith.constant dense : vector<4xi1> %1 = arith.constant dense<1>: vector<4xi1> - //CHECK: %[[cst2:.*]] = arith.constant dense<2.900000e+00> : vector<2x1xf32> - %2 = arith.constant dense<2.9>: vector<2x1xf32> - //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> - %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> - //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose}> : vector<2x1xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout>, vector<4xi1> - xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose}> : vector<2x1xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout>, vector<4xi1> + //CHECK: %[[cst2:.*]] = arith.constant dense<2.900000e+00> : vector<2xf32> + %2 = arith.constant dense<2.9>: vector<2xf32> + //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr> + %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr> + //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr>, vector<4xi1> + xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr>, vector<4xi1> gpu.return } @@ -539,12 +534,12 @@ gpu.func @test_store_simt_2(%src: ui64) { %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> //CHECK: %[[cst1:.*]] = arith.constant dense : vector<4xi1> %1 = arith.constant dense<1>: vector<4xi1> - //CHECK: %[[cst2:.*]] = arith.constant {{.*}} : vector<1x2xf16> - %2 = arith.constant dense<2.9>: vector<1x2xf16> - //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr, #xegpu.layout> - %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr, #xegpu.layout> - //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose}> : vector<1x2xf16>, !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr, #xegpu.layout>, vector<4xi1> - xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose}> : vector<1x2xf16>, !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr, #xegpu.layout>, vector<4xi1> + //CHECK: %[[cst2:.*]] = arith.constant {{.*}} : vector<2xf16> + %2 = arith.constant dense<2.9>: vector<2xf16> + //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr> + %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr> + //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<2xf16>, !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr>, vector<4xi1> + xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<2xf16>, !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr>, vector<4xi1> gpu.return } @@ -572,10 +567,10 @@ gpu.func @test_store_simt_3(%src: ui64) { %1 = arith.constant dense<1>: vector<4xi1> //CHECK: %[[cst2:.*]] = arith.constant dense<2.900000e+00> : vector<1xf32> %2 = arith.constant dense<2.9>: vector<1xf32> - //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout> - %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout> - //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<1xf32>, !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout>, vector<4xi1> - xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<1xf32>, !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout>, vector<4xi1> + //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>> + %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>> + //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<1xf32>, !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>>, vector<4xi1> + xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<1xf32>, !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>>, vector<4xi1> gpu.return } @@ -635,15 +630,10 @@ gpu.func @test_dpas_vc(%a : vector<8x16xf16>, %b: vector<16x16xf16>) { gpu.return } -// CHECK: gpu.func @test_dpas_simt(%[[arg0:.*]]: vector<8x1xf16>, %[[arg1:.*]]: vector<8x2xf16>) -gpu.func @test_dpas_simt(%a : vector<8x1xf16>, %b: vector<8x2xf16>) { - // CHECK: xegpu.dpas %[[arg0]], %[[arg1]] {a_layout = #xegpu.layout, - // CHECK: b_layout = #xegpu.layout, - // CHECK: c_layout = #xegpu.layout} : vector<8x1xf16>, vector<8x2xf16> -> vector<8x1xf32> - %1 = xegpu.dpas %a, %b {a_layout = #xegpu.layout, - b_layout = #xegpu.layout, - c_layout = #xegpu.layout} - : vector<8x1xf16>, vector<8x2xf16> -> vector<8x1xf32> +// CHECK: gpu.func @test_dpas_simt(%[[arg0:.*]]: vector<8xf16>, %[[arg1:.*]]: vector<16xf16>) +gpu.func @test_dpas_simt(%a : vector<8xf16>, %b: vector<16xf16>) { + // CHECK: xegpu.dpas %[[arg0]], %[[arg1]] : vector<8xf16>, vector<16xf16> -> vector<8xf32> + %1 = xegpu.dpas %a, %b : vector<8xf16>, vector<16xf16> -> vector<8xf32> gpu.return } From 2159119977dfb62c11d808777529dd34ed0abd43 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Thu, 10 Apr 2025 20:25:00 +0000 Subject: [PATCH 2/7] refine verfier for load_nd and store_nd --- .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 4 +- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 53 +++++++++---------- mlir/test/Dialect/XeGPU/invalid.mlir | 19 +++++-- 3 files changed, 43 insertions(+), 33 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 9af6eaf69aec3..5fa18754305ca 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -840,7 +840,9 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>] can be represented as `B: vector<8x16x2xf16>`. In SIMT code, each work-item from a subgroup holds a data fragment for A, B, C and the result, - which are represented as 1D vectors. + which are represented as 1D vectors. Please refer to [OpenCL Intel extentions] + (https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html) + for more details about the fragment distribution. Note: on PVC, the hardware can perform load with VNNI transformation when data element type is 16-bit or lower precision, taking 2 or 4 elements from diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index fef39508c3bfe..1dafc9936107e 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -270,33 +270,31 @@ LogicalResult LoadNdOp::verify() { if (!isReadHintOrNone(getL3HintAttr())) return emitOpError("invalid l3_hint: ") << getL3HintAttr(); - // Handling a 1D vector as the result can be complex. It may represent the - // outcome of a 1D block load in SIMD mode or a fragment of a block load - // result in SIMT mode. In the latter case, the tensor descriptor must be - // evenly distributed, with each lane holding an equally sized fragment of - // the result. Only subgroup size 8 or 16 is supported. - if (valueTy.getRank() == 1 && - valueTy.getNumElements() < tdescTy.getNumElements()) { + int tdescElems = tdescTy.getNumElements() * tdescTy.getArrayLength(); + int valueElems = valueTy.getNumElements(); + + // If the result vector is 1D and has less elements than the tensor + // descriptor, it is supposed to be a SIMT op. The layout attribute in + // tensor_desc is not needed. + if (valueElems < tdescElems && valueTy.getRank() == 1) { // SIMT mode doesn't need LayoutAttr. if (tdescTy.getLayoutAttr()) return emitOpError() << "TensorDesc doesn't need LayoutAttr for SIMT code"; - int tdescElems = tdescTy.getNumElements() * tdescTy.getArrayLength(); - int valueElems = valueTy.getNumElements(); - - int lanes = tdescElems % valueElems == 0 ? tdescElems / valueElems : -1; - if (lanes != 16 && lanes != 8) { + // For SIMT code, the load is evenly distributed across all lanes in a + // subgroup. Since subgroup size is arch dependent, we only check even + // distribution here. + if (tdescElems % valueElems) return emitOpError() << "Result shape " << makeString(getShapeOf(valueTy)) << " is not a valid distribution for tensor descriptor " << tdescTy; - } + return success(); } // Check SIMD mode. - auto array_len = tdescTy.getArrayLength(); // adjusted tensor descriptor shape tracks the expected shape of the result. auto tdescShape = getShapeOf(tdescTy); auto valueShape = getShapeOf(valueTy); @@ -328,6 +326,7 @@ LogicalResult LoadNdOp::verify() { } } + auto array_len = tdescTy.getArrayLength(); if (array_len > 1) { tdescShape.insert(tdescShape.begin(), array_len); } @@ -366,25 +365,23 @@ LogicalResult StoreNdOp::verify() { if (!isWriteHintOrNone(getL3HintAttr())) return emitOpError("invalid l3_hint: ") << getL3HintAttr(); - auto tdescShape = getShapeOf(dstTy); - auto valueShape = getShapeOf(valTy); + auto array_len = dstTy.getArrayLength(); + if (array_len > 1) + return emitOpError("array length is not supported by store_nd.\n"); + + auto tdescElems = dstTy.getNumElements(); + auto valueElems = valTy.getNumElements(); - // Similar to LoadNdOp, handling a 1D vector as the value can be complex. It - // may represent the input of a 1D block store in SIMD mode or a fragment of - // a block store input in SIMT mode. In the latter case, the tensor descriptor - // must be evenly distributed, with each lane holding an equally sized - // fragment of the input. Only subgroup size 8 or 16 is supported. - if (valTy.getRank() == 1 && valTy.getNumElements() < dstTy.getNumElements()) { + // Similar to LoadNdOp, if the value vector is 1D and has less elements than + // the tensor descriptor, it is supposed to be a SIMT op. The layout attribute + // in tensor_desc is not needed. + if (valTy.getRank() == 1 && valueElems < tdescElems) { // SIMT mode doesn't need LayoutAttr. if (dstTy.getLayoutAttr()) return emitOpError() << "TensorDesc doesn't need LayoutAttr for SIMT code"; - int tdescElems = dstTy.getNumElements() * dstTy.getArrayLength(); - int valueElems = valueShape[0]; - - int lanes = tdescElems % valueElems == 0 ? tdescElems / valueElems : -1; - if (lanes != 16 && lanes != 8) { + if (tdescElems % valueElems) { return emitOpError() << "Value shape " << makeString(getShapeOf(valTy)) << " is not a valid distribution for tensor descriptor " << dstTy; @@ -393,6 +390,8 @@ LogicalResult StoreNdOp::verify() { } // SIMD code should have the same shape as the tensor descriptor. + auto tdescShape = getShapeOf(dstTy); + auto valueShape = getShapeOf(valTy); if (tdescShape != valueShape) { return emitOpError() << "Value shape " << makeString(valueShape) << " is not consistent with tensor descriptor " diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index c0739d735dfec..a02427b6e317b 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -80,9 +80,9 @@ func.func @test_load_nd_vc_3(%src: memref<8x16xf16>) { // ----- func.func @test_load_nd_layout(%src: memref<24x32xf32>) { %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16xf32> - // expected-error@+1 {{Result shape [8] is not a valid distribution for tensor descriptor}} + // expected-error@+1 {{Result shape [3] is not a valid distribution for tensor descriptor}} %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, - l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf32> -> vector<8xf32> + l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf32> -> vector<3xf32> return } @@ -119,10 +119,19 @@ func.func @test_store_nd_vc_2(%dst: memref<16xf16>) { } // ----- -func.func @test_store_nd_simt(%dst: memref<24x32xf32>, %data: vector<4xf32>) { +func.func @test_store_nd_vc_3(%dst: memref<24x32xf16>) { + %1 = arith.constant dense<1.0>: vector<2x24x32xf16> + %2 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr> + // expected-error@+1 {{array length is not supported by store_nd}} + xegpu.store_nd %1, %2: vector<2x24x32xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr> + return +} + +// ----- +func.func @test_store_nd_simt(%dst: memref<24x32xf32>, %data: vector<3xf32>) { %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16xf32> - // expected-error@+1 {{Value shape [4] is not a valid distribution for tensor descriptor}} - xegpu.store_nd %data, %1 : vector<4xf32>, !xegpu.tensor_desc<16xf32> + // expected-error@+1 {{Value shape [3] is not a valid distribution for tensor descriptor}} + xegpu.store_nd %data, %1 : vector<3xf32>, !xegpu.tensor_desc<16xf32> return } From 775d039bb7a5ba9fd91939411e2d69312879f1e0 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Tue, 15 Apr 2025 18:46:55 +0000 Subject: [PATCH 3/7] refine verifier for gather/scatter --- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 62 +++++++++----------------- mlir/test/Dialect/XeGPU/invalid.mlir | 4 +- 2 files changed, 22 insertions(+), 44 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 1dafc9936107e..f5205c5e7e5bc 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -547,30 +547,18 @@ LogicalResult LoadGatherOp::verify() { return emitOpError("dim-0 of the Mask and TensorDesc should be the same."); auto chunkSize = tdescTy.getChunkSize(); - // for SIMT code, the value should be 1D vector with size of chunkSize. - if (valueTy.getRank() == 1 && valueTy.getNumElements() != tdescShape[0]) { - if (valueTy.getNumElements() != chunkSize) { + + // a valid shape for SIMT case + if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) { + if (tdescTy.getLayoutAttr()) return emitOpError() - << "Result shape " << makeString(valueShape) - << " is not a valid distribution for tensor descriptor " - << tdescTy; - } else { // valid SIMT code doesn't need LayoutAttr and TransposeAttr. - if (tdescTy.getLayoutAttr()) - return emitOpError() - << "TensorDesc doesn't need LayoutAttr for SIMT code"; - if (getTransposeAttr()) - return emitOpError() << "doesn't need TransposeAttr for SIMT code"; - } - return success(); - } else if (valueTy.getRank() == 1 && tdescShape[0] == chunkSize) { - // for 1D vector and valueTy.getNumElements() == tdescShape[0] case, - // it is a valid SIMT code if chunkSize happens to be the same as - // subgroup size, e.g., tensor_desc<16x16xf16, chunkSize = 16> + << "TensorDesc doesn't need LayoutAttr for SIMT code"; + if (getTransposeAttr()) + return emitOpError() << "doesn't need TransposeAttr for SIMT code"; return success(); } - // For SIMD code verification. - if (tdescTy.getRank() == 2) { + if (tdescTy.getRank() == 2 && valueTy.getRank() == 2) { if (!getTransposeAttr()) return emitOpError("load of rank-2 tensor has to be transposed."); transpose({1, 0}, tdescShape); @@ -578,7 +566,8 @@ LogicalResult LoadGatherOp::verify() { if (tdescShape != valueShape) return emitOpError() << "Result shape " << makeString(valueShape) - << " is not consistent with tensor descriptor " + << " is neither a valid distribution for SIMT nor " + "consistent with the tensor descriptor for SIMD " << tdescTy; return success(); } @@ -613,30 +602,18 @@ LogicalResult StoreScatterOp::verify() { return emitOpError("dim-0 of the Mask and TensorDesc should be the same."); auto chunkSize = tdescTy.getChunkSize(); - // for SIMT code, the value should be 1D vector with size of chunkSize. - if (valueTy.getRank() == 1 && valueTy.getNumElements() != tdescShape[0]) { - if (valueTy.getNumElements() != chunkSize) { + + // a valid shape for SIMT case + if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) { + if (tdescTy.getLayoutAttr()) return emitOpError() - << "Value shape " << makeString(valueShape) - << " is not a valid distribution for tensor descriptor " - << tdescTy; - } else { // valid SIMT code doesn't need LayoutAttr and TransposeAttr. - if (tdescTy.getLayoutAttr()) - return emitOpError() - << "TensorDesc doesn't need LayoutAttr for SIMT code"; - if (getTransposeAttr()) - return emitOpError() << "doesn't need TransposeAttr for SIMT code"; - } - return success(); - } else if (valueTy.getRank() == 1 && tdescShape[0] == chunkSize) { - // for 1D vector and valueTy.getNumElements() == tdescShape[0] case, - // it is a valid SIMT code if chunkSize happens to be the same as - // subgroup size, e.g., tensor_desc<16x16xf16, chunkSize = 16> + << "TensorDesc doesn't need LayoutAttr for SIMT code"; + if (getTransposeAttr()) + return emitOpError() << "doesn't need TransposeAttr for SIMT code"; return success(); } - // for SIMD code verification. - if (tdescTy.getRank() == 2) { + if (tdescTy.getRank() == 2 && valueTy.getRank() == 2) { if (!getTransposeAttr()) return emitOpError("Store of a rank-2 tensor has to be transposed."); transpose({1, 0}, tdescShape); @@ -644,7 +621,8 @@ LogicalResult StoreScatterOp::verify() { if (tdescShape != valueShape) return emitOpError() << "Value shape " << makeString(valueShape) - << " is not consistent with tensor descriptor " + << " is neither a valid distribution for SIMT nor " + "consistent with the tensor descriptor for SIMD " << tdescTy; return success(); diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index a02427b6e317b..2a7436807f5f4 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -255,7 +255,7 @@ func.func @test_load_gather_simt_1(%src: ui64) { %0 = arith.constant dense<1>: vector<4xi1> %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr> - // expected-error@+1 {{Result shape [6] is not a valid distribution for tensor descriptor}} + // expected-error@+1 {{Result shape [6] is neither a valid distribution for SIMT nor consistent with the tensor descriptor for SIMD}} %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr>, vector<4xi1> -> vector<6xf32> return } @@ -266,7 +266,7 @@ func.func @test_store_scatter_simt_1(%src: ui64) { %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> %val = arith.constant dense<2.9>: vector<6xf32> %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr> - // expected-error@+1 {{Value shape [6] is not a valid distribution for tensor descriptor}} + // expected-error@+1 {{Value shape [6] is neither a valid distribution for SIMT nor consistent with the tensor descriptor for SIMD}} xegpu.store %val, %1, %0 <{l1_hint = #xegpu.cache_hint}> : vector<6xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr>, vector<4xi1> return } From 5520ce18138b5153d7ecb874fe10be78127d719e Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Tue, 15 Apr 2025 18:59:39 +0000 Subject: [PATCH 4/7] update comments --- mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 13 +++++++------ mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 1 - 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 269e445c3790c..b865b80f0075e 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -320,21 +320,22 @@ LogicalResult TensorDescType::verify( // --------------------------------------------------------------------- // Case 1: Regular loads/stores. // --------------------------------------------------------------------- -// Distributed vector shape must be: -// [chunk_size / lane_data_size, lane_data_size] -// If the tensor descriptor shape is 1D, first dimension is ignored (set to 1). -// [lane_data_size] +// The following conditions must be met: +// * tensor_desc[0] == lane_layout[0] +// Distributed vector is a 1D vector with shape: +// [chunk_size] // --------------------------------------------------------------------- // Case 2: Block loads/stores // --------------------------------------------------------------------- // Additional definitions: // tensor_size = tensor_desc[0] * .. * tensor_desc[r-1] * array_length // n_distribution_units = tensor_size / distribution_unit_size +// fragment_size = n_distribution_units * lane_data_size // Given above definitions, the following conditions must be met: // * tensor_desc[0] % (lane_layout[0] × lane_data[0]) == 0 // * tensor_desc[1] % (lane_layout[1] × lane_data[1]) == 0 -// Distributed vector shape must be: -// [n_distribution_units, lane_data_size] +// Distributed vector is a 1D vector with shape: +// [fragment_size] FailureOr TensorDescType::getDistributedVectorType() { auto layout = llvm::dyn_cast_if_present(getLayout()); // It only works for subgroup level layout, which only has lane_layout diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index f5205c5e7e5bc..4305c0431cc7e 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -295,7 +295,6 @@ LogicalResult LoadNdOp::verify() { } // Check SIMD mode. - // adjusted tensor descriptor shape tracks the expected shape of the result. auto tdescShape = getShapeOf(tdescTy); auto valueShape = getShapeOf(valueTy); From 7072bc1bf5a36613adf5f0cdb201c4dbeb1b81f5 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Thu, 17 Apr 2025 15:41:57 +0000 Subject: [PATCH 5/7] refator verifiers for load_gather, store_scatter and dpas --- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 144 ++++++++++--------------- mlir/test/Dialect/XeGPU/invalid.mlir | 11 +- 2 files changed, 66 insertions(+), 89 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 4305c0431cc7e..b02490909e067 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -101,6 +101,48 @@ static bool isEvenDistributed(llvm::ArrayRef shape, return true; } +static LogicalResult isValidGatherScatterParams(Type maskTy, VectorType valueTy, TensorDescType tdescTy, UnitAttr transposeAttr, function_ref emitError) { + + if (!tdescTy.isScattered()) + return emitError() << "Expects a scattered TensorDesc."; + + if (!valueTy) + return emitError() << "Expecting a vector type result."; + + auto maskShape = getShapeOf(maskTy); + auto valueShape = getShapeOf(valueTy); + auto tdescShape = getShapeOf(tdescTy); + auto chunkSize = tdescTy.getChunkSize(); + + if (valueTy.getElementType() != tdescTy.getElementType()) + return emitError() << "Value should have the same element type as TensorDesc."; + + if (tdescShape[0] != maskShape[0]) + return emitError() << "dim-0 of the Mask and TensorDesc should be the same."; + + // a valid shape for SIMT case + if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) { + if (tdescTy.getLayoutAttr()) + return emitError() << "TensorDesc doesn't need LayoutAttr for SIMT code"; + if (transposeAttr) + return emitError() << "doesn't need TransposeAttr for SIMT code"; + return success(); + } + + if (tdescTy.getRank() == 2 && valueTy.getRank() == 2) { + if (!transposeAttr) + return emitError() << "rank-2 tensor has to be transposed."; + transpose({1, 0}, tdescShape); + } + + if (tdescShape != valueShape) + return emitError() << "Value shape " << makeString(valueShape) + << " is neither a valid distribution for SIMT nor " + "consistent with the tensor descriptor for SIMD " + << tdescTy; + return success(); +} + //===----------------------------------------------------------------------===// // XeGPU_CreateNdDescOp //===----------------------------------------------------------------------===// @@ -517,12 +559,6 @@ LogicalResult LoadGatherOp::verify() { auto maskTy = getMaskType(); auto valueTy = getValueType(); - if (!valueTy) - return emitOpError("Expecting a vector type result.\n"); - - if (!tdescTy.isScattered()) - return emitOpError("Expects a scattered TensorDesc.\n"); - if (!isReadHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); @@ -532,43 +568,8 @@ LogicalResult LoadGatherOp::verify() { if (!isReadHintOrNone(getL3HintAttr())) return emitOpError("invalid l3_hint: ") << getL3HintAttr(); - auto tdescElemTy = tdescTy.getElementType(); - auto valueElemTy = getElementType(); - if (tdescElemTy != valueElemTy) - return emitOpError( - "Value should have the same element type as TensorDesc."); - - auto maskShape = getShapeOf(maskTy); - auto valueShape = getShapeOf(valueTy); - auto tdescShape = getShapeOf(tdescTy); - - if (tdescShape[0] != maskShape[0]) - return emitOpError("dim-0 of the Mask and TensorDesc should be the same."); - - auto chunkSize = tdescTy.getChunkSize(); - - // a valid shape for SIMT case - if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) { - if (tdescTy.getLayoutAttr()) - return emitOpError() - << "TensorDesc doesn't need LayoutAttr for SIMT code"; - if (getTransposeAttr()) - return emitOpError() << "doesn't need TransposeAttr for SIMT code"; - return success(); - } - - if (tdescTy.getRank() == 2 && valueTy.getRank() == 2) { - if (!getTransposeAttr()) - return emitOpError("load of rank-2 tensor has to be transposed."); - transpose({1, 0}, tdescShape); - } - - if (tdescShape != valueShape) - return emitOpError() << "Result shape " << makeString(valueShape) - << " is neither a valid distribution for SIMT nor " - "consistent with the tensor descriptor for SIMD " - << tdescTy; - return success(); + return isValidGatherScatterParams(maskTy, valueTy, tdescTy, getTransposeAttr(), + [&]() { return emitOpError(); }); } //===----------------------------------------------------------------------===// @@ -576,8 +577,8 @@ LogicalResult LoadGatherOp::verify() { //===----------------------------------------------------------------------===// LogicalResult StoreScatterOp::verify() { auto tdescTy = getTensorDescType(); - if (!tdescTy.isScattered()) - return emitOpError("Expects a scattered TensorDesc.\n"); + auto maskTy = getMaskType(); + auto valueTy = getValueType(); if (!isWriteHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); @@ -588,43 +589,8 @@ LogicalResult StoreScatterOp::verify() { if (!isWriteHintOrNone(getL3HintAttr())) return emitOpError("invalid l3_hint: ") << getL3HintAttr(); - auto maskTy = getMaskType(); - auto valueTy = getValueType(); - - if (!valueTy) - return emitOpError("Expecting a vector type for the value.\n"); - - auto maskShape = getShapeOf(maskTy); - auto tdescShape = getShapeOf(tdescTy); - auto valueShape = getShapeOf(valueTy); - if (tdescShape[0] != maskShape[0]) - return emitOpError("dim-0 of the Mask and TensorDesc should be the same."); - - auto chunkSize = tdescTy.getChunkSize(); - - // a valid shape for SIMT case - if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) { - if (tdescTy.getLayoutAttr()) - return emitOpError() - << "TensorDesc doesn't need LayoutAttr for SIMT code"; - if (getTransposeAttr()) - return emitOpError() << "doesn't need TransposeAttr for SIMT code"; - return success(); - } - - if (tdescTy.getRank() == 2 && valueTy.getRank() == 2) { - if (!getTransposeAttr()) - return emitOpError("Store of a rank-2 tensor has to be transposed."); - transpose({1, 0}, tdescShape); - } - - if (tdescShape != valueShape) - return emitOpError() << "Value shape " << makeString(valueShape) - << " is neither a valid distribution for SIMT nor " - "consistent with the tensor descriptor for SIMD " - << tdescTy; - - return success(); + return isValidGatherScatterParams(maskTy, valueTy, tdescTy, getTransposeAttr(), + [&]() { return emitOpError(); }); } //===----------------------------------------------------------------------===// @@ -660,14 +626,18 @@ LogicalResult DpasOp::verify() { auto rhsShape = getRhsType().getShape(); auto resShape = getResultType().getShape(); - if (getAcc()) { - if (getAcc().getType() != getResultType()) - return emitOpError("Expecting the acc type to be the same as result."); - } + if (getAcc() && getAcc().getType() != getResultType()) + return emitOpError("Expecting the acc type to be the same as result."); - // SIMT code: skip the check since lack of semantic info at this level. + // SIMT code: the size of the B operand has to be a multiple of 32 bits. + // It skips the semantic check since lack of architecture information. // Users need to ensure the correctness. if (lhsRank == 1 && rhsRank == 1 && resRank == 1) { + auto numElems = getRhsType().getNumElements(); + auto elemTy = getRhsType().getElementType(); + auto factor = 32 / elemTy.getIntOrFloatBitWidth(); + if (numElems % factor != 0) + return emitOpError("Expecting B operand to be a multiple of 32 bits."); return success(); } else { // SIMD code if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resRank != 2) diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index 2a7436807f5f4..67ed89e11b4c9 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -255,7 +255,7 @@ func.func @test_load_gather_simt_1(%src: ui64) { %0 = arith.constant dense<1>: vector<4xi1> %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr> - // expected-error@+1 {{Result shape [6] is neither a valid distribution for SIMT nor consistent with the tensor descriptor for SIMD}} + // expected-error@+1 {{Value shape [6] is neither a valid distribution for SIMT nor consistent with the tensor descriptor for SIMD}} %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr>, vector<4xi1> -> vector<6xf32> return } @@ -347,12 +347,19 @@ func.func @test_dpas_4(%a : vector<16x16xf16>, %b: vector<8x16x2xf16>) { } // ----- -func.func @test_dpas_4(%a : vector<8x16xf16>, %b: vector<8x8x2xf16>) { +func.func @test_dpas_5(%a : vector<8x16xf16>, %b: vector<8x8x2xf16>) { // expected-error@+1 {{N-dimension mismatch}} %1 = xegpu.dpas %a, %b : vector<8x16xf16>, vector<8x8x2xf16> -> vector<8x16xf32> return } +// ----- +func.func @test_dpas_simt_1(%a : vector<8xf16>, %b: vector<15xf16>) { + // expected-error@+1 {{Expecting B operand to be a multiple of 32 bits}} + %1 = xegpu.dpas %a, %b : vector<8xf16>, vector<15xf16> -> vector<8xf32> + return +} + // ----- func.func @test_atomic_rmw(%src: ui64, %value : vector<16x4xf32>, %mask : vector<16xi1>) { %0 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xindex> From 605c99ec7291f8a7f3dec4e7951e196b4f6f27b5 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Thu, 17 Apr 2025 15:44:58 +0000 Subject: [PATCH 6/7] fix format --- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index b02490909e067..1da2752f44b99 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -101,7 +101,10 @@ static bool isEvenDistributed(llvm::ArrayRef shape, return true; } -static LogicalResult isValidGatherScatterParams(Type maskTy, VectorType valueTy, TensorDescType tdescTy, UnitAttr transposeAttr, function_ref emitError) { +static LogicalResult +isValidGatherScatterParams(Type maskTy, VectorType valueTy, + TensorDescType tdescTy, UnitAttr transposeAttr, + function_ref emitError) { if (!tdescTy.isScattered()) return emitError() << "Expects a scattered TensorDesc."; @@ -115,10 +118,12 @@ static LogicalResult isValidGatherScatterParams(Type maskTy, VectorType valueTy, auto chunkSize = tdescTy.getChunkSize(); if (valueTy.getElementType() != tdescTy.getElementType()) - return emitError() << "Value should have the same element type as TensorDesc."; + return emitError() + << "Value should have the same element type as TensorDesc."; if (tdescShape[0] != maskShape[0]) - return emitError() << "dim-0 of the Mask and TensorDesc should be the same."; + return emitError() + << "dim-0 of the Mask and TensorDesc should be the same."; // a valid shape for SIMT case if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) { @@ -568,8 +573,9 @@ LogicalResult LoadGatherOp::verify() { if (!isReadHintOrNone(getL3HintAttr())) return emitOpError("invalid l3_hint: ") << getL3HintAttr(); - return isValidGatherScatterParams(maskTy, valueTy, tdescTy, getTransposeAttr(), - [&]() { return emitOpError(); }); + return isValidGatherScatterParams(maskTy, valueTy, tdescTy, + getTransposeAttr(), + [&]() { return emitOpError(); }); } //===----------------------------------------------------------------------===// @@ -589,8 +595,9 @@ LogicalResult StoreScatterOp::verify() { if (!isWriteHintOrNone(getL3HintAttr())) return emitOpError("invalid l3_hint: ") << getL3HintAttr(); - return isValidGatherScatterParams(maskTy, valueTy, tdescTy, getTransposeAttr(), - [&]() { return emitOpError(); }); + return isValidGatherScatterParams(maskTy, valueTy, tdescTy, + getTransposeAttr(), + [&]() { return emitOpError(); }); } //===----------------------------------------------------------------------===// From fb2506c4f2edd8da58f770f8281a891dbe6db7a7 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Thu, 17 Apr 2025 20:23:21 +0000 Subject: [PATCH 7/7] address comments --- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 1da2752f44b99..e0e25365220b5 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -646,19 +646,21 @@ LogicalResult DpasOp::verify() { if (numElems % factor != 0) return emitOpError("Expecting B operand to be a multiple of 32 bits."); return success(); - } else { // SIMD code - if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resRank != 2) - return emitOpError( - "expecting lhs and result to be a 2D vector, and rhs to be either " - "2D or 3D (packed) vector."); - auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0]; - if (bK != lhsShape[1]) - return emitOpError("K-dimension mismatch."); - if (lhsShape[0] != resShape[0]) - return emitOpError("M-dimension mismatch."); - if (rhsShape[1] != resShape[1]) - return emitOpError("N-dimension mismatch."); } + + // SIMD code + if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resRank != 2) + return emitOpError( + "expecting lhs and result to be a 2D vector, and rhs to be either " + "2D or 3D (packed) vector."); + auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0]; + if (bK != lhsShape[1]) + return emitOpError("K-dimension mismatch."); + if (lhsShape[0] != resShape[0]) + return emitOpError("M-dimension mismatch."); + if (rhsShape[1] != resShape[1]) + return emitOpError("N-dimension mismatch."); + return success(); }