diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index c1d9ac838..da3582766 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -33,9 +33,11 @@ jobs: fail-fast: false matrix: python-version: + - "3.9" - "3.10" - "3.11" - "3.12" + - "3.13" toolchain: - "stable" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b548ff18f..abcfcf321 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,7 +22,7 @@ repos: - id: actionlint-docker - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.3.0 + rev: v0.9.10 hooks: # Run the linter. - id: ruff diff --git a/benchmarks/tpch/tpch.py b/benchmarks/tpch/tpch.py index fb86b12b6..bfb9ac398 100644 --- a/benchmarks/tpch/tpch.py +++ b/benchmarks/tpch/tpch.py @@ -59,13 +59,13 @@ def bench(data_path, query_path): end = time.time() time_millis = (end - start) * 1000 total_time_millis += time_millis - print("setup,{}".format(round(time_millis, 1))) - results.write("setup,{}\n".format(round(time_millis, 1))) + print(f"setup,{round(time_millis, 1)}") + results.write(f"setup,{round(time_millis, 1)}\n") results.flush() # run queries for query in range(1, 23): - with open("{}/q{}.sql".format(query_path, query)) as f: + with open(f"{query_path}/q{query}.sql") as f: text = f.read() tmp = text.split(";") queries = [] @@ -83,14 +83,14 @@ def bench(data_path, query_path): end = time.time() time_millis = (end - start) * 1000 total_time_millis += time_millis - print("q{},{}".format(query, round(time_millis, 1))) - results.write("q{},{}\n".format(query, round(time_millis, 1))) + print(f"q{query},{round(time_millis, 1)}") + results.write(f"q{query},{round(time_millis, 1)}\n") results.flush() except Exception as e: print("query", query, "failed", e) - print("total,{}".format(round(total_time_millis, 1))) - results.write("total,{}\n".format(round(total_time_millis, 1))) + print(f"total,{round(total_time_millis, 1)}") + results.write(f"total,{round(total_time_millis, 1)}\n") if __name__ == "__main__": diff --git a/dev/release/check-rat-report.py b/dev/release/check-rat-report.py index d3dd7c5dd..0c9f4c326 100644 --- a/dev/release/check-rat-report.py +++ b/dev/release/check-rat-report.py @@ -29,7 +29,7 @@ exclude_globs_filename = sys.argv[1] xml_filename = sys.argv[2] -globs = [line.strip() for line in open(exclude_globs_filename, "r")] +globs = [line.strip() for line in open(exclude_globs_filename)] tree = ET.parse(xml_filename) root = tree.getroot() diff --git a/dev/release/generate-changelog.py b/dev/release/generate-changelog.py index 2564eea86..e30e2def2 100755 --- a/dev/release/generate-changelog.py +++ b/dev/release/generate-changelog.py @@ -26,15 +26,11 @@ def print_pulls(repo_name, title, pulls): if len(pulls) > 0: - print("**{}:**".format(title)) + print(f"**{title}:**") print() for pull, commit in pulls: - url = "https://github.com/{}/pull/{}".format(repo_name, pull.number) - print( - "- {} [#{}]({}) ({})".format( - pull.title, pull.number, url, commit.author.login - ) - ) + url = f"https://github.com/{repo_name}/pull/{pull.number}" + print(f"- {pull.title} [#{pull.number}]({url}) ({commit.author.login})") print() diff --git a/docs/source/conf.py b/docs/source/conf.py index 2e5a41339..c82a189e0 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -73,7 +73,7 @@ autoapi_python_class_content = "both" -def autoapi_skip_member_fn(app, what, name, obj, skip, options): +def autoapi_skip_member_fn(app, what, name, obj, skip, options): # noqa: ARG001 skip_contents = [ # Re-exports ("class", "datafusion.DataFrame"), diff --git a/examples/python-udwf.py b/examples/python-udwf.py index 7d39dc1b8..98d118bf2 100644 --- a/examples/python-udwf.py +++ b/examples/python-udwf.py @@ -59,7 +59,7 @@ def __init__(self, alpha: float) -> None: def supports_bounded_execution(self) -> bool: return True - def get_range(self, idx: int, num_rows: int) -> tuple[int, int]: + def get_range(self, idx: int, num_rows: int) -> tuple[int, int]: # noqa: ARG002 # Override the default range of current row since uses_window_frame is False # So for the purpose of this test we just smooth from the previous row to # current. diff --git a/examples/tpch/_tests.py b/examples/tpch/_tests.py index c4d872085..2be4dfabd 100644 --- a/examples/tpch/_tests.py +++ b/examples/tpch/_tests.py @@ -27,28 +27,25 @@ def df_selection(col_name, col_type): if col_type == pa.float64() or isinstance(col_type, pa.Decimal128Type): return F.round(col(col_name), lit(2)).alias(col_name) - elif col_type == pa.string() or col_type == pa.string_view(): + if col_type == pa.string() or col_type == pa.string_view(): return F.trim(col(col_name)).alias(col_name) - else: - return col(col_name) + return col(col_name) def load_schema(col_name, col_type): if col_type == pa.int64() or col_type == pa.int32(): return col_name, pa.string() - elif isinstance(col_type, pa.Decimal128Type): + if isinstance(col_type, pa.Decimal128Type): return col_name, pa.float64() - else: - return col_name, col_type + return col_name, col_type def expected_selection(col_name, col_type): if col_type == pa.int64() or col_type == pa.int32(): return F.trim(col(col_name)).cast(col_type).alias(col_name) - elif col_type == pa.string() or col_type == pa.string_view(): + if col_type == pa.string() or col_type == pa.string_view(): return F.trim(col(col_name)).alias(col_name) - else: - return col(col_name) + return col(col_name) def selections_and_schema(original_schema): diff --git a/pyproject.toml b/pyproject.toml index d16a18aa6..060e3b80a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ name = "datafusion" description = "Build and run queries against data" readme = "README.md" license = { file = "LICENSE.txt" } -requires-python = ">=3.8" +requires-python = ">=3.9" keywords = ["datafusion", "dataframe", "rust", "query-engine"] classifiers = [ "Development Status :: 2 - Pre-Alpha", @@ -35,7 +35,6 @@ classifiers = [ "Operating System :: Microsoft :: Windows", "Operating System :: POSIX :: Linux", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", @@ -66,7 +65,57 @@ features = ["substrait"] # Enable docstring linting using the google style guide [tool.ruff.lint] -select = ["E4", "E7", "E9", "F", "FA", "D", "W", "I"] +select = ["ALL" ] +ignore = [ + "A001", # Allow using words like min as variable names + "A002", # Allow using words like filter as variable names + "ANN401", # Allow Any for wrapper classes + "COM812", # Recommended to ignore these rules when using with ruff-format + "FIX002", # Allow TODO lines - consider removing at some point + "FBT001", # Allow boolean positional args + "FBT002", # Allow boolean positional args + "ISC001", # Recommended to ignore these rules when using with ruff-format + "SLF001", # Allow accessing private members + "TD002", + "TD003", # Allow TODO lines + "UP007", # Disallowing Union is pedantic + # TODO: Enable all of the following, but this PR is getting too large already + "PT001", + "ANN204", + "B008", + "EM101", + "PLR0913", + "PLR1714", + "ANN201", + "C400", + "TRY003", + "B904", + "UP006", + "RUF012", + "FBT003", + "C416", + "SIM102", + "PGH003", + "PLR2004", + "PERF401", + "PD901", + "EM102", + "ERA001", + "SIM108", + "ICN001", + "ANN001", + "ANN202", + "PTH", + "N812", + "INP001", + "DTZ007", + "PLW2901", + "RET503", + "RUF015", + "A005", + "TC001", + "UP035", +] [tool.ruff.lint.pydocstyle] convention = "google" @@ -76,16 +125,30 @@ max-doc-length = 88 # Disable docstring checking for these directories [tool.ruff.lint.per-file-ignores] -"python/tests/*" = ["D"] -"examples/*" = ["D", "W505"] -"dev/*" = ["D"] -"benchmarks/*" = ["D", "F"] +"python/tests/*" = [ + "ANN", + "ARG", + "BLE001", + "D", + "S101", + "SLF", + "PD", + "PLR2004", + "PT011", + "RUF015", + "S608", + "PLR0913", + "PT004", +] +"examples/*" = ["D", "W505", "E501", "T201", "S101"] +"dev/*" = ["D", "E", "T", "S", "PLR", "C", "SIM", "UP", "EXE", "N817"] +"benchmarks/*" = ["D", "F", "T", "BLE", "FURB", "PLR", "E", "TD", "TRY", "S", "SIM", "EXE", "UP"] "docs/*" = ["D"] [dependency-groups] dev = [ "maturin>=1.8.1", - "numpy>1.24.4 ; python_full_version >= '3.10'", + "numpy>1.25.0", "pytest>=7.4.4", "ruff>=0.9.1", "toml>=0.10.2", diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index f11ce54a6..286e5dc31 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -48,44 +48,47 @@ from .io import read_avro, read_csv, read_json, read_parquet from .plan import ExecutionPlan, LogicalPlan from .record_batch import RecordBatch, RecordBatchStream -from .udf import Accumulator, AggregateUDF, ScalarUDF, WindowUDF +from .udf import Accumulator, AggregateUDF, ScalarUDF, WindowUDF, udaf, udf, udwf __version__ = importlib_metadata.version(__name__) __all__ = [ "Accumulator", + "AggregateUDF", + "Catalog", "Config", - "DataFrame", - "SessionContext", - "SessionConfig", - "SQLOptions", - "RuntimeEnvBuilder", - "Expr", - "ScalarUDF", - "WindowFrame", - "column", - "col", - "literal", - "lit", "DFSchema", - "Catalog", + "DataFrame", "Database", - "Table", - "AggregateUDF", - "WindowUDF", - "LogicalPlan", "ExecutionPlan", + "Expr", + "LogicalPlan", "RecordBatch", "RecordBatchStream", + "RuntimeEnvBuilder", + "SQLOptions", + "ScalarUDF", + "SessionConfig", + "SessionContext", + "Table", + "WindowFrame", + "WindowUDF", + "col", + "column", "common", "expr", "functions", + "lit", + "literal", "object_store", - "substrait", - "read_parquet", "read_avro", "read_csv", "read_json", + "read_parquet", + "substrait", + "udaf", + "udf", + "udwf", ] @@ -120,10 +123,3 @@ def str_lit(value): def lit(value): """Create a literal expression.""" return Expr.literal(value) - - -udf = ScalarUDF.udf - -udaf = AggregateUDF.udaf - -udwf = WindowUDF.udwf diff --git a/python/datafusion/common.py b/python/datafusion/common.py index a2298c634..e762a993b 100644 --- a/python/datafusion/common.py +++ b/python/datafusion/common.py @@ -20,7 +20,7 @@ from ._internal import common as common_internal -# TODO these should all have proper wrapper classes +# TODO: these should all have proper wrapper classes DFSchema = common_internal.DFSchema DataType = common_internal.DataType @@ -38,15 +38,15 @@ "DFSchema", "DataType", "DataTypeMap", - "RexType", - "PythonType", - "SqlType", "NullTreatment", - "SqlTable", + "PythonType", + "RexType", + "SqlFunction", "SqlSchema", - "SqlView", "SqlStatistics", - "SqlFunction", + "SqlTable", + "SqlType", + "SqlView", ] diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 282b2a477..0ab1a908a 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -393,8 +393,6 @@ def with_temp_file_path(self, path: str | pathlib.Path) -> RuntimeEnvBuilder: class RuntimeConfig(RuntimeEnvBuilder): """See `RuntimeEnvBuilder`.""" - pass - class SQLOptions: """Options to be used when performing SQL queries.""" @@ -498,7 +496,7 @@ def __init__( self.ctx = SessionContextInternal(config, runtime) - def enable_url_table(self) -> "SessionContext": + def enable_url_table(self) -> SessionContext: """Control if local files can be queried as tables. Returns: diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index de5d8376e..d1c71c2bb 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -29,6 +29,7 @@ List, Literal, Optional, + Type, Union, overload, ) @@ -49,10 +50,11 @@ import polars as pl import pyarrow as pa + from datafusion._internal import DataFrame as DataFrameInternal + from datafusion._internal import expr as expr_internal + from enum import Enum -from datafusion._internal import DataFrame as DataFrameInternal -from datafusion._internal import expr as expr_internal from datafusion.expr import Expr, SortExpr, sort_or_default @@ -73,7 +75,7 @@ class Compression(Enum): LZ4_RAW = "lz4_raw" @classmethod - def from_str(cls, value: str) -> "Compression": + def from_str(cls: Type[Compression], value: str) -> Compression: """Convert a string to a Compression enum value. Args: @@ -88,8 +90,9 @@ def from_str(cls, value: str) -> "Compression": try: return cls(value.lower()) except ValueError: + valid_values = str([item.value for item in Compression]) raise ValueError( - f"{value} is not a valid Compression. Valid values are: {[item.value for item in Compression]}" + f"{value} is not a valid Compression. Valid values are: {valid_values}" ) def get_default_level(self) -> Optional[int]: @@ -104,9 +107,9 @@ def get_default_level(self) -> Optional[int]: # https://github.com/apache/datafusion-python/pull/981#discussion_r1904789223 if self == Compression.GZIP: return 6 - elif self == Compression.BROTLI: + if self == Compression.BROTLI: return 1 - elif self == Compression.ZSTD: + if self == Compression.ZSTD: return 4 return None diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 3639abec6..702f75aed 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -101,63 +101,63 @@ WindowExpr = expr_internal.WindowExpr __all__ = [ - "Expr", - "Column", - "Literal", - "BinaryExpr", - "Literal", + "Aggregate", "AggregateFunction", - "Not", - "IsNotNull", - "IsNull", - "IsTrue", - "IsFalse", - "IsUnknown", - "IsNotTrue", - "IsNotFalse", - "IsNotUnknown", - "Negative", - "Like", - "ILike", - "SimilarTo", - "ScalarVariable", "Alias", - "InList", - "Exists", - "Subquery", - "InSubquery", - "ScalarSubquery", - "Placeholder", - "GroupingSet", + "Analyze", + "Between", + "BinaryExpr", "Case", "CaseBuilder", "Cast", - "TryCast", - "Between", + "Column", + "CreateMemoryTable", + "CreateView", + "Distinct", + "DropTable", + "EmptyRelation", + "Exists", "Explain", + "Expr", + "Extension", + "Filter", + "GroupingSet", + "ILike", + "InList", + "InSubquery", + "IsFalse", + "IsNotFalse", + "IsNotNull", + "IsNotTrue", + "IsNotUnknown", + "IsNull", + "IsTrue", + "IsUnknown", + "Join", + "JoinConstraint", + "JoinType", + "Like", "Limit", - "Aggregate", + "Literal", + "Literal", + "Negative", + "Not", + "Partitioning", + "Placeholder", + "Projection", + "Repartition", + "ScalarSubquery", + "ScalarVariable", + "SimilarTo", "Sort", "SortExpr", - "Analyze", - "EmptyRelation", - "Join", - "JoinType", - "JoinConstraint", + "Subquery", + "SubqueryAlias", + "TableScan", + "TryCast", "Union", "Unnest", "UnnestExpr", - "Extension", - "Filter", - "Projection", - "TableScan", - "CreateMemoryTable", - "CreateView", - "Distinct", - "SubqueryAlias", - "DropTable", - "Partitioning", - "Repartition", "Window", "WindowExpr", "WindowFrame", @@ -311,7 +311,7 @@ def __getitem__(self, key: str | int) -> Expr: ) return Expr(self.expr.__getitem__(key)) - def __eq__(self, rhs: Any) -> Expr: + def __eq__(self, rhs: object) -> Expr: """Equal to. Accepts either an expression or any valid PyArrow scalar literal value. @@ -320,7 +320,7 @@ def __eq__(self, rhs: Any) -> Expr: rhs = Expr.literal(rhs) return Expr(self.expr.__eq__(rhs.expr)) - def __ne__(self, rhs: Any) -> Expr: + def __ne__(self, rhs: object) -> Expr: """Not equal to. Accepts either an expression or any valid PyArrow scalar literal value. diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index b449c4868..0cc7434cf 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -18,13 +18,12 @@ from __future__ import annotations -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional import pyarrow as pa from datafusion._internal import functions as f from datafusion.common import NullTreatment -from datafusion.context import SessionContext from datafusion.expr import ( CaseBuilder, Expr, @@ -34,6 +33,9 @@ sort_list_to_raw_sort_list, ) +if TYPE_CHECKING: + from datafusion.context import SessionContext + __all__ = [ "abs", "acos", @@ -81,8 +83,8 @@ "array_sort", "array_to_string", "array_union", - "arrow_typeof", "arrow_cast", + "arrow_typeof", "ascii", "asin", "asinh", @@ -97,6 +99,7 @@ "bool_and", "bool_or", "btrim", + "cardinality", "case", "cbrt", "ceil", @@ -116,6 +119,7 @@ "covar", "covar_pop", "covar_samp", + "cume_dist", "current_date", "current_time", "date_bin", @@ -125,17 +129,17 @@ "datetrunc", "decode", "degrees", + "dense_rank", "digest", "empty", "encode", "ends_with", - "extract", "exp", + "extract", "factorial", "find_in_set", "first_value", "flatten", - "cardinality", "floor", "from_unixtime", "gcd", @@ -143,8 +147,10 @@ "initcap", "isnan", "iszero", + "lag", "last_value", "lcm", + "lead", "left", "length", "levenshtein", @@ -166,10 +172,10 @@ "list_prepend", "list_push_back", "list_push_front", - "list_repeat", "list_remove", "list_remove_all", "list_remove_n", + "list_repeat", "list_replace", "list_replace_all", "list_replace_n", @@ -180,14 +186,14 @@ "list_union", "ln", "log", - "log10", "log2", + "log10", "lower", "lpad", "ltrim", "make_array", - "make_list", "make_date", + "make_list", "max", "md5", "mean", @@ -195,19 +201,22 @@ "min", "named_struct", "nanvl", - "nvl", "now", "nth_value", + "ntile", "nullif", + "nvl", "octet_length", "order_by", "overlay", + "percent_rank", "pi", "pow", "power", "radians", "random", "range", + "rank", "regexp_like", "regexp_match", "regexp_replace", @@ -225,6 +234,7 @@ "reverse", "right", "round", + "row_number", "rpad", "rtrim", "sha224", @@ -252,8 +262,8 @@ "to_hex", "to_timestamp", "to_timestamp_micros", - "to_timestamp_nanos", "to_timestamp_millis", + "to_timestamp_nanos", "to_timestamp_seconds", "to_unixtime", "translate", @@ -268,14 +278,6 @@ "when", # Window Functions "window", - "lead", - "lag", - "row_number", - "rank", - "dense_rank", - "percent_rank", - "cume_dist", - "ntile", ] @@ -292,14 +294,14 @@ def nullif(expr1: Expr, expr2: Expr) -> Expr: return Expr(f.nullif(expr1.expr, expr2.expr)) -def encode(input: Expr, encoding: Expr) -> Expr: +def encode(expr: Expr, encoding: Expr) -> Expr: """Encode the ``input``, using the ``encoding``. encoding can be base64 or hex.""" - return Expr(f.encode(input.expr, encoding.expr)) + return Expr(f.encode(expr.expr, encoding.expr)) -def decode(input: Expr, encoding: Expr) -> Expr: +def decode(expr: Expr, encoding: Expr) -> Expr: """Decode the ``input``, using the ``encoding``. encoding can be base64 or hex.""" - return Expr(f.decode(input.expr, encoding.expr)) + return Expr(f.decode(expr.expr, encoding.expr)) def array_to_string(expr: Expr, delimiter: Expr) -> Expr: diff --git a/python/datafusion/input/__init__.py b/python/datafusion/input/__init__.py index f85ce21f0..f0c1f42b4 100644 --- a/python/datafusion/input/__init__.py +++ b/python/datafusion/input/__init__.py @@ -23,5 +23,5 @@ from .location import LocationInputPlugin __all__ = [ - LocationInputPlugin, + "LocationInputPlugin", ] diff --git a/python/datafusion/input/base.py b/python/datafusion/input/base.py index 4eba19784..f67dde2a1 100644 --- a/python/datafusion/input/base.py +++ b/python/datafusion/input/base.py @@ -38,11 +38,9 @@ class BaseInputSource(ABC): """ @abstractmethod - def is_correct_input(self, input_item: Any, table_name: str, **kwargs) -> bool: + def is_correct_input(self, input_item: Any, table_name: str, **kwargs: Any) -> bool: """Returns `True` if the input is valid.""" - pass @abstractmethod - def build_table(self, input_item: Any, table_name: str, **kwarg) -> SqlTable: + def build_table(self, input_item: Any, table_name: str, **kwarg: Any) -> SqlTable: # type: ignore[invalid-type-form] """Create a table from the input source.""" - pass diff --git a/python/datafusion/input/location.py b/python/datafusion/input/location.py index 517cd1578..08d98d115 100644 --- a/python/datafusion/input/location.py +++ b/python/datafusion/input/location.py @@ -18,7 +18,7 @@ """The default input source for DataFusion.""" import glob -import os +from pathlib import Path from typing import Any from datafusion.common import DataTypeMap, SqlTable @@ -31,7 +31,7 @@ class LocationInputPlugin(BaseInputSource): This can be read in from a file (on disk, remote etc.). """ - def is_correct_input(self, input_item: Any, table_name: str, **kwargs): + def is_correct_input(self, input_item: Any, table_name: str, **kwargs: Any) -> bool: # noqa: ARG002 """Returns `True` if the input is valid.""" return isinstance(input_item, str) @@ -39,27 +39,28 @@ def build_table( self, input_item: str, table_name: str, - **kwargs, - ) -> SqlTable: + **kwargs: Any, # noqa: ARG002 + ) -> SqlTable: # type: ignore[invalid-type-form] """Create a table from the input source.""" - _, extension = os.path.splitext(input_item) - format = extension.lstrip(".").lower() + extension = Path(input_item).suffix + file_format = extension.lstrip(".").lower() num_rows = 0 # Total number of rows in the file. Used for statistics columns = [] - if format == "parquet": + if file_format == "parquet": import pyarrow.parquet as pq # Read the Parquet metadata metadata = pq.read_metadata(input_item) num_rows = metadata.num_rows # Iterate through the schema and build the SqlTable - for col in metadata.schema: - columns.append( - ( - col.name, - DataTypeMap.from_parquet_type_str(col.physical_type), - ) + columns = [ + ( + col.name, + DataTypeMap.from_parquet_type_str(col.physical_type), ) + for col in metadata.schema + ] + elif format == "csv": import csv @@ -69,19 +70,18 @@ def build_table( # to get that information. However, this should only be occurring # at table creation time and therefore shouldn't # slow down query performance. - with open(input_item, "r") as file: + with Path(input_item).open() as file: reader = csv.reader(file) - header_row = next(reader) - print(header_row) + _header_row = next(reader) for _ in reader: num_rows += 1 # TODO: Need to actually consume this row into reasonable columns - raise RuntimeError("TODO: Currently unable to support CSV input files.") + msg = "TODO: Currently unable to support CSV input files." + raise RuntimeError(msg) else: - raise RuntimeError( - f"Input of format: `{format}` is currently not supported.\ + msg = f"Input of format: `{format}` is currently not supported.\ Only Parquet and CSV." - ) + raise RuntimeError(msg) # Input could possibly be multiple files. Create a list if so input_files = glob.glob(input_item) diff --git a/python/datafusion/io.py b/python/datafusion/io.py index 3b6264948..3e39703e3 100644 --- a/python/datafusion/io.py +++ b/python/datafusion/io.py @@ -19,15 +19,19 @@ from __future__ import annotations -import pathlib - -import pyarrow +from typing import TYPE_CHECKING from datafusion.dataframe import DataFrame -from datafusion.expr import Expr from ._internal import SessionContext as SessionContextInternal +if TYPE_CHECKING: + import pathlib + + import pyarrow as pa + + from datafusion.expr import Expr + def read_parquet( path: str | pathlib.Path, @@ -35,7 +39,7 @@ def read_parquet( parquet_pruning: bool = True, file_extension: str = ".parquet", skip_metadata: bool = True, - schema: pyarrow.Schema | None = None, + schema: pa.Schema | None = None, file_sort_order: list[list[Expr]] | None = None, ) -> DataFrame: """Read a Parquet source into a :py:class:`~datafusion.dataframe.Dataframe`. @@ -79,7 +83,7 @@ def read_parquet( def read_json( path: str | pathlib.Path, - schema: pyarrow.Schema | None = None, + schema: pa.Schema | None = None, schema_infer_max_records: int = 1000, file_extension: str = ".json", table_partition_cols: list[tuple[str, str]] | None = None, @@ -120,7 +124,7 @@ def read_json( def read_csv( path: str | pathlib.Path | list[str] | list[pathlib.Path], - schema: pyarrow.Schema | None = None, + schema: pa.Schema | None = None, has_header: bool = True, delimiter: str = ",", schema_infer_max_records: int = 1000, @@ -173,7 +177,7 @@ def read_csv( def read_avro( path: str | pathlib.Path, - schema: pyarrow.Schema | None = None, + schema: pa.Schema | None = None, file_partition_cols: list[tuple[str, str]] | None = None, file_extension: str = ".avro", ) -> DataFrame: diff --git a/python/datafusion/object_store.py b/python/datafusion/object_store.py index 7cc17506f..6298526f5 100644 --- a/python/datafusion/object_store.py +++ b/python/datafusion/object_store.py @@ -24,4 +24,4 @@ MicrosoftAzure = object_store.MicrosoftAzure Http = object_store.Http -__all__ = ["AmazonS3", "GoogleCloud", "LocalFileSystem", "MicrosoftAzure", "Http"] +__all__ = ["AmazonS3", "GoogleCloud", "Http", "LocalFileSystem", "MicrosoftAzure"] diff --git a/python/datafusion/plan.py b/python/datafusion/plan.py index 133fc446d..0b7bebcb3 100644 --- a/python/datafusion/plan.py +++ b/python/datafusion/plan.py @@ -19,7 +19,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, List +from typing import TYPE_CHECKING, Any import datafusion._internal as df_internal @@ -27,8 +27,8 @@ from datafusion.context import SessionContext __all__ = [ - "LogicalPlan", "ExecutionPlan", + "LogicalPlan", ] @@ -54,7 +54,7 @@ def to_variant(self) -> Any: """Convert the logical plan into its specific variant.""" return self._raw_plan.to_variant() - def inputs(self) -> List[LogicalPlan]: + def inputs(self) -> list[LogicalPlan]: """Returns the list of inputs to the logical plan.""" return [LogicalPlan(p) for p in self._raw_plan.inputs()] @@ -106,7 +106,7 @@ def __init__(self, plan: df_internal.ExecutionPlan) -> None: """This constructor should not be called by the end user.""" self._raw_plan = plan - def children(self) -> List[ExecutionPlan]: + def children(self) -> list[ExecutionPlan]: """Get a list of children `ExecutionPlan` that act as inputs to this plan. The returned list will be empty for leaf nodes such as scans, will contain a diff --git a/python/datafusion/record_batch.py b/python/datafusion/record_batch.py index 772cd9089..556eaa786 100644 --- a/python/datafusion/record_batch.py +++ b/python/datafusion/record_batch.py @@ -26,14 +26,14 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - import pyarrow + import pyarrow as pa import typing_extensions import datafusion._internal as df_internal class RecordBatch: - """This class is essentially a wrapper for :py:class:`pyarrow.RecordBatch`.""" + """This class is essentially a wrapper for :py:class:`pa.RecordBatch`.""" def __init__(self, record_batch: df_internal.RecordBatch) -> None: """This constructor is generally not called by the end user. @@ -42,8 +42,8 @@ def __init__(self, record_batch: df_internal.RecordBatch) -> None: """ self.record_batch = record_batch - def to_pyarrow(self) -> pyarrow.RecordBatch: - """Convert to :py:class:`pyarrow.RecordBatch`.""" + def to_pyarrow(self) -> pa.RecordBatch: + """Convert to :py:class:`pa.RecordBatch`.""" return self.record_batch.to_pyarrow() diff --git a/python/datafusion/substrait.py b/python/datafusion/substrait.py index 06302fe38..f10adfb0c 100644 --- a/python/datafusion/substrait.py +++ b/python/datafusion/substrait.py @@ -23,7 +23,6 @@ from __future__ import annotations -import pathlib from typing import TYPE_CHECKING try: @@ -36,11 +35,13 @@ from ._internal import substrait as substrait_internal if TYPE_CHECKING: + import pathlib + from datafusion.context import SessionContext __all__ = [ - "Plan", "Consumer", + "Plan", "Producer", "Serde", ] @@ -68,11 +69,9 @@ def encode(self) -> bytes: @deprecated("Use `Plan` instead.") -class plan(Plan): +class plan(Plan): # noqa: N801 """See `Plan`.""" - pass - class Serde: """Provides the ``Substrait`` serialization and deserialization.""" @@ -140,11 +139,9 @@ def deserialize_bytes(proto_bytes: bytes) -> Plan: @deprecated("Use `Serde` instead.") -class serde(Serde): +class serde(Serde): # noqa: N801 """See `Serde` instead.""" - pass - class Producer: """Generates substrait plans from a logical plan.""" @@ -168,11 +165,9 @@ def to_substrait_plan(logical_plan: LogicalPlan, ctx: SessionContext) -> Plan: @deprecated("Use `Producer` instead.") -class producer(Producer): +class producer(Producer): # noqa: N801 """Use `Producer` instead.""" - pass - class Consumer: """Generates a logical plan from a substrait plan.""" @@ -194,7 +189,5 @@ def from_substrait_plan(ctx: SessionContext, plan: Plan) -> LogicalPlan: @deprecated("Use `Consumer` instead.") -class consumer(Consumer): +class consumer(Consumer): # noqa: N801 """Use `Consumer` instead.""" - - pass diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index af7bcf2ed..603b7063d 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -22,15 +22,15 @@ import functools from abc import ABCMeta, abstractmethod from enum import Enum -from typing import TYPE_CHECKING, Callable, List, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, overload -import pyarrow +import pyarrow as pa import datafusion._internal as df_internal from datafusion.expr import Expr if TYPE_CHECKING: - _R = TypeVar("_R", bound=pyarrow.DataType) + _R = TypeVar("_R", bound=pa.DataType) class Volatility(Enum): @@ -72,7 +72,7 @@ class Volatility(Enum): for each output row, resulting in a unique random value for each row. """ - def __str__(self): + def __str__(self) -> str: """Returns the string equivalent.""" return self.name.lower() @@ -88,7 +88,7 @@ def __init__( self, name: str, func: Callable[..., _R], - input_types: pyarrow.DataType | list[pyarrow.DataType], + input_types: pa.DataType | list[pa.DataType], return_type: _R, volatility: Volatility | str, ) -> None: @@ -96,7 +96,7 @@ def __init__( See helper method :py:func:`udf` for argument details. """ - if isinstance(input_types, pyarrow.DataType): + if isinstance(input_types, pa.DataType): input_types = [input_types] self._udf = df_internal.ScalarUDF( name, func, input_types, return_type, str(volatility) @@ -111,7 +111,27 @@ def __call__(self, *args: Expr) -> Expr: args_raw = [arg.expr for arg in args] return Expr(self._udf.__call__(*args_raw)) - class udf: + @overload + @staticmethod + def udf( + input_types: list[pa.DataType], + return_type: _R, + volatility: Volatility | str, + name: Optional[str] = None, + ) -> Callable[..., ScalarUDF]: ... + + @overload + @staticmethod + def udf( + func: Callable[..., _R], + input_types: list[pa.DataType], + return_type: _R, + volatility: Volatility | str, + name: Optional[str] = None, + ) -> ScalarUDF: ... + + @staticmethod + def udf(*args: Any, **kwargs: Any): # noqa: D417 """Create a new User-Defined Function (UDF). This class can be used both as a **function** and as a **decorator**. @@ -125,7 +145,7 @@ class udf: Args: func (Callable, optional): **Only needed when calling as a function.** Skip this argument when using `udf` as a decorator. - input_types (list[pyarrow.DataType]): The data types of the arguments + input_types (list[pa.DataType]): The data types of the arguments to `func`. This list must be of the same length as the number of arguments. return_type (_R): The data type of the return value from the function. @@ -141,40 +161,28 @@ class udf: ``` def double_func(x): return x * 2 - double_udf = udf(double_func, [pyarrow.int32()], pyarrow.int32(), + double_udf = udf(double_func, [pa.int32()], pa.int32(), "volatile", "double_it") ``` **Using `udf` as a decorator:** ``` - @udf([pyarrow.int32()], pyarrow.int32(), "volatile", "double_it") + @udf([pa.int32()], pa.int32(), "volatile", "double_it") def double_udf(x): return x * 2 ``` """ - def __new__(cls, *args, **kwargs): - """Create a new UDF. - - Trigger UDF function or decorator depending on if the first args is callable - """ - if args and callable(args[0]): - # Case 1: Used as a function, require the first parameter to be callable - return cls._function(*args, **kwargs) - else: - # Case 2: Used as a decorator with parameters - return cls._decorator(*args, **kwargs) - - @staticmethod def _function( func: Callable[..., _R], - input_types: list[pyarrow.DataType], + input_types: list[pa.DataType], return_type: _R, volatility: Volatility | str, name: Optional[str] = None, ) -> ScalarUDF: if not callable(func): - raise TypeError("`func` argument must be callable") + msg = "`func` argument must be callable" + raise TypeError(msg) if name is None: if hasattr(func, "__qualname__"): name = func.__qualname__.lower() @@ -188,49 +196,50 @@ def _function( volatility=volatility, ) - @staticmethod def _decorator( - input_types: list[pyarrow.DataType], + input_types: list[pa.DataType], return_type: _R, volatility: Volatility | str, name: Optional[str] = None, - ): - def decorator(func): + ) -> Callable: + def decorator(func: Callable): udf_caller = ScalarUDF.udf( func, input_types, return_type, volatility, name ) @functools.wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any): return udf_caller(*args, **kwargs) return wrapper return decorator + if args and callable(args[0]): + # Case 1: Used as a function, require the first parameter to be callable + return _function(*args, **kwargs) + # Case 2: Used as a decorator with parameters + return _decorator(*args, **kwargs) + class Accumulator(metaclass=ABCMeta): """Defines how an :py:class:`AggregateUDF` accumulates values.""" @abstractmethod - def state(self) -> List[pyarrow.Scalar]: + def state(self) -> list[pa.Scalar]: """Return the current state.""" - pass @abstractmethod - def update(self, *values: pyarrow.Array) -> None: + def update(self, *values: pa.Array) -> None: """Evaluate an array of values and update state.""" - pass @abstractmethod - def merge(self, states: List[pyarrow.Array]) -> None: + def merge(self, states: list[pa.Array]) -> None: """Merge a set of states.""" - pass @abstractmethod - def evaluate(self) -> pyarrow.Scalar: + def evaluate(self) -> pa.Scalar: """Return the resultant value.""" - pass class AggregateUDF: @@ -244,9 +253,9 @@ def __init__( self, name: str, accumulator: Callable[[], Accumulator], - input_types: list[pyarrow.DataType], - return_type: pyarrow.DataType, - state_type: list[pyarrow.DataType], + input_types: list[pa.DataType], + return_type: pa.DataType, + state_type: list[pa.DataType], volatility: Volatility | str, ) -> None: """Instantiate a user-defined aggregate function (UDAF). @@ -272,7 +281,29 @@ def __call__(self, *args: Expr) -> Expr: args_raw = [arg.expr for arg in args] return Expr(self._udaf.__call__(*args_raw)) - class udaf: + @overload + @staticmethod + def udaf( + input_types: pa.DataType | list[pa.DataType], + return_type: pa.DataType, + state_type: list[pa.DataType], + volatility: Volatility | str, + name: Optional[str] = None, + ) -> Callable[..., AggregateUDF]: ... + + @overload + @staticmethod + def udaf( + accum: Callable[[], Accumulator], + input_types: pa.DataType | list[pa.DataType], + return_type: pa.DataType, + state_type: list[pa.DataType], + volatility: Volatility | str, + name: Optional[str] = None, + ) -> AggregateUDF: ... + + @staticmethod + def udaf(*args: Any, **kwargs: Any): # noqa: D417 """Create a new User-Defined Aggregate Function (UDAF). This class allows you to define an **aggregate function** that can be used in @@ -300,13 +331,13 @@ class Summarize(Accumulator): def __init__(self, bias: float = 0.0): self._sum = pa.scalar(bias) - def state(self) -> List[pa.Scalar]: + def state(self) -> list[pa.Scalar]: return [self._sum] def update(self, values: pa.Array) -> None: self._sum = pa.scalar(self._sum.as_py() + pc.sum(values).as_py()) - def merge(self, states: List[pa.Array]) -> None: + def merge(self, states: list[pa.Array]) -> None: self._sum = pa.scalar(self._sum.as_py() + pc.sum(states[0]).as_py()) def evaluate(self) -> pa.Scalar: @@ -344,37 +375,23 @@ def udf4() -> Summarize: aggregation or window function calls. """ - def __new__(cls, *args, **kwargs): - """Create a new UDAF. - - Trigger UDAF function or decorator depending on if the first args is - callable - """ - if args and callable(args[0]): - # Case 1: Used as a function, require the first parameter to be callable - return cls._function(*args, **kwargs) - else: - # Case 2: Used as a decorator with parameters - return cls._decorator(*args, **kwargs) - - @staticmethod def _function( accum: Callable[[], Accumulator], - input_types: pyarrow.DataType | list[pyarrow.DataType], - return_type: pyarrow.DataType, - state_type: list[pyarrow.DataType], + input_types: pa.DataType | list[pa.DataType], + return_type: pa.DataType, + state_type: list[pa.DataType], volatility: Volatility | str, name: Optional[str] = None, ) -> AggregateUDF: if not callable(accum): - raise TypeError("`func` must be callable.") - if not isinstance(accum.__call__(), Accumulator): - raise TypeError( - "Accumulator must implement the abstract base class Accumulator" - ) + msg = "`func` must be callable." + raise TypeError(msg) + if not isinstance(accum(), Accumulator): + msg = "Accumulator must implement the abstract base class Accumulator" + raise TypeError(msg) if name is None: - name = accum.__call__().__class__.__qualname__.lower() - if isinstance(input_types, pyarrow.DataType): + name = accum().__class__.__qualname__.lower() + if isinstance(input_types, pa.DataType): input_types = [input_types] return AggregateUDF( name=name, @@ -385,29 +402,34 @@ def _function( volatility=volatility, ) - @staticmethod def _decorator( - input_types: pyarrow.DataType | list[pyarrow.DataType], - return_type: pyarrow.DataType, - state_type: list[pyarrow.DataType], + input_types: pa.DataType | list[pa.DataType], + return_type: pa.DataType, + state_type: list[pa.DataType], volatility: Volatility | str, name: Optional[str] = None, - ): - def decorator(accum: Callable[[], Accumulator]): + ) -> Callable[..., Callable[..., Expr]]: + def decorator(accum: Callable[[], Accumulator]) -> Callable[..., Expr]: udaf_caller = AggregateUDF.udaf( accum, input_types, return_type, state_type, volatility, name ) @functools.wraps(accum) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Expr: return udaf_caller(*args, **kwargs) return wrapper return decorator + if args and callable(args[0]): + # Case 1: Used as a function, require the first parameter to be callable + return _function(*args, **kwargs) + # Case 2: Used as a decorator with parameters + return _decorator(*args, **kwargs) + -class WindowEvaluator(metaclass=ABCMeta): +class WindowEvaluator: """Evaluator class for user-defined window functions (UDWF). It is up to the user to decide which evaluate function is appropriate. @@ -423,7 +445,7 @@ class WindowEvaluator(metaclass=ABCMeta): +------------------------+--------------------------------+------------------+---------------------------+ | True | True/False | True/False | ``evaluate`` | +------------------------+--------------------------------+------------------+---------------------------+ - """ # noqa: W505 + """ # noqa: W505, E501 def memoize(self) -> None: """Perform a memoize operation to improve performance. @@ -436,9 +458,8 @@ def memoize(self) -> None: `memoize` is called after each input batch is processed, and such functions can save whatever they need """ - pass - def get_range(self, idx: int, num_rows: int) -> tuple[int, int]: + def get_range(self, idx: int, num_rows: int) -> tuple[int, int]: # noqa: ARG002 """Return the range for the window fuction. If `uses_window_frame` flag is `false`. This method is used to @@ -460,14 +481,17 @@ def is_causal(self) -> bool: """Get whether evaluator needs future data for its result.""" return False - def evaluate_all(self, values: list[pyarrow.Array], num_rows: int) -> pyarrow.Array: + def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array: """Evaluate a window function on an entire input partition. This function is called once per input *partition* for window functions that *do not use* values from the window frame, such as - :py:func:`~datafusion.functions.row_number`, :py:func:`~datafusion.functions.rank`, - :py:func:`~datafusion.functions.dense_rank`, :py:func:`~datafusion.functions.percent_rank`, - :py:func:`~datafusion.functions.cume_dist`, :py:func:`~datafusion.functions.lead`, + :py:func:`~datafusion.functions.row_number`, + :py:func:`~datafusion.functions.rank`, + :py:func:`~datafusion.functions.dense_rank`, + :py:func:`~datafusion.functions.percent_rank`, + :py:func:`~datafusion.functions.cume_dist`, + :py:func:`~datafusion.functions.lead`, and :py:func:`~datafusion.functions.lag`. It produces the result of all rows in a single pass. It @@ -499,12 +523,11 @@ def evaluate_all(self, values: list[pyarrow.Array], num_rows: int) -> pyarrow.Ar .. code-block:: text avg(x) OVER (PARTITION BY y ORDER BY z ROWS BETWEEN 2 PRECEDING AND 3 FOLLOWING) - """ # noqa: W505 - pass + """ # noqa: W505, E501 def evaluate( - self, values: list[pyarrow.Array], eval_range: tuple[int, int] - ) -> pyarrow.Scalar: + self, values: list[pa.Array], eval_range: tuple[int, int] + ) -> pa.Scalar: """Evaluate window function on a range of rows in an input partition. This is the simplest and most general function to implement @@ -519,11 +542,10 @@ def evaluate( and evaluation results of ORDER BY expressions. If function has a single argument, `values[1..]` will contain ORDER BY expression results. """ - pass def evaluate_all_with_rank( self, num_rows: int, ranks_in_partition: list[tuple[int, int]] - ) -> pyarrow.Array: + ) -> pa.Array: """Called for window functions that only need the rank of a row. Evaluate the partition evaluator against the partition using @@ -552,7 +574,6 @@ def evaluate_all_with_rank( The user must implement this method if ``include_rank`` returns True. """ - pass def supports_bounded_execution(self) -> bool: """Can the window function be incrementally computed using bounded memory?""" @@ -567,10 +588,6 @@ def include_rank(self) -> bool: return False -if TYPE_CHECKING: - _W = TypeVar("_W", bound=WindowEvaluator) - - class WindowUDF: """Class for performing window user-defined functions (UDF). @@ -582,8 +599,8 @@ def __init__( self, name: str, func: Callable[[], WindowEvaluator], - input_types: list[pyarrow.DataType], - return_type: pyarrow.DataType, + input_types: list[pa.DataType], + return_type: pa.DataType, volatility: Volatility | str, ) -> None: """Instantiate a user-defined window function (UDWF). @@ -607,8 +624,8 @@ def __call__(self, *args: Expr) -> Expr: @staticmethod def udwf( func: Callable[[], WindowEvaluator], - input_types: pyarrow.DataType | list[pyarrow.DataType], - return_type: pyarrow.DataType, + input_types: pa.DataType | list[pa.DataType], + return_type: pa.DataType, volatility: Volatility | str, name: Optional[str] = None, ) -> WindowUDF: @@ -648,16 +665,16 @@ def bias_10() -> BiasedNumbers: Returns: A user-defined window function. - """ # noqa W505 + """ # noqa: W505, E501 if not callable(func): - raise TypeError("`func` must be callable.") - if not isinstance(func.__call__(), WindowEvaluator): - raise TypeError( - "`func` must implement the abstract base class WindowEvaluator" - ) + msg = "`func` must be callable." + raise TypeError(msg) + if not isinstance(func(), WindowEvaluator): + msg = "`func` must implement the abstract base class WindowEvaluator" + raise TypeError(msg) if name is None: - name = func.__call__().__class__.__qualname__.lower() - if isinstance(input_types, pyarrow.DataType): + name = func().__class__.__qualname__.lower() + if isinstance(input_types, pa.DataType): input_types = [input_types] return WindowUDF( name=name, @@ -666,3 +683,10 @@ def bias_10() -> BiasedNumbers: return_type=return_type, volatility=volatility, ) + + +# Convenience exports so we can import instead of treating as +# variables at the package root +udf = ScalarUDF.udf +udaf = AggregateUDF.udaf +udwf = WindowUDF.udwf diff --git a/python/tests/generic.py b/python/tests/generic.py index 0177e2df0..1b98fdf9e 100644 --- a/python/tests/generic.py +++ b/python/tests/generic.py @@ -16,6 +16,7 @@ # under the License. import datetime +from datetime import timezone import numpy as np import pyarrow as pa @@ -26,29 +27,29 @@ def data(): - np.random.seed(1) + rng = np.random.default_rng(1) data = np.concatenate( [ - np.random.normal(0, 0.01, size=50), - np.random.normal(50, 0.01, size=50), + rng.normal(0, 0.01, size=50), + rng.normal(50, 0.01, size=50), ] ) return pa.array(data) def data_with_nans(): - np.random.seed(0) - data = np.random.normal(0, 0.01, size=50) - mask = np.random.randint(0, 2, size=50) + rng = np.random.default_rng(0) + data = rng.normal(0, 0.01, size=50) + mask = rng.normal(0, 2, size=50) data[mask == 0] = np.nan return data def data_datetime(f): data = [ - datetime.datetime.now(), - datetime.datetime.now() - datetime.timedelta(days=1), - datetime.datetime.now() + datetime.timedelta(days=1), + datetime.datetime.now(tz=timezone.utc), + datetime.datetime.now(tz=timezone.utc) - datetime.timedelta(days=1), + datetime.datetime.now(tz=timezone.utc) + datetime.timedelta(days=1), ] return pa.array(data, type=pa.timestamp(f), mask=np.array([False, True, False])) diff --git a/python/tests/test_aggregation.py b/python/tests/test_aggregation.py index 5ef46131b..61b1c7d80 100644 --- a/python/tests/test_aggregation.py +++ b/python/tests/test_aggregation.py @@ -66,7 +66,7 @@ def df_aggregate_100(): @pytest.mark.parametrize( - "agg_expr, calc_expected", + ("agg_expr", "calc_expected"), [ (f.avg(column("a")), lambda a, b, c, d: np.array(np.average(a))), ( @@ -114,7 +114,7 @@ def test_aggregation_stats(df, agg_expr, calc_expected): @pytest.mark.parametrize( - "agg_expr, expected, array_sort", + ("agg_expr", "expected", "array_sort"), [ (f.approx_distinct(column("b")), pa.array([2], type=pa.uint64()), False), ( @@ -182,12 +182,11 @@ def test_aggregation(df, agg_expr, expected, array_sort): agg_df.show() result = agg_df.collect()[0] - print(result) assert result.column(0) == expected @pytest.mark.parametrize( - "name,expr,expected", + ("name", "expr", "expected"), [ ( "approx_percentile_cont", @@ -299,7 +298,9 @@ def test_aggregate_100(df_aggregate_100, name, expr, expected): ] -@pytest.mark.parametrize("name,expr,result", data_test_bitwise_and_boolean_functions) +@pytest.mark.parametrize( + ("name", "expr", "result"), data_test_bitwise_and_boolean_functions +) def test_bit_and_bool_fns(df, name, expr, result): df = df.aggregate([], [expr.alias(name)]) @@ -311,7 +312,7 @@ def test_bit_and_bool_fns(df, name, expr, result): @pytest.mark.parametrize( - "name,expr,result", + ("name", "expr", "result"), [ ("first_value", f.first_value(column("a")), [0, 4]), ( @@ -361,7 +362,6 @@ def test_bit_and_bool_fns(df, name, expr, result): ), [8, 9], ), - ("first_value", f.first_value(column("a")), [0, 4]), ( "nth_value_ordered", f.nth_value(column("a"), 2, order_by=[column("a").sort(ascending=False)]), @@ -401,7 +401,7 @@ def test_first_last_value(df_partitioned, name, expr, result) -> None: @pytest.mark.parametrize( - "name,expr,result", + ("name", "expr", "result"), [ ("string_agg", f.string_agg(column("a"), ","), "one,two,three,two"), ("string_agg", f.string_agg(column("b"), ""), "03124"), diff --git a/python/tests/test_catalog.py b/python/tests/test_catalog.py index 214f6b165..23b328458 100644 --- a/python/tests/test_catalog.py +++ b/python/tests/test_catalog.py @@ -19,6 +19,9 @@ import pytest +# Note we take in `database` as a variable even though we don't use +# it because that will cause the fixture to set up the context with +# the tables we need. def test_basic(ctx, database): with pytest.raises(KeyError): ctx.catalog("non-existent") @@ -26,10 +29,10 @@ def test_basic(ctx, database): default = ctx.catalog() assert default.names() == ["public"] - for database in [default.database("public"), default.database()]: - assert database.names() == {"csv1", "csv", "csv2"} + for db in [default.database("public"), default.database()]: + assert db.names() == {"csv1", "csv", "csv2"} - table = database.table("csv") + table = db.table("csv") assert table.kind == "physical" assert table.schema == pa.schema( [ diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 91046e6b8..7a0a7aa08 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -16,7 +16,6 @@ # under the License. import datetime as dt import gzip -import os import pathlib import pyarrow as pa @@ -45,7 +44,7 @@ def test_create_context_runtime_config_only(): SessionContext(runtime=RuntimeEnvBuilder()) -@pytest.mark.parametrize("path_to_str", (True, False)) +@pytest.mark.parametrize("path_to_str", [True, False]) def test_runtime_configs(tmp_path, path_to_str): path1 = tmp_path / "dir1" path2 = tmp_path / "dir2" @@ -62,7 +61,7 @@ def test_runtime_configs(tmp_path, path_to_str): assert db is not None -@pytest.mark.parametrize("path_to_str", (True, False)) +@pytest.mark.parametrize("path_to_str", [True, False]) def test_temporary_files(tmp_path, path_to_str): path = str(tmp_path) if path_to_str else tmp_path @@ -79,14 +78,14 @@ def test_create_context_with_all_valid_args(): runtime = RuntimeEnvBuilder().with_disk_manager_os().with_fair_spill_pool(10000000) config = ( SessionConfig() - .with_create_default_catalog_and_schema(True) + .with_create_default_catalog_and_schema(enabled=True) .with_default_catalog_and_schema("foo", "bar") .with_target_partitions(1) - .with_information_schema(True) - .with_repartition_joins(False) - .with_repartition_aggregations(False) - .with_repartition_windows(False) - .with_parquet_pruning(False) + .with_information_schema(enabled=True) + .with_repartition_joins(enabled=False) + .with_repartition_aggregations(enabled=False) + .with_repartition_windows(enabled=False) + .with_parquet_pruning(enabled=False) ) ctx = SessionContext(config, runtime) @@ -167,7 +166,7 @@ def test_from_arrow_table(ctx): def record_batch_generator(num_batches: int): schema = pa.schema([("a", pa.int64()), ("b", pa.int64())]) - for i in range(num_batches): + for _i in range(num_batches): yield pa.RecordBatch.from_arrays( [pa.array([1, 2, 3]), pa.array([4, 5, 6])], schema=schema ) @@ -492,10 +491,10 @@ def test_table_not_found(ctx): def test_read_json(ctx): - path = os.path.dirname(os.path.abspath(__file__)) + path = pathlib.Path(__file__).parent.resolve() # Default - test_data_path = os.path.join(path, "data_test_context", "data.json") + test_data_path = path / "data_test_context" / "data.json" df = ctx.read_json(test_data_path) result = df.collect() @@ -515,7 +514,7 @@ def test_read_json(ctx): assert result[0].schema == schema # File extension - test_data_path = os.path.join(path, "data_test_context", "data.json") + test_data_path = path / "data_test_context" / "data.json" df = ctx.read_json(test_data_path, file_extension=".json") result = df.collect() @@ -524,15 +523,17 @@ def test_read_json(ctx): def test_read_json_compressed(ctx, tmp_path): - path = os.path.dirname(os.path.abspath(__file__)) - test_data_path = os.path.join(path, "data_test_context", "data.json") + path = pathlib.Path(__file__).parent.resolve() + test_data_path = path / "data_test_context" / "data.json" # File compression type gzip_path = tmp_path / "data.json.gz" - with open(test_data_path, "rb") as csv_file: - with gzip.open(gzip_path, "wb") as gzipped_file: - gzipped_file.writelines(csv_file) + with ( + pathlib.Path.open(test_data_path, "rb") as csv_file, + gzip.open(gzip_path, "wb") as gzipped_file, + ): + gzipped_file.writelines(csv_file) df = ctx.read_json(gzip_path, file_extension=".gz", file_compression_type="gz") result = df.collect() @@ -563,14 +564,16 @@ def test_read_csv_list(ctx): def test_read_csv_compressed(ctx, tmp_path): - test_data_path = "testing/data/csv/aggregate_test_100.csv" + test_data_path = pathlib.Path("testing/data/csv/aggregate_test_100.csv") # File compression type gzip_path = tmp_path / "aggregate_test_100.csv.gz" - with open(test_data_path, "rb") as csv_file: - with gzip.open(gzip_path, "wb") as gzipped_file: - gzipped_file.writelines(csv_file) + with ( + pathlib.Path.open(test_data_path, "rb") as csv_file, + gzip.open(gzip_path, "wb") as gzipped_file, + ): + gzipped_file.writelines(csv_file) csv_df = ctx.read_csv(gzip_path, file_extension=".gz", file_compression_type="gz") csv_df.select(column("c1")).show() @@ -603,7 +606,7 @@ def test_create_sql_options(): def test_sql_with_options_no_ddl(ctx): sql = "CREATE TABLE IF NOT EXISTS valuetable AS VALUES(1,'HELLO'),(12,'DATAFUSION')" ctx.sql(sql) - options = SQLOptions().with_allow_ddl(False) + options = SQLOptions().with_allow_ddl(allow=False) with pytest.raises(Exception, match="DDL"): ctx.sql_with_options(sql, options=options) @@ -618,7 +621,7 @@ def test_sql_with_options_no_dml(ctx): ctx.register_dataset(table_name, dataset) sql = f'INSERT INTO "{table_name}" VALUES (1, 2), (2, 3);' ctx.sql(sql) - options = SQLOptions().with_allow_dml(False) + options = SQLOptions().with_allow_dml(allow=False) with pytest.raises(Exception, match="DML"): ctx.sql_with_options(sql, options=options) @@ -626,6 +629,6 @@ def test_sql_with_options_no_dml(ctx): def test_sql_with_options_no_statements(ctx): sql = "SET time zone = 1;" ctx.sql(sql) - options = SQLOptions().with_allow_statements(False) + options = SQLOptions().with_allow_statements(allow=False) with pytest.raises(Exception, match="SetVariable"): ctx.sql_with_options(sql, options=options) diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index c636e896a..d084f12dd 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -339,7 +339,7 @@ def test_join(): # Verify we don't make a breaking change to pre-43.0.0 # where users would pass join_keys as a positional argument - df2 = df.join(df1, (["a"], ["a"]), how="inner") # type: ignore + df2 = df.join(df1, (["a"], ["a"]), how="inner") df2.show() df2 = df2.sort(column("l.a")) table = pa.Table.from_batches(df2.collect()) @@ -375,17 +375,17 @@ def test_join_invalid_params(): with pytest.raises( ValueError, match=r"`left_on` or `right_on` should not provided with `on`" ): - df2 = df.join(df1, on="a", how="inner", right_on="test") # type: ignore + df2 = df.join(df1, on="a", how="inner", right_on="test") with pytest.raises( ValueError, match=r"`left_on` and `right_on` should both be provided." ): - df2 = df.join(df1, left_on="a", how="inner") # type: ignore + df2 = df.join(df1, left_on="a", how="inner") with pytest.raises( ValueError, match=r"either `on` or `left_on` and `right_on` should be provided." ): - df2 = df.join(df1, how="inner") # type: ignore + df2 = df.join(df1, how="inner") def test_join_on(): @@ -567,7 +567,7 @@ def test_distinct(): ] -@pytest.mark.parametrize("name,expr,result", data_test_window_functions) +@pytest.mark.parametrize(("name", "expr", "result"), data_test_window_functions) def test_window_functions(partitioned_df, name, expr, result): df = partitioned_df.select( column("a"), column("b"), column("c"), f.alias(expr, name) @@ -731,7 +731,7 @@ def test_execution_plan(aggregate_df): plan = aggregate_df.execution_plan() expected = ( - "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[sum(test.c2)]\n" # noqa: E501 + "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[sum(test.c2)]\n" ) assert expected == plan.display() @@ -756,7 +756,7 @@ def test_execution_plan(aggregate_df): ctx = SessionContext() rows_returned = 0 - for idx in range(0, plan.partition_count): + for idx in range(plan.partition_count): stream = ctx.execute(plan, idx) try: batch = stream.next() @@ -885,7 +885,7 @@ def test_union_distinct(ctx): ) df_c = ctx.create_dataframe([[batch]]).sort(column("a")) - df_a_u_b = df_a.union(df_b, True).sort(column("a")) + df_a_u_b = df_a.union(df_b, distinct=True).sort(column("a")) assert df_c.collect() == df_a_u_b.collect() assert df_c.collect() == df_a_u_b.collect() @@ -954,8 +954,6 @@ def test_to_arrow_table(df): def test_execute_stream(df): stream = df.execute_stream() - for s in stream: - print(type(s)) assert all(batch is not None for batch in stream) assert not list(stream) # after one iteration the generator must be exhausted @@ -969,7 +967,7 @@ def test_execute_stream_to_arrow_table(df, schema): (batch.to_pyarrow() for batch in stream), schema=df.schema() ) else: - pyarrow_table = pa.Table.from_batches((batch.to_pyarrow() for batch in stream)) + pyarrow_table = pa.Table.from_batches(batch.to_pyarrow() for batch in stream) assert isinstance(pyarrow_table, pa.Table) assert pyarrow_table.shape == (3, 3) @@ -1033,7 +1031,7 @@ def test_describe(df): } -@pytest.mark.parametrize("path_to_str", (True, False)) +@pytest.mark.parametrize("path_to_str", [True, False]) def test_write_csv(ctx, df, tmp_path, path_to_str): path = str(tmp_path) if path_to_str else tmp_path @@ -1046,7 +1044,7 @@ def test_write_csv(ctx, df, tmp_path, path_to_str): assert result == expected -@pytest.mark.parametrize("path_to_str", (True, False)) +@pytest.mark.parametrize("path_to_str", [True, False]) def test_write_json(ctx, df, tmp_path, path_to_str): path = str(tmp_path) if path_to_str else tmp_path @@ -1059,7 +1057,7 @@ def test_write_json(ctx, df, tmp_path, path_to_str): assert result == expected -@pytest.mark.parametrize("path_to_str", (True, False)) +@pytest.mark.parametrize("path_to_str", [True, False]) def test_write_parquet(df, tmp_path, path_to_str): path = str(tmp_path) if path_to_str else tmp_path @@ -1071,7 +1069,7 @@ def test_write_parquet(df, tmp_path, path_to_str): @pytest.mark.parametrize( - "compression, compression_level", + ("compression", "compression_level"), [("gzip", 6), ("brotli", 7), ("zstd", 15)], ) def test_write_compressed_parquet(df, tmp_path, compression, compression_level): @@ -1082,7 +1080,7 @@ def test_write_compressed_parquet(df, tmp_path, compression, compression_level): ) # test that the actual compression scheme is the one written - for root, dirs, files in os.walk(path): + for _root, _dirs, files in os.walk(path): for file in files: if file.endswith(".parquet"): metadata = pq.ParquetFile(tmp_path / file).metadata.to_dict() @@ -1097,7 +1095,7 @@ def test_write_compressed_parquet(df, tmp_path, compression, compression_level): @pytest.mark.parametrize( - "compression, compression_level", + ("compression", "compression_level"), [("gzip", 12), ("brotli", 15), ("zstd", 23), ("wrong", 12)], ) def test_write_compressed_parquet_wrong_compression_level( @@ -1152,7 +1150,7 @@ def test_dataframe_export(df) -> None: table = pa.table(df, schema=desired_schema) assert table.num_columns == 1 assert table.num_rows == 3 - for i in range(0, 3): + for i in range(3): assert table[0][i].as_py() is None # Expect an error when we cannot convert schema @@ -1186,8 +1184,8 @@ def add_with_parameter(df_internal, value: Any) -> DataFrame: result = df.to_pydict() assert result["a"] == [1, 2, 3] - assert result["string_col"] == ["string data" for _i in range(0, 3)] - assert result["new_col"] == [3 for _i in range(0, 3)] + assert result["string_col"] == ["string data" for _i in range(3)] + assert result["new_col"] == [3 for _i in range(3)] def test_dataframe_repr_html(df) -> None: diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index 354c7e180..926e69845 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -85,18 +85,14 @@ def test_limit(test_ctx): plan = plan.to_variant() assert isinstance(plan, Limit) - # TODO: Upstream now has expressions for skip and fetch - # REF: https://github.com/apache/datafusion/pull/12836 - # assert plan.skip() == 0 + assert "Skip: None" in str(plan) df = test_ctx.sql("select c1 from test LIMIT 10 OFFSET 5") plan = df.logical_plan() plan = plan.to_variant() assert isinstance(plan, Limit) - # TODO: Upstream now has expressions for skip and fetch - # REF: https://github.com/apache/datafusion/pull/12836 - # assert plan.skip() == 5 + assert "Skip: Some(Literal(Int64(5)))" in str(plan) def test_aggregate_query(test_ctx): @@ -165,6 +161,7 @@ def traverse_logical_plan(plan): res = traverse_logical_plan(input_plan) if res is not None: return res + return None ctx = SessionContext() data = {"id": [1, 2, 3], "name": ["Alice", "Bob", "Charlie"]} @@ -176,7 +173,7 @@ def traverse_logical_plan(plan): assert variant.expr().to_variant().qualified_name() == "table1.name" assert ( str(variant.list()) - == '[Expr(Utf8("dfa")), Expr(Utf8("ad")), Expr(Utf8("dfre")), Expr(Utf8("vsa"))]' + == '[Expr(Utf8("dfa")), Expr(Utf8("ad")), Expr(Utf8("dfre")), Expr(Utf8("vsa"))]' # noqa: E501 ) assert not variant.negated() diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index fca05bb8f..ed88a16e3 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import math -from datetime import datetime +from datetime import datetime, timezone import numpy as np import pyarrow as pa @@ -25,6 +25,8 @@ np.seterr(invalid="ignore") +DEFAULT_TZ = timezone.utc + @pytest.fixture def df(): @@ -37,9 +39,9 @@ def df(): pa.array(["hello ", " world ", " !"], type=pa.string_view()), pa.array( [ - datetime(2022, 12, 31), - datetime(2027, 6, 26), - datetime(2020, 7, 2), + datetime(2022, 12, 31, tzinfo=DEFAULT_TZ), + datetime(2027, 6, 26, tzinfo=DEFAULT_TZ), + datetime(2020, 7, 2, tzinfo=DEFAULT_TZ), ] ), pa.array([False, True, True]), @@ -221,12 +223,12 @@ def py_indexof(arr, v): def py_arr_remove(arr, v, n=None): new_arr = arr[:] found = 0 - while found != n: - try: + try: + while found != n: new_arr.remove(v) found += 1 - except ValueError: - break + except ValueError: + pass return new_arr @@ -234,13 +236,13 @@ def py_arr_remove(arr, v, n=None): def py_arr_replace(arr, from_, to, n=None): new_arr = arr[:] found = 0 - while found != n: - try: + try: + while found != n: idx = new_arr.index(from_) new_arr[idx] = to found += 1 - except ValueError: - break + except ValueError: + pass return new_arr @@ -268,266 +270,266 @@ def py_flatten(arr): @pytest.mark.parametrize( ("stmt", "py_expr"), [ - [ + ( lambda col: f.array_append(col, literal(99.0)), lambda data: [np.append(arr, 99.0) for arr in data], - ], - [ + ), + ( lambda col: f.array_push_back(col, literal(99.0)), lambda data: [np.append(arr, 99.0) for arr in data], - ], - [ + ), + ( lambda col: f.list_append(col, literal(99.0)), lambda data: [np.append(arr, 99.0) for arr in data], - ], - [ + ), + ( lambda col: f.list_push_back(col, literal(99.0)), lambda data: [np.append(arr, 99.0) for arr in data], - ], - [ + ), + ( lambda col: f.array_concat(col, col), lambda data: [np.concatenate([arr, arr]) for arr in data], - ], - [ + ), + ( lambda col: f.array_cat(col, col), lambda data: [np.concatenate([arr, arr]) for arr in data], - ], - [ + ), + ( lambda col: f.list_cat(col, col), lambda data: [np.concatenate([arr, arr]) for arr in data], - ], - [ + ), + ( lambda col: f.list_concat(col, col), lambda data: [np.concatenate([arr, arr]) for arr in data], - ], - [ + ), + ( lambda col: f.array_dims(col), lambda data: [[len(r)] for r in data], - ], - [ + ), + ( lambda col: f.array_distinct(col), lambda data: [list(set(r)) for r in data], - ], - [ + ), + ( lambda col: f.list_distinct(col), lambda data: [list(set(r)) for r in data], - ], - [ + ), + ( lambda col: f.list_dims(col), lambda data: [[len(r)] for r in data], - ], - [ + ), + ( lambda col: f.array_element(col, literal(1)), lambda data: [r[0] for r in data], - ], - [ + ), + ( lambda col: f.array_empty(col), lambda data: [len(r) == 0 for r in data], - ], - [ + ), + ( lambda col: f.empty(col), lambda data: [len(r) == 0 for r in data], - ], - [ + ), + ( lambda col: f.array_extract(col, literal(1)), lambda data: [r[0] for r in data], - ], - [ + ), + ( lambda col: f.list_element(col, literal(1)), lambda data: [r[0] for r in data], - ], - [ + ), + ( lambda col: f.list_extract(col, literal(1)), lambda data: [r[0] for r in data], - ], - [ + ), + ( lambda col: f.array_length(col), lambda data: [len(r) for r in data], - ], - [ + ), + ( lambda col: f.list_length(col), lambda data: [len(r) for r in data], - ], - [ + ), + ( lambda col: f.array_has(col, literal(1.0)), lambda data: [1.0 in r for r in data], - ], - [ + ), + ( lambda col: f.array_has_all( col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]]) ), lambda data: [np.all([v in r for v in [1.0, 3.0, 5.0]]) for r in data], - ], - [ + ), + ( lambda col: f.array_has_any( col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]]) ), lambda data: [np.any([v in r for v in [1.0, 3.0, 5.0]]) for r in data], - ], - [ + ), + ( lambda col: f.array_position(col, literal(1.0)), lambda data: [py_indexof(r, 1.0) for r in data], - ], - [ + ), + ( lambda col: f.array_indexof(col, literal(1.0)), lambda data: [py_indexof(r, 1.0) for r in data], - ], - [ + ), + ( lambda col: f.list_position(col, literal(1.0)), lambda data: [py_indexof(r, 1.0) for r in data], - ], - [ + ), + ( lambda col: f.list_indexof(col, literal(1.0)), lambda data: [py_indexof(r, 1.0) for r in data], - ], - [ + ), + ( lambda col: f.array_positions(col, literal(1.0)), lambda data: [[i + 1 for i, _v in enumerate(r) if _v == 1.0] for r in data], - ], - [ + ), + ( lambda col: f.list_positions(col, literal(1.0)), lambda data: [[i + 1 for i, _v in enumerate(r) if _v == 1.0] for r in data], - ], - [ + ), + ( lambda col: f.array_ndims(col), lambda data: [np.array(r).ndim for r in data], - ], - [ + ), + ( lambda col: f.list_ndims(col), lambda data: [np.array(r).ndim for r in data], - ], - [ + ), + ( lambda col: f.array_prepend(literal(99.0), col), lambda data: [np.insert(arr, 0, 99.0) for arr in data], - ], - [ + ), + ( lambda col: f.array_push_front(literal(99.0), col), lambda data: [np.insert(arr, 0, 99.0) for arr in data], - ], - [ + ), + ( lambda col: f.list_prepend(literal(99.0), col), lambda data: [np.insert(arr, 0, 99.0) for arr in data], - ], - [ + ), + ( lambda col: f.list_push_front(literal(99.0), col), lambda data: [np.insert(arr, 0, 99.0) for arr in data], - ], - [ + ), + ( lambda col: f.array_pop_back(col), lambda data: [arr[:-1] for arr in data], - ], - [ + ), + ( lambda col: f.array_pop_front(col), lambda data: [arr[1:] for arr in data], - ], - [ + ), + ( lambda col: f.array_remove(col, literal(3.0)), lambda data: [py_arr_remove(arr, 3.0, 1) for arr in data], - ], - [ + ), + ( lambda col: f.list_remove(col, literal(3.0)), lambda data: [py_arr_remove(arr, 3.0, 1) for arr in data], - ], - [ + ), + ( lambda col: f.array_remove_n(col, literal(3.0), literal(2)), lambda data: [py_arr_remove(arr, 3.0, 2) for arr in data], - ], - [ + ), + ( lambda col: f.list_remove_n(col, literal(3.0), literal(2)), lambda data: [py_arr_remove(arr, 3.0, 2) for arr in data], - ], - [ + ), + ( lambda col: f.array_remove_all(col, literal(3.0)), lambda data: [py_arr_remove(arr, 3.0) for arr in data], - ], - [ + ), + ( lambda col: f.list_remove_all(col, literal(3.0)), lambda data: [py_arr_remove(arr, 3.0) for arr in data], - ], - [ + ), + ( lambda col: f.array_repeat(col, literal(2)), lambda data: [[arr] * 2 for arr in data], - ], - [ + ), + ( lambda col: f.list_repeat(col, literal(2)), lambda data: [[arr] * 2 for arr in data], - ], - [ + ), + ( lambda col: f.array_replace(col, literal(3.0), literal(4.0)), lambda data: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data], - ], - [ + ), + ( lambda col: f.list_replace(col, literal(3.0), literal(4.0)), lambda data: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data], - ], - [ + ), + ( lambda col: f.array_replace_n(col, literal(3.0), literal(4.0), literal(1)), lambda data: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data], - ], - [ + ), + ( lambda col: f.list_replace_n(col, literal(3.0), literal(4.0), literal(2)), lambda data: [py_arr_replace(arr, 3.0, 4.0, 2) for arr in data], - ], - [ + ), + ( lambda col: f.array_replace_all(col, literal(3.0), literal(4.0)), lambda data: [py_arr_replace(arr, 3.0, 4.0) for arr in data], - ], - [ + ), + ( lambda col: f.list_replace_all(col, literal(3.0), literal(4.0)), lambda data: [py_arr_replace(arr, 3.0, 4.0) for arr in data], - ], - [ + ), + ( lambda col: f.array_sort(col, descending=True, null_first=True), lambda data: [np.sort(arr)[::-1] for arr in data], - ], - [ + ), + ( lambda col: f.list_sort(col, descending=False, null_first=False), lambda data: [np.sort(arr) for arr in data], - ], - [ + ), + ( lambda col: f.array_slice(col, literal(2), literal(4)), lambda data: [arr[1:4] for arr in data], - ], + ), pytest.param( lambda col: f.list_slice(col, literal(-1), literal(2)), lambda data: [arr[-1:2] for arr in data], ), - [ + ( lambda col: f.array_intersect(col, literal([3.0, 4.0])), lambda data: [np.intersect1d(arr, [3.0, 4.0]) for arr in data], - ], - [ + ), + ( lambda col: f.list_intersect(col, literal([3.0, 4.0])), lambda data: [np.intersect1d(arr, [3.0, 4.0]) for arr in data], - ], - [ + ), + ( lambda col: f.array_union(col, literal([12.0, 999.0])), lambda data: [np.union1d(arr, [12.0, 999.0]) for arr in data], - ], - [ + ), + ( lambda col: f.list_union(col, literal([12.0, 999.0])), lambda data: [np.union1d(arr, [12.0, 999.0]) for arr in data], - ], - [ + ), + ( lambda col: f.array_except(col, literal([3.0])), lambda data: [np.setdiff1d(arr, [3.0]) for arr in data], - ], - [ + ), + ( lambda col: f.list_except(col, literal([3.0])), lambda data: [np.setdiff1d(arr, [3.0]) for arr in data], - ], - [ + ), + ( lambda col: f.array_resize(col, literal(10), literal(0.0)), lambda data: [py_arr_resize(arr, 10, 0.0) for arr in data], - ], - [ + ), + ( lambda col: f.list_resize(col, literal(10), literal(0.0)), lambda data: [py_arr_resize(arr, 10, 0.0) for arr in data], - ], - [ + ), + ( lambda col: f.range(literal(1), literal(5), literal(2)), lambda data: [np.arange(1, 5, 2)], - ], + ), ], ) def test_array_functions(stmt, py_expr): @@ -611,22 +613,22 @@ def test_make_array_functions(make_func): @pytest.mark.parametrize( ("stmt", "py_expr"), [ - [ + ( f.array_to_string(column("arr"), literal(",")), lambda data: [",".join([str(int(v)) for v in r]) for r in data], - ], - [ + ), + ( f.array_join(column("arr"), literal(",")), lambda data: [",".join([str(int(v)) for v in r]) for r in data], - ], - [ + ), + ( f.list_to_string(column("arr"), literal(",")), lambda data: [",".join([str(int(v)) for v in r]) for r in data], - ], - [ + ), + ( f.list_join(column("arr"), literal(",")), lambda data: [",".join([str(int(v)) for v in r]) for r in data], - ], + ), ], ) def test_array_function_obj_tests(stmt, py_expr): @@ -640,7 +642,7 @@ def test_array_function_obj_tests(stmt, py_expr): @pytest.mark.parametrize( - "function, expected_result", + ("function", "expected_result"), [ ( f.ascii(column("a")), @@ -894,54 +896,72 @@ def test_temporal_functions(df): assert result.column(0) == pa.array([12, 6, 7], type=pa.int32()) assert result.column(1) == pa.array([2022, 2027, 2020], type=pa.int32()) assert result.column(2) == pa.array( - [datetime(2022, 12, 1), datetime(2027, 6, 1), datetime(2020, 7, 1)], - type=pa.timestamp("us"), + [ + datetime(2022, 12, 1, tzinfo=DEFAULT_TZ), + datetime(2027, 6, 1, tzinfo=DEFAULT_TZ), + datetime(2020, 7, 1, tzinfo=DEFAULT_TZ), + ], + type=pa.timestamp("ns", tz=DEFAULT_TZ), ) assert result.column(3) == pa.array( - [datetime(2022, 12, 31), datetime(2027, 6, 26), datetime(2020, 7, 2)], - type=pa.timestamp("us"), + [ + datetime(2022, 12, 31, tzinfo=DEFAULT_TZ), + datetime(2027, 6, 26, tzinfo=DEFAULT_TZ), + datetime(2020, 7, 2, tzinfo=DEFAULT_TZ), + ], + type=pa.timestamp("ns", tz=DEFAULT_TZ), ) assert result.column(4) == pa.array( [ - datetime(2022, 12, 30, 23, 47, 30), - datetime(2027, 6, 25, 23, 47, 30), - datetime(2020, 7, 1, 23, 47, 30), + datetime(2022, 12, 30, 23, 47, 30, tzinfo=DEFAULT_TZ), + datetime(2027, 6, 25, 23, 47, 30, tzinfo=DEFAULT_TZ), + datetime(2020, 7, 1, 23, 47, 30, tzinfo=DEFAULT_TZ), ], - type=pa.timestamp("ns"), + type=pa.timestamp("ns", tz=DEFAULT_TZ), ) assert result.column(5) == pa.array( - [datetime(2023, 1, 10, 20, 52, 54)] * 3, type=pa.timestamp("s") + [datetime(2023, 1, 10, 20, 52, 54, tzinfo=DEFAULT_TZ)] * 3, + type=pa.timestamp("s"), ) assert result.column(6) == pa.array( - [datetime(2023, 9, 7, 5, 6, 14, 523952)] * 3, type=pa.timestamp("ns") + [datetime(2023, 9, 7, 5, 6, 14, 523952, tzinfo=DEFAULT_TZ)] * 3, + type=pa.timestamp("ns"), ) assert result.column(7) == pa.array( - [datetime(2023, 9, 7, 5, 6, 14)] * 3, type=pa.timestamp("s") + [datetime(2023, 9, 7, 5, 6, 14, tzinfo=DEFAULT_TZ)] * 3, type=pa.timestamp("s") ) assert result.column(8) == pa.array( - [datetime(2023, 9, 7, 5, 6, 14, 523000)] * 3, type=pa.timestamp("ms") + [datetime(2023, 9, 7, 5, 6, 14, 523000, tzinfo=DEFAULT_TZ)] * 3, + type=pa.timestamp("ms"), ) assert result.column(9) == pa.array( - [datetime(2023, 9, 7, 5, 6, 14, 523952)] * 3, type=pa.timestamp("us") + [datetime(2023, 9, 7, 5, 6, 14, 523952, tzinfo=DEFAULT_TZ)] * 3, + type=pa.timestamp("us"), ) assert result.column(10) == pa.array([31, 26, 2], type=pa.int32()) assert result.column(11) == pa.array( - [datetime(2023, 9, 7, 5, 6, 14, 523952)] * 3, type=pa.timestamp("ns") + [datetime(2023, 9, 7, 5, 6, 14, 523952, tzinfo=DEFAULT_TZ)] * 3, + type=pa.timestamp("ns"), ) assert result.column(12) == pa.array( - [datetime(2023, 9, 7, 5, 6, 14)] * 3, type=pa.timestamp("s") + [datetime(2023, 9, 7, 5, 6, 14, tzinfo=DEFAULT_TZ)] * 3, + type=pa.timestamp("s"), ) assert result.column(13) == pa.array( - [datetime(2023, 9, 7, 5, 6, 14, 523000)] * 3, type=pa.timestamp("ms") + [datetime(2023, 9, 7, 5, 6, 14, 523000, tzinfo=DEFAULT_TZ)] * 3, + type=pa.timestamp("ms"), ) assert result.column(14) == pa.array( - [datetime(2023, 9, 7, 5, 6, 14, 523952)] * 3, type=pa.timestamp("us") + [datetime(2023, 9, 7, 5, 6, 14, 523952, tzinfo=DEFAULT_TZ)] * 3, + type=pa.timestamp("us"), ) assert result.column(15) == pa.array( - [datetime(2023, 9, 7, 5, 6, 14, 523952)] * 3, type=pa.timestamp("ns") + [datetime(2023, 9, 7, 5, 6, 14, 523952, tzinfo=DEFAULT_TZ)] * 3, + type=pa.timestamp("ns"), ) assert result.column(16) == pa.array( - [datetime(2023, 9, 7, 5, 6, 14, 523952)] * 3, type=pa.timestamp("ns") + [datetime(2023, 9, 7, 5, 6, 14, 523952, tzinfo=DEFAULT_TZ)] * 3, + type=pa.timestamp("ns"), ) @@ -1057,7 +1077,7 @@ def test_regr_funcs_sql_2(): @pytest.mark.parametrize( - "func, expected", + ("func", "expected"), [ pytest.param(f.regr_slope(column("c2"), column("c1")), [4.6], id="regr_slope"), pytest.param( @@ -1160,7 +1180,7 @@ def test_binary_string_functions(df): @pytest.mark.parametrize( - "python_datatype, name, expected", + ("python_datatype", "name", "expected"), [ pytest.param(bool, "e", pa.bool_(), id="bool"), pytest.param(int, "b", pa.int64(), id="int"), @@ -1179,7 +1199,7 @@ def test_cast(df, python_datatype, name: str, expected): @pytest.mark.parametrize( - "negated, low, high, expected", + ("negated", "low", "high", "expected"), [ pytest.param(False, 3, 5, {"filtered": [4, 5]}), pytest.param(False, 4, 5, {"filtered": [4, 5]}), diff --git a/python/tests/test_imports.py b/python/tests/test_imports.py index 0c155cbde..9ef7ed89a 100644 --- a/python/tests/test_imports.py +++ b/python/tests/test_imports.py @@ -169,14 +169,15 @@ def test_class_module_is_datafusion(): def test_import_from_functions_submodule(): - from datafusion.functions import abs, sin # noqa + from datafusion.functions import abs as df_abs + from datafusion.functions import sin - assert functions.abs is abs + assert functions.abs is df_abs assert functions.sin is sin msg = "cannot import name 'foobar' from 'datafusion.functions'" with pytest.raises(ImportError, match=msg): - from datafusion.functions import foobar # noqa + from datafusion.functions import foobar # noqa: F401 def test_classes_are_inheritable(): diff --git a/python/tests/test_input.py b/python/tests/test_input.py index 806471357..4663f6148 100644 --- a/python/tests/test_input.py +++ b/python/tests/test_input.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -import os +import pathlib from datafusion.input.location import LocationInputPlugin @@ -23,10 +23,10 @@ def test_location_input(): location_input = LocationInputPlugin() - cwd = os.getcwd() - input_file = cwd + "/testing/data/parquet/generated_simple_numerics/blogs.parquet" + cwd = pathlib.Path.cwd() + input_file = cwd / "testing/data/parquet/generated_simple_numerics/blogs.parquet" table_name = "blog" - tbl = location_input.build_table(input_file, table_name) - assert "blog" == tbl.name - assert 3 == len(tbl.columns) + tbl = location_input.build_table(str(input_file), table_name) + assert tbl.name == "blog" + assert len(tbl.columns) == 3 assert "blogs.parquet" in tbl.filepaths[0] diff --git a/python/tests/test_io.py b/python/tests/test_io.py index 21ad188ee..7ca509689 100644 --- a/python/tests/test_io.py +++ b/python/tests/test_io.py @@ -14,8 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import os -import pathlib +from pathlib import Path import pyarrow as pa from datafusion import column @@ -23,10 +22,10 @@ def test_read_json_global_ctx(ctx): - path = os.path.dirname(os.path.abspath(__file__)) + path = Path(__file__).parent.resolve() # Default - test_data_path = os.path.join(path, "data_test_context", "data.json") + test_data_path = Path(path) / "data_test_context" / "data.json" df = read_json(test_data_path) result = df.collect() @@ -46,7 +45,7 @@ def test_read_json_global_ctx(ctx): assert result[0].schema == schema # File extension - test_data_path = os.path.join(path, "data_test_context", "data.json") + test_data_path = Path(path) / "data_test_context" / "data.json" df = read_json(test_data_path, file_extension=".json") result = df.collect() @@ -59,7 +58,7 @@ def test_read_parquet_global(): parquet_df.show() assert parquet_df is not None - path = pathlib.Path.cwd() / "parquet/data/alltypes_plain.parquet" + path = Path.cwd() / "parquet/data/alltypes_plain.parquet" parquet_df = read_parquet(path=path) assert parquet_df is not None @@ -90,6 +89,6 @@ def test_read_avro(): avro_df.show() assert avro_df is not None - path = pathlib.Path.cwd() / "testing/data/avro/alltypes_plain.avro" + path = Path.cwd() / "testing/data/avro/alltypes_plain.avro" avro_df = read_avro(path=path) assert avro_df is not None diff --git a/python/tests/test_sql.py b/python/tests/test_sql.py index 862f745bf..b6348e3a0 100644 --- a/python/tests/test_sql.py +++ b/python/tests/test_sql.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import gzip -import os +from pathlib import Path import numpy as np import pyarrow as pa @@ -47,9 +47,8 @@ def test_register_csv(ctx, tmp_path): ) write_csv(table, path) - with open(path, "rb") as csv_file: - with gzip.open(gzip_path, "wb") as gzipped_file: - gzipped_file.writelines(csv_file) + with Path.open(path, "rb") as csv_file, gzip.open(gzip_path, "wb") as gzipped_file: + gzipped_file.writelines(csv_file) ctx.register_csv("csv", path) ctx.register_csv("csv1", str(path)) @@ -158,7 +157,7 @@ def test_register_parquet(ctx, tmp_path): assert result.to_pydict() == {"cnt": [100]} -@pytest.mark.parametrize("path_to_str", (True, False)) +@pytest.mark.parametrize("path_to_str", [True, False]) def test_register_parquet_partitioned(ctx, tmp_path, path_to_str): dir_root = tmp_path / "dataset_parquet_partitioned" dir_root.mkdir(exist_ok=False) @@ -194,7 +193,7 @@ def test_register_parquet_partitioned(ctx, tmp_path, path_to_str): assert dict(zip(rd["grp"], rd["cnt"])) == {"a": 3, "b": 1} -@pytest.mark.parametrize("path_to_str", (True, False)) +@pytest.mark.parametrize("path_to_str", [True, False]) def test_register_dataset(ctx, tmp_path, path_to_str): path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data()) path = str(path) if path_to_str else path @@ -209,13 +208,15 @@ def test_register_dataset(ctx, tmp_path, path_to_str): def test_register_json(ctx, tmp_path): - path = os.path.dirname(os.path.abspath(__file__)) - test_data_path = os.path.join(path, "data_test_context", "data.json") + path = Path(__file__).parent.resolve() + test_data_path = Path(path) / "data_test_context" / "data.json" gzip_path = tmp_path / "data.json.gz" - with open(test_data_path, "rb") as json_file: - with gzip.open(gzip_path, "wb") as gzipped_file: - gzipped_file.writelines(json_file) + with ( + Path.open(test_data_path, "rb") as json_file, + gzip.open(gzip_path, "wb") as gzipped_file, + ): + gzipped_file.writelines(json_file) ctx.register_json("json", test_data_path) ctx.register_json("json1", str(test_data_path)) @@ -470,16 +471,18 @@ def test_simple_select(ctx, tmp_path, arr): # In DF 43.0.0 we now default to having BinaryView and StringView # so the array that is saved to the parquet is slightly different # than the array read. Convert to values for comparison. - if isinstance(result, pa.BinaryViewArray) or isinstance(result, pa.StringViewArray): + if isinstance(result, (pa.BinaryViewArray, pa.StringViewArray)): arr = arr.tolist() result = result.tolist() np.testing.assert_equal(result, arr) -@pytest.mark.parametrize("file_sort_order", (None, [[col("int").sort(True, True)]])) -@pytest.mark.parametrize("pass_schema", (True, False)) -@pytest.mark.parametrize("path_to_str", (True, False)) +@pytest.mark.parametrize( + "file_sort_order", [None, [[col("int").sort(ascending=True, nulls_first=True)]]] +) +@pytest.mark.parametrize("pass_schema", [True, False]) +@pytest.mark.parametrize("path_to_str", [True, False]) def test_register_listing_table( ctx, tmp_path, pass_schema, file_sort_order, path_to_str ): @@ -528,7 +531,7 @@ def test_register_listing_table( assert dict(zip(rd["grp"], rd["count"])) == {"a": 5, "b": 2} result = ctx.sql( - "SELECT grp, COUNT(*) AS count FROM my_table WHERE date_id=20201005 GROUP BY grp" + "SELECT grp, COUNT(*) AS count FROM my_table WHERE date_id=20201005 GROUP BY grp" # noqa: E501 ).collect() result = pa.Table.from_batches(result) diff --git a/python/tests/test_store.py b/python/tests/test_store.py index 53ffc3acf..ac9af98f3 100644 --- a/python/tests/test_store.py +++ b/python/tests/test_store.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -import os +from pathlib import Path import pytest from datafusion import SessionContext @@ -23,17 +23,16 @@ @pytest.fixture def ctx(): - ctx = SessionContext() - return ctx + return SessionContext() def test_read_parquet(ctx): ctx.register_parquet( "test", - f"file://{os.getcwd()}/parquet/data/alltypes_plain.parquet", - [], - True, - ".parquet", + f"file://{Path.cwd()}/parquet/data/alltypes_plain.parquet", + table_partition_cols=[], + parquet_pruning=True, + file_extension=".parquet", ) df = ctx.sql("SELECT * FROM test") assert isinstance(df.collect(), list) diff --git a/python/tests/test_substrait.py b/python/tests/test_substrait.py index feada7cde..f367a447d 100644 --- a/python/tests/test_substrait.py +++ b/python/tests/test_substrait.py @@ -50,7 +50,7 @@ def test_substrait_serialization(ctx): substrait_plan = ss.Producer.to_substrait_plan(df.logical_plan(), ctx) -@pytest.mark.parametrize("path_to_str", (True, False)) +@pytest.mark.parametrize("path_to_str", [True, False]) def test_substrait_file_serialization(ctx, tmp_path, path_to_str): batch = pa.RecordBatch.from_arrays( [pa.array([1, 2, 3]), pa.array([4, 5, 6])], diff --git a/python/tests/test_udaf.py b/python/tests/test_udaf.py index 97cf81f3c..453ff6f4f 100644 --- a/python/tests/test_udaf.py +++ b/python/tests/test_udaf.py @@ -17,8 +17,6 @@ from __future__ import annotations -from typing import List - import pyarrow as pa import pyarrow.compute as pc import pytest @@ -31,7 +29,7 @@ class Summarize(Accumulator): def __init__(self, initial_value: float = 0.0): self._sum = pa.scalar(initial_value) - def state(self) -> List[pa.Scalar]: + def state(self) -> list[pa.Scalar]: return [self._sum] def update(self, values: pa.Array) -> None: @@ -39,7 +37,7 @@ def update(self, values: pa.Array) -> None: # This breaks on `None` self._sum = pa.scalar(self._sum.as_py() + pc.sum(values).as_py()) - def merge(self, states: List[pa.Array]) -> None: + def merge(self, states: list[pa.Array]) -> None: # Not nice since pyarrow scalars can't be summed yet. # This breaks on `None` self._sum = pa.scalar(self._sum.as_py() + pc.sum(states[0]).as_py()) @@ -56,7 +54,7 @@ class MissingMethods(Accumulator): def __init__(self): self._sum = pa.scalar(0) - def state(self) -> List[pa.Scalar]: + def state(self) -> list[pa.Scalar]: return [self._sum] @@ -86,7 +84,7 @@ def test_errors(df): "evaluate, merge, update)" ) with pytest.raises(Exception, match=msg): - accum = udaf( # noqa F841 + accum = udaf( # noqa: F841 MissingMethods, pa.int64(), pa.int64(), diff --git a/python/tests/test_udwf.py b/python/tests/test_udwf.py index 2fea34aa3..3d6dcf9d8 100644 --- a/python/tests/test_udwf.py +++ b/python/tests/test_udwf.py @@ -298,7 +298,7 @@ def test_udwf_errors(df): ] -@pytest.mark.parametrize("name,expr,expected", data_test_udwf_functions) +@pytest.mark.parametrize(("name", "expr", "expected"), data_test_udwf_functions) def test_udwf_functions(df, name, expr, expected): df = df.select("a", "b", f.round(expr, lit(3)).alias(name)) diff --git a/python/tests/test_wrapper_coverage.py b/python/tests/test_wrapper_coverage.py index ac064ba95..d7f6f6e35 100644 --- a/python/tests/test_wrapper_coverage.py +++ b/python/tests/test_wrapper_coverage.py @@ -19,6 +19,7 @@ import datafusion.functions import datafusion.object_store import datafusion.substrait +import pytest # EnumType introduced in 3.11. 3.10 and prior it was called EnumMeta. try: @@ -41,10 +42,8 @@ def missing_exports(internal_obj, wrapped_obj) -> None: internal_attr = getattr(internal_obj, attr) wrapped_attr = getattr(wrapped_obj, attr) - if internal_attr is not None: - if wrapped_attr is None: - print("Missing attribute: ", attr) - assert False + if internal_attr is not None and wrapped_attr is None: + pytest.fail(f"Missing attribute: {attr}") if attr in ["__self__", "__class__"]: continue diff --git a/uv.lock b/uv.lock index 587ddc8b7..8958c3086 100644 --- a/uv.lock +++ b/uv.lock @@ -351,7 +351,6 @@ wheels = [ [[package]] name = "datafusion" -version = "44.0.0" source = { editable = "." } dependencies = [ { name = "pyarrow", version = "17.0.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" },