From d23c97d170efbc1cf9ec5a2e4b10f057d88edb23 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Thu, 23 Jan 2025 10:24:11 +0100 Subject: [PATCH 1/8] [mlir][xegpu] TensorDesc verifier Adds XeGPU tensor descriptor type verifier. The checks focus on ensuring that provided subgroup map is valid with respect to the underlying data. --- .../mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 2 +- mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 56 ++++++++++++++++++- 2 files changed, 54 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td index d09c5c1870d50..494f11f041b71 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td @@ -179,7 +179,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc", }]; let hasCustomAssemblyFormat = true; - + let genVerifyDecl = 1; } diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index eb01b15de75c6..42d59da2f7a92 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -175,9 +175,10 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) { if (parser.parseGreater()) return {}; - return TensorDescType::get(parser.getContext(), shape, elementType, - encoding.value_or(mlir::Attribute()), - sg_map.value_or(mlir::Attribute())); + return TensorDescType::getChecked( + [&]() { return parser.emitError(parser.getNameLoc()); }, + parser.getContext(), shape, elementType, + encoding.value_or(mlir::Attribute()), sg_map.value_or(mlir::Attribute())); } void TensorDescType::print(::mlir::AsmPrinter &printer) const { @@ -223,6 +224,55 @@ TensorDescType TensorDescType::get(llvm::ArrayRef shape, return Base::get(context, shape, elementType, attr, sg_map); } +LogicalResult TensorDescType::verify( + llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, + llvm::ArrayRef shape, mlir::Type elementType, + mlir::Attribute encoding, mlir::Attribute sg_map) { + size_t rank = shape.size(); + if (rank > 2) + return emitError() << "desc shape rank exceeds 2"; + + if (auto sgMapAttr = llvm::dyn_cast_if_present(sg_map)) { + ArrayRef wiLayout = sgMapAttr.getWiLayout(); + ArrayRef wiData = sgMapAttr.getWiData(); + + if (rank == 1) { + if (wiLayout[0] != 1 || wiData[0] != 1) + return emitError() << "outer layout and data mapping must be 1 " + "for 1D tensor"; + } + + // For 1D tensor, pad the shape with an outer unit dimension to allow common + // validation logic. + SmallVector tensorShape(shape.begin(), shape.end()); + if (rank == 1) + tensorShape = {1, tensorShape.back()}; + + size_t dims = tensorShape.size(); + for (size_t i = 0; i < dims; ++i) { + uint32_t numElemPerWi = wiLayout[i] * wiData[i]; + if (tensorShape[i] < numElemPerWi || tensorShape[i] % numElemPerWi != 0) + return emitError() << "cannot map " << tensorShape[i] + << " elements into " << wiLayout[i] << " by " + << wiData[i] << " tiles"; + } + + if (llvm::isa_and_nonnull(encoding)) { + auto scatterAttr = llvm::dyn_cast(encoding); + if (wiData[0] != 1) + return emitError() + << "cannot map over non-contiguous scattered elements"; + + unsigned chunkSize = scatterAttr.getChunkSize().getInt(); + if (wiData[1] > chunkSize) + return emitError() + << "too few contiguous elements for work item mapping"; + } + } + + return success(); +} + } // namespace xegpu } // namespace mlir From 5e3e7a80e907c0e4160cc00b76cd3f3f6d6c5e56 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Fri, 24 Jan 2025 17:25:29 +0100 Subject: [PATCH 2/8] Test cases --- mlir/test/Dialect/XeGPU/invalid.mlir | 79 +++++++++++++++++++++++++++- 1 file changed, 78 insertions(+), 1 deletion(-) diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index 201f72120cf2c..975b4aea84fe2 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -315,4 +315,81 @@ func.func @test_atomic_rmw(%src: ui64, %value : vector<16x4xf32>, %mask : vector // expected-error@+1 {{failed to verify that all of {tensorDesc, value, result} have same shape}} xegpu.atomic_rmw addf %1, %mask, %value: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1>, vector<16x4xf32> -> vector<16x8xf32> return -} \ No newline at end of file +} + +// ----- +func.func @tensor_desc_invalid_rank(%src: memref<24x32xf32>) { + %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> + // expected-error@+1 {{desc shape rank exceeds 2}} + !xegpu.tensor_desc<16x2x2xf32> + return +} + +// ----- +func.func @tensor_desc_1D_invalid_map_layout(%src: memref<24x32xf32>) { + %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> + // expected-error@+1 {{outer layout and data mapping must be 1 for 1D tensor}} + !xegpu.tensor_desc<16xf32, #xegpu.sg_map> + return +} + +// ----- +func.func @tensor_desc_1D_invalid_map_data(%src: memref<24x32xf32>) { + %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> + // expected-error@+1 {{outer layout and data mapping must be 1 for 1D tensor}} + !xegpu.tensor_desc<16xf32, #xegpu.sg_map> + return +} + +// ----- +func.func @tensor_desc_invalid_map_layout(%src: memref<24x32xf32>) { + %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> + // expected-error@+1 {{cannot map 8 elements into 16 by 1 tiles}} + !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map> + return +} + +// ----- +func.func @tensor_desc_invalid_map_layout_1(%src: memref<24x32xf32>) { + %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> + // expected-error@+1 {{cannot map 4 elements into 8 by 1 tiles}} + !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map> + return +} + +// ----- +func.func @tensor_desc_invalid_map_data(%src: memref<24x32xf32>) { + %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> + // expected-error@+1 {{cannot map 4 elements into 2 by 4 tiles}} + !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map> + return +} + +// ----- +func.func @tensor_desc_invalid_map_data_1(%src: memref<24x32xf32>) { + %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> + // expected-error@+1 {{cannot map 4 elements into 8 by 1 tiles}} + !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map> + return +} + +// ----- +func.func @tensor_desc_scatter_invalid_map_data(%src: ui64) { + %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> + // expected-error@+1 {{cannot map over non-contiguous scattered elements}} + !xegpu.tensor_desc<4x2xf32, + #xegpu.scatter_tdesc_attr, + #xegpu.sg_map> + return +} + +// ----- +func.func @tensor_desc_scatter_invalid_map_data_1(%src: ui64, %offsets: vector<16xindex>) { + %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> -> + // expected-error@+1 {{too few contiguous elements for work item mapping}} + !xegpu.tensor_desc<16xf32, + #xegpu.scatter_tdesc_attr, + #xegpu.sg_map> + return +} From 7aba5a69eca63bd4f0924f48df4bbe0f3c5a8439 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Mon, 27 Jan 2025 14:40:36 +0100 Subject: [PATCH 3/8] Update load/store verifier + tests --- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 26 +++++++------ mlir/test/Dialect/XeGPU/XeGPUOps.mlir | 22 +++++++++++ mlir/test/Dialect/XeGPU/invalid.mlir | 52 ++++++++++++++++++++++++-- 3 files changed, 86 insertions(+), 14 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index cd883baa986b8..996fb36382033 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -81,24 +81,28 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) { // each dimension. static bool isArgShapesValid(ArrayRef descShape, ArrayRef valShape, SGMapAttr sgMap) { - if (descShape == valShape) { - if (!sgMap) - return true; - - // this can be relaxed if necessary by supporting non-2d shapes distribution - // until the constraints are defined this lives here instead of the tensor - // descriptor type. - return valShape.size() == sgMap.getWiLayout().size(); - } + // Equal shapes with no distribution - no further verification needed. + if (descShape == valShape && !sgMap) + return true; + // Unknown distribution - cannot perform operation on partial shape. if (!sgMap) return false; - if (valShape.size() != descShape.size()) + // Invalid rank or mixed rank usage. + size_t descRank = descShape.size(); + if (descRank > 2 || valShape.size() != descRank) return false; + // For 1D, SG map is guaranteed to be unit size in the outer dimension. + // Only take the distribution over the innermost dimension for validation. + ArrayRef wiLayout = sgMap.getWiLayout(); + SmallVector mapLayout(wiLayout.begin(), wiLayout.end()); + if (descRank == 1) + mapLayout = {wiLayout.back()}; + for (const auto &[factor, dim, expected] : - llvm::zip_equal(sgMap.getWiLayout(), valShape, descShape)) { + llvm::zip_equal(mapLayout, valShape, descShape)) { if (factor * dim != expected) return false; } diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir index dcd6b01974cf3..8af1b600ad0a4 100644 --- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir +++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir @@ -97,6 +97,16 @@ gpu.func @test_load_nd_vc_3(%src: memref<24x32xf32>) { gpu.return } +// CHECK: func @test_load_nd_vc_4(%[[arg0:.*]]: memref<24x32xf32>) { +gpu.func @test_load_nd_vc_4(%src: memref<24x32xf32>) { + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.sg_map> + %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> + !xegpu.tensor_desc<32xf32, #xegpu.sg_map> + // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32xf32, #xegpu.sg_map> -> vector<2xf32> + %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32xf32, #xegpu.sg_map> -> vector<2xf32> + gpu.return +} + // CHECK: func @test_store_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) { gpu.func @test_store_nd_vc(%dst: memref<24x32xf16>) { // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x32xf16> @@ -132,6 +142,18 @@ gpu.func @test_store_nd_vc_3(%src: memref<24x32xf16>) { gpu.return } +// CHECK: func @test_store_nd_vc_4(%[[arg0:.*]]: memref<24x32xf16>) { +gpu.func @test_store_nd_vc_4(%src: memref<24x32xf16>) { + // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2xf16> + %1 = arith.constant dense<1.0>: vector<2xf16> + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16, #xegpu.sg_map> + %2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> + !xegpu.tensor_desc<32xf16, #xegpu.sg_map> + // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<2xf16>, !xegpu.tensor_desc<32xf16, #xegpu.sg_map> + xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: vector<2xf16>, !xegpu.tensor_desc<32xf16, #xegpu.sg_map> + gpu.return +} + // CHECK: gpu.func @test_create_update_nd_tdesc_vc(%[[arg0:.*]]: memref<24x32xf32>) { gpu.func @test_create_update_nd_tdesc_vc(%src: memref<24x32xf32>) { // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index 975b4aea84fe2..5ed93db2be502 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -82,16 +82,33 @@ func.func @test_load_nd_vc_4(%src: memref<24x32xf32>) { %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map> // expected-error@+1 {{Result shape doesn't match TensorDesc shape.}} - %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map> -> vector<8x2xf32> + %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, + l2_hint = #xegpu.cache_hint}> + : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map> + -> vector<8x2xf32> return } // ----- func.func @test_load_nd_vc_5(%src: memref<24x32xf32>) { %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> - !xegpu.tensor_desc<16xf32, #xegpu.sg_map> + !xegpu.tensor_desc<16xf32, #xegpu.sg_map> // expected-error@+1 {{Result shape doesn't match TensorDesc shape.}} - %2 = xegpu.load_nd %1: !xegpu.tensor_desc<16xf32, #xegpu.sg_map> -> vector<16xf32> + %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, + l2_hint = #xegpu.cache_hint}> + : !xegpu.tensor_desc<16xf32, #xegpu.sg_map> + -> vector<8xf32> + return +} + +// ----- +func.func @test_load_nd_vc_6(%src: memref<24x32xf32>) { + %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> + !xegpu.tensor_desc<8x16xf32> + // expected-error@+1 {{Result shape doesn't match TensorDesc shape.}} + %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, + l2_hint = #xegpu.cache_hint}> + : !xegpu.tensor_desc<8x16xf32> -> vector<8x1xf32> return } @@ -116,6 +133,35 @@ func.func @test_store_nd_vc_2(%dst: memref<16xf16>) { return } +// ----- +func.func @test_store_nd_vc_3(%dst: memref<24x32xf32>, %data: vector<8x2xf32>) { + %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> -> + !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map> + // expected-error@+1 {{Result shape doesn't match TensorDesc shape.}} + xegpu.store_nd %data, %1 + : vector<8x2xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map> + return +} + +// ----- +func.func @test_store_nd_vc_4(%dst: memref<24x32xf32>, %data: vector<2xf32>) { + %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> -> + !xegpu.tensor_desc<16xf32, #xegpu.sg_map> + // expected-error@+1 {{Result shape doesn't match TensorDesc shape.}} + xegpu.store_nd %data, %1 + : vector<2xf32>, !xegpu.tensor_desc<16xf32, #xegpu.sg_map> + return +} + +// ----- +func.func @test_store_nd_vc_5(%dst: memref<24x32xf32>, %data: vector<8x1xf32>) { + %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> -> + !xegpu.tensor_desc<8x16xf32> + // expected-error@+1 {{Result shape doesn't match TensorDesc shape.}} + xegpu.store_nd %data, %1 : vector<8x1xf32>, !xegpu.tensor_desc<8x16xf32> + return +} + // ----- func.func @test_update_nd_offset_1(%dst: memref<16xf16>) { %0 = arith.constant dense<[0, 2, 4, 6, 8, 10, 12, 14]> : vector<8xindex> From fa82d893c6b8ee3ab9ebaa5fd97bf00553f10504 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Mon, 27 Jan 2025 15:16:33 +0100 Subject: [PATCH 4/8] Refactor --- mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 42d59da2f7a92..ef0ea38027c45 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -257,7 +257,7 @@ LogicalResult TensorDescType::verify( << wiData[i] << " tiles"; } - if (llvm::isa_and_nonnull(encoding)) { + if (mlir::isa_and_nonnull(encoding)) { auto scatterAttr = llvm::dyn_cast(encoding); if (wiData[0] != 1) return emitError() From 5df9cf0089676d2650effacba226ecfdf4fdd523 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Fri, 7 Feb 2025 13:53:20 +0100 Subject: [PATCH 5/8] Improve scattered verification --- mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 56 ++++++++++++++-------- mlir/test/Dialect/XeGPU/invalid.mlir | 34 ++++++++----- 2 files changed, 58 insertions(+), 32 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index ef0ea38027c45..077a924dfad26 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -229,8 +229,24 @@ LogicalResult TensorDescType::verify( llvm::ArrayRef shape, mlir::Type elementType, mlir::Attribute encoding, mlir::Attribute sg_map) { size_t rank = shape.size(); - if (rank > 2) - return emitError() << "desc shape rank exceeds 2"; + if (rank != 1 && rank != 2) + return emitError() << "expected 1D or 2D tensor"; + + // Scattered attribute imposes extra restriction on tensor descriptor. + // Block attribute can only be validated further against data transfer + // operations. + auto scatterAttr = mlir::dyn_cast_if_present(encoding); + if (scatterAttr) { + // Expected tensor ranks for scattered data: + // - 1D tensor for fully non-contiguous elements (chunk size == 1) + // - 2D tensor for scattered blocks (chunk size > 1) + IntegerAttr chunkAttr = scatterAttr.getChunkSize(); + unsigned chunkSize = chunkAttr ? chunkAttr.getInt() : 1; + if (rank == 1 && chunkSize != 1) + return emitError() << "expected non-contiguous elements for 1D tensor"; + if (rank == 2 && chunkSize < 2) + return emitError() << "expected chunk blocks for 2D tensor"; + } if (auto sgMapAttr = llvm::dyn_cast_if_present(sg_map)) { ArrayRef wiLayout = sgMapAttr.getWiLayout(); @@ -238,8 +254,22 @@ LogicalResult TensorDescType::verify( if (rank == 1) { if (wiLayout[0] != 1 || wiData[0] != 1) - return emitError() << "outer layout and data mapping must be 1 " - "for 1D tensor"; + return emitError() + << "outer layout distribution and data mapping must be 1 " + "for 1D tensor"; + } + + if (scatterAttr) { + // Validate subgroup mapping rules for scattered tensors. + if (wiData[0] != 1) + return emitError() + << "cannot map over non-contiguous scattered row elements"; + + IntegerAttr chunkAttr = scatterAttr.getChunkSize(); + unsigned chunkSize = chunkAttr ? chunkAttr.getInt() : 1; + if (wiData[1] != chunkSize) + return emitError() << "work item data mapping must match the number of " + "contiguous elements"; } // For 1D tensor, pad the shape with an outer unit dimension to allow common @@ -252,21 +282,9 @@ LogicalResult TensorDescType::verify( for (size_t i = 0; i < dims; ++i) { uint32_t numElemPerWi = wiLayout[i] * wiData[i]; if (tensorShape[i] < numElemPerWi || tensorShape[i] % numElemPerWi != 0) - return emitError() << "cannot map " << tensorShape[i] - << " elements into " << wiLayout[i] << " by " - << wiData[i] << " tiles"; - } - - if (mlir::isa_and_nonnull(encoding)) { - auto scatterAttr = llvm::dyn_cast(encoding); - if (wiData[0] != 1) - return emitError() - << "cannot map over non-contiguous scattered elements"; - - unsigned chunkSize = scatterAttr.getChunkSize().getInt(); - if (wiData[1] > chunkSize) - return emitError() - << "too few contiguous elements for work item mapping"; + return emitError() << "cannot distribute " << tensorShape[i] << " over " + << wiLayout[i] << " work items with " << wiData[i] + << " elements each"; } } diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index 5ed93db2be502..733eb1559d6fb 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -183,8 +183,8 @@ func.func @test_create_tdesc_vc_1(%src: ui64) { // ----- func.func @test_create_tdesc_vc_2(%src: ui64) { %0 = arith.constant dense<[0, 2, 4, 6, 8, 10, 12, 14]> : vector<8xindex> - // expected-error@+1 {{Incorrect TensorDesc shape}} %1 = xegpu.create_tdesc %src, %0 : ui64, vector<8xindex> + // expected-error@+1 {{expected chunk blocks for 2D tensor}} -> !xegpu.tensor_desc<8x4xf16, #xegpu.scatter_tdesc_attr<>> return } @@ -219,7 +219,7 @@ func.func @test_prefetch_vc_2(%src: ui64) { // ----- func.func @test_create_tdesc_sg_map_1(%src: ui64) { %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> - // expected-error@+1 {{Detected a conflict between SG map's work-item layout and TensorDesc shape. Check the index of `subgroup_size` in WI layout map}} + // expected-error@+1 {{outer layout distribution and data mapping must be 1 for 1D tensor}} %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map> return } @@ -227,7 +227,7 @@ func.func @test_create_tdesc_sg_map_1(%src: ui64) { // ----- func.func @test_create_tdesc_sg_map_2(%src: ui64) { %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> - // expected-error@+1 {{TensorDesc's SG map only supports multiple elements contiguous along rows}} + // expected-error@+1 {{cannot map over non-contiguous scattered row elements}} %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map> return } @@ -235,7 +235,7 @@ func.func @test_create_tdesc_sg_map_2(%src: ui64) { // ----- func.func @test_create_tdesc_sg_map_3(%src: ui64) { %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> - // expected-error@+1 {{TensorDesc's chunkSize must match WI's data mapping}} + // expected-error@+1 {{work item data mapping must match the number of contiguous elements}} %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x3xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map> return } @@ -366,15 +366,23 @@ func.func @test_atomic_rmw(%src: ui64, %value : vector<16x4xf32>, %mask : vector // ----- func.func @tensor_desc_invalid_rank(%src: memref<24x32xf32>) { %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> - // expected-error@+1 {{desc shape rank exceeds 2}} + // expected-error@+1 {{expected 1D or 2D tensor}} !xegpu.tensor_desc<16x2x2xf32> return } +// ----- +func.func @tensor_desc_invalid_rank_1(%src: memref<24x32xf32>) { + %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> + // expected-error@+1 {{expected 1D or 2D tensor}} + !xegpu.tensor_desc + return +} + // ----- func.func @tensor_desc_1D_invalid_map_layout(%src: memref<24x32xf32>) { %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> - // expected-error@+1 {{outer layout and data mapping must be 1 for 1D tensor}} + // expected-error@+1 {{outer layout distribution and data mapping must be 1 for 1D tensor}} !xegpu.tensor_desc<16xf32, #xegpu.sg_map> return } @@ -382,7 +390,7 @@ func.func @tensor_desc_1D_invalid_map_layout(%src: memref<24x32xf32>) { // ----- func.func @tensor_desc_1D_invalid_map_data(%src: memref<24x32xf32>) { %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> - // expected-error@+1 {{outer layout and data mapping must be 1 for 1D tensor}} + // expected-error@+1 {{outer layout distribution and data mapping must be 1 for 1D tensor}} !xegpu.tensor_desc<16xf32, #xegpu.sg_map> return } @@ -390,7 +398,7 @@ func.func @tensor_desc_1D_invalid_map_data(%src: memref<24x32xf32>) { // ----- func.func @tensor_desc_invalid_map_layout(%src: memref<24x32xf32>) { %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> - // expected-error@+1 {{cannot map 8 elements into 16 by 1 tiles}} + // expected-error@+1 {{cannot distribute 8 over 16 work items with 1 elements each}} !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map> return } @@ -398,7 +406,7 @@ func.func @tensor_desc_invalid_map_layout(%src: memref<24x32xf32>) { // ----- func.func @tensor_desc_invalid_map_layout_1(%src: memref<24x32xf32>) { %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> - // expected-error@+1 {{cannot map 4 elements into 8 by 1 tiles}} + // expected-error@+1 {{cannot distribute 4 over 8 work items with 1 elements each}} !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map> return } @@ -406,7 +414,7 @@ func.func @tensor_desc_invalid_map_layout_1(%src: memref<24x32xf32>) { // ----- func.func @tensor_desc_invalid_map_data(%src: memref<24x32xf32>) { %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> - // expected-error@+1 {{cannot map 4 elements into 2 by 4 tiles}} + // expected-error@+1 {{cannot distribute 4 over 2 work items with 4 elements each}} !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map> return } @@ -414,7 +422,7 @@ func.func @tensor_desc_invalid_map_data(%src: memref<24x32xf32>) { // ----- func.func @tensor_desc_invalid_map_data_1(%src: memref<24x32xf32>) { %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> - // expected-error@+1 {{cannot map 4 elements into 8 by 1 tiles}} + // expected-error@+1 {{cannot distribute 4 over 8 work items with 1 elements each}} !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map> return } @@ -423,7 +431,7 @@ func.func @tensor_desc_invalid_map_data_1(%src: memref<24x32xf32>) { func.func @tensor_desc_scatter_invalid_map_data(%src: ui64) { %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> - // expected-error@+1 {{cannot map over non-contiguous scattered elements}} + // expected-error@+1 {{cannot map over non-contiguous scattered row elements}} !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map> @@ -433,7 +441,7 @@ func.func @tensor_desc_scatter_invalid_map_data(%src: ui64) { // ----- func.func @tensor_desc_scatter_invalid_map_data_1(%src: ui64, %offsets: vector<16xindex>) { %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> -> - // expected-error@+1 {{too few contiguous elements for work item mapping}} + // expected-error@+1 {{work item data mapping must match the number of contiguous elements}} !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map> From f502c66de3e230a364ead78f3d3404f7368ca10d Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Fri, 7 Feb 2025 14:01:52 +0100 Subject: [PATCH 6/8] Remove TensorDesc invariant checks from op verifier --- mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 3 +++ mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 17 +---------------- 2 files changed, 4 insertions(+), 16 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 077a924dfad26..0f17c6a8fa98d 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -261,6 +261,9 @@ LogicalResult TensorDescType::verify( if (scatterAttr) { // Validate subgroup mapping rules for scattered tensors. + // A work-item's slice of the tensor with shape [sg_size] or + // [sg_size, chunk_size] will be [1] or [1, chunks_size] respectively, + // the mapping should reflect that. if (wiData[0] != 1) return emitError() << "cannot map over non-contiguous scattered row elements"; diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 996fb36382033..476689fae4e25 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -458,22 +458,7 @@ LogicalResult CreateDescOp::verify() { if (shape != tdescShape) return emitOpError("Incorrect TensorDesc shape. ") << "Expected is " << makeString(shape) << "\n"; - if (auto sgMap = tdescTy.getSGMapAttr()) { - // A work-item's slice of the TensorDesc with shape [sg_size] or - // [sg_size, chunk_size] will be [1] or [1, chunks_size] respectively, - // the mapping should reflect that. - if (sgMap.getWiData()[0] > 1) - return emitOpError("TensorDesc's SG map only supports multiple elements " - "contiguous along rows."); - if (chunkSize != static_cast(sgMap.getWiData()[1])) - return emitOpError( - "TensorDesc's chunkSize must match WI's data mapping."); - if (int rank = tdescTy.getRank(); - (sgMap.getWiLayout()[2 - rank] != tdescShape[0])) - return emitOpError("Detected a conflict between SG map's work-item " - "layout and TensorDesc shape. Check the index of " - "`subgroup_size` in WI layout map."); - } + return success(); } From 50c628323506d2af53cfa8ea62259485cadc348c Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Fri, 7 Feb 2025 14:20:45 +0100 Subject: [PATCH 7/8] Add more chunk_size test cases --- mlir/test/Dialect/XeGPU/invalid.mlir | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index 733eb1559d6fb..48e8c2808abda 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -447,3 +447,23 @@ func.func @tensor_desc_scatter_invalid_map_data_1(%src: ui64, %offsets: vector<1 #xegpu.sg_map> return } + +// ----- +func.func @tensor_desc_scatter_invalid_chunk_size_1D(%src: ui64, %offsets: vector<16xindex>) { + %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> -> + // expected-error@+1 {{expected non-contiguous elements for 1D tensor}} + !xegpu.tensor_desc<16xf32, + #xegpu.scatter_tdesc_attr, + #xegpu.sg_map> + return +} + +// ----- +func.func @tensor_desc_scatter_invalid_chunk_size_2D(%src: ui64, %offsets: vector<16xindex>) { + %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> -> + // expected-error@+1 {{expected chunk blocks for 2D tensor}} + !xegpu.tensor_desc<16x2xf32, + #xegpu.scatter_tdesc_attr, + #xegpu.sg_map> + return +} From a72c12f8f05f8f7456e442f40882f809ac10cdd1 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Fri, 7 Feb 2025 14:59:55 +0100 Subject: [PATCH 8/8] Move memory space check to type verifier --- mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 11 ++++++++--- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 4 ---- mlir/test/Dialect/XeGPU/invalid.mlir | 2 +- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 0f17c6a8fa98d..becc32d122697 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -232,9 +232,6 @@ LogicalResult TensorDescType::verify( if (rank != 1 && rank != 2) return emitError() << "expected 1D or 2D tensor"; - // Scattered attribute imposes extra restriction on tensor descriptor. - // Block attribute can only be validated further against data transfer - // operations. auto scatterAttr = mlir::dyn_cast_if_present(encoding); if (scatterAttr) { // Expected tensor ranks for scattered data: @@ -248,6 +245,14 @@ LogicalResult TensorDescType::verify( return emitError() << "expected chunk blocks for 2D tensor"; } + if (auto blockAttr = + mlir::dyn_cast_if_present(encoding)) { + MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace(); + if (rank == 2 && memorySpaceAttr && + memorySpaceAttr.getValue() == MemorySpace::SLM) + return emitError() << "SLM is not supported for 2D block tensor"; + } + if (auto sgMapAttr = llvm::dyn_cast_if_present(sg_map)) { ArrayRef wiLayout = sgMapAttr.getWiLayout(); ArrayRef wiData = sgMapAttr.getWiData(); diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 476689fae4e25..e06d99ac20bb7 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -231,10 +231,6 @@ LogicalResult CreateNdDescOp::verify() { if (getType().isScattered()) return emitOpError("Expects a non-scattered TensorDesc.\n"); - if (getType().getRank() == 2 && - tdescMemorySpace == static_cast(MemorySpace::SLM)) - return emitOpError("SLM is not supported for 2D Block TensorDesc.\n"); - return success(); } diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index 48e8c2808abda..9162e0012f6d5 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -17,7 +17,7 @@ func.func @test_create_nd_tdesc_vc_2(%src: memref<24x32xf32>) { // ----- func.func @test_create_nd_tdesc_vc_3(%src: memref<2x24x32xf32, 3>) { - // expected-error@+1 {{SLM is not supported for 2D Block TensorDesc}} + // expected-error@+1 {{SLM is not supported for 2D block tensor}} %1 = xegpu.create_nd_tdesc %src[0, 0, 0] : memref<2x24x32xf32, 3> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> return }