@@ -4795,6 +4795,44 @@ static SmallVector<Value> getNewOperands(DestinationStyleOpInterface op,
47954795 return newOperands;
47964796}
47974797
4798+ // Given the (potentially) updated packed type, `newPackedTy`, generates an
4799+ // updated mixed-tile-sizes attribute. A tile size is updated only
4800+ // when:
4801+ // * a dim from newPackedTy is static, and
4802+ // * the corresponding size from mixedTiles is still dynamic.
4803+ // Otherwise, the original tile size is preserved.
4804+ // Note - packed-type-dim and mixed-tile-size should always match!
4805+ static SmallVector<OpFoldResult>
4806+ getNewMixedTileSizes (PatternRewriter &rewriter, Type newPackedTy,
4807+ SmallVector<OpFoldResult> mixedTiles) {
4808+ SmallVector<OpFoldResult> newMixedTileSizes;
4809+ for (auto it : llvm::zip (cast<ShapedType>(newPackedTy)
4810+ .getShape ()
4811+ .take_back (mixedTiles.size ()),
4812+ mixedTiles)) {
4813+ int64_t shape = std::get<0 >(it);
4814+ if (shape == ShapedType::kDynamic ) {
4815+ newMixedTileSizes.push_back (std::get<1 >(it));
4816+ continue ;
4817+ }
4818+
4819+ // If the current result dim is static, update the dynamic mixed-size
4820+ // (provided the original value is dynamic).
4821+ OpFoldResult tile = std::get<1 >(it);
4822+ if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tile)) {
4823+ // Already a constant
4824+ newMixedTileSizes.push_back (tile);
4825+ } else {
4826+ assert (getConstantIntValue (tile).value () == shape &&
4827+ " tile size and dim size don't match!" );
4828+ newMixedTileSizes.push_back (
4829+ (rewriter.getIntegerAttr (rewriter.getIndexType (), shape)));
4830+ }
4831+ }
4832+
4833+ return newMixedTileSizes;
4834+ }
4835+
47984836// / Folds a tensor.cast op into a consuming tensor::PackOp op if the
47994837// / `tensor.cast` has source that is more static than the consuming op.
48004838// /
@@ -4821,28 +4859,8 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
48214859 SmallVector<Value> newOperands = getNewOperands (op, newResultTypes);
48224860
48234861 // Get the updated mixed-tile-sizes attribute.
4824- SmallVector<OpFoldResult> newMixedTileSizes;
4825- for (auto it : llvm::zip (cast<ShapedType>(newResultTypes[0 ])
4826- .getShape ()
4827- .take_back (op.getMixedTiles ().size ()),
4828- op.getMixedTiles ())) {
4829- int64_t shape = std::get<0 >(it);
4830- if (shape == ShapedType::kDynamic ) {
4831- newMixedTileSizes.push_back (std::get<1 >(it));
4832- continue ;
4833- }
4834-
4835- if (Attribute attr =
4836- llvm::dyn_cast_if_present<Attribute>(std::get<1 >(it))) {
4837- // Already a constant
4838- newMixedTileSizes.push_back (std::get<1 >(it));
4839- } else {
4840- assert (getConstantIntValue (std::get<1 >(it)).value () == shape &&
4841- " tile size and dim size don't match!" );
4842- newMixedTileSizes.push_back (
4843- (rewriter.getIntegerAttr (rewriter.getIndexType (), shape)));
4844- }
4845- }
4862+ SmallVector<OpFoldResult> newMixedTileSizes =
4863+ getNewMixedTileSizes (rewriter, newResultTypes[0 ], op.getMixedTiles ());
48464864
48474865 // Clone op.
48484866 // TODO: Strictly speaking, discardable attributes should be _discarded_ at
@@ -4873,7 +4891,7 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
48734891// / Example:
48744892// / ```mlir
48754893// / %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
4876- // / %2 = tensor.unpack %1 ... : tensor<1x1x8x1xi32 > -> tensor<7x?xi32>
4894+ // / %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32 > -> tensor<7x?xi32>
48774895// / ```
48784896// /
48794897// / folds into:
@@ -4894,32 +4912,8 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
48944912 Value sourceTensor = newOperands[0 ];
48954913
48964914 // Get the updated mixed-tile-sizes attribute.
4897- SmallVector<OpFoldResult> newMixedTileSizes;
4898- for (auto it : llvm::zip (cast<ShapedType>(sourceTensor.getType ())
4899- .getShape ()
4900- .take_back (op.getMixedTiles ().size ()),
4901- op.getMixedTiles ())) {
4902- int64_t shape = std::get<0 >(it);
4903- // If the current source shape is dynamic, just preserve this mixed
4904- // size.
4905- if (shape == ShapedType::kDynamic ) {
4906- newMixedTileSizes.push_back (std::get<1 >(it));
4907- continue ;
4908- }
4909-
4910- // If the current source is static, update the dynamic mixed-size
4911- // (provided the original value is dynamic).
4912- if (Attribute attr =
4913- llvm::dyn_cast_if_present<Attribute>(std::get<1 >(it))) {
4914- // Already a constant
4915- newMixedTileSizes.push_back (std::get<1 >(it));
4916- } else {
4917- assert (getConstantIntValue (std::get<1 >(it)).value () == shape &&
4918- " tile size and dim size don't match!" );
4919- newMixedTileSizes.push_back (
4920- (rewriter.getIntegerAttr (rewriter.getIndexType (), shape)));
4921- }
4922- }
4915+ SmallVector<OpFoldResult> newMixedTileSizes = getNewMixedTileSizes (
4916+ rewriter, sourceTensor.getType (), op.getMixedTiles ());
49234917
49244918 // Clone op.
49254919 // TODO: Strictly speaking, discardable attributes should be _discarded_ at
0 commit comments