@@ -1069,27 +1069,34 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
10691069 return commonConversionPrecondition (rewriter, preconditionType, op);
10701070}
10711071
1072- // / Verify that source and destination element types meet the precondition for
1073- // / the supported aligned conversion cases. Alignment means that the either the
1074- // / source element type is multiple of the destination element type or the other
1075- // / way around.
1072+ // / Verify that `subByteVecType` and `dstType` are aligned. Alignment
1073+ // / means that:
1074+ // / 1. The `dstType` element type is a multiple of the
1075+ // / `srcVectorOfSubByteType` element type (e.g. i4 vs i8 is OK, but i3 vs i8
1076+ // / is not supported). Let this multiple be `N`.
1077+ // / 2. The number of the (trailing) elements in `srcVectorOfSubByteType` is a
1078+ // / multiple of `N` from 1. (e.g., when targetting i8, 2xi4 is OK, but 3xi4 is
1079+ // / not supported).
10761080// /
1077- // / NOTE: This method assumes that common conversion preconditions are met.
1081+ // / NOTE: This method assumes that common conversion preconditions are met. In
1082+ // / particular, the element type of `dstType` is assumed to be a multi-byte
1083+ // / type (e.g. i8, i16, i32).
10781084static LogicalResult alignedConversionPrecondition (PatternRewriter &rewriter,
1079- VectorType srcType ,
1085+ VectorType subByteVecType ,
10801086 VectorType dstType,
10811087 Operation *op) {
1082- if (!srcType || !dstType)
1088+ if (!subByteVecType || !dstType)
10831089 return rewriter.notifyMatchFailure (op, " Not a supported aligned case" );
1084- unsigned srcElemBitwidth = srcType .getElementTypeBitWidth ();
1090+ unsigned srcElemBitwidth = subByteVecType .getElementTypeBitWidth ();
10851091 unsigned dstElemBitwidth = dstType.getElementTypeBitWidth ();
10861092
10871093 // Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
10881094 if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
10891095 (dstElemBitwidth % srcElemBitwidth) != 0 )
10901096 return rewriter.notifyMatchFailure (op, " Not a supported aligned case" );
10911097
1092- if ((srcType.getShape ().back () % 2 ) != 0 )
1098+ const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth;
1099+ if ((subByteVecType.getShape ().back () % numSrcElemsPerDestElem) != 0 )
10931100 return rewriter.notifyMatchFailure (
10941101 op, " Not an even number of i4 elements in trailing dim" );
10951102
0 commit comments