Skip to content

Commit 296a812

Browse files
committed
Continue work on buffering
1 parent 4e5bcb8 commit 296a812

File tree

7 files changed

+112
-31
lines changed

7 files changed

+112
-31
lines changed

tensorforge/backend/instructions/memory/load.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(self, **kwargs):
2828
self._num_threads = kwargs['num_threads']
2929
self._permute: None = kwargs['permute']
3030
self._manual_unroll_threshold = 4
31+
self._no_memcpy = kwargs['no_memcpy'] if 'no_memcpy' in kwargs else False
3132

3233
if 'max_load_offset' in kwargs:
3334
self._max_load_offset = kwargs['max_load_offset']
@@ -54,7 +55,7 @@ def __init__(self, **kwargs):
5455
self._shr_mem.add_user(self)
5556
self._is_ready: bool = False
5657

57-
self._use_cuda_memcpy = self._context.get_vm().get_hw_descr().vendor == 'nvidia'
58+
self._use_cuda_memcpy = self._context.get_vm().get_hw_descr().vendor == 'nvidia' and not self._no_memcpy
5859

5960
if self._permute is None:
6061
self._permute = [i for i in range(len(self._src.obj.shape))]
@@ -252,8 +253,8 @@ def get_permute(self) -> List[int]:
252253
return self._permute
253254

254255
def _check(self) -> None:
255-
if self._src.stype != SymbolType.Global:
256-
raise InternalError('shr-load: `src` operand is not in global mem.')
256+
#if self._src.stype != SymbolType.Global:
257+
# raise InternalError('shr-load: `src` operand is not in global mem.')
257258

258259
if not isinstance(self._src.obj, Tensor):
259260
raise InternalError(f'shr-load: `src` operand is not a tensor, instead: {self._src.obj}')

tensorforge/backend/instructions/memory/store.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -211,10 +211,7 @@ def __init__(self,
211211
context: Context,
212212
src: Symbol,
213213
dest: Symbol,
214-
alpha: float,
215-
beta: float,
216-
num_compute_threads: int,
217-
num_active_threads: int):
214+
num_threads: int):
218215
super(StoreShrMemToGlb, self).__init__(context)
219216

220217
#if src.stype != SymbolType.SharedMem:
@@ -225,8 +222,6 @@ def __init__(self,
225222

226223
self._dest = dest
227224
self._src = src
228-
self._alpha = alpha
229-
self._beta = beta
230225
self._num_threads = num_active_threads
231226
self._is_ready = True
232227

tensorforge/backend/instructions/ptr_manip.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,17 @@ def __init__(self,
1010
src,
1111
dest,
1212
include_extra_offset=True,
13-
batch_offset=0):
13+
batch_offset=0,
14+
update_dest=None,
15+
pipeline = False):
1416
super(GetElementPtr, self).__init__(context)
1517
self._src = src
1618
self._dest = dest
1719
self._include_extra_offset = include_extra_offset
1820
self._is_ready = True
1921
self._batch_offset = batch_offset
22+
self._update_dest = update_dest
23+
self._pipeline = pipeline
2024

2125
def gen_code(self, writer):
2226

@@ -30,21 +34,23 @@ def gen_code(self, writer):
3034

3135
datatype = self._vm._fp_type if self._src.obj.datatype is None else self._src.obj.datatype
3236

37+
const_mod = '' if self._pipeline else 'const'
38+
3339
address = ''
3440
if isinstance(batch_addressing, StridedAddressing):
3541
main_offset = f'{GeneralLexicon.BATCH_ID_NAME}{self._batch_offset} * {batch_addressing.stride}'
3642
sub_offset = f'{batch_obj.get_offset_to_first_element()}'
3743
address = f'{main_offset} + {batch_addressing.offset} + {sub_offset}{extra_offset}'
3844
rhs = f'&{self._src.name}[{address}]'
3945
lhs = 'const ' if self._src.obj.direction == DataFlowDirection.SOURCE else ''
40-
lhs += f'{datatype} * const {self._vm.get_lexic().restrict_kw} {self._dest.name}'
46+
lhs += f'{datatype} *{const_mod} {self._vm.get_lexic().restrict_kw} {self._dest.name}'
4147
if batch_addressing == Addressing.STRIDED:
4248
main_offset = f'{GeneralLexicon.BATCH_ID_NAME}{self._batch_offset} * {batch_obj.get_real_volume()}'
4349
sub_offset = f'{batch_obj.get_offset_to_first_element()}'
4450
address = f'{main_offset} + {sub_offset}{extra_offset}'
4551
rhs = f'&{self._src.name}[{address}]'
4652
lhs = 'const ' if self._src.obj.direction == DataFlowDirection.SOURCE else ''
47-
lhs += f'{datatype} * const {self._vm.get_lexic().restrict_kw} {self._dest.name}'
53+
lhs += f'{datatype} *{const_mod} {self._vm.get_lexic().restrict_kw} {self._dest.name}'
4854
elif batch_addressing == Addressing.PTR_BASED:
4955
main_offset = f'{GeneralLexicon.BATCH_ID_NAME}{self._batch_offset}'
5056
sub_offset = f'{batch_obj.get_offset_to_first_element()}'
@@ -57,19 +63,23 @@ def gen_code(self, writer):
5763
rhs = f'(tensorforge::SpacePtrRestrict<{lhs}, tensorforge::GlobalMemspace>){rhs}'
5864
lhs = f'auto {self._dest.name}'
5965
else:
60-
lhs += f'{datatype} * const {self._vm.get_lexic().restrict_kw} {self._dest.name}'
66+
lhs += f'{datatype} *{const_mod} {self._vm.get_lexic().restrict_kw} {self._dest.name}'
6167
elif batch_addressing == Addressing.NONE:
6268
address = f'{batch_obj.get_offset_to_first_element()}'
6369
rhs = f'&{self._src.name}[{address}]'
6470
lhs = 'const ' if self._src.obj.direction == DataFlowDirection.SOURCE else ''
65-
lhs += f'{datatype} * const {self._vm.get_lexic().restrict_kw} {self._dest.name}'
71+
lhs += f'{datatype} *{const_mod} {self._vm.get_lexic().restrict_kw} {self._dest.name}'
6672
elif batch_addressing == Addressing.SCALAR:
6773
rhs = f'{self._src.name}'
6874
lhs = f'{datatype} {self._dest.name}'
6975
else:
7076
GenerationError(f'unknown addressing of {self._src.name}, given {batch_addressing}')
7177

72-
writer(f'{lhs} = {rhs};')
78+
if self._update_dest:
79+
writer(f'const auto {self._update_dest.name} = {self._dest.name};')
80+
writer(f'{self._dest.name} = {rhs};')
81+
else:
82+
writer(f'{lhs} = {rhs};')
7383

7484
def __str__(self) -> str:
7585
return f'{self._dest.name} = getelementptr_b2g {self._src.name};'

tensorforge/backend/opt/multibuffer.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from .abstract import AbstractTransformer, Context, AbstractInstruction
33
from tensorforge.backend.instructions.compute import ComputeInstruction
44
from tensorforge.backend.instructions.memory import AbstractShrMemWrite, MemoryInstruction
5-
from tensorforge.backend.instructions.memory.load import LoadInstruction, LoadWait, GlbToRegLoader
6-
from tensorforge.backend.instructions.memory.store import StoreRegToReg
5+
from tensorforge.backend.instructions.memory.load import LoadInstruction, LoadWait, GlbToRegLoader, GlbToShrLoader
6+
from tensorforge.backend.instructions.memory.store import StoreRegToReg, StoreShrMemToGlb
77
from tensorforge.backend.instructions.ptr_manip import GetElementPtr
88
from tensorforge.backend.instructions.allocate import RegisterAlloc
99
from tensorforge.backend.symbol import SymbolType, Symbol
@@ -12,18 +12,23 @@
1212
class MultiBuffer(AbstractTransformer):
1313
def __init__(self,
1414
context: Context,
15-
instructions: List[AbstractInstruction]):
15+
instructions: List[AbstractInstruction],
16+
shm, scopes):
1617
super(MultiBuffer, self).__init__(context, instructions)
1718
self._global_instrs = []
19+
self._shm = shm
20+
self._shm_symbol = scopes.get_symbol(self._shm)
1821

1922
def apply(self) -> None:
23+
earlystop = False
24+
2025
globalinstrs = []
2126
newinstrs = []
2227

2328
epmap = {}
2429

2530
for i, instr in enumerate(self._instrs):
26-
if isinstance(instr, LoadInstruction) and not isinstance(instr, LoadWait):
31+
if isinstance(instr, GlbToRegLoader):
2732
newregs = deepcopy(instr._dest.obj)
2833
newregs.name = f'preload_{newregs.name}'
2934
newregsym = Symbol(newregs.name, SymbolType.Register, newregs)
@@ -43,13 +48,35 @@ def apply(self) -> None:
4348
newinstrs += [LoadWait(newload1)]
4449
newinstrs += [StoreRegToReg(self._context, newregsym, instr._dest, instr._num_threads)]
4550
newinstrs += [newload2]
46-
elif isinstance(instr, GetElementPtr) or isinstance(instr, RegisterAlloc):
51+
elif isinstance(instr, GlbToShrLoader):
52+
newshrsym = Symbol(f'preload_{instr._dest.name}', SymbolType.SharedMem, instr._dest.obj)
53+
newshrsym.data_view = instr._dest.data_view
54+
newshrsym.num_threads = instr._dest.num_threads
55+
newshrsym.datatype = instr._dest.datatype
56+
newsym = Symbol(f'next_{instr._src.name}', instr._src.stype, instr._src.obj)
57+
newsym.data_view = instr._src.data_view
58+
newsym.num_threads = instr._src.num_threads
59+
newsym.datatype = instr._src.datatype
60+
newload1 = GlbToShrLoader(context=self._context, src=newsym, dest=newshrsym, shr_mem=self._shm_symbol, num_threads=instr._num_threads, permute=None)
61+
newload2 = GlbToShrLoader(context=self._context, src=newsym, dest=newshrsym, shr_mem=self._shm_symbol, num_threads=instr._num_threads, permute=None)
62+
globalinstrs += [GetElementPtr(self._context, epmap[instr._src.name], newsym, batch_offset=1)]
63+
globalinstrs += [newload1]
64+
newinstrs += [GetElementPtr(self._context, epmap[instr._src.name], newsym, batch_offset=1)]
65+
newinstrs += [LoadWait(newload1)]
66+
newinstrs += [GlbToShrLoader(context=self._context, src=newshrsym, dest=instr._dest, shr_mem=self._shm_symbol, num_threads=instr._num_threads, permute=None, no_memcpy=True)]
67+
newinstrs += [newload2]
68+
elif isinstance(instr, GetElementPtr) or isinstance(instr, RegisterAlloc) or isinstance(instr, LoadWait):
4769
newinstrs += [instr]
4870

4971
# hack
5072
if isinstance(instr, GetElementPtr):
5173
epmap[instr._dest.name] = instr._src
5274
else:
53-
self._global_instrs += globalinstrs
54-
self._instrs = newinstrs + self._instrs[i:]
55-
break
75+
if earlystop:
76+
newinstrs += self._instrs[i:]
77+
break
78+
else:
79+
newinstrs += [instr]
80+
81+
self._instrs = newinstrs
82+
self._global_instrs += globalinstrs

tensorforge/backend/opt/optimizer.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,32 +10,40 @@
1010
from .remove_redundancy import RemoveRedundancyOpt
1111
from .memmove import MoveLoads
1212
from .multibuffer import MultiBuffer
13+
from .ptrpipe import PtrPipe
1314

1415
class OptimizationStage:
1516
def __init__(self,
1617
context: Context,
1718
shr_mem: ShrMemObject,
1819
instructions: List[AbstractInstruction],
19-
num_threads: int):
20+
num_threads: int,
21+
scopes):
2022
self._context = context
2123
self._shr_mem: ShrMemObject = shr_mem
2224
self._instrs: List[AbstractInstruction] = instructions
2325
self._global_instrs: List[AbstractInstruction] = []
2426
self._num_instrs: int = len(instructions)
2527
self._user_options = context.get_user_options()
2628
self._num_threads = num_threads
29+
self._scopes = scopes
2730

2831
def optimize(self):
2932
opt = MoveLoads(self._context, self._instrs)
3033
opt.apply()
3134
self._instrs = opt.get_instructions()
3235

33-
opt = MultiBuffer(self._context, self._instrs)
36+
opt = MultiBuffer(self._context, self._instrs, self._shr_mem, self._scopes)
3437
opt.apply()
3538
self._instrs = opt.get_instructions()
3639
self._global_instrs = opt._global_instrs
3740

38-
opt = LivenessAnalysis(self._context, self._instrs)
41+
opt = PtrPipe(self._context, self._instrs)
42+
opt.apply()
43+
self._instrs = opt.get_instructions()
44+
self._global_instrs += opt._global_instrs
45+
46+
opt = LivenessAnalysis(self._context, self._global_instrs + self._instrs)
3947
opt.apply()
4048
live_map: Dict[int, Set[Symbol]] = opt.get_live_map()
4149

tensorforge/backend/opt/ptrpipe.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from typing import List
2+
from .abstract import AbstractTransformer, Context, AbstractInstruction
3+
from tensorforge.backend.instructions.compute import ComputeInstruction
4+
from tensorforge.backend.instructions.memory import AbstractShrMemWrite, MemoryInstruction
5+
from tensorforge.backend.instructions.memory.load import LoadInstruction, LoadWait, GlbToRegLoader
6+
from tensorforge.backend.instructions.memory.store import StoreRegToReg
7+
from tensorforge.backend.instructions.ptr_manip import GetElementPtr
8+
from tensorforge.backend.instructions.allocate import RegisterAlloc
9+
from tensorforge.backend.symbol import SymbolType, Symbol
10+
from copy import deepcopy
11+
12+
class PtrPipe(AbstractTransformer):
13+
def __init__(self,
14+
context: Context,
15+
instructions: List[AbstractInstruction]):
16+
super(PtrPipe, self).__init__(context, instructions)
17+
self._global_instrs = []
18+
19+
def apply(self) -> None:
20+
globalinstrs = []
21+
newinstrs = []
22+
23+
for i, instr in enumerate(self._instrs):
24+
if isinstance(instr, GetElementPtr):
25+
newdest = Symbol(f'preload{instr._batch_offset}_{instr._src.name}', instr._src.stype, instr._src.obj)
26+
newgep = GetElementPtr(self._context, instr._src, newdest, batch_offset=instr._batch_offset + 1, update_dest=instr._dest)
27+
newinstrs += [newgep]
28+
newgepstart = GetElementPtr(self._context, instr._src, newdest, batch_offset=instr._batch_offset + 1, pipeline=True)
29+
globalinstrs += [newgepstart]
30+
else:
31+
newinstrs += [instr]
32+
33+
self._instrs = newinstrs
34+
self._global_instrs = globalinstrs

tensorforge/generators/generator.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ def generate(self):
139139
opt = OptimizationStage(context=self._context,
140140
shr_mem=self._section.shr_mem_obj,
141141
instructions=self._section.ir,
142-
num_threads=self._num_threads)
142+
num_threads=self._num_threads,
143+
scopes = self._scopes)
143144
opt.optimize()
144145
self._section.ir = opt.get_instructions()
145146
self._section.global_ir += opt.get_global_instructions()
@@ -191,6 +192,7 @@ def _generate_kernel(self):
191192

192193
writer(f'const auto {GeneralLexicon.BATCH_ID_NAME}_start = {start};')
193194
writer(f'const auto {GeneralLexicon.BATCH_ID_NAME}1 = {GeneralLexicon.BATCH_ID_NAME}_start < {GeneralLexicon.NUM_ELEMENTS}{i} ? {GeneralLexicon.BATCH_ID_NAME}_start : 0;')
195+
writer(f'const auto {GeneralLexicon.BATCH_ID_NAME}2 = {GeneralLexicon.BATCH_ID_NAME}1 + {stride} < {GeneralLexicon.NUM_ELEMENTS}{i} ? {GeneralLexicon.BATCH_ID_NAME}1 + {stride} : {GeneralLexicon.BATCH_ID_NAME}1;')
194196

195197
for instruction in section.global_ir:
196198
if instruction.is_ready():
@@ -214,6 +216,7 @@ def generate_inner():
214216

215217
with writer.For(f'size_t {GeneralLexicon.BATCH_ID_NAME}0 = {start}; {GeneralLexicon.BATCH_ID_NAME}0 < {GeneralLexicon.NUM_ELEMENTS}{i}; {GeneralLexicon.BATCH_ID_NAME}0 += {stride}'):
216218
writer(f'const auto {GeneralLexicon.BATCH_ID_NAME}1 = {GeneralLexicon.BATCH_ID_NAME}0 + {stride} < {GeneralLexicon.NUM_ELEMENTS}{i} ? {GeneralLexicon.BATCH_ID_NAME}0 + {stride} : {GeneralLexicon.BATCH_ID_NAME}0;')
219+
writer(f'const auto {GeneralLexicon.BATCH_ID_NAME}2 = {GeneralLexicon.BATCH_ID_NAME}1 + {stride} < {GeneralLexicon.NUM_ELEMENTS}{i} ? {GeneralLexicon.BATCH_ID_NAME}1 + {stride} : {GeneralLexicon.BATCH_ID_NAME}1;')
217220
generate_inner()
218221
elif self._clusterlaunchcontrol:
219222
writer(f'__shared__ tensorforge::ClusterLaunchCtrl launchctrl;')
@@ -621,7 +624,10 @@ def _get_element_size_guard(self, i):
621624
return f'{GeneralLexicon.BATCH_ID_NAME}0 < {GeneralLexicon.NUM_ELEMENTS}{i}'
622625

623626
def _get_flag_guard(self, writer, i):
624-
writer(f'bool allowed = true;')
625-
with writer.If(f'{GeneralLexicon.FLAGS_NAME}{i} != nullptr'):
626-
writer(f'allowed = static_cast<bool>({GeneralLexicon.FLAGS_NAME}{i}[{GeneralLexicon.BATCH_ID_NAME}0]);')
627-
return 'allowed'
627+
if False:
628+
writer(f'bool allowed = true;')
629+
with writer.If(f'{GeneralLexicon.FLAGS_NAME}{i} != nullptr'):
630+
writer(f'allowed = static_cast<bool>({GeneralLexicon.FLAGS_NAME}{i}[{GeneralLexicon.BATCH_ID_NAME}0]);')
631+
return 'allowed'
632+
else:
633+
return 'true'

0 commit comments

Comments
 (0)