Skip to content

Commit d054bd0

Browse files
committed
fix: refactor suitableAxis func in recompose pass
1 parent fb69354 commit d054bd0

File tree

1 file changed

+18
-18
lines changed

1 file changed

+18
-18
lines changed

src/Dialect/ONNX/Transforms/Recompose.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -579,62 +579,62 @@ struct RecomposeLayerNormFromMulPattern : public OpRewritePattern<ONNXMulOp> {
579579

580580
private:
581581
// Return the reduced dimensions as bit vector.
582-
static FailureOr<llvm::SmallBitVector> getReducedAxis(
583-
Operation *op, int64_t xRank, int64_t &axis) {
582+
static bool getReducedAxis(Operation *op, int64_t xRank, int64_t &axis,
583+
llvm::SmallBitVector &reducedAxis) {
584584
SmallVector<int64_t> axes; // The axes attribute/operand of the ReduceMeanOp
585+
reducedAxis.resize(xRank, false);
585586
if (auto reduceOpV13 = mlir::dyn_cast<ONNXReduceMeanV13Op>(op)) {
586587
if (reduceOpV13.getKeepdims() != 1)
587-
return success(reportFailure("need keepdims = 1"));
588+
return reportFailure("need keepdims = 1");
588589
ArrayAttr axesAttr = reduceOpV13.getAxesAttr();
589590
for (size_t i = 0; i < axesAttr.size(); ++i) {
590591
axes.emplace_back(onnx_mlir::ArrayAttrIntVal(axesAttr, i));
591592
}
592593
} else if (auto reduceOp = mlir::dyn_cast<ONNXReduceMeanOp>(op)) {
593594
if (reduceOp.getKeepdims() != 1)
594-
return success(reportFailure("need keepdims = 1"));
595+
return reportFailure("need keepdims = 1");
595596
Value axesValue = reduceOp.getAxes();
596597
if (isa<NoneType>(axesValue.getType())) {
597598
if (reduceOp.getNoopWithEmptyAxes()) {
598599
// No reduction
599-
return success(
600-
reportFailure("needs a reduction on at least one dimension"));
600+
return reportFailure("needs a reduction on at least one dimension");
601601
} else {
602602
// Reduction on all dimensions
603603
axis = 0;
604-
return llvm::SmallBitVector(xRank, true);
604+
reducedAxis.set(0, xRank);
605+
return true;
605606
}
606607
}
607608
if (!onnx_mlir::getI64ValuesFromONNXConstantOp(axesValue, axes)) {
608-
return success(reportFailure("only static axes are supported"));
609+
return reportFailure("only static axes are supported");
609610
}
610611
} else {
611612
llvm_unreachable("ReduceMean is the only supported op");
612613
}
613614

614615
// Record axes value in bit vector.
615-
llvm::SmallBitVector reduceAxes(xRank, false);
616616
for (int64_t axe : axes) {
617617
int64_t a = onnx_mlir::getAxisInRange(axe, xRank);
618-
reduceAxes[a] = true;
618+
reducedAxis[a] = true;
619619
}
620-
return reduceAxes;
620+
return true;
621621
}
622622

623623
// Check if the axis is suitable for Layernorm.
624624
static bool suitableAxis(Operation *op, int64_t xRank, int64_t &axis) {
625-
auto reduceAxes = getReducedAxis(op, xRank, axis);
626-
if (failed(reduceAxes))
625+
llvm::SmallBitVector reducedAxis;
626+
if (!getReducedAxis(op, xRank, axis, reducedAxis))
627627
return false;
628628

629629
// Check that we have a "false"* "true"+ pattern.
630630
bool foundFirstAxis = false;
631631
for (int64_t i = 0; i < xRank; ++i) {
632632
if (!foundFirstAxis) {
633-
if (reduceAxes.value()[i]) {
633+
if (reducedAxis[i]) {
634634
foundFirstAxis = true;
635635
axis = i;
636636
}
637-
} else if (!reduceAxes.value()[i]) {
637+
} else if (!reducedAxis[i]) {
638638
// Once we found an axis, we must reduce all subsequent dimensions.
639639
return false;
640640
}
@@ -651,14 +651,14 @@ struct RecomposeLayerNormFromMulPattern : public OpRewritePattern<ONNXMulOp> {
651651
// represented as Layernorm with axis = 2.
652652
static FailureOr<SmallVector<int64_t>> isAxisSuitableWithTranspose(
653653
Operation *op, int64_t xRank, int64_t &axis) {
654-
auto reduceAxes = getReducedAxis(op, xRank, axis);
655-
if (failed(reduceAxes))
654+
llvm::SmallBitVector reducedAxis;
655+
if (!getReducedAxis(op, xRank, axis, reducedAxis))
656656
return failure();
657657

658658
SmallVector<int64_t> reducedIdx;
659659
SmallVector<int64_t> nonReducedIdx;
660660
for (int64_t i = 0; i < xRank; ++i) {
661-
auto &array = reduceAxes.value()[i] ? reducedIdx : nonReducedIdx;
661+
auto &array = reducedAxis[i] ? reducedIdx : nonReducedIdx;
662662
array.push_back(i);
663663
}
664664

0 commit comments

Comments
 (0)