Skip to content

Commit 4ce2415

Browse files
committed
Correct memref.subview test
Signed-off-by: dchigarev <[email protected]>
1 parent 78b057d commit 4ce2415

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -152,11 +152,11 @@ gpu.func @no_load_tensor(%source: tensor<32x64xf32>,
152152
// -----
153153
gpu.module @xevm_module {
154154
gpu.func @gather_from_subview(%source: memref<4096x4096xf16>,
155-
%off1: index, %off2: index,
155+
%memref_off: index, %off1: index, %off2: index,
156156
%indices: vector<8xindex>,
157157
%mask: vector<8xi1>,
158158
%pass_thru: vector<8xf16>) -> vector<8xf16> {
159-
%subview = memref.subview %source[%off1, %off2] [256, 256] [1, 1]
159+
%subview = memref.subview %source[%memref_off, %memref_off] [256, 256] [1, 1]
160160
: memref<4096x4096xf16>
161161
to memref<256x256xf16, strided<[4096, 1], offset: ?>>
162162
%0 = vector.gather %subview[%off1, %off2][%indices], %mask, %pass_thru
@@ -167,15 +167,15 @@ gpu.func @gather_from_subview(%source: memref<4096x4096xf16>,
167167
}
168168
// CHECK-LABEL: @gather_from_subview(
169169
// CHECK-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
170-
// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index,
170+
// CHECK-SAME: %[[MEMREF_OFF:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index,
171171
// CHECK-SAME: %[[INDICES:.+]]: vector<8xindex>,
172172
// CHECK-SAME: %[[MASK:.+]]: vector<8xi1>,
173173
// CHECK-SAME: %[[PASS:.+]]: vector<8xf16>) -> vector<8xf16> {
174-
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1]
174+
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[MEMREF_OFF]], %[[MEMREF_OFF]]] [256, 256] [1, 1]
175175
// CHECK: %[[BB:.+]], %[[OFFSET:.+]],{{.*}},{{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index
176-
// CHECK: arith.muli {{.*}} : index
176+
// CHECK: arith.muli {{.*}}%[[OFF1]]{{.*}} : index
177177
// CHECK: arith.addi %[[OFFSET]]{{.*}} : index
178-
// CHECK: %[[BASE_OFF:.+]] = arith.addi {{.*}} : index
178+
// CHECK: %[[BASE_OFF:.+]] = arith.addi {{.*}}%[[OFF2]]{{.*}} : index
179179
// CHECK: %[[SPLAT:.+]] = vector.broadcast %[[BASE_OFF]] : index to vector<8xindex>
180180
// CHECK: %[[LIN:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
181181
// CHECK: %[[BASE_IDX:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index

mlir/test/Conversion/VectorToXeGPU/scatter-to-xegpu.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -176,10 +176,10 @@ gpu.func @non_unit_inner_stride_3D(
176176
gpu.module @xevm_module {
177177
gpu.func @scatter_into_subview(%vals: vector<8xf16>,
178178
%source: memref<4096x4096xf16>,
179-
%off1: index, %off2: index,
179+
%memref_off: index, %off1: index, %off2: index,
180180
%indices: vector<8xindex>,
181181
%mask: vector<8xi1>) {
182-
%subview = memref.subview %source[%off1, %off2] [256, 256] [1, 1]
182+
%subview = memref.subview %source[%memref_off, %memref_off] [256, 256] [1, 1]
183183
: memref<4096x4096xf16>
184184
to memref<256x256xf16, strided<[4096, 1], offset: ?>>
185185
vector.scatter %subview[%off1, %off2][%indices], %mask, %vals
@@ -190,13 +190,13 @@ gpu.func @scatter_into_subview(%vals: vector<8xf16>,
190190
// CHECK-LABEL: @scatter_into_subview(
191191
// CHECK-SAME: %[[VALS:.+]]: vector<8xf16>,
192192
// CHECK-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
193-
// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index,
193+
// CHECK-SAME: %[[MEMREF_OFF:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index,
194194
// CHECK-SAME: %[[INDICES:.+]]: vector<8xindex>, %[[MASK:.+]]: vector<8xi1>) {
195-
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1]
195+
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[MEMREF_OFF]], %[[MEMREF_OFF]]] [256, 256] [1, 1]
196196
// CHECK: %[[BB:.+]], %[[OFFSET:.+]],{{.*}},{{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index
197-
// CHECK: arith.muli {{.*}} : index
197+
// CHECK: arith.muli {{.*}}%[[OFF1]]{{.*}} : index
198198
// CHECK: arith.addi %[[OFFSET]]{{.*}} : index
199-
// CHECK: %[[BASE_OFF:.+]] = arith.addi {{.*}} : index
199+
// CHECK: %[[BASE_OFF:.+]] = arith.addi {{.*}}%[[OFF2]]{{.*}} : index
200200
// CHECK: %[[SPLAT:.+]] = vector.broadcast %[[BASE_OFF]] : index to vector<8xindex>
201201
// CHECK: %[[LIN:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
202202
// CHECK: %[[BASE_IDX:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index

0 commit comments

Comments
 (0)