Skip to content

Commit 057ef6b

Browse files
AndraBiscaabiscagithub-actions[bot]
authored
Object FIFO: fix DMA channel detection (#1933)
Co-authored-by: AndraBisca <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent bd40321 commit 057ef6b

File tree

12 files changed

+677
-521
lines changed

12 files changed

+677
-521
lines changed

lib/Dialect/AIE/Transforms/AIEObjectFifoStatefulTransform.cpp

Lines changed: 57 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -69,47 +69,65 @@ class LockAnalysis {
6969
};
7070

7171
//===----------------------------------------------------------------------===//
72-
// TileDMA Channel Analysis
72+
// DMA Channel Analysis
7373
//===----------------------------------------------------------------------===//
7474
class DMAChannelAnalysis {
75-
DenseMap<Value, int> masterChannelsPerTile;
76-
DenseMap<Value, int> slaveChannelsPerTile;
75+
DenseMap<std::tuple<Value, DMAChannelDir, int>, int> channelsPerTile;
7776

7877
public:
7978
DMAChannelAnalysis(DeviceOp &device) {
80-
// go over the channels used for each tile and update the master/slave
81-
// channel maps
79+
// go over the channels used for each tile and update channel map
8280
for (auto memOp : device.getOps<MemOp>()) {
8381
Region &r = memOp.getBody();
8482
for (auto &bl : r.getBlocks()) {
8583
for (auto op : bl.getOps<DMAStartOp>()) {
86-
if (op.isSend())
87-
getMasterDMAChannel(memOp.getTile());
88-
else
89-
getSlaveDMAChannel(memOp.getTile());
84+
channelsPerTile[{memOp.getTile(), op.getChannelDir(),
85+
op.getChannelIndex()}] = 1;
86+
}
87+
}
88+
}
89+
for (auto memOp : device.getOps<MemTileDMAOp>()) {
90+
Region &r = memOp.getBody();
91+
for (auto &bl : r.getBlocks()) {
92+
for (auto op : bl.getOps<DMAStartOp>()) {
93+
channelsPerTile[{memOp.getTile(), op.getChannelDir(),
94+
op.getChannelIndex()}] = 1;
95+
}
96+
}
97+
}
98+
for (auto memOp : device.getOps<ShimDMAOp>()) {
99+
Region &r = memOp.getBody();
100+
for (auto &bl : r.getBlocks()) {
101+
for (auto op : bl.getOps<DMAStartOp>()) {
102+
channelsPerTile[{memOp.getTile(), op.getChannelDir(),
103+
op.getChannelIndex()}] = 1;
90104
}
91105
}
92106
}
93107
}
94108

95-
/// Given an AIE tile, returns its next usable master channel.
96-
DMAChannel getMasterDMAChannel(Value tile) {
97-
if (masterChannelsPerTile.find(tile) == masterChannelsPerTile.end())
98-
masterChannelsPerTile[tile] = 0;
99-
else
100-
masterChannelsPerTile[tile]++;
101-
DMAChannel dmaChan = {DMAChannelDir::MM2S, masterChannelsPerTile[tile]};
102-
return dmaChan;
103-
}
104-
105-
/// Given an AIE tile, returns its next usable slave channel.
106-
DMAChannel getSlaveDMAChannel(Value tile) {
107-
if (slaveChannelsPerTile.find(tile) == slaveChannelsPerTile.end())
108-
slaveChannelsPerTile[tile] = 0;
109-
else
110-
slaveChannelsPerTile[tile]++;
111-
DMAChannel dmaChan = {DMAChannelDir::S2MM, slaveChannelsPerTile[tile]};
112-
return dmaChan;
109+
/// Given a tile and DMAChannelDir, returns next usable channel index for
110+
/// that tile.
111+
int getDMAChannelIndex(TileOp tileOp, DMAChannelDir dir) {
112+
const auto &targetModel = getTargetModel(tileOp);
113+
int maxChannelNum = 0;
114+
if (tileOp.isShimTile())
115+
maxChannelNum = 2;
116+
else {
117+
if (dir == DMAChannelDir::MM2S)
118+
maxChannelNum = targetModel.getNumSourceSwitchboxConnections(
119+
tileOp.getCol(), tileOp.getRow(), WireBundle::DMA);
120+
else
121+
maxChannelNum = targetModel.getNumDestSwitchboxConnections(
122+
tileOp.getCol(), tileOp.getRow(), WireBundle::DMA);
123+
}
124+
for (int i = 0; i < maxChannelNum; i++)
125+
if (int usageCnt = channelsPerTile[{tileOp.getResult(), dir, i}];
126+
usageCnt == 0) {
127+
channelsPerTile[{tileOp.getResult(), dir, i}] = 1;
128+
return i;
129+
}
130+
return -1;
113131
}
114132
};
115133

@@ -1518,8 +1536,12 @@ struct AIEObjectFifoStatefulTransformPass
15181536
// rely on shared memory and share the same buffers.
15191537
for (auto &[producer, consumers] : splitFifos) {
15201538
// create producer tile DMA
1521-
DMAChannel producerChan =
1522-
dmaAnalysis.getMasterDMAChannel(producer.getProducerTile());
1539+
int producerChanIndex = dmaAnalysis.getDMAChannelIndex(
1540+
producer.getProducerTileOp(), DMAChannelDir::MM2S);
1541+
if (producerChanIndex == -1)
1542+
producer.getProducerTileOp().emitOpError(
1543+
"number of output DMA channel exceeded!");
1544+
DMAChannel producerChan = {DMAChannelDir::MM2S, producerChanIndex};
15231545
createDMA(device, builder, producer, producerChan.direction,
15241546
producerChan.channel, 0, producer.getDimensionsToStreamAttr(),
15251547
producer.getPadDimensionsAttr());
@@ -1535,8 +1557,12 @@ struct AIEObjectFifoStatefulTransformPass
15351557
for (auto consumer : consumers) {
15361558

15371559
// create consumer tile DMA
1538-
DMAChannel consumerChan =
1539-
dmaAnalysis.getSlaveDMAChannel(consumer.getProducerTile());
1560+
int consumerChanIndex = dmaAnalysis.getDMAChannelIndex(
1561+
consumer.getProducerTileOp(), DMAChannelDir::S2MM);
1562+
if (consumerChanIndex == -1)
1563+
consumer.getProducerTileOp().emitOpError(
1564+
"number of input DMA channel exceeded!");
1565+
DMAChannel consumerChan = {DMAChannelDir::S2MM, consumerChanIndex};
15401566
BDDimLayoutArrayAttr consumerDims =
15411567
consumer.getDimensionsFromStreamPerConsumer()[0];
15421568
createDMA(device, builder, consumer, consumerChan.direction,

test/npu-xrt/adjacent_memtile_access/three_memtiles/aie.mlir

Lines changed: 0 additions & 240 deletions
This file was deleted.

0 commit comments

Comments
 (0)