Skip to content

Commit ae29ec5

Browse files
authored
Clean up and speed up Shim DMA Allocation Retrieval (#2644)
1 parent e4a47e6 commit ae29ec5

File tree

136 files changed

+132
-671
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

136 files changed

+132
-671
lines changed

include/aie/Dialect/AIE/IR/AIEDialect.h

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -200,23 +200,6 @@ int32_t getBufferBaseAddress(mlir::Operation *bufOp);
200200
#define GET_OP_CLASSES
201201
#include "aie/Dialect/AIE/IR/AIEOps.h.inc"
202202

203-
namespace xilinx::AIE {
204-
class DeviceOp;
205-
class ShimDMAAllocationOp;
206-
struct ShimDMAllocationGetter {
207-
public:
208-
std::optional<AIE::ShimDMAAllocationOp> get(DeviceOp dev,
209-
mlir::StringRef sym_name);
210-
211-
private:
212-
llvm::DenseMap<std::pair<DeviceOp, mlir::StringRef>,
213-
std::optional<AIE::ShimDMAAllocationOp>>
214-
allocGetter;
215-
std::optional<AIE::ShimDMAAllocationOp>
216-
cachelessGet(DeviceOp dev, mlir::StringRef sym_name);
217-
};
218-
} // namespace xilinx::AIE
219-
220203
namespace xilinx::AIE {
221204

222205
void collectTiles(DeviceOp &device,

include/aie/Dialect/AIE/IR/AIEOps.td

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1583,7 +1583,7 @@ def AIE_PutCascadeOp: AIE_Op<"put_cascade", []> {
15831583
let assemblyFormat = [{ `(` $cascade_value `:` type($cascade_value) `)` attr-dict }];
15841584
}
15851585

1586-
def AIE_ShimDMAAllocationOp : AIE_Op<"shim_dma_allocation", [HasParent<"DeviceOp">]> {
1586+
def AIE_ShimDMAAllocationOp : AIE_Op<"shim_dma_allocation", [HasParent<"DeviceOp">, Symbol]> {
15871587
let summary = "Runtime allocation information for a single shim DMA";
15881588
let description = [{
15891589
This op exists for cases where shim_dma configuration is performed outside of MLIR-AIE
@@ -1607,7 +1607,7 @@ def AIE_ShimDMAAllocationOp : AIE_Op<"shim_dma_allocation", [HasParent<"DeviceOp
16071607
}];
16081608

16091609
let arguments = (
1610-
ins FlatSymbolRefAttr:$sym_name,
1610+
ins SymbolNameAttr:$sym_name,
16111611
DMAChannelDir:$channel_dir,
16121612
AIEI64Attr:$channel_index,
16131613
AIEI64Attr:$col,
@@ -1626,15 +1626,6 @@ def AIE_ShimDMAAllocationOp : AIE_Op<"shim_dma_allocation", [HasParent<"DeviceOp
16261626
static ::xilinx::AIE::ShimDMAAllocationOp getForSymbol(::xilinx::AIE::DeviceOp device, ::llvm::StringRef symbol);
16271627
}];
16281628

1629-
let builders = [
1630-
OpBuilder<(ins "mlir::StringRef":$sym_name, "DMAChannelDir":$dir, "int":$channel_index,
1631-
"int":$col, "bool":$plio), [{
1632-
build($_builder, $_state, ::mlir::SymbolRefAttr::get(odsBuilder.getContext(), sym_name),
1633-
DMAChannelDirAttr::get(odsBuilder.getContext(), dir),
1634-
$_builder.getI64IntegerAttr(channel_index),
1635-
$_builder.getI64IntegerAttr(col), $_builder.getBoolAttr(plio), nullptr);
1636-
}]>
1637-
];
16381629
}
16391630

16401631
def AIE_ObjectFifoCreateOp: AIE_Op<"objectfifo", [HasParent<"DeviceOp">, Symbol]> {

include/aie/Dialect/AIEX/IR/AIEX.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,7 @@ def AIE_NpuDmaMemcpyNdOp: AIEX_Op<"npu.dma_memcpy_nd", [
690690
ConfinedAttr<DenseI64ArrayAttr, [DenseArrayCount<4>]>:$static_sizes,
691691
ConfinedAttr<DenseI64ArrayAttr, [DenseArrayCount<4>]>:$static_strides,
692692
OptionalAttr<PacketInfoAttr>:$packet,
693-
FlatSymbolRefAttr:$metadata,
693+
SymbolRefAttr:$metadata,
694694
I64Attr:$id,
695695
DefaultValuedOptionalAttr<BoolAttr, "false">:$issue_token,
696696
DefaultValuedOptionalAttr<I64Attr, "0">:$d0_zero_before,
@@ -1096,7 +1096,7 @@ def AIE_DMAConfigureTaskForOp : AIEX_Op<"dma_configure_task_for", [HasParent<"Ru
10961096
let summary = "As dma_configure_task, but specify tile, direction and channel by reference to a Shim DMA allocation op";
10971097

10981098
let arguments = (
1099-
ins FlatSymbolRefAttr:$alloc,
1099+
ins SymbolRefAttr:$alloc,
11001100
DefaultValuedOptionalAttr<BoolAttr, "false">:$issue_token,
11011101
DefaultValuedOptionalAttr<I32Attr, "0">:$repeat_count
11021102
);

lib/Dialect/AIE/IR/AIEDialect.cpp

Lines changed: 4 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -90,49 +90,6 @@ void AIEDialect::initialize() {
9090
addInterfaces<AIEInlinerInterface, AIEDialectFoldInterface>();
9191
}
9292

93-
// Helper class to get a ShimDMAAllocationOp for a given <device, symbol name>
94-
// pair. An object of this class is invalidated if, for any symbol_name, a
95-
// ShimDMAAllocationOp that uses it changes, as the cache is not updated in
96-
// this case.
97-
98-
// Return the first ShimDMAAllocationOp nested inside the DeviceOp 'dev' that
99-
// uses the symbol 'sym_name'
100-
std::optional<xilinx::AIE::ShimDMAAllocationOp>
101-
xilinx::AIE::ShimDMAllocationGetter::get(DeviceOp dev, StringRef sym_name) {
102-
auto key = std::make_pair(dev, sym_name);
103-
auto it = allocGetter.find(key);
104-
if (it != allocGetter.end()) {
105-
return it->second;
106-
}
107-
108-
// Call cachelessGet to look for the allocation operation
109-
auto allocOp = cachelessGet(dev, sym_name);
110-
111-
// Only cache the value if it's not empty (i.e., not std::nullopt)
112-
if (allocOp.has_value()) {
113-
allocGetter[key] = allocOp; // Cache it
114-
}
115-
116-
return allocOp; // Return the found or empty optional
117-
}
118-
119-
// Finding the ShimDMAAllocationOp for a given <DeviceOp, symbol_name> pair
120-
// can be slow when the symbol is used in many places. This version of the
121-
// function is only called when the cache does not have a ShimDMAAllocationOp
122-
// stored from a previous lookup.
123-
std::optional<xilinx::AIE::ShimDMAAllocationOp>
124-
xilinx::AIE::ShimDMAllocationGetter::cachelessGet(DeviceOp dev,
125-
mlir::StringRef sym_name) {
126-
auto *sym = dev.lookupSymbol(sym_name);
127-
if (!sym)
128-
return std::nullopt;
129-
auto uses = SymbolTable::getSymbolUses(sym, dev);
130-
for (auto use : *uses)
131-
if (auto infoOp = dyn_cast<AIE::ShimDMAAllocationOp>(use.getUser()))
132-
return infoOp;
133-
return std::nullopt;
134-
}
135-
13693
// Helper methods to retrieve the encoding associated to a burst length,
13794
// or to find the highest available burst length if the requested one is 0
13895
// (default value).
@@ -2367,11 +2324,10 @@ void BDChainOp::print(OpAsmPrinter &printer) {
23672324

23682325
ShimDMAAllocationOp ShimDMAAllocationOp::getForSymbol(DeviceOp device,
23692326
llvm::StringRef symbol) {
2370-
auto alloc_ops = device.getOps<ShimDMAAllocationOp>();
2371-
for (auto it = alloc_ops.begin(); it != alloc_ops.end(); ++it) {
2372-
AIE::ShimDMAAllocationOp a = *it;
2373-
if (a.getSymName() == symbol) {
2374-
return a;
2327+
Operation *maybeOp = device.lookupSymbol(symbol);
2328+
if (maybeOp) {
2329+
if (ShimDMAAllocationOp op = dyn_cast<ShimDMAAllocationOp>(maybeOp)) {
2330+
return op;
23752331
}
23762332
}
23772333
return nullptr;

lib/Dialect/AIE/Transforms/AIEGenerateColumnControlOverlay.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -363,13 +363,8 @@ struct AIEGenerateColumnControlOverlayPass
363363

364364
builder.create<AIE::ShimDMAAllocationOp>(
365365
builder.getUnknownLoc(), StringRef(dma_name), dir,
366-
rowToShimChanMap[tOp.rowIndex()], shimTile.colIndex(), false);
367-
MemRefType memref_ty = MemRefType::get(
368-
ArrayRef<int64_t>{2048}, IntegerType::get(builder.getContext(), 32),
369-
nullptr, 0);
370-
builder.create<memref::GlobalOp>(builder.getUnknownLoc(), dma_name,
371-
builder.getStringAttr("public"),
372-
memref_ty, nullptr, false, nullptr);
366+
rowToShimChanMap[tOp.rowIndex()], shimTile.colIndex(), false,
367+
nullptr);
373368
}
374369
}
375370

lib/Dialect/AIE/Transforms/AIEObjectFifoStatefulTransform.cpp

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1635,20 +1635,27 @@ struct AIEObjectFifoStatefulTransformPass
16351635
/// shimDMAAllocationOp containing the channelDir, channelIndex and
16361636
/// shimTile col assigned by the objectFifo lowering.
16371637
void createObjectFifoAllocationInfo(OpBuilder &builder, MLIRContext *ctx,
1638-
FlatSymbolRefAttr obj_fifo, int colIndex,
1639-
DMAChannelDir channelDir,
1638+
ObjectFifoCreateOp &objFifoOp,
1639+
int colIndex, DMAChannelDir channelDir,
16401640
int channelIndex, bool plio,
16411641
std::optional<PacketInfoAttr> packet) {
16421642
PacketInfoAttr packetInfo = nullptr;
16431643
if (packet)
16441644
packetInfo = *packet;
1645-
builder.create<ShimDMAAllocationOp>(builder.getUnknownLoc(), obj_fifo,
1645+
std::string alloc_name = getShimAllocationName(objFifoOp.getName());
1646+
// SymbolRefAttr::get(ctx, objFifoOp.getName())
1647+
builder.create<ShimDMAAllocationOp>(builder.getUnknownLoc(),
1648+
StringAttr::get(ctx, alloc_name),
16461649
DMAChannelDirAttr::get(ctx, channelDir),
16471650
builder.getI64IntegerAttr(channelIndex),
16481651
builder.getI64IntegerAttr(colIndex),
16491652
builder.getBoolAttr(plio), packetInfo);
16501653
}
16511654

1655+
static std::string getShimAllocationName(llvm::StringRef objFifoName) {
1656+
return (objFifoName + "_shim_alloc").str();
1657+
}
1658+
16521659
/// Function used to verify that an objectfifo is present in at most one
16531660
/// ObjectFifoLinkOp.
16541661
void verifyObjectFifoLinks(DeviceOp &device) {
@@ -1908,9 +1915,9 @@ struct AIEObjectFifoStatefulTransformPass
19081915

19091916
if (producer.getProducerTileOp().isShimTile())
19101917
createObjectFifoAllocationInfo(
1911-
builder, ctx, SymbolRefAttr::get(ctx, producer.getName()),
1912-
producer.getProducerTileOp().colIndex(), producerChan.direction,
1913-
producerChan.channel, producer.getPlio(), bdPacket);
1918+
builder, ctx, producer, producer.getProducerTileOp().colIndex(),
1919+
producerChan.direction, producerChan.channel, producer.getPlio(),
1920+
bdPacket);
19141921

19151922
PacketFlowOp packetflow;
19161923
if (clPacketSwObjectFifos) {
@@ -1963,9 +1970,9 @@ struct AIEObjectFifoStatefulTransformPass
19631970

19641971
if (consumer.getProducerTileOp().isShimTile())
19651972
createObjectFifoAllocationInfo(
1966-
builder, ctx, SymbolRefAttr::get(ctx, producer.getName()),
1967-
consumer.getProducerTileOp().colIndex(), consumerChan.direction,
1968-
consumerChan.channel, producer.getPlio(), {});
1973+
builder, ctx, producer, consumer.getProducerTileOp().colIndex(),
1974+
consumerChan.direction, consumerChan.channel, producer.getPlio(),
1975+
{});
19691976

19701977
if (!clPacketSwObjectFifos) {
19711978
// create flow
@@ -2233,34 +2240,41 @@ struct AIEObjectFifoStatefulTransformPass
22332240
if (res.wasInterrupted())
22342241
return signalPassFailure();
22352242
}
2236-
// make global symbols to replace the to be erased ObjectFifoCreateOps
2237-
for (auto createOp : device.getOps<ObjectFifoCreateOp>()) {
2238-
builder.setInsertionPointToStart(device.getBody());
2239-
auto sym_name = createOp.getName();
2240-
createOp->setAttr(SymbolTable::getSymbolAttrName(),
2241-
builder.getStringAttr("__erase_" + sym_name));
2242-
auto memrefType = llvm::cast<AIEObjectFifoType>(createOp.getElemType())
2243-
.getElementType();
2244-
builder.create<memref::GlobalOp>(builder.getUnknownLoc(), sym_name,
2245-
builder.getStringAttr("public"),
2246-
memrefType, nullptr, false, nullptr);
2247-
}
22482243

22492244
//===------------------------------------------------------------------===//
22502245
// Remove old ops
22512246
//===------------------------------------------------------------------===//
22522247
SetVector<Operation *> opsToErase;
22532248
device.walk([&](Operation *op) {
2254-
if (isa<ObjectFifoCreateOp, ObjectFifoLinkOp,
2255-
ObjectFifoRegisterExternalBuffersOp, ObjectFifoAcquireOp,
2256-
ObjectFifoSubviewAccessOp, ObjectFifoReleaseOp,
2257-
ObjectFifoAllocateOp>(op))
2249+
if (isa<ObjectFifoLinkOp, ObjectFifoRegisterExternalBuffersOp,
2250+
ObjectFifoAcquireOp, ObjectFifoSubviewAccessOp,
2251+
ObjectFifoReleaseOp, ObjectFifoAllocateOp>(op))
22582252
opsToErase.insert(op);
22592253
});
22602254
SmallVector<Operation *> sorted{opsToErase.begin(), opsToErase.end()};
22612255
computeTopologicalSorting(sorted);
22622256
for (auto *op : llvm::reverse(sorted))
22632257
op->erase();
2258+
2259+
//===------------------------------------------------------------------===//
2260+
// Replace any remaining uses of object fifo symbol with symbol of its shim
2261+
// dma allocation.
2262+
//===------------------------------------------------------------------===//
2263+
opsToErase.clear();
2264+
for (auto createOp : device.getOps<ObjectFifoCreateOp>()) {
2265+
std::string shimAllocName = getShimAllocationName(createOp.getName());
2266+
if (failed(SymbolTable::replaceAllSymbolUses(
2267+
createOp.getNameAttr(), builder.getStringAttr(shimAllocName),
2268+
device))) {
2269+
createOp.emitError(
2270+
"failed to replace symbol uses with shim allocation");
2271+
return signalPassFailure();
2272+
}
2273+
opsToErase.insert(createOp);
2274+
}
2275+
for (auto *op : opsToErase) {
2276+
op->erase();
2277+
}
22642278
}
22652279
};
22662280

lib/Dialect/AIEX/IR/AIEXDialect.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -453,10 +453,10 @@ LogicalResult AIEX::NpuDmaMemcpyNdOp::verify() {
453453
// even if it exceeds the maximum stride/wrap size of any one dimension,
454454
// and simply do not lower any data layout transformations, since there is
455455
// no other way to express this at the dma_memcpy_nd interface otherwise.
456-
AIE::ShimDMAllocationGetter allocGetter;
457456
AIE::DeviceOp dev = getOperation()->getParentOfType<AIE::DeviceOp>();
458-
if (auto allocOp = allocGetter.get(dev, getMetadata())) {
459-
int col = allocOp->getCol();
457+
if (auto allocOp = AIE::ShimDMAAllocationOp::getForSymbol(
458+
dev, getMetadata().getRootReference())) {
459+
int col = allocOp.getCol();
460460
bool skipTransformationChecks = isLinearTransferWithoutTransformation();
461461
if (failed(verifyStridesWraps(*this, buffer, col, 0, inputSizes,
462462
inputStrides, hardwareSizes, hardwareStrides,

lib/Dialect/AIEX/Transforms/AIECtrlPacketToDma.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,8 @@ struct AIECtrlPacketToDmaPass : AIECtrlPacketToDmaBase<AIECtrlPacketToDmaPass> {
160160
const std::vector<int64_t> staticSizes = {1, 1, 1, batchIt->totalSize};
161161
const std::vector<int64_t> staticStrides = {0, 0, 0, 1};
162162

163-
StringRef metadata = builder.getStringAttr(batchIt->shimDmaAllocName);
163+
SymbolRefAttr metadata =
164+
SymbolRefAttr::get(builder.getContext(), batchIt->shimDmaAllocName);
164165
builder.create<NpuDmaMemcpyNdOp>(
165166
builder.getUnknownLoc(), newBlockArg, SmallVector<Value>{},
166167
SmallVector<Value>{}, SmallVector<Value>{}, ArrayRef(staticOffsets),

0 commit comments

Comments
 (0)