Skip to content

Commit f31b80e

Browse files
committed
Add if checks for output waiting, succeed in opRepr fixup of multibuffer buffs and futures
1 parent 838456d commit f31b80e

File tree

1 file changed

+37
-12
lines changed

1 file changed

+37
-12
lines changed

Deeploy/TilingExtension/CodeTransformationPasses/DoubleBufferingTilingCodeGeneration.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class DoubleBufferingTilingCodeGeneration(TilingCodeGeneration):
2222
_lineComment = NodeTemplate("\n// ${comment}")
2323

2424
_moveTileInCheckOpenStatement = NodeTemplate("if ((${tileIdxVar}) < ${numTiles}[*${tileIdxPtr}+1]) {")
25+
_waitTileOutCheckOpenStatment = NodeTemplate("if ((${tileIdxVar}) >= ${bufferCount}) {")
2526

2627
# LMACAN: The brackets around ${tileIdxVar} are important to ensure correct order
2728
# of the modulo operation. Breaking case without the brackets is when we
@@ -91,24 +92,40 @@ def _ioCalls(
9192
externalBufferShape,
9293
override_type = VoidType)
9394

95+
bufferFutures = [self.dma.getFuture(tensorName, direction, copyIdx = i) for i in range(self.bufferCount)]
96+
9497
# TODO: The generate dma transfer calls cannot be called multiple times in the loop because it hoists
9598
# but it has to accept a future so now I'm stuck. I'm thinking hoisting shouldn't polute the
9699
# function that creates the call... or I do it in a hacky way where I overwrite the *future*
97100
# keyword in the opRepr in the second call...
98-
transferCalls = self._generateDmaTransferCalls(ctxt, tensorName, rectangles, tileIdxVar, buff,
99-
externalBufferRef, direction, future)
100-
101-
caseBlocks = []
102-
for i, buff in enumerate(multibufferMap[tensorName]):
103-
future = self.dma.getFuture(tensorName, direction, copyIdx = i)
104-
futures[i].add(future)
105-
106-
block = [future.alloc()]
107-
#block.extend() TODO: add transferCalls
108-
caseBlocks.append(block)
101+
# UPDATE: Inside CodeSnippets don't have the same name for local/external buffers and I don't want to change that.
102+
# ... so back to the seperation of codegen and hoisting.
103+
# UPDATA: Or! We search all the opReprs for the initial value and swap them
104+
firstTransferCalls = self._generateDmaTransferCalls(ctxt, tensorName, rectangles, tileIdxVar,
105+
multibufferMap[tensorName][0], externalBufferRef,
106+
direction, bufferFutures[0])
107+
108+
caseBlocks = [[bufferFutures[0].alloc()] + firstTransferCalls]
109+
for future, buff in zip(bufferFutures[1:], multibufferMap[tensorName][1:]):
110+
transferCalls = []
111+
for call in firstTransferCalls:
112+
# LMACAN: Fixup the operatorRepresentation
113+
opRepr = {}
114+
for key, value in call.operatorRepresentation.items():
115+
if value == bufferFutures[0].name:
116+
opRepr[key] = future.name
117+
elif value == multibufferMap[tensorName][0].name:
118+
opRepr[key] = buff.name
119+
else:
120+
opRepr[key] = value
121+
transferCalls.append(CodeSnippet(call.template, opRepr))
122+
caseBlocks.append([future.alloc()] + transferCalls)
109123

110124
calls.extend(self._switch(caseBlocks, tileIdxVar))
111125

126+
for futureSet, future in zip(futures, bufferFutures):
127+
futureSet.add(future)
128+
112129
referenceUpdate = self._generateExternalReferenceUpdate(ctxt, tensorName, rectangles, tileIdxVar,
113130
externalBufferRef)
114131
if referenceUpdate is not None:
@@ -190,8 +207,14 @@ def _tilingLoop(self, ctxt: NetworkContext, executionBlock: ExecutionBlock,
190207
egressDmaTransferCalls += egressCalls
191208

192209
egressDmaWaitStatements = [CodeSnippet(self._lineComment, {"comment": "OUTPUT WAITING"})]
210+
egressDmaWaitStatements.append(
211+
CodeSnippet(self._waitTileOutCheckOpenStatment, {
212+
"tileIdxVar": "TILING_I",
213+
"bufferCount": self.bufferCount
214+
}))
193215
egressDmaWaitStatements += self._switch([[f.wait() for f in futureSet] for futureSet in egressFutures],
194216
"TILING_I")
217+
egressDmaWaitStatements.append(CodeSnippet(self._blockClose, {}))
195218

196219
allFutures = set()
197220
for futureSet in ingressFutures + egressFutures:
@@ -202,7 +225,9 @@ def _tilingLoop(self, ctxt: NetworkContext, executionBlock: ExecutionBlock,
202225
setupStatements.extend(firstIngressCalls)
203226

204227
teardownStatements: List[CodeSnippet] = []
205-
teardownStatements.extend([f.wait() for futureSet in egressFutures for f in futureSet])
228+
totalNumTiles = len(tilingSchedule.inputLoadSchedule) # TODO: Is this the best way to do this?
229+
remainingEgressFutures = egressFutures[:min(totalNumTiles, len(egressFutures))]
230+
teardownStatements.extend([f.wait() for futureSet in remainingEgressFutures for f in futureSet])
206231
teardownStatements.extend(f.deinit() for f in allFutures)
207232

208233
closeLoopStatements = [CodeSnippet(self._closeTileLoopTemplate, {**operatorRepresentation})]

0 commit comments

Comments
 (0)