Skip to content

Commit 78b057d

Browse files
committed
Handle non-unit inner stride
Signed-off-by: dchigarev <[email protected]>
1 parent 62c5c38 commit 78b057d

File tree

3 files changed

+126
-14
lines changed

3 files changed

+126
-14
lines changed

mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -99,19 +99,12 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
9999

100100
// Common preconditions for the lowering of vector.gather and vector.scatter:
101101
// 1. Source is a memref.
102-
// 2. The innermost dimension of the memref is contiguous (stride == 1)
103102
static LogicalResult gatherScatterPreconditions(PatternRewriter &rewriter,
104103
Operation *op, Type baseType) {
105104
auto srcTy = dyn_cast<MemRefType>(baseType);
106105
if (!srcTy)
107106
return rewriter.notifyMatchFailure(op, "Expects memref source");
108107

109-
SmallVector<int64_t> strides;
110-
int64_t offset;
111-
if (failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1)
112-
return rewriter.notifyMatchFailure(
113-
op, "Buffer must be contiguous in the innermost dimension");
114-
115108
return success();
116109
}
117110

@@ -219,9 +212,14 @@ computeMemrefMeta(OpType xferOp, PatternRewriter &rewriter) {
219212
SmallVector<int64_t> intStrides;
220213
if (failed(memrefType.getStridesAndOffset(intStrides, offset)))
221214
return {{}, offsetVal};
222-
// Wrap static strides as MLIR values
223-
for (int64_t s : intStrides)
224-
strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s));
215+
bool hasDynamicStrides = llvm::any_of(intStrides, [](int64_t strideVal) {
216+
return ShapedType::isDynamic(strideVal);
217+
});
218+
219+
if (!hasDynamicStrides)
220+
for (int64_t s : intStrides)
221+
strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s));
222+
225223
if (!ShapedType::isDynamic(offset))
226224
offsetVal = arith::ConstantIndexOp::create(rewriter, loc, offset);
227225
}
@@ -389,13 +387,20 @@ static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
389387
Value indices = gatScatOp.getIndices();
390388
VectorType vecType = cast<VectorType>(indices.getType());
391389

390+
Value strideVector =
391+
vector::BroadcastOp::create(rewriter, loc, vecType, strides.back())
392+
.getResult();
393+
Value stridedIndices =
394+
arith::MulIOp::create(rewriter, loc, strideVector, indices).getResult();
395+
392396
Value baseVector =
393397
vector::BroadcastOp::create(
394398
rewriter, loc,
395399
VectorType::get(vecType.getShape(), rewriter.getIndexType()),
396400
baseOffset)
397401
.getResult();
398-
return arith::AddIOp::create(rewriter, loc, baseVector, indices).getResult();
402+
return arith::AddIOp::create(rewriter, loc, baseVector, stridedIndices)
403+
.getResult();
399404
}
400405

401406
template <

mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,67 @@ gpu.func @gather_from_subview(%source: memref<4096x4096xf16>,
185185
// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS]] : vector<8xi1>, vector<8xf16>
186186
// CHECK: gpu.return %[[RES]] : vector<8xf16>
187187
}
188+
189+
// -----
190+
gpu.module @xevm_module {
191+
gpu.func @non_unit_inner_stride_1D(
192+
%source: memref<32xf32, strided<[?], offset: ?>>,
193+
%off: index, %indices: vector<8xindex>, %mask: vector<8xi1>,
194+
%pass_thru: vector<8xf32>) -> vector<8xf32> {
195+
%0 = vector.gather %source[%off][%indices], %mask, %pass_thru
196+
: memref<32xf32, strided<[?], offset: ?>>,
197+
vector<8xindex>, vector<8xi1>, vector<8xf32>
198+
into vector<8xf32>
199+
gpu.return %0 : vector<8xf32>
200+
}
201+
// CHECK-LABEL: @non_unit_inner_stride_1D(
202+
// CHECK-SAME: %[[SRC:.+]]: memref<32xf32, strided<[?], offset: ?>>,
203+
// CHECK-SAME: %[[OFF1:.+]]: index,
204+
// CHECK-SAME: %[[INDICES:.+]]: vector<8xindex>,
205+
// CHECK-SAME: %[[MASK:.+]]: vector<8xi1>, %[[PASS:.+]]: vector<8xf32>) -> vector<8xf32> {
206+
// CHECK: %[[BB:.+]], %[[M_OFF:.+]], %[[SZ:.+]], %[[STRIDE:.+]] = memref.extract_strided_metadata %[[SRC]]
207+
// CHECK: arith.muli %[[OFF1]], %[[STRIDE]] : index
208+
// CHECK: arith.addi {{.*}} : index
209+
// CHECK: %[[STRD_VEC:.+]] = vector.broadcast %[[STRIDE]] : index to vector<8xindex>
210+
// CHECK: %[[STRD_INDICES:.+]] = arith.muli %[[STRD_VEC:.+]], %[[INDICES]] : vector<8xindex>
211+
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
212+
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[STRD_INDICES]] : vector<8xindex>
213+
// CHECK: %[[BASE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<32xf32, strided<[?], offset: ?>> -> index
214+
// CHECK: %[[BASE_I64:.+]] = arith.index_cast %[[BASE]] : index to i64
215+
// CHECK: %[[V:.+]] = xegpu.load %[[BASE_I64]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf32>
216+
// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[V]], %[[PASS]] : vector<8xi1>, vector<8xf32>
217+
// CHECK: gpu.return %[[RES]] : vector<8xf32>
218+
}
219+
220+
// -----
221+
gpu.module @xevm_module {
222+
gpu.func @non_unit_inner_stride_3D(
223+
%source: memref<4x8x32xf32, strided<[?, 128, 2], offset: ?>>,
224+
%off0: index, %off1: index, %off2: index,
225+
%indices: vector<8xindex>, %mask: vector<8xi1>,
226+
%pass_thru: vector<8xf32>) -> vector<8xf32> {
227+
%0 = vector.gather %source[%off0, %off1, %off2][%indices], %mask, %pass_thru
228+
: memref<4x8x32xf32, strided<[?, 128, 2], offset: ?>>,
229+
vector<8xindex>, vector<8xi1>, vector<8xf32>
230+
into vector<8xf32>
231+
gpu.return %0 : vector<8xf32>
232+
}
233+
// CHECK-LABEL: @non_unit_inner_stride_3D(
234+
// CHECK-SAME: %[[SRC:.+]]: memref<4x8x32xf32, strided<[?, 128, 2], offset: ?>>,
235+
// CHECK-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index,
236+
// CHECK-SAME: %[[INDICES:.+]]: vector<8xindex>, %[[MASK:.+]]: vector<8xi1>,
237+
// CHECK-SAME: %[[PASS:.+]]: vector<8xf32>) -> vector<8xf32> {
238+
// CHECK: %[[BB:.+]], %[[M_OFF:.+]], %[[SIZES:.+]]:3, %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
239+
// CHECK: arith.muli %[[OFF0]], %[[STRIDES]]#0 : index
240+
// CHECK: arith.addi {{.*}} : index
241+
// CHECK-COUNT2: arith.muli {{.*}} : index
242+
// CHECK-COUNT2: arith.addi {{.*}} : index
243+
// CHECK: %[[STRD_INDICES:.+]] = arith.muli {{.*}}%[[INDICES]]{{.*}} : vector<8xindex>
244+
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<8xindex>
245+
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[STRD_INDICES]] : vector<8xindex>
246+
// CHECK: %[[BASE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<4x8x32xf32, strided<[?, 128, 2], offset: ?>> -> index
247+
// CHECK: %[[BASE_I64:.+]] = arith.index_cast %[[BASE]] : index to i64
248+
// CHECK: %[[V:.+]] = xegpu.load %[[BASE_I64]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf32>
249+
// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[V]], %[[PASS]] : vector<8xi1>, vector<8xf32>
250+
// CHECK: gpu.return %[[RES]] : vector<8xf32>
251+
}

mlir/test/Conversion/VectorToXeGPU/scatter-to-xegpu.mlir

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,58 @@ gpu.func @store_dynamic_source2(%vec: vector<8x16xf32>, %source: memref<?x8x16xf
118118

119119
// -----
120120
gpu.module @xevm_module {
121-
gpu.func @no_store_non_unit_inner_stride(
121+
gpu.func @non_unit_inner_stride_1D(
122122
%vec: vector<8xf32>, %source: memref<32xf32, strided<[?], offset: ?>>,
123123
%off: index, %indices: vector<8xindex>, %mask: vector<8xi1>) {
124124
vector.scatter %source[%off][%indices], %mask, %vec
125125
: memref<32xf32, strided<[?], offset: ?>>, vector<8xindex>, vector<8xi1>, vector<8xf32>
126126
gpu.return
127127
}
128-
// CHECK-LABEL: @no_store_non_unit_inner_stride(
129-
// CHECK: vector.scatter
128+
// CHECK-LABEL: @non_unit_inner_stride_1D(
129+
// CHECK-SAME: %[[VAL:.+]]: vector<8xf32>, %[[SRC:.+]]: memref<32xf32, strided<[?], offset: ?>>,
130+
// CHECK-SAME: %[[OFF1:.+]]: index,
131+
// CHECK-SAME: %[[INDICES:.+]]: vector<8xindex>, %[[MASK:.+]]: vector<8xi1>) {
132+
// CHECK: %[[BB:.+]], %[[M_OFF:.+]], %[[SZ:.+]], %[[STRIDE:.+]] = memref.extract_strided_metadata %[[SRC]]
133+
// CHECK: arith.muli %[[OFF1]], %[[STRIDE]] : index
134+
// CHECK: arith.addi {{.*}} : index
135+
// CHECK: %[[STRD_VEC:.+]] = vector.broadcast %[[STRIDE]] : index to vector<8xindex>
136+
// CHECK: %[[STRD_INDICES:.+]] = arith.muli %[[STRD_VEC:.+]], %[[INDICES]] : vector<8xindex>
137+
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
138+
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[STRD_INDICES]] : vector<8xindex>
139+
// CHECK: %[[BASE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<32xf32, strided<[?], offset: ?>> -> index
140+
// CHECK: %[[BASE_I64:.+]] = arith.index_cast %[[BASE]] : index to i64
141+
// CHECK: xegpu.store %[[VAL]], %[[BASE_I64]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8xf32>, i64, vector<8xindex>, vector<8xi1>
142+
// CHECK: gpu.return
143+
}
144+
145+
// -----
146+
gpu.module @xevm_module {
147+
gpu.func @non_unit_inner_stride_3D(
148+
%vec: vector<8xf32>,
149+
%source: memref<4x8x32xf32, strided<[?, 128, 2], offset: ?>>,
150+
%off0: index, %off1: index, %off2: index,
151+
%indices: vector<8xindex>, %mask: vector<8xi1>) {
152+
vector.scatter %source[%off0, %off1, %off2][%indices], %mask, %vec
153+
: memref<4x8x32xf32, strided<[?, 128, 2], offset: ?>>,
154+
vector<8xindex>, vector<8xi1>, vector<8xf32>
155+
gpu.return
156+
}
157+
// CHECK-LABEL: @non_unit_inner_stride_3D(
158+
// CHECK-SAME: %[[VAL:.+]]: vector<8xf32>, %[[SRC:.+]]: memref<4x8x32xf32, strided<[?, 128, 2], offset: ?>>,
159+
// CHECK-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index,
160+
// CHECK-SAME: %[[INDICES:.+]]: vector<8xindex>, %[[MASK:.+]]: vector<8xi1>) {
161+
// CHECK: %[[BB:.+]], %[[M_OFF:.+]], %[[SIZES:.+]]:3, %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
162+
// CHECK: arith.muli %[[OFF0]], %[[STRIDES]]#0 : index
163+
// CHECK: arith.addi {{.*}} : index
164+
// CHECK-COUNT2: arith.muli {{.*}} : index
165+
// CHECK-COUNT2: arith.addi {{.*}} : index
166+
// CHECK: %[[STRD_INDICES:.+]] = arith.muli {{.*}}%[[INDICES]]{{.*}} : vector<8xindex>
167+
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<8xindex>
168+
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[STRD_INDICES]] : vector<8xindex>
169+
// CHECK: %[[BASE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<4x8x32xf32, strided<[?, 128, 2], offset: ?>> -> index
170+
// CHECK: %[[BASE_I64:.+]] = arith.index_cast %[[BASE]] : index to i64
171+
// CHECK: xegpu.store %[[VAL]], %[[BASE_I64]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8xf32>, i64, vector<8xindex>, vector<8xi1>
172+
// CHECK: gpu.return
130173
}
131174

132175
// -----

0 commit comments

Comments
 (0)