Skip to content

Commit d62c043

Browse files
authored
Order support in xetile.init_tile/load_tile (#707)
order support
1 parent 9779a93 commit d62c043

File tree

5 files changed

+91
-16
lines changed

5 files changed

+91
-16
lines changed

lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -294,16 +294,23 @@ class SgInitTileOpPattern
294294
auto loc = op.getLoc();
295295
auto source = op.getSource();
296296
auto tileTy = op.getType();
297-
auto innerBlk = tileTy.getInnerBlocks();
298-
auto shape = tileTy.getShape();
297+
auto innerBlocks = tileTy.getInnerBlocks();
298+
auto shape = llvm::to_vector(tileTy.getShape());
299299
auto indexType = rewriter.getIndexType();
300300

301301
if (tileTy.getRank() != 2)
302302
return op.emitOpError("The tile shape should be 2D.");
303303

304-
if (!innerBlk || innerBlk.size() != 2)
304+
if (!innerBlocks || innerBlocks.size() != 2)
305305
return op.emitOpError("Missing valid innerBlock for the tile in op.");
306306

307+
// Need to make a copy, so we can swap values.
308+
auto innerBlk = llvm::to_vector(innerBlocks.asArrayRef());
309+
if (tileTy.getOrder().asArrayRef() == mlir::ArrayRef({0, 1})) {
310+
std::swap(innerBlk[0], innerBlk[1]);
311+
std::swap(shape[0], shape[1]);
312+
}
313+
307314
// using array_length for load if dim1 of innerBlocks
308315
// is smaller than dim 1 of shape.
309316
auto array_length =
@@ -330,6 +337,7 @@ class SgInitTileOpPattern
330337

331338
auto offsetsX = offsets[0];
332339
auto offsetsY = offsets[1];
340+
333341
auto tDescTy = xegpu::TensorDescType::get(
334342
innerBlk, tileTy.getElementType(), xegpu::MemoryScope::GLOBAL,
335343
array_length, true /*boundary_check*/, {} /*scattered*/,
@@ -450,8 +458,26 @@ struct SgLoadTileOpPattern
450458
vnniAttr = rewriter.getI32IntegerAttr(axis);
451459
}
452460

453-
// TODO: add transpose info
454461
mlir::DenseI64ArrayAttr transposeAttr;
462+
auto srcOrder = tileTy.getOrder();
463+
if (srcOrder.asArrayRef() == mlir::ArrayRef({1, 0})) {
464+
// Nothing to do
465+
} else if (srcOrder.asArrayRef() == mlir::ArrayRef({0, 1})) {
466+
auto elemWidth = elemTy.getIntOrFloatBitWidth();
467+
if (elemWidth == 32) {
468+
transposeAttr = rewriter.getDenseI64ArrayAttr({1, 0});
469+
} else if (elemWidth == 16 && vnniAttr && vnniAttr.getInt() == 0) {
470+
transposeAttr = rewriter.getDenseI64ArrayAttr({1, 0});
471+
transposeBitWidthAttr = rewriter.getI32IntegerAttr(32);
472+
vnniAttr = nullptr;
473+
} else {
474+
return ((mlir::PatternRewriter &)rewriter)
475+
.notifyMatchFailure(op, "Unsupported element type for transpose");
476+
}
477+
} else {
478+
return ((mlir::PatternRewriter &)rewriter)
479+
.notifyMatchFailure(op, "Unsupported order");
480+
}
455481

456482
rewriter.setInsertionPoint(op);
457483
llvm::SmallVector<::mlir::Value> xegpuOps;
@@ -461,10 +487,17 @@ struct SgLoadTileOpPattern
461487
auto shape = tdescTy.getShape().vec();
462488
auto array_length = tdescTy.getArrayLength();
463489

490+
if (transposeAttr)
491+
std::swap(shape[0], shape[1]);
492+
464493
if (vnniAttr) {
465494
auto axis = vnniAttr.getInt();
466495
shape[axis] /= factor;
467496
shape.push_back(factor);
497+
} else if (transposeBitWidthAttr) {
498+
auto axis = 0;
499+
shape[axis] /= factor;
500+
shape.push_back(factor);
468501
}
469502

470503
if (array_length != 1)

lib/Dialect/XeTile/Transforms/Blocking.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -483,12 +483,14 @@ struct InitTileOpPattern : public XeTileConversion<xetile::InitTileOp> {
483483

484484
int factor = 32 / elementSize;
485485
vnni = false;
486-
innerBlocks = mlir::DenseI64ArrayAttr::get(
487-
getContext(),
488-
getInnerBlockSizes<Load>(
489-
op.getOperation(), mlir::FloatType::getF32(getContext()),
490-
tileTy.getShape()[0], (tileTy.getShape()[1]) * factor,
491-
this->uArchInterface, vnni, transpose));
486+
llvm::SmallVector<int64_t, 2> innerBlock = getInnerBlockSizes<Load>(
487+
op.getOperation(), mlir::FloatType::getF32(getContext()),
488+
tileTy.getShape()[1], (tileTy.getShape()[0]) / factor,
489+
this->uArchInterface, vnni, transpose);
490+
std::swap(innerBlock[0], innerBlock[1]);
491+
innerBlock[0] *= factor;
492+
innerBlocks = mlir::DenseI64ArrayAttr::get(getContext(), innerBlock);
493+
492494
} else if (transpose && elementSize < 32) {
493495
return rewriter.notifyMatchFailure(op, "Invalid transpose.");
494496
} else {

lib/Utils/XeArch.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,14 @@ XePVCuArch::get2DLoadConfig(mlir::Operation *op, int element_data_size,
9595
<< "transpose and transform are not supported together";
9696
}
9797

98+
// FIXME: We do support transpose on f16 wtih transpose_bit_width==32,
99+
// disable check for now.
98100
// only d32 and d64 is supported for transpose operations
99-
if ((transpose) && (element_data_size != 32 && element_data_size != 64)) {
100-
return op->emitOpError()
101-
<< "transposed load only supports d32 and d64 data sizes. "
102-
<< "Given element data size: d" << element_data_size;
103-
}
101+
// if ((transpose) && (element_data_size != 32 && element_data_size != 64)) {
102+
// return op->emitOpError()
103+
// << "transposed load only supports d32 and d64 data sizes. "
104+
// << "Given element data size: d" << element_data_size;
105+
// }
104106

105107
// only d8 and d16 are suported for VNNI transform operations
106108
if ((vnni) && (element_data_size != 8 && element_data_size != 16)) {
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// RUN: imex-opt --split-input-file --xetile-blocking --convert-xetile-to-xegpu --cse %s -verify-diagnostics -o -| FileCheck %s
2+
3+
// CHECK-LABEL: @test_func
4+
// CHECK-SAME: (%[[A:.*]]: memref<128x64xf16>, %[[B:.*]]: memref<64x128xf16, strided<[1, 64]>>)
5+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
6+
// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
7+
// CHECK: %[[D1:.*]] = xegpu.create_nd_tdesc %[[A]][%[[C0]], %[[C0]]] {mode = vc} : memref<128x64xf16> -> !xegpu.tensor_desc<32x16xf16>
8+
// CHECK: %[[D2:.*]] = xegpu.create_nd_tdesc %[[B]][%[[C0]], %[[C0]]] {mode = vc} : memref<64x128xf16, strided<[1, 64]>> -> !xegpu.tensor_desc<32x16xf16>
9+
// CHECK: %{{.*}} = xegpu.load_nd %[[D1]] {mode = vc, vnni_axis = 1, l1_hint = cached, l2_hint = cached, l3_hint = cached} : !xegpu.tensor_desc<32x16xf16> -> vector<32x8x2xf16>
10+
// CHECK: %{{.*}} = xegpu.load_nd %[[D2]] {mode = vc, transpose = [1, 0], transpose_bit_width = 32, l1_hint = cached, l2_hint = cached, l3_hint = cached} : !xegpu.tensor_desc<32x16xf16> -> vector<8x32x2xf16>
11+
// CHECK: %[[D3:.*]] = xegpu.update_nd_offset %[[D1]], [%[[C0]], %[[C16]]] {mode = vc} : !xegpu.tensor_desc<32x16xf16> -> !xegpu.tensor_desc<32x16xf16>
12+
// CHECK: %[[D4:.*]] = xegpu.update_nd_offset %[[D2]], [%[[C16]], %[[C0]]] {mode = vc} : !xegpu.tensor_desc<32x16xf16> -> !xegpu.tensor_desc<32x16xf16>
13+
// CHECK: %{{.*}} = xegpu.load_nd %[[D3]] {mode = vc, vnni_axis = 1, l1_hint = cached, l2_hint = cached, l3_hint = cached} : !xegpu.tensor_desc<32x16xf16> -> vector<32x8x2xf16>
14+
// CHECK: %{{.*}} = xegpu.load_nd %[[D4]] {mode = vc, transpose = [1, 0], transpose_bit_width = 32, l1_hint = cached, l2_hint = cached, l3_hint = cached} : !xegpu.tensor_desc<32x16xf16> -> vector<8x32x2xf16>
15+
gpu.module @test_kernel {
16+
func.func @test_func(%A : memref<128x64xf16>, %B : memref<64x128xf16, strided<[1, 64], offset: 0>>) {
17+
%c0 = arith.constant 0 : index
18+
%c32 = arith.constant 32 : index
19+
%c16 = arith.constant 16 : index
20+
%A_block_iter0 = xetile.init_tile %A[%c0, %c0] : memref<128x64xf16> -> !xetile.tile<32x16xf16>
21+
%B_block_iter0 = xetile.init_tile %B[%c0, %c0] : memref<64x128xf16, strided<[1, 64], offset: 0>> -> !xetile.tile<16x32xf16, #xetile.tile_attr<order = [0, 1]>>
22+
23+
%A_block_value0 = xetile.load_tile %A_block_iter0 : !xetile.tile<32x16xf16> -> vector<32x16xf16>
24+
%B_block_value0 = xetile.load_tile %B_block_iter0 : !xetile.tile<16x32xf16, #xetile.tile_attr<order = [0,1]>> -> vector<16x32xf16>
25+
26+
%mma_out0 = xetile.tile_mma %A_block_value0, %B_block_value0 : vector<32x16xf16>, vector<16x32xf16> -> vector<32x32xf32>
27+
28+
%A_block_iter1 = xetile.update_tile_offset %A_block_iter0, [%c0, %c16] : !xetile.tile<32x16xf16>, index, index -> !xetile.tile<32x16xf16>
29+
%B_block_iter1 = xetile.update_tile_offset %B_block_iter0, [%c16, %c0] : !xetile.tile<16x32xf16, #xetile.tile_attr<order = [0,1]>>, index, index -> !xetile.tile<16x32xf16, #xetile.tile_attr<order = [0,1]>>
30+
31+
%A_block_value1 = xetile.load_tile %A_block_iter1 : !xetile.tile<32x16xf16> -> vector<32x16xf16>
32+
%B_block_value1 = xetile.load_tile %B_block_iter1 : !xetile.tile<16x32xf16, #xetile.tile_attr<order = [0,1]>> -> vector<16x32xf16>
33+
34+
%mma_out1 = xetile.tile_mma %A_block_value1, %B_block_value1, %mma_out0 : vector<32x16xf16>, vector<16x32xf16>, vector<32x32xf32> -> vector<32x32xf32>
35+
36+
return
37+
}
38+
}

test/Dialect/XeTile/Transforms/blocking.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ gpu.module @test_kernel {
152152
//CHECK: gpu.func @tile_mma(%[[arg0:.*]]: memref<128x128xf16>, %[[arg1:.*]]: memref<128x128xf16>)
153153
//CHECK: %[[c0:.*]] = arith.constant 0 : index
154154
//CHECK: %[[R0:.*]] = xetile.init_tile %[[arg0]][%[[c0]], %[[c0]]] : memref<128x128xf16> -> !xetile.tile<90x76xf16, #xetile.tile_attr<inner_blocks = [30, 19]>>
155-
//CHECK: %[[R1:.*]] = xetile.init_tile %[[arg1]][%[[c0]], %[[c0]]] : memref<128x128xf16> -> !xetile.tile<76x90xf16, #xetile.tile_attr<order = [0, 1], inner_blocks = [19, 6]>>
155+
//CHECK: %[[R1:.*]] = xetile.init_tile %[[arg1]][%[[c0]], %[[c0]]] : memref<128x128xf16> -> !xetile.tile<76x90xf16, #xetile.tile_attr<order = [0, 1], inner_blocks = [4, 30]>>
156156
gpu.func @tile_mma(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
157157
%c0 = arith.constant 0 : index
158158
%1 = xetile.init_tile %a[%c0, %c0] : memref<128x128xf16> -> !xetile.tile<90x76xf16>

0 commit comments

Comments
 (0)