@@ -107,7 +107,7 @@ static void addAMDAIEBufferizePasses(OpPassManager &pm,
107107}
108108
109109void addAMDAIEToAIEPasses (OpPassManager &passManager,
110- bool insertLoopAroundCoreBlock) {
110+ bool insertLoopAroundCoreBlock, bool reprogramDmas ) {
111111 // The infinite loop insertion transformation needs to be called before the
112112 // `AcquireReleaseToUseLock` pass as the latter will perform loop unrolling
113113 // based on the objFifo depths.
@@ -123,9 +123,7 @@ void addAMDAIEToAIEPasses(OpPassManager &passManager,
123123 passManager.addPass (createAMDAIEAddNoAliasFunctionArgumentsPass ());
124124 {
125125 AMDAIELowerToAIEOptions options;
126- // TODO(avarma): In follow-up PRs this will be replaced by a global flag.
127- // Currently setting as `false`.
128- options.reprogramDmas = /* reprogramDmas=*/ false ;
126+ options.reprogramDmas = reprogramDmas;
129127 passManager.addPass (createAMDAIELowerToAIEPass (options));
130128 }
131129 passManager.addPass (createAMDAIERemoveMemorySpacePass ());
@@ -672,7 +670,7 @@ void buildAMDAIETransformPassPipeline(
672670 PacketFlowStrategy packetFlowStrategy, bool enableCoalescingLoops,
673671 bool enableCollapsingUnitDims, OutliningStrategy enableFunctionOutlining,
674672 int callReplication, bool insertLoopAroundCoreBlock, bool enableCtrlPkt,
675- uint32_t coreStackSize) {
673+ uint32_t coreStackSize, bool reprogramDmas ) {
676674 OpPassManager &modulePassManager = variantPassManager.nest <ModuleOp>();
677675 {
678676 FunctionLikeNest funcPassManager (modulePassManager);
@@ -707,7 +705,8 @@ void buildAMDAIETransformPassPipeline(
707705 modulePassManager, packetFlowStrategy, useTilePipeline,
708706 enableVectorizationPasses, enableCoalescingLoops,
709707 enableCollapsingUnitDims, enableFunctionOutlining, callReplication,
710- insertLoopAroundCoreBlock, numCols, enableCtrlPkt, coreStackSize);
708+ insertLoopAroundCoreBlock, numCols, enableCtrlPkt, coreStackSize,
709+ reprogramDmas);
711710 } else if (useLowerToAIEPipeline == LowerToAIEPassPipeline::AIR) {
712711 addMLIRAIRLoweringPasses (modulePassManager, device, useTilePipeline,
713712 matmulElementwiseFusion,
@@ -733,7 +732,7 @@ void addAMDAIEObjectFifoLoweringPasses(
733732 bool enableCoalescingLoops, bool enableCollapsingUnitDims,
734733 OutliningStrategy enableFunctionOutlining, int callReplication,
735734 bool insertLoopAroundCoreBlock, uint32_t numCols, bool enableCtrlPkt,
736- uint32_t coreStackSize) {
735+ uint32_t coreStackSize, bool reprogramDmas ) {
737736 passManager.addPass (createEraseHALDescriptorTypeFromMemRefPass ());
738737 passManager.addPass (memref::createFoldMemRefAliasOpsPass ());
739738
@@ -796,18 +795,28 @@ void addAMDAIEObjectFifoLoweringPasses(
796795
797796 passManager.addPass (createCSEPass ());
798797 passManager.addPass (createCanonicalizerPass ());
799- passManager.addPass (createAMDAIEAssignLogicalObjectFifoDepthPass ());
798+ {
799+ AMDAIEAssignLogicalObjectFifoDepthOptions options;
800+ // TODO(avarma): In case reprogramming Dmas, we currently disable double
801+ // buffering. Relax the constraint later after modifying
802+ // controlcode-lowering and controlcode-to-transaction-binary pass to work
803+ // with double buffering.
804+ if (reprogramDmas) {
805+ options.l2BufferDepth = 1 ;
806+ options.l1BufferDepth = 1 ;
807+ }
808+ passManager.addPass (createAMDAIEAssignLogicalObjectFifoDepthPass (options));
809+ }
800810
801811 passManager.addPass (createAMDAIEAssignTilesPass ());
802812 passManager.addPass (createCSEPass ());
803813 passManager.addPass (createCanonicalizerPass ());
804814
805- passManager.addPass (createAMDAIEDmaToCircularDmaPass ());
815+ if (!reprogramDmas) passManager.addPass (createAMDAIEDmaToCircularDmaPass ());
816+
806817 {
807818 AMDAIECreateAIEWorkgroupOptions options;
808- // TODO(avarma): In follow-up PRs this will be replaced by a global flag.
809- // Currently setting as `false`.
810- options.reprogramDmas = /* reprogramDmas=*/ false ;
819+ options.reprogramDmas = reprogramDmas;
811820 passManager.addNestedPass <func::FuncOp>(
812821 createAMDAIECreateAIEWorkgroupPass (options));
813822 }
@@ -874,11 +883,30 @@ void addAMDAIEObjectFifoLoweringPasses(
874883
875884 passManager.addPass (createAMDAIENpuDmaToHalfDmaCpyNdPass ());
876885 passManager.addPass (createAMDAIEInsertDmaBdChainPass ());
877- passManager.addPass (createAMDAIEFoldDmaWaitsPass ());
878- passManager.addPass (createAMDAIEControlCodeLoweringPass ());
886+ // TODO(avarma): Currently with fold dma wait pass, in case of DMA
887+ // reprogramming we get ALL zeroes. To be triaged/fixed later in order to
888+ // relax this constraint and optimize the wait ops.
889+ if (!reprogramDmas) passManager.addPass (createAMDAIEFoldDmaWaitsPass ());
890+
891+ {
892+ AMDAIEControlCodeLoweringOptions options;
893+ options.reprogramDmas = reprogramDmas;
894+ passManager.addPass (createAMDAIEControlCodeLoweringPass (options));
895+ }
896+ if (reprogramDmas) {
897+ passManager.addPass (createAMDAIEAssignBDIDsPass ());
898+ {
899+ // For Conv ops use basic sequential scheme to avoid numerical error.
900+ // TODO: Find a better working scheme for Conv ops
901+ AMDAIEAssignBufferAddressOptions options;
902+ if (useTilePipeline == TilePassPipeline::ConvDecomposePipeline)
903+ options.allocScheme = AllocScheme::Sequential;
904+ passManager.addPass (createAMDAIEAssignBufferAddressPass (options));
905+ }
906+ }
879907 passManager.addPass (createAMDAIEControlCodeToTransactionPass ());
880908
881- addAMDAIEToAIEPasses (passManager, insertLoopAroundCoreBlock);
909+ addAMDAIEToAIEPasses (passManager, insertLoopAroundCoreBlock, reprogramDmas );
882910
883911 // Now lower using the AIE passes from MLIR-AIE.
884912 addMLIRAIELoweringPasses (passManager, useTilePipeline);
0 commit comments