diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index a2bfa721f2515..c2335eecc3781 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -548,9 +548,7 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> { let hasVerifier = 1; } -def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllRanksMatch<["value", "TensorDesc"]>, - AllElementTypesMatch<["value", "TensorDesc"]>, - AllElementCountsMatch<["value", "TensorDesc"]>]> { +def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllElementTypesMatch<["value", "TensorDesc"]>]> { let summary = "load a set of scattered data points from memory."; let description = [{ It (aka. load) load data per each work-item. The output @@ -620,8 +618,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllRanksMatch<["value", "TensorDesc"] let hasVerifier = 1; } -def XeGPU_StoreScatterOp : XeGPU_Op<"store", [AllElementCountsMatch<["value", "TensorDesc"]>, - AllElementTypesMatch<["value", "TensorDesc"]>]> { +def XeGPU_StoreScatterOp : XeGPU_Op<"store", [AllElementTypesMatch<["value", "TensorDesc"]>]> { let summary = "store data to scattered memory locations."; let description = [{ It (aka. store) stores data to scattered memory locations. The value is typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 15c435f1fa257..58bb931b6fa1e 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -453,7 +453,22 @@ LogicalResult CreateDescOp::verify() { if (shape != tdescShape) return emitOpError("Incorrect TensorDesc shape. ") << "Expected is " << makeString(shape) << "\n"; - + if (auto sgMap = tdescTy.getSGMapAttr()) { + // A work-item's slice of the TensorDesc with shape [sg_size] or + // [sg_size, chunk_size] will be [1] or [1, chunks_size] respectively, + // the mapping should reflect that. + if (sgMap.getWiData()[0] > 1) + return emitOpError("TensorDesc's SG map only supports multiple elements " + "contiguous along rows."); + if (chunkSize != static_cast(sgMap.getWiData()[1])) + return emitOpError( + "TensorDesc's chunkSize must match WI's data mapping."); + if (int rank = tdescTy.getRank(); + (sgMap.getWiLayout()[2 - rank] != tdescShape[0])) + return emitOpError("Detected a conflict between SG map's work-item " + "layout and TensorDesc shape. Check the index of " + "`subgroup_size` in WI layout map."); + } return success(); } @@ -512,10 +527,23 @@ LogicalResult LoadGatherOp::verify() { if (tdescTy.getRank() == 2) { if (!getTransposeAttr()) - return emitOpError("load_gather has to be transposed."); + return emitOpError("load of rank-2 tensor has to be transposed."); transpose({1, 0}, tdescShape); } + if (auto sgMap = tdescTy.getSGMapAttr()) { + auto valueVecTy = cast(valueTy); + const int32_t wiData = + sgMap.getWiData()[0] > 1 ? sgMap.getWiData()[0] : sgMap.getWiData()[1]; + // All represent the same concept: a number of row elements to store. + if (valueVecTy.getNumElements() != wiData || + valueVecTy.getNumElements() != tdescTy.getChunkSize()) { + return emitOpError("Chunk size, vector size and wi_data must match."); + } + // Work-item's slice (i.e., vector shape to load) is [1] or [1, chunk_size]. + tdescShape[tdescTy.getRank() - 1] = 1; + } + if (valueShape != tdescShape) return emitOpError("Unexpected result shape") << "(Expected shape: " << makeString(tdescShape) @@ -551,10 +579,23 @@ LogicalResult StoreScatterOp::verify() { if (tdescTy.getRank() == 2) { if (!getTransposeAttr()) - return emitOpError("load_gather has to be transposed."); + return emitOpError("Store of a rank-2 tensor has to be transposed."); transpose({1, 0}, tdescShape); } + if (auto sgMap = tdescTy.getSGMapAttr()) { + auto valueVecTy = cast(valueTy); + const int32_t wiData = + sgMap.getWiData()[0] > 1 ? sgMap.getWiData()[0] : sgMap.getWiData()[1]; + // All represent the same concept: a number of row elements to store. + if (valueVecTy.getNumElements() != wiData || + valueVecTy.getNumElements() != tdescTy.getChunkSize()) { + return emitOpError("Chunk size, vector size and wi_data must match."); + } + // Work-item's slice (i.e., vector to store) is [1] or [1, chunk_size]. + tdescShape[tdescTy.getRank() - 1] = 1; + } + if (valueShape != tdescShape) return emitOpError("Unexpected value shape") << "(Expected shape: " << makeString(tdescShape) diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir index d7174a489888a..dcd6b01974cf3 100644 --- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir +++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir @@ -163,11 +163,69 @@ gpu.func @test_create_tdesc_vc_1(%src: memref) { gpu.func @test_create_tdesc_vc_with_sg_map(%src: ui64) { //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> - //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map> - %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map> + //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map> + %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map> gpu.return } +// CHECK: gpu.func @test_load_with_sg_map(%[[arg0:.*]]: ui64) { +gpu.func @test_load_with_sg_map(%src: ui64) { + //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + //CHECK: %[[cst1:.*]] = arith.constant dense : vector<4xi1> + %1 = arith.constant dense<1>: vector<4xi1> + //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map> + %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map> + //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map>, vector<4xi1> -> vector<2x1xf32> + %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map>, vector<4xi1> -> vector<2x1xf32> + gpu.return +} + +// CHECK: gpu.func @test_load_with_sg_map_2(%[[arg0:.*]]: ui64) { +gpu.func @test_load_with_sg_map_2(%src: ui64) { + //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + //CHECK: %[[cst1:.*]] = arith.constant dense : vector<4xi1> + %1 = arith.constant dense<1>: vector<4xi1> + //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map> + %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map> + //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map>, vector<4xi1> -> vector<1xf32> + %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map>, vector<4xi1> -> vector<1xf32> + gpu.return +} + +// CHECK: gpu.func @test_store_with_sg_map(%[[arg0:.*]]: ui64) { +gpu.func @test_store_with_sg_map(%src: ui64) { + //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + //CHECK: %[[cst1:.*]] = arith.constant dense : vector<4xi1> + %1 = arith.constant dense<1>: vector<4xi1> + //CHECK: %[[cst2:.*]] = arith.constant dense<2.900000e+00> : vector<2x1xf32> + %2 = arith.constant dense<2.9>: vector<2x1xf32> + //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map> + %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map> + //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose}> : vector<2x1xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map>, vector<4xi1> + xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose}> : vector<2x1xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map>, vector<4xi1> + gpu.return +} + +// CHECK: gpu.func @test_store_with_sg_map_2(%[[arg0:.*]]: ui64) { +gpu.func @test_store_with_sg_map_2(%src: ui64) { + //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + //CHECK: %[[cst1:.*]] = arith.constant dense : vector<4xi1> + %1 = arith.constant dense<1>: vector<4xi1> + //CHECK: %[[cst2:.*]] = arith.constant dense<2.900000e+00> : vector<1xf32> + %2 = arith.constant dense<2.9>: vector<1xf32> + //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map> + %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map> + //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<1xf32>, !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map>, vector<4xi1> + xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<1xf32>, !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map>, vector<4xi1> + gpu.return +} + + + // CHECK: gpu.func @test_prefetch_vc(%[[arg0:.*]]: ui64) { gpu.func @test_prefetch_vc(%src: ui64) { //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index 7816bff0582f8..201f72120cf2c 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -170,6 +170,83 @@ func.func @test_prefetch_vc_2(%src: ui64) { return } +// ----- +func.func @test_create_tdesc_sg_map_1(%src: ui64) { + %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + // expected-error@+1 {{Detected a conflict between SG map's work-item layout and TensorDesc shape. Check the index of `subgroup_size` in WI layout map}} + %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map> + return +} + +// ----- +func.func @test_create_tdesc_sg_map_2(%src: ui64) { + %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + // expected-error@+1 {{TensorDesc's SG map only supports multiple elements contiguous along rows}} + %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map> + return +} + +// ----- +func.func @test_create_tdesc_sg_map_3(%src: ui64) { + %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + // expected-error@+1 {{TensorDesc's chunkSize must match WI's data mapping}} + %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x3xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map> + return +} + +// ----- +func.func @test_load_gather_sg_map_1(%src: ui64) { + %0 = arith.constant dense<1>: vector<4xi1> + %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map> + // expected-error@+1 {{Unexpected result shape(Expected shape: [2, 1], Given shape: [1, 2])}} + %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map>, vector<4xi1> -> vector<1x2xf32> + return +} + +// ----- +func.func @test_load_gather_sg_map_2(%src: ui64) { + %0 = arith.constant dense<1>: vector<4xi1> + %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map> + // expected-error@+1 {{Unexpected result shape(Expected shape: [2, 1], Given shape: [2])}} + %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map>, vector<4xi1> -> vector<2xf32> + return +} + +// ----- +func.func @test_load_gather_sg_map_3(%src: ui64) { + %0 = arith.constant dense<1>: vector<4xi1> + %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map> + // expected-error@+1 {{Chunk size, vector size and wi_data must match}} + %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map>, vector<4xi1> -> vector<1xf32> + return +} + + +// ----- +func.func @test_store_scatter_sg_map_1(%src: ui64) { + %0 = arith.constant dense<1>: vector<4xi1> + %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %val = arith.constant dense<2.9>: vector<1x2xf32> + %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map> + // expected-error@+1 {{Unexpected value shape(Expected shape: [2, 1], Given shape: [1, 2])}} + xegpu.store %val, %1, %0 <{l1_hint = #xegpu.cache_hint, transpose}> : vector<1x2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map>, vector<4xi1> + return +} + +// ----- +func.func @test_store_scatter_sg_map_2(%src: ui64) { + %0 = arith.constant dense<1>: vector<4xi1> + %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %val = arith.constant dense<2.9>: vector<2xf32> + %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map> + // expected-error@+1 {{Unexpected value shape(Expected shape: [2, 1], Given shape: [2])}} + xegpu.store %val, %1, %0 <{l1_hint = #xegpu.cache_hint, transpose}> : vector<2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map>, vector<4xi1> + return +} + // ----- func.func @test_load_gather_vc_1(%src: memref<24x32xf16>) { %0 = arith.constant dense<1>: vector<4xi1>