Skip to content

Commit 68b1750

Browse files
committed
fix
1 parent 2896b34 commit 68b1750

File tree

2 files changed

+33
-20
lines changed

2 files changed

+33
-20
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -937,6 +937,17 @@ removeTemporaryLayoutAttributes(ArrayRef<NamedAttribute> attrs) {
937937
return newAttrs;
938938
}
939939

940+
/// Helper function to check if the layout is packed. Layout is packed if it is
941+
/// 2D and lane_data[0] != 1 (data packed from col dimension).
942+
static bool hasPackedLayout(xegpu::LayoutAttr layout) {
943+
if (layout == xegpu::LayoutAttr())
944+
return false;
945+
auto laneData = layout.getLaneData();
946+
if (!laneData || laneData.size() != 2)
947+
return false;
948+
return laneData.asArrayRef()[0] != 1;
949+
}
950+
940951
/// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body
941952
/// of the original GPUFuncOp to the new GPUFuncOp such that entire body is
942953
/// contained within a WarpExecuteOnLane0Op.
@@ -1265,18 +1276,20 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
12651276
dropLayouts(loadOp.getTensorDescType()); /// Distributed tensor
12661277
/// descriptor type does not
12671278
/// contain layout info.
1268-
Value newLoadOp = rewriter.create<xegpu::LoadNdOp>(
1279+
auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
12691280
newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
12701281
resolveDistributedTy(newWarpOp->getResult(newRetIndices[0]),
12711282
distributedTensorDescTy, rewriter),
12721283
removeTemporaryLayoutAttributes(loadOp->getAttrs()));
1284+
/// Set the packed attribute if the layout requires it.
1285+
newLoadOp.setPacked(hasPackedLayout(layout));
12731286
Value distributedVal = newWarpOp.getResult(operandIdx);
12741287
/// There can be a conflict between the vector type distributed by the
12751288
/// warp op and (xegpu-specific) distributed type supported by the load
12761289
/// op. Resolve these mismatches by inserting a cast.
1277-
newLoadOp =
1278-
resolveDistributedTy(newLoadOp, distributedTypeByWarpOp, rewriter);
1279-
rewriter.replaceAllUsesWith(distributedVal, newLoadOp);
1290+
auto tyResolvedVal = resolveDistributedTy(
1291+
newLoadOp.getResult(), distributedTypeByWarpOp, rewriter);
1292+
rewriter.replaceAllUsesWith(distributedVal, tyResolvedVal);
12801293
return success();
12811294
}
12821295
};

mlir/test/Dialect/XeGPU/subgroup-distribution.mlir

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
// CHECK-LABEL: gpu.func @test_store_nd_1d
44
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<16xf32>) {
5-
// CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<1xf32>
6-
// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<16xf32> -> !xegpu.tensor_desc<16xf32>
5+
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<1xf32>
6+
// CHECK-DAG: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<16xf32> -> !xegpu.tensor_desc<16xf32>
77
// CHECK: xegpu.store_nd %[[CST]], %[[T0]] : vector<1xf32>, !xegpu.tensor_desc<16xf32>
88
// CHECK: gpu.return
99
gpu.module @test {
@@ -19,8 +19,8 @@ gpu.func @test_store_nd_1d(%arg0: memref<16xf32>){
1919
// -----
2020
// CHECK-LABEL: gpu.func @test_store_nd_2d
2121
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<16x16xf16>) {
22-
// CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<16xf16>
23-
// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
22+
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<16xf16>
23+
// CHECK-DAG: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
2424
// CHECK: xegpu.store_nd %[[CST]], %[[T0]] : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
2525
gpu.module @test {
2626
gpu.func @test_store_nd_2d(%arg0: memref<16x16xf16>){
@@ -38,8 +38,8 @@ gpu.func @test_store_nd_2d(%arg0: memref<16x16xf16>){
3838
// CHECK-LABEL: gpu.func @test_load_nd_1d
3939
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<16xf32>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16xf32>) {
4040
// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<16xf32> -> !xegpu.tensor_desc<16xf32>
41-
// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] : !xegpu.tensor_desc<16xf32> -> vector<1xf32>
42-
// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16xf32> -> !xegpu.tensor_desc<16xf32>
41+
// CHECK-DAG: %[[T1:.*]] = xegpu.load_nd %[[T0]] : !xegpu.tensor_desc<16xf32> -> vector<1xf32>
42+
// CHECK-DAG: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16xf32> -> !xegpu.tensor_desc<16xf32>
4343
// CHECK: xegpu.store_nd %[[T1]], %[[T2]] : vector<1xf32>, !xegpu.tensor_desc<16xf32>
4444
gpu.module @test {
4545
gpu.func @test_load_nd_1d(%arg0: memref<16xf32>, %arg1: memref<16xf32>){
@@ -56,8 +56,8 @@ gpu.func @test_load_nd_1d(%arg0: memref<16xf32>, %arg1: memref<16xf32>){
5656
// CHECK-LABEL: gpu.func @test_load_nd_2d
5757
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>) {
5858
// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
59-
// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
60-
// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
59+
// CHECK-DAG: %[[T1:.*]] = xegpu.load_nd %[[T0]] : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
60+
// CHECK-DAG: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
6161
// CHECK: xegpu.store_nd %[[T1]], %[[T2]] : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
6262
gpu.module @test {
6363
gpu.func @test_load_nd_2d(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>){
@@ -77,8 +77,8 @@ gpu.func @test_load_nd_2d(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>){
7777
// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>> -> vector<32xf16>
7878
// CHECK: %[[T2:.*]] = vector.shape_cast %[[T1]] : vector<32xf16> to vector<2x16x1xf16>
7979
// CHECK: %[[T3:.*]] = vector.extract %[[T2]][0] : vector<16x1xf16> from vector<2x16x1xf16>
80-
// CHECK: %[[T4:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
81-
// CHECK: %[[T5:.*]] = vector.shape_cast %[[T3]] : vector<16x1xf16> to vector<16xf16>
80+
// CHECK-DAG: %[[T4:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
81+
// CHECK-DAG: %[[T5:.*]] = vector.shape_cast %[[T3]] : vector<16x1xf16> to vector<16xf16>
8282
// CHECK: xegpu.store_nd %[[T5]], %[[T4]] : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
8383
gpu.module @test {
8484
gpu.func @test_load_nd_array_length(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>){
@@ -100,9 +100,9 @@ gpu.func @test_load_nd_array_length(%arg0: memref<16x16xf16>, %arg1: memref<16x1
100100
// CHECK: ^bb0(%[[ARG4:[0-9a-zA-Z]+]]: vector<8x16xf16>, %[[ARG5:[0-9a-zA-Z]+]]: vector<16x16xf16>, %[[ARG6:[0-9a-zA-Z]+]]: vector<8x16xf32>, %[[ARG7:[0-9a-zA-Z]+]]: memref<8x16xf32>):
101101
// CHECK: gpu.yield %[[ARG4]], %[[ARG5]], %[[ARG6]] : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32>
102102
// CHECK: }
103-
// CHECK: %[[T2:.*]] = vector.shape_cast %[[T1]]#0 : vector<8x1xf16> to vector<8xf16>
104-
// CHECK: %[[T3:.*]] = vector.shape_cast %[[T1]]#1 : vector<16x1xf16> to vector<16xf16>
105-
// CHECK: %[[T4:.*]] = vector.shape_cast %[[T1]]#2 : vector<8x1xf32> to vector<8xf32>
103+
// CHECK-DAG: %[[T2:.*]] = vector.shape_cast %[[T1]]#0 : vector<8x1xf16> to vector<8xf16>
104+
// CHECK-DAG: %[[T3:.*]] = vector.shape_cast %[[T1]]#1 : vector<16x1xf16> to vector<16xf16>
105+
// CHECK-DAG: %[[T4:.*]] = vector.shape_cast %[[T1]]#2 : vector<8x1xf32> to vector<8xf32>
106106
// CHECK: %[[T5:.*]] = xegpu.dpas %[[T2]], %[[T3]], %[[T4]] : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
107107
// CHECK: %[[T6:.*]] = xegpu.create_nd_tdesc %[[ARG3]][%{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
108108
// CHECK: xegpu.store_nd %[[T5]], %[[T6]] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
@@ -120,11 +120,11 @@ gpu.func @test_dpas(%arg0: vector<8x16xf16>, %arg1: vector<16x16xf16>, %arg3: ve
120120
// CHECK-LABEL: gpu.func @load_dpas_store
121121
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
122122
// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
123-
// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
123+
// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
124124
// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
125125
// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]] : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
126-
// CHECK: %[[T4:.*]] = xegpu.dpas %[[T3]], %[[T1]] : vector<8xf16>, vector<16xf16> -> vector<8xf32>
127-
// CHECK: %[[T5:.*]] = xegpu.create_nd_tdesc %[[ARG2]][%{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
126+
// CHECK-DAG: %[[T4:.*]] = xegpu.dpas %[[T3]], %[[T1]] : vector<8xf16>, vector<16xf16> -> vector<8xf32>
127+
// CHECK-DAG: %[[T5:.*]] = xegpu.create_nd_tdesc %[[ARG2]][%{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
128128
// CHECK: xegpu.store_nd %[[T4]], %[[T5]] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
129129
gpu.module @test {
130130
gpu.func @load_dpas_store(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg3: memref<8x16xf32>){

0 commit comments

Comments
 (0)