Skip to content

Commit 43b0edf

Browse files
authored
[VectorExt] Implement BufferizationInterface for transfer_gather (iree-org#21219)
1 parent 07203e6 commit 43b0edf

File tree

8 files changed

+130
-0
lines changed

8 files changed

+130
-0
lines changed

compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3106,3 +3106,20 @@ func.func @cache_swizzle_resource_cast(%stride: index) {
31063106
// CHECK: %[[CAST:.+]] = amdgpu.fat_raw_buffer_cast %[[INPUT]]
31073107
// CHECK-SAME: cacheSwizzleStride(%[[TRUNC]]) resetOffset
31083108
// CHECK-SAME: memref<2xf32, #hal.descriptor_type<storage_buffer>> to memref<2xf32, #amdgpu.address_space<fat_raw_buffer>>
3109+
3110+
// -----
3111+
3112+
func.func @transfer_gather(%source : tensor<?x64xf16>, %indices: vector<8xindex>) -> vector<8x64xf16> {
3113+
%c0 = arith.constant 0 : index
3114+
%cst = arith.constant 0.0 : f16
3115+
%out = iree_vector_ext.transfer_gather %source[%c0, %c0][%indices: vector<8xindex>, None], %cst {
3116+
indexed_maps = [affine_map<(d0, d1) -> (d0)>]
3117+
} : tensor<?x64xf16>, vector<8x64xf16>
3118+
return %out : vector<8x64xf16>
3119+
}
3120+
3121+
// CHECK-LABEL: func.func @transfer_gather
3122+
// CHECK-SAME: %[[SOURCE:.+]]: tensor<?x64xf16>, %[[INDICES:.+]]: vector<8xindex>
3123+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
3124+
// CHECK: %[[BUFFER:.+]] = bufferization.to_buffer %[[SOURCE]]
3125+
// CHECK: iree_vector_ext.transfer_gather %[[BUFFER]][%[[C0]], %[[C0]]][%[[INDICES]]: vector<8xindex>, None]

compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/BUILD.bazel

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,13 @@ iree_gentbl_cc_library(
3535
iree_compiler_cc_library(
3636
name = "VectorExtTransforms",
3737
srcs = [
38+
"BufferizationInterfaces.cpp",
3839
"Passes.cpp",
3940
"VectorExtFoldUnitExtentDims.cpp",
4041
"VectorizeIREEVectorExtOps.cpp",
4142
],
4243
hdrs = [
44+
"BufferizationInterfaces.h",
4345
"Passes.h",
4446
"Passes.h.inc",
4547
"Transforms.h",
@@ -51,6 +53,9 @@ iree_compiler_cc_library(
5153
"//compiler/src/iree/compiler/Dialect/LinalgExt/Utils",
5254
"@llvm-project//llvm:Support",
5355
"@llvm-project//mlir:ArithDialect",
56+
"@llvm-project//mlir:BufferizationDialect",
57+
"@llvm-project//mlir:BufferizationInterfaces",
58+
"@llvm-project//mlir:BufferizationTransforms",
5459
"@llvm-project//mlir:FunctionInterfaces",
5560
"@llvm-project//mlir:IR",
5661
"@llvm-project//mlir:LinalgDialect",
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
// Copyright 2025 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/BufferizationInterfaces.h"
8+
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
9+
10+
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11+
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12+
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
13+
#include "mlir/IR/BuiltinTypes.h"
14+
#include "mlir/IR/Value.h"
15+
16+
namespace mlir::iree_compiler::IREE::VectorExt {
17+
18+
using mlir::bufferization::AnalysisState;
19+
using mlir::bufferization::BufferizableOpInterface;
20+
using mlir::bufferization::BufferizationOptions;
21+
using mlir::bufferization::BufferizationState;
22+
using mlir::bufferization::BufferRelation;
23+
using mlir::bufferization::replaceOpWithNewBufferizedOp;
24+
25+
namespace {
26+
27+
struct TransferGatherOpInterface
28+
: public BufferizableOpInterface::ExternalModel<
29+
TransferGatherOpInterface, IREE::VectorExt::TransferGatherOp> {
30+
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
31+
const AnalysisState &state) const {
32+
assert(isa<RankedTensorType>(opOperand.get().getType()) &&
33+
"only tensor types expected");
34+
return true;
35+
}
36+
37+
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
38+
const AnalysisState &state) const {
39+
assert(isa<RankedTensorType>(opOperand.get().getType()) &&
40+
"only tensor types expected");
41+
return false;
42+
}
43+
44+
bufferization::AliasingValueList
45+
getAliasingValues(Operation *op, OpOperand &opOperand,
46+
const AnalysisState &state) const {
47+
return {};
48+
}
49+
50+
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
51+
const BufferizationOptions &options,
52+
BufferizationState &state) const {
53+
auto gatherOp = cast<IREE::VectorExt::TransferGatherOp>(op);
54+
assert(isa<TensorType>(gatherOp.getShapedType()) &&
55+
"only tensor types expected");
56+
FailureOr<Value> buffer =
57+
getBuffer(rewriter, gatherOp.getBase(), options, state);
58+
if (failed(buffer))
59+
return failure();
60+
replaceOpWithNewBufferizedOp<IREE::VectorExt::TransferGatherOp>(
61+
rewriter, gatherOp, gatherOp.getVectorType(), *buffer,
62+
gatherOp.getIndices(), gatherOp.getIndexVecs(), gatherOp.getIndexed(),
63+
gatherOp.getIndexedMaps(), gatherOp.getPermutationMap(),
64+
gatherOp.getPadding(), gatherOp.getMask(), gatherOp.getInBoundsAttr());
65+
return success();
66+
}
67+
};
68+
69+
} // namespace
70+
71+
void registerIREEVectorExtBufferizationInterfaces(DialectRegistry &registry) {
72+
registry.addExtension(
73+
+[](MLIRContext *context, IREEVectorExtDialect *dialect) {
74+
TransferGatherOp::attachInterface<TransferGatherOpInterface>(*context);
75+
});
76+
}
77+
78+
} // namespace mlir::iree_compiler::IREE::VectorExt
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// Copyright 2025 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#ifndef IREE_COMPILER_CODEGEN_DIALECT_VECTOR_EXT_TRANSFORMS_BUFFERIZATIONINTERFACES_H_
8+
#define IREE_COMPILER_CODEGEN_DIALECT_VECTOR_EXT_TRANSFORMS_BUFFERIZATIONINTERFACES_H_
9+
10+
#include "mlir/IR/Dialect.h"
11+
12+
namespace mlir::iree_compiler::IREE::VectorExt {
13+
14+
// Register all interfaces needed for bufferization.
15+
void registerIREEVectorExtBufferizationInterfaces(DialectRegistry &registry);
16+
17+
} // namespace mlir::iree_compiler::IREE::VectorExt
18+
19+
#endif // IREE_COMPILER_CODEGEN_DIALECT_VECTOR_EXT_TRANSFORMS_BUFFERIZATIONINTERFACES_H_

compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,21 @@ iree_cc_library(
2323
NAME
2424
VectorExtTransforms
2525
HDRS
26+
"BufferizationInterfaces.h"
2627
"Passes.h"
2728
"Passes.h.inc"
2829
"Transforms.h"
2930
SRCS
31+
"BufferizationInterfaces.cpp"
3032
"Passes.cpp"
3133
"VectorExtFoldUnitExtentDims.cpp"
3234
"VectorizeIREEVectorExtOps.cpp"
3335
DEPS
3436
::PassesIncGen
3537
LLVMSupport
3638
MLIRArithDialect
39+
MLIRBufferizationDialect
40+
MLIRBufferizationTransforms
3741
MLIRFunctionInterfaces
3842
MLIRIR
3943
MLIRLinalgDialect

compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ iree_compiler_cc_library(
9191
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
9292
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
9393
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms:BufferizationInterfaces",
94+
"//compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR:IREEVectorExtDialect",
95+
"//compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms:VectorExtTransforms",
9496
"//compiler/src/iree/compiler/Codegen/Utils",
9597
"//compiler/src/iree/compiler/Dialect/HAL/IR",
9698
"//compiler/src/iree/compiler/Dialect/LinalgExt/IR",

compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
1111
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
1212
#include "iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.h"
13+
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.h"
14+
#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/BufferizationInterfaces.h"
1315
#include "iree/compiler/Codegen/Utils/Utils.h"
1416
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
1517
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
@@ -675,6 +677,7 @@ void registerBufferizationInterfaces(DialectRegistry &registry) {
675677

676678
// Register IREE operations.
677679
registerIREEGPUBufferizationInterfaces(registry);
680+
IREE::VectorExt::registerIREEVectorExtBufferizationInterfaces(registry);
678681
registry.addExtension(
679682
+[](MLIRContext *ctx, IREE::TensorExt::IREETensorExtDialect *dialect) {
680683
// DispatchTensorLoadOp

compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ iree_cc_library(
7979
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
8080
iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
8181
iree::compiler::Codegen::Dialect::GPU::Transforms::BufferizationInterfaces
82+
iree::compiler::Codegen::Dialect::VectorExt::IR::IREEVectorExtDialect
83+
iree::compiler::Codegen::Dialect::VectorExt::Transforms::VectorExtTransforms
8284
iree::compiler::Codegen::Utils
8385
iree::compiler::Dialect::HAL::IR
8486
iree::compiler::Dialect::LinalgExt::IR

0 commit comments

Comments
 (0)