Skip to content

Commit 6d573d1

Browse files
authored
Add reduction_size option for xetile.reduction op. (#1058)
Add canonicalization pattern for new variant. Add test case for canonicalization.
1 parent aff7024 commit 6d573d1

File tree

5 files changed

+118
-4
lines changed

5 files changed

+118
-4
lines changed

include/imex/Dialect/XeTile/IR/XeTileOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,8 @@ def XeTile_ReductionOp: XeTile_Op<"reduction", []> {
540540

541541
let arguments = (ins Vector_CombiningKindAttr: $kind,
542542
XeTile_2DVector: $source,
543-
DenseI64ArrayAttr: $reduction_dims);
543+
DenseI64ArrayAttr: $reduction_dims,
544+
DefaultValuedAttr<I64Attr, "0">: $reduction_size);
544545
let results = (outs XeTile_2DVector: $result);
545546
let assemblyFormat = [{
546547
$kind `,` $source $reduction_dims attr-dict `:` type($source) `->` type($result)

lib/Dialect/XeTile/IR/XeTileOps.cpp

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,9 +322,28 @@ mlir::LogicalResult TransposeOp::verify() {
322322
mlir::LogicalResult ReductionOp::verify() {
323323
auto dims = getReductionDims();
324324
auto resShape = getResult().getType().getShape();
325-
for (auto i : dims)
326-
if (resShape[i] != 1)
327-
return emitOpError("reduction dimension of result must have size 1");
325+
if (getReductionSize() > 0) {
326+
if (dims.size() > 1)
327+
// When reduction size is specified,
328+
// only a single dimension can be reduced.
329+
return emitOpError(
330+
"when reduction size is specified, only a single reduction "
331+
"dimension is allowed.");
332+
auto srcTy = getSource().getType();
333+
if (srcTy.getRank() != 2)
334+
return emitOpError(
335+
"when reduction size is specified, source must be a 2D vector.");
336+
auto redDim = dims.front();
337+
if (resShape[redDim] !=
338+
srcTy.getShape()[redDim] / static_cast<int64_t>(getReductionSize()))
339+
return emitOpError(
340+
"reduction size does not match the expected size of the result "
341+
"dimension after reduction.");
342+
} else {
343+
for (auto i : dims)
344+
if (resShape[i] != 1)
345+
return emitOpError("reduction dimension of result must have size 1");
346+
}
328347
return mlir::success();
329348
}
330349

lib/Dialect/XeTile/Transforms/Canonicalization.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,57 @@ struct RemoveRedundantTransposeOpPattern
365365
}
366366
};
367367

368+
// Remove XeTile reduction's reduction_size attribute and replace it with shape
369+
// casts around it.
370+
struct RemoveReductionSizePattern
371+
: public mlir::OpRewritePattern<imex::xetile::ReductionOp> {
372+
using mlir::OpRewritePattern<imex::xetile::ReductionOp>::OpRewritePattern;
373+
mlir::LogicalResult
374+
matchAndRewrite(imex::xetile::ReductionOp op,
375+
mlir::PatternRewriter &rewriter) const override {
376+
if (op.getReductionSize() == 0)
377+
return mlir::failure();
378+
379+
auto loc = op.getLoc();
380+
auto srcTy = op.getSource().getType();
381+
auto elemTy = srcTy.getElementType();
382+
auto reductionSize = static_cast<int64_t>(op.getReductionSize());
383+
// Get reduction dimension "i"
384+
// Op validation ensures that only a single reduction dimension is
385+
// present if reduction size is set to non zero.
386+
// Also source rank is restricted to 2D.
387+
auto redDim = op.getReductionDims().front();
388+
auto resultTy = op.getType();
389+
auto numRes = resultTy.getNumElements();
390+
llvm::SmallVector<int64_t> newShape;
391+
if (redDim == 0) {
392+
newShape.push_back(reductionSize);
393+
newShape.push_back(numRes);
394+
} else {
395+
newShape.push_back(numRes);
396+
newShape.push_back(reductionSize);
397+
}
398+
llvm::SmallVector<int64_t> newRedShape;
399+
if (redDim == 0) {
400+
newRedShape.push_back(1);
401+
newRedShape.push_back(numRes);
402+
} else {
403+
newRedShape.push_back(numRes);
404+
newRedShape.push_back(1);
405+
}
406+
mlir::Value newCast = rewriter.create<mlir::vector::ShapeCastOp>(
407+
loc, mlir::VectorType::get(newShape, elemTy), op.getSource());
408+
mlir::Value newReductionOp = rewriter.create<imex::xetile::ReductionOp>(
409+
loc, mlir::VectorType::get(newRedShape, elemTy), op.getKind(), newCast,
410+
op.getReductionDims());
411+
412+
auto shapeCastOp = rewriter.create<mlir::vector::ShapeCastOp>(
413+
op.getLoc(), op.getType(), newReductionOp);
414+
rewriter.replaceOp(op, shapeCastOp);
415+
return mlir::success();
416+
}
417+
};
418+
368419
struct XeTileCanonicalizationPass final
369420
: public imex::impl::XeTileCanonicalizationBase<
370421
XeTileCanonicalizationPass> {
@@ -430,6 +481,7 @@ struct XeTileCanonicalizationPass final
430481

431482
target.addLegalOp<mlir::memref::ReinterpretCastOp>();
432483
target.addLegalOp<imex::xetile::TransposeOp>();
484+
target.addLegalOp<mlir::vector::ShapeCastOp>();
433485
// Col-major tile creattion is not allowed.
434486
target.addDynamicallyLegalOp<imex::xetile::InitTileOp>(
435487
[&](imex::xetile::InitTileOp op) {
@@ -450,6 +502,11 @@ struct XeTileCanonicalizationPass final
450502
[&](imex::xetile::LoadTileOp op) {
451503
return isValidTile(op.getTileType());
452504
});
505+
// ReductionOp is legal if reduction size is 0.
506+
target.addDynamicallyLegalOp<imex::xetile::ReductionOp>(
507+
[&](imex::xetile::ReductionOp op) {
508+
return op.getReductionSize() == 0;
509+
});
453510
// If any iterArg of the forOp is a col-major tile, it is illegal.
454511
target.addDynamicallyLegalOp<mlir::scf::ForOp>([&](mlir::scf::ForOp op) {
455512
for (auto arg : op.getRegionIterArgs()) {
@@ -475,6 +532,7 @@ struct XeTileCanonicalizationPass final
475532
.add<InitTileOpPattern, LoadTileOpPattern, UpdateTileOffsetOpPattern,
476533
PrefetchTilePattern, ScfForOpPattern, ScfYieldOpPattern>(
477534
typeConverter, context);
535+
patterns.add<RemoveReductionSizePattern>(context);
478536

479537
if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
480538
std::move(patterns))))

test/Dialect/XeTile/IR/ops.mlir

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,12 @@ func.func @test_reduce(%source: vector<8x16xf16>) {
330330
return
331331
}
332332

333+
func.func @test_reduce_reduction_size(%source: vector<8x16xf16>) {
334+
// CHECK: xetile.reduction {{.*}} [0] {reduction_size = 8 : i64} : vector<8x16xf16> -> vector<1x16xf16>
335+
%1 = xetile.reduction <add>, %source [0] { reduction_size = 8 : i64 } : vector<8x16xf16> -> vector<1x16xf16>
336+
return
337+
}
338+
333339
func.func @test_reduce_map(%source: vector<256x128xf16>) {
334340
// CHECK: xetile.reduction {{.*}} [1] {map1 = #xetile.wg_map<sg_layout = [32, 1], sg_data = [8, 128]>, map2 = #xetile.wg_map<sg_layout = [32, 1], sg_data = [8, 1]>} : vector<256x128xf16> -> vector<256x1xf16>
335341
%1 = xetile.reduction <add>, %source [1] {map1 = #wg_map_a, map2 = #wg_map_a2} : vector<256x128xf16> -> vector<256x1xf16>

test/Dialect/XeTile/Transforms/canonicalization.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,3 +344,33 @@ gpu.module @test_module {
344344
//CHECK: %[[transpose:.*]] = memref.transpose %[[arg0]] (d0, d1) -> (d1, d0) : memref<512x128xf16, 3> to memref<128x512xf16, strided<[1, 128]>, 3>
345345
//CHECK: %[[r2:.*]] = xetile.init_tile %[[transpose]][16, 32] : memref<128x512xf16, strided<[1, 128]>, 3> -> !xetile.tile<16x32xf16, #xetile.tile_attr<order = [0, 1], memory_space = 3 : i64>>
346346
//CHECK: xetile.store_tile %[[r1]], %[[r2]] : vector<16x32xf16>, !xetile.tile<16x32xf16, #xetile.tile_attr<order = [0, 1], memory_space = 3 : i64>>
347+
348+
// -----
349+
gpu.module @test_module {
350+
// CHECK-LABEL: gpu.func @test_reduction_size
351+
// CHECK-SAME: (%[[ARG0:.*]]: vector<64x256xf32>
352+
gpu.func @test_reduction_size(%arg0: vector<64x256xf32>, %arg1: memref<64x8xf32>) {
353+
%cst0 = arith.constant 0 : index
354+
// CHECK: %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<64x256xf32> to vector<512x32xf32>
355+
// CHECK: %[[V1:.*]] = xetile.reduction <add>, %[[V0]] [1] : vector<512x32xf32> -> vector<512x1xf32>
356+
// CHECK: %[[V2:.*]] = vector.shape_cast %[[V1]] : vector<512x1xf32> to vector<64x8xf32>
357+
%reduced = xetile.reduction <add>, %arg0 [1] { reduction_size = 32 : i64 } : vector<64x256xf32> -> vector<64x8xf32>
358+
vector.store %reduced, %arg1[%cst0, %cst0] : memref<64x8xf32>, vector<64x8xf32>
359+
gpu.return
360+
}
361+
}
362+
363+
// -----
364+
gpu.module @test_module {
365+
// CHECK-LABEL: gpu.func @test_reduction_size
366+
// CHECK-SAME: (%[[ARG0:.*]]: vector<64x256xf32>
367+
gpu.func @test_reduction_size(%arg0: vector<64x256xf32>, %arg1: memref<2x256xf32>) {
368+
%cst0 = arith.constant 0 : index
369+
// CHECK: %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<64x256xf32> to vector<32x512xf32>
370+
// CHECK: %[[V1:.*]] = xetile.reduction <add>, %[[V0]] [0] : vector<32x512xf32> -> vector<1x512xf32>
371+
// CHECK: %[[V2:.*]] = vector.shape_cast %[[V1]] : vector<1x512xf32> to vector<2x256xf32>
372+
%reduced = xetile.reduction <add>, %arg0 [0] { reduction_size = 32 : i64 } : vector<64x256xf32> -> vector<2x256xf32>
373+
vector.store %reduced, %arg1[%cst0, %cst0] : memref<2x256xf32>, vector<2x256xf32>
374+
gpu.return
375+
}
376+
}

0 commit comments

Comments
 (0)