Skip to content

Commit edd033d

Browse files
committed
Apply suggestion from code review
1 parent e434b53 commit edd033d

File tree

1 file changed

+11
-15
lines changed

1 file changed

+11
-15
lines changed

src/Dialect/ONNX/ONNXOps/Canonicalize.cpp

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1668,10 +1668,7 @@ namespace {
16681668
if (!elementsAttr.isSplat()) {
16691669
return false;
16701670
}
1671-
if (!elementsAttr.template getSplatValue<APFloat>().isZero()) {
1672-
return false;
1673-
}
1674-
return true;
1671+
return elementsAttr.template getSplatValue<APFloat>().isZero();
16751672
}
16761673

16771674
template <typename LN_TYPE, typename MATCH_OP_TYPE,
@@ -1687,9 +1684,9 @@ struct PropagateBiasOrScaleIntoLayerNormRewritePatternBase
16871684
[[nodiscard]] virtual bool doExisitingScaleAndBiasAllowFusion(
16881685
LN_TYPE lnOp) const = 0;
16891686

1690-
LogicalResult verifyAndCalculateNewReshapeShapes(Operation *reshapeOp,
1691-
MATCH_OP_TYPE matchOp, PatternRewriter &rewriter, Value scaleOrBias,
1692-
SmallVectorImpl<int64_t> &newScaleOrBiasShape) const {
1687+
FailureOr<SmallVector<int64_t>> verifyAndCalculateNewReshapeShapes(
1688+
Operation *reshapeOp, MATCH_OP_TYPE matchOp, PatternRewriter &rewriter,
1689+
Value scaleOrBias) const {
16931690
// if we have a reshape, check that the add/mul is not changing the shape
16941691
// by broadcasting
16951692
auto reshapeResultType =
@@ -1770,11 +1767,11 @@ struct PropagateBiasOrScaleIntoLayerNormRewritePatternBase
17701767
matchOp, "bias/scale shape is not compatible with reshape input");
17711768
}
17721769
}
1773-
1770+
SmallVector<int64_t> newScaleOrBiasShape;
17741771
newScaleOrBiasShape.push_back(reshapeInShape[reshapeInComputationDim]);
17751772
newScaleOrBiasShape.append(
17761773
reshapeInShape.size() - reshapeInComputationDim - 1, 1);
1777-
return success();
1774+
return newScaleOrBiasShape;
17781775
}
17791776

17801777
LogicalResult matchAndRewrite(
@@ -1827,10 +1824,12 @@ struct PropagateBiasOrScaleIntoLayerNormRewritePatternBase
18271824
lnOp, "existing scale and bias do not allow fusion");
18281825

18291826
if (reshapeOp) {
1830-
if (failed(verifyAndCalculateNewReshapeShapes(reshapeOp, matchOp,
1831-
rewriter, scaleOrBias, newScaleOrBiasShape))) {
1827+
auto newShape = verifyAndCalculateNewReshapeShapes(
1828+
reshapeOp, matchOp, rewriter, scaleOrBias);
1829+
if (failed(newShape)) {
18321830
return failure();
18331831
}
1832+
newScaleOrBiasShape = std::move(*newShape);
18341833
}
18351834

18361835
// Norms only support unidirectional broadcasting to x
@@ -1888,10 +1887,7 @@ struct PropagateScaleIntoLayerNormPattern
18881887
if (!elementsAttr.isSplat()) {
18891888
return false;
18901889
}
1891-
if (!elementsAttr.template getSplatValue<APFloat>().isExactlyValue(1.0)) {
1892-
return false;
1893-
}
1894-
return true;
1890+
return elementsAttr.template getSplatValue<APFloat>().isExactlyValue(1.0);
18951891
}
18961892
};
18971893

0 commit comments

Comments
 (0)