diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td index d09c5c1870d50..494f11f041b71 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td @@ -179,7 +179,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc", }]; let hasCustomAssemblyFormat = true; - + let genVerifyDecl = 1; } diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index eb01b15de75c6..becc32d122697 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -175,9 +175,10 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) { if (parser.parseGreater()) return {}; - return TensorDescType::get(parser.getContext(), shape, elementType, - encoding.value_or(mlir::Attribute()), - sg_map.value_or(mlir::Attribute())); + return TensorDescType::getChecked( + [&]() { return parser.emitError(parser.getNameLoc()); }, + parser.getContext(), shape, elementType, + encoding.value_or(mlir::Attribute()), sg_map.value_or(mlir::Attribute())); } void TensorDescType::print(::mlir::AsmPrinter &printer) const { @@ -223,6 +224,81 @@ TensorDescType TensorDescType::get(llvm::ArrayRef shape, return Base::get(context, shape, elementType, attr, sg_map); } +LogicalResult TensorDescType::verify( + llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, + llvm::ArrayRef shape, mlir::Type elementType, + mlir::Attribute encoding, mlir::Attribute sg_map) { + size_t rank = shape.size(); + if (rank != 1 && rank != 2) + return emitError() << "expected 1D or 2D tensor"; + + auto scatterAttr = mlir::dyn_cast_if_present(encoding); + if (scatterAttr) { + // Expected tensor ranks for scattered data: + // - 1D tensor for fully non-contiguous elements (chunk size == 1) + // - 2D tensor for scattered blocks (chunk size > 1) + IntegerAttr chunkAttr = scatterAttr.getChunkSize(); + unsigned chunkSize = chunkAttr ? chunkAttr.getInt() : 1; + if (rank == 1 && chunkSize != 1) + return emitError() << "expected non-contiguous elements for 1D tensor"; + if (rank == 2 && chunkSize < 2) + return emitError() << "expected chunk blocks for 2D tensor"; + } + + if (auto blockAttr = + mlir::dyn_cast_if_present(encoding)) { + MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace(); + if (rank == 2 && memorySpaceAttr && + memorySpaceAttr.getValue() == MemorySpace::SLM) + return emitError() << "SLM is not supported for 2D block tensor"; + } + + if (auto sgMapAttr = llvm::dyn_cast_if_present(sg_map)) { + ArrayRef wiLayout = sgMapAttr.getWiLayout(); + ArrayRef wiData = sgMapAttr.getWiData(); + + if (rank == 1) { + if (wiLayout[0] != 1 || wiData[0] != 1) + return emitError() + << "outer layout distribution and data mapping must be 1 " + "for 1D tensor"; + } + + if (scatterAttr) { + // Validate subgroup mapping rules for scattered tensors. + // A work-item's slice of the tensor with shape [sg_size] or + // [sg_size, chunk_size] will be [1] or [1, chunks_size] respectively, + // the mapping should reflect that. + if (wiData[0] != 1) + return emitError() + << "cannot map over non-contiguous scattered row elements"; + + IntegerAttr chunkAttr = scatterAttr.getChunkSize(); + unsigned chunkSize = chunkAttr ? chunkAttr.getInt() : 1; + if (wiData[1] != chunkSize) + return emitError() << "work item data mapping must match the number of " + "contiguous elements"; + } + + // For 1D tensor, pad the shape with an outer unit dimension to allow common + // validation logic. + SmallVector tensorShape(shape.begin(), shape.end()); + if (rank == 1) + tensorShape = {1, tensorShape.back()}; + + size_t dims = tensorShape.size(); + for (size_t i = 0; i < dims; ++i) { + uint32_t numElemPerWi = wiLayout[i] * wiData[i]; + if (tensorShape[i] < numElemPerWi || tensorShape[i] % numElemPerWi != 0) + return emitError() << "cannot distribute " << tensorShape[i] << " over " + << wiLayout[i] << " work items with " << wiData[i] + << " elements each"; + } + } + + return success(); +} + } // namespace xegpu } // namespace mlir diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index cd883baa986b8..e06d99ac20bb7 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -81,24 +81,28 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) { // each dimension. static bool isArgShapesValid(ArrayRef descShape, ArrayRef valShape, SGMapAttr sgMap) { - if (descShape == valShape) { - if (!sgMap) - return true; - - // this can be relaxed if necessary by supporting non-2d shapes distribution - // until the constraints are defined this lives here instead of the tensor - // descriptor type. - return valShape.size() == sgMap.getWiLayout().size(); - } + // Equal shapes with no distribution - no further verification needed. + if (descShape == valShape && !sgMap) + return true; + // Unknown distribution - cannot perform operation on partial shape. if (!sgMap) return false; - if (valShape.size() != descShape.size()) + // Invalid rank or mixed rank usage. + size_t descRank = descShape.size(); + if (descRank > 2 || valShape.size() != descRank) return false; + // For 1D, SG map is guaranteed to be unit size in the outer dimension. + // Only take the distribution over the innermost dimension for validation. + ArrayRef wiLayout = sgMap.getWiLayout(); + SmallVector mapLayout(wiLayout.begin(), wiLayout.end()); + if (descRank == 1) + mapLayout = {wiLayout.back()}; + for (const auto &[factor, dim, expected] : - llvm::zip_equal(sgMap.getWiLayout(), valShape, descShape)) { + llvm::zip_equal(mapLayout, valShape, descShape)) { if (factor * dim != expected) return false; } @@ -227,10 +231,6 @@ LogicalResult CreateNdDescOp::verify() { if (getType().isScattered()) return emitOpError("Expects a non-scattered TensorDesc.\n"); - if (getType().getRank() == 2 && - tdescMemorySpace == static_cast(MemorySpace::SLM)) - return emitOpError("SLM is not supported for 2D Block TensorDesc.\n"); - return success(); } @@ -454,22 +454,7 @@ LogicalResult CreateDescOp::verify() { if (shape != tdescShape) return emitOpError("Incorrect TensorDesc shape. ") << "Expected is " << makeString(shape) << "\n"; - if (auto sgMap = tdescTy.getSGMapAttr()) { - // A work-item's slice of the TensorDesc with shape [sg_size] or - // [sg_size, chunk_size] will be [1] or [1, chunks_size] respectively, - // the mapping should reflect that. - if (sgMap.getWiData()[0] > 1) - return emitOpError("TensorDesc's SG map only supports multiple elements " - "contiguous along rows."); - if (chunkSize != static_cast(sgMap.getWiData()[1])) - return emitOpError( - "TensorDesc's chunkSize must match WI's data mapping."); - if (int rank = tdescTy.getRank(); - (sgMap.getWiLayout()[2 - rank] != tdescShape[0])) - return emitOpError("Detected a conflict between SG map's work-item " - "layout and TensorDesc shape. Check the index of " - "`subgroup_size` in WI layout map."); - } + return success(); } diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir index dcd6b01974cf3..8af1b600ad0a4 100644 --- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir +++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir @@ -97,6 +97,16 @@ gpu.func @test_load_nd_vc_3(%src: memref<24x32xf32>) { gpu.return } +// CHECK: func @test_load_nd_vc_4(%[[arg0:.*]]: memref<24x32xf32>) { +gpu.func @test_load_nd_vc_4(%src: memref<24x32xf32>) { + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.sg_map> + %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> + !xegpu.tensor_desc<32xf32, #xegpu.sg_map> + // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32xf32, #xegpu.sg_map> -> vector<2xf32> + %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32xf32, #xegpu.sg_map> -> vector<2xf32> + gpu.return +} + // CHECK: func @test_store_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) { gpu.func @test_store_nd_vc(%dst: memref<24x32xf16>) { // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x32xf16> @@ -132,6 +142,18 @@ gpu.func @test_store_nd_vc_3(%src: memref<24x32xf16>) { gpu.return } +// CHECK: func @test_store_nd_vc_4(%[[arg0:.*]]: memref<24x32xf16>) { +gpu.func @test_store_nd_vc_4(%src: memref<24x32xf16>) { + // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2xf16> + %1 = arith.constant dense<1.0>: vector<2xf16> + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16, #xegpu.sg_map> + %2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> + !xegpu.tensor_desc<32xf16, #xegpu.sg_map> + // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<2xf16>, !xegpu.tensor_desc<32xf16, #xegpu.sg_map> + xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: vector<2xf16>, !xegpu.tensor_desc<32xf16, #xegpu.sg_map> + gpu.return +} + // CHECK: gpu.func @test_create_update_nd_tdesc_vc(%[[arg0:.*]]: memref<24x32xf32>) { gpu.func @test_create_update_nd_tdesc_vc(%src: memref<24x32xf32>) { // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index 201f72120cf2c..9162e0012f6d5 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -17,7 +17,7 @@ func.func @test_create_nd_tdesc_vc_2(%src: memref<24x32xf32>) { // ----- func.func @test_create_nd_tdesc_vc_3(%src: memref<2x24x32xf32, 3>) { - // expected-error@+1 {{SLM is not supported for 2D Block TensorDesc}} + // expected-error@+1 {{SLM is not supported for 2D block tensor}} %1 = xegpu.create_nd_tdesc %src[0, 0, 0] : memref<2x24x32xf32, 3> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> return } @@ -82,16 +82,33 @@ func.func @test_load_nd_vc_4(%src: memref<24x32xf32>) { %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map> // expected-error@+1 {{Result shape doesn't match TensorDesc shape.}} - %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map> -> vector<8x2xf32> + %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, + l2_hint = #xegpu.cache_hint}> + : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map> + -> vector<8x2xf32> return } // ----- func.func @test_load_nd_vc_5(%src: memref<24x32xf32>) { %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> - !xegpu.tensor_desc<16xf32, #xegpu.sg_map> + !xegpu.tensor_desc<16xf32, #xegpu.sg_map> // expected-error@+1 {{Result shape doesn't match TensorDesc shape.}} - %2 = xegpu.load_nd %1: !xegpu.tensor_desc<16xf32, #xegpu.sg_map> -> vector<16xf32> + %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, + l2_hint = #xegpu.cache_hint}> + : !xegpu.tensor_desc<16xf32, #xegpu.sg_map> + -> vector<8xf32> + return +} + +// ----- +func.func @test_load_nd_vc_6(%src: memref<24x32xf32>) { + %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> + !xegpu.tensor_desc<8x16xf32> + // expected-error@+1 {{Result shape doesn't match TensorDesc shape.}} + %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, + l2_hint = #xegpu.cache_hint}> + : !xegpu.tensor_desc<8x16xf32> -> vector<8x1xf32> return } @@ -116,6 +133,35 @@ func.func @test_store_nd_vc_2(%dst: memref<16xf16>) { return } +// ----- +func.func @test_store_nd_vc_3(%dst: memref<24x32xf32>, %data: vector<8x2xf32>) { + %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> -> + !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map> + // expected-error@+1 {{Result shape doesn't match TensorDesc shape.}} + xegpu.store_nd %data, %1 + : vector<8x2xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map> + return +} + +// ----- +func.func @test_store_nd_vc_4(%dst: memref<24x32xf32>, %data: vector<2xf32>) { + %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> -> + !xegpu.tensor_desc<16xf32, #xegpu.sg_map> + // expected-error@+1 {{Result shape doesn't match TensorDesc shape.}} + xegpu.store_nd %data, %1 + : vector<2xf32>, !xegpu.tensor_desc<16xf32, #xegpu.sg_map> + return +} + +// ----- +func.func @test_store_nd_vc_5(%dst: memref<24x32xf32>, %data: vector<8x1xf32>) { + %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> -> + !xegpu.tensor_desc<8x16xf32> + // expected-error@+1 {{Result shape doesn't match TensorDesc shape.}} + xegpu.store_nd %data, %1 : vector<8x1xf32>, !xegpu.tensor_desc<8x16xf32> + return +} + // ----- func.func @test_update_nd_offset_1(%dst: memref<16xf16>) { %0 = arith.constant dense<[0, 2, 4, 6, 8, 10, 12, 14]> : vector<8xindex> @@ -137,8 +183,8 @@ func.func @test_create_tdesc_vc_1(%src: ui64) { // ----- func.func @test_create_tdesc_vc_2(%src: ui64) { %0 = arith.constant dense<[0, 2, 4, 6, 8, 10, 12, 14]> : vector<8xindex> - // expected-error@+1 {{Incorrect TensorDesc shape}} %1 = xegpu.create_tdesc %src, %0 : ui64, vector<8xindex> + // expected-error@+1 {{expected chunk blocks for 2D tensor}} -> !xegpu.tensor_desc<8x4xf16, #xegpu.scatter_tdesc_attr<>> return } @@ -173,7 +219,7 @@ func.func @test_prefetch_vc_2(%src: ui64) { // ----- func.func @test_create_tdesc_sg_map_1(%src: ui64) { %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> - // expected-error@+1 {{Detected a conflict between SG map's work-item layout and TensorDesc shape. Check the index of `subgroup_size` in WI layout map}} + // expected-error@+1 {{outer layout distribution and data mapping must be 1 for 1D tensor}} %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map> return } @@ -181,7 +227,7 @@ func.func @test_create_tdesc_sg_map_1(%src: ui64) { // ----- func.func @test_create_tdesc_sg_map_2(%src: ui64) { %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> - // expected-error@+1 {{TensorDesc's SG map only supports multiple elements contiguous along rows}} + // expected-error@+1 {{cannot map over non-contiguous scattered row elements}} %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map> return } @@ -189,7 +235,7 @@ func.func @test_create_tdesc_sg_map_2(%src: ui64) { // ----- func.func @test_create_tdesc_sg_map_3(%src: ui64) { %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> - // expected-error@+1 {{TensorDesc's chunkSize must match WI's data mapping}} + // expected-error@+1 {{work item data mapping must match the number of contiguous elements}} %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x3xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map> return } @@ -315,4 +361,109 @@ func.func @test_atomic_rmw(%src: ui64, %value : vector<16x4xf32>, %mask : vector // expected-error@+1 {{failed to verify that all of {tensorDesc, value, result} have same shape}} xegpu.atomic_rmw addf %1, %mask, %value: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1>, vector<16x4xf32> -> vector<16x8xf32> return -} \ No newline at end of file +} + +// ----- +func.func @tensor_desc_invalid_rank(%src: memref<24x32xf32>) { + %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> + // expected-error@+1 {{expected 1D or 2D tensor}} + !xegpu.tensor_desc<16x2x2xf32> + return +} + +// ----- +func.func @tensor_desc_invalid_rank_1(%src: memref<24x32xf32>) { + %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> + // expected-error@+1 {{expected 1D or 2D tensor}} + !xegpu.tensor_desc + return +} + +// ----- +func.func @tensor_desc_1D_invalid_map_layout(%src: memref<24x32xf32>) { + %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> + // expected-error@+1 {{outer layout distribution and data mapping must be 1 for 1D tensor}} + !xegpu.tensor_desc<16xf32, #xegpu.sg_map> + return +} + +// ----- +func.func @tensor_desc_1D_invalid_map_data(%src: memref<24x32xf32>) { + %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> + // expected-error@+1 {{outer layout distribution and data mapping must be 1 for 1D tensor}} + !xegpu.tensor_desc<16xf32, #xegpu.sg_map> + return +} + +// ----- +func.func @tensor_desc_invalid_map_layout(%src: memref<24x32xf32>) { + %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> + // expected-error@+1 {{cannot distribute 8 over 16 work items with 1 elements each}} + !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map> + return +} + +// ----- +func.func @tensor_desc_invalid_map_layout_1(%src: memref<24x32xf32>) { + %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> + // expected-error@+1 {{cannot distribute 4 over 8 work items with 1 elements each}} + !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map> + return +} + +// ----- +func.func @tensor_desc_invalid_map_data(%src: memref<24x32xf32>) { + %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> + // expected-error@+1 {{cannot distribute 4 over 2 work items with 4 elements each}} + !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map> + return +} + +// ----- +func.func @tensor_desc_invalid_map_data_1(%src: memref<24x32xf32>) { + %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> + // expected-error@+1 {{cannot distribute 4 over 8 work items with 1 elements each}} + !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map> + return +} + +// ----- +func.func @tensor_desc_scatter_invalid_map_data(%src: ui64) { + %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> + // expected-error@+1 {{cannot map over non-contiguous scattered row elements}} + !xegpu.tensor_desc<4x2xf32, + #xegpu.scatter_tdesc_attr, + #xegpu.sg_map> + return +} + +// ----- +func.func @tensor_desc_scatter_invalid_map_data_1(%src: ui64, %offsets: vector<16xindex>) { + %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> -> + // expected-error@+1 {{work item data mapping must match the number of contiguous elements}} + !xegpu.tensor_desc<16xf32, + #xegpu.scatter_tdesc_attr, + #xegpu.sg_map> + return +} + +// ----- +func.func @tensor_desc_scatter_invalid_chunk_size_1D(%src: ui64, %offsets: vector<16xindex>) { + %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> -> + // expected-error@+1 {{expected non-contiguous elements for 1D tensor}} + !xegpu.tensor_desc<16xf32, + #xegpu.scatter_tdesc_attr, + #xegpu.sg_map> + return +} + +// ----- +func.func @tensor_desc_scatter_invalid_chunk_size_2D(%src: ui64, %offsets: vector<16xindex>) { + %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> -> + // expected-error@+1 {{expected chunk blocks for 2D tensor}} + !xegpu.tensor_desc<16x2xf32, + #xegpu.scatter_tdesc_attr, + #xegpu.sg_map> + return +}