@@ -56,6 +56,11 @@ class Variable(TypedAstNode):
56
56
'stack' if memory should be allocated on the stack, represents stack arrays and scalars.
57
57
'alias' if object allows access to memory stored in another variable.
58
58
59
+ memory_location: str, default: 'host'
60
+ 'host' the variable can only be accessed by the CPU.
61
+ 'device' the variable can only be accessed by the GPU.
62
+ 'managed' the variable can be accessed by CPU and GPU and is being managed by the Cuda API (memory transfer is being done implicitly).
63
+
59
64
is_const : bool, default: False
60
65
Indicates if object is a const argument of a function.
61
66
@@ -98,7 +103,7 @@ class Variable(TypedAstNode):
98
103
>>> Variable(PythonNativeInt(), DottedName('matrix', 'n_rows'))
99
104
matrix.n_rows
100
105
"""
101
- __slots__ = ('_name' , '_alloc_shape' , '_memory_handling' , '_is_const' , '_is_target' ,
106
+ __slots__ = ('_name' , '_alloc_shape' , '_memory_handling' , '_memory_location' , ' _is_const' , '_is_target' ,
102
107
'_is_optional' , '_allows_negative_indexes' , '_cls_base' , '_is_argument' , '_is_temp' ,
103
108
'_shape' ,'_is_private' ,'_class_type' )
104
109
_attribute_nodes = ()
@@ -109,6 +114,7 @@ def __init__(
109
114
name ,
110
115
* ,
111
116
memory_handling = 'stack' ,
117
+ memory_location = 'host' ,
112
118
is_const = False ,
113
119
is_target = False ,
114
120
is_optional = False ,
@@ -141,6 +147,10 @@ def __init__(
141
147
raise ValueError ("memory_handling must be 'heap', 'stack' or 'alias'" )
142
148
self ._memory_handling = memory_handling
143
149
150
+ if memory_location not in ('host' , 'device' , 'managed' ):
151
+ raise ValueError ("memory_location must be 'host', 'device' or 'managed'" )
152
+ self ._memory_location = memory_location
153
+
144
154
if not isinstance (is_const , bool ):
145
155
raise TypeError ('is_const must be a boolean.' )
146
156
self ._is_const = is_const
@@ -323,6 +333,36 @@ def cls_base(self):
323
333
"""
324
334
return self ._cls_base
325
335
336
+ @property
337
+ def memory_location (self ):
338
+ """ Indicates whether a Variable has a dynamic size
339
+ """
340
+ return self ._memory_location
341
+
342
+ @memory_location .setter
343
+ def memory_location (self , memory_location ):
344
+ if memory_location not in ('host' , 'device' , 'managed' ):
345
+ raise ValueError ("memory_location must be 'host', 'device' or 'managed'" )
346
+ self ._memory_location = memory_location
347
+
348
+ @property
349
+ def on_host (self ):
350
+ """ Indicates if memory is only accessible by the CPU
351
+ """
352
+ return self .memory_location == 'host'
353
+
354
+ @property
355
+ def on_device (self ):
356
+ """ Indicates if memory is only accessible by the GPU
357
+ """
358
+ return self .memory_location == 'device'
359
+
360
+ @property
361
+ def is_managed (self ):
362
+ """ Indicates if memory is being managed by CUDA API
363
+ """
364
+ return self .memory_location == 'managed'
365
+
326
366
@property
327
367
def is_const (self ):
328
368
"""
0 commit comments