Skip to content

Commit d11b876

Browse files
authored
[Stream] Enable batch affinity queries in SpecializeEncoding pass. (#19975)
The returned function (i.e., `ResolveLayoutAttrFn`) can be very inefficient because there could be other data-flow analysis in a run. The revision updates the `ResolveLayoutAttrFn` API. Now it accepts a list of query, and it stores the results to the map of `SetVector<Attribute>`. In the encoding specialization pass, it introduces `StreamTensorOpUpdater` class. There are two phases in the updater. The class caches all the queries in `init()`, and updates all the encodings in `run()`. The `init` method is introduced because there could be a failure in the initialization. In this context, we do not put them to the constructor because we can not signal the error in constructors. See https://google.github.io/styleguide/cppguide.html#Doing_Work_in_Constructors The pass gets 440x speed-up for one of SDXL compilation. The lit test configuration change (i.e., `--pass-pipeline='builtin.module(iree-stream-specialize-encodings)'`) is needed because we want to validate failures for unsupported encodings. --------- Signed-off-by: hanhanW <[email protected]>
1 parent d3cfe11 commit d11b876

File tree

4 files changed

+238
-112
lines changed

4 files changed

+238
-112
lines changed

compiler/src/iree/compiler/Dialect/HAL/IR/HALDialect.cpp

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -123,30 +123,43 @@ class HALAffinityAnalysisDialectInterface
123123
: public IREE::Stream::AffinityAnalysisDialectInterface {
124124
public:
125125
using AffinityAnalysisDialectInterface::AffinityAnalysisDialectInterface;
126+
127+
// Returns a function that gathers the corresponding
128+
// EncodingLayoutAttrInterface attributes for each
129+
// (IREE::Stream::Affinity, Operation) query. The attribute is extracted from
130+
// the `encoding` field in the HAL::ExecutableTargetAttr configuration. If the
131+
// `encoding` is not present, the target attribute is returned.
126132
IREE::Stream::ResolveLayoutAttrFn
127133
makeLayoutAttrResolver(ModuleOp moduleOp) const {
128-
return [=](IREE::Stream::AffinityAttr affinityAttr, Operation *op,
129-
SetVector<Attribute> &layoutAttrs) -> LogicalResult {
130-
// This needs to be in the lambda because the moduleOp could be modified..
134+
return [=](ArrayRef<IREE::Stream::AffinityAndOpPair> batchQueries,
135+
llvm::DenseMap<IREE::Stream::AffinityAndOpPair,
136+
SetVector<Attribute>> &layoutAttrs)
137+
-> LogicalResult {
138+
// This needs to be in the lambda because the moduleOp could be modified.
131139
IREE::HAL::DeviceAnalysis deviceAnalysis(moduleOp);
132140
if (failed(deviceAnalysis.run())) {
133-
return op->emitError("failed to run DeviceAnalysis");
141+
return moduleOp->emitError("failed to run DeviceAnalysis");
134142
}
135-
SetVector<IREE::HAL::ExecutableTargetAttr> resultSet;
136-
deviceAnalysis.gatherRequiredExecutableTargets(affinityAttr, op,
137-
resultSet);
138-
for (auto targetAttr : resultSet) {
139-
Attribute result = targetAttr;
140-
if (auto attr = targetAttr.getConfiguration().getNamed("encoding")) {
141-
if (auto encodingLayoutAttr =
142-
dyn_cast<IREE::Encoding::EncodingLayoutAttrInterface>(
143-
attr->getValue())) {
144-
result = encodingLayoutAttr.cloneWithSimplifiedConfig(
145-
targetAttr.getConfiguration());
143+
144+
for (IREE::Stream::AffinityAndOpPair key : batchQueries) {
145+
auto [affinityAttr, op] = key;
146+
SetVector<IREE::HAL::ExecutableTargetAttr> resultSet;
147+
deviceAnalysis.gatherRequiredExecutableTargets(affinityAttr, op,
148+
resultSet);
149+
for (auto targetAttr : resultSet) {
150+
Attribute result = targetAttr;
151+
if (auto attr = targetAttr.getConfiguration().getNamed("encoding")) {
152+
if (auto encodingLayoutAttr =
153+
dyn_cast<IREE::Encoding::EncodingLayoutAttrInterface>(
154+
attr->getValue())) {
155+
result = encodingLayoutAttr.cloneWithSimplifiedConfig(
156+
targetAttr.getConfiguration());
157+
}
146158
}
159+
layoutAttrs[key].insert(result);
147160
}
148-
layoutAttrs.insert(result);
149161
}
162+
150163
return success();
151164
};
152165
};

compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,13 @@
1616

1717
namespace mlir::iree_compiler::IREE::Stream {
1818

19+
using AffinityAndOpPair = std::pair<AffinityAttr, Operation *>;
20+
21+
// The function could be slow, if any data flow analysis is involved. Thus, the
22+
// API provides the batch mode.
1923
using ResolveLayoutAttrFn = std::function<LogicalResult(
20-
AffinityAttr, Operation *, SetVector<Attribute> &)>;
24+
ArrayRef<AffinityAndOpPair> batchQueries,
25+
llvm::DenseMap<AffinityAndOpPair, SetVector<Attribute>> &layoutAttrs)>;
2126

2227
class AffinityAnalysisDialectInterface
2328
: public DialectInterface::Base<AffinityAnalysisDialectInterface> {

compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp

Lines changed: 166 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "llvm/Support/Debug.h"
1818
#include "llvm/Support/LogicalResult.h"
1919
#include "mlir/IR/BuiltinAttributes.h"
20+
#include "mlir/IR/BuiltinOps.h"
2021
#include "mlir/IR/BuiltinTypes.h"
2122
#include "mlir/IR/PatternMatch.h"
2223
#include "mlir/IR/SymbolTable.h"
@@ -57,6 +58,8 @@ SmallVector<const T *> gatherUsedDialectInterfaces(mlir::ModuleOp moduleOp) {
5758
return results;
5859
}
5960

61+
} // namespace
62+
6063
// Returns an updated encoding attribute if the type is a RankedTensorType
6164
// and an EncodingAttr is present. Otherwise, returns std::nullopt. The
6265
// method uses the EncodingLayoutAttrInterface from the EncodingAttr to
@@ -274,14 +277,140 @@ static RankedTensorType cloneWithEncoding(RankedTensorType type,
274277
encodingAttr);
275278
}
276279

280+
/// Returns all the stream tensor ops that implement AffinityOpInterface, where
281+
/// a stream affinity indicates the kind of enviroment the ops are expected run
282+
/// in.
283+
static SmallVector<IREE::Stream::AffinityOpInterface>
284+
collectStreamTensorOps(FunctionOpInterface funcOp) {
285+
SmallVector<IREE::Stream::AffinityOpInterface> result;
286+
funcOp.walk([&](IREE::Stream::AffinityOpInterface affinityOp) {
287+
// Only need to update encoding types for ops that have TensorPhaseOp trait.
288+
if (!affinityOp->hasTrait<OpTrait::IREE::Stream::TensorPhaseOp>()) {
289+
return;
290+
}
291+
292+
// Bail out if the operation does not have an affinity attribute.
293+
auto affinityAttr = affinityOp.getAffinityAttr();
294+
if (!affinityAttr) {
295+
return;
296+
}
297+
result.push_back(affinityOp);
298+
});
299+
return result;
300+
}
301+
302+
namespace {
303+
304+
// Adds the resolved layouts to all tensor types on stream tensor ops, if
305+
// encodings are present. Most of stream tensor ops implement
306+
// AffinityOpInterface, where a stream affinity indicates the kind of
307+
// enviroment the ops are expected run in. When an encoding is present in the
308+
// tensor type, the method resolves the layouts, strips outdated information,
309+
// and adds the resolved layouts to the encodings. The updated encodings should
310+
// have enough information for other lowering transformations.
311+
// TODO(hanchung): Add support for stream.tensor.load ops and
312+
// stream.tensor.store ops. They are not affinity ops, so additional analysis
313+
// will be needed in the work.
314+
class StreamTensorOpUpdater {
315+
public:
316+
explicit StreamTensorOpUpdater(ModuleOp moduleOp) : moduleOp(moduleOp){};
317+
~StreamTensorOpUpdater() {}
318+
319+
// Collects the stream tensor op candidates, and prepares all the needed
320+
// information for the update. This must be called once before calling `run`.
321+
// Note that all the ops are unmodified after the execution.
322+
LogicalResult init();
323+
324+
// Adds the resolved layouts to all tensor types of `streamOps`, if encodings
325+
// are present.
326+
LogicalResult run();
327+
328+
private:
329+
// Appends the query from the `affinityOp` to `queries`. Note that most of
330+
// operations only care the execution affinity. There are outliers (e.g.,
331+
// tensor dispatch op, etc.) that need to resolve affinities for
332+
// operand resources.
333+
LogicalResult addQuery(IREE::Stream::AffinityAnalysis &affinityAnalysis,
334+
IREE::Stream::AffinityOpInterface affinityOp);
335+
336+
// The list of the queries that can be used for batch affinity queries. The
337+
// analysis could be very expensive because it could apply the whole program
338+
// data flow analysis.
339+
SmallVector<IREE::Stream::AffinityAndOpPair> queries;
340+
341+
// The layout resolvers for each query.
342+
llvm::DenseMap<IREE::Stream::AffinityAndOpPair, SetVector<Attribute>>
343+
cachedLayoutAttrs;
344+
345+
// Input moduleOp. The op is not expected to be updated during the query.
346+
// Because data flow analaysis can be involved. Modifying the IR invalidates
347+
// the state and may lead to crashes as pointer references into the IR
348+
// structure are retained.
349+
ModuleOp moduleOp;
350+
351+
// The ops that need to be updated.
352+
SmallVector<IREE::Stream::AffinityOpInterface> streamOps;
353+
354+
// The layout resolver function, which is used to resolve layouts for
355+
// encodings. See StreamInterfaces.h for more details.
356+
IREE::Stream::ResolveLayoutAttrFn resolveLayoutAttr;
357+
};
358+
359+
} // namespace
360+
361+
LogicalResult StreamTensorOpUpdater::init() {
362+
auto usedDialects = gatherUsedDialectInterfaces<
363+
IREE::Stream::AffinityAnalysisDialectInterface>(moduleOp);
364+
if (usedDialects.size() != 1) {
365+
return moduleOp.emitError("expected only one dialect implementing "
366+
"AffinityAnalysisDialectInterface");
367+
}
368+
resolveLayoutAttr = usedDialects[0]->makeLayoutAttrResolver(moduleOp);
369+
370+
for (auto funcOp : moduleOp.getOps<FunctionOpInterface>()) {
371+
streamOps.append(collectStreamTensorOps(funcOp));
372+
}
373+
374+
return success();
375+
}
376+
377+
LogicalResult StreamTensorOpUpdater::addQuery(
378+
IREE::Stream::AffinityAnalysis &affinityAnalysis,
379+
IREE::Stream::AffinityOpInterface affinityOp) {
380+
queries.emplace_back(affinityOp.getAffinityAttr(), affinityOp);
381+
382+
if (auto dispatchOp =
383+
dyn_cast<IREE::Stream::TensorDispatchOp>(affinityOp.getOperation())) {
384+
for (auto [operand, typeAttr] :
385+
llvm::zip_equal(dispatchOp.getMixedOperands(),
386+
dispatchOp.getOperandEncodings().getValue())) {
387+
auto type = cast<TypeAttr>(typeAttr).getValue();
388+
// Skip if the operand type is not AffinityType.
389+
if (!isa<IREE::Stream::AffinityTypeInterface>(type)) {
390+
continue;
391+
}
392+
SmallVector<IREE::Stream::AffinityAttr> affinityAttrs;
393+
if (!affinityAnalysis.tryLookupResourceAffinity(operand, affinityAttrs)) {
394+
return failure();
395+
}
396+
for (auto affinity : affinityAttrs) {
397+
queries.emplace_back(affinity, affinityOp);
398+
}
399+
}
400+
}
401+
402+
return success();
403+
}
404+
277405
/// Updates the operand encondings and result encodings for the `dispatchOp`
278406
/// with resolved layouts.
279-
static LogicalResult
280-
updateTensorDispatchOp(RewriterBase &rewriter, ModuleOp moduleOp,
281-
IREE::Stream::AffinityAnalysis &affinityAnalysis,
282-
IREE::Stream::TensorDispatchOp dispatchOp,
283-
const SetVector<Attribute> &resLayoutResolvers,
284-
IREE::Stream::ResolveLayoutAttrFn resolveLayoutAttr) {
407+
static LogicalResult updateTensorDispatchOp(
408+
RewriterBase &rewriter, ModuleOp moduleOp,
409+
IREE::Stream::AffinityAnalysis &affinityAnalysis,
410+
IREE::Stream::TensorDispatchOp dispatchOp,
411+
const SetVector<Attribute> &resLayoutResolvers,
412+
llvm::DenseMap<IREE::Stream::AffinityAndOpPair, SetVector<Attribute>>
413+
&cachedLayoutAttrs) {
285414
SmallVector<Type> newOperandEncodings;
286415
for (auto [operand, typeAttr] :
287416
llvm::zip_equal(dispatchOp.getMixedOperands(),
@@ -299,11 +428,11 @@ updateTensorDispatchOp(RewriterBase &rewriter, ModuleOp moduleOp,
299428
if (affinityAttrs.size() != 1) {
300429
return failure();
301430
}
302-
SetVector<Attribute> layoutResolvers;
303-
if (failed(
304-
resolveLayoutAttr(affinityAttrs[0], moduleOp, layoutResolvers))) {
305-
return dispatchOp.emitError("failed on making layout resolvers");
306-
}
431+
432+
IREE::Stream::AffinityAndOpPair key(affinityAttrs[0], dispatchOp);
433+
assert(cachedLayoutAttrs.contains(key) &&
434+
"the (affinity, dispatchOp) query is invalid");
435+
const SetVector<Attribute> &layoutResolvers = cachedLayoutAttrs[key];
307436

308437
std::optional<IREE::Encoding::EncodingAttr> encodingAttr =
309438
getEncodingWithNewLayouts(type, layoutResolvers);
@@ -325,7 +454,6 @@ updateTensorDispatchOp(RewriterBase &rewriter, ModuleOp moduleOp,
325454
newResultEncodings.push_back(type);
326455
continue;
327456
}
328-
329457
std::optional<IREE::Encoding::EncodingAttr> encodingAttr =
330458
getEncodingWithNewLayouts(type, resLayoutResolvers);
331459
if (!encodingAttr) {
@@ -472,53 +600,34 @@ updateResultEncoding(RewriterBase &rewriter, OpTy op,
472600
return success();
473601
}
474602

475-
/// Adds the resolved layouts to all tensor types on stream tensor ops, if
476-
/// encodings are present. Most of stream tensor ops implement
477-
/// AffinityOpInterface, where a stream affinity indicates the kind of
478-
/// enviroment the ops are expected run in. When an encoding is present in the
479-
/// tensor type, the method resolves the layouts, strips outdated information,
480-
/// and adds the resolved layouts to the encodings. The updated encodings should
481-
/// have enough information for other lowering transformations.
482-
/// TODO(hanchung): Add support for stream.tensor.load ops and
483-
/// stream.tensor.store ops. They are not affinity ops, so additional analysis
484-
/// will be needed in the work.
485-
static LogicalResult addLayoutsToTensorPhaseOps(
486-
ModuleOp moduleOp, IREE::Stream::AffinityAnalysis &affinityAnalysis,
487-
FunctionOpInterface funcOp,
488-
IREE::Stream::ResolveLayoutAttrFn resolveLayoutAttr) {
489-
SmallVector<IREE::Stream::AffinityOpInterface> candidates;
490-
funcOp.walk([&](IREE::Stream::AffinityOpInterface affinityOp) {
491-
// Only need to update encoding types for ops that have TensorPhaseOp trait.
492-
if (!affinityOp->hasTrait<OpTrait::IREE::Stream::TensorPhaseOp>()) {
493-
return;
494-
}
603+
LogicalResult StreamTensorOpUpdater::run() {
604+
IREE::Stream::AffinityAnalysis affinityAnalysis(moduleOp);
605+
if (failed(affinityAnalysis.run())) {
606+
return moduleOp.emitError("failed on running affinity analysis");
607+
}
495608

496-
// Bail out if the operation does not have an affinity attribute.
497-
auto affinityAttr = affinityOp.getAffinityAttr();
498-
if (!affinityAttr) {
499-
return;
609+
for (auto op : streamOps) {
610+
if (failed(addQuery(affinityAnalysis, op))) {
611+
return failure();
500612
}
501-
candidates.push_back(affinityOp);
502-
});
613+
}
503614

504-
if (candidates.empty()) {
505-
return success();
615+
if (failed(resolveLayoutAttr(queries, cachedLayoutAttrs))) {
616+
return failure();
506617
}
507618

508-
IRRewriter rewriter(funcOp.getContext());
509-
for (auto affinityOp : candidates) {
510-
auto affinityAttr = affinityOp.getAffinityAttr();
511-
SetVector<Attribute> layoutResolvers;
512-
if (failed(resolveLayoutAttr(affinityAttr, moduleOp, layoutResolvers))) {
513-
return affinityOp.emitError("failed on making layout resolvers");
514-
}
619+
IRRewriter rewriter(moduleOp.getContext());
620+
for (auto affinityOp : streamOps) {
621+
const SetVector<Attribute> &layoutResolvers =
622+
cachedLayoutAttrs[IREE::Stream::AffinityAndOpPair(
623+
affinityOp.getAffinityAttr(), affinityOp)];
515624

516625
LogicalResult result =
517626
TypeSwitch<Operation *, LogicalResult>(affinityOp)
518627
.Case<IREE::Stream::TensorDispatchOp>([&](auto op) {
519628
return updateTensorDispatchOp(rewriter, moduleOp,
520629
affinityAnalysis, op,
521-
layoutResolvers, resolveLayoutAttr);
630+
layoutResolvers, cachedLayoutAttrs);
522631
})
523632
.Case<IREE::Stream::TensorSizeOfOp>([&](auto op) {
524633
return updateTensorSizeOfOp(rewriter, op, layoutResolvers);
@@ -549,36 +658,26 @@ static LogicalResult addLayoutsToTensorPhaseOps(
549658
}
550659
return success();
551660
}
552-
} // namespace
553661

662+
namespace {
554663
struct SpecializeEncodingsPass
555664
: public impl::SpecializeEncodingsPassBase<SpecializeEncodingsPass> {
556665
void runOnOperation() override {
557666
ModuleOp moduleOp = getOperation();
558-
auto usedDialects = gatherUsedDialectInterfaces<
559-
IREE::Stream::AffinityAnalysisDialectInterface>(moduleOp);
560-
if (usedDialects.size() != 1) {
561-
moduleOp.emitError("expected only one dialect implementing "
562-
"AffinityAnalysisDialectInterface");
667+
668+
StreamTensorOpUpdater streamTensorOpUpdater(moduleOp);
669+
if (failed(streamTensorOpUpdater.init())) {
670+
moduleOp.emitError("failed to initialize StreamTensorOpUpdater");
563671
return signalPassFailure();
564672
}
565-
566-
IREE::Stream::AffinityAnalysis affinityAnalysis(moduleOp);
567-
if (failed(affinityAnalysis.run())) {
568-
moduleOp.emitError("failed on running affinity analysis");
673+
if (failed(streamTensorOpUpdater.run())) {
674+
moduleOp.emitError(
675+
"failed to add layouts to Stream::TensorPhaseOp with encodings");
569676
return signalPassFailure();
570677
}
571678

572679
SymbolTable symbolTable(moduleOp);
573-
IREE::Stream::ResolveLayoutAttrFn resolveLayoutAttr =
574-
usedDialects[0]->makeLayoutAttrResolver(moduleOp);
575680
for (auto funcOp : moduleOp.getOps<FunctionOpInterface>()) {
576-
if (failed(addLayoutsToTensorPhaseOps(moduleOp, affinityAnalysis, funcOp,
577-
resolveLayoutAttr))) {
578-
funcOp.emitError(
579-
"failed on adding layouts to Stream::TensorPhaseOp with encodings");
580-
return signalPassFailure();
581-
}
582681
if (failed(duplicateExecutablesPerLayoutVariant(moduleOp, symbolTable,
583682
funcOp))) {
584683
funcOp.emitError("failed on executable duplication");
@@ -587,5 +686,6 @@ struct SpecializeEncodingsPass
587686
}
588687
}
589688
};
689+
} // namespace
590690

591691
} // namespace mlir::iree_compiler::IREE::Stream

0 commit comments

Comments
 (0)