Skip to content

Commit 24dd68d

Browse files
[Reprogram][AssignNpuDmaBdId] Add support for handling non-shim DMA ops (#1332)
-- This commit adds support for handling non-shim DMA ops as part of `assign-npu-dma-bd-ids` pass. -- This is being added to AMDAIE dialect to make [DMA reprogramming](#1287) work. Signed-off-by: Abhishek Varma <[email protected]>
1 parent 10fdd67 commit 24dd68d

File tree

2 files changed

+93
-31
lines changed

2 files changed

+93
-31
lines changed

compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEAssignNpuDmaBdIds.cpp

Lines changed: 43 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ namespace mlir::iree_compiler::AMDAIE {
2020
namespace {
2121

2222
/// Utility to retrieve a TileOp from a vector of tile values, while doing
23-
/// appropriate verifications.
23+
/// appropriate verifications. It's expected to return failure for non-shim
24+
/// tiles.
2425
template <CopyOpOperateOn OperateOn>
2526
FailureOr<AMDAIE::TileOp> getGeneratorTileOp(
2627
AMDAIE::NpuDmaCpyNdOp npuDmaOp,
@@ -52,9 +53,9 @@ FailureOr<AMDAIE::TileOp> getGeneratorTileOp(
5253
<< tiles.size();
5354
}
5455
Value tile = tiles[0];
55-
if (!shimTileToGeneratorMap.contains(tile))
56-
return npuDmaOp.emitOpError()
57-
<< "no channel BD ID generator found for tile: " << tile;
56+
// Since we can have non-Shim DMA ops as npu.dma_cpy_nd when reprogramming
57+
// DMAs, we can simply return failure instead of emitting an error.
58+
if (!shimTileToGeneratorMap.contains(tile)) return failure();
5859

5960
auto tileOp = dyn_cast_if_present<AMDAIE::TileOp>(tile.getDefiningOp());
6061
if (!tileOp) return npuDmaOp.emitOpError() << "no tile op found";
@@ -244,16 +245,16 @@ class BdIdAssignmentUtil {
244245
DenseMap<AMDAIE::BdIdOp, SmallVector<uint32_t>> bdIdOpToBdIdsMap;
245246
// A mapping from DMAOp to its corresponding source/target BD IDs.
246247
DenseMap<AMDAIE::NpuDmaCpyNdOp, SmallVector<AMDAIE::BdIdOp, 2>>
247-
dmaOpToBdIdMap;
248+
shimDmaOpToBdIdMap;
248249

249250
public:
250251
BdIdAssignmentUtil(
251252
DenseMap<Value, ChannelBdIdGenerator> shimTileToGeneratorMap)
252253
: shimTileToGeneratorMap(std::move(shimTileToGeneratorMap)) {}
253254

254255
DenseMap<AMDAIE::NpuDmaCpyNdOp, SmallVector<AMDAIE::BdIdOp, 2>> &
255-
getDmaOpToBdIdMap() {
256-
return dmaOpToBdIdMap;
256+
getShimDmaOpToBdIdMap() {
257+
return shimDmaOpToBdIdMap;
257258
}
258259

259260
/// Assign Bd Ids to each DmaBatch belonging to a particular Tile in the
@@ -347,9 +348,10 @@ class BdIdAssignmentUtil {
347348
}
348349

349350
/// Assign required Bd Ids to the DmaOps of the current DmaBatch. This
350-
/// assignment is tracked by maintaining `dmaOpToBdIdMap`, which essentially
351-
/// maps a DmaOp to its source/target Bd Ids. Also, the API splits the
352-
/// available BD IDs equally amongst all DmaOps in the DmaBatch when assigning
351+
/// assignment is tracked by maintaining `shimDmaOpToBdIdMap`, which
352+
/// essentially maps a DmaOp to its source/target Bd Ids. Also, the API splits
353+
/// the available BD IDs equally amongst all DmaOps in the DmaBatch when
354+
/// assigning
353355
LogicalResult assignRequiredBdIdsInCurrentBatch(
354356
IRRewriter &rewriter, AMDAIE::TileOp tileOp,
355357
std::unique_ptr<DmaBatch> &dmaBatch) {
@@ -410,12 +412,12 @@ class BdIdAssignmentUtil {
410412
AMDAIE::BdIdOp bdIdOp = rewriter.create<AMDAIE::BdIdOp>(
411413
rewriter.getUnknownLoc(), tileOp, affineApply.getResult());
412414
bdIdOpToBdIdsMap[bdIdOp] = bdIds;
413-
if (!dmaOpToBdIdMap.contains(dmaOp)) {
415+
if (!shimDmaOpToBdIdMap.contains(dmaOp)) {
414416
SmallVector<AMDAIE::BdIdOp, 2> bdIdOps = {nullptr, nullptr};
415-
dmaOpToBdIdMap[dmaOp] = bdIdOps;
417+
shimDmaOpToBdIdMap[dmaOp] = bdIdOps;
416418
}
417419

418-
dmaOpToBdIdMap[dmaOp][dmaTileData.bdIdMapIndex] = bdIdOp;
420+
shimDmaOpToBdIdMap[dmaOp][dmaTileData.bdIdMapIndex] = bdIdOp;
419421
} else {
420422
// Assign a constant BD ID.
421423
std::optional<uint32_t> bdId = generator.getAndAssignBdId(
@@ -425,11 +427,11 @@ class BdIdAssignmentUtil {
425427
rewriter.getUnknownLoc(), rewriter.getIndexAttr(bdId.value()));
426428
AMDAIE::BdIdOp bdIdOp = rewriter.create<AMDAIE::BdIdOp>(
427429
rewriter.getUnknownLoc(), tileOp, constant.getResult());
428-
if (!dmaOpToBdIdMap.contains(dmaOp)) {
430+
if (!shimDmaOpToBdIdMap.contains(dmaOp)) {
429431
SmallVector<AMDAIE::BdIdOp, 2> bdIdOps = {nullptr, nullptr};
430-
dmaOpToBdIdMap[dmaOp] = bdIdOps;
432+
shimDmaOpToBdIdMap[dmaOp] = bdIdOps;
431433
}
432-
dmaOpToBdIdMap[dmaOp][dmaTileData.bdIdMapIndex] = bdIdOp;
434+
shimDmaOpToBdIdMap[dmaOp][dmaTileData.bdIdMapIndex] = bdIdOp;
433435
}
434436
// Reset to fetch next DmaOps' DmaTileData.
435437
dmaTileData.bdIdMapIndex = -1;
@@ -471,17 +473,17 @@ class BdIdAssignmentUtil {
471473
}
472474

473475
/// DmaOps are assigned Bd Ids prior to invoking this function and a map
474-
/// `dmaOpToBdIdMap` is maintained that maps a DmaOp to its source/target Bd
475-
/// Ids. For each DmaOp in the list `dmaOps`, this API will check the
476-
/// `dmaOpToBdIdMap` and release Bd Ids if assigned.
476+
/// `shimDmaOpToBdIdMap` is maintained that maps a DmaOp to its source/target
477+
/// Bd Ids. For each DmaOp in the list `dmaOps`, this API will check the
478+
/// `shimDmaOpToBdIdMap` and release Bd Ids if assigned.
477479
LogicalResult releaseAssignedBdIdsInDmaOps(
478480
SmallVectorImpl<AMDAIE::NpuDmaCpyNdOp> &dmaOps) {
479481
// Release BD ID used by input DMA op.
480482
for (AMDAIE::NpuDmaCpyNdOp npuDmaOp : dmaOps) {
481-
if (AMDAIE::BdIdOp bdIdOp = dmaOpToBdIdMap[npuDmaOp][0]; bdIdOp) {
483+
if (AMDAIE::BdIdOp bdIdOp = shimDmaOpToBdIdMap[npuDmaOp][0]; bdIdOp) {
482484
if (failed(releaseBdId(bdIdOp))) return failure();
483485
}
484-
if (AMDAIE::BdIdOp bdIdOp = dmaOpToBdIdMap[npuDmaOp][1]; bdIdOp) {
486+
if (AMDAIE::BdIdOp bdIdOp = shimDmaOpToBdIdMap[npuDmaOp][1]; bdIdOp) {
485487
if (failed(releaseBdId(bdIdOp))) return failure();
486488
}
487489
}
@@ -490,20 +492,22 @@ class BdIdAssignmentUtil {
490492
};
491493

492494
/// Traverse each DmaOp inside ControlCode and replace it with new new DmaOp
493-
/// that has Bd Ids assigned using `dmaOpToBdIdMap`.
495+
/// that has Bd Ids assigned using `shimDmaOpToBdIdMap`.
494496
static LogicalResult replaceDmaOps(
495497
IRRewriter &rewriter, AMDAIE::ControlCodeOp controlCodeOp,
496498
DenseMap<AMDAIE::NpuDmaCpyNdOp, SmallVector<AMDAIE::BdIdOp, 2>>
497-
&dmaOpToBdIdMap) {
499+
&shimDmaOpToBdIdMap) {
498500
WalkResult res = controlCodeOp->walk([&](AMDAIE::NpuDmaCpyNdOp npuDmaOp) {
499-
assert(dmaOpToBdIdMap.contains(npuDmaOp) && "No BD ID mapping found");
501+
if (!shimDmaOpToBdIdMap.contains(npuDmaOp)) return WalkResult::advance();
500502
Value sourceBdId = nullptr;
501503
Value targetBdId = nullptr;
502-
if (AMDAIE::BdIdOp bdIdOp = dmaOpToBdIdMap[npuDmaOp][/*sourceBdIdIndex=*/0];
504+
if (AMDAIE::BdIdOp bdIdOp =
505+
shimDmaOpToBdIdMap[npuDmaOp][/*sourceBdIdIndex=*/0];
503506
bdIdOp) {
504507
sourceBdId = bdIdOp.getResult();
505508
}
506-
if (AMDAIE::BdIdOp bdIdOp = dmaOpToBdIdMap[npuDmaOp][/*targetBdIdIndex=*/1];
509+
if (AMDAIE::BdIdOp bdIdOp =
510+
shimDmaOpToBdIdMap[npuDmaOp][/*targetBdIdIndex=*/1];
507511
bdIdOp) {
508512
targetBdId = bdIdOp.getResult();
509513
}
@@ -554,19 +558,26 @@ static TileDmaBatchGraph createTileDmaBatchGraph(
554558
};
555559

556560
auto processNpuDmaCpyNdOp = [&](AMDAIE::NpuDmaCpyNdOp dmaOp) {
561+
bool isShimDmaOp = false;
557562
if (dmaOp.getSource()) {
558563
FailureOr<AMDAIE::TileOp> tile =
559564
getGeneratorTileOp<CopyOpOperateOn::Source>(dmaOp,
560565
shimTileToGeneratorMap);
561-
if (succeeded(tile)) tileDmaBatchGraph.addDmaToBatch(*tile, dmaOp);
566+
if (succeeded(tile)) {
567+
tileDmaBatchGraph.addDmaToBatch(*tile, dmaOp);
568+
isShimDmaOp = true;
569+
}
562570
}
563571
if (dmaOp.getTarget()) {
564572
FailureOr<AMDAIE::TileOp> tile =
565573
getGeneratorTileOp<CopyOpOperateOn::Target>(dmaOp,
566574
shimTileToGeneratorMap);
567-
if (succeeded(tile)) tileDmaBatchGraph.addDmaToBatch(*tile, dmaOp);
575+
if (succeeded(tile)) {
576+
tileDmaBatchGraph.addDmaToBatch(*tile, dmaOp);
577+
isShimDmaOp = true;
578+
}
568579
}
569-
if (!currDmaOp) currDmaOp = dmaOp;
580+
if (!currDmaOp && isShimDmaOp) currDmaOp = dmaOp;
570581
};
571582

572583
auto processNpuDmaWaitOp = [&](AMDAIE::NpuDmaWaitOp npuWaitOp) {
@@ -627,7 +638,8 @@ LogicalResult assignNpuDmaBdIds(AMDAIE::WorkgroupOp workgroupOp) {
627638
AMDAIE::ControlCodeOp controlCodeOp = workgroupOp.getControlCode();
628639
// Since a DMA op can have source and target, therefore we can have two BD IDs
629640
// for any DMA op. Hence we maintain a map from DMA op to a vector of BD IDs.
630-
DenseMap<AMDAIE::NpuDmaCpyNdOp, SmallVector<AMDAIE::BdIdOp>> dmaOpToBdIdMap;
641+
DenseMap<AMDAIE::NpuDmaCpyNdOp, SmallVector<AMDAIE::BdIdOp>>
642+
shimDmaOpToBdIdMap;
631643
TileDmaBatchGraph tileDmaBatchGraph = createTileDmaBatchGraph(
632644
workgroupOp, controlCodeOp, shimTileToGeneratorMap);
633645
tileDmaBatchGraph.inferRequiredBdIds();
@@ -636,7 +648,7 @@ LogicalResult assignNpuDmaBdIds(AMDAIE::WorkgroupOp workgroupOp) {
636648
return failure();
637649
}
638650
if (failed(replaceDmaOps(rewriter, controlCodeOp,
639-
bdIdAssignmentUtil.getDmaOpToBdIdMap())))
651+
bdIdAssignmentUtil.getShimDmaOpToBdIdMap())))
640652
return failure();
641653
// At this step we have all the information to traverse and perform the
642654
// replacements of the DMA Ops.

compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/assign_npu_dma_bd_ids.mlir

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -842,3 +842,53 @@ module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb}
842842
return
843843
}
844844
}
845+
846+
// -----
847+
848+
// CHECK-LABEL: @non_shim_dma_op
849+
#executable_target_amdaie_pdi_fb = #hal.executable.target<"amd-aie", "amdaie-pdi-fb", {num_cols = 1 : i32, num_rows = 1 : i32, target_device = "npu1_4col", ukernels = "none"}>
850+
#pipeline_layout = #hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>
851+
module attributes {hal.executable.target = #executable_target_amdaie_pdi_fb} {
852+
func.func @non_shim_dma_op() {
853+
%c0 = arith.constant 0 : index
854+
%c2 = arith.constant 2 : index
855+
%c1 = arith.constant 1 : index
856+
amdaie.workgroup {
857+
%alloc = memref.alloc() : memref<1x1x8x4x8x4xi32, 2 : i32>
858+
%tile_0_1 = amdaie.tile(%c0, %c1)
859+
%alloc_0 = memref.alloc() : memref<1x1x32x32xi32, 1 : i32>
860+
%lof_0_1 = amdaie.logicalobjectfifo.from_memref %alloc_0, {%tile_0_1} : memref<1x1x32x32xi32, 1 : i32> -> !amdaie.logicalobjectfifo<memref<1024xi32, 1 : i32>, 2>
861+
%alloc_1 = memref.alloc() : memref<1x1x32x32xi32, 1 : i32>
862+
%lof_0_1_2 = amdaie.logicalobjectfifo.from_memref %alloc_1, {%tile_0_1} : memref<1x1x32x32xi32, 1 : i32> -> !amdaie.logicalobjectfifo<memref<1024xi32, 1 : i32>, 2>
863+
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<32x32xi32>
864+
%assume_align = memref.assume_alignment %0, 64 : memref<32x32xi32>
865+
%tile_0_0 = amdaie.tile(%c0, %c0)
866+
%lof_0_0 = amdaie.logicalobjectfifo.from_memref %assume_align, {%tile_0_0} : memref<32x32xi32> -> !amdaie.logicalobjectfifo<memref<1024xi32>>
867+
%channel = amdaie.channel(%tile_0_0, 1, port_type = DMA, direction = MM2S)
868+
%channel_3 = amdaie.channel(%tile_0_1, 1, port_type = DMA, direction = S2MM)
869+
// CHECK: %[[SHIM_CONNECTION:.*]] = amdaie.connection
870+
// CHECK-SAME: memref<1024xi32>
871+
%1 = amdaie.connection(%lof_0_1_2 {%channel_3}, %lof_0_0 {%channel}) {connection_type = #amdaie<connection_type Packet>} : (!amdaie.logicalobjectfifo<memref<1024xi32, 1 : i32>, 2>, !amdaie.logicalobjectfifo<memref<1024xi32>>)
872+
%tile_0_2 = amdaie.tile(%c0, %c2)
873+
%lof_0_2 = amdaie.logicalobjectfifo.from_memref %alloc, {%tile_0_2} : memref<1x1x8x4x8x4xi32, 2 : i32> -> !amdaie.logicalobjectfifo<memref<1024xi32, 2 : i32>, 2>
874+
%channel_4 = amdaie.channel(%tile_0_1, 0, port_type = DMA, direction = MM2S)
875+
%channel_5 = amdaie.channel(%tile_0_2, 0, port_type = DMA, direction = S2MM)
876+
// CHECK: %[[NON_SHIM_CONNECTION:.*]] = amdaie.connection
877+
%2 = amdaie.connection(%lof_0_2 {%channel_5}, %lof_0_1 {%channel_4}) {connection_type = #amdaie<connection_type Circuit>} : (!amdaie.logicalobjectfifo<memref<1024xi32, 2 : i32>, 2>, !amdaie.logicalobjectfifo<memref<1024xi32, 1 : i32>, 2>)
878+
// CHECK: amdaie.controlcode {
879+
amdaie.controlcode {
880+
// CHECK: %[[BD_ID:.*]] = amdaie.bd_id
881+
// CHECK: amdaie.npu.dma_cpy_nd async_source %[[SHIM_CONNECTION]]
882+
// CHECK-SAME: bd_id = %[[BD_ID]]
883+
%3 = amdaie.npu.dma_cpy_nd async_source %1(%lof_0_1_2[0, 0] [32, 32] [32, 1], %lof_0_0[0, 0] [32, 32] [32, 1]) : target_type = !amdaie.logicalobjectfifo<memref<1024xi32, 1 : i32>, 2> source_type = !amdaie.logicalobjectfifo<memref<1024xi32>>
884+
amdaie.npu.dma_wait(%3 : !amdaie.async_source_token)
885+
// CHECK-NOT: amdaie.bd_id
886+
// CHECK: amdaie.npu.dma_cpy_nd async_source %[[NON_SHIM_CONNECTION]]
887+
// CHECK-NOT: bd_id =
888+
%4 = amdaie.npu.dma_cpy_nd async_source %2(%lof_0_2[0, 0, 0] [32, 8, 4] [4, 128, 1], %lof_0_1[0, 0] [32, 32] [32, 1]) : target_type = !amdaie.logicalobjectfifo<memref<1024xi32, 2 : i32>, 2> source_type = !amdaie.logicalobjectfifo<memref<1024xi32, 1 : i32>, 2>
889+
amdaie.end
890+
}
891+
}
892+
return
893+
}
894+
}

0 commit comments

Comments
 (0)