@@ -169,16 +169,26 @@ def _dtypes(self, kind):
169
169
int32 = torch .int32
170
170
int64 = torch .int64
171
171
uint8 = torch .uint8
172
- # uint16, uint32, and uint64 are present in newer versions of pytorch,
173
- # but they aren't generally supported by the array API functions, so
174
- # we omit them from this function.
172
+ try :
173
+ # pytorch >= 2.3
174
+ uint16 = torch .uint16
175
+ uint32 = torch .uint32
176
+ uint64 = torch .uint64
177
+ uint_kinds = {
178
+ "uint16" : uint16 ,
179
+ "uint32" : uint32 ,
180
+ "uint64" : uint64 ,
181
+ }
182
+ except AttributeError :
183
+ uint_kinds = {}
184
+
175
185
float32 = torch .float32
176
186
float64 = torch .float64
177
187
complex64 = torch .complex64
178
188
complex128 = torch .complex128
179
189
180
190
if kind is None :
181
- return {
191
+ kinds = {
182
192
"bool" : bool ,
183
193
"int8" : int8 ,
184
194
"int16" : int16 ,
@@ -190,6 +200,8 @@ def _dtypes(self, kind):
190
200
"complex64" : complex64 ,
191
201
"complex128" : complex128 ,
192
202
}
203
+ kinds .update (uint_kinds )
204
+ return kinds
193
205
if kind == "bool" :
194
206
return {"bool" : bool }
195
207
if kind == "signed integer" :
@@ -200,17 +212,21 @@ def _dtypes(self, kind):
200
212
"int64" : int64 ,
201
213
}
202
214
if kind == "unsigned integer" :
203
- return {
215
+ kinds = {
204
216
"uint8" : uint8 ,
205
217
}
218
+ kinds .update (uint_kinds )
219
+ return kinds
206
220
if kind == "integral" :
207
- return {
221
+ kinds = {
208
222
"int8" : int8 ,
209
223
"int16" : int16 ,
210
224
"int32" : int32 ,
211
225
"int64" : int64 ,
212
226
"uint8" : uint8 ,
213
227
}
228
+ kinds .update (uint_kinds )
229
+ return kinds
214
230
if kind == "real floating" :
215
231
return {
216
232
"float32" : float32 ,
@@ -222,7 +238,7 @@ def _dtypes(self, kind):
222
238
"complex128" : complex128 ,
223
239
}
224
240
if kind == "numeric" :
225
- return {
241
+ kinds = {
226
242
"int8" : int8 ,
227
243
"int16" : int16 ,
228
244
"int32" : int32 ,
@@ -233,6 +249,9 @@ def _dtypes(self, kind):
233
249
"complex64" : complex64 ,
234
250
"complex128" : complex128 ,
235
251
}
252
+ kinds .update (uint_kinds )
253
+ return kinds
254
+
236
255
if isinstance (kind , tuple ):
237
256
res = {}
238
257
for k in kind :
0 commit comments