@@ -579,62 +579,62 @@ struct RecomposeLayerNormFromMulPattern : public OpRewritePattern<ONNXMulOp> {
579579
580580private:
581581 // Return the reduced dimensions as bit vector.
582- static FailureOr<llvm::SmallBitVector> getReducedAxis (
583- Operation *op, int64_t xRank, int64_t &axis ) {
582+ static bool getReducedAxis (Operation *op, int64_t xRank, int64_t &axis,
583+ llvm::SmallBitVector &reducedAxis ) {
584584 SmallVector<int64_t > axes; // The axes attribute/operand of the ReduceMeanOp
585+ reducedAxis.resize (xRank, false );
585586 if (auto reduceOpV13 = mlir::dyn_cast<ONNXReduceMeanV13Op>(op)) {
586587 if (reduceOpV13.getKeepdims () != 1 )
587- return success ( reportFailure (" need keepdims = 1" ) );
588+ return reportFailure (" need keepdims = 1" );
588589 ArrayAttr axesAttr = reduceOpV13.getAxesAttr ();
589590 for (size_t i = 0 ; i < axesAttr.size (); ++i) {
590591 axes.emplace_back (onnx_mlir::ArrayAttrIntVal (axesAttr, i));
591592 }
592593 } else if (auto reduceOp = mlir::dyn_cast<ONNXReduceMeanOp>(op)) {
593594 if (reduceOp.getKeepdims () != 1 )
594- return success ( reportFailure (" need keepdims = 1" ) );
595+ return reportFailure (" need keepdims = 1" );
595596 Value axesValue = reduceOp.getAxes ();
596597 if (isa<NoneType>(axesValue.getType ())) {
597598 if (reduceOp.getNoopWithEmptyAxes ()) {
598599 // No reduction
599- return success (
600- reportFailure (" needs a reduction on at least one dimension" ));
600+ return reportFailure (" needs a reduction on at least one dimension" );
601601 } else {
602602 // Reduction on all dimensions
603603 axis = 0 ;
604- return llvm::SmallBitVector (xRank, true );
604+ reducedAxis.set (0 , xRank);
605+ return true ;
605606 }
606607 }
607608 if (!onnx_mlir::getI64ValuesFromONNXConstantOp (axesValue, axes)) {
608- return success ( reportFailure (" only static axes are supported" ) );
609+ return reportFailure (" only static axes are supported" );
609610 }
610611 } else {
611612 llvm_unreachable (" ReduceMean is the only supported op" );
612613 }
613614
614615 // Record axes value in bit vector.
615- llvm::SmallBitVector reduceAxes (xRank, false );
616616 for (int64_t axe : axes) {
617617 int64_t a = onnx_mlir::getAxisInRange (axe, xRank);
618- reduceAxes [a] = true ;
618+ reducedAxis [a] = true ;
619619 }
620- return reduceAxes ;
620+ return true ;
621621 }
622622
623623 // Check if the axis is suitable for Layernorm.
624624 static bool suitableAxis (Operation *op, int64_t xRank, int64_t &axis) {
625- auto reduceAxes = getReducedAxis (op, xRank, axis) ;
626- if (failed (reduceAxes ))
625+ llvm::SmallBitVector reducedAxis ;
626+ if (! getReducedAxis (op, xRank, axis, reducedAxis ))
627627 return false ;
628628
629629 // Check that we have a "false"* "true"+ pattern.
630630 bool foundFirstAxis = false ;
631631 for (int64_t i = 0 ; i < xRank; ++i) {
632632 if (!foundFirstAxis) {
633- if (reduceAxes. value () [i]) {
633+ if (reducedAxis [i]) {
634634 foundFirstAxis = true ;
635635 axis = i;
636636 }
637- } else if (!reduceAxes. value () [i]) {
637+ } else if (!reducedAxis [i]) {
638638 // Once we found an axis, we must reduce all subsequent dimensions.
639639 return false ;
640640 }
@@ -651,14 +651,14 @@ struct RecomposeLayerNormFromMulPattern : public OpRewritePattern<ONNXMulOp> {
651651 // represented as Layernorm with axis = 2.
652652 static FailureOr<SmallVector<int64_t >> isAxisSuitableWithTranspose (
653653 Operation *op, int64_t xRank, int64_t &axis) {
654- auto reduceAxes = getReducedAxis (op, xRank, axis) ;
655- if (failed (reduceAxes ))
654+ llvm::SmallBitVector reducedAxis ;
655+ if (! getReducedAxis (op, xRank, axis, reducedAxis ))
656656 return failure ();
657657
658658 SmallVector<int64_t > reducedIdx;
659659 SmallVector<int64_t > nonReducedIdx;
660660 for (int64_t i = 0 ; i < xRank; ++i) {
661- auto &array = reduceAxes. value () [i] ? reducedIdx : nonReducedIdx;
661+ auto &array = reducedAxis [i] ? reducedIdx : nonReducedIdx;
662662 array.push_back (i);
663663 }
664664
0 commit comments