Skip to content

Commit becbe40

Browse files
authored
Batch control packets by tile for aie2p (#2604)
1 parent 7b9a467 commit becbe40

File tree

2 files changed

+125
-47
lines changed

2 files changed

+125
-47
lines changed

lib/Dialect/AIEX/Transforms/AIECtrlPacketToDma.cpp

Lines changed: 93 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -82,56 +82,102 @@ struct AIECtrlPacketToDmaPass : AIECtrlPacketToDmaBase<AIECtrlPacketToDmaPass> {
8282
auto newBlockArg = newSeq.getBody().addArgument(ctrlPktMemrefType, loc);
8383
builder.setInsertionPointToStart(&newSeq.getBody().front());
8484

85+
// Collect all npu.control_packet ops, grouped by location in 'batches'
86+
struct BatchInfo {
87+
TileID tileId;
88+
int64_t startOffset;
89+
int64_t totalSize;
90+
std::string shimDmaAllocName;
91+
int shimChan;
92+
};
93+
std::vector<BatchInfo> batches;
94+
8595
int64_t ddrOffset = 0;
8696
Block &entry = f.getBody().front();
97+
98+
// First pass: collect and batch control packet operations
99+
for (auto &o : entry) {
100+
auto ctrlPktOp = dyn_cast<NpuControlPacketOp>(&o);
101+
if (!ctrlPktOp)
102+
continue;
103+
int col = ctrlPktOp.getColumnFromAddr();
104+
int row = ctrlPktOp.getRowFromAddr();
105+
106+
// Calculate control packet size
107+
int64_t ctrlPktSize = 0;
108+
auto data = ctrlPktOp.getData();
109+
if (data)
110+
ctrlPktSize = data->size();
111+
else if (ctrlPktOp.getLength())
112+
ctrlPktSize = *ctrlPktOp.getLength();
113+
ctrlPktSize++; // Ctrl info word
114+
ctrlPktSize++; // Packet header
115+
116+
// Check if we can batch with the previous packet
117+
if (targetModel.getTargetArch() == AIEArch::AIE2p && !batches.empty() &&
118+
batches.back().tileId == TileID{col, row}) {
119+
// Add to existing batch
120+
batches.back().totalSize += ctrlPktSize;
121+
} else {
122+
// Start new batch
123+
auto rowToShimChanMap =
124+
getRowToShimChanMap(targetModel, WireBundle::DMA);
125+
int shimChan = rowToShimChanMap[row];
126+
127+
std::string shimDmaAllocName = "ctrlpkt";
128+
shimDmaAllocName += "_col" + std::to_string(col);
129+
shimDmaAllocName += "_mm2s";
130+
shimDmaAllocName += "_chan" + std::to_string(shimChan);
131+
132+
batches.push_back({TileID{col, row}, ddrOffset, ctrlPktSize,
133+
shimDmaAllocName, shimChan});
134+
}
135+
ddrOffset += ctrlPktSize;
136+
}
137+
138+
// Second pass: emit batched operations in original order
139+
auto batchIt = batches.begin();
140+
87141
for (auto &o : entry) {
88-
llvm::TypeSwitch<Operation *>(&o).Case<NpuControlPacketOp>(
89-
[&](auto op) {
90-
// Destination tile info
91-
int col = op.getColumnFromAddr();
92-
int row = op.getRowFromAddr();
93-
94-
// Control packet offset (to raw data at ddr) and size
95-
int64_t ctrlPktSize = 0;
96-
auto data = op.getData();
97-
if (data)
98-
ctrlPktSize = data->size();
99-
else if (op.getLength())
100-
ctrlPktSize = *op.getLength();
101-
ctrlPktSize++; // Ctrl info word
102-
ctrlPktSize++; // Packet header
103-
104-
const std::vector<int64_t> staticOffsets = {0, 0, 0, ddrOffset};
105-
ddrOffset += ctrlPktSize;
106-
const std::vector<int64_t> staticSizes = {1, 1, 1, ctrlPktSize};
107-
const std::vector<int64_t> staticStrides = {0, 0, 0, 1};
108-
109-
// Shim dma alloc symbol name
110-
std::string shimDmaAllocName = "ctrlpkt";
111-
shimDmaAllocName += "_col" + std::to_string(col);
112-
shimDmaAllocName += "_mm2s";
113-
auto rowToShimChanMap =
114-
getRowToShimChanMap(targetModel, WireBundle::DMA);
115-
int shimChan = rowToShimChanMap[row];
116-
shimDmaAllocName += "_chan" + std::to_string(shimChan);
117-
118-
StringRef metadata = builder.getStringAttr(shimDmaAllocName);
119-
builder.create<NpuDmaMemcpyNdOp>(
120-
builder.getUnknownLoc(), newBlockArg, SmallVector<Value>{},
121-
SmallVector<Value>{}, SmallVector<Value>{},
122-
ArrayRef(staticOffsets), ArrayRef(staticSizes),
123-
ArrayRef(staticStrides), nullptr, metadata, 0, true, 0, 0, 0,
124-
0, 0, 0);
125-
126-
auto shimRow = builder.getI32IntegerAttr(0);
127-
auto shimCol = builder.getI32IntegerAttr(col);
128-
auto dir = builder.getI32IntegerAttr(1); // MM2S
129-
auto chan = builder.getI32IntegerAttr(shimChan);
130-
auto col_num = builder.getI32IntegerAttr(1);
131-
auto row_num = builder.getI32IntegerAttr(1);
132-
builder.create<AIEX::NpuSyncOp>(loc, shimCol, shimRow, dir, chan,
133-
col_num, row_num);
134-
});
142+
auto ctrlPktOp = dyn_cast<NpuControlPacketOp>(&o);
143+
if (!ctrlPktOp)
144+
continue;
145+
146+
assert(batchIt != batches.end() &&
147+
"Expected control packet to be in a batch");
148+
149+
int col = ctrlPktOp.getColumnFromAddr();
150+
int row = ctrlPktOp.getRowFromAddr();
151+
152+
// Check if this is the first packet of a new batch,
153+
// otherwise skip it.
154+
if (batchIt->tileId != TileID{col, row})
155+
continue;
156+
157+
// Emit the batched DMA operation for this (col, row) pair
158+
const std::vector<int64_t> staticOffsets = {0, 0, 0,
159+
batchIt->startOffset};
160+
const std::vector<int64_t> staticSizes = {1, 1, 1, batchIt->totalSize};
161+
const std::vector<int64_t> staticStrides = {0, 0, 0, 1};
162+
163+
StringRef metadata = builder.getStringAttr(batchIt->shimDmaAllocName);
164+
builder.create<NpuDmaMemcpyNdOp>(
165+
builder.getUnknownLoc(), newBlockArg, SmallVector<Value>{},
166+
SmallVector<Value>{}, SmallVector<Value>{}, ArrayRef(staticOffsets),
167+
ArrayRef(staticSizes), ArrayRef(staticStrides), nullptr, metadata,
168+
0, true, 0, 0, 0, 0, 0, 0);
169+
170+
auto shimRow = builder.getI32IntegerAttr(0);
171+
auto shimCol = builder.getI32IntegerAttr(col);
172+
auto dir = builder.getI32IntegerAttr(1); // MM2S
173+
auto chan = builder.getI32IntegerAttr(batchIt->shimChan);
174+
auto col_num = builder.getI32IntegerAttr(1);
175+
auto row_num = builder.getI32IntegerAttr(1);
176+
builder.create<AIEX::NpuSyncOp>(loc, shimCol, shimRow, dir, chan,
177+
col_num, row_num);
178+
++batchIt;
179+
if (batchIt == batches.end())
180+
break;
135181
}
136182

137183
erased.push_back(f);

test/dialect/AIEX/ctrl_pkt_to_dma.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,35 @@ aie.device(npu1_1col) {
4242
aie.shim_dma_allocation @ctrlpkt_col0_mm2s_chan0(MM2S, 0, 0)
4343
memref.global "public" @ctrlpkt_col0_mm2s_chan0 : memref<2048xi32>
4444
}
45+
46+
// -----
47+
48+
// CHECK-LABEL: aie.device(npu1) {
49+
// CHECK: aiex.npu.dma_memcpy_nd(%arg0[0, 0, 0, 0][1, 1, 1, 3][0, 0, 0, 1])
50+
// CHECK: aiex.npu.dma_memcpy_nd(%arg0[0, 0, 0, 3][1, 1, 1, 3][0, 0, 0, 1])
51+
// CHECK: aiex.npu.dma_memcpy_nd(%arg0[0, 0, 0, 6][1, 1, 1, 3][0, 0, 0, 1])
52+
// CHECK: aiex.npu.dma_memcpy_nd(%arg0[0, 0, 0, 9][1, 1, 1, 3][0, 0, 0, 1])
53+
// CHECK: aiex.npu.dma_memcpy_nd(%arg0[0, 0, 0, 12][1, 1, 1, 3][0, 0, 0, 1])
54+
aie.device(npu1) {
55+
aiex.runtime_sequence() {
56+
aiex.control_packet {address = 0 : ui32, data = array<i32: 100>, opcode = 0 : i32, stream_id = 0 : i32}
57+
aiex.control_packet {address = 4 : ui32, data = array<i32: 200>, opcode = 0 : i32, stream_id = 0 : i32}
58+
aiex.control_packet {address = 8 : ui32, data = array<i32: 300>, opcode = 0 : i32, stream_id = 0 : i32}
59+
aiex.control_packet {address = 12 : ui32, data = array<i32: 400>, opcode = 0 : i32, stream_id = 0 : i32}
60+
aiex.control_packet {address = 16 : ui32, data = array<i32: 500>, opcode = 0 : i32, stream_id = 0 : i32}
61+
}
62+
}
63+
64+
// -----
65+
66+
// CHECK-LABEL: aie.device(npu2) {
67+
// CHECK: aiex.npu.dma_memcpy_nd(%arg0[0, 0, 0, 0][1, 1, 1, 15][0, 0, 0, 1])
68+
aie.device(npu2) {
69+
aiex.runtime_sequence() {
70+
aiex.control_packet {address = 0 : ui32, data = array<i32: 100>, opcode = 0 : i32, stream_id = 0 : i32}
71+
aiex.control_packet {address = 4 : ui32, data = array<i32: 200>, opcode = 0 : i32, stream_id = 0 : i32}
72+
aiex.control_packet {address = 8 : ui32, data = array<i32: 300>, opcode = 0 : i32, stream_id = 0 : i32}
73+
aiex.control_packet {address = 12 : ui32, data = array<i32: 400>, opcode = 0 : i32, stream_id = 0 : i32}
74+
aiex.control_packet {address = 16 : ui32, data = array<i32: 500>, opcode = 0 : i32, stream_id = 0 : i32}
75+
}
76+
}

0 commit comments

Comments
 (0)