Skip to content

Commit f98dc9a

Browse files
silee2Priyanshu3820
authored andcommitted
[MLIR][Conversion] XeGPU to XeVM: Use adaptor for getting base address from memref. (llvm#168610)
adaptor already lowers memref to base address. Conversion patterns should use it instead of generating code to get base address from memref.
1 parent 11bed0d commit f98dc9a

File tree

4 files changed

+19
-14
lines changed

4 files changed

+19
-14
lines changed

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,10 +194,9 @@ class CreateNdDescToXeVMPattern
194194
if (!sourceMemrefTy.hasRank()) {
195195
return rewriter.notifyMatchFailure(op, "Expected ranked Memref.");
196196
}
197-
baseAddr =
198-
memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
199-
// Cast index to i64.
200-
baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
197+
// Access adaptor after failure check to avoid rolling back generated code
198+
// for materialization cast.
199+
baseAddr = adaptor.getSource();
201200
} else {
202201
baseAddr = adaptor.getSource();
203202
if (baseAddr.getType() != i64Ty) {

mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ gpu.module @create_nd_tdesc {
77
// CHECK-SAME: %[[DYN:.*]]: memref<?x?xf16>) kernel {
88
gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
99
%stride1: index, %stride2: index, %offset1: index, %offset2: index, %dyn: memref<?x?xf16>) kernel {
10+
// CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref<?x?xf16> -> index
11+
// CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
1012
// CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
1113
// CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64
1214
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
@@ -27,9 +29,9 @@ gpu.module @create_nd_tdesc {
2729
// CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<16x32xf32, 1> to memref<16x32xf32>
2830
%srcce = memref.memory_space_cast %src : memref<16x32xf32, 1> to memref<16x32xf32>
2931

30-
// CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32>
3132
// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
3233
// CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
34+
// CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32>
3335
// CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32
3436
// CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32
3537
// CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64
@@ -52,8 +54,6 @@ gpu.module @create_nd_tdesc {
5254
// CHECK: %[[C16:.*]] = arith.constant 16 : index
5355
%BLOCK_DMODEL = arith.constant 16 : index
5456
// CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
55-
// CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref<?x?xf16> -> index
56-
// CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
5757
// CHECK: %[[C0_I32_6:.*]] = arith.constant 0 : i32
5858
// CHECK: %[[C0_I32_7:.*]] = arith.constant 0 : i32
5959
// CHECK: %[[VAR21:.*]] = arith.index_cast %[[C16]] : index to i32

mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@ gpu.module @load_store_check {
99

1010
// CHECK: %[[SRCCE:.*]] = memref.memory_space_cast %[[SRC]] : memref<512xf32, 1> to memref<512xf32>
1111
%srcce = memref.memory_space_cast %src : memref<512xf32, 1> to memref<512xf32>
12+
// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[SRCCE]] : memref<512xf32> -> index
13+
// CHECK: %[[INTPTR_I64:.*]] = arith.index_castui %[[INTPTR]] : index to i64
1214
// CHECK: %[[DSTTE:.*]] = memref.memory_space_cast %[[DST]] : memref<256xf32, 1> to memref<256xf32>
1315
%dstte = memref.memory_space_cast %dst : memref<256xf32, 1> to memref<256xf32>
16+
// CHECK: %[[INTPTR1:.*]] = memref.extract_aligned_pointer_as_index %[[DSTTE]] : memref<256xf32> -> index
17+
// CHECK: %[[INTPTR1_I64:.*]] = arith.index_castui %[[INTPTR1]] : index to i64
1418

15-
// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[SRCCE]] : memref<512xf32> -> index
16-
// CHECK: %[[INTPTR_I64:.*]] = arith.index_castui %[[INTPTR]] : index to i64
1719
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<512xf32> -> !xegpu.tensor_desc<32xf32>
1820
// CHECK: %[[ADDR:.*]] = arith.addi %[[INTPTR_I64]], %[[C384]] : i64
1921
// CHECK: %[[PTR:.*]] = llvm.inttoptr %[[ADDR]] : i64 to !llvm.ptr<1>
@@ -22,8 +24,6 @@ gpu.module @load_store_check {
2224
%loaded = xegpu.load_nd %src_tdesc[96] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
2325
: !xegpu.tensor_desc<32xf32> -> vector<2xf32>
2426

25-
// CHECK: %[[INTPTR1:.*]] = memref.extract_aligned_pointer_as_index %[[DSTTE]] : memref<256xf32> -> index
26-
// CHECK: %[[INTPTR1_I64:.*]] = arith.index_castui %[[INTPTR1]] : index to i64
2727
%dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<256xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.block_tdesc_attr<memory_space = global>>
2828
// CHECK: %[[ADDR1:.*]] = arith.addi %[[INTPTR1_I64]], %[[C512]] : i64
2929
// CHECK: %[[PTR1:.*]] = llvm.inttoptr %[[ADDR1]] : i64 to !llvm.ptr<1>

mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
22

33
gpu.module @load_store_check {
4+
// CHECK-LABEL: gpu.func @load_store(
5+
// CHECK-SAME: %[[ARG0:.*]]: memref<8x16xf32, 1>, %[[ARG1:.*]]: memref<8x16xf32, 1>) kernel {
46
gpu.func @load_store(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel {
57
%srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32>
68
%dstte = memref.memory_space_cast %dst : memref<8x16xf32, 1> to memref<8x16xf32>
79

8-
// CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64
10+
// CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<8x16xf32, 1> to memref<8x16xf32>
11+
// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index
12+
// CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui %[[INTPTR]] : index to i64
13+
// CHECK: %[[MEMSPACECAST_0:.*]] = memref.memory_space_cast %[[ARG1]] : memref<8x16xf32, 1> to memref<8x16xf32>
14+
// CHECK: %[[INTPTR_1:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST_0]] : memref<8x16xf32> -> index
15+
// CHECK: %[[ST_PTR_AS_I64:.*]] = arith.index_castui %[[INTPTR_1]] : index to i64
916
// CHECK: %[[LD_CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
1017
// CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
1118
// CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32>
@@ -42,9 +49,8 @@ gpu.module @load_store_check {
4249
//CHECK: %[[LOADED_F32_MODIFIED:.*]] = vector.insert %{{.*}}, %[[LD_LOADED_F32]] [0] : f32 into vector<8xf32>
4350
%loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32>
4451

45-
// CHECK: %[[PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64
4652
// CHECK: %[[CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
47-
// CHECK: %[[DESC_0:.*]] = vector.insert %[[PTR_AS_I64]], %[[CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
53+
// CHECK: %[[DESC_0:.*]] = vector.insert %[[ST_PTR_AS_I64]], %[[CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
4854
// CHECK: %[[DESC_1:.*]] = vector.bitcast %[[DESC_0]] : vector<4xi64> to vector<8xi32>
4955
// CHECK: %[[DESC_2:.*]] = vector.insert {{.*}}, %[[DESC_1]] [2] : i32 into vector<8xi32>
5056
// CHECK: %[[DESC_3:.*]] = vector.insert {{.*}}, %[[DESC_2]] [3] : i32 into vector<8xi32>

0 commit comments

Comments
 (0)