@@ -82,56 +82,102 @@ struct AIECtrlPacketToDmaPass : AIECtrlPacketToDmaBase<AIECtrlPacketToDmaPass> {
8282 auto newBlockArg = newSeq.getBody ().addArgument (ctrlPktMemrefType, loc);
8383 builder.setInsertionPointToStart (&newSeq.getBody ().front ());
8484
85+ // Collect all npu.control_packet ops, grouped by location in 'batches'
86+ struct BatchInfo {
87+ TileID tileId;
88+ int64_t startOffset;
89+ int64_t totalSize;
90+ std::string shimDmaAllocName;
91+ int shimChan;
92+ };
93+ std::vector<BatchInfo> batches;
94+
8595 int64_t ddrOffset = 0 ;
8696 Block &entry = f.getBody ().front ();
97+
98+ // First pass: collect and batch control packet operations
99+ for (auto &o : entry) {
100+ auto ctrlPktOp = dyn_cast<NpuControlPacketOp>(&o);
101+ if (!ctrlPktOp)
102+ continue ;
103+ int col = ctrlPktOp.getColumnFromAddr ();
104+ int row = ctrlPktOp.getRowFromAddr ();
105+
106+ // Calculate control packet size
107+ int64_t ctrlPktSize = 0 ;
108+ auto data = ctrlPktOp.getData ();
109+ if (data)
110+ ctrlPktSize = data->size ();
111+ else if (ctrlPktOp.getLength ())
112+ ctrlPktSize = *ctrlPktOp.getLength ();
113+ ctrlPktSize++; // Ctrl info word
114+ ctrlPktSize++; // Packet header
115+
116+ // Check if we can batch with the previous packet
117+ if (targetModel.getTargetArch () == AIEArch::AIE2p && !batches.empty () &&
118+ batches.back ().tileId == TileID{col, row}) {
119+ // Add to existing batch
120+ batches.back ().totalSize += ctrlPktSize;
121+ } else {
122+ // Start new batch
123+ auto rowToShimChanMap =
124+ getRowToShimChanMap (targetModel, WireBundle::DMA);
125+ int shimChan = rowToShimChanMap[row];
126+
127+ std::string shimDmaAllocName = " ctrlpkt" ;
128+ shimDmaAllocName += " _col" + std::to_string (col);
129+ shimDmaAllocName += " _mm2s" ;
130+ shimDmaAllocName += " _chan" + std::to_string (shimChan);
131+
132+ batches.push_back ({TileID{col, row}, ddrOffset, ctrlPktSize,
133+ shimDmaAllocName, shimChan});
134+ }
135+ ddrOffset += ctrlPktSize;
136+ }
137+
138+ // Second pass: emit batched operations in original order
139+ auto batchIt = batches.begin ();
140+
87141 for (auto &o : entry) {
88- llvm::TypeSwitch<Operation *>(&o).Case <NpuControlPacketOp>(
89- [&](auto op) {
90- // Destination tile info
91- int col = op.getColumnFromAddr ();
92- int row = op.getRowFromAddr ();
93-
94- // Control packet offset (to raw data at ddr) and size
95- int64_t ctrlPktSize = 0 ;
96- auto data = op.getData ();
97- if (data)
98- ctrlPktSize = data->size ();
99- else if (op.getLength ())
100- ctrlPktSize = *op.getLength ();
101- ctrlPktSize++; // Ctrl info word
102- ctrlPktSize++; // Packet header
103-
104- const std::vector<int64_t > staticOffsets = {0 , 0 , 0 , ddrOffset};
105- ddrOffset += ctrlPktSize;
106- const std::vector<int64_t > staticSizes = {1 , 1 , 1 , ctrlPktSize};
107- const std::vector<int64_t > staticStrides = {0 , 0 , 0 , 1 };
108-
109- // Shim dma alloc symbol name
110- std::string shimDmaAllocName = " ctrlpkt" ;
111- shimDmaAllocName += " _col" + std::to_string (col);
112- shimDmaAllocName += " _mm2s" ;
113- auto rowToShimChanMap =
114- getRowToShimChanMap (targetModel, WireBundle::DMA);
115- int shimChan = rowToShimChanMap[row];
116- shimDmaAllocName += " _chan" + std::to_string (shimChan);
117-
118- StringRef metadata = builder.getStringAttr (shimDmaAllocName);
119- builder.create <NpuDmaMemcpyNdOp>(
120- builder.getUnknownLoc (), newBlockArg, SmallVector<Value>{},
121- SmallVector<Value>{}, SmallVector<Value>{},
122- ArrayRef (staticOffsets), ArrayRef (staticSizes),
123- ArrayRef (staticStrides), nullptr , metadata, 0 , true , 0 , 0 , 0 ,
124- 0 , 0 , 0 );
125-
126- auto shimRow = builder.getI32IntegerAttr (0 );
127- auto shimCol = builder.getI32IntegerAttr (col);
128- auto dir = builder.getI32IntegerAttr (1 ); // MM2S
129- auto chan = builder.getI32IntegerAttr (shimChan);
130- auto col_num = builder.getI32IntegerAttr (1 );
131- auto row_num = builder.getI32IntegerAttr (1 );
132- builder.create <AIEX::NpuSyncOp>(loc, shimCol, shimRow, dir, chan,
133- col_num, row_num);
134- });
142+ auto ctrlPktOp = dyn_cast<NpuControlPacketOp>(&o);
143+ if (!ctrlPktOp)
144+ continue ;
145+
146+ assert (batchIt != batches.end () &&
147+ " Expected control packet to be in a batch" );
148+
149+ int col = ctrlPktOp.getColumnFromAddr ();
150+ int row = ctrlPktOp.getRowFromAddr ();
151+
152+ // Check if this is the first packet of a new batch,
153+ // otherwise skip it.
154+ if (batchIt->tileId != TileID{col, row})
155+ continue ;
156+
157+ // Emit the batched DMA operation for this (col, row) pair
158+ const std::vector<int64_t > staticOffsets = {0 , 0 , 0 ,
159+ batchIt->startOffset };
160+ const std::vector<int64_t > staticSizes = {1 , 1 , 1 , batchIt->totalSize };
161+ const std::vector<int64_t > staticStrides = {0 , 0 , 0 , 1 };
162+
163+ StringRef metadata = builder.getStringAttr (batchIt->shimDmaAllocName );
164+ builder.create <NpuDmaMemcpyNdOp>(
165+ builder.getUnknownLoc (), newBlockArg, SmallVector<Value>{},
166+ SmallVector<Value>{}, SmallVector<Value>{}, ArrayRef (staticOffsets),
167+ ArrayRef (staticSizes), ArrayRef (staticStrides), nullptr , metadata,
168+ 0 , true , 0 , 0 , 0 , 0 , 0 , 0 );
169+
170+ auto shimRow = builder.getI32IntegerAttr (0 );
171+ auto shimCol = builder.getI32IntegerAttr (col);
172+ auto dir = builder.getI32IntegerAttr (1 ); // MM2S
173+ auto chan = builder.getI32IntegerAttr (batchIt->shimChan );
174+ auto col_num = builder.getI32IntegerAttr (1 );
175+ auto row_num = builder.getI32IntegerAttr (1 );
176+ builder.create <AIEX::NpuSyncOp>(loc, shimCol, shimRow, dir, chan,
177+ col_num, row_num);
178+ ++batchIt;
179+ if (batchIt == batches.end ())
180+ break ;
135181 }
136182
137183 erased.push_back (f);
0 commit comments