Skip to content

Commit 5f7b471

Browse files
authored
[Stream] Add layouts to encodings for all stream tensor AffinityOp. (#19726)
The revision adds the support for the rest of AffinityOp that have TensorPhase trait, i.e., TensorCloneOp, TensorSliceOp, TensorFillOp, and TensorUpdateOp ops. It is tricky to handle encodings for transfer ops, so only the encoding in the fill op is updated. If other operations have tensor encodings, it returns a failure for now. There are two stream tensor ops do not implement the AffinityOpInterface, so they are not supported within the revision. They are stream.tensor.load op and stream.tensor.store op. We should be able to track the resource affinity for these two ops, and it requires additional analysis. Thus, they are not scoped within the revision. The revision also adds the missing documentation to the `addLayoutsToTensorPhaseOps` method. --------- Signed-off-by: hanhanW <[email protected]>
1 parent ac46df5 commit 5f7b471

File tree

2 files changed

+197
-3
lines changed

2 files changed

+197
-3
lines changed

compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp

Lines changed: 99 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,22 @@ updateTensorSizeOfOp(RewriterBase &rewriter,
358358
return success();
359359
}
360360

361+
/// Updates the target encoding of `op` with resolved layouts.
362+
static LogicalResult
363+
updateTensorFillOp(RewriterBase &rewriter, IREE::Stream::TensorFillOp op,
364+
const SetVector<Attribute> &layoutResolvers) {
365+
auto encodingType = dyn_cast<RankedTensorType>(op.getTargetEncoding());
366+
std::optional<IREE::Encoding::EncodingAttr> encodingAttr =
367+
getEncodingWithNewLayouts(encodingType, layoutResolvers);
368+
if (!encodingAttr) {
369+
return success();
370+
}
371+
rewriter.modifyOpInPlace(op, [&] {
372+
op.setTargetEncoding(cloneWithEncoding(encodingType, encodingAttr.value()));
373+
});
374+
return success();
375+
}
376+
361377
/// Returns failure if `op` has encoding. The EncodingAttr has padding
362378
/// semantic, a constant op with such encoding can not be resolved at this
363379
/// moment.
@@ -375,7 +391,70 @@ updateTensorConstantOp(RewriterBase &rewriter,
375391
return success();
376392
}
377393

378-
/// Updates the result_encoding for `op`. The op have to define a
394+
/// Returns a failure if there are encodings in target encoding type or update
395+
/// encoding type.
396+
static LogicalResult updateTensorUpdateOp(RewriterBase &rewriter,
397+
IREE::Stream::TensorUpdateOp op) {
398+
auto targetEncodingType = dyn_cast<RankedTensorType>(op.getTargetEncoding());
399+
if (targetEncodingType && targetEncodingType.getEncoding()) {
400+
return failure();
401+
}
402+
auto updateEncodingType = dyn_cast<RankedTensorType>(op.getUpdateEncoding());
403+
if (updateEncodingType && updateEncodingType.getEncoding()) {
404+
return failure();
405+
}
406+
return success();
407+
}
408+
409+
/// Returns a failure if there are encodings in source encoding type or result
410+
/// encoding type.
411+
static LogicalResult updateTensorCloneOp(RewriterBase &rewriter,
412+
IREE::Stream::TensorCloneOp op) {
413+
auto sourceEncodingType = dyn_cast<RankedTensorType>(op.getSourceEncoding());
414+
if (sourceEncodingType && sourceEncodingType.getEncoding()) {
415+
return failure();
416+
}
417+
auto resultEncodingType = dyn_cast<RankedTensorType>(op.getResultEncoding());
418+
if (resultEncodingType && resultEncodingType.getEncoding()) {
419+
return failure();
420+
}
421+
return success();
422+
}
423+
424+
/// Returns a failure if there are encodings in source encoding type or result
425+
/// encoding type.
426+
static LogicalResult updateTensorSliceOp(RewriterBase &rewriter,
427+
IREE::Stream::TensorSliceOp op) {
428+
auto sourceEncodingType = dyn_cast<RankedTensorType>(op.getSourceEncoding());
429+
if (sourceEncodingType && sourceEncodingType.getEncoding()) {
430+
return failure();
431+
}
432+
auto resultEncodingType = dyn_cast<RankedTensorType>(op.getResultEncoding());
433+
if (resultEncodingType && resultEncodingType.getEncoding()) {
434+
return failure();
435+
}
436+
return success();
437+
}
438+
439+
/// Updates the source_encoding for `op`. The op has to define a
440+
/// `source_encoding` parameter.
441+
template <typename OpTy>
442+
static LogicalResult
443+
updateSourceEncoding(RewriterBase &rewriter, OpTy op,
444+
const SetVector<Attribute> &layoutResolvers) {
445+
auto encodingType = dyn_cast<RankedTensorType>(op.getSourceEncoding());
446+
std::optional<IREE::Encoding::EncodingAttr> encodingAttr =
447+
getEncodingWithNewLayouts(encodingType, layoutResolvers);
448+
if (!encodingAttr) {
449+
return success();
450+
}
451+
rewriter.modifyOpInPlace(op, [&] {
452+
op.setSourceEncoding(cloneWithEncoding(encodingType, encodingAttr.value()));
453+
});
454+
return success();
455+
}
456+
457+
/// Updates the result_encoding for `op`. The op has to define a
379458
/// `result_encoding` parameter.
380459
template <typename OpTy>
381460
static LogicalResult
@@ -393,6 +472,16 @@ updateResultEncoding(RewriterBase &rewriter, OpTy op,
393472
return success();
394473
}
395474

475+
/// Adds the resolved layouts to all tensor types on stream tensor ops, if
476+
/// encodings are present. Most of stream tensor ops implement
477+
/// AffinityOpInterface, where a stream affinity indicates the kind of
478+
/// enviroment the ops are expected run in. When an encoding is present in the
479+
/// tensor type, the method resolves the layouts, strips outdated information,
480+
/// and adds the resolved layouts to the encodings. The updated encodings should
481+
/// have enough information for other lowering transformations.
482+
/// TODO(hanchung): Add support for stream.tensor.load ops and
483+
/// stream.tensor.store ops. They are not affinity ops, so additional analysis
484+
/// will be needed in the work.
396485
static LogicalResult addLayoutsToTensorPhaseOps(
397486
ModuleOp moduleOp, IREE::Stream::AffinityAnalysis &affinityAnalysis,
398487
FunctionOpInterface funcOp,
@@ -424,7 +513,6 @@ static LogicalResult addLayoutsToTensorPhaseOps(
424513
return affinityOp.emitError("failed on making layout resolvers");
425514
}
426515

427-
// TODO(hanchung): Update other Stream operations.
428516
LogicalResult result =
429517
TypeSwitch<Operation *, LogicalResult>(affinityOp)
430518
.Case<IREE::Stream::TensorDispatchOp>([&](auto op) {
@@ -442,6 +530,15 @@ static LogicalResult addLayoutsToTensorPhaseOps(
442530
.Case<IREE::Stream::TensorConstantOp>([&](auto op) {
443531
return updateTensorConstantOp(rewriter, op, layoutResolvers);
444532
})
533+
.Case<IREE::Stream::TensorFillOp>([&](auto op) {
534+
return updateTensorFillOp(rewriter, op, layoutResolvers);
535+
})
536+
.Case<IREE::Stream::TensorCloneOp>(
537+
[&](auto op) { return updateTensorCloneOp(rewriter, op); })
538+
.Case<IREE::Stream::TensorSliceOp>(
539+
[&](auto op) { return updateTensorSliceOp(rewriter, op); })
540+
.Case<IREE::Stream::TensorUpdateOp>(
541+
[&](auto op) { return updateTensorUpdateOp(rewriter, op); })
445542
.Default([](Operation *op) {
446543
return op->emitOpError("Unhandled stream op");
447544
});

compiler/src/iree/compiler/Dialect/Stream/Transforms/test/specialize_encodings.mlir

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,39 @@ module {
6565

6666
// -----
6767

68+
#map0 = affine_map<(m, n, k) -> (m, k)>
69+
#map1 = affine_map<(m, n, k) -> (k, n)>
70+
#map2 = affine_map<(m, n, k) -> (m, n)>
71+
#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_encoding.unspecialized_encoding<123>}>
72+
#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device
73+
#encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map0, #map1, #map2]>
74+
module {
75+
util.global private @device_a = #device_target_local_0_
76+
77+
util.func public @tensor_fill_op(%arg0: f32, %arg1: !stream.resource<*>, %arg2: index, %arg3: index) {
78+
%c0 = arith.constant 0 : index
79+
%c1 = arith.constant 1 : index
80+
%0 = stream.tensor.fill on(#hal.device.affinity<@device_a>)
81+
%arg0, %arg1[%c0, %c0 for %c1, %c1] : f32
82+
-> tensor<?x4xf32, #encoding>{%arg2} in %arg1 as !stream.resource<*>{%arg3}
83+
util.return
84+
}
85+
}
86+
// CHECK-DAG: #[[$ENCODING:.+]] = #iree_encoding.encoding<{{.+}} layouts = [#iree_encoding.specialized_encoding<123, tensor<?x4xf32>>]
87+
// CHECK: #[[TARGET:.+]] = #hal.device.target
88+
// CHECK: util.global private @[[$DEVICE:.+]] = #[[TARGET]]
89+
// CHECK-LABEL: util.func public @tensor_fill_op
90+
// CHECK: stream.tensor.fill on(#hal.device.affinity<@[[$DEVICE]]>)
91+
// CHECK-SAME: f32 -> tensor<?x4xf32, #[[$ENCODING]]>
92+
93+
// -----
94+
6895
// Checks that the stream.tensor.constant op with encoding is not supported.
6996

7097
#map0 = affine_map<(m, n, k) -> (m, k)>
7198
#map1 = affine_map<(m, n, k) -> (k, n)>
7299
#map2 = affine_map<(m, n, k) -> (m, n)>
73-
#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_cpu.vmvx_encoding_layout<>}>
100+
#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_encoding.unspecialized_encoding<123>}>
74101
#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device
75102
#encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map0, #map1, #map2]>
76103
module {
@@ -85,6 +112,76 @@ module {
85112

86113
// -----
87114

115+
// Checks that the stream.tensor.clone op with encoding is not supported.
116+
117+
#map0 = affine_map<(m, n, k) -> (m, k)>
118+
#map1 = affine_map<(m, n, k) -> (k, n)>
119+
#map2 = affine_map<(m, n, k) -> (m, n)>
120+
#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_encoding.unspecialized_encoding<123>}>
121+
#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device
122+
#encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map0, #map1, #map2]>
123+
module {
124+
util.global private @device_a = #device_target_local_0_
125+
126+
// expected-error @+1 {{failed on adding layouts to Stream::TensorPhaseOp with encodings}}
127+
util.func public @tensor_clone_op(%arg0: !stream.resource<*>, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
128+
%0 = stream.tensor.clone on(#hal.device.affinity<@device_a>)
129+
%arg0 : tensor<?x4xf32, #encoding>{%arg1} in !stream.resource<*>{%arg2}
130+
-> tensor<?x4xf32, #encoding>{%arg1} in !stream.resource<*>{%arg2}
131+
util.return
132+
}
133+
}
134+
135+
// -----
136+
137+
// Checks that the stream.tensor.slice op with encoding is not supported.
138+
139+
#map0 = affine_map<(m, n, k) -> (m, k)>
140+
#map1 = affine_map<(m, n, k) -> (k, n)>
141+
#map2 = affine_map<(m, n, k) -> (m, n)>
142+
#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_encoding.unspecialized_encoding<123>}>
143+
#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device
144+
#encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map0, #map1, #map2]>
145+
module {
146+
util.global private @device_a = #device_target_local_0_
147+
148+
// expected-error @+1 {{failed on adding layouts to Stream::TensorPhaseOp with encodings}}
149+
util.func public @tensor_slice_op_with_encoding(%arg0: !stream.resource<*>, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
150+
%c0 = arith.constant 0 : index
151+
%c1 = arith.constant 1 : index
152+
%1 = stream.tensor.slice on(#hal.device.affinity<@device_a>)
153+
%arg0[%c0, %c1 for %arg3, %c1] : tensor<?x4xf32, #encoding>{%arg1} in !stream.resource<*>{%arg2}
154+
-> tensor<?x1xf32, #encoding>{%arg3} in !stream.resource<*>{%arg4}
155+
util.return
156+
}
157+
}
158+
159+
// -----
160+
161+
// Checks that the stream.tensor.update op with encoding is not supported.
162+
163+
#map0 = affine_map<(m, n, k) -> (m, k)>
164+
#map1 = affine_map<(m, n, k) -> (k, n)>
165+
#map2 = affine_map<(m, n, k) -> (m, n)>
166+
#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_encoding.unspecialized_encoding<123>}>
167+
#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device
168+
#encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map0, #map1, #map2]>
169+
module {
170+
util.global private @device_a = #device_target_local_0_
171+
172+
// expected-error @+1 {{failed on adding layouts to Stream::TensorPhaseOp with encodings}}
173+
util.func public @tensor_update_op(%arg0: !stream.resource<*>, %arg1: index, %arg2: !stream.resource<*>, %arg3: index, %arg4: index) {
174+
%c0 = arith.constant 0 : index
175+
%c1 = arith.constant 1 : index
176+
%0 = stream.tensor.update on(#hal.device.affinity<@device_a>)
177+
%arg0, %arg2[%c0, %c0] : tensor<2x2xf32, #encoding> in !stream.resource<*>{%arg1}
178+
-> tensor<?x4xf32, #encoding>{%arg3} in %arg2 as !stream.resource<*>{%arg4}
179+
util.return
180+
}
181+
}
182+
183+
// -----
184+
88185
#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_encoding.unspecialized_encoding<123>}>
89186
#map = affine_map<(d0) -> (d0)>
90187
#map0 = affine_map<(m, n, k) -> (m, k)>

0 commit comments

Comments
 (0)