1616class Future :
1717
1818 _initTemplate : NodeTemplate
19+ _allocTemplate : NodeTemplate
1920 _deinitTemplate : NodeTemplate
2021 _waitTemplate : NodeTemplate
2122
@@ -28,6 +29,9 @@ def _operatorRepresentation(self, comment: str = "") -> OperatorRepresentation:
2829 def init (self , comment : str = "" ) -> CodeSnippet :
2930 return CodeSnippet (self ._initTemplate , self ._operatorRepresentation (comment ))
3031
32+ def alloc (self , comment : str = "" ) -> CodeSnippet :
33+ return CodeSnippet (self ._allocTemplate , self ._operatorRepresentation (comment ))
34+
3135 def deinit (self , comment : str = "" ) -> CodeSnippet :
3236 return CodeSnippet (self ._deinitTemplate , self ._operatorRepresentation (comment ))
3337
@@ -105,12 +109,12 @@ def transfer(self,
105109 strideLoc : Tuple [int , ...],
106110 direction : DmaDirection ,
107111 future : Future ,
108- comment : str = "" ) -> List [CodeSnippet ]:
112+ comment : str = "" ) -> Tuple [ List [CodeSnippet ], List [ CodeSnippet ], List [ CodeSnippet ] ]:
109113 self .checkTransfer (ctxt , externalBuffer , localBuffer , shape , strideExt , strideLoc , direction )
110114 opRepr = self .transferOpRepr (externalBuffer , localBuffer , shape , strideExt , strideLoc , direction , future ,
111115 comment )
112116 template = self ._transferTemplates [len (shape )]
113- return [CodeSnippet (template , opRepr )]
117+ return [future . alloc ( comment )], [ CodeSnippet (template , opRepr )], [ ]
114118
115119 def setup (self ) -> List [CodeSnippet ]:
116120 return []
@@ -122,6 +126,7 @@ def teardown(self) -> List[CodeSnippet]:
122126class EmptyFuture (Future ):
123127
124128 _initTemplate = NodeTemplate ("" )
129+ _allocTemplate = NodeTemplate ("" )
125130 _deinitTemplate = NodeTemplate ("" )
126131 _waitTemplate = NodeTemplate ("" )
127132
@@ -158,23 +163,25 @@ def transfer(self,
158163 strideLoc : Tuple [int , ...],
159164 direction : DmaDirection ,
160165 future : Future ,
161- comment : str = "" ) -> List [CodeSnippet ]:
166+ comment : str = "" ) -> Tuple [ List [CodeSnippet ], List [ CodeSnippet ], List [ CodeSnippet ] ]:
162167 tmpFuture = self .dma .getFuture (future .name .removesuffix ("_future" ))
163168 callStack = []
164- callStack .append (tmpFuture .init ())
165- callStack .extend (
166- self .dma .transfer (ctxt ,
167- externalBuffer ,
168- localBuffer ,
169- shape ,
170- strideExt ,
171- strideLoc ,
172- direction ,
173- tmpFuture ,
174- comment = comment ))
175- callStack .append (tmpFuture .wait ())
176- callStack .append (tmpFuture .deinit ())
177- return callStack
169+ callStack .append (tmpFuture .init (comment ))
170+ callStack .append (tmpFuture .alloc (comment ))
171+ _ , dma_code , _ = self .dma .transfer (ctxt ,
172+ externalBuffer ,
173+ localBuffer ,
174+ shape ,
175+ strideExt ,
176+ strideLoc ,
177+ direction ,
178+ tmpFuture ,
179+ comment = comment )
180+ callStack .extend (dma_code )
181+ callStack .append (tmpFuture .wait (comment ))
182+ callStack .append (tmpFuture .deinit (comment ))
183+
184+ return [], callStack , []
178185
179186 def setup (self ) -> List [CodeSnippet ]:
180187 return self .dma .setup ()
@@ -239,7 +246,7 @@ def transfer(self,
239246 direction : DmaDirection ,
240247 future : Future ,
241248 strideExtPad : int = 0 ,
242- comment : str = "" ) -> List [CodeSnippet ]:
249+ comment : str = "" ) -> Tuple [ List [CodeSnippet ], List [ CodeSnippet ], List [ CodeSnippet ] ]:
243250 transferRank = len (shape )
244251 kernelRank = self .nearestSupportedTransferRank (transferRank )
245252
@@ -275,18 +282,19 @@ def transfer(self,
275282 "offset" : "ext_offset"
276283 }))
277284
278- callStack .extend (
279- self .dma .transfer (ctxt ,
280- externalBufferOffseted ,
281- localBufferOffseted ,
282- shape [- kernelRank :],
283- strideExt [- kernelRank :],
284- strideLoc [- kernelRank :],
285- direction ,
286- future ,
287- comment = comment ))
285+ alloc_code , dma_code , deinit_code = self .dma .transfer (ctxt ,
286+ externalBufferOffseted ,
287+ localBufferOffseted ,
288+ shape [- kernelRank :],
289+ strideExt [- kernelRank :],
290+ strideLoc [- kernelRank :],
291+ direction ,
292+ future ,
293+ comment = comment )
294+
295+ callStack .extend (dma_code )
288296 callStack .append (CodeSnippet (self .NestedForLoopCloseTemplate (nestedLoopDepth ), {}))
289- return callStack
297+ return alloc_code , callStack , deinit_code
290298 elif kernelRank == transferRank :
291299 return self .dma .transfer (ctxt ,
292300 externalBuffer ,
0 commit comments