@@ -152,11 +152,11 @@ gpu.func @no_load_tensor(%source: tensor<32x64xf32>,
152152// -----
153153gpu.module @xevm_module {
154154gpu.func @gather_from_subview (%source: memref <4096 x4096 xf16 >,
155- %off1: index , %off2: index ,
155+ %memref_off: index , % off1: index , %off2: index ,
156156 %indices: vector <8 xindex >,
157157 %mask: vector <8 xi1 >,
158158 %pass_thru: vector <8 xf16 >) -> vector <8 xf16 > {
159- %subview = memref.subview %source [%off1 , %off2 ] [256 , 256 ] [1 , 1 ]
159+ %subview = memref.subview %source [%memref_off , %memref_off ] [256 , 256 ] [1 , 1 ]
160160 : memref <4096 x4096 xf16 >
161161 to memref <256 x256 xf16 , strided <[4096 , 1 ], offset : ?>>
162162 %0 = vector.gather %subview [%off1 , %off2 ][%indices ], %mask , %pass_thru
@@ -167,15 +167,15 @@ gpu.func @gather_from_subview(%source: memref<4096x4096xf16>,
167167}
168168// CHECK-LABEL: @gather_from_subview(
169169// CHECK-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
170- // CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index,
170+ // CHECK-SAME: %[[MEMREF_OFF:.+]]: index, %[[ OFF1:.+]]: index, %[[OFF2:.+]]: index,
171171// CHECK-SAME: %[[INDICES:.+]]: vector<8xindex>,
172172// CHECK-SAME: %[[MASK:.+]]: vector<8xi1>,
173173// CHECK-SAME: %[[PASS:.+]]: vector<8xf16>) -> vector<8xf16> {
174- // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1 ]], %[[OFF2 ]]] [256, 256] [1, 1]
174+ // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[MEMREF_OFF ]], %[[MEMREF_OFF ]]] [256, 256] [1, 1]
175175// CHECK: %[[BB:.+]], %[[OFFSET:.+]],{{.*}},{{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index
176- // CHECK: arith.muli {{.*}} : index
176+ // CHECK: arith.muli {{.*}}%[[OFF1]]{{.*}} : index
177177// CHECK: arith.addi %[[OFFSET]]{{.*}} : index
178- // CHECK: %[[BASE_OFF:.+]] = arith.addi {{.*}} : index
178+ // CHECK: %[[BASE_OFF:.+]] = arith.addi {{.*}}%[[OFF2]]{{.*}} : index
179179// CHECK: %[[SPLAT:.+]] = vector.broadcast %[[BASE_OFF]] : index to vector<8xindex>
180180// CHECK: %[[LIN:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
181181// CHECK: %[[BASE_IDX:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index
0 commit comments