31
31
"Info" ,
32
32
]
33
33
34
- from dataclasses import dataclass
35
34
from typing import (
36
35
Any ,
37
36
List ,
45
44
Protocol ,
46
45
)
47
46
from enum import Enum
47
+ from .data_types import DType
48
48
49
49
array = TypeVar ("array" , bound = "_array" )
50
50
device = TypeVar ("device" )
51
- dtype = TypeVar ("dtype" )
52
- Device = TypeVar ("Device" )
53
- Dtype = TypeVar ("Dtype" )
51
+ dtype = TypeVar ("dtype" , bound = DType )
52
+ device_ = TypeVar ("device_" ) # only used in this file
53
+ dtype_ = TypeVar ("dtype_" , bound = DType ) # only used in this file
54
54
SupportsDLPack = TypeVar ("SupportsDLPack" )
55
55
SupportsBufferProtocol = TypeVar ("SupportsBufferProtocol" )
56
56
PyCapsule = TypeVar ("PyCapsule" )
@@ -149,12 +149,12 @@ def dtypes(
149
149
)
150
150
151
151
152
- class _array (Protocol [array , Dtype , Device , PyCapsule ]): # type: ignore
152
+ class _array (Protocol [array , dtype_ , device_ , PyCapsule ]): # type: ignore
153
153
def __init__ (self : array ) -> None :
154
154
"""Initialize the attributes for the array object class."""
155
155
156
156
@property
157
- def dtype (self : array ) -> Dtype :
157
+ def dtype (self : array ) -> dtype_ :
158
158
"""
159
159
Data type of the array elements.
160
160
@@ -165,7 +165,7 @@ def dtype(self: array) -> Dtype:
165
165
"""
166
166
167
167
@property
168
- def device (self : array ) -> Device :
168
+ def device (self : array ) -> device_ :
169
169
"""
170
170
Hardware device the array data resides on.
171
171
@@ -1344,7 +1344,7 @@ def __xor__(self: array, other: Union[int, bool, array], /) -> array:
1344
1344
"""
1345
1345
1346
1346
def to_device (
1347
- self : array , device : Device , / , * , stream : Optional [Union [int , Any ]] = None
1347
+ self : array , device : device_ , / , * , stream : Optional [Union [int , Any ]] = None
1348
1348
) -> array :
1349
1349
"""
1350
1350
Copy the array from the device on which it currently resides to the specified ``device``.
0 commit comments