@@ -798,9 +798,15 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
798798 // transformations such as padding and bufferization since the
799799 // extract/insert slice pairs make the accessed iteration argument
800800 // subdomains explicit.
801-
801+ SmallVector<int64_t > domainSizes;
802+ // FIXME: tileToPartialReduction adds the new init tensor to the output
803+ // but doesn't update the indexing type of the index map causing a crash.
804+ // isAllParallelLoops
805+ if (linalgOp.getNumParallelLoops () == linalgOp.getNumLoops ()) {
806+ domainSizes = linalgOp.getStaticLoopRanges ();
807+ }
802808 Type operandType = opOperand.get ().getType ();
803- if (!isTiled (map, tileSizes, linalgOp. getStaticLoopRanges () ) &&
809+ if (!isTiled (map, tileSizes, domainSizes ) &&
804810 !(isa<RankedTensorType>(operandType) &&
805811 linalgOp.isDpsInit (&opOperand))) {
806812 allSliceParams.push_back (std::nullopt );
@@ -812,7 +818,7 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
812818
813819 allSliceParams.push_back (computeSliceParameters (
814820 builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes,
815- omitPartialTileCheck, linalgOp. getStaticLoopRanges () ));
821+ omitPartialTileCheck, domainSizes ));
816822 }
817823
818824 return allSliceParams;
0 commit comments