Skip to content

Commit 955a2ee

Browse files
committed
refactor all the fs
1 parent 1bad384 commit 955a2ee

File tree

1 file changed

+77
-77
lines changed

1 file changed

+77
-77
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 77 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -83,35 +83,24 @@
8383
)
8484
from pyiceberg.expressions.visitors import visit as boolean_expression_visit
8585
from pyiceberg.io import (
86-
ADLS_ACCOUNT_HOST,
8786
ADLS_ACCOUNT_KEY,
8887
ADLS_ACCOUNT_NAME,
8988
ADLS_BLOB_STORAGE_AUTHORITY,
9089
ADLS_BLOB_STORAGE_SCHEME,
91-
ADLS_CLIENT_ID,
92-
ADLS_CLIENT_SECRET,
93-
ADLS_CONNECTION_STRING,
9490
ADLS_DFS_STORAGE_AUTHORITY,
9591
ADLS_DFS_STORAGE_SCHEME,
9692
ADLS_SAS_TOKEN,
97-
ADLS_TENANT_ID,
9893
AWS_ACCESS_KEY_ID,
9994
AWS_REGION,
10095
AWS_ROLE_ARN,
10196
AWS_ROLE_SESSION_NAME,
10297
AWS_SECRET_ACCESS_KEY,
10398
AWS_SESSION_TOKEN,
104-
GCS_ACCESS,
105-
GCS_CACHE_TIMEOUT,
106-
GCS_CONSISTENCY,
10799
GCS_DEFAULT_LOCATION,
108100
GCS_PROJECT_ID,
109-
GCS_REQUESTER_PAYS,
110101
GCS_SERVICE_HOST,
111-
GCS_SESSION_KWARGS,
112102
GCS_TOKEN,
113103
GCS_TOKEN_EXPIRES_AT_MS,
114-
GCS_VERSION_AWARE,
115104
HDFS_HOST,
116105
HDFS_KERB_TICKET,
117106
HDFS_PORT,
@@ -516,9 +505,8 @@ def _initialize_oss_fs(self) -> FileSystem:
516505

517506
properties = filter_properties(self.properties, key_predicate=lambda k: k.startswith(("s3.", "client.", "oss.")))
518507
used_keys: set[str] = set()
519-
client_kwargs = {}
520-
521508
get = lambda *keys: self._get_first_property_value_with_tracking(properties, used_keys, *keys) # noqa: E731
509+
client_kwargs = {}
522510

523511
if endpoint := get(S3_ENDPOINT, "oss.endpoint_override"):
524512
client_kwargs["endpoint_override"] = endpoint
@@ -557,9 +545,8 @@ def _initialize_s3_fs(self, netloc: Optional[str]) -> FileSystem:
557545

558546
properties = filter_properties(self.properties, key_predicate=lambda k: k.startswith(("s3.", "client.")))
559547
used_keys: set[str] = set()
560-
client_kwargs = {}
561-
562548
get = lambda *keys: self._get_first_property_value_with_tracking(properties, used_keys, *keys) # noqa: E731
549+
client_kwargs = {}
563550

564551
if endpoint := get(S3_ENDPOINT, "s3.endpoint_override"):
565552
client_kwargs["endpoint_override"] = endpoint
@@ -612,26 +599,36 @@ def _initialize_azure_fs(self) -> FileSystem:
612599

613600
from pyarrow.fs import AzureFileSystem
614601

615-
# Mapping from PyIceberg properties to AzureFileSystem parameter names
616-
property_mapping = {
617-
ADLS_ACCOUNT_NAME: "account_name",
618-
ADLS_ACCOUNT_KEY: "account_key",
619-
ADLS_BLOB_STORAGE_AUTHORITY: "blob_storage_authority",
620-
ADLS_DFS_STORAGE_AUTHORITY: "dfs_storage_authority",
621-
ADLS_BLOB_STORAGE_SCHEME: "blob_storage_scheme",
622-
ADLS_DFS_STORAGE_SCHEME: "dfs_storage_scheme",
623-
ADLS_SAS_TOKEN: "sas_token",
624-
ADLS_CLIENT_ID: "client_id",
625-
ADLS_CLIENT_SECRET: "client_secret",
626-
ADLS_TENANT_ID: "tenant_id",
627-
}
602+
properties = filter_properties(self.properties, key_predicate=lambda k: k.startswith("adls."))
603+
used_keys: set[str] = set()
604+
get = lambda *keys: self._get_first_property_value_with_tracking(properties, used_keys, *keys) # noqa: E731
605+
client_kwargs = {}
628606

629-
special_properties = {
630-
ADLS_CONNECTION_STRING,
631-
ADLS_ACCOUNT_HOST,
632-
}
607+
if account_name := get(ADLS_ACCOUNT_NAME, "adls.account_name"):
608+
client_kwargs["account_name"] = account_name
609+
610+
if account_key := get(ADLS_ACCOUNT_KEY, "adls.account_key"):
611+
client_kwargs["account_key"] = account_key
612+
613+
if blob_storage_authority := get(ADLS_BLOB_STORAGE_AUTHORITY, "adls.blob_storage_authority"):
614+
client_kwargs["blob_storage_authority"] = blob_storage_authority
615+
616+
if dfs_storage_authority := get(ADLS_DFS_STORAGE_AUTHORITY, "adls.dfs_storage_authority"):
617+
client_kwargs["dfs_storage_authority"] = dfs_storage_authority
633618

634-
client_kwargs = self._process_basic_properties(property_mapping, special_properties, "adls")
619+
if blob_storage_scheme := get(ADLS_BLOB_STORAGE_SCHEME, "adls.blob_storage_scheme"):
620+
client_kwargs["blob_storage_scheme"] = blob_storage_scheme
621+
622+
if dfs_storage_scheme := get(ADLS_DFS_STORAGE_SCHEME, "adls.dfs_storage_scheme"):
623+
client_kwargs["dfs_storage_scheme"] = dfs_storage_scheme
624+
625+
if sas_token := get(ADLS_SAS_TOKEN, "adls.sas_token"):
626+
client_kwargs["sas_token"] = sas_token
627+
628+
remaining_adls_props = {
629+
k.removeprefix("adls."): v for k, v in self.properties.items() if k.startswith("adls.") and k not in used_keys
630+
}
631+
client_kwargs = {**remaining_adls_props, **client_kwargs}
635632
return AzureFileSystem(**client_kwargs)
636633

637634
def _initialize_hdfs_fs(self, scheme: str, netloc: Optional[str]) -> FileSystem:
@@ -640,59 +637,62 @@ def _initialize_hdfs_fs(self, scheme: str, netloc: Optional[str]) -> FileSystem:
640637
if netloc:
641638
return HadoopFileSystem.from_uri(f"{scheme}://{netloc}")
642639

643-
# Mapping from PyIceberg properties to HadoopFileSystem parameter names
644-
property_mapping = {
645-
HDFS_HOST: "host",
646-
HDFS_PORT: "port",
647-
HDFS_USER: "user",
648-
HDFS_KERB_TICKET: "kerb_ticket",
649-
}
650-
651-
hdfs_kwargs = self._process_basic_properties(property_mapping, set(), "hdfs")
652-
653-
# Handle port conversion to int
654-
if "port" in hdfs_kwargs:
655-
hdfs_kwargs["port"] = int(hdfs_kwargs["port"])
640+
properties = filter_properties(self.properties, key_predicate=lambda k: k.startswith("hdfs."))
641+
used_keys: set[str] = set()
642+
get = lambda *keys: self._get_first_property_value_with_tracking(properties, used_keys, *keys) # noqa: E731
643+
client_kwargs = {}
656644

657-
return HadoopFileSystem(**hdfs_kwargs)
645+
if host := get(HDFS_HOST):
646+
client_kwargs["host"] = host
647+
if port := get(HDFS_PORT):
648+
# port should be an integer type
649+
client_kwargs["port"] = int(port)
650+
if user := get(HDFS_USER):
651+
client_kwargs["user"] = user
652+
if kerb_ticket := get(HDFS_KERB_TICKET, "hdfs.kerb_ticket"):
653+
client_kwargs["kerb_ticket"] = kerb_ticket
654+
655+
remaining_hdfs_props = {
656+
k.removeprefix("hdfs."): v for k, v in self.properties.items() if k.startswith("hdfs.") and k not in used_keys
657+
}
658+
client_kwargs = {**remaining_hdfs_props, **client_kwargs}
659+
return HadoopFileSystem(**client_kwargs)
658660

659661
def _initialize_gcs_fs(self) -> FileSystem:
660662
from pyarrow.fs import GcsFileSystem
661663

662-
# Mapping from PyIceberg properties to GcsFileSystem parameter names
663-
property_mapping = {
664-
GCS_TOKEN: "access_token",
665-
GCS_DEFAULT_LOCATION: "default_bucket_location",
666-
GCS_PROJECT_ID: "project_id",
667-
}
668-
669-
# Properties that need special handling
670-
special_properties = {
671-
GCS_TOKEN_EXPIRES_AT_MS,
672-
GCS_SERVICE_HOST,
673-
GCS_ACCESS,
674-
GCS_CONSISTENCY,
675-
GCS_CACHE_TIMEOUT,
676-
GCS_REQUESTER_PAYS,
677-
GCS_SESSION_KWARGS,
678-
GCS_VERSION_AWARE,
679-
}
680-
681-
gcs_kwargs = self._process_basic_properties(property_mapping, special_properties, "gcs")
682-
683-
if expiration := self.properties.get(GCS_TOKEN_EXPIRES_AT_MS):
684-
gcs_kwargs["credential_token_expiration"] = millis_to_datetime(int(expiration))
664+
properties = filter_properties(self.properties, key_predicate=lambda k: k.startswith("gcs."))
665+
used_keys: set[str] = set()
666+
get = lambda *keys: self._get_first_property_value_with_tracking(properties, used_keys, *keys) # noqa: E731
667+
client_kwargs = {}
685668

686-
if endpoint := self.properties.get(GCS_SERVICE_HOST):
669+
if access_token := get(GCS_TOKEN, "gcs.access_token"):
670+
client_kwargs["access_token"] = access_token
671+
if expiration := get(GCS_TOKEN_EXPIRES_AT_MS, "gcs.credential_token_expiration"):
672+
client_kwargs["credential_token_expiration"] = millis_to_datetime(int(expiration))
673+
if bucket_location := get(GCS_DEFAULT_LOCATION, "gcs.default_bucket_location"):
674+
client_kwargs["default_bucket_location"] = bucket_location
675+
if endpoint := get(GCS_SERVICE_HOST):
687676
url_parts = urlparse(endpoint)
688-
gcs_kwargs["scheme"] = url_parts.scheme
689-
gcs_kwargs["endpoint_override"] = url_parts.netloc
690-
691-
return GcsFileSystem(**gcs_kwargs)
677+
client_kwargs["scheme"] = url_parts.scheme
678+
client_kwargs["endpoint_override"] = url_parts.netloc
679+
if scheme := get("gcs.scheme") and "scheme" not in client_kwargs:
680+
client_kwargs["scheme"] = scheme
681+
if endpoint_override := get("gcs.endpoint_override") and "endpoint_override" not in client_kwargs:
682+
client_kwargs["endpoint_override"] = endpoint_override
683+
684+
if project_id := get(GCS_PROJECT_ID, "gcs.project_id"):
685+
client_kwargs["project_id"] = project_id
686+
687+
remaining_gcs_props = {
688+
k.removeprefix("gcs."): v for k, v in self.properties.items() if k.startswith("gcs.") and k not in used_keys
689+
}
690+
client_kwargs = {**remaining_gcs_props, **client_kwargs}
691+
return GcsFileSystem(**client_kwargs)
692692

693693
def _initialize_local_fs(self) -> FileSystem:
694-
local_kwargs = self._process_basic_properties({}, set(), "file")
695-
return PyArrowLocalFileSystem(**local_kwargs)
694+
client_kwargs = {k.removeprefix("file."): v for k, v in self.properties.items() if k.startswith("file.")}
695+
return PyArrowLocalFileSystem(**client_kwargs)
696696

697697
def new_input(self, location: str) -> PyArrowFile:
698698
"""Get a PyArrowFile instance to read bytes from the file at the given location.

0 commit comments

Comments
 (0)