@@ -562,7 +562,7 @@ class DeviceAwareDataManager(DataManager):
562562 def __init__ (self , options = None , ** kwargs ):
563563 self .gpu_fit = options ['gpu-fit' ]
564564 self .gpu_create = options ['gpu-create' ]
565- self .pmode = options .get ('place-transfers' )
565+ self .gpu_place_transfers = options .get ('place-transfers' )
566566
567567 super ().__init__ (** kwargs )
568568
@@ -595,7 +595,8 @@ def _map_array_on_high_bw_mem(self, site, obj, storage):
595595
596596 storage .update (obj , site , maps = mmap , unmaps = unmap )
597597
598- def _map_function_on_high_bw_mem (self , site , obj , storage , devicerm , read_only = False ):
598+ def _map_function_on_high_bw_mem (self , site , obj , storage , devicerm ,
599+ read_only = False , ** kwargs ):
599600 """
600601 Map a Function already defined in the host memory in to the device high
601602 bandwidth memory.
@@ -628,42 +629,41 @@ def _map_function_on_high_bw_mem(self, site, obj, storage, devicerm, read_only=F
628629 storage .update (obj , site , maps = mmap , unmaps = unmap , efuncs = efuncs )
629630
630631 @iet_pass
631- def place_transfers (self , iet , data_movs = None , ** kwargs ):
632+ def place_transfers (self , iet , data_movs = None , ctx = None , ** kwargs ):
632633 """
633634 Create a new IET with host-device data transfers. This requires mapping
634635 symbols to the suitable memory spaces.
635636 """
636- if not self .pmode :
637+ if not self .gpu_place_transfers :
637638 return iet , {}
638639
639- @singledispatch
640- def _place_transfers (iet , data_movs ):
640+ if not isinstance (iet , EntryFunction ):
641641 return iet , {}
642642
643- @_place_transfers .register (EntryFunction )
644- def _ (iet , data_movs ):
645- reads , writes = data_movs
643+ reads , writes = data_movs
646644
647- # Special symbol which gives user code control over data deallocations
648- devicerm = DeviceRM ()
645+ # Special symbol which gives user code control over data deallocations
646+ devicerm = DeviceRM ()
649647
650- storage = Storage ()
651- for i in filter_sorted (writes ):
652- if i .is_Array :
653- self ._map_array_on_high_bw_mem (iet , i , storage )
654- else :
655- self ._map_function_on_high_bw_mem (iet , i , storage , devicerm )
656- for i in filter_sorted (reads - writes ):
657- if i .is_Array :
658- self ._map_array_on_high_bw_mem (iet , i , storage )
659- else :
660- self ._map_function_on_high_bw_mem (iet , i , storage , devicerm , True )
661-
662- iet , efuncs = self ._inject_definitions (iet , storage )
648+ storage = Storage ()
649+ for i in filter_sorted (writes ):
650+ if i .is_Array :
651+ self ._map_array_on_high_bw_mem (iet , i , storage )
652+ else :
653+ self ._map_function_on_high_bw_mem (
654+ iet , i , storage , devicerm , ctx = ctx
655+ )
656+ for i in filter_sorted (reads - writes ):
657+ if i .is_Array :
658+ self ._map_array_on_high_bw_mem (iet , i , storage )
659+ else :
660+ self ._map_function_on_high_bw_mem (
661+ iet , i , storage , devicerm , read_only = True , ctx = ctx
662+ )
663663
664- return iet , { 'efuncs' : efuncs }
664+ iet , efuncs = self . _inject_definitions ( iet , storage )
665665
666- return _place_transfers ( iet , data_movs = data_movs )
666+ return iet , { 'efuncs' : efuncs }
667667
668668 @iet_pass
669669 def place_devptr (self , iet , ** kwargs ):
0 commit comments