Skip to content

Commit d9612b9

Browse files
authored
Add aiex.npu.blockwrite operation (Xilinx#1638)
1 parent 99ef5ff commit d9612b9

File tree

8 files changed

+191
-99
lines changed

8 files changed

+191
-99
lines changed

include/aie/Dialect/AIEX/IR/AIEX.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,22 @@ def AIE_NpuWrite32Op: AIEX_Op<"npu.write32", []> {
683683
}];
684684
}
685685

686+
// BLOCKWRITE
687+
def AIE_NpuBlockWriteOp: AIEX_Op<"npu.blockwrite", []> {
688+
let summary = "blockwrite operator";
689+
let arguments = (
690+
ins AnyMemRef:$data,
691+
UI32Attr:$address
692+
);
693+
let results = (outs );
694+
let assemblyFormat = [{
695+
`(` $data `)` attr-dict `:` type($data)
696+
}];
697+
let description = [{
698+
blockwrite operator
699+
}];
700+
}
701+
686702
// OP_SYNC
687703
def AIE_NpuSyncOp: AIEX_Op<"npu.sync", []> {
688704
let summary = "sync operator";

lib/Dialect/AIEX/Transforms/AIEDmaToNpu.cpp

Lines changed: 107 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,10 @@ struct ShimDMAllocationGetter {
7070
};
7171
} // namespace
7272

73-
struct RtpToNpuPattern : OpConversionPattern<NpuWriteRTPOp> {
73+
struct RtpToWrite32Pattern : OpConversionPattern<NpuWriteRTPOp> {
7474
using OpConversionPattern::OpConversionPattern;
7575

76-
RtpToNpuPattern(MLIRContext *context, PatternBenefit benefit = 1)
76+
RtpToWrite32Pattern(MLIRContext *context, PatternBenefit benefit = 1)
7777
: OpConversionPattern(context, benefit) {}
7878

7979
LogicalResult
@@ -118,12 +118,12 @@ struct RtpToNpuPattern : OpConversionPattern<NpuWriteRTPOp> {
118118
}
119119
};
120120

121-
struct PushToNpuPattern : OpConversionPattern<NpuPushQueueOp> {
121+
struct PushQueuetoWrite32Pattern : OpConversionPattern<NpuPushQueueOp> {
122122

123123
public:
124124
using OpConversionPattern::OpConversionPattern;
125125

126-
PushToNpuPattern(MLIRContext *context, PatternBenefit benefit = 1)
126+
PushQueuetoWrite32Pattern(MLIRContext *context, PatternBenefit benefit = 1)
127127
: OpConversionPattern(context, benefit) {}
128128

129129
LogicalResult
@@ -343,16 +343,16 @@ struct DmaToNpuPattern : OpConversionPattern<NpuDmaMemcpyNdOp> {
343343
/// Convert NpuDmaWaitOp into NpuSyncOp by retrieving the necessary
344344
/// information from the ShimDMAAllocationOp referenced through the
345345
/// symbol argument of this op.
346-
struct DmaWaitToNpuPattern : OpConversionPattern<NpuDmaWaitOp> {
346+
struct DmaWaitToSyncPattern : OpConversionPattern<NpuDmaWaitOp> {
347347

348348
private:
349349
ShimDMAllocationGetter &allocGetter;
350350

351351
public:
352352
using OpConversionPattern::OpConversionPattern;
353353

354-
DmaWaitToNpuPattern(MLIRContext *context, ShimDMAllocationGetter &getter,
355-
PatternBenefit benefit = 1)
354+
DmaWaitToSyncPattern(MLIRContext *context, ShimDMAllocationGetter &getter,
355+
PatternBenefit benefit = 1)
356356
: OpConversionPattern(context, benefit), allocGetter(getter) {}
357357

358358
LogicalResult
@@ -374,11 +374,103 @@ struct DmaWaitToNpuPattern : OpConversionPattern<NpuDmaWaitOp> {
374374
op, shimDmaAllocOp->getCol(), /* row */ 0,
375375
static_cast<uint32_t>(shimDmaAllocOp->getChannelDir()),
376376
shimDmaAllocOp->getChannelIndex(), 1, 1);
377+
378+
return success();
379+
}
380+
};
381+
382+
struct WriteBdToBlockWritePattern : OpConversionPattern<NpuWriteBdOp> {
383+
using OpConversionPattern::OpConversionPattern;
384+
385+
WriteBdToBlockWritePattern(MLIRContext *context, PatternBenefit benefit = 1)
386+
: OpConversionPattern(context, benefit) {}
387+
388+
LogicalResult
389+
matchAndRewrite(NpuWriteBdOp op, OpAdaptor adaptor,
390+
ConversionPatternRewriter &rewriter) const override {
391+
392+
AIE::DeviceOp dev = op->getParentOfType<AIE::DeviceOp>();
393+
const AIE::AIETargetModel &tm = dev.getTargetModel();
394+
395+
auto bd_id = op.getBdId();
396+
uint32_t bd_addr = (op.getColumn() << tm.getColumnShift()) |
397+
(op.getRow() << tm.getRowShift()) |
398+
(0x1D000 + bd_id * 0x20);
399+
400+
std::vector<uint32_t> words(8, 0);
401+
402+
// DMA_BDX_0
403+
words[0] = op.getBufferLength();
404+
405+
// DMA_BDX_1
406+
words[1] = op.getBufferOffset();
407+
408+
// DMA_BDX_2
409+
// En Packet , OoO BD ID , Packet ID , Packet Type
410+
words[2] |= (op.getEnablePacket() & 0x1) << 30;
411+
words[2] |= (op.getOutOfOrderId() & 0x3f) << 24;
412+
words[2] |= (op.getPacketId() & 0x1f) << 19;
413+
words[2] |= (op.getPacketType() & 0x7) << 16;
414+
415+
// DMA_BDX_3
416+
// TODO: Secure Access
417+
words[3] |= (op.getD0Size() & 0x3ff) << 20;
418+
words[3] |= op.getD0Stride() & 0xfffff;
419+
420+
// DMA_BDX_4
421+
words[4] = 0x80000000; // burst length;
422+
words[4] |= (op.getD1Size() & 0x3ff) << 20;
423+
words[4] |= op.getD1Stride() & 0xfffff;
424+
425+
// DMA_BDX_5
426+
// TODO: SIMID, AxCache, AXQoS
427+
words[5] = op.getD2Stride() & 0xfffff;
428+
429+
// DMA_BDX_6
430+
words[6] |= (op.getIterationCurrent() & 0x3f) << 26;
431+
words[6] |= (op.getIterationSize() & 0x3f) << 20;
432+
words[6] |= op.getIterationStride() & 0xfffff;
433+
434+
// DMA_BDX_7
435+
// TODO: TLAST Suppress
436+
words[7] |= (op.getNextBd() & 0xf) << 27;
437+
words[7] |= (op.getUseNextBd() & 0x1) << 26;
438+
words[7] |= (op.getValidBd() & 0x1) << 25;
439+
words[7] |= (op.getLockRelVal() & 0xef) << 18;
440+
words[7] |= (op.getLockRelId() & 0xf) << 13;
441+
words[7] |= (op.getLockAcqEnable() & 0x1) << 12;
442+
words[7] |= (op.getLockAcqVal() & 0xef) << 5;
443+
words[7] |= op.getLockAcqId() & 0xf;
444+
445+
MemRefType memrefType = MemRefType::get({8}, rewriter.getI32Type());
446+
TensorType tensorType = RankedTensorType::get({8}, rewriter.getI32Type());
447+
memref::GlobalOp global = nullptr;
448+
{
449+
OpBuilder::InsertionGuard guard(rewriter);
450+
std::string name = "blockwrite_data_";
451+
rewriter.setInsertionPoint(op->getParentOfType<func::FuncOp>());
452+
int id = 0;
453+
while (dev.lookupSymbol(name + std::to_string(id)))
454+
id++;
455+
name += std::to_string(id);
456+
global = rewriter.create<memref::GlobalOp>(
457+
op->getLoc(), name, rewriter.getStringAttr("private"), memrefType,
458+
DenseElementsAttr::get<uint32_t>(tensorType, words), true, nullptr);
459+
}
460+
auto memref = rewriter.create<memref::GetGlobalOp>(op->getLoc(), memrefType,
461+
global.getName());
462+
(void)rewriter.replaceOpWithNewOp<NpuBlockWriteOp>(
463+
op, memref.getResult(), rewriter.getUI32IntegerAttr(bd_addr));
377464
return success();
378465
}
379466
};
380467

381468
struct AIEDmaToNpuPass : AIEDmaToNpuBase<AIEDmaToNpuPass> {
469+
470+
void getDependentDialects(DialectRegistry &registry) const override {
471+
registry.insert<memref::MemRefDialect>();
472+
}
473+
382474
void runOnOperation() override {
383475

384476
ShimDMAllocationGetter cachingGetter;
@@ -387,18 +479,22 @@ struct AIEDmaToNpuPass : AIEDmaToNpuBase<AIEDmaToNpuPass> {
387479

388480
ConversionTarget target(getContext());
389481
target.addLegalDialect<AIEXDialect>();
482+
target.addLegalDialect<memref::MemRefDialect>();
390483
target.addLegalOp<AIE::BufferOp>();
391484
target.addLegalOp<AIE::ShimDMAAllocationOp>();
392-
target.addIllegalOp<NpuWriteRTPOp>();
485+
393486
target.addIllegalOp<NpuDmaMemcpyNdOp>();
394487
target.addIllegalOp<NpuDmaWaitOp>();
395488
target.addIllegalOp<NpuPushQueueOp>();
489+
target.addIllegalOp<NpuWriteRTPOp>();
490+
target.addIllegalOp<NpuWriteBdOp>();
396491

397492
RewritePatternSet patterns(&getContext());
398493
patterns.insert<DmaToNpuPattern>(&getContext(), cachingGetter);
399-
patterns.insert<DmaWaitToNpuPattern>(&getContext(), cachingGetter);
400-
patterns.insert<PushToNpuPattern>(&getContext());
401-
patterns.insert<RtpToNpuPattern>(&getContext());
494+
patterns.insert<DmaWaitToSyncPattern>(&getContext(), cachingGetter);
495+
patterns.insert<PushQueuetoWrite32Pattern>(&getContext());
496+
patterns.insert<RtpToWrite32Pattern>(&getContext());
497+
patterns.insert<WriteBdToBlockWritePattern>(&getContext());
402498

403499
if (failed(applyPartialConversion(device, target, std::move(patterns))))
404500
signalPassFailure();

lib/Targets/AIETargetNPU.cpp

Lines changed: 41 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -104,66 +104,50 @@ void appendAddressPatch(std::vector<uint32_t> &instructions,
104104
words[11] = 0;
105105
}
106106

107-
void appendWriteBdShimTile(std::vector<uint32_t> &instructions,
108-
NpuWriteBdOp op) {
107+
void appendBlockWrite(std::vector<uint32_t> &instructions, NpuBlockWriteOp op) {
109108

110-
auto words = reserveAndGetTail(instructions, 12);
111-
const AIETargetModel &tm = op->getParentOfType<DeviceOp>().getTargetModel();
109+
Value memref = op.getData();
110+
int64_t width = cast<MemRefType>(memref.getType()).getElementTypeBitWidth();
111+
if (width != 32) {
112+
op.emitWarning("Only 32-bit data type is supported for now");
113+
return;
114+
}
115+
116+
memref::GetGlobalOp getGlobal = memref.getDefiningOp<memref::GetGlobalOp>();
117+
if (!getGlobal) {
118+
op.emitError("Only MemRefs from memref.get_global are supported");
119+
return;
120+
}
121+
122+
auto global = dyn_cast_if_present<memref::GlobalOp>(
123+
op->getParentOfType<AIE::DeviceOp>().lookupSymbol(getGlobal.getName()));
124+
if (!global) {
125+
op.emitError("Global symbol not found");
126+
return;
127+
}
128+
129+
auto initVal = global.getInitialValue();
130+
if (!initVal) {
131+
op.emitError("Global symbol has no initial value");
132+
return;
133+
}
134+
135+
auto data = dyn_cast<DenseIntElementsAttr>(*initVal);
136+
if (!data) {
137+
op.emitError("Global symbol initial value is not a dense int array");
138+
return;
139+
}
112140

113141
// XAIE_IO_BLOCKWRITE
142+
auto words = reserveAndGetTail(instructions, data.size() + 4);
114143
words[0] = TXN_OPC_BLOCKWRITE;
115144
words[1] = 0;
116-
117-
// RegOff
118-
auto bd_id = op.getBdId();
119-
uint32_t bd_addr = (op.getColumn() << tm.getColumnShift()) |
120-
(op.getRow() << tm.getRowShift()) |
121-
(0x1D000 + bd_id * 0x20);
122-
words[2] = bd_addr; // ADDR
145+
words[2] = op.getAddress();
123146
words[3] = words.size() * sizeof(uint32_t); // Operation Size
124147

125-
// DMA_BDX_0
126-
words[4] = op.getBufferLength();
127-
128-
// DMA_BDX_1
129-
words[5] = op.getBufferOffset();
130-
131-
// DMA_BDX_2
132-
// En Packet , OoO BD ID , Packet ID , Packet Type
133-
words[6] |= (op.getEnablePacket() & 0x1) << 30;
134-
words[6] |= (op.getOutOfOrderId() & 0x3f) << 24;
135-
words[6] |= (op.getPacketId() & 0x1f) << 19;
136-
words[6] |= (op.getPacketType() & 0x7) << 16;
137-
138-
// DMA_BDX_3
139-
// TODO: Secure Access
140-
words[7] |= (op.getD0Size() & 0x3ff) << 20;
141-
words[7] |= op.getD0Stride() & 0xfffff;
142-
143-
// DMA_BDX_4
144-
words[8] = 0x80000000; // burst length;
145-
words[8] |= (op.getD1Size() & 0x3ff) << 20;
146-
words[8] |= op.getD1Stride() & 0xfffff;
147-
148-
// DMA_BDX_5
149-
// TODO: SIMID, AxCache, AXQoS
150-
words[9] = op.getD2Stride() & 0xfffff;
151-
152-
// DMA_BDX_6
153-
words[10] |= (op.getIterationCurrent() & 0x3f) << 26;
154-
words[10] |= (op.getIterationSize() & 0x3f) << 20;
155-
words[10] |= op.getIterationStride() & 0xfffff;
156-
157-
// DMA_BDX_7
158-
// TODO: TLAST Suppress
159-
words[11] |= (op.getNextBd() & 0xf) << 27;
160-
words[11] |= (op.getUseNextBd() & 0x1) << 26;
161-
words[11] |= (op.getValidBd() & 0x1) << 25;
162-
words[11] |= (op.getLockRelVal() & 0xef) << 18;
163-
words[11] |= (op.getLockRelId() & 0xf) << 13;
164-
words[11] |= (op.getLockAcqEnable() & 0x1) << 12;
165-
words[11] |= (op.getLockAcqVal() & 0xef) << 5;
166-
words[11] |= op.getLockAcqId() & 0xf;
148+
unsigned i = 4;
149+
for (auto d : data)
150+
words[i++] = d.getZExtValue();
167151
}
168152

169153
} // namespace
@@ -195,13 +179,13 @@ std::vector<uint32_t> xilinx::AIE::AIETranslateToNPU(ModuleOp module) {
195179
count++;
196180
appendWrite32(instructions, op);
197181
})
198-
.Case<NpuAddressPatchOp>([&](auto op) {
182+
.Case<NpuBlockWriteOp>([&](auto op) {
199183
count++;
200-
appendAddressPatch(instructions, op);
184+
appendBlockWrite(instructions, op);
201185
})
202-
.Case<NpuWriteBdOp>([&](auto op) {
186+
.Case<NpuAddressPatchOp>([&](auto op) {
203187
count++;
204-
appendWriteBdShimTile(instructions, op);
188+
appendAddressPatch(instructions, op);
205189
});
206190
}
207191
}

test/Conversion/DmaToNpu/aiert_insts.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
//===----------------------------------------------------------------------===//
88

99
// RUN: aie-opt --aie-dma-to-npu %s | FileCheck %s
10-
// CHECK: aiex.npu.writebd {bd_id = 1 : i32, buffer_length = 32 : i32, buffer_offset = 0 : i32, column = 0 : i32, d0_size = 0 : i32, d0_stride = 0 : i32, d1_size = 0 : i32, d1_stride = 0 : i32, d2_stride = 0 : i32, enable_packet = 0 : i32, iteration_current = 0 : i32, iteration_size = 0 : i32, iteration_stride = 0 : i32, lock_acq_enable = 0 : i32, lock_acq_id = 0 : i32, lock_acq_val = 0 : i32, lock_rel_id = 0 : i32, lock_rel_val = 0 : i32, next_bd = 0 : i32, out_of_order_id = 0 : i32, packet_id = 0 : i32, packet_type = 0 : i32, row = 0 : i32, use_next_bd = 0 : i32, valid_bd = 1 : i32}
10+
// CHECK: aiex.npu.blockwrite(%{{.*}}) {address = 118816 : ui32} : memref<8xi32>
1111
// CHECK: aiex.npu.write32 {address = 119300 : ui32, column = 0 : i32, row = 0 : i32, value = 2147483649 : ui32}
12-
// CHECK: aiex.npu.writebd {bd_id = 0 : i32, buffer_length = 32 : i32, buffer_offset = 128 : i32, column = 0 : i32, d0_size = 8 : i32, d0_stride = 0 : i32, d1_size = 2 : i32, d1_stride = 7 : i32, d2_stride = 15 : i32, enable_packet = 0 : i32, iteration_current = 0 : i32, iteration_size = 0 : i32, iteration_stride = 0 : i32, lock_acq_enable = 0 : i32, lock_acq_id = 0 : i32, lock_acq_val = 0 : i32, lock_rel_id = 0 : i32, lock_rel_val = 0 : i32, next_bd = 0 : i32, out_of_order_id = 0 : i32, packet_id = 0 : i32, packet_type = 0 : i32, row = 0 : i32, use_next_bd = 0 : i32, valid_bd = 1 : i32}
12+
// CHECK: aiex.npu.blockwrite(%{{.*}}) {address = 118784 : ui32} : memref<8xi32>
1313
// CHECK: aiex.npu.write32 {address = 119316 : ui32, column = 0 : i32, row = 0 : i32, value = 0 : ui32}
1414

1515
module {
@@ -24,8 +24,8 @@ module {
2424
%c8 = arith.constant 8 : i64
2525
%c16 = arith.constant 16 : i64
2626
%c32 = arith.constant 32 : i64
27-
aiex.npu.dma_memcpy_nd (0, 0, %out[%c0,%c0,%c0,%c0][%c1,%c1,%c1,%c32][%c0,%c0,%c0, %c1]) { metadata = @of_toMem, id = 1 : i64 } : memref<64xi32>
28-
aiex.npu.dma_memcpy_nd (0, 0, %in[%c0,%c2,%c0,%c0][%c1,%c2,%c2,%c8][%c0,%c16,%c8, %c1]) { metadata = @of_fromMem, id = 0 : i64 } : memref<4x2x8xi32>
27+
aiex.npu.dma_memcpy_nd (0, 0, %out[%c0,%c0,%c0,%c0][%c1,%c1,%c1,%c32][%c0,%c0,%c0, %c1]) { metadata = @of_toMem, id = 1 : i64, issue_token = true } : memref<64xi32>
28+
aiex.npu.dma_memcpy_nd (0, 0, %in[%c0,%c2,%c0,%c0][%c1,%c2,%c2,%c8][%c0,%c16,%c8, %c1]) { metadata = @of_fromMem, id = 0 : i64, issue_token = false } : memref<4x2x8xi32>
2929
return
3030
}
3131
aie.shim_dma_allocation @of_fromMem (MM2S, 0, 0)

0 commit comments

Comments
 (0)