@@ -1718,11 +1718,9 @@ static bool isBroadcastLike(Operation *op) {
17181718 return false ;
17191719
17201720 // Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3).
1721- // Condition 1: dst has hight rank.
1722- // Condition 2: src shape is a suffix of dst shape.
1723- //
17241721 // Note that checking that dst shape has a prefix of 1s is not sufficient,
1725- // for example (2,3) -> (1,3,2) is not broadcast-like.
1722+ // for example (2,3) -> (1,3,2) is not broadcast-like. A sufficient condition
1723+ // is that the source shape is a suffix of the destination shape.
17261724 VectorType srcType = shapeCast.getSourceVectorType ();
17271725 ArrayRef<int64_t > srcShape = srcType.getShape ();
17281726 uint64_t srcRank = srcType.getRank ();
@@ -1734,16 +1732,16 @@ static bool isBroadcastLike(Operation *op) {
17341732// /
17351733// / Example:
17361734// /
1737- // / broadcast extract
1738- // / (3, 4) --------> (2, 3, 4) ------> (4)
1735+ // / broadcast extract [1][2]
1736+ // / (3, 4) --------> (2, 3, 4) ---------------- > (4)
17391737// /
17401738// / becomes
1741- // / extract
1742- // / (3,4) ---------------------------> (4)
1739+ // / extract [1]
1740+ // / (3,4) ------------------------------------- > (4)
17431741// /
17441742// /
1745- // / The variable names used in this implementation use names which correspond to
1746- // / the above shapes as,
1743+ // / The variable names used in this implementation correspond to the above
1744+ // / shapes as,
17471745// /
17481746// / - (3, 4) is `input` shape.
17491747// / - (2, 3, 4) is `broadcast` shape.
@@ -1775,14 +1773,15 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
17751773 if (extractRank > inputRank)
17761774 return Value ();
17771775
1778- // Proof by contradiction that, at this point, input is a vector.
1779- // Suppose input is a scalar.
1780- // ==> inputRank is 0.
1781- // ==> extractRank is 0 (because extractRank <= inputRank).
1782- // ==> extract is scalar (because rank-0 extraction is always scalar).
1783- // ==> input and extract are scalar, so same type.
1784- // ==> returned early (check same type).
1785- // Contradiction!
1776+ // The above condition guarantees that input is a vector:
1777+ //
1778+ // If input is a scalar:
1779+ // 1) inputRank is 0, so
1780+ // 2) extractRank is 0 (because extractRank <= inputRank), so
1781+ // 3) extract is scalar (because rank-0 extraction is always scalar), s0
1782+ // 4) input and extract are scalar, so same type.
1783+ // But then we should have returned earlier when the types were compared for
1784+ // equivalence. So input is not a scalar at this point.
17861785 assert (inputType && " input must be a vector type because of previous checks" );
17871786 ArrayRef<int64_t > inputShape = inputType.getShape ();
17881787
0 commit comments