Skip to content

Commit e2dc3ad

Browse files
committed
ENH: torch: add uintN type to __array_namespace_info__
1 parent bf43770 commit e2dc3ad

File tree

1 file changed

+26
-7
lines changed

1 file changed

+26
-7
lines changed

array_api_compat/torch/_info.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -169,16 +169,26 @@ def _dtypes(self, kind):
169169
int32 = torch.int32
170170
int64 = torch.int64
171171
uint8 = torch.uint8
172-
# uint16, uint32, and uint64 are present in newer versions of pytorch,
173-
# but they aren't generally supported by the array API functions, so
174-
# we omit them from this function.
172+
try:
173+
# pytorch >= 2.3
174+
uint16 = torch.uint16
175+
uint32 = torch.uint32
176+
uint64 = torch.uint64
177+
uint_kinds = {
178+
"uint16": uint16,
179+
"uint32": uint32,
180+
"uint64": uint64,
181+
}
182+
except AttributeError:
183+
uint_kinds = {}
184+
175185
float32 = torch.float32
176186
float64 = torch.float64
177187
complex64 = torch.complex64
178188
complex128 = torch.complex128
179189

180190
if kind is None:
181-
return {
191+
kinds = {
182192
"bool": bool,
183193
"int8": int8,
184194
"int16": int16,
@@ -190,6 +200,8 @@ def _dtypes(self, kind):
190200
"complex64": complex64,
191201
"complex128": complex128,
192202
}
203+
kinds.update(uint_kinds)
204+
return kinds
193205
if kind == "bool":
194206
return {"bool": bool}
195207
if kind == "signed integer":
@@ -200,17 +212,21 @@ def _dtypes(self, kind):
200212
"int64": int64,
201213
}
202214
if kind == "unsigned integer":
203-
return {
215+
kinds= {
204216
"uint8": uint8,
205217
}
218+
kinds.update(uint_kinds)
219+
return kinds
206220
if kind == "integral":
207-
return {
221+
kinds= {
208222
"int8": int8,
209223
"int16": int16,
210224
"int32": int32,
211225
"int64": int64,
212226
"uint8": uint8,
213227
}
228+
kinds.update(uint_kinds)
229+
return kinds
214230
if kind == "real floating":
215231
return {
216232
"float32": float32,
@@ -222,7 +238,7 @@ def _dtypes(self, kind):
222238
"complex128": complex128,
223239
}
224240
if kind == "numeric":
225-
return {
241+
kinds = {
226242
"int8": int8,
227243
"int16": int16,
228244
"int32": int32,
@@ -233,6 +249,9 @@ def _dtypes(self, kind):
233249
"complex64": complex64,
234250
"complex128": complex128,
235251
}
252+
kinds.update(uint_kinds)
253+
return kinds
254+
236255
if isinstance(kind, tuple):
237256
res = {}
238257
for k in kind:

0 commit comments

Comments
 (0)