Skip to content

Commit 48c9a8a

Browse files
authored
[MLIR][XeGPU] Enable blocking for scatter ops with offsets (llvm#162896)
The unroll patterns for these ops were added in the previous PR but the getTileShape method was not changed to handle these ops and hence blocking pass was not kicking in.
1 parent b86503e commit 48c9a8a

File tree

2 files changed

+113
-2
lines changed

2 files changed

+113
-2
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,11 +161,24 @@ XeGPUBlockingPass::getTileShape(Operation *op) const {
161161
xegpu::UpdateOffsetOp, xegpu::LoadMatrixOp>(op))
162162
return getTileShape(op->getOpResult(0));
163163
if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp,
164-
xegpu::LoadGatherOp, xegpu::StoreMatrixOp>(op))
164+
xegpu::StoreMatrixOp>(op))
165165
return getTileShape(op->getOpOperand(0));
166-
if (isa<xegpu::StoreNdOp, xegpu::StoreScatterOp>(op))
166+
if (isa<xegpu::StoreNdOp>(op))
167167
return getTileShape(op->getOpOperand(1));
168168

169+
// Handle LoadGatherOp and StoreScatterOp (with and without offset)
170+
if (auto loadGatherOp = dyn_cast<xegpu::LoadGatherOp>(op)) {
171+
if (loadGatherOp.getOffsets())
172+
return getTileShape(loadGatherOp->getOpResult(0));
173+
else
174+
return getTileShape(loadGatherOp->getOpOperand(0));
175+
}
176+
177+
if (auto storeScatterOp = dyn_cast<xegpu::StoreScatterOp>(op))
178+
return getTileShape(storeScatterOp.getOffsets()
179+
? storeScatterOp->getOpOperand(0)
180+
: storeScatterOp->getOpOperand(1));
181+
169182
if (isa<xegpu::DpasOp>(op)) {
170183
std::optional<SmallVector<int64_t>> aTile =
171184
getTileShape(op->getOpOperand(0));

mlir/test/Dialect/XeGPU/xegpu-blocking.mlir

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,3 +584,101 @@ gpu.module @test_kernel {
584584
gpu.return
585585
}
586586
}
587+
588+
// -----
589+
gpu.module @test_kernel {
590+
// CHECK-LABEL: load_with_offsets
591+
// CHECK-COUNT-2: xegpu.load {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<16xindex>, vector<16xi1> -> vector<16xf32>
592+
gpu.func @load_with_offsets(%src: ui64) -> vector<32xf32> {
593+
%cst = arith.constant dense<[
594+
0, 8, 16, 24, 32, 40, 48, 56,
595+
64, 72, 80, 88, 96, 104, 112, 120,
596+
128, 136, 144, 152, 160, 168, 176, 184,
597+
192, 200, 208, 216, 224, 232, 240, 248
598+
]> : vector<32xindex>
599+
600+
%c17 = arith.constant 17: index
601+
%mask = vector.create_mask %c17: vector<32xi1>
602+
%ld = xegpu.load %src[%cst], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<inst_data = [16]>, l1_hint = #xegpu.cache_hint<cached>} : ui64, vector<32xindex>, vector<32xi1> -> vector<32xf32>
603+
604+
gpu.return %ld : vector<32xf32>
605+
}
606+
}
607+
608+
// -----
609+
gpu.module @test_kernel {
610+
// CHECK-LABEL: store_with_offsets
611+
// CHECK-COUNT-2: xegpu.store {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<16xf32>, ui64, vector<16xindex>, vector<16xi1>
612+
gpu.func @store_with_offsets(%src: ui64) {
613+
%cst = arith.constant dense<[
614+
0, 8, 16, 24, 32, 40, 48, 56,
615+
64, 72, 80, 88, 96, 104, 112, 120,
616+
128, 136, 144, 152, 160, 168, 176, 184,
617+
192, 200, 208, 216, 224, 232, 240, 248
618+
]> : vector<32xindex>
619+
620+
%c17 = arith.constant 17: index
621+
%mask = vector.create_mask %c17: vector<32xi1>
622+
623+
%st_vec = arith.constant dense<1023.0>: vector<32xf32>
624+
xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 1, layout_operand_0 = #xegpu.layout<inst_data = [16]>,
625+
layout_operand_2 = #xegpu.layout<inst_data = [16]>,
626+
layout_operand_3 = #xegpu.layout<inst_data = [16]>,
627+
l1_hint = #xegpu.cache_hint<cached>} : vector<32xf32>, ui64, vector<32xindex>, vector<32xi1>
628+
629+
gpu.return
630+
}
631+
}
632+
633+
// -----
634+
gpu.module @test_kernel {
635+
// CHECK-LABEL: load_with_offsets_chunk
636+
// CHECK: [[cst:%.+]] = arith.constant dense<0.000000e+00> : vector<32x4xf32>
637+
// CHECK: [[cst0:%.+]] = arith.constant dense<[130, 138, 146, 154, 162, 170, 178, 186, 194, 202, 210, 218, 226, 234, 242, 250]> : vector<16xindex>
638+
// CHECK: [[cst1:%.+]] = arith.constant dense<[2, 10, 18, 26, 34, 42, 50, 58, 66, 74, 82, 90, 98, 106, 114, 122]> : vector<16xindex>
639+
// CHECK: [[cst2:%.+]] = arith.constant dense<[128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<16xindex>
640+
// CHECK: [[cst3:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex>
641+
// CHECK-COUNT-4: xegpu.load {{.*}}[{{.*}}], {{.*}} <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<16xindex>, vector<16xi1> -> vector<16x2xf32>
642+
gpu.func @load_with_offsets_chunk(%src: ui64) -> vector<32x4xf32> {
643+
%cst = arith.constant dense<[
644+
0, 8, 16, 24, 32, 40, 48, 56,
645+
64, 72, 80, 88, 96, 104, 112, 120,
646+
128, 136, 144, 152, 160, 168, 176, 184,
647+
192, 200, 208, 216, 224, 232, 240, 248
648+
]> : vector<32xindex>
649+
650+
%c17 = arith.constant 17: index
651+
%mask = vector.create_mask %c17: vector<32xi1>
652+
%ld = xegpu.load %src[%cst], %mask {chunk_size = 4, layout_result_0 = #xegpu.layout<inst_data = [16, 2]>, l1_hint = #xegpu.cache_hint<cached>} : ui64, vector<32xindex>, vector<32xi1> -> vector<32x4xf32>
653+
gpu.return %ld : vector<32x4xf32>
654+
}
655+
}
656+
657+
// -----
658+
gpu.module @test_kernel {
659+
// CHECK-LABEL: store_with_offsets_chunk
660+
// CHECK: [[cst:%.+]] = arith.constant dense<1.023000e+03> : vector<16x2xf32
661+
// CHECK: [[cst0:%.+]] = arith.constant dense<[130, 138, 146, 154, 162, 170, 178, 186, 194, 202, 210, 218, 226, 234, 242, 250]> : vector<16xindex>
662+
// CHECK: [[cst1:%.+]] = arith.constant dense<[2, 10, 18, 26, 34, 42, 50, 58, 66, 74, 82, 90, 98, 106, 114, 122]> : vector<16xindex>
663+
// CHECK: [[cst2:%.+]] = arith.constant dense<[128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<16xindex>
664+
// CHECK: [[cst3:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex>
665+
// CHECK-COUNT-4: xegpu.store {{.*}}[{{.*}}], {{.*}} <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<16x2xf32>, ui64, vector<16xindex>, vector<16xi1>
666+
gpu.func @store_with_offsets_chunk(%src: ui64) {
667+
%cst = arith.constant dense<[
668+
0, 8, 16, 24, 32, 40, 48, 56,
669+
64, 72, 80, 88, 96, 104, 112, 120,
670+
128, 136, 144, 152, 160, 168, 176, 184,
671+
192, 200, 208, 216, 224, 232, 240, 248
672+
]> : vector<32xindex>
673+
674+
%c17 = arith.constant 17: index
675+
%mask = vector.create_mask %c17: vector<32xi1>
676+
677+
%st_vec = arith.constant dense<1023.>: vector<32x4xf32>
678+
xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 4, layout_operand_0 = #xegpu.layout<inst_data = [16, 2]>,
679+
layout_operand_2 = #xegpu.layout<inst_data = [16, 2]>,
680+
layout_operand_3 = #xegpu.layout<inst_data = [16, 2]>,
681+
l1_hint = #xegpu.cache_hint<cached>} : vector<32x4xf32>, ui64, vector<32xindex>, vector<32xi1>
682+
gpu.return
683+
}
684+
}

0 commit comments

Comments
 (0)