@@ -67,7 +67,7 @@ def __class_getitem__(self, params):
6767 # No dimension provided, e.g., Vector[np.float32]
6868 dtype = params
6969 # Use NDArray for supported numeric dtypes, else list
70- if DtypeRegistry . get_by_dtype ( dtype ) is not None :
70+ if dtype in DtypeRegistry . _DTYPE_TO_KIND :
7171 return Annotated [NDArray [dtype ], VectorInfo (dim = None )]
7272 return Annotated [list [dtype ], VectorInfo (dim = None )]
7373 else :
@@ -79,7 +79,7 @@ def __class_getitem__(self, params):
7979 if typing .get_origin (dim_literal ) is Literal
8080 else None
8181 )
82- if DtypeRegistry . get_by_dtype ( dtype ) is not None :
82+ if dtype in DtypeRegistry . _DTYPE_TO_KIND :
8383 return Annotated [NDArray [dtype ], VectorInfo (dim = dim_val )]
8484 return Annotated [list [dtype ], VectorInfo (dim = dim_val )]
8585
@@ -119,34 +119,28 @@ class DtypeRegistry:
119119 Maps NumPy dtypes to their CocoIndex type kind.
120120 """
121121
122- _DTYPE_TO_KIND : dict [type , str ] = {
122+ _DTYPE_TO_KIND : dict [ElementType , str ] = {
123123 np .float32 : "Float32" ,
124124 np .float64 : "Float64" ,
125125 np .int64 : "Int64" ,
126126 }
127127
128128 @classmethod
129- def get_by_dtype (cls , dtype : Any ) -> tuple [type , str ] | None :
130- """Get the NumPy dtype and its CocoIndex kind by dtype."""
129+ def validate_dtype_and_get_kind (cls , dtype : ElementType ) -> str :
130+ """
131+ Validate that the given dtype is supported, and get its CocoIndex kind by dtype.
132+ """
131133 if dtype is Any :
132134 raise TypeError (
133135 "NDArray for Vector must use a concrete numpy dtype, got `Any`."
134136 )
135137 kind = cls ._DTYPE_TO_KIND .get (dtype )
136- return None if kind is None else (dtype , kind )
137-
138- @classmethod
139- def validate_and_get_dtype_info (cls , dtype : Any ) -> tuple [type , str ]:
140- """
141- Validate that the given dtype is supported.
142- """
143- dtype_info = cls .get_by_dtype (dtype )
144- if dtype_info is None :
138+ if kind is None :
145139 raise ValueError (
146140 f"Unsupported NumPy dtype in NDArray: { dtype } . "
147141 f"Supported dtypes: { cls ._DTYPE_TO_KIND .keys ()} "
148142 )
149- return dtype_info
143+ return kind
150144
151145
152146@dataclasses .dataclass
@@ -227,7 +221,8 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
227221 elif kind != "Struct" :
228222 raise ValueError (f"Unexpected type kind for struct: { kind } " )
229223 elif is_numpy_number_type (t ):
230- np_number_type , kind = DtypeRegistry .validate_and_get_dtype_info (t )
224+ np_number_type = t
225+ kind = DtypeRegistry .validate_dtype_and_get_kind (t )
231226 elif base_type is collections .abc .Sequence or base_type is list :
232227 args = typing .get_args (t )
233228 elem_type = args [0 ]
@@ -249,7 +244,7 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
249244 kind = "Vector"
250245 np_number_type = t
251246 elem_type = extract_ndarray_scalar_dtype (np_number_type )
252- _ = DtypeRegistry .validate_and_get_dtype_info (elem_type )
247+ _ = DtypeRegistry .validate_dtype_and_get_kind (elem_type )
253248 vector_info = VectorInfo (dim = None ) if vector_info is None else vector_info
254249
255250 elif base_type is collections .abc .Mapping or base_type is dict :
0 commit comments