Skip to content

Commit 8941125

Browse files
committed
Enable blocking for scatter ops with offsets
1 parent 8aa4997 commit 8941125

File tree

2 files changed

+117
-2
lines changed

2 files changed

+117
-2
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,11 +161,24 @@ XeGPUBlockingPass::getTileShape(Operation *op) const {
161161
xegpu::UpdateOffsetOp, xegpu::LoadMatrixOp>(op))
162162
return getTileShape(op->getOpResult(0));
163163
if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp,
164-
xegpu::LoadGatherOp, xegpu::StoreMatrixOp>(op))
164+
xegpu::StoreMatrixOp>(op))
165165
return getTileShape(op->getOpOperand(0));
166-
if (isa<xegpu::StoreNdOp, xegpu::StoreScatterOp>(op))
166+
if (isa<xegpu::StoreNdOp>(op))
167167
return getTileShape(op->getOpOperand(1));
168168

169+
// Handle LoadGatherOp and StoreScatterOp (with and without offset)
170+
if (auto loadGatherOp = dyn_cast<xegpu::LoadGatherOp>(op)) {
171+
if (loadGatherOp.getOffsets())
172+
return getTileShape(loadGatherOp->getOpResult(0));
173+
else
174+
return getTileShape(loadGatherOp->getOpOperand(0));
175+
}
176+
177+
if (auto storeScatterOp = dyn_cast<xegpu::StoreScatterOp>(op))
178+
return getTileShape(storeScatterOp.getOffsets()
179+
? storeScatterOp->getOpOperand(0)
180+
: storeScatterOp->getOpOperand(1));
181+
169182
if (isa<xegpu::DpasOp>(op)) {
170183
std::optional<SmallVector<int64_t>> aTile =
171184
getTileShape(op->getOpOperand(0));

mlir/test/Dialect/XeGPU/xegpu-blocking.mlir

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,3 +584,105 @@ gpu.module @test_kernel {
584584
gpu.return
585585
}
586586
}
587+
588+
// -----
589+
gpu.module @test_kernel {
590+
// CHECK-LABEL: load_with_offsets
591+
// CHECK-SAME: [[arg0:%.+]]: ui64
592+
// CHECK-COUNT-2: xegpu.load {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<16xindex>, vector<16xi1> -> vector<16xf32>
593+
gpu.func @load_with_offsets(%src: ui64) -> vector<32xf32> {
594+
%cst = arith.constant dense<[
595+
0, 8, 16, 24, 32, 40, 48, 56,
596+
64, 72, 80, 88, 96, 104, 112, 120,
597+
128, 136, 144, 152, 160, 168, 176, 184,
598+
192, 200, 208, 216, 224, 232, 240, 248
599+
]> : vector<32xindex>
600+
601+
%c17 = arith.constant 17: index
602+
%mask = vector.create_mask %c17: vector<32xi1>
603+
%ld = xegpu.load %src[%cst], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<inst_data = [16]>, l1_hint = #xegpu.cache_hint<cached>} : ui64, vector<32xindex>, vector<32xi1> -> vector<32xf32>
604+
605+
gpu.return %ld : vector<32xf32>
606+
}
607+
}
608+
609+
// -----
610+
gpu.module @test_kernel {
611+
// CHECK-LABEL: store_with_offsets
612+
// CHECK-SAME: [[arg0:%.+]]: ui64
613+
// CHECK-COUNT-2: xegpu.store {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<16xf32>, ui64, vector<16xindex>, vector<16xi1>
614+
gpu.func @store_with_offsets(%src: ui64) {
615+
%cst = arith.constant dense<[
616+
0, 8, 16, 24, 32, 40, 48, 56,
617+
64, 72, 80, 88, 96, 104, 112, 120,
618+
128, 136, 144, 152, 160, 168, 176, 184,
619+
192, 200, 208, 216, 224, 232, 240, 248
620+
]> : vector<32xindex>
621+
622+
%c17 = arith.constant 17: index
623+
%mask = vector.create_mask %c17: vector<32xi1>
624+
625+
%st_vec = arith.constant dense<1023.0>: vector<32xf32>
626+
xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 1, layout_operand_0 = #xegpu.layout<inst_data = [16]>,
627+
layout_operand_2 = #xegpu.layout<inst_data = [16]>,
628+
layout_operand_3 = #xegpu.layout<inst_data = [16]>,
629+
l1_hint = #xegpu.cache_hint<cached>} : vector<32xf32>, ui64, vector<32xindex>, vector<32xi1>
630+
631+
gpu.return
632+
}
633+
}
634+
635+
// -----
636+
gpu.module @test_kernel {
637+
// CHECK-LABEL: load_with_offsets_chunk
638+
// CHECK-SAME: [[arg0:%.+]]: ui64
639+
// CHECK: [[cst:%.+]] = arith.constant dense<0.000000e+00> : vector<32x4xf32>
640+
// CHECK: [[cst0:%.+]] = arith.constant dense<[130, 138, 146, 154, 162, 170, 178, 186, 194, 202, 210, 218, 226, 234, 242, 250]> : vector<16xindex>
641+
// CHECK: [[cst1:%.+]] = arith.constant dense<[2, 10, 18, 26, 34, 42, 50, 58, 66, 74, 82, 90, 98, 106, 114, 122]> : vector<16xindex>
642+
// CHECK: [[cst2:%.+]] = arith.constant dense<[128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<16xindex>
643+
// CHECK: [[cst3:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex>
644+
// CHECK-COUNT-4: xegpu.load {{.*}}[{{.*}}], {{.*}} <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<16xindex>, vector<16xi1> -> vector<16x2xf32>
645+
gpu.func @load_with_offsets_chunk(%src: ui64) -> vector<32x4xf32> {
646+
%cst = arith.constant dense<[
647+
0, 8, 16, 24, 32, 40, 48, 56,
648+
64, 72, 80, 88, 96, 104, 112, 120,
649+
128, 136, 144, 152, 160, 168, 176, 184,
650+
192, 200, 208, 216, 224, 232, 240, 248
651+
]> : vector<32xindex>
652+
653+
%c17 = arith.constant 17: index
654+
%mask = vector.create_mask %c17: vector<32xi1>
655+
%ld = xegpu.load %src[%cst], %mask {chunk_size = 4, layout_result_0 = #xegpu.layout<inst_data = [16, 2]>, l1_hint = #xegpu.cache_hint<cached>} : ui64, vector<32xindex>, vector<32xi1> -> vector<32x4xf32>
656+
gpu.return %ld : vector<32x4xf32>
657+
}
658+
}
659+
660+
// -----
661+
gpu.module @test_kernel {
662+
// CHECK-LABEL: store_with_offsets_chunk
663+
// CHECK-SAME: [[arg0:%.+]]: ui64
664+
// CHECK: [[cst:%.+]] = arith.constant dense<1.023000e+03> : vector<16x2xf32
665+
// CHECK: [[cst0:%.+]] = arith.constant dense<[130, 138, 146, 154, 162, 170, 178, 186, 194, 202, 210, 218, 226, 234, 242, 250]> : vector<16xindex>
666+
// CHECK: [[cst1:%.+]] = arith.constant dense<[2, 10, 18, 26, 34, 42, 50, 58, 66, 74, 82, 90, 98, 106, 114, 122]> : vector<16xindex>
667+
// CHECK: [[cst2:%.+]] = arith.constant dense<[128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<16xindex>
668+
// CHECK: [[cst3:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex>
669+
// CHECK-COUNT-4: xegpu.store {{.*}}[{{.*}}], {{.*}} <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<16x2xf32>, ui64, vector<16xindex>, vector<16xi1>
670+
gpu.func @store_with_offsets_chunk(%src: ui64) {
671+
%cst = arith.constant dense<[
672+
0, 8, 16, 24, 32, 40, 48, 56,
673+
64, 72, 80, 88, 96, 104, 112, 120,
674+
128, 136, 144, 152, 160, 168, 176, 184,
675+
192, 200, 208, 216, 224, 232, 240, 248
676+
]> : vector<32xindex>
677+
678+
%c17 = arith.constant 17: index
679+
%mask = vector.create_mask %c17: vector<32xi1>
680+
681+
%st_vec = arith.constant dense<1023.>: vector<32x4xf32>
682+
xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 4, layout_operand_0 = #xegpu.layout<inst_data = [16, 2]>,
683+
layout_operand_2 = #xegpu.layout<inst_data = [16, 2]>,
684+
layout_operand_3 = #xegpu.layout<inst_data = [16, 2]>,
685+
l1_hint = #xegpu.cache_hint<cached>} : vector<32x4xf32>, ui64, vector<32xindex>, vector<32xi1>
686+
gpu.return
687+
}
688+
}

0 commit comments

Comments
 (0)