@@ -42,8 +42,8 @@ def __init__(self, *args, **kwargs):
42
42
super ().__init__ (* args , ** kwargs , encoding = "utf-8" )
43
43
44
44
45
- @ abstractmethod
46
- class ReadableFile :
45
+ class ReadableFile ( ABC ):
46
+ @ abstractmethod
47
47
def as_reader (self , cls : type [IOBase ]) -> AsyncContextManager [IOBase ]:
48
48
"""Return a reader for the file.
49
49
@@ -56,8 +56,8 @@ def as_reader(self, cls: type[IOBase]) -> AsyncContextManager[IOBase]:
56
56
operations in the context manager on exit
57
57
(i.e after the yield statement).
58
58
"""
59
- raise NotImplementedError
60
59
60
+ @abstractmethod
61
61
def path_like (self ) -> Path :
62
62
"""Return a Path object that represents the path to the file.
63
63
@@ -69,7 +69,6 @@ def path_like(self) -> Path:
69
69
This will be called to sniff the file format and compression format of
70
70
the file from the suffixes of the path.
71
71
"""
72
- raise NotImplementedError
73
72
74
73
@asynccontextmanager
75
74
async def popped_suffix_tempfile (
@@ -91,6 +90,9 @@ async def popped_suffix_tempfile(
91
90
with tempfile .NamedTemporaryFile (suffix = "" .join (new .suffixes )) as fp :
92
91
yield Path (fp .name ), fp
93
92
93
+ def __repr__ (self ) -> str :
94
+ return str (self .path_like ())
95
+
94
96
95
97
class LocalFile (ReadableFile ):
96
98
"""A class that represents a local file that can be read.
@@ -391,6 +393,15 @@ def read_file_from_handle(
391
393
return [safe_load (reader )]
392
394
393
395
396
+ class TempFile (LocalFile ):
397
+ def __init__ (self , path : Path , pointer = None , original_path : str = "" ):
398
+ super ().__init__ (path , pointer )
399
+ self .original_path = original_path or self .path .name
400
+
401
+ def __repr__ (self ):
402
+ return self .original_path if self .original_path else super ().__repr__ ()
403
+
404
+
394
405
class GzipFileFormat (CompressionCodec , alias = ".gz" ):
395
406
"""Compression codec for Gzip files.
396
407
@@ -405,13 +416,13 @@ class GzipFileFormat(CompressionCodec, alias=".gz"):
405
416
"""
406
417
407
418
@asynccontextmanager
408
- async def decompress_file (self ) -> ReadableFile :
419
+ async def decompress_file (self ) -> AsyncIterator [ ReadableFile ] :
409
420
async with self .file .popped_suffix_tempfile () as (new_path , temp_file ):
410
421
async with self .file .as_reader (BufferedReader ) as reader :
411
422
with gzip .GzipFile (fileobj = reader ) as decompressor :
412
423
temp_file .write (decompressor .read ())
413
424
temp_file .seek (0 )
414
- yield LocalFile (new_path , temp_file )
425
+ yield TempFile (new_path , temp_file , str ( self . file ) )
415
426
416
427
417
428
class Bz2FileFormat (CompressionCodec , alias = ".bz2" ):
@@ -428,14 +439,14 @@ class Bz2FileFormat(CompressionCodec, alias=".bz2"):
428
439
"""
429
440
430
441
@asynccontextmanager
431
- async def decompress_file (self ) -> ReadableFile :
442
+ async def decompress_file (self ) -> AsyncIterator [ ReadableFile ] :
432
443
async with self .file .popped_suffix_tempfile () as (new_path , temp_file ):
433
444
async with self .file .as_reader (BufferedReader ) as reader :
434
445
decompressor = bz2 .BZ2Decompressor ()
435
446
for chunk in iter (lambda : reader .read (1024 * 1024 ), b"" ):
436
447
temp_file .write (decompressor .decompress (chunk ))
437
448
temp_file .seek (0 )
438
- yield LocalFile (new_path , temp_file )
449
+ yield TempFile (new_path , temp_file , str ( self . file ) )
439
450
440
451
441
452
class LocalFileSource (FileSource , alias = "local" ):
@@ -539,7 +550,10 @@ def archive_if_required(self, key: str):
539
550
540
551
def path_like (self ) -> Path :
541
552
path = Path (self .key )
542
- return path .with_suffix (self .object_format or path .suffix )
553
+ if self .object_format :
554
+ return path .with_suffix (self .object_format )
555
+
556
+ return path
543
557
544
558
@asynccontextmanager
545
559
async def as_reader (self , reader : IOBase ):
@@ -549,6 +563,9 @@ async def as_reader(self, reader: IOBase):
549
563
yield reader (BytesIO (streaming_body .read ()))
550
564
self .archive_if_required (self .key )
551
565
566
+ def __repr__ (self ) -> str :
567
+ return f"s3://{ self .bucket } /{ self .key } "
568
+
552
569
553
570
class S3FileSource (FileSource , alias = "s3" ):
554
571
"""A class that represents a source of files stored in S3.
@@ -558,41 +575,62 @@ class S3FileSource(FileSource, alias="s3"):
558
575
bucket and yield instances of S3File that can be read by the pipeline.
559
576
560
577
The class also has a method to archive the file after it has been read.
578
+
579
+ The class can also filter the objects returned by the prefix scan in the
580
+ following ways:
581
+ - Specifying object_format OR suffix will filter the objects via endswith
582
+ - Providing strings to object_format AND suffix will:
583
+ - Filter the objects via endswith with suffix (a blank string will match all)
584
+ - Process each object as if it ended with the contents of object_format
561
585
"""
562
586
563
587
@classmethod
564
588
def from_file_data (
565
589
cls ,
566
590
bucket : str ,
567
591
prefix : Optional [str ] = None ,
592
+ suffix : Optional [str ] = None ,
568
593
archive_dir : Optional [str ] = None ,
569
594
object_format : Optional [str ] = None ,
570
595
** aws_client_args ,
571
596
):
572
597
return cls (
573
598
bucket = bucket ,
574
599
prefix = prefix ,
600
+ suffix = suffix ,
575
601
archive_dir = archive_dir ,
576
602
object_format = object_format ,
577
603
s3_client = AwsClientFactory (** aws_client_args ).make_client ("s3" ),
578
604
)
579
605
580
606
def __init__ (
581
607
self ,
608
+ * ,
582
609
bucket : str ,
583
610
s3_client ,
584
611
archive_dir : Optional [str ] = None ,
585
612
object_format : Optional [str ] = None ,
586
613
prefix : Optional [str ] = None ,
614
+ suffix : Optional [str ] = None ,
587
615
):
588
616
self .bucket = bucket
589
617
self .s3_client = s3_client
590
618
self .archive_dir = archive_dir
591
619
self .object_format = object_format
592
620
self .prefix = prefix or ""
621
+ self .suffix = suffix
622
+
623
+ def object_is_not_in_archive (self , key : str ) -> bool :
624
+ return not key .startswith (self .archive_dir ) if self .archive_dir else True
625
+
626
+ def key_matches_suffix (self , key : str ) -> bool :
627
+ if self .suffix is not None :
628
+ return key .endswith (self .suffix )
593
629
594
- def object_is_in_archive (self , key : str ) -> bool :
595
- return key .startswith (self .archive_dir ) if self .archive_dir else False
630
+ if self .object_format :
631
+ return key .endswith (self .object_format )
632
+
633
+ return True
596
634
597
635
def find_keys_in_bucket (self ) -> Iterable [str ]:
598
636
# Returns all keys in the bucket that are not in the archive dir,
@@ -601,10 +639,14 @@ def find_keys_in_bucket(self) -> Iterable[str]:
601
639
page_iterator = paginator .paginate (Bucket = self .bucket , Prefix = self .prefix )
602
640
for page in page_iterator :
603
641
keys = (obj ["Key" ] for obj in page .get ("Contents" , []))
604
- yield from filter (
605
- lambda k : not self .object_is_in_archive (k )
606
- and k .endswith (self .object_format if self .object_format else "" ),
607
- keys ,
642
+ yield from (
643
+ filter (
644
+ self .key_matches_suffix ,
645
+ filter (
646
+ self .object_is_not_in_archive ,
647
+ keys ,
648
+ ),
649
+ )
608
650
)
609
651
610
652
async def get_files (self ):
@@ -614,13 +656,38 @@ async def get_files(self):
614
656
s3_client = self .s3_client ,
615
657
bucket = self .bucket ,
616
658
archive_dir = self .archive_dir ,
617
- object_format = self .object_format ,
659
+ # for backwards compatibility:
660
+ # -- Only override object_format if suffix is provided.
661
+ # -- To treat ALL files as object_format, set suffix to "".
662
+ object_format = (
663
+ self .object_format
664
+ if self .object_format and self .suffix is not None
665
+ else None
666
+ ),
618
667
)
619
668
620
669
def describe (self ) -> str :
670
+ data = {
671
+ k : v
672
+ for k , v in {
673
+ "bucket" : self .bucket ,
674
+ "archive_dir" : self .archive_dir ,
675
+ "object_format" : self .object_format ,
676
+ "prefix" : self .prefix ,
677
+ "suffix" : self .suffix ,
678
+ }.items ()
679
+ if v
680
+ }
681
+ return f"S3FileSource{ data } "
682
+
683
+ def __eq__ (self , other : Any ) -> bool :
621
684
return (
622
- f"S3FileSource{{bucket: { self .bucket } , prefix: { self .prefix } , "
623
- f"archive_dir: { self .archive_dir } , object_format: { self .object_format } }}"
685
+ isinstance (other , S3FileSource )
686
+ and self .s3_client == other .s3_client
687
+ and self .bucket == other .bucket
688
+ and self .prefix == other .prefix
689
+ and self .archive_dir == other .archive_dir
690
+ and self .object_format == other .object_format
624
691
)
625
692
626
693
@@ -669,7 +736,8 @@ def __init__(self, file_sources: Sequence[FileSource]) -> None:
669
736
self .logger = getLogger (__name__ )
670
737
671
738
async def read_file (
672
- self , file : ReadableFile
739
+ self ,
740
+ file : ReadableFile ,
673
741
) -> AsyncGenerator [JsonLikeDocument , None ]:
674
742
intermediaries : list [AsyncContextManager [ReadableFile ]] = []
675
743
@@ -688,6 +756,14 @@ async def read_file(
688
756
continue
689
757
except MissingFromRegistryError :
690
758
pass
759
+ except OSError as e :
760
+ self .logger .warning (
761
+ "Failed to decompress %s file. "
762
+ "Please ensure the file is in the correct format." ,
763
+ file ,
764
+ extra = {"exception" : str (e )},
765
+ )
766
+ break
691
767
692
768
# If we didn't find a compression codec, try to find a file format
693
769
# codec that can read the file. If a file format codec is found,
@@ -703,7 +779,9 @@ async def read_file(
703
779
pass
704
780
except Exception as e :
705
781
self .logger .warning (
706
- "Failed to parse %s file. Please ensure the file is in the correct format." ,
782
+ "Failed to parse %s file (at path %s). "
783
+ "Please ensure the file is in the correct format." ,
784
+ file ,
707
785
file .path_like (),
708
786
extra = {"exception" : str (e )},
709
787
)
0 commit comments