Skip to content

Commit dde9f3b

Browse files
nbpatelmahesh-attarde
authored andcommitted
[MLIR][XeGPU] Add unroll pattern for load_gather and store_scatter with offsets (llvm#159453)
This PR adds unrolling/blocking patterns for load_gather and store_scatter ops with offsets.
1 parent 061631e commit dde9f3b

File tree

3 files changed

+340
-41
lines changed

3 files changed

+340
-41
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 191 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,195 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
537537
}
538538
};
539539

540+
/// This pattern handles the unrolling of LoadGatherOp with offsets (gathered
541+
/// load).
542+
/// It unrolls the offsets and mask operands accordingly, and creates multiple
543+
/// LoadGatherOp with the unrolled operands.
544+
struct UnrollLoadGatherOpWithOffset
545+
: public UnrollPattern<xegpu::LoadGatherOp> {
546+
using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
547+
LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
548+
PatternRewriter &rewriter) const override {
549+
Location loc = op.getLoc();
550+
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
551+
Value offsets = op.getOffsets();
552+
Value mask = op.getMask();
553+
554+
// Only handle the case where offsets are present (scattered load)
555+
if (!offsets)
556+
return failure();
557+
558+
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
559+
if (!targetShape)
560+
return failure();
561+
562+
SmallVector<int64_t> targetMaskShape(*targetShape);
563+
int64_t chunkSize = 1;
564+
if (auto chunkSizeAttr = op->getAttr("chunk_size")) {
565+
if (auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
566+
chunkSize = intAttr.getInt();
567+
}
568+
569+
// Unroll mask and offsets with correct shape
570+
VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
571+
VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());
572+
Type elemTy = valueTy.getElementType();
573+
VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
574+
575+
SmallVector<Type> convertedMaskTypes;
576+
SmallVector<Value> convertedMasks;
577+
SmallVector<Type> convertedOffsetTypes;
578+
SmallVector<Value> convertedOffsets;
579+
580+
if (chunkSize > 1) {
581+
// For chunked loads, mask and offsets have one less dimension
582+
targetMaskShape.pop_back();
583+
int64_t blockedChunkSize = targetShape->back();
584+
int64_t numNewChunks = chunkSize / blockedChunkSize;
585+
chunkSize = blockedChunkSize;
586+
587+
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
588+
convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
589+
590+
SmallVector<Value> convertedMasksBase =
591+
pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
592+
SmallVector<Value> convertedOffsetsBase =
593+
pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
594+
595+
for (auto maskVal : convertedMasksBase)
596+
convertedMasks.append(numNewChunks, maskVal);
597+
598+
for (auto [baseOffset, offsetType] :
599+
llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
600+
for (int64_t i = 0; i < numNewChunks; ++i) {
601+
Value inc = arith::ConstantIndexOp::create(rewriter, loc,
602+
i * blockedChunkSize);
603+
Value incVec =
604+
vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
605+
Value offsetVal =
606+
arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
607+
convertedOffsets.push_back(offsetVal);
608+
}
609+
}
610+
} else {
611+
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
612+
convertedMasks =
613+
pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
614+
615+
convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
616+
convertedOffsets =
617+
pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
618+
}
619+
620+
SmallVector<Value> newOps;
621+
for (auto [o, m] : llvm::zip(convertedOffsets, convertedMasks)) {
622+
auto newOp = xegpu::LoadGatherOp::create(
623+
rewriter, loc, newValueTy, op.getSource(), o, m,
624+
rewriter.getI64IntegerAttr(chunkSize), op.getL1HintAttr(),
625+
op.getL2HintAttr(), op.getL3HintAttr());
626+
newOps.push_back(newOp);
627+
}
628+
629+
Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
630+
rewriter.replaceOp(op, castOp);
631+
return success();
632+
}
633+
};
634+
635+
/// This pattern handles the unrolling of StoreScatterOp with offsets (scattered
636+
/// store).
637+
/// It unrolls the offsets and mask operands accordingly, and creates multiple
638+
/// StoreScatterOp with the unrolled operands.
639+
struct UnrollStoreScatterOpWithOffsets
640+
: public UnrollPattern<xegpu::StoreScatterOp> {
641+
using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
642+
LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
643+
PatternRewriter &rewriter) const override {
644+
Location loc = op.getLoc();
645+
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
646+
Value offsets = op.getOffsets();
647+
Value mask = op.getMask();
648+
649+
// Only handle the case where offsets are present (scattered store)
650+
if (!offsets)
651+
return failure();
652+
653+
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
654+
if (!targetShape)
655+
return failure();
656+
657+
int64_t chunkSize = 1;
658+
if (auto chunkSizeAttr = op->getAttr("chunk_size")) {
659+
if (auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
660+
chunkSize = intAttr.getInt();
661+
}
662+
663+
SmallVector<int64_t> targetMaskShape(*targetShape);
664+
VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
665+
VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());
666+
667+
SmallVector<Type> convertedMaskTypes;
668+
SmallVector<Value> convertedMasks;
669+
SmallVector<Type> convertedOffsetTypes;
670+
SmallVector<Value> convertedOffsets;
671+
672+
if (chunkSize > 1) {
673+
targetMaskShape.pop_back();
674+
int64_t blockedChunkSize = targetShape->back();
675+
int64_t numNewChunks = chunkSize / blockedChunkSize;
676+
chunkSize = blockedChunkSize;
677+
678+
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
679+
convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
680+
681+
SmallVector<Value> convertedMasksBase =
682+
pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
683+
SmallVector<Value> convertedOffsetsBase =
684+
pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
685+
686+
for (auto maskVal : convertedMasksBase)
687+
convertedMasks.append(numNewChunks, maskVal);
688+
689+
for (auto [baseOffset, offsetType] :
690+
llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
691+
for (int64_t i = 0; i < numNewChunks; ++i) {
692+
Value inc = arith::ConstantIndexOp::create(rewriter, loc,
693+
i * blockedChunkSize);
694+
Value incVec =
695+
vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
696+
Value offsetVal =
697+
arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
698+
convertedOffsets.push_back(offsetVal);
699+
}
700+
}
701+
} else {
702+
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
703+
convertedMasks =
704+
pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
705+
706+
convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
707+
convertedOffsets =
708+
pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
709+
}
710+
711+
SmallVector<Type> convertedValTypes =
712+
getUnrolledTypes(valueTy, *targetShape);
713+
SmallVector<Value> convertedValues =
714+
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
715+
716+
for (auto [v, o, m] :
717+
llvm::zip(convertedValues, convertedOffsets, convertedMasks)) {
718+
xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m,
719+
rewriter.getI64IntegerAttr(chunkSize),
720+
op.getL1HintAttr(), op.getL2HintAttr(),
721+
op.getL3HintAttr());
722+
}
723+
724+
rewriter.eraseOp(op);
725+
return success();
726+
}
727+
};
728+
540729
struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
541730
using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
542731
LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
@@ -766,6 +955,7 @@ void mlir::xegpu::populateXeGPUUnrollPatterns(
766955
.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
767956
UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp, UnrollCreateDescOp,
768957
UnrollLoadGatherOp, UnrollStoreScatterOp, UnrollPrefetchOp,
769-
UnrollUpdateOffsetOp, UnrollLoadMatrixOp, UnrollStoreMatrixOp>(
958+
UnrollUpdateOffsetOp, UnrollLoadMatrixOp, UnrollStoreMatrixOp,
959+
UnrollLoadGatherOpWithOffset, UnrollStoreScatterOpWithOffsets>(
770960
patterns.getContext(), options);
771961
}

mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,27 @@ gpu.module @test {
210210
gpu.return %ld : vector<32xf32>
211211
}
212212

213+
//-----
214+
215+
216+
// CHECK-LABEL: load_with_offsets
217+
// CHECK-SAME: [[arg0:%.+]]: ui64
218+
// CHECK-COUNT-2: xegpu.load {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<16xindex>, vector<16xi1> -> vector<16xf32>
219+
gpu.func @load_with_offsets(%src: ui64) -> vector<32xf32> {
220+
%cst = arith.constant dense<[
221+
0, 8, 16, 24, 32, 40, 48, 56,
222+
64, 72, 80, 88, 96, 104, 112, 120,
223+
128, 136, 144, 152, 160, 168, 176, 184,
224+
192, 200, 208, 216, 224, 232, 240, 248
225+
]> : vector<32xindex>
226+
227+
%c17 = arith.constant 17: index
228+
%mask = vector.create_mask %c17: vector<32xi1>
229+
%ld = xegpu.load %src[%cst], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<inst_data = [16]>, l1_hint = #xegpu.cache_hint<cached>} : ui64, vector<32xindex>, vector<32xi1> -> vector<32xf32>
230+
231+
gpu.return %ld : vector<32xf32>
232+
}
233+
213234
//-----
214235

215236
// CHECK-LABEL: prefetch
@@ -254,6 +275,28 @@ gpu.module @test {
254275

255276
gpu.return
256277
}
278+
279+
//-----
280+
281+
// CHECK-LABEL: store_with_offsets
282+
// CHECK-SAME: [[arg0:%.+]]: ui64
283+
// CHECK-COUNT-2: xegpu.store {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<16xf32>, ui64, vector<16xindex>, vector<16xi1>
284+
gpu.func @store_with_offsets(%src: ui64) {
285+
%cst = arith.constant dense<[
286+
0, 8, 16, 24, 32, 40, 48, 56,
287+
64, 72, 80, 88, 96, 104, 112, 120,
288+
128, 136, 144, 152, 160, 168, 176, 184,
289+
192, 200, 208, 216, 224, 232, 240, 248
290+
]> : vector<32xindex>
291+
292+
%c17 = arith.constant 17: index
293+
%mask = vector.create_mask %c17: vector<32xi1>
294+
295+
%st_vec = arith.constant dense<1023.0>: vector<32xf32>
296+
xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 1, layout = #xegpu.layout<inst_data = [16]>, l1_hint = #xegpu.cache_hint<cached>} : vector<32xf32>, ui64, vector<32xindex>, vector<32xi1>
297+
298+
gpu.return
299+
}
257300

258301
//-----
259302
// CHECK-LABEL: create_tdesc_step_chunk
@@ -319,6 +362,29 @@ gpu.module @test {
319362
gpu.return %ld : vector<32x4xf32>
320363
}
321364

365+
//-----
366+
// CHECK-LABEL: load_with_offsets_chunk
367+
// CHECK-SAME: [[arg0:%.+]]: ui64
368+
// CHECK: [[cst:%.+]] = arith.constant dense<0.000000e+00> : vector<32x4xf32>
369+
// CHECK: [[cst0:%.+]] = arith.constant dense<[130, 138, 146, 154, 162, 170, 178, 186, 194, 202, 210, 218, 226, 234, 242, 250]> : vector<16xindex>
370+
// CHECK: [[cst1:%.+]] = arith.constant dense<[2, 10, 18, 26, 34, 42, 50, 58, 66, 74, 82, 90, 98, 106, 114, 122]> : vector<16xindex>
371+
// CHECK: [[cst2:%.+]] = arith.constant dense<[128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<16xindex>
372+
// CHECK: [[cst3:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex>
373+
// CHECK-COUNT-4: xegpu.load {{.*}}[{{.*}}], {{.*}} <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<16xindex>, vector<16xi1> -> vector<16x2xf32>
374+
gpu.func @load_with_offsets_chunk(%src: ui64) -> vector<32x4xf32> {
375+
%cst = arith.constant dense<[
376+
0, 8, 16, 24, 32, 40, 48, 56,
377+
64, 72, 80, 88, 96, 104, 112, 120,
378+
128, 136, 144, 152, 160, 168, 176, 184,
379+
192, 200, 208, 216, 224, 232, 240, 248
380+
]> : vector<32xindex>
381+
382+
%c17 = arith.constant 17: index
383+
%mask = vector.create_mask %c17: vector<32xi1>
384+
%ld = xegpu.load %src[%cst], %mask {chunk_size = 4, layout_result_0 = #xegpu.layout<inst_data = [16, 2]>, l1_hint = #xegpu.cache_hint<cached>} : ui64, vector<32xindex>, vector<32xi1> -> vector<32x4xf32>
385+
gpu.return %ld : vector<32x4xf32>
386+
}
387+
322388
//-----
323389
// CHECK-LABEL: store_chunk
324390
// CHECK-SAME: [[arg0:%.+]]: ui64
@@ -342,6 +408,31 @@ gpu.module @test {
342408
gpu.return
343409
}
344410

411+
//-----
412+
// CHECK-LABEL: store_with_offsets_chunk
413+
// CHECK-SAME: [[arg0:%.+]]: ui64
414+
// CHECK: [[cst:%.+]] = arith.constant dense<1.023000e+03> : vector<16x2xf32
415+
// CHECK: [[cst0:%.+]] = arith.constant dense<[130, 138, 146, 154, 162, 170, 178, 186, 194, 202, 210, 218, 226, 234, 242, 250]> : vector<16xindex>
416+
// CHECK: [[cst1:%.+]] = arith.constant dense<[2, 10, 18, 26, 34, 42, 50, 58, 66, 74, 82, 90, 98, 106, 114, 122]> : vector<16xindex>
417+
// CHECK: [[cst2:%.+]] = arith.constant dense<[128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<16xindex>
418+
// CHECK: [[cst3:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex>
419+
// CHECK-COUNT-4: xegpu.store {{.*}}[{{.*}}], {{.*}} <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<16x2xf32>, ui64, vector<16xindex>, vector<16xi1>
420+
gpu.func @store_with_offsets_chunk(%src: ui64) {
421+
%cst = arith.constant dense<[
422+
0, 8, 16, 24, 32, 40, 48, 56,
423+
64, 72, 80, 88, 96, 104, 112, 120,
424+
128, 136, 144, 152, 160, 168, 176, 184,
425+
192, 200, 208, 216, 224, 232, 240, 248
426+
]> : vector<32xindex>
427+
428+
%c17 = arith.constant 17: index
429+
%mask = vector.create_mask %c17: vector<32xi1>
430+
431+
%st_vec = arith.constant dense<1023.>: vector<32x4xf32>
432+
xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 4, layout = #xegpu.layout<inst_data = [16, 2]>, l1_hint = #xegpu.cache_hint<cached>} : vector<32x4xf32>, ui64, vector<32xindex>, vector<32xi1>
433+
gpu.return
434+
}
435+
345436
//-----
346437
// CHECK-LABEL: prefetch_chunk
347438
// CHECK-SAME: [[arg0:%.+]]: ui64

0 commit comments

Comments
 (0)