@@ -37,10 +37,12 @@ using FunctionLikeNest =
3737 MultiOpNest<func::FuncOp, IREE::Util::InitializerOp, IREE::Util::FuncOp>;
3838
3939// ===----------------------------------------------------------------------===//
40- // Utilities
40+ // --iree-stream-cleanup-pipeline
4141// ===----------------------------------------------------------------------===//
4242
43- static void addCleanupPatterns (OpPassManager &passManager) {
43+ static void buildStreamCleanupPassPipeline (
44+ OpPassManager &passManager,
45+ const IREE::Stream::TransformOptions &transformOptions) {
4446 FunctionLikeNest (passManager)
4547 // Standard MLIR cleanup.
4648 .addPass (mlir::createCanonicalizerPass)
@@ -84,7 +86,7 @@ void buildStreamTensorPassPipeline(OpPassManager &passManager,
8486
8587 // Cleanup the program prior to outlining constants in case there is
8688 // propagation or fusion that needs to happen first.
87- addCleanupPatterns (passManager);
89+ buildStreamCleanupPassPipeline (passManager, transformOptions );
8890
8991 // ----------------------------------------------------------------------------
9092 // Conversion
@@ -114,7 +116,7 @@ void buildStreamTensorPassPipeline(OpPassManager &passManager,
114116 passManager.addPass (mlir::createInlinerPass ());
115117
116118 // Cleanup globals that were created during conversion.
117- addCleanupPatterns (passManager);
119+ buildStreamCleanupPassPipeline (passManager, transformOptions );
118120
119121 // Bring all initializers together so that we can schedule them.
120122 passManager.addPass (IREE::Util::createCombineInitializersPass ());
@@ -160,7 +162,7 @@ void buildStreamAsyncPassPipeline(OpPassManager &passManager,
160162 passManager.addNestedPass <IREE::Stream::ExecutableOp>(
161163 IREE::Stream::createEncodeDeviceTensorsPass ());
162164
163- addCleanupPatterns (passManager);
165+ buildStreamCleanupPassPipeline (passManager, transformOptions );
164166
165167 // Everything must now be in stream.async.* form but we don't yet have
166168 // lifetime assigned.
@@ -186,7 +188,7 @@ void buildStreamAsyncPassPipeline(OpPassManager &passManager,
186188 // change and it makes the IR cleaner.
187189 passManager.addPass (IREE::Stream::createRefineUsagePass ());
188190
189- addCleanupPatterns (passManager);
191+ buildStreamCleanupPassPipeline (passManager, transformOptions );
190192
191193 // Verify all stream.async.* op access ranges that we can by taking advantage
192194 // of statically available information or that which we can infer from data
@@ -207,6 +209,13 @@ void buildStreamAsyncPassPipeline(OpPassManager &passManager,
207209 // Group concurrently executable work into waves.
208210 .addPass (IREE::Stream::createScheduleConcurrencyPass);
209211
212+ // When synchronous initialization is requested we need to separate any work
213+ // behind a timepoint in the initializer from the consumers of that timepoint.
214+ if (transformOptions.initializationMode ==
215+ IREE::Stream::InitializationMode::Synchronous) {
216+ passManager.addPass (IREE::Stream::createSyncInitializersPass ());
217+ }
218+
210219 // Materialize timepoints across the entire module. This simplifies scheduling
211220 // of the timeline as we can shake the IR and see what timepoints we still
212221 // have left.
@@ -217,7 +226,7 @@ void buildStreamAsyncPassPipeline(OpPassManager &passManager,
217226 // for partitioning/placement before turning them into opaque dispatches.
218227 passManager.addPass (IREE::Stream::createMaterializeBuiltinsPass ());
219228
220- addCleanupPatterns (passManager);
229+ buildStreamCleanupPassPipeline (passManager, transformOptions );
221230
222231 // Everything must now be in stream.async.* form.
223232 passManager.addPass (IREE::Stream::createVerifyLoweringToAsyncPass ());
@@ -245,13 +254,17 @@ void buildStreamCmdPassPipeline(OpPassManager &passManager,
245254 // Layout packed slices to emit the arithmetic required for all resource
246255 // offsets. This enables us to propagate the subviews across the program
247256 // below.
248- .addPass (IREE::Stream::createLayoutSlicesPass);
257+ .addPass (IREE::Stream::createLayoutSlicesPass)
258+
259+ // Apply canonicalization patterns to clean up subview ops prior to
260+ // propagating subranges.
261+ .addPass (mlir::createCanonicalizerPass);
249262
250263 // Propagate subviews throughout the program to unify resource storage access.
251264 // After propagation many resource SSA values can be deduped or folded by the
252265 // cleanup patterns.
253266 passManager.addPass (IREE::Util::createPropagateSubrangesPass ());
254- addCleanupPatterns (passManager);
267+ buildStreamCleanupPassPipeline (passManager, transformOptions );
255268
256269 // TODO(benvanik): outline streams (ala dispatch regions). Note that we may
257270 // want to do this earlier to enable better deduplication but that makes the
@@ -270,7 +283,7 @@ void buildStreamOptimizationPassPipeline(
270283 OpPassManager &passManager, const TransformOptions &transformOptions) {
271284 // Forming streams involves a fair amount of subgraph stitching, which can
272285 // cause duplication. Run CSE to collapse.
273- addCleanupPatterns (passManager);
286+ buildStreamCleanupPassPipeline (passManager, transformOptions );
274287
275288 // If any scf ops crept in we get rid of them here. We should be able to
276289 // support them all the way through the stream dialect but some passes are not
@@ -290,7 +303,7 @@ void buildStreamOptimizationPassPipeline(
290303 OpPassManager ipoPipeline (mlir::ModuleOp::getOperationName ());
291304
292305 // IPO and other cleanups.
293- addCleanupPatterns (ipoPipeline);
306+ buildStreamCleanupPassPipeline (ipoPipeline, transformOptions );
294307
295308 // TODO(#9747): elide timepoints that are know-reached due to host
296309 // synchronization via stream.timepoint.await.
@@ -333,7 +346,7 @@ void buildStreamOptimizationPassPipeline(
333346
334347 // Folding operands requires that canonicalization/CSE folds the inputs that
335348 // we check for.
336- addCleanupPatterns (passManager);
349+ buildStreamCleanupPassPipeline (passManager, transformOptions );
337350 passManager.addPass (IREE::Stream::createFoldUniformOperandsPass ());
338351
339352 // Only want to specialize after we've added all the operands we need above.
@@ -383,7 +396,7 @@ void buildStreamTransformPassPipeline(
383396 // ----------------------------------------------------------------------------
384397
385398 // Final cleanup after we optimize dispatches and fuse operands and bindings.
386- addCleanupPatterns (passManager);
399+ buildStreamCleanupPassPipeline (passManager, transformOptions );
387400
388401 // Symbol DCE any remaining variables/functions that are now no longer
389402 // required.
@@ -404,6 +417,13 @@ void registerStreamPasses() {
404417 registerPasses ();
405418
406419 // Pipelines.
420+ PassPipelineRegistration<TransformOptions> cleanupPassPipeline (
421+ " iree-stream-cleanup-pipeline" ,
422+ " Runs the cleanup passes that are performed between stages of the full "
423+ " stream pipeline." ,
424+ [](OpPassManager &passManager, const TransformOptions &transformOptions) {
425+ buildStreamCleanupPassPipeline (passManager, transformOptions);
426+ });
407427 PassPipelineRegistration<TransformOptions> tensorPassPipeline (
408428 " iree-stream-tensor-transformation-pipeline" ,
409429 " Lowers source dialects into stream.tensor.* IR." ,
0 commit comments