Skip to content

Commit bbe7f5c

Browse files
authored
Adding --iree-scheduling-initialization-mode= flag. (iree-org#19778)
This allows for choosing whether initializers return immediately with asynchronous work still pending or if they block and wait prior to returning. Users benchmarking will want to use synchronous mode while users wanting to overlap other work with initialization will want asynchronous mode. Since all existing frameworks operate with synchronous initialization the default is changed to that. For some spooky reason (iree-org#19795) this causes a few more onnx op tests to fail in addition to existing ones that were already failing. They've been xfailed for now because I cannot figure out what's going on or reproduce the issue. Fixes iree-org#19770.
1 parent c52eb68 commit bbe7f5c

File tree

20 files changed

+341
-47
lines changed

20 files changed

+341
-47
lines changed

compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,30 @@ class ValueResourceUsage : public AbstractResourceUsage<DFX::ValueElement> {
484484
getState() ^= resultUsage.getState();
485485
}
486486
})
487+
.Case<IREE::Stream::AsyncExecuteOp, IREE::Stream::AsyncConcurrentOp>(
488+
[&](auto op) {
489+
IREE::Stream::AsyncConcurrentOp c;
490+
// Take on the state from the internal usage.
491+
for (auto yieldOp :
492+
op.getClosureBodyRegion()
493+
.template getOps<IREE::Stream::YieldOp>()) {
494+
auto &yieldUsage = solver.getElementFor<ValueResourceUsage>(
495+
*this,
496+
Position::forValue(
497+
yieldOp.getOperand(result.getResultNumber())),
498+
DFX::Resolution::REQUIRED);
499+
getState() ^= yieldUsage.getState();
500+
}
501+
// If the result is passed through as a tied operand then also
502+
// inherit the original state.
503+
auto tiedOperand = op.getTiedResultOperand(result);
504+
if (tiedOperand) {
505+
auto &tiedUsage = solver.getElementFor<ValueResourceUsage>(
506+
*this, Position::forValue(tiedOperand),
507+
DFX::Resolution::REQUIRED);
508+
getState() ^= tiedUsage.getState();
509+
}
510+
})
487511
.Default([&](Operation *op) {});
488512
}
489513

@@ -805,6 +829,30 @@ class ValueResourceUsage : public AbstractResourceUsage<DFX::ValueElement> {
805829
getState() ^= resultUsage.getState();
806830
}
807831
})
832+
.Case<IREE::Stream::AsyncExecuteOp, IREE::Stream::AsyncConcurrentOp>(
833+
[&](auto op) {
834+
// Take on the traits of all ops within the execution region that
835+
// use the value and handle ties if needed.
836+
auto &operandUsage = solver.getElementFor<ValueResourceUsage>(
837+
*this,
838+
Position::forValue(
839+
op.getClosureBodyRegion().getArgument(operandIdx)),
840+
DFX::Resolution::REQUIRED);
841+
getState() ^= operandUsage.getState();
842+
for (auto result : op.getOperandTiedResults(operandIdx)) {
843+
auto &resultUsage = solver.getElementFor<ValueResourceUsage>(
844+
*this, Position::forValue(result),
845+
DFX::Resolution::REQUIRED);
846+
getState() ^= resultUsage.getState();
847+
}
848+
})
849+
.Case([&](IREE::Stream::YieldOp op) {
850+
// Take on the traits of the result of the parent operation.
851+
Value result = op->getParentOp()->getResult(operandIdx);
852+
auto &resultUsage = solver.getElementFor<ValueResourceUsage>(
853+
*this, Position::forValue(result), DFX::Resolution::REQUIRED);
854+
getState() ^= resultUsage.getState();
855+
})
808856
.Default([&](Operation *op) {});
809857
}
810858

compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -461,9 +461,11 @@ struct ChainDependentAwaits : public OpRewritePattern<Op> {
461461
for (auto operand : llvm::enumerate(op.getResourceOperands())) {
462462
if (auto awaitOp =
463463
operand.value().template getDefiningOp<TimepointAwaitOp>()) {
464-
newTimepoints.push_back(awaitOp.getAwaitTimepoint());
465-
replacements.push_back(std::make_pair(
466-
operand.index(), awaitOp.getTiedResultOperand(operand.value())));
464+
if (!awaitOp.getSync()) {
465+
newTimepoints.push_back(awaitOp.getAwaitTimepoint());
466+
replacements.push_back(std::make_pair(
467+
operand.index(), awaitOp.getTiedResultOperand(operand.value())));
468+
}
467469
}
468470
}
469471
if (replacements.empty())
@@ -3050,7 +3052,9 @@ findSourceAwaitOp(Value resource) {
30503052
baseResource.getDefiningOp())) {
30513053
if (auto awaitOp = dyn_cast<IREE::Stream::TimepointAwaitOp>(
30523054
baseResource.getDefiningOp())) {
3053-
return {awaitOp, baseResource};
3055+
if (!awaitOp.getSync()) {
3056+
return {awaitOp, baseResource};
3057+
}
30543058
}
30553059
auto tiedValue = definingOp.getTiedResultOperand(baseResource);
30563060
if (!tiedValue)
@@ -3141,6 +3145,11 @@ struct SinkAwaitToFirstConsumer : public OpRewritePattern<TimepointAwaitOp> {
31413145
using OpRewritePattern::OpRewritePattern;
31423146
LogicalResult matchAndRewrite(TimepointAwaitOp op,
31433147
PatternRewriter &rewriter) const override {
3148+
// Don't move sync points as they may be implicitly guarding execution.
3149+
if (op.getSync()) {
3150+
return rewriter.notifyMatchFailure(op, "sync awaits cannot be moved");
3151+
}
3152+
31443153
// TODO(benvanik): amortize this dominance calculation.
31453154
DominanceInfo domInfo(op->getParentOp());
31463155

@@ -3197,6 +3206,7 @@ struct SinkSubviewsAcrossAwaits : public OpRewritePattern<TimepointAwaitOp> {
31973206
using OpRewritePattern::OpRewritePattern;
31983207
LogicalResult matchAndRewrite(TimepointAwaitOp op,
31993208
PatternRewriter &rewriter) const override {
3209+
rewriter.setInsertionPointAfter(op);
32003210
rewriter.startOpModification(op);
32013211
bool didChange = false;
32023212
for (auto operand : llvm::enumerate(op.getResourceOperands())) {
@@ -3276,7 +3286,7 @@ struct GroupAwaitsByTimepoint : public OpRewritePattern<TimepointAwaitOp> {
32763286
if (dominanceInfo.dominates(use.getOwner(), op))
32773287
continue;
32783288
auto awaitOp = dyn_cast<TimepointAwaitOp>(use.getOwner());
3279-
if (!awaitOp)
3289+
if (!awaitOp || awaitOp.getSync())
32803290
continue;
32813291
// Ensure all dependencies of the await op are available.
32823292
if (!areAllOperandsDefinedBy(awaitOp, op, dominanceInfo)) {
@@ -3351,6 +3361,7 @@ struct FoldDuplicateAwaitResources : public OpRewritePattern<TimepointAwaitOp> {
33513361
// Create replacement op with deduped operands/results.
33523362
auto newOp = rewriter.create<IREE::Stream::TimepointAwaitOp>(
33533363
op.getLoc(), newOperands, newOperandSizes, op.getAwaitTimepoint());
3364+
newOp.setSync(op.getSync());
33543365

33553366
// Replace all duplicate results with the base results.
33563367
for (auto &replacement : replacements) {
@@ -3363,6 +3374,24 @@ struct FoldDuplicateAwaitResources : public OpRewritePattern<TimepointAwaitOp> {
33633374
}
33643375
};
33653376

3377+
struct ElideUnusedTimepointAwait : public OpRewritePattern<TimepointAwaitOp> {
3378+
using OpRewritePattern::OpRewritePattern;
3379+
LogicalResult matchAndRewrite(TimepointAwaitOp op,
3380+
PatternRewriter &rewriter) const override {
3381+
// If there are any uses the await is required to associate the timepoint.
3382+
if (!op.use_empty()) {
3383+
return failure();
3384+
}
3385+
// If the await is a sync point then we cannot elide it even if it has no
3386+
// uses.
3387+
if (op.getSync()) {
3388+
return rewriter.notifyMatchFailure(op, "sync ops cannot be elided");
3389+
}
3390+
rewriter.eraseOp(op);
3391+
return success();
3392+
}
3393+
};
3394+
33663395
} // namespace
33673396

33683397
void TimepointAwaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -3373,7 +3402,7 @@ void TimepointAwaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
33733402
results.insert<SinkSubviewsAcrossAwaits>(context);
33743403
results.insert<GroupAwaitsByTimepoint>(context);
33753404
results.insert<FoldDuplicateAwaitResources>(context);
3376-
results.insert<ElideUnusedOp<TimepointAwaitOp>>(context);
3405+
results.insert<ElideUnusedTimepointAwait>(context);
33773406
}
33783407

33793408
//===----------------------------------------------------------------------===//

compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3793,7 +3793,8 @@ def Stream_TimepointAwaitOp : Stream_PureOp<"timepoint.await", [
37933793
Stream_StagingResource,
37943794
]>>:$resource_operands,
37953795
Variadic<Stream_Size>:$resource_operand_sizes,
3796-
Stream_Timepoint:$await_timepoint
3796+
Stream_Timepoint:$await_timepoint,
3797+
UnitAttr:$sync
37973798
);
37983799
let results = (outs
37993800
Variadic<AnyTypeOf<[
@@ -3803,6 +3804,7 @@ def Stream_TimepointAwaitOp : Stream_PureOp<"timepoint.await", [
38033804
);
38043805

38053806
let assemblyFormat = [{
3807+
(`sync` $sync^)?
38063808
$await_timepoint `=` `` `>`
38073809
$resource_operands `:`
38083810
custom<ShapedTypeList>(type($resource_operands),

compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ iree_compiler_cc_library(
4040
"ScheduleExecution.cpp",
4141
"SpecializeDispatches.cpp",
4242
"SpecializeEncodings.cpp",
43+
"SyncInitializers.cpp",
4344
"VerifyAffinities.cpp",
4445
"VerifyAsyncAccessRanges.cpp",
4546
"VerifyLowerings.cpp",

compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ iree_cc_library(
4141
"ScheduleExecution.cpp"
4242
"SpecializeDispatches.cpp"
4343
"SpecializeEncodings.cpp"
44+
"SyncInitializers.cpp"
4445
"VerifyAffinities.cpp"
4546
"VerifyAsyncAccessRanges.cpp"
4647
"VerifyLowerings.cpp"

compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,12 @@ using FunctionLikeNest =
3737
MultiOpNest<func::FuncOp, IREE::Util::InitializerOp, IREE::Util::FuncOp>;
3838

3939
//===----------------------------------------------------------------------===//
40-
// Utilities
40+
// --iree-stream-cleanup-pipeline
4141
//===----------------------------------------------------------------------===//
4242

43-
static void addCleanupPatterns(OpPassManager &passManager) {
43+
static void buildStreamCleanupPassPipeline(
44+
OpPassManager &passManager,
45+
const IREE::Stream::TransformOptions &transformOptions) {
4446
FunctionLikeNest(passManager)
4547
// Standard MLIR cleanup.
4648
.addPass(mlir::createCanonicalizerPass)
@@ -84,7 +86,7 @@ void buildStreamTensorPassPipeline(OpPassManager &passManager,
8486

8587
// Cleanup the program prior to outlining constants in case there is
8688
// propagation or fusion that needs to happen first.
87-
addCleanupPatterns(passManager);
89+
buildStreamCleanupPassPipeline(passManager, transformOptions);
8890

8991
//----------------------------------------------------------------------------
9092
// Conversion
@@ -114,7 +116,7 @@ void buildStreamTensorPassPipeline(OpPassManager &passManager,
114116
passManager.addPass(mlir::createInlinerPass());
115117

116118
// Cleanup globals that were created during conversion.
117-
addCleanupPatterns(passManager);
119+
buildStreamCleanupPassPipeline(passManager, transformOptions);
118120

119121
// Bring all initializers together so that we can schedule them.
120122
passManager.addPass(IREE::Util::createCombineInitializersPass());
@@ -160,7 +162,7 @@ void buildStreamAsyncPassPipeline(OpPassManager &passManager,
160162
passManager.addNestedPass<IREE::Stream::ExecutableOp>(
161163
IREE::Stream::createEncodeDeviceTensorsPass());
162164

163-
addCleanupPatterns(passManager);
165+
buildStreamCleanupPassPipeline(passManager, transformOptions);
164166

165167
// Everything must now be in stream.async.* form but we don't yet have
166168
// lifetime assigned.
@@ -186,7 +188,7 @@ void buildStreamAsyncPassPipeline(OpPassManager &passManager,
186188
// change and it makes the IR cleaner.
187189
passManager.addPass(IREE::Stream::createRefineUsagePass());
188190

189-
addCleanupPatterns(passManager);
191+
buildStreamCleanupPassPipeline(passManager, transformOptions);
190192

191193
// Verify all stream.async.* op access ranges that we can by taking advantage
192194
// of statically available information or that which we can infer from data
@@ -207,6 +209,13 @@ void buildStreamAsyncPassPipeline(OpPassManager &passManager,
207209
// Group concurrently executable work into waves.
208210
.addPass(IREE::Stream::createScheduleConcurrencyPass);
209211

212+
// When synchronous initialization is requested we need to separate any work
213+
// behind a timepoint in the initializer from the consumers of that timepoint.
214+
if (transformOptions.initializationMode ==
215+
IREE::Stream::InitializationMode::Synchronous) {
216+
passManager.addPass(IREE::Stream::createSyncInitializersPass());
217+
}
218+
210219
// Materialize timepoints across the entire module. This simplifies scheduling
211220
// of the timeline as we can shake the IR and see what timepoints we still
212221
// have left.
@@ -217,7 +226,7 @@ void buildStreamAsyncPassPipeline(OpPassManager &passManager,
217226
// for partitioning/placement before turning them into opaque dispatches.
218227
passManager.addPass(IREE::Stream::createMaterializeBuiltinsPass());
219228

220-
addCleanupPatterns(passManager);
229+
buildStreamCleanupPassPipeline(passManager, transformOptions);
221230

222231
// Everything must now be in stream.async.* form.
223232
passManager.addPass(IREE::Stream::createVerifyLoweringToAsyncPass());
@@ -245,13 +254,17 @@ void buildStreamCmdPassPipeline(OpPassManager &passManager,
245254
// Layout packed slices to emit the arithmetic required for all resource
246255
// offsets. This enables us to propagate the subviews across the program
247256
// below.
248-
.addPass(IREE::Stream::createLayoutSlicesPass);
257+
.addPass(IREE::Stream::createLayoutSlicesPass)
258+
259+
// Apply canonicalization patterns to clean up subview ops prior to
260+
// propagating subranges.
261+
.addPass(mlir::createCanonicalizerPass);
249262

250263
// Propagate subviews throughout the program to unify resource storage access.
251264
// After propagation many resource SSA values can be deduped or folded by the
252265
// cleanup patterns.
253266
passManager.addPass(IREE::Util::createPropagateSubrangesPass());
254-
addCleanupPatterns(passManager);
267+
buildStreamCleanupPassPipeline(passManager, transformOptions);
255268

256269
// TODO(benvanik): outline streams (ala dispatch regions). Note that we may
257270
// want to do this earlier to enable better deduplication but that makes the
@@ -270,7 +283,7 @@ void buildStreamOptimizationPassPipeline(
270283
OpPassManager &passManager, const TransformOptions &transformOptions) {
271284
// Forming streams involves a fair amount of subgraph stitching, which can
272285
// cause duplication. Run CSE to collapse.
273-
addCleanupPatterns(passManager);
286+
buildStreamCleanupPassPipeline(passManager, transformOptions);
274287

275288
// If any scf ops crept in we get rid of them here. We should be able to
276289
// support them all the way through the stream dialect but some passes are not
@@ -290,7 +303,7 @@ void buildStreamOptimizationPassPipeline(
290303
OpPassManager ipoPipeline(mlir::ModuleOp::getOperationName());
291304

292305
// IPO and other cleanups.
293-
addCleanupPatterns(ipoPipeline);
306+
buildStreamCleanupPassPipeline(ipoPipeline, transformOptions);
294307

295308
// TODO(#9747): elide timepoints that are know-reached due to host
296309
// synchronization via stream.timepoint.await.
@@ -333,7 +346,7 @@ void buildStreamOptimizationPassPipeline(
333346

334347
// Folding operands requires that canonicalization/CSE folds the inputs that
335348
// we check for.
336-
addCleanupPatterns(passManager);
349+
buildStreamCleanupPassPipeline(passManager, transformOptions);
337350
passManager.addPass(IREE::Stream::createFoldUniformOperandsPass());
338351

339352
// Only want to specialize after we've added all the operands we need above.
@@ -383,7 +396,7 @@ void buildStreamTransformPassPipeline(
383396
//----------------------------------------------------------------------------
384397

385398
// Final cleanup after we optimize dispatches and fuse operands and bindings.
386-
addCleanupPatterns(passManager);
399+
buildStreamCleanupPassPipeline(passManager, transformOptions);
387400

388401
// Symbol DCE any remaining variables/functions that are now no longer
389402
// required.
@@ -404,6 +417,13 @@ void registerStreamPasses() {
404417
registerPasses();
405418

406419
// Pipelines.
420+
PassPipelineRegistration<TransformOptions> cleanupPassPipeline(
421+
"iree-stream-cleanup-pipeline",
422+
"Runs the cleanup passes that are performed between stages of the full "
423+
"stream pipeline.",
424+
[](OpPassManager &passManager, const TransformOptions &transformOptions) {
425+
buildStreamCleanupPassPipeline(passManager, transformOptions);
426+
});
407427
PassPipelineRegistration<TransformOptions> tensorPassPipeline(
408428
"iree-stream-tensor-transformation-pipeline",
409429
"Lowers source dialects into stream.tensor.* IR.",

compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.h

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,22 @@ namespace mlir::iree_compiler::IREE::Stream {
2424
// Pipelines
2525
//===----------------------------------------------------------------------===//
2626

27-
// TODO(benvanik): find a way to share this with IREEVM.h w/o circular deps.
27+
// TODO(benvanik): find a way to share option enums with the top-level Options.h
28+
// w/o circular deps.
29+
30+
// Defines the behavior of initialization.
31+
enum class InitializationMode {
32+
// Synchronously initialize all parameters and globals prior to returning
33+
// from the module initializer.
34+
Synchronous = 0,
35+
// Asynchronously initialize all parameters and globals and return
36+
// immediately from the module initializer without waiting for them to
37+
// complete. Subsequent invocations will queue waiting for any dependencies
38+
// they have on the initialized values.
39+
Asynchronous = 1,
40+
};
41+
42+
// TODO(benvanik): find a way to share this with Options.h w/o circular deps.
2843
// Defines the output format of a dump pass.
2944
enum class DumpOutputFormat {
3045
// Dumping disabled.
@@ -40,7 +55,21 @@ enum class DumpOutputFormat {
4055
};
4156

4257
struct TransformOptions : public PassPipelineOptions<TransformOptions> {
43-
// TODO(benvanik): options for async/sync overrides.
58+
Option<InitializationMode> initializationMode{
59+
*this,
60+
"initialization-mode",
61+
llvm::cl::desc(
62+
"Specifies the initialization mode for parameters and globals."),
63+
llvm::cl::init(InitializationMode::Synchronous),
64+
llvm::cl::values(
65+
clEnumValN(InitializationMode::Synchronous, "sync",
66+
"Synchronously initialize all parameters and globals "
67+
"prior to returning from the module initializer."),
68+
clEnumValN(InitializationMode::Asynchronous, "async",
69+
"Asynchronously initialize all parameters and globals and "
70+
"return immediately from the module initializer without "
71+
"waiting for them to complete.")),
72+
};
4473

4574
Option<bool> optimizeBindings{
4675
*this,

0 commit comments

Comments
 (0)