Skip to content

Commit 2153a8a

Browse files
committed
Address feedback
1 parent 7f4e202 commit 2153a8a

File tree

3 files changed

+94
-69
lines changed

3 files changed

+94
-69
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- XeGPUWgToSg.cpp - XeGPU WorkGroup to Subgroup Pass -------===//
1+
//===- XeGPUWgToSg.cpp - XeGPU Workgroup to Subgroup Pass -----------------===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -25,15 +25,10 @@ namespace xegpu {
2525
} // namespace xegpu
2626
} // namespace mlir
2727

28-
#define DEBUG_TYPE "xegpu-wg-to-sg"
29-
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
30-
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
31-
3228
using namespace mlir;
3329

3430
namespace {
3531

36-
// clang-format off
3732
/// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
3833
/// from a workgroup descriptor. It replaces the offsets and sizes with
3934
/// appropriate values for the subgroup.
@@ -42,11 +37,14 @@ namespace {
4237
/// %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32>
4338
/// -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4],
4439
/// sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
45-
/// is converted to 9 subgroup level operations based on the sg_layout & sg_data:
40+
/// is converted to 9 subgroup level operations based on the sg_layout &
41+
/// sg_data:
4642
/// %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> ->
47-
/// !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
43+
/// !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2],
44+
/// lane_data = [1, 1]>>
4845
///
49-
/// The sg_layout and sg_data attributes are dropped after the pass as they are no longer needed.
46+
/// The sg_layout and sg_data attributes are dropped after the pass as they are
47+
/// no longer needed.
5048
///
5149
/// 24x24 matrix distribution example:
5250
/// sg_layout = [4, 4], sg_data = [2, 2]
@@ -69,9 +67,8 @@ namespace {
6967
/// | 2x2 2x2 2x2 2x2 |
7068
/// +------------------------+
7169
///
72-
/// Since the 24x24 matrix is divided into 8x8 distribution units, there will be 9
73-
/// distribution units (3x3) in total. Hence the 9 subgroup level operations.
74-
// clang-format on
70+
/// Since the 24x24 matrix is divided into 8x8 distribution units, there will be
71+
/// 9 distribution units (3x3) in total. Hence the 9 subgroup level operations.
7572
struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
7673
using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
7774

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir

Lines changed: 37 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,70 @@
11
// RUN: mlir-opt --xegpu-wg-to-sg -split-input-file %s | FileCheck %s
22

33
gpu.module @test_round_robin_assignment {
4-
// CHECK: test_create_nd_tdesc
5-
// CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
4+
// CHECK-LABEL: test_create_nd_tdesc
5+
// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
66
gpu.func @test_create_nd_tdesc(%src: memref<24x32xf32>) {
7-
// CHECK-COUNT-12: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
7+
// CHECK-COUNT-12: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<24x32xf32>
8+
// CHECK-SAME: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
89
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
910
gpu.return
1011
}
1112

12-
// CHECK: test_load_nd_tdesc
13-
// CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
13+
// CHECK-LABEL: test_load_nd_tdesc
14+
// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
1415
gpu.func @test_load_nd_tdesc(%src: memref<24x32xf32>) {
1516
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
16-
// CHECK-COUNT-12: %[[LOAD:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<2x2xf32>
17+
// CHECK-COUNT-12: xegpu.load_nd %{{.*}}
18+
// CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
19+
// CHECK-SAME-COUNT-12: -> vector<2x2xf32>
1720
%load = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<24x32xf32>
1821
gpu.return
1922
}
2023

21-
// CHECK: test_store_nd
22-
// CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
24+
// CHECK-LABEL: test_store_nd
25+
// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
2326
gpu.func @test_store_nd(%src: memref<24x32xf32>) {
2427
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
25-
// CHECK-COUNT-12: xegpu.store_nd %{{.*}}, %{{.*}} : vector<2x2xf32>, !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
28+
// CHECK-COUNT-12: xegpu.store_nd %{{.*}}, %{{.*}}
29+
// CHECK-SAME-COUNT-12: : vector<2x2xf32>, !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
2630
%load = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<24x32xf32>
2731
xegpu.store_nd %load, %tdesc: vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
2832
gpu.return
2933
}
3034

31-
// CHECK: test_update_nd
32-
// CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
35+
// CHECK-LABEL: test_update_nd
36+
// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
3337
gpu.func @test_update_nd(%src: memref<24x32xf32>){
3438
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
35-
// CHECK-COUNT-12: %[[UPDATE:.*]] = xegpu.update_nd_offset %{{.*}}, [0, 16] : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
39+
// CHECK-COUNT-12: xegpu.update_nd_offset %{{.*}}, [0, 16]
40+
// CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
3641
%update = xegpu.update_nd_offset %tdesc, [0, 16] : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
3742
gpu.return
3843
}
3944

40-
// CHECK: test_dpas
41-
// CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
42-
// CHECK: %[[ARG_1:.*]]: memref<32x24xf32>
43-
// CHECK: %[[ARG_2:.*]]: memref<24x24xf32>
44-
gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>, %c: memref<24x24xf32>) {
45-
// CHECK-COUNT-12: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
46-
// CHECK-COUNT-12: %[[TDESC1:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<32x24xf32> -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
47-
// CHECK-COUNT-9: %[[TDESC2:.*]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}] : memref<24x24xf32> -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
48-
// CHECK-COUNT-144: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>, vector<2x2xf32> -> vector<2x2xf32>
49-
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
50-
%load_a = xegpu.load_nd %tdesc_a: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<24x32xf32>
51-
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32> -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
52-
%load_b = xegpu.load_nd %tdesc_b: !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<32x24xf32>
53-
%tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x24xf32> -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
54-
%dpas = xegpu.dpas %load_a, %load_b {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>} : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
45+
// CHECK-LABEL: test_dpas
46+
// CHECK-SAME: (%[[ARG_0:.*]]: memref<8x8xf32>, %[[ARG_1:.*]]: memref<8x8xf32>, %[[ARG_2:.*]]: memref<8x8xf32>)
47+
gpu.func @test_dpas(%a: memref<8x8xf32>, %b: memref<8x8xf32>, %c: memref<8x8xf32>) {
48+
// CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<8x8xf32>
49+
// CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
50+
// CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<8x8xf32>
51+
// CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
52+
// CHECK-COUNT-4: xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}] : memref<8x8xf32>
53+
// CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
54+
// CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}}
55+
// CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
56+
// CHECK-SAME-COUNT-16: : vector<2x2xf32>, vector<2x2xf32> -> vector<2x2xf32>
57+
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<8x8xf32> -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
58+
%load_a = xegpu.load_nd %tdesc_a: !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<8x8xf32>
59+
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<8x8xf32> -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
60+
%load_b = xegpu.load_nd %tdesc_b: !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<8x8xf32>
61+
%tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<8x8xf32> -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
62+
%dpas = xegpu.dpas %load_a, %load_b {layout = #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>} : vector<8x8xf32>, vector<8x8xf32> -> vector<8x8xf32>
5563
gpu.return
5664
}
5765

58-
// CHECK: test_prefetch_nd_tdesc
59-
// CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
66+
// CHECK-LABEL: test_prefetch_nd_tdesc
67+
// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
6068
gpu.func @test_prefetch_nd_tdesc(%src: memref<24x32xf32>) {
6169
// CHECK-COUNT-12: xegpu.prefetch_nd %{{.*}} : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
6270
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>

0 commit comments

Comments
 (0)