@@ -70,10 +70,10 @@ struct ShimDMAllocationGetter {
7070};
7171} // namespace
7272
73- struct RtpToNpuPattern : OpConversionPattern<NpuWriteRTPOp> {
73+ struct RtpToWrite32Pattern : OpConversionPattern<NpuWriteRTPOp> {
7474 using OpConversionPattern::OpConversionPattern;
7575
76- RtpToNpuPattern (MLIRContext *context, PatternBenefit benefit = 1 )
76+ RtpToWrite32Pattern (MLIRContext *context, PatternBenefit benefit = 1 )
7777 : OpConversionPattern(context, benefit) {}
7878
7979 LogicalResult
@@ -118,12 +118,12 @@ struct RtpToNpuPattern : OpConversionPattern<NpuWriteRTPOp> {
118118 }
119119};
120120
121- struct PushToNpuPattern : OpConversionPattern<NpuPushQueueOp> {
121+ struct PushQueuetoWrite32Pattern : OpConversionPattern<NpuPushQueueOp> {
122122
123123public:
124124 using OpConversionPattern::OpConversionPattern;
125125
126- PushToNpuPattern (MLIRContext *context, PatternBenefit benefit = 1 )
126+ PushQueuetoWrite32Pattern (MLIRContext *context, PatternBenefit benefit = 1 )
127127 : OpConversionPattern(context, benefit) {}
128128
129129 LogicalResult
@@ -343,16 +343,16 @@ struct DmaToNpuPattern : OpConversionPattern<NpuDmaMemcpyNdOp> {
343343// / Convert NpuDmaWaitOp into NpuSyncOp by retrieving the necessary
344344// / information from the ShimDMAAllocationOp referenced through the
345345// / symbol argument of this op.
346- struct DmaWaitToNpuPattern : OpConversionPattern<NpuDmaWaitOp> {
346+ struct DmaWaitToSyncPattern : OpConversionPattern<NpuDmaWaitOp> {
347347
348348private:
349349 ShimDMAllocationGetter &allocGetter;
350350
351351public:
352352 using OpConversionPattern::OpConversionPattern;
353353
354- DmaWaitToNpuPattern (MLIRContext *context, ShimDMAllocationGetter &getter,
355- PatternBenefit benefit = 1 )
354+ DmaWaitToSyncPattern (MLIRContext *context, ShimDMAllocationGetter &getter,
355+ PatternBenefit benefit = 1 )
356356 : OpConversionPattern(context, benefit), allocGetter(getter) {}
357357
358358 LogicalResult
@@ -374,11 +374,103 @@ struct DmaWaitToNpuPattern : OpConversionPattern<NpuDmaWaitOp> {
374374 op, shimDmaAllocOp->getCol (), /* row */ 0 ,
375375 static_cast <uint32_t >(shimDmaAllocOp->getChannelDir ()),
376376 shimDmaAllocOp->getChannelIndex (), 1 , 1 );
377+
378+ return success ();
379+ }
380+ };
381+
382+ struct WriteBdToBlockWritePattern : OpConversionPattern<NpuWriteBdOp> {
383+ using OpConversionPattern::OpConversionPattern;
384+
385+ WriteBdToBlockWritePattern (MLIRContext *context, PatternBenefit benefit = 1 )
386+ : OpConversionPattern(context, benefit) {}
387+
388+ LogicalResult
389+ matchAndRewrite (NpuWriteBdOp op, OpAdaptor adaptor,
390+ ConversionPatternRewriter &rewriter) const override {
391+
392+ AIE::DeviceOp dev = op->getParentOfType <AIE::DeviceOp>();
393+ const AIE::AIETargetModel &tm = dev.getTargetModel ();
394+
395+ auto bd_id = op.getBdId ();
396+ uint32_t bd_addr = (op.getColumn () << tm.getColumnShift ()) |
397+ (op.getRow () << tm.getRowShift ()) |
398+ (0x1D000 + bd_id * 0x20 );
399+
400+ std::vector<uint32_t > words (8 , 0 );
401+
402+ // DMA_BDX_0
403+ words[0 ] = op.getBufferLength ();
404+
405+ // DMA_BDX_1
406+ words[1 ] = op.getBufferOffset ();
407+
408+ // DMA_BDX_2
409+ // En Packet , OoO BD ID , Packet ID , Packet Type
410+ words[2 ] |= (op.getEnablePacket () & 0x1 ) << 30 ;
411+ words[2 ] |= (op.getOutOfOrderId () & 0x3f ) << 24 ;
412+ words[2 ] |= (op.getPacketId () & 0x1f ) << 19 ;
413+ words[2 ] |= (op.getPacketType () & 0x7 ) << 16 ;
414+
415+ // DMA_BDX_3
416+ // TODO: Secure Access
417+ words[3 ] |= (op.getD0Size () & 0x3ff ) << 20 ;
418+ words[3 ] |= op.getD0Stride () & 0xfffff ;
419+
420+ // DMA_BDX_4
421+ words[4 ] = 0x80000000 ; // burst length;
422+ words[4 ] |= (op.getD1Size () & 0x3ff ) << 20 ;
423+ words[4 ] |= op.getD1Stride () & 0xfffff ;
424+
425+ // DMA_BDX_5
426+ // TODO: SIMID, AxCache, AXQoS
427+ words[5 ] = op.getD2Stride () & 0xfffff ;
428+
429+ // DMA_BDX_6
430+ words[6 ] |= (op.getIterationCurrent () & 0x3f ) << 26 ;
431+ words[6 ] |= (op.getIterationSize () & 0x3f ) << 20 ;
432+ words[6 ] |= op.getIterationStride () & 0xfffff ;
433+
434+ // DMA_BDX_7
435+ // TODO: TLAST Suppress
436+ words[7 ] |= (op.getNextBd () & 0xf ) << 27 ;
437+ words[7 ] |= (op.getUseNextBd () & 0x1 ) << 26 ;
438+ words[7 ] |= (op.getValidBd () & 0x1 ) << 25 ;
439+ words[7 ] |= (op.getLockRelVal () & 0xef ) << 18 ;
440+ words[7 ] |= (op.getLockRelId () & 0xf ) << 13 ;
441+ words[7 ] |= (op.getLockAcqEnable () & 0x1 ) << 12 ;
442+ words[7 ] |= (op.getLockAcqVal () & 0xef ) << 5 ;
443+ words[7 ] |= op.getLockAcqId () & 0xf ;
444+
445+ MemRefType memrefType = MemRefType::get ({8 }, rewriter.getI32Type ());
446+ TensorType tensorType = RankedTensorType::get ({8 }, rewriter.getI32Type ());
447+ memref::GlobalOp global = nullptr ;
448+ {
449+ OpBuilder::InsertionGuard guard (rewriter);
450+ std::string name = " blockwrite_data_" ;
451+ rewriter.setInsertionPoint (op->getParentOfType <func::FuncOp>());
452+ int id = 0 ;
453+ while (dev.lookupSymbol (name + std::to_string (id)))
454+ id++;
455+ name += std::to_string (id);
456+ global = rewriter.create <memref::GlobalOp>(
457+ op->getLoc (), name, rewriter.getStringAttr (" private" ), memrefType,
458+ DenseElementsAttr::get<uint32_t >(tensorType, words), true , nullptr );
459+ }
460+ auto memref = rewriter.create <memref::GetGlobalOp>(op->getLoc (), memrefType,
461+ global.getName ());
462+ (void )rewriter.replaceOpWithNewOp <NpuBlockWriteOp>(
463+ op, memref.getResult (), rewriter.getUI32IntegerAttr (bd_addr));
377464 return success ();
378465 }
379466};
380467
381468struct AIEDmaToNpuPass : AIEDmaToNpuBase<AIEDmaToNpuPass> {
469+
470+ void getDependentDialects (DialectRegistry ®istry) const override {
471+ registry.insert <memref::MemRefDialect>();
472+ }
473+
382474 void runOnOperation () override {
383475
384476 ShimDMAllocationGetter cachingGetter;
@@ -387,18 +479,22 @@ struct AIEDmaToNpuPass : AIEDmaToNpuBase<AIEDmaToNpuPass> {
387479
388480 ConversionTarget target (getContext ());
389481 target.addLegalDialect <AIEXDialect>();
482+ target.addLegalDialect <memref::MemRefDialect>();
390483 target.addLegalOp <AIE::BufferOp>();
391484 target.addLegalOp <AIE::ShimDMAAllocationOp>();
392- target. addIllegalOp <NpuWriteRTPOp>();
485+
393486 target.addIllegalOp <NpuDmaMemcpyNdOp>();
394487 target.addIllegalOp <NpuDmaWaitOp>();
395488 target.addIllegalOp <NpuPushQueueOp>();
489+ target.addIllegalOp <NpuWriteRTPOp>();
490+ target.addIllegalOp <NpuWriteBdOp>();
396491
397492 RewritePatternSet patterns (&getContext ());
398493 patterns.insert <DmaToNpuPattern>(&getContext (), cachingGetter);
399- patterns.insert <DmaWaitToNpuPattern>(&getContext (), cachingGetter);
400- patterns.insert <PushToNpuPattern>(&getContext ());
401- patterns.insert <RtpToNpuPattern>(&getContext ());
494+ patterns.insert <DmaWaitToSyncPattern>(&getContext (), cachingGetter);
495+ patterns.insert <PushQueuetoWrite32Pattern>(&getContext ());
496+ patterns.insert <RtpToWrite32Pattern>(&getContext ());
497+ patterns.insert <WriteBdToBlockWritePattern>(&getContext ());
402498
403499 if (failed (applyPartialConversion (device, target, std::move (patterns))))
404500 signalPassFailure ();
0 commit comments