Skip to content

Commit fd89c96

Browse files
authored
Add a canonicalizer for dma bds (#1948)
1 parent a0b89ad commit fd89c96

File tree

3 files changed

+219
-2
lines changed

3 files changed

+219
-2
lines changed

include/aie/Dialect/AIE/IR/AIEOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,6 +1008,8 @@ def AIE_DMAStartOp: AIE_Op<"dma_start", [
10081008
bool isSend() { return getChannelDir() == DMAChannelDir::MM2S; }
10091009
bool isRecv() { return getChannelDir() == DMAChannelDir::S2MM; }
10101010
}];
1011+
1012+
let hasCanonicalizer = 1;
10111013
}
10121014

10131015
def AIE_DMAOp: AIE_Op<"dma", [

lib/Dialect/AIE/IR/AIEDialect.cpp

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1995,6 +1995,128 @@ int MemTileDMAOp::colIndex() { return getTileOp().colIndex(); }
19951995

19961996
int MemTileDMAOp::rowIndex() { return getTileOp().rowIndex(); }
19971997

1998+
//===----------------------------------------------------------------------===//
1999+
// DMAStartOp
2000+
//===----------------------------------------------------------------------===//
2001+
2002+
static LogicalResult FoldDMAStartOp(DMAStartOp op, PatternRewriter &rewriter) {
2003+
2004+
llvm::SetVector<Block *> reachable;
2005+
SmallVector<Block *, 16> worklist;
2006+
Block *firstBD = op.getSuccessor(0);
2007+
reachable.insert(firstBD);
2008+
worklist.push_back(firstBD);
2009+
while (!worklist.empty()) {
2010+
Block *block = worklist.pop_back_val();
2011+
if (block->empty())
2012+
continue;
2013+
auto successors = block->getTerminator()->getSuccessors();
2014+
for (auto *i : successors) {
2015+
if (!reachable.contains(i)) {
2016+
reachable.insert(i);
2017+
worklist.push_back(i);
2018+
}
2019+
}
2020+
}
2021+
2022+
// BD chain ends with an EndOp, indicating non-repeating pattern: BD chain
2023+
// folding not applicable.
2024+
if (isa<EndOp>((reachable.back())->getTerminator()))
2025+
return failure();
2026+
2027+
// Check for identical bds.
2028+
auto areIdenticalUseLocks = [](UseLockOp op1, UseLockOp op2) {
2029+
if (!op1 || !op2)
2030+
return false;
2031+
if (op1.getLock() != op2.getLock())
2032+
return false;
2033+
if (op1.getAction() != op2.getAction())
2034+
return false;
2035+
if (op1.getValue() != op2.getValue())
2036+
return false;
2037+
return true;
2038+
};
2039+
auto areIdenticalDmaBDOps = [](DMABDOp op1, DMABDOp op2) {
2040+
if (!op1 || !op2)
2041+
return false;
2042+
if (op1.getBuffer() != op2.getBuffer())
2043+
return false;
2044+
if (op1.getOffset() != op2.getOffset())
2045+
return false;
2046+
if (op1.getLen() != op2.getLen())
2047+
return false;
2048+
if (op1.getDimensions() != op2.getDimensions())
2049+
return false;
2050+
if (op1.getPadDimensions() != op2.getPadDimensions())
2051+
return false;
2052+
if (op1.getPadValue() != op2.getPadValue())
2053+
return false;
2054+
if (op1.getPacket() != op2.getPacket())
2055+
return false;
2056+
return true;
2057+
};
2058+
auto areIdenticalBDs = [areIdenticalUseLocks,
2059+
areIdenticalDmaBDOps](Block *b1, Block *b2) {
2060+
auto b1OpRange = b1->without_terminator();
2061+
auto b2OpRange = b2->without_terminator();
2062+
if (llvm::range_size(b1OpRange) != llvm::range_size(b2OpRange))
2063+
return false;
2064+
auto b1It = b1OpRange.begin();
2065+
auto b2It = b2OpRange.begin();
2066+
while (b1It != b1OpRange.end()) {
2067+
if ((*b1It).getName().getStringRef() != (*b2It).getName().getStringRef())
2068+
return false;
2069+
2070+
if (auto b1UseLockOp = dyn_cast<UseLockOp>(*b1It)) {
2071+
auto b2UseLockOp = dyn_cast<UseLockOp>(*b2It);
2072+
if (!areIdenticalUseLocks(b1UseLockOp, b2UseLockOp))
2073+
return false;
2074+
} else if (auto b1DMABDOp = dyn_cast<DMABDOp>(*b1It)) {
2075+
auto b2DMABDOp = dyn_cast<DMABDOp>(*b2It);
2076+
if (!areIdenticalDmaBDOps(b1DMABDOp, b2DMABDOp))
2077+
return false;
2078+
}
2079+
2080+
b1It++;
2081+
b2It++;
2082+
}
2083+
return true;
2084+
};
2085+
2086+
// Get a vector of unique BDs.
2087+
SmallVector<Block *> uniquePattern;
2088+
auto patternIt = reachable.begin();
2089+
while (patternIt != reachable.end() &&
2090+
llvm::none_of(uniquePattern, [patternIt, areIdenticalBDs](Block *b1) {
2091+
return areIdenticalBDs(*patternIt, b1);
2092+
})) {
2093+
uniquePattern.push_back(*patternIt);
2094+
patternIt++;
2095+
}
2096+
2097+
unsigned idx = 0;
2098+
while (patternIt != reachable.end()) {
2099+
// BD repetition found. Check if repeating pattern.
2100+
if (!areIdenticalBDs(*patternIt, uniquePattern[idx]))
2101+
return failure();
2102+
patternIt++;
2103+
idx = (++idx) % uniquePattern.size();
2104+
}
2105+
2106+
// Repeating BD chains detected. Erasing repetitions.
2107+
auto lastBDTerm = dyn_cast<NextBDOp>(reachable.back()->getTerminator());
2108+
auto lastUniqueBDTerm =
2109+
dyn_cast<NextBDOp>(uniquePattern.back()->getTerminator());
2110+
lastUniqueBDTerm.setSuccessor(lastBDTerm.getSuccessor());
2111+
2112+
return success();
2113+
}
2114+
2115+
void DMAStartOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2116+
MLIRContext *context) {
2117+
patterns.add(FoldDMAStartOp);
2118+
}
2119+
19982120
//===----------------------------------------------------------------------===//
19992121
// SwitchboxOp
20002122
//===----------------------------------------------------------------------===//

test/dialect/AIE/canonicalize-mem.mlir

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,44 @@
2222
// CHECK-NEXT: ^bb3: // 2 preds: ^bb0, ^bb2
2323
// CHECK-NEXT: aie.end
2424
// CHECK-NEXT: }
25-
// CHECK-NEXT: }
25+
26+
// CHECK: %[[TILE_1_2:.*]] = aie.tile(1, 2)
27+
// CHECK-DAG: %[[BUF_0:.*]] = aie.buffer(%[[TILE_1_2]]) {sym_name = "buf_0"} : memref<256xi32>
28+
// CHECK-DAG: %[[BUF_1:.*]] = aie.buffer(%[[TILE_1_2]]) {sym_name = "buf_1"} : memref<256xi32>
29+
// CHECK-DAG: %[[BUF_2:.*]] = aie.buffer(%[[TILE_1_2]]) {sym_name = "buf_2"} : memref<256xi32>
30+
// CHECK-DAG: %[[BUF_3:.*]] = aie.buffer(%[[TILE_1_2]]) {sym_name = "buf_3"} : memref<256xi32>
31+
// CHECK-DAG: %[[LOCK_0:.*]] = aie.lock(%{{.*}}, 0)
32+
// CHECK: aie.mem(%[[TILE_1_2]]) {
33+
// CHECK-NEXT: %[[VAL_0:.*]] = aie.dma_start(MM2S, 0, ^bb2, ^bb1)
34+
// CHECK-NEXT: ^bb1: // pred: ^bb0
35+
// CHECK-NEXT: %[[VAL_1:.*]] = aie.dma_start(MM2S, 1, ^bb5, ^bb4)
36+
// CHECK-NEXT: ^bb2: // 2 preds: ^bb0, ^bb3
37+
// CHECK-NEXT: aie.use_lock(%[[LOCK_0]], Acquire, 1)
38+
// CHECK-NEXT: aie.dma_bd(%[[BUF_0]] : memref<256xi32>, 0, 256)
39+
// CHECK-NEXT: aie.use_lock(%[[LOCK_0]], Release, 0)
40+
// CHECK-NEXT: aie.next_bd ^bb3
41+
// CHECK-NEXT: ^bb3: // pred: ^bb2
42+
// CHECK-NEXT: aie.use_lock(%[[LOCK_0]], Acquire, 1)
43+
// CHECK-NEXT: aie.dma_bd(%[[BUF_1]] : memref<256xi32>, 0, 256)
44+
// CHECK-NEXT: aie.use_lock(%[[LOCK_0]], Release, 0)
45+
// CHECK-NEXT: aie.next_bd ^bb2
46+
// CHECK-NEXT: ^bb4: // pred: ^bb1
47+
// CHECK-NEXT: aie.end
48+
// CHECK-NEXT: ^bb5: // 2 preds: ^bb1, ^bb6
49+
// CHECK-NEXT: aie.use_lock(%[[LOCK_0]], Acquire, 1)
50+
// CHECK-NEXT: aie.dma_bd(%[[BUF_2]] : memref<256xi32>, 0, 128)
51+
// CHECK-NEXT: aie.use_lock(%[[LOCK_0]], Release, 0)
52+
// CHECK-NEXT: aie.next_bd ^bb6
53+
// CHECK-NEXT: ^bb6: // pred: ^bb5
54+
// CHECK-NEXT: aie.use_lock(%[[LOCK_0]], Acquire, 1)
55+
// CHECK-NEXT: aie.dma_bd(%[[BUF_2]] : memref<256xi32>, 128, 128)
56+
// CHECK-NEXT: aie.use_lock(%[[LOCK_0]], Release, 0)
57+
// CHECK-NEXT: aie.next_bd ^bb5
2658

2759
module @test {
2860
%t1 = aie.tile(1, 1)
2961

30-
%mem13 = aie.mem(%t1) {
62+
%mem11 = aie.mem(%t1) {
3163
%dma0 = aie.dma_start("MM2S", 0, ^bd0, ^end)
3264
^bd0:
3365
aie.next_bd ^bd1 // point to the next BD, or termination
@@ -36,4 +68,65 @@ module @test {
3668
^end:
3769
aie.end
3870
}
71+
72+
73+
%t2 = aie.tile(1, 2)
74+
75+
%buf_0 = aie.buffer(%t2) { sym_name = "buf_0" } : memref<256xi32>
76+
%buf_1 = aie.buffer(%t2) { sym_name = "buf_1" } : memref<256xi32>
77+
%buf_2 = aie.buffer(%t2) { sym_name = "buf_2" } : memref<256xi32>
78+
%buf_3 = aie.buffer(%t2) { sym_name = "buf_3" } : memref<256xi32>
79+
80+
%lock_0 = aie.lock(%t2, 0)
81+
%lock_1 = aie.lock(%t2, 1)
82+
%lock_2 = aie.lock(%t2, 0)
83+
%lock_3 = aie.lock(%t2, 0)
84+
85+
%mem12 = aie.mem(%t2) {
86+
%start1 = aie.dma_start("MM2S", 0, ^bd0, ^dma0)
87+
^dma0:
88+
%start2 = aie.dma_start("MM2S", 1, ^bd4, ^end)
89+
^bd0:
90+
aie.use_lock(%lock_0, Acquire, 1)
91+
aie.dma_bd(%buf_0 : memref<256xi32>, 0, 256)
92+
aie.use_lock(%lock_0, Release, 0)
93+
aie.next_bd ^bd1
94+
^bd1:
95+
aie.use_lock(%lock_0, Acquire, 1)
96+
aie.dma_bd(%buf_1 : memref<256xi32>, 0, 256)
97+
aie.use_lock(%lock_0, Release, 0)
98+
aie.next_bd ^bd2
99+
^bd2:
100+
aie.use_lock(%lock_0, Acquire, 1)
101+
aie.dma_bd(%buf_0 : memref<256xi32>, 0, 256)
102+
aie.use_lock(%lock_0, Release, 0)
103+
aie.next_bd ^bd3
104+
^bd3:
105+
aie.use_lock(%lock_0, Acquire, 1)
106+
aie.dma_bd(%buf_1 : memref<256xi32>, 0, 256)
107+
aie.use_lock(%lock_0, Release, 0)
108+
aie.next_bd ^bd0
109+
^end:
110+
aie.end
111+
^bd4:
112+
aie.use_lock(%lock_0, Acquire, 1)
113+
aie.dma_bd(%buf_2 : memref<256xi32>, 0, 128)
114+
aie.use_lock(%lock_0, Release, 0)
115+
aie.next_bd ^bd5
116+
^bd5:
117+
aie.use_lock(%lock_0, Acquire, 1)
118+
aie.dma_bd(%buf_2 : memref<256xi32>, 128, 128)
119+
aie.use_lock(%lock_0, Release, 0)
120+
aie.next_bd ^bd6
121+
^bd6:
122+
aie.use_lock(%lock_0, Acquire, 1)
123+
aie.dma_bd(%buf_2 : memref<256xi32>, 0, 128)
124+
aie.use_lock(%lock_0, Release, 0)
125+
aie.next_bd ^bd7
126+
^bd7:
127+
aie.use_lock(%lock_0, Acquire, 1)
128+
aie.dma_bd(%buf_2 : memref<256xi32>, 128, 128)
129+
aie.use_lock(%lock_0, Release, 0)
130+
aie.next_bd ^bd4
131+
}
39132
}

0 commit comments

Comments
 (0)