@@ -1069,27 +1069,34 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
1069
1069
return commonConversionPrecondition (rewriter, preconditionType, op);
1070
1070
}
1071
1071
1072
- // / Verify that source and destination element types meet the precondition for
1073
- // / the supported aligned conversion cases. Alignment means that the either the
1074
- // / source element type is multiple of the destination element type or the other
1075
- // / way around.
1072
+ // / Verify that `subByteVecType` and `dstType` are aligned. Alignment
1073
+ // / means that:
1074
+ // / 1. The `dstType` element type is a multiple of the
1075
+ // / `srcVectorOfSubByteType` element type (e.g. i4 vs i8 is OK, but i3 vs i8
1076
+ // / is not supported). Let this multiple be `N`.
1077
+ // / 2. The number of the (trailing) elements in `srcVectorOfSubByteType` is a
1078
+ // / multiple of `N` from 1. (e.g., when targetting i8, 2xi4 is OK, but 3xi4 is
1079
+ // / not supported).
1076
1080
// /
1077
- // / NOTE: This method assumes that common conversion preconditions are met.
1081
+ // / NOTE: This method assumes that common conversion preconditions are met. In
1082
+ // / particular, the element type of `dstType` is assumed to be a multi-byte
1083
+ // / type (e.g. i8, i16, i32).
1078
1084
static LogicalResult alignedConversionPrecondition (PatternRewriter &rewriter,
1079
- VectorType srcType ,
1085
+ VectorType subByteVecType ,
1080
1086
VectorType dstType,
1081
1087
Operation *op) {
1082
- if (!srcType || !dstType)
1088
+ if (!subByteVecType || !dstType)
1083
1089
return rewriter.notifyMatchFailure (op, " Not a supported aligned case" );
1084
- unsigned srcElemBitwidth = srcType .getElementTypeBitWidth ();
1090
+ unsigned srcElemBitwidth = subByteVecType .getElementTypeBitWidth ();
1085
1091
unsigned dstElemBitwidth = dstType.getElementTypeBitWidth ();
1086
1092
1087
1093
// Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
1088
1094
if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
1089
1095
(dstElemBitwidth % srcElemBitwidth) != 0 )
1090
1096
return rewriter.notifyMatchFailure (op, " Not a supported aligned case" );
1091
1097
1092
- if ((srcType.getShape ().back () % 2 ) != 0 )
1098
+ const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth;
1099
+ if ((subByteVecType.getShape ().back () % numSrcElemsPerDestElem) != 0 )
1093
1100
return rewriter.notifyMatchFailure (
1094
1101
op, " Not an even number of i4 elements in trailing dim" );
1095
1102
0 commit comments