@@ -1668,10 +1668,7 @@ namespace {
16681668 if (!elementsAttr.isSplat ()) {
16691669 return false ;
16701670 }
1671- if (!elementsAttr.template getSplatValue <APFloat>().isZero ()) {
1672- return false ;
1673- }
1674- return true ;
1671+ return elementsAttr.template getSplatValue <APFloat>().isZero ();
16751672}
16761673
16771674template <typename LN_TYPE, typename MATCH_OP_TYPE,
@@ -1687,9 +1684,9 @@ struct PropagateBiasOrScaleIntoLayerNormRewritePatternBase
16871684 [[nodiscard]] virtual bool doExisitingScaleAndBiasAllowFusion (
16881685 LN_TYPE lnOp) const = 0;
16891686
1690- LogicalResult verifyAndCalculateNewReshapeShapes (Operation *reshapeOp,
1691- MATCH_OP_TYPE matchOp, PatternRewriter &rewriter, Value scaleOrBias ,
1692- SmallVectorImpl< int64_t > &newScaleOrBiasShape ) const {
1687+ FailureOr<SmallVector< int64_t >> verifyAndCalculateNewReshapeShapes (
1688+ Operation *reshapeOp, MATCH_OP_TYPE matchOp, PatternRewriter &rewriter,
1689+ Value scaleOrBias ) const {
16931690 // if we have a reshape, check that the add/mul is not changing the shape
16941691 // by broadcasting
16951692 auto reshapeResultType =
@@ -1770,11 +1767,11 @@ struct PropagateBiasOrScaleIntoLayerNormRewritePatternBase
17701767 matchOp, " bias/scale shape is not compatible with reshape input" );
17711768 }
17721769 }
1773-
1770+ SmallVector< int64_t > newScaleOrBiasShape;
17741771 newScaleOrBiasShape.push_back (reshapeInShape[reshapeInComputationDim]);
17751772 newScaleOrBiasShape.append (
17761773 reshapeInShape.size () - reshapeInComputationDim - 1 , 1 );
1777- return success () ;
1774+ return newScaleOrBiasShape ;
17781775 }
17791776
17801777 LogicalResult matchAndRewrite (
@@ -1827,10 +1824,12 @@ struct PropagateBiasOrScaleIntoLayerNormRewritePatternBase
18271824 lnOp, " existing scale and bias do not allow fusion" );
18281825
18291826 if (reshapeOp) {
1830- if (failed (verifyAndCalculateNewReshapeShapes (reshapeOp, matchOp,
1831- rewriter, scaleOrBias, newScaleOrBiasShape))) {
1827+ auto newShape = verifyAndCalculateNewReshapeShapes (
1828+ reshapeOp, matchOp, rewriter, scaleOrBias);
1829+ if (failed (newShape)) {
18321830 return failure ();
18331831 }
1832+ newScaleOrBiasShape = std::move (*newShape);
18341833 }
18351834
18361835 // Norms only support unidirectional broadcasting to x
@@ -1888,10 +1887,7 @@ struct PropagateScaleIntoLayerNormPattern
18881887 if (!elementsAttr.isSplat ()) {
18891888 return false ;
18901889 }
1891- if (!elementsAttr.template getSplatValue <APFloat>().isExactlyValue (1.0 )) {
1892- return false ;
1893- }
1894- return true ;
1890+ return elementsAttr.template getSplatValue <APFloat>().isExactlyValue (1.0 );
18951891 }
18961892};
18971893
0 commit comments