Skip to content

Commit e023c1a

Browse files
committed
add a corner unit test
1 parent d1584fc commit e023c1a

File tree

3 files changed

+78
-19
lines changed

3 files changed

+78
-19
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -185,24 +185,44 @@ bool XeGPUBlockingPass::needsUnroll(Operation *op) const {
185185
if (isa<LoopLikeOpInterface>(op))
186186
return false;
187187

188-
for (auto &opr : op->getOpOperands()) {
188+
auto isUnrollable = [&](Value value,
189+
ArrayRef<int64_t> tileShape) -> std::optional<bool> {
190+
Type valTy = value.getType();
191+
if (auto tdesc = dyn_cast<xegpu::TensorDescType>(valTy)) {
192+
xegpu::LayoutAttr layout = tdesc.getLayoutAttr();
193+
if (!layout)
194+
return std::nullopt;
195+
if (layout.isWgLayout())
196+
return false;
197+
if (layout.getInstData())
198+
return true;
199+
}
200+
201+
auto shapedType = dyn_cast<ShapedType>(valTy);
202+
if (shapedType && !llvm::equal(tileShape, shapedType.getShape()))
203+
return true;
204+
205+
return std::nullopt;
206+
};
207+
208+
for (OpOperand &opr : op->getOpOperands()) {
189209
std::optional<SmallVector<int64_t>> tileShape = getTileShape(opr);
190-
auto shapedType = dyn_cast<ShapedType>(opr.get().getType());
191-
if (!shapedType || !tileShape)
210+
if (!tileShape)
192211
continue;
193212

194-
if (!llvm::equal(*tileShape, shapedType.getShape()))
195-
return true;
213+
std::optional<bool> unrollable = isUnrollable(opr.get(), *tileShape);
214+
if (unrollable.has_value())
215+
return unrollable.value();
196216
}
197217

198-
for (auto result : op->getOpResults()) {
218+
for (OpResult result : op->getOpResults()) {
199219
std::optional<SmallVector<int64_t>> tileShape = getTileShape(result);
200-
auto shapedType = dyn_cast<ShapedType>(result.getType());
201-
if (!shapedType || !tileShape)
220+
if (!tileShape)
202221
continue;
203222

204-
if (!llvm::equal(*tileShape, shapedType.getShape()))
205-
return true;
223+
std::optional<bool> unrollable = isUnrollable(result, *tileShape);
224+
if (unrollable.has_value())
225+
return unrollable.value();
206226
}
207227
return false;
208228
}

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
136136
ArrayRef<int64_t> shape = tdescTy.getShape();
137137

138138
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
139-
if (!targetShape || llvm::equal(*targetShape, shape))
139+
if (!targetShape)
140140
return failure();
141141

142142
auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
@@ -187,10 +187,9 @@ struct UnrollUpdateNdOffsetOp : public UnrollPattern<xegpu::UpdateNdOffsetOp> {
187187
PatternRewriter &rewriter) const override {
188188
Location loc = op.getLoc();
189189
xegpu::TensorDescType tdescTy = op.getTensorDescType();
190-
ArrayRef<int64_t> shape = tdescTy.getShape();
191190

192191
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
193-
if (!targetShape || llvm::equal(*targetShape, shape))
192+
if (!targetShape)
194193
return failure();
195194

196195
SmallVector<Type> convertedTdescTypes =
@@ -216,10 +215,9 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
216215
PatternRewriter &rewriter) const override {
217216
Location loc = op.getLoc();
218217
xegpu::TensorDescType tdescTy = op.getTensorDescType();
219-
ArrayRef<int64_t> shape = tdescTy.getShape();
220218

221219
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
222-
if (!targetShape || llvm::equal(*targetShape, shape))
220+
if (!targetShape)
223221
return failure();
224222

225223
SmallVector<Type> convertedTdescTypes =
@@ -243,10 +241,9 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
243241
Location loc = op.getLoc();
244242
VectorType valueTy = op.getType();
245243
xegpu::TensorDescType tdescTy = op.getTensorDescType();
246-
ArrayRef<int64_t> shape = tdescTy.getShape();
247244

248245
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
249-
if (!targetShape || llvm::equal(*targetShape, shape))
246+
if (!targetShape)
250247
return failure();
251248

252249
Type elemTy = tdescTy.getElementType();
@@ -278,10 +275,9 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
278275
Location loc = op.getLoc();
279276
VectorType valueTy = op.getValueType();
280277
xegpu::TensorDescType tdescTy = op.getTensorDescType();
281-
ArrayRef<int64_t> shape = tdescTy.getShape();
282278

283279
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
284-
if (!targetShape || llvm::equal(*targetShape, shape))
280+
if (!targetShape)
285281
return failure();
286282

287283
SmallVector<Type> convertedValTypes =

mlir/test/Dialect/XeGPU/xegpu-blocking.mlir

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,49 @@ gpu.module @test_kernel {
8282
}
8383
}
8484

85+
// -----
86+
#l1 = #xegpu.layout<inst_data = [8, 16]>
87+
#l2 = #xegpu.layout<inst_data = [16, 16]>
88+
gpu.module @test_kernel {
89+
gpu.func @test_gemm(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) {
90+
%c0 = arith.constant 0 : index
91+
%c8 = arith.constant 8 : index
92+
%c16 = arith.constant 16 : index
93+
%c32 = arith.constant 32 : index
94+
%c1024 = arith.constant 1024 : index
95+
%block_id_x = gpu.block_id x
96+
%block_id_y = gpu.block_id y
97+
%m = arith.muli %block_id_x, %c8 : index
98+
%n = arith.muli %block_id_y, %c32 : index
99+
100+
%c_tdesc = xegpu.create_nd_tdesc %C[%m, %n] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x32xf32, #l1>
101+
102+
//CHECK-COUNT-2: xegpu.load_nd {{.*}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
103+
%c_init = xegpu.load_nd %c_tdesc : !xegpu.tensor_desc<8x32xf32, #l1> -> vector<8x32xf32>
104+
105+
%a_tdesc = xegpu.create_nd_tdesc %A[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16, #l1>
106+
%b_tdesc = xegpu.create_nd_tdesc %B[%c0, %n] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x32xf16, #l2>
107+
%out:3 = scf.for %k = %c0 to %c1024 step %c16
108+
iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_init)
109+
-> (!xegpu.tensor_desc<8x16xf16, #l1>, !xegpu.tensor_desc<16x32xf16, #l2>, vector<8x32xf32>) {
110+
//CHECK: %22 = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
111+
%a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16, #l1> -> vector<8x16xf16>
112+
//CHECK-COUNT-2: xegpu.load_nd {{.*}} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
113+
%b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x32xf16, #l2> -> vector<16x32xf16>
114+
%c = xegpu.dpas %a, %b, %arg2 {layout_result_0 = #l1}: vector<8x16xf16>, vector<16x32xf16>, vector<8x32xf32> -> vector<8x32xf32>
115+
//CHECK: xegpu.update_nd_offset {{.*}} [%c0, %c32] : !xegpu.tensor_desc<8x16xf16>
116+
%a_next_tdesc = xegpu.update_nd_offset %arg0, [%c0, %c32] : !xegpu.tensor_desc<8x16xf16, #l1>
117+
//CHECK-COUNT-2: xegpu.update_nd_offset {{.*}} [%c32, %c0] : !xegpu.tensor_desc<16x16xf16>
118+
%b_next_tdesc = xegpu.update_nd_offset %arg1, [%c32, %c0] : !xegpu.tensor_desc<16x32xf16, #l2>
119+
scf.yield %a_next_tdesc, %b_next_tdesc, %c
120+
: !xegpu.tensor_desc<8x16xf16, #l1>, !xegpu.tensor_desc<16x32xf16, #l2>, vector<8x32xf32>
121+
}
122+
//CHECK-COUNT-2: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
123+
xegpu.store_nd %out#2, %c_tdesc: vector<8x32xf32>, !xegpu.tensor_desc<8x32xf32, #l1>
124+
gpu.return
125+
}
126+
}
127+
85128
// -----
86129
#a = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [8, 1]>
87130
#b = #xegpu.layout<inst_data = [16, 16], lane_layout = [1, 16], lane_data = [16, 1]>

0 commit comments

Comments
 (0)