Skip to content

Commit e06f7ea

Browse files
authored
S3 source filtering (#422)
* chore: improve tests * fix: make s3 handle files as if they were object_format * chore: lint * chore: cleanup * fix: tests on windows * fix: tests on windows (parquet) * fix: tests on windows (open twice) * fix: additional testcase (no extension) * one more test case * improve test clarity * add missing cases * make changes non-breaking * more tests for more edge cases * remove vestigial code * new edge case around blank vs None suffix * new test didn't test what it said it tested * logic error in blank vs None * test-style unification * more test cleanup
1 parent 44f2070 commit e06f7ea

File tree

4 files changed

+936
-555
lines changed

4 files changed

+936
-555
lines changed

nodestream/pipeline/extractors/files.py

Lines changed: 98 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ def __init__(self, *args, **kwargs):
4242
super().__init__(*args, **kwargs, encoding="utf-8")
4343

4444

45-
@abstractmethod
46-
class ReadableFile:
45+
class ReadableFile(ABC):
46+
@abstractmethod
4747
def as_reader(self, cls: type[IOBase]) -> AsyncContextManager[IOBase]:
4848
"""Return a reader for the file.
4949
@@ -56,8 +56,8 @@ def as_reader(self, cls: type[IOBase]) -> AsyncContextManager[IOBase]:
5656
operations in the context manager on exit
5757
(i.e after the yield statement).
5858
"""
59-
raise NotImplementedError
6059

60+
@abstractmethod
6161
def path_like(self) -> Path:
6262
"""Return a Path object that represents the path to the file.
6363
@@ -69,7 +69,6 @@ def path_like(self) -> Path:
6969
This will be called to sniff the file format and compression format of
7070
the file from the suffixes of the path.
7171
"""
72-
raise NotImplementedError
7372

7473
@asynccontextmanager
7574
async def popped_suffix_tempfile(
@@ -91,6 +90,9 @@ async def popped_suffix_tempfile(
9190
with tempfile.NamedTemporaryFile(suffix="".join(new.suffixes)) as fp:
9291
yield Path(fp.name), fp
9392

93+
def __repr__(self) -> str:
94+
return str(self.path_like())
95+
9496

9597
class LocalFile(ReadableFile):
9698
"""A class that represents a local file that can be read.
@@ -391,6 +393,15 @@ def read_file_from_handle(
391393
return [safe_load(reader)]
392394

393395

396+
class TempFile(LocalFile):
397+
def __init__(self, path: Path, pointer=None, original_path: str = ""):
398+
super().__init__(path, pointer)
399+
self.original_path = original_path or self.path.name
400+
401+
def __repr__(self):
402+
return self.original_path if self.original_path else super().__repr__()
403+
404+
394405
class GzipFileFormat(CompressionCodec, alias=".gz"):
395406
"""Compression codec for Gzip files.
396407
@@ -405,13 +416,13 @@ class GzipFileFormat(CompressionCodec, alias=".gz"):
405416
"""
406417

407418
@asynccontextmanager
408-
async def decompress_file(self) -> ReadableFile:
419+
async def decompress_file(self) -> AsyncIterator[ReadableFile]:
409420
async with self.file.popped_suffix_tempfile() as (new_path, temp_file):
410421
async with self.file.as_reader(BufferedReader) as reader:
411422
with gzip.GzipFile(fileobj=reader) as decompressor:
412423
temp_file.write(decompressor.read())
413424
temp_file.seek(0)
414-
yield LocalFile(new_path, temp_file)
425+
yield TempFile(new_path, temp_file, str(self.file))
415426

416427

417428
class Bz2FileFormat(CompressionCodec, alias=".bz2"):
@@ -428,14 +439,14 @@ class Bz2FileFormat(CompressionCodec, alias=".bz2"):
428439
"""
429440

430441
@asynccontextmanager
431-
async def decompress_file(self) -> ReadableFile:
442+
async def decompress_file(self) -> AsyncIterator[ReadableFile]:
432443
async with self.file.popped_suffix_tempfile() as (new_path, temp_file):
433444
async with self.file.as_reader(BufferedReader) as reader:
434445
decompressor = bz2.BZ2Decompressor()
435446
for chunk in iter(lambda: reader.read(1024 * 1024), b""):
436447
temp_file.write(decompressor.decompress(chunk))
437448
temp_file.seek(0)
438-
yield LocalFile(new_path, temp_file)
449+
yield TempFile(new_path, temp_file, str(self.file))
439450

440451

441452
class LocalFileSource(FileSource, alias="local"):
@@ -539,7 +550,10 @@ def archive_if_required(self, key: str):
539550

540551
def path_like(self) -> Path:
541552
path = Path(self.key)
542-
return path.with_suffix(self.object_format or path.suffix)
553+
if self.object_format:
554+
return path.with_suffix(self.object_format)
555+
556+
return path
543557

544558
@asynccontextmanager
545559
async def as_reader(self, reader: IOBase):
@@ -549,6 +563,9 @@ async def as_reader(self, reader: IOBase):
549563
yield reader(BytesIO(streaming_body.read()))
550564
self.archive_if_required(self.key)
551565

566+
def __repr__(self) -> str:
567+
return f"s3://{self.bucket}/{self.key}"
568+
552569

553570
class S3FileSource(FileSource, alias="s3"):
554571
"""A class that represents a source of files stored in S3.
@@ -558,41 +575,62 @@ class S3FileSource(FileSource, alias="s3"):
558575
bucket and yield instances of S3File that can be read by the pipeline.
559576
560577
The class also has a method to archive the file after it has been read.
578+
579+
The class can also filter the objects returned by the prefix scan in the
580+
following ways:
581+
- Specifying object_format OR suffix will filter the objects via endswith
582+
- Providing strings to object_format AND suffix will:
583+
- Filter the objects via endswith with suffix (a blank string will match all)
584+
- Process each object as if it ended with the contents of object_format
561585
"""
562586

563587
@classmethod
564588
def from_file_data(
565589
cls,
566590
bucket: str,
567591
prefix: Optional[str] = None,
592+
suffix: Optional[str] = None,
568593
archive_dir: Optional[str] = None,
569594
object_format: Optional[str] = None,
570595
**aws_client_args,
571596
):
572597
return cls(
573598
bucket=bucket,
574599
prefix=prefix,
600+
suffix=suffix,
575601
archive_dir=archive_dir,
576602
object_format=object_format,
577603
s3_client=AwsClientFactory(**aws_client_args).make_client("s3"),
578604
)
579605

580606
def __init__(
581607
self,
608+
*,
582609
bucket: str,
583610
s3_client,
584611
archive_dir: Optional[str] = None,
585612
object_format: Optional[str] = None,
586613
prefix: Optional[str] = None,
614+
suffix: Optional[str] = None,
587615
):
588616
self.bucket = bucket
589617
self.s3_client = s3_client
590618
self.archive_dir = archive_dir
591619
self.object_format = object_format
592620
self.prefix = prefix or ""
621+
self.suffix = suffix
622+
623+
def object_is_not_in_archive(self, key: str) -> bool:
624+
return not key.startswith(self.archive_dir) if self.archive_dir else True
625+
626+
def key_matches_suffix(self, key: str) -> bool:
627+
if self.suffix is not None:
628+
return key.endswith(self.suffix)
593629

594-
def object_is_in_archive(self, key: str) -> bool:
595-
return key.startswith(self.archive_dir) if self.archive_dir else False
630+
if self.object_format:
631+
return key.endswith(self.object_format)
632+
633+
return True
596634

597635
def find_keys_in_bucket(self) -> Iterable[str]:
598636
# Returns all keys in the bucket that are not in the archive dir,
@@ -601,10 +639,14 @@ def find_keys_in_bucket(self) -> Iterable[str]:
601639
page_iterator = paginator.paginate(Bucket=self.bucket, Prefix=self.prefix)
602640
for page in page_iterator:
603641
keys = (obj["Key"] for obj in page.get("Contents", []))
604-
yield from filter(
605-
lambda k: not self.object_is_in_archive(k)
606-
and k.endswith(self.object_format if self.object_format else ""),
607-
keys,
642+
yield from (
643+
filter(
644+
self.key_matches_suffix,
645+
filter(
646+
self.object_is_not_in_archive,
647+
keys,
648+
),
649+
)
608650
)
609651

610652
async def get_files(self):
@@ -614,13 +656,38 @@ async def get_files(self):
614656
s3_client=self.s3_client,
615657
bucket=self.bucket,
616658
archive_dir=self.archive_dir,
617-
object_format=self.object_format,
659+
# for backwards compatibility:
660+
# -- Only override object_format if suffix is provided.
661+
# -- To treat ALL files as object_format, set suffix to "".
662+
object_format=(
663+
self.object_format
664+
if self.object_format and self.suffix is not None
665+
else None
666+
),
618667
)
619668

620669
def describe(self) -> str:
670+
data = {
671+
k: v
672+
for k, v in {
673+
"bucket": self.bucket,
674+
"archive_dir": self.archive_dir,
675+
"object_format": self.object_format,
676+
"prefix": self.prefix,
677+
"suffix": self.suffix,
678+
}.items()
679+
if v
680+
}
681+
return f"S3FileSource{data}"
682+
683+
def __eq__(self, other: Any) -> bool:
621684
return (
622-
f"S3FileSource{{bucket: {self.bucket}, prefix: {self.prefix}, "
623-
f"archive_dir: {self.archive_dir}, object_format: {self.object_format}}}"
685+
isinstance(other, S3FileSource)
686+
and self.s3_client == other.s3_client
687+
and self.bucket == other.bucket
688+
and self.prefix == other.prefix
689+
and self.archive_dir == other.archive_dir
690+
and self.object_format == other.object_format
624691
)
625692

626693

@@ -669,7 +736,8 @@ def __init__(self, file_sources: Sequence[FileSource]) -> None:
669736
self.logger = getLogger(__name__)
670737

671738
async def read_file(
672-
self, file: ReadableFile
739+
self,
740+
file: ReadableFile,
673741
) -> AsyncGenerator[JsonLikeDocument, None]:
674742
intermediaries: list[AsyncContextManager[ReadableFile]] = []
675743

@@ -688,6 +756,14 @@ async def read_file(
688756
continue
689757
except MissingFromRegistryError:
690758
pass
759+
except OSError as e:
760+
self.logger.warning(
761+
"Failed to decompress %s file. "
762+
"Please ensure the file is in the correct format.",
763+
file,
764+
extra={"exception": str(e)},
765+
)
766+
break
691767

692768
# If we didn't find a compression codec, try to find a file format
693769
# codec that can read the file. If a file format codec is found,
@@ -703,7 +779,9 @@ async def read_file(
703779
pass
704780
except Exception as e:
705781
self.logger.warning(
706-
"Failed to parse %s file. Please ensure the file is in the correct format.",
782+
"Failed to parse %s file (at path %s). "
783+
"Please ensure the file is in the correct format.",
784+
file,
707785
file.path_like(),
708786
extra={"exception": str(e)},
709787
)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "nodestream"
3-
version = "0.14.9"
3+
version = "0.14.10"
44
description = "A Fast, Declarative ETL for Graph Databases."
55
license = "GPL-3.0-only"
66
authors = [

0 commit comments

Comments
 (0)