@@ -81,10 +81,10 @@ def test_parquet_index_write(
8181
8282 for i , _ds in enumerate (ds ):
8383 idx = i % 5
84- assert len (_ds ) == 3
85- assert _ds [0 ] == pq_data ["name" ][idx ]
86- assert _ds [1 ] == pq_data ["weight" ][idx ]
87- assert _ds [2 ] == pq_data ["height" ][idx ]
84+ assert isinstance (_ds , dict )
85+ assert _ds ["name" ] == pq_data ["name" ][idx ]
86+ assert _ds ["weight" ] == pq_data ["weight" ][idx ]
87+ assert _ds ["height" ] == pq_data ["height" ][idx ]
8888
8989
9090@pytest .mark .skipif (condition = sys .platform == "win32" , reason = "Fails on windows and test gets cancelled" )
@@ -168,7 +168,9 @@ def test_get_parquet_indexer_cls(pq_url, cls, expectation, monkeypatch, fsspec_m
168168@pytest .mark .usefixtures ("clean_pq_index_cache" )
169169@patch ("litdata.utilities.parquet._HF_HUB_AVAILABLE" , True )
170170@patch ("litdata.streaming.downloader._HF_HUB_AVAILABLE" , True )
171- def test_stream_hf_parquet_dataset (monkeypatch , huggingface_hub_fs_mock , pq_data ):
171+ @pytest .mark .parametrize (("pre_load_chunk" ), [False , True ])
172+ @pytest .mark .parametrize (("low_memory" ), [False , True ])
173+ def test_stream_hf_parquet_dataset (monkeypatch , huggingface_hub_fs_mock , pq_data , pre_load_chunk , low_memory ):
172174 hf_url = "hf://datasets/some_org/some_repo/some_path"
173175
174176 # Test case 1: Invalid item_loader
@@ -180,27 +182,18 @@ def test_stream_hf_parquet_dataset(monkeypatch, huggingface_hub_fs_mock, pq_data
180182 assert len (ds ) == 25 # 5 datasets for 5 loops
181183 for i , _ds in enumerate (ds ):
182184 idx = i % 5
183- assert len (_ds ) == 3
184- assert _ds [0 ] == pq_data ["name" ][idx ]
185- assert _ds [1 ] == pq_data ["weight" ][idx ]
186- assert _ds [2 ] == pq_data ["height" ][idx ]
187-
188- # Test case 3: Streaming with ParquetLoader as item_loader and low_memory=False
189- ds = StreamingDataset (hf_url , item_loader = ParquetLoader (low_memory = False ))
190- assert len (ds ) == 25
191- for i , _ds in enumerate (ds ):
192- idx = i % 5
193- assert len (_ds ) == 3
194- assert _ds [0 ] == pq_data ["name" ][idx ]
195- assert _ds [1 ] == pq_data ["weight" ][idx ]
196- assert _ds [2 ] == pq_data ["height" ][idx ]
197-
198- # Test case 4: Streaming with ParquetLoader and low_memory=True
199- ds = StreamingDataset (hf_url , item_loader = ParquetLoader (low_memory = True ))
185+ assert isinstance (_ds , dict )
186+ assert _ds ["name" ] == pq_data ["name" ][idx ]
187+ assert _ds ["weight" ] == pq_data ["weight" ][idx ]
188+ assert _ds ["height" ] == pq_data ["height" ][idx ]
189+
190+ # Test case 3: Streaming with passing item_loader
191+ print ("pre_load_chunk" , pre_load_chunk , "low_memory" , low_memory )
192+ ds = StreamingDataset (hf_url , item_loader = ParquetLoader (pre_load_chunk , low_memory ))
200193 assert len (ds ) == 25
201194 for i , _ds in enumerate (ds ):
202195 idx = i % 5
203- assert len (_ds ) == 3
204- assert _ds [0 ] == pq_data ["name" ][idx ]
205- assert _ds [1 ] == pq_data ["weight" ][idx ]
206- assert _ds [2 ] == pq_data ["height" ][idx ]
196+ assert isinstance (_ds , dict )
197+ assert _ds ["name" ] == pq_data ["name" ][idx ]
198+ assert _ds ["weight" ] == pq_data ["weight" ][idx ]
199+ assert _ds ["height" ] == pq_data ["height" ][idx ]
0 commit comments