@@ -582,6 +582,15 @@ static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter,
582582
583583namespace {
584584
585+ // / Helper functon to return the index of the last dynamic dimension in `shape`.
586+ int64_t lastDynIndex (ArrayRef<int64_t > shape) {
587+ return static_cast <int64_t >(
588+ std::distance (
589+ std::find (shape.rbegin (), shape.rend (), ShapedType::kDynamic ),
590+ shape.rend ()) -
591+ 1 );
592+ }
593+
585594// / Rewrites contiguous row-major vector.transfer_read ops by inserting
586595// / memref.collapse_shape on the source so that the resulting
587596// / vector.transfer_read has a 1D source. Requires the source shape to be
@@ -631,8 +640,9 @@ class FlattenContiguousRowMajorTransferReadPattern
631640 return failure ();
632641
633642 // Determinine the first memref dimension to collapse
634- int64_t firstDimToCollapse =
635- sourceType.getRank () - sourceType.getMaxCollapsableTrailingDims ();
643+ int64_t firstDimToCollapse = std::max (
644+ lastDynIndex (sourceType.getShape ()),
645+ sourceType.getRank () - sourceType.getMaxCollapsableTrailingDims ());
636646
637647 // 1. Collapse the source memref
638648 Value collapsedSource =
@@ -725,8 +735,9 @@ class FlattenContiguousRowMajorTransferWritePattern
725735 return failure ();
726736
727737 // Determinine the first memref dimension to collapse
728- int64_t firstDimToCollapse =
729- sourceType.getRank () - sourceType.getMaxCollapsableTrailingDims ();
738+ int64_t firstDimToCollapse = std::max (
739+ lastDynIndex (sourceType.getShape ()),
740+ sourceType.getRank () - sourceType.getMaxCollapsableTrailingDims ());
730741
731742 // 1. Collapse the source memref
732743 Value collapsedSource =
0 commit comments