Skip to content

Commit b3ba670

Browse files
committed
Address feedback
1 parent b8da87e commit b3ba670

File tree

4 files changed

+95
-49
lines changed

4 files changed

+95
-49
lines changed

mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
3333
"Print the result of the subgroup map propagation analysis and exit.">];
3434
}
3535

36-
def XeGPUWgToSgDistribute : Pass<"xegpu-wg-to-sg-distribute", "::mlir::gpu::GPUModuleOp"> {
36+
def XeGPUWgToSgDistribute : Pass<"xegpu-wg-to-sg-distribute"> {
3737
let summary = "Transform WorkGroup level XeGPU code to SubGroup level";
3838
let description = [{
3939
This transform pass distributes the workgroup level computation to

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77
//===----------------------------------------------------------------------===//
88
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
99

10+
#include "mlir/Dialect/Affine/Utils.h"
11+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1012
#include "mlir/Dialect/Index/IR/IndexDialect.h"
13+
#include "mlir/Dialect/Index/IR/IndexOps.h"
1114
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1215
#include "mlir/Dialect/Utils/IndexingUtils.h"
1316
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1417
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
1518
#include "mlir/Transforms/DialectConversion.h"
16-
#include <mlir/Dialect/GPU/IR/GPUDialect.h>
17-
#include <mlir/Dialect/Index/IR/IndexOps.h>
1819

1920
namespace mlir {
2021
namespace xegpu {
@@ -70,15 +71,6 @@ namespace {
7071
struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
7172
using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
7273

73-
// Convert linear subgroup ID to 2D coordinates
74-
// TODO: Delinearize for nD
75-
SmallVector<Value> delinearizeSubgroupId(ConversionPatternRewriter &rewriter,
76-
Location loc, Value sgID,
77-
Value sgDimX, Value sgDimY) const {
78-
return {rewriter.create<index::DivUOp>(loc, sgID, sgDimY),
79-
rewriter.create<index::RemUOp>(loc, sgID, sgDimY)};
80-
}
81-
8274
// Calculate offset for each subgroup
8375
SmallVector<OpFoldResult>
8476
calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
@@ -144,7 +136,8 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
144136

145137
// TODO : Handle order attribute
146138
// Get the subgroup ID
147-
auto linearSgId = rewriter.create<gpu::SubgroupIdOp>(loc, nullptr);
139+
auto linearSgId =
140+
rewriter.create<gpu::SubgroupIdOp>(loc, /*upper_bound=*/nullptr);
148141

149142
// Create constants for layout dimensions
150143
SmallVector<Value> sgLayoutDim(sgLayout.size());
@@ -156,9 +149,11 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
156149
sgDataDim[i] = rewriter.create<arith::ConstantIndexOp>(loc, sgShape[i]);
157150
}
158151

159-
// Delinearize the 1D subgroup id into 2d
160-
SmallVector<Value> sgIds = delinearizeSubgroupId(
161-
rewriter, loc, linearSgId, sgLayoutDim[0], sgLayoutDim[1]);
152+
auto deLinearizeSgId =
153+
affine::delinearizeIndex(rewriter, loc, linearSgId, sgLayoutDim);
154+
if (failed(deLinearizeSgId))
155+
return failure();
156+
SmallVector<Value> sgIds = *deLinearizeSgId;
162157

163158
// Calculate distribution unit shape and local offsets for subgroup
164159
SmallVector<int64_t> distUnitShape(sgLayout.size());
@@ -267,9 +262,9 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
267262
if (!originalLayout)
268263
return failure();
269264

265+
size_t i = 0;
270266
SmallVector<Value> newDpasOps;
271267
for (auto aVec : adaptor.getLhs()) {
272-
size_t i = 0;
273268
for (auto bVec : adaptor.getRhs()) {
274269

275270
llvm::SmallVector<Value> operands({aVec, bVec});

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ gpu.module @test_round_robin_assignment {
66
gpu.func @test_create_nd_tdesc(%src: memref<24x32xf32>) {
77
// CHECK-COUNT-12: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<24x32xf32>
88
// CHECK-SAME: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
9-
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
9+
// CHECK-NOT: xegpu.create_nd_tdesc
10+
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
11+
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
1012
gpu.return
1113
}
1214

@@ -17,18 +19,26 @@ gpu.module @test_round_robin_assignment {
1719
// CHECK-COUNT-12: xegpu.load_nd %{{.*}}
1820
// CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
1921
// CHECK-SAME-COUNT-12: -> vector<2x2xf32>
20-
%load = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<24x32xf32>
22+
// CHECK-NOT: xegpu.load_nd
23+
%load = xegpu.load_nd %tdesc
24+
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
25+
-> vector<24x32xf32>
2126
gpu.return
2227
}
2328

2429
// CHECK-LABEL: test_store_nd
2530
// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
2631
gpu.func @test_store_nd(%src: memref<24x32xf32>) {
27-
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
32+
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
33+
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
2834
// CHECK-COUNT-12: xegpu.store_nd %{{.*}}, %{{.*}}
2935
// CHECK-SAME-COUNT-12: : vector<2x2xf32>, !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
30-
%load = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<24x32xf32>
31-
xegpu.store_nd %load, %tdesc: vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
36+
// CHECK-NOT : xegpu.store_nd
37+
%load = xegpu.load_nd %tdesc
38+
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
39+
-> vector<24x32xf32>
40+
xegpu.store_nd %load, %tdesc
41+
: vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
3242
gpu.return
3343
}
3444

@@ -38,7 +48,9 @@ gpu.module @test_round_robin_assignment {
3848
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
3949
// CHECK-COUNT-12: xegpu.update_nd_offset %{{.*}}, [0, 16]
4050
// CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
41-
%update = xegpu.update_nd_offset %tdesc, [0, 16] : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
51+
// CHECK-NOT: xegpu.update_nd_offset
52+
%update = xegpu.update_nd_offset %tdesc, [0, 16]
53+
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
4254
gpu.return
4355
}
4456

@@ -47,28 +59,45 @@ gpu.module @test_round_robin_assignment {
4759
gpu.func @test_dpas(%a: memref<8x8xf32>, %b: memref<8x8xf32>, %c: memref<8x8xf32>) {
4860
// CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<8x8xf32>
4961
// CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
62+
// CHECK-NOT: xegpu.create_nd_tdesc
5063
// CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<8x8xf32>
5164
// CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
65+
// CHECK-NOT: xegpu.create_nd_tdesc
5266
// CHECK-COUNT-4: xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}] : memref<8x8xf32>
5367
// CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
68+
// CHECK-NOT: xegpu.create_nd_tdesc
5469
// CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}}
5570
// CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
5671
// CHECK-SAME-COUNT-16: : vector<2x2xf32>, vector<2x2xf32> -> vector<2x2xf32>
57-
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<8x8xf32> -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
58-
%load_a = xegpu.load_nd %tdesc_a: !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<8x8xf32>
59-
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<8x8xf32> -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
60-
%load_b = xegpu.load_nd %tdesc_b: !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<8x8xf32>
61-
%tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<8x8xf32> -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
62-
%dpas = xegpu.dpas %load_a, %load_b {layout = #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>} : vector<8x8xf32>, vector<8x8xf32> -> vector<8x8xf32>
72+
// CHECK-NOT: xegpu.dpas
73+
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<8x8xf32>
74+
-> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
75+
%load_a = xegpu.load_nd %tdesc_a
76+
: !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
77+
-> vector<8x8xf32>
78+
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<8x8xf32>
79+
-> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
80+
%load_b = xegpu.load_nd %tdesc_b
81+
: !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
82+
-> vector<8x8xf32>
83+
%tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<8x8xf32>
84+
-> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
85+
%dpas = xegpu.dpas %load_a, %load_b
86+
{layout = #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
87+
: vector<8x8xf32>, vector<8x8xf32> -> vector<8x8xf32>
6388
gpu.return
6489
}
6590

6691
// CHECK-LABEL: test_prefetch_nd_tdesc
6792
// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
6893
gpu.func @test_prefetch_nd_tdesc(%src: memref<24x32xf32>) {
69-
// CHECK-COUNT-12: xegpu.prefetch_nd %{{.*}} : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
70-
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
71-
xegpu.prefetch_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
94+
// CHECK-COUNT-12: xegpu.prefetch_nd %{{.*}}
95+
// CHECK-SAME-COUNT-12 : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
96+
// CHECK-NOT: xegpu.prefetch_nd
97+
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
98+
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
99+
xegpu.prefetch_nd %tdesc
100+
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
72101
gpu.return
73102
}
74103
}

0 commit comments

Comments
 (0)