@@ -20,7 +20,8 @@ namespace mlir::iree_compiler::AMDAIE {
2020namespace {
2121
2222// / Utility to retrieve a TileOp from a vector of tile values, while doing
23- // / appropriate verifications.
23+ // / appropriate verifications. It's expected to return failure for non-shim
24+ // / tiles.
2425template <CopyOpOperateOn OperateOn>
2526FailureOr<AMDAIE::TileOp> getGeneratorTileOp (
2627 AMDAIE::NpuDmaCpyNdOp npuDmaOp,
@@ -52,9 +53,9 @@ FailureOr<AMDAIE::TileOp> getGeneratorTileOp(
5253 << tiles.size ();
5354 }
5455 Value tile = tiles[0 ];
55- if (!shimTileToGeneratorMap. contains (tile))
56- return npuDmaOp. emitOpError ()
57- << " no channel BD ID generator found for tile: " << tile ;
56+ // Since we can have non-Shim DMA ops as npu.dma_cpy_nd when reprogramming
57+ // DMAs, we can simply return failure instead of emitting an error.
58+ if (!shimTileToGeneratorMap. contains ( tile)) return failure () ;
5859
5960 auto tileOp = dyn_cast_if_present<AMDAIE::TileOp>(tile.getDefiningOp ());
6061 if (!tileOp) return npuDmaOp.emitOpError () << " no tile op found" ;
@@ -244,16 +245,16 @@ class BdIdAssignmentUtil {
244245 DenseMap<AMDAIE::BdIdOp, SmallVector<uint32_t >> bdIdOpToBdIdsMap;
245246 // A mapping from DMAOp to its corresponding source/target BD IDs.
246247 DenseMap<AMDAIE::NpuDmaCpyNdOp, SmallVector<AMDAIE::BdIdOp, 2 >>
247- dmaOpToBdIdMap ;
248+ shimDmaOpToBdIdMap ;
248249
249250 public:
250251 BdIdAssignmentUtil (
251252 DenseMap<Value, ChannelBdIdGenerator> shimTileToGeneratorMap)
252253 : shimTileToGeneratorMap(std::move(shimTileToGeneratorMap)) {}
253254
254255 DenseMap<AMDAIE::NpuDmaCpyNdOp, SmallVector<AMDAIE::BdIdOp, 2 >> &
255- getDmaOpToBdIdMap () {
256- return dmaOpToBdIdMap ;
256+ getShimDmaOpToBdIdMap () {
257+ return shimDmaOpToBdIdMap ;
257258 }
258259
259260 // / Assign Bd Ids to each DmaBatch belonging to a particular Tile in the
@@ -347,9 +348,10 @@ class BdIdAssignmentUtil {
347348 }
348349
349350 // / Assign required Bd Ids to the DmaOps of the current DmaBatch. This
350- // / assignment is tracked by maintaining `dmaOpToBdIdMap`, which essentially
351- // / maps a DmaOp to its source/target Bd Ids. Also, the API splits the
352- // / available BD IDs equally amongst all DmaOps in the DmaBatch when assigning
351+ // / assignment is tracked by maintaining `shimDmaOpToBdIdMap`, which
352+ // / essentially maps a DmaOp to its source/target Bd Ids. Also, the API splits
353+ // / the available BD IDs equally amongst all DmaOps in the DmaBatch when
354+ // / assigning
353355 LogicalResult assignRequiredBdIdsInCurrentBatch (
354356 IRRewriter &rewriter, AMDAIE::TileOp tileOp,
355357 std::unique_ptr<DmaBatch> &dmaBatch) {
@@ -410,12 +412,12 @@ class BdIdAssignmentUtil {
410412 AMDAIE::BdIdOp bdIdOp = rewriter.create <AMDAIE::BdIdOp>(
411413 rewriter.getUnknownLoc (), tileOp, affineApply.getResult ());
412414 bdIdOpToBdIdsMap[bdIdOp] = bdIds;
413- if (!dmaOpToBdIdMap .contains (dmaOp)) {
415+ if (!shimDmaOpToBdIdMap .contains (dmaOp)) {
414416 SmallVector<AMDAIE::BdIdOp, 2 > bdIdOps = {nullptr , nullptr };
415- dmaOpToBdIdMap [dmaOp] = bdIdOps;
417+ shimDmaOpToBdIdMap [dmaOp] = bdIdOps;
416418 }
417419
418- dmaOpToBdIdMap [dmaOp][dmaTileData.bdIdMapIndex ] = bdIdOp;
420+ shimDmaOpToBdIdMap [dmaOp][dmaTileData.bdIdMapIndex ] = bdIdOp;
419421 } else {
420422 // Assign a constant BD ID.
421423 std::optional<uint32_t > bdId = generator.getAndAssignBdId (
@@ -425,11 +427,11 @@ class BdIdAssignmentUtil {
425427 rewriter.getUnknownLoc (), rewriter.getIndexAttr (bdId.value ()));
426428 AMDAIE::BdIdOp bdIdOp = rewriter.create <AMDAIE::BdIdOp>(
427429 rewriter.getUnknownLoc (), tileOp, constant.getResult ());
428- if (!dmaOpToBdIdMap .contains (dmaOp)) {
430+ if (!shimDmaOpToBdIdMap .contains (dmaOp)) {
429431 SmallVector<AMDAIE::BdIdOp, 2 > bdIdOps = {nullptr , nullptr };
430- dmaOpToBdIdMap [dmaOp] = bdIdOps;
432+ shimDmaOpToBdIdMap [dmaOp] = bdIdOps;
431433 }
432- dmaOpToBdIdMap [dmaOp][dmaTileData.bdIdMapIndex ] = bdIdOp;
434+ shimDmaOpToBdIdMap [dmaOp][dmaTileData.bdIdMapIndex ] = bdIdOp;
433435 }
434436 // Reset to fetch next DmaOps' DmaTileData.
435437 dmaTileData.bdIdMapIndex = -1 ;
@@ -471,17 +473,17 @@ class BdIdAssignmentUtil {
471473 }
472474
473475 // / DmaOps are assigned Bd Ids prior to invoking this function and a map
474- // / `dmaOpToBdIdMap ` is maintained that maps a DmaOp to its source/target Bd
475- // / Ids. For each DmaOp in the list `dmaOps`, this API will check the
476- // / `dmaOpToBdIdMap ` and release Bd Ids if assigned.
476+ // / `shimDmaOpToBdIdMap ` is maintained that maps a DmaOp to its source/target
477+ // / Bd Ids. For each DmaOp in the list `dmaOps`, this API will check the
478+ // / `shimDmaOpToBdIdMap ` and release Bd Ids if assigned.
477479 LogicalResult releaseAssignedBdIdsInDmaOps (
478480 SmallVectorImpl<AMDAIE::NpuDmaCpyNdOp> &dmaOps) {
479481 // Release BD ID used by input DMA op.
480482 for (AMDAIE::NpuDmaCpyNdOp npuDmaOp : dmaOps) {
481- if (AMDAIE::BdIdOp bdIdOp = dmaOpToBdIdMap [npuDmaOp][0 ]; bdIdOp) {
483+ if (AMDAIE::BdIdOp bdIdOp = shimDmaOpToBdIdMap [npuDmaOp][0 ]; bdIdOp) {
482484 if (failed (releaseBdId (bdIdOp))) return failure ();
483485 }
484- if (AMDAIE::BdIdOp bdIdOp = dmaOpToBdIdMap [npuDmaOp][1 ]; bdIdOp) {
486+ if (AMDAIE::BdIdOp bdIdOp = shimDmaOpToBdIdMap [npuDmaOp][1 ]; bdIdOp) {
485487 if (failed (releaseBdId (bdIdOp))) return failure ();
486488 }
487489 }
@@ -490,20 +492,22 @@ class BdIdAssignmentUtil {
490492};
491493
492494// / Traverse each DmaOp inside ControlCode and replace it with new new DmaOp
493- // / that has Bd Ids assigned using `dmaOpToBdIdMap `.
495+ // / that has Bd Ids assigned using `shimDmaOpToBdIdMap `.
494496static LogicalResult replaceDmaOps (
495497 IRRewriter &rewriter, AMDAIE::ControlCodeOp controlCodeOp,
496498 DenseMap<AMDAIE::NpuDmaCpyNdOp, SmallVector<AMDAIE::BdIdOp, 2 >>
497- &dmaOpToBdIdMap ) {
499+ &shimDmaOpToBdIdMap ) {
498500 WalkResult res = controlCodeOp->walk ([&](AMDAIE::NpuDmaCpyNdOp npuDmaOp) {
499- assert (dmaOpToBdIdMap .contains (npuDmaOp) && " No BD ID mapping found " );
501+ if (!shimDmaOpToBdIdMap .contains (npuDmaOp)) return WalkResult::advance ( );
500502 Value sourceBdId = nullptr ;
501503 Value targetBdId = nullptr ;
502- if (AMDAIE::BdIdOp bdIdOp = dmaOpToBdIdMap[npuDmaOp][/* sourceBdIdIndex=*/ 0 ];
504+ if (AMDAIE::BdIdOp bdIdOp =
505+ shimDmaOpToBdIdMap[npuDmaOp][/* sourceBdIdIndex=*/ 0 ];
503506 bdIdOp) {
504507 sourceBdId = bdIdOp.getResult ();
505508 }
506- if (AMDAIE::BdIdOp bdIdOp = dmaOpToBdIdMap[npuDmaOp][/* targetBdIdIndex=*/ 1 ];
509+ if (AMDAIE::BdIdOp bdIdOp =
510+ shimDmaOpToBdIdMap[npuDmaOp][/* targetBdIdIndex=*/ 1 ];
507511 bdIdOp) {
508512 targetBdId = bdIdOp.getResult ();
509513 }
@@ -554,19 +558,26 @@ static TileDmaBatchGraph createTileDmaBatchGraph(
554558 };
555559
556560 auto processNpuDmaCpyNdOp = [&](AMDAIE::NpuDmaCpyNdOp dmaOp) {
561+ bool isShimDmaOp = false ;
557562 if (dmaOp.getSource ()) {
558563 FailureOr<AMDAIE::TileOp> tile =
559564 getGeneratorTileOp<CopyOpOperateOn::Source>(dmaOp,
560565 shimTileToGeneratorMap);
561- if (succeeded (tile)) tileDmaBatchGraph.addDmaToBatch (*tile, dmaOp);
566+ if (succeeded (tile)) {
567+ tileDmaBatchGraph.addDmaToBatch (*tile, dmaOp);
568+ isShimDmaOp = true ;
569+ }
562570 }
563571 if (dmaOp.getTarget ()) {
564572 FailureOr<AMDAIE::TileOp> tile =
565573 getGeneratorTileOp<CopyOpOperateOn::Target>(dmaOp,
566574 shimTileToGeneratorMap);
567- if (succeeded (tile)) tileDmaBatchGraph.addDmaToBatch (*tile, dmaOp);
575+ if (succeeded (tile)) {
576+ tileDmaBatchGraph.addDmaToBatch (*tile, dmaOp);
577+ isShimDmaOp = true ;
578+ }
568579 }
569- if (!currDmaOp) currDmaOp = dmaOp;
580+ if (!currDmaOp && isShimDmaOp ) currDmaOp = dmaOp;
570581 };
571582
572583 auto processNpuDmaWaitOp = [&](AMDAIE::NpuDmaWaitOp npuWaitOp) {
@@ -627,7 +638,8 @@ LogicalResult assignNpuDmaBdIds(AMDAIE::WorkgroupOp workgroupOp) {
627638 AMDAIE::ControlCodeOp controlCodeOp = workgroupOp.getControlCode ();
628639 // Since a DMA op can have source and target, therefore we can have two BD IDs
629640 // for any DMA op. Hence we maintain a map from DMA op to a vector of BD IDs.
630- DenseMap<AMDAIE::NpuDmaCpyNdOp, SmallVector<AMDAIE::BdIdOp>> dmaOpToBdIdMap;
641+ DenseMap<AMDAIE::NpuDmaCpyNdOp, SmallVector<AMDAIE::BdIdOp>>
642+ shimDmaOpToBdIdMap;
631643 TileDmaBatchGraph tileDmaBatchGraph = createTileDmaBatchGraph (
632644 workgroupOp, controlCodeOp, shimTileToGeneratorMap);
633645 tileDmaBatchGraph.inferRequiredBdIds ();
@@ -636,7 +648,7 @@ LogicalResult assignNpuDmaBdIds(AMDAIE::WorkgroupOp workgroupOp) {
636648 return failure ();
637649 }
638650 if (failed (replaceDmaOps (rewriter, controlCodeOp,
639- bdIdAssignmentUtil.getDmaOpToBdIdMap ())))
651+ bdIdAssignmentUtil.getShimDmaOpToBdIdMap ())))
640652 return failure ();
641653 // At this step we have all the information to traverse and perform the
642654 // replacements of the DMA Ops.
0 commit comments