Skip to content

Commit 9f80663

Browse files
Max191keshavvinayak01
authored andcommitted
[Encoding] Remove ambiguity from encoding propagation interface methods (iree-org#21415)
The `EncodingPropagationAttrInterface` has shared methods for propagation upwards and downwards, and the methods only take a `Value` parameter, which is not enough to determine (1) Which operation to propagate through, or (2) Which direction to propagate. This PR splits the interface methods into separate methods for upwards and downwards propagation. The upwards propagation methods now take an `OpResult`, and the downwards propagation methods take an `OpOperand *`, to indicate for which operation/operand the propagation is happening for. There are no test changes, because the propagation interface methods are already tested in `DispatchCreation/test/propagate_encodings.mlir`. Signed-off-by: Max Dawkins <[email protected]> Signed-off-by: keshavvinayak01 <[email protected]>
1 parent 3c76642 commit 9f80663

File tree

4 files changed

+60
-26
lines changed

4 files changed

+60
-26
lines changed

compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingInterfaces.td

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -345,19 +345,20 @@ def IREEEncoding_PropagationAttrInterface :
345345
AttrInterface<"EncodingPropagationAttrInterface"> {
346346
let cppNamespace = "::mlir::iree_compiler::IREE::Encoding";
347347
let description = [{
348-
Interface used to query new encoding attributes that can be propagated to
349-
the operands and result of the operation.
348+
Interface used to query new encoding attributes resulting from propagation
349+
to the operands and results of operations.
350350
}];
351351

352352
let methods = [
353353
InterfaceMethod<
354354
[{
355-
Returns true if the encodings can be propagated across the operation.
355+
Returns true if the encoding can be propagated down through the
356+
target's owner operation.
356357
}],
357358
/*retTy=*/"bool",
358-
/*methodName=*/"isPropagable",
359+
/*methodName=*/"isPropagableDown",
359360
/*args=*/(ins
360-
"::mlir::Value":$target
361+
"::mlir::OpOperand *":$target
361362
),
362363
/*methodBody=*/"",
363364
/*defaultImplementation=*/[{
@@ -366,14 +367,47 @@ def IREEEncoding_PropagationAttrInterface :
366367
>,
367368
InterfaceMethod<
368369
[{
369-
Returns the new encodings for operand and result types for the given
370+
Returns true if the encoding can be propagated up through the
371+
target's owner operation.
372+
}],
373+
/*retTy=*/"bool",
374+
/*methodName=*/"isPropagableUp",
375+
/*args=*/(ins
376+
"::mlir::OpResult":$target
377+
),
378+
/*methodBody=*/"",
379+
/*defaultImplementation=*/[{
380+
return false;
381+
}]
382+
>,
383+
InterfaceMethod<
384+
[{
385+
Returns the new encodings for operand and result types for the
386+
target's owner operation after propagating the encoding down through
387+
the operation.
388+
}],
389+
/*retTy=*/
390+
"llvm::FailureOr<::mlir::iree_compiler::IREE::Encoding::PropagationEncoding>",
391+
/*methodName=*/"generateSinkingEncodings",
392+
/*args=*/(ins
393+
"::mlir::OpOperand *":$target
394+
),
395+
/*methodBody=*/"",
396+
/*defaultImplementation=*/[{
397+
return failure();
398+
}]
399+
>,
400+
InterfaceMethod<
401+
[{
402+
Returns the new encodings for operand and result types for the
403+
target's owner operation after propagating the encoding up through the
370404
operation.
371405
}],
372406
/*retTy=*/
373407
"llvm::FailureOr<::mlir::iree_compiler::IREE::Encoding::PropagationEncoding>",
374-
/*methodName=*/"generateEncodings",
408+
/*methodName=*/"generateBubblingEncodings",
375409
/*args=*/(ins
376-
"::mlir::Value":$target
410+
"::mlir::OpResult":$target
377411
),
378412
/*methodBody=*/"",
379413
/*defaultImplementation=*/[{

compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingTypes.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ struct PropagationResult {
7070
// the propagation.
7171
SmallVector<Operation *> generatedEncodingOps;
7272

73-
// The new corresponding result that is created by the propagation. It is
74-
// returned to the caller for further transformation or replacement.
75-
Value replacement;
73+
// The new results created after propagating an encoding through an operation.
74+
// It is returned to the caller for further transformation or replacement.
75+
SmallVector<Value> replacements;
7676
};
7777

7878
} // namespace mlir::iree_compiler::IREE::Encoding

compiler/src/iree/compiler/DispatchCreation/PropagateEncodings.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,30 +46,30 @@ struct PropagateEncodingsPass
4646

4747
LogicalResult SwapEncodingOpWithTensorCollapseShapeOp::matchAndRewrite(
4848
IREE::Encoding::SetEncodingOp encodingOp, PatternRewriter &rewriter) const {
49-
Value target = encodingOp.getSource();
49+
auto collapseOp =
50+
encodingOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
51+
if (!collapseOp) {
52+
return rewriter.notifyMatchFailure(encodingOp,
53+
"expected a collapse_shape producer");
54+
}
55+
auto target = cast<OpResult>(collapseOp.getResult());
5056
auto propagationAttrInterface =
5157
dyn_cast<IREE::Encoding::EncodingPropagationAttrInterface>(
5258
encodingOp.getResultType().getEncoding());
5359
if (!propagationAttrInterface ||
54-
!propagationAttrInterface.isPropagable(target)) {
60+
!propagationAttrInterface.isPropagableUp(target)) {
5561
return rewriter.notifyMatchFailure(
5662
encodingOp, "the propagation attribute interface isn't defined or the "
5763
"target isn't propagable");
5864
}
5965
// Get the encoding attributes for the operands and results of the operation.
6066
FailureOr<IREE::Encoding::PropagationEncoding> propagationEncodings =
61-
propagationAttrInterface.generateEncodings(target);
67+
propagationAttrInterface.generateBubblingEncodings(target);
6268
if (failed(propagationEncodings)) {
6369
return rewriter.notifyMatchFailure(encodingOp,
6470
"not able to determine propagation "
6571
"attributes for operands and results");
6672
}
67-
auto collapseOp =
68-
encodingOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
69-
if (!collapseOp) {
70-
return rewriter.notifyMatchFailure(encodingOp,
71-
"expected a collapse_shape producer");
72-
}
7373
if (!IREE::Flow::isNonNullAndOutsideDispatch(encodingOp) ||
7474
!IREE::Flow::isNonNullAndOutsideDispatch(collapseOp)) {
7575
return rewriter.notifyMatchFailure(
@@ -91,7 +91,7 @@ LogicalResult SwapEncodingOpWithTensorCollapseShapeOp::matchAndRewrite(
9191
return rewriter.notifyMatchFailure(
9292
encodingOp, "not able to propagate encodings and find replacement");
9393
}
94-
rewriter.replaceOp(encodingOp, maybeResult->replacement);
94+
rewriter.replaceOp(encodingOp, maybeResult->replacements[0]);
9595
return success();
9696
}
9797

compiler/src/iree/compiler/ExternalInterfaces/EncodingExternalModels.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ namespace {
1818
struct ContractionAttrPropagationInterface
1919
: public IREE::Encoding::EncodingPropagationAttrInterface::ExternalModel<
2020
ContractionAttrPropagationInterface, IREE::Encoding::MatmulKAttr> {
21-
bool isPropagable(Attribute attr, Value target) const {
21+
bool isPropagableUp(Attribute attr, OpResult target) const {
2222
auto encoding = cast<IREE::Encoding::MatmulKAttr>(attr);
23-
Operation *attachedToOperation = target.getDefiningOp();
23+
Operation *attachedToOperation = target.getOwner();
2424
if (!attachedToOperation) {
2525
return false;
2626
}
@@ -42,11 +42,11 @@ struct ContractionAttrPropagationInterface
4242
}
4343

4444
FailureOr<IREE::Encoding::PropagationEncoding>
45-
generateEncodings(Attribute attr, Value target) const {
45+
generateBubblingEncodings(Attribute attr, OpResult target) const {
4646
auto encoding = cast<IREE::Encoding::MatmulKAttr>(attr);
4747
return TypeSwitch<Operation *,
4848
FailureOr<IREE::Encoding::PropagationEncoding>>(
49-
target.getDefiningOp())
49+
target.getOwner())
5050
.Case<tensor::CollapseShapeOp>([&](auto collapseOp) {
5151
ArrayRef<int32_t> kDims = encoding.getKDims().asArrayRef();
5252
SmallVector<ReassociationIndices, 4> reassociationMaps =
@@ -98,7 +98,7 @@ struct ContractionOpPropagationInterface
9898
loc, resultEncodingType, newEncodingOp,
9999
collapseOp.getReassociationIndices());
100100
IREE::Encoding::PropagationResult result;
101-
result.replacement = newCollapseOp;
101+
result.replacements = {newCollapseOp};
102102
result.generatedEncodingOps.push_back(newEncodingOp.getDefiningOp());
103103
return result;
104104
})

0 commit comments

Comments
 (0)