Skip to content

Commit d754db3

Browse files
committed
Next memmove correction
1 parent 067e178 commit d754db3

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

tensorforge/backend/opt/memmove.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from tensorforge.backend.instructions.memory import AbstractShrMemWrite, MemoryInstruction
55
from tensorforge.backend.instructions.memory.load import LoadInstruction, LoadWait
66
from tensorforge.backend.instructions.ptr_manip import GetElementPtr
7+
from tensorforge.backend.instructions.allocate import RegisterAlloc
78
from tensorforge.backend.symbol import SymbolType
89

910
class MoveLoads(AbstractTransformer):
@@ -15,19 +16,26 @@ def __init__(self,
1516
def apply(self) -> None:
1617
instrsOut = []
1718
stored = []
19+
def clear_stored(instrsOut):
20+
while len(stored) > 0:
21+
delayed = stored.pop(0)
22+
instrsOut += [delayed]
1823
for instr in reversed(self._instrs):
1924
if isinstance(instr, LoadInstruction):
2025
instrsOut += [LoadWait(instr)]
21-
while len(stored) > 0:
22-
delayed = stored.pop()
23-
instrsOut += [delayed]
26+
clear_stored(instrsOut)
2427
stored.append(instr)
28+
elif isinstance(instr, RegisterAlloc):
29+
for st in stored:
30+
if st._dest is instr._dest:
31+
stored.append(instr)
32+
break
33+
else:
34+
instrsOut += [instr]
2535
else:
26-
if not isinstance(instr, ComputeInstruction):
27-
while len(stored) > 0:
28-
delayed = stored.pop()
29-
instrsOut += [delayed]
36+
if isinstance(instr, GetElementPtr):
37+
clear_stored(instrsOut)
3038
instrsOut += [instr]
31-
instrsOut += stored[::-1]
39+
clear_stored(instrsOut)
3240

3341
self._instrs = instrsOut[::-1]

0 commit comments

Comments
 (0)