Skip to content

Commit 1760d70

Browse files
[Reprogram] Modify controlcode-lowering and lower-to-aie for DMA reprogramming (#1330)
-- This commit includes modifications to `controlcode-lowering` and `lower-to-aie` passes for DMA reprogramming. -- This is being added to AMDAIE dialect to make [DMA reprogramming](#1287) work. Signed-off-by: Abhishek Varma <[email protected]>
1 parent e4c1861 commit 1760d70

File tree

9 files changed

+691
-132
lines changed

9 files changed

+691
-132
lines changed

compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEControlCodeLowering.cpp

Lines changed: 434 additions & 0 deletions
Large diffs are not rendered by default.

compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELowerToAIE.cpp

Lines changed: 12 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -35,82 +35,6 @@
3535
using namespace xilinx;
3636

3737
namespace mlir::iree_compiler::AMDAIE {
38-
39-
/// Compute the 'global' repetition count: the product over all dimensions with
40-
/// zero stride of the size of the dimension.
41-
///
42-
/// The case where sizes and strides are empty is a special case, and '0' is
43-
/// returned.
44-
static int64_t getRepetitionCount(ArrayRef<OpFoldResult> sizes,
45-
ArrayRef<OpFoldResult> strides) {
46-
assert(sizes.size() == strides.size() &&
47-
"expected stride and size vectors of same size");
48-
if (strides.empty()) return 0;
49-
size_t repetitionCount{1};
50-
for (uint32_t i = 0; i < strides.size(); ++i) {
51-
if (!isConstantIntValue(strides[i], 0)) continue;
52-
std::optional<int64_t> maybeSize = getConstantIntValue(sizes[i]);
53-
assert(maybeSize.has_value() &&
54-
"expected constant size in this zero stride dimension");
55-
assert(maybeSize.value() >= 0 && "expected a non-negative size");
56-
repetitionCount *= maybeSize.value();
57-
}
58-
return repetitionCount;
59-
}
60-
61-
/// Utility to retrieve the common repetition count from all producers and
62-
/// consumers of a logical objectFifo.
63-
static FailureOr<size_t> getRepetitionCount(LogicalObjFifoOpInterface op) {
64-
SmallVector<int64_t> repetitionCounts;
65-
auto appendRepetitionCount = [&](ArrayRef<OpFoldResult> sizes,
66-
ArrayRef<OpFoldResult> strides) {
67-
size_t repetitionCount = getRepetitionCount(sizes, strides);
68-
if (repetitionCount != 0) repetitionCounts.push_back(repetitionCount);
69-
};
70-
71-
for (Operation *userOp : op->getUsers()) {
72-
if (auto connectionOp = dyn_cast<AMDAIE::ConnectionOp>(userOp)) {
73-
FailureOr<AMDAIE::NpuCircularDmaCpyNdOp> maybeNpuDmaUserOp =
74-
connectionOp.getNpuCircularDmaCpyNdUser();
75-
76-
if (failed(maybeNpuDmaUserOp)) continue;
77-
78-
AMDAIE::NpuCircularDmaCpyNdOp npuDma = maybeNpuDmaUserOp.value();
79-
80-
if (connectionOp.getTarget() &&
81-
dyn_cast_if_present<LogicalObjFifoOpInterface>(
82-
connectionOp.getTarget().getDefiningOp()) == op) {
83-
appendRepetitionCount(npuDma.getTargetMixedSizes(),
84-
npuDma.getTargetMixedStrides());
85-
}
86-
87-
if (connectionOp.getSource() &&
88-
dyn_cast_if_present<LogicalObjFifoOpInterface>(
89-
connectionOp.getSource().getDefiningOp()) == op) {
90-
appendRepetitionCount(npuDma.getSourceMixedSizes(),
91-
npuDma.getSourceMixedStrides());
92-
}
93-
}
94-
}
95-
96-
// merge the repetition counts:
97-
if (repetitionCounts.empty()) return 1;
98-
int64_t combinedRepetitionCount =
99-
*std::min_element(repetitionCounts.begin(), repetitionCounts.end());
100-
101-
// if any of the repetition counts are not divisible by the combined
102-
// repetition count, that's a problem:
103-
if (!std::all_of(
104-
repetitionCounts.begin(), repetitionCounts.end(),
105-
[&](size_t c) { return c % combinedRepetitionCount == 0; })) {
106-
return op.emitOpError()
107-
<< " could not resolved a common repetition count based on the "
108-
"individual repetition counts: "
109-
<< getArrayString<int64_t>(repetitionCounts);
110-
}
111-
return combinedRepetitionCount;
112-
}
113-
11438
//===----------------------------------------------------------------------===//
11539
// AIEDeviceBuilder utilities
11640
//===----------------------------------------------------------------------===//
@@ -334,44 +258,6 @@ void AIEDeviceBuilder::eraseOp(Operation *op) {
334258
rewriter.eraseOp(op);
335259
}
336260

337-
LogicalResult AIEDeviceBuilder::foldDimsAndReturnAsStatic(
338-
SmallVector<OpFoldResult> sizes, SmallVector<OpFoldResult> strides,
339-
SmallVector<int64_t> &newSizes, SmallVector<int64_t> &newStrides,
340-
size_t repetitionCount, uint8_t memSpace,
341-
function_ref<InFlightDiagnostic()> emitError) {
342-
if (failed(foldRepetitionCount(rewriter.getContext(), sizes, strides,
343-
repetitionCount))) {
344-
return emitError() << "could not fold repetition counts from sizes: "
345-
<< getConstantIntValuesString(sizes)
346-
<< " strides: " << getConstantIntValuesString(strides)
347-
<< " repetitionCount: " << repetitionCount << ".";
348-
}
349-
SmallVector<OpFoldResult> offsets(
350-
strides.size(), getAsIndexOpFoldResult(rewriter.getContext(), 0));
351-
(void)foldUnitDims(rewriter.getContext(), offsets, sizes, strides);
352-
353-
DmaDimConfig dmaDimConfig(deviceModel, memSpace);
354-
SmallVector<int64_t> maxSizes = dmaDimConfig.getMaxSizes(offsets.size());
355-
SmallVector<OpFoldResult> linearOffsets, linearSizes, linearStrides;
356-
(void)foldLinearDims(
357-
rewriter.getContext(), offsets, sizes, strides, linearOffsets,
358-
linearSizes, linearStrides, [&](size_t idxFromEnd, int64_t size) {
359-
return idxFromEnd < maxSizes.size() &&
360-
size <= maxSizes[maxSizes.size() - idxFromEnd - 1];
361-
});
362-
std::optional<SmallVector<int64_t>> maybeStaticSizes =
363-
getConstantIntValues(linearSizes);
364-
std::optional<SmallVector<int64_t>> maybeStaticStrides =
365-
getConstantIntValues(linearStrides);
366-
if (!maybeStaticSizes || !maybeStaticStrides) {
367-
return emitError()
368-
<< "found dynamic sizes or strides which is not supported";
369-
}
370-
newSizes = std::move(maybeStaticSizes.value());
371-
newStrides = std::move(maybeStaticStrides.value());
372-
return success();
373-
}
374-
375261
void AIEDeviceBuilder::remapOperands(Operation *op) {
376262
for (int i = 0; i < op->getNumOperands(); ++i) {
377263
Value operand = op->getOperand(i);
@@ -577,6 +463,7 @@ LogicalResult AIEDeviceBuilder::bufferToAIE(AMDAIE::BufferOp bufferOp,
577463
LogicalResult AIEDeviceBuilder::connectionToAIE(
578464
AMDAIE::ConnectionOp connectionOp, Block *deviceBlock,
579465
int &connectionIndex) {
466+
if (reprogramDmas) return success();
580467
LLVM_DEBUG(llvm::dbgs() << "Convert [AMDAIE::ConnectionOp]\n");
581468
rewriter.setInsertionPoint(deviceBlock->getTerminator());
582469
SmallVector<AMDAIE::ChannelOp> producerChannels;
@@ -705,7 +592,7 @@ LogicalResult AIEDeviceBuilder::connectionToAIE(
705592
std::make_pair(consumerLocks[0], producerLocks[0]);
706593
SmallVector<int64_t> canonicalizedSizes, canonicalizedStrides;
707594
if (failed(foldDimsAndReturnAsStatic(
708-
maybeNpuDmaUserOp->getSourceMixedSizes(),
595+
rewriter, deviceModel, maybeNpuDmaUserOp->getSourceMixedSizes(),
709596
maybeNpuDmaUserOp->getSourceMixedStrides(), canonicalizedSizes,
710597
canonicalizedStrides, repetitionCount.value(),
711598
maybeSourceMemSpace.value(),
@@ -810,7 +697,7 @@ LogicalResult AIEDeviceBuilder::connectionToAIE(
810697
std::make_pair(producerLocks[0], consumerLocks[0]);
811698
SmallVector<int64_t> canonicalizedSizes, canonicalizedStrides;
812699
if (failed(foldDimsAndReturnAsStatic(
813-
maybeNpuDmaUserOp->getTargetMixedSizes(),
700+
rewriter, deviceModel, maybeNpuDmaUserOp->getTargetMixedSizes(),
814701
maybeNpuDmaUserOp->getTargetMixedStrides(), canonicalizedSizes,
815702
canonicalizedStrides, repetitionCount.value(),
816703
maybeTargetMemSpace.value(),
@@ -993,7 +880,7 @@ LogicalResult AIEDeviceBuilder::workgroupToAIE(AMDAIE::WorkgroupOp workgroupOp,
993880
if (failed(connectionToAIE(dmaOp, deviceBlock, connectionIndex))) {
994881
return WalkResult::interrupt();
995882
}
996-
return WalkResult::advance();
883+
return WalkResult::skip();
997884
})
998885
.Case<AMDAIE::ControlCodeOp>([&](auto controlCodeOp) {
999886
// Skip control code as it should already be translated into firmware
@@ -1178,6 +1065,9 @@ LogicalResult AIEDeviceBuilder::lowerToAIE(ModuleOp moduleOp) {
11781065
class AMDAIELowerToAIEPass
11791066
: public impl::AMDAIELowerToAIEBase<AMDAIELowerToAIEPass> {
11801067
public:
1068+
AMDAIELowerToAIEPass(const AMDAIELowerToAIEOptions &options)
1069+
: AMDAIELowerToAIEBase(options) {}
1070+
11811071
void getDependentDialects(DialectRegistry &registry) const override {
11821072
registry.insert<mlir::memref::MemRefDialect, AMDAIEDialect,
11831073
xilinx::AIE::AIEDialect, xilinx::AIEX::AIEXDialect>();
@@ -1196,13 +1086,15 @@ class AMDAIELowerToAIEPass
11961086
return signalPassFailure();
11971087
}
11981088
AMDAIEDeviceModel deviceModel = getDeviceModel(maybeDevice.value());
1199-
AIEDeviceBuilder builder(moduleOp.getContext(), std::move(deviceModel));
1089+
AIEDeviceBuilder builder(moduleOp.getContext(), std::move(deviceModel),
1090+
reprogramDmas);
12001091
if (failed(builder.lowerToAIE(moduleOp))) return signalPassFailure();
12011092
}
12021093
};
12031094

1204-
std::unique_ptr<Pass> createAMDAIELowerToAIEPass() {
1205-
return std::make_unique<AMDAIELowerToAIEPass>();
1095+
std::unique_ptr<Pass> createAMDAIELowerToAIEPass(
1096+
AMDAIELowerToAIEOptions options) {
1097+
return std::make_unique<AMDAIELowerToAIEPass>(options);
12061098
}
12071099

12081100
} // namespace mlir::iree_compiler::AMDAIE

compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELowerToAIE.h

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,11 @@ using BDDimLayoutAndLength = std::pair<AIE::BDDimLayoutArrayAttr, int64_t>;
3232
/// `amdaie.workgroup`.
3333
class AIEDeviceBuilder {
3434
public:
35-
AIEDeviceBuilder(MLIRContext *ctx, AMDAIEDeviceModel deviceModel)
36-
: rewriter(ctx), deviceModel(std::move(deviceModel)) {}
35+
AIEDeviceBuilder(MLIRContext *ctx, AMDAIEDeviceModel deviceModel,
36+
bool reprogramDmas)
37+
: rewriter(ctx),
38+
deviceModel(std::move(deviceModel)),
39+
reprogramDmas(reprogramDmas) {}
3740

3841
LogicalResult lowerToAIE(ModuleOp moduleOp);
3942

@@ -98,14 +101,6 @@ class AIEDeviceBuilder {
98101
/// might be used after `op` is erased.
99102
void eraseOp(Operation *op);
100103

101-
/// Utility to fold the provided repetition count, unit dims, linear dims and
102-
/// to convert the sizes and strides into static versions and return them.
103-
LogicalResult foldDimsAndReturnAsStatic(
104-
SmallVector<OpFoldResult> sizes, SmallVector<OpFoldResult> strides,
105-
SmallVector<int64_t> &newSizes, SmallVector<int64_t> &newStrides,
106-
size_t repetitionCount, uint8_t memSpace,
107-
function_ref<InFlightDiagnostic()> emitError);
108-
109104
/// Utility to remap the provided operation's operands.
110105
void remapOperands(Operation *op);
111106

@@ -127,6 +122,9 @@ class AIEDeviceBuilder {
127122
connectionToSourceTargetMemOps;
128123
/// Map from connection ops to the flow ops they have been converted into.
129124
DenseMap<AMDAIE::ConnectionOp, SmallVector<Operation *>> connectionToFlowOps;
125+
/// Set using the pass' flag `reprogram-dmas` and is used to enable/disable
126+
/// reprogramming of DMAs.
127+
bool reprogramDmas;
130128
};
131129

132130
} // namespace mlir::iree_compiler::AMDAIE

compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,13 @@ void addAMDAIEToAIEPasses(OpPassManager &passManager,
121121
passManager.addPass(createAMDAIESinkIntoCorePass());
122122
passManager.addPass(createCanonicalizerPass());
123123
passManager.addPass(createAMDAIEAddNoAliasFunctionArgumentsPass());
124-
passManager.addPass(createAMDAIELowerToAIEPass());
124+
{
125+
AMDAIELowerToAIEOptions options;
126+
// TODO(avarma): In follow-up PRs this will be replaced by a global flag.
127+
// Currently setting as `false`.
128+
options.reprogramDmas = /*reprogramDmas=*/false;
129+
passManager.addPass(createAMDAIELowerToAIEPass(options));
130+
}
125131
passManager.addPass(createAMDAIERemoveMemorySpacePass());
126132
passManager.addPass(createCanonicalizerPass());
127133
}

compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,8 @@ std::unique_ptr<Pass> createAMDAIELowerFuncArgsPass();
273273

274274
/// Create pass to lower from the AMDAIE dialect to the AIE/AIEX dialects.
275275
void addAMDAIEToAIEPasses(OpPassManager &);
276-
std::unique_ptr<Pass> createAMDAIELowerToAIEPass();
276+
std::unique_ptr<Pass> createAMDAIELowerToAIEPass(
277+
AMDAIELowerToAIEOptions options = {});
277278

278279
/// Create pass to lower a sequence of operation(s) to a iree_codegen.ukernel.*
279280
/// operation.

compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@ def AMDAIEControlCodeLowering :
215215
let options = [
216216
Option<"argIdxOffset", "arg-idx-offset", "int32_t", /*default=*/"0",
217217
"The offset to be added to the argument index.">,
218+
Option<"reprogramDmas", "reprogram-dmas", "bool", /*default=*/"false",
219+
"Flag to reprogram DMAs. When enabled, no circular DMAs will be produced">,
218220
];
219221
}
220222

@@ -639,6 +641,10 @@ def AMDAIELowerToAIE :
639641
Pass<"iree-amdaie-lower-to-aie", "ModuleOp"> {
640642
let summary = "Lower from the AMDAIE dialect to the AIE/AIEX dialects";
641643
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIELowerToAIEPass()";
644+
let options = [
645+
Option<"reprogramDmas", "reprogram-dmas", "bool", /*default=*/"false",
646+
"Flag to reprogram DMAs. When enabled, no circular DMAs will be produced">,
647+
];
642648
}
643649

644650
def AMDAIELowerToUKernels :

0 commit comments

Comments
 (0)