@@ -49,27 +49,26 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
4949 int64_t currTargetShape = targetShape[targetDim];
5050 while (sourceDim < (sourceShape.size () - 1 ) &&
5151 sourceShape[sourceDim] != ShapedType::kDynamic &&
52- prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
52+ (currTargetShape == ShapedType::kDynamic ||
53+ prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape)) {
5354 prodOfCollapsedDims *= sourceShape[sourceDim];
5455 currIndices.push_back (sourceDim++);
5556 }
5657
58+ if (sourceDim >= sourceShape.size ())
59+ return std::nullopt ;
60+
5761 // If the current expanded dimension is dynamic, then the collapsed
5862 // dimensions should also be dynamic and product of all previous unprocessed
5963 // dimensions of the expanded shape should be 1.
6064 if (sourceShape[sourceDim] == ShapedType::kDynamic &&
61- (currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1 ))
62- return std::nullopt ;
63-
64- // If the collapsed dim is dynamic, the current expanded dim should also
65- // be dynamic.
66- if (currTargetShape == ShapedType::kDynamic &&
67- sourceShape[sourceDim] != ShapedType::kDynamic )
65+ currTargetShape != ShapedType::kDynamic )
6866 return std::nullopt ;
6967
7068 // For static shapes, if the product of dimensions of the expanded shape
7169 // should match the collapsed dimension shape.
72- if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
70+ if (sourceShape[sourceDim] != ShapedType::kDynamic &&
71+ prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
7372 return std::nullopt ;
7473
7574 currIndices.push_back (sourceDim++);
@@ -315,11 +314,11 @@ SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
315314 // have proven that these are not sliced. In this case we just take
316315 // the full extent of each dimension in the reassociation list.
317316 if (linearizedDimensions[it.index ()]) {
318- llvm::append_range (
319- offsetsSizesAndStrides,
320- llvm::map_range (it. value (), [&]( int64_t idx) -> Range {
321- return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
322- }));
317+ llvm::append_range (offsetsSizesAndStrides,
318+ llvm::map_range (it. value (), [&]( int64_t idx) -> Range {
319+ return {zeroAttr, collapseShapeInputShape[idx],
320+ oneAttr};
321+ }));
323322 continue ;
324323 }
325324
0 commit comments