@@ -22,6 +22,7 @@ class DoubleBufferingTilingCodeGeneration(TilingCodeGeneration):
2222 _lineComment = NodeTemplate ("\n // ${comment}" )
2323
2424 _moveTileInCheckOpenStatement = NodeTemplate ("if ((${tileIdxVar}) < ${numTiles}[*${tileIdxPtr}+1]) {" )
25+ _waitTileOutCheckOpenStatment = NodeTemplate ("if ((${tileIdxVar}) >= ${bufferCount}) {" )
2526
2627 # LMACAN: The brackets around ${tileIdxVar} are important to ensure correct order
2728 # of the modulo operation. Breaking case without the brackets is when we
@@ -91,24 +92,40 @@ def _ioCalls(
9192 externalBufferShape ,
9293 override_type = VoidType )
9394
95+ bufferFutures = [self .dma .getFuture (tensorName , direction , copyIdx = i ) for i in range (self .bufferCount )]
96+
9497 # TODO: The generate dma transfer calls cannot be called multiple times in the loop because it hoists
9598 # but it has to accept a future so now I'm stuck. I'm thinking hoisting shouldn't polute the
9699 # function that creates the call... or I do it in a hacky way where I overwrite the *future*
97100 # keyword in the opRepr in the second call...
98- transferCalls = self ._generateDmaTransferCalls (ctxt , tensorName , rectangles , tileIdxVar , buff ,
99- externalBufferRef , direction , future )
100-
101- caseBlocks = []
102- for i , buff in enumerate (multibufferMap [tensorName ]):
103- future = self .dma .getFuture (tensorName , direction , copyIdx = i )
104- futures [i ].add (future )
105-
106- block = [future .alloc ()]
107- #block.extend() TODO: add transferCalls
108- caseBlocks .append (block )
101+ # UPDATE: Inside CodeSnippets don't have the same name for local/external buffers and I don't want to change that.
102+ # ... so back to the seperation of codegen and hoisting.
103+ # UPDATA: Or! We search all the opReprs for the initial value and swap them
104+ firstTransferCalls = self ._generateDmaTransferCalls (ctxt , tensorName , rectangles , tileIdxVar ,
105+ multibufferMap [tensorName ][0 ], externalBufferRef ,
106+ direction , bufferFutures [0 ])
107+
108+ caseBlocks = [[bufferFutures [0 ].alloc ()] + firstTransferCalls ]
109+ for future , buff in zip (bufferFutures [1 :], multibufferMap [tensorName ][1 :]):
110+ transferCalls = []
111+ for call in firstTransferCalls :
112+ # LMACAN: Fixup the operatorRepresentation
113+ opRepr = {}
114+ for key , value in call .operatorRepresentation .items ():
115+ if value == bufferFutures [0 ].name :
116+ opRepr [key ] = future .name
117+ elif value == multibufferMap [tensorName ][0 ].name :
118+ opRepr [key ] = buff .name
119+ else :
120+ opRepr [key ] = value
121+ transferCalls .append (CodeSnippet (call .template , opRepr ))
122+ caseBlocks .append ([future .alloc ()] + transferCalls )
109123
110124 calls .extend (self ._switch (caseBlocks , tileIdxVar ))
111125
126+ for futureSet , future in zip (futures , bufferFutures ):
127+ futureSet .add (future )
128+
112129 referenceUpdate = self ._generateExternalReferenceUpdate (ctxt , tensorName , rectangles , tileIdxVar ,
113130 externalBufferRef )
114131 if referenceUpdate is not None :
@@ -190,8 +207,14 @@ def _tilingLoop(self, ctxt: NetworkContext, executionBlock: ExecutionBlock,
190207 egressDmaTransferCalls += egressCalls
191208
192209 egressDmaWaitStatements = [CodeSnippet (self ._lineComment , {"comment" : "OUTPUT WAITING" })]
210+ egressDmaWaitStatements .append (
211+ CodeSnippet (self ._waitTileOutCheckOpenStatment , {
212+ "tileIdxVar" : "TILING_I" ,
213+ "bufferCount" : self .bufferCount
214+ }))
193215 egressDmaWaitStatements += self ._switch ([[f .wait () for f in futureSet ] for futureSet in egressFutures ],
194216 "TILING_I" )
217+ egressDmaWaitStatements .append (CodeSnippet (self ._blockClose , {}))
195218
196219 allFutures = set ()
197220 for futureSet in ingressFutures + egressFutures :
@@ -202,7 +225,9 @@ def _tilingLoop(self, ctxt: NetworkContext, executionBlock: ExecutionBlock,
202225 setupStatements .extend (firstIngressCalls )
203226
204227 teardownStatements : List [CodeSnippet ] = []
205- teardownStatements .extend ([f .wait () for futureSet in egressFutures for f in futureSet ])
228+ totalNumTiles = len (tilingSchedule .inputLoadSchedule ) # TODO: Is this the best way to do this?
229+ remainingEgressFutures = egressFutures [:min (totalNumTiles , len (egressFutures ))]
230+ teardownStatements .extend ([f .wait () for futureSet in remainingEgressFutures for f in futureSet ])
206231 teardownStatements .extend (f .deinit () for f in allFutures )
207232
208233 closeLoopStatements = [CodeSnippet (self ._closeTileLoopTemplate , {** operatorRepresentation })]
0 commit comments