77more details.
88
99"""
10- from torch import (
11- asarray ,
12- get_default_dtype ,
13- device ,
14- empty ,
15- bool ,
16- int8 ,
17- int16 ,
18- int32 ,
19- int64 ,
20- uint8 ,
21- uint16 ,
22- uint32 ,
23- uint64 ,
24- float32 ,
25- float64 ,
26- complex64 ,
27- complex128 ,
28- )
10+ import torch
2911
3012from functools import cache
3113
@@ -130,7 +112,7 @@ def default_device(self):
130112 'cpu'
131113
132114 """
133- return device ("cpu" )
115+ return torch . device ("cpu" )
134116
135117 def default_dtypes (self , * , device = None ):
136118 """
@@ -165,80 +147,32 @@ def default_dtypes(self, *, device=None):
165147 'indexing': torch.int64}
166148
167149 """
168- default_floating = get_default_dtype ()
169- default_complex = complex64 if default_floating == float32 else complex128
170- default_integral = asarray (0 , device = device ).dtype
150+ default_floating = torch . get_default_dtype ()
151+ default_complex = torch . complex64 if default_floating == torch . float32 else torch . complex128
152+ default_integral = torch . asarray (0 , device = device ).dtype
171153 return {
172154 "real floating" : default_floating ,
173155 "complex floating" : default_complex ,
174156 "integral" : default_integral ,
175157 "indexing" : default_integral ,
176158 }
177159
178- @cache
179- def dtypes (self , * , device = None , kind = None ):
180- """
181- The array API data types supported by PyTorch.
182-
183- Note that this function only returns data types that are defined by
184- the array API.
185-
186- Parameters
187- ----------
188- device : str, optional
189- The device to get the data types for.
190- kind : str or tuple of str, optional
191- The kind of data types to return. If ``None``, all data types are
192- returned. If a string, only data types of that kind are returned.
193- If a tuple, a dictionary containing the union of the given kinds
194- is returned. The following kinds are supported:
195-
196- - ``'bool'``: boolean data types (i.e., ``bool``).
197- - ``'signed integer'``: signed integer data types (i.e., ``int8``,
198- ``int16``, ``int32``, ``int64``).
199- - ``'unsigned integer'``: unsigned integer data types (i.e.,
200- ``uint8``, ``uint16``, ``uint32``, ``uint64``).
201- - ``'integral'``: integer data types. Shorthand for ``('signed
202- integer', 'unsigned integer')``.
203- - ``'real floating'``: real-valued floating-point data types
204- (i.e., ``float32``, ``float64``).
205- - ``'complex floating'``: complex floating-point data types (i.e.,
206- ``complex64``, ``complex128``).
207- - ``'numeric'``: numeric data types. Shorthand for ``('integral',
208- 'real floating', 'complex floating')``.
209-
210- Returns
211- -------
212- dtypes : dict
213- A dictionary mapping the names of data types to the corresponding
214- PyTorch data types.
215-
216- See Also
217- --------
218- __array_namespace_info__.capabilities,
219- __array_namespace_info__.default_device,
220- __array_namespace_info__.default_dtypes,
221- __array_namespace_info__.devices
222-
223- Examples
224- --------
225- >>> info = np.__array_namespace_info__()
226- >>> info.dtypes(kind='signed integer')
227- {'int8': numpy.int8,
228- 'int16': numpy.int16,
229- 'int32': numpy.int32,
230- 'int64': numpy.int64}
231-
232- """
233- res = self ._dtypes (kind )
234- for k , v in res .copy ().items ():
235- try :
236- empty ((0 ,), dtype = v , device = device )
237- except :
238- del res [k ]
239- return res
240160
241161 def _dtypes (self , kind ):
162+ bool = torch .bool
163+ int8 = torch .int8
164+ int16 = torch .int16
165+ int32 = torch .int32
166+ int64 = torch .int64
167+ uint8 = getattr (torch , "uint8" , None )
168+ uint16 = getattr (torch , "uint16" , None )
169+ uint32 = getattr (torch , "uint32" , None )
170+ uint64 = getattr (torch , "uint64" , None )
171+ float32 = torch .float32
172+ float64 = torch .float64
173+ complex64 = torch .complex64
174+ complex128 = torch .complex128
175+
242176 if kind is None :
243177 return {
244178 "bool" : bool ,
@@ -314,6 +248,72 @@ def _dtypes(self, kind):
314248 return res
315249 raise ValueError (f"unsupported kind: { kind !r} " )
316250
251+ @cache
252+ def dtypes (self , * , device = None , kind = None ):
253+ """
254+ The array API data types supported by PyTorch.
255+
256+ Note that this function only returns data types that are defined by
257+ the array API.
258+
259+ Parameters
260+ ----------
261+ device : str, optional
262+ The device to get the data types for.
263+ kind : str or tuple of str, optional
264+ The kind of data types to return. If ``None``, all data types are
265+ returned. If a string, only data types of that kind are returned.
266+ If a tuple, a dictionary containing the union of the given kinds
267+ is returned. The following kinds are supported:
268+
269+ - ``'bool'``: boolean data types (i.e., ``bool``).
270+ - ``'signed integer'``: signed integer data types (i.e., ``int8``,
271+ ``int16``, ``int32``, ``int64``).
272+ - ``'unsigned integer'``: unsigned integer data types (i.e.,
273+ ``uint8``, ``uint16``, ``uint32``, ``uint64``).
274+ - ``'integral'``: integer data types. Shorthand for ``('signed
275+ integer', 'unsigned integer')``.
276+ - ``'real floating'``: real-valued floating-point data types
277+ (i.e., ``float32``, ``float64``).
278+ - ``'complex floating'``: complex floating-point data types (i.e.,
279+ ``complex64``, ``complex128``).
280+ - ``'numeric'``: numeric data types. Shorthand for ``('integral',
281+ 'real floating', 'complex floating')``.
282+
283+ Returns
284+ -------
285+ dtypes : dict
286+ A dictionary mapping the names of data types to the corresponding
287+ PyTorch data types.
288+
289+ See Also
290+ --------
291+ __array_namespace_info__.capabilities,
292+ __array_namespace_info__.default_device,
293+ __array_namespace_info__.default_dtypes,
294+ __array_namespace_info__.devices
295+
296+ Examples
297+ --------
298+ >>> info = np.__array_namespace_info__()
299+ >>> info.dtypes(kind='signed integer')
300+ {'int8': numpy.int8,
301+ 'int16': numpy.int16,
302+ 'int32': numpy.int32,
303+ 'int64': numpy.int64}
304+
305+ """
306+ res = self ._dtypes (kind )
307+ for k , v in res .copy ().items ():
308+ if v is None :
309+ del res [k ]
310+ continue
311+ try :
312+ torch .empty ((0 ,), dtype = v , device = device )
313+ except :
314+ del res [k ]
315+ return res
316+
317317 @cache
318318 def devices (self ):
319319 """
@@ -343,7 +343,7 @@ def devices(self):
343343 # message of torch.device to get the list of all possible types of
344344 # device:
345345 try :
346- device ('notadevice' )
346+ torch . device ('notadevice' )
347347 except RuntimeError as e :
348348 # The error message is something like:
349349 # "Expected one of cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone device type at start of device string: notadevice"
@@ -358,7 +358,7 @@ def devices(self):
358358 i = 0
359359 while True :
360360 try :
361- a = empty ((0 ,), device = device (device_name , index = i ))
361+ a = torch . empty ((0 ,), device = torch . device (device_name , index = i ))
362362 if a .device in devices :
363363 break
364364 devices .append (a .device )
0 commit comments