Skip to content

Commit a906d51

Browse files
add a CL device<->host transfer mapper
1 parent 5f61931 commit a906d51

File tree

2 files changed

+50
-2
lines changed

2 files changed

+50
-2
lines changed

examples/demo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import pytato as pt
66

77

8-
n = pt.make_size_param("n")
9-
a = pt.make_placeholder(name="a", shape=(n, n), dtype=np.float64)
8+
an = np.random.randn(20, 20)
9+
a = pt.make_placeholder(an)
1010

1111
a2a = a@(2*a)
1212
aat = a@a.T

pytato/transform/__init__.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1951,4 +1951,52 @@ def cached_data_wrapper_if_present(ary: ArrayOrNames) -> ArrayOrNames:
19511951

19521952
# }}}
19531953

1954+
1955+
# {{{ TransferMapper
1956+
1957+
class TransferMapper(CopyMapper):
1958+
def __init__(self, to_device: bool, queue: Any, allocator: Any = None) -> None:
1959+
super().__init__()
1960+
self.to_device = to_device
1961+
self.queue = queue
1962+
self.allocator = allocator
1963+
1964+
def map_data_wrapper(self, expr: DataWrapper) -> Array:
1965+
import sys
1966+
if "pyopencl" not in sys.modules:
1967+
return super().map_data_wrapper(expr)
1968+
1969+
from pyopencl.array import Array as CLArray, to_device
1970+
1971+
if isinstance(expr.data, CLArray) and not self.to_device:
1972+
data = expr.data.get()
1973+
return DataWrapper(
1974+
data=data,
1975+
shape=expr.shape,
1976+
axes=expr.axes,
1977+
tags=expr.tags,
1978+
non_equality_tags=expr.non_equality_tags)
1979+
elif isinstance(expr.data, np.ndarray) and self.to_device:
1980+
data = to_device(self.queue, expr.data, allocator=self.allocator)
1981+
return DataWrapper(
1982+
data=data,
1983+
shape=expr.shape,
1984+
axes=expr.axes,
1985+
tags=expr.tags,
1986+
non_equality_tags=expr.non_equality_tags)
1987+
1988+
return super().map_data_wrapper(expr)
1989+
1990+
1991+
def transfer_to_device(expr: ArrayOrNames, queue: Any,
1992+
allocator: Any = None) -> ArrayOrNames:
1993+
return TransferMapper(True, queue, allocator)(expr)
1994+
1995+
1996+
def transfer_to_host(expr: ArrayOrNames, queue: Any,
1997+
allocator: Any = None) -> ArrayOrNames:
1998+
return TransferMapper(False, queue, allocator)(expr)
1999+
2000+
# }}}
2001+
19542002
# vim: foldmethod=marker

0 commit comments

Comments
 (0)