Skip to content

Commit 7019948

Browse files
committed
Skip xegpu nd ops for trimming leading unit dims
1 parent 19278f9 commit 7019948

File tree

2 files changed

+47
-27
lines changed

2 files changed

+47
-27
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,17 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
147147
if (layout && layout.isForSubgroup()) {
148148
if (!layout.getEffectiveInstDataAsInt().empty()) {
149149
SmallVector<int64_t> instData = layout.getEffectiveInstDataAsInt();
150-
// Remove all leading unit dimensions from inst_data
151-
while (!instData.empty() && instData.front() == 1)
152-
instData.erase(instData.begin());
150+
// Remove leading unit dimensions from inst_data
151+
// Skip it for xegpu nd ops since it will be 2D
152+
Operation *definingOp = value.getDefiningOp();
153+
bool skipLeadingUnitDimRemoval =
154+
definingOp &&
155+
(isa<xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::DpasOp,
156+
xegpu::StoreNdOp, xegpu::PrefetchNdOp>(definingOp));
157+
if (!skipLeadingUnitDimRemoval) {
158+
while (!instData.empty() && instData.front() == 1)
159+
instData.erase(instData.begin());
160+
}
153161
return instData;
154162
}
155163

@@ -359,7 +367,6 @@ void XeGPUBlockingPass::runOnOperation() {
359367
// To create a new attribute with a different chunk_size:
360368
auto newEncoding = xegpu::ScatterTensorDescAttr::get(
361369
ctx, tdescTy.getMemorySpace(), blockedChunkSize);
362-
363370
encoding = newEncoding;
364371
}
365372
}

mlir/test/Dialect/XeGPU/xegpu-blocking.mlir

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -711,31 +711,44 @@ gpu.module @test_kernel {
711711
}
712712

713713
// -----
714+
#l = #xegpu.layout<inst_data = [1, 16]>
714715
gpu.module @test_kernel {
715-
// CHECK-LABEL: store_scatter
716-
// CHECK-SAME: [[arg0:%.+]]: ui64
717-
// CHECK-DAG: [[cst:%.+]] = arith.constant dense<true> : vector<16xi1>
718-
// CHECK-DAG: [[cst_0:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex>
719-
// CHECK-DAG: [[cst_1:%.+]] = arith.constant dense<[128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<16xindex>
720-
// CHECK-DAG: [[cst_2:%.+]] = arith.constant dense<1.023000e+03> : vector<16xf32>
721-
// CHECK: xegpu.store [[cst_2]], [[arg0]][[[cst_0]]], [[cst]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<16xf32>, ui64, vector<16xindex>, vector<16xi1>
722-
// CHECK: xegpu.store [[cst_2]], [[arg0]][[[cst_1]]], [[cst]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<16xf32>, ui64, vector<16xindex>, vector<16xi1>
723-
gpu.func @store_scatter(%src: ui64) {
724-
%cst = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 1, 16]>} dense<[[
725-
[0, 8, 16, 24, 32, 40, 48, 56,
726-
64, 72, 80, 88, 96, 104, 112, 120,
727-
128, 136, 144, 152, 160, 168, 176, 184,
728-
192, 200, 208, 216, 224, 232, 240, 248]
729-
]]> : vector<1x1x32xindex>
716+
// CHECK-LABEL: load_store_nd_with_offsets
717+
// CHECK-SAME: [[arg0:%.+]]: memref<1024x1024xf32>, [[arg1:%.+]]: memref<1024x1024xf32>, [[arg2:%.+]]: memref<1024x1024xf32>
718+
// CHECK-DAG: [[cst:%.+]] = arith.constant dense<0.000000e+00> : vector<1x32xf32>
719+
// CHECK-DAG: [[c16:%.+]] = arith.constant 16 : index
720+
// CHECK-DAG: [[c0:%.+]] = arith.constant 0 : index
721+
// CHECK: [[tdesc_a:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<1x16xf32>
722+
// CHECK: [[tdesc_b:%.+]] = xegpu.create_nd_tdesc [[arg1]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<1x16xf32>
723+
// CHECK: [[tdesc_c:%.+]] = xegpu.create_nd_tdesc [[arg2]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<1x16xf32>
724+
// CHECK: [[ld_a0:%.+]] = xegpu.load_nd [[tdesc_a]][[[c0]], [[c0]]] : !xegpu.tensor_desc<1x16xf32> -> vector<1x16xf32>
725+
// CHECK: [[ld_a1:%.+]] = xegpu.load_nd [[tdesc_a]][[[c0]], [[c16]]] : !xegpu.tensor_desc<1x16xf32> -> vector<1x16xf32>
726+
// CHECK: [[ld_b0:%.+]] = xegpu.load_nd [[tdesc_b]][[[c0]], [[c0]]] : !xegpu.tensor_desc<1x16xf32> -> vector<1x16xf32>
727+
// CHECK: [[ld_b1:%.+]] = xegpu.load_nd [[tdesc_b]][[[c0]], [[c16]]] : !xegpu.tensor_desc<1x16xf32> -> vector<1x16xf32>
728+
// CHECK: [[cast_a0:%.+]] = vector.shape_cast [[ld_a0]] : vector<1x16xf32> to vector<16xf32>
729+
// CHECK: [[cast_b0:%.+]] = vector.shape_cast [[ld_b0]] : vector<1x16xf32> to vector<16xf32>
730+
// CHECK: [[add0:%.+]] = arith.addf [[cast_a0]], [[cast_b0]] : vector<16xf32>
731+
// CHECK: [[ins0:%.+]] = vector.insert_strided_slice [[add0]], [[cst]] {offsets = [0, 0], strides = [1]} : vector<16xf32> into vector<1x32xf32>
732+
// CHECK: [[cast_a1:%.+]] = vector.shape_cast [[ld_a1]] : vector<1x16xf32> to vector<16xf32>
733+
// CHECK: [[cast_b1:%.+]] = vector.shape_cast [[ld_b1]] : vector<1x16xf32> to vector<16xf32>
734+
// CHECK: [[add1:%.+]] = arith.addf [[cast_a1]], [[cast_b1]] : vector<16xf32>
735+
// CHECK: [[ins1:%.+]] = vector.insert_strided_slice [[add1]], [[ins0]] {offsets = [0, 16], strides = [1]} : vector<16xf32> into vector<1x32xf32>
736+
// CHECK: [[ext0:%.+]] = vector.extract_strided_slice [[ins1]] {offsets = [0, 0], sizes = [1, 16], strides = [1, 1]} : vector<1x32xf32> to vector<1x16xf32>
737+
// CHECK: [[ext1:%.+]] = vector.extract_strided_slice [[ins1]] {offsets = [0, 16], sizes = [1, 16], strides = [1, 1]} : vector<1x32xf32> to vector<1x16xf32>
738+
// CHECK: xegpu.store_nd [[ext0]], [[tdesc_c]][[[c0]], [[c0]]] : vector<1x16xf32>, !xegpu.tensor_desc<1x16xf32>
739+
// CHECK: xegpu.store_nd [[ext1]], [[tdesc_c]][[[c0]], [[c16]]] : vector<1x16xf32>, !xegpu.tensor_desc<1x16xf32>
740+
gpu.func @load_store_nd_with_offsets(%A: memref<1024x1024xf32>, %B: memref<1024x1024xf32>, %C: memref<1024x1024xf32>) {
741+
%c0 = arith.constant 0 : index
730742

731-
%mask = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 1, 16]>} dense<true> : vector<1x1x32xi1>
743+
%a_tdesc = xegpu.create_nd_tdesc %A : memref<1024x1024xf32> -> !xegpu.tensor_desc<1x32xf32, #l>
744+
%b_tdesc = xegpu.create_nd_tdesc %B : memref<1024x1024xf32> -> !xegpu.tensor_desc<1x32xf32, #l>
745+
%c_tdesc = xegpu.create_nd_tdesc %C : memref<1024x1024xf32> -> !xegpu.tensor_desc<1x32xf32, #l>
732746

733-
%st_vec = arith.constant dense<1023.0> : vector<1x1x32xf32>
734-
xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 1, layout_operand_0 = #xegpu.layout<inst_data = [1, 1, 16]>,
735-
layout_operand_2 = #xegpu.layout<inst_data = [1, 1, 16]>,
736-
layout_operand_3 = #xegpu.layout<inst_data = [1, 1, 16]>,
737-
l1_hint = #xegpu.cache_hint<cached>} : vector<1x1x32xf32>, ui64, vector<1x1x32xindex>, vector<1x1x32xi1>
747+
%a = xegpu.load_nd %a_tdesc[%c0, %c0] : !xegpu.tensor_desc<1x32xf32, #l> -> vector<1x32xf32>
748+
%b = xegpu.load_nd %b_tdesc[%c0, %c0] : !xegpu.tensor_desc<1x32xf32, #l> -> vector<1x32xf32>
738749

739-
gpu.return
750+
%result = arith.addf %a, %b {layout_result_0 = #l} : vector<1x32xf32>
751+
xegpu.store_nd %result, %c_tdesc[%c0, %c0] : vector<1x32xf32>, !xegpu.tensor_desc<1x32xf32, #l>
752+
gpu.return
740753
}
741-
}
754+
}

0 commit comments

Comments
 (0)