@@ -119,11 +119,11 @@ class DtypeRegistry:
119119 _mappings : dict [type , DtypeInfo ] = {
120120 np .float32 : DtypeInfo (np .float32 , "Float32" , float ),
121121 np .float64 : DtypeInfo (np .float64 , "Float64" , float ),
122- np .int32 : DtypeInfo (np .int32 , "Int32 " , int ),
122+ np .int32 : DtypeInfo (np .int32 , "Int64 " , int ),
123123 np .int64 : DtypeInfo (np .int64 , "Int64" , int ),
124- np .uint8 : DtypeInfo (np .uint8 , "UInt8 " , int ),
125- np .uint16 : DtypeInfo (np .uint16 , "UInt16 " , int ),
126- np .uint32 : DtypeInfo (np .uint32 , "UInt32 " , int ),
124+ np .uint8 : DtypeInfo (np .uint8 , "Int64 " , int ),
125+ np .uint16 : DtypeInfo (np .uint16 , "Int64 " , int ),
126+ np .uint32 : DtypeInfo (np .uint32 , "Int64 " , int ),
127127 }
128128
129129 @classmethod
@@ -135,20 +135,6 @@ def get_by_dtype(cls, dtype: Any) -> DtypeInfo | None:
135135 )
136136 return cls ._mappings .get (dtype )
137137
138- @staticmethod
139- def get_by_kind (kind : str ) -> DtypeInfo | None :
140- """Get DtypeInfo by kind."""
141- return next (
142- (info for info in DtypeRegistry ._mappings .values () if info .kind == kind ),
143- None ,
144- )
145-
146- @staticmethod
147- def rust_compatible_kind (kind : str ) -> str :
148- """Map to a Rust-compatible kind for schema encoding."""
149- # incompatible_integer_kinds = {"Int32", "UInt8", "UInt16", "UInt32", "UInt64"}
150- return "Int64" if "Int" in kind else kind
151-
152138 @staticmethod
153139 def supported_dtypes () -> KeysView [type ]:
154140 """Get a list of supported NumPy dtypes."""
@@ -167,6 +153,9 @@ class AnalyzedTypeInfo:
167153
168154 key_type : type | None # For element of KTable
169155 struct_type : type | None # For Struct, a dataclass or namedtuple
156+ np_number_type : (
157+ type | None
158+ ) # NumPy dtype for the element type, if represented by numpy.ndarray or a NumPy scalar
170159
171160 attrs : dict [str , Any ] | None
172161 nullable : bool = False
@@ -221,6 +210,7 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
221210 struct_type : type | None = None
222211 elem_type : ElementType | None = None
223212 key_type : type | None = None
213+ np_number_type : type | None = None
224214 if _is_struct_type (t ):
225215 struct_type = t
226216
@@ -254,11 +244,11 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
254244 if not dtype_args :
255245 raise ValueError ("Invalid dtype specification for NDArray" )
256246
257- numpy_dtype = dtype_args [0 ]
258- dtype_info = DtypeRegistry .get_by_dtype (numpy_dtype )
247+ np_number_type = dtype_args [0 ]
248+ dtype_info = DtypeRegistry .get_by_dtype (np_number_type )
259249 if dtype_info is None :
260250 raise ValueError (
261- f"Unsupported numpy dtype for NDArray: { numpy_dtype } . "
251+ f"Unsupported numpy dtype for NDArray: { np_number_type } . "
262252 f"Supported dtypes: { DtypeRegistry .supported_dtypes ()} "
263253 )
264254 elem_type = dtype_info .annotated_type
@@ -272,6 +262,7 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
272262 dtype_info = DtypeRegistry .get_by_dtype (t )
273263 if dtype_info is not None :
274264 kind = dtype_info .kind
265+ np_number_type = dtype_info .numpy_dtype
275266 elif t is bytes :
276267 kind = "Bytes"
277268 elif t is str :
@@ -301,6 +292,7 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
301292 elem_type = elem_type ,
302293 key_type = key_type ,
303294 struct_type = struct_type ,
295+ np_number_type = np_number_type ,
304296 attrs = attrs ,
305297 nullable = nullable ,
306298 )
@@ -355,9 +347,6 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
355347 raise ValueError ("Vector type must have an element type" )
356348 elem_type_info = analyze_type_info (type_info .elem_type )
357349 encoded_type ["element_type" ] = _encode_type (elem_type_info )
358- encoded_type ["element_type" ]["kind" ] = DtypeRegistry .rust_compatible_kind (
359- elem_type_info .kind
360- )
361350 encoded_type ["dimension" ] = type_info .vector_info .dim
362351
363352 elif type_info .kind in TABLE_TYPES :
0 commit comments