22from .abstract import AbstractTransformer , Context , AbstractInstruction
33from tensorforge .backend .instructions .compute import ComputeInstruction
44from tensorforge .backend .instructions .memory import AbstractShrMemWrite , MemoryInstruction
5- from tensorforge .backend .instructions .memory .load import LoadInstruction , LoadWait , GlbToRegLoader
6- from tensorforge .backend .instructions .memory .store import StoreRegToReg
5+ from tensorforge .backend .instructions .memory .load import LoadInstruction , LoadWait , GlbToRegLoader , GlbToShrLoader
6+ from tensorforge .backend .instructions .memory .store import StoreRegToReg , StoreShrMemToGlb
77from tensorforge .backend .instructions .ptr_manip import GetElementPtr
88from tensorforge .backend .instructions .allocate import RegisterAlloc
99from tensorforge .backend .symbol import SymbolType , Symbol
1212class MultiBuffer (AbstractTransformer ):
1313 def __init__ (self ,
1414 context : Context ,
15- instructions : List [AbstractInstruction ]):
15+ instructions : List [AbstractInstruction ],
16+ shm , scopes ):
1617 super (MultiBuffer , self ).__init__ (context , instructions )
1718 self ._global_instrs = []
19+ self ._shm = shm
20+ self ._shm_symbol = scopes .get_symbol (self ._shm )
1821
1922 def apply (self ) -> None :
23+ earlystop = False
24+
2025 globalinstrs = []
2126 newinstrs = []
2227
2328 epmap = {}
2429
2530 for i , instr in enumerate (self ._instrs ):
26- if isinstance (instr , LoadInstruction ) and not isinstance ( instr , LoadWait ):
31+ if isinstance (instr , GlbToRegLoader ):
2732 newregs = deepcopy (instr ._dest .obj )
2833 newregs .name = f'preload_{ newregs .name } '
2934 newregsym = Symbol (newregs .name , SymbolType .Register , newregs )
@@ -43,13 +48,35 @@ def apply(self) -> None:
4348 newinstrs += [LoadWait (newload1 )]
4449 newinstrs += [StoreRegToReg (self ._context , newregsym , instr ._dest , instr ._num_threads )]
4550 newinstrs += [newload2 ]
46- elif isinstance (instr , GetElementPtr ) or isinstance (instr , RegisterAlloc ):
51+ elif isinstance (instr , GlbToShrLoader ):
52+ newshrsym = Symbol (f'preload_{ instr ._dest .name } ' , SymbolType .SharedMem , instr ._dest .obj )
53+ newshrsym .data_view = instr ._dest .data_view
54+ newshrsym .num_threads = instr ._dest .num_threads
55+ newshrsym .datatype = instr ._dest .datatype
56+ newsym = Symbol (f'next_{ instr ._src .name } ' , instr ._src .stype , instr ._src .obj )
57+ newsym .data_view = instr ._src .data_view
58+ newsym .num_threads = instr ._src .num_threads
59+ newsym .datatype = instr ._src .datatype
60+ newload1 = GlbToShrLoader (context = self ._context , src = newsym , dest = newshrsym , shr_mem = self ._shm_symbol , num_threads = instr ._num_threads , permute = None )
61+ newload2 = GlbToShrLoader (context = self ._context , src = newsym , dest = newshrsym , shr_mem = self ._shm_symbol , num_threads = instr ._num_threads , permute = None )
62+ globalinstrs += [GetElementPtr (self ._context , epmap [instr ._src .name ], newsym , batch_offset = 1 )]
63+ globalinstrs += [newload1 ]
64+ newinstrs += [GetElementPtr (self ._context , epmap [instr ._src .name ], newsym , batch_offset = 1 )]
65+ newinstrs += [LoadWait (newload1 )]
66+ newinstrs += [GlbToShrLoader (context = self ._context , src = newshrsym , dest = instr ._dest , shr_mem = self ._shm_symbol , num_threads = instr ._num_threads , permute = None , no_memcpy = True )]
67+ newinstrs += [newload2 ]
68+ elif isinstance (instr , GetElementPtr ) or isinstance (instr , RegisterAlloc ) or isinstance (instr , LoadWait ):
4769 newinstrs += [instr ]
4870
4971 # hack
5072 if isinstance (instr , GetElementPtr ):
5173 epmap [instr ._dest .name ] = instr ._src
5274 else :
53- self ._global_instrs += globalinstrs
54- self ._instrs = newinstrs + self ._instrs [i :]
55- break
75+ if earlystop :
76+ newinstrs += self ._instrs [i :]
77+ break
78+ else :
79+ newinstrs += [instr ]
80+
81+ self ._instrs = newinstrs
82+ self ._global_instrs += globalinstrs
0 commit comments