@@ -1122,29 +1122,38 @@ ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
11221122static void getGenericEffectsImpl (
11231123 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
11241124 &effects,
1125- ValueRange results, const ValueRange inputOperands,
1126- ValueRange outputOperands) {
1127- for (auto operand : inputOperands) {
1125+ LinalgOp linalgOp) {
1126+ SmallVector<Value> inputOperands = linalgOp. getDpsInputs ();
1127+ for (auto [index, operand] : llvm::enumerate ( inputOperands) ) {
11281128 if (!llvm::isa<MemRefType>(operand.getType ()))
11291129 continue ;
1130- effects.emplace_back (MemoryEffects::Read::get (), operand,
1131- SideEffects::DefaultResource::get ());
1130+ if (linalgOp.payloadUsesValueFromOperand (&linalgOp->getOpOperand (index))) {
1131+ effects.emplace_back (MemoryEffects::Read::get (), operand, /* stage=*/ 0 ,
1132+ /* effectOnFullRegion=*/ true ,
1133+ SideEffects::DefaultResource::get ());
1134+ }
11321135 }
1133- for (auto operand : outputOperands) {
1136+ unsigned inputOperandSize = inputOperands.size ();
1137+
1138+ for (auto [index, operand] : llvm::enumerate (linalgOp.getDpsInits ())) {
11341139 if (!llvm::isa<MemRefType>(operand.getType ()))
11351140 continue ;
1136- effects.emplace_back (MemoryEffects::Read::get (), operand,
1137- SideEffects::DefaultResource::get ());
1138- effects.emplace_back (MemoryEffects::Write::get (), operand,
1141+ if (linalgOp.payloadUsesValueFromOperand (
1142+ &linalgOp->getOpOperand (index + inputOperandSize))) {
1143+ effects.emplace_back (MemoryEffects::Read::get (), operand, /* stage=*/ 0 ,
1144+ /* effectOnFullRegion=*/ true ,
1145+ SideEffects::DefaultResource::get ());
1146+ }
1147+ effects.emplace_back (MemoryEffects::Write::get (), operand, /* stage=*/ 0 ,
1148+ /* effectOnFullRegion=*/ true ,
11391149 SideEffects::DefaultResource::get ());
11401150 }
11411151}
11421152
11431153void GenericOp::getEffects (
11441154 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
11451155 &effects) {
1146- getGenericEffectsImpl (effects, getOperation ()->getResults (), getDpsInputs (),
1147- getDpsInits ());
1156+ getGenericEffectsImpl (effects, cast<LinalgOp>(getOperation ()));
11481157}
11491158
11501159LogicalResult GenericOp::verify () { return success (); }
@@ -1492,8 +1501,7 @@ ArrayAttr MapOp::getIndexingMaps() {
14921501void MapOp::getEffects (
14931502 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
14941503 &effects) {
1495- getGenericEffectsImpl (effects, getOperation ()->getResults (), getDpsInputs (),
1496- getDpsInits ());
1504+ getGenericEffectsImpl (effects, cast<LinalgOp>(getOperation ()));
14971505}
14981506
14991507// ===----------------------------------------------------------------------===//
@@ -1561,8 +1569,7 @@ ArrayAttr ReduceOp::getIndexingMaps() {
15611569void ReduceOp::getEffects (
15621570 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
15631571 &effects) {
1564- getGenericEffectsImpl (effects, getOperation ()->getResults (), getDpsInputs (),
1565- getDpsInits ());
1572+ getGenericEffectsImpl (effects, cast<LinalgOp>(getOperation ()));
15661573}
15671574
15681575static ParseResult parseDenseI64ArrayAttr (OpAsmParser &parser,
@@ -1846,8 +1853,7 @@ ArrayAttr TransposeOp::getIndexingMaps() {
18461853void TransposeOp::getEffects (
18471854 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
18481855 &effects) {
1849- getGenericEffectsImpl (effects, getOperation ()->getResults (), getDpsInputs (),
1850- getDpsInits ());
1856+ getGenericEffectsImpl (effects, cast<LinalgOp>(getOperation ()));
18511857}
18521858
18531859LogicalResult TransposeOp::fold (FoldAdaptor adaptor,
@@ -1984,8 +1990,7 @@ ArrayAttr BroadcastOp::getIndexingMaps() {
19841990void BroadcastOp::getEffects (
19851991 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
19861992 &effects) {
1987- getGenericEffectsImpl (effects, getOperation ()->getResults (), getDpsInputs (),
1988- getDpsInits ());
1993+ getGenericEffectsImpl (effects, cast<LinalgOp>(getOperation ()));
19891994}
19901995
19911996void BroadcastOp::getCanonicalizationPatterns (RewritePatternSet &results,
@@ -2513,8 +2518,23 @@ SoftmaxOp::reifyResultShapes(OpBuilder &b,
25132518void SoftmaxOp::getEffects (
25142519 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
25152520 &effects) {
2516- getGenericEffectsImpl (effects, getOperation ()->getResults (), getDpsInputs (),
2517- getDpsInits ());
2521+ for (Value operand : getDpsInputs ()) {
2522+ if (!llvm::isa<MemRefType>(operand.getType ()))
2523+ continue ;
2524+ effects.emplace_back (MemoryEffects::Read::get (), operand, /* stage=*/ 0 ,
2525+ /* effectOnFullRegion=*/ true ,
2526+ SideEffects::DefaultResource::get ());
2527+ }
2528+ for (Value operand : getDpsInits ()) {
2529+ if (!llvm::isa<MemRefType>(operand.getType ()))
2530+ continue ;
2531+ effects.emplace_back (MemoryEffects::Read::get (), operand, /* stage=*/ 0 ,
2532+ /* effectOnFullRegion=*/ true ,
2533+ SideEffects::DefaultResource::get ());
2534+ effects.emplace_back (MemoryEffects::Write::get (), operand, /* stage=*/ 0 ,
2535+ /* effectOnFullRegion=*/ true ,
2536+ SideEffects::DefaultResource::get ());
2537+ }
25182538}
25192539
25202540// Helper functions for softmax decomposition.
0 commit comments