Skip to content

Commit 7ad625d

Browse files
committed
save work
1 parent 5c1c908 commit 7ad625d

File tree

2 files changed

+33
-19
lines changed

2 files changed

+33
-19
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,13 +1042,6 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
10421042
auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>();
10431043
unsigned operandIdx = operand->getOperandNumber();
10441044

1045-
auto srcTypedVal = dyn_cast<TypedValue<MemRefType>>(descOp.getSource());
1046-
if (!srcTypedVal)
1047-
return rewriter.notifyMatchFailure(
1048-
descOp, "expecting a memref typed value as the source");
1049-
1050-
auto descOffsets = descOp.getMixedOffsets();
1051-
10521045
xegpu::LayoutAttr layout = descOp.getType().getLayoutAttr();
10531046
if (!layout)
10541047
return rewriter.notifyMatchFailure(

mlir/test/Dialect/XeGPU/subgroup-distribution.mlir

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
// RUN: mlir-opt -xegpu-subgroup-distribute -split-input-file %s | FileCheck %s
22

3-
// CHECK-LABEL: gpu.func @test_store_nd_1d
3+
// CHECK-LABEL: gpu.func @store_nd_1d
44
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<16xf32>) {
55
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<1xf32>
66
// CHECK-DAG: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<16xf32> -> !xegpu.tensor_desc<16xf32>
77
// CHECK: xegpu.store_nd %[[CST]], %[[T0]] : vector<1xf32>, !xegpu.tensor_desc<16xf32>
88
// CHECK: gpu.return
99
gpu.module @test {
10-
gpu.func @test_store_nd_1d(%arg0: memref<16xf32>){
10+
gpu.func @store_nd_1d(%arg0: memref<16xf32>){
1111
%c0 = arith.constant 0 : index
1212
%1 = arith.constant dense<1.000000e+00> : vector<16xf32>
1313
%0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<16xf32> -> !xegpu.tensor_desc<16xf32>
@@ -17,13 +17,13 @@ gpu.func @test_store_nd_1d(%arg0: memref<16xf32>){
1717
}
1818

1919
// -----
20-
// CHECK-LABEL: gpu.func @test_store_nd_2d
20+
// CHECK-LABEL: gpu.func @store_nd_2d
2121
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<16x16xf16>) {
2222
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<16xf16>
2323
// CHECK-DAG: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
2424
// CHECK: xegpu.store_nd %[[CST]], %[[T0]] : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
2525
gpu.module @test {
26-
gpu.func @test_store_nd_2d(%arg0: memref<16x16xf16>){
26+
gpu.func @store_nd_2d(%arg0: memref<16x16xf16>){
2727
%c0 = arith.constant 0 : index
2828
%1 = arith.constant dense<1.000000e+00> : vector<16x16xf16>
2929
%0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
@@ -35,14 +35,14 @@ gpu.func @test_store_nd_2d(%arg0: memref<16x16xf16>){
3535

3636

3737
// -----
38-
// CHECK-LABEL: gpu.func @test_load_nd_1d
38+
// CHECK-LABEL: gpu.func @load_nd_1d
3939
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<16xf32>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16xf32>) {
4040
// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<16xf32> -> !xegpu.tensor_desc<16xf32>
4141
// CHECK-DAG: %[[T1:.*]] = xegpu.load_nd %[[T0]] : !xegpu.tensor_desc<16xf32> -> vector<1xf32>
4242
// CHECK-DAG: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16xf32> -> !xegpu.tensor_desc<16xf32>
4343
// CHECK: xegpu.store_nd %[[T1]], %[[T2]] : vector<1xf32>, !xegpu.tensor_desc<16xf32>
4444
gpu.module @test {
45-
gpu.func @test_load_nd_1d(%arg0: memref<16xf32>, %arg1: memref<16xf32>){
45+
gpu.func @load_nd_1d(%arg0: memref<16xf32>, %arg1: memref<16xf32>){
4646
%c0 = arith.constant 0 : index
4747
%0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<16xf32> -> !xegpu.tensor_desc<16xf32>
4848
%1 = xegpu.load_nd %0 : !xegpu.tensor_desc<16xf32> -> vector<16xf32>
@@ -53,14 +53,14 @@ gpu.func @test_load_nd_1d(%arg0: memref<16xf32>, %arg1: memref<16xf32>){
5353
}
5454

5555
// -----
56-
// CHECK-LABEL: gpu.func @test_load_nd_2d
56+
// CHECK-LABEL: gpu.func @load_nd_2d
5757
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>) {
5858
// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
5959
// CHECK-DAG: %[[T1:.*]] = xegpu.load_nd %[[T0]] : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
6060
// CHECK-DAG: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
6161
// CHECK: xegpu.store_nd %[[T1]], %[[T2]] : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
6262
gpu.module @test {
63-
gpu.func @test_load_nd_2d(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>){
63+
gpu.func @load_nd_2d(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>){
6464
%c0 = arith.constant 0 : index
6565
%0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
6666
%1 = xegpu.load_nd %0 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -71,7 +71,7 @@ gpu.func @test_load_nd_2d(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>){
7171
}
7272

7373
// -----
74-
// CHECK-LABEL: gpu.func @test_load_nd_array_length
74+
// CHECK-LABEL: gpu.func @load_nd_array_length
7575
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>) {
7676
// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>>
7777
// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>> -> vector<32xf16>
@@ -81,7 +81,7 @@ gpu.func @test_load_nd_2d(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>){
8181
// CHECK-DAG: %[[T5:.*]] = vector.shape_cast %[[T3]] : vector<16x1xf16> to vector<16xf16>
8282
// CHECK: xegpu.store_nd %[[T5]], %[[T4]] : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
8383
gpu.module @test {
84-
gpu.func @test_load_nd_array_length(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>){
84+
gpu.func @load_nd_array_length(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>){
8585
%c0 = arith.constant 0 : index
8686
%0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>>
8787
%1 = xegpu.load_nd %0 : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>> -> vector<2x16x16xf16>
@@ -93,7 +93,7 @@ gpu.func @test_load_nd_array_length(%arg0: memref<16x16xf16>, %arg1: memref<16x1
9393
}
9494

9595
// -----
96-
// CHECK-LABEL: gpu.func @test_dpas
96+
// CHECK-LABEL: gpu.func @dpas
9797
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: vector<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: vector<16x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: vector<8x16xf32>, %[[ARG3:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
9898
// CHECK: %[[T1:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] args(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]]
9999
// CHECK-SAME: vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32>, memref<8x16xf32>) -> (vector<8x1xf16>, vector<16x1xf16>, vector<8x1xf32>) {
@@ -107,7 +107,7 @@ gpu.func @test_load_nd_array_length(%arg0: memref<16x16xf16>, %arg1: memref<16x1
107107
// CHECK: %[[T6:.*]] = xegpu.create_nd_tdesc %[[ARG3]][%{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
108108
// CHECK: xegpu.store_nd %[[T5]], %[[T6]] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
109109
gpu.module @test {
110-
gpu.func @test_dpas(%arg0: vector<8x16xf16>, %arg1: vector<16x16xf16>, %arg3: vector<8x16xf32>, %arg2: memref<8x16xf32>){
110+
gpu.func @dpas(%arg0: vector<8x16xf16>, %arg1: vector<16x16xf16>, %arg3: vector<8x16xf32>, %arg2: memref<8x16xf32>){
111111
%c0 = arith.constant 0 : index
112112
%0 = xegpu.dpas %arg0, %arg1, %arg3 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
113113
%3 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
@@ -139,3 +139,24 @@ gpu.func @load_dpas_store(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %ar
139139
gpu.return
140140
}
141141
}
142+
143+
// -----
144+
gpu.module @test {
145+
// CHECK-LABEL: gpu.func @create_nd_tdesc_non_memref
146+
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: ui64, %[[ARG1:[0-9a-zA-Z]+]]: ui64, %[[ARG2:[0-9a-zA-Z]+]]: index,
147+
// CHECK-SAME: %[[ARG3:[0-9a-zA-Z]+]]: index, %[[ARG4:[0-9a-zA-Z]+]]: index,
148+
// CHECK-SAME: %[[ARG5:[0-9a-zA-Z]+]]: index, %[[ARG6:[0-9a-zA-Z]+]]: index, %[[ARG7:[0-9a-zA-Z]+]]: index) {
149+
// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}], [%[[ARG2]], %[[ARG3]]], [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16>
150+
// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
151+
// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}], [%[[ARG2]], %[[ARG3]]], [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16>
152+
// CHECK: xegpu.store_nd %[[T1]], %[[T2]] : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
153+
gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %arg1: ui64,
154+
%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index) {
155+
%c0 = arith.constant 0 : index
156+
%0 = xegpu.create_nd_tdesc %arg0 [%c0, %c0], [%arg2, %arg3], [%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16>
157+
%1 = xegpu.load_nd %0 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
158+
%2 = xegpu.create_nd_tdesc %arg1[%c0, %c0], [%arg2, %arg3], [%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16>
159+
xegpu.store_nd %1, %2 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
160+
gpu.return
161+
}
162+
}

0 commit comments

Comments
 (0)