@@ -492,7 +492,6 @@ getVectorDistributeReductionConfig(
492492 return loweringConfig;
493493}
494494
495- // TODO: Use IndexingMapInterface here instead of linalg::LinalgOp.
496495static LogicalResult
497496populateConfigInfo (const llvm::SetVector<linalg::LinalgOp> &computeOps,
498497 IREE::GPU::TargetAttr target, int64_t workgroupSize,
@@ -524,28 +523,42 @@ populateConfigInfo(const llvm::SetVector<linalg::LinalgOp> &computeOps,
524523 // LinalgOp with only parallel dims. This is needed if the op cannot be fused
525524 // with a reduction or introduces new loop dimensions.
526525 auto shouldAttachLoweringConfig = [&](linalg::LinalgOp linalgOp) -> bool {
527- // We want to attach a lowering config to this operation if it introduces
528- // a new dimension, when going by topological order in the backward slice.
529- // The only two ways to introduce a new dimension are:
530- //
531- // 1. We have a reduction dimension.
532- if (hasReductionIterator (linalgOp)) {
533- return true ;
534- }
535- // 2. There is no consumer which is a compute op (i.e., it already
536- // has some way of getting fused).
537- if (llvm::none_of (linalgOp->getUsers (), [&](Operation *user) {
526+ // If the operation has a gather, we want to fuse it with the
527+ // reduction.
528+ if (hasExternalCapture (cast<linalg::GenericOp>(linalgOp))) {
529+ return false ;
530+ }
531+ // If some of the users are in computeOps and some are outside of
532+ // computeOps; attach lowering config, since the op can't be fused.
533+ if (llvm::any_of (linalgOp->getUsers (),
534+ [&](Operation *user) {
535+ auto linalgUser = dyn_cast<linalg::LinalgOp>(user);
536+ return linalgUser && computeOps.contains (linalgUser);
537+ }) &&
538+ llvm::any_of (linalgOp->getUsers (), [&](Operation *user) {
538539 auto linalgUser = dyn_cast<linalg::LinalgOp>(user);
539- return linalgUser && computeOps. contains (linalgUser) ;
540+ return ! linalgUser;
540541 })) {
541542 return true ;
542543 }
543544
545+ // If the indexing map introduces new dimensions (more inputs than results),
546+ // attach a lowering config.
547+ for (OpOperand *operand : linalgOp.getDpsInputOperands ()) {
548+ int64_t operandIdx = linalgOp.getIndexingMapIndex (operand);
549+ AffineMap indexingMap = linalgOp.getIndexingMapsArray ()[operandIdx];
550+ if (indexingMap.getNumResults () > 0 &&
551+ indexingMap.getNumInputs () > indexingMap.getNumResults ()) {
552+ return true ;
553+ }
554+ }
555+
544556 return false ;
545557 };
546558
547559 for (linalg::LinalgOp linalgOp : computeOps) {
548- if (shouldAttachLoweringConfig (linalgOp)) {
560+ if (hasReductionIterator (linalgOp) ||
561+ shouldAttachLoweringConfig (linalgOp)) {
549562 auto loweringConfig = getVectorDistributeReductionConfig (
550563 linalgOp, target, sharedWgpTiles, workgroupSize, subgroupSize,
551564 threadLoads);
0 commit comments