Skip to content

Commit 4d2c284

Browse files
committed
Add alignment handling
Signed-off-by: dchigarev <[email protected]>
1 parent d97a3c2 commit 4d2c284

File tree

3 files changed

+62
-0
lines changed

3 files changed

+62
-0
lines changed

mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,12 @@ struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
620620
Value localOffsets = computeOffsets(rewriter, gatherOp, strides);
621621
Value flatMemref = collapseMemrefTo1D(gatherOp, rewriter);
622622

623+
if (auto alignment = gatherOp.getAlignment()) {
624+
flatMemref = memref::AssumeAlignmentOp::create(rewriter, loc, flatMemref,
625+
alignment.value())
626+
.getResult();
627+
}
628+
623629
auto xeGatherOp = xegpu::LoadGatherOp::create(
624630
rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask(),
625631
/*chunk_size=*/IntegerAttr{},
@@ -653,6 +659,12 @@ struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
653659
Value localOffsets = computeOffsets(rewriter, scatterOp, strides);
654660
Value flatMemref = collapseMemrefTo1D(scatterOp, rewriter);
655661

662+
if (auto alignment = scatterOp.getAlignment()) {
663+
flatMemref = memref::AssumeAlignmentOp::create(rewriter, loc, flatMemref,
664+
alignment.value())
665+
.getResult();
666+
}
667+
656668
xegpu::StoreScatterOp::create(rewriter, loc, scatterOp.getValueToStore(),
657669
flatMemref, localOffsets, scatterOp.getMask(),
658670
/*chunk_size=*/IntegerAttr{},

mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,30 @@ gpu.func @no_load_non_unit_inner_stride(
158158
// CHECK: vector.gather
159159
}
160160

161+
// -----
162+
gpu.module @xevm_module {
163+
gpu.func @load_1D_aligned(%source: memref<8x16x32xf32>,
164+
%off1: index, %off2: index, %off3: index,
165+
%indices: vector<8xindex>, %mask: vector<8xi1>,
166+
%pass_thru: vector<8xf32>) -> vector<8xf32> {
167+
%0 = vector.gather %source[%off1, %off2, %off3][%indices], %mask,
168+
%pass_thru {alignment = 256} : memref<8x16x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32> into vector<8xf32>
169+
gpu.return %0 : vector<8xf32>
170+
}
171+
// CHECK-LABEL: @load_1D_aligned(
172+
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
173+
// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
174+
// CHECK-SAME: %[[INDICES:.+]]: vector<8xindex>
175+
// CHECK-SAME: %[[MASK:.+]]: vector<8xi1>
176+
// CHECK-SAME: %[[PASS_THRU:.+]]: vector<8xf32>) -> vector<8xf32> {
177+
// CHECK-COUNT2: arith.muli {{.*}} : index
178+
// CHECK-COUNT2: arith.addi {{.*}} : index
179+
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
180+
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
181+
// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
182+
// CHECK: %[[COLLAPSE_ALIGN:.+]] = memref.assume_alignment %[[COLLAPSE]], 256 : memref<4096xf32>
183+
// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_ALIGN]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : memref<4096xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32>
184+
// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8xi1>, vector<8xf32>
185+
// CHECK: gpu.return %[[RES]] : vector<8xf32>
186+
}
187+

mlir/test/Conversion/VectorToXeGPU/scatter-to-xegpu.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,26 @@ gpu.func @no_store_non_unit_inner_stride(
123123
// CHECK-LABEL: @no_store_non_unit_inner_stride(
124124
// CHECK: vector.scatter
125125
}
126+
127+
// -----
128+
gpu.module @xevm_module {
129+
gpu.func @store_1D_aligned(%vec: vector<8xf32>, %source: memref<8x16x32xf32>,
130+
%off1: index, %off2: index, %off3: index,
131+
%indices: vector<8xindex>, %mask: vector<8xi1>) {
132+
vector.scatter %source[%off1, %off2, %off3][%indices], %mask, %vec {alignment = 256}
133+
: memref<8x16x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32>
134+
gpu.return
135+
}
136+
// CHECK-LABEL: @store_1D_aligned(
137+
// CHECK-SAME: %[[VAL:.+]]: vector<8xf32>, %[[SRC:.+]]: memref<8x16x32xf32>,
138+
// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
139+
// CHECK-SAME: %[[INDICES:.+]]: vector<8xindex>, %[[MASK:.+]]: vector<8xi1>) {
140+
// CHECK-COUNT2: arith.muli {{.*}} : index
141+
// CHECK-COUNT2: arith.addi {{.*}} : index
142+
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
143+
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
144+
// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
145+
// CHECK: %[[COLLAPSE_ALIGN:.+]] = memref.assume_alignment %[[COLLAPSE]], 256 : memref<4096xf32>
146+
// CHECK: xegpu.store %[[VAL]], %[[COLLAPSE_ALIGN]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8xf32>, memref<4096xf32>, vector<8xindex>, vector<8xi1>
147+
// CHECK: gpu.return
148+
}

0 commit comments

Comments
 (0)