Skip to content

Commit 4e5bcb8

Browse files
committed
Implment double buffering
1 parent 6c2ed1b commit 4e5bcb8

File tree

6 files changed

+106
-32
lines changed

6 files changed

+106
-32
lines changed

tensorforge/backend/instructions/builders/ptr_manip_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class GetElementPtrBuilder(AbstractBuilder):
1212
def __init__(self, context: Context, scopes: Scopes):
1313
super(GetElementPtrBuilder, self).__init__(context, scopes)
1414

15-
def build(self, src: Symbol, include_extra_offset: bool = True):
15+
def build(self, src: Symbol, include_extra_offset: bool = True, batch_offset = 0):
1616
self._reset()
1717

1818
dstype = src.stype
@@ -32,6 +32,6 @@ def build(self, src: Symbol, include_extra_offset: bool = True):
3232
self._scopes.add_symbol(dest)
3333

3434
if src.stype != SymbolType.Data:
35-
self._instructions.append(GetElementPtr(self._context, src, dest, include_extra_offset))
35+
self._instructions.append(GetElementPtr(self._context, src, dest, include_extra_offset, batch_offset))
3636

3737
src.add_user(self)

tensorforge/backend/instructions/memory/load.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ def gen_code_inner(self, writer: Writer) -> None:
332332

333333
start = (total_size // granularity) * granularity
334334

335-
elif self._context.get_vm().get_hw_descr().vendor in ['amd']:
335+
elif self._context.get_vm().get_hw_descr().vendor in ['amd'] and False:
336336

337337
# float4 load
338338

@@ -354,10 +354,11 @@ def gen_code_inner(self, writer: Writer) -> None:
354354
for g in [4, 2, 1]: # [4, 3, 2, 1]
355355
# 4x4
356356
# writer(f'const auto f{g}idx = (threadIdx.x % {g}) * {self._num_threads} + (threadIdx.x / {g}) * {g};')
357+
total_count_g = (total_count // g) * g
357358

358-
writer(f'const auto f{g}idx = ((threadIdx.x / {16 // g}) % {g}) * {self._num_threads} + (threadIdx.x % {16 // g}) * {g} + (threadIdx.x / 16) * 16;')
359+
if start != total_count_g:
360+
writer(f'const auto f{g}idx = ((threadIdx.x / {16 // g}) % {g}) * {self._num_threads} + (threadIdx.x % {16 // g}) * {g} + (threadIdx.x / 16) * 16;')
359361

360-
total_count_g = (total_count // g) * g
361362
for i in range(start, total_count_g, g):
362363
sidx = i // lead_count
363364
ridx = i % lead_count

tensorforge/backend/instructions/ptr_manip.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
11
from .abstract_instruction import AbstractInstruction
2-
from tensorforge.common.vm.vm import VM
2+
from tensorforge.common.context import Context
33
from tensorforge.common.helper import get_extra_offset_name, Addressing
44
from tensorforge.common.basic_types import GeneralLexicon, DataFlowDirection, StridedAddressing
55
from tensorforge.common.exceptions import GenerationError
66

77
class GetElementPtr(AbstractInstruction):
88
def __init__(self,
9-
vm: VM,
9+
context: Context,
1010
src,
1111
dest,
12-
include_extra_offset=True):
13-
super(GetElementPtr, self).__init__(vm)
12+
include_extra_offset=True,
13+
batch_offset=0):
14+
super(GetElementPtr, self).__init__(context)
1415
self._src = src
1516
self._dest = dest
1617
self._include_extra_offset = include_extra_offset
1718
self._is_ready = True
19+
self._batch_offset = batch_offset
1820

1921
def gen_code(self, writer):
2022

@@ -30,21 +32,21 @@ def gen_code(self, writer):
3032

3133
address = ''
3234
if isinstance(batch_addressing, StridedAddressing):
33-
main_offset = f'{GeneralLexicon.BATCH_ID_NAME} * {batch_addressing.stride}'
35+
main_offset = f'{GeneralLexicon.BATCH_ID_NAME}{self._batch_offset} * {batch_addressing.stride}'
3436
sub_offset = f'{batch_obj.get_offset_to_first_element()}'
3537
address = f'{main_offset} + {batch_addressing.offset} + {sub_offset}{extra_offset}'
3638
rhs = f'&{self._src.name}[{address}]'
3739
lhs = 'const ' if self._src.obj.direction == DataFlowDirection.SOURCE else ''
3840
lhs += f'{datatype} * const {self._vm.get_lexic().restrict_kw} {self._dest.name}'
3941
if batch_addressing == Addressing.STRIDED:
40-
main_offset = f'{GeneralLexicon.BATCH_ID_NAME} * {batch_obj.get_real_volume()}'
42+
main_offset = f'{GeneralLexicon.BATCH_ID_NAME}{self._batch_offset} * {batch_obj.get_real_volume()}'
4143
sub_offset = f'{batch_obj.get_offset_to_first_element()}'
4244
address = f'{main_offset} + {sub_offset}{extra_offset}'
4345
rhs = f'&{self._src.name}[{address}]'
4446
lhs = 'const ' if self._src.obj.direction == DataFlowDirection.SOURCE else ''
4547
lhs += f'{datatype} * const {self._vm.get_lexic().restrict_kw} {self._dest.name}'
4648
elif batch_addressing == Addressing.PTR_BASED:
47-
main_offset = f'{GeneralLexicon.BATCH_ID_NAME}'
49+
main_offset = f'{GeneralLexicon.BATCH_ID_NAME}{self._batch_offset}'
4850
sub_offset = f'{batch_obj.get_offset_to_first_element()}'
4951
address = f'{main_offset}][{sub_offset}{extra_offset}'
5052
src_suffix = '_ptr' if self._vm.get_lexic()._backend == 'targetdart' else ''
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from typing import List
2+
from .abstract import AbstractTransformer, Context, AbstractInstruction
3+
from tensorforge.backend.instructions.compute import ComputeInstruction
4+
from tensorforge.backend.instructions.memory import AbstractShrMemWrite, MemoryInstruction
5+
from tensorforge.backend.instructions.memory.load import LoadInstruction, LoadWait, GlbToRegLoader
6+
from tensorforge.backend.instructions.memory.store import StoreRegToReg
7+
from tensorforge.backend.instructions.ptr_manip import GetElementPtr
8+
from tensorforge.backend.instructions.allocate import RegisterAlloc
9+
from tensorforge.backend.symbol import SymbolType, Symbol
10+
from copy import deepcopy
11+
12+
class MultiBuffer(AbstractTransformer):
13+
def __init__(self,
14+
context: Context,
15+
instructions: List[AbstractInstruction]):
16+
super(MultiBuffer, self).__init__(context, instructions)
17+
self._global_instrs = []
18+
19+
def apply(self) -> None:
20+
globalinstrs = []
21+
newinstrs = []
22+
23+
epmap = {}
24+
25+
for i, instr in enumerate(self._instrs):
26+
if isinstance(instr, LoadInstruction) and not isinstance(instr, LoadWait):
27+
newregs = deepcopy(instr._dest.obj)
28+
newregs.name = f'preload_{newregs.name}'
29+
newregsym = Symbol(newregs.name, SymbolType.Register, newregs)
30+
newregsym.data_view = instr._dest.data_view
31+
newregsym.num_threads = instr._dest.num_threads
32+
newregsym.datatype = instr._dest.datatype
33+
newsym = Symbol(f'next_{instr._src.name}', instr._src.stype, instr._src.obj)
34+
newsym.data_view = instr._src.data_view
35+
newsym.num_threads = instr._src.num_threads
36+
newsym.datatype = instr._src.datatype
37+
newload1 = GlbToRegLoader(self._context, newsym, newregsym, instr._num_threads, instr._linearize)
38+
newload2 = GlbToRegLoader(self._context, newsym, newregsym, instr._num_threads, instr._linearize)
39+
globalinstrs += [GetElementPtr(self._context, epmap[instr._src.name], newsym, batch_offset=1)]
40+
globalinstrs += [RegisterAlloc(self._context, newregsym, 0, 0.0)]
41+
globalinstrs += [newload1]
42+
newinstrs += [GetElementPtr(self._context, epmap[instr._src.name], newsym, batch_offset=1)]
43+
newinstrs += [LoadWait(newload1)]
44+
newinstrs += [StoreRegToReg(self._context, newregsym, instr._dest, instr._num_threads)]
45+
newinstrs += [newload2]
46+
elif isinstance(instr, GetElementPtr) or isinstance(instr, RegisterAlloc):
47+
newinstrs += [instr]
48+
49+
# hack
50+
if isinstance(instr, GetElementPtr):
51+
epmap[instr._dest.name] = instr._src
52+
else:
53+
self._global_instrs += globalinstrs
54+
self._instrs = newinstrs + self._instrs[i:]
55+
break

tensorforge/backend/opt/optimizer.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .sync_block import SyncThreadsOpt
1010
from .remove_redundancy import RemoveRedundancyOpt
1111
from .memmove import MoveLoads
12+
from .multibuffer import MultiBuffer
1213

1314
class OptimizationStage:
1415
def __init__(self,
@@ -19,6 +20,7 @@ def __init__(self,
1920
self._context = context
2021
self._shr_mem: ShrMemObject = shr_mem
2122
self._instrs: List[AbstractInstruction] = instructions
23+
self._global_instrs: List[AbstractInstruction] = []
2224
self._num_instrs: int = len(instructions)
2325
self._user_options = context.get_user_options()
2426
self._num_threads = num_threads
@@ -28,6 +30,11 @@ def optimize(self):
2830
opt.apply()
2931
self._instrs = opt.get_instructions()
3032

33+
opt = MultiBuffer(self._context, self._instrs)
34+
opt.apply()
35+
self._instrs = opt.get_instructions()
36+
self._global_instrs = opt._global_instrs
37+
3138
opt = LivenessAnalysis(self._context, self._instrs)
3239
opt.apply()
3340
live_map: Dict[int, Set[Symbol]] = opt.get_live_map()
@@ -63,3 +70,6 @@ def optimize(self):
6370

6471
def get_instructions(self):
6572
return self._instrs
73+
74+
def get_global_instructions(self):
75+
return self._global_instrs

tensorforge/generators/generator.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def generate(self):
142142
num_threads=self._num_threads)
143143
opt.optimize()
144144
self._section.ir = opt.get_instructions()
145+
self._section.global_ir += opt.get_global_instructions()
145146

146147
# add final sync for persistent threads
147148
if self._persistent_threading or self._clusterlaunchcontrol:
@@ -173,6 +174,24 @@ def _generate_kernel(self):
173174

174175
for i,section in enumerate(self._sections):
175176
with writer.AnonymousScope():
177+
178+
offset = []
179+
idx = i - 1
180+
for ssection in reversed(self._sections[:i]):
181+
if ssection.barrier:
182+
break
183+
offset += [f'{GeneralLexicon.NUM_ELEMENTS}{idx}']
184+
idx -= 1
185+
186+
stride = f'({vm.get_lexic().grid_dim_x} * {vm.get_lexic().block_dim_y})'
187+
if len(offset) == 0:
188+
start = self._get_2d_block_id()
189+
else:
190+
start = f'({self._get_2d_block_id()} + {" + ".join(offset)}) % {stride}'
191+
192+
writer(f'const auto {GeneralLexicon.BATCH_ID_NAME}_start = {start};')
193+
writer(f'const auto {GeneralLexicon.BATCH_ID_NAME}1 = {GeneralLexicon.BATCH_ID_NAME}_start < {GeneralLexicon.NUM_ELEMENTS}{i} ? {GeneralLexicon.BATCH_ID_NAME}_start : 0;')
194+
176195
for instruction in section.global_ir:
177196
if instruction.is_ready():
178197
instruction.gen_code(writer)
@@ -193,37 +212,24 @@ def generate_inner():
193212
# TODO: OMP target
194213
# TODO: maybe iterate over adjacent elements? (for indirect pointers)
195214

196-
offset = []
197-
idx = i - 1
198-
for ssection in reversed(self._sections[:i]):
199-
if ssection.barrier:
200-
break
201-
offset += [f'{GeneralLexicon.NUM_ELEMENTS}{idx}']
202-
idx -= 1
203-
204-
stride = f'({vm.get_lexic().grid_dim_x} * {vm.get_lexic().block_dim_y})'
205-
if len(offset) == 0:
206-
start = self._get_2d_block_id()
207-
else:
208-
start = f'({self._get_2d_block_id()} + {" + ".join(offset)}) % {stride}'
209-
210-
with writer.For(f'size_t {GeneralLexicon.BATCH_ID_NAME} = {start}; {GeneralLexicon.BATCH_ID_NAME} < {GeneralLexicon.NUM_ELEMENTS}{i}; {GeneralLexicon.BATCH_ID_NAME} += {stride}'):
215+
with writer.For(f'size_t {GeneralLexicon.BATCH_ID_NAME}0 = {start}; {GeneralLexicon.BATCH_ID_NAME}0 < {GeneralLexicon.NUM_ELEMENTS}{i}; {GeneralLexicon.BATCH_ID_NAME}0 += {stride}'):
216+
writer(f'const auto {GeneralLexicon.BATCH_ID_NAME}1 = {GeneralLexicon.BATCH_ID_NAME}0 + {stride} < {GeneralLexicon.NUM_ELEMENTS}{i} ? {GeneralLexicon.BATCH_ID_NAME}0 + {stride} : {GeneralLexicon.BATCH_ID_NAME}0;')
211217
generate_inner()
212218
elif self._clusterlaunchcontrol:
213219
writer(f'__shared__ tensorforge::ClusterLaunchCtrl launchctrl;')
214220
writer(f'int phase = 0;')
215221
writer(f'launchctrl.init();')
216-
writer(f'size_t {GeneralLexicon.BATCH_ID_NAME} = {self._get_2d_block_id()};')
222+
writer(f'size_t {GeneralLexicon.BATCH_ID_NAME}0 = {self._get_2d_block_id()};')
217223
with writer.While(f'true'):
218224
writer('launchctrl.setupNext();')
219225
with writer.If(f'{self._get_element_size_guard(i)}'):
220226
generate_inner()
221227
writer('const auto nextBlock = launchctrl.queryNext(phase);')
222228
with writer.If('!nextBlock.has_value()'):
223229
writer('break;')
224-
writer(f'{GeneralLexicon.BATCH_ID_NAME} = {self._get_2d_block_id("nextBlock.value()")};')
230+
writer(f'{GeneralLexicon.BATCH_ID_NAME}0 = {self._get_2d_block_id("nextBlock.value()")};')
225231
else:
226-
writer(f'const size_t {GeneralLexicon.BATCH_ID_NAME} = {self._get_2d_block_id()};')
232+
writer(f'const size_t {GeneralLexicon.BATCH_ID_NAME}0 = {self._get_2d_block_id()};')
227233
with writer.If(f'{self._get_element_size_guard(i)}'):
228234
generate_inner()
229235

@@ -612,10 +618,10 @@ def _get_2d_block_id(self, block=None):
612618
return f'{lexic.thread_idx_y} + {lexic.block_dim_y} * ({block})'
613619

614620
def _get_element_size_guard(self, i):
615-
return f'{GeneralLexicon.BATCH_ID_NAME} < {GeneralLexicon.NUM_ELEMENTS}{i}'
621+
return f'{GeneralLexicon.BATCH_ID_NAME}0 < {GeneralLexicon.NUM_ELEMENTS}{i}'
616622

617623
def _get_flag_guard(self, writer, i):
618624
writer(f'bool allowed = true;')
619625
with writer.If(f'{GeneralLexicon.FLAGS_NAME}{i} != nullptr'):
620-
writer(f'allowed = static_cast<bool>({GeneralLexicon.FLAGS_NAME}{i}[{GeneralLexicon.BATCH_ID_NAME}]);')
626+
writer(f'allowed = static_cast<bool>({GeneralLexicon.FLAGS_NAME}{i}[{GeneralLexicon.BATCH_ID_NAME}0]);')
621627
return 'allowed'

0 commit comments

Comments
 (0)