1717#include " llvm/Support/Debug.h"
1818#include " llvm/Support/LogicalResult.h"
1919#include " mlir/IR/BuiltinAttributes.h"
20+ #include " mlir/IR/BuiltinOps.h"
2021#include " mlir/IR/BuiltinTypes.h"
2122#include " mlir/IR/PatternMatch.h"
2223#include " mlir/IR/SymbolTable.h"
@@ -57,6 +58,8 @@ SmallVector<const T *> gatherUsedDialectInterfaces(mlir::ModuleOp moduleOp) {
5758 return results;
5859}
5960
61+ } // namespace
62+
6063// Returns an updated encoding attribute if the type is a RankedTensorType
6164// and an EncodingAttr is present. Otherwise, returns std::nullopt. The
6265// method uses the EncodingLayoutAttrInterface from the EncodingAttr to
@@ -274,14 +277,140 @@ static RankedTensorType cloneWithEncoding(RankedTensorType type,
274277 encodingAttr);
275278}
276279
280+ // / Returns all the stream tensor ops that implement AffinityOpInterface, where
281+ // / a stream affinity indicates the kind of enviroment the ops are expected run
282+ // / in.
283+ static SmallVector<IREE::Stream::AffinityOpInterface>
284+ collectStreamTensorOps (FunctionOpInterface funcOp) {
285+ SmallVector<IREE::Stream::AffinityOpInterface> result;
286+ funcOp.walk ([&](IREE::Stream::AffinityOpInterface affinityOp) {
287+ // Only need to update encoding types for ops that have TensorPhaseOp trait.
288+ if (!affinityOp->hasTrait <OpTrait::IREE::Stream::TensorPhaseOp>()) {
289+ return ;
290+ }
291+
292+ // Bail out if the operation does not have an affinity attribute.
293+ auto affinityAttr = affinityOp.getAffinityAttr ();
294+ if (!affinityAttr) {
295+ return ;
296+ }
297+ result.push_back (affinityOp);
298+ });
299+ return result;
300+ }
301+
302+ namespace {
303+
304+ // Adds the resolved layouts to all tensor types on stream tensor ops, if
305+ // encodings are present. Most of stream tensor ops implement
306+ // AffinityOpInterface, where a stream affinity indicates the kind of
307+ // enviroment the ops are expected run in. When an encoding is present in the
308+ // tensor type, the method resolves the layouts, strips outdated information,
309+ // and adds the resolved layouts to the encodings. The updated encodings should
310+ // have enough information for other lowering transformations.
311+ // TODO(hanchung): Add support for stream.tensor.load ops and
312+ // stream.tensor.store ops. They are not affinity ops, so additional analysis
313+ // will be needed in the work.
314+ class StreamTensorOpUpdater {
315+ public:
316+ explicit StreamTensorOpUpdater (ModuleOp moduleOp) : moduleOp(moduleOp){};
317+ ~StreamTensorOpUpdater () {}
318+
319+ // Collects the stream tensor op candidates, and prepares all the needed
320+ // information for the update. This must be called once before calling `run`.
321+ // Note that all the ops are unmodified after the execution.
322+ LogicalResult init ();
323+
324+ // Adds the resolved layouts to all tensor types of `streamOps`, if encodings
325+ // are present.
326+ LogicalResult run ();
327+
328+ private:
329+ // Appends the query from the `affinityOp` to `queries`. Note that most of
330+ // operations only care the execution affinity. There are outliers (e.g.,
331+ // tensor dispatch op, etc.) that need to resolve affinities for
332+ // operand resources.
333+ LogicalResult addQuery (IREE::Stream::AffinityAnalysis &affinityAnalysis,
334+ IREE::Stream::AffinityOpInterface affinityOp);
335+
336+ // The list of the queries that can be used for batch affinity queries. The
337+ // analysis could be very expensive because it could apply the whole program
338+ // data flow analysis.
339+ SmallVector<IREE::Stream::AffinityAndOpPair> queries;
340+
341+ // The layout resolvers for each query.
342+ llvm::DenseMap<IREE::Stream::AffinityAndOpPair, SetVector<Attribute>>
343+ cachedLayoutAttrs;
344+
345+ // Input moduleOp. The op is not expected to be updated during the query.
346+ // Because data flow analaysis can be involved. Modifying the IR invalidates
347+ // the state and may lead to crashes as pointer references into the IR
348+ // structure are retained.
349+ ModuleOp moduleOp;
350+
351+ // The ops that need to be updated.
352+ SmallVector<IREE::Stream::AffinityOpInterface> streamOps;
353+
354+ // The layout resolver function, which is used to resolve layouts for
355+ // encodings. See StreamInterfaces.h for more details.
356+ IREE::Stream::ResolveLayoutAttrFn resolveLayoutAttr;
357+ };
358+
359+ } // namespace
360+
361+ LogicalResult StreamTensorOpUpdater::init () {
362+ auto usedDialects = gatherUsedDialectInterfaces<
363+ IREE::Stream::AffinityAnalysisDialectInterface>(moduleOp);
364+ if (usedDialects.size () != 1 ) {
365+ return moduleOp.emitError (" expected only one dialect implementing "
366+ " AffinityAnalysisDialectInterface" );
367+ }
368+ resolveLayoutAttr = usedDialects[0 ]->makeLayoutAttrResolver (moduleOp);
369+
370+ for (auto funcOp : moduleOp.getOps <FunctionOpInterface>()) {
371+ streamOps.append (collectStreamTensorOps (funcOp));
372+ }
373+
374+ return success ();
375+ }
376+
377+ LogicalResult StreamTensorOpUpdater::addQuery (
378+ IREE::Stream::AffinityAnalysis &affinityAnalysis,
379+ IREE::Stream::AffinityOpInterface affinityOp) {
380+ queries.emplace_back (affinityOp.getAffinityAttr (), affinityOp);
381+
382+ if (auto dispatchOp =
383+ dyn_cast<IREE::Stream::TensorDispatchOp>(affinityOp.getOperation ())) {
384+ for (auto [operand, typeAttr] :
385+ llvm::zip_equal (dispatchOp.getMixedOperands (),
386+ dispatchOp.getOperandEncodings ().getValue ())) {
387+ auto type = cast<TypeAttr>(typeAttr).getValue ();
388+ // Skip if the operand type is not AffinityType.
389+ if (!isa<IREE::Stream::AffinityTypeInterface>(type)) {
390+ continue ;
391+ }
392+ SmallVector<IREE::Stream::AffinityAttr> affinityAttrs;
393+ if (!affinityAnalysis.tryLookupResourceAffinity (operand, affinityAttrs)) {
394+ return failure ();
395+ }
396+ for (auto affinity : affinityAttrs) {
397+ queries.emplace_back (affinity, affinityOp);
398+ }
399+ }
400+ }
401+
402+ return success ();
403+ }
404+
277405// / Updates the operand encondings and result encodings for the `dispatchOp`
278406// / with resolved layouts.
279- static LogicalResult
280- updateTensorDispatchOp (RewriterBase &rewriter, ModuleOp moduleOp,
281- IREE::Stream::AffinityAnalysis &affinityAnalysis,
282- IREE::Stream::TensorDispatchOp dispatchOp,
283- const SetVector<Attribute> &resLayoutResolvers,
284- IREE::Stream::ResolveLayoutAttrFn resolveLayoutAttr) {
407+ static LogicalResult updateTensorDispatchOp (
408+ RewriterBase &rewriter, ModuleOp moduleOp,
409+ IREE::Stream::AffinityAnalysis &affinityAnalysis,
410+ IREE::Stream::TensorDispatchOp dispatchOp,
411+ const SetVector<Attribute> &resLayoutResolvers,
412+ llvm::DenseMap<IREE::Stream::AffinityAndOpPair, SetVector<Attribute>>
413+ &cachedLayoutAttrs) {
285414 SmallVector<Type> newOperandEncodings;
286415 for (auto [operand, typeAttr] :
287416 llvm::zip_equal (dispatchOp.getMixedOperands (),
@@ -299,11 +428,11 @@ updateTensorDispatchOp(RewriterBase &rewriter, ModuleOp moduleOp,
299428 if (affinityAttrs.size () != 1 ) {
300429 return failure ();
301430 }
302- SetVector<Attribute> layoutResolvers;
303- if ( failed (
304- resolveLayoutAttr (affinityAttrs[ 0 ], moduleOp, layoutResolvers))) {
305- return dispatchOp. emitError ( " failed on making layout resolvers " );
306- }
431+
432+ IREE::Stream::AffinityAndOpPair key (affinityAttrs[ 0 ], dispatchOp);
433+ assert (cachedLayoutAttrs. contains (key) &&
434+ " the (affinity, dispatchOp) query is invalid " );
435+ const SetVector<Attribute> &layoutResolvers = cachedLayoutAttrs[key];
307436
308437 std::optional<IREE::Encoding::EncodingAttr> encodingAttr =
309438 getEncodingWithNewLayouts (type, layoutResolvers);
@@ -325,7 +454,6 @@ updateTensorDispatchOp(RewriterBase &rewriter, ModuleOp moduleOp,
325454 newResultEncodings.push_back (type);
326455 continue ;
327456 }
328-
329457 std::optional<IREE::Encoding::EncodingAttr> encodingAttr =
330458 getEncodingWithNewLayouts (type, resLayoutResolvers);
331459 if (!encodingAttr) {
@@ -472,53 +600,34 @@ updateResultEncoding(RewriterBase &rewriter, OpTy op,
472600 return success ();
473601}
474602
475- // / Adds the resolved layouts to all tensor types on stream tensor ops, if
476- // / encodings are present. Most of stream tensor ops implement
477- // / AffinityOpInterface, where a stream affinity indicates the kind of
478- // / enviroment the ops are expected run in. When an encoding is present in the
479- // / tensor type, the method resolves the layouts, strips outdated information,
480- // / and adds the resolved layouts to the encodings. The updated encodings should
481- // / have enough information for other lowering transformations.
482- // / TODO(hanchung): Add support for stream.tensor.load ops and
483- // / stream.tensor.store ops. They are not affinity ops, so additional analysis
484- // / will be needed in the work.
485- static LogicalResult addLayoutsToTensorPhaseOps (
486- ModuleOp moduleOp, IREE::Stream::AffinityAnalysis &affinityAnalysis,
487- FunctionOpInterface funcOp,
488- IREE::Stream::ResolveLayoutAttrFn resolveLayoutAttr) {
489- SmallVector<IREE::Stream::AffinityOpInterface> candidates;
490- funcOp.walk ([&](IREE::Stream::AffinityOpInterface affinityOp) {
491- // Only need to update encoding types for ops that have TensorPhaseOp trait.
492- if (!affinityOp->hasTrait <OpTrait::IREE::Stream::TensorPhaseOp>()) {
493- return ;
494- }
603+ LogicalResult StreamTensorOpUpdater::run () {
604+ IREE::Stream::AffinityAnalysis affinityAnalysis (moduleOp);
605+ if (failed (affinityAnalysis.run ())) {
606+ return moduleOp.emitError (" failed on running affinity analysis" );
607+ }
495608
496- // Bail out if the operation does not have an affinity attribute.
497- auto affinityAttr = affinityOp.getAffinityAttr ();
498- if (!affinityAttr) {
499- return ;
609+ for (auto op : streamOps) {
610+ if (failed (addQuery (affinityAnalysis, op))) {
611+ return failure ();
500612 }
501- candidates.push_back (affinityOp);
502- });
613+ }
503614
504- if (candidates. empty ( )) {
505- return success ();
615+ if (failed ( resolveLayoutAttr (queries, cachedLayoutAttrs) )) {
616+ return failure ();
506617 }
507618
508- IRRewriter rewriter (funcOp.getContext ());
509- for (auto affinityOp : candidates) {
510- auto affinityAttr = affinityOp.getAffinityAttr ();
511- SetVector<Attribute> layoutResolvers;
512- if (failed (resolveLayoutAttr (affinityAttr, moduleOp, layoutResolvers))) {
513- return affinityOp.emitError (" failed on making layout resolvers" );
514- }
619+ IRRewriter rewriter (moduleOp.getContext ());
620+ for (auto affinityOp : streamOps) {
621+ const SetVector<Attribute> &layoutResolvers =
622+ cachedLayoutAttrs[IREE::Stream::AffinityAndOpPair (
623+ affinityOp.getAffinityAttr (), affinityOp)];
515624
516625 LogicalResult result =
517626 TypeSwitch<Operation *, LogicalResult>(affinityOp)
518627 .Case <IREE::Stream::TensorDispatchOp>([&](auto op) {
519628 return updateTensorDispatchOp (rewriter, moduleOp,
520629 affinityAnalysis, op,
521- layoutResolvers, resolveLayoutAttr );
630+ layoutResolvers, cachedLayoutAttrs );
522631 })
523632 .Case <IREE::Stream::TensorSizeOfOp>([&](auto op) {
524633 return updateTensorSizeOfOp (rewriter, op, layoutResolvers);
@@ -549,36 +658,26 @@ static LogicalResult addLayoutsToTensorPhaseOps(
549658 }
550659 return success ();
551660}
552- } // namespace
553661
662+ namespace {
554663struct SpecializeEncodingsPass
555664 : public impl::SpecializeEncodingsPassBase<SpecializeEncodingsPass> {
556665 void runOnOperation () override {
557666 ModuleOp moduleOp = getOperation ();
558- auto usedDialects = gatherUsedDialectInterfaces<
559- IREE::Stream::AffinityAnalysisDialectInterface>(moduleOp);
560- if (usedDialects.size () != 1 ) {
561- moduleOp.emitError (" expected only one dialect implementing "
562- " AffinityAnalysisDialectInterface" );
667+
668+ StreamTensorOpUpdater streamTensorOpUpdater (moduleOp);
669+ if (failed (streamTensorOpUpdater.init ())) {
670+ moduleOp.emitError (" failed to initialize StreamTensorOpUpdater" );
563671 return signalPassFailure ();
564672 }
565-
566- IREE::Stream::AffinityAnalysis affinityAnalysis (moduleOp);
567- if (failed (affinityAnalysis.run ())) {
568- moduleOp.emitError (" failed on running affinity analysis" );
673+ if (failed (streamTensorOpUpdater.run ())) {
674+ moduleOp.emitError (
675+ " failed to add layouts to Stream::TensorPhaseOp with encodings" );
569676 return signalPassFailure ();
570677 }
571678
572679 SymbolTable symbolTable (moduleOp);
573- IREE::Stream::ResolveLayoutAttrFn resolveLayoutAttr =
574- usedDialects[0 ]->makeLayoutAttrResolver (moduleOp);
575680 for (auto funcOp : moduleOp.getOps <FunctionOpInterface>()) {
576- if (failed (addLayoutsToTensorPhaseOps (moduleOp, affinityAnalysis, funcOp,
577- resolveLayoutAttr))) {
578- funcOp.emitError (
579- " failed on adding layouts to Stream::TensorPhaseOp with encodings" );
580- return signalPassFailure ();
581- }
582681 if (failed (duplicateExecutablesPerLayoutVariant (moduleOp, symbolTable,
583682 funcOp))) {
584683 funcOp.emitError (" failed on executable duplication" );
@@ -587,5 +686,6 @@ struct SpecializeEncodingsPass
587686 }
588687 }
589688};
689+ } // namespace
590690
591691} // namespace mlir::iree_compiler::IREE::Stream
0 commit comments