diff --git a/mkdocs/docs/configuration.md b/mkdocs/docs/configuration.md index f4fbe0c8d8..d3a22bbbff 100644 --- a/mkdocs/docs/configuration.md +++ b/mkdocs/docs/configuration.md @@ -213,13 +213,47 @@ PyIceberg uses [S3FileSystem](https://arrow.apache.org/docs/python/generated/pya ### PyArrow - +#### PyArrow Specific Properties | Key | Example | Description | | ------------------------------- | ------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | pyarrow.use-large-types-on-read | True | Use large PyArrow types i.e. [large_string](https://arrow.apache.org/docs/python/generated/pyarrow.large_string.html), [large_binary](https://arrow.apache.org/docs/python/generated/pyarrow.large_binary.html) and [large_list](https://arrow.apache.org/docs/python/generated/pyarrow.large_list.html) field types on table scans. The default value is True. | - +#### Advanced FileSystem Configuration + +When using `PyArrowFileIO`, you can **pass additional configuration properties directly to the underlying PyArrow filesystem implementations**. This feature enables you to use any PyArrow filesystem option without requiring explicit PyIceberg support. + +PyIceberg first processes its own supported properties for each filesystem, then passes any remaining properties with the appropriate prefix directly to the PyArrow filesystem constructor. This approach ensures: + +1. PyIceberg's built-in properties take precedence +2. Advanced PyArrow options are automatically supported +3. New PyArrow features become available immediately + +##### Configuration Format + +Use this format for additional properties: + +```txt +{fs_scheme}.{parameter_name}={value} +``` + +Where: + +- `{fs_scheme}` is the filesystem scheme (e.g., `s3`, `hdfs`, `gcs`, `adls`, `file`) +- `{parameter_name}` must match the exact parameter name expected by the PyArrow filesystem constructor +- `{value}` must be the correct type expected by the underlying filesystem (string, integer, boolean, etc.) + +##### Supported Prefixes and FileSystems + +| Property Prefix | FileSystem | Example | Description | +|-----------------|------------------------------------------------------------------------------------------------------|-----------------------------|-----------------------------------------------------| +| `s3.` | [S3FileSystem](https://arrow.apache.org/docs/python/generated/pyarrow.fs.S3FileSystem.html) | `s3.load_frequency=900` | Passed as `load_frequency=900` to S3FileSystem | +| `hdfs.` | [HadoopFileSystem](https://arrow.apache.org/docs/python/generated/pyarrow.fs.HadoopFileSystem.html) | `hdfs.replication=3` | Passed as `replication=3` to HadoopFileSystem | +| `gcs.` | [GcsFileSystem](https://arrow.apache.org/docs/python/generated/pyarrow.fs.GcsFileSystem.html) | `gcs.project_id=test` | Passed as `project_id='test'` to GcsFileSystem | +| `adls.` | [AzureFileSystem](https://arrow.apache.org/docs/python/generated/pyarrow.fs.AzureFileSystem.html) | `adls.account_name=foo` | Passed as `account_name=foo` to AzureFileSystem | +| `file.` | [LocalFileSystem](https://arrow.apache.org/docs/python/generated/pyarrow.fs.LocalFileSystem.html) | `file.use_mmap=true` | Passed as `use_mmap=True` to LocalFileSystem | + +**Note:** Refer to the PyArrow documentation for each filesystem to understand the available parameters and their expected types. Property values are passed directly to PyArrow, so they must match the exact parameter names and types expected by the filesystem constructors. ## Location Providers diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 2797371028..b3cb02a0db 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -97,6 +97,7 @@ AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN, GCS_DEFAULT_LOCATION, + GCS_PROJECT_ID, GCS_SERVICE_HOST, GCS_TOKEN, GCS_TOKEN_EXPIRES_AT_MS, @@ -182,7 +183,14 @@ from pyiceberg.utils.datetime import millis_to_datetime from pyiceberg.utils.decimal import unscaled_to_decimal from pyiceberg.utils.deprecated import deprecation_message -from pyiceberg.utils.properties import get_first_property_value, property_as_bool, property_as_int +from pyiceberg.utils.properties import ( + convert_str_to_bool, + filter_properties, + get_first_property_value_with_tracking, + properties_with_prefix, + property_as_bool, + property_as_int, +) from pyiceberg.utils.singleton import Singleton from pyiceberg.utils.truncate import truncate_upper_bound_binary_string, truncate_upper_bound_text_string @@ -420,84 +428,132 @@ def _initialize_fs(self, scheme: str, netloc: Optional[str] = None) -> FileSyste else: raise ValueError(f"Unrecognized filesystem type in URI: {scheme}") + def _resolve_s3_region( + self, provided_region: Optional[str], resolve_region_override: Any, bucket: Optional[str] + ) -> Optional[str]: + """ + Resolve S3 region based on configuration and optional bucket-based resolution. + + Args: + provided_region: Region explicitly provided in configuration + resolve_region_override: Whether to resolve region from bucket (can be string or bool) + bucket: Bucket name for region resolution + + Returns: + The resolved region string, or None if no region could be determined + """ + # Handle resolve_region_override conversion + should_resolve_region = False + if resolve_region_override is not None: + should_resolve_region = convert_str_to_bool(resolve_region_override) + + # If no region provided or explicit resolve requested, try to resolve from bucket + if provided_region is None or should_resolve_region: + resolved_region = _cached_resolve_s3_region(bucket=bucket) + + # Warn if resolved region differs from provided region + if provided_region is not None and resolved_region and resolved_region != provided_region: + logger.warning( + f"PyArrow FileIO overriding S3 bucket region for bucket {bucket}: " + f"provided region {provided_region}, actual region {resolved_region}" + ) + + return resolved_region or provided_region + + return provided_region + def _initialize_oss_fs(self) -> FileSystem: from pyarrow.fs import S3FileSystem - client_kwargs: Dict[str, Any] = { - "endpoint_override": self.properties.get(S3_ENDPOINT), - "access_key": get_first_property_value(self.properties, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID), - "secret_key": get_first_property_value(self.properties, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY), - "session_token": get_first_property_value(self.properties, S3_SESSION_TOKEN, AWS_SESSION_TOKEN), - "region": get_first_property_value(self.properties, S3_REGION, AWS_REGION), - "force_virtual_addressing": property_as_bool(self.properties, S3_FORCE_VIRTUAL_ADDRESSING, True), - } - - if proxy_uri := self.properties.get(S3_PROXY_URI): + properties = filter_properties(self.properties, key_predicate=lambda k: k.startswith(("s3.", "client."))) + used_keys: set[str] = set() + + def get_property_with_tracking(*keys: str) -> str | None: + return get_first_property_value_with_tracking(properties, used_keys, *keys) + + client_kwargs: Properties = {} + + if endpoint := get_property_with_tracking(S3_ENDPOINT, "s3.endpoint_override"): + client_kwargs["endpoint_override"] = endpoint + if access_key := get_property_with_tracking(S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID, "s3.access_key"): + client_kwargs["access_key"] = access_key + if secret_key := get_property_with_tracking(S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY, "s3.secret_key"): + client_kwargs["secret_key"] = secret_key + if session_token := get_property_with_tracking(S3_SESSION_TOKEN, AWS_SESSION_TOKEN, "s3.session_token"): + client_kwargs["session_token"] = session_token + if region := get_property_with_tracking(S3_REGION, AWS_REGION): + client_kwargs["region"] = region + _ = get_property_with_tracking( + S3_RESOLVE_REGION + ) # this feature is only available for S3. Use `get` here so it does not get passed down to the S3FileSystem constructor + if force_virtual_addressing := get_property_with_tracking(S3_FORCE_VIRTUAL_ADDRESSING, "s3.force_virtual_addressing"): + client_kwargs["force_virtual_addressing"] = convert_str_to_bool(force_virtual_addressing) + else: + # For Alibaba OSS protocol, default to True + client_kwargs["force_virtual_addressing"] = True + if proxy_uri := get_property_with_tracking(S3_PROXY_URI, "s3.proxy_options"): client_kwargs["proxy_options"] = proxy_uri - - if connect_timeout := self.properties.get(S3_CONNECT_TIMEOUT): + if connect_timeout := get_property_with_tracking(S3_CONNECT_TIMEOUT, "s3.connect_timeout"): client_kwargs["connect_timeout"] = float(connect_timeout) - - if request_timeout := self.properties.get(S3_REQUEST_TIMEOUT): + if request_timeout := get_property_with_tracking(S3_REQUEST_TIMEOUT, "s3.request_timeout"): client_kwargs["request_timeout"] = float(request_timeout) - - if role_arn := get_first_property_value(self.properties, S3_ROLE_ARN, AWS_ROLE_ARN): + if role_arn := get_property_with_tracking(S3_ROLE_ARN, AWS_ROLE_ARN, "s3.role_arn"): client_kwargs["role_arn"] = role_arn - - if session_name := get_first_property_value(self.properties, S3_ROLE_SESSION_NAME, AWS_ROLE_SESSION_NAME): + if session_name := get_property_with_tracking(S3_ROLE_SESSION_NAME, AWS_ROLE_SESSION_NAME, "s3.session_name"): client_kwargs["session_name"] = session_name + # get the rest of the properties with the `s3.` prefix that are not already evaluated + remaining_s3_props = properties_with_prefix({k: v for k, v in self.properties.items() if k not in used_keys}, "s3.") + client_kwargs = {**remaining_s3_props, **client_kwargs} return S3FileSystem(**client_kwargs) def _initialize_s3_fs(self, netloc: Optional[str]) -> FileSystem: from pyarrow.fs import S3FileSystem - provided_region = get_first_property_value(self.properties, S3_REGION, AWS_REGION) + properties = filter_properties(self.properties, key_predicate=lambda k: k.startswith(("s3.", "client."))) + used_keys: set[str] = set() - # Do this when we don't provide the region at all, or when we explicitly enable it - if provided_region is None or property_as_bool(self.properties, S3_RESOLVE_REGION, False) is True: - # Resolve region from netloc(bucket), fallback to user-provided region - # Only supported by buckets hosted by S3 - bucket_region = _cached_resolve_s3_region(bucket=netloc) or provided_region - if provided_region is not None and bucket_region != provided_region: - logger.warning( - f"PyArrow FileIO overriding S3 bucket region for bucket {netloc}: " - f"provided region {provided_region}, actual region {bucket_region}" - ) - else: - bucket_region = provided_region - - client_kwargs: Dict[str, Any] = { - "endpoint_override": self.properties.get(S3_ENDPOINT), - "access_key": get_first_property_value(self.properties, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID), - "secret_key": get_first_property_value(self.properties, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY), - "session_token": get_first_property_value(self.properties, S3_SESSION_TOKEN, AWS_SESSION_TOKEN), - "region": bucket_region, - } + def get_property_with_tracking(*keys: str) -> str | None: + return get_first_property_value_with_tracking(properties, used_keys, *keys) - if proxy_uri := self.properties.get(S3_PROXY_URI): - client_kwargs["proxy_options"] = proxy_uri + client_kwargs: Properties = {} - if connect_timeout := self.properties.get(S3_CONNECT_TIMEOUT): - client_kwargs["connect_timeout"] = float(connect_timeout) + # Handle S3 region configuration with optional auto-resolution + client_kwargs["region"] = self._resolve_s3_region( + provided_region=get_property_with_tracking(S3_REGION, AWS_REGION), + resolve_region_override=get_property_with_tracking(S3_RESOLVE_REGION), + bucket=netloc, + ) - if request_timeout := self.properties.get(S3_REQUEST_TIMEOUT): + if endpoint := get_property_with_tracking(S3_ENDPOINT, "s3.endpoint_override"): + client_kwargs["endpoint_override"] = endpoint + if access_key := get_property_with_tracking(S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID, "s3.access_key"): + client_kwargs["access_key"] = access_key + if secret_key := get_property_with_tracking(S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY, "s3.secret_key"): + client_kwargs["secret_key"] = secret_key + if session_token := get_property_with_tracking(S3_SESSION_TOKEN, AWS_SESSION_TOKEN, "s3.session_token"): + client_kwargs["session_token"] = session_token + if proxy_uri := get_property_with_tracking(S3_PROXY_URI, "s3.proxy_options"): + client_kwargs["proxy_options"] = proxy_uri + if connect_timeout := get_property_with_tracking(S3_CONNECT_TIMEOUT, "s3.connect_timeout"): + client_kwargs["connect_timeout"] = float(connect_timeout) + if request_timeout := get_property_with_tracking(S3_REQUEST_TIMEOUT, "s3.request_timeout"): client_kwargs["request_timeout"] = float(request_timeout) - - if role_arn := get_first_property_value(self.properties, S3_ROLE_ARN, AWS_ROLE_ARN): + if role_arn := get_property_with_tracking(S3_ROLE_ARN, AWS_ROLE_ARN, "s3.role_arn"): client_kwargs["role_arn"] = role_arn - - if session_name := get_first_property_value(self.properties, S3_ROLE_SESSION_NAME, AWS_ROLE_SESSION_NAME): + if session_name := get_property_with_tracking(S3_ROLE_SESSION_NAME, AWS_ROLE_SESSION_NAME, "s3.session_name"): client_kwargs["session_name"] = session_name - if self.properties.get(S3_FORCE_VIRTUAL_ADDRESSING) is not None: - client_kwargs["force_virtual_addressing"] = property_as_bool(self.properties, S3_FORCE_VIRTUAL_ADDRESSING, False) - - if (retry_strategy_impl := self.properties.get(S3_RETRY_STRATEGY_IMPL)) and ( - retry_instance := _import_retry_strategy(retry_strategy_impl) - ): - client_kwargs["retry_strategy"] = retry_instance + if force_virtual_addressing := get_property_with_tracking(S3_FORCE_VIRTUAL_ADDRESSING, "s3.force_virtual_addressing"): + client_kwargs["force_virtual_addressing"] = convert_str_to_bool(force_virtual_addressing) + # Handle retry strategy special case + if retry_strategy_impl := get_property_with_tracking(S3_RETRY_STRATEGY_IMPL, "s3.retry_strategy"): + if retry_instance := _import_retry_strategy(retry_strategy_impl): + client_kwargs["retry_strategy"] = retry_instance + # get the rest of the properties with the `s3.` prefix that are not already evaluated + remaining_s3_props = properties_with_prefix({k: v for k, v in self.properties.items() if k not in used_keys}, "s3.") + client_kwargs = {**remaining_s3_props, **client_kwargs} return S3FileSystem(**client_kwargs) def _initialize_azure_fs(self) -> FileSystem: @@ -512,68 +568,110 @@ def _initialize_azure_fs(self) -> FileSystem: from pyarrow.fs import AzureFileSystem - client_kwargs: Dict[str, str] = {} + properties = filter_properties(self.properties, key_predicate=lambda k: k.startswith("adls.")) + used_keys: set[str] = set() + + def get_property_with_tracking(*keys: str) -> str | None: + return get_first_property_value_with_tracking(properties, used_keys, *keys) - if account_name := self.properties.get(ADLS_ACCOUNT_NAME): + client_kwargs: Properties = {} + + if account_name := get_property_with_tracking(ADLS_ACCOUNT_NAME, "adls.account_name"): client_kwargs["account_name"] = account_name - if account_key := self.properties.get(ADLS_ACCOUNT_KEY): + if account_key := get_property_with_tracking(ADLS_ACCOUNT_KEY, "adls.account_key"): client_kwargs["account_key"] = account_key - if blob_storage_authority := self.properties.get(ADLS_BLOB_STORAGE_AUTHORITY): + if blob_storage_authority := get_property_with_tracking(ADLS_BLOB_STORAGE_AUTHORITY, "adls.blob_storage_authority"): client_kwargs["blob_storage_authority"] = blob_storage_authority - if dfs_storage_authority := self.properties.get(ADLS_DFS_STORAGE_AUTHORITY): + if dfs_storage_authority := get_property_with_tracking(ADLS_DFS_STORAGE_AUTHORITY, "adls.dfs_storage_authority"): client_kwargs["dfs_storage_authority"] = dfs_storage_authority - if blob_storage_scheme := self.properties.get(ADLS_BLOB_STORAGE_SCHEME): + if blob_storage_scheme := get_property_with_tracking(ADLS_BLOB_STORAGE_SCHEME, "adls.blob_storage_scheme"): client_kwargs["blob_storage_scheme"] = blob_storage_scheme - if dfs_storage_scheme := self.properties.get(ADLS_DFS_STORAGE_SCHEME): + if dfs_storage_scheme := get_property_with_tracking(ADLS_DFS_STORAGE_SCHEME, "adls.dfs_storage_scheme"): client_kwargs["dfs_storage_scheme"] = dfs_storage_scheme - if sas_token := self.properties.get(ADLS_SAS_TOKEN): + if sas_token := get_property_with_tracking(ADLS_SAS_TOKEN, "adls.sas_token"): client_kwargs["sas_token"] = sas_token + # get the rest of the properties with the `adls.` prefix that are not already evaluated + remaining_adls_props = properties_with_prefix({k: v for k, v in self.properties.items() if k not in used_keys}, "adls.") + client_kwargs = {**remaining_adls_props, **client_kwargs} return AzureFileSystem(**client_kwargs) def _initialize_hdfs_fs(self, scheme: str, netloc: Optional[str]) -> FileSystem: from pyarrow.fs import HadoopFileSystem - hdfs_kwargs: Dict[str, Any] = {} if netloc: return HadoopFileSystem.from_uri(f"{scheme}://{netloc}") - if host := self.properties.get(HDFS_HOST): - hdfs_kwargs["host"] = host - if port := self.properties.get(HDFS_PORT): + + properties = filter_properties(self.properties, key_predicate=lambda k: k.startswith("hdfs.")) + used_keys: set[str] = set() + + def get_property_with_tracking(*keys: str) -> str | None: + return get_first_property_value_with_tracking(properties, used_keys, *keys) + + client_kwargs: Properties = {} + + if host := get_property_with_tracking(HDFS_HOST): + client_kwargs["host"] = host + if port := get_property_with_tracking(HDFS_PORT): # port should be an integer type - hdfs_kwargs["port"] = int(port) - if user := self.properties.get(HDFS_USER): - hdfs_kwargs["user"] = user - if kerb_ticket := self.properties.get(HDFS_KERB_TICKET): - hdfs_kwargs["kerb_ticket"] = kerb_ticket + client_kwargs["port"] = int(port) + if user := get_property_with_tracking(HDFS_USER): + client_kwargs["user"] = user + if kerb_ticket := get_property_with_tracking(HDFS_KERB_TICKET, "hdfs.kerb_ticket"): + client_kwargs["kerb_ticket"] = kerb_ticket - return HadoopFileSystem(**hdfs_kwargs) + # get the rest of the properties with the `hdfs.` prefix that are not already evaluated + remaining_hdfs_props = properties_with_prefix({k: v for k, v in self.properties.items() if k not in used_keys}, "hdfs.") + client_kwargs = {**remaining_hdfs_props, **client_kwargs} + return HadoopFileSystem(**client_kwargs) def _initialize_gcs_fs(self) -> FileSystem: from pyarrow.fs import GcsFileSystem - gcs_kwargs: Dict[str, Any] = {} - if access_token := self.properties.get(GCS_TOKEN): - gcs_kwargs["access_token"] = access_token - if expiration := self.properties.get(GCS_TOKEN_EXPIRES_AT_MS): - gcs_kwargs["credential_token_expiration"] = millis_to_datetime(int(expiration)) - if bucket_location := self.properties.get(GCS_DEFAULT_LOCATION): - gcs_kwargs["default_bucket_location"] = bucket_location - if endpoint := self.properties.get(GCS_SERVICE_HOST): + properties = filter_properties(self.properties, key_predicate=lambda k: k.startswith("gcs.")) + used_keys: set[str] = set() + + def get_property_with_tracking(*keys: str) -> str | None: + return get_first_property_value_with_tracking(properties, used_keys, *keys) + + client_kwargs: Properties = {} + + if access_token := get_property_with_tracking(GCS_TOKEN, "gcs.access_token"): + client_kwargs["access_token"] = access_token + if expiration := get_property_with_tracking(GCS_TOKEN_EXPIRES_AT_MS, "gcs.credential_token_expiration"): + client_kwargs["credential_token_expiration"] = millis_to_datetime(int(expiration)) + if bucket_location := get_property_with_tracking(GCS_DEFAULT_LOCATION, "gcs.default_bucket_location"): + client_kwargs["default_bucket_location"] = bucket_location + if endpoint := get_property_with_tracking(GCS_SERVICE_HOST): url_parts = urlparse(endpoint) - gcs_kwargs["scheme"] = url_parts.scheme - gcs_kwargs["endpoint_override"] = url_parts.netloc + client_kwargs["scheme"] = url_parts.scheme + client_kwargs["endpoint_override"] = url_parts.netloc + if ( + scheme := get_property_with_tracking("gcs.scheme") + ) and "scheme" not in client_kwargs: # GCS_SERVICE_HOST takes precedence + client_kwargs["scheme"] = scheme + if ( + endpoint_override := get_property_with_tracking("gcs.endpoint_override") + ) and "endpoint_override" not in client_kwargs: # GCS_SERVICE_HOST takes precedence + client_kwargs["endpoint_override"] = endpoint_override + + if project_id := get_property_with_tracking(GCS_PROJECT_ID, "gcs.project_id"): + client_kwargs["project_id"] = project_id - return GcsFileSystem(**gcs_kwargs) + # get the rest of the properties with the `gcs.` prefix that are not already evaluated + remaining_gcs_props = properties_with_prefix({k: v for k, v in self.properties.items() if k not in used_keys}, "gcs.") + client_kwargs = {**remaining_gcs_props, **client_kwargs} + return GcsFileSystem(**client_kwargs) def _initialize_local_fs(self) -> FileSystem: - return PyArrowLocalFileSystem() + client_kwargs = properties_with_prefix(self.properties, "file.") + return PyArrowLocalFileSystem(**client_kwargs) def new_input(self, location: str) -> PyArrowFile: """Get a PyArrowFile instance to read bytes from the file at the given location. diff --git a/pyiceberg/utils/properties.py b/pyiceberg/utils/properties.py index 2b228f6e41..8809d78d8e 100644 --- a/pyiceberg/utils/properties.py +++ b/pyiceberg/utils/properties.py @@ -17,6 +17,7 @@ from typing import ( Any, + Callable, Dict, Optional, ) @@ -68,6 +69,13 @@ def property_as_bool( return default +def convert_str_to_bool(value: Any) -> bool: + """Convert string or other value to boolean, handling string representations properly.""" + if isinstance(value, str): + return strtobool(value) + return bool(value) + + def get_first_property_value( properties: Properties, *property_names: str, @@ -78,8 +86,60 @@ def get_first_property_value( return None +def get_first_property_value_with_tracking(props: Properties, used_keys: set[str], *keys: str) -> Optional[Any]: + """Tracks all candidate keys and returns the first value found.""" + used_keys.update(keys) + for key in keys: + if key in props: + return props[key] + return None + + def get_header_properties( properties: Properties, ) -> Properties: header_prefix_len = len(HEADER_PREFIX) return {key[header_prefix_len:]: value for key, value in properties.items() if key.startswith(HEADER_PREFIX)} + + +def properties_with_prefix( + properties: Properties, + prefix: str, +) -> Properties: + """ + Return subset of provided map with keys matching the provided prefix. Matching is case-sensitive and the matching prefix is removed from the keys in returned map. + + Args: + properties: input map + prefix: prefix to choose keys from input map + + Returns: + subset of input map with keys starting with provided prefix and prefix trimmed out + """ + if not properties: + return {} + + return {key[len(prefix) :]: value for key, value in properties.items() if key.startswith(prefix)} + + +def filter_properties( + properties: Properties, + key_predicate: Callable[[str], bool], +) -> Properties: + """ + Filter the properties map by the provided key predicate. + + Args: + properties: input map + key_predicate: predicate to choose keys from input map + + Returns: + subset of input map with keys satisfying the predicate + """ + if not properties: + return {} + + if key_predicate is None: + raise ValueError("Invalid key predicate: None") + + return {key: value for key, value in properties.items() if key_predicate(key)} diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 4f121ba3bc..236e5b3b16 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -412,6 +412,144 @@ def test_pyarrow_unified_session_properties() -> None: ) +def test_s3_pyarrow_specific_properties() -> None: + pyarrow_file_io = PyArrowFileIO( + { + "s3.endpoint": "http://localhost:9000", + "s3.access-key-id": "user", + "s3.secret-access-key": "pass", + "s3.load_frequency": 900, + "s3.region": "us-east-1", + } + ) + + # Test that valid PyArrow properties work without error + with patch("pyarrow.fs.S3FileSystem") as mock_s3fs: + pyarrow_file_io._initialize_s3_fs(None) + + # Verify that properties are passed through correctly + mock_s3fs.assert_called_with( + endpoint_override="http://localhost:9000", + access_key="user", + secret_key="pass", + load_frequency=900, + region="us-east-1", + ) + + # Test that invalid PyArrow properties raise TypeError + with pytest.raises(TypeError) as exc_info: + pyarrow_file_io = PyArrowFileIO( + { + "s3.endpoint": "http://localhost:9000", + "s3.access-key-id": "admin", + "s3.secret-access-key": "password", + "s3.unknown_property": "val", + } + ) + pyarrow_file_io._initialize_s3_fs(None) + + assert "got an unexpected keyword argument 'unknown_property'" in str(exc_info.value) + + +def test_iceberg_cred_properties_take_precedence() -> None: + session_properties: Properties = { + "s3.access-key-id": "explicit-access-key", + "s3.secret-access-key": "explicit-secret-key", + "s3.region": "us-east-1", + # These should be ignored because explicit properties take precedence + "s3.access_key": "passed-access-key", + "s3.secret_key": "passed-secret-key", + } + + with patch("pyarrow.fs.S3FileSystem") as mock_s3fs: + s3_fileio = PyArrowFileIO(properties=session_properties) + + s3_fileio._initialize_s3_fs(None) + + # Assert that explicit properties are used from above + mock_s3fs.assert_called_with( + access_key="explicit-access-key", + secret_key="explicit-secret-key", + region="us-east-1", + ) + + +def test_hdfs_pyarrow_specific_properties() -> None: + hdfs_properties: Properties = { + "hdfs.host": "localhost", + "hdfs.port": "9000", + "hdfs.user": "user", + "hdfs.kerberos_ticket": "test", + # non iceberg properties + "hdfs.replication": 3, + "hdfs.block_size": 134217728, + } + + with patch("pyarrow.fs.HadoopFileSystem") as mock_hdfs: + hdfs_fileio = PyArrowFileIO(properties=hdfs_properties) + hdfs_fileio._initialize_hdfs_fs("hdfs", None) + + mock_hdfs.assert_called_with( + host="localhost", + port=9000, + user="user", + kerb_ticket="test", + replication=3, + block_size=134217728, + ) + + +def test_local_filesystem_pyarrow_specific_properties() -> None: + local_properties: Properties = {"file.buffer_size": 8192, "file.use_mmap": True} + + with patch("pyiceberg.io.pyarrow.PyArrowLocalFileSystem") as mock_local: + local_fileio = PyArrowFileIO(properties=local_properties) + local_fileio._initialize_local_fs() + + mock_local.assert_called_with( + buffer_size=8192, + use_mmap=True, + ) + + +def test_gcs_pyarrow_specific_properties() -> None: + pyarrow_file_io = PyArrowFileIO( + { + "gcs.project-id": "project", + "gcs.oauth2.token": "test", + "gcs.default-bucket-location": "loc", + } + ) + + with patch("pyarrow.fs.GcsFileSystem") as mock_gcs: + pyarrow_file_io._initialize_gcs_fs() + + mock_gcs.assert_called_with( + project_id="project", + access_token="test", + default_bucket_location="loc", + ) + + +@skip_if_pyarrow_too_old +def test_pyarrow_adls_pyarrow_specific_properties() -> None: + pyarrow_file_io = PyArrowFileIO( + {"adls.account-name": "user", "adls.account-key": "pass", "adls.sas-token": "testsas", "adls.client_id": "client"} + ) + + # Test that valid PyArrow properties work without error + with patch("pyarrow.fs.AzureFileSystem") as mock_azure: + pyarrow_file_io._initialize_azure_fs() + + # Verify that properties are passed through correctly + mock_azure.assert_called_with( + account_name="user", + account_key="pass", + sas_token="testsas", + client_id="client", + ) + + def test_schema_to_pyarrow_schema_include_field_ids(table_schema_nested: Schema) -> None: actual = schema_to_pyarrow(table_schema_nested) expected = """foo: large_string @@ -2609,6 +2747,17 @@ def _s3_region_map(bucket: str) -> str: assert pyarrow_file_io.new_input(f"oss://{bucket_region[0]}/path/to/file")._filesystem.region == user_provided_region +def test_pyarrow_filesystem_properties() -> None: + pyarrow_file_io = PyArrowFileIO({"s3.load_frequency": 200}) + pyarrow_file_io.new_input("s3://bucket/path/to/file") + + with pytest.raises(TypeError) as exc_info: + pyarrow_file_io = PyArrowFileIO({"s3.unknown_property": "val"}) + pyarrow_file_io.new_input("s3://bucket/path/to/file") + + assert "got an unexpected keyword argument 'unknown_property'" in str(exc_info.value) + + def test_pyarrow_io_multi_fs() -> None: pyarrow_file_io = PyArrowFileIO({"s3.region": "ap-southeast-1"}) @@ -2638,3 +2787,22 @@ def test_retry_strategy_not_found() -> None: io = PyArrowFileIO(properties={S3_RETRY_STRATEGY_IMPL: "pyiceberg.DoesNotExist"}) with pytest.warns(UserWarning, match="Could not initialize S3 retry strategy: pyiceberg.DoesNotExist"): io.new_input("s3://bucket/path/to/file") + + +def test_hdfs_filesystem_properties_with_netloc() -> None: + """Test that HDFS filesystem uses from_uri when netloc is provided.""" + hdfs_properties: Properties = { + "hdfs.host": "localhost", + "hdfs.port": "9000", + "hdfs.user": "testuser", + } + + with patch("pyarrow.fs.HadoopFileSystem") as mock_hdfs: + hdfs_fileio = PyArrowFileIO(properties=hdfs_properties) + filename = str(uuid.uuid4()) + + # When netloc is provided, it should use from_uri instead of properties + hdfs_fileio.new_input(location=f"hdfs://testcluster:8020/{filename}") + + # Verify that from_uri is called instead of constructor with properties + mock_hdfs.from_uri.assert_called_with("hdfs://testcluster:8020")