diff --git a/flytekit/types/structured/basic_dfs.py b/flytekit/types/structured/basic_dfs.py index 2b91e3422e..82680d2787 100644 --- a/flytekit/types/structured/basic_dfs.py +++ b/flytekit/types/structured/basic_dfs.py @@ -12,6 +12,7 @@ from flytekit.models.literals import StructuredDatasetMetadata from flytekit.models.types import StructuredDatasetType from flytekit.types.structured.structured_dataset import ( + ABFS, GCS, LOCAL, PARQUET, @@ -106,7 +107,7 @@ def decode( # Don't override default protocol -for protocol in [LOCAL, S3, GCS]: +for protocol in [LOCAL, S3, GCS, ABFS]: StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler(protocol), default_for_type=False) StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler(protocol), default_for_type=False) StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler(protocol), default_for_type=False) diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index d37b9aff37..cdb26a87c2 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -39,6 +39,7 @@ # Protocols BIGQUERY = "bq" S3 = "s3" +ABFS = "abfs" GCS = "gs" LOCAL = "/" diff --git a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/__init__.py b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/__init__.py index 92451a4db4..e8d88f26c3 100644 --- a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/__init__.py +++ b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/__init__.py @@ -26,7 +26,7 @@ from flytekit import StructuredDatasetTransformerEngine, logger from flytekit.configuration import internal -from flytekit.types.structured.structured_dataset import GCS, S3 +from flytekit.types.structured.structured_dataset import ABFS, GCS, S3 from .arrow import ArrowToParquetEncodingHandler, ParquetToArrowDecodingHandler from .pandas import PandasToParquetEncodingHandler, ParquetToPandasDecodingHandler @@ -41,6 +41,9 @@ def _register(protocol: str): StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler(protocol), True, True) +if importlib.util.find_spec("adlfs"): + _register(ABFS) + if importlib.util.find_spec("s3fs"): _register(S3) diff --git a/plugins/flytekit-data-fsspec/setup.py b/plugins/flytekit-data-fsspec/setup.py index c43d7c1481..3678a0b518 100644 --- a/plugins/flytekit-data-fsspec/setup.py +++ b/plugins/flytekit-data-fsspec/setup.py @@ -22,6 +22,7 @@ install_requires=plugin_requires, extras_require={ # https://github.com/fsspec/filesystem_spec/blob/master/setup.py#L36 + "abfs": ["adlfs>=2022.2.0"], "aws": ["s3fs>=2021.7.0"], "gcp": ["gcsfs>=2021.7.0"], }, diff --git a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py index f79e97efb5..06b1127504 100644 --- a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py +++ b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py @@ -7,6 +7,7 @@ from flytekit.models.literals import StructuredDatasetMetadata from flytekit.models.types import StructuredDatasetType from flytekit.types.structured.structured_dataset import ( + ABFS, GCS, LOCAL, PARQUET, @@ -62,7 +63,7 @@ def decode( return pl.read_parquet(path) -for protocol in [LOCAL, S3, GCS]: +for protocol in [LOCAL, S3, GCS, ABFS]: StructuredDatasetTransformerEngine.register( PolarsDataFrameToParquetEncodingHandler(protocol), default_for_type=False ) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py index 651860d4b7..9fef590bcc 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py @@ -7,7 +7,11 @@ from flytekit.models.literals import StructuredDatasetMetadata from flytekit.models.types import StructuredDatasetType from flytekit.types.structured.structured_dataset import ( + ABFS, + GCS, + LOCAL, PARQUET, + S3, StructuredDataset, StructuredDatasetDecoder, StructuredDatasetEncoder, @@ -48,6 +52,6 @@ def decode( return user_ctx.spark_session.read.parquet(flyte_value.uri) -for protocol in ["/", "s3"]: +for protocol in [LOCAL, S3, GCS, ABFS]: StructuredDatasetTransformerEngine.register(SparkToParquetEncodingHandler(protocol), default_for_type=False) StructuredDatasetTransformerEngine.register(ParquetToSparkDecodingHandler(protocol), default_for_type=False)