@@ -160,15 +160,15 @@ def default_interleaved_field_map() -> dict[str, str]:
160160 def __post_init__ (self ) -> None :
161161 """Validate reader configuration."""
162162 super ().__post_init__ ()
163- if self . sample_format not in _SUPPORTED_SAMPLE_FORMATS :
164- msg = f"Unsupported sample_format=' { self .sample_format } '. Expected one of: auto, simple, interleaved"
165- raise ValueError ( msg )
166- if self . modalities_to_load not in _SUPPORTED_MODALITIES_TO_LOAD :
167- msg = f"Unsupported modalities_to_load=' { self . modalities_to_load } '. Expected one of: all, image, text"
168- raise ValueError ( msg )
169- if self . error_handling not in _SUPPORTED_ERROR_HANDLING :
170- msg = f"Unsupported error_handling ='{ self . error_handling } '. Expected one of: raise, skip, log "
171- raise ValueError (msg )
163+ for field_name , value , supported in (
164+ ( " sample_format" , self .sample_format , _SUPPORTED_SAMPLE_FORMATS ),
165+ ( "modalities_to_load" , self . modalities_to_load , _SUPPORTED_MODALITIES_TO_LOAD ),
166+ ( "error_handling" , self . error_handling , _SUPPORTED_ERROR_HANDLING ),
167+ ):
168+ if value not in supported :
169+ options = ", " . join ( sorted ( supported ))
170+ msg = f"Unsupported { field_name } ='{ value } '. Expected one of: { options } "
171+ raise ValueError (msg )
172172 default_map = self .default_interleaved_field_map ()
173173 unknown = sorted (set (self .interleaved_field_map or {}) - set (default_map ))
174174 if unknown :
@@ -244,25 +244,22 @@ def _rows_from_text_member(
244244 source : RowSource ,
245245 ) -> list [dict [str , object ]]:
246246 if suffix == ".json" :
247- parsed = self ._maybe_rows_from_interleaved_json_member (payload , source , state , member_name )
248- if parsed is not None :
249- return parsed
247+ if payload is None :
248+ msg = f"JSON member '{ member_name } ' missing payload bytes"
249+ raise WebDatasetMemberParseError (msg )
250+ try :
251+ return self ._rows_from_interleaved_json (payload , source , state )
252+ except WebDatasetMemberParseError :
253+ if self .sample_format == "interleaved" :
254+ raise
250255 if not self ._loads_modality ("text" ):
251256 return []
252257 sid , position = self ._next_sample_and_position (state .sample_counters , member_name , "text" )
253258 text_content = self ._decode_text_payload (payload , member_name )
254259 content_type = "application/json" if suffix == ".json" else "text/plain"
255260 if suffix == ".json" :
256261 _record_metadata_row (state , sid , text_content or "{}" )
257- return [
258- self ._text_row (
259- sid = sid ,
260- position = position ,
261- source_shard = source .source_shard ,
262- content_type = content_type ,
263- text_content = text_content ,
264- )
265- ]
262+ return [self ._text_row (sid = sid , position = position , source_shard = source .source_shard , content_type = content_type , text_content = text_content )]
266263
267264 def _rows_from_interleaved_json (
268265 self ,
@@ -275,11 +272,11 @@ def _rows_from_interleaved_json(
275272 sample_payload = dict (decoded )
276273 sample_payload .pop (self .interleaved_field_map ["segments" ], None )
277274 _record_metadata_row (state , sample_id , self ._json_or_none (sample_payload ) or "{}" )
278- rows : list [dict [str , object ]] = []
279275 field_map = self .interleaved_field_map
280276 modality_field = field_map ["modality" ]
281277 text_field = field_map ["text" ]
282278 content_key_field = field_map ["content_key" ]
279+ rows : list [dict [str , object ]] = []
283280 for idx , segment in enumerate (segments ):
284281 modality = _required_segment_str (segment , modality_field )
285282 if modality not in _SUPPORTED_INTERLEAVED_MODALITIES :
@@ -288,53 +285,31 @@ def _rows_from_interleaved_json(
288285 "in WebDatasetReaderStage (supported: text, image)"
289286 )
290287 raise WebDatasetMemberParseError (msg )
291- if self ._loads_modality (modality ):
292- if modality == "text" :
293- rows .append (
294- self ._text_row (
295- sid = sample_id ,
296- position = idx ,
297- source_shard = source .source_shard ,
298- content_type = "text/plain" ,
299- text_content = _required_segment_str (segment , text_field ),
300- element_metadata_json = self ._json_or_none (segment ),
301- )
302- )
303- else :
304- rows .append (
305- self ._image_row (
306- sid = sample_id ,
307- position = idx ,
308- source = source ,
309- content_key = _required_segment_str (segment , content_key_field ),
310- element_metadata_json = self ._json_or_none (segment ),
311- )
288+ if not self ._loads_modality (modality ):
289+ continue
290+ if modality == "text" :
291+ rows .append (
292+ self ._text_row (
293+ sid = sample_id ,
294+ position = idx ,
295+ source_shard = source .source_shard ,
296+ content_type = "text/plain" ,
297+ text_content = _required_segment_str (segment , text_field ),
298+ element_metadata_json = self ._json_or_none (segment ),
312299 )
300+ )
301+ continue
302+ rows .append (
303+ self ._image_row (
304+ sid = sample_id ,
305+ position = idx ,
306+ source = source ,
307+ content_key = _required_segment_str (segment , content_key_field ),
308+ element_metadata_json = self ._json_or_none (segment ),
309+ )
310+ )
313311 return rows
314312
315- def _maybe_rows_from_interleaved_json_member (
316- self ,
317- payload : bytes | None ,
318- source : RowSource ,
319- state : RowBuildState ,
320- member_name : str ,
321- ) -> list [dict [str , object ]] | None :
322- if payload is None :
323- msg = f"JSON member '{ member_name } ' missing payload bytes"
324- raise WebDatasetMemberParseError (msg )
325- try :
326- parsed = self ._rows_from_interleaved_json (payload , source , state )
327- except WebDatasetMemberParseError :
328- if self .sample_format == "interleaved" :
329- raise
330- return None
331- except KeyError as err :
332- if self .sample_format == "interleaved" :
333- msg = f"Interleaved JSON missing required field: { err } "
334- raise WebDatasetMemberParseError (msg ) from err
335- return None
336- return parsed
337-
338313 @staticmethod
339314 def _decode_text_payload (payload : bytes | None , member_name : str ) -> str :
340315 if payload is None :
@@ -368,15 +343,7 @@ def _rows_from_binary_member(
368343 if not self ._loads_modality (modality ):
369344 return []
370345 sid , position = self ._next_sample_and_position (state .sample_counters , member_name , modality )
371- return [
372- self ._image_row (
373- sid = sid ,
374- position = position ,
375- source = source ,
376- content_key = member_name ,
377- binary_content = payload if self .load_binary else None ,
378- )
379- ]
346+ return [self ._image_row (sid = sid , position = position , source = source , content_key = member_name , binary_content = payload if self .load_binary else None )]
380347
381348 def _next_sample_and_position (
382349 self ,
0 commit comments