11import dataclasses
22import datetime
33import uuid
4- from typing import (
5- Annotated ,
6- List ,
7- Dict ,
8- Literal ,
9- Any ,
10- get_args ,
11- NamedTuple ,
12- )
13- from collections .abc import Sequence , Mapping
14- import pytest
4+ from collections .abc import Mapping , Sequence
5+ from typing import Annotated , Any , Dict , List , Literal , NamedTuple , get_args , get_origin
6+
157import numpy as np
8+ import pytest
169from numpy .typing import NDArray
1710
1811from cocoindex .typing import (
19- analyze_type_info ,
20- Vector ,
21- VectorInfo ,
22- TypeKind ,
23- TypeAttr ,
12+ AnalyzedTypeInfo ,
2413 Float32 ,
2514 Float64 ,
15+ TypeAttr ,
16+ TypeKind ,
17+ Vector ,
18+ VectorInfo ,
19+ analyze_type_info ,
2620 encode_enriched_type ,
27- AnalyzedTypeInfo ,
2821)
2922
3023
@@ -42,61 +35,57 @@ class SimpleNamedTuple(NamedTuple):
4235def test_ndarray_float32_no_dim () -> None :
4336 typ = NDArray [np .float32 ]
4437 result = analyze_type_info (typ )
45- assert result == AnalyzedTypeInfo (
46- kind = "Vector" ,
47- vector_info = VectorInfo (dim = None ),
48- elem_type = Float32 ,
49- key_type = None ,
50- struct_type = None ,
51- np_number_type = np .ndarray [tuple [int , ...], np .dtype [np .float32 ]],
52- attrs = None ,
53- nullable = False ,
54- )
38+ assert result .kind == "Vector"
39+ assert result .vector_info == VectorInfo (dim = None )
40+ assert result .elem_type == Float32
41+ assert result .key_type is None
42+ assert result .struct_type is None
43+ assert result .nullable is False
44+ assert result .np_number_type is not None
45+ assert get_origin (result .np_number_type ) == np .ndarray
46+ assert get_args (result .np_number_type )[1 ] == np .dtype [np .float32 ]
5547
5648
5749def test_vector_float32_no_dim () -> None :
5850 typ = Vector [np .float32 ]
5951 result = analyze_type_info (typ )
60- assert result == AnalyzedTypeInfo (
61- kind = "Vector" ,
62- vector_info = VectorInfo (dim = None ),
63- elem_type = Float32 ,
64- key_type = None ,
65- struct_type = None ,
66- np_number_type = np .ndarray [tuple [int , ...], np .dtype [np .float32 ]],
67- attrs = None ,
68- nullable = False ,
69- )
52+ assert result .kind == "Vector"
53+ assert result .vector_info == VectorInfo (dim = None )
54+ assert result .elem_type == Float32
55+ assert result .key_type is None
56+ assert result .struct_type is None
57+ assert result .nullable is False
58+ assert result .np_number_type is not None
59+ assert get_origin (result .np_number_type ) == np .ndarray
60+ assert get_args (result .np_number_type )[1 ] == np .dtype [np .float32 ]
7061
7162
7263def test_ndarray_float64_with_dim () -> None :
7364 typ = Annotated [NDArray [np .float64 ], VectorInfo (dim = 128 )]
7465 result = analyze_type_info (typ )
75- assert result == AnalyzedTypeInfo (
76- kind = "Vector" ,
77- vector_info = VectorInfo (dim = 128 ),
78- elem_type = Float64 ,
79- key_type = None ,
80- struct_type = None ,
81- np_number_type = np .ndarray [tuple [int , ...], np .dtype [np .float64 ]],
82- attrs = None ,
83- nullable = False ,
84- )
66+ assert result .kind == "Vector"
67+ assert result .vector_info == VectorInfo (dim = 128 )
68+ assert result .elem_type == Float64
69+ assert result .key_type is None
70+ assert result .struct_type is None
71+ assert result .nullable is False
72+ assert result .np_number_type is not None
73+ assert get_origin (result .np_number_type ) == np .ndarray
74+ assert get_args (result .np_number_type )[1 ] == np .dtype [np .float64 ]
8575
8676
8777def test_vector_float32_with_dim () -> None :
8878 typ = Vector [np .float32 , Literal [384 ]]
8979 result = analyze_type_info (typ )
90- assert result == AnalyzedTypeInfo (
91- kind = "Vector" ,
92- vector_info = VectorInfo (dim = 384 ),
93- elem_type = Float32 ,
94- key_type = None ,
95- struct_type = None ,
96- np_number_type = np .ndarray [tuple [int , ...], np .dtype [np .float32 ]],
97- attrs = None ,
98- nullable = False ,
99- )
80+ assert result .kind == "Vector"
81+ assert result .vector_info == VectorInfo (dim = 384 )
82+ assert result .elem_type == Float32
83+ assert result .key_type is None
84+ assert result .struct_type is None
85+ assert result .nullable is False
86+ assert result .np_number_type is not None
87+ assert get_origin (result .np_number_type ) == np .ndarray
88+ assert get_args (result .np_number_type )[1 ] == np .dtype [np .float32 ]
10089
10190
10291def test_ndarray_int64_no_dim () -> None :
@@ -105,37 +94,39 @@ def test_ndarray_int64_no_dim() -> None:
10594 assert result .kind == "Vector"
10695 assert result .vector_info == VectorInfo (dim = None )
10796 assert get_args (result .elem_type ) == (int , TypeKind ("Int64" ))
108- assert not result .nullable
97+ assert result .nullable is False
98+ assert result .np_number_type is not None
99+ assert get_origin (result .np_number_type ) == np .ndarray
100+ assert get_args (result .np_number_type )[1 ] == np .dtype [np .int64 ]
109101
110102
111103def test_nullable_ndarray () -> None :
112104 typ = NDArray [np .float32 ] | None
113105 result = analyze_type_info (typ )
114- assert result == AnalyzedTypeInfo (
115- kind = "Vector" ,
116- vector_info = VectorInfo (dim = None ),
117- elem_type = Float32 ,
118- key_type = None ,
119- struct_type = None ,
120- np_number_type = np .ndarray [tuple [int , ...], np .dtype [np .float32 ]],
121- attrs = None ,
122- nullable = True ,
123- )
106+ assert result .kind == "Vector"
107+ assert result .vector_info == VectorInfo (dim = None )
108+ assert result .elem_type == Float32
109+ assert result .key_type is None
110+ assert result .struct_type is None
111+ assert result .nullable is True
112+ assert result .np_number_type is not None
113+ assert get_origin (result .np_number_type ) == np .ndarray
114+ assert get_args (result .np_number_type )[1 ] == np .dtype [np .float32 ]
124115
125116
126- def test_scalar_numpy_types ():
117+ def test_scalar_numpy_types () -> None :
127118 for np_type , expected_kind in [
128119 (np .int64 , "Int64" ),
129120 (np .float32 , "Float32" ),
130121 (np .float64 , "Float64" ),
131122 ]:
132123 type_info = analyze_type_info (np_type )
133- assert (
134- type_info .kind == expected_kind
135- ), f"Expected { expected_kind } for { np_type } , got { type_info . kind } "
136- assert (
137- type_info .np_number_type == np_type
138- ), f"Expected { np_type } , got { type_info . np_number_type } "
124+ assert type_info . kind == expected_kind , (
125+ f"Expected { expected_kind } for { np_type } , got { type_info .kind } "
126+ )
127+ assert type_info . np_number_type == np_type , (
128+ f"Expected { np_type } , got { type_info .np_number_type } "
129+ )
139130 assert type_info .elem_type is None
140131 assert type_info .vector_info is None
141132
@@ -144,7 +135,7 @@ def test_vector_str() -> None:
144135 typ = Vector [str ]
145136 result = analyze_type_info (typ )
146137 assert result .kind == "Vector"
147- assert result .elem_type == str
138+ assert result .elem_type is str
148139 assert result .vector_info == VectorInfo (dim = None )
149140
150141
@@ -160,7 +151,7 @@ def test_non_numpy_vector() -> None:
160151 typ = Vector [float , Literal [3 ]]
161152 result = analyze_type_info (typ )
162153 assert result .kind == "Vector"
163- assert result .elem_type == float
154+ assert result .elem_type is float
164155 assert result .vector_info == VectorInfo (dim = 3 )
165156
166157
@@ -504,16 +495,16 @@ def test_encode_enriched_type_nullable() -> None:
504495 assert result ["nullable" ] is True
505496
506497
507- def test_encode_scalar_numpy_types_schema ():
498+ def test_encode_scalar_numpy_types_schema () -> None :
508499 for np_type , expected_kind in [
509500 (np .int64 , "Int64" ),
510501 (np .float32 , "Float32" ),
511502 (np .float64 , "Float64" ),
512503 ]:
513504 schema = encode_enriched_type (np_type )
514- assert (
515- schema [" type" ][ " kind" ] == expected_kind
516- ), f"Expected { expected_kind } for { np_type } , got { schema [ 'type' ][ 'kind' ] } "
505+ assert schema [ "type" ][ "kind" ] == expected_kind , (
506+ f"Expected { expected_kind } for { np_type } , got { schema [' type' ][ ' kind' ] } "
507+ )
517508 assert not schema .get ("nullable" , False )
518509
519510
0 commit comments