Skip to content

Commit e39e2c8

Browse files
committed
refactor
1 parent f965370 commit e39e2c8

File tree

1 file changed

+69
-74
lines changed

1 file changed

+69
-74
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 69 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,6 @@
130130
S3_ROLE_SESSION_NAME,
131131
S3_SECRET_ACCESS_KEY,
132132
S3_SESSION_TOKEN,
133-
S3_SIGNER_ENDPOINT,
134-
S3_SIGNER_URI,
135133
FileIO,
136134
InputFile,
137135
InputStream,
@@ -197,7 +195,11 @@
197195
from pyiceberg.utils.datetime import millis_to_datetime
198196
from pyiceberg.utils.decimal import unscaled_to_decimal
199197
from pyiceberg.utils.deprecated import deprecation_message
200-
from pyiceberg.utils.properties import get_first_property_value, properties_with_prefix, property_as_bool, property_as_int
198+
from pyiceberg.utils.properties import (
199+
filter_properties,
200+
property_as_bool,
201+
property_as_int,
202+
)
201203
from pyiceberg.utils.singleton import Singleton
202204
from pyiceberg.utils.truncate import truncate_upper_bound_binary_string, truncate_upper_bound_text_string
203205

@@ -461,51 +463,76 @@ def _process_basic_properties(
461463

462464
return client_kwargs
463465

466+
def _get_first_property_value_with_tracking(self, props: Properties, used_keys: set[str], *keys: str) -> Optional[Any]:
467+
"""Tracks all candidate keys and returns the first value found."""
468+
used_keys.update(keys)
469+
for key in keys:
470+
if key in props:
471+
return props[key]
472+
return None
473+
464474
def _initialize_oss_fs(self) -> FileSystem:
465475
from pyarrow.fs import S3FileSystem
466476

477+
properties = filter_properties(self.properties, key_predicate=lambda k: k.startswith(("oss.", "client.")))
478+
used_keys: set[str] = set()
467479
client_kwargs = {}
468-
if endpoint := get_first_property_value(self.properties, S3_ENDPOINT, "oss.endpoint_override"):
480+
481+
get = lambda *keys: self._get_first_property_value_with_tracking(properties, used_keys, *keys) # noqa: E731
482+
483+
if endpoint := get(S3_ENDPOINT, "oss.endpoint_override"):
469484
client_kwargs["endpoint_override"] = endpoint
470-
if access_key := get_first_property_value(self.properties, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID, "oss.access_key"):
485+
if access_key := get(S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID, "oss.access_key"):
471486
client_kwargs["access_key"] = access_key
472-
if secret_key := get_first_property_value(self.properties, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY, "oss.secret_key"):
487+
if secret_key := get(S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY, "oss.secret_key"):
473488
client_kwargs["secret_key"] = secret_key
474-
if session_token := get_first_property_value(self.properties, S3_SESSION_TOKEN, AWS_SESSION_TOKEN, "oss.session_token"):
489+
if session_token := get(S3_SESSION_TOKEN, AWS_SESSION_TOKEN, "oss.session_token"):
475490
client_kwargs["session_token"] = session_token
476-
if region := get_first_property_value(self.properties, S3_REGION, AWS_REGION, "oss.region"):
491+
if region := get(S3_REGION, AWS_REGION, "oss.region"):
477492
client_kwargs["region"] = region
478-
# Check for force_virtual_addressing in order of preference, defaulting to True if not found
479-
if force_virtual_addressing := get_first_property_value(
480-
self.properties, S3_FORCE_VIRTUAL_ADDRESSING, "oss.force_virtual_addressing"
481-
):
493+
# Check for force_virtual_addressing in order of preference. For oss FS, defaulting to True if not found
494+
if force_virtual_addressing := get(S3_FORCE_VIRTUAL_ADDRESSING, "oss.force_virtual_addressing"):
482495
if isinstance(force_virtual_addressing, str): # S3_FORCE_VIRTUAL_ADDRESSING's value can be a string
483496
force_virtual_addressing = strtobool(force_virtual_addressing)
484497
client_kwargs["force_virtual_addressing"] = force_virtual_addressing
485498
else:
486499
client_kwargs["force_virtual_addressing"] = True
487-
if proxy_uri := get_first_property_value(self.properties, S3_PROXY_URI, "oss.proxy_options"):
500+
if proxy_uri := get(S3_PROXY_URI, "oss.proxy_options"):
488501
client_kwargs["proxy_options"] = proxy_uri
489-
if connect_timeout := get_first_property_value(self.properties, S3_CONNECT_TIMEOUT, "oss.connect_timeout"):
502+
if connect_timeout := get(S3_CONNECT_TIMEOUT, "oss.connect_timeout"):
490503
client_kwargs["connect_timeout"] = float(connect_timeout)
491-
if request_timeout := get_first_property_value(self.properties, S3_REQUEST_TIMEOUT, "oss.request_timeout"):
504+
if request_timeout := get(S3_REQUEST_TIMEOUT, "oss.request_timeout"):
492505
client_kwargs["request_timeout"] = float(request_timeout)
493-
if role_arn := get_first_property_value(self.properties, S3_ROLE_ARN, AWS_ROLE_ARN, "oss.role_arn"):
506+
if role_arn := get(S3_ROLE_ARN, AWS_ROLE_ARN, "oss.role_arn"):
494507
client_kwargs["role_arn"] = role_arn
495-
if session_name := get_first_property_value(
496-
self.properties, S3_ROLE_SESSION_NAME, AWS_ROLE_SESSION_NAME, "oss.session_name"
497-
):
508+
if session_name := get(S3_ROLE_SESSION_NAME, AWS_ROLE_SESSION_NAME, "oss.session_name"):
498509
client_kwargs["session_name"] = session_name
499510

500-
oss_properties = properties_with_prefix(self.properties, prefix="oss.")
501-
client_kwargs = {**oss_properties, **client_kwargs}
511+
remaining_oss_props = {
512+
k.removeprefix("oss."): v for k, v in self.properties.items() if k.startswith("oss.") and k not in used_keys
513+
}
514+
client_kwargs = {**remaining_oss_props, **client_kwargs}
502515
return S3FileSystem(**client_kwargs)
503516

504517
def _initialize_s3_fs(self, netloc: Optional[str]) -> FileSystem:
505518
from pyarrow.fs import S3FileSystem
506519

507-
provided_region = get_first_property_value(self.properties, S3_REGION, AWS_REGION)
520+
properties = filter_properties(self.properties, key_predicate=lambda k: k.startswith(("s3.", "client.")))
521+
used_keys: set[str] = set()
522+
client_kwargs = {}
523+
524+
get = lambda *keys: self._get_first_property_value_with_tracking(properties, used_keys, *keys) # noqa: E731
508525

526+
if endpoint := get(S3_ENDPOINT, "s3.endpoint_override"):
527+
client_kwargs["endpoint_override"] = endpoint
528+
if access_key := get(S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID, "s3.access_key"):
529+
client_kwargs["access_key"] = access_key
530+
if secret_key := get(S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY, "s3.secret_key"):
531+
client_kwargs["secret_key"] = secret_key
532+
if session_token := get(S3_SESSION_TOKEN, AWS_SESSION_TOKEN, "s3.session_token"):
533+
client_kwargs["session_token"] = session_token
534+
535+
provided_region = get(S3_REGION, AWS_REGION)
509536
# Do this when we don't provide the region at all, or when we explicitly enable it
510537
if provided_region is None or property_as_bool(self.properties, S3_RESOLVE_REGION, False) is True:
511538
# Resolve region from netloc(bucket), fallback to user-provided region
@@ -518,66 +545,34 @@ def _initialize_s3_fs(self, netloc: Optional[str]) -> FileSystem:
518545
)
519546
else:
520547
bucket_region = provided_region
521-
522-
# Mapping from PyIceberg properties to S3FileSystem parameter names
523-
property_mapping = {
524-
S3_ENDPOINT: "endpoint_override",
525-
S3_PROXY_URI: "proxy_options",
526-
S3_CONNECT_TIMEOUT: "connect_timeout",
527-
S3_REQUEST_TIMEOUT: "request_timeout",
528-
S3_RETRY_STRATEGY_IMPL: "retry_strategy",
529-
}
530-
531-
# Properties that need special handling
532-
special_properties = {
533-
S3_ACCESS_KEY_ID,
534-
S3_SECRET_ACCESS_KEY,
535-
S3_SESSION_TOKEN,
536-
S3_ROLE_ARN,
537-
S3_ROLE_SESSION_NAME,
538-
S3_RESOLVE_REGION,
539-
S3_REGION,
540-
S3_RETRY_STRATEGY_IMPL,
541-
S3_CONNECT_TIMEOUT,
542-
S3_REQUEST_TIMEOUT,
543-
S3_SIGNER_ENDPOINT,
544-
S3_SIGNER_URI,
545-
S3_FORCE_VIRTUAL_ADDRESSING,
546-
}
547-
548-
client_kwargs = self._process_basic_properties(property_mapping, special_properties, "s3")
549548
client_kwargs["region"] = bucket_region
549+
used_keys.add(S3_RESOLVE_REGION)
550550

551-
if S3_ACCESS_KEY_ID in self.properties or AWS_ACCESS_KEY_ID in self.properties:
552-
client_kwargs["access_key"] = get_first_property_value(self.properties, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID)
553-
554-
if S3_SECRET_ACCESS_KEY in self.properties or AWS_SECRET_ACCESS_KEY in self.properties:
555-
client_kwargs["secret_key"] = get_first_property_value(self.properties, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY)
556-
557-
if S3_SESSION_TOKEN in self.properties or AWS_SESSION_TOKEN in self.properties:
558-
client_kwargs["session_token"] = get_first_property_value(self.properties, S3_SESSION_TOKEN, AWS_SESSION_TOKEN)
559-
560-
if S3_ROLE_ARN in self.properties or AWS_ROLE_ARN in self.properties:
561-
client_kwargs["role_arn"] = get_first_property_value(self.properties, S3_ROLE_ARN, AWS_ROLE_ARN)
562-
563-
if S3_ROLE_SESSION_NAME in self.properties or AWS_ROLE_SESSION_NAME in self.properties:
564-
client_kwargs["session_name"] = get_first_property_value(self.properties, S3_ROLE_SESSION_NAME, AWS_ROLE_SESSION_NAME)
551+
if force_virtual_addressing := get(S3_FORCE_VIRTUAL_ADDRESSING, "s3.force_virtual_addressing"):
552+
if isinstance(force_virtual_addressing, str): # S3_FORCE_VIRTUAL_ADDRESSING's value can be a string
553+
force_virtual_addressing = strtobool(force_virtual_addressing)
554+
client_kwargs["force_virtual_addressing"] = force_virtual_addressing
565555

566-
if connect_timeout := self.properties.get(S3_CONNECT_TIMEOUT):
556+
if proxy_uri := get(S3_PROXY_URI, "s3.proxy_options"):
557+
client_kwargs["proxy_options"] = proxy_uri
558+
if connect_timeout := get(S3_CONNECT_TIMEOUT, "s3.connect_timeout"):
567559
client_kwargs["connect_timeout"] = float(connect_timeout)
568-
569-
if request_timeout := self.properties.get(S3_REQUEST_TIMEOUT):
560+
if request_timeout := get(S3_REQUEST_TIMEOUT, "s3.request_timeout"):
570561
client_kwargs["request_timeout"] = float(request_timeout)
571-
572-
if self.properties.get(S3_FORCE_VIRTUAL_ADDRESSING) is not None:
573-
client_kwargs["force_virtual_addressing"] = property_as_bool(self.properties, S3_FORCE_VIRTUAL_ADDRESSING, False)
562+
if role_arn := get(S3_ROLE_ARN, AWS_ROLE_ARN, "s3.role_arn"):
563+
client_kwargs["role_arn"] = role_arn
564+
if session_name := get(S3_ROLE_SESSION_NAME, AWS_ROLE_SESSION_NAME, "s3.session_name"):
565+
client_kwargs["session_name"] = session_name
574566

575567
# Handle retry strategy special case
576-
if (retry_strategy_impl := self.properties.get(S3_RETRY_STRATEGY_IMPL)) and (
577-
retry_instance := _import_retry_strategy(retry_strategy_impl)
578-
):
579-
client_kwargs["retry_strategy"] = retry_instance
568+
if retry_strategy_impl := get(S3_RETRY_STRATEGY_IMPL, "s3.retry_strategy"):
569+
if retry_instance := _import_retry_strategy(retry_strategy_impl):
570+
client_kwargs["retry_strategy"] = retry_instance
580571

572+
remaining_s3_props = {
573+
k.removeprefix("s3."): v for k, v in self.properties.items() if k.startswith("s3.") and k not in used_keys
574+
}
575+
client_kwargs = {**remaining_s3_props, **client_kwargs}
581576
return S3FileSystem(**client_kwargs)
582577

583578
def _initialize_azure_fs(self) -> FileSystem:

0 commit comments

Comments
 (0)