Skip to content

Commit 0938623

Browse files
authored
[Codegen] Add ops to translate between tensor and memref (#20619)
This PR adds 2 new ops to go from tensors to memrefs and memrefs to tensors. The new ops are: - `iree_codegen.load_from_memref` - Loads a tensor from a memref with compatible shape and same element type. - `iree_codegen.store_to_memref` - Stores a tensor to a memref with compatible shape and same element type. These ops can be used to bufferize hal.interface.subspan and iree_tensor_ext.dispatch.tensor.store ops early on, which allows for certain transformations that are only possible with memref semantics to happen without needing to bufferize the entire dispatch. This PR only adds the basic op definitions, and various op interface implementations (bufferization, subset insertion, etc.) will come as follow-ups. --------- Signed-off-by: Max Dawkins <[email protected]>
1 parent 21072b6 commit 0938623

File tree

6 files changed

+216
-0
lines changed

6 files changed

+216
-0
lines changed

compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,49 @@ void ExtractStridedMetadataOp::getAsmResultNames(
6565
setNameFn(getStrides().front(), "strides");
6666
}
6767
}
68+
69+
//===----------------------------------------------------------------------===//
70+
// LoadFromMemrefOp
71+
//===----------------------------------------------------------------------===//
72+
73+
LogicalResult LoadFromMemrefOp::verify() {
74+
RankedTensorType resultType = getResult().getType();
75+
MemRefType sourceType = getSource().getType();
76+
if (failed(verifyCompatibleShape(resultType.getShape(),
77+
sourceType.getShape())) ||
78+
resultType.getElementType() != sourceType.getElementType()) {
79+
return emitOpError("source and result shapes must be compatible and "
80+
"element types must match");
81+
}
82+
return success();
83+
}
84+
85+
void LoadFromMemrefOp::getEffects(
86+
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
87+
&effects) {
88+
effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable(),
89+
SideEffects::DefaultResource::get());
90+
}
91+
92+
//===----------------------------------------------------------------------===//
93+
// StoreToMemrefOp
94+
//===----------------------------------------------------------------------===//
95+
96+
LogicalResult StoreToMemrefOp::verify() {
97+
RankedTensorType valueType = getValue().getType();
98+
MemRefType targetType = getTarget().getType();
99+
if (failed(
100+
verifyCompatibleShape(valueType.getShape(), targetType.getShape())) ||
101+
valueType.getElementType() != targetType.getElementType()) {
102+
return emitOpError("value and target shapes must be compatible and element "
103+
"types must match");
104+
}
105+
return success();
106+
}
107+
108+
void StoreToMemrefOp::getEffects(
109+
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
110+
&effects) {
111+
effects.emplace_back(MemoryEffects::Write::get(), &getTargetMutable(),
112+
SideEffects::DefaultResource::get());
113+
}

compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.td

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,5 +171,56 @@ def IREECodegen_NullPointerOp :
171171
let assemblyFormat = "attr-dict";
172172
}
173173

174+
//===----------------------------------------------------------------------===//
175+
// LoadFrom/StoreToMemref Ops
176+
//===----------------------------------------------------------------------===//
177+
178+
def IREECodegen_LoadFromMemrefOp : Op<IREECodegen_Dialect, "load_from_memref",
179+
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
180+
let summary = [{loads a tensor from a memref}];
181+
let description = [{
182+
Loads a tensor from a memref with a compatible shape and the same element
183+
type. The read_only attribute indicates that the source buffer from which
184+
the tensor is read is a read only buffer. This hint can be used by
185+
bufferization to determine whether or not the buffer that this op reads
186+
from may be written.
187+
}];
188+
189+
let arguments = (ins
190+
AnyStridedMemRef:$source,
191+
UnitAttr:$read_only
192+
);
193+
let results = (outs
194+
AnyRankedTensor:$result
195+
);
196+
197+
let assemblyFormat = [{
198+
$source attr-dict `:` type($source) `->` type($result)
199+
}];
200+
201+
let hasVerifier = 1;
202+
}
203+
204+
def IREECodegen_StoreToMemrefOp : Op<IREECodegen_Dialect, "store_to_memref",
205+
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
206+
let summary = [{stores a tensor into a memref}];
207+
let description = [{
208+
Stores a tensor into a memref with a compatible shape and the same element
209+
type.
210+
}];
211+
212+
let arguments = (ins
213+
AnyRankedTensor:$value,
214+
AnyStridedMemRef:$target
215+
);
216+
let results = (outs);
217+
218+
let assemblyFormat = [{
219+
$value `,` $target
220+
attr-dict `:` type($value) `into` type($target)
221+
}];
222+
223+
let hasVerifier = 1;
224+
}
174225

175226
#endif // IREE_CODEGEN_DIALECT_IREECODEGENOPS

compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ iree_lit_test_suite(
2020
[
2121
"invalid.mlir",
2222
"lowering_config_attr.mlir",
23+
"roundtrip.mlir",
2324
"ukernel_ops.mlir",
2425
"ukernel_ops_cse.mlir",
2526
"workgroup_mapping_attrs.mlir",

compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ iree_lit_test_suite(
1616
SRCS
1717
"invalid.mlir"
1818
"lowering_config_attr.mlir"
19+
"roundtrip.mlir"
1920
"ukernel_ops.mlir"
2021
"ukernel_ops_cse.mlir"
2122
"workgroup_mapping_attrs.mlir"

compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/invalid.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,35 @@ module {
1010
return
1111
}
1212
}
13+
14+
// -----
15+
16+
func.func @load_from_memref_invalid_shape(%arg0: memref<5xf32>) -> tensor<4xf32> {
17+
// expected-error @+1 {{source and result shapes must be compatible and element types must match}}
18+
%value = iree_codegen.load_from_memref %arg0 : memref<5xf32> -> tensor<4xf32>
19+
return %value : tensor<4xf32>
20+
}
21+
22+
// -----
23+
24+
func.func @load_from_memref_invalid_element_type(%arg0: memref<4xf32>) -> tensor<4xf16> {
25+
// expected-error @+1 {{source and result shapes must be compatible and element types must match}}
26+
%value = iree_codegen.load_from_memref %arg0 : memref<4xf32> -> tensor<4xf16>
27+
return %value : tensor<4xf16>
28+
}
29+
30+
// -----
31+
32+
func.func @store_to_memref_invalid_shape(%arg0: tensor<4xf32>, %arg1: memref<5xf32>) {
33+
// expected-error @+1 {{value and target shapes must be compatible and element types must match}}
34+
iree_codegen.store_to_memref %arg0, %arg1 : tensor<4xf32> into memref<5xf32>
35+
return
36+
}
37+
38+
// -----
39+
40+
func.func @store_to_memref_invalid_element_type(%arg0: tensor<4xf16>, %arg1: memref<4xf32>) {
41+
// expected-error @+1 {{value and target shapes must be compatible and element types must match}}
42+
iree_codegen.store_to_memref %arg0, %arg1 : tensor<4xf16> into memref<4xf32>
43+
return
44+
}
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// RUN: iree-opt --split-input-file %s | FileCheck %s
2+
3+
func.func @load_from_memref(%arg0: memref<4xf32>) -> tensor<4xf32> {
4+
%value = iree_codegen.load_from_memref %arg0 : memref<4xf32> -> tensor<4xf32>
5+
return %value : tensor<4xf32>
6+
}
7+
// CHECK-LABEL: func.func @load_from_memref(
8+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]
9+
// CHECK: iree_codegen.load_from_memref %[[ARG0]]
10+
// CHECK-SAME: : memref<4xf32> -> tensor<4xf32>
11+
12+
// -----
13+
14+
func.func @load_from_memref_read_only(%arg0: memref<4xf32>) -> tensor<4xf32> {
15+
%value = iree_codegen.load_from_memref %arg0 {read_only} : memref<4xf32> -> tensor<4xf32>
16+
return %value : tensor<4xf32>
17+
}
18+
// CHECK-LABEL: func.func @load_from_memref_read_only(
19+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]
20+
// CHECK: iree_codegen.load_from_memref %[[ARG0]] {read_only}
21+
// CHECK-SAME: : memref<4xf32> -> tensor<4xf32>
22+
23+
// -----
24+
25+
func.func @load_from_memref_mixed_static_dynamic(%arg0: memref<?x4xf32>) -> tensor<4x?xf32> {
26+
%value = iree_codegen.load_from_memref %arg0 : memref<?x4xf32> -> tensor<4x?xf32>
27+
return %value : tensor<4x?xf32>
28+
}
29+
// CHECK-LABEL: func.func @load_from_memref_mixed_static_dynamic(
30+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]
31+
// CHECK: iree_codegen.load_from_memref %[[ARG0]]
32+
// CHECK-SAME: : memref<?x4xf32> -> tensor<4x?xf32>
33+
34+
// -----
35+
36+
func.func @load_from_strided_memref(
37+
%arg0: memref<?x?xf32, strided<[?, 1], offset: ?>>
38+
) -> tensor<?x?xf32> {
39+
%value = iree_codegen.load_from_memref %arg0
40+
: memref<?x?xf32, strided<[?, 1], offset: ?>> -> tensor<?x?xf32>
41+
return %value : tensor<?x?xf32>
42+
}
43+
// CHECK-LABEL: func.func @load_from_strided_memref(
44+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]:
45+
// CHECK: iree_codegen.load_from_memref %[[ARG0]]
46+
// CHECK-SAME: : memref<?x?xf32, strided<[?, 1], offset: ?>> -> tensor<?x?xf32>
47+
48+
// -----
49+
50+
func.func @store_to_memref(%arg0: tensor<4xf32>, %arg1: memref<4xf32>) {
51+
iree_codegen.store_to_memref %arg0, %arg1 : tensor<4xf32> into memref<4xf32>
52+
return
53+
}
54+
// CHECK-LABEL: func.func @store_to_memref(
55+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]
56+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]
57+
// CHECK: iree_codegen.store_to_memref %[[ARG0]], %[[ARG1]]
58+
// CHECK-SAME: : tensor<4xf32> into memref<4xf32>
59+
60+
// -----
61+
62+
func.func @store_to_memref_mixed_static_dynamic(%arg0: tensor<4x?xf32>, %arg1: memref<?x4xf32>) {
63+
iree_codegen.store_to_memref %arg0, %arg1 : tensor<4x?xf32> into memref<?x4xf32>
64+
return
65+
}
66+
// CHECK-LABEL: func.func @store_to_memref_mixed_static_dynamic(
67+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]
68+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]
69+
// CHECK: iree_codegen.store_to_memref %[[ARG0]], %[[ARG1]]
70+
// CHECK-SAME: : tensor<4x?xf32> into memref<?x4xf32>
71+
72+
// -----
73+
74+
func.func @store_to_strided_memref(
75+
%arg0: tensor<?x?xf32>, %arg1: memref<?x?xf32, strided<[?, 1], offset: ?>>
76+
) {
77+
iree_codegen.store_to_memref %arg0, %arg1
78+
: tensor<?x?xf32> into memref<?x?xf32, strided<[?, 1], offset: ?>>
79+
return
80+
}
81+
// CHECK-LABEL: func.func @store_to_strided_memref(
82+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]:
83+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]:
84+
// CHECK: iree_codegen.store_to_memref %[[ARG0]], %[[ARG1]]
85+
// CHECK-SAME: : tensor<?x?xf32> into memref<?x?xf32, strided<[?, 1], offset: ?>>

0 commit comments

Comments
 (0)