3535using namespace xilinx ;
3636
3737namespace mlir ::iree_compiler::AMDAIE {
38-
39- // / Compute the 'global' repetition count: the product over all dimensions with
40- // / zero stride of the size of the dimension.
41- // /
42- // / The case where sizes and strides are empty is a special case, and '0' is
43- // / returned.
44- static int64_t getRepetitionCount (ArrayRef<OpFoldResult> sizes,
45- ArrayRef<OpFoldResult> strides) {
46- assert (sizes.size () == strides.size () &&
47- " expected stride and size vectors of same size" );
48- if (strides.empty ()) return 0 ;
49- size_t repetitionCount{1 };
50- for (uint32_t i = 0 ; i < strides.size (); ++i) {
51- if (!isConstantIntValue (strides[i], 0 )) continue ;
52- std::optional<int64_t > maybeSize = getConstantIntValue (sizes[i]);
53- assert (maybeSize.has_value () &&
54- " expected constant size in this zero stride dimension" );
55- assert (maybeSize.value () >= 0 && " expected a non-negative size" );
56- repetitionCount *= maybeSize.value ();
57- }
58- return repetitionCount;
59- }
60-
61- // / Utility to retrieve the common repetition count from all producers and
62- // / consumers of a logical objectFifo.
63- static FailureOr<size_t > getRepetitionCount (LogicalObjFifoOpInterface op) {
64- SmallVector<int64_t > repetitionCounts;
65- auto appendRepetitionCount = [&](ArrayRef<OpFoldResult> sizes,
66- ArrayRef<OpFoldResult> strides) {
67- size_t repetitionCount = getRepetitionCount (sizes, strides);
68- if (repetitionCount != 0 ) repetitionCounts.push_back (repetitionCount);
69- };
70-
71- for (Operation *userOp : op->getUsers ()) {
72- if (auto connectionOp = dyn_cast<AMDAIE::ConnectionOp>(userOp)) {
73- FailureOr<AMDAIE::NpuCircularDmaCpyNdOp> maybeNpuDmaUserOp =
74- connectionOp.getNpuCircularDmaCpyNdUser ();
75-
76- if (failed (maybeNpuDmaUserOp)) continue ;
77-
78- AMDAIE::NpuCircularDmaCpyNdOp npuDma = maybeNpuDmaUserOp.value ();
79-
80- if (connectionOp.getTarget () &&
81- dyn_cast_if_present<LogicalObjFifoOpInterface>(
82- connectionOp.getTarget ().getDefiningOp ()) == op) {
83- appendRepetitionCount (npuDma.getTargetMixedSizes (),
84- npuDma.getTargetMixedStrides ());
85- }
86-
87- if (connectionOp.getSource () &&
88- dyn_cast_if_present<LogicalObjFifoOpInterface>(
89- connectionOp.getSource ().getDefiningOp ()) == op) {
90- appendRepetitionCount (npuDma.getSourceMixedSizes (),
91- npuDma.getSourceMixedStrides ());
92- }
93- }
94- }
95-
96- // merge the repetition counts:
97- if (repetitionCounts.empty ()) return 1 ;
98- int64_t combinedRepetitionCount =
99- *std::min_element (repetitionCounts.begin (), repetitionCounts.end ());
100-
101- // if any of the repetition counts are not divisible by the combined
102- // repetition count, that's a problem:
103- if (!std::all_of (
104- repetitionCounts.begin (), repetitionCounts.end (),
105- [&](size_t c) { return c % combinedRepetitionCount == 0 ; })) {
106- return op.emitOpError ()
107- << " could not resolved a common repetition count based on the "
108- " individual repetition counts: "
109- << getArrayString<int64_t >(repetitionCounts);
110- }
111- return combinedRepetitionCount;
112- }
113-
11438// ===----------------------------------------------------------------------===//
11539// AIEDeviceBuilder utilities
11640// ===----------------------------------------------------------------------===//
@@ -334,44 +258,6 @@ void AIEDeviceBuilder::eraseOp(Operation *op) {
334258 rewriter.eraseOp (op);
335259}
336260
337- LogicalResult AIEDeviceBuilder::foldDimsAndReturnAsStatic (
338- SmallVector<OpFoldResult> sizes, SmallVector<OpFoldResult> strides,
339- SmallVector<int64_t > &newSizes, SmallVector<int64_t > &newStrides,
340- size_t repetitionCount, uint8_t memSpace,
341- function_ref<InFlightDiagnostic()> emitError) {
342- if (failed (foldRepetitionCount (rewriter.getContext (), sizes, strides,
343- repetitionCount))) {
344- return emitError () << " could not fold repetition counts from sizes: "
345- << getConstantIntValuesString (sizes)
346- << " strides: " << getConstantIntValuesString (strides)
347- << " repetitionCount: " << repetitionCount << " ." ;
348- }
349- SmallVector<OpFoldResult> offsets (
350- strides.size (), getAsIndexOpFoldResult (rewriter.getContext (), 0 ));
351- (void )foldUnitDims (rewriter.getContext (), offsets, sizes, strides);
352-
353- DmaDimConfig dmaDimConfig (deviceModel, memSpace);
354- SmallVector<int64_t > maxSizes = dmaDimConfig.getMaxSizes (offsets.size ());
355- SmallVector<OpFoldResult> linearOffsets, linearSizes, linearStrides;
356- (void )foldLinearDims (
357- rewriter.getContext (), offsets, sizes, strides, linearOffsets,
358- linearSizes, linearStrides, [&](size_t idxFromEnd, int64_t size) {
359- return idxFromEnd < maxSizes.size () &&
360- size <= maxSizes[maxSizes.size () - idxFromEnd - 1 ];
361- });
362- std::optional<SmallVector<int64_t >> maybeStaticSizes =
363- getConstantIntValues (linearSizes);
364- std::optional<SmallVector<int64_t >> maybeStaticStrides =
365- getConstantIntValues (linearStrides);
366- if (!maybeStaticSizes || !maybeStaticStrides) {
367- return emitError ()
368- << " found dynamic sizes or strides which is not supported" ;
369- }
370- newSizes = std::move (maybeStaticSizes.value ());
371- newStrides = std::move (maybeStaticStrides.value ());
372- return success ();
373- }
374-
375261void AIEDeviceBuilder::remapOperands (Operation *op) {
376262 for (int i = 0 ; i < op->getNumOperands (); ++i) {
377263 Value operand = op->getOperand (i);
@@ -577,6 +463,7 @@ LogicalResult AIEDeviceBuilder::bufferToAIE(AMDAIE::BufferOp bufferOp,
577463LogicalResult AIEDeviceBuilder::connectionToAIE (
578464 AMDAIE::ConnectionOp connectionOp, Block *deviceBlock,
579465 int &connectionIndex) {
466+ if (reprogramDmas) return success ();
580467 LLVM_DEBUG (llvm::dbgs () << " Convert [AMDAIE::ConnectionOp]\n " );
581468 rewriter.setInsertionPoint (deviceBlock->getTerminator ());
582469 SmallVector<AMDAIE::ChannelOp> producerChannels;
@@ -705,7 +592,7 @@ LogicalResult AIEDeviceBuilder::connectionToAIE(
705592 std::make_pair (consumerLocks[0 ], producerLocks[0 ]);
706593 SmallVector<int64_t > canonicalizedSizes, canonicalizedStrides;
707594 if (failed (foldDimsAndReturnAsStatic (
708- maybeNpuDmaUserOp->getSourceMixedSizes (),
595+ rewriter, deviceModel, maybeNpuDmaUserOp->getSourceMixedSizes (),
709596 maybeNpuDmaUserOp->getSourceMixedStrides (), canonicalizedSizes,
710597 canonicalizedStrides, repetitionCount.value (),
711598 maybeSourceMemSpace.value (),
@@ -810,7 +697,7 @@ LogicalResult AIEDeviceBuilder::connectionToAIE(
810697 std::make_pair (producerLocks[0 ], consumerLocks[0 ]);
811698 SmallVector<int64_t > canonicalizedSizes, canonicalizedStrides;
812699 if (failed (foldDimsAndReturnAsStatic (
813- maybeNpuDmaUserOp->getTargetMixedSizes (),
700+ rewriter, deviceModel, maybeNpuDmaUserOp->getTargetMixedSizes (),
814701 maybeNpuDmaUserOp->getTargetMixedStrides (), canonicalizedSizes,
815702 canonicalizedStrides, repetitionCount.value (),
816703 maybeTargetMemSpace.value (),
@@ -993,7 +880,7 @@ LogicalResult AIEDeviceBuilder::workgroupToAIE(AMDAIE::WorkgroupOp workgroupOp,
993880 if (failed (connectionToAIE (dmaOp, deviceBlock, connectionIndex))) {
994881 return WalkResult::interrupt ();
995882 }
996- return WalkResult::advance ();
883+ return WalkResult::skip ();
997884 })
998885 .Case <AMDAIE::ControlCodeOp>([&](auto controlCodeOp) {
999886 // Skip control code as it should already be translated into firmware
@@ -1178,6 +1065,9 @@ LogicalResult AIEDeviceBuilder::lowerToAIE(ModuleOp moduleOp) {
11781065class AMDAIELowerToAIEPass
11791066 : public impl::AMDAIELowerToAIEBase<AMDAIELowerToAIEPass> {
11801067 public:
1068+ AMDAIELowerToAIEPass (const AMDAIELowerToAIEOptions &options)
1069+ : AMDAIELowerToAIEBase(options) {}
1070+
11811071 void getDependentDialects (DialectRegistry ®istry) const override {
11821072 registry.insert <mlir::memref::MemRefDialect, AMDAIEDialect,
11831073 xilinx::AIE::AIEDialect, xilinx::AIEX::AIEXDialect>();
@@ -1196,13 +1086,15 @@ class AMDAIELowerToAIEPass
11961086 return signalPassFailure ();
11971087 }
11981088 AMDAIEDeviceModel deviceModel = getDeviceModel (maybeDevice.value ());
1199- AIEDeviceBuilder builder (moduleOp.getContext (), std::move (deviceModel));
1089+ AIEDeviceBuilder builder (moduleOp.getContext (), std::move (deviceModel),
1090+ reprogramDmas);
12001091 if (failed (builder.lowerToAIE (moduleOp))) return signalPassFailure ();
12011092 }
12021093};
12031094
1204- std::unique_ptr<Pass> createAMDAIELowerToAIEPass () {
1205- return std::make_unique<AMDAIELowerToAIEPass>();
1095+ std::unique_ptr<Pass> createAMDAIELowerToAIEPass (
1096+ AMDAIELowerToAIEOptions options) {
1097+ return std::make_unique<AMDAIELowerToAIEPass>(options);
12061098}
12071099
12081100} // namespace mlir::iree_compiler::AMDAIE
0 commit comments