Skip to content

Commit eab6282

Browse files
committed
compiler: Pass ctx down to _map_function_on_high_bw_mem
1 parent 5f680d4 commit eab6282

File tree

1 file changed

+26
-26
lines changed

1 file changed

+26
-26
lines changed

devito/passes/iet/definitions.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ class DeviceAwareDataManager(DataManager):
562562
def __init__(self, options=None, **kwargs):
563563
self.gpu_fit = options['gpu-fit']
564564
self.gpu_create = options['gpu-create']
565-
self.pmode = options.get('place-transfers')
565+
self.gpu_place_transfers = options.get('place-transfers')
566566

567567
super().__init__(**kwargs)
568568

@@ -595,7 +595,8 @@ def _map_array_on_high_bw_mem(self, site, obj, storage):
595595

596596
storage.update(obj, site, maps=mmap, unmaps=unmap)
597597

598-
def _map_function_on_high_bw_mem(self, site, obj, storage, devicerm, read_only=False):
598+
def _map_function_on_high_bw_mem(self, site, obj, storage, devicerm,
599+
read_only=False, **kwargs):
599600
"""
600601
Map a Function already defined in the host memory in to the device high
601602
bandwidth memory.
@@ -628,42 +629,41 @@ def _map_function_on_high_bw_mem(self, site, obj, storage, devicerm, read_only=F
628629
storage.update(obj, site, maps=mmap, unmaps=unmap, efuncs=efuncs)
629630

630631
@iet_pass
631-
def place_transfers(self, iet, data_movs=None, **kwargs):
632+
def place_transfers(self, iet, data_movs=None, ctx=None, **kwargs):
632633
"""
633634
Create a new IET with host-device data transfers. This requires mapping
634635
symbols to the suitable memory spaces.
635636
"""
636-
if not self.pmode:
637+
if not self.gpu_place_transfers:
637638
return iet, {}
638639

639-
@singledispatch
640-
def _place_transfers(iet, data_movs):
640+
if not isinstance(iet, EntryFunction):
641641
return iet, {}
642642

643-
@_place_transfers.register(EntryFunction)
644-
def _(iet, data_movs):
645-
reads, writes = data_movs
643+
reads, writes = data_movs
646644

647-
# Special symbol which gives user code control over data deallocations
648-
devicerm = DeviceRM()
645+
# Special symbol which gives user code control over data deallocations
646+
devicerm = DeviceRM()
649647

650-
storage = Storage()
651-
for i in filter_sorted(writes):
652-
if i.is_Array:
653-
self._map_array_on_high_bw_mem(iet, i, storage)
654-
else:
655-
self._map_function_on_high_bw_mem(iet, i, storage, devicerm)
656-
for i in filter_sorted(reads - writes):
657-
if i.is_Array:
658-
self._map_array_on_high_bw_mem(iet, i, storage)
659-
else:
660-
self._map_function_on_high_bw_mem(iet, i, storage, devicerm, True)
661-
662-
iet, efuncs = self._inject_definitions(iet, storage)
648+
storage = Storage()
649+
for i in filter_sorted(writes):
650+
if i.is_Array:
651+
self._map_array_on_high_bw_mem(iet, i, storage)
652+
else:
653+
self._map_function_on_high_bw_mem(
654+
iet, i, storage, devicerm, ctx=ctx
655+
)
656+
for i in filter_sorted(reads - writes):
657+
if i.is_Array:
658+
self._map_array_on_high_bw_mem(iet, i, storage)
659+
else:
660+
self._map_function_on_high_bw_mem(
661+
iet, i, storage, devicerm, read_only=True, ctx=ctx
662+
)
663663

664-
return iet, {'efuncs': efuncs}
664+
iet, efuncs = self._inject_definitions(iet, storage)
665665

666-
return _place_transfers(iet, data_movs=data_movs)
666+
return iet, {'efuncs': efuncs}
667667

668668
@iet_pass
669669
def place_devptr(self, iet, **kwargs):

0 commit comments

Comments
 (0)