@@ -3033,10 +3033,17 @@ void transform::TileReductionUsingForallOp::build(
3033
3033
}
3034
3034
3035
3035
DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne (
3036
- transform::TransformRewriter &rewriter, LinalgOp target,
3036
+ transform::TransformRewriter &rewriter, Operation * target,
3037
3037
transform::ApplyToEachResultList &results,
3038
3038
transform::TransformState &state) {
3039
3039
rewriter.setInsertionPoint (target);
3040
+
3041
+ auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
3042
+ if (!partialReductionOp) {
3043
+ return emitSilenceableFailure (
3044
+ target->getLoc (),
3045
+ " Operation should implement PartialReductionOpInterface" );
3046
+ }
3040
3047
SmallVector<OpFoldResult> numThreads =
3041
3048
getAsOpFoldResult (rewriter.getI64ArrayAttr (getNumThreads ()));
3042
3049
SmallVector<OpFoldResult> tileSizes =
@@ -3058,14 +3065,14 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
3058
3065
extractFromIntegerArrayAttr<unsigned >(getReductionDims ());
3059
3066
if (reductionDims.empty ()) {
3060
3067
for (auto [idx, iteratorType] :
3061
- llvm::enumerate (target. getIteratorTypesArray ())) {
3068
+ llvm::enumerate (partialReductionOp. getLoopIteratorTypes ())) {
3062
3069
if (iteratorType == utils::IteratorType::reduction)
3063
3070
reductionDims.push_back (idx);
3064
3071
}
3065
3072
}
3066
3073
options.setReductionDims (reductionDims);
3067
- FailureOr<scf::SCFTilingResult> result = scf::tileUsingSCF (
3068
- rewriter, cast<TilingInterface>(target. getOperation ()) , options);
3074
+ FailureOr<scf::SCFTilingResult> result =
3075
+ scf::tileUsingSCF ( rewriter, partialReductionOp , options);
3069
3076
3070
3077
if (failed (result)) {
3071
3078
auto diag = emitSilenceableError () << " could not tile reduction" ;
0 commit comments