@@ -36,23 +36,36 @@ def cudf_from_avro_util(schema: dict, records: list) -> cudf.DataFrame:
3636 return cudf .read_avro (buffer )
3737
3838
39- avro_type_params = [
40- ("boolean" , "bool" ),
41- ("int" , "int32" ),
42- ("long" , "int64" ),
43- ("float" , "float32" ),
44- ("double" , "float64" ),
45- ("bytes" , "str" ),
46- ("string" , "str" ),
47- ]
48-
49-
50- @pytest .mark .parametrize ("avro_type, expected_dtype" , avro_type_params )
39+ @pytest .fixture (
40+ params = [
41+ ("boolean" , "bool" ),
42+ ("int" , "int32" ),
43+ ("long" , "int64" ),
44+ ("float" , "float32" ),
45+ ("double" , "float64" ),
46+ ("bytes" , "str" ),
47+ ("string" , "str" ),
48+ ]
49+ )
50+ def avro_type_params (request ):
51+ return request .param
52+
53+
54+ @pytest .fixture (params = [True , False ])
55+ def nullable (request ):
56+ return request .param
57+
58+
59+ @pytest .fixture (params = [True , False ])
60+ def prepend_null (request ):
61+ return request .param
62+
63+
5164@pytest .mark .parametrize ("namespace" , [None , "root_ns" ])
52- @pytest .mark .parametrize ("nullable" , [True , False ])
5365def test_can_detect_dtype_from_avro_type (
54- avro_type , expected_dtype , namespace , nullable
66+ avro_type_params , namespace , nullable
5567):
68+ avro_type , expected_dtype = avro_type_params
5669 avro_type = avro_type if not nullable else ["null" , avro_type ]
5770
5871 schema = fastavro .parse_schema (
@@ -73,12 +86,11 @@ def test_can_detect_dtype_from_avro_type(
7386 assert_eq (expected , actual )
7487
7588
76- @pytest .mark .parametrize ("avro_type, expected_dtype" , avro_type_params )
7789@pytest .mark .parametrize ("namespace" , [None , "root_ns" ])
78- @pytest .mark .parametrize ("nullable" , [True , False ])
7990def test_can_detect_dtype_from_avro_type_nested (
80- avro_type , expected_dtype , namespace , nullable
91+ avro_type_params , namespace , nullable
8192):
93+ avro_type , expected_dtype = avro_type_params
8294 avro_type = avro_type if not nullable else ["null" , avro_type ]
8395
8496 schema_leaf = {
@@ -146,8 +158,8 @@ def test_can_parse_single_value(avro_type, cudf_type, avro_val, cudf_val):
146158 assert_eq (expected , actual )
147159
148160
149- @ pytest . mark . parametrize ( "avro_type, cudf_type" , avro_type_params )
150- def test_can_parse_single_null ( avro_type , cudf_type ):
161+ def test_can_parse_single_null ( avro_type_params ):
162+ avro_type , expected_dtype = avro_type_params
151163 schema_root = {
152164 "name" : "root" ,
153165 "type" : "record" ,
@@ -159,14 +171,14 @@ def test_can_parse_single_null(avro_type, cudf_type):
159171 actual = cudf_from_avro_util (schema_root , records )
160172
161173 expected = cudf .DataFrame (
162- {"prop" : cudf .Series (data = [None ], dtype = cudf_type )}
174+ {"prop" : cudf .Series (data = [None ], dtype = expected_dtype )}
163175 )
164176
165177 assert_eq (expected , actual )
166178
167179
168- @ pytest . mark . parametrize ( "avro_type, cudf_type" , avro_type_params )
169- def test_can_parse_no_data ( avro_type , cudf_type ):
180+ def test_can_parse_no_data ( avro_type_params ):
181+ avro_type , expected_dtype = avro_type_params
170182 schema_root = {
171183 "name" : "root" ,
172184 "type" : "record" ,
@@ -177,16 +189,18 @@ def test_can_parse_no_data(avro_type, cudf_type):
177189
178190 actual = cudf_from_avro_util (schema_root , records )
179191
180- expected = cudf .DataFrame ({"prop" : cudf .Series (data = [], dtype = cudf_type )})
192+ expected = cudf .DataFrame (
193+ {"prop" : cudf .Series (data = [], dtype = expected_dtype )}
194+ )
181195
182196 assert_eq (expected , actual )
183197
184198
185199@pytest .mark .xfail (
186200 reason = "cudf avro reader is unable to parse zero-field metadata."
187201)
188- @ pytest . mark . parametrize ( "avro_type, cudf_type" , avro_type_params )
189- def test_can_parse_no_fields ( avro_type , cudf_type ):
202+ def test_can_parse_no_fields ( avro_type_params ):
203+ avro_type , expected_dtype = avro_type_params
190204 schema_root = {
191205 "name" : "root" ,
192206 "type" : "record" ,
@@ -251,26 +265,15 @@ def test_avro_decompression(set_decomp_env_vars, rows, codec):
251265 assert_eq (expected_df , got_df )
252266
253267
254- avro_logical_type_params = [
255- # (avro logical type, avro primitive type, cudf expected dtype)
256- ("date" , "int" , "datetime64[s]" ),
257- ]
258-
259-
260- @pytest .mark .parametrize (
261- "logical_type, primitive_type, expected_dtype" , avro_logical_type_params
262- )
263268@pytest .mark .parametrize ("namespace" , [None , "root_ns" ])
264- @pytest .mark .parametrize ("nullable" , [True , False ])
265- @pytest .mark .parametrize ("prepend_null" , [True , False ])
266269def test_can_detect_dtypes_from_avro_logical_type (
267- logical_type ,
268- primitive_type ,
269- expected_dtype ,
270270 namespace ,
271271 nullable ,
272272 prepend_null ,
273273):
274+ logical_type = "date"
275+ primitive_type = "int"
276+ expected_dtype = "datetime64[s]"
274277 avro_type = [{"logicalType" : logical_type , "type" : primitive_type }]
275278 if nullable :
276279 if prepend_null :
@@ -303,8 +306,6 @@ def get_days_from_epoch(date: datetime.date | None) -> int | None:
303306
304307
305308@pytest .mark .parametrize ("namespace" , [None , "root_ns" ])
306- @pytest .mark .parametrize ("nullable" , [True , False ])
307- @pytest .mark .parametrize ("prepend_null" , [True , False ])
308309@pytest .mark .skipif (
309310 PANDAS_VERSION < PANDAS_CURRENT_SUPPORTED_VERSION ,
310311 reason = "Fails in older versions of pandas (datetime(9999, ...) too large)" ,
0 commit comments