|
2 | 2 | #
|
3 | 3 | # SPDX-License-Identifier: Apache-2.0
|
4 | 4 |
|
5 |
| -from warnings import warn |
6 | 5 |
|
7 |
| -import dpctl |
8 |
| -from numba.core.types import Array as NpArrayType |
9 |
| - |
10 |
| -from numba_dpex.core.exceptions import ( |
11 |
| - ComputeFollowsDataInferenceError, |
12 |
| - ExecutionQueueInferenceError, |
13 |
| -) |
| 6 | +from numba_dpex.core.exceptions import ExecutionQueueInferenceError |
14 | 7 | from numba_dpex.core.types import USMNdArray
|
15 | 8 |
|
16 | 9 |
|
@@ -50,131 +43,34 @@ def chk_compute_follows_data_compliance(usm_array_arglist):
|
50 | 43 | def determine_kernel_launch_queue(args, argtypes, kernel_name):
|
51 | 44 | """Determines the queue where the kernel is to be launched.
|
52 | 45 |
|
53 |
| - The execution queue is derived using the following algorithm. In future, |
54 |
| - support for ``numpy.ndarray`` and ``dpctl.device_context`` is to be |
55 |
| - removed and queue derivation will follows Python Array API's |
56 |
| - "compute follows data" logic. |
57 |
| -
|
58 |
| - Check if there are array arguments. |
59 |
| - True: |
60 |
| - Check if all array arguments are of type numpy.ndarray |
61 |
| - (numba.types.Array) |
62 |
| - True: |
63 |
| - Check if the kernel was invoked from within a |
64 |
| - dpctl.device_context. |
65 |
| - True: |
66 |
| - Provide a deprecation warning for device_context use and |
67 |
| - point to using dpctl.tensor.usm_ndarray or dpnp.ndarray |
68 |
| -
|
69 |
| - return dpctl.get_current_queue |
70 |
| - False: |
71 |
| - Raise ExecutionQueueInferenceError |
72 |
| - False: |
73 |
| - Check if all of the arrays are USMNdarray |
74 |
| - True: |
75 |
| - Check if execution queue could be inferred using |
76 |
| - compute follows data rules |
77 |
| - True: |
78 |
| - return the compute follows data inferred queue |
79 |
| - False: |
80 |
| - Raise ComputeFollowsDataInferenceError |
81 |
| - False: |
82 |
| - Raise ComputeFollowsDataInferenceError |
83 |
| - False: |
84 |
| - Check if the kernel was invoked from within a dpctl.device_context. |
85 |
| - True: |
86 |
| - Provide a deprecation warning for device_context use and |
87 |
| - point to using dpctl.tensor.usm_ndarray of dpnp.ndarray |
88 |
| -
|
89 |
| - return dpctl.get_current_queue |
90 |
| - False: |
91 |
| - Raise ExecutionQueueInferenceError |
| 46 | + The execution queue is derived following Python Array API's |
| 47 | + "compute follows data" programming model. |
92 | 48 |
|
93 | 49 | Args:
|
94 |
| - args : A list of arguments passed to the kernel stored in the |
95 |
| - launcher. |
96 | 50 | argtypes : The Numba inferred type for each argument.
|
| 51 | + kernel_name : The name of the kernel function |
97 | 52 |
|
98 | 53 | Returns:
|
99 | 54 | A queue the common queue used to allocate the arrays. If no such
|
100 | 55 | queue exists, then raises an Exception.
|
101 | 56 |
|
102 | 57 | Raises:
|
103 |
| - ComputeFollowsDataInferenceError: If the queue could not be inferred |
104 |
| - using compute follows data rules. |
105 |
| - ExecutionQueueInferenceError: If the queue could not be inferred |
106 |
| - using the dpctl queue manager. |
| 58 | + ExecutionQueueInferenceError: If the queue could not be inferred. |
107 | 59 | """
|
108 | 60 |
|
109 |
| - # FIXME: The args parameter is not needed once numpy support is removed |
110 |
| - |
111 |
| - # Needed as USMNdArray derives from Array |
112 |
| - array_argnums = [ |
113 |
| - i |
114 |
| - for i, _ in enumerate(args) |
115 |
| - if isinstance(argtypes[i], NpArrayType) |
116 |
| - and not isinstance(argtypes[i], USMNdArray) |
117 |
| - ] |
118 | 61 | usmarray_argnums = [
|
119 | 62 | i for i, _ in enumerate(args) if isinstance(argtypes[i], USMNdArray)
|
120 | 63 | ]
|
121 | 64 |
|
122 |
| - # if usm and non-usm array arguments are getting mixed, then the |
123 |
| - # execution queue cannot be inferred using compute follows data rules. |
124 |
| - if array_argnums and usmarray_argnums: |
125 |
| - raise ComputeFollowsDataInferenceError( |
126 |
| - array_argnums, usmarray_argnum_list=usmarray_argnums |
| 65 | + usm_array_args = [ |
| 66 | + argtype for i, argtype in enumerate(argtypes) if i in usmarray_argnums |
| 67 | + ] |
| 68 | + |
| 69 | + queue = chk_compute_follows_data_compliance(usm_array_args) |
| 70 | + |
| 71 | + if not queue: |
| 72 | + raise ExecutionQueueInferenceError( |
| 73 | + kernel_name, usmarray_argnum_list=usmarray_argnums |
127 | 74 | )
|
128 |
| - elif array_argnums and not usmarray_argnums: |
129 |
| - if dpctl.is_in_device_context(): |
130 |
| - warn( |
131 |
| - "Support for dpctl.device_context to specify the " |
132 |
| - + "execution queue is deprecated. " |
133 |
| - + "Use dpctl.tensor.usm_ndarray based array " |
134 |
| - + "containers instead. ", |
135 |
| - DeprecationWarning, |
136 |
| - stacklevel=2, |
137 |
| - ) |
138 |
| - warn( |
139 |
| - "Support for NumPy ndarray objects as kernel arguments is " |
140 |
| - + "deprecated. Use dpctl.tensor.usm_ndarray based array " |
141 |
| - + "containers instead. ", |
142 |
| - DeprecationWarning, |
143 |
| - stacklevel=2, |
144 |
| - ) |
145 |
| - return dpctl.get_current_queue() |
146 |
| - else: |
147 |
| - raise ExecutionQueueInferenceError(kernel_name) |
148 |
| - elif usmarray_argnums and not array_argnums: |
149 |
| - if dpctl.is_in_device_context(): |
150 |
| - warn( |
151 |
| - "dpctl.device_context ignored as the kernel arguments " |
152 |
| - + "are dpctl.tensor.usm_ndarray based array containers." |
153 |
| - ) |
154 |
| - usm_array_args = [ |
155 |
| - argtype |
156 |
| - for i, argtype in enumerate(argtypes) |
157 |
| - if i in usmarray_argnums |
158 |
| - ] |
159 |
| - |
160 |
| - queue = chk_compute_follows_data_compliance(usm_array_args) |
161 | 75 |
|
162 |
| - if not queue: |
163 |
| - raise ComputeFollowsDataInferenceError( |
164 |
| - kernel_name, usmarray_argnum_list=usmarray_argnums |
165 |
| - ) |
166 |
| - |
167 |
| - return queue |
168 |
| - else: |
169 |
| - if dpctl.is_in_device_context(): |
170 |
| - warn( |
171 |
| - "Support for dpctl.device_context to specify the " |
172 |
| - + "execution queue is deprecated. " |
173 |
| - + "Use dpctl.tensor.usm_ndarray based array " |
174 |
| - + "containers instead. ", |
175 |
| - DeprecationWarning, |
176 |
| - stacklevel=2, |
177 |
| - ) |
178 |
| - return dpctl.get_current_queue() |
179 |
| - else: |
180 |
| - raise ExecutionQueueInferenceError(kernel_name) |
| 76 | + return queue |
0 commit comments