@@ -1951,4 +1951,52 @@ def cached_data_wrapper_if_present(ary: ArrayOrNames) -> ArrayOrNames:
19511951
19521952# }}}
19531953
1954+
1955+ # {{{ TransferMapper
1956+
1957+ class TransferMapper (CopyMapper ):
1958+ def __init__ (self , to_device : bool , queue : Any , allocator : Any = None ) -> None :
1959+ super ().__init__ ()
1960+ self .to_device = to_device
1961+ self .queue = queue
1962+ self .allocator = allocator
1963+
1964+ def map_data_wrapper (self , expr : DataWrapper ) -> Array :
1965+ import sys
1966+ if "pyopencl" not in sys .modules :
1967+ return super ().map_data_wrapper (expr )
1968+
1969+ from pyopencl .array import Array as CLArray , to_device
1970+
1971+ if isinstance (expr .data , CLArray ) and not self .to_device :
1972+ data = expr .data .get ()
1973+ return DataWrapper (
1974+ data = data ,
1975+ shape = expr .shape ,
1976+ axes = expr .axes ,
1977+ tags = expr .tags ,
1978+ non_equality_tags = expr .non_equality_tags )
1979+ elif isinstance (expr .data , np .ndarray ) and self .to_device :
1980+ data = to_device (self .queue , expr .data , allocator = self .allocator )
1981+ return DataWrapper (
1982+ data = data ,
1983+ shape = expr .shape ,
1984+ axes = expr .axes ,
1985+ tags = expr .tags ,
1986+ non_equality_tags = expr .non_equality_tags )
1987+
1988+ return super ().map_data_wrapper (expr )
1989+
1990+
1991+ def transfer_to_device (expr : ArrayOrNames , queue : Any ,
1992+ allocator : Any = None ) -> ArrayOrNames :
1993+ return TransferMapper (True , queue , allocator )(expr )
1994+
1995+
1996+ def transfer_to_host (expr : ArrayOrNames , queue : Any ,
1997+ allocator : Any = None ) -> ArrayOrNames :
1998+ return TransferMapper (False , queue , allocator )(expr )
1999+
2000+ # }}}
2001+
19542002# vim: foldmethod=marker
0 commit comments