Skip to content

Commit 84c77f8

Browse files
hanhanWpstarkcdpr
authored andcommitted
[DT] Support partial load/store for identity encoding resolver. (iree-org#22360)
The revision adds the support of partial load/store lowering for identity encoding resolver; it removes the checks from `MaterializeTensorExtDispatchTensorLoadOp` and `MaterializeTensorExtDispatchTensorStoreOp` because they belong to encoding resolver implementation details. The data-tiling encoding resolvers all have the check: https://github.com/iree-org/iree/blob/fcae3fcd1f5032a24ca00d913a6f026cb37edcf1/compiler/src/iree/compiler/Codegen/ExternalInterfaces/Utils.h#L136-L141 The check for padding resolver: https://github.com/iree-org/iree/blob/4127b869cc72b230b56c331e53db7ca71de067b1/compiler/src/iree/compiler/Codegen/ExternalInterfaces/GPUEncodingExternalModels.cpp#L596-L603 Signed-off-by: hanhanW <[email protected]>
1 parent c0e8d3c commit 84c77f8

File tree

3 files changed

+92
-28
lines changed

3 files changed

+92
-28
lines changed

compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPatterns.cpp

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -229,13 +229,6 @@ struct MaterializeTensorExtDispatchTensorLoadOp
229229
matchAndRewrite(IREE::TensorExt::DispatchTensorLoadOp loadOp,
230230
OpAdaptor adaptor,
231231
ConversionPatternRewriter &rewriter) const override {
232-
// Only handle operations where the load covers the entire
233-
// `!iree_tensor_ext.dispatch.tensor` type.
234-
// TODO(ravishankarm): Relax this for partial loads.
235-
if (!loadOp.isLoadOfWholeSource()) {
236-
return rewriter.notifyMatchFailure(loadOp, "unhandled partial loads");
237-
}
238-
239232
auto sourceType = loadOp.getSourceType();
240233
auto boundTensorType = cast<RankedTensorType>(sourceType.getBoundType());
241234
auto *typeConverter = static_cast<const MaterializeEncodingTypeConverter *>(
@@ -272,13 +265,6 @@ struct MaterializeTensorExtDispatchTensorStoreOp
272265
matchAndRewrite(IREE::TensorExt::DispatchTensorStoreOp storeOp,
273266
OpAdaptor adaptor,
274267
ConversionPatternRewriter &rewriter) const override {
275-
// Only handle operations where the store covers the entire
276-
// `!iree_tensor_ext.dispatch.tensor` type.
277-
// TODO(ravishankarm): Relax this for partial stores.
278-
if (!storeOp.isStoreToWholeTarget()) {
279-
return rewriter.notifyMatchFailure(storeOp, "unhandled partial stores");
280-
}
281-
282268
auto targetType = storeOp.getTargetType();
283269
auto boundTensorType = cast<RankedTensorType>(targetType.getBoundType());
284270
auto *typeConverter = static_cast<const MaterializeEncodingTypeConverter *>(

compiler/src/iree/compiler/Codegen/Common/test/materialize_encoding_for_iree_ops.mlir

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#encoding_lhs = #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], iteration_sizes = [?, ?, ?]>
1717
#encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], iteration_sizes = [?, ?, ?]>
1818
#encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], iteration_sizes = [?, ?, ?]>
19-
func.func @matmul_lowering_f32f32f32_identity_resolver() attributes {
19+
func.func @matmul_lowering_f32f32f32_identity_resolver_full_slices() attributes {
2020
hal.executable.target = #hal.executable.target<"llvm-cpu", "whatever", {iree.encoding.resolver = #iree_encoding.identity_resolver<>}>
2121
} {
2222
%c0 = arith.constant 0 : index
@@ -48,7 +48,7 @@ func.func @matmul_lowering_f32f32f32_identity_resolver() attributes {
4848
-> !iree_tensor_ext.dispatch.tensor<readwrite:tensor<?x?xf32, #encoding_result>>{%M, %N}
4949
return
5050
}
51-
// CHECK-LABEL: func @matmul_lowering_f32f32f32_identity_resolver()
51+
// CHECK-LABEL: func @matmul_lowering_f32f32f32_identity_resolver_full_slices()
5252
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
5353
// CHECK-DAG: %[[M:.+]] = hal.interface.constant.load layout(#pipeline_layout) ordinal(0)
5454
// CHECK-DAG: %[[N:.+]] = hal.interface.constant.load layout(#pipeline_layout) ordinal(1)
@@ -73,6 +73,93 @@ func.func @matmul_lowering_f32f32f32_identity_resolver() attributes {
7373

7474
// -----
7575

76+
#pipeline_layout = #hal.pipeline.layout<constants = 12, bindings = [
77+
#hal.pipeline.binding<storage_buffer>,
78+
#hal.pipeline.binding<storage_buffer>,
79+
#hal.pipeline.binding<storage_buffer>
80+
]>
81+
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
82+
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
83+
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
84+
#encoding_lhs = #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], iteration_sizes = [?, ?, ?]>
85+
#encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], iteration_sizes = [?, ?, ?]>
86+
#encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], iteration_sizes = [?, ?, ?]>
87+
func.func @matmul_lowering_f32f32f32_identity_resolver_partial_slices() attributes {
88+
hal.executable.target = #hal.executable.target<"llvm-cpu", "whatever", {iree.encoding.resolver = #iree_encoding.identity_resolver<>}>
89+
} {
90+
%c0 = arith.constant 0 : index
91+
%M = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
92+
%N = hal.interface.constant.load layout(#pipeline_layout) ordinal(1) : index
93+
%K = hal.interface.constant.load layout(#pipeline_layout) ordinal(2) : index
94+
%sizeM = hal.interface.constant.load layout(#pipeline_layout) ordinal(3) : index
95+
%sizeN = hal.interface.constant.load layout(#pipeline_layout) ordinal(4) : index
96+
%sizeK = hal.interface.constant.load layout(#pipeline_layout) ordinal(5) : index
97+
%offsetM = hal.interface.constant.load layout(#pipeline_layout) ordinal(6) : index
98+
%offsetN = hal.interface.constant.load layout(#pipeline_layout) ordinal(7) : index
99+
%offsetK = hal.interface.constant.load layout(#pipeline_layout) ordinal(8) : index
100+
%strideM = hal.interface.constant.load layout(#pipeline_layout) ordinal(9) : index
101+
%strideN = hal.interface.constant.load layout(#pipeline_layout) ordinal(10) : index
102+
%strideK = hal.interface.constant.load layout(#pipeline_layout) ordinal(11) : index
103+
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0)
104+
: !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x?xf32, #encoding_lhs>>{%M, %K}
105+
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0)
106+
: !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x?xf32, #encoding_rhs>>{%K, %N}
107+
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0)
108+
: !iree_tensor_ext.dispatch.tensor<readwrite:tensor<?x?xf32, #encoding_result>>{%M, %N}
109+
%3 = iree_tensor_ext.dispatch.tensor.load %0, offsets = [%offsetM, %offsetK], sizes = [%sizeM, %sizeK], strides = [%strideM, %strideK]
110+
: !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x?xf32, #encoding_lhs>>{%M, %K}
111+
-> tensor<?x?xf32, #encoding_lhs>
112+
%4 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [%offsetK, %offsetN], sizes = [%sizeK, %sizeN], strides = [%strideK, %strideN]
113+
: !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x?xf32, #encoding_rhs>>{%K, %N}
114+
-> tensor<?x?xf32, #encoding_rhs>
115+
%5 = iree_tensor_ext.dispatch.tensor.load %2, offsets = [%offsetM, %offsetN], sizes = [%sizeM, %sizeN], strides = [%strideM, %strideN]
116+
: !iree_tensor_ext.dispatch.tensor<readwrite:tensor<?x?xf32, #encoding_result>>{%M, %N}
117+
-> tensor<?x?xf32, #encoding_result>
118+
%6 = linalg.matmul
119+
ins(%3, %4 : tensor<?x?xf32, #encoding_lhs>,
120+
tensor<?x?xf32, #encoding_rhs>)
121+
outs(%5 : tensor<?x?xf32, #encoding_result>)
122+
-> tensor<?x?xf32, #encoding_result>
123+
iree_tensor_ext.dispatch.tensor.store %6, %2, offsets = [%offsetM, %offsetN], sizes = [%sizeM, %sizeN], strides = [%strideM, %strideN]
124+
: tensor<?x?xf32, #encoding_result>
125+
-> !iree_tensor_ext.dispatch.tensor<readwrite:tensor<?x?xf32, #encoding_result>>{%M, %N}
126+
return
127+
128+
}
129+
// CHECK-LABEL: func @matmul_lowering_f32f32f32_identity_resolver_partial_slices()
130+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
131+
// CHECK-DAG: %[[M:.+]] = hal.interface.constant.load layout(#pipeline_layout) ordinal(0)
132+
// CHECK-DAG: %[[N:.+]] = hal.interface.constant.load layout(#pipeline_layout) ordinal(1)
133+
// CHECK-DAG: %[[K:.+]] = hal.interface.constant.load layout(#pipeline_layout) ordinal(2)
134+
// CHECK-DAG: %[[SIZE_M:.+]] = hal.interface.constant.load layout(#pipeline_layout) ordinal(3)
135+
// CHECK-DAG: %[[SIZE_N:.+]] = hal.interface.constant.load layout(#pipeline_layout) ordinal(4)
136+
// CHECK-DAG: %[[SIZE_K:.+]] = hal.interface.constant.load layout(#pipeline_layout) ordinal(5)
137+
// CHECK-DAG: %[[OFFSET_M:.+]] = hal.interface.constant.load layout(#pipeline_layout) ordinal(6)
138+
// CHECK-DAG: %[[OFFSET_N:.+]] = hal.interface.constant.load layout(#pipeline_layout) ordinal(7)
139+
// CHECK-DAG: %[[OFFSET_K:.+]] = hal.interface.constant.load layout(#pipeline_layout) ordinal(8)
140+
// CHECK-DAG: %[[STRIDE_M:.+]] = hal.interface.constant.load layout(#pipeline_layout) ordinal(9)
141+
// CHECK-DAG: %[[STRIDE_N:.+]] = hal.interface.constant.load layout(#pipeline_layout) ordinal(10)
142+
// CHECK-DAG: %[[STRIDE_K:.+]] = hal.interface.constant.load layout(#pipeline_layout) ordinal(11)
143+
// CHECK: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0)
144+
// CHECK-SAME: !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x?xf32>>{%[[M]], %[[K]]}
145+
// CHECK: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1)
146+
// CHECK-SAME: !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x?xf32>>{%[[K]], %[[N]]}
147+
// CHECK: %[[OUTS_BINDING:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2)
148+
// CHECK-SAME: !iree_tensor_ext.dispatch.tensor<readwrite:tensor<?x?xf32>>{%[[M]], %[[N]]}
149+
// CHECK: %[[LHS:.+]] = iree_tensor_ext.dispatch.tensor.load %[[LHS_BINDING]]
150+
// CHECK-SAME: offsets = [%[[OFFSET_M]], %[[OFFSET_K]]], sizes = [%[[SIZE_M]], %[[SIZE_K]]], strides = [%[[STRIDE_M]], %[[STRIDE_K]]]
151+
// CHECK: %[[RHS:.+]] = iree_tensor_ext.dispatch.tensor.load %[[RHS_BINDING]]
152+
// CHECK-SAME: offsets = [%[[OFFSET_K]], %[[OFFSET_N]]], sizes = [%[[SIZE_K]], %[[SIZE_N]]], strides = [%[[STRIDE_K]], %[[STRIDE_N]]]
153+
// CHECK: %[[OUTS:.+]] = iree_tensor_ext.dispatch.tensor.load %[[OUTS_BINDING]]
154+
// CHECK-SAME: offsets = [%[[OFFSET_M]], %[[OFFSET_N]]], sizes = [%[[SIZE_M]], %[[SIZE_N]]], strides = [%[[STRIDE_M]], %[[STRIDE_N]]]
155+
// CHECK: %[[RES:.+]] = linalg.matmul
156+
// CHECK-SAME: ins(%[[LHS]], %[[RHS]] :
157+
// CHECK-SAME: outs(%[[OUTS]] :
158+
// CHECK: iree_tensor_ext.dispatch.tensor.store %[[RES]], %[[OUTS_BINDING]]
159+
// CHECK-SAME: offsets = [%[[OFFSET_M]], %[[OFFSET_N]]], sizes = [%[[SIZE_M]], %[[SIZE_N]]], strides = [%[[STRIDE_M]], %[[STRIDE_N]]]
160+
161+
// -----
162+
76163
//----------------------------------------------------------------------------//
77164
// Test suite using CPU encoding resolvers.
78165
//----------------------------------------------------------------------------//

compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -550,18 +550,9 @@ LogicalResult IdentityResolverAttr::getOffsetsSizesStrides(
550550
SmallVectorImpl<OpFoldResult> &newOffsets,
551551
SmallVectorImpl<OpFoldResult> &newSizes,
552552
SmallVectorImpl<OpFoldResult> &newStrides) const {
553-
// Only handle cases where the slice spans the whole
554-
// `!iree_tensor_ext.dispatch.tensor` type.
555-
// TODO(hanchung): Enable partial slices. It was copied from pattern's
556-
// implementaion, i.e., the users, and it can be dropped after we move the
557-
// checks to the interface implementations.
558-
if (!type.doesSliceSpanWholeTensor(dynamicDims, offsets, sizes, strides)) {
559-
return failure();
560-
}
561-
auto boundTensorType = cast<RankedTensorType>(type.getBoundType());
562-
newSizes = getMixedValues(boundTensorType.getShape(), dynamicDims, builder);
563-
newOffsets.resize(newSizes.size(), builder.getIndexAttr(0));
564-
newStrides.resize(newSizes.size(), builder.getIndexAttr(1));
553+
newSizes.assign(sizes.begin(), sizes.end());
554+
newOffsets.assign(offsets.begin(), offsets.end());
555+
newStrides.assign(strides.begin(), strides.end());
565556
return success();
566557
}
567558

0 commit comments

Comments
 (0)