Skip to content

Commit 25d82ea

Browse files
authored
[aievec] Updates on vector.reduction support (#2618)
1 parent d9695bd commit 25d82ea

File tree

2 files changed

+158
-18
lines changed

2 files changed

+158
-18
lines changed

lib/Dialect/AIEVec/Transforms/VectorToAIEVecConversions.cpp

Lines changed: 70 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -419,9 +419,10 @@ static aievec::CmpOp createCmpOpAIE2(ConversionPatternRewriter &rewriter,
419419
}
420420

421421
template <typename DstOpTy>
422-
static void generateAIEVecOpsForReductionOp(ConversionPatternRewriter &rewriter,
423-
vector::ReductionOp srcOp,
424-
int shiftIndex, Value curValue) {
422+
static aievec::ExtElemOp
423+
generateAIEVecOpsForReductionOp(ConversionPatternRewriter &rewriter,
424+
vector::ReductionOp srcOp, int shiftIndex,
425+
Value curValue) {
425426
assert(shiftIndex > 0 && (shiftIndex & (shiftIndex - 1)) == 0 &&
426427
"shiftIndex must be power of 2");
427428

@@ -447,8 +448,8 @@ static void generateAIEVecOpsForReductionOp(ConversionPatternRewriter &rewriter,
447448

448449
auto zeroConstOp =
449450
rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
450-
rewriter.replaceOpWithNewOp<aievec::ExtElemOp>(srcOp, scalarType, curOp,
451-
zeroConstOp.getResult());
451+
return rewriter.create<aievec::ExtElemOp>(loc, scalarType, curOp,
452+
zeroConstOp.getResult());
452453
}
453454

454455
static func::FuncOp getOrInsertFuncDecl(ConversionPatternRewriter &rewriter,
@@ -1476,6 +1477,8 @@ using LowerVectorMinimumFOpToAIEVecMinOp =
14761477
LowerVectorMinMaxOpToAIEVecMinMaxOp<arith::MinimumFOp, aievec::MinOp>;
14771478
using LowerVectorMaximumFOpToAIEVecMaxOp =
14781479
LowerVectorMinMaxOpToAIEVecMinMaxOp<arith::MaximumFOp, aievec::MaxOp>;
1480+
using LowerVectorMaxNumFFOpToAIEVecMaxOp =
1481+
LowerVectorMinMaxOpToAIEVecMinMaxOp<arith::MaxNumFOp, aievec::MaxOp>;
14791482

14801483
template <typename SrcOpTy, typename CmpTy>
14811484
struct LowerVectorCmpOpToAIEVecCmpOp : OpConversionPattern<SrcOpTy> {
@@ -1591,8 +1594,14 @@ struct LowerVectorReductionMinOp : OpConversionPattern<vector::ReductionOp> {
15911594
return failure();
15921595

15931596
int shiftIndex = laneSize / 2;
1594-
generateAIEVecOpsForReductionOp<aievec::MinOp>(rewriter, srcOp, shiftIndex,
1595-
srcOp.getVector());
1597+
auto reduceResultOp = generateAIEVecOpsForReductionOp<aievec::MinOp>(
1598+
rewriter, srcOp, shiftIndex, srcOp.getVector());
1599+
1600+
if (srcOp.getAcc())
1601+
rewriter.replaceOpWithNewOp<arith::MinimumFOp>(
1602+
srcOp, reduceResultOp.getResult(), srcOp.getAcc());
1603+
else
1604+
rewriter.replaceOp(srcOp, reduceResultOp);
15961605
return success();
15971606
}
15981607
};
@@ -1605,7 +1614,8 @@ struct LowerVectorReductionMaxOp : OpConversionPattern<vector::ReductionOp> {
16051614
ConversionPatternRewriter &rewriter) const override {
16061615
if (auto kind = srcOp.getKind(); kind != vector::CombiningKind::MAXSI &&
16071616
kind != vector::CombiningKind::MAXUI &&
1608-
kind != vector::CombiningKind::MAXIMUMF)
1617+
kind != vector::CombiningKind::MAXIMUMF &&
1618+
kind != vector::CombiningKind::MAXNUMF)
16091619
return failure();
16101620

16111621
auto vType = cast<VectorType>(srcOp.getVector().getType());
@@ -1617,8 +1627,14 @@ struct LowerVectorReductionMaxOp : OpConversionPattern<vector::ReductionOp> {
16171627
return failure();
16181628

16191629
int shiftIndex = laneSize / 2;
1620-
generateAIEVecOpsForReductionOp<aievec::MaxOp>(rewriter, srcOp, shiftIndex,
1621-
srcOp.getVector());
1630+
auto reduceResultOp = generateAIEVecOpsForReductionOp<aievec::MaxOp>(
1631+
rewriter, srcOp, shiftIndex, srcOp.getVector());
1632+
1633+
if (srcOp.getAcc())
1634+
rewriter.replaceOpWithNewOp<arith::MaximumFOp>(
1635+
srcOp, reduceResultOp.getResult(), srcOp.getAcc());
1636+
else
1637+
rewriter.replaceOp(srcOp, reduceResultOp);
16221638
return success();
16231639
}
16241640
};
@@ -1659,11 +1675,22 @@ struct LowerVectorReductionAddIntOp : OpConversionPattern<vector::ReductionOp> {
16591675
loc, lExtOp.getResult().getType(), lExtOp.getResult(),
16601676
rExtOp.getResult());
16611677
shiftIndex /= 2;
1662-
generateAIEVecOpsForReductionOp<aievec::AddElemOp>(
1678+
auto reduceResultOp = generateAIEVecOpsForReductionOp<aievec::AddElemOp>(
16631679
rewriter, srcOp, shiftIndex, addElemOp.getResult());
1664-
} else
1665-
generateAIEVecOpsForReductionOp<aievec::AddElemOp>(
1680+
if (srcOp.getAcc())
1681+
rewriter.replaceOpWithNewOp<arith::AddIOp>(
1682+
srcOp, reduceResultOp.getResult(), srcOp.getAcc());
1683+
else
1684+
rewriter.replaceOp(srcOp, reduceResultOp);
1685+
} else {
1686+
auto reduceResultOp = generateAIEVecOpsForReductionOp<aievec::AddElemOp>(
16661687
rewriter, srcOp, shiftIndex, srcOp.getVector());
1688+
if (srcOp.getAcc())
1689+
rewriter.replaceOpWithNewOp<arith::AddIOp>(
1690+
srcOp, reduceResultOp.getResult(), srcOp.getAcc());
1691+
else
1692+
rewriter.replaceOp(srcOp, reduceResultOp);
1693+
}
16671694

16681695
return success();
16691696
}
@@ -1717,8 +1744,14 @@ struct LowerVectorReductionAddFloatOp
17171744

17181745
auto zeroConstOp =
17191746
rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
1720-
rewriter.replaceOpWithNewOp<aievec::ExtElemOp>(srcOp, scalarType, curOp,
1721-
zeroConstOp.getResult());
1747+
auto reduceResultOp = rewriter.create<aievec::ExtElemOp>(
1748+
srcOp.getLoc(), scalarType, curOp, zeroConstOp.getResult());
1749+
1750+
if (srcOp.getAcc())
1751+
rewriter.replaceOpWithNewOp<arith::AddFOp>(
1752+
srcOp, reduceResultOp.getResult(), srcOp.getAcc());
1753+
else
1754+
rewriter.replaceOp(srcOp, reduceResultOp);
17221755
return success();
17231756
}
17241757
};
@@ -1779,8 +1812,14 @@ struct LowerVectorReductionAddBfloat16Op
17791812

17801813
auto zeroConstOp =
17811814
rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
1782-
rewriter.replaceOpWithNewOp<aievec::ExtElemOp>(srcOp, scalarType, concatOp,
1783-
zeroConstOp.getResult());
1815+
auto reduceResultOp = rewriter.create<aievec::ExtElemOp>(
1816+
srcOp.getLoc(), scalarType, concatOp, zeroConstOp.getResult());
1817+
1818+
if (srcOp.getAcc())
1819+
rewriter.replaceOpWithNewOp<arith::AddFOp>(
1820+
srcOp, reduceResultOp.getResult(), srcOp.getAcc());
1821+
else
1822+
rewriter.replaceOp(srcOp, reduceResultOp);
17841823
return success();
17851824
}
17861825
};
@@ -3125,6 +3164,7 @@ static void populateAIEVecV2ConversionPatterns(RewritePatternSet &patterns,
31253164
LowerVectorMinimumFOpToAIEVecMinOp,
31263165
LowerVectorMaxSIOpToAIEVecMaxOp,
31273166
LowerVectorMaximumFOpToAIEVecMaxOp,
3167+
LowerVectorMaxNumFFOpToAIEVecMaxOp,
31283168
LowerVectorCmpIOpToAIEVecCmpOp,
31293169
LowerVectorCmpFOpToAIEVecCmpOp,
31303170
LowerVectorSelectOpToAIEVecSelOp,
@@ -3705,6 +3745,17 @@ static void configureAIEVecV2Legalizations(ConversionTarget &target,
37053745
return !elWidthSet.count(resultElWidth) || laneSize * resultElWidth != 512;
37063746
});
37073747

3748+
target.addDynamicallyLegalOp<arith::MaxNumFOp>([=](arith::MaxNumFOp op) {
3749+
auto resultType = dyn_cast<VectorType>(op.getType());
3750+
if (!resultType)
3751+
return true;
3752+
3753+
auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3754+
unsigned laneSize = getVectorLaneSize(resultType);
3755+
3756+
return !elWidthSet.count(resultElWidth) || laneSize * resultElWidth != 512;
3757+
});
3758+
37083759
target.addDynamicallyLegalOp<arith::CmpIOp>([=](arith::CmpIOp op) {
37093760
auto lhsType = dyn_cast<VectorType>(op.getLhs().getType());
37103761
if (!lhsType)
@@ -3746,7 +3797,8 @@ static void configureAIEVecV2Legalizations(ConversionTarget &target,
37463797
kind != vector::CombiningKind::MINIMUMF &&
37473798
kind != vector::CombiningKind::MAXSI &&
37483799
kind != vector::CombiningKind::MAXUI &&
3749-
kind != vector::CombiningKind::MAXIMUMF)
3800+
kind != vector::CombiningKind::MAXIMUMF &&
3801+
kind != vector::CombiningKind::MAXNUMF)
37503802
return true;
37513803

37523804
auto vType = dyn_cast<VectorType>(op.getVector().getType());

test/Conversion/VectorToAIEVec/test-reduce.mlir

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,3 +347,91 @@ func.func @reduce_max_bf16(%arg0: vector<32xbf16>) -> bf16 {
347347
// CHECK: return %[[EXTELEM]] : bf16
348348
return %0 : bf16
349349
}
350+
351+
// CHECK-LABEL:func @reduce_add_f32_w_acc
352+
// CHECK-SAME: %[[SRC:.*]]: vector<16xf32>
353+
func.func @reduce_add_f32_w_acc(%arg0: vector<16xf32>) -> f32 {
354+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
355+
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : i32
356+
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : i32
357+
// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : i32
358+
// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : i32
359+
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
360+
// CHECK: %[[SHIFT32:.*]] = aievec.shift %[[SRC]], %[[SRC]], %[[C32]] {isAcc = false} : vector<16xf32>, vector<16xf32>, i32, vector<16xf32>
361+
// CHECK: %[[CASTL1:.*]] = aievec.cast %[[SRC]] {isResAcc = true} : vector<16xf32>, vector<16xf32>
362+
// CHECK: %[[CASTR1:.*]] = aievec.cast %[[SHIFT32]] {isResAcc = true} : vector<16xf32>, vector<16xf32>
363+
// CHECK: %[[ADD1:.*]] = aievec.add_elem %[[CASTL1]], %[[CASTR1]] : vector<16xf32>
364+
// CHECK: %[[CAST1:.*]] = aievec.cast %[[ADD1]] {isResAcc = false} : vector<16xf32>, vector<16xf32>
365+
// CHECK: %[[SHIFT16:.*]] = aievec.shift %[[CAST1]], %[[CAST1]], %[[C16]] {isAcc = false} : vector<16xf32>, vector<16xf32>, i32, vector<16xf32>
366+
// CHECK: %[[CASTR2:.*]] = aievec.cast %[[SHIFT16]] {isResAcc = true} : vector<16xf32>, vector<16xf32>
367+
// CHECK: %[[ADD2:.*]] = aievec.add_elem %[[ADD1]], %[[CASTR2]] : vector<16xf32>
368+
// CHECK: %[[CAST2:.*]] = aievec.cast %[[ADD2]] {isResAcc = false} : vector<16xf32>, vector<16xf32>
369+
// CHECK: %[[SHIFT8:.*]] = aievec.shift %[[CAST2]], %[[CAST2]], %[[C8]] {isAcc = false} : vector<16xf32>, vector<16xf32>, i32, vector<16xf32>
370+
// CHECK: %[[CASTR3:.*]] = aievec.cast %[[SHIFT8]] {isResAcc = true} : vector<16xf32>, vector<16xf32>
371+
// CHECK: %[[ADD3:.*]] = aievec.add_elem %[[ADD2]], %[[CASTR3]] : vector<16xf32>
372+
// CHECK: %[[CAST3:.*]] = aievec.cast %[[ADD3]] {isResAcc = false} : vector<16xf32>, vector<16xf32>
373+
// CHECK: %[[SHIFT4:.*]] = aievec.shift %[[CAST3]], %[[CAST3]], %[[C4]] {isAcc = false} : vector<16xf32>, vector<16xf32>, i32, vector<16xf32>
374+
// CHECK: %[[CASTR4:.*]] = aievec.cast %[[SHIFT4]] {isResAcc = true} : vector<16xf32>, vector<16xf32>
375+
// CHECK: %[[ADD4:.*]] = aievec.add_elem %[[ADD3]], %[[CASTR4]] : vector<16xf32>
376+
// CHECK: %[[CAST4:.*]] = aievec.cast %[[ADD4]] {isResAcc = false} : vector<16xf32>, vector<16xf32>
377+
// CHECK: %[[EXTELEM:.*]] = aievec.ext_elem %[[CAST4]], %[[C0]] : vector<16xf32>, i32, f32
378+
// CHECK: %[[ADD5:.*]] = arith.addf %[[EXTELEM]], %[[CST]] : f32
379+
%cst = arith.constant 0.000000e+00 : f32
380+
%0 = vector.reduction <add>, %arg0, %cst : vector<16xf32> into f32
381+
// CHECK: return %[[ADD5]] : f32
382+
return %0 : f32
383+
}
384+
385+
// CHECK-LABEL:func @reduce_min_bf16_w_acc
386+
// CHECK-SAME: %[[SRC:.*]]: vector<32xbf16>
387+
func.func @reduce_min_bf16_w_acc(%arg0: vector<32xbf16>) -> bf16 {
388+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
389+
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : i32
390+
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : i32
391+
// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : i32
392+
// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : i32
393+
// CHECK-DAG: %[[CST:.*]] = arith.constant 9.982440e+08 : bf16
394+
// CHECK: %[[SHIFT32:.*]] = aievec.shift %[[SRC]], %[[SRC]], %[[C32]] {isAcc = false} : vector<32xbf16>, vector<32xbf16>, i32, vector<32xbf16>
395+
// CHECK: %[[MIN1:.*]] = aievec.min %[[SRC]], %[[SHIFT32]] : vector<32xbf16>
396+
// CHECK: %[[SHIFT16:.*]] = aievec.shift %[[MIN1]], %[[MIN1]], %[[C16]] {isAcc = false} : vector<32xbf16>, vector<32xbf16>, i32, vector<32xbf16>
397+
// CHECK: %[[MIN2:.*]] = aievec.min %[[MIN1]], %[[SHIFT16]] : vector<32xbf16>
398+
// CHECK: %[[SHIFT8:.*]] = aievec.shift %[[MIN2]], %[[MIN2]], %[[C8]] {isAcc = false} : vector<32xbf16>, vector<32xbf16>, i32, vector<32xbf16>
399+
// CHECK: %[[MIN3:.*]] = aievec.min %[[MIN2]], %[[SHIFT8]] : vector<32xbf16>
400+
// CHECK: %[[SHIFT4:.*]] = aievec.shift %[[MIN3]], %[[MIN3]], %[[C4]] {isAcc = false} : vector<32xbf16>, vector<32xbf16>, i32, vector<32xbf16>
401+
// CHECK: %[[MIN4:.*]] = aievec.min %[[MIN3]], %[[SHIFT4]] : vector<32xbf16>
402+
// CHECK: %[[SHIFT5:.*]] = aievec.shift %[[MIN4]], %[[MIN4]], %[[C2]] {isAcc = false} : vector<32xbf16>, vector<32xbf16>, i32, vector<32xbf16>
403+
// CHECK: %[[MIN5:.*]] = aievec.min %[[MIN4]], %[[SHIFT5]] : vector<32xbf16>
404+
// CHECK: %[[EXTELEM:.*]] = aievec.ext_elem %[[MAX5]], %[[C0]] : vector<32xbf16>, i32, bf16
405+
// CHECK: %[[MIN6:.*]] = arith.minimumf %[[EXTELEM]], %[[CST]] : bf16
406+
%cst = arith.constant 1.0e+09 : bf16
407+
%0 = vector.reduction <minimumf>, %arg0, %cst : vector<32xbf16> into bf16
408+
// CHECK: return %[[MIN6]] : bf16
409+
return %0 : bf16
410+
}
411+
412+
// CHECK-LABEL:func @reduce_max_bf16_w_acc
413+
// CHECK-SAME: %[[SRC:.*]]: vector<32xbf16>
414+
func.func @reduce_max_bf16_w_acc(%arg0: vector<32xbf16>) -> bf16 {
415+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
416+
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : i32
417+
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : i32
418+
// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : i32
419+
// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : i32
420+
// CHECK-DAG: %[[CST:.*]] = arith.constant -9.982440e+08 : bf16
421+
// CHECK: %[[SHIFT32:.*]] = aievec.shift %[[SRC]], %[[SRC]], %[[C32]] {isAcc = false} : vector<32xbf16>, vector<32xbf16>, i32, vector<32xbf16>
422+
// CHECK: %[[MAX1:.*]] = aievec.max %[[SRC]], %[[SHIFT32]] : vector<32xbf16>
423+
// CHECK: %[[SHIFT16:.*]] = aievec.shift %[[MAX1]], %[[MAX1]], %[[C16]] {isAcc = false} : vector<32xbf16>, vector<32xbf16>, i32, vector<32xbf16>
424+
// CHECK: %[[MAX2:.*]] = aievec.max %[[MAX1]], %[[SHIFT16]] : vector<32xbf16>
425+
// CHECK: %[[SHIFT8:.*]] = aievec.shift %[[MAX2]], %[[MAX2]], %[[C8]] {isAcc = false} : vector<32xbf16>, vector<32xbf16>, i32, vector<32xbf16>
426+
// CHECK: %[[MAX3:.*]] = aievec.max %[[MAX2]], %[[SHIFT8]] : vector<32xbf16>
427+
// CHECK: %[[SHIFT4:.*]] = aievec.shift %[[MAX3]], %[[MAX3]], %[[C4]] {isAcc = false} : vector<32xbf16>, vector<32xbf16>, i32, vector<32xbf16>
428+
// CHECK: %[[MAX4:.*]] = aievec.max %[[MAX3]], %[[SHIFT4]] : vector<32xbf16>
429+
// CHECK: %[[SHIFT5:.*]] = aievec.shift %[[MAX4]], %[[MAX4]], %[[C2]] {isAcc = false} : vector<32xbf16>, vector<32xbf16>, i32, vector<32xbf16>
430+
// CHECK: %[[MAX5:.*]] = aievec.max %[[MAX4]], %[[SHIFT5]] : vector<32xbf16>
431+
// CHECK: %[[EXTELEM:.*]] = aievec.ext_elem %[[MAX5]], %[[C0]] : vector<32xbf16>, i32, bf16
432+
// CHECK: %[[MAX6:.*]] = arith.maximumf %[[EXTELEM]], %[[CST]] : bf16
433+
%cst = arith.constant -1.0e+09 : bf16
434+
%0 = vector.reduction <maximumf>, %arg0, %cst : vector<32xbf16> into bf16
435+
// CHECK: return %[[MAX6]] : bf16
436+
return %0 : bf16
437+
}

0 commit comments

Comments
 (0)