Skip to content

Commit a7a248e

Browse files
authored
[Integrate] Update bufferization related codes for upstream custom types support. (iree-org#21250)
1 parent 8c3d87d commit a7a248e

File tree

6 files changed

+62
-12
lines changed

6 files changed

+62
-12
lines changed

compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,12 @@ class EliminateEmptyTensorsPass final
7878
: public impl::EliminateEmptyTensorsPassBase<EliminateEmptyTensorsPass> {
7979
public:
8080
void getDependentDialects(DialectRegistry &registry) const override {
81-
registry.insert<tensor::TensorDialect>();
81+
// BufferizationDialect is needed for using type interfaces, like
82+
// TensorLikeType. Because the builtin types, e.g., RankedTensorType, etc.,
83+
// implement the type interface in
84+
// bufferization::BufferizationDialect::initialize().
85+
registry
86+
.insert<bufferization::BufferizationDialect, tensor::TensorDialect>();
8287
}
8388

8489
void runOnOperation() override;

compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1515
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
1616
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
17+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1718
#include "mlir/IR/BuiltinTypes.h"
1819
#include "mlir/IR/Value.h"
1920

@@ -108,7 +109,7 @@ struct BarrierRegionOpBufferizationInterface
108109
SmallVector<Value> &invocationStack) const {
109110
auto barrierOp = cast<IREE::GPU::BarrierRegionOp>(op);
110111

111-
FailureOr<BaseMemRefType> memrefType = failure();
112+
FailureOr<mlir::bufferization::BufferLikeType> memrefType = failure();
112113
if (auto opResult = dyn_cast<OpResult>(value)) {
113114
int64_t resultNum = opResult.getResultNumber();
114115
memrefType = bufferization::getBufferType(
@@ -121,7 +122,7 @@ struct BarrierRegionOpBufferizationInterface
121122
}
122123
if (failed(memrefType))
123124
return failure();
124-
return memrefType;
125+
return cast<BaseMemRefType>(*memrefType);
125126
}
126127

127128
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -146,10 +147,13 @@ struct BarrierRegionOpBufferizationInterface
146147
tensorizedOperands.push_back(replacement);
147148
continue;
148149
}
149-
tensorizedOperands.push_back(rewriter
150-
.create<bufferization::ToTensorOp>(
151-
replacement.getLoc(), replacement)
152-
.getResult());
150+
tensorizedOperands.push_back(
151+
rewriter
152+
.create<bufferization::ToTensorOp>(
153+
replacement.getLoc(),
154+
memref::getTensorTypeFromMemRefType(replacement.getType()),
155+
replacement)
156+
.getResult());
153157
}
154158

155159
rewriter.setInsertionPoint(barrierOp);
@@ -205,7 +209,7 @@ struct ValueBarrierOpBufferizationInterface
205209
state, invocationStack);
206210
if (failed(srcMemrefType))
207211
return failure();
208-
return srcMemrefType;
212+
return cast<BaseMemRefType>(*srcMemrefType);
209213
}
210214

211215
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -341,8 +345,9 @@ struct BufferResourceCastOpBufferizationInterface
341345
if (failed(srcMemrefType))
342346
return failure();
343347

344-
if (!hasStorageBufferMemSpace(srcMemrefType.value())) {
345-
return srcMemrefType;
348+
auto baseMemrefType = cast<BaseMemRefType>(srcMemrefType.value());
349+
if (!hasStorageBufferMemSpace(baseMemrefType)) {
350+
return baseMemrefType;
346351
}
347352

348353
auto rankedSrcType = cast<MemRefType>(srcMemrefType.value());

compiler/src/iree/compiler/ExternalInterfaces/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ iree_compiler_cc_library(
4444
"//compiler/src/iree/compiler/Dialect/Util/IR",
4545
"@llvm-project//llvm:Support",
4646
"@llvm-project//mlir:ArithDialect",
47+
"@llvm-project//mlir:BufferizationInterfaces",
4748
"@llvm-project//mlir:GPUDialect",
4849
"@llvm-project//mlir:IR",
4950
"@llvm-project//mlir:LinalgDialect",

compiler/src/iree/compiler/ExternalInterfaces/TensorExtExternalModels.cpp

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88

99
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
1010
#include "iree/compiler/Dialect/TensorExt/IR/TensorExtOps.h"
11+
#include "iree/compiler/Dialect/TensorExt/IR/TensorExtTypes.h"
12+
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
13+
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
1114
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
1215

1316
namespace mlir::iree_compiler {
@@ -62,6 +65,39 @@ struct EncodingTypeExternalModel
6265
}
6366
};
6467

68+
struct TensorLikeTypeExternalModel
69+
: bufferization::TensorLikeType::ExternalModel<
70+
TensorLikeTypeExternalModel, IREE::TensorExt::DispatchTensorType> {
71+
FailureOr<bufferization::BufferLikeType> getBufferType(
72+
Type type, const bufferization::BufferizationOptions &options,
73+
llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const {
74+
auto dispatchTensorType = cast<IREE::TensorExt::DispatchTensorType>(type);
75+
auto tensorType = cast<TensorType>(dispatchTensorType.asRankedTensorType());
76+
auto memSpace = options.defaultMemorySpaceFn(tensorType);
77+
if (!memSpace.has_value()) {
78+
return emitError() << "could not infer memory space";
79+
}
80+
return cast<bufferization::BufferLikeType>(
81+
getMemRefType(tensorType, options, /*layout=*/{}, *memSpace));
82+
}
83+
84+
LogicalResult verifyCompatibleBufferType(
85+
Type type, bufferization::BufferLikeType bufferType,
86+
llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const {
87+
auto dispatchTensorType = cast<IREE::TensorExt::DispatchTensorType>(type);
88+
assert(isa<BaseMemRefType>(bufferType) && "expected memref type");
89+
auto memrefType = cast<ShapedType>(bufferType);
90+
if (dispatchTensorType.getShape() != memrefType.getShape()) {
91+
return emitError() << "shapes do not match";
92+
}
93+
if (dispatchTensorType.getBoundElementType() !=
94+
memrefType.getElementType()) {
95+
return emitError() << "element types do not match";
96+
}
97+
return success();
98+
}
99+
};
100+
65101
} // namespace
66102

67103
void registerTensorExtExternalModels(DialectRegistry &registry) {
@@ -72,7 +108,7 @@ void registerTensorExtExternalModels(DialectRegistry &registry) {
72108
IREE::TensorExt::DispatchWorkloadOrdinalOp::attachInterface<
73109
WorkloadOrdinalOpInterface>(*ctx);
74110
IREE::TensorExt::DispatchTensorType::attachInterface<
75-
EncodingTypeExternalModel>(*ctx);
111+
EncodingTypeExternalModel, TensorLikeTypeExternalModel>(*ctx);
76112
});
77113
}
78114

compiler/src/iree/compiler/Preprocessing/Common/Passes.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,9 @@ def TransposeMatmulPass : Pass<"iree-preprocessing-transpose-matmul-pass"> {
159159
)}]>
160160
];
161161
let dependentDialects = [
162+
// TODO(hanchung): Remove the dep after switching upstream patterns to not
163+
// use bufferization::hasTensorSemantics method.
164+
"mlir::bufferization::BufferizationDialect",
162165
"mlir::linalg::LinalgDialect",
163166
];
164167
}

third_party/llvm-project

Submodule llvm-project updated 29 files

0 commit comments

Comments
 (0)