Skip to content

Commit 89da9fc

Browse files
authored
[TensorExt][NFC] Remove dependency on Encoding through ExternalModel (#20612)
There is no strong dependency of `DispatchTensorType` on the Encoding dialect as the encoding can be just any attribute and is optional, so I don't think there's a need to have the TensorExt dialect explicitly depend on the Encoding just for the interface. Decoupling them would avoid potential circular dependency issues (I ran into that while working on #20160 (comment)), remove the need for other (potential) users of the TensorExt dialect to use the encoding and allow users to bring their own encoding interfaces while avoiding having to bring in the Encoding dialect. Signed-off-by: Jorn Tuyls <[email protected]>
1 parent b86ed92 commit 89da9fc

File tree

5 files changed

+29
-15
lines changed

5 files changed

+29
-15
lines changed

compiler/src/iree/compiler/Dialect/TensorExt/IR/BUILD.bazel

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ iree_compiler_cc_library(
4949
],
5050
deps = [
5151
":TensorExtOpsGen",
52-
"//compiler/src/iree/compiler/Dialect/Encoding/IR",
5352
"//compiler/src/iree/compiler/Dialect/Util/IR",
5453
"@llvm-project//llvm:Support",
5554
"@llvm-project//mlir:ArithDialect",

compiler/src/iree/compiler/Dialect/TensorExt/IR/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ iree_cc_library(
3434
MLIRIR
3535
MLIRInferTypeOpInterface
3636
MLIRTensorDialect
37-
iree::compiler::Dialect::Encoding::IR
3837
iree::compiler::Dialect::Util::IR
3938
PUBLIC
4039
)

compiler/src/iree/compiler/Dialect/TensorExt/IR/TensorExtTypes.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,13 +108,6 @@ bool DispatchTensorType::hasStaticShape(ArrayRef<int64_t> shape) const {
108108
return hasStaticShape() && getShape() == shape;
109109
}
110110

111-
Type DispatchTensorType::getEncodingType() const { return getBoundType(); }
112-
113-
Type DispatchTensorType::updateEncoding(Attribute encoding) const {
114-
return DispatchTensorType::get(getAccess(), getShape(), getBoundElementType(),
115-
encoding);
116-
}
117-
118111
LogicalResult
119112
DispatchTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
120113
uint32_t access, Type boundType) {

compiler/src/iree/compiler/Dialect/TensorExt/IR/TensorExtTypes.h

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
#ifndef IREE_COMPILER_DIALECT_TENSOREXT_IR_TENSOREXTTYPES_H_
88
#define IREE_COMPILER_DIALECT_TENSOREXT_IR_TENSOREXTTYPES_H_
99

10-
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
1110
#include "iree/compiler/Dialect/TensorExt/IR/TensorExtDialect.h"
1211
#include "llvm/ADT/DenseMapInfo.h"
1312
#include "llvm/ADT/SmallVector.h"
@@ -40,8 +39,7 @@ enum class TensorAccess : uint32_t {
4039
// we can't extend it and reuse all of this.
4140
class DispatchTensorType
4241
: public Type::TypeBase<DispatchTensorType, Type,
43-
detail::DispatchTensorTypeStorage,
44-
IREE::Encoding::EncodingTypeInterface::Trait> {
42+
detail::DispatchTensorTypeStorage> {
4543
public:
4644
using ImplType = detail::DispatchTensorTypeStorage;
4745

@@ -126,9 +124,6 @@ class DispatchTensorType
126124
}
127125
return llvm::cast<RankedTensorType>(boundType);
128126
}
129-
130-
Type getEncodingType() const;
131-
Type updateEncoding(Attribute encoding) const;
132127
};
133128

134129
void printType(DispatchTensorType &type, DialectAsmPrinter &p);

compiler/src/iree/compiler/ExternalInterfaces/TensorExtExternalModels.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,17 @@
66

77
#include "iree/compiler/ExternalInterfaces/TensorExtExternalModels.h"
88

9+
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
910
#include "iree/compiler/Dialect/TensorExt/IR/TensorExtOps.h"
1011
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
1112

1213
namespace mlir::iree_compiler {
1314
namespace {
1415

16+
//===----------------------------------------------------------------------===//
17+
// Op Interfaces
18+
//===----------------------------------------------------------------------===//
19+
1520
struct DispatchTensorLoadOpInterface
1621
: public ValueBoundsOpInterface::ExternalModel<
1722
DispatchTensorLoadOpInterface,
@@ -36,6 +41,27 @@ struct WorkloadOrdinalOpInterface
3641
}
3742
};
3843

44+
//===----------------------------------------------------------------------===//
45+
// Type Interfaces
46+
//===----------------------------------------------------------------------===//
47+
48+
struct EncodingTypeExternalModel
49+
: public IREE::Encoding::EncodingTypeInterface::ExternalModel<
50+
EncodingTypeExternalModel, IREE::TensorExt::DispatchTensorType> {
51+
52+
Type getEncodingType(Type type) const {
53+
auto dispatchTensorType = cast<IREE::TensorExt::DispatchTensorType>(type);
54+
return dispatchTensorType.getBoundType();
55+
}
56+
57+
Type updateEncoding(Type type, Attribute encoding) const {
58+
auto dispatchTensorType = cast<IREE::TensorExt::DispatchTensorType>(type);
59+
return IREE::TensorExt::DispatchTensorType::get(
60+
dispatchTensorType.getAccess(), dispatchTensorType.getShape(),
61+
dispatchTensorType.getBoundElementType(), encoding);
62+
}
63+
};
64+
3965
} // namespace
4066

4167
void registerTensorExtExternalModels(DialectRegistry &registry) {
@@ -45,6 +71,8 @@ void registerTensorExtExternalModels(DialectRegistry &registry) {
4571
DispatchTensorLoadOpInterface>(*ctx);
4672
IREE::TensorExt::DispatchWorkloadOrdinalOp::attachInterface<
4773
WorkloadOrdinalOpInterface>(*ctx);
74+
IREE::TensorExt::DispatchTensorType::attachInterface<
75+
EncodingTypeExternalModel>(*ctx);
4876
});
4977
}
5078

0 commit comments

Comments
 (0)