diff --git a/.github/workflows/code_quality.yml b/.github/workflows/code_quality.yml new file mode 100644 index 00000000..719e54ca --- /dev/null +++ b/.github/workflows/code_quality.yml @@ -0,0 +1,45 @@ +name: Code Quality Checks +on: + workflow_dispatch: + inputs: + git_ref: + type: string + description: Git ref of the DuckDB python package + required: false + workflow_call: + inputs: + git_ref: + type: string + description: Git ref of the DuckDB python package + required: false + +defaults: + run: + shell: bash + +jobs: + run_checks: + name: Run linting, formatting and static type checker + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + ref: ${{ inputs.git_ref }} + fetch-depth: 0 + persist-credentials: false + + - name: Install Astral UV + uses: astral-sh/setup-uv@v6 + with: + version: "0.7.14" + python-version: 3.9 + + - name: pre-commit (cache) + uses: actions/cache@v4 + with: + path: ~/.cache/pre-commit + key: pre-commit-${{ hashFiles('.pre-commit-config.yaml') }} + + - name: pre-commit (--all-files) + run: | + uvx pre-commit run --show-diff-on-failure --color=always --all-files diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index ab696897..ce78df1c 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -151,4 +151,4 @@ jobs: echo "### C++ Coverage Summary" >> $GITHUB_STEP_SUMMARY echo '```' >> $GITHUB_STEP_SUMMARY echo "$SUMMARY_CPP" >> $GITHUB_STEP_SUMMARY - echo '```' >> $GITHUB_STEP_SUMMARY \ No newline at end of file + echo '```' >> $GITHUB_STEP_SUMMARY diff --git a/.github/workflows/on_pr.yml b/.github/workflows/on_pr.yml index 7a4669cb..85c8904a 100644 --- a/.github/workflows/on_pr.yml +++ b/.github/workflows/on_pr.yml @@ -23,9 +23,14 @@ jobs: name: Make sure submodule is in a sane state uses: ./.github/workflows/submodule_sanity.yml + code_quality: + name: Code-quality checks + needs: submodule_sanity_guard + uses: ./.github/workflows/code_quality.yml + packaging_test: name: Build a minimal set of packages and run all tests on them - needs: submodule_sanity_guard + needs: code_quality # Skip packaging tests for draft PRs if: ${{ github.event_name != 'pull_request' || github.event.pull_request.draft == false }} uses: ./.github/workflows/packaging.yml @@ -36,7 +41,7 @@ jobs: coverage_test: name: Run coverage tests - needs: submodule_sanity_guard + needs: code_quality # Only run coverage test for draft PRs if: ${{ github.event_name == 'pull_request' && github.event.pull_request.draft == true }} uses: ./.github/workflows/coverage.yml diff --git a/.github/workflows/on_push.yml b/.github/workflows/on_push.yml index 1a282d69..706f8789 100644 --- a/.github/workflows/on_push.yml +++ b/.github/workflows/on_push.yml @@ -18,8 +18,13 @@ concurrency: cancel-in-progress: true jobs: + code_quality: + name: Code-quality checks + uses: ./.github/workflows/code_quality.yml + test: name: Run coverage tests + needs: code_quality uses: ./.github/workflows/coverage.yml with: git_ref: ${{ github.ref }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..0010a4fa --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,37 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.13.3 + hooks: + # Run the linter. + - id: ruff-check + # Run the formatter. + - id: ruff-format + + - repo: https://github.com/pre-commit/mirrors-clang-format + rev: v21.1.2 # pick the version of clang-format you want + hooks: + - id: clang-format + files: \.(c|cpp|cc|h|hpp|cxx|hxx)$ + + - repo: https://github.com/cheshirekow/cmake-format-precommit + rev: v0.6.13 + hooks: + - id: cmake-format + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.18.2 + hooks: + - id: mypy + entry: mypy -v + files: ^(duckdb/|_duckdb-stubs/) + exclude: ^duckdb/(experimental|query_graph)/ + additional_dependencies: [ numpy, polars ] + + - repo: local + hooks: + - id: post-checkout-submodules + name: Update submodule post-checkout + entry: .githooks/post-checkout + language: script + stages: [ post-checkout ] diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..d4f4b61b --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,19 @@ +# Changelog + +## v1.4.1 +**DuckDB Core**: v1.4.1 + +### Bug Fixes +- **ADBC Driver**: Fixed ADBC driver implementation (#81) +- **SQLAlchemy compatibility**: Added `__hash__` method overload (#61) +- **Error Handling**: Reset PyErr before throwing Python exceptions (#69) +- **Polars Lazyframes**: Fixed Polars expression pushdown (#102) + +### Code Quality Improvements & Developer Experience +- **MyPy Support**: MyPy is functional again and better integrated with the dev workflow +- **Stubs**: Re-created and manually curated stubs for the binary extension +- **Type Shadowing**: Deprecated `typing` and `functional` modules +- **Linting & Formatting**: Comprehensive code quality improvements with Ruff +- **Type Annotations**: Added missing overloads and improved type coverage +- **Pre-commit Integration**: Added ruff, clang-format, cmake-format and mypy configs +- **CI/CD**: Added code quality workflow diff --git a/CMakeLists.txt b/CMakeLists.txt index a9bc047d..ab9e1cee 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -48,13 +48,12 @@ duckdb_add_library(duckdb_target) # Bundle in INTERFACE library add_library(_duckdb_dependencies INTERFACE) -target_link_libraries(_duckdb_dependencies INTERFACE - pybind11::pybind11 - duckdb_target -) +target_link_libraries(_duckdb_dependencies INTERFACE pybind11::pybind11 + duckdb_target) # Also add include directory -target_include_directories(_duckdb_dependencies INTERFACE - $ +target_include_directories( + _duckdb_dependencies + INTERFACE $ ) # ──────────────────────────────────────────── @@ -62,36 +61,71 @@ target_include_directories(_duckdb_dependencies INTERFACE # ──────────────────────────────────────────── add_subdirectory(src/duckdb_py) -pybind11_add_module(_duckdb - $ - $ - $ - $ - $ - $ - $ - $ - $ - $ - $ - $ - $ -) +pybind11_add_module( + _duckdb + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $) # add _duckdb_dependencies target_link_libraries(_duckdb PRIVATE _duckdb_dependencies) +# ──────────────────────────────────────────── +# Controlling symbol export +# +# We want to export exactly two symbols: - PyInit__duckdb: this allows CPython +# to load the module - duckdb_adbc_init: the DuckDB ADBC driver +# +# The export of symbols on OSX and Linux is controlled by: - Visibility +# annotations in the code (for this lib we use the PYBIND11_EXPORT macro) - +# Telling the linker which symbols we want exported, which we do below +# +# For Windows, we rely on just the visbility annotations. +# ──────────────────────────────────────────── +set_target_properties( + _duckdb + PROPERTIES CXX_VISIBILITY_PRESET hidden + C_VISIBILITY_PRESET hidden + VISIBILITY_INLINES_HIDDEN ON) + +if(APPLE) + target_link_options( + _duckdb PRIVATE "LINKER:-exported_symbol,_duckdb_adbc_init" + "LINKER:-exported_symbol,_PyInit__duckdb") +elseif(UNIX AND NOT APPLE) + target_link_options( + _duckdb PRIVATE "LINKER:--export-dynamic-symbol=duckdb_adbc_init" + "LINKER:--export-dynamic-symbol=PyInit__duckdb") +elseif(WIN32) + target_link_options(_duckdb PRIVATE "/EXPORT:duckdb_adbc_init" + "/EXPORT:PyInit__duckdb") +endif() + # ──────────────────────────────────────────── # Put the object file in the correct place # ──────────────────────────────────────────── -# If we're not building through scikit-build-core then we have to set a different dest dir +# If we're not building through scikit-build-core then we have to set a +# different dest dir include(GNUInstallDirs) if(DEFINED SKBUILD_PLATLIB_DIR) set(_DUCKDB_PY_INSTALL_DIR "${SKBUILD_PLATLIB_DIR}") elseif(DEFINED Python_SITEARCH) set(_DUCKDB_PY_INSTALL_DIR "${Python_SITEARCH}") else() - message(WARNING "Could not determine Python install dir. Falling back to CMAKE_INSTALL_LIBDIR.") + message( + WARNING + "Could not determine Python install dir. Falling back to CMAKE_INSTALL_LIBDIR." + ) set(_DUCKDB_PY_INSTALL_DIR "${CMAKE_INSTALL_LIBDIR}") endif() diff --git a/README.md b/README.md index 627349b2..4ad94403 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,14 @@ API Docs (Python)

-# The [DuckDB](https://github.com/duckdb/duckdb) Python Package +# DuckDB: A Fast, In-Process, Portable, Open Source, Analytical Database System + +* **Simple**: DuckDB is easy to install and deploy. It has zero external dependencies and runs in-process in its host application or as a single binary. +* **Portable**: DuckDB runs on Linux, macOS, Windows, Android, iOS and all popular hardware architectures. It has idiomatic client APIs for major programming languages. +* **Feature-rich**: DuckDB offers a rich SQL dialect. It can read and write file formats such as CSV, Parquet, and JSON, to and from the local file system and remote endpoints such as S3 buckets. +* **Fast**: DuckDB runs analytical queries at blazing speed thanks to its columnar engine, which supports parallel execution and can process larger-than-memory workloads. +* **Extensible**: DuckDB is extensible by third-party features such as new data types, functions, file formats and new SQL syntax. User contributions are available as community extensions. +* **Free**: DuckDB and its core extensions are open-source under the permissive MIT License. The intellectual property of the project is held by the DuckDB Foundation. ## Installation diff --git a/_duckdb-stubs/__init__.pyi b/_duckdb-stubs/__init__.pyi new file mode 100644 index 00000000..6c36d7be --- /dev/null +++ b/_duckdb-stubs/__init__.pyi @@ -0,0 +1,1443 @@ +import os +import pathlib +import typing as pytyping +from typing_extensions import Self + +if pytyping.TYPE_CHECKING: + import fsspec + import numpy as np + import polars + import pandas + import pyarrow.lib + import torch as pytorch + import tensorflow + from collections.abc import Callable, Sequence, Mapping + from duckdb import sqltypes, func + + # the field_ids argument to to_parquet and write_parquet has a recursive structure + ParquetFieldIdsType = Mapping[str, pytyping.Union[int, "ParquetFieldIdsType"]] + +__all__: list[str] = [ + "BinderException", + "CSVLineTerminator", + "CaseExpression", + "CatalogException", + "CoalesceOperator", + "ColumnExpression", + "ConnectionException", + "ConstantExpression", + "ConstraintException", + "ConversionException", + "DataError", + "DatabaseError", + "DefaultExpression", + "DependencyException", + "DuckDBPyConnection", + "DuckDBPyRelation", + "Error", + "ExpectedResultType", + "ExplainType", + "Expression", + "FatalException", + "FunctionExpression", + "HTTPException", + "IOException", + "IntegrityError", + "InternalError", + "InternalException", + "InterruptException", + "InvalidInputException", + "InvalidTypeException", + "LambdaExpression", + "NotImplementedException", + "NotSupportedError", + "OperationalError", + "OutOfMemoryException", + "OutOfRangeException", + "ParserException", + "PermissionException", + "ProgrammingError", + "PythonExceptionHandling", + "RenderMode", + "SQLExpression", + "SequenceException", + "SerializationException", + "StarExpression", + "Statement", + "StatementType", + "SyntaxException", + "TransactionException", + "TypeMismatchException", + "Warning", + "aggregate", + "alias", + "apilevel", + "append", + "array_type", + "arrow", + "begin", + "checkpoint", + "close", + "commit", + "connect", + "create_function", + "cursor", + "decimal_type", + "default_connection", + "description", + "df", + "distinct", + "dtype", + "duplicate", + "enum_type", + "execute", + "executemany", + "extract_statements", + "fetch_arrow_table", + "fetch_df", + "fetch_df_chunk", + "fetch_record_batch", + "fetchall", + "fetchdf", + "fetchmany", + "fetchnumpy", + "fetchone", + "filesystem_is_registered", + "filter", + "from_arrow", + "from_csv_auto", + "from_df", + "from_parquet", + "from_query", + "get_table_names", + "install_extension", + "interrupt", + "limit", + "list_filesystems", + "list_type", + "load_extension", + "map_type", + "order", + "paramstyle", + "pl", + "project", + "query", + "query_df", + "query_progress", + "read_csv", + "read_json", + "read_parquet", + "register", + "register_filesystem", + "remove_function", + "rollback", + "row_type", + "rowcount", + "set_default_connection", + "sql", + "sqltype", + "string_type", + "struct_type", + "table", + "table_function", + "tf", + "threadsafety", + "token_type", + "tokenize", + "torch", + "type", + "union_type", + "unregister", + "unregister_filesystem", + "values", + "view", + "write_csv", +] + +class BinderException(ProgrammingError): ... + +class CSVLineTerminator: + CARRIAGE_RETURN_LINE_FEED: pytyping.ClassVar[ + CSVLineTerminator + ] # value = + LINE_FEED: pytyping.ClassVar[CSVLineTerminator] # value = + __members__: pytyping.ClassVar[ + dict[str, CSVLineTerminator] + ] # value = {'LINE_FEED': , 'CARRIAGE_RETURN_LINE_FEED': } # noqa: E501 + def __eq__(self, other: object) -> bool: ... + def __getstate__(self) -> int: ... + def __hash__(self) -> int: ... + def __index__(self) -> int: ... + def __init__(self, value: pytyping.SupportsInt) -> None: ... + def __int__(self) -> int: ... + def __ne__(self, other: object) -> bool: ... + def __setstate__(self, state: pytyping.SupportsInt) -> None: ... + @property + def name(self) -> str: ... + @property + def value(self) -> int: ... + +class CatalogException(ProgrammingError): ... +class ConnectionException(OperationalError): ... +class ConstraintException(IntegrityError): ... +class ConversionException(DataError): ... +class DataError(DatabaseError): ... +class DatabaseError(Error): ... +class DependencyException(DatabaseError): ... + +class DuckDBPyConnection: + def __del__(self) -> None: ... + def __enter__(self) -> Self: ... + def __exit__(self, exc_type: object, exc: object, traceback: object) -> None: ... + def append(self, table_name: str, df: pandas.DataFrame, *, by_name: bool = False) -> DuckDBPyConnection: ... + def array_type(self, type: sqltypes.DuckDBPyType, size: pytyping.SupportsInt) -> sqltypes.DuckDBPyType: ... + def arrow(self, rows_per_batch: pytyping.SupportsInt = 1000000) -> pyarrow.lib.RecordBatchReader: ... + def begin(self) -> DuckDBPyConnection: ... + def checkpoint(self) -> DuckDBPyConnection: ... + def close(self) -> None: ... + def commit(self) -> DuckDBPyConnection: ... + def create_function( + self, + name: str, + function: Callable[..., pytyping.Any], + parameters: list[sqltypes.DuckDBPyType] | None = None, + return_type: sqltypes.DuckDBPyType | None = None, + *, + type: func.PythonUDFType = ..., + null_handling: func.FunctionNullHandling = ..., + exception_handling: PythonExceptionHandling = ..., + side_effects: bool = False, + ) -> DuckDBPyConnection: ... + def cursor(self) -> DuckDBPyConnection: ... + def decimal_type(self, width: pytyping.SupportsInt, scale: pytyping.SupportsInt) -> sqltypes.DuckDBPyType: ... + def df(self, *, date_as_object: bool = False) -> pandas.DataFrame: ... + def dtype(self, type_str: str) -> sqltypes.DuckDBPyType: ... + def duplicate(self) -> DuckDBPyConnection: ... + def enum_type( + self, name: str, type: sqltypes.DuckDBPyType, values: list[pytyping.Any] + ) -> sqltypes.DuckDBPyType: ... + def execute(self, query: Statement | str, parameters: object = None) -> DuckDBPyConnection: ... + def executemany(self, query: Statement | str, parameters: object = None) -> DuckDBPyConnection: ... + def extract_statements(self, query: str) -> list[Statement]: ... + def fetch_arrow_table(self, rows_per_batch: pytyping.SupportsInt = 1000000) -> pyarrow.lib.Table: ... + def fetch_df(self, *, date_as_object: bool = False) -> pandas.DataFrame: ... + def fetch_df_chunk( + self, vectors_per_chunk: pytyping.SupportsInt = 1, *, date_as_object: bool = False + ) -> pandas.DataFrame: ... + def fetch_record_batch(self, rows_per_batch: pytyping.SupportsInt = 1000000) -> pyarrow.lib.RecordBatchReader: ... + def fetchall(self) -> list[tuple[pytyping.Any, ...]]: ... + def fetchdf(self, *, date_as_object: bool = False) -> pandas.DataFrame: ... + def fetchmany(self, size: pytyping.SupportsInt = 1) -> list[tuple[pytyping.Any, ...]]: ... + def fetchnumpy(self) -> dict[str, np.typing.NDArray[pytyping.Any] | pandas.Categorical]: ... + def fetchone(self) -> tuple[pytyping.Any, ...] | None: ... + def filesystem_is_registered(self, name: str) -> bool: ... + def from_arrow(self, arrow_object: object) -> DuckDBPyRelation: ... + def from_csv_auto( + self, + path_or_buffer: str | bytes | os.PathLike[str] | os.PathLike[bytes], + header: bool | int | None = None, + compression: str | None = None, + sep: str | None = None, + delimiter: str | None = None, + files_to_sniff: int | None = None, + comment: str | None = None, + thousands: str | None = None, + dtype: dict[str, str] | list[str] | None = None, + na_values: str | list[str] | None = None, + skiprows: int | None = None, + quotechar: str | None = None, + escapechar: str | None = None, + encoding: str | None = None, + parallel: bool | None = None, + date_format: str | None = None, + timestamp_format: str | None = None, + sample_size: int | None = None, + auto_detect: bool | int | None = None, + all_varchar: bool | None = None, + normalize_names: bool | None = None, + null_padding: bool | None = None, + names: list[str] | None = None, + lineterminator: str | None = None, + columns: dict[str, str] | None = None, + auto_type_candidates: list[str] | None = None, + max_line_size: int | None = None, + ignore_errors: bool | None = None, + store_rejects: bool | None = None, + rejects_table: str | None = None, + rejects_scan: str | None = None, + rejects_limit: int | None = None, + force_not_null: list[str] | None = None, + buffer_size: int | None = None, + decimal: str | None = None, + allow_quoted_nulls: bool | None = None, + filename: bool | str | None = None, + hive_partitioning: bool | None = None, + union_by_name: bool | None = None, + hive_types: dict[str, str] | None = None, + hive_types_autocast: bool | None = None, + strict_mode: bool | None = None, + ) -> DuckDBPyRelation: ... + def from_df(self, df: pandas.DataFrame) -> DuckDBPyRelation: ... + @pytyping.overload + def from_parquet( + self, + file_glob: str, + binary_as_string: bool = False, + *, + file_row_number: bool = False, + filename: bool = False, + hive_partitioning: bool = False, + union_by_name: bool = False, + compression: str | None = None, + ) -> DuckDBPyRelation: ... + @pytyping.overload + def from_parquet( + self, + file_globs: Sequence[str], + binary_as_string: bool = False, + *, + file_row_number: bool = False, + filename: bool = False, + hive_partitioning: bool = False, + union_by_name: bool = False, + compression: str | None = None, + ) -> DuckDBPyRelation: ... + def from_query(self, query: str, *, alias: str = "", params: object = None) -> DuckDBPyRelation: ... + def get_table_names(self, query: str, *, qualified: bool = False) -> set[str]: ... + def install_extension( + self, + extension: str, + *, + force_install: bool = False, + repository: str | None = None, + repository_url: str | None = None, + version: str | None = None, + ) -> None: ... + def interrupt(self) -> None: ... + def list_filesystems(self) -> list[str]: ... + def list_type(self, type: sqltypes.DuckDBPyType) -> sqltypes.DuckDBPyType: ... + def load_extension(self, extension: str) -> None: ... + def map_type(self, key: sqltypes.DuckDBPyType, value: sqltypes.DuckDBPyType) -> sqltypes.DuckDBPyType: ... + def pl(self, rows_per_batch: pytyping.SupportsInt = 1000000, *, lazy: bool = False) -> polars.DataFrame: ... + def query(self, query: str, *, alias: str = "", params: object = None) -> DuckDBPyRelation: ... + def query_progress(self) -> float: ... + def read_csv( + self, + path_or_buffer: str | bytes | os.PathLike[str], + header: bool | int | None = None, + compression: str | None = None, + sep: str | None = None, + delimiter: str | None = None, + files_to_sniff: int | None = None, + comment: str | None = None, + thousands: str | None = None, + dtype: dict[str, str] | list[str] | None = None, + na_values: str | list[str] | None = None, + skiprows: int | None = None, + quotechar: str | None = None, + escapechar: str | None = None, + encoding: str | None = None, + parallel: bool | None = None, + date_format: str | None = None, + timestamp_format: str | None = None, + sample_size: int | None = None, + auto_detect: bool | int | None = None, + all_varchar: bool | None = None, + normalize_names: bool | None = None, + null_padding: bool | None = None, + names: list[str] | None = None, + lineterminator: str | None = None, + columns: dict[str, str] | None = None, + auto_type_candidates: list[str] | None = None, + max_line_size: int | None = None, + ignore_errors: bool | None = None, + store_rejects: bool | None = None, + rejects_table: str | None = None, + rejects_scan: str | None = None, + rejects_limit: int | None = None, + force_not_null: list[str] | None = None, + buffer_size: int | None = None, + decimal: str | None = None, + allow_quoted_nulls: bool | None = None, + filename: bool | str | None = None, + hive_partitioning: bool | None = None, + union_by_name: bool | None = None, + hive_types: dict[str, str] | None = None, + hive_types_autocast: bool | None = None, + strict_mode: bool | None = None, + ) -> DuckDBPyRelation: ... + def read_json( + self, + path_or_buffer: str | bytes | os.PathLike[str], + *, + columns: dict[str, str] | None = None, + sample_size: int | None = None, + maximum_depth: int | None = None, + records: str | None = None, + format: str | None = None, + date_format: str | None = None, + timestamp_format: str | None = None, + compression: str | None = None, + maximum_object_size: int | None = None, + ignore_errors: bool | None = None, + convert_strings_to_integers: bool | None = None, + field_appearance_threshold: float | None = None, + map_inference_threshold: int | None = None, + maximum_sample_files: int | None = None, + filename: bool | str | None = None, + hive_partitioning: bool | None = None, + union_by_name: bool | None = None, + hive_types: dict[str, str] | None = None, + hive_types_autocast: bool | None = None, + ) -> DuckDBPyRelation: ... + @pytyping.overload + def read_parquet( + self, + file_glob: str, + binary_as_string: bool = False, + *, + file_row_number: bool = False, + filename: bool = False, + hive_partitioning: bool = False, + union_by_name: bool = False, + compression: str | None = None, + ) -> DuckDBPyRelation: ... + @pytyping.overload + def read_parquet( + self, + file_globs: Sequence[str], + binary_as_string: bool = False, + *, + file_row_number: bool = False, + filename: bool = False, + hive_partitioning: bool = False, + union_by_name: bool = False, + compression: pytyping.Any = None, + ) -> DuckDBPyRelation: ... + def register(self, view_name: str, python_object: object) -> DuckDBPyConnection: ... + def register_filesystem(self, filesystem: fsspec.AbstractFileSystem) -> None: ... + def remove_function(self, name: str) -> DuckDBPyConnection: ... + def rollback(self) -> DuckDBPyConnection: ... + def row_type( + self, fields: dict[str, sqltypes.DuckDBPyType] | list[sqltypes.DuckDBPyType] + ) -> sqltypes.DuckDBPyType: ... + def sql(self, query: Statement | str, *, alias: str = "", params: object = None) -> DuckDBPyRelation: ... + def sqltype(self, type_str: str) -> sqltypes.DuckDBPyType: ... + def string_type(self, collation: str = "") -> sqltypes.DuckDBPyType: ... + def struct_type( + self, fields: dict[str, sqltypes.DuckDBPyType] | list[sqltypes.DuckDBPyType] + ) -> sqltypes.DuckDBPyType: ... + def table(self, table_name: str) -> DuckDBPyRelation: ... + def table_function(self, name: str, parameters: object = None) -> DuckDBPyRelation: ... + def tf(self) -> dict[str, tensorflow.Tensor]: ... + def torch(self) -> dict[str, pytorch.Tensor]: ... + def type(self, type_str: str) -> sqltypes.DuckDBPyType: ... + def union_type( + self, members: list[sqltypes.DuckDBPyType] | dict[str, sqltypes.DuckDBPyType] + ) -> sqltypes.DuckDBPyType: ... + def unregister(self, view_name: str) -> DuckDBPyConnection: ... + def unregister_filesystem(self, name: str) -> None: ... + def values(self, *args: list[pytyping.Any] | tuple[Expression, ...] | Expression) -> DuckDBPyRelation: ... + def view(self, view_name: str) -> DuckDBPyRelation: ... + @property + def description(self) -> list[tuple[str, sqltypes.DuckDBPyType, None, None, None, None, None]]: ... + @property + def rowcount(self) -> int: ... + +class DuckDBPyRelation: + def __arrow_c_stream__(self, requested_schema: object | None = None) -> pytyping.Any: ... + def __contains__(self, name: str) -> bool: ... + def __getattr__(self, name: str) -> DuckDBPyRelation: ... + def __getitem__(self, name: str) -> DuckDBPyRelation: ... + def __len__(self) -> int: ... + def aggregate(self, aggr_expr: Expression | str, group_expr: Expression | str = "") -> DuckDBPyRelation: ... + def any_value( + self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + ) -> DuckDBPyRelation: ... + def apply( + self, + function_name: str, + function_aggr: str, + group_expr: str = "", + function_parameter: str = "", + projected_columns: str = "", + ) -> DuckDBPyRelation: ... + def arg_max( + self, arg_column: str, value_column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + ) -> DuckDBPyRelation: ... + def arg_min( + self, arg_column: str, value_column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + ) -> DuckDBPyRelation: ... + def arrow(self, batch_size: pytyping.SupportsInt = 1000000) -> pyarrow.lib.RecordBatchReader: ... + def avg( + self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + ) -> DuckDBPyRelation: ... + def bit_and( + self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + ) -> DuckDBPyRelation: ... + def bit_or( + self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + ) -> DuckDBPyRelation: ... + def bit_xor( + self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + ) -> DuckDBPyRelation: ... + def bitstring_agg( + self, + column: str, + min: int | None = None, + max: int | None = None, + groups: str = "", + window_spec: str = "", + projected_columns: str = "", + ) -> DuckDBPyRelation: ... + def bool_and( + self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + ) -> DuckDBPyRelation: ... + def bool_or( + self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + ) -> DuckDBPyRelation: ... + def close(self) -> None: ... + def count( + self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + ) -> DuckDBPyRelation: ... + def create(self, table_name: str) -> None: ... + def create_view(self, view_name: str, replace: bool = True) -> DuckDBPyRelation: ... + def cross(self, other_rel: DuckDBPyRelation) -> DuckDBPyRelation: ... + def cume_dist(self, window_spec: str, projected_columns: str = "") -> DuckDBPyRelation: ... + def dense_rank(self, window_spec: str, projected_columns: str = "") -> DuckDBPyRelation: ... + def describe(self) -> DuckDBPyRelation: ... + def df(self, *, date_as_object: bool = False) -> pandas.DataFrame: ... + def distinct(self) -> DuckDBPyRelation: ... + def except_(self, other_rel: DuckDBPyRelation) -> DuckDBPyRelation: ... + def execute(self) -> DuckDBPyRelation: ... + def explain(self, type: ExplainType = ExplainType.STANDARD) -> str: ... + def favg( + self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + ) -> DuckDBPyRelation: ... + def fetch_arrow_reader(self, batch_size: pytyping.SupportsInt = 1000000) -> pyarrow.lib.RecordBatchReader: ... + def fetch_arrow_table(self, batch_size: pytyping.SupportsInt = 1000000) -> pyarrow.lib.Table: ... + def fetch_df_chunk( + self, vectors_per_chunk: pytyping.SupportsInt = 1, *, date_as_object: bool = False + ) -> pandas.DataFrame: ... + def fetch_record_batch(self, rows_per_batch: pytyping.SupportsInt = 1000000) -> pyarrow.lib.RecordBatchReader: ... + def fetchall(self) -> list[tuple[pytyping.Any, ...]]: ... + def fetchdf(self, *, date_as_object: bool = False) -> pandas.DataFrame: ... + def fetchmany(self, size: pytyping.SupportsInt = 1) -> list[tuple[pytyping.Any, ...]]: ... + def fetchnumpy(self) -> dict[str, np.typing.NDArray[pytyping.Any] | pandas.Categorical]: ... + def fetchone(self) -> tuple[pytyping.Any, ...] | None: ... + def filter(self, filter_expr: Expression | str) -> DuckDBPyRelation: ... + def first(self, column: str, groups: str = "", projected_columns: str = "") -> DuckDBPyRelation: ... + def first_value(self, column: str, window_spec: str = "", projected_columns: str = "") -> DuckDBPyRelation: ... + def fsum( + self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + ) -> DuckDBPyRelation: ... + def geomean(self, column: str, groups: str = "", projected_columns: str = "") -> DuckDBPyRelation: ... + def histogram( + self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + ) -> DuckDBPyRelation: ... + def insert(self, values: pytyping.List[object]) -> None: ... + def insert_into(self, table_name: str) -> None: ... + def intersect(self, other_rel: DuckDBPyRelation) -> DuckDBPyRelation: ... + def join( + self, other_rel: DuckDBPyRelation, condition: Expression | str, how: str = "inner" + ) -> DuckDBPyRelation: ... + def lag( + self, + column: str, + window_spec: str, + offset: pytyping.SupportsInt = 1, + default_value: str = "NULL", + ignore_nulls: bool = False, + projected_columns: str = "", + ) -> DuckDBPyRelation: ... + def last(self, column: str, groups: str = "", projected_columns: str = "") -> DuckDBPyRelation: ... + def last_value(self, column: str, window_spec: str = "", projected_columns: str = "") -> DuckDBPyRelation: ... + def lead( + self, + column: str, + window_spec: str, + offset: pytyping.SupportsInt = 1, + default_value: str = "NULL", + ignore_nulls: bool = False, + projected_columns: str = "", + ) -> DuckDBPyRelation: ... + def limit(self, n: pytyping.SupportsInt, offset: pytyping.SupportsInt = 0) -> DuckDBPyRelation: ... + def list( + self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + ) -> DuckDBPyRelation: ... + def map( + self, map_function: Callable[..., pytyping.Any], *, schema: dict[str, sqltypes.DuckDBPyType] | None = None + ) -> DuckDBPyRelation: ... + def max( + self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + ) -> DuckDBPyRelation: ... + def mean( + self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + ) -> DuckDBPyRelation: ... + def median( + self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + ) -> DuckDBPyRelation: ... + def min( + self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + ) -> DuckDBPyRelation: ... + def mode( + self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + ) -> DuckDBPyRelation: ... + def n_tile( + self, window_spec: str, num_buckets: pytyping.SupportsInt, projected_columns: str = "" + ) -> DuckDBPyRelation: ... + def nth_value( + self, + column: str, + window_spec: str, + offset: pytyping.SupportsInt, + ignore_nulls: bool = False, + projected_columns: str = "", + ) -> DuckDBPyRelation: ... + def order(self, order_expr: str) -> DuckDBPyRelation: ... + def percent_rank(self, window_spec: str, projected_columns: str = "") -> DuckDBPyRelation: ... + def pl(self, batch_size: pytyping.SupportsInt = 1000000, *, lazy: bool = False) -> polars.DataFrame: ... + def product( + self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + ) -> DuckDBPyRelation: ... + def project(self, *args: str | Expression, groups: str = "") -> DuckDBPyRelation: ... + def quantile( + self, + column: str, + q: float | pytyping.List[float] = 0.5, + groups: str = "", + window_spec: str = "", + projected_columns: str = "", + ) -> DuckDBPyRelation: ... + def quantile_cont( + self, + column: str, + q: float | pytyping.List[float] = 0.5, + groups: str = "", + window_spec: str = "", + projected_columns: str = "", + ) -> DuckDBPyRelation: ... + def quantile_disc( + self, + column: str, + q: float | pytyping.List[float] = 0.5, + groups: str = "", + window_spec: str = "", + projected_columns: str = "", + ) -> DuckDBPyRelation: ... + def query(self, virtual_table_name: str, sql_query: str) -> DuckDBPyRelation: ... + def rank(self, window_spec: str, projected_columns: str = "") -> DuckDBPyRelation: ... + def rank_dense(self, window_spec: str, projected_columns: str = "") -> DuckDBPyRelation: ... + def record_batch(self, batch_size: pytyping.SupportsInt = 1000000) -> pyarrow.RecordBatchReader: ... + def row_number(self, window_spec: str, projected_columns: str = "") -> DuckDBPyRelation: ... + def select(self, *args: str | Expression, groups: str = "") -> DuckDBPyRelation: ... + def select_dtypes(self, types: pytyping.List[sqltypes.DuckDBPyType | str]) -> DuckDBPyRelation: ... + def select_types(self, types: pytyping.List[sqltypes.DuckDBPyType | str]) -> DuckDBPyRelation: ... + def set_alias(self, alias: str) -> DuckDBPyRelation: ... + def show( + self, + *, + max_width: pytyping.SupportsInt | None = None, + max_rows: pytyping.SupportsInt | None = None, + max_col_width: pytyping.SupportsInt | None = None, + null_value: str | None = None, + render_mode: RenderMode | None = None, + ) -> None: ... + def sort(self, *args: Expression) -> DuckDBPyRelation: ... + def sql_query(self) -> str: ... + def std( + self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + ) -> DuckDBPyRelation: ... + def stddev( + self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + ) -> DuckDBPyRelation: ... + def stddev_pop( + self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + ) -> DuckDBPyRelation: ... + def stddev_samp( + self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + ) -> DuckDBPyRelation: ... + def string_agg( + self, column: str, sep: str = ",", groups: str = "", window_spec: str = "", projected_columns: str = "" + ) -> DuckDBPyRelation: ... + def sum( + self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + ) -> DuckDBPyRelation: ... + def tf(self) -> dict[str, tensorflow.Tensor]: ... + def to_arrow_table(self, batch_size: pytyping.SupportsInt = 1000000) -> pyarrow.lib.Table: ... + def to_csv( + self, + file_name: str, + *, + sep: str | None = None, + na_rep: str | None = None, + header: bool | None = None, + quotechar: str | None = None, + escapechar: str | None = None, + date_format: str | None = None, + timestamp_format: str | None = None, + quoting: str | int | None = None, + encoding: str | None = None, + compression: str | None = None, + overwrite: bool | None = None, + per_thread_output: bool | None = None, + use_tmp_file: bool | None = None, + partition_by: pytyping.List[str] | None = None, + write_partition_columns: bool | None = None, + ) -> None: ... + def to_df(self, *, date_as_object: bool = False) -> pandas.DataFrame: ... + def to_parquet( + self, + file_name: str, + *, + compression: str | None = None, + field_ids: ParquetFieldIdsType | pytyping.Literal["auto"] | None = None, + row_group_size_bytes: int | str | None = None, + row_group_size: int | None = None, + overwrite: bool | None = None, + per_thread_output: bool | None = None, + use_tmp_file: bool | None = None, + partition_by: pytyping.List[str] | None = None, + write_partition_columns: bool | None = None, + append: bool | None = None, + ) -> None: ... + def to_table(self, table_name: str) -> None: ... + def to_view(self, view_name: str, replace: bool = True) -> DuckDBPyRelation: ... + def torch(self) -> dict[str, pytorch.Tensor]: ... + def union(self, union_rel: DuckDBPyRelation) -> DuckDBPyRelation: ... + def unique(self, unique_aggr: str) -> DuckDBPyRelation: ... + def update(self, set: Expression | str, *, condition: Expression | str | None = None) -> None: ... + def value_counts(self, column: str, groups: str = "") -> DuckDBPyRelation: ... + def var( + self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + ) -> DuckDBPyRelation: ... + def var_pop( + self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + ) -> DuckDBPyRelation: ... + def var_samp( + self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + ) -> DuckDBPyRelation: ... + def variance( + self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + ) -> DuckDBPyRelation: ... + def write_csv( + self, + file_name: str, + sep: str | None = None, + na_rep: str | None = None, + header: bool | None = None, + quotechar: str | None = None, + escapechar: str | None = None, + date_format: str | None = None, + timestamp_format: str | None = None, + quoting: str | int | None = None, + encoding: str | None = None, + compression: str | None = None, + overwrite: bool | None = None, + per_thread_output: bool | None = None, + use_tmp_file: bool | None = None, + partition_by: pytyping.List[str] | None = None, + write_partition_columns: bool | None = None, + ) -> None: ... + def write_parquet( + self, + file_name: str, + compression: str | None = None, + field_ids: ParquetFieldIdsType | pytyping.Literal["auto"] | None = None, + row_group_size_bytes: str | int | None = None, + row_group_size: int | None = None, + overwrite: bool | None = None, + per_thread_output: bool | None = None, + use_tmp_file: bool | None = None, + partition_by: pytyping.List[str] | None = None, + write_partition_columns: bool | None = None, + append: bool | None = None, + ) -> None: ... + @property + def alias(self) -> str: ... + @property + def columns(self) -> pytyping.List[str]: ... + @property + def description(self) -> pytyping.List[tuple[str, sqltypes.DuckDBPyType, None, None, None, None, None]]: ... + @property + def dtypes(self) -> pytyping.List[str]: ... + @property + def shape(self) -> tuple[int, int]: ... + @property + def type(self) -> str: ... + @property + def types(self) -> pytyping.List[sqltypes.DuckDBPyType]: ... + +class Error(Exception): ... + +class ExpectedResultType: + CHANGED_ROWS: pytyping.ClassVar[ExpectedResultType] # value = + NOTHING: pytyping.ClassVar[ExpectedResultType] # value = + QUERY_RESULT: pytyping.ClassVar[ExpectedResultType] # value = + __members__: pytyping.ClassVar[ + dict[str, ExpectedResultType] + ] # value = {'QUERY_RESULT': , 'CHANGED_ROWS': , 'NOTHING': } # noqa: E501 + def __eq__(self, other: object) -> bool: ... + def __getstate__(self) -> int: ... + def __hash__(self) -> int: ... + def __index__(self) -> int: ... + def __init__(self, value: pytyping.SupportsInt) -> None: ... + def __int__(self) -> int: ... + def __ne__(self, other: object) -> bool: ... + def __setstate__(self, state: pytyping.SupportsInt) -> None: ... + @property + def name(self) -> str: ... + @property + def value(self) -> int: ... + +class ExplainType: + ANALYZE: pytyping.ClassVar[ExplainType] # value = + STANDARD: pytyping.ClassVar[ExplainType] # value = + __members__: pytyping.ClassVar[ + dict[str, ExplainType] + ] # value = {'STANDARD': , 'ANALYZE': } + def __eq__(self, other: object) -> bool: ... + def __getstate__(self) -> int: ... + def __hash__(self) -> int: ... + def __index__(self) -> int: ... + def __init__(self, value: pytyping.SupportsInt) -> None: ... + def __int__(self) -> int: ... + def __ne__(self, other: object) -> bool: ... + def __setstate__(self, state: pytyping.SupportsInt) -> None: ... + @property + def name(self) -> str: ... + @property + def value(self) -> int: ... + +class Expression: + def __add__(self, other: Expression) -> Expression: ... + def __and__(self, other: Expression) -> Expression: ... + def __div__(self, other: Expression) -> Expression: ... + def __eq__(self, other: Expression) -> Expression: ... # type: ignore[override] + def __floordiv__(self, other: Expression) -> Expression: ... + def __ge__(self, other: Expression) -> Expression: ... + def __gt__(self, other: Expression) -> Expression: ... + @pytyping.overload + def __init__(self, arg0: str) -> None: ... + @pytyping.overload + def __init__(self, arg0: pytyping.Any) -> None: ... + def __invert__(self) -> Expression: ... + def __le__(self, other: Expression) -> Expression: ... + def __lt__(self, other: Expression) -> Expression: ... + def __mod__(self, other: Expression) -> Expression: ... + def __mul__(self, other: Expression) -> Expression: ... + def __ne__(self, other: Expression) -> Expression: ... # type: ignore[override] + def __neg__(self) -> Expression: ... + def __or__(self, other: Expression) -> Expression: ... + def __pow__(self, other: Expression) -> Expression: ... + def __radd__(self, other: Expression) -> Expression: ... + def __rand__(self, other: Expression) -> Expression: ... + def __rdiv__(self, other: Expression) -> Expression: ... + def __rfloordiv__(self, other: Expression) -> Expression: ... + def __rmod__(self, other: Expression) -> Expression: ... + def __rmul__(self, other: Expression) -> Expression: ... + def __ror__(self, other: Expression) -> Expression: ... + def __rpow__(self, other: Expression) -> Expression: ... + def __rsub__(self, other: Expression) -> Expression: ... + def __rtruediv__(self, other: Expression) -> Expression: ... + def __sub__(self, other: Expression) -> Expression: ... + def __truediv__(self, other: Expression) -> Expression: ... + def alias(self, name: str) -> Expression: ... + def asc(self) -> Expression: ... + def between(self, lower: Expression, upper: Expression) -> Expression: ... + def cast(self, type: sqltypes.DuckDBPyType) -> Expression: ... + def collate(self, collation: str) -> Expression: ... + def desc(self) -> Expression: ... + def get_name(self) -> str: ... + def isin(self, *args: Expression) -> Expression: ... + def isnotin(self, *args: Expression) -> Expression: ... + def isnotnull(self) -> Expression: ... + def isnull(self) -> Expression: ... + def nulls_first(self) -> Expression: ... + def nulls_last(self) -> Expression: ... + def otherwise(self, value: Expression) -> Expression: ... + def show(self) -> None: ... + def when(self, condition: Expression, value: Expression) -> Expression: ... + +class FatalException(DatabaseError): ... + +class HTTPException(IOException): + status_code: int + body: str + reason: str + headers: dict[str, str] + +class IOException(OperationalError): ... +class IntegrityError(DatabaseError): ... +class InternalError(DatabaseError): ... +class InternalException(InternalError): ... +class InterruptException(DatabaseError): ... +class InvalidInputException(ProgrammingError): ... +class InvalidTypeException(ProgrammingError): ... +class NotImplementedException(NotSupportedError): ... +class NotSupportedError(DatabaseError): ... +class OperationalError(DatabaseError): ... +class OutOfMemoryException(OperationalError): ... +class OutOfRangeException(DataError): ... +class ParserException(ProgrammingError): ... +class PermissionException(DatabaseError): ... +class ProgrammingError(DatabaseError): ... + +class PythonExceptionHandling: + DEFAULT: pytyping.ClassVar[PythonExceptionHandling] # value = + RETURN_NULL: pytyping.ClassVar[PythonExceptionHandling] # value = + __members__: pytyping.ClassVar[ + dict[str, PythonExceptionHandling] + ] # value = {'DEFAULT': , 'RETURN_NULL': } # noqa: E501 + def __eq__(self, other: object) -> bool: ... + def __getstate__(self) -> int: ... + def __hash__(self) -> int: ... + def __index__(self) -> int: ... + def __init__(self, value: pytyping.SupportsInt) -> None: ... + def __int__(self) -> int: ... + def __ne__(self, other: object) -> bool: ... + def __setstate__(self, state: pytyping.SupportsInt) -> None: ... + @property + def name(self) -> str: ... + @property + def value(self) -> int: ... + +class RenderMode: + COLUMNS: pytyping.ClassVar[RenderMode] # value = + ROWS: pytyping.ClassVar[RenderMode] # value = + __members__: pytyping.ClassVar[ + dict[str, RenderMode] + ] # value = {'ROWS': , 'COLUMNS': } + def __eq__(self, other: object) -> bool: ... + def __getstate__(self) -> int: ... + def __hash__(self) -> int: ... + def __index__(self) -> int: ... + def __init__(self, value: pytyping.SupportsInt) -> None: ... + def __int__(self) -> int: ... + def __ne__(self, other: object) -> bool: ... + def __setstate__(self, state: pytyping.SupportsInt) -> None: ... + @property + def name(self) -> str: ... + @property + def value(self) -> int: ... + +class SequenceException(DatabaseError): ... +class SerializationException(OperationalError): ... + +class Statement: + @property + def expected_result_type(self) -> list[StatementType]: ... + @property + def named_parameters(self) -> set[str]: ... + @property + def query(self) -> str: ... + @property + def type(self) -> StatementType: ... + +class StatementType: + ALTER_STATEMENT: pytyping.ClassVar[StatementType] # value = + ANALYZE_STATEMENT: pytyping.ClassVar[StatementType] # value = + ATTACH_STATEMENT: pytyping.ClassVar[StatementType] # value = + CALL_STATEMENT: pytyping.ClassVar[StatementType] # value = + COPY_DATABASE_STATEMENT: pytyping.ClassVar[StatementType] # value = + COPY_STATEMENT: pytyping.ClassVar[StatementType] # value = + CREATE_FUNC_STATEMENT: pytyping.ClassVar[StatementType] # value = + CREATE_STATEMENT: pytyping.ClassVar[StatementType] # value = + DELETE_STATEMENT: pytyping.ClassVar[StatementType] # value = + DETACH_STATEMENT: pytyping.ClassVar[StatementType] # value = + DROP_STATEMENT: pytyping.ClassVar[StatementType] # value = + EXECUTE_STATEMENT: pytyping.ClassVar[StatementType] # value = + EXPLAIN_STATEMENT: pytyping.ClassVar[StatementType] # value = + EXPORT_STATEMENT: pytyping.ClassVar[StatementType] # value = + EXTENSION_STATEMENT: pytyping.ClassVar[StatementType] # value = + INSERT_STATEMENT: pytyping.ClassVar[StatementType] # value = + INVALID_STATEMENT: pytyping.ClassVar[StatementType] # value = + LOAD_STATEMENT: pytyping.ClassVar[StatementType] # value = + LOGICAL_PLAN_STATEMENT: pytyping.ClassVar[StatementType] # value = + MERGE_INTO_STATEMENT: pytyping.ClassVar[StatementType] # value = + MULTI_STATEMENT: pytyping.ClassVar[StatementType] # value = + PRAGMA_STATEMENT: pytyping.ClassVar[StatementType] # value = + PREPARE_STATEMENT: pytyping.ClassVar[StatementType] # value = + RELATION_STATEMENT: pytyping.ClassVar[StatementType] # value = + SELECT_STATEMENT: pytyping.ClassVar[StatementType] # value = + SET_STATEMENT: pytyping.ClassVar[StatementType] # value = + TRANSACTION_STATEMENT: pytyping.ClassVar[StatementType] # value = + UPDATE_STATEMENT: pytyping.ClassVar[StatementType] # value = + VACUUM_STATEMENT: pytyping.ClassVar[StatementType] # value = + VARIABLE_SET_STATEMENT: pytyping.ClassVar[StatementType] # value = + __members__: pytyping.ClassVar[ + dict[str, StatementType] + ] # value = {'INVALID_STATEMENT': , 'SELECT_STATEMENT': , 'INSERT_STATEMENT': , 'UPDATE_STATEMENT': , 'CREATE_STATEMENT': , 'DELETE_STATEMENT': , 'PREPARE_STATEMENT': , 'EXECUTE_STATEMENT': , 'ALTER_STATEMENT': , 'TRANSACTION_STATEMENT': , 'COPY_STATEMENT': , 'ANALYZE_STATEMENT': , 'VARIABLE_SET_STATEMENT': , 'CREATE_FUNC_STATEMENT': , 'EXPLAIN_STATEMENT': , 'DROP_STATEMENT': , 'EXPORT_STATEMENT': , 'PRAGMA_STATEMENT': , 'VACUUM_STATEMENT': , 'CALL_STATEMENT': , 'SET_STATEMENT': , 'LOAD_STATEMENT': , 'RELATION_STATEMENT': , 'EXTENSION_STATEMENT': , 'LOGICAL_PLAN_STATEMENT': , 'ATTACH_STATEMENT': , 'DETACH_STATEMENT': , 'MULTI_STATEMENT': , 'COPY_DATABASE_STATEMENT': , 'MERGE_INTO_STATEMENT': } # noqa: E501 + def __eq__(self, other: object) -> bool: ... + def __getstate__(self) -> int: ... + def __hash__(self) -> int: ... + def __index__(self) -> int: ... + def __init__(self, value: pytyping.SupportsInt) -> None: ... + def __int__(self) -> int: ... + def __ne__(self, other: object) -> bool: ... + def __setstate__(self, state: pytyping.SupportsInt) -> None: ... + @property + def name(self) -> str: ... + @property + def value(self) -> int: ... + +class SyntaxException(ProgrammingError): ... +class TransactionException(OperationalError): ... +class TypeMismatchException(DataError): ... +class Warning(Exception): ... + +class token_type: + __members__: pytyping.ClassVar[ + dict[str, token_type] + ] # value = {'identifier': , 'numeric_const': , 'string_const': , 'operator': , 'keyword': , 'comment': } # noqa: E501 + comment: pytyping.ClassVar[token_type] # value = + identifier: pytyping.ClassVar[token_type] # value = + keyword: pytyping.ClassVar[token_type] # value = + numeric_const: pytyping.ClassVar[token_type] # value = + operator: pytyping.ClassVar[token_type] # value = + string_const: pytyping.ClassVar[token_type] # value = + def __eq__(self, other: object) -> bool: ... + def __getstate__(self) -> int: ... + def __hash__(self) -> int: ... + def __index__(self) -> int: ... + def __init__(self, value: pytyping.SupportsInt) -> None: ... + def __int__(self) -> int: ... + def __ne__(self, other: object) -> bool: ... + def __setstate__(self, state: pytyping.SupportsInt) -> None: ... + @property + def name(self) -> str: ... + @property + def value(self) -> int: ... + +def CaseExpression(condition: Expression, value: Expression) -> Expression: ... +def CoalesceOperator(*args: Expression) -> Expression: ... +def ColumnExpression(*args: str) -> Expression: ... +def ConstantExpression(value: Expression | str) -> Expression: ... +def DefaultExpression() -> Expression: ... +def FunctionExpression(function_name: str, *args: Expression) -> Expression: ... +def LambdaExpression(lhs: Expression | str | tuple[str], rhs: Expression) -> Expression: ... +def SQLExpression(expression: str) -> Expression: ... +@pytyping.overload +def StarExpression(*, exclude: Expression | str | tuple[str]) -> Expression: ... +@pytyping.overload +def StarExpression() -> Expression: ... +def aggregate( + df: pandas.DataFrame, + aggr_expr: Expression | list[Expression] | str | list[str], + group_expr: str = "", + *, + connection: DuckDBPyConnection | None = None, +) -> DuckDBPyRelation: ... +def alias(df: pandas.DataFrame, alias: str, *, connection: DuckDBPyConnection | None = None) -> DuckDBPyRelation: ... +def append( + table_name: str, df: pandas.DataFrame, *, by_name: bool = False, connection: DuckDBPyConnection | None = None +) -> DuckDBPyConnection: ... +def array_type( + type: sqltypes.DuckDBPyType, size: pytyping.SupportsInt, *, connection: DuckDBPyConnection | None = None +) -> sqltypes.DuckDBPyType: ... +@pytyping.overload +def arrow( + rows_per_batch: pytyping.SupportsInt = 1000000, *, connection: DuckDBPyConnection | None = None +) -> pyarrow.lib.RecordBatchReader: ... +@pytyping.overload +def arrow(arrow_object: pytyping.Any, *, connection: DuckDBPyConnection | None = None) -> DuckDBPyRelation: ... +def begin(*, connection: DuckDBPyConnection | None = None) -> DuckDBPyConnection: ... +def checkpoint(*, connection: DuckDBPyConnection | None = None) -> DuckDBPyConnection: ... +def close(*, connection: DuckDBPyConnection | None = None) -> None: ... +def commit(*, connection: DuckDBPyConnection | None = None) -> DuckDBPyConnection: ... +def connect( + database: str | pathlib.Path = ":memory:", + read_only: bool = False, + config: dict[str, str] | None = None, +) -> DuckDBPyConnection: ... +def create_function( + name: str, + function: Callable[..., pytyping.Any], + parameters: list[sqltypes.DuckDBPyType] | None = None, + return_type: sqltypes.DuckDBPyType | None = None, + *, + type: func.PythonUDFType = ..., + null_handling: func.FunctionNullHandling = ..., + exception_handling: PythonExceptionHandling = ..., + side_effects: bool = False, + connection: DuckDBPyConnection | None = None, +) -> DuckDBPyConnection: ... +def cursor(*, connection: DuckDBPyConnection | None = None) -> DuckDBPyConnection: ... +def decimal_type( + width: pytyping.SupportsInt, scale: pytyping.SupportsInt, *, connection: DuckDBPyConnection | None = None +) -> sqltypes.DuckDBPyType: ... +def default_connection() -> DuckDBPyConnection: ... +def description( + *, connection: DuckDBPyConnection | None = None +) -> list[tuple[str, sqltypes.DuckDBPyType, None, None, None, None, None]] | None: ... +@pytyping.overload +def df(*, date_as_object: bool = False, connection: DuckDBPyConnection | None = None) -> pandas.DataFrame: ... +@pytyping.overload +def df(df: pandas.DataFrame, *, connection: DuckDBPyConnection | None = None) -> DuckDBPyRelation: ... +def distinct(df: pandas.DataFrame, *, connection: DuckDBPyConnection | None = None) -> DuckDBPyRelation: ... +def dtype(type_str: str, *, connection: DuckDBPyConnection | None = None) -> sqltypes.DuckDBPyType: ... +def duplicate(*, connection: DuckDBPyConnection | None = None) -> DuckDBPyConnection: ... +def enum_type( + name: str, + type: sqltypes.DuckDBPyType, + values: list[pytyping.Any], + *, + connection: DuckDBPyConnection | None = None, +) -> sqltypes.DuckDBPyType: ... +def execute( + query: Statement | str, + parameters: object = None, + *, + connection: DuckDBPyConnection | None = None, +) -> DuckDBPyConnection: ... +def executemany( + query: Statement | str, + parameters: object = None, + *, + connection: DuckDBPyConnection | None = None, +) -> DuckDBPyConnection: ... +def extract_statements(query: str, *, connection: DuckDBPyConnection | None = None) -> list[Statement]: ... +def fetch_arrow_table( + rows_per_batch: pytyping.SupportsInt = 1000000, *, connection: DuckDBPyConnection | None = None +) -> pyarrow.lib.Table: ... +def fetch_df(*, date_as_object: bool = False, connection: DuckDBPyConnection | None = None) -> pandas.DataFrame: ... +def fetch_df_chunk( + vectors_per_chunk: pytyping.SupportsInt = 1, + *, + date_as_object: bool = False, + connection: DuckDBPyConnection | None = None, +) -> pandas.DataFrame: ... +def fetch_record_batch( + rows_per_batch: pytyping.SupportsInt = 1000000, *, connection: DuckDBPyConnection | None = None +) -> pyarrow.lib.RecordBatchReader: ... +def fetchall(*, connection: DuckDBPyConnection | None = None) -> list[tuple[pytyping.Any, ...]]: ... +def fetchdf(*, date_as_object: bool = False, connection: DuckDBPyConnection | None = None) -> pandas.DataFrame: ... +def fetchmany( + size: pytyping.SupportsInt = 1, *, connection: DuckDBPyConnection | None = None +) -> list[tuple[pytyping.Any, ...]]: ... +def fetchnumpy( + *, connection: DuckDBPyConnection | None = None +) -> dict[str, np.typing.NDArray[pytyping.Any] | pandas.Categorical]: ... +def fetchone(*, connection: DuckDBPyConnection | None = None) -> tuple[pytyping.Any, ...] | None: ... +def filesystem_is_registered(name: str, *, connection: DuckDBPyConnection | None = None) -> bool: ... +def filter( + df: pandas.DataFrame, + filter_expr: Expression | str, + *, + connection: DuckDBPyConnection | None = None, +) -> DuckDBPyRelation: ... +def from_arrow( + arrow_object: object, + *, + connection: DuckDBPyConnection | None = None, +) -> DuckDBPyRelation: ... +def from_csv_auto( + path_or_buffer: str | bytes | os.PathLike[str], + header: bool | int | None = None, + compression: str | None = None, + sep: str | None = None, + delimiter: str | None = None, + files_to_sniff: int | None = None, + comment: str | None = None, + thousands: str | None = None, + dtype: dict[str, str] | list[str] | None = None, + na_values: str | list[str] | None = None, + skiprows: int | None = None, + quotechar: str | None = None, + escapechar: str | None = None, + encoding: str | None = None, + parallel: bool | None = None, + date_format: str | None = None, + timestamp_format: str | None = None, + sample_size: int | None = None, + auto_detect: bool | int | None = None, + all_varchar: bool | None = None, + normalize_names: bool | None = None, + null_padding: bool | None = None, + names: list[str] | None = None, + lineterminator: str | None = None, + columns: dict[str, str] | None = None, + auto_type_candidates: list[str] | None = None, + max_line_size: int | None = None, + ignore_errors: bool | None = None, + store_rejects: bool | None = None, + rejects_table: str | None = None, + rejects_scan: str | None = None, + rejects_limit: int | None = None, + force_not_null: list[str] | None = None, + buffer_size: int | None = None, + decimal: str | None = None, + allow_quoted_nulls: bool | None = None, + filename: bool | str | None = None, + hive_partitioning: bool | None = None, + union_by_name: bool | None = None, + hive_types: dict[str, str] | None = None, + hive_types_autocast: bool | None = None, + strict_mode: bool | None = None, +) -> DuckDBPyRelation: ... +def from_df(df: pandas.DataFrame, *, connection: DuckDBPyConnection | None = None) -> DuckDBPyRelation: ... +@pytyping.overload +def from_parquet( + file_glob: str, + binary_as_string: bool = False, + *, + file_row_number: bool = False, + filename: bool = False, + hive_partitioning: bool = False, + union_by_name: bool = False, + compression: str | None = None, + connection: DuckDBPyConnection | None = None, +) -> DuckDBPyRelation: ... +@pytyping.overload +def from_parquet( + file_globs: Sequence[str], + binary_as_string: bool = False, + *, + file_row_number: bool = False, + filename: bool = False, + hive_partitioning: bool = False, + union_by_name: bool = False, + compression: pytyping.Any = None, + connection: DuckDBPyConnection | None = None, +) -> DuckDBPyRelation: ... +def from_query( + query: Statement | str, + *, + alias: str = "", + params: object = None, + connection: DuckDBPyConnection | None = None, +) -> DuckDBPyRelation: ... +def get_table_names( + query: str, *, qualified: bool = False, connection: DuckDBPyConnection | None = None +) -> set[str]: ... +def install_extension( + extension: str, + *, + force_install: bool = False, + repository: str | None = None, + repository_url: str | None = None, + version: str | None = None, + connection: DuckDBPyConnection | None = None, +) -> None: ... +def interrupt(*, connection: DuckDBPyConnection | None = None) -> None: ... +def limit( + df: pandas.DataFrame, + n: pytyping.SupportsInt, + offset: pytyping.SupportsInt = 0, + *, + connection: DuckDBPyConnection | None = None, +) -> DuckDBPyRelation: ... +def list_filesystems(*, connection: DuckDBPyConnection | None = None) -> list[str]: ... +def list_type( + type: sqltypes.DuckDBPyType, *, connection: DuckDBPyConnection | None = None +) -> sqltypes.DuckDBPyType: ... +def load_extension(extension: str, *, connection: DuckDBPyConnection | None = None) -> None: ... +def map_type( + key: sqltypes.DuckDBPyType, + value: sqltypes.DuckDBPyType, + *, + connection: DuckDBPyConnection | None = None, +) -> sqltypes.DuckDBPyType: ... +def order( + df: pandas.DataFrame, order_expr: str, *, connection: DuckDBPyConnection | None = None +) -> DuckDBPyRelation: ... +def pl( + rows_per_batch: pytyping.SupportsInt = 1000000, + *, + lazy: bool = False, + connection: DuckDBPyConnection | None = None, +) -> polars.DataFrame: ... +def project( + df: pandas.DataFrame, *args: str | Expression, groups: str = "", connection: DuckDBPyConnection | None = None +) -> DuckDBPyRelation: ... +def query( + query: Statement | str, + *, + alias: str = "", + params: object = None, + connection: DuckDBPyConnection | None = None, +) -> DuckDBPyRelation: ... +def query_df( + df: pandas.DataFrame, + virtual_table_name: str, + sql_query: str, + *, + connection: DuckDBPyConnection | None = None, +) -> DuckDBPyRelation: ... +def query_progress(*, connection: DuckDBPyConnection | None = None) -> float: ... +def read_csv( + path_or_buffer: str | bytes | os.PathLike[str], + header: bool | int | None = None, + compression: str | None = None, + sep: str | None = None, + delimiter: str | None = None, + files_to_sniff: int | None = None, + comment: str | None = None, + thousands: str | None = None, + dtype: dict[str, str] | list[str] | None = None, + na_values: str | list[str] | None = None, + skiprows: int | None = None, + quotechar: str | None = None, + escapechar: str | None = None, + encoding: str | None = None, + parallel: bool | None = None, + date_format: str | None = None, + timestamp_format: str | None = None, + sample_size: int | None = None, + auto_detect: bool | int | None = None, + all_varchar: bool | None = None, + normalize_names: bool | None = None, + null_padding: bool | None = None, + names: list[str] | None = None, + lineterminator: str | None = None, + columns: dict[str, str] | None = None, + auto_type_candidates: list[str] | None = None, + max_line_size: int | None = None, + ignore_errors: bool | None = None, + store_rejects: bool | None = None, + rejects_table: str | None = None, + rejects_scan: str | None = None, + rejects_limit: int | None = None, + force_not_null: list[str] | None = None, + buffer_size: int | None = None, + decimal: str | None = None, + allow_quoted_nulls: bool | None = None, + filename: bool | str | None = None, + hive_partitioning: bool | None = None, + union_by_name: bool | None = None, + hive_types: dict[str, str] | None = None, + hive_types_autocast: bool | None = None, + strict_mode: bool | None = None, +) -> DuckDBPyRelation: ... +def read_json( + path_or_buffer: str | bytes | os.PathLike[str], + *, + columns: dict[str, str] | None = None, + sample_size: int | None = None, + maximum_depth: int | None = None, + records: str | None = None, + format: str | None = None, + date_format: str | None = None, + timestamp_format: str | None = None, + compression: str | None = None, + maximum_object_size: int | None = None, + ignore_errors: bool | None = None, + convert_strings_to_integers: bool | None = None, + field_appearance_threshold: float | None = None, + map_inference_threshold: int | None = None, + maximum_sample_files: int | None = None, + filename: bool | str | None = None, + hive_partitioning: bool | None = None, + union_by_name: bool | None = None, + hive_types: dict[str, str] | None = None, + hive_types_autocast: bool | None = None, +) -> DuckDBPyRelation: ... +@pytyping.overload +def read_parquet( + file_glob: str, + binary_as_string: bool = False, + *, + file_row_number: bool = False, + filename: bool = False, + hive_partitioning: bool = False, + union_by_name: bool = False, + compression: str | None = None, + connection: DuckDBPyConnection | None = None, +) -> DuckDBPyRelation: ... +@pytyping.overload +def read_parquet( + file_globs: Sequence[str], + binary_as_string: bool = False, + *, + file_row_number: bool = False, + filename: bool = False, + hive_partitioning: bool = False, + union_by_name: bool = False, + compression: pytyping.Any = None, + connection: DuckDBPyConnection | None = None, +) -> DuckDBPyRelation: ... +def register( + view_name: str, + python_object: object, + *, + connection: DuckDBPyConnection | None = None, +) -> DuckDBPyConnection: ... +def register_filesystem( + filesystem: fsspec.AbstractFileSystem, *, connection: DuckDBPyConnection | None = None +) -> None: ... +def remove_function(name: str, *, connection: DuckDBPyConnection | None = None) -> DuckDBPyConnection: ... +def rollback(*, connection: DuckDBPyConnection | None = None) -> DuckDBPyConnection: ... +def row_type( + fields: dict[str, sqltypes.DuckDBPyType] | list[sqltypes.DuckDBPyType], + *, + connection: DuckDBPyConnection | None = None, +) -> sqltypes.DuckDBPyType: ... +def rowcount(*, connection: DuckDBPyConnection | None = None) -> int: ... +def set_default_connection(connection: DuckDBPyConnection) -> None: ... +def sql( + query: Statement | str, + *, + alias: str = "", + params: object = None, + connection: DuckDBPyConnection | None = None, +) -> DuckDBPyRelation: ... +def sqltype(type_str: str, *, connection: DuckDBPyConnection | None = None) -> sqltypes.DuckDBPyType: ... +def string_type(collation: str = "", *, connection: DuckDBPyConnection | None = None) -> sqltypes.DuckDBPyType: ... +def struct_type( + fields: dict[str, sqltypes.DuckDBPyType] | list[sqltypes.DuckDBPyType], + *, + connection: DuckDBPyConnection | None = None, +) -> sqltypes.DuckDBPyType: ... +def table(table_name: str, *, connection: DuckDBPyConnection | None = None) -> DuckDBPyRelation: ... +def table_function( + name: str, + parameters: object = None, + *, + connection: DuckDBPyConnection | None = None, +) -> DuckDBPyRelation: ... +def tf(*, connection: DuckDBPyConnection | None = None) -> dict[str, tensorflow.Tensor]: ... +def tokenize(query: str) -> list[tuple[int, token_type]]: ... +def torch(*, connection: DuckDBPyConnection | None = None) -> dict[str, pytorch.Tensor]: ... +def type(type_str: str, *, connection: DuckDBPyConnection | None = None) -> sqltypes.DuckDBPyType: ... +def union_type( + members: dict[str, sqltypes.DuckDBPyType] | list[sqltypes.DuckDBPyType], + *, + connection: DuckDBPyConnection | None = None, +) -> sqltypes.DuckDBPyType: ... +def unregister(view_name: str, *, connection: DuckDBPyConnection | None = None) -> DuckDBPyConnection: ... +def unregister_filesystem(name: str, *, connection: DuckDBPyConnection | None = None) -> None: ... +def values( + *args: list[pytyping.Any] | tuple[Expression, ...] | Expression, connection: DuckDBPyConnection | None = None +) -> DuckDBPyRelation: ... +def view(view_name: str, *, connection: DuckDBPyConnection | None = None) -> DuckDBPyRelation: ... +def write_csv( + df: pandas.DataFrame, + filename: str, + *, + sep: str | None = None, + na_rep: str | None = None, + header: bool | None = None, + quotechar: str | None = None, + escapechar: str | None = None, + date_format: str | None = None, + timestamp_format: str | None = None, + quoting: str | int | None = None, + encoding: str | None = None, + compression: str | None = None, + overwrite: bool | None = None, + per_thread_output: bool | None = None, + use_tmp_file: bool | None = None, + partition_by: list[str] | None = None, + write_partition_columns: bool | None = None, +) -> None: ... + +__formatted_python_version__: str +__git_revision__: str +__interactive__: bool +__jupyter__: bool +__standard_vector_size__: int +__version__: str +_clean_default_connection: pytyping.Any # value = +apilevel: str +paramstyle: str +threadsafety: int diff --git a/_duckdb-stubs/_func.pyi b/_duckdb-stubs/_func.pyi new file mode 100644 index 00000000..68484499 --- /dev/null +++ b/_duckdb-stubs/_func.pyi @@ -0,0 +1,46 @@ +import typing as pytyping + +__all__: list[str] = ["ARROW", "DEFAULT", "NATIVE", "SPECIAL", "FunctionNullHandling", "PythonUDFType"] + +class FunctionNullHandling: + DEFAULT: pytyping.ClassVar[FunctionNullHandling] # value = + SPECIAL: pytyping.ClassVar[FunctionNullHandling] # value = + __members__: pytyping.ClassVar[ + dict[str, FunctionNullHandling] + ] # value = {'DEFAULT': , 'SPECIAL': } + def __eq__(self, other: object) -> bool: ... + def __getstate__(self) -> int: ... + def __hash__(self) -> int: ... + def __index__(self) -> int: ... + def __init__(self, value: pytyping.SupportsInt) -> None: ... + def __int__(self) -> int: ... + def __ne__(self, other: object) -> bool: ... + def __setstate__(self, state: pytyping.SupportsInt) -> None: ... + @property + def name(self) -> str: ... + @property + def value(self) -> int: ... + +class PythonUDFType: + ARROW: pytyping.ClassVar[PythonUDFType] # value = + NATIVE: pytyping.ClassVar[PythonUDFType] # value = + __members__: pytyping.ClassVar[ + dict[str, PythonUDFType] + ] # value = {'NATIVE': , 'ARROW': } + def __eq__(self, other: object) -> bool: ... + def __getstate__(self) -> int: ... + def __hash__(self) -> int: ... + def __index__(self) -> int: ... + def __init__(self, value: pytyping.SupportsInt) -> None: ... + def __int__(self) -> int: ... + def __ne__(self, other: object) -> bool: ... + def __setstate__(self, state: pytyping.SupportsInt) -> None: ... + @property + def name(self) -> str: ... + @property + def value(self) -> int: ... + +ARROW: PythonUDFType # value = +DEFAULT: FunctionNullHandling # value = +NATIVE: PythonUDFType # value = +SPECIAL: FunctionNullHandling # value = diff --git a/_duckdb-stubs/_sqltypes.pyi b/_duckdb-stubs/_sqltypes.pyi new file mode 100644 index 00000000..88abb977 --- /dev/null +++ b/_duckdb-stubs/_sqltypes.pyi @@ -0,0 +1,75 @@ +import duckdb +import typing as pytyping + +__all__: list[str] = [ + "BIGINT", + "BIT", + "BLOB", + "BOOLEAN", + "DATE", + "DOUBLE", + "FLOAT", + "HUGEINT", + "INTEGER", + "INTERVAL", + "SMALLINT", + "SQLNULL", + "TIME", + "TIMESTAMP", + "TIMESTAMP_MS", + "TIMESTAMP_NS", + "TIMESTAMP_S", + "TIMESTAMP_TZ", + "TIME_TZ", + "TINYINT", + "UBIGINT", + "UHUGEINT", + "UINTEGER", + "USMALLINT", + "UTINYINT", + "UUID", + "VARCHAR", + "DuckDBPyType", +] + +class DuckDBPyType: + def __eq__(self, other: object) -> bool: ... + def __getattr__(self, name: str) -> DuckDBPyType: ... + def __getitem__(self, name: str) -> DuckDBPyType: ... + def __hash__(self) -> int: ... + @pytyping.overload + def __init__(self, type_str: str, connection: duckdb.DuckDBPyConnection) -> None: ... + @pytyping.overload + def __init__(self, obj: object) -> None: ... + @property + def children(self) -> list[tuple[str, object]]: ... + @property + def id(self) -> str: ... + +BIGINT: DuckDBPyType # value = BIGINT +BIT: DuckDBPyType # value = BIT +BLOB: DuckDBPyType # value = BLOB +BOOLEAN: DuckDBPyType # value = BOOLEAN +DATE: DuckDBPyType # value = DATE +DOUBLE: DuckDBPyType # value = DOUBLE +FLOAT: DuckDBPyType # value = FLOAT +HUGEINT: DuckDBPyType # value = HUGEINT +INTEGER: DuckDBPyType # value = INTEGER +INTERVAL: DuckDBPyType # value = INTERVAL +SMALLINT: DuckDBPyType # value = SMALLINT +SQLNULL: DuckDBPyType # value = "NULL" +TIME: DuckDBPyType # value = TIME +TIMESTAMP: DuckDBPyType # value = TIMESTAMP +TIMESTAMP_MS: DuckDBPyType # value = TIMESTAMP_MS +TIMESTAMP_NS: DuckDBPyType # value = TIMESTAMP_NS +TIMESTAMP_S: DuckDBPyType # value = TIMESTAMP_S +TIMESTAMP_TZ: DuckDBPyType # value = TIMESTAMP WITH TIME ZONE +TIME_TZ: DuckDBPyType # value = TIME WITH TIME ZONE +TINYINT: DuckDBPyType # value = TINYINT +UBIGINT: DuckDBPyType # value = UBIGINT +UHUGEINT: DuckDBPyType # value = UHUGEINT +UINTEGER: DuckDBPyType # value = UINTEGER +USMALLINT: DuckDBPyType # value = USMALLINT +UTINYINT: DuckDBPyType # value = UTINYINT +UUID: DuckDBPyType # value = UUID +VARCHAR: DuckDBPyType # value = VARCHAR diff --git a/adbc_driver_duckdb/__init__.py b/adbc_driver_duckdb/__init__.py index 528be73f..e81f5090 100644 --- a/adbc_driver_duckdb/__init__.py +++ b/adbc_driver_duckdb/__init__.py @@ -19,12 +19,11 @@ import enum import functools +import importlib import typing import adbc_driver_manager -__all__ = ["StatementOptions", "connect"] - class StatementOptions(enum.Enum): """Statement options specific to the DuckDB driver.""" @@ -36,12 +35,16 @@ class StatementOptions(enum.Enum): def connect(path: typing.Optional[str] = None) -> adbc_driver_manager.AdbcDatabase: """Create a low level ADBC connection to DuckDB.""" if path is None: - return adbc_driver_manager.AdbcDatabase(driver=_driver_path(), entrypoint="duckdb_adbc_init") - return adbc_driver_manager.AdbcDatabase(driver=_driver_path(), entrypoint="duckdb_adbc_init", path=path) + return adbc_driver_manager.AdbcDatabase(driver=driver_path(), entrypoint="duckdb_adbc_init") + return adbc_driver_manager.AdbcDatabase(driver=driver_path(), entrypoint="duckdb_adbc_init", path=path) @functools.cache -def _driver_path() -> str: - import duckdb - - return duckdb.duckdb.__file__ +def driver_path() -> str: + """Get the path to the DuckDB ADBC driver.""" + duckdb_module_spec = importlib.util.find_spec("_duckdb") + if duckdb_module_spec is None: + msg = "Could not find duckdb shared library. Did you pip install duckdb?" + raise ImportError(msg) + print(f"Found duckdb shared library at {duckdb_module_spec.origin}") + return duckdb_module_spec.origin diff --git a/adbc_driver_duckdb/dbapi.py b/adbc_driver_duckdb/dbapi.py index 793c4242..5d0a8702 100644 --- a/adbc_driver_duckdb/dbapi.py +++ b/adbc_driver_duckdb/dbapi.py @@ -15,14 +15,13 @@ # specific language governing permissions and limitations # under the License. -""" -DBAPI 2.0-compatible facade for the ADBC DuckDB driver. -""" +"""DBAPI 2.0-compatible facade for the ADBC DuckDB driver.""" import typing import adbc_driver_manager import adbc_driver_manager.dbapi + import adbc_driver_duckdb __all__ = [ diff --git a/cmake/compiler_launcher.cmake b/cmake/compiler_launcher.cmake index d8d1598a..8f77da86 100644 --- a/cmake/compiler_launcher.cmake +++ b/cmake/compiler_launcher.cmake @@ -8,19 +8,25 @@ include(CMakeParseArguments) # Function to look for ccache and sccache to speed up builds, if available # ──────────────────────────────────────────── function(setup_compiler_launcher_if_available) - if(NOT DEFINED CMAKE_C_COMPILER_LAUNCHER AND NOT DEFINED ENV{CMAKE_C_COMPILER_LAUNCHER}) - find_program(COMPILER_LAUNCHER NAMES ccache sccache) - if(COMPILER_LAUNCHER) - message(STATUS "Using ${COMPILER_LAUNCHER} as C compiler launcher") - set(CMAKE_C_COMPILER_LAUNCHER "${COMPILER_LAUNCHER}" CACHE STRING "" FORCE) - endif() + if(NOT DEFINED CMAKE_C_COMPILER_LAUNCHER AND NOT DEFINED + ENV{CMAKE_C_COMPILER_LAUNCHER}) + find_program(COMPILER_LAUNCHER NAMES ccache sccache) + if(COMPILER_LAUNCHER) + message(STATUS "Using ${COMPILER_LAUNCHER} as C compiler launcher") + set(CMAKE_C_COMPILER_LAUNCHER + "${COMPILER_LAUNCHER}" + CACHE STRING "" FORCE) endif() + endif() - if(NOT DEFINED CMAKE_CXX_COMPILER_LAUNCHER AND NOT DEFINED ENV{CMAKE_CXX_COMPILER_LAUNCHER}) - find_program(COMPILER_LAUNCHER NAMES ccache sccache) - if(COMPILER_LAUNCHER) - message(STATUS "Using ${COMPILER_LAUNCHER} as C++ compiler launcher") - set(CMAKE_CXX_COMPILER_LAUNCHER "${COMPILER_LAUNCHER}" CACHE STRING "" FORCE) - endif() + if(NOT DEFINED CMAKE_CXX_COMPILER_LAUNCHER + AND NOT DEFINED ENV{CMAKE_CXX_COMPILER_LAUNCHER}) + find_program(COMPILER_LAUNCHER NAMES ccache sccache) + if(COMPILER_LAUNCHER) + message(STATUS "Using ${COMPILER_LAUNCHER} as C++ compiler launcher") + set(CMAKE_CXX_COMPILER_LAUNCHER + "${COMPILER_LAUNCHER}" + CACHE STRING "" FORCE) endif() + endif() endfunction() diff --git a/cmake/duckdb_loader.cmake b/cmake/duckdb_loader.cmake index 80ec9a4d..2fc738fb 100644 --- a/cmake/duckdb_loader.cmake +++ b/cmake/duckdb_loader.cmake @@ -2,25 +2,20 @@ # # Simple DuckDB Build Configuration Module # -# Sets sensible defaults for DuckDB Python extension builds and provides -# a clean interface for adding DuckDB as a library target. Adds jemalloc -# option for debugging but will never allow jemalloc in a release build if -# not on Linux. +# Sets sensible defaults for DuckDB Python extension builds and provides a clean +# interface for adding DuckDB as a library target. Adds jemalloc option for +# debugging but will never allow jemalloc in a release build if not on Linux. # -# Usage: -# include(cmake/duckdb_loader.cmake) -# # Optionally load extensions -# set(CORE_EXTENSIONS "json;parquet;icu") +# Usage: include(cmake/duckdb_loader.cmake) # Optionally load extensions +# set(CORE_EXTENSIONS "json;parquet;icu") # -# # set sensible defaults for a debug build: -# duckdb_configure_for_debug() +# # set sensible defaults for a debug build: duckdb_configure_for_debug() # -# # ...or, set sensible defaults for a release build: -# duckdb_configure_for_release() +# # ...or, set sensible defaults for a release build: +# duckdb_configure_for_release() # -# # Link to your target -# duckdb_add_library(duckdb_target) -# target_link_libraries(my_lib PRIVATE ${duckdb_target}) +# # Link to your target duckdb_add_library(duckdb_target) +# target_link_libraries(my_lib PRIVATE ${duckdb_target}) include_guard(GLOBAL) @@ -30,13 +25,14 @@ include_guard(GLOBAL) # Helper macro to set default values that can be overridden from command line macro(_duckdb_set_default var_name default_value) - if(NOT DEFINED ${var_name}) - set(${var_name} ${default_value}) - endif() + if(NOT DEFINED ${var_name}) + set(${var_name} ${default_value}) + endif() endmacro() # Source configuration -_duckdb_set_default(DUCKDB_SOURCE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/external/duckdb") +_duckdb_set_default(DUCKDB_SOURCE_PATH + "${CMAKE_CURRENT_SOURCE_DIR}/external/duckdb") # Extension list - commonly used extensions for Python _duckdb_set_default(CORE_EXTENSIONS "core_functions;parquet;icu;json") @@ -55,7 +51,8 @@ _duckdb_set_default(ENABLE_EXTENSION_AUTOLOADING ON) # Performance options - enable optimizations by default _duckdb_set_default(NATIVE_ARCH OFF) -# Sanitizers are off for Python by default. Enabling might result in "symbol not found" for '___ubsan_vptr_type_cache' +# Sanitizers are off for Python by default. Enabling might result in "symbol not +# found" for '___ubsan_vptr_type_cache' _duckdb_set_default(ENABLE_SANITIZER OFF) _duckdb_set_default(ENABLE_UBSAN OFF) @@ -64,141 +61,185 @@ _duckdb_set_default(FORCE_ASSERT OFF) _duckdb_set_default(DEBUG_STACKTRACE OFF) # Convert to cache variables for CMake GUI/ccmake compatibility -set(DUCKDB_SOURCE_PATH "${DUCKDB_SOURCE_PATH}" CACHE PATH "Path to DuckDB source directory") -set(CORE_EXTENSIONS "${CORE_EXTENSIONS}" CACHE STRING "Semicolon-separated list of extensions to enable") -set(BUILD_SHELL "${BUILD_SHELL}" CACHE BOOL "Build the DuckDB shell executable") -set(BUILD_UNITTESTS "${BUILD_UNITTESTS}" CACHE BOOL "Build DuckDB unit tests") -set(BUILD_BENCHMARKS "${BUILD_BENCHMARKS}" CACHE BOOL "Build DuckDB benchmarks") -set(DISABLE_UNITY "${DISABLE_UNITY}" CACHE BOOL "Disable unity builds (slower compilation)") -set(DISABLE_BUILTIN_EXTENSIONS "${DISABLE_BUILTIN_EXTENSIONS}" CACHE BOOL "Disable all built-in extensions") -set(ENABLE_EXTENSION_AUTOINSTALL "${ENABLE_EXTENSION_AUTOINSTALL}" CACHE BOOL "Enable extension auto-installing by default.") -set(ENABLE_EXTENSION_AUTOLOADING "${ENABLE_EXTENSION_AUTOLOADING}" CACHE BOOL "Enable extension auto-loading by default.") -set(NATIVE_ARCH "${NATIVE_ARCH}" CACHE BOOL "Optimize for native architecture") -set(ENABLE_SANITIZER "${ENABLE_SANITIZER}" CACHE BOOL "Enable address sanitizer") -set(ENABLE_UBSAN "${ENABLE_UBSAN}" CACHE BOOL "Enable undefined behavior sanitizer") -set(FORCE_ASSERT "${FORCE_ASSERT}" CACHE BOOL "Enable assertions in release builds") -set(DEBUG_STACKTRACE "${DEBUG_STACKTRACE}" CACHE BOOL "Print a stracktrace on asserts and when testing crashes") +set(DUCKDB_SOURCE_PATH + "${DUCKDB_SOURCE_PATH}" + CACHE PATH "Path to DuckDB source directory") +set(CORE_EXTENSIONS + "${CORE_EXTENSIONS}" + CACHE STRING "Semicolon-separated list of extensions to enable") +set(BUILD_SHELL + "${BUILD_SHELL}" + CACHE BOOL "Build the DuckDB shell executable") +set(BUILD_UNITTESTS + "${BUILD_UNITTESTS}" + CACHE BOOL "Build DuckDB unit tests") +set(BUILD_BENCHMARKS + "${BUILD_BENCHMARKS}" + CACHE BOOL "Build DuckDB benchmarks") +set(DISABLE_UNITY + "${DISABLE_UNITY}" + CACHE BOOL "Disable unity builds (slower compilation)") +set(DISABLE_BUILTIN_EXTENSIONS + "${DISABLE_BUILTIN_EXTENSIONS}" + CACHE BOOL "Disable all built-in extensions") +set(ENABLE_EXTENSION_AUTOINSTALL + "${ENABLE_EXTENSION_AUTOINSTALL}" + CACHE BOOL "Enable extension auto-installing by default.") +set(ENABLE_EXTENSION_AUTOLOADING + "${ENABLE_EXTENSION_AUTOLOADING}" + CACHE BOOL "Enable extension auto-loading by default.") +set(NATIVE_ARCH + "${NATIVE_ARCH}" + CACHE BOOL "Optimize for native architecture") +set(ENABLE_SANITIZER + "${ENABLE_SANITIZER}" + CACHE BOOL "Enable address sanitizer") +set(ENABLE_UBSAN + "${ENABLE_UBSAN}" + CACHE BOOL "Enable undefined behavior sanitizer") +set(FORCE_ASSERT + "${FORCE_ASSERT}" + CACHE BOOL "Enable assertions in release builds") +set(DEBUG_STACKTRACE + "${DEBUG_STACKTRACE}" + CACHE BOOL "Print a stracktrace on asserts and when testing crashes") # ════════════════════════════════════════════════════════════════════════════════ # Internal Functions # ════════════════════════════════════════════════════════════════════════════════ function(_duckdb_validate_jemalloc_config) - # Check if jemalloc is in the extension list - if(NOT CORE_EXTENSIONS MATCHES "jemalloc") - return() + # Check if jemalloc is in the extension list + if(NOT CORE_EXTENSIONS MATCHES "jemalloc") + return() + endif() + + # If we're on Linux then using jemalloc is fine, otherwise we only allow it in + # debug builds + if(NOT CMAKE_SYSTEM_NAME STREQUAL "Linux") + set(is_debug_build FALSE) + if(CMAKE_BUILD_TYPE STREQUAL "Debug") + set(is_debug_build TRUE) endif() - - # If we're on Linux then using jemalloc is fine, otherwise we only allow it in debug builds - if(NOT CMAKE_SYSTEM_NAME STREQUAL "Linux") - set(is_debug_build FALSE) - if(CMAKE_BUILD_TYPE STREQUAL "Debug") - set(is_debug_build TRUE) - endif() - if(is_debug_build) - message(WARNING - "jemalloc extension enabled on ${CMAKE_SYSTEM_NAME} in Debug build.\n" - "This is only recommended for debugging purposes.\n" - "jemalloc is officially supported only on Linux.") - else() - message(WARNING - "jemalloc extension is only supported on ${CMAKE_SYSTEM_NAME} in Debug builds.\n" - "Removing jemalloc from extension list.\n" - "In non-debug builds, jemalloc is only supported on Linux.") - # Remove jemalloc from the extension list - string(REPLACE "jemalloc" "" CORE_EXTENSIONS_FILTERED "${CORE_EXTENSIONS}") - string(REGEX REPLACE ";+" ";" CORE_EXTENSIONS_FILTERED "${CORE_EXTENSIONS_FILTERED}") - string(REGEX REPLACE "^;|;$" "" CORE_EXTENSIONS_FILTERED "${CORE_EXTENSIONS_FILTERED}") - set(CORE_EXTENSIONS "${CORE_EXTENSIONS_FILTERED}" PARENT_SCOPE) - endif() + if(is_debug_build) + message( + WARNING + "jemalloc extension enabled on ${CMAKE_SYSTEM_NAME} in Debug build.\n" + "This is only recommended for debugging purposes.\n" + "jemalloc is officially supported only on Linux.") + else() + message( + WARNING + "jemalloc extension is only supported on ${CMAKE_SYSTEM_NAME} in Debug builds.\n" + "Removing jemalloc from extension list.\n" + "In non-debug builds, jemalloc is only supported on Linux.") + # Remove jemalloc from the extension list + string(REPLACE "jemalloc" "" CORE_EXTENSIONS_FILTERED + "${CORE_EXTENSIONS}") + string(REGEX REPLACE ";+" ";" CORE_EXTENSIONS_FILTERED + "${CORE_EXTENSIONS_FILTERED}") + string(REGEX REPLACE "^;|;$" "" CORE_EXTENSIONS_FILTERED + "${CORE_EXTENSIONS_FILTERED}") + set(CORE_EXTENSIONS + "${CORE_EXTENSIONS_FILTERED}" + PARENT_SCOPE) endif() + endif() endfunction() function(_duckdb_validate_source_path) - if(NOT EXISTS "${DUCKDB_SOURCE_PATH}") - message(FATAL_ERROR - "DuckDB source path does not exist: ${DUCKDB_SOURCE_PATH}\n" - "Please set DUCKDB_SOURCE_PATH to the correct location.") - endif() - - if(NOT EXISTS "${DUCKDB_SOURCE_PATH}/CMakeLists.txt") - message(FATAL_ERROR - "DuckDB source path does not contain CMakeLists.txt: ${DUCKDB_SOURCE_PATH}\n" - "Please ensure this points to the root of DuckDB source tree.") - endif() + if(NOT EXISTS "${DUCKDB_SOURCE_PATH}") + message( + FATAL_ERROR "DuckDB source path does not exist: ${DUCKDB_SOURCE_PATH}\n" + "Please set DUCKDB_SOURCE_PATH to the correct location.") + endif() + + if(NOT EXISTS "${DUCKDB_SOURCE_PATH}/CMakeLists.txt") + message( + FATAL_ERROR + "DuckDB source path does not contain CMakeLists.txt: ${DUCKDB_SOURCE_PATH}\n" + "Please ensure this points to the root of DuckDB source tree.") + endif() endfunction() function(_duckdb_create_interface_target target_name) - add_library(${target_name} INTERFACE) - - # Include directories to deal with leaking 3rd party headers in duckdb headers - # See https://github.com/duckdblabs/duckdb-internal/issues/5084 - target_include_directories(${target_name} INTERFACE - # Main DuckDB headers - $ - # Third-party headers that leak through DuckDB's API - $ - $ - $ - $ - $ - $ + add_library(${target_name} INTERFACE) + + # Include directories to deal with leaking 3rd party headers in duckdb headers + # See https://github.com/duckdblabs/duckdb-internal/issues/5084 + target_include_directories( + ${target_name} + INTERFACE + # Main DuckDB headers + $ + # Third-party headers that leak through DuckDB's API + $ + $ + $ + $ + $ + $) + + # Compile definitions based on configuration + target_compile_definitions( + ${target_name} INTERFACE $<$:DUCKDB_FORCE_ASSERT> + $<$:DUCKDB_DEBUG_MODE>) + + if(CMAKE_SYSTEM_NAME STREQUAL "Windows") + target_compile_options( + ${target_name} + INTERFACE + /wd4244 # suppress Conversion from 'type1' to 'type2', possible loss of + # data + /wd4267 # suppress Conversion from ‘size_t’ to ‘type’, possible loss of + # data + /wd4200 # suppress Nonstandard extension used: zero-sized array in + # struct/union + /wd26451 + /wd26495 # suppress Code Analysis + /D_CRT_SECURE_NO_WARNINGS # suppress warnings about unsafe functions + /D_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR # see + # https://github.com/duckdblabs/duckdb-internal/issues/5151 + /utf-8 # treat source files as UTF-8 encoded ) - - # Compile definitions based on configuration - target_compile_definitions(${target_name} INTERFACE - $<$:DUCKDB_FORCE_ASSERT> - $<$:DUCKDB_DEBUG_MODE> - ) - - if(CMAKE_SYSTEM_NAME STREQUAL "Windows") - target_compile_options(${target_name} INTERFACE - /wd4244 # suppress Conversion from 'type1' to 'type2', possible loss of data - /wd4267 # suppress Conversion from ‘size_t’ to ‘type’, possible loss of data - /wd4200 # suppress Nonstandard extension used: zero-sized array in struct/union - /wd26451 /wd26495 # suppress Code Analysis - /D_CRT_SECURE_NO_WARNINGS # suppress warnings about unsafe functions - /D_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR # see https://github.com/duckdblabs/duckdb-internal/issues/5151 - /utf-8 # treat source files as UTF-8 encoded - ) - elseif(CMAKE_SYSTEM_NAME STREQUAL "Darwin") - target_compile_options(${target_name} INTERFACE - -stdlib=libc++ # for libc++ in favor of older libstdc++ + elseif(CMAKE_SYSTEM_NAME STREQUAL "Darwin") + target_compile_options( + ${target_name} + INTERFACE -stdlib=libc++ # for libc++ in favor of older libstdc++ -mmacosx-version-min=10.7 # minimum osx version compatibility - ) - endif() + ) + endif() - # Link to the DuckDB static library - target_link_libraries(${target_name} INTERFACE duckdb_static) + # Link to the DuckDB static library + target_link_libraries(${target_name} INTERFACE duckdb_static) - # Enable position independent code for shared library builds - set_target_properties(${target_name} PROPERTIES - INTERFACE_POSITION_INDEPENDENT_CODE ON - ) + # Enable position independent code for shared library builds + set_target_properties(${target_name} + PROPERTIES INTERFACE_POSITION_INDEPENDENT_CODE ON) endfunction() function(_duckdb_print_summary) - message(STATUS "DuckDB Configuration:") - message(STATUS " Source: ${DUCKDB_SOURCE_PATH}") - message(STATUS " Build Type: ${CMAKE_BUILD_TYPE}") - message(STATUS " Native Arch: ${NATIVE_ARCH}") - message(STATUS " Unity Build Disabled: ${DISABLE_UNITY}") - - if(CORE_EXTENSIONS) - message(STATUS " Extensions: ${CORE_EXTENSIONS}") - endif() - - set(debug_opts) - if(FORCE_ASSERT) - list(APPEND debug_opts "FORCE_ASSERT") - endif() - if(DEBUG_STACKTRACE) - list(APPEND debug_opts "DEBUG_STACKTRACE") - endif() - - if(debug_opts) - message(STATUS " Debug Options: ${debug_opts}") - endif() + message(STATUS "DuckDB Configuration:") + message(STATUS " Source: ${DUCKDB_SOURCE_PATH}") + message(STATUS " Build Type: ${CMAKE_BUILD_TYPE}") + message(STATUS " Native Arch: ${NATIVE_ARCH}") + message(STATUS " Unity Build Disabled: ${DISABLE_UNITY}") + + if(CORE_EXTENSIONS) + message(STATUS " Extensions: ${CORE_EXTENSIONS}") + endif() + + set(debug_opts) + if(FORCE_ASSERT) + list(APPEND debug_opts "FORCE_ASSERT") + endif() + if(DEBUG_STACKTRACE) + list(APPEND debug_opts "DEBUG_STACKTRACE") + endif() + + if(debug_opts) + message(STATUS " Debug Options: ${debug_opts}") + endif() endfunction() # ════════════════════════════════════════════════════════════════════════════════ @@ -206,15 +247,15 @@ endfunction() # ════════════════════════════════════════════════════════════════════════════════ function(duckdb_add_library target_name) - _duckdb_validate_source_path() - _duckdb_validate_jemalloc_config() - _duckdb_print_summary() + _duckdb_validate_source_path() + _duckdb_validate_jemalloc_config() + _duckdb_print_summary() - # Add DuckDB subdirectory - it will use our variables - add_subdirectory("${DUCKDB_SOURCE_PATH}" duckdb EXCLUDE_FROM_ALL) + # Add DuckDB subdirectory - it will use our variables + add_subdirectory("${DUCKDB_SOURCE_PATH}" duckdb EXCLUDE_FROM_ALL) - # Create clean interface target - _duckdb_create_interface_target(${target_name}) + # Create clean interface target + _duckdb_create_interface_target(${target_name}) endfunction() # ════════════════════════════════════════════════════════════════════════════════ @@ -222,16 +263,20 @@ endfunction() # ════════════════════════════════════════════════════════════════════════════════ function(duckdb_configure_for_debug) - # Only set if not already defined (allows override from command line) - if(NOT DEFINED FORCE_ASSERT) - set(FORCE_ASSERT ON PARENT_SCOPE) - endif() - if(NOT DEFINED DEBUG_STACKTRACE) - set(DEBUG_STACKTRACE ON PARENT_SCOPE) - endif() - message(STATUS "DuckDB: Configured for debug build") + # Only set if not already defined (allows override from command line) + if(NOT DEFINED FORCE_ASSERT) + set(FORCE_ASSERT + ON + PARENT_SCOPE) + endif() + if(NOT DEFINED DEBUG_STACKTRACE) + set(DEBUG_STACKTRACE + ON + PARENT_SCOPE) + endif() + message(STATUS "DuckDB: Configured for debug build") endfunction() function(duckdb_configure_for_release) - message(STATUS "DuckDB: Configured for release build") -endfunction() \ No newline at end of file + message(STATUS "DuckDB: Configured for release build") +endfunction() diff --git a/duckdb/__init__.py b/duckdb/__init__.py index b5e994fa..e1a4aa9a 100644 --- a/duckdb/__init__.py +++ b/duckdb/__init__.py @@ -1,475 +1,381 @@ -# Modules -import duckdb.functional as functional -import duckdb.typing as typing -from _duckdb import __version__ as duckdb_version -from importlib.metadata import version +# ruff: noqa: F401 +"""The DuckDB Python Package. -# duckdb.__version__ returns the version of the distribution package, i.e. the pypi version -__version__ = version("duckdb") +This module re-exports the DuckDB C++ extension (`_duckdb`) and provides DuckDB's public API. -# version() is a more human friendly formatted version string of both the distribution package and the bundled duckdb -def version(): - return f"{__version__} (with duckdb {duckdb_version})" +Note: +- Some symbols exposed here are implementation details of DuckDB's C++ engine. +- They are kept for backwards compatibility but are not considered stable API. +- Future versions may move them into submodules with deprecation warnings. +""" -_exported_symbols = ['__version__', 'version'] - -_exported_symbols.extend([ - "typing", - "functional" -]) - -class DBAPITypeObject: - def __init__(self, types: list[typing.DuckDBPyType]) -> None: - self.types = types - - def __eq__(self, other): - if isinstance(other, typing.DuckDBPyType): - return other in self.types - return False - - def __repr__(self): - return f"" - -# Define the standard DBAPI sentinels -STRING = DBAPITypeObject([typing.VARCHAR]) -NUMBER = DBAPITypeObject([ - typing.TINYINT, - typing.UTINYINT, - typing.SMALLINT, - typing.USMALLINT, - typing.INTEGER, - typing.UINTEGER, - typing.BIGINT, - typing.UBIGINT, - typing.HUGEINT, - typing.UHUGEINT, - typing.DuckDBPyType("BIGNUM"), - typing.DuckDBPyType("DECIMAL"), - typing.FLOAT, - typing.DOUBLE -]) -DATETIME = DBAPITypeObject([ - typing.DATE, - typing.TIME, - typing.TIME_TZ, - typing.TIMESTAMP, - typing.TIMESTAMP_TZ, - typing.TIMESTAMP_NS, - typing.TIMESTAMP_MS, - typing.TIMESTAMP_S -]) -BINARY = DBAPITypeObject([typing.BLOB]) -ROWID = None - -# Classes from _duckdb import ( - DuckDBPyRelation, - DuckDBPyConnection, - Statement, - ExplainType, - StatementType, - ExpectedResultType, - CSVLineTerminator, - PythonExceptionHandling, - RenderMode, - Expression, - ConstantExpression, + BinderException, + CaseExpression, + CatalogException, + CoalesceOperator, ColumnExpression, + ConnectionException, + ConstantExpression, + ConstraintException, + ConversionException, + CSVLineTerminator, + DatabaseError, + DataError, DefaultExpression, - CoalesceOperator, - LambdaExpression, - StarExpression, - FunctionExpression, - CaseExpression, - SQLExpression -) -_exported_symbols.extend([ - "DuckDBPyRelation", - "DuckDBPyConnection", - "ExplainType", - "PythonExceptionHandling", - "Expression", - "ConstantExpression", - "ColumnExpression", - "DefaultExpression", - "CoalesceOperator", - "LambdaExpression", - "StarExpression", - "FunctionExpression", - "CaseExpression", - "SQLExpression" -]) - -# These are overloaded twice, we define them inside of C++ so pybind can deal with it -_exported_symbols.extend([ - 'df', - 'arrow' -]) -from _duckdb import ( - df, - arrow -) - -# NOTE: this section is generated by tools/pythonpkg/scripts/generate_connection_wrapper_methods.py. -# Do not edit this section manually, your changes will be overwritten! - -# START OF CONNECTION WRAPPER - -from _duckdb import ( - cursor, - register_filesystem, - unregister_filesystem, - list_filesystems, - filesystem_is_registered, - create_function, - remove_function, - sqltype, - dtype, - type, - array_type, - list_type, - union_type, - string_type, - enum_type, - decimal_type, - struct_type, - row_type, - map_type, - duplicate, - execute, - executemany, - close, - interrupt, - query_progress, - fetchone, - fetchmany, - fetchall, - fetchnumpy, - fetchdf, - fetch_df, - df, - fetch_df_chunk, - pl, - fetch_arrow_table, - arrow, - fetch_record_batch, - torch, - tf, - begin, - commit, - rollback, - checkpoint, - append, - register, - unregister, - table, - view, - values, - table_function, - read_json, - extract_statements, - sql, - query, - from_query, - read_csv, - from_csv_auto, - from_df, - from_arrow, - from_parquet, - read_parquet, - from_parquet, - read_parquet, - get_table_names, - install_extension, - load_extension, - project, - distinct, - write_csv, - aggregate, - alias, - filter, - limit, - order, - query_df, - description, - rowcount, -) - -_exported_symbols.extend([ - 'cursor', - 'register_filesystem', - 'unregister_filesystem', - 'list_filesystems', - 'filesystem_is_registered', - 'create_function', - 'remove_function', - 'sqltype', - 'dtype', - 'type', - 'array_type', - 'list_type', - 'union_type', - 'string_type', - 'enum_type', - 'decimal_type', - 'struct_type', - 'row_type', - 'map_type', - 'duplicate', - 'execute', - 'executemany', - 'close', - 'interrupt', - 'query_progress', - 'fetchone', - 'fetchmany', - 'fetchall', - 'fetchnumpy', - 'fetchdf', - 'fetch_df', - 'df', - 'fetch_df_chunk', - 'pl', - 'fetch_arrow_table', - 'arrow', - 'fetch_record_batch', - 'torch', - 'tf', - 'begin', - 'commit', - 'rollback', - 'checkpoint', - 'append', - 'register', - 'unregister', - 'table', - 'view', - 'values', - 'table_function', - 'read_json', - 'extract_statements', - 'sql', - 'query', - 'from_query', - 'read_csv', - 'from_csv_auto', - 'from_df', - 'from_arrow', - 'from_parquet', - 'read_parquet', - 'from_parquet', - 'read_parquet', - 'get_table_names', - 'install_extension', - 'load_extension', - 'project', - 'distinct', - 'write_csv', - 'aggregate', - 'alias', - 'filter', - 'limit', - 'order', - 'query_df', - 'description', - 'rowcount', -]) - -# END OF CONNECTION WRAPPER - -# Enums -from _duckdb import ( - ANALYZE, - DEFAULT, - RETURN_NULL, - STANDARD, - COLUMNS, - ROWS -) -_exported_symbols.extend([ - "ANALYZE", - "DEFAULT", - "RETURN_NULL", - "STANDARD" -]) - - -# read-only properties -from _duckdb import ( - __standard_vector_size__, - __interactive__, - __jupyter__, - __formatted_python_version__, - apilevel, - comment, - identifier, - keyword, - numeric_const, - operator, - paramstyle, - string_const, - threadsafety, - token_type, - tokenize -) -_exported_symbols.extend([ - "__standard_vector_size__", - "__interactive__", - "__jupyter__", - "__formatted_python_version__", - "apilevel", - "comment", - "identifier", - "keyword", - "numeric_const", - "operator", - "paramstyle", - "string_const", - "threadsafety", - "token_type", - "tokenize" -]) - - -from _duckdb import ( - connect, - default_connection, - set_default_connection, -) - -_exported_symbols.extend([ - "connect", - "default_connection", - "set_default_connection", -]) - -# Exceptions -from _duckdb import ( + DependencyException, + DuckDBPyConnection, + DuckDBPyRelation, Error, - DataError, - ConversionException, - OutOfRangeException, - TypeMismatchException, + ExpectedResultType, + ExplainType, + Expression, FatalException, + FunctionExpression, + HTTPException, IntegrityError, - ConstraintException, InternalError, InternalException, InterruptException, - NotSupportedError, + InvalidInputException, + InvalidTypeException, + IOException, + LambdaExpression, NotImplementedException, + NotSupportedError, OperationalError, - ConnectionException, - IOException, - HTTPException, OutOfMemoryException, - SerializationException, - TransactionException, + OutOfRangeException, + ParserException, PermissionException, ProgrammingError, - BinderException, - CatalogException, - InvalidInputException, - InvalidTypeException, - ParserException, - SyntaxException, + PythonExceptionHandling, + RenderMode, SequenceException, - Warning + SerializationException, + SQLExpression, + StarExpression, + Statement, + StatementType, + SyntaxException, + TransactionException, + TypeMismatchException, + Warning, + __formatted_python_version__, + __git_revision__, + __interactive__, + __jupyter__, + __standard_vector_size__, + _clean_default_connection, + aggregate, + alias, + apilevel, + append, + array_type, + arrow, + begin, + checkpoint, + close, + commit, + connect, + create_function, + cursor, + decimal_type, + default_connection, + description, + df, + distinct, + dtype, + duplicate, + enum_type, + execute, + executemany, + extract_statements, + fetch_arrow_table, + fetch_df, + fetch_df_chunk, + fetch_record_batch, + fetchall, + fetchdf, + fetchmany, + fetchnumpy, + fetchone, + filesystem_is_registered, + filter, + from_arrow, + from_csv_auto, + from_df, + from_parquet, + from_query, + get_table_names, + install_extension, + interrupt, + limit, + list_filesystems, + list_type, + load_extension, + map_type, + order, + paramstyle, + pl, + project, + query, + query_df, + query_progress, + read_csv, + read_json, + read_parquet, + register, + register_filesystem, + remove_function, + rollback, + row_type, + rowcount, + set_default_connection, + sql, + sqltype, + string_type, + struct_type, + table, + table_function, + tf, + threadsafety, + token_type, + tokenize, + torch, + type, + union_type, + unregister, + unregister_filesystem, + values, + view, + write_csv, ) -_exported_symbols.extend([ - "Error", - "DataError", - "ConversionException", - "OutOfRangeException", - "TypeMismatchException", - "FatalException", - "IntegrityError", - "ConstraintException", - "InternalError", - "InternalException", - "InterruptException", - "NotSupportedError", - "NotImplementedException", - "OperationalError", - "ConnectionException", - "IOException", - "HTTPException", - "OutOfMemoryException", - "SerializationException", - "TransactionException", - "PermissionException", - "ProgrammingError", - "BinderException", - "CatalogException", - "InvalidInputException", - "InvalidTypeException", - "ParserException", - "SyntaxException", - "SequenceException", - "Warning" -]) -# Value +from duckdb._dbapi_type_object import ( + BINARY, + DATETIME, + NUMBER, + ROWID, + STRING, + DBAPITypeObject, +) +from duckdb._version import ( + __duckdb_version__, + __version__, + version, +) from duckdb.value.constant import ( - Value, - NullValue, - BooleanValue, - UnsignedBinaryValue, - UnsignedShortValue, - UnsignedIntegerValue, - UnsignedLongValue, BinaryValue, - ShortValue, - IntegerValue, - LongValue, - HugeIntegerValue, - FloatValue, - DoubleValue, - DecimalValue, - StringValue, - UUIDValue, BitValue, BlobValue, + BooleanValue, DateValue, + DecimalValue, + DoubleValue, + FloatValue, + HugeIntegerValue, + IntegerValue, IntervalValue, - TimestampValue, - TimestampSecondValue, + ListValue, + LongValue, + MapValue, + NullValue, + ShortValue, + StringValue, + StructValue, TimestampMilisecondValue, TimestampNanosecondValue, + TimestampSecondValue, TimestampTimeZoneValue, - TimeValue, + TimestampValue, TimeTimeZoneValue, + TimeValue, + UnionType, + UnsignedBinaryValue, + UnsignedHugeIntegerValue, + UnsignedIntegerValue, + UnsignedLongValue, + UnsignedShortValue, + UUIDValue, + Value, ) -_exported_symbols.extend([ - "Value", - "NullValue", - "BooleanValue", - "UnsignedBinaryValue", - "UnsignedShortValue", - "UnsignedIntegerValue", - "UnsignedLongValue", +__all__: list[str] = [ "BinaryValue", - "ShortValue", - "IntegerValue", - "LongValue", - "HugeIntegerValue", - "FloatValue", - "DoubleValue", - "DecimalValue", - "StringValue", - "UUIDValue", + "BinderException", "BitValue", "BlobValue", + "BooleanValue", + "CSVLineTerminator", + "CaseExpression", + "CatalogException", + "CoalesceOperator", + "ColumnExpression", + "ConnectionException", + "ConstantExpression", + "ConstraintException", + "ConversionException", + "DataError", + "DatabaseError", "DateValue", + "DecimalValue", + "DefaultExpression", + "DependencyException", + "DoubleValue", + "DuckDBPyConnection", + "DuckDBPyRelation", + "Error", + "ExpectedResultType", + "ExplainType", + "Expression", + "FatalException", + "FloatValue", + "FunctionExpression", + "HTTPException", + "HugeIntegerValue", + "IOException", + "IntegerValue", + "IntegrityError", + "InternalError", + "InternalException", + "InterruptException", "IntervalValue", - "TimestampValue", - "TimestampSecondValue", + "InvalidInputException", + "InvalidTypeException", + "LambdaExpression", + "ListValue", + "LongValue", + "MapValue", + "NotImplementedException", + "NotSupportedError", + "NullValue", + "OperationalError", + "OutOfMemoryException", + "OutOfRangeException", + "ParserException", + "PermissionException", + "ProgrammingError", + "PythonExceptionHandling", + "RenderMode", + "SQLExpression", + "SequenceException", + "SerializationException", + "ShortValue", + "StarExpression", + "Statement", + "StatementType", + "StringValue", + "StructValue", + "SyntaxException", + "TimeTimeZoneValue", + "TimeValue", "TimestampMilisecondValue", "TimestampNanosecondValue", + "TimestampSecondValue", "TimestampTimeZoneValue", - "TimeValue", - "TimeTimeZoneValue", -]) - -__all__ = _exported_symbols + "TimestampValue", + "TransactionException", + "TypeMismatchException", + "UUIDValue", + "UnionType", + "UnsignedBinaryValue", + "UnsignedHugeIntegerValue", + "UnsignedIntegerValue", + "UnsignedLongValue", + "UnsignedShortValue", + "Value", + "Warning", + "__formatted_python_version__", + "__git_revision__", + "__interactive__", + "__jupyter__", + "__standard_vector_size__", + "__version__", + "_clean_default_connection", + "aggregate", + "alias", + "apilevel", + "append", + "array_type", + "arrow", + "begin", + "checkpoint", + "close", + "commit", + "connect", + "create_function", + "cursor", + "decimal_type", + "default_connection", + "description", + "df", + "distinct", + "dtype", + "duplicate", + "enum_type", + "execute", + "executemany", + "extract_statements", + "fetch_arrow_table", + "fetch_df", + "fetch_df_chunk", + "fetch_record_batch", + "fetchall", + "fetchdf", + "fetchmany", + "fetchnumpy", + "fetchone", + "filesystem_is_registered", + "filter", + "from_arrow", + "from_csv_auto", + "from_df", + "from_parquet", + "from_query", + "get_table_names", + "install_extension", + "interrupt", + "limit", + "list_filesystems", + "list_type", + "load_extension", + "map_type", + "order", + "paramstyle", + "paramstyle", + "pl", + "project", + "query", + "query_df", + "query_progress", + "read_csv", + "read_json", + "read_parquet", + "register", + "register_filesystem", + "remove_function", + "rollback", + "row_type", + "rowcount", + "set_default_connection", + "sql", + "sqltype", + "string_type", + "struct_type", + "table", + "table_function", + "tf", + "threadsafety", + "threadsafety", + "token_type", + "tokenize", + "torch", + "type", + "union_type", + "unregister", + "unregister_filesystem", + "values", + "view", + "write_csv", +] diff --git a/duckdb/__init__.pyi b/duckdb/__init__.pyi deleted file mode 100644 index 8f27e5e3..00000000 --- a/duckdb/__init__.pyi +++ /dev/null @@ -1,713 +0,0 @@ -# to regenerate this from scratch, run scripts/regenerate_python_stubs.sh . -# be warned - currently there are still tweaks needed after this file is -# generated. These should be annotated with a comment like -# # stubgen override -# to help the sanity of maintainers. - -import duckdb.typing as typing -import duckdb.functional as functional -from duckdb.typing import DuckDBPyType -from duckdb.functional import FunctionNullHandling, PythonUDFType -from duckdb.value.constant import ( - Value, - NullValue, - BooleanValue, - UnsignedBinaryValue, - UnsignedShortValue, - UnsignedIntegerValue, - UnsignedLongValue, - BinaryValue, - ShortValue, - IntegerValue, - LongValue, - HugeIntegerValue, - FloatValue, - DoubleValue, - DecimalValue, - StringValue, - UUIDValue, - BitValue, - BlobValue, - DateValue, - IntervalValue, - TimestampValue, - TimestampSecondValue, - TimestampMilisecondValue, - TimestampNanosecondValue, - TimestampTimeZoneValue, - TimeValue, - TimeTimeZoneValue, -) - -# We also run this in python3.7, where this is needed -from typing_extensions import Literal -# stubgen override - missing import of Set -from typing import Any, ClassVar, Set, Optional, Callable -from io import StringIO, TextIOBase -from pathlib import Path - -from typing import overload, Dict, List, Union, Tuple -import pandas -# stubgen override - unfortunately we need this for version checks -import sys -import fsspec -import pyarrow.lib -import polars -# stubgen override - This should probably not be exposed -apilevel: str -comment: token_type -identifier: token_type -keyword: token_type -numeric_const: token_type -operator: token_type -paramstyle: str -string_const: token_type -threadsafety: int -__standard_vector_size__: int -STANDARD: ExplainType -ANALYZE: ExplainType -DEFAULT: PythonExceptionHandling -RETURN_NULL: PythonExceptionHandling -ROWS: RenderMode -COLUMNS: RenderMode - -__version__: str - -__interactive__: bool -__jupyter__: bool -__formatted_python_version__: str - -class BinderException(ProgrammingError): ... - -class CatalogException(ProgrammingError): ... - -class ConnectionException(OperationalError): ... - -class ConstraintException(IntegrityError): ... - -class ConversionException(DataError): ... - -class DataError(Error): ... - -class ExplainType: - STANDARD: ExplainType - ANALYZE: ExplainType - def __int__(self) -> int: ... - def __index__(self) -> int: ... - @property - def __members__(self) -> Dict[str, ExplainType]: ... - @property - def name(self) -> str: ... - @property - def value(self) -> int: ... - -class RenderMode: - ROWS: RenderMode - COLUMNS: RenderMode - def __int__(self) -> int: ... - def __index__(self) -> int: ... - @property - def __members__(self) -> Dict[str, RenderMode]: ... - @property - def name(self) -> str: ... - @property - def value(self) -> int: ... - -class PythonExceptionHandling: - DEFAULT: PythonExceptionHandling - RETURN_NULL: PythonExceptionHandling - def __int__(self) -> int: ... - def __index__(self) -> int: ... - @property - def __members__(self) -> Dict[str, PythonExceptionHandling]: ... - @property - def name(self) -> str: ... - @property - def value(self) -> int: ... - -class CSVLineTerminator: - LINE_FEED: CSVLineTerminator - CARRIAGE_RETURN_LINE_FEED: CSVLineTerminator - def __int__(self) -> int: ... - def __index__(self) -> int: ... - @property - def __members__(self) -> Dict[str, CSVLineTerminator]: ... - @property - def name(self) -> str: ... - @property - def value(self) -> int: ... - -class ExpectedResultType: - QUERY_RESULT: ExpectedResultType - CHANGED_ROWS: ExpectedResultType - NOTHING: ExpectedResultType - def __int__(self) -> int: ... - def __index__(self) -> int: ... - @property - def __members__(self) -> Dict[str, ExpectedResultType]: ... - @property - def name(self) -> str: ... - @property - def value(self) -> int: ... - -class StatementType: - INVALID: StatementType - SELECT: StatementType - INSERT: StatementType - UPDATE: StatementType - CREATE: StatementType - DELETE: StatementType - PREPARE: StatementType - EXECUTE: StatementType - ALTER: StatementType - TRANSACTION: StatementType - COPY: StatementType - ANALYZE: StatementType - VARIABLE_SET: StatementType - CREATE_FUNC: StatementType - EXPLAIN: StatementType - DROP: StatementType - EXPORT: StatementType - PRAGMA: StatementType - VACUUM: StatementType - CALL: StatementType - SET: StatementType - LOAD: StatementType - RELATION: StatementType - EXTENSION: StatementType - LOGICAL_PLAN: StatementType - ATTACH: StatementType - DETACH: StatementType - MULTI: StatementType - COPY_DATABASE: StatementType - MERGE_INTO: StatementType - def __int__(self) -> int: ... - def __index__(self) -> int: ... - @property - def __members__(self) -> Dict[str, StatementType]: ... - @property - def name(self) -> str: ... - @property - def value(self) -> int: ... - -class Statement: - def __init__(self, *args, **kwargs) -> None: ... - @property - def query(self) -> str: ... - @property - def named_parameters(self) -> Set[str]: ... - @property - def expected_result_type(self) -> List[ExpectedResultType]: ... - @property - def type(self) -> StatementType: ... - -class Expression: - def __init__(self, *args, **kwargs) -> None: ... - def __neg__(self) -> "Expression": ... - - def __add__(self, expr: "Expression") -> "Expression": ... - def __radd__(self, expr: "Expression") -> "Expression": ... - - def __sub__(self, expr: "Expression") -> "Expression": ... - def __rsub__(self, expr: "Expression") -> "Expression": ... - - def __mul__(self, expr: "Expression") -> "Expression": ... - def __rmul__(self, expr: "Expression") -> "Expression": ... - - def __div__(self, expr: "Expression") -> "Expression": ... - def __rdiv__(self, expr: "Expression") -> "Expression": ... - - def __truediv__(self, expr: "Expression") -> "Expression": ... - def __rtruediv__(self, expr: "Expression") -> "Expression": ... - - def __floordiv__(self, expr: "Expression") -> "Expression": ... - def __rfloordiv__(self, expr: "Expression") -> "Expression": ... - - def __mod__(self, expr: "Expression") -> "Expression": ... - def __rmod__(self, expr: "Expression") -> "Expression": ... - - def __pow__(self, expr: "Expression") -> "Expression": ... - def __rpow__(self, expr: "Expression") -> "Expression": ... - - def __and__(self, expr: "Expression") -> "Expression": ... - def __rand__(self, expr: "Expression") -> "Expression": ... - def __or__(self, expr: "Expression") -> "Expression": ... - def __ror__(self, expr: "Expression") -> "Expression": ... - def __invert__(self) -> "Expression": ... - - def __eq__(# type: ignore[override] - self, expr: "Expression") -> "Expression": ... - def __ne__(# type: ignore[override] - self, expr: "Expression") -> "Expression": ... - def __gt__(self, expr: "Expression") -> "Expression": ... - def __ge__(self, expr: "Expression") -> "Expression": ... - def __lt__(self, expr: "Expression") -> "Expression": ... - def __le__(self, expr: "Expression") -> "Expression": ... - - def show(self) -> None: ... - def __repr__(self) -> str: ... - def get_name(self) -> str: ... - def alias(self, alias: str) -> "Expression": ... - def when(self, condition: "Expression", value: "Expression") -> "Expression": ... - def otherwise(self, value: "Expression") -> "Expression": ... - def cast(self, type: DuckDBPyType) -> "Expression": ... - def between(self, lower: "Expression", upper: "Expression") -> "Expression": ... - def collate(self, collation: str) -> "Expression": ... - def asc(self) -> "Expression": ... - def desc(self) -> "Expression": ... - def nulls_first(self) -> "Expression": ... - def nulls_last(self) -> "Expression": ... - def isnull(self) -> "Expression": ... - def isnotnull(self) -> "Expression": ... - def isin(self, *cols: "Expression") -> "Expression": ... - def isnotin(self, *cols: "Expression") -> "Expression": ... - -def StarExpression(exclude: Optional[List[str]] = None) -> Expression: ... -def ColumnExpression(column: str) -> Expression: ... -def DefaultExpression() -> Expression: ... -def ConstantExpression(val: Any) -> Expression: ... -def CaseExpression(condition: Expression, value: Expression) -> Expression: ... -def FunctionExpression(function: str, *cols: Expression) -> Expression: ... -def CoalesceOperator(*cols: Expression) -> Expression: ... -def LambdaExpression(lhs: Union[Tuple["Expression", ...], str], rhs: Expression) -> Expression: ... -def SQLExpression(expr: str) -> Expression: ... - -class DuckDBPyConnection: - def __init__(self, *args, **kwargs) -> None: ... - def __enter__(self) -> DuckDBPyConnection: ... - def __exit__(self, exc_type: object, exc: object, traceback: object) -> None: ... - def __del__(self) -> None: ... - @property - def description(self) -> Optional[List[Any]]: ... - @property - def rowcount(self) -> int: ... - - # NOTE: this section is generated by tools/pythonpkg/scripts/generate_connection_stubs.py. - # Do not edit this section manually, your changes will be overwritten! - - # START OF CONNECTION METHODS - def cursor(self) -> DuckDBPyConnection: ... - def register_filesystem(self, filesystem: fsspec.AbstractFileSystem) -> None: ... - def unregister_filesystem(self, name: str) -> None: ... - def list_filesystems(self) -> list: ... - def filesystem_is_registered(self, name: str) -> bool: ... - def create_function(self, name: str, function: function, parameters: Optional[List[DuckDBPyType]] = None, return_type: Optional[DuckDBPyType] = None, *, type: Optional[PythonUDFType] = PythonUDFType.NATIVE, null_handling: Optional[FunctionNullHandling] = FunctionNullHandling.DEFAULT, exception_handling: Optional[PythonExceptionHandling] = PythonExceptionHandling.DEFAULT, side_effects: bool = False) -> DuckDBPyConnection: ... - def remove_function(self, name: str) -> DuckDBPyConnection: ... - def sqltype(self, type_str: str) -> DuckDBPyType: ... - def dtype(self, type_str: str) -> DuckDBPyType: ... - def type(self, type_str: str) -> DuckDBPyType: ... - def array_type(self, type: DuckDBPyType, size: int) -> DuckDBPyType: ... - def list_type(self, type: DuckDBPyType) -> DuckDBPyType: ... - def union_type(self, members: DuckDBPyType) -> DuckDBPyType: ... - def string_type(self, collation: str = "") -> DuckDBPyType: ... - def enum_type(self, name: str, type: DuckDBPyType, values: List[Any]) -> DuckDBPyType: ... - def decimal_type(self, width: int, scale: int) -> DuckDBPyType: ... - def struct_type(self, fields: Union[Dict[str, DuckDBPyType], List[str]]) -> DuckDBPyType: ... - def row_type(self, fields: Union[Dict[str, DuckDBPyType], List[str]]) -> DuckDBPyType: ... - def map_type(self, key: DuckDBPyType, value: DuckDBPyType) -> DuckDBPyType: ... - def duplicate(self) -> DuckDBPyConnection: ... - def execute(self, query: object, parameters: object = None) -> DuckDBPyConnection: ... - def executemany(self, query: object, parameters: object = None) -> DuckDBPyConnection: ... - def close(self) -> None: ... - def interrupt(self) -> None: ... - def query_progress(self) -> float: ... - def fetchone(self) -> Optional[tuple]: ... - def fetchmany(self, size: int = 1) -> List[Any]: ... - def fetchall(self) -> List[Any]: ... - def fetchnumpy(self) -> dict: ... - def fetchdf(self, *, date_as_object: bool = False) -> pandas.DataFrame: ... - def fetch_df(self, *, date_as_object: bool = False) -> pandas.DataFrame: ... - def df(self, *, date_as_object: bool = False) -> pandas.DataFrame: ... - def fetch_df_chunk(self, vectors_per_chunk: int = 1, *, date_as_object: bool = False) -> pandas.DataFrame: ... - def pl(self, rows_per_batch: int = 1000000, *, lazy: bool = False) -> polars.DataFrame: ... - def fetch_arrow_table(self, rows_per_batch: int = 1000000) -> pyarrow.lib.Table: ... - def fetch_record_batch(self, rows_per_batch: int = 1000000) -> pyarrow.lib.RecordBatchReader: ... - def arrow(self, rows_per_batch: int = 1000000) -> pyarrow.lib.RecordBatchReader: ... - def torch(self) -> dict: ... - def tf(self) -> dict: ... - def begin(self) -> DuckDBPyConnection: ... - def commit(self) -> DuckDBPyConnection: ... - def rollback(self) -> DuckDBPyConnection: ... - def checkpoint(self) -> DuckDBPyConnection: ... - def append(self, table_name: str, df: pandas.DataFrame, *, by_name: bool = False) -> DuckDBPyConnection: ... - def register(self, view_name: str, python_object: object) -> DuckDBPyConnection: ... - def unregister(self, view_name: str) -> DuckDBPyConnection: ... - def table(self, table_name: str) -> DuckDBPyRelation: ... - def view(self, view_name: str) -> DuckDBPyRelation: ... - def values(self, *args: Union[List[Any],Expression, Tuple[Expression]]) -> DuckDBPyRelation: ... - def table_function(self, name: str, parameters: object = None) -> DuckDBPyRelation: ... - def read_json(self, path_or_buffer: Union[str, StringIO, TextIOBase], *, columns: Optional[Dict[str,str]] = None, sample_size: Optional[int] = None, maximum_depth: Optional[int] = None, records: Optional[str] = None, format: Optional[str] = None, date_format: Optional[str] = None, timestamp_format: Optional[str] = None, compression: Optional[str] = None, maximum_object_size: Optional[int] = None, ignore_errors: Optional[bool] = None, convert_strings_to_integers: Optional[bool] = None, field_appearance_threshold: Optional[float] = None, map_inference_threshold: Optional[int] = None, maximum_sample_files: Optional[int] = None, filename: Optional[bool | str] = None, hive_partitioning: Optional[bool] = None, union_by_name: Optional[bool] = None, hive_types: Optional[Dict[str, str]] = None, hive_types_autocast: Optional[bool] = None) -> DuckDBPyRelation: ... - def extract_statements(self, query: str) -> List[Statement]: ... - def sql(self, query: str, *, alias: str = "", params: object = None) -> DuckDBPyRelation: ... - def query(self, query: str, *, alias: str = "", params: object = None) -> DuckDBPyRelation: ... - def from_query(self, query: str, *, alias: str = "", params: object = None) -> DuckDBPyRelation: ... - def read_csv(self, path_or_buffer: Union[str, StringIO, TextIOBase], *, header: Optional[bool | int] = None, compression: Optional[str] = None, sep: Optional[str] = None, delimiter: Optional[str] = None, dtype: Optional[Dict[str, str] | List[str]] = None, na_values: Optional[str| List[str]] = None, skiprows: Optional[int] = None, quotechar: Optional[str] = None, escapechar: Optional[str] = None, encoding: Optional[str] = None, parallel: Optional[bool] = None, date_format: Optional[str] = None, timestamp_format: Optional[str] = None, sample_size: Optional[int] = None, all_varchar: Optional[bool] = None, normalize_names: Optional[bool] = None, null_padding: Optional[bool] = None, names: Optional[List[str]] = None, lineterminator: Optional[str] = None, columns: Optional[Dict[str, str]] = None, auto_type_candidates: Optional[List[str]] = None, max_line_size: Optional[int] = None, ignore_errors: Optional[bool] = None, store_rejects: Optional[bool] = None, rejects_table: Optional[str] = None, rejects_scan: Optional[str] = None, rejects_limit: Optional[int] = None, force_not_null: Optional[List[str]] = None, buffer_size: Optional[int] = None, decimal: Optional[str] = None, allow_quoted_nulls: Optional[bool] = None, filename: Optional[bool | str] = None, hive_partitioning: Optional[bool] = None, union_by_name: Optional[bool] = None, hive_types: Optional[Dict[str, str]] = None, hive_types_autocast: Optional[bool] = None) -> DuckDBPyRelation: ... - def from_csv_auto(self, path_or_buffer: Union[str, StringIO, TextIOBase], *, header: Optional[bool | int] = None, compression: Optional[str] = None, sep: Optional[str] = None, delimiter: Optional[str] = None, dtype: Optional[Dict[str, str] | List[str]] = None, na_values: Optional[str| List[str]] = None, skiprows: Optional[int] = None, quotechar: Optional[str] = None, escapechar: Optional[str] = None, encoding: Optional[str] = None, parallel: Optional[bool] = None, date_format: Optional[str] = None, timestamp_format: Optional[str] = None, sample_size: Optional[int] = None, all_varchar: Optional[bool] = None, normalize_names: Optional[bool] = None, null_padding: Optional[bool] = None, names: Optional[List[str]] = None, lineterminator: Optional[str] = None, columns: Optional[Dict[str, str]] = None, auto_type_candidates: Optional[List[str]] = None, max_line_size: Optional[int] = None, ignore_errors: Optional[bool] = None, store_rejects: Optional[bool] = None, rejects_table: Optional[str] = None, rejects_scan: Optional[str] = None, rejects_limit: Optional[int] = None, force_not_null: Optional[List[str]] = None, buffer_size: Optional[int] = None, decimal: Optional[str] = None, allow_quoted_nulls: Optional[bool] = None, filename: Optional[bool | str] = None, hive_partitioning: Optional[bool] = None, union_by_name: Optional[bool] = None, hive_types: Optional[Dict[str, str]] = None, hive_types_autocast: Optional[bool] = None) -> DuckDBPyRelation: ... - def from_df(self, df: pandas.DataFrame) -> DuckDBPyRelation: ... - def from_arrow(self, arrow_object: object) -> DuckDBPyRelation: ... - def from_parquet(self, file_glob: str, binary_as_string: bool = False, *, file_row_number: bool = False, filename: bool = False, hive_partitioning: bool = False, union_by_name: bool = False, compression: Optional[str] = None) -> DuckDBPyRelation: ... - def read_parquet(self, file_glob: str, binary_as_string: bool = False, *, file_row_number: bool = False, filename: bool = False, hive_partitioning: bool = False, union_by_name: bool = False, compression: Optional[str] = None) -> DuckDBPyRelation: ... - def get_table_names(self, query: str, *, qualified: bool = False) -> Set[str]: ... - def install_extension(self, extension: str, *, force_install: bool = False, repository: Optional[str] = None, repository_url: Optional[str] = None, version: Optional[str] = None) -> None: ... - def load_extension(self, extension: str) -> None: ... - # END OF CONNECTION METHODS - -class DuckDBPyRelation: - def close(self) -> None: ... - def __getattr__(self, name: str) -> DuckDBPyRelation: ... - def __getitem__(self, name: str) -> DuckDBPyRelation: ... - def __init__(self, *args, **kwargs) -> None: ... - def __contains__(self, name: str) -> bool: ... - def aggregate(self, aggr_expr: str, group_expr: str = ...) -> DuckDBPyRelation: ... - def apply(self, function_name: str, function_aggr: str, group_expr: str = ..., function_parameter: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - - def cume_dist(self, window_spec: str, projected_columns: str = ...) -> DuckDBPyRelation: ... - def dense_rank(self, window_spec: str, projected_columns: str = ...) -> DuckDBPyRelation: ... - def percent_rank(self, window_spec: str, projected_columns: str = ...) -> DuckDBPyRelation: ... - def rank(self, window_spec: str, projected_columns: str = ...) -> DuckDBPyRelation: ... - def rank_dense(self, window_spec: str, projected_columns: str = ...) -> DuckDBPyRelation: ... - def row_number(self, window_spec: str, projected_columns: str = ...) -> DuckDBPyRelation: ... - - def lag(self, column: str, window_spec: str, offset: int, default_value: str, ignore_nulls: bool, projected_columns: str = ...) -> DuckDBPyRelation: ... - def lead(self, column: str, window_spec: str, offset: int, default_value: str, ignore_nulls: bool, projected_columns: str = ...) -> DuckDBPyRelation: ... - def nth_value(self, column: str, window_spec: str, offset: int, ignore_nulls: bool = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - - def value_counts(self, column: str, groups: str = ...) -> DuckDBPyRelation: ... - def geomean(self, column: str, groups: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def first(self, column: str, groups: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def first_value(self, column: str, window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def last(self, column: str, groups: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def last_value(self, column: str, window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def mode(self, aggregation_columns: str, group_columns: str = ...) -> DuckDBPyRelation: ... - def n_tile(self, window_spec: str, num_buckets: int, projected_columns: str = ...) -> DuckDBPyRelation: ... - def quantile_cont(self, column: str, q: Union[float, List[float]] = ..., groups: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def quantile_disc(self, column: str, q: Union[float, List[float]] = ..., groups: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def sum(self, sum_aggr: str, group_expr: str = ...) -> DuckDBPyRelation: ... - - def any_value(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def arg_max(self, arg_column: str, value_column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def arg_min(self, arg_column: str, value_column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def avg(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def bit_and(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def bit_or(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def bit_xor(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def bitstring_agg(self, column: str, min: Optional[int], max: Optional[int], groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def bool_and(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def bool_or(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def count(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def favg(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def fsum(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def histogram(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def max(self, max_aggr: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def min(self, min_aggr: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def mean(self, mean_aggr: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def median(self, median_aggr: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def product(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def quantile(self, q: str, quantile_aggr: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def std(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def stddev(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def stddev_pop(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def stddev_samp(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def string_agg(self, column: str, sep: str = ..., groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def var(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def var_pop(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def var_samp(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def variance(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def list(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - - def arrow(self, batch_size: int = ...) -> pyarrow.lib.RecordBatchReader: ... - def __arrow_c_stream__(self, requested_schema: Optional[object] = None) -> object: ... - def create(self, table_name: str) -> None: ... - def create_view(self, view_name: str, replace: bool = ...) -> DuckDBPyRelation: ... - def describe(self) -> DuckDBPyRelation: ... - def df(self, *args, **kwargs) -> pandas.DataFrame: ... - def distinct(self) -> DuckDBPyRelation: ... - def except_(self, other_rel: DuckDBPyRelation) -> DuckDBPyRelation: ... - def execute(self, *args, **kwargs) -> DuckDBPyRelation: ... - def explain(self, type: Optional[Literal['standard', 'analyze'] | int] = 'standard') -> str: ... - def fetchall(self) -> List[Any]: ... - def fetchmany(self, size: int = ...) -> List[Any]: ... - def fetchnumpy(self) -> dict: ... - def fetchone(self) -> Optional[tuple]: ... - def fetchdf(self, *args, **kwargs) -> Any: ... - def fetch_arrow_reader(self, batch_size: int = ...) -> pyarrow.lib.RecordBatchReader: ... - def fetch_arrow_table(self, rows_per_batch: int = ...) -> pyarrow.lib.Table: ... - def filter(self, filter_expr: Union[Expression, str]) -> DuckDBPyRelation: ... - def insert(self, values: List[Any]) -> None: ... - def update(self, set: Dict[str, Expression], condition: Optional[Expression] = None) -> None: ... - def insert_into(self, table_name: str) -> None: ... - def intersect(self, other_rel: DuckDBPyRelation) -> DuckDBPyRelation: ... - def join(self, other_rel: DuckDBPyRelation, condition: Union[str, Expression], how: str = ...) -> DuckDBPyRelation: ... - def cross(self, other_rel: DuckDBPyRelation) -> DuckDBPyRelation: ... - def limit(self, n: int, offset: int = ...) -> DuckDBPyRelation: ... - def map(self, map_function: function, schema: Optional[Dict[str, DuckDBPyType]] = None) -> DuckDBPyRelation: ... - def order(self, order_expr: str) -> DuckDBPyRelation: ... - def sort(self, *cols: Expression) -> DuckDBPyRelation: ... - def project(self, *cols: Union[str, Expression]) -> DuckDBPyRelation: ... - def select(self, *cols: Union[str, Expression]) -> DuckDBPyRelation: ... - def pl(self, rows_per_batch: int = ..., connection: DuckDBPyConnection = ...) -> polars.DataFrame: ... - def query(self, virtual_table_name: str, sql_query: str) -> DuckDBPyRelation: ... - def record_batch(self, batch_size: int = ...) -> pyarrow.lib.RecordBatchReader: ... - def fetch_record_batch(self, rows_per_batch: int = 1000000, *, connection: DuckDBPyConnection = ...) -> pyarrow.lib.RecordBatchReader: ... - def select_types(self, types: List[Union[str, DuckDBPyType]]) -> DuckDBPyRelation: ... - def select_dtypes(self, types: List[Union[str, DuckDBPyType]]) -> DuckDBPyRelation: ... - def set_alias(self, alias: str) -> DuckDBPyRelation: ... - def show(self, max_width: Optional[int] = None, max_rows: Optional[int] = None, max_col_width: Optional[int] = None, null_value: Optional[str] = None, render_mode: Optional[RenderMode] = None) -> None: ... - def sql_query(self) -> str: ... - def to_arrow_table(self, batch_size: int = ...) -> pyarrow.lib.Table: ... - def to_csv( - self, - file_name: str, - sep: Optional[str] = None, - na_rep: Optional[str] = None, - header: Optional[bool] = None, - quotechar: Optional[str] = None, - escapechar: Optional[str] = None, - date_format: Optional[str] = None, - timestamp_format: Optional[str] = None, - quoting: Optional[str | int] = None, - encoding: Optional[str] = None, - compression: Optional[str] = None, - write_partition_columns: Optional[bool] = None, - overwrite: Optional[bool] = None, - per_thread_output: Optional[bool] = None, - use_tmp_file: Optional[bool] = None, - partition_by: Optional[List[str]] = None - ) -> None: ... - def to_df(self, *args, **kwargs) -> pandas.DataFrame: ... - def to_parquet( - self, - file_name: str, - compression: Optional[str] = None, - field_ids: Optional[dict | str] = None, - row_group_size_bytes: Optional[int | str] = None, - row_group_size: Optional[int] = None, - partition_by: Optional[List[str]] = None, - write_partition_columns: Optional[bool] = None, - overwrite: Optional[bool] = None, - per_thread_output: Optional[bool] = None, - use_tmp_file: Optional[bool] = None, - append: Optional[bool] = None - ) -> None: ... - def fetch_df_chunk(self, vectors_per_chunk: int = 1, *, date_as_object: bool = False) -> pandas.DataFrame: ... - def to_table(self, table_name: str) -> None: ... - def to_view(self, view_name: str, replace: bool = ...) -> DuckDBPyRelation: ... - def torch(self, connection: DuckDBPyConnection = ...) -> dict: ... - def tf(self, connection: DuckDBPyConnection = ...) -> dict: ... - def union(self, union_rel: DuckDBPyRelation) -> DuckDBPyRelation: ... - def unique(self, unique_aggr: str) -> DuckDBPyRelation: ... - def write_csv( - self, - file_name: str, - sep: Optional[str] = None, - na_rep: Optional[str] = None, - header: Optional[bool] = None, - quotechar: Optional[str] = None, - escapechar: Optional[str] = None, - date_format: Optional[str] = None, - timestamp_format: Optional[str] = None, - quoting: Optional[str | int] = None, - encoding: Optional[str] = None, - compression: Optional[str] = None, - write_partition_columns: Optional[bool] = None, - overwrite: Optional[bool] = None, - per_thread_output: Optional[bool] = None, - use_tmp_file: Optional[bool] = None, - partition_by: Optional[List[str]] = None - ) -> None: ... - def write_parquet( - self, - file_name: str, - compression: Optional[str] = None, - field_ids: Optional[dict | str] = None, - row_group_size_bytes: Optional[int | str] = None, - row_group_size: Optional[int] = None, - partition_by: Optional[List[str]] = None, - write_partition_columns: Optional[bool] = None, - overwrite: Optional[bool] = None, - per_thread_output: Optional[bool] = None, - use_tmp_file: Optional[bool] = None, - append: Optional[bool] = None - ) -> None: ... - def __len__(self) -> int: ... - @property - def alias(self) -> str: ... - @property - def columns(self) -> List[str]: ... - @property - def dtypes(self) -> List[DuckDBPyType]: ... - @property - def description(self) -> List[Any]: ... - @property - def shape(self) -> tuple[int, int]: ... - @property - def type(self) -> str: ... - @property - def types(self) -> List[DuckDBPyType]: ... - -class Error(Exception): ... - -class FatalException(Error): ... - -class HTTPException(IOException): - status_code: int - body: str - reason: str - headers: Dict[str, str] - -class IOException(OperationalError): ... - -class IntegrityError(Error): ... - -class InternalError(Error): ... - -class InternalException(InternalError): ... - -class InterruptException(Error): ... - -class InvalidInputException(ProgrammingError): ... - -class InvalidTypeException(ProgrammingError): ... - -class NotImplementedException(NotSupportedError): ... - -class NotSupportedError(Error): ... - -class OperationalError(Error): ... - -class OutOfMemoryException(OperationalError): ... - -class OutOfRangeException(DataError): ... - -class ParserException(ProgrammingError): ... - -class PermissionException(Error): ... - -class ProgrammingError(Error): ... - -class SequenceException(Error): ... - -class SerializationException(OperationalError): ... - -class SyntaxException(ProgrammingError): ... - -class TransactionException(OperationalError): ... - -class TypeMismatchException(DataError): ... - -class Warning(Exception): ... - -class token_type: - # stubgen override - these make mypy sad - #__doc__: ClassVar[str] = ... # read-only - #__members__: ClassVar[dict] = ... # read-only - __entries: ClassVar[dict] = ... - comment: ClassVar[token_type] = ... - identifier: ClassVar[token_type] = ... - keyword: ClassVar[token_type] = ... - numeric_const: ClassVar[token_type] = ... - operator: ClassVar[token_type] = ... - string_const: ClassVar[token_type] = ... - def __init__(self, value: int) -> None: ... - def __eq__(self, other: object) -> bool: ... - def __getstate__(self) -> int: ... - def __hash__(self) -> int: ... - # stubgen override - pybind only puts index in python >= 3.8: https://github.com/EricCousineau-TRI/pybind11/blob/54430436/include/pybind11/pybind11.h#L1789 - if sys.version_info >= (3, 7): - def __index__(self) -> int: ... - def __int__(self) -> int: ... - def __ne__(self, other: object) -> bool: ... - def __setstate__(self, state: int) -> None: ... - @property - def name(self) -> str: ... - @property - def value(self) -> int: ... - @property - # stubgen override - this gets removed by stubgen but it shouldn't - def __members__(self) -> object: ... - -def connect(database: Union[str, Path] = ..., read_only: bool = ..., config: dict = ...) -> DuckDBPyConnection: ... -def default_connection() -> DuckDBPyConnection: ... -def set_default_connection(connection: DuckDBPyConnection) -> None: ... -def tokenize(query: str) -> List[Any]: ... - -# NOTE: this section is generated by tools/pythonpkg/scripts/generate_connection_wrapper_stubs.py. -# Do not edit this section manually, your changes will be overwritten! - -# START OF CONNECTION WRAPPER -def cursor(*, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... -def register_filesystem(filesystem: fsspec.AbstractFileSystem, *, connection: DuckDBPyConnection = ...) -> None: ... -def unregister_filesystem(name: str, *, connection: DuckDBPyConnection = ...) -> None: ... -def list_filesystems(*, connection: DuckDBPyConnection = ...) -> list: ... -def filesystem_is_registered(name: str, *, connection: DuckDBPyConnection = ...) -> bool: ... -def create_function(name: str, function: function, parameters: Optional[List[DuckDBPyType]] = None, return_type: Optional[DuckDBPyType] = None, *, type: Optional[PythonUDFType] = PythonUDFType.NATIVE, null_handling: Optional[FunctionNullHandling] = FunctionNullHandling.DEFAULT, exception_handling: Optional[PythonExceptionHandling] = PythonExceptionHandling.DEFAULT, side_effects: bool = False, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... -def remove_function(name: str, *, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... -def sqltype(type_str: str, *, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... -def dtype(type_str: str, *, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... -def type(type_str: str, *, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... -def array_type(type: DuckDBPyType, size: int, *, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... -def list_type(type: DuckDBPyType, *, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... -def union_type(members: DuckDBPyType, *, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... -def string_type(collation: str = "", *, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... -def enum_type(name: str, type: DuckDBPyType, values: List[Any], *, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... -def decimal_type(width: int, scale: int, *, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... -def struct_type(fields: Union[Dict[str, DuckDBPyType], List[str]], *, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... -def row_type(fields: Union[Dict[str, DuckDBPyType], List[str]], *, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... -def map_type(key: DuckDBPyType, value: DuckDBPyType, *, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... -def duplicate(*, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... -def execute(query: object, parameters: object = None, *, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... -def executemany(query: object, parameters: object = None, *, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... -def close(*, connection: DuckDBPyConnection = ...) -> None: ... -def interrupt(*, connection: DuckDBPyConnection = ...) -> None: ... -def query_progress(*, connection: DuckDBPyConnection = ...) -> float: ... -def fetchone(*, connection: DuckDBPyConnection = ...) -> Optional[tuple]: ... -def fetchmany(size: int = 1, *, connection: DuckDBPyConnection = ...) -> List[Any]: ... -def fetchall(*, connection: DuckDBPyConnection = ...) -> List[Any]: ... -def fetchnumpy(*, connection: DuckDBPyConnection = ...) -> dict: ... -def fetchdf(*, date_as_object: bool = False, connection: DuckDBPyConnection = ...) -> pandas.DataFrame: ... -def fetch_df(*, date_as_object: bool = False, connection: DuckDBPyConnection = ...) -> pandas.DataFrame: ... -def df(*, date_as_object: bool = False, connection: DuckDBPyConnection = ...) -> pandas.DataFrame: ... -def fetch_df_chunk(vectors_per_chunk: int = 1, *, date_as_object: bool = False, connection: DuckDBPyConnection = ...) -> pandas.DataFrame: ... -def pl(rows_per_batch: int = 1000000, *, lazy: bool = False, connection: DuckDBPyConnection = ...) -> polars.DataFrame: ... -def fetch_arrow_table(rows_per_batch: int = 1000000, *, connection: DuckDBPyConnection = ...) -> pyarrow.lib.Table: ... -def fetch_record_batch(rows_per_batch: int = 1000000, *, connection: DuckDBPyConnection = ...) -> pyarrow.lib.RecordBatchReader: ... -def arrow(rows_per_batch: int = 1000000, *, connection: DuckDBPyConnection = ...) -> pyarrow.lib.RecordBatchReader: ... -def torch(*, connection: DuckDBPyConnection = ...) -> dict: ... -def tf(*, connection: DuckDBPyConnection = ...) -> dict: ... -def begin(*, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... -def commit(*, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... -def rollback(*, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... -def checkpoint(*, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... -def append(table_name: str, df: pandas.DataFrame, *, by_name: bool = False, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... -def register(view_name: str, python_object: object, *, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... -def unregister(view_name: str, *, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... -def table(table_name: str, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def view(view_name: str, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def values(*args: Union[List[Any],Expression, Tuple[Expression]], connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def table_function(name: str, parameters: object = None, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def read_json(path_or_buffer: Union[str, StringIO, TextIOBase], *, columns: Optional[Dict[str,str]] = None, sample_size: Optional[int] = None, maximum_depth: Optional[int] = None, records: Optional[str] = None, format: Optional[str] = None, date_format: Optional[str] = None, timestamp_format: Optional[str] = None, compression: Optional[str] = None, maximum_object_size: Optional[int] = None, ignore_errors: Optional[bool] = None, convert_strings_to_integers: Optional[bool] = None, field_appearance_threshold: Optional[float] = None, map_inference_threshold: Optional[int] = None, maximum_sample_files: Optional[int] = None, filename: Optional[bool | str] = None, hive_partitioning: Optional[bool] = None, union_by_name: Optional[bool] = None, hive_types: Optional[Dict[str, str]] = None, hive_types_autocast: Optional[bool] = None, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def extract_statements(query: str, *, connection: DuckDBPyConnection = ...) -> List[Statement]: ... -def sql(query: str, *, alias: str = "", params: object = None, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def query(query: str, *, alias: str = "", params: object = None, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def from_query(query: str, *, alias: str = "", params: object = None, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def read_csv(path_or_buffer: Union[str, StringIO, TextIOBase], *, header: Optional[bool | int] = None, compression: Optional[str] = None, sep: Optional[str] = None, delimiter: Optional[str] = None, dtype: Optional[Dict[str, str] | List[str]] = None, na_values: Optional[str| List[str]] = None, skiprows: Optional[int] = None, quotechar: Optional[str] = None, escapechar: Optional[str] = None, encoding: Optional[str] = None, parallel: Optional[bool] = None, date_format: Optional[str] = None, timestamp_format: Optional[str] = None, sample_size: Optional[int] = None, all_varchar: Optional[bool] = None, normalize_names: Optional[bool] = None, null_padding: Optional[bool] = None, names: Optional[List[str]] = None, lineterminator: Optional[str] = None, columns: Optional[Dict[str, str]] = None, auto_type_candidates: Optional[List[str]] = None, max_line_size: Optional[int] = None, ignore_errors: Optional[bool] = None, store_rejects: Optional[bool] = None, rejects_table: Optional[str] = None, rejects_scan: Optional[str] = None, rejects_limit: Optional[int] = None, force_not_null: Optional[List[str]] = None, buffer_size: Optional[int] = None, decimal: Optional[str] = None, allow_quoted_nulls: Optional[bool] = None, filename: Optional[bool | str] = None, hive_partitioning: Optional[bool] = None, union_by_name: Optional[bool] = None, hive_types: Optional[Dict[str, str]] = None, hive_types_autocast: Optional[bool] = None, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def from_csv_auto(path_or_buffer: Union[str, StringIO, TextIOBase], *, header: Optional[bool | int] = None, compression: Optional[str] = None, sep: Optional[str] = None, delimiter: Optional[str] = None, dtype: Optional[Dict[str, str] | List[str]] = None, na_values: Optional[str| List[str]] = None, skiprows: Optional[int] = None, quotechar: Optional[str] = None, escapechar: Optional[str] = None, encoding: Optional[str] = None, parallel: Optional[bool] = None, date_format: Optional[str] = None, timestamp_format: Optional[str] = None, sample_size: Optional[int] = None, all_varchar: Optional[bool] = None, normalize_names: Optional[bool] = None, null_padding: Optional[bool] = None, names: Optional[List[str]] = None, lineterminator: Optional[str] = None, columns: Optional[Dict[str, str]] = None, auto_type_candidates: Optional[List[str]] = None, max_line_size: Optional[int] = None, ignore_errors: Optional[bool] = None, store_rejects: Optional[bool] = None, rejects_table: Optional[str] = None, rejects_scan: Optional[str] = None, rejects_limit: Optional[int] = None, force_not_null: Optional[List[str]] = None, buffer_size: Optional[int] = None, decimal: Optional[str] = None, allow_quoted_nulls: Optional[bool] = None, filename: Optional[bool | str] = None, hive_partitioning: Optional[bool] = None, union_by_name: Optional[bool] = None, hive_types: Optional[Dict[str, str]] = None, hive_types_autocast: Optional[bool] = None, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def from_df(df: pandas.DataFrame, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def from_arrow(arrow_object: object, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def from_parquet(file_glob: str, binary_as_string: bool = False, *, file_row_number: bool = False, filename: bool = False, hive_partitioning: bool = False, union_by_name: bool = False, compression: Optional[str] = None, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def read_parquet(file_glob: str, binary_as_string: bool = False, *, file_row_number: bool = False, filename: bool = False, hive_partitioning: bool = False, union_by_name: bool = False, compression: Optional[str] = None, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def get_table_names(query: str, *, qualified: bool = False, connection: DuckDBPyConnection = ...) -> Set[str]: ... -def install_extension(extension: str, *, force_install: bool = False, repository: Optional[str] = None, repository_url: Optional[str] = None, version: Optional[str] = None, connection: DuckDBPyConnection = ...) -> None: ... -def load_extension(extension: str, *, connection: DuckDBPyConnection = ...) -> None: ... -def project(df: pandas.DataFrame, *args: str, groups: str = "", connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def distinct(df: pandas.DataFrame, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def write_csv(df: pandas.DataFrame, filename: str, *, sep: Optional[str] = None, na_rep: Optional[str] = None, header: Optional[bool] = None, quotechar: Optional[str] = None, escapechar: Optional[str] = None, date_format: Optional[str] = None, timestamp_format: Optional[str] = None, quoting: Optional[str | int] = None, encoding: Optional[str] = None, compression: Optional[str] = None, overwrite: Optional[bool] = None, per_thread_output: Optional[bool] = None, use_tmp_file: Optional[bool] = None, partition_by: Optional[List[str]] = None, write_partition_columns: Optional[bool] = None, connection: DuckDBPyConnection = ...) -> None: ... -def aggregate(df: pandas.DataFrame, aggr_expr: str | List[Expression], group_expr: str = "", *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def alias(df: pandas.DataFrame, alias: str, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def filter(df: pandas.DataFrame, filter_expr: str, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def limit(df: pandas.DataFrame, n: int, offset: int = 0, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def order(df: pandas.DataFrame, order_expr: str, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def query_df(df: pandas.DataFrame, virtual_table_name: str, sql_query: str, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def description(*, connection: DuckDBPyConnection = ...) -> Optional[List[Any]]: ... -def rowcount(*, connection: DuckDBPyConnection = ...) -> int: ... -# END OF CONNECTION WRAPPER diff --git a/duckdb/_dbapi_type_object.py b/duckdb/_dbapi_type_object.py new file mode 100644 index 00000000..ed73760d --- /dev/null +++ b/duckdb/_dbapi_type_object.py @@ -0,0 +1,231 @@ +"""DuckDB DB API 2.0 Type Objects Module. + +This module provides DB API 2.0 compliant type objects for DuckDB, allowing applications +to check column types returned by queries against standard database API categories. + +Example: + >>> import duckdb + >>> + >>> conn = duckdb.connect() + >>> cursor = conn.cursor() + >>> cursor.execute("SELECT 'hello' as text_col, 42 as num_col, CURRENT_DATE as date_col") + >>> + >>> # Check column types using DB API type objects + >>> for i, desc in enumerate(cursor.description): + >>> col_name, col_type = desc[0], desc[1] + >>> if col_type == duckdb.STRING: + >>> print(f"{col_name} is a string type") + >>> elif col_type == duckdb.NUMBER: + >>> print(f"{col_name} is a numeric type") + >>> elif col_type == duckdb.DATETIME: + >>> print(f"{col_name} is a date/time type") + +See Also: + - PEP 249: https://peps.python.org/pep-0249/ + - DuckDB Type System: https://duckdb.org/docs/sql/data_types/overview +""" + +from duckdb import sqltypes + + +class DBAPITypeObject: + """DB API 2.0 type object for categorizing database column types. + + This class implements the type objects defined in PEP 249 (DB API 2.0). + It allows checking whether a specific DuckDB type belongs to a broader + category like STRING, NUMBER, DATETIME, etc. + + The type object supports equality comparison with DuckDBPyType instances, + returning True if the type belongs to this category. + + Args: + types: A list of DuckDBPyType instances that belong to this type category. + + Example: + >>> string_types = DBAPITypeObject([sqltypes.VARCHAR, sqltypes.CHAR]) + >>> result = sqltypes.VARCHAR == string_types # True + >>> result = sqltypes.INTEGER == string_types # False + + Note: + This follows the DB API 2.0 specification where type objects are compared + using equality operators rather than isinstance() checks. + """ + + def __init__(self, types: list[sqltypes.DuckDBPyType]) -> None: + """Initialize a DB API type object. + + Args: + types: List of DuckDB types that belong to this category. + """ + self.types = types + + def __eq__(self, other: object) -> bool: + """Check if a DuckDB type belongs to this type category. + + This method implements the DB API 2.0 type checking mechanism. + It returns True if the other object is a DuckDBPyType that + is contained in this type category. + + Args: + other: The object to compare, typically a DuckDBPyType instance. + + Returns: + True if other is a DuckDBPyType in this category, False otherwise. + + Example: + >>> NUMBER == sqltypes.INTEGER # True + >>> NUMBER == sqltypes.VARCHAR # False + """ + if isinstance(other, sqltypes.DuckDBPyType): + return other in self.types + return False + + def __repr__(self) -> str: + """Return a string representation of this type object. + + Returns: + A string showing the type object and its contained DuckDB types. + + Example: + >>> repr(STRING) + '' + """ + return f"" + + +# Define the standard DB API 2.0 type objects for DuckDB + +STRING = DBAPITypeObject([sqltypes.VARCHAR]) +""" +STRING type object for text-based database columns. + +This type object represents all string/text types in DuckDB. Currently includes: +- VARCHAR: Variable-length character strings + +Use this to check if a column contains textual data that should be handled +as Python strings. + +DB API 2.0 Reference: + https://peps.python.org/pep-0249/#string + +Example: + >>> cursor.description[0][1] == STRING # Check if first column is text +""" + +NUMBER = DBAPITypeObject( + [ + sqltypes.TINYINT, + sqltypes.UTINYINT, + sqltypes.SMALLINT, + sqltypes.USMALLINT, + sqltypes.INTEGER, + sqltypes.UINTEGER, + sqltypes.BIGINT, + sqltypes.UBIGINT, + sqltypes.HUGEINT, + sqltypes.UHUGEINT, + sqltypes.DuckDBPyType("BIGNUM"), + sqltypes.DuckDBPyType("DECIMAL"), + sqltypes.FLOAT, + sqltypes.DOUBLE, + ] +) +""" +NUMBER type object for numeric database columns. + +This type object represents all numeric types in DuckDB, including: + +Integer Types: +- TINYINT, UTINYINT: 8-bit signed/unsigned integers +- SMALLINT, USMALLINT: 16-bit signed/unsigned integers +- INTEGER, UINTEGER: 32-bit signed/unsigned integers +- BIGINT, UBIGINT: 64-bit signed/unsigned integers +- HUGEINT, UHUGEINT: 128-bit signed/unsigned integers + +Decimal Types: +- BIGNUM: Arbitrary precision integers +- DECIMAL: Fixed-point decimal numbers + +Floating Point Types: +- FLOAT: 32-bit floating point +- DOUBLE: 64-bit floating point + +Use this to check if a column contains numeric data that should be handled +as Python int, float, or Decimal objects. + +DB API 2.0 Reference: + https://peps.python.org/pep-0249/#number + +Example: + >>> cursor.description[1][1] == NUMBER # Check if second column is numeric +""" + +DATETIME = DBAPITypeObject( + [ + sqltypes.DATE, + sqltypes.TIME, + sqltypes.TIME_TZ, + sqltypes.TIMESTAMP, + sqltypes.TIMESTAMP_TZ, + sqltypes.TIMESTAMP_NS, + sqltypes.TIMESTAMP_MS, + sqltypes.TIMESTAMP_S, + ] +) +""" +DATETIME type object for date and time database columns. + +This type object represents all date/time types in DuckDB, including: + +Date Types: +- DATE: Calendar dates (year, month, day) + +Time Types: +- TIME: Time of day without timezone +- TIME_TZ: Time of day with timezone + +Timestamp Types: +- TIMESTAMP: Date and time without timezone (microsecond precision) +- TIMESTAMP_TZ: Date and time with timezone +- TIMESTAMP_NS: Nanosecond precision timestamps +- TIMESTAMP_MS: Millisecond precision timestamps +- TIMESTAMP_S: Second precision timestamps + +Use this to check if a column contains temporal data that should be handled +as Python datetime, date, or time objects. + +DB API 2.0 Reference: + https://peps.python.org/pep-0249/#datetime + +Example: + >>> cursor.description[2][1] == DATETIME # Check if third column is date/time +""" + +BINARY = DBAPITypeObject([sqltypes.BLOB]) +""" +BINARY type object for binary data database columns. + +This type object represents binary data types in DuckDB: +- BLOB: Binary Large Objects for storing arbitrary binary data + +Use this to check if a column contains binary data that should be handled +as Python bytes objects. + +DB API 2.0 Reference: + https://peps.python.org/pep-0249/#binary + +Example: + >>> cursor.description[3][1] == BINARY # Check if fourth column is binary +""" + +ROWID = None +""" +ROWID type object for row identifier columns. + +DB API 2.0 Reference: + https://peps.python.org/pep-0249/#rowid + +Note: + This will always be None for DuckDB connections. Applications should not + rely on ROWID functionality when using DuckDB. +""" diff --git a/duckdb/_version.py b/duckdb/_version.py new file mode 100644 index 00000000..165bdef2 --- /dev/null +++ b/duckdb/_version.py @@ -0,0 +1,22 @@ +# ---------------------------------------------------------------------- +# Version API +# +# We provide three symbols: +# - duckdb.__version__: The version of this package +# - duckdb.__duckdb_version__: The version of duckdb that is bundled +# - duckdb.version(): A human-readable version string containing both of the above +# ---------------------------------------------------------------------- +from importlib.metadata import version as _dist_version + +import _duckdb + +__version__: str = _dist_version("duckdb") +"""Version of the DuckDB Python Package.""" + +__duckdb_version__: str = _duckdb.__version__ +"""Version of DuckDB that is bundled.""" + + +def version() -> str: + """Human-friendly formatted version string of both the distribution package and the bundled DuckDB engine.""" + return f"{__version__} (with duckdb {_duckdb.__version__})" diff --git a/duckdb/bytes_io_wrapper.py b/duckdb/bytes_io_wrapper.py index 829b69cd..722c7cb4 100644 --- a/duckdb/bytes_io_wrapper.py +++ b/duckdb/bytes_io_wrapper.py @@ -1,7 +1,5 @@ -from io import StringIO, TextIOBase -from typing import Union +"""StringIO buffer wrapper. -""" BSD 3-Clause License Copyright (c) 2008-2011, AQR Capital Management, LLC, Lambda Foundry, Inc. and PyData Development Team @@ -35,11 +33,17 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """ +from io import StringIO, TextIOBase +from typing import Any, Union + class BytesIOWrapper: - # Wrapper that wraps a StringIO buffer and reads bytes from it - # Created for compat with pyarrow read_csv - def __init__(self, buffer: Union[StringIO, TextIOBase], encoding: str = "utf-8") -> None: + """Wrapper that wraps a StringIO buffer and reads bytes from it. + + Created for compat with pyarrow read_csv. + """ + + def __init__(self, buffer: Union[StringIO, TextIOBase], encoding: str = "utf-8") -> None: # noqa: D107 self.buffer = buffer self.encoding = encoding # Because a character can be represented by more than 1 byte, @@ -48,10 +52,10 @@ def __init__(self, buffer: Union[StringIO, TextIOBase], encoding: str = "utf-8") # overflow to the front of the bytestring the next time reading is performed self.overflow = b"" - def __getattr__(self, attr: str): + def __getattr__(self, attr: str) -> Any: # noqa: D105, ANN401 return getattr(self.buffer, attr) - def read(self, n: Union[int, None] = -1) -> bytes: + def read(self, n: Union[int, None] = -1) -> bytes: # noqa: D102 assert self.buffer is not None bytestring = self.buffer.read(n).encode(self.encoding) # When n=-1/n greater than remaining bytes: Read entire file/rest of file @@ -63,4 +67,3 @@ def read(self, n: Union[int, None] = -1) -> bytes: to_return = combined_bytestring[:n] self.overflow = combined_bytestring[n:] return to_return - diff --git a/duckdb/experimental/__init__.py b/duckdb/experimental/__init__.py index 0ab3305b..1b5ee51b 100644 --- a/duckdb/experimental/__init__.py +++ b/duckdb/experimental/__init__.py @@ -1,2 +1,3 @@ -from . import spark +from . import spark # noqa: D104 + __all__ = spark.__all__ diff --git a/duckdb/experimental/spark/__init__.py b/duckdb/experimental/spark/__init__.py index 66895dcb..f9db73ef 100644 --- a/duckdb/experimental/spark/__init__.py +++ b/duckdb/experimental/spark/__init__.py @@ -1,7 +1,6 @@ -from .sql import SparkSession, DataFrame -from .conf import SparkConf +from .conf import SparkConf # noqa: D104 from .context import SparkContext -from ._globals import _NoValue from .exception import ContributionsAcceptedError +from .sql import DataFrame, SparkSession -__all__ = ["SparkSession", "DataFrame", "SparkConf", "SparkContext", "ContributionsAcceptedError"] +__all__ = ["ContributionsAcceptedError", "DataFrame", "SparkConf", "SparkContext", "SparkSession"] diff --git a/duckdb/experimental/spark/_globals.py b/duckdb/experimental/spark/_globals.py index c43287e6..0625a140 100644 --- a/duckdb/experimental/spark/_globals.py +++ b/duckdb/experimental/spark/_globals.py @@ -15,8 +15,7 @@ # limitations under the License. # -""" -Module defining global singleton classes. +"""Module defining global singleton classes. This module raises a RuntimeError if an attempt to reload it is made. In that way the identities of the classes defined here are fixed and will remain so @@ -38,7 +37,8 @@ def foo(arg=pyducdkb.spark._NoValue): # Disallow reloading this module so as to preserve the identities of the # classes defined here. if "_is_loaded" in globals(): - raise RuntimeError("Reloading duckdb.experimental.spark._globals is not allowed") + msg = "Reloading duckdb.experimental.spark._globals is not allowed" + raise RuntimeError(msg) _is_loaded = True @@ -54,23 +54,23 @@ class _NoValueType: __instance = None - def __new__(cls): + def __new__(cls) -> "_NoValueType": # ensure that only one instance exists if not cls.__instance: - cls.__instance = super(_NoValueType, cls).__new__(cls) + cls.__instance = super().__new__(cls) return cls.__instance # Make the _NoValue instance falsey - def __nonzero__(self): + def __nonzero__(self) -> bool: return False __bool__ = __nonzero__ # needed for python 2 to preserve identity through a pickle - def __reduce__(self): + def __reduce__(self) -> tuple[type, tuple]: return (self.__class__, ()) - def __repr__(self): + def __repr__(self) -> str: return "" diff --git a/duckdb/experimental/spark/_typing.py b/duckdb/experimental/spark/_typing.py index 0c06fed5..1ed78ea8 100644 --- a/duckdb/experimental/spark/_typing.py +++ b/duckdb/experimental/spark/_typing.py @@ -16,10 +16,11 @@ # specific language governing permissions and limitations # under the License. -from typing import Callable, Iterable, Sized, TypeVar, Union -from typing_extensions import Literal, Protocol +from collections.abc import Iterable, Sized +from typing import Callable, TypeVar, Union -from numpy import int32, int64, float32, float64, ndarray +from numpy import float32, float64, int32, int64, ndarray +from typing_extensions import Literal, Protocol, Self F = TypeVar("F", bound=Callable) T_co = TypeVar("T_co", covariant=True) @@ -30,17 +31,14 @@ class SupportsIAdd(Protocol): - def __iadd__(self, other: "SupportsIAdd") -> "SupportsIAdd": - ... + def __iadd__(self, other: "SupportsIAdd") -> Self: ... class SupportsOrdering(Protocol): - def __lt__(self, other: "SupportsOrdering") -> bool: - ... + def __lt__(self, other: "SupportsOrdering") -> bool: ... -class SizedIterable(Protocol, Sized, Iterable[T_co]): - ... +class SizedIterable(Protocol, Sized, Iterable[T_co]): ... S = TypeVar("S", bound=SupportsOrdering) diff --git a/duckdb/experimental/spark/conf.py b/duckdb/experimental/spark/conf.py index 11680a9a..974115d6 100644 --- a/duckdb/experimental/spark/conf.py +++ b/duckdb/experimental/spark/conf.py @@ -1,44 +1,45 @@ -from typing import Optional, List, Tuple +from typing import Optional # noqa: D100 + from duckdb.experimental.spark.exception import ContributionsAcceptedError -class SparkConf: - def __init__(self): +class SparkConf: # noqa: D101 + def __init__(self) -> None: # noqa: D107 raise NotImplementedError - def contains(self, key: str) -> bool: + def contains(self, key: str) -> bool: # noqa: D102 raise ContributionsAcceptedError - def get(self, key: str, defaultValue: Optional[str] = None) -> Optional[str]: + def get(self, key: str, defaultValue: Optional[str] = None) -> Optional[str]: # noqa: D102 raise ContributionsAcceptedError - def getAll(self) -> List[Tuple[str, str]]: + def getAll(self) -> list[tuple[str, str]]: # noqa: D102 raise ContributionsAcceptedError - def set(self, key: str, value: str) -> "SparkConf": + def set(self, key: str, value: str) -> "SparkConf": # noqa: D102 raise ContributionsAcceptedError - def setAll(self, pairs: List[Tuple[str, str]]) -> "SparkConf": + def setAll(self, pairs: list[tuple[str, str]]) -> "SparkConf": # noqa: D102 raise ContributionsAcceptedError - def setAppName(self, value: str) -> "SparkConf": + def setAppName(self, value: str) -> "SparkConf": # noqa: D102 raise ContributionsAcceptedError - def setExecutorEnv( - self, key: Optional[str] = None, value: Optional[str] = None, pairs: Optional[List[Tuple[str, str]]] = None + def setExecutorEnv( # noqa: D102 + self, key: Optional[str] = None, value: Optional[str] = None, pairs: Optional[list[tuple[str, str]]] = None ) -> "SparkConf": raise ContributionsAcceptedError - def setIfMissing(self, key: str, value: str) -> "SparkConf": + def setIfMissing(self, key: str, value: str) -> "SparkConf": # noqa: D102 raise ContributionsAcceptedError - def setMaster(self, value: str) -> "SparkConf": + def setMaster(self, value: str) -> "SparkConf": # noqa: D102 raise ContributionsAcceptedError - def setSparkHome(self, value: str) -> "SparkConf": + def setSparkHome(self, value: str) -> "SparkConf": # noqa: D102 raise ContributionsAcceptedError - def toDebugString(self) -> str: + def toDebugString(self) -> str: # noqa: D102 raise ContributionsAcceptedError diff --git a/duckdb/experimental/spark/context.py b/duckdb/experimental/spark/context.py index a2e7c78f..c78bde65 100644 --- a/duckdb/experimental/spark/context.py +++ b/duckdb/experimental/spark/context.py @@ -1,42 +1,42 @@ -from typing import Optional +from typing import Optional # noqa: D100 + import duckdb from duckdb import DuckDBPyConnection - -from duckdb.experimental.spark.exception import ContributionsAcceptedError from duckdb.experimental.spark.conf import SparkConf +from duckdb.experimental.spark.exception import ContributionsAcceptedError -class SparkContext: - def __init__(self, master: str): - self._connection = duckdb.connect(':memory:') +class SparkContext: # noqa: D101 + def __init__(self, master: str) -> None: # noqa: D107 + self._connection = duckdb.connect(":memory:") # This aligns the null ordering with Spark. self._connection.execute("set default_null_order='nulls_first_on_asc_last_on_desc'") @property - def connection(self) -> DuckDBPyConnection: + def connection(self) -> DuckDBPyConnection: # noqa: D102 return self._connection - def stop(self) -> None: + def stop(self) -> None: # noqa: D102 self._connection.close() @classmethod - def getOrCreate(cls, conf: Optional[SparkConf] = None) -> "SparkContext": + def getOrCreate(cls, conf: Optional[SparkConf] = None) -> "SparkContext": # noqa: D102 raise ContributionsAcceptedError @classmethod - def setSystemProperty(cls, key: str, value: str) -> None: + def setSystemProperty(cls, key: str, value: str) -> None: # noqa: D102 raise ContributionsAcceptedError @property - def applicationId(self) -> str: + def applicationId(self) -> str: # noqa: D102 raise ContributionsAcceptedError @property - def defaultMinPartitions(self) -> int: + def defaultMinPartitions(self) -> int: # noqa: D102 raise ContributionsAcceptedError @property - def defaultParallelism(self) -> int: + def defaultParallelism(self) -> int: # noqa: D102 raise ContributionsAcceptedError # @property @@ -44,33 +44,35 @@ def defaultParallelism(self) -> int: # raise ContributionsAcceptedError @property - def startTime(self) -> str: + def startTime(self) -> str: # noqa: D102 raise ContributionsAcceptedError @property - def uiWebUrl(self) -> str: + def uiWebUrl(self) -> str: # noqa: D102 raise ContributionsAcceptedError @property - def version(self) -> str: + def version(self) -> str: # noqa: D102 raise ContributionsAcceptedError - def __repr__(self) -> str: + def __repr__(self) -> str: # noqa: D105 raise ContributionsAcceptedError - # def accumulator(self, value: ~T, accum_param: Optional[ForwardRef('AccumulatorParam[T]')] = None) -> 'Accumulator[T]': + # def accumulator(self, value: ~T, accum_param: Optional[ForwardRef('AccumulatorParam[T]')] = None + # ) -> 'Accumulator[T]': # pass - def addArchive(self, path: str) -> None: + def addArchive(self, path: str) -> None: # noqa: D102 raise ContributionsAcceptedError - def addFile(self, path: str, recursive: bool = False) -> None: + def addFile(self, path: str, recursive: bool = False) -> None: # noqa: D102 raise ContributionsAcceptedError - def addPyFile(self, path: str) -> None: + def addPyFile(self, path: str) -> None: # noqa: D102 raise ContributionsAcceptedError - # def binaryFiles(self, path: str, minPartitions: Optional[int] = None) -> duckdb.experimental.spark.rdd.RDD[typing.Tuple[str, bytes]]: + # def binaryFiles(self, path: str, minPartitions: Optional[int] = None + # ) -> duckdb.experimental.spark.rdd.RDD[typing.Tuple[str, bytes]]: # pass # def binaryRecords(self, path: str, recordLength: int) -> duckdb.experimental.spark.rdd.RDD[bytes]: @@ -79,37 +81,45 @@ def addPyFile(self, path: str) -> None: # def broadcast(self, value: ~T) -> 'Broadcast[T]': # pass - def cancelAllJobs(self) -> None: + def cancelAllJobs(self) -> None: # noqa: D102 raise ContributionsAcceptedError - def cancelJobGroup(self, groupId: str) -> None: + def cancelJobGroup(self, groupId: str) -> None: # noqa: D102 raise ContributionsAcceptedError - def dump_profiles(self, path: str) -> None: + def dump_profiles(self, path: str) -> None: # noqa: D102 raise ContributionsAcceptedError # def emptyRDD(self) -> duckdb.experimental.spark.rdd.RDD[typing.Any]: # pass - def getCheckpointDir(self) -> Optional[str]: + def getCheckpointDir(self) -> Optional[str]: # noqa: D102 raise ContributionsAcceptedError - def getConf(self) -> SparkConf: + def getConf(self) -> SparkConf: # noqa: D102 raise ContributionsAcceptedError - def getLocalProperty(self, key: str) -> Optional[str]: + def getLocalProperty(self, key: str) -> Optional[str]: # noqa: D102 raise ContributionsAcceptedError - # def hadoopFile(self, path: str, inputFormatClass: str, keyClass: str, valueClass: str, keyConverter: Optional[str] = None, valueConverter: Optional[str] = None, conf: Optional[Dict[str, str]] = None, batchSize: int = 0) -> pyspark.rdd.RDD[typing.Tuple[~T, ~U]]: + # def hadoopFile(self, path: str, inputFormatClass: str, keyClass: str, valueClass: str, + # keyConverter: Optional[str] = None, valueConverter: Optional[str] = None, + # conf: Optional[Dict[str, str]] = None, batchSize: int = 0) -> pyspark.rdd.RDD[typing.Tuple[~T, ~U]]: # pass - # def hadoopRDD(self, inputFormatClass: str, keyClass: str, valueClass: str, keyConverter: Optional[str] = None, valueConverter: Optional[str] = None, conf: Optional[Dict[str, str]] = None, batchSize: int = 0) -> pyspark.rdd.RDD[typing.Tuple[~T, ~U]]: + # def hadoopRDD(self, inputFormatClass: str, keyClass: str, valueClass: str, keyConverter: Optional[str] = None, + # valueConverter: Optional[str] = None, conf: Optional[Dict[str, str]] = None, batchSize: int = 0 + # ) -> pyspark.rdd.RDD[typing.Tuple[~T, ~U]]: # pass - # def newAPIHadoopFile(self, path: str, inputFormatClass: str, keyClass: str, valueClass: str, keyConverter: Optional[str] = None, valueConverter: Optional[str] = None, conf: Optional[Dict[str, str]] = None, batchSize: int = 0) -> pyspark.rdd.RDD[typing.Tuple[~T, ~U]]: + # def newAPIHadoopFile(self, path: str, inputFormatClass: str, keyClass: str, valueClass: str, + # keyConverter: Optional[str] = None, valueConverter: Optional[str] = None, + # conf: Optional[Dict[str, str]] = None, batchSize: int = 0) -> pyspark.rdd.RDD[typing.Tuple[~T, ~U]]: # pass - # def newAPIHadoopRDD(self, inputFormatClass: str, keyClass: str, valueClass: str, keyConverter: Optional[str] = None, valueConverter: Optional[str] = None, conf: Optional[Dict[str, str]] = None, batchSize: int = 0) -> pyspark.rdd.RDD[typing.Tuple[~T, ~U]]: + # def newAPIHadoopRDD(self, inputFormatClass: str, keyClass: str, valueClass: str, + # keyConverter: Optional[str] = None, valueConverter: Optional[str] = None, + # conf: Optional[Dict[str, str]] = None, batchSize: int = 0) -> pyspark.rdd.RDD[typing.Tuple[~T, ~U]]: # pass # def parallelize(self, c: Iterable[~T], numSlices: Optional[int] = None) -> pyspark.rdd.RDD[~T]: @@ -118,46 +128,52 @@ def getLocalProperty(self, key: str) -> Optional[str]: # def pickleFile(self, name: str, minPartitions: Optional[int] = None) -> pyspark.rdd.RDD[typing.Any]: # pass - # def range(self, start: int, end: Optional[int] = None, step: int = 1, numSlices: Optional[int] = None) -> pyspark.rdd.RDD[int]: + # def range(self, start: int, end: Optional[int] = None, step: int = 1, numSlices: Optional[int] = None + # ) -> pyspark.rdd.RDD[int]: # pass - # def runJob(self, rdd: pyspark.rdd.RDD[~T], partitionFunc: Callable[[Iterable[~T]], Iterable[~U]], partitions: Optional[Sequence[int]] = None, allowLocal: bool = False) -> List[~U]: + # def runJob(self, rdd: pyspark.rdd.RDD[~T], partitionFunc: Callable[[Iterable[~T]], Iterable[~U]], + # partitions: Optional[Sequence[int]] = None, allowLocal: bool = False) -> List[~U]: # pass - # def sequenceFile(self, path: str, keyClass: Optional[str] = None, valueClass: Optional[str] = None, keyConverter: Optional[str] = None, valueConverter: Optional[str] = None, minSplits: Optional[int] = None, batchSize: int = 0) -> pyspark.rdd.RDD[typing.Tuple[~T, ~U]]: + # def sequenceFile(self, path: str, keyClass: Optional[str] = None, valueClass: Optional[str] = None, + # keyConverter: Optional[str] = None, valueConverter: Optional[str] = None, minSplits: Optional[int] = None, + # batchSize: int = 0) -> pyspark.rdd.RDD[typing.Tuple[~T, ~U]]: # pass - def setCheckpointDir(self, dirName: str) -> None: + def setCheckpointDir(self, dirName: str) -> None: # noqa: D102 raise ContributionsAcceptedError - def setJobDescription(self, value: str) -> None: + def setJobDescription(self, value: str) -> None: # noqa: D102 raise ContributionsAcceptedError - def setJobGroup(self, groupId: str, description: str, interruptOnCancel: bool = False) -> None: + def setJobGroup(self, groupId: str, description: str, interruptOnCancel: bool = False) -> None: # noqa: D102 raise ContributionsAcceptedError - def setLocalProperty(self, key: str, value: str) -> None: + def setLocalProperty(self, key: str, value: str) -> None: # noqa: D102 raise ContributionsAcceptedError - def setLogLevel(self, logLevel: str) -> None: + def setLogLevel(self, logLevel: str) -> None: # noqa: D102 raise ContributionsAcceptedError - def show_profiles(self) -> None: + def show_profiles(self) -> None: # noqa: D102 raise ContributionsAcceptedError - def sparkUser(self) -> str: + def sparkUser(self) -> str: # noqa: D102 raise ContributionsAcceptedError # def statusTracker(self) -> duckdb.experimental.spark.status.StatusTracker: # raise ContributionsAcceptedError - # def textFile(self, name: str, minPartitions: Optional[int] = None, use_unicode: bool = True) -> pyspark.rdd.RDD[str]: + # def textFile(self, name: str, minPartitions: Optional[int] = None, use_unicode: bool = True + # ) -> pyspark.rdd.RDD[str]: # pass # def union(self, rdds: List[pyspark.rdd.RDD[~T]]) -> pyspark.rdd.RDD[~T]: # pass - # def wholeTextFiles(self, path: str, minPartitions: Optional[int] = None, use_unicode: bool = True) -> pyspark.rdd.RDD[typing.Tuple[str, str]]: + # def wholeTextFiles(self, path: str, minPartitions: Optional[int] = None, use_unicode: bool = True + # ) -> pyspark.rdd.RDD[typing.Tuple[str, str]]: # pass diff --git a/duckdb/experimental/spark/errors/__init__.py b/duckdb/experimental/spark/errors/__init__.py index 5f2af443..ee7688ea 100644 --- a/duckdb/experimental/spark/errors/__init__.py +++ b/duckdb/experimental/spark/errors/__init__.py @@ -15,58 +15,56 @@ # limitations under the License. # -""" -PySpark exceptions. -""" -from .exceptions.base import ( # noqa: F401 - PySparkException, +"""PySpark exceptions.""" + +from .exceptions.base import ( AnalysisException, - TempTableAlreadyExistsException, - ParseException, - IllegalArgumentException, ArithmeticException, - UnsupportedOperationException, ArrayIndexOutOfBoundsException, DateTimeException, + IllegalArgumentException, NumberFormatException, - StreamingQueryException, - QueryExecutionException, + ParseException, + PySparkAssertionError, + PySparkAttributeError, + PySparkException, + PySparkIndexError, + PySparkNotImplementedError, + PySparkRuntimeError, + PySparkTypeError, + PySparkValueError, PythonException, - UnknownException, + QueryExecutionException, SparkRuntimeException, SparkUpgradeException, - PySparkTypeError, - PySparkValueError, - PySparkIndexError, - PySparkAttributeError, - PySparkRuntimeError, - PySparkAssertionError, - PySparkNotImplementedError, + StreamingQueryException, + TempTableAlreadyExistsException, + UnknownException, + UnsupportedOperationException, ) - __all__ = [ - "PySparkException", "AnalysisException", - "TempTableAlreadyExistsException", - "ParseException", - "IllegalArgumentException", "ArithmeticException", - "UnsupportedOperationException", "ArrayIndexOutOfBoundsException", "DateTimeException", + "IllegalArgumentException", "NumberFormatException", - "StreamingQueryException", - "QueryExecutionException", + "ParseException", + "PySparkAssertionError", + "PySparkAttributeError", + "PySparkException", + "PySparkIndexError", + "PySparkNotImplementedError", + "PySparkRuntimeError", + "PySparkTypeError", + "PySparkValueError", "PythonException", - "UnknownException", + "QueryExecutionException", "SparkRuntimeException", "SparkUpgradeException", - "PySparkTypeError", - "PySparkValueError", - "PySparkIndexError", - "PySparkAttributeError", - "PySparkRuntimeError", - "PySparkAssertionError", - "PySparkNotImplementedError", + "StreamingQueryException", + "TempTableAlreadyExistsException", + "UnknownException", + "UnsupportedOperationException", ] diff --git a/duckdb/experimental/spark/errors/error_classes.py b/duckdb/experimental/spark/errors/error_classes.py index 256fb644..22055cbf 100644 --- a/duckdb/experimental/spark/errors/error_classes.py +++ b/duckdb/experimental/spark/errors/error_classes.py @@ -1,4 +1,4 @@ -# +# ruff: noqa: D100, E501 # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. diff --git a/duckdb/experimental/spark/errors/exceptions/__init__.py b/duckdb/experimental/spark/errors/exceptions/__init__.py index cce3acad..edd0e7e1 100644 --- a/duckdb/experimental/spark/errors/exceptions/__init__.py +++ b/duckdb/experimental/spark/errors/exceptions/__init__.py @@ -1,4 +1,4 @@ -# +# # noqa: D104 # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. diff --git a/duckdb/experimental/spark/errors/exceptions/base.py b/duckdb/experimental/spark/errors/exceptions/base.py index 21dba03b..2eae2a19 100644 --- a/duckdb/experimental/spark/errors/exceptions/base.py +++ b/duckdb/experimental/spark/errors/exceptions/base.py @@ -1,20 +1,19 @@ -from typing import Dict, Optional, cast +from typing import Optional, cast # noqa: D100 from ..utils import ErrorClassesReader + class PySparkException(Exception): - """ - Base Exception for handling errors generated from PySpark. - """ + """Base Exception for handling errors generated from PySpark.""" - def __init__( + def __init__( # noqa: D107 self, message: Optional[str] = None, # The error class, decides the message format, must be one of the valid options listed in 'error_classes.py' error_class: Optional[str] = None, # The dictionary listing the arguments specified in the message (or the error_class) - message_parameters: Optional[Dict[str, str]] = None, - ): + message_parameters: Optional[dict[str, str]] = None, + ) -> None: # `message` vs `error_class` & `message_parameters` are mutually exclusive. assert (message is not None and (error_class is None and message_parameters is None)) or ( message is None and (error_class is not None and message_parameters is not None) @@ -24,7 +23,7 @@ def __init__( if message is None: self.message = self.error_reader.get_error_message( - cast(str, error_class), cast(Dict[str, str], message_parameters) + cast("str", error_class), cast("dict[str, str]", message_parameters) ) else: self.message = message @@ -33,25 +32,23 @@ def __init__( self.message_parameters = message_parameters def getErrorClass(self) -> Optional[str]: - """ - Returns an error class as a string. + """Returns an error class as a string. .. versionadded:: 3.4.0 - See Also + See Also: -------- :meth:`PySparkException.getMessageParameters` :meth:`PySparkException.getSqlState` """ return self.error_class - def getMessageParameters(self) -> Optional[Dict[str, str]]: - """ - Returns a message parameters as a dictionary. + def getMessageParameters(self) -> Optional[dict[str, str]]: + """Returns a message parameters as a dictionary. .. versionadded:: 3.4.0 - See Also + See Also: -------- :meth:`PySparkException.getErrorClass` :meth:`PySparkException.getSqlState` @@ -59,159 +56,113 @@ def getMessageParameters(self) -> Optional[Dict[str, str]]: return self.message_parameters def getSqlState(self) -> None: - """ - Returns an SQLSTATE as a string. + """Returns an SQLSTATE as a string. Errors generated in Python have no SQLSTATE, so it always returns None. .. versionadded:: 3.4.0 - See Also + See Also: -------- :meth:`PySparkException.getErrorClass` :meth:`PySparkException.getMessageParameters` """ return None - def __str__(self) -> str: + def __str__(self) -> str: # noqa: D105 if self.getErrorClass() is not None: return f"[{self.getErrorClass()}] {self.message}" else: return self.message + class AnalysisException(PySparkException): - """ - Failed to analyze a SQL query plan. - """ + """Failed to analyze a SQL query plan.""" class SessionNotSameException(PySparkException): - """ - Performed the same operation on different SparkSession. - """ + """Performed the same operation on different SparkSession.""" class TempTableAlreadyExistsException(AnalysisException): - """ - Failed to create temp view since it is already exists. - """ + """Failed to create temp view since it is already exists.""" class ParseException(AnalysisException): - """ - Failed to parse a SQL command. - """ + """Failed to parse a SQL command.""" class IllegalArgumentException(PySparkException): - """ - Passed an illegal or inappropriate argument. - """ + """Passed an illegal or inappropriate argument.""" class ArithmeticException(PySparkException): - """ - Arithmetic exception thrown from Spark with an error class. - """ + """Arithmetic exception thrown from Spark with an error class.""" class UnsupportedOperationException(PySparkException): - """ - Unsupported operation exception thrown from Spark with an error class. - """ + """Unsupported operation exception thrown from Spark with an error class.""" class ArrayIndexOutOfBoundsException(PySparkException): - """ - Array index out of bounds exception thrown from Spark with an error class. - """ + """Array index out of bounds exception thrown from Spark with an error class.""" class DateTimeException(PySparkException): - """ - Datetime exception thrown from Spark with an error class. - """ + """Datetime exception thrown from Spark with an error class.""" class NumberFormatException(IllegalArgumentException): - """ - Number format exception thrown from Spark with an error class. - """ + """Number format exception thrown from Spark with an error class.""" class StreamingQueryException(PySparkException): - """ - Exception that stopped a :class:`StreamingQuery`. - """ + """Exception that stopped a :class:`StreamingQuery`.""" class QueryExecutionException(PySparkException): - """ - Failed to execute a query. - """ + """Failed to execute a query.""" class PythonException(PySparkException): - """ - Exceptions thrown from Python workers. - """ + """Exceptions thrown from Python workers.""" class SparkRuntimeException(PySparkException): - """ - Runtime exception thrown from Spark with an error class. - """ + """Runtime exception thrown from Spark with an error class.""" class SparkUpgradeException(PySparkException): - """ - Exception thrown because of Spark upgrade. - """ + """Exception thrown because of Spark upgrade.""" class UnknownException(PySparkException): - """ - None of the above exceptions. - """ + """None of the above exceptions.""" class PySparkValueError(PySparkException, ValueError): - """ - Wrapper class for ValueError to support error classes. - """ + """Wrapper class for ValueError to support error classes.""" class PySparkIndexError(PySparkException, IndexError): - """ - Wrapper class for IndexError to support error classes. - """ + """Wrapper class for IndexError to support error classes.""" class PySparkTypeError(PySparkException, TypeError): - """ - Wrapper class for TypeError to support error classes. - """ + """Wrapper class for TypeError to support error classes.""" class PySparkAttributeError(PySparkException, AttributeError): - """ - Wrapper class for AttributeError to support error classes. - """ + """Wrapper class for AttributeError to support error classes.""" class PySparkRuntimeError(PySparkException, RuntimeError): - """ - Wrapper class for RuntimeError to support error classes. - """ + """Wrapper class for RuntimeError to support error classes.""" class PySparkAssertionError(PySparkException, AssertionError): - """ - Wrapper class for AssertionError to support error classes. - """ + """Wrapper class for AssertionError to support error classes.""" class PySparkNotImplementedError(PySparkException, NotImplementedError): - """ - Wrapper class for NotImplementedError to support error classes. - """ + """Wrapper class for NotImplementedError to support error classes.""" diff --git a/duckdb/experimental/spark/errors/utils.py b/duckdb/experimental/spark/errors/utils.py index a375c0c7..8a71f3b0 100644 --- a/duckdb/experimental/spark/errors/utils.py +++ b/duckdb/experimental/spark/errors/utils.py @@ -1,4 +1,4 @@ -# +# # noqa: D100 # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. @@ -16,37 +16,30 @@ # import re -from typing import Dict from .error_classes import ERROR_CLASSES_MAP class ErrorClassesReader: - """ - A reader to load error information from error_classes.py. - """ + """A reader to load error information from error_classes.py.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 self.error_info_map = ERROR_CLASSES_MAP - def get_error_message(self, error_class: str, message_parameters: Dict[str, str]) -> str: - """ - Returns the completed error message by applying message parameters to the message template. - """ + def get_error_message(self, error_class: str, message_parameters: dict[str, str]) -> str: + """Returns the completed error message by applying message parameters to the message template.""" message_template = self.get_message_template(error_class) # Verify message parameters. message_parameters_from_template = re.findall("<([a-zA-Z0-9_-]+)>", message_template) assert set(message_parameters_from_template) == set(message_parameters), ( - f"Undefined error message parameter for error class: {error_class}. " - f"Parameters: {message_parameters}" + f"Undefined error message parameter for error class: {error_class}. Parameters: {message_parameters}" ) table = str.maketrans("<>", "{}") return message_template.translate(table).format(**message_parameters) def get_message_template(self, error_class: str) -> str: - """ - Returns the message template for corresponding error class from error_classes.py. + """Returns the message template for corresponding error class from error_classes.py. For example, when given `error_class` is "EXAMPLE_ERROR_CLASS", @@ -93,7 +86,8 @@ def get_message_template(self, error_class: str) -> str: if main_error_class in self.error_info_map: main_error_class_info_map = self.error_info_map[main_error_class] else: - raise ValueError(f"Cannot find main error class '{main_error_class}'") + msg = f"Cannot find main error class '{main_error_class}'" + raise ValueError(msg) main_message_template = "\n".join(main_error_class_info_map["message"]) @@ -108,7 +102,8 @@ def get_message_template(self, error_class: str) -> str: if sub_error_class in main_error_class_subclass_info_map: sub_error_class_info_map = main_error_class_subclass_info_map[sub_error_class] else: - raise ValueError(f"Cannot find sub error class '{sub_error_class}'") + msg = f"Cannot find sub error class '{sub_error_class}'" + raise ValueError(msg) sub_message_template = "\n".join(sub_error_class_info_map["message"]) message_template = main_message_template + " " + sub_message_template diff --git a/duckdb/experimental/spark/exception.py b/duckdb/experimental/spark/exception.py index 7cb47650..c3a7c1b6 100644 --- a/duckdb/experimental/spark/exception.py +++ b/duckdb/experimental/spark/exception.py @@ -1,14 +1,17 @@ +# ruff: noqa: D100 +from typing import Optional + + class ContributionsAcceptedError(NotImplementedError): - """ - This method is not planned to be implemented, if you would like to implement this method + """This method is not planned to be implemented, if you would like to implement this method or show your interest in this method to other members of the community, - feel free to open up a PR or a Discussion over on https://github.com/duckdb/duckdb - """ + feel free to open up a PR or a Discussion over on https://github.com/duckdb/duckdb. + """ # noqa: D205 - def __init__(self, message=None): + def __init__(self, message: Optional[str] = None) -> None: # noqa: D107 doc = self.__class__.__doc__ if message: - doc = message + '\n' + doc + doc = message + "\n" + doc super().__init__(doc) diff --git a/duckdb/experimental/spark/sql/__init__.py b/duckdb/experimental/spark/sql/__init__.py index 2312ee50..418273f0 100644 --- a/duckdb/experimental/spark/sql/__init__.py +++ b/duckdb/experimental/spark/sql/__init__.py @@ -1,7 +1,7 @@ -from .session import SparkSession -from .readwriter import DataFrameWriter -from .dataframe import DataFrame +from .catalog import Catalog # noqa: D104 from .conf import RuntimeConfig -from .catalog import Catalog +from .dataframe import DataFrame +from .readwriter import DataFrameWriter +from .session import SparkSession -__all__ = ["SparkSession", "DataFrame", "RuntimeConfig", "DataFrameWriter", "Catalog"] +__all__ = ["Catalog", "DataFrame", "DataFrameWriter", "RuntimeConfig", "SparkSession"] diff --git a/duckdb/experimental/spark/sql/_typing.py b/duckdb/experimental/spark/sql/_typing.py index 7b1f9ad1..caf0058c 100644 --- a/duckdb/experimental/spark/sql/_typing.py +++ b/duckdb/experimental/spark/sql/_typing.py @@ -19,12 +19,11 @@ from typing import ( Any, Callable, - List, Optional, - Tuple, TypeVar, Union, ) + try: from typing import Literal, Protocol except ImportError: @@ -57,24 +56,21 @@ float, ) -RowLike = TypeVar("RowLike", List[Any], Tuple[Any, ...], types.Row) +RowLike = TypeVar("RowLike", list[Any], tuple[Any, ...], types.Row) SQLBatchedUDFType = Literal[100] class SupportsOpen(Protocol): - def open(self, partition_id: int, epoch_id: int) -> bool: - ... + def open(self, partition_id: int, epoch_id: int) -> bool: ... class SupportsProcess(Protocol): - def process(self, row: types.Row) -> None: - ... + def process(self, row: types.Row) -> None: ... class SupportsClose(Protocol): - def close(self, error: Exception) -> None: - ... + def close(self, error: Exception) -> None: ... class UserDefinedFunctionLike(Protocol): @@ -83,11 +79,8 @@ class UserDefinedFunctionLike(Protocol): deterministic: bool @property - def returnType(self) -> types.DataType: - ... + def returnType(self) -> types.DataType: ... - def __call__(self, *args: ColumnOrName) -> Column: - ... + def __call__(self, *args: ColumnOrName) -> Column: ... - def asNondeterministic(self) -> "UserDefinedFunctionLike": - ... + def asNondeterministic(self) -> "UserDefinedFunctionLike": ... diff --git a/duckdb/experimental/spark/sql/catalog.py b/duckdb/experimental/spark/sql/catalog.py index ebedb1a1..70fc7b18 100644 --- a/duckdb/experimental/spark/sql/catalog.py +++ b/duckdb/experimental/spark/sql/catalog.py @@ -1,14 +1,15 @@ -from typing import List, NamedTuple, Optional +from typing import NamedTuple, Optional, Union # noqa: D100 + from .session import SparkSession -class Database(NamedTuple): +class Database(NamedTuple): # noqa: D101 name: str description: Optional[str] locationUri: str -class Table(NamedTuple): +class Table(NamedTuple): # noqa: D101 name: str database: Optional[str] description: Optional[str] @@ -16,7 +17,7 @@ class Table(NamedTuple): isTemporary: bool -class Column(NamedTuple): +class Column(NamedTuple): # noqa: D101 name: str description: Optional[str] dataType: str @@ -25,36 +26,36 @@ class Column(NamedTuple): isBucket: bool -class Function(NamedTuple): +class Function(NamedTuple): # noqa: D101 name: str description: Optional[str] className: str isTemporary: bool -class Catalog: - def __init__(self, session: SparkSession): +class Catalog: # noqa: D101 + def __init__(self, session: SparkSession) -> None: # noqa: D107 self._session = session - def listDatabases(self) -> List[Database]: - res = self._session.conn.sql('select database_name from duckdb_databases()').fetchall() + def listDatabases(self) -> list[Database]: # noqa: D102 + res = self._session.conn.sql("select database_name from duckdb_databases()").fetchall() - def transform_to_database(x) -> Database: - return Database(name=x[0], description=None, locationUri='') + def transform_to_database(x: list[str]) -> Database: + return Database(name=x[0], description=None, locationUri="") databases = [transform_to_database(x) for x in res] return databases - def listTables(self) -> List[Table]: - res = self._session.conn.sql('select table_name, database_name, sql, temporary from duckdb_tables()').fetchall() + def listTables(self) -> list[Table]: # noqa: D102 + res = self._session.conn.sql("select table_name, database_name, sql, temporary from duckdb_tables()").fetchall() - def transform_to_table(x) -> Table: - return Table(name=x[0], database=x[1], description=x[2], tableType='', isTemporary=x[3]) + def transform_to_table(x: list[str]) -> Table: + return Table(name=x[0], database=x[1], description=x[2], tableType="", isTemporary=x[3]) tables = [transform_to_table(x) for x in res] return tables - def listColumns(self, tableName: str, dbName: Optional[str] = None) -> List[Column]: + def listColumns(self, tableName: str, dbName: Optional[str] = None) -> list[Column]: # noqa: D102 query = f""" select column_name, data_type, is_nullable from duckdb_columns() where table_name = '{tableName}' """ @@ -62,17 +63,17 @@ def listColumns(self, tableName: str, dbName: Optional[str] = None) -> List[Colu query += f" and database_name = '{dbName}'" res = self._session.conn.sql(query).fetchall() - def transform_to_column(x) -> Column: + def transform_to_column(x: list[Union[str, bool]]) -> Column: return Column(name=x[0], description=None, dataType=x[1], nullable=x[2], isPartition=False, isBucket=False) columns = [transform_to_column(x) for x in res] return columns - def listFunctions(self, dbName: Optional[str] = None) -> List[Function]: + def listFunctions(self, dbName: Optional[str] = None) -> list[Function]: # noqa: D102 raise NotImplementedError - def setCurrentDatabase(self, dbName: str) -> None: + def setCurrentDatabase(self, dbName: str) -> None: # noqa: D102 raise NotImplementedError -__all__ = ["Catalog", "Table", "Column", "Function", "Database"] +__all__ = ["Catalog", "Column", "Database", "Function", "Table"] diff --git a/duckdb/experimental/spark/sql/column.py b/duckdb/experimental/spark/sql/column.py index 5f0b2b99..661e4da7 100644 --- a/duckdb/experimental/spark/sql/column.py +++ b/duckdb/experimental/spark/sql/column.py @@ -1,19 +1,19 @@ -from typing import Union, TYPE_CHECKING, Any, cast, Callable, Tuple -from ..exception import ContributionsAcceptedError +from collections.abc import Iterable # noqa: D100 +from typing import TYPE_CHECKING, Any, Callable, Union, cast +from ..exception import ContributionsAcceptedError from .types import DataType if TYPE_CHECKING: - from ._typing import ColumnOrName, LiteralType, DecimalLiteral, DateTimeLiteral + from ._typing import DateTimeLiteral, DecimalLiteral, LiteralType -from duckdb import ConstantExpression, ColumnExpression, FunctionExpression, Expression - -from duckdb.typing import DuckDBPyType +from duckdb import ColumnExpression, ConstantExpression, Expression, FunctionExpression +from duckdb.sqltypes import DuckDBPyType __all__ = ["Column"] -def _get_expr(x) -> Expression: +def _get_expr(x: Union["Column", str]) -> Expression: return x.expr if isinstance(x, Column) else ConstantExpression(x) @@ -30,7 +30,7 @@ def _unary_op( name: str, doc: str = "unary operator", ) -> Callable[["Column"], "Column"]: - """Create a method for given unary operator""" + """Create a method for given unary operator.""" def _(self: "Column") -> "Column": # Call the function identified by 'name' on the internal Expression object @@ -45,7 +45,7 @@ def _bin_op( name: str, doc: str = "binary operator", ) -> Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"]], "Column"]: - """Create a method for given binary operator""" + """Create a method for given binary operator.""" def _( self: "Column", @@ -63,7 +63,7 @@ def _bin_func( name: str, doc: str = "binary function", ) -> Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"]], "Column"]: - """Create a function expression for the given binary function""" + """Create a function expression for the given binary function.""" def _( self: "Column", @@ -78,8 +78,7 @@ def _( class Column: - """ - A column in a DataFrame. + """A column in a DataFrame. :class:`Column` instances can be created by:: @@ -95,11 +94,11 @@ class Column: .. versionadded:: 1.3.0 """ - def __init__(self, expr: Expression): + def __init__(self, expr: Expression) -> None: # noqa: D107 self.expr = expr # arithmetic operators - def __neg__(self): + def __neg__(self) -> "Column": # noqa: D105 return Column(-self.expr) # `and`, `or`, `not` cannot be overloaded in Python, @@ -138,9 +137,8 @@ def __neg__(self): __rpow__ = _bin_op("__rpow__") - def __getitem__(self, k: Any) -> "Column": - """ - An expression that gets an item at position ``ordinal`` out of a list, + def __getitem__(self, k: Any) -> "Column": # noqa: ANN401 + """An expression that gets an item at position ``ordinal`` out of a list, or gets an item by key out of a dict. .. versionadded:: 1.3.0 @@ -153,35 +151,34 @@ def __getitem__(self, k: Any) -> "Column": k a literal value, or a slice object without step. - Returns + Returns: ------- :class:`Column` Column representing the item got by key out of a dict, or substrings sliced by the given slice object. - Examples + Examples: -------- - >>> df = spark.createDataFrame([('abcedfg', {"key": "value"})], ["l", "d"]) - >>> df.select(df.l[slice(1, 3)], df.d['key']).show() + >>> df = spark.createDataFrame([("abcedfg", {"key": "value"})], ["l", "d"]) + >>> df.select(df.l[slice(1, 3)], df.d["key"]).show() +------------------+------+ |substring(l, 1, 3)|d[key]| +------------------+------+ | abc| value| +------------------+------+ - """ + """ # noqa: D205 if isinstance(k, slice): raise ContributionsAcceptedError # if k.step is not None: # raise ValueError("Using a slice with a step value is not supported") # return self.substr(k.start, k.stop) else: - # FIXME: this is super hacky + # TODO: this is super hacky # noqa: TD002, TD003 expr_str = str(self.expr) + "." + str(k) return Column(ColumnExpression(expr_str)) - def __getattr__(self, item: Any) -> "Column": - """ - An expression that gets an item at position ``ordinal`` out of a list, + def __getattr__(self, item: Any) -> "Column": # noqa: ANN401 + """An expression that gets an item at position ``ordinal`` out of a list, or gets an item by key out of a dict. Parameters @@ -189,55 +186,53 @@ def __getattr__(self, item: Any) -> "Column": item a literal value. - Returns + Returns: ------- :class:`Column` Column representing the item got by key out of a dict. - Examples + Examples: -------- - >>> df = spark.createDataFrame([('abcedfg', {"key": "value"})], ["l", "d"]) + >>> df = spark.createDataFrame([("abcedfg", {"key": "value"})], ["l", "d"]) >>> df.select(df.d.key).show() +------+ |d[key]| +------+ | value| +------+ - """ + """ # noqa: D205 if item.startswith("__"): - raise AttributeError("Can not access __ (dunder) method") + msg = "Can not access __ (dunder) method" + raise AttributeError(msg) return self[item] - def alias(self, alias: str): + def alias(self, alias: str) -> "Column": # noqa: D102 return Column(self.expr.alias(alias)) - def when(self, condition: "Column", value: Any): + def when(self, condition: "Column", value: Union["Column", str]) -> "Column": # noqa: D102 if not isinstance(condition, Column): - raise TypeError("condition should be a Column") + msg = "condition should be a Column" + raise TypeError(msg) v = _get_expr(value) expr = self.expr.when(condition.expr, v) return Column(expr) - def otherwise(self, value: Any): + def otherwise(self, value: Union["Column", str]) -> "Column": # noqa: D102 v = _get_expr(value) expr = self.expr.otherwise(v) return Column(expr) - def cast(self, dataType: Union[DataType, str]) -> "Column": - if isinstance(dataType, str): - # Try to construct a default DuckDBPyType from it - internal_type = DuckDBPyType(dataType) - else: - internal_type = dataType.duckdb_type + def cast(self, dataType: Union[DataType, str]) -> "Column": # noqa: D102 + internal_type = DuckDBPyType(dataType) if isinstance(dataType, str) else dataType.duckdb_type return Column(self.expr.cast(internal_type)) - def isin(self, *cols: Any) -> "Column": + def isin(self, *cols: Union[Iterable[Union["Column", str]], Union["Column", str]]) -> "Column": # noqa: D102 if len(cols) == 1 and isinstance(cols[0], (list, set)): # Only one argument supplied, it's a list - cols = cast(Tuple, cols[0]) + cols = cast("tuple", cols[0]) cols = cast( - Tuple, + "tuple", [_get_expr(c) for c in cols], ) return Column(self.expr.isin(*cols)) @@ -247,14 +242,14 @@ def __eq__( # type: ignore[override] self, other: Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"], ) -> "Column": - """binary function""" + """Binary function.""" return Column(self.expr == (_get_expr(other))) def __ne__( # type: ignore[override] self, - other: Any, + other: object, ) -> "Column": - """binary function""" + """Binary function.""" return Column(self.expr != (_get_expr(other))) __lt__ = _bin_op("__lt__") @@ -347,22 +342,20 @@ def __ne__( # type: ignore[override] nulls_first = _unary_op("nulls_first") nulls_last = _unary_op("nulls_last") - - def asc_nulls_first(self) -> "Column": + def asc_nulls_first(self) -> "Column": # noqa: D102 return self.asc().nulls_first() - def asc_nulls_last(self) -> "Column": + def asc_nulls_last(self) -> "Column": # noqa: D102 return self.asc().nulls_last() - def desc_nulls_first(self) -> "Column": + def desc_nulls_first(self) -> "Column": # noqa: D102 return self.desc().nulls_first() - def desc_nulls_last(self) -> "Column": + def desc_nulls_last(self) -> "Column": # noqa: D102 return self.desc().nulls_last() - def isNull(self) -> "Column": + def isNull(self) -> "Column": # noqa: D102 return Column(self.expr.isnull()) - def isNotNull(self) -> "Column": + def isNotNull(self) -> "Column": # noqa: D102 return Column(self.expr.isnotnull()) - diff --git a/duckdb/experimental/spark/sql/conf.py b/duckdb/experimental/spark/sql/conf.py index 98b773fb..e44f2566 100644 --- a/duckdb/experimental/spark/sql/conf.py +++ b/duckdb/experimental/spark/sql/conf.py @@ -1,22 +1,23 @@ -from typing import Optional, Union -from duckdb.experimental.spark._globals import _NoValueType, _NoValue +from typing import Optional, Union # noqa: D100 + from duckdb import DuckDBPyConnection +from duckdb.experimental.spark._globals import _NoValue, _NoValueType -class RuntimeConfig: - def __init__(self, connection: DuckDBPyConnection): +class RuntimeConfig: # noqa: D101 + def __init__(self, connection: DuckDBPyConnection) -> None: # noqa: D107 self._connection = connection - def set(self, key: str, value: str) -> None: + def set(self, key: str, value: str) -> None: # noqa: D102 raise NotImplementedError - def isModifiable(self, key: str) -> bool: + def isModifiable(self, key: str) -> bool: # noqa: D102 raise NotImplementedError - def unset(self, key: str) -> None: + def unset(self, key: str) -> None: # noqa: D102 raise NotImplementedError - def get(self, key: str, default: Union[Optional[str], _NoValueType] = _NoValue) -> str: + def get(self, key: str, default: Union[Optional[str], _NoValueType] = _NoValue) -> str: # noqa: D102 raise NotImplementedError diff --git a/duckdb/experimental/spark/sql/dataframe.py b/duckdb/experimental/spark/sql/dataframe.py index a81a423b..9dba64e4 100644 --- a/duckdb/experimental/spark/sql/dataframe.py +++ b/duckdb/experimental/spark/sql/dataframe.py @@ -1,24 +1,20 @@ +import uuid # noqa: D100 from functools import reduce +from keyword import iskeyword from typing import ( TYPE_CHECKING, Any, Callable, - List, - Dict, Optional, - Tuple, Union, cast, overload, ) -import uuid -from keyword import iskeyword import duckdb from duckdb import ColumnExpression, Expression, StarExpression -from ._typing import ColumnOrName -from ..errors import PySparkTypeError, PySparkValueError, PySparkIndexError +from ..errors import PySparkIndexError, PySparkTypeError, PySparkValueError from ..exception import ContributionsAcceptedError from .column import Column from .readwriter import DataFrameWriter @@ -29,43 +25,42 @@ import pyarrow as pa from pandas.core.frame import DataFrame as PandasDataFrame - from .group import GroupedData, Grouping + from ._typing import ColumnOrName + from .group import GroupedData from .session import SparkSession -from ..errors import PySparkValueError -from .functions import _to_column_expr, col, lit +from duckdb.experimental.spark.sql import functions as spark_sql_functions -class DataFrame: - def __init__(self, relation: duckdb.DuckDBPyRelation, session: "SparkSession"): +class DataFrame: # noqa: D101 + def __init__(self, relation: duckdb.DuckDBPyRelation, session: "SparkSession") -> None: # noqa: D107 self.relation = relation self.session = session self._schema = None if self.relation is not None: self._schema = duckdb_to_spark_schema(self.relation.columns, self.relation.types) - def show(self, **kwargs) -> None: + def show(self, **kwargs) -> None: # noqa: D102 self.relation.show() - def toPandas(self) -> "PandasDataFrame": + def toPandas(self) -> "PandasDataFrame": # noqa: D102 return self.relation.df() def toArrow(self) -> "pa.Table": - """ - Returns the contents of this :class:`DataFrame` as PyArrow ``pyarrow.Table``. + """Returns the contents of this :class:`DataFrame` as PyArrow ``pyarrow.Table``. This is only available if PyArrow is installed and available. .. versionadded:: 4.0.0 - Notes + Notes: ----- This method should only be used if the resulting PyArrow ``pyarrow.Table`` is expected to be small, as all the data is loaded into the driver's memory. This API is a developer API. - Examples + Examples: -------- >>> df.toArrow() # doctest: +SKIP pyarrow.Table @@ -88,7 +83,7 @@ def createOrReplaceTempView(self, name: str) -> None: name : str Name of the view. - Examples + Examples: -------- Create a local temporary view named 'people'. @@ -108,12 +103,13 @@ def createOrReplaceTempView(self, name: str) -> None: """ self.relation.create_view(name, True) - def createGlobalTempView(self, name: str) -> None: + def createGlobalTempView(self, name: str) -> None: # noqa: D102 raise NotImplementedError - def withColumnRenamed(self, columnName: str, newName: str) -> "DataFrame": + def withColumnRenamed(self, columnName: str, newName: str) -> "DataFrame": # noqa: D102 if columnName not in self.relation: - raise ValueError(f"DataFrame does not contain a column named {columnName}") + msg = f"DataFrame does not contain a column named {columnName}" + raise ValueError(msg) cols = [] for x in self.relation.columns: col = ColumnExpression(x) @@ -123,7 +119,7 @@ def withColumnRenamed(self, columnName: str, newName: str) -> "DataFrame": rel = self.relation.select(*cols) return DataFrame(rel, self.session) - def withColumn(self, columnName: str, col: Column) -> "DataFrame": + def withColumn(self, columnName: str, col: Column) -> "DataFrame": # noqa: D102 if not isinstance(col, Column): raise PySparkTypeError( error_class="NOT_COLUMN", @@ -143,9 +139,8 @@ def withColumn(self, columnName: str, col: Column) -> "DataFrame": rel = self.relation.select(*cols) return DataFrame(rel, self.session) - def withColumns(self, *colsMap: Dict[str, Column]) -> "DataFrame": - """ - Returns a new :class:`DataFrame` by adding multiple columns or replacing the + def withColumns(self, *colsMap: dict[str, Column]) -> "DataFrame": + """Returns a new :class:`DataFrame` by adding multiple columns or replacing the existing columns that have the same names. The colsMap is a map of column name and column, the column must only refer to attributes @@ -162,22 +157,22 @@ def withColumns(self, *colsMap: Dict[str, Column]) -> "DataFrame": colsMap : dict a dict of column name and :class:`Column`. Currently, only a single map is supported. - Returns + Returns: ------- :class:`DataFrame` DataFrame with new or replaced columns. - Examples + Examples: -------- >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"]) - >>> df.withColumns({'age2': df.age + 2, 'age3': df.age + 3}).show() + >>> df.withColumns({"age2": df.age + 2, "age3": df.age + 3}).show() +---+-----+----+----+ |age| name|age2|age3| +---+-----+----+----+ | 2|Alice| 4| 5| | 5| Bob| 7| 8| +---+-----+----+----+ - """ + """ # noqa: D205 # Below code is to help enable kwargs in future. assert len(colsMap) == 1 colsMap = colsMap[0] # type: ignore[assignment] @@ -218,9 +213,8 @@ def withColumns(self, *colsMap: Dict[str, Column]) -> "DataFrame": rel = self.relation.select(*cols) return DataFrame(rel, self.session) - def withColumnsRenamed(self, colsMap: Dict[str, str]) -> "DataFrame": - """ - Returns a new :class:`DataFrame` by renaming multiple columns. + def withColumnsRenamed(self, colsMap: dict[str, str]) -> "DataFrame": + """Returns a new :class:`DataFrame` by renaming multiple columns. This is a no-op if the schema doesn't contain the given column names. .. versionadded:: 3.4.0 @@ -232,31 +226,31 @@ def withColumnsRenamed(self, colsMap: Dict[str, str]) -> "DataFrame": a dict of existing column names and corresponding desired column names. Currently, only a single map is supported. - Returns + Returns: ------- :class:`DataFrame` DataFrame with renamed columns. - See Also + See Also: -------- :meth:`withColumnRenamed` - Notes + Notes: ----- Support Spark Connect - Examples + Examples: -------- >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"]) - >>> df = df.withColumns({'age2': df.age + 2, 'age3': df.age + 3}) - >>> df.withColumnsRenamed({'age2': 'age4', 'age3': 'age5'}).show() + >>> df = df.withColumns({"age2": df.age + 2, "age3": df.age + 3}) + >>> df.withColumnsRenamed({"age2": "age4", "age3": "age5"}).show() +---+-----+----+----+ |age| name|age4|age5| +---+-----+----+----+ | 2|Alice| 4| 5| | 5| Bob| 7| 8| +---+-----+----+----+ - """ + """ # noqa: D205 if not isinstance(colsMap, dict): raise PySparkTypeError( error_class="NOT_DICT", @@ -265,9 +259,8 @@ def withColumnsRenamed(self, colsMap: Dict[str, str]) -> "DataFrame": unknown_columns = set(colsMap.keys()) - set(self.relation.columns) if unknown_columns: - raise ValueError( - f"DataFrame does not contain column(s): {', '.join(unknown_columns)}" - ) + msg = f"DataFrame does not contain column(s): {', '.join(unknown_columns)}" + raise ValueError(msg) # Compute this only once old_column_names = list(colsMap.keys()) @@ -289,11 +282,7 @@ def withColumnsRenamed(self, colsMap: Dict[str, str]) -> "DataFrame": rel = self.relation.select(*cols) return DataFrame(rel, self.session) - - - def transform( - self, func: Callable[..., "DataFrame"], *args: Any, **kwargs: Any - ) -> "DataFrame": + def transform(self, func: Callable[..., "DataFrame"], *args: Any, **kwargs: Any) -> "DataFrame": # noqa: ANN401 """Returns a new :class:`DataFrame`. Concise syntax for chaining custom transformations. .. versionadded:: 3.0.0 @@ -314,21 +303,19 @@ def transform( .. versionadded:: 3.3.0 - Returns + Returns: ------- :class:`DataFrame` Transformed DataFrame. - Examples + Examples: -------- >>> from pyspark.sql.functions import col >>> df = spark.createDataFrame([(1, 1.0), (2, 2.0)], ["int", "float"]) >>> def cast_all_to_int(input_df): ... return input_df.select([col(col_name).cast("int") for col_name in input_df.columns]) - ... >>> def sort_columns_asc(input_df): ... return input_df.select(*sorted(input_df.columns)) - ... >>> df.transform(cast_all_to_int).transform(sort_columns_asc).show() +-----+---+ |float|int| @@ -338,8 +325,9 @@ def transform( +-----+---+ >>> def add_n(input_df, n): - ... return input_df.select([(col(col_name) + n).alias(col_name) - ... for col_name in input_df.columns]) + ... return input_df.select( + ... [(col(col_name) + n).alias(col_name) for col_name in input_df.columns] + ... ) >>> df.transform(add_n, 1).transform(add_n, n=10).show() +---+-----+ |int|float| @@ -350,14 +338,11 @@ def transform( """ result = func(self, *args, **kwargs) assert isinstance(result, DataFrame), ( - "Func returned an instance of type [%s], " - "should have been DataFrame." % type(result) + f"Func returned an instance of type [{type(result)}], should have been DataFrame." ) return result - def sort( - self, *cols: Union[str, Column, List[Union[str, Column]]], **kwargs: Any - ) -> "DataFrame": + def sort(self, *cols: Union[str, Column, list[Union[str, Column]]], **kwargs: Any) -> "DataFrame": # noqa: ANN401 """Returns a new :class:`DataFrame` sorted by the specified column(s). Parameters @@ -372,16 +357,15 @@ def sort( Sort ascending vs. descending. Specify list for multiple sort orders. If a list is specified, the length of the list must equal the length of the `cols`. - Returns + Returns: ------- :class:`DataFrame` Sorted DataFrame. - Examples + Examples: -------- >>> from pyspark.sql.functions import desc, asc - >>> df = spark.createDataFrame([ - ... (2, "Alice"), (5, "Bob")], schema=["age", "name"]) + >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"]) Sort the DataFrame in ascending order. @@ -419,8 +403,9 @@ def sort( Specify multiple columns - >>> df = spark.createDataFrame([ - ... (2, "Alice"), (2, "Bob"), (5, "Bob")], schema=["age", "name"]) + >>> df = spark.createDataFrame( + ... [(2, "Alice"), (2, "Bob"), (5, "Bob")], schema=["age", "name"] + ... ) >>> df.orderBy(desc("age"), "name").show() +---+-----+ |age| name| @@ -453,7 +438,7 @@ def sort( for c in cols: _c = c if isinstance(c, str): - _c = col(c) + _c = spark_sql_functions.col(c) elif isinstance(c, int) and not isinstance(c, bool): # ordinal is 1-based if c > 0: @@ -481,13 +466,13 @@ def sort( message_parameters={"arg_name": "ascending", "arg_type": type(ascending).__name__}, ) - columns = [_to_column_expr(c) for c in columns] + columns = [spark_sql_functions._to_column_expr(c) for c in columns] rel = self.relation.sort(*columns) return DataFrame(rel, self.session) orderBy = sort - def head(self, n: Optional[int] = None) -> Union[Optional[Row], List[Row]]: + def head(self, n: Optional[int] = None) -> Union[Optional[Row], list[Row]]: # noqa: D102 if n is None: rs = self.head(1) return rs[0] if rs else None @@ -495,7 +480,7 @@ def head(self, n: Optional[int] = None) -> Union[Optional[Row], List[Row]]: first = head - def take(self, num: int) -> List[Row]: + def take(self, num: int) -> list[Row]: # noqa: D102 return self.limit(num).collect() def filter(self, condition: "ColumnOrName") -> "DataFrame": @@ -509,15 +494,14 @@ def filter(self, condition: "ColumnOrName") -> "DataFrame": a :class:`Column` of :class:`types.BooleanType` or a string of SQL expressions. - Returns + Returns: ------- :class:`DataFrame` Filtered DataFrame. - Examples + Examples: -------- - >>> df = spark.createDataFrame([ - ... (2, "Alice"), (5, "Bob")], schema=["age", "name"]) + >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"]) Filter by :class:`Column` instances. @@ -563,38 +547,34 @@ def filter(self, condition: "ColumnOrName") -> "DataFrame": where = filter - def select(self, *cols) -> "DataFrame": + def select(self, *cols) -> "DataFrame": # noqa: D102 cols = list(cols) if len(cols) == 1: cols = cols[0] if isinstance(cols, list): - projections = [ - x.expr if isinstance(x, Column) else ColumnExpression(x) for x in cols - ] + projections = [x.expr if isinstance(x, Column) else ColumnExpression(x) for x in cols] else: - projections = [ - cols.expr if isinstance(cols, Column) else ColumnExpression(cols) - ] + projections = [cols.expr if isinstance(cols, Column) else ColumnExpression(cols)] rel = self.relation.select(*projections) return DataFrame(rel, self.session) @property - def columns(self) -> List[str]: + def columns(self) -> list[str]: """Returns all column names as a list. - Examples + Examples: -------- >>> df.columns ['age', 'name'] """ return [f.name for f in self.schema.fields] - def _ipython_key_completions_(self) -> List[str]: + def _ipython_key_completions_(self) -> list[str]: # Provides tab-completion for column names in PySpark DataFrame # when accessed in bracket notation, e.g. df['] return self.columns - def __dir__(self) -> List[str]: + def __dir__(self) -> list[str]: # noqa: D105 out = set(super().__dir__()) out.update(c for c in self.columns if c.isidentifier() and not iskeyword(c)) return sorted(out) @@ -602,7 +582,7 @@ def __dir__(self) -> List[str]: def join( self, other: "DataFrame", - on: Optional[Union[str, List[str], Column, List[Column]]] = None, + on: Optional[Union[str, list[str], Column, list[Column]]] = None, how: Optional[str] = None, ) -> "DataFrame": """Joins with another :class:`DataFrame`, using the given join expression. @@ -622,12 +602,12 @@ def join( ``right``, ``rightouter``, ``right_outer``, ``semi``, ``leftsemi``, ``left_semi``, ``anti``, ``leftanti`` and ``left_anti``. - Returns + Returns: ------- :class:`DataFrame` Joined DataFrame. - Examples + Examples: -------- The following performs a full outer join between ``df1`` and ``df2``. @@ -636,22 +616,24 @@ def join( >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")]).toDF("age", "name") >>> df2 = spark.createDataFrame([Row(height=80, name="Tom"), Row(height=85, name="Bob")]) >>> df3 = spark.createDataFrame([Row(age=2, name="Alice"), Row(age=5, name="Bob")]) - >>> df4 = spark.createDataFrame([ - ... Row(age=10, height=80, name="Alice"), - ... Row(age=5, height=None, name="Bob"), - ... Row(age=None, height=None, name="Tom"), - ... Row(age=None, height=None, name=None), - ... ]) + >>> df4 = spark.createDataFrame( + ... [ + ... Row(age=10, height=80, name="Alice"), + ... Row(age=5, height=None, name="Bob"), + ... Row(age=None, height=None, name="Tom"), + ... Row(age=None, height=None, name=None), + ... ] + ... ) Inner join on columns (default) - >>> df.join(df2, 'name').select(df.name, df2.height).show() + >>> df.join(df2, "name").select(df.name, df2.height).show() +----+------+ |name|height| +----+------+ | Bob| 85| +----+------+ - >>> df.join(df4, ['name', 'age']).select(df.name, df.age).show() + >>> df.join(df4, ["name", "age"]).select(df.name, df.age).show() +----+---+ |name|age| +----+---+ @@ -660,8 +642,9 @@ def join( Outer join for both DataFrames on the 'name' column. - >>> df.join(df2, df.name == df2.name, 'outer').select( - ... df.name, df2.height).sort(desc("name")).show() + >>> df.join(df2, df.name == df2.name, "outer").select(df.name, df2.height).sort( + ... desc("name") + ... ).show() +-----+------+ | name|height| +-----+------+ @@ -669,7 +652,7 @@ def join( |Alice| NULL| | NULL| 80| +-----+------+ - >>> df.join(df2, 'name', 'outer').select('name', 'height').sort(desc("name")).show() + >>> df.join(df2, "name", "outer").select("name", "height").sort(desc("name")).show() +-----+------+ | name|height| +-----+------+ @@ -680,11 +663,9 @@ def join( Outer join for both DataFrams with multiple columns. - >>> df.join( - ... df3, - ... [df.name == df3.name, df.age == df3.age], - ... 'outer' - ... ).select(df.name, df3.age).show() + >>> df.join(df3, [df.name == df3.name, df.age == df3.age], "outer").select( + ... df.name, df3.age + ... ).show() +-----+---+ | name|age| +-----+---+ @@ -692,20 +673,16 @@ def join( | Bob| 5| +-----+---+ """ - if on is not None and not isinstance(on, list): on = [on] # type: ignore[assignment] - if on is not None and not all([isinstance(x, str) for x in on]): + if on is not None and not all(isinstance(x, str) for x in on): assert isinstance(on, list) # Get (or create) the Expressions from the list of Columns - on = [_to_column_expr(x) for x in on] + on = [spark_sql_functions._to_column_expr(x) for x in on] # & all the Expressions together to form one Expression - assert isinstance( - on[0], Expression - ), "on should be Column or list of Column" - on = reduce(lambda x, y: x.__and__(y), cast(List[Expression], on)) - + assert isinstance(on[0], Expression), "on should be Column or list of Column" + on = reduce(lambda x, y: x.__and__(y), cast("list[Expression]", on)) if on is None and how is None: result = self.relation.join(other.relation) @@ -714,14 +691,14 @@ def join( how = "inner" if on is None: on = "true" - elif isinstance(on, list) and all([isinstance(x, str) for x in on]): + elif isinstance(on, list) and all(isinstance(x, str) for x in on): # Passed directly through as a list of strings on = on else: on = str(on) assert isinstance(how, str), "how should be a string" - def map_to_recognized_jointype(how): + def map_to_recognized_jointype(how: str) -> str: known_aliases = { "inner": [], "outer": ["full", "fullouter", "full_outer"], @@ -730,15 +707,10 @@ def map_to_recognized_jointype(how): "anti": ["leftanti", "left_anti"], "semi": ["leftsemi", "left_semi"], } - mapped_type = None for type, aliases in known_aliases.items(): if how == type or how in aliases: - mapped_type = type - break - - if not mapped_type: - mapped_type = how - return mapped_type + return type + return how how = map_to_recognized_jointype(how) result = self.relation.join(other.relation, on, how) @@ -757,18 +729,16 @@ def crossJoin(self, other: "DataFrame") -> "DataFrame": other : :class:`DataFrame` Right side of the cartesian product. - Returns + Returns: ------- :class:`DataFrame` Joined DataFrame. - Examples + Examples: -------- >>> from pyspark.sql import Row - >>> df = spark.createDataFrame( - ... [(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) - >>> df2 = spark.createDataFrame( - ... [Row(height=80, name="Tom"), Row(height=85, name="Bob")]) + >>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) + >>> df2 = spark.createDataFrame([Row(height=80, name="Tom"), Row(height=85, name="Bob")]) >>> df.crossJoin(df2.select("height")).select("age", "name", "height").show() +---+-----+------+ |age| name|height| @@ -791,21 +761,21 @@ def alias(self, alias: str) -> "DataFrame": alias : str an alias name to be set for the :class:`DataFrame`. - Returns + Returns: ------- :class:`DataFrame` Aliased DataFrame. - Examples + Examples: -------- >>> from pyspark.sql.functions import col, desc - >>> df = spark.createDataFrame( - ... [(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) + >>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) >>> df_as1 = df.alias("df_as1") >>> df_as2 = df.alias("df_as2") - >>> joined_df = df_as1.join(df_as2, col("df_as1.name") == col("df_as2.name"), 'inner') - >>> joined_df.select( - ... "df_as1.name", "df_as2.name", "df_as2.age").sort(desc("df_as1.name")).show() + >>> joined_df = df_as1.join(df_as2, col("df_as1.name") == col("df_as2.name"), "inner") + >>> joined_df.select("df_as1.name", "df_as2.name", "df_as2.age").sort( + ... desc("df_as1.name") + ... ).show() +-----+-----+---+ | name| name|age| +-----+-----+---+ @@ -817,7 +787,7 @@ def alias(self, alias: str) -> "DataFrame": assert isinstance(alias, str), "alias should be a string" return DataFrame(self.relation.set_alias(alias), self.session) - def drop(self, *cols: "ColumnOrName") -> "DataFrame": # type: ignore[misc] + def drop(self, *cols: "ColumnOrName") -> "DataFrame": # type: ignore[misc] # noqa: D102 exclude = [] for col in cols: if isinstance(col, str): @@ -834,7 +804,7 @@ def drop(self, *cols: "ColumnOrName") -> "DataFrame": # type: ignore[misc] expr = StarExpression(exclude=exclude) return DataFrame(self.relation.select(expr), self.session) - def __repr__(self) -> str: + def __repr__(self) -> str: # noqa: D105 return str(self.relation) def limit(self, num: int) -> "DataFrame": @@ -846,15 +816,14 @@ def limit(self, num: int) -> "DataFrame": Number of records to return. Will return this number of records or all records if the DataFrame contains less than this number of records. - Returns + Returns: ------- :class:`DataFrame` Subset of the records - Examples + Examples: -------- - >>> df = spark.createDataFrame( - ... [(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) + >>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) >>> df.limit(1).show() +---+----+ |age|name| @@ -870,17 +839,15 @@ def limit(self, num: int) -> "DataFrame": rel = self.relation.limit(num) return DataFrame(rel, self.session) - def __contains__(self, item: str): - """ - Check if the :class:`DataFrame` contains a column by the name of `item` - """ + def __contains__(self, item: str) -> bool: + """Check if the :class:`DataFrame` contains a column by the name of `item`.""" return item in self.relation @property def schema(self) -> StructType: """Returns the schema of this :class:`DataFrame` as a :class:`duckdb.experimental.spark.sql.types.StructType`. - Examples + Examples: -------- >>> df.schema StructType([StructField('age', IntegerType(), True), @@ -889,25 +856,21 @@ def schema(self) -> StructType: return self._schema @overload - def __getitem__(self, item: Union[int, str]) -> Column: - ... + def __getitem__(self, item: Union[int, str]) -> Column: ... @overload - def __getitem__(self, item: Union[Column, List, Tuple]) -> "DataFrame": - ... + def __getitem__(self, item: Union[Column, list, tuple]) -> "DataFrame": ... - def __getitem__( - self, item: Union[int, str, Column, List, Tuple] - ) -> Union[Column, "DataFrame"]: + def __getitem__(self, item: Union[int, str, Column, list, tuple]) -> Union[Column, "DataFrame"]: """Returns the column as a :class:`Column`. - Examples + Examples: -------- - >>> df.select(df['age']).collect() + >>> df.select(df["age"]).collect() [Row(age=2), Row(age=5)] - >>> df[ ["name", "age"]].collect() + >>> df[["name", "age"]].collect() [Row(name='Alice', age=2), Row(name='Bob', age=5)] - >>> df[ df.age > 3 ].collect() + >>> df[df.age > 3].collect() [Row(age=5, name='Bob')] >>> df[df[0] > 3].collect() [Row(age=5, name='Bob')] @@ -919,31 +882,29 @@ def __getitem__( elif isinstance(item, (list, tuple)): return self.select(*item) elif isinstance(item, int): - return col(self._schema[item].name) + return spark_sql_functions.col(self._schema[item].name) else: - raise TypeError(f"Unexpected item type: {type(item)}") + msg = f"Unexpected item type: {type(item)}" + raise TypeError(msg) def __getattr__(self, name: str) -> Column: """Returns the :class:`Column` denoted by ``name``. - Examples + Examples: -------- >>> df.select(df.age).collect() [Row(age=2), Row(age=5)] """ if name not in self.relation.columns: - raise AttributeError( - "'%s' object has no attribute '%s'" % (self.__class__.__name__, name) - ) + msg = f"'{self.__class__.__name__}' object has no attribute '{name}'" + raise AttributeError(msg) return Column(duckdb.ColumnExpression(self.relation.alias, name)) @overload - def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": - ... + def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": ... @overload - def groupBy(self, __cols: Union[List[Column], List[str]]) -> "GroupedData": - ... + def groupBy(self, __cols: Union[list[Column], list[str]]) -> "GroupedData": ... # noqa: PYI063 def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc] """Groups the :class:`DataFrame` using the specified columns, @@ -959,15 +920,16 @@ def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc] Each element should be a column name (string) or an expression (:class:`Column`) or list of them. - Returns + Returns: ------- :class:`GroupedData` Grouped data by given columns. - Examples + Examples: -------- - >>> df = spark.createDataFrame([ - ... (2, "Alice"), (2, "Bob"), (2, "Bob"), (5, "Bob")], schema=["age", "name"]) + >>> df = spark.createDataFrame( + ... [(2, "Alice"), (2, "Bob"), (2, "Bob"), (5, "Bob")], schema=["age", "name"] + ... ) Empty grouping columns triggers a global aggregation. @@ -1008,22 +970,19 @@ def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc] | Bob| 2| 2| | Bob| 5| 1| +-----+---+-----+ - """ + """ # noqa: D205 from .group import GroupedData, Grouping - if len(cols) == 1 and isinstance(cols[0], list): - columns = cols[0] - else: - columns = cols + columns = cols[0] if len(cols) == 1 and isinstance(cols[0], list) else cols return GroupedData(Grouping(*columns), self) groupby = groupBy @property - def write(self) -> DataFrameWriter: + def write(self) -> DataFrameWriter: # noqa: D102 return DataFrameWriter(self) - def printSchema(self): + def printSchema(self) -> None: # noqa: D102 raise ContributionsAcceptedError def union(self, other: "DataFrame") -> "DataFrame": @@ -1035,22 +994,22 @@ def union(self, other: "DataFrame") -> "DataFrame": other : :class:`DataFrame` Another :class:`DataFrame` that needs to be unioned - Returns + Returns: ------- :class:`DataFrame` - See Also + See Also: -------- DataFrame.unionAll - Notes + Notes: ----- This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union (that does deduplication of elements), use this function followed by :func:`distinct`. Also as standard in SQL, this function resolves columns by position (not by name). - Examples + Examples: -------- >>> df1 = spark.createDataFrame([[1, 2, 3]], ["col0", "col1", "col2"]) >>> df2 = spark.createDataFrame([[4, 5, 6]], ["col1", "col2", "col0"]) @@ -1068,14 +1027,12 @@ def union(self, other: "DataFrame") -> "DataFrame": | 1| 2| 3| | 1| 2| 3| +----+----+----+ - """ + """ # noqa: D205 return DataFrame(self.relation.union(other.relation), self.session) unionAll = union - def unionByName( - self, other: "DataFrame", allowMissingColumns: bool = False - ) -> "DataFrame": + def unionByName(self, other: "DataFrame", allowMissingColumns: bool = False) -> "DataFrame": """Returns a new :class:`DataFrame` containing union of rows in this and another :class:`DataFrame`. @@ -1096,12 +1053,12 @@ def unionByName( .. versionadded:: 3.1.0 - Returns + Returns: ------- :class:`DataFrame` Combined DataFrame. - Examples + Examples: -------- The difference between this function and :func:`union` is that this function resolves columns by name (not by position): @@ -1130,14 +1087,14 @@ def unionByName( | 1| 2| 3|NULL| |NULL| 4| 5| 6| +----+----+----+----+ - """ + """ # noqa: D205 if allowMissingColumns: cols = [] for col in self.relation.columns: if col in other.relation.columns: cols.append(col) else: - cols.append(lit(None)) + cols.append(spark_sql_functions.lit(None)) other = other.select(*cols) else: other = other.select(*self.relation.columns) @@ -1160,16 +1117,16 @@ def intersect(self, other: "DataFrame") -> "DataFrame": other : :class:`DataFrame` Another :class:`DataFrame` that needs to be combined. - Returns + Returns: ------- :class:`DataFrame` Combined DataFrame. - Notes + Notes: ----- This is equivalent to `INTERSECT` in SQL. - Examples + Examples: -------- >>> df1 = spark.createDataFrame([("a", 1), ("a", 1), ("b", 3), ("c", 4)], ["C1", "C2"]) >>> df2 = spark.createDataFrame([("a", 1), ("a", 1), ("b", 3)], ["C1", "C2"]) @@ -1180,7 +1137,7 @@ def intersect(self, other: "DataFrame") -> "DataFrame": | b| 3| | a| 1| +---+---+ - """ + """ # noqa: D205 return self.intersectAll(other).drop_duplicates() def intersectAll(self, other: "DataFrame") -> "DataFrame": @@ -1200,12 +1157,12 @@ def intersectAll(self, other: "DataFrame") -> "DataFrame": other : :class:`DataFrame` Another :class:`DataFrame` that needs to be combined. - Returns + Returns: ------- :class:`DataFrame` Combined DataFrame. - Examples + Examples: -------- >>> df1 = spark.createDataFrame([("a", 1), ("a", 1), ("b", 3), ("c", 4)], ["C1", "C2"]) >>> df2 = spark.createDataFrame([("a", 1), ("a", 1), ("b", 3)], ["C1", "C2"]) @@ -1217,7 +1174,7 @@ def intersectAll(self, other: "DataFrame") -> "DataFrame": | a| 1| | b| 3| +---+---+ - """ + """ # noqa: D205 return DataFrame(self.relation.intersect(other.relation), self.session) def exceptAll(self, other: "DataFrame") -> "DataFrame": @@ -1237,14 +1194,15 @@ def exceptAll(self, other: "DataFrame") -> "DataFrame": other : :class:`DataFrame` The other :class:`DataFrame` to compare to. - Returns + Returns: ------- :class:`DataFrame` - Examples + Examples: -------- >>> df1 = spark.createDataFrame( - ... [("a", 1), ("a", 1), ("a", 1), ("a", 2), ("b", 3), ("c", 4)], ["C1", "C2"]) + ... [("a", 1), ("a", 1), ("a", 1), ("a", 2), ("b", 3), ("c", 4)], ["C1", "C2"] + ... ) >>> df2 = spark.createDataFrame([("a", 1), ("b", 3)], ["C1", "C2"]) >>> df1.exceptAll(df2).show() +---+---+ @@ -1256,10 +1214,10 @@ def exceptAll(self, other: "DataFrame") -> "DataFrame": | c| 4| +---+---+ - """ + """ # noqa: D205 return DataFrame(self.relation.except_(other.relation), self.session) - def dropDuplicates(self, subset: Optional[List[str]] = None) -> "DataFrame": + def dropDuplicates(self, subset: Optional[list[str]] = None) -> "DataFrame": """Return a new :class:`DataFrame` with duplicate rows removed, optionally only considering certain columns. @@ -1276,19 +1234,21 @@ def dropDuplicates(self, subset: Optional[List[str]] = None) -> "DataFrame": subset : List of column names, optional List of columns to use for duplicate comparison (default All columns). - Returns + Returns: ------- :class:`DataFrame` DataFrame without duplicates. - Examples + Examples: -------- >>> from pyspark.sql import Row - >>> df = spark.createDataFrame([ - ... Row(name='Alice', age=5, height=80), - ... Row(name='Alice', age=5, height=80), - ... Row(name='Alice', age=10, height=80) - ... ]) + >>> df = spark.createDataFrame( + ... [ + ... Row(name="Alice", age=5, height=80), + ... Row(name="Alice", age=5, height=80), + ... Row(name="Alice", age=10, height=80), + ... ] + ... ) Deduplicate the same rows. @@ -1302,16 +1262,16 @@ def dropDuplicates(self, subset: Optional[List[str]] = None) -> "DataFrame": Deduplicate values on 'name' and 'height' columns. - >>> df.dropDuplicates(['name', 'height']).show() + >>> df.dropDuplicates(["name", "height"]).show() +-----+---+------+ | name|age|height| +-----+---+------+ |Alice| 5| 80| +-----+---+------+ - """ + """ # noqa: D205 if subset: rn_col = f"tmp_col_{uuid.uuid1().hex}" - subset_str = ', '.join([f'"{c}"' for c in subset]) + subset_str = ", ".join([f'"{c}"' for c in subset]) window_spec = f"OVER(PARTITION BY {subset_str}) AS {rn_col}" df = DataFrame(self.relation.row_number(window_spec, "*"), self.session) return df.filter(f"{rn_col} = 1").drop(rn_col) @@ -1320,19 +1280,17 @@ def dropDuplicates(self, subset: Optional[List[str]] = None) -> "DataFrame": drop_duplicates = dropDuplicates - def distinct(self) -> "DataFrame": """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`. - Returns + Returns: ------- :class:`DataFrame` DataFrame with distinct records. - Examples + Examples: -------- - >>> df = spark.createDataFrame( - ... [(14, "Tom"), (23, "Alice"), (23, "Alice")], ["age", "name"]) + >>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (23, "Alice")], ["age", "name"]) Return the number of distinct rows in the :class:`DataFrame` @@ -1345,15 +1303,14 @@ def distinct(self) -> "DataFrame": def count(self) -> int: """Returns the number of rows in this :class:`DataFrame`. - Returns + Returns: ------- int Number of rows. - Examples + Examples: -------- - >>> df = spark.createDataFrame( - ... [(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) + >>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) Return the number of rows in the :class:`DataFrame`. @@ -1369,33 +1326,28 @@ def _cast_types(self, *types) -> "DataFrame": assert types_count == len(existing_columns) cast_expressions = [ - f"{existing}::{target_type} as {existing}" - for existing, target_type in zip(existing_columns, types) + f"{existing}::{target_type} as {existing}" for existing, target_type in zip(existing_columns, types) ] cast_expressions = ", ".join(cast_expressions) new_rel = self.relation.project(cast_expressions) return DataFrame(new_rel, self.session) - def toDF(self, *cols) -> "DataFrame": + def toDF(self, *cols) -> "DataFrame": # noqa: D102 existing_columns = self.relation.columns column_count = len(cols) if column_count != len(existing_columns): - raise PySparkValueError( - message="Provided column names and number of columns in the DataFrame don't match" - ) + raise PySparkValueError(message="Provided column names and number of columns in the DataFrame don't match") existing_columns = [ColumnExpression(x) for x in existing_columns] - projections = [ - existing.alias(new) for existing, new in zip(existing_columns, cols) - ] + projections = [existing.alias(new) for existing, new in zip(existing_columns, cols)] new_rel = self.relation.project(*projections) return DataFrame(new_rel, self.session) - def collect(self) -> List[Row]: + def collect(self) -> list[Row]: # noqa: D102 columns = self.relation.columns result = self.relation.fetchall() - def construct_row(values, names) -> Row: + def construct_row(values: list, names: list[str]) -> Row: row = tuple.__new__(Row, list(values)) row.__fields__ = list(names) return row @@ -1411,16 +1363,16 @@ def cache(self) -> "DataFrame": .. versionchanged:: 3.4.0 Supports Spark Connect. - Notes + Notes: ----- The default storage level has changed to `MEMORY_AND_DISK_DESER` to match Scala in 3.0. - Returns + Returns: ------- :class:`DataFrame` Cached DataFrame. - Examples + Examples: -------- >>> df = spark.range(1) >>> df.cache() diff --git a/duckdb/experimental/spark/sql/functions.py b/duckdb/experimental/spark/sql/functions.py index fecada95..79a2a8e2 100644 --- a/duckdb/experimental/spark/sql/functions.py +++ b/duckdb/experimental/spark/sql/functions.py @@ -1,5 +1,5 @@ -import warnings -from typing import Any, Callable, Union, overload, Optional, List, Tuple, TYPE_CHECKING +import warnings # noqa: D100 +from typing import TYPE_CHECKING, Any, Callable, Optional, Union, overload from duckdb import ( CaseExpression, @@ -11,32 +11,31 @@ LambdaExpression, SQLExpression, ) + if TYPE_CHECKING: from .dataframe import DataFrame from ..errors import PySparkTypeError from ..exception import ContributionsAcceptedError +from . import types as _types from ._typing import ColumnOrName from .column import Column, _get_expr -from . import types as _types def _invoke_function_over_columns(name: str, *cols: "ColumnOrName") -> Column: - """ - Invokes n-ary JVM function identified by name + """Invokes n-ary JVM function identified by name and wraps the result with :class:`~pyspark.sql.Column`. - """ + """ # noqa: D205 cols = [_to_column_expr(expr) for expr in cols] return _invoke_function(name, *cols) -def col(column: str): +def col(column: str) -> Column: # noqa: D103 return Column(ColumnExpression(column)) def upper(col: "ColumnOrName") -> Column: - """ - Converts a string expression to upper case. + """Converts a string expression to upper case. .. versionadded:: 1.5.0 @@ -48,12 +47,12 @@ def upper(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` upper case values. - Examples + Examples: -------- >>> df = spark.createDataFrame(["Spark", "PySpark", "Pandas API"], "STRING") >>> df.select(upper("value")).show() @@ -69,8 +68,7 @@ def upper(col: "ColumnOrName") -> Column: def ucase(str: "ColumnOrName") -> Column: - """ - Returns `str` with all characters changed to uppercase. + """Returns `str` with all characters changed to uppercase. .. versionadded:: 3.5.0 @@ -79,7 +77,7 @@ def ucase(str: "ColumnOrName") -> Column: str : :class:`~pyspark.sql.Column` or str Input column or strings. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.range(1).select(sf.ucase(sf.lit("Spark"))).show() @@ -92,28 +90,25 @@ def ucase(str: "ColumnOrName") -> Column: return upper(str) -def when(condition: "Column", value: Any) -> Column: +def when(condition: "Column", value: Union[Column, str]) -> Column: # noqa: D103 if not isinstance(condition, Column): - raise TypeError("condition should be a Column") + msg = "condition should be a Column" + raise TypeError(msg) v = _get_expr(value) expr = CaseExpression(condition.expr, v) return Column(expr) -def _inner_expr_or_val(val): +def _inner_expr_or_val(val: Union[Column, str]) -> Union[Column, str]: return val.expr if isinstance(val, Column) else val -def struct(*cols: Column) -> Column: - return Column( - FunctionExpression("struct_pack", *[_inner_expr_or_val(x) for x in cols]) - ) +def struct(*cols: Column) -> Column: # noqa: D103 + return Column(FunctionExpression("struct_pack", *[_inner_expr_or_val(x) for x in cols])) -def array( - *cols: Union["ColumnOrName", Union[List["ColumnOrName"], Tuple["ColumnOrName", ...]]] -) -> Column: - """Creates a new array column. +def array(*cols: Union["ColumnOrName", Union[list["ColumnOrName"], tuple["ColumnOrName", ...]]]) -> Column: + r"""Creates a new array column. .. versionadded:: 1.4.0 @@ -126,19 +121,19 @@ def array( column names or :class:`~pyspark.sql.Column`\\s that have the same data type. - Returns + Returns: ------- :class:`~pyspark.sql.Column` a column of array type. - Examples + Examples: -------- >>> df = spark.createDataFrame([("Alice", 2), ("Bob", 5)], ("name", "age")) - >>> df.select(array('age', 'age').alias("arr")).collect() + >>> df.select(array("age", "age").alias("arr")).collect() [Row(arr=[2, 2]), Row(arr=[5, 5])] >>> df.select(array([df.age, df.age]).alias("arr")).collect() [Row(arr=[2, 2]), Row(arr=[5, 5])] - >>> df.select(array('age', 'age').alias("col")).printSchema() + >>> df.select(array("age", "age").alias("col")).printSchema() root |-- col: array (nullable = false) | |-- element: long (containsNull = true) @@ -148,11 +143,11 @@ def array( return _invoke_function_over_columns("list_value", *cols) -def lit(col: Any) -> Column: +def lit(col: Any) -> Column: # noqa: D103, ANN401 return col if isinstance(col, Column) else Column(ConstantExpression(col)) -def _invoke_function(function: str, *arguments): +def _invoke_function(function: str, *arguments) -> Column: return Column(FunctionExpression(function, *arguments)) @@ -167,15 +162,16 @@ def _to_column_expr(col: ColumnOrName) -> Expression: message_parameters={"arg_name": "col", "arg_type": type(col).__name__}, ) + def regexp_replace(str: "ColumnOrName", pattern: str, replacement: str) -> Column: r"""Replace all substrings of the specified string value that match regexp with rep. .. versionadded:: 1.5.0 - Examples + Examples: -------- - >>> df = spark.createDataFrame([('100-200',)], ['str']) - >>> df.select(regexp_replace('str', r'(\d+)', '--').alias('d')).collect() + >>> df = spark.createDataFrame([("100-200",)], ["str"]) + >>> df.select(regexp_replace("str", r"(\d+)", "--").alias("d")).collect() [Row(d='-----')] """ return _invoke_function( @@ -187,11 +183,8 @@ def regexp_replace(str: "ColumnOrName", pattern: str, replacement: str) -> Colum ) -def slice( - x: "ColumnOrName", start: Union["ColumnOrName", int], length: Union["ColumnOrName", int] -) -> Column: - """ - Collection function: returns an array containing all the elements in `x` from index `start` +def slice(x: "ColumnOrName", start: Union["ColumnOrName", int], length: Union["ColumnOrName", int]) -> Column: + """Collection function: returns an array containing all the elements in `x` from index `start` (array indices start at 1, or from the end if `start` is negative) with the specified `length`. .. versionadded:: 2.4.0 @@ -208,17 +201,17 @@ def slice( length : :class:`~pyspark.sql.Column` or str or int column name, column, or int containing the length of the slice - Returns + Returns: ------- :class:`~pyspark.sql.Column` a column of array type. Subset of array. - Examples + Examples: -------- - >>> df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ['x']) + >>> df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ["x"]) >>> df.select(slice(df.x, 2, 2).alias("sliced")).collect() [Row(sliced=[2, 3]), Row(sliced=[5])] - """ + """ # noqa: D205 start = ConstantExpression(start) if isinstance(start, int) else _to_column_expr(start) length = ConstantExpression(length) if isinstance(length, int) else _to_column_expr(length) @@ -227,61 +220,8 @@ def slice( return _invoke_function("list_slice", _to_column_expr(x), start, end) -def asc(col: "ColumnOrName") -> Column: - """ - Returns a sort expression based on the ascending order of the given column name. - - .. versionadded:: 1.3.0 - - .. versionchanged:: 3.4.0 - Supports Spark Connect. - - Parameters - ---------- - col : :class:`~pyspark.sql.Column` or str - target column to sort by in the ascending order. - - Returns - ------- - :class:`~pyspark.sql.Column` - the column specifying the order. - - Examples - -------- - Sort by the column 'id' in the descending order. - - >>> df = spark.range(5) - >>> df = df.sort(desc("id")) - >>> df.show() - +---+ - | id| - +---+ - | 4| - | 3| - | 2| - | 1| - | 0| - +---+ - - Sort by the column 'id' in the ascending order. - - >>> df.orderBy(asc("id")).show() - +---+ - | id| - +---+ - | 0| - | 1| - | 2| - | 3| - | 4| - +---+ - """ - return Column(_to_column_expr(col)).asc() - - def asc_nulls_first(col: "ColumnOrName") -> Column: - """ - Returns a sort expression based on the ascending order of the given + """Returns a sort expression based on the ascending order of the given column name, and null values return before non-null values. .. versionadded:: 2.4.0 @@ -294,16 +234,14 @@ def asc_nulls_first(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to sort by in the ascending order. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column specifying the order. - Examples + Examples: -------- - >>> df1 = spark.createDataFrame([(1, "Bob"), - ... (0, None), - ... (2, "Alice")], ["age", "name"]) + >>> df1 = spark.createDataFrame([(1, "Bob"), (0, None), (2, "Alice")], ["age", "name"]) >>> df1.sort(asc_nulls_first(df1.name)).show() +---+-----+ |age| name| @@ -313,13 +251,12 @@ def asc_nulls_first(col: "ColumnOrName") -> Column: | 1| Bob| +---+-----+ - """ + """ # noqa: D205 return asc(col).nulls_first() def asc_nulls_last(col: "ColumnOrName") -> Column: - """ - Returns a sort expression based on the ascending order of the given + """Returns a sort expression based on the ascending order of the given column name, and null values appear after non-null values. .. versionadded:: 2.4.0 @@ -332,16 +269,14 @@ def asc_nulls_last(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to sort by in the ascending order. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column specifying the order. - Examples + Examples: -------- - >>> df1 = spark.createDataFrame([(0, None), - ... (1, "Bob"), - ... (2, "Alice")], ["age", "name"]) + >>> df1 = spark.createDataFrame([(0, None), (1, "Bob"), (2, "Alice")], ["age", "name"]) >>> df1.sort(asc_nulls_last(df1.name)).show() +---+-----+ |age| name| @@ -351,50 +286,12 @@ def asc_nulls_last(col: "ColumnOrName") -> Column: | 0| NULL| +---+-----+ - """ + """ # noqa: D205 return asc(col).nulls_last() -def desc(col: "ColumnOrName") -> Column: - """ - Returns a sort expression based on the descending order of the given column name. - - .. versionadded:: 1.3.0 - - .. versionchanged:: 3.4.0 - Supports Spark Connect. - - Parameters - ---------- - col : :class:`~pyspark.sql.Column` or str - target column to sort by in the descending order. - - Returns - ------- - :class:`~pyspark.sql.Column` - the column specifying the order. - - Examples - -------- - Sort by the column 'id' in the descending order. - - >>> spark.range(5).orderBy(desc("id")).show() - +---+ - | id| - +---+ - | 4| - | 3| - | 2| - | 1| - | 0| - +---+ - """ - return Column(_to_column_expr(col)).desc() - - def desc_nulls_first(col: "ColumnOrName") -> Column: - """ - Returns a sort expression based on the descending order of the given + """Returns a sort expression based on the descending order of the given column name, and null values appear before non-null values. .. versionadded:: 2.4.0 @@ -407,16 +304,14 @@ def desc_nulls_first(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to sort by in the descending order. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column specifying the order. - Examples + Examples: -------- - >>> df1 = spark.createDataFrame([(0, None), - ... (1, "Bob"), - ... (2, "Alice")], ["age", "name"]) + >>> df1 = spark.createDataFrame([(0, None), (1, "Bob"), (2, "Alice")], ["age", "name"]) >>> df1.sort(desc_nulls_first(df1.name)).show() +---+-----+ |age| name| @@ -426,13 +321,12 @@ def desc_nulls_first(col: "ColumnOrName") -> Column: | 2|Alice| +---+-----+ - """ + """ # noqa: D205 return desc(col).nulls_first() def desc_nulls_last(col: "ColumnOrName") -> Column: - """ - Returns a sort expression based on the descending order of the given + """Returns a sort expression based on the descending order of the given column name, and null values appear after non-null values. .. versionadded:: 2.4.0 @@ -445,16 +339,14 @@ def desc_nulls_last(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to sort by in the descending order. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column specifying the order. - Examples + Examples: -------- - >>> df1 = spark.createDataFrame([(0, None), - ... (1, "Bob"), - ... (2, "Alice")], ["age", "name"]) + >>> df1 = spark.createDataFrame([(0, None), (1, "Bob"), (2, "Alice")], ["age", "name"]) >>> df1.sort(desc_nulls_last(df1.name)).show() +---+-----+ |age| name| @@ -464,13 +356,12 @@ def desc_nulls_last(col: "ColumnOrName") -> Column: | 0| NULL| +---+-----+ - """ + """ # noqa: D205 return desc(col).nulls_last() def left(str: "ColumnOrName", len: "ColumnOrName") -> Column: - """ - Returns the leftmost `len`(`len` can be string type) characters from the string `str`, + """Returns the leftmost `len`(`len` can be string type) characters from the string `str`, if `len` is less or equal than 0 the result is an empty string. .. versionadded:: 3.5.0 @@ -482,25 +373,30 @@ def left(str: "ColumnOrName", len: "ColumnOrName") -> Column: len : :class:`~pyspark.sql.Column` or str Input column or strings, the leftmost `len`. - Examples + Examples: -------- - >>> df = spark.createDataFrame([("Spark SQL", 3,)], ['a', 'b']) - >>> df.select(left(df.a, df.b).alias('r')).collect() + >>> df = spark.createDataFrame( + ... [ + ... ( + ... "Spark SQL", + ... 3, + ... ) + ... ], + ... ["a", "b"], + ... ) + >>> df.select(left(df.a, df.b).alias("r")).collect() [Row(r='Spa')] - """ + """ # noqa: D205 len = _to_column_expr(len) return Column( CaseExpression(len <= ConstantExpression(0), ConstantExpression("")).otherwise( - FunctionExpression( - "array_slice", _to_column_expr(str), ConstantExpression(0), len - ) + FunctionExpression("array_slice", _to_column_expr(str), ConstantExpression(0), len) ) ) def right(str: "ColumnOrName", len: "ColumnOrName") -> Column: - """ - Returns the rightmost `len`(`len` can be string type) characters from the string `str`, + """Returns the rightmost `len`(`len` can be string type) characters from the string `str`, if `len` is less or equal than 0 the result is an empty string. .. versionadded:: 3.5.0 @@ -512,25 +408,29 @@ def right(str: "ColumnOrName", len: "ColumnOrName") -> Column: len : :class:`~pyspark.sql.Column` or str Input column or strings, the rightmost `len`. - Examples + Examples: -------- - >>> df = spark.createDataFrame([("Spark SQL", 3,)], ['a', 'b']) - >>> df.select(right(df.a, df.b).alias('r')).collect() + >>> df = spark.createDataFrame( + ... [ + ... ( + ... "Spark SQL", + ... 3, + ... ) + ... ], + ... ["a", "b"], + ... ) + >>> df.select(right(df.a, df.b).alias("r")).collect() [Row(r='SQL')] - """ + """ # noqa: D205 len = _to_column_expr(len) return Column( CaseExpression(len <= ConstantExpression(0), ConstantExpression("")).otherwise( - FunctionExpression( - "array_slice", _to_column_expr(str), -len, ConstantExpression(-1) - ) + FunctionExpression("array_slice", _to_column_expr(str), -len, ConstantExpression(-1)) ) ) -def levenshtein( - left: "ColumnOrName", right: "ColumnOrName", threshold: Optional[int] = None -) -> Column: +def levenshtein(left: "ColumnOrName", right: "ColumnOrName", threshold: Optional[int] = None) -> Column: """Computes the Levenshtein distance of the two given strings. .. versionadded:: 1.5.0 @@ -551,17 +451,25 @@ def levenshtein( .. versionchanged: 3.5.0 Added ``threshold`` argument. - Returns + Returns: ------- :class:`~pyspark.sql.Column` Levenshtein distance as integer value. - Examples + Examples: -------- - >>> df0 = spark.createDataFrame([('kitten', 'sitting',)], ['l', 'r']) - >>> df0.select(levenshtein('l', 'r').alias('d')).collect() + >>> df0 = spark.createDataFrame( + ... [ + ... ( + ... "kitten", + ... "sitting", + ... ) + ... ], + ... ["l", "r"], + ... ) + >>> df0.select(levenshtein("l", "r").alias("d")).collect() [Row(d=3)] - >>> df0.select(levenshtein('l', 'r', 2).alias('d')).collect() + >>> df0.select(levenshtein("l", "r", 2).alias("d")).collect() [Row(d=-1)] """ distance = _invoke_function_over_columns("levenshtein", left, right) @@ -569,12 +477,13 @@ def levenshtein( return distance else: distance = _to_column_expr(distance) - return Column(CaseExpression(distance <= ConstantExpression(threshold), distance).otherwise(ConstantExpression(-1))) + return Column( + CaseExpression(distance <= ConstantExpression(threshold), distance).otherwise(ConstantExpression(-1)) + ) def lpad(col: "ColumnOrName", len: int, pad: str) -> Column: - """ - Left-pad the string column to width `len` with `pad`. + """Left-pad the string column to width `len` with `pad`. .. versionadded:: 1.5.0 @@ -590,23 +499,27 @@ def lpad(col: "ColumnOrName", len: int, pad: str) -> Column: pad : str chars to prepend. - Returns + Returns: ------- :class:`~pyspark.sql.Column` left padded result. - Examples + Examples: -------- - >>> df = spark.createDataFrame([('abcd',)], ['s',]) - >>> df.select(lpad(df.s, 6, '#').alias('s')).collect() + >>> df = spark.createDataFrame( + ... [("abcd",)], + ... [ + ... "s", + ... ], + ... ) + >>> df.select(lpad(df.s, 6, "#").alias("s")).collect() [Row(s='##abcd')] """ return _invoke_function("lpad", _to_column_expr(col), ConstantExpression(len), ConstantExpression(pad)) def rpad(col: "ColumnOrName", len: int, pad: str) -> Column: - """ - Right-pad the string column to width `len` with `pad`. + """Right-pad the string column to width `len` with `pad`. .. versionadded:: 1.5.0 @@ -622,23 +535,27 @@ def rpad(col: "ColumnOrName", len: int, pad: str) -> Column: pad : str chars to append. - Returns + Returns: ------- :class:`~pyspark.sql.Column` right padded result. - Examples + Examples: -------- - >>> df = spark.createDataFrame([('abcd',)], ['s',]) - >>> df.select(rpad(df.s, 6, '#').alias('s')).collect() + >>> df = spark.createDataFrame( + ... [("abcd",)], + ... [ + ... "s", + ... ], + ... ) + >>> df.select(rpad(df.s, 6, "#").alias("s")).collect() [Row(s='abcd##')] """ return _invoke_function("rpad", _to_column_expr(col), ConstantExpression(len), ConstantExpression(pad)) def ascii(col: "ColumnOrName") -> Column: - """ - Computes the numeric value of the first character of the string column. + """Computes the numeric value of the first character of the string column. .. versionadded:: 1.5.0 @@ -650,12 +567,12 @@ def ascii(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` numeric value. - Examples + Examples: -------- >>> df = spark.createDataFrame(["Spark", "PySpark", "Pandas API"], "STRING") >>> df.select(ascii("value")).show() @@ -671,8 +588,7 @@ def ascii(col: "ColumnOrName") -> Column: def asin(col: "ColumnOrName") -> Column: - """ - Computes inverse sine of the input column. + """Computes inverse sine of the input column. .. versionadded:: 1.4.0 @@ -684,12 +600,12 @@ def asin(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` inverse sine of `col`, as if computed by `java.lang.Math.asin()` - Examples + Examples: -------- >>> df = spark.createDataFrame([(0,), (2,)]) >>> df.select(asin(df.schema.fieldNames()[0])).show() @@ -701,15 +617,16 @@ def asin(col: "ColumnOrName") -> Column: +--------+ """ col = _to_column_expr(col) - # FIXME: ConstantExpression(float("nan")) gives NULL and not NaN - return Column(CaseExpression((col < -1.0) | (col > 1.0), ConstantExpression(float("nan"))).otherwise(FunctionExpression("asin", col))) + # TODO: ConstantExpression(float("nan")) gives NULL and not NaN # noqa: TD002, TD003 + return Column( + CaseExpression((col < -1.0) | (col > 1.0), ConstantExpression(float("nan"))).otherwise( + FunctionExpression("asin", col) + ) + ) -def like( - str: "ColumnOrName", pattern: "ColumnOrName", escapeChar: Optional["Column"] = None -) -> Column: - """ - Returns true if str matches `pattern` with `escape`, +def like(str: "ColumnOrName", pattern: "ColumnOrName", escapeChar: Optional["Column"] = None) -> Column: + r"""Returns true if str matches `pattern` with `escape`, null if any arguments are null, false otherwise. The default escape character is the '\'. @@ -726,31 +643,24 @@ def like( If an escape character precedes a special symbol or another escape character, the following character is matched literally. It is invalid to escape any other character. - Examples + Examples: -------- - >>> df = spark.createDataFrame([("Spark", "_park")], ['a', 'b']) - >>> df.select(like(df.a, df.b).alias('r')).collect() + >>> df = spark.createDataFrame([("Spark", "_park")], ["a", "b"]) + >>> df.select(like(df.a, df.b).alias("r")).collect() [Row(r=True)] >>> df = spark.createDataFrame( - ... [("%SystemDrive%/Users/John", "/%SystemDrive/%//Users%")], - ... ['a', 'b'] + ... [("%SystemDrive%/Users/John", "/%SystemDrive/%//Users%")], ["a", "b"] ... ) - >>> df.select(like(df.a, df.b, lit('/')).alias('r')).collect() + >>> df.select(like(df.a, df.b, lit("/")).alias("r")).collect() [Row(r=True)] - """ - if escapeChar is None: - escapeChar = ConstantExpression("\\") - else: - escapeChar = _to_column_expr(escapeChar) + """ # noqa: D205 + escapeChar = ConstantExpression("\\") if escapeChar is None else _to_column_expr(escapeChar) return _invoke_function("like_escape", _to_column_expr(str), _to_column_expr(pattern), escapeChar) -def ilike( - str: "ColumnOrName", pattern: "ColumnOrName", escapeChar: Optional["Column"] = None -) -> Column: - """ - Returns true if str matches `pattern` with `escape` case-insensitively, +def ilike(str: "ColumnOrName", pattern: "ColumnOrName", escapeChar: Optional["Column"] = None) -> Column: + r"""Returns true if str matches `pattern` with `escape` case-insensitively, null if any arguments are null, false otherwise. The default escape character is the '\'. @@ -767,29 +677,24 @@ def ilike( If an escape character precedes a special symbol or another escape character, the following character is matched literally. It is invalid to escape any other character. - Examples + Examples: -------- - >>> df = spark.createDataFrame([("Spark", "_park")], ['a', 'b']) - >>> df.select(ilike(df.a, df.b).alias('r')).collect() + >>> df = spark.createDataFrame([("Spark", "_park")], ["a", "b"]) + >>> df.select(ilike(df.a, df.b).alias("r")).collect() [Row(r=True)] >>> df = spark.createDataFrame( - ... [("%SystemDrive%/Users/John", "/%SystemDrive/%//Users%")], - ... ['a', 'b'] + ... [("%SystemDrive%/Users/John", "/%SystemDrive/%//Users%")], ["a", "b"] ... ) - >>> df.select(ilike(df.a, df.b, lit('/')).alias('r')).collect() + >>> df.select(ilike(df.a, df.b, lit("/")).alias("r")).collect() [Row(r=True)] - """ - if escapeChar is None: - escapeChar = ConstantExpression("\\") - else: - escapeChar = _to_column_expr(escapeChar) + """ # noqa: D205 + escapeChar = ConstantExpression("\\") if escapeChar is None else _to_column_expr(escapeChar) return _invoke_function("ilike_escape", _to_column_expr(str), _to_column_expr(pattern), escapeChar) def array_agg(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns a list of objects with duplicates. + """Aggregate function: returns a list of objects with duplicates. .. versionadded:: 3.5.0 @@ -798,30 +703,29 @@ def array_agg(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` list of objects with duplicates. - Examples + Examples: -------- - >>> df = spark.createDataFrame([[1],[1],[2]], ["c"]) - >>> df.agg(array_agg('c').alias('r')).collect() + >>> df = spark.createDataFrame([[1], [1], [2]], ["c"]) + >>> df.agg(array_agg("c").alias("r")).collect() [Row(r=[1, 1, 2])] """ return _invoke_function_over_columns("list", col) def collect_list(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns a list of objects with duplicates. + """Aggregate function: returns a list of objects with duplicates. .. versionadded:: 1.6.0 .. versionchanged:: 3.4.0 Supports Spark Connect. - Notes + Notes: ----- The function is non-deterministic because the order of collected results depends on the order of the rows which may be non-deterministic after a shuffle. @@ -831,23 +735,22 @@ def collect_list(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` list of objects with duplicates. - Examples + Examples: -------- - >>> df2 = spark.createDataFrame([(2,), (5,), (5,)], ('age',)) - >>> df2.agg(collect_list('age')).collect() + >>> df2 = spark.createDataFrame([(2,), (5,), (5,)], ("age",)) + >>> df2.agg(collect_list("age")).collect() [Row(collect_list(age)=[2, 5, 5])] """ return array_agg(col) -def array_append(col: "ColumnOrName", value: Any) -> Column: - """ - Collection function: returns an array of the elements in col1 along +def array_append(col: "ColumnOrName", value: Union[Column, str]) -> Column: + """Collection function: returns an array of the elements in col1 along with the added element in col2 at the last of the array. .. versionadded:: 3.4.0 @@ -859,32 +762,29 @@ def array_append(col: "ColumnOrName", value: Any) -> Column: value : a literal value, or a :class:`~pyspark.sql.Column` expression. - Returns + Returns: ------- :class:`~pyspark.sql.Column` an array of values from first array along with the element. - Notes + Notes: ----- Supports Spark Connect. - Examples + Examples: -------- >>> from pyspark.sql import Row >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2="c")]) >>> df.select(array_append(df.c1, df.c2)).collect() [Row(array_append(c1, c2)=['b', 'a', 'c', 'c'])] - >>> df.select(array_append(df.c1, 'x')).collect() + >>> df.select(array_append(df.c1, "x")).collect() [Row(array_append(c1, x)=['b', 'a', 'c', 'x'])] - """ + """ # noqa: D205 return _invoke_function("list_append", _to_column_expr(col), _get_expr(value)) -def array_insert( - arr: "ColumnOrName", pos: Union["ColumnOrName", int], value: Any -) -> Column: - """ - Collection function: adds an item into a given array at a specified array index. +def array_insert(arr: "ColumnOrName", pos: Union["ColumnOrName", int], value: Union[Column, str]) -> Column: + """Collection function: adds an item into a given array at a specified array index. Array indices start at 1, or start from the end if index is negative. Index above array size appends the array, or prepends the array if index is negative, with 'null' elements. @@ -901,26 +801,25 @@ def array_insert( value : a literal value, or a :class:`~pyspark.sql.Column` expression. - Returns + Returns: ------- :class:`~pyspark.sql.Column` an array of values, including the new specified value - Notes + Notes: ----- Supports Spark Connect. - Examples + Examples: -------- >>> df = spark.createDataFrame( - ... [(['a', 'b', 'c'], 2, 'd'), (['c', 'b', 'a'], -2, 'd')], - ... ['data', 'pos', 'val'] + ... [(["a", "b", "c"], 2, "d"), (["c", "b", "a"], -2, "d")], ["data", "pos", "val"] ... ) - >>> df.select(array_insert(df.data, df.pos.cast('integer'), df.val).alias('data')).collect() + >>> df.select(array_insert(df.data, df.pos.cast("integer"), df.val).alias("data")).collect() [Row(data=['a', 'd', 'b', 'c']), Row(data=['c', 'b', 'd', 'a'])] - >>> df.select(array_insert(df.data, 5, 'hello').alias('data')).collect() + >>> df.select(array_insert(df.data, 5, "hello").alias("data")).collect() [Row(data=['a', 'b', 'c', None, 'hello']), Row(data=['c', 'b', 'a', None, 'hello'])] - """ + """ # noqa: D205 pos = _get_expr(pos) arr = _to_column_expr(arr) # Depending on if the position is positive or not, we need to interpret it differently. @@ -944,9 +843,7 @@ def array_insert( FunctionExpression( "list_resize", FunctionExpression("list_value", None), - FunctionExpression( - "subtract", FunctionExpression("abs", pos), list_length_plus_1 - ), + FunctionExpression("subtract", FunctionExpression("abs", pos), list_length_plus_1), ), arr, ), @@ -964,9 +861,7 @@ def array_insert( "list_slice", list_, 1, - CaseExpression( - pos_is_positive, FunctionExpression("subtract", pos, 1) - ).otherwise(pos), + CaseExpression(pos_is_positive, FunctionExpression("subtract", pos, 1)).otherwise(pos), ), # Here we insert the value at the specified position FunctionExpression("list_value", _get_expr(value)), @@ -975,17 +870,14 @@ def array_insert( FunctionExpression( "list_slice", list_, - CaseExpression(pos_is_positive, pos).otherwise( - FunctionExpression("add", pos, 1) - ), + CaseExpression(pos_is_positive, pos).otherwise(FunctionExpression("add", pos, 1)), -1, ), ) -def array_contains(col: "ColumnOrName", value: Any) -> Column: - """ - Collection function: returns null if the array is null, true if the array contains the +def array_contains(col: "ColumnOrName", value: Union[Column, str]) -> Column: + """Collection function: returns null if the array is null, true if the array contains the given value, and false otherwise. Parameters @@ -995,26 +887,25 @@ def array_contains(col: "ColumnOrName", value: Any) -> Column: value : value or column to check for in array - Returns + Returns: ------- :class:`~pyspark.sql.Column` a column of Boolean type. - Examples + Examples: -------- - >>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data']) + >>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ["data"]) >>> df.select(array_contains(df.data, "a")).collect() [Row(array_contains(data, a)=True), Row(array_contains(data, a)=False)] >>> df.select(array_contains(df.data, lit("a"))).collect() [Row(array_contains(data, a)=True), Row(array_contains(data, a)=False)] - """ + """ # noqa: D205 value = _get_expr(value) return _invoke_function("array_contains", _to_column_expr(col), value) def array_distinct(col: "ColumnOrName") -> Column: - """ - Collection function: removes duplicate values from the array. + """Collection function: removes duplicate values from the array. .. versionadded:: 2.4.0 @@ -1026,14 +917,14 @@ def array_distinct(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str name of column or expression - Returns + Returns: ------- :class:`~pyspark.sql.Column` an array of unique values. - Examples + Examples: -------- - >>> df = spark.createDataFrame([([1, 2, 3, 2],), ([4, 5, 5, 4],)], ['data']) + >>> df = spark.createDataFrame([([1, 2, 3, 2],), ([4, 5, 5, 4],)], ["data"]) >>> df.select(array_distinct(df.data)).collect() [Row(array_distinct(data)=[1, 2, 3]), Row(array_distinct(data)=[4, 5])] """ @@ -1041,8 +932,7 @@ def array_distinct(col: "ColumnOrName") -> Column: def array_intersect(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: - """ - Collection function: returns an array of the elements in the intersection of col1 and col2, + """Collection function: returns an array of the elements in the intersection of col1 and col2, without duplicates. .. versionadded:: 2.4.0 @@ -1057,24 +947,23 @@ def array_intersect(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: col2 : :class:`~pyspark.sql.Column` or str name of column containing array - Returns + Returns: ------- :class:`~pyspark.sql.Column` an array of values in the intersection of two arrays. - Examples + Examples: -------- >>> from pyspark.sql import Row >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) >>> df.select(array_intersect(df.c1, df.c2)).collect() [Row(array_intersect(c1, c2)=['a', 'c'])] - """ + """ # noqa: D205 return _invoke_function_over_columns("array_intersect", col1, col2) def array_union(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: - """ - Collection function: returns an array of the elements in the union of col1 and col2, + """Collection function: returns an array of the elements in the union of col1 and col2, without duplicates. .. versionadded:: 2.4.0 @@ -1089,24 +978,23 @@ def array_union(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: col2 : :class:`~pyspark.sql.Column` or str name of column containing array - Returns + Returns: ------- :class:`~pyspark.sql.Column` an array of values in union of two arrays. - Examples + Examples: -------- >>> from pyspark.sql import Row >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) >>> df.select(array_union(df.c1, df.c2)).collect() [Row(array_union(c1, c2)=['b', 'a', 'c', 'd', 'f'])] - """ + """ # noqa: D205 return _invoke_function_over_columns("array_distinct", _invoke_function_over_columns("array_concat", col1, col2)) def array_max(col: "ColumnOrName") -> Column: - """ - Collection function: returns the maximum value of the array. + """Collection function: returns the maximum value of the array. .. versionadded:: 2.4.0 @@ -1118,23 +1006,24 @@ def array_max(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str name of column or expression - Returns + Returns: ------- :class:`~pyspark.sql.Column` maximum value of an array. - Examples + Examples: -------- - >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data']) - >>> df.select(array_max(df.data).alias('max')).collect() + >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ["data"]) + >>> df.select(array_max(df.data).alias("max")).collect() [Row(max=3), Row(max=10)] """ - return _invoke_function("array_extract", _to_column_expr(_invoke_function_over_columns("array_sort", col)), _get_expr(-1)) + return _invoke_function( + "array_extract", _to_column_expr(_invoke_function_over_columns("array_sort", col)), _get_expr(-1) + ) def array_min(col: "ColumnOrName") -> Column: - """ - Collection function: returns the minimum value of the array. + """Collection function: returns the minimum value of the array. .. versionadded:: 2.4.0 @@ -1146,23 +1035,24 @@ def array_min(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str name of column or expression - Returns + Returns: ------- :class:`~pyspark.sql.Column` minimum value of array. - Examples + Examples: -------- - >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data']) - >>> df.select(array_min(df.data).alias('min')).collect() + >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ["data"]) + >>> df.select(array_min(df.data).alias("min")).collect() [Row(min=1), Row(min=-1)] """ - return _invoke_function("array_extract", _to_column_expr(_invoke_function_over_columns("array_sort", col)), _get_expr(1)) + return _invoke_function( + "array_extract", _to_column_expr(_invoke_function_over_columns("array_sort", col)), _get_expr(1) + ) def avg(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns the average of the values in a group. + """Aggregate function: returns the average of the values in a group. .. versionadded:: 1.3.0 @@ -1174,12 +1064,12 @@ def avg(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column for computed results. - Examples + Examples: -------- >>> df = spark.range(10) >>> df.select(avg(col("id"))).show() @@ -1193,8 +1083,7 @@ def avg(col: "ColumnOrName") -> Column: def sum(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns the sum of all values in the expression. + """Aggregate function: returns the sum of all values in the expression. .. versionadded:: 1.3.0 @@ -1206,12 +1095,12 @@ def sum(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column for computed results. - Examples + Examples: -------- >>> df = spark.range(10) >>> df.select(sum(df["id"])).show() @@ -1225,8 +1114,7 @@ def sum(col: "ColumnOrName") -> Column: def max(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns the maximum value of the expression in a group. + """Aggregate function: returns the maximum value of the expression in a group. .. versionadded:: 1.3.0 @@ -1238,12 +1126,12 @@ def max(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` column for computed results. - Examples + Examples: -------- >>> df = spark.range(10) >>> df.select(max(col("id"))).show() @@ -1257,8 +1145,7 @@ def max(col: "ColumnOrName") -> Column: def mean(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns the average of the values in a group. + """Aggregate function: returns the average of the values in a group. An alias of :func:`avg`. .. versionadded:: 1.4.0 @@ -1271,12 +1158,12 @@ def mean(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column for computed results. - Examples + Examples: -------- >>> df = spark.range(10) >>> df.select(mean(df.id)).show() @@ -1285,13 +1172,12 @@ def mean(col: "ColumnOrName") -> Column: +-------+ | 4.5| +-------+ - """ + """ # noqa: D205 return _invoke_function_over_columns("mean", col) def median(col: "ColumnOrName") -> Column: - """ - Returns the median of the values in a group. + """Returns the median of the values in a group. .. versionadded:: 3.4.0 @@ -1300,22 +1186,28 @@ def median(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the median of the values in a group. - Notes + Notes: ----- Supports Spark Connect. - Examples + Examples: -------- - >>> df = spark.createDataFrame([ - ... ("Java", 2012, 20000), ("dotNET", 2012, 5000), - ... ("Java", 2012, 22000), ("dotNET", 2012, 10000), - ... ("dotNET", 2013, 48000), ("Java", 2013, 30000)], - ... schema=("course", "year", "earnings")) + >>> df = spark.createDataFrame( + ... [ + ... ("Java", 2012, 20000), + ... ("dotNET", 2012, 5000), + ... ("Java", 2012, 22000), + ... ("dotNET", 2012, 10000), + ... ("dotNET", 2013, 48000), + ... ("Java", 2013, 30000), + ... ], + ... schema=("course", "year", "earnings"), + ... ) >>> df.groupby("course").agg(median("earnings")).show() +------+----------------+ |course|median(earnings)| @@ -1328,8 +1220,7 @@ def median(col: "ColumnOrName") -> Column: def mode(col: "ColumnOrName") -> Column: - """ - Returns the most frequent value in a group. + """Returns the most frequent value in a group. .. versionadded:: 3.4.0 @@ -1338,22 +1229,28 @@ def mode(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the most frequent value in a group. - Notes + Notes: ----- Supports Spark Connect. - Examples + Examples: -------- - >>> df = spark.createDataFrame([ - ... ("Java", 2012, 20000), ("dotNET", 2012, 5000), - ... ("Java", 2012, 20000), ("dotNET", 2012, 5000), - ... ("dotNET", 2013, 48000), ("Java", 2013, 30000)], - ... schema=("course", "year", "earnings")) + >>> df = spark.createDataFrame( + ... [ + ... ("Java", 2012, 20000), + ... ("dotNET", 2012, 5000), + ... ("Java", 2012, 20000), + ... ("dotNET", 2012, 5000), + ... ("dotNET", 2013, 48000), + ... ("Java", 2013, 30000), + ... ], + ... schema=("course", "year", "earnings"), + ... ) >>> df.groupby("course").agg(mode("year")).show() +------+----------+ |course|mode(year)| @@ -1366,8 +1263,7 @@ def mode(col: "ColumnOrName") -> Column: def min(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns the minimum value of the expression in a group. + """Aggregate function: returns the minimum value of the expression in a group. .. versionadded:: 1.3.0 @@ -1379,12 +1275,12 @@ def min(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` column for computed results. - Examples + Examples: -------- >>> df = spark.range(10) >>> df.select(min(df.id)).show() @@ -1409,29 +1305,26 @@ def any_value(col: "ColumnOrName") -> Column: ignorenulls : :class:`~pyspark.sql.Column` or bool if first value is null then look for first non-null value. - Returns + Returns: ------- :class:`~pyspark.sql.Column` some value of `col` for a group of rows. - Examples + Examples: -------- - >>> df = spark.createDataFrame([(None, 1), - ... ("a", 2), - ... ("a", 3), - ... ("b", 8), - ... ("b", 2)], ["c1", "c2"]) - >>> df.select(any_value('c1'), any_value('c2')).collect() + >>> df = spark.createDataFrame( + ... [(None, 1), ("a", 2), ("a", 3), ("b", 8), ("b", 2)], ["c1", "c2"] + ... ) + >>> df.select(any_value("c1"), any_value("c2")).collect() [Row(any_value(c1)=None, any_value(c2)=1)] - >>> df.select(any_value('c1', True), any_value('c2', True)).collect() + >>> df.select(any_value("c1", True), any_value("c2", True)).collect() [Row(any_value(c1)='a', any_value(c2)=1)] """ return _invoke_function_over_columns("any_value", col) def count(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns the number of items in a group. + """Aggregate function: returns the number of items in a group. .. versionadded:: 1.3.0 @@ -1443,12 +1336,12 @@ def count(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` column for computed results. - Examples + Examples: -------- Count by all columns (start), and by a column that does not count ``None``. @@ -1479,29 +1372,29 @@ def approx_count_distinct(col: "ColumnOrName", rsd: Optional[float] = None) -> C maximum relative standard deviation allowed (default = 0.05). For rsd < 0.01, it is more efficient to use :func:`count_distinct` - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column of computed results. - Examples + Examples: -------- - >>> df = spark.createDataFrame([1,2,2,3], "INT") - >>> df.agg(approx_count_distinct("value").alias('distinct_values')).show() + >>> df = spark.createDataFrame([1, 2, 2, 3], "INT") + >>> df.agg(approx_count_distinct("value").alias("distinct_values")).show() +---------------+ |distinct_values| +---------------+ | 3| +---------------+ - """ + """ # noqa: D205 if rsd is not None: - raise ValueError("rsd is not supported by DuckDB") + msg = "rsd is not supported by DuckDB" + raise ValueError(msg) return _invoke_function_over_columns("approx_count_distinct", col) def approxCountDistinct(col: "ColumnOrName", rsd: Optional[float] = None) -> Column: - """ - .. versionadded:: 1.3.0 + """.. versionadded:: 1.3.0. .. versionchanged:: 3.4.0 Supports Spark Connect. @@ -1509,7 +1402,7 @@ def approxCountDistinct(col: "ColumnOrName", rsd: Optional[float] = None) -> Col .. deprecated:: 2.1.0 Use :func:`approx_count_distinct` instead. """ - warnings.warn("Deprecated in 2.1, use approx_count_distinct instead.", FutureWarning) + warnings.warn("Deprecated in 2.1, use approx_count_distinct instead.", FutureWarning, stacklevel=3) return approx_count_distinct(col, rsd) @@ -1525,8 +1418,7 @@ def transform( col: "ColumnOrName", f: Union[Callable[[Column], Column], Callable[[Column, Column], Column]], ) -> Column: - """ - Returns an array of elements after applying a transformation to each element in the input array. + """Returns an array of elements after applying a transformation to each element in the input array. .. versionadded:: 3.1.0 @@ -1550,12 +1442,12 @@ def transform( Python ``UserDefinedFunctions`` are not supported (`SPARK-27052 `__). - Returns + Returns: ------- :class:`~pyspark.sql.Column` a new array of transformed elements. - Examples + Examples: -------- >>> df = spark.createDataFrame([(1, [1, 2, 3, 4])], ("key", "values")) >>> df.select(transform("values", lambda x: x * 2).alias("doubled")).show() @@ -1567,7 +1459,6 @@ def transform( >>> def alternate(x, i): ... return when(i % 2 == 0, x).otherwise(-x) - ... >>> df.select(transform("values", alternate).alias("alternated")).show() +--------------+ | alternated| @@ -1579,8 +1470,7 @@ def transform( def concat_ws(sep: str, *cols: "ColumnOrName") -> "Column": - """ - Concatenates multiple input string columns together into a single string column, + """Concatenates multiple input string columns together into a single string column, using the given separator. .. versionadded:: 1.5.0 @@ -1595,24 +1485,23 @@ def concat_ws(sep: str, *cols: "ColumnOrName") -> "Column": cols : :class:`~pyspark.sql.Column` or str list of columns to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` string of concatenated words. - Examples + Examples: -------- - >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd']) - >>> df.select(concat_ws('-', df.s, df.d).alias('s')).collect() + >>> df = spark.createDataFrame([("abcd", "123")], ["s", "d"]) + >>> df.select(concat_ws("-", df.s, df.d).alias("s")).collect() [Row(s='abcd-123')] - """ + """ # noqa: D205 cols = [_to_column_expr(expr) for expr in cols] return _invoke_function("concat_ws", ConstantExpression(sep), *cols) def lower(col: "ColumnOrName") -> Column: - """ - Converts a string expression to lower case. + """Converts a string expression to lower case. .. versionadded:: 1.5.0 @@ -1624,12 +1513,12 @@ def lower(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` lower case values. - Examples + Examples: -------- >>> df = spark.createDataFrame(["Spark", "PySpark", "Pandas API"], "STRING") >>> df.select(lower("value")).show() @@ -1645,8 +1534,7 @@ def lower(col: "ColumnOrName") -> Column: def lcase(str: "ColumnOrName") -> Column: - """ - Returns `str` with all characters changed to lowercase. + """Returns `str` with all characters changed to lowercase. .. versionadded:: 3.5.0 @@ -1655,7 +1543,7 @@ def lcase(str: "ColumnOrName") -> Column: str : :class:`~pyspark.sql.Column` or str Input column or strings. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.range(1).select(sf.lcase(sf.lit("Spark"))).show() @@ -1669,8 +1557,7 @@ def lcase(str: "ColumnOrName") -> Column: def ceil(col: "ColumnOrName") -> Column: - """ - Computes the ceiling of the given value. + """Computes the ceiling of the given value. .. versionadded:: 1.4.0 @@ -1682,12 +1569,12 @@ def ceil(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column for computed results. - Examples + Examples: -------- >>> df = spark.range(1) >>> df.select(ceil(lit(-0.1))).show() @@ -1700,13 +1587,12 @@ def ceil(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("ceil", col) -def ceiling(col: "ColumnOrName") -> Column: +def ceiling(col: "ColumnOrName") -> Column: # noqa: D103 return ceil(col) def floor(col: "ColumnOrName") -> Column: - """ - Computes the floor of the given value. + """Computes the floor of the given value. .. versionadded:: 1.4.0 @@ -1718,12 +1604,12 @@ def floor(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str column to find floor for. - Returns + Returns: ------- :class:`~pyspark.sql.Column` nearest integer that is less than or equal to given value. - Examples + Examples: -------- >>> df = spark.range(1) >>> df.select(floor(lit(2.5))).show() @@ -1737,8 +1623,7 @@ def floor(col: "ColumnOrName") -> Column: def abs(col: "ColumnOrName") -> Column: - """ - Computes the absolute value. + """Computes the absolute value. .. versionadded:: 1.3.0 @@ -1750,12 +1635,12 @@ def abs(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` column for computed results. - Examples + Examples: -------- >>> df = spark.range(1) >>> df.select(abs(lit(-1))).show() @@ -1781,14 +1666,14 @@ def isnan(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` True if value is NaN and False otherwise. - Examples + Examples: -------- - >>> df = spark.createDataFrame([(1.0, float('nan')), (float('nan'), 2.0)], ("a", "b")) + >>> df = spark.createDataFrame([(1.0, float("nan")), (float("nan"), 2.0)], ("a", "b")) >>> df.select("a", "b", isnan("a").alias("r1"), isnan(df.b).alias("r2")).show() +---+---+-----+-----+ | a| b| r1| r2| @@ -1813,12 +1698,12 @@ def isnull(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` True if value is null and False otherwise. - Examples + Examples: -------- >>> df = spark.createDataFrame([(1, None), (None, 2)], ("a", "b")) >>> df.select("a", "b", isnull("a").alias("r1"), isnull(df.b).alias("r2")).show() @@ -1833,8 +1718,7 @@ def isnull(col: "ColumnOrName") -> Column: def isnotnull(col: "ColumnOrName") -> Column: - """ - Returns true if `col` is not null, or false otherwise. + """Returns true if `col` is not null, or false otherwise. .. versionadded:: 3.5.0 @@ -1842,42 +1726,53 @@ def isnotnull(col: "ColumnOrName") -> Column: ---------- col : :class:`~pyspark.sql.Column` or str - Examples + Examples: -------- >>> df = spark.createDataFrame([(None,), (1,)], ["e"]) - >>> df.select(isnotnull(df.e).alias('r')).collect() + >>> df.select(isnotnull(df.e).alias("r")).collect() [Row(r=False), Row(r=True)] """ return Column(_to_column_expr(col).isnotnull()) def equal_null(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: - """ - Returns same result as the EQUAL(=) operator for non-null operands, + """Returns same result as the EQUAL(=) operator for non-null operands, but returns true if both are null, false if one of the them is null. .. versionadded:: 3.5.0 Parameters ---------- col1 : :class:`~pyspark.sql.Column` or str col2 : :class:`~pyspark.sql.Column` or str - Examples + + Examples: -------- - >>> df = spark.createDataFrame([(None, None,), (1, 9,)], ["a", "b"]) - >>> df.select(equal_null(df.a, df.b).alias('r')).collect() + >>> df = spark.createDataFrame( + ... [ + ... ( + ... None, + ... None, + ... ), + ... ( + ... 1, + ... 9, + ... ), + ... ], + ... ["a", "b"], + ... ) + >>> df.select(equal_null(df.a, df.b).alias("r")).collect() [Row(r=True), Row(r=False)] - """ + """ # noqa: D205, D415 if isinstance(col1, str): col1 = col(col1) if isinstance(col2, str): col2 = col(col2) - return nvl((col1 == col2) | ((col1.isNull() & col2.isNull())), lit(False)) + return nvl((col1 == col2) | (col1.isNull() & col2.isNull()), lit(False)) def flatten(col: "ColumnOrName") -> Column: - """ - Collection function: creates a single array from an array of arrays. + """Collection function: creates a single array from an array of arrays. If a structure of nested arrays is deeper than two levels, only one level of nesting is removed. @@ -1891,14 +1786,14 @@ def flatten(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str name of column or expression - Returns + Returns: ------- :class:`~pyspark.sql.Column` flattened array. - Examples + Examples: -------- - >>> df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],), ([None, [4, 5]],)], ['data']) + >>> df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],), ([None, [4, 5]],)], ["data"]) >>> df.show(truncate=False) +------------------------+ |data | @@ -1906,26 +1801,21 @@ def flatten(col: "ColumnOrName") -> Column: |[[1, 2, 3], [4, 5], [6]]| |[NULL, [4, 5]] | +------------------------+ - >>> df.select(flatten(df.data).alias('r')).show() + >>> df.select(flatten(df.data).alias("r")).show() +------------------+ | r| +------------------+ |[1, 2, 3, 4, 5, 6]| | NULL| +------------------+ - """ + """ # noqa: D205 col = _to_column_expr(col) contains_null = _list_contains_null(col) - return Column( - CaseExpression(contains_null, None).otherwise( - FunctionExpression("flatten", col) - ) - ) + return Column(CaseExpression(contains_null, None).otherwise(FunctionExpression("flatten", col))) def array_compact(col: "ColumnOrName") -> Column: - """ - Collection function: removes null values from the array. + """Collection function: removes null values from the array. .. versionadded:: 3.4.0 @@ -1934,18 +1824,18 @@ def array_compact(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str name of column or expression - Returns + Returns: ------- :class:`~pyspark.sql.Column` an array by excluding the null values. - Notes + Notes: ----- Supports Spark Connect. - Examples + Examples: -------- - >>> df = spark.createDataFrame([([1, None, 2, 3],), ([4, 5, None, 4],)], ['data']) + >>> df = spark.createDataFrame([([1, None, 2, 3],), ([4, 5, None, 4],)], ["data"]) >>> df.select(array_compact(df.data)).collect() [Row(array_compact(data)=[1, 2, 3]), Row(array_compact(data)=[4, 5, 4])] """ @@ -1954,9 +1844,8 @@ def array_compact(col: "ColumnOrName") -> Column: ) -def array_remove(col: "ColumnOrName", element: Any) -> Column: - """ - Collection function: Remove all elements that equal to element from the given array. +def array_remove(col: "ColumnOrName", element: Any) -> Column: # noqa: ANN401 + """Collection function: Remove all elements that equal to element from the given array. .. versionadded:: 2.4.0 @@ -1970,23 +1859,24 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column: element : element to be removed from the array - Returns + Returns: ------- :class:`~pyspark.sql.Column` an array excluding given value. - Examples + Examples: -------- - >>> df = spark.createDataFrame([([1, 2, 3, 1, 1],), ([],)], ['data']) + >>> df = spark.createDataFrame([([1, 2, 3, 1, 1],), ([],)], ["data"]) >>> df.select(array_remove(df.data, 1)).collect() [Row(array_remove(data, 1)=[2, 3]), Row(array_remove(data, 1)=[])] """ - return _invoke_function("list_filter", _to_column_expr(col), LambdaExpression("x", ColumnExpression("x") != ConstantExpression(element))) + return _invoke_function( + "list_filter", _to_column_expr(col), LambdaExpression("x", ColumnExpression("x") != ConstantExpression(element)) + ) def last_day(date: "ColumnOrName") -> Column: - """ - Returns the last day of the month which the given date belongs to. + """Returns the last day of the month which the given date belongs to. .. versionadded:: 1.5.0 @@ -1998,24 +1888,22 @@ def last_day(date: "ColumnOrName") -> Column: date : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` last day of the month. - Examples + Examples: -------- - >>> df = spark.createDataFrame([('1997-02-10',)], ['d']) - >>> df.select(last_day(df.d).alias('date')).collect() + >>> df = spark.createDataFrame([("1997-02-10",)], ["d"]) + >>> df.select(last_day(df.d).alias("date")).collect() [Row(date=datetime.date(1997, 2, 28))] """ return _invoke_function("last_day", _to_column_expr(date)) - def sqrt(col: "ColumnOrName") -> Column: - """ - Computes the square root of the specified float value. + """Computes the square root of the specified float value. .. versionadded:: 1.3.0 @@ -2027,12 +1915,12 @@ def sqrt(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` column for computed results. - Examples + Examples: -------- >>> df = spark.range(1) >>> df.select(sqrt(lit(4))).show() @@ -2046,8 +1934,7 @@ def sqrt(col: "ColumnOrName") -> Column: def cbrt(col: "ColumnOrName") -> Column: - """ - Computes the cube-root of the given value. + """Computes the cube-root of the given value. .. versionadded:: 1.4.0 @@ -2059,12 +1946,12 @@ def cbrt(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column for computed results. - Examples + Examples: -------- >>> df = spark.range(1) >>> df.select(cbrt(lit(27))).show() @@ -2078,9 +1965,8 @@ def cbrt(col: "ColumnOrName") -> Column: def char(col: "ColumnOrName") -> Column: - """ - Returns the ASCII character having the binary equivalent to `col`. If col is larger than 256 the - result is equivalent to char(col % 256) + """Returns the ASCII character having the binary equivalent to `col`. If col is larger than 256 the + result is equivalent to char(col % 256). .. versionadded:: 3.5.0 @@ -2089,7 +1975,7 @@ def char(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str Input column or strings. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.range(1).select(sf.char(sf.lit(65))).show() @@ -2098,7 +1984,7 @@ def char(col: "ColumnOrName") -> Column: +--------+ | A| +--------+ - """ + """ # noqa: D205 col = _to_column_expr(col) return Column(FunctionExpression("chr", CaseExpression(col > 256, col % 256).otherwise(col))) @@ -2119,25 +2005,24 @@ def corr(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: col1 : :class:`~pyspark.sql.Column` or str second column to calculate correlation. - Returns + Returns: ------- :class:`~pyspark.sql.Column` Pearson Correlation Coefficient of these two column values. - Examples + Examples: -------- >>> a = range(20) >>> b = [2 * x for x in range(20)] >>> df = spark.createDataFrame(zip(a, b), ["a", "b"]) - >>> df.agg(corr("a", "b").alias('c')).collect() + >>> df.agg(corr("a", "b").alias("c")).collect() [Row(c=1.0)] - """ + """ # noqa: D205 return _invoke_function_over_columns("corr", col1, col2) def cot(col: "ColumnOrName") -> Column: - """ - Computes cotangent of the input column. + """Computes cotangent of the input column. .. versionadded:: 3.3.0 @@ -2149,12 +2034,12 @@ def cot(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str angle in radians. - Returns + Returns: ------- :class:`~pyspark.sql.Column` cotangent of the angle. - Examples + Examples: -------- >>> import math >>> df = spark.range(1) @@ -2169,7 +2054,7 @@ def e() -> Column: .. versionadded:: 3.5.0 - Examples + Examples: -------- >>> spark.range(1).select(e()).show() +-----------------+ @@ -2182,18 +2067,19 @@ def e() -> Column: def negative(col: "ColumnOrName") -> Column: - """ - Returns the negative value. + """Returns the negative value. .. versionadded:: 3.5.0 Parameters ---------- col : :class:`~pyspark.sql.Column` or str column to calculate negative value for. - Returns + + Returns: ------- :class:`~pyspark.sql.Column` negative value. - Examples + + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.range(3).select(sf.negative("id")).show() @@ -2204,7 +2090,7 @@ def negative(col: "ColumnOrName") -> Column: | -1| | -2| +------------+ - """ + """ # noqa: D205, D415 return abs(col) * -1 @@ -2213,7 +2099,7 @@ def pi() -> Column: .. versionadded:: 3.5.0 - Examples + Examples: -------- >>> spark.range(1).select(pi()).show() +-----------------+ @@ -2226,8 +2112,7 @@ def pi() -> Column: def positive(col: "ColumnOrName") -> Column: - """ - Returns the value. + """Returns the value. .. versionadded:: 3.5.0 @@ -2236,14 +2121,14 @@ def positive(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str input value column. - Returns + Returns: ------- :class:`~pyspark.sql.Column` value. - Examples + Examples: -------- - >>> df = spark.createDataFrame([(-1,), (0,), (1,)], ['v']) + >>> df = spark.createDataFrame([(-1,), (0,), (1,)], ["v"]) >>> df.select(positive("v").alias("p")).show() +---+ | p| @@ -2257,8 +2142,7 @@ def positive(col: "ColumnOrName") -> Column: def pow(col1: Union["ColumnOrName", float], col2: Union["ColumnOrName", float]) -> Column: - """ - Returns the value of the first argument raised to the power of the second argument. + """Returns the value of the first argument raised to the power of the second argument. .. versionadded:: 1.4.0 @@ -2272,12 +2156,12 @@ def pow(col1: Union["ColumnOrName", float], col2: Union["ColumnOrName", float]) col2 : str, :class:`~pyspark.sql.Column` or float the exponent number. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the base rased to the power the argument. - Examples + Examples: -------- >>> df = spark.range(1) >>> df.select(pow(lit(3), lit(2))).first() @@ -2287,8 +2171,7 @@ def pow(col1: Union["ColumnOrName", float], col2: Union["ColumnOrName", float]) def printf(format: "ColumnOrName", *cols: "ColumnOrName") -> Column: - """ - Formats the arguments in printf-style and returns the result as a string column. + r"""Formats the arguments in printf-style and returns the result as a string column. .. versionadded:: 3.5.0 @@ -2299,11 +2182,18 @@ def printf(format: "ColumnOrName", *cols: "ColumnOrName") -> Column: cols : :class:`~pyspark.sql.Column` or str column names or :class:`~pyspark.sql.Column`\\s to be used in formatting - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.createDataFrame( - ... [("aa%d%s", 123, "cc",)], ["a", "b", "c"] + ... [ + ... ( + ... "aa%d%s", + ... 123, + ... "cc", + ... ) + ... ], + ... ["a", "b", "c"], ... ).select(sf.printf("a", "b", "c")).show() +---------------+ |printf(a, b, c)| @@ -2315,8 +2205,7 @@ def printf(format: "ColumnOrName", *cols: "ColumnOrName") -> Column: def product(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns the product of the values in a group. + """Aggregate function: returns the product of the values in a group. .. versionadded:: 3.2.0 @@ -2328,16 +2217,16 @@ def product(col: "ColumnOrName") -> Column: col : str, :class:`Column` column containing values to be multiplied together - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column for computed results. - Examples + Examples: -------- - >>> df = spark.range(1, 10).toDF('x').withColumn('mod3', col('x') % 3) - >>> prods = df.groupBy('mod3').agg(product('x').alias('product')) - >>> prods.orderBy('mod3').show() + >>> df = spark.range(1, 10).toDF("x").withColumn("mod3", col("x") % 3) + >>> prods = df.groupBy("mod3").agg(product("x").alias("product")) + >>> prods.orderBy("mod3").show() +----+-------+ |mod3|product| +----+-------+ @@ -2358,7 +2247,7 @@ def rand(seed: Optional[int] = None) -> Column: .. versionchanged:: 3.4.0 Supports Spark Connect. - Notes + Notes: ----- The function is non-deterministic in general case. @@ -2367,25 +2256,26 @@ def rand(seed: Optional[int] = None) -> Column: seed : int (default: None) seed value for random generator. - Returns + Returns: ------- :class:`~pyspark.sql.Column` random values. - Examples + Examples: -------- >>> from pyspark.sql import functions as sf - >>> spark.range(0, 2, 1, 1).withColumn('rand', sf.rand(seed=42) * 3).show() + >>> spark.range(0, 2, 1, 1).withColumn("rand", sf.rand(seed=42) * 3).show() +---+------------------+ | id| rand| +---+------------------+ | 0|1.8575681106759028| | 1|1.5288056527339444| +---+------------------+ - """ + """ # noqa: D205 if seed is not None: # Maybe call setseed just before but how do we know when it is executed? - raise ContributionsAcceptedError("Seed is not yet implemented") + msg = "Seed is not yet implemented" + raise ContributionsAcceptedError(msg) return _invoke_function("random") @@ -2401,17 +2291,17 @@ def regexp(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: regexp : :class:`~pyspark.sql.Column` or str regex pattern to apply. - Returns + Returns: ------- :class:`~pyspark.sql.Column` true if `str` matches a Java regex, or false otherwise. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf - >>> spark.createDataFrame( - ... [("1a 2b 14m", r"(\d+)")], ["str", "regexp"] - ... ).select(sf.regexp('str', sf.lit(r'(\d+)'))).show() + >>> spark.createDataFrame([("1a 2b 14m", r"(\d+)")], ["str", "regexp"]).select( + ... sf.regexp("str", sf.lit(r"(\d+)")) + ... ).show() +------------------+ |REGEXP(str, (\d+))| +------------------+ @@ -2419,9 +2309,9 @@ def regexp(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: +------------------+ >>> import pyspark.sql.functions as sf - >>> spark.createDataFrame( - ... [("1a 2b 14m", r"(\d+)")], ["str", "regexp"] - ... ).select(sf.regexp('str', sf.lit(r'\d{2}b'))).show() + >>> spark.createDataFrame([("1a 2b 14m", r"(\d+)")], ["str", "regexp"]).select( + ... sf.regexp("str", sf.lit(r"\d{2}b")) + ... ).show() +-------------------+ |REGEXP(str, \d{2}b)| +-------------------+ @@ -2429,9 +2319,9 @@ def regexp(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: +-------------------+ >>> import pyspark.sql.functions as sf - >>> spark.createDataFrame( - ... [("1a 2b 14m", r"(\d+)")], ["str", "regexp"] - ... ).select(sf.regexp('str', sf.col("regexp"))).show() + >>> spark.createDataFrame([("1a 2b 14m", r"(\d+)")], ["str", "regexp"]).select( + ... sf.regexp("str", sf.col("regexp")) + ... ).show() +-------------------+ |REGEXP(str, regexp)| +-------------------+ @@ -2454,21 +2344,21 @@ def regexp_count(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: regexp : :class:`~pyspark.sql.Column` or str regex pattern to apply. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the number of times that a Java regex pattern is matched in the string. - Examples + Examples: -------- >>> df = spark.createDataFrame([("1a 2b 14m", r"\d+")], ["str", "regexp"]) - >>> df.select(regexp_count('str', lit(r'\d+')).alias('d')).collect() + >>> df.select(regexp_count("str", lit(r"\d+")).alias("d")).collect() [Row(d=3)] - >>> df.select(regexp_count('str', lit(r'mmm')).alias('d')).collect() + >>> df.select(regexp_count("str", lit(r"mmm")).alias("d")).collect() [Row(d=0)] - >>> df.select(regexp_count("str", col("regexp")).alias('d')).collect() + >>> df.select(regexp_count("str", col("regexp")).alias("d")).collect() [Row(d=3)] - """ + """ # noqa: D205 return _invoke_function_over_columns("len", _invoke_function_over_columns("regexp_extract_all", str, regexp)) @@ -2490,29 +2380,29 @@ def regexp_extract(str: "ColumnOrName", pattern: str, idx: int) -> Column: idx : int matched group id. - Returns + Returns: ------- :class:`~pyspark.sql.Column` matched value specified by `idx` group id. - Examples + Examples: -------- - >>> df = spark.createDataFrame([('100-200',)], ['str']) - >>> df.select(regexp_extract('str', r'(\d+)-(\d+)', 1).alias('d')).collect() + >>> df = spark.createDataFrame([("100-200",)], ["str"]) + >>> df.select(regexp_extract("str", r"(\d+)-(\d+)", 1).alias("d")).collect() [Row(d='100')] - >>> df = spark.createDataFrame([('foo',)], ['str']) - >>> df.select(regexp_extract('str', r'(\d+)', 1).alias('d')).collect() + >>> df = spark.createDataFrame([("foo",)], ["str"]) + >>> df.select(regexp_extract("str", r"(\d+)", 1).alias("d")).collect() [Row(d='')] - >>> df = spark.createDataFrame([('aaaac',)], ['str']) - >>> df.select(regexp_extract('str', '(a+)(b)?(c)', 2).alias('d')).collect() + >>> df = spark.createDataFrame([("aaaac",)], ["str"]) + >>> df.select(regexp_extract("str", "(a+)(b)?(c)", 2).alias("d")).collect() [Row(d='')] - """ - return _invoke_function("regexp_extract", _to_column_expr(str), ConstantExpression(pattern), ConstantExpression(idx)) + """ # noqa: D205 + return _invoke_function( + "regexp_extract", _to_column_expr(str), ConstantExpression(pattern), ConstantExpression(idx) + ) -def regexp_extract_all( - str: "ColumnOrName", regexp: "ColumnOrName", idx: Optional[Union[int, Column]] = None -) -> Column: +def regexp_extract_all(str: "ColumnOrName", regexp: "ColumnOrName", idx: Optional[Union[int, Column]] = None) -> Column: r"""Extract all strings in the `str` that match the Java regex `regexp` and corresponding to the regex group index. @@ -2527,26 +2417,28 @@ def regexp_extract_all( idx : int matched group id. - Returns + Returns: ------- :class:`~pyspark.sql.Column` all strings in the `str` that match a Java regex and corresponding to the regex group index. - Examples + Examples: -------- >>> df = spark.createDataFrame([("100-200, 300-400", r"(\d+)-(\d+)")], ["str", "regexp"]) - >>> df.select(regexp_extract_all('str', lit(r'(\d+)-(\d+)')).alias('d')).collect() + >>> df.select(regexp_extract_all("str", lit(r"(\d+)-(\d+)")).alias("d")).collect() [Row(d=['100', '300'])] - >>> df.select(regexp_extract_all('str', lit(r'(\d+)-(\d+)'), 1).alias('d')).collect() + >>> df.select(regexp_extract_all("str", lit(r"(\d+)-(\d+)"), 1).alias("d")).collect() [Row(d=['100', '300'])] - >>> df.select(regexp_extract_all('str', lit(r'(\d+)-(\d+)'), 2).alias('d')).collect() + >>> df.select(regexp_extract_all("str", lit(r"(\d+)-(\d+)"), 2).alias("d")).collect() [Row(d=['200', '400'])] - >>> df.select(regexp_extract_all('str', col("regexp")).alias('d')).collect() + >>> df.select(regexp_extract_all("str", col("regexp")).alias("d")).collect() [Row(d=['100', '300'])] - """ + """ # noqa: D205 if idx is None: idx = 1 - return _invoke_function("regexp_extract_all", _to_column_expr(str), _to_column_expr(regexp), ConstantExpression(idx)) + return _invoke_function( + "regexp_extract_all", _to_column_expr(str), _to_column_expr(regexp), ConstantExpression(idx) + ) def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: @@ -2561,17 +2453,17 @@ def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: regexp : :class:`~pyspark.sql.Column` or str regex pattern to apply. - Returns + Returns: ------- :class:`~pyspark.sql.Column` true if `str` matches a Java regex, or false otherwise. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf - >>> spark.createDataFrame( - ... [("1a 2b 14m", r"(\d+)")], ["str", "regexp"] - ... ).select(sf.regexp_like('str', sf.lit(r'(\d+)'))).show() + >>> spark.createDataFrame([("1a 2b 14m", r"(\d+)")], ["str", "regexp"]).select( + ... sf.regexp_like("str", sf.lit(r"(\d+)")) + ... ).show() +-----------------------+ |REGEXP_LIKE(str, (\d+))| +-----------------------+ @@ -2579,9 +2471,9 @@ def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: +-----------------------+ >>> import pyspark.sql.functions as sf - >>> spark.createDataFrame( - ... [("1a 2b 14m", r"(\d+)")], ["str", "regexp"] - ... ).select(sf.regexp_like('str', sf.lit(r'\d{2}b'))).show() + >>> spark.createDataFrame([("1a 2b 14m", r"(\d+)")], ["str", "regexp"]).select( + ... sf.regexp_like("str", sf.lit(r"\d{2}b")) + ... ).show() +------------------------+ |REGEXP_LIKE(str, \d{2}b)| +------------------------+ @@ -2589,9 +2481,9 @@ def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: +------------------------+ >>> import pyspark.sql.functions as sf - >>> spark.createDataFrame( - ... [("1a 2b 14m", r"(\d+)")], ["str", "regexp"] - ... ).select(sf.regexp_like('str', sf.col("regexp"))).show() + >>> spark.createDataFrame([("1a 2b 14m", r"(\d+)")], ["str", "regexp"]).select( + ... sf.regexp_like("str", sf.col("regexp")) + ... ).show() +------------------------+ |REGEXP_LIKE(str, regexp)| +------------------------+ @@ -2614,27 +2506,32 @@ def regexp_substr(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: regexp : :class:`~pyspark.sql.Column` or str regex pattern to apply. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the substring that matches a Java regex within the string `str`. - Examples + Examples: -------- >>> df = spark.createDataFrame([("1a 2b 14m", r"\d+")], ["str", "regexp"]) - >>> df.select(regexp_substr('str', lit(r'\d+')).alias('d')).collect() + >>> df.select(regexp_substr("str", lit(r"\d+")).alias("d")).collect() [Row(d='1')] - >>> df.select(regexp_substr('str', lit(r'mmm')).alias('d')).collect() + >>> df.select(regexp_substr("str", lit(r"mmm")).alias("d")).collect() [Row(d=None)] - >>> df.select(regexp_substr("str", col("regexp")).alias('d')).collect() + >>> df.select(regexp_substr("str", col("regexp")).alias("d")).collect() [Row(d='1')] - """ - return Column(FunctionExpression("nullif", FunctionExpression("regexp_extract", _to_column_expr(str), _to_column_expr(regexp)), ConstantExpression(""))) + """ # noqa: D205 + return Column( + FunctionExpression( + "nullif", + FunctionExpression("regexp_extract", _to_column_expr(str), _to_column_expr(regexp)), + ConstantExpression(""), + ) + ) def repeat(col: "ColumnOrName", n: int) -> Column: - """ - Repeats a string column n times, and returns it as a new string column. + """Repeats a string column n times, and returns it as a new string column. .. versionadded:: 1.5.0 @@ -2648,25 +2545,27 @@ def repeat(col: "ColumnOrName", n: int) -> Column: n : int number of times to repeat value. - Returns + Returns: ------- :class:`~pyspark.sql.Column` string with repeated values. - Examples + Examples: -------- - >>> df = spark.createDataFrame([('ab',)], ['s',]) - >>> df.select(repeat(df.s, 3).alias('s')).collect() + >>> df = spark.createDataFrame( + ... [("ab",)], + ... [ + ... "s", + ... ], + ... ) + >>> df.select(repeat(df.s, 3).alias("s")).collect() [Row(s='ababab')] """ return _invoke_function("repeat", _to_column_expr(col), ConstantExpression(n)) -def sequence( - start: "ColumnOrName", stop: "ColumnOrName", step: Optional["ColumnOrName"] = None -) -> Column: - """ - Generate a sequence of integers from `start` to `stop`, incrementing by `step`. +def sequence(start: "ColumnOrName", stop: "ColumnOrName", step: Optional["ColumnOrName"] = None) -> Column: + """Generate a sequence of integers from `start` to `stop`, incrementing by `step`. If `step` is not set, incrementing by 1 if `start` is less than or equal to `stop`, otherwise -1. @@ -2684,20 +2583,20 @@ def sequence( step : :class:`~pyspark.sql.Column` or str, optional value to add to current to get next element (default is 1) - Returns + Returns: ------- :class:`~pyspark.sql.Column` an array of sequence values - Examples + Examples: -------- - >>> df1 = spark.createDataFrame([(-2, 2)], ('C1', 'C2')) - >>> df1.select(sequence('C1', 'C2').alias('r')).collect() + >>> df1 = spark.createDataFrame([(-2, 2)], ("C1", "C2")) + >>> df1.select(sequence("C1", "C2").alias("r")).collect() [Row(r=[-2, -1, 0, 1, 2])] - >>> df2 = spark.createDataFrame([(4, -4, -2)], ('C1', 'C2', 'C3')) - >>> df2.select(sequence('C1', 'C2', 'C3').alias('r')).collect() + >>> df2 = spark.createDataFrame([(4, -4, -2)], ("C1", "C2", "C3")) + >>> df2.select(sequence("C1", "C2", "C3").alias("r")).collect() [Row(r=[4, 2, 0, -2, -4])] - """ + """ # noqa: D205 if step is None: return _invoke_function_over_columns("generate_series", start, stop) else: @@ -2705,8 +2604,7 @@ def sequence( def sign(col: "ColumnOrName") -> Column: - """ - Computes the signum of the given value. + """Computes the signum of the given value. .. versionadded:: 1.4.0 @@ -2718,18 +2616,15 @@ def sign(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column for computed results. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf - >>> spark.range(1).select( - ... sf.sign(sf.lit(-5)), - ... sf.sign(sf.lit(6)) - ... ).show() + >>> spark.range(1).select(sf.sign(sf.lit(-5)), sf.sign(sf.lit(6))).show() +--------+-------+ |sign(-5)|sign(6)| +--------+-------+ @@ -2740,8 +2635,7 @@ def sign(col: "ColumnOrName") -> Column: def signum(col: "ColumnOrName") -> Column: - """ - Computes the signum of the given value. + """Computes the signum of the given value. .. versionadded:: 1.4.0 @@ -2753,18 +2647,15 @@ def signum(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column for computed results. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf - >>> spark.range(1).select( - ... sf.signum(sf.lit(-5)), - ... sf.signum(sf.lit(6)) - ... ).show() + >>> spark.range(1).select(sf.signum(sf.lit(-5)), sf.signum(sf.lit(6))).show() +----------+---------+ |SIGNUM(-5)|SIGNUM(6)| +----------+---------+ @@ -2775,8 +2666,7 @@ def signum(col: "ColumnOrName") -> Column: def sin(col: "ColumnOrName") -> Column: - """ - Computes sine of the input column. + """Computes sine of the input column. .. versionadded:: 1.4.0 @@ -2788,12 +2678,12 @@ def sin(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` sine of the angle, as if computed by `java.lang.Math.sin()` - Examples + Examples: -------- >>> import math >>> df = spark.range(1) @@ -2804,8 +2694,7 @@ def sin(col: "ColumnOrName") -> Column: def skewness(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns the skewness of the values in a group. + """Aggregate function: returns the skewness of the values in a group. .. versionadded:: 1.6.0 @@ -2817,14 +2706,14 @@ def skewness(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` skewness of given column. - Examples + Examples: -------- - >>> df = spark.createDataFrame([[1],[1],[2]], ["c"]) + >>> df = spark.createDataFrame([[1], [1], [2]], ["c"]) >>> df.select(skewness(df.c)).first() Row(skewness(c)=0.70710...) """ @@ -2832,8 +2721,7 @@ def skewness(col: "ColumnOrName") -> Column: def encode(col: "ColumnOrName", charset: str) -> Column: - """ - Computes the first argument into a binary from a string using the provided character set + """Computes the first argument into a binary from a string using the provided character set (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). .. versionadded:: 1.5.0 @@ -2848,29 +2736,29 @@ def encode(col: "ColumnOrName", charset: str) -> Column: charset : str charset to use to encode. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column for computed results. - Examples + Examples: -------- - >>> df = spark.createDataFrame([('abcd',)], ['c']) + >>> df = spark.createDataFrame([("abcd",)], ["c"]) >>> df.select(encode("c", "UTF-8")).show() +----------------+ |encode(c, UTF-8)| +----------------+ | [61 62 63 64]| +----------------+ - """ + """ # noqa: D205 if charset != "UTF-8": - raise ContributionsAcceptedError("Only UTF-8 charset is supported right now") + msg = "Only UTF-8 charset is supported right now" + raise ContributionsAcceptedError(msg) return _invoke_function("encode", _to_column_expr(col)) def find_in_set(str: "ColumnOrName", str_array: "ColumnOrName") -> Column: - """ - Returns the index (1-based) of the given string (`str`) in the comma-delimited + """Returns the index (1-based) of the given string (`str`) in the comma-delimited list (`strArray`). Returns 0, if the string was not found or if the given string (`str`) contains a comma. @@ -2883,26 +2771,22 @@ def find_in_set(str: "ColumnOrName", str_array: "ColumnOrName") -> Column: str_array : :class:`~pyspark.sql.Column` or str The comma-delimited list. - Examples + Examples: -------- - >>> df = spark.createDataFrame([("ab", "abc,b,ab,c,def")], ['a', 'b']) - >>> df.select(find_in_set(df.a, df.b).alias('r')).collect() + >>> df = spark.createDataFrame([("ab", "abc,b,ab,c,def")], ["a", "b"]) + >>> df.select(find_in_set(df.a, df.b).alias("r")).collect() [Row(r=3)] - """ + """ # noqa: D205 str_array = _to_column_expr(str_array) str = _to_column_expr(str) return Column( - CaseExpression( - FunctionExpression("contains", str, ConstantExpression(",")), 0 - ).otherwise( + CaseExpression(FunctionExpression("contains", str, ConstantExpression(",")), 0).otherwise( CoalesceOperator( FunctionExpression( - "list_position", - FunctionExpression("string_split", str_array, ConstantExpression(",")), - str + "list_position", FunctionExpression("string_split", str_array, ConstantExpression(",")), str ), # If the element cannot be found, list_position returns null but we want to return 0 - ConstantExpression(0) + ConstantExpression(0), ) ) ) @@ -2919,7 +2803,7 @@ def first(col: "ColumnOrName", ignorenulls: bool = False) -> Column: .. versionchanged:: 3.4.0 Supports Spark Connect. - Notes + Notes: ----- The function is non-deterministic because its results depends on the order of the rows which may be non-deterministic after a shuffle. @@ -2931,12 +2815,12 @@ def first(col: "ColumnOrName", ignorenulls: bool = False) -> Column: ignorenulls : :class:`~pyspark.sql.Column` or str if first value is null then look for first non-null value. - Returns + Returns: ------- :class:`~pyspark.sql.Column` first value of the group. - Examples + Examples: -------- >>> df = spark.createDataFrame([("Alice", 2), ("Bob", 5), ("Alice", None)], ("name", "age")) >>> df = df.orderBy(df.age) @@ -2974,7 +2858,7 @@ def last(col: "ColumnOrName", ignorenulls: bool = False) -> Column: .. versionchanged:: 3.4.0 Supports Spark Connect. - Notes + Notes: ----- The function is non-deterministic because its results depends on the order of the rows which may be non-deterministic after a shuffle. @@ -2986,12 +2870,12 @@ def last(col: "ColumnOrName", ignorenulls: bool = False) -> Column: ignorenulls : :class:`~pyspark.sql.Column` or str if last value is null then look for non-null value. - Returns + Returns: ------- :class:`~pyspark.sql.Column` last value of the group. - Examples + Examples: -------- >>> df = spark.createDataFrame([("Alice", 2), ("Bob", 5), ("Alice", None)], ("name", "age")) >>> df = df.orderBy(df.age.desc()) @@ -3018,10 +2902,8 @@ def last(col: "ColumnOrName", ignorenulls: bool = False) -> Column: return _invoke_function_over_columns("last", col) - def greatest(*cols: "ColumnOrName") -> Column: - """ - Returns the greatest value of the list of column names, skipping null values. + """Returns the greatest value of the list of column names, skipping null values. This function takes at least 2 parameters. It will return null if all parameters are null. .. versionadded:: 1.5.0 @@ -3034,28 +2916,27 @@ def greatest(*cols: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str columns to check for gratest value. - Returns + Returns: ------- :class:`~pyspark.sql.Column` gratest value. - Examples + Examples: -------- - >>> df = spark.createDataFrame([(1, 4, 3)], ['a', 'b', 'c']) + >>> df = spark.createDataFrame([(1, 4, 3)], ["a", "b", "c"]) >>> df.select(greatest(df.a, df.b, df.c).alias("greatest")).collect() [Row(greatest=4)] - """ - + """ # noqa: D205 if len(cols) < 2: - raise ValueError("greatest should take at least 2 columns") + msg = "greatest should take at least 2 columns" + raise ValueError(msg) cols = [_to_column_expr(expr) for expr in cols] return _invoke_function("greatest", *cols) def least(*cols: "ColumnOrName") -> Column: - """ - Returns the least value of the list of column names, skipping null values. + """Returns the least value of the list of column names, skipping null values. This function takes at least 2 parameters. It will return null if all parameters are null. .. versionadded:: 1.5.0 @@ -3068,27 +2949,27 @@ def least(*cols: "ColumnOrName") -> Column: cols : :class:`~pyspark.sql.Column` or str column names or columns to be compared - Returns + Returns: ------- :class:`~pyspark.sql.Column` least value. - Examples + Examples: -------- - >>> df = spark.createDataFrame([(1, 4, 3)], ['a', 'b', 'c']) + >>> df = spark.createDataFrame([(1, 4, 3)], ["a", "b", "c"]) >>> df.select(least(df.a, df.b, df.c).alias("least")).collect() [Row(least=1)] - """ + """ # noqa: D205 if len(cols) < 2: - raise ValueError("least should take at least 2 columns") + msg = "least should take at least 2 columns" + raise ValueError(msg) cols = [_to_column_expr(expr) for expr in cols] return _invoke_function("least", *cols) def trim(col: "ColumnOrName") -> Column: - """ - Trim the spaces from left end for the specified string value. + """Trim the spaces from left end for the specified string value. .. versionadded:: 1.5.0 @@ -3100,12 +2981,12 @@ def trim(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` left trimmed values. - Examples + Examples: -------- >>> df = spark.createDataFrame([" Spark", "Spark ", " Spark"], "STRING") >>> df.select(ltrim("value").alias("r")).withColumn("length", length("r")).show() @@ -3121,8 +3002,7 @@ def trim(col: "ColumnOrName") -> Column: def rtrim(col: "ColumnOrName") -> Column: - """ - Trim the spaces from right end for the specified string value. + """Trim the spaces from right end for the specified string value. .. versionadded:: 1.5.0 @@ -3134,12 +3014,12 @@ def rtrim(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` right trimmed values. - Examples + Examples: -------- >>> df = spark.createDataFrame([" Spark", "Spark ", " Spark"], "STRING") >>> df.select(rtrim("value").alias("r")).withColumn("length", length("r")).show() @@ -3155,8 +3035,7 @@ def rtrim(col: "ColumnOrName") -> Column: def ltrim(col: "ColumnOrName") -> Column: - """ - Trim the spaces from left end for the specified string value. + """Trim the spaces from left end for the specified string value. .. versionadded:: 1.5.0 @@ -3168,12 +3047,12 @@ def ltrim(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` left trimmed values. - Examples + Examples: -------- >>> df = spark.createDataFrame([" Spark", "Spark ", " Spark"], "STRING") >>> df.select(ltrim("value").alias("r")).withColumn("length", length("r")).show() @@ -3189,8 +3068,7 @@ def ltrim(col: "ColumnOrName") -> Column: def btrim(str: "ColumnOrName", trim: Optional["ColumnOrName"] = None) -> Column: - """ - Remove the leading and trailing `trim` characters from `str`. + """Remove the leading and trailing `trim` characters from `str`. .. versionadded:: 3.5.0 @@ -3201,14 +3079,22 @@ def btrim(str: "ColumnOrName", trim: Optional["ColumnOrName"] = None) -> Column: trim : :class:`~pyspark.sql.Column` or str The trim string characters to trim, the default value is a single space - Examples + Examples: -------- - >>> df = spark.createDataFrame([("SSparkSQLS", "SL", )], ['a', 'b']) - >>> df.select(btrim(df.a, df.b).alias('r')).collect() + >>> df = spark.createDataFrame( + ... [ + ... ( + ... "SSparkSQLS", + ... "SL", + ... ) + ... ], + ... ["a", "b"], + ... ) + >>> df.select(btrim(df.a, df.b).alias("r")).collect() [Row(r='parkSQ')] - >>> df = spark.createDataFrame([(" SparkSQL ",)], ['a']) - >>> df.select(btrim(df.a).alias('r')).collect() + >>> df = spark.createDataFrame([(" SparkSQL ",)], ["a"]) + >>> df.select(btrim(df.a).alias("r")).collect() [Row(r='SparkSQL')] """ if trim is not None: @@ -3218,8 +3104,7 @@ def btrim(str: "ColumnOrName", trim: Optional["ColumnOrName"] = None) -> Column: def endswith(str: "ColumnOrName", suffix: "ColumnOrName") -> Column: - """ - Returns a boolean. The value is True if str ends with suffix. + """Returns a boolean. The value is True if str ends with suffix. Returns NULL if either input expression is NULL. Otherwise, returns False. Both str or suffix must be of STRING or BINARY type. @@ -3232,13 +3117,29 @@ def endswith(str: "ColumnOrName", suffix: "ColumnOrName") -> Column: suffix : :class:`~pyspark.sql.Column` or str A column of string, the suffix. - Examples + Examples: -------- - >>> df = spark.createDataFrame([("Spark SQL", "Spark",)], ["a", "b"]) - >>> df.select(endswith(df.a, df.b).alias('r')).collect() + >>> df = spark.createDataFrame( + ... [ + ... ( + ... "Spark SQL", + ... "Spark", + ... ) + ... ], + ... ["a", "b"], + ... ) + >>> df.select(endswith(df.a, df.b).alias("r")).collect() [Row(r=False)] - >>> df = spark.createDataFrame([("414243", "4243",)], ["e", "f"]) + >>> df = spark.createDataFrame( + ... [ + ... ( + ... "414243", + ... "4243", + ... ) + ... ], + ... ["e", "f"], + ... ) >>> df = df.select(to_binary("e").alias("e"), to_binary("f").alias("f")) >>> df.printSchema() root @@ -3250,13 +3151,12 @@ def endswith(str: "ColumnOrName", suffix: "ColumnOrName") -> Column: +--------------+--------------+ | true| false| +--------------+--------------+ - """ + """ # noqa: D205 return _invoke_function_over_columns("suffix", str, suffix) def startswith(str: "ColumnOrName", prefix: "ColumnOrName") -> Column: - """ - Returns a boolean. The value is True if str starts with prefix. + """Returns a boolean. The value is True if str starts with prefix. Returns NULL if either input expression is NULL. Otherwise, returns False. Both str or prefix must be of STRING or BINARY type. @@ -3269,13 +3169,29 @@ def startswith(str: "ColumnOrName", prefix: "ColumnOrName") -> Column: prefix : :class:`~pyspark.sql.Column` or str A column of string, the prefix. - Examples + Examples: -------- - >>> df = spark.createDataFrame([("Spark SQL", "Spark",)], ["a", "b"]) - >>> df.select(startswith(df.a, df.b).alias('r')).collect() + >>> df = spark.createDataFrame( + ... [ + ... ( + ... "Spark SQL", + ... "Spark", + ... ) + ... ], + ... ["a", "b"], + ... ) + >>> df.select(startswith(df.a, df.b).alias("r")).collect() [Row(r=True)] - >>> df = spark.createDataFrame([("414243", "4142",)], ["e", "f"]) + >>> df = spark.createDataFrame( + ... [ + ... ( + ... "414243", + ... "4142", + ... ) + ... ], + ... ["e", "f"], + ... ) >>> df = df.select(to_binary("e").alias("e"), to_binary("f").alias("f")) >>> df.printSchema() root @@ -3287,7 +3203,7 @@ def startswith(str: "ColumnOrName", prefix: "ColumnOrName") -> Column: +----------------+----------------+ | true| false| +----------------+----------------+ - """ + """ # noqa: D205 return _invoke_function_over_columns("starts_with", str, prefix) @@ -3306,16 +3222,16 @@ def length(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` length of the value. - Examples + Examples: -------- - >>> spark.createDataFrame([('ABC ',)], ['a']).select(length('a').alias('length')).collect() + >>> spark.createDataFrame([("ABC ",)], ["a"]).select(length("a").alias("length")).collect() [Row(length=4)] - """ + """ # noqa: D205 return _invoke_function_over_columns("length", col) @@ -3328,11 +3244,13 @@ def coalesce(*cols: "ColumnOrName") -> Column: ---------- cols : :class:`~pyspark.sql.Column` or str list of columns to work on. - Returns + + Returns: ------- :class:`~pyspark.sql.Column` value of the first column that is not null. - Examples + + Examples: -------- >>> cDf = spark.createDataFrame([(None, None), (1, None), (None, 2)], ("a", "b")) >>> cDf.show() @@ -3351,7 +3269,7 @@ def coalesce(*cols: "ColumnOrName") -> Column: | 1| | 2| +--------------+ - >>> cDf.select('*', coalesce(cDf["a"], lit(0.0))).show() + >>> cDf.select("*", coalesce(cDf["a"], lit(0.0))).show() +----+----+----------------+ | a| b|coalesce(a, 0.0)| +----+----+----------------+ @@ -3359,33 +3277,42 @@ def coalesce(*cols: "ColumnOrName") -> Column: | 1|NULL| 1.0| |NULL| 2| 0.0| +----+----+----------------+ - """ - + """ # noqa: D205, D415 cols = [_to_column_expr(expr) for expr in cols] return Column(CoalesceOperator(*cols)) def nvl(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: - """ - Returns `col2` if `col1` is null, or `col1` otherwise. + """Returns `col2` if `col1` is null, or `col1` otherwise. .. versionadded:: 3.5.0 Parameters ---------- col1 : :class:`~pyspark.sql.Column` or str col2 : :class:`~pyspark.sql.Column` or str - Examples + + Examples: -------- - >>> df = spark.createDataFrame([(None, 8,), (1, 9,)], ["a", "b"]) - >>> df.select(nvl(df.a, df.b).alias('r')).collect() + >>> df = spark.createDataFrame( + ... [ + ... ( + ... None, + ... 8, + ... ), + ... ( + ... 1, + ... 9, + ... ), + ... ], + ... ["a", "b"], + ... ) + >>> df.select(nvl(df.a, df.b).alias("r")).collect() [Row(r=8), Row(r=1)] - """ - + """ # noqa: D205, D415 return coalesce(col1, col2) def nvl2(col1: "ColumnOrName", col2: "ColumnOrName", col3: "ColumnOrName") -> Column: - """ - Returns `col2` if `col1` is not null, or `col3` otherwise. + """Returns `col2` if `col1` is not null, or `col3` otherwise. .. versionadded:: 3.5.0 @@ -3395,10 +3322,24 @@ def nvl2(col1: "ColumnOrName", col2: "ColumnOrName", col3: "ColumnOrName") -> Co col2 : :class:`~pyspark.sql.Column` or str col3 : :class:`~pyspark.sql.Column` or str - Examples + Examples: -------- - >>> df = spark.createDataFrame([(None, 8, 6,), (1, 9, 9,)], ["a", "b", "c"]) - >>> df.select(nvl2(df.a, df.b, df.c).alias('r')).collect() + >>> df = spark.createDataFrame( + ... [ + ... ( + ... None, + ... 8, + ... 6, + ... ), + ... ( + ... 1, + ... 9, + ... 9, + ... ), + ... ], + ... ["a", "b", "c"], + ... ) + >>> df.select(nvl2(df.a, df.b, df.c).alias("r")).collect() [Row(r=6), Row(r=9)] """ col1 = _to_column_expr(col1) @@ -3408,14 +3349,14 @@ def nvl2(col1: "ColumnOrName", col2: "ColumnOrName", col3: "ColumnOrName") -> Co def ifnull(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: - """ - Returns `col2` if `col1` is null, or `col1` otherwise. + """Returns `col2` if `col1` is null, or `col1` otherwise. .. versionadded:: 3.5.0 Parameters ---------- col1 : :class:`~pyspark.sql.Column` or str col2 : :class:`~pyspark.sql.Column` or str - Examples + + Examples: -------- >>> import pyspark.sql.functions as sf >>> df = spark.createDataFrame([(None,), (1,)], ["e"]) @@ -3426,13 +3367,12 @@ def ifnull(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: | 8| | 1| +------------+ - """ + """ # noqa: D205, D415 return coalesce(col1, col2) def nullif(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: - """ - Returns null if `col1` equals to `col2`, or `col1` otherwise. + """Returns null if `col1` equals to `col2`, or `col1` otherwise. .. versionadded:: 3.5.0 @@ -3441,10 +3381,22 @@ def nullif(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: col1 : :class:`~pyspark.sql.Column` or str col2 : :class:`~pyspark.sql.Column` or str - Examples + Examples: -------- - >>> df = spark.createDataFrame([(None, None,), (1, 9,)], ["a", "b"]) - >>> df.select(nullif(df.a, df.b).alias('r')).collect() + >>> df = spark.createDataFrame( + ... [ + ... ( + ... None, + ... None, + ... ), + ... ( + ... 1, + ... 9, + ... ), + ... ], + ... ["a", "b"], + ... ) + >>> df.select(nullif(df.a, df.b).alias("r")).collect() [Row(r=None), Row(r=1)] """ return _invoke_function_over_columns("nullif", col1, col2) @@ -3463,14 +3415,14 @@ def md5(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column for computed results. - Examples + Examples: -------- - >>> spark.createDataFrame([('ABC',)], ['a']).select(md5('a').alias('hash')).collect() + >>> spark.createDataFrame([("ABC",)], ["a"]).select(md5("a").alias("hash")).collect() [Row(hash='902fbdd2b1df0c4f70b4a5d23525e932')] """ return _invoke_function_over_columns("md5", col) @@ -3494,12 +3446,12 @@ def sha2(col: "ColumnOrName", numBits: int) -> Column: the desired bit length of the result, which must have a value of 224, 256, 384, 512, or 0 (which is equivalent to 256). - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column for computed results. - Examples + Examples: -------- >>> df = spark.createDataFrame([["Alice"], ["Bob"]], ["name"]) >>> df.withColumn("sha2", sha2(df.name, 256)).show(truncate=False) @@ -3509,47 +3461,44 @@ def sha2(col: "ColumnOrName", numBits: int) -> Column: |Alice|3bc51062973c458d5a6f2d8d64a023246354ad7e064b1e4e009ec8a0699a3043| |Bob |cd9fb1e148ccd8442e5aa74904cc73bf6fb54d1d54d333bd596aa9bb4bb4e961| +-----+----------------------------------------------------------------+ - """ - + """ # noqa: D205 if numBits not in {224, 256, 384, 512, 0}: - raise ValueError("numBits should be one of {224, 256, 384, 512, 0}") + msg = "numBits should be one of {224, 256, 384, 512, 0}" + raise ValueError(msg) if numBits == 256: return _invoke_function_over_columns("sha256", col) - raise ContributionsAcceptedError( - "SHA-224, SHA-384, and SHA-512 are not supported yet." - ) + msg = "SHA-224, SHA-384, and SHA-512 are not supported yet." + raise ContributionsAcceptedError(msg) def curdate() -> Column: - """ - Returns the current date at the start of query evaluation as a :class:`DateType` column. + """Returns the current date at the start of query evaluation as a :class:`DateType` column. All calls of current_date within the same query return the same value. .. versionadded:: 3.5.0 - Returns + Returns: ------- :class:`~pyspark.sql.Column` current date. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf - >>> spark.range(1).select(sf.curdate()).show() # doctest: +SKIP + >>> spark.range(1).select(sf.curdate()).show() # doctest: +SKIP +--------------+ |current_date()| +--------------+ | 2022-08-26| +--------------+ - """ + """ # noqa: D205 return _invoke_function("today") def current_date() -> Column: - """ - Returns the current date at the start of query evaluation as a :class:`DateType` column. + """Returns the current date at the start of query evaluation as a :class:`DateType` column. All calls of current_date within the same query return the same value. .. versionadded:: 1.5.0 @@ -3557,39 +3506,38 @@ def current_date() -> Column: .. versionchanged:: 3.4.0 Supports Spark Connect. - Returns + Returns: ------- :class:`~pyspark.sql.Column` current date. - Examples + Examples: -------- >>> df = spark.range(1) - >>> df.select(current_date()).show() # doctest: +SKIP + >>> df.select(current_date()).show() # doctest: +SKIP +--------------+ |current_date()| +--------------+ | 2022-08-26| +--------------+ - """ + """ # noqa: D205 return curdate() def now() -> Column: - """ - Returns the current timestamp at the start of query evaluation. + """Returns the current timestamp at the start of query evaluation. .. versionadded:: 3.5.0 - Returns + Returns: ------- :class:`~pyspark.sql.Column` current timestamp at the start of query evaluation. - Examples + Examples: -------- >>> df = spark.range(1) - >>> df.select(now()).show(truncate=False) # doctest: +SKIP + >>> df.select(now()).show(truncate=False) # doctest: +SKIP +-----------------------+ |now() | +-----------------------+ @@ -3598,9 +3546,9 @@ def now() -> Column: """ return _invoke_function("now") + def desc(col: "ColumnOrName") -> Column: - """ - Returns a sort expression based on the descending order of the given column name. + """Returns a sort expression based on the descending order of the given column name. .. versionadded:: 1.3.0 @@ -3612,12 +3560,12 @@ def desc(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to sort by in the descending order. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column specifying the order. - Examples + Examples: -------- Sort by the column 'id' in the descending order. @@ -3634,9 +3582,9 @@ def desc(col: "ColumnOrName") -> Column: """ return Column(_to_column_expr(col).desc()) + def asc(col: "ColumnOrName") -> Column: - """ - Returns a sort expression based on the ascending order of the given column name. + """Returns a sort expression based on the ascending order of the given column name. .. versionadded:: 1.3.0 @@ -3648,12 +3596,12 @@ def asc(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to sort by in the ascending order. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column specifying the order. - Examples + Examples: -------- Sort by the column 'id' in the descending order. @@ -3685,9 +3633,9 @@ def asc(col: "ColumnOrName") -> Column: """ return Column(_to_column_expr(col).asc()) + def date_trunc(format: str, timestamp: "ColumnOrName") -> Column: - """ - Returns timestamp truncated to the unit specified by the format. + """Returns timestamp truncated to the unit specified by the format. .. versionadded:: 2.3.0 @@ -3698,12 +3646,12 @@ def date_trunc(format: str, timestamp: "ColumnOrName") -> Column: 'day', 'dd', 'hour', 'minute', 'second', 'week', 'quarter' timestamp : :class:`~pyspark.sql.Column` or str - Examples + Examples: -------- - >>> df = spark.createDataFrame([('1997-02-28 05:02:11',)], ['t']) - >>> df.select(date_trunc('year', df.t).alias('year')).collect() + >>> df = spark.createDataFrame([("1997-02-28 05:02:11",)], ["t"]) + >>> df.select(date_trunc("year", df.t).alias("year")).collect() [Row(year=datetime.datetime(1997, 1, 1, 0, 0))] - >>> df.select(date_trunc('mon', df.t).alias('month')).collect() + >>> df.select(date_trunc("mon", df.t).alias("month")).collect() [Row(month=datetime.datetime(1997, 2, 1, 0, 0))] """ format = format.lower() @@ -3719,8 +3667,7 @@ def date_trunc(format: str, timestamp: "ColumnOrName") -> Column: def date_part(field: "ColumnOrName", source: "ColumnOrName") -> Column: - """ - Extracts a part of the date/timestamp or interval source. + """Extracts a part of the date/timestamp or interval source. .. versionadded:: 3.5.0 @@ -3732,22 +3679,22 @@ def date_part(field: "ColumnOrName", source: "ColumnOrName") -> Column: source : :class:`~pyspark.sql.Column` or str a date/timestamp or interval column from where `field` should be extracted. - Returns + Returns: ------- :class:`~pyspark.sql.Column` a part of the date/timestamp or interval source. - Examples + Examples: -------- >>> import datetime - >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ['ts']) + >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ["ts"]) >>> df.select( - ... date_part(lit('YEAR'), 'ts').alias('year'), - ... date_part(lit('month'), 'ts').alias('month'), - ... date_part(lit('WEEK'), 'ts').alias('week'), - ... date_part(lit('D'), 'ts').alias('day'), - ... date_part(lit('M'), 'ts').alias('minute'), - ... date_part(lit('S'), 'ts').alias('second') + ... date_part(lit("YEAR"), "ts").alias("year"), + ... date_part(lit("month"), "ts").alias("month"), + ... date_part(lit("WEEK"), "ts").alias("week"), + ... date_part(lit("D"), "ts").alias("day"), + ... date_part(lit("M"), "ts").alias("minute"), + ... date_part(lit("S"), "ts").alias("second"), ... ).collect() [Row(year=2015, month=4, week=15, day=8, minute=8, second=Decimal('15.000000'))] """ @@ -3755,8 +3702,7 @@ def date_part(field: "ColumnOrName", source: "ColumnOrName") -> Column: def extract(field: "ColumnOrName", source: "ColumnOrName") -> Column: - """ - Extracts a part of the date/timestamp or interval source. + """Extracts a part of the date/timestamp or interval source. .. versionadded:: 3.5.0 @@ -3767,22 +3713,22 @@ def extract(field: "ColumnOrName", source: "ColumnOrName") -> Column: source : :class:`~pyspark.sql.Column` or str a date/timestamp or interval column from where `field` should be extracted. - Returns + Returns: ------- :class:`~pyspark.sql.Column` a part of the date/timestamp or interval source. - Examples + Examples: -------- >>> import datetime - >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ['ts']) + >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ["ts"]) >>> df.select( - ... extract(lit('YEAR'), 'ts').alias('year'), - ... extract(lit('month'), 'ts').alias('month'), - ... extract(lit('WEEK'), 'ts').alias('week'), - ... extract(lit('D'), 'ts').alias('day'), - ... extract(lit('M'), 'ts').alias('minute'), - ... extract(lit('S'), 'ts').alias('second') + ... extract(lit("YEAR"), "ts").alias("year"), + ... extract(lit("month"), "ts").alias("month"), + ... extract(lit("WEEK"), "ts").alias("week"), + ... extract(lit("D"), "ts").alias("day"), + ... extract(lit("M"), "ts").alias("minute"), + ... extract(lit("S"), "ts").alias("second"), ... ).collect() [Row(year=2015, month=4, week=15, day=8, minute=8, second=Decimal('15.000000'))] """ @@ -3790,8 +3736,7 @@ def extract(field: "ColumnOrName", source: "ColumnOrName") -> Column: def datepart(field: "ColumnOrName", source: "ColumnOrName") -> Column: - """ - Extracts a part of the date/timestamp or interval source. + """Extracts a part of the date/timestamp or interval source. .. versionadded:: 3.5.0 @@ -3803,22 +3748,22 @@ def datepart(field: "ColumnOrName", source: "ColumnOrName") -> Column: source : :class:`~pyspark.sql.Column` or str a date/timestamp or interval column from where `field` should be extracted. - Returns + Returns: ------- :class:`~pyspark.sql.Column` a part of the date/timestamp or interval source. - Examples + Examples: -------- >>> import datetime - >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ['ts']) + >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ["ts"]) >>> df.select( - ... datepart(lit('YEAR'), 'ts').alias('year'), - ... datepart(lit('month'), 'ts').alias('month'), - ... datepart(lit('WEEK'), 'ts').alias('week'), - ... datepart(lit('D'), 'ts').alias('day'), - ... datepart(lit('M'), 'ts').alias('minute'), - ... datepart(lit('S'), 'ts').alias('second') + ... datepart(lit("YEAR"), "ts").alias("year"), + ... datepart(lit("month"), "ts").alias("month"), + ... datepart(lit("WEEK"), "ts").alias("week"), + ... datepart(lit("D"), "ts").alias("day"), + ... datepart(lit("M"), "ts").alias("minute"), + ... datepart(lit("S"), "ts").alias("second"), ... ).collect() [Row(year=2015, month=4, week=15, day=8, minute=8, second=Decimal('15.000000'))] """ @@ -3826,8 +3771,7 @@ def datepart(field: "ColumnOrName", source: "ColumnOrName") -> Column: def date_diff(end: "ColumnOrName", start: "ColumnOrName") -> Column: - """ - Returns the number of days from `start` to `end`. + """Returns the number of days from `start` to `end`. .. versionadded:: 3.5.0 @@ -3838,12 +3782,12 @@ def date_diff(end: "ColumnOrName", start: "ColumnOrName") -> Column: start : :class:`~pyspark.sql.Column` or column name from date column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` difference in days between two dates. - See Also + See Also: -------- :meth:`pyspark.sql.functions.dateadd` :meth:`pyspark.sql.functions.date_add` @@ -3851,18 +3795,22 @@ def date_diff(end: "ColumnOrName", start: "ColumnOrName") -> Column: :meth:`pyspark.sql.functions.datediff` :meth:`pyspark.sql.functions.timestamp_diff` - Examples + Examples: -------- >>> import pyspark.sql.functions as sf - >>> df = spark.createDataFrame([('2015-04-08','2015-05-10')], ['d1', 'd2']) - >>> df.select('*', sf.date_diff(sf.col('d1').cast('DATE'), sf.col('d2').cast('DATE'))).show() + >>> df = spark.createDataFrame([("2015-04-08", "2015-05-10")], ["d1", "d2"]) + >>> df.select( + ... "*", sf.date_diff(sf.col("d1").cast("DATE"), sf.col("d2").cast("DATE")) + ... ).show() +----------+----------+-----------------+ | d1| d2|date_diff(d1, d2)| +----------+----------+-----------------+ |2015-04-08|2015-05-10| -32| +----------+----------+-----------------+ - >>> df.select('*', sf.date_diff(sf.col('d1').cast('DATE'), sf.col('d2').cast('DATE'))).show() + >>> df.select( + ... "*", sf.date_diff(sf.col("d1").cast("DATE"), sf.col("d2").cast("DATE")) + ... ).show() +----------+----------+-----------------+ | d1| d2|date_diff(d2, d1)| +----------+----------+-----------------+ @@ -3873,8 +3821,7 @@ def date_diff(end: "ColumnOrName", start: "ColumnOrName") -> Column: def year(col: "ColumnOrName") -> Column: - """ - Extract the year of a given date/timestamp as integer. + """Extract the year of a given date/timestamp as integer. .. versionadded:: 1.5.0 @@ -3886,23 +3833,22 @@ def year(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target date/timestamp column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` year part of the date/timestamp as integer. - Examples + Examples: -------- - >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) - >>> df.select(year('dt').alias('year')).collect() + >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) + >>> df.select(year("dt").alias("year")).collect() [Row(year=2015)] """ return _invoke_function_over_columns("year", col) def quarter(col: "ColumnOrName") -> Column: - """ - Extract the quarter of a given date/timestamp as integer. + """Extract the quarter of a given date/timestamp as integer. .. versionadded:: 1.5.0 @@ -3914,23 +3860,22 @@ def quarter(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target date/timestamp column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` quarter of the date/timestamp as integer. - Examples + Examples: -------- - >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) - >>> df.select(quarter('dt').alias('quarter')).collect() + >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) + >>> df.select(quarter("dt").alias("quarter")).collect() [Row(quarter=2)] """ return _invoke_function_over_columns("quarter", col) def month(col: "ColumnOrName") -> Column: - """ - Extract the month of a given date/timestamp as integer. + """Extract the month of a given date/timestamp as integer. .. versionadded:: 1.5.0 @@ -3942,24 +3887,23 @@ def month(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target date/timestamp column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` month part of the date/timestamp as integer. - Examples + Examples: -------- - >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) - >>> df.select(month('dt').alias('month')).collect() + >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) + >>> df.select(month("dt").alias("month")).collect() [Row(month=4)] """ return _invoke_function_over_columns("month", col) def dayofweek(col: "ColumnOrName") -> Column: - """ - Extract the day of the week of a given date/timestamp as integer. - Ranges from 1 for a Sunday through to 7 for a Saturday + """Extract the day of the week of a given date/timestamp as integer. + Ranges from 1 for a Sunday through to 7 for a Saturday. .. versionadded:: 2.3.0 @@ -3971,23 +3915,22 @@ def dayofweek(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target date/timestamp column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` day of the week for given date/timestamp as integer. - Examples + Examples: -------- - >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) - >>> df.select(dayofweek('dt').alias('day')).collect() + >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) + >>> df.select(dayofweek("dt").alias("day")).collect() [Row(day=4)] - """ + """ # noqa: D205 return _invoke_function_over_columns("dayofweek", col) + lit(1) def day(col: "ColumnOrName") -> Column: - """ - Extract the day of the month of a given date/timestamp as integer. + """Extract the day of the month of a given date/timestamp as integer. .. versionadded:: 3.5.0 @@ -3996,23 +3939,22 @@ def day(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target date/timestamp column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` day of the month for given date/timestamp as integer. - Examples + Examples: -------- - >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) - >>> df.select(day('dt').alias('day')).collect() + >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) + >>> df.select(day("dt").alias("day")).collect() [Row(day=8)] """ return _invoke_function_over_columns("day", col) def dayofmonth(col: "ColumnOrName") -> Column: - """ - Extract the day of the month of a given date/timestamp as integer. + """Extract the day of the month of a given date/timestamp as integer. .. versionadded:: 1.5.0 @@ -4024,23 +3966,22 @@ def dayofmonth(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target date/timestamp column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` day of the month for given date/timestamp as integer. - Examples + Examples: -------- - >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) - >>> df.select(dayofmonth('dt').alias('day')).collect() + >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) + >>> df.select(dayofmonth("dt").alias("day")).collect() [Row(day=8)] """ return day(col) def dayofyear(col: "ColumnOrName") -> Column: - """ - Extract the day of the year of a given date/timestamp as integer. + """Extract the day of the year of a given date/timestamp as integer. .. versionadded:: 1.5.0 @@ -4052,23 +3993,22 @@ def dayofyear(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target date/timestamp column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` day of the year for given date/timestamp as integer. - Examples + Examples: -------- - >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) - >>> df.select(dayofyear('dt').alias('day')).collect() + >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) + >>> df.select(dayofyear("dt").alias("day")).collect() [Row(day=98)] """ return _invoke_function_over_columns("dayofyear", col) def hour(col: "ColumnOrName") -> Column: - """ - Extract the hours of a given timestamp as integer. + """Extract the hours of a given timestamp as integer. .. versionadded:: 1.5.0 @@ -4080,24 +4020,23 @@ def hour(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target date/timestamp column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` hour part of the timestamp as integer. - Examples + Examples: -------- >>> import datetime - >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ['ts']) - >>> df.select(hour('ts').alias('hour')).collect() + >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ["ts"]) + >>> df.select(hour("ts").alias("hour")).collect() [Row(hour=13)] """ return _invoke_function_over_columns("hour", col) def minute(col: "ColumnOrName") -> Column: - """ - Extract the minutes of a given timestamp as integer. + """Extract the minutes of a given timestamp as integer. .. versionadded:: 1.5.0 @@ -4109,24 +4048,23 @@ def minute(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target date/timestamp column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` minutes part of the timestamp as integer. - Examples + Examples: -------- >>> import datetime - >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ['ts']) - >>> df.select(minute('ts').alias('minute')).collect() + >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ["ts"]) + >>> df.select(minute("ts").alias("minute")).collect() [Row(minute=8)] """ return _invoke_function_over_columns("minute", col) def second(col: "ColumnOrName") -> Column: - """ - Extract the seconds of a given date as integer. + """Extract the seconds of a given date as integer. .. versionadded:: 1.5.0 @@ -4138,26 +4076,25 @@ def second(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target date/timestamp column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` `seconds` part of the timestamp as integer. - Examples + Examples: -------- >>> import datetime - >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ['ts']) - >>> df.select(second('ts').alias('second')).collect() + >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ["ts"]) + >>> df.select(second("ts").alias("second")).collect() [Row(second=15)] """ return _invoke_function_over_columns("second", col) def weekofyear(col: "ColumnOrName") -> Column: - """ - Extract the week number of a given date as integer. + """Extract the week number of a given date as integer. A week is considered to start on a Monday and week 1 is the first week with more than 3 days, - as defined by ISO 8601 + as defined by ISO 8601. .. versionadded:: 1.5.0 @@ -4169,23 +4106,22 @@ def weekofyear(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target timestamp column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` `week` of the year for given date as integer. - Examples + Examples: -------- - >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) - >>> df.select(weekofyear(df.dt).alias('week')).collect() + >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) + >>> df.select(weekofyear(df.dt).alias("week")).collect() [Row(week=15)] - """ + """ # noqa: D205 return _invoke_function_over_columns("weekofyear", col) def cos(col: "ColumnOrName") -> Column: - """ - Computes cosine of the input column. + """Computes cosine of the input column. .. versionadded:: 1.4.0 @@ -4197,12 +4133,12 @@ def cos(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str angle in radians - Returns + Returns: ------- :class:`~pyspark.sql.Column` cosine of the angle, as if computed by `java.lang.Math.cos()`. - Examples + Examples: -------- >>> import math >>> df = spark.range(1) @@ -4213,8 +4149,7 @@ def cos(col: "ColumnOrName") -> Column: def acos(col: "ColumnOrName") -> Column: - """ - Computes inverse cosine of the input column. + """Computes inverse cosine of the input column. .. versionadded:: 1.4.0 @@ -4226,12 +4161,12 @@ def acos(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` inverse cosine of `col`, as if computed by `java.lang.Math.acos()` - Examples + Examples: -------- >>> df = spark.range(1, 3) >>> df.select(acos(df.id)).show() @@ -4246,8 +4181,7 @@ def acos(col: "ColumnOrName") -> Column: def call_function(funcName: str, *cols: "ColumnOrName") -> Column: - """ - Call a SQL function. + r"""Call a SQL function. .. versionadded:: 3.5.0 @@ -4258,16 +4192,16 @@ def call_function(funcName: str, *cols: "ColumnOrName") -> Column: cols : :class:`~pyspark.sql.Column` or str column names or :class:`~pyspark.sql.Column`\\s to be used in the function - Returns + Returns: ------- :class:`~pyspark.sql.Column` result of executed function. - Examples + Examples: -------- >>> from pyspark.sql.functions import call_udf, col >>> from pyspark.sql.types import IntegerType, StringType - >>> df = spark.createDataFrame([(1, "a"),(2, "b"), (3, "c")],["id", "name"]) + >>> df = spark.createDataFrame([(1, "a"), (2, "b"), (3, "c")], ["id", "name"]) >>> _ = spark.udf.register("intX2", lambda i: i * 2, IntegerType()) >>> df.select(call_function("intX2", "id")).show() +---------+ @@ -4328,19 +4262,19 @@ def covar_pop(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: col1 : :class:`~pyspark.sql.Column` or str second column to calculate covariance. - Returns + Returns: ------- :class:`~pyspark.sql.Column` covariance of these two column values. - Examples + Examples: -------- >>> a = [1] * 10 >>> b = [1] * 10 >>> df = spark.createDataFrame(zip(a, b), ["a", "b"]) - >>> df.agg(covar_pop("a", "b").alias('c')).collect() + >>> df.agg(covar_pop("a", "b").alias("c")).collect() [Row(c=0.0)] - """ + """ # noqa: D205 return _invoke_function_over_columns("covar_pop", col1, col2) @@ -4360,25 +4294,24 @@ def covar_samp(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: col1 : :class:`~pyspark.sql.Column` or str second column to calculate covariance. - Returns + Returns: ------- :class:`~pyspark.sql.Column` sample covariance of these two column values. - Examples + Examples: -------- >>> a = [1] * 10 >>> b = [1] * 10 >>> df = spark.createDataFrame(zip(a, b), ["a", "b"]) - >>> df.agg(covar_samp("a", "b").alias('c')).collect() + >>> df.agg(covar_samp("a", "b").alias("c")).collect() [Row(c=0.0)] - """ + """ # noqa: D205 return _invoke_function_over_columns("covar_samp", col1, col2) def exp(col: "ColumnOrName") -> Column: - """ - Computes the exponential of the given value. + """Computes the exponential of the given value. .. versionadded:: 1.4.0 @@ -4390,12 +4323,12 @@ def exp(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str column to calculate exponential for. - Returns + Returns: ------- :class:`~pyspark.sql.Column` exponential of the given value. - Examples + Examples: -------- >>> df = spark.range(1) >>> df.select(exp(lit(0))).show() @@ -4409,8 +4342,7 @@ def exp(col: "ColumnOrName") -> Column: def factorial(col: "ColumnOrName") -> Column: - """ - Computes the factorial of the given value. + """Computes the factorial of the given value. .. versionadded:: 1.5.0 @@ -4422,15 +4354,15 @@ def factorial(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str a column to calculate factorial for. - Returns + Returns: ------- :class:`~pyspark.sql.Column` factorial of given value. - Examples + Examples: -------- - >>> df = spark.createDataFrame([(5,)], ['n']) - >>> df.select(factorial(df.n).alias('f')).collect() + >>> df = spark.createDataFrame([(5,)], ["n"]) + >>> df.select(factorial(df.n).alias("f")).collect() [Row(f=120)] """ return _invoke_function_over_columns("factorial", col) @@ -4449,15 +4381,15 @@ def log2(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str a column to calculate logariphm for. - Returns + Returns: ------- :class:`~pyspark.sql.Column` logariphm of given value. - Examples + Examples: -------- - >>> df = spark.createDataFrame([(4,)], ['a']) - >>> df.select(log2('a').alias('log2')).show() + >>> df = spark.createDataFrame([(4,)], ["a"]) + >>> df.select(log2("a").alias("log2")).show() +----+ |log2| +----+ @@ -4477,15 +4409,15 @@ def ln(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str a column to calculate logariphm for. - Returns + Returns: ------- :class:`~pyspark.sql.Column` natural logarithm of given value. - Examples + Examples: -------- - >>> df = spark.createDataFrame([(4,)], ['a']) - >>> df.select(ln('a')).show() + >>> df = spark.createDataFrame([(4,)], ["a"]) + >>> df.select(ln("a")).show() +------------------+ | ln(a)| +------------------+ @@ -4496,8 +4428,7 @@ def ln(col: "ColumnOrName") -> Column: def degrees(col: "ColumnOrName") -> Column: - """ - Converts an angle measured in radians to an approximately equivalent angle + """Converts an angle measured in radians to an approximately equivalent angle measured in degrees. .. versionadded:: 2.1.0 @@ -4510,25 +4441,23 @@ def degrees(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str angle in radians - Returns + Returns: ------- :class:`~pyspark.sql.Column` angle in degrees, as if computed by `java.lang.Math.toDegrees()` - Examples + Examples: -------- >>> import math >>> df = spark.range(1) >>> df.select(degrees(lit(math.pi))).first() Row(DEGREES(3.14159...)=180.0) - """ + """ # noqa: D205 return _invoke_function_over_columns("degrees", col) - def radians(col: "ColumnOrName") -> Column: - """ - Converts an angle measured in degrees to an approximately equivalent angle + """Converts an angle measured in degrees to an approximately equivalent angle measured in radians. .. versionadded:: 2.1.0 @@ -4541,23 +4470,22 @@ def radians(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str angle in degrees - Returns + Returns: ------- :class:`~pyspark.sql.Column` angle in radians, as if computed by `java.lang.Math.toRadians()` - Examples + Examples: -------- >>> df = spark.range(1) >>> df.select(radians(lit(180))).first() Row(RADIANS(180)=3.14159...) - """ + """ # noqa: D205 return _invoke_function_over_columns("radians", col) def atan(col: "ColumnOrName") -> Column: - """ - Compute inverse tangent of the input column. + """Compute inverse tangent of the input column. .. versionadded:: 1.4.0 @@ -4569,12 +4497,12 @@ def atan(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` inverse tangent of `col`, as if computed by `java.lang.Math.atan()` - Examples + Examples: -------- >>> df = spark.range(1) >>> df.select(atan(df.id)).show() @@ -4588,8 +4516,7 @@ def atan(col: "ColumnOrName") -> Column: def atan2(col1: Union["ColumnOrName", float], col2: Union["ColumnOrName", float]) -> Column: - """ - .. versionadded:: 1.4.0 + """.. versionadded:: 1.4.0. .. versionchanged:: 3.4.0 Supports Spark Connect. @@ -4601,7 +4528,7 @@ def atan2(col1: Union["ColumnOrName", float], col2: Union["ColumnOrName", float] col2 : str, :class:`~pyspark.sql.Column` or float coordinate on x-axis - Returns + Returns: ------- :class:`~pyspark.sql.Column` the `theta` component of the point @@ -4610,22 +4537,23 @@ def atan2(col1: Union["ColumnOrName", float], col2: Union["ColumnOrName", float] (`x`, `y`) in Cartesian coordinates, as if computed by `java.lang.Math.atan2()` - Examples + Examples: -------- >>> df = spark.range(1) >>> df.select(atan2(lit(1), lit(2))).first() Row(ATAN2(1, 2)=0.46364...) """ + def lit_or_column(x: Union["ColumnOrName", float]) -> Column: if isinstance(x, (int, float)): return lit(x) return x + return _invoke_function_over_columns("atan2", lit_or_column(col1), lit_or_column(col2)) def tan(col: "ColumnOrName") -> Column: - """ - Computes tangent of the input column. + """Computes tangent of the input column. .. versionadded:: 1.4.0 @@ -4637,12 +4565,12 @@ def tan(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str angle in radians - Returns + Returns: ------- :class:`~pyspark.sql.Column` tangent of the given value, as if computed by `java.lang.Math.tan()` - Examples + Examples: -------- >>> import math >>> df = spark.range(1) @@ -4653,8 +4581,7 @@ def tan(col: "ColumnOrName") -> Column: def round(col: "ColumnOrName", scale: int = 0) -> Column: - """ - Round the given value to `scale` decimal places using HALF_UP rounding mode if `scale` >= 0 + """Round the given value to `scale` decimal places using HALF_UP rounding mode if `scale` >= 0 or at integral part when `scale` < 0. .. versionadded:: 1.5.0 @@ -4669,22 +4596,21 @@ def round(col: "ColumnOrName", scale: int = 0) -> Column: scale : int optional default 0 scale value. - Returns + Returns: ------- :class:`~pyspark.sql.Column` rounded values. - Examples + Examples: -------- - >>> spark.createDataFrame([(2.5,)], ['a']).select(round('a', 0).alias('r')).collect() + >>> spark.createDataFrame([(2.5,)], ["a"]).select(round("a", 0).alias("r")).collect() [Row(r=3.0)] - """ + """ # noqa: D205 return _invoke_function_over_columns("round", col, lit(scale)) def bround(col: "ColumnOrName", scale: int = 0) -> Column: - """ - Round the given value to `scale` decimal places using HALF_EVEN rounding mode if `scale` >= 0 + """Round the given value to `scale` decimal places using HALF_EVEN rounding mode if `scale` >= 0 or at integral part when `scale` < 0. .. versionadded:: 2.0.0 @@ -4699,22 +4625,21 @@ def bround(col: "ColumnOrName", scale: int = 0) -> Column: scale : int optional default 0 scale value. - Returns + Returns: ------- :class:`~pyspark.sql.Column` rounded values. - Examples + Examples: -------- - >>> spark.createDataFrame([(2.5,)], ['a']).select(bround('a', 0).alias('r')).collect() + >>> spark.createDataFrame([(2.5,)], ["a"]).select(bround("a", 0).alias("r")).collect() [Row(r=2.0)] - """ + """ # noqa: D205 return _invoke_function_over_columns("round_even", col, lit(scale)) def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: - """ - Collection function: Returns element of array at given (0-based) index. + """Collection function: Returns element of array at given (0-based) index. If the index points outside of the array boundaries, then this function returns NULL. @@ -4727,23 +4652,23 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: index : :class:`~pyspark.sql.Column` or str or int index to check for in array - Returns + Returns: ------- :class:`~pyspark.sql.Column` value at given position. - Notes + Notes: ----- The position is not 1 based, but 0 based index. Supports Spark Connect. - See Also + See Also: -------- :meth:`element_at` - Examples + Examples: -------- - >>> df = spark.createDataFrame([(["a", "b", "c"], 1)], ['data', 'index']) + >>> df = spark.createDataFrame([(["a", "b", "c"], 1)], ["data", "index"]) >>> df.select(get(df.data, 1)).show() +------------+ |get(data, 1)| @@ -4778,7 +4703,7 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: +----------------------+ | a| +----------------------+ - """ + """ # noqa: D205 index = ConstantExpression(index) if isinstance(index, int) else _to_column_expr(index) # Spark uses 0-indexing, DuckDB 1-indexing index = index + 1 @@ -4799,14 +4724,14 @@ def initcap(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` string with all first letters are uppercase in each word. - Examples + Examples: -------- - >>> spark.createDataFrame([('ab cd',)], ['a']).select(initcap("a").alias('v')).collect() + >>> spark.createDataFrame([("ab cd",)], ["a"]).select(initcap("a").alias("v")).collect() [Row(v='Ab Cd')] """ return Column( @@ -4814,18 +4739,14 @@ def initcap(col: "ColumnOrName") -> Column: "array_to_string", FunctionExpression( "list_transform", - FunctionExpression( - "string_split", _to_column_expr(col), ConstantExpression(" ") - ), + FunctionExpression("string_split", _to_column_expr(col), ConstantExpression(" ")), LambdaExpression( "x", FunctionExpression( "concat", FunctionExpression( "upper", - FunctionExpression( - "array_extract", ColumnExpression("x"), 1 - ), + FunctionExpression("array_extract", ColumnExpression("x"), 1), ), FunctionExpression("array_slice", ColumnExpression("x"), 2, -1), ), @@ -4837,8 +4758,7 @@ def initcap(col: "ColumnOrName") -> Column: def octet_length(col: "ColumnOrName") -> Column: - """ - Calculates the byte length for the specified string column. + r"""Calculates the byte length for the specified string column. .. versionadded:: 3.3.0 @@ -4850,15 +4770,15 @@ def octet_length(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str Source column or strings - Returns + Returns: ------- :class:`~pyspark.sql.Column` Byte length of the col - Examples + Examples: -------- >>> from pyspark.sql.functions import octet_length - >>> spark.createDataFrame([('cat',), ( '\U0001F408',)], ['cat']) \\ + >>> spark.createDataFrame([('cat',), ( '\U0001f408',)], ['cat']) \\ ... .select(octet_length('cat')).collect() [Row(octet_length(cat)=3), Row(octet_length(cat)=4)] """ @@ -4866,8 +4786,10 @@ def octet_length(col: "ColumnOrName") -> Column: def hex(col: "ColumnOrName") -> Column: - """ - Computes hex value of the given column, which could be :class:`~pyspark.sql.types.StringType`, :class:`~pyspark.sql.types.BinaryType`, :class:`~pyspark.sql.types.IntegerType` or :class:`~pyspark.sql.types.LongType`. + """Computes hex value of the given column. + + The column can be :class:`~pyspark.sql.types.StringType`, :class:`~pyspark.sql.types.BinaryType`, + :class:`~pyspark.sql.types.IntegerType` or :class:`~pyspark.sql.types.LongType`. .. versionadded:: 1.5.0 @@ -4879,22 +4801,24 @@ def hex(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` hexadecimal representation of given value as string. - Examples + Examples: -------- - >>> spark.createDataFrame([('ABC', 3)], ['a', 'b']).select(hex('a'), hex('b')).collect() + >>> spark.createDataFrame([("ABC", 3)], ["a", "b"]).select(hex("a"), hex("b")).collect() [Row(hex(a)='414243', hex(b)='3')] """ return _invoke_function_over_columns("hex", col) def unhex(col: "ColumnOrName") -> Column: - """ - Inverse of hex. Interprets each pair of characters as a hexadecimal number and converts to the byte representation of number. column and returns it as a binary column. + """Inverse of hex. + + Interprets each pair of characters as a hexadecimal number and converts to the byte representation of number column + and returns it as a binary column. .. versionadded:: 1.5.0 @@ -4906,22 +4830,21 @@ def unhex(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` string representation of given hexadecimal value. - Examples + Examples: -------- - >>> spark.createDataFrame([('414243',)], ['a']).select(unhex('a')).collect() + >>> spark.createDataFrame([("414243",)], ["a"]).select(unhex("a")).collect() [Row(unhex(a)=bytearray(b'ABC'))] """ return _invoke_function_over_columns("unhex", col) def base64(col: "ColumnOrName") -> Column: - """ - Computes the BASE64 encoding of a binary column and returns it as a string column. + """Computes the BASE64 encoding of a binary column and returns it as a string column. .. versionadded:: 1.5.0 @@ -4933,12 +4856,12 @@ def base64(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` BASE64 encoding of string value. - Examples + Examples: -------- >>> df = spark.createDataFrame(["Spark", "PySpark", "Pandas API"], "STRING") >>> df.select(base64("value")).show() @@ -4950,14 +4873,13 @@ def base64(col: "ColumnOrName") -> Column: |UGFuZGFzIEFQSQ==| +----------------+ """ - if isinstance(col,str): + if isinstance(col, str): col = Column(ColumnExpression(col)) return _invoke_function_over_columns("base64", col.cast("BLOB")) def unbase64(col: "ColumnOrName") -> Column: - """ - Decodes a BASE64 encoded string column and returns it as a binary column. + """Decodes a BASE64 encoded string column and returns it as a binary column. .. versionadded:: 1.5.0 @@ -4969,16 +4891,14 @@ def unbase64(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` encoded string value. - Examples + Examples: -------- - >>> df = spark.createDataFrame(["U3Bhcms=", - ... "UHlTcGFyaw==", - ... "UGFuZGFzIEFQSQ=="], "STRING") + >>> df = spark.createDataFrame(["U3Bhcms=", "UHlTcGFyaw==", "UGFuZGFzIEFQSQ=="], "STRING") >>> df.select(unbase64("value")).show() +--------------------+ | unbase64(value)| @@ -4992,8 +4912,7 @@ def unbase64(col: "ColumnOrName") -> Column: def add_months(start: "ColumnOrName", months: Union["ColumnOrName", int]) -> Column: - """ - Returns the date that is `months` months after `start`. If `months` is a negative value + """Returns the date that is `months` months after `start`. If `months` is a negative value then these amount of months will be deducted from the `start`. .. versionadded:: 1.5.0 @@ -5009,30 +4928,27 @@ def add_months(start: "ColumnOrName", months: Union["ColumnOrName", int]) -> Col how many months after the given date to calculate. Accepts negative value as well to calculate backwards. - Returns + Returns: ------- :class:`~pyspark.sql.Column` a date after/before given number of months. - Examples + Examples: -------- - >>> df = spark.createDataFrame([('2015-04-08', 2)], ['dt', 'add']) - >>> df.select(add_months(df.dt, 1).alias('next_month')).collect() + >>> df = spark.createDataFrame([("2015-04-08", 2)], ["dt", "add"]) + >>> df.select(add_months(df.dt, 1).alias("next_month")).collect() [Row(next_month=datetime.date(2015, 5, 8))] - >>> df.select(add_months(df.dt, df.add.cast('integer')).alias('next_month')).collect() + >>> df.select(add_months(df.dt, df.add.cast("integer")).alias("next_month")).collect() [Row(next_month=datetime.date(2015, 6, 8))] - >>> df.select(add_months('dt', -2).alias('prev_month')).collect() + >>> df.select(add_months("dt", -2).alias("prev_month")).collect() [Row(prev_month=datetime.date(2015, 2, 8))] - """ + """ # noqa: D205 months = ConstantExpression(months) if isinstance(months, int) else _to_column_expr(months) return _invoke_function("date_add", _to_column_expr(start), FunctionExpression("to_months", months)).cast("date") -def array_join( - col: "ColumnOrName", delimiter: str, null_replacement: Optional[str] = None -) -> Column: - """ - Concatenates the elements of `column` using the `delimiter`. Null values are replaced with +def array_join(col: "ColumnOrName", delimiter: str, null_replacement: Optional[str] = None) -> Column: + """Concatenates the elements of `column` using the `delimiter`. Null values are replaced with `null_replacement` if set, otherwise they are ignored. .. versionadded:: 2.4.0 @@ -5049,30 +4965,36 @@ def array_join( null_replacement : str, optional if set then null values will be replaced by this value - Returns + Returns: ------- :class:`~pyspark.sql.Column` a column of string type. Concatenated values. - Examples + Examples: -------- - >>> df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], ['data']) + >>> df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], ["data"]) >>> df.select(array_join(df.data, ",").alias("joined")).collect() [Row(joined='a,b,c'), Row(joined='a')] >>> df.select(array_join(df.data, ",", "NULL").alias("joined")).collect() [Row(joined='a,b,c'), Row(joined='a,NULL')] - """ + """ # noqa: D205 col = _to_column_expr(col) if null_replacement is not None: col = FunctionExpression( - "list_transform", col, LambdaExpression("x", CaseExpression(ColumnExpression("x").isnull(), ConstantExpression(null_replacement)).otherwise(ColumnExpression("x"))) + "list_transform", + col, + LambdaExpression( + "x", + CaseExpression(ColumnExpression("x").isnull(), ConstantExpression(null_replacement)).otherwise( + ColumnExpression("x") + ), + ), ) return _invoke_function("array_to_string", col, ConstantExpression(delimiter)) -def array_position(col: "ColumnOrName", value: Any) -> Column: - """ - Collection function: Locates the position of the first occurrence of the given value +def array_position(col: "ColumnOrName", value: Any) -> Column: # noqa: ANN401 + """Collection function: Locates the position of the first occurrence of the given value in the given array. Returns null if either of the arguments are null. .. versionadded:: 2.4.0 @@ -5080,7 +5002,7 @@ def array_position(col: "ColumnOrName", value: Any) -> Column: .. versionchanged:: 3.4.0 Supports Spark Connect. - Notes + Notes: ----- The position is not zero based, but 1 based index. Returns 0 if the given value could not be found in the array. @@ -5092,23 +5014,26 @@ def array_position(col: "ColumnOrName", value: Any) -> Column: value : Any value to look for. - Returns + Returns: ------- :class:`~pyspark.sql.Column` position of the value in the given array if found and 0 otherwise. - Examples + Examples: -------- - >>> df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ['data']) + >>> df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ["data"]) >>> df.select(array_position(df.data, "a")).collect() [Row(array_position(data, a)=3), Row(array_position(data, a)=0)] - """ - return Column(CoalesceOperator(_to_column_expr(_invoke_function_over_columns("list_position", col, lit(value))), ConstantExpression(0))) + """ # noqa: D205 + return Column( + CoalesceOperator( + _to_column_expr(_invoke_function_over_columns("list_position", col, lit(value))), ConstantExpression(0) + ) + ) -def array_prepend(col: "ColumnOrName", value: Any) -> Column: - """ - Collection function: Returns an array containing element as +def array_prepend(col: "ColumnOrName", value: Any) -> Column: # noqa: ANN401 + """Collection function: Returns an array containing element as well as all elements from array. The new element is positioned at the beginning of the array. @@ -5121,23 +5046,22 @@ def array_prepend(col: "ColumnOrName", value: Any) -> Column: value : a literal value, or a :class:`~pyspark.sql.Column` expression. - Returns + Returns: ------- :class:`~pyspark.sql.Column` an array excluding given value. - Examples + Examples: -------- - >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) + >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ["data"]) >>> df.select(array_prepend(df.data, 1)).collect() [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] - """ + """ # noqa: D205 return _invoke_function_over_columns("list_prepend", lit(value), col) def array_repeat(col: "ColumnOrName", count: Union["ColumnOrName", int]) -> Column: - """ - Collection function: creates an array containing a column repeated count times. + """Collection function: creates an array containing a column repeated count times. .. versionadded:: 2.4.0 @@ -5151,15 +5075,15 @@ def array_repeat(col: "ColumnOrName", count: Union["ColumnOrName", int]) -> Colu count : :class:`~pyspark.sql.Column` or str or int column name, column, or int containing the number of times to repeat the first argument - Returns + Returns: ------- :class:`~pyspark.sql.Column` an array of repeated elements. - Examples + Examples: -------- - >>> df = spark.createDataFrame([('ab',)], ['data']) - >>> df.select(array_repeat(df.data, 3).alias('r')).collect() + >>> df = spark.createDataFrame([("ab",)], ["data"]) + >>> df.select(array_repeat(df.data, 3).alias("r")).collect() [Row(r=['ab', 'ab', 'ab'])] """ count = lit(count) if isinstance(count, int) else count @@ -5168,8 +5092,7 @@ def array_repeat(col: "ColumnOrName", count: Union["ColumnOrName", int]) -> Colu def array_size(col: "ColumnOrName") -> Column: - """ - Returns the total number of elements in the array. The function returns null for null input. + """Returns the total number of elements in the array. The function returns null for null input. .. versionadded:: 3.5.0 @@ -5178,24 +5101,22 @@ def array_size(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` total number of elements in the array. - Examples + Examples: -------- - >>> df = spark.createDataFrame([([2, 1, 3],), (None,)], ['data']) - >>> df.select(array_size(df.data).alias('r')).collect() + >>> df = spark.createDataFrame([([2, 1, 3],), (None,)], ["data"]) + >>> df.select(array_size(df.data).alias("r")).collect() [Row(r=3), Row(r=None)] """ return _invoke_function_over_columns("len", col) -def array_sort( - col: "ColumnOrName", comparator: Optional[Callable[[Column, Column], Column]] = None -) -> Column: - """ - Collection function: sorts the input array in ascending order. The elements of the input array + +def array_sort(col: "ColumnOrName", comparator: Optional[Callable[[Column, Column], Column]] = None) -> Column: + """Collection function: sorts the input array in ascending order. The elements of the input array must be orderable. Null elements will be placed at the end of the returned array. .. versionadded:: 2.4.0 @@ -5217,32 +5138,38 @@ def array_sort( positive integer as the first element is less than, equal to, or greater than the second element. If the comparator function returns null, the function will fail and raise an error. - Returns + Returns: ------- :class:`~pyspark.sql.Column` sorted array. - Examples + Examples: -------- - >>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data']) - >>> df.select(array_sort(df.data).alias('r')).collect() + >>> df = spark.createDataFrame([([2, 1, None, 3],), ([1],), ([],)], ["data"]) + >>> df.select(array_sort(df.data).alias("r")).collect() [Row(r=[1, 2, 3, None]), Row(r=[1]), Row(r=[])] - >>> df = spark.createDataFrame([(["foo", "foobar", None, "bar"],),(["foo"],),([],)], ['data']) - >>> df.select(array_sort( - ... "data", - ... lambda x, y: when(x.isNull() | y.isNull(), lit(0)).otherwise(length(y) - length(x)) - ... ).alias("r")).collect() + >>> df = spark.createDataFrame( + ... [(["foo", "foobar", None, "bar"],), (["foo"],), ([],)], ["data"] + ... ) + >>> df.select( + ... array_sort( + ... "data", + ... lambda x, y: when(x.isNull() | y.isNull(), lit(0)).otherwise( + ... length(y) - length(x) + ... ), + ... ).alias("r") + ... ).collect() [Row(r=['foobar', 'foo', None, 'bar']), Row(r=['foo']), Row(r=[])] - """ + """ # noqa: D205 if comparator is not None: - raise ContributionsAcceptedError("comparator is not yet supported") + msg = "comparator is not yet supported" + raise ContributionsAcceptedError(msg) else: return _invoke_function_over_columns("list_sort", col, lit("ASC"), lit("NULLS LAST")) def sort_array(col: "ColumnOrName", asc: bool = True) -> Column: - """ - Collection function: sorts the input array in ascending or descending order according + """Collection function: sorts the input array in ascending or descending order according to the natural ordering of the array elements. Null elements will be placed at the beginning of the returned array in ascending order or at the end of the returned array in descending order. @@ -5260,19 +5187,19 @@ def sort_array(col: "ColumnOrName", asc: bool = True) -> Column: whether to sort in ascending or descending order. If `asc` is True (default) then ascending and if False then descending. - Returns + Returns: ------- :class:`~pyspark.sql.Column` sorted array. - Examples + Examples: -------- - >>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data']) - >>> df.select(sort_array(df.data).alias('r')).collect() + >>> df = spark.createDataFrame([([2, 1, None, 3],), ([1],), ([],)], ["data"]) + >>> df.select(sort_array(df.data).alias("r")).collect() [Row(r=[None, 1, 2, 3]), Row(r=[1]), Row(r=[])] - >>> df.select(sort_array(df.data, asc=False).alias('r')).collect() + >>> df.select(sort_array(df.data, asc=False).alias("r")).collect() [Row(r=[3, 2, 1, None]), Row(r=[1]), Row(r=[])] - """ + """ # noqa: D205 if asc: order = "ASC" null_order = "NULLS FIRST" @@ -5283,8 +5210,7 @@ def sort_array(col: "ColumnOrName", asc: bool = True) -> Column: def split(str: "ColumnOrName", pattern: str, limit: int = -1) -> Column: - """ - Splits str around matches of the given pattern. + """Splits str around matches of the given pattern. .. versionadded:: 1.5.0 @@ -5310,29 +5236,34 @@ def split(str: "ColumnOrName", pattern: str, limit: int = -1) -> Column: .. versionchanged:: 3.0 `split` now takes an optional `limit` field. If not provided, default limit value is -1. - Returns + Returns: ------- :class:`~pyspark.sql.Column` array of separated strings. - Examples + Examples: -------- - >>> df = spark.createDataFrame([('oneAtwoBthreeC',)], ['s',]) - >>> df.select(split(df.s, '[ABC]', 2).alias('s')).collect() + >>> df = spark.createDataFrame( + ... [("oneAtwoBthreeC",)], + ... [ + ... "s", + ... ], + ... ) + >>> df.select(split(df.s, "[ABC]", 2).alias("s")).collect() [Row(s=['one', 'twoBthreeC'])] - >>> df.select(split(df.s, '[ABC]', -1).alias('s')).collect() + >>> df.select(split(df.s, "[ABC]", -1).alias("s")).collect() [Row(s=['one', 'two', 'three', ''])] """ if limit > 0: # Unclear how to implement this in DuckDB as we'd need to map back from the split array # to the original array which is tricky with regular expressions. - raise ContributionsAcceptedError("limit is not yet supported") + msg = "limit is not yet supported" + raise ContributionsAcceptedError(msg) return _invoke_function_over_columns("regexp_split_to_array", str, lit(pattern)) def split_part(src: "ColumnOrName", delimiter: "ColumnOrName", partNum: "ColumnOrName") -> Column: - """ - Splits `str` by delimiter and return requested part of the split (1-based). + """Splits `str` by delimiter and return requested part of the split (1-based). If any input is null, returns null. if `partNum` is out of range of split parts, returns empty string. If `partNum` is 0, throws an error. If `partNum` is negative, the parts are counted backward from the end of the string. @@ -5349,23 +5280,35 @@ def split_part(src: "ColumnOrName", delimiter: "ColumnOrName", partNum: "ColumnO partNum : :class:`~pyspark.sql.Column` or str A column of string, requested part of the split (1-based). - Examples + Examples: -------- - >>> df = spark.createDataFrame([("11.12.13", ".", 3,)], ["a", "b", "c"]) - >>> df.select(split_part(df.a, df.b, df.c).alias('r')).collect() + >>> df = spark.createDataFrame( + ... [ + ... ( + ... "11.12.13", + ... ".", + ... 3, + ... ) + ... ], + ... ["a", "b", "c"], + ... ) + >>> df.select(split_part(df.a, df.b, df.c).alias("r")).collect() [Row(r='13')] - """ + """ # noqa: D205 src = _to_column_expr(src) delimiter = _to_column_expr(delimiter) partNum = _to_column_expr(partNum) part = FunctionExpression("split_part", src, delimiter, partNum) - return Column(CaseExpression(src.isnull() | delimiter.isnull() | partNum.isnull(), ConstantExpression(None)).otherwise(CaseExpression(delimiter == ConstantExpression(""), ConstantExpression("")).otherwise(part))) + return Column( + CaseExpression(src.isnull() | delimiter.isnull() | partNum.isnull(), ConstantExpression(None)).otherwise( + CaseExpression(delimiter == ConstantExpression(""), ConstantExpression("")).otherwise(part) + ) + ) def stddev_samp(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns the unbiased sample standard deviation of + """Aggregate function: returns the unbiased sample standard deviation of the expression in a group. .. versionadded:: 1.6.0 @@ -5378,12 +5321,12 @@ def stddev_samp(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` standard deviation of given column. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.range(6).select(sf.stddev_samp("id")).show() @@ -5392,13 +5335,12 @@ def stddev_samp(col: "ColumnOrName") -> Column: +------------------+ |1.8708286933869...| +------------------+ - """ + """ # noqa: D205 return _invoke_function_over_columns("stddev_samp", col) def stddev(col: "ColumnOrName") -> Column: - """ - Aggregate function: alias for stddev_samp. + """Aggregate function: alias for stddev_samp. .. versionadded:: 1.6.0 @@ -5410,12 +5352,12 @@ def stddev(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` standard deviation of given column. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.range(6).select(sf.stddev("id")).show() @@ -5427,9 +5369,9 @@ def stddev(col: "ColumnOrName") -> Column: """ return stddev_samp(col) + def std(col: "ColumnOrName") -> Column: - """ - Aggregate function: alias for stddev_samp. + """Aggregate function: alias for stddev_samp. .. versionadded:: 3.5.0 @@ -5438,12 +5380,12 @@ def std(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` standard deviation of given column. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.range(6).select(sf.std("id")).show() @@ -5457,8 +5399,7 @@ def std(col: "ColumnOrName") -> Column: def stddev_pop(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns population standard deviation of + """Aggregate function: returns population standard deviation of the expression in a group. .. versionadded:: 1.6.0 @@ -5471,12 +5412,12 @@ def stddev_pop(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` standard deviation of given column. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.range(6).select(sf.stddev_pop("id")).show() @@ -5485,13 +5426,12 @@ def stddev_pop(col: "ColumnOrName") -> Column: +-----------------+ |1.707825127659...| +-----------------+ - """ + """ # noqa: D205 return _invoke_function_over_columns("stddev_pop", col) def var_pop(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns the population variance of the values in a group. + """Aggregate function: returns the population variance of the values in a group. .. versionadded:: 1.6.0 @@ -5503,12 +5443,12 @@ def var_pop(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` variance of given column. - Examples + Examples: -------- >>> df = spark.range(6) >>> df.select(var_pop(df.id)).first() @@ -5518,8 +5458,7 @@ def var_pop(col: "ColumnOrName") -> Column: def var_samp(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns the unbiased sample variance of + """Aggregate function: returns the unbiased sample variance of the values in a group. .. versionadded:: 1.6.0 @@ -5532,12 +5471,12 @@ def var_samp(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` variance of given column. - Examples + Examples: -------- >>> df = spark.range(6) >>> df.select(var_samp(df.id)).show() @@ -5546,13 +5485,12 @@ def var_samp(col: "ColumnOrName") -> Column: +------------+ | 3.5| +------------+ - """ + """ # noqa: D205 return _invoke_function_over_columns("var_samp", col) def variance(col: "ColumnOrName") -> Column: - """ - Aggregate function: alias for var_samp + """Aggregate function: alias for var_samp. .. versionadded:: 1.6.0 @@ -5564,12 +5502,12 @@ def variance(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` variance of given column. - Examples + Examples: -------- >>> df = spark.range(6) >>> df.select(variance(df.id)).show() @@ -5583,8 +5521,7 @@ def variance(col: "ColumnOrName") -> Column: def weekday(col: "ColumnOrName") -> Column: - """ - Returns the day of the week for date/timestamp (0 = Monday, 1 = Tuesday, ..., 6 = Sunday). + """Returns the day of the week for date/timestamp (0 = Monday, 1 = Tuesday, ..., 6 = Sunday). .. versionadded:: 3.5.0 @@ -5593,15 +5530,15 @@ def weekday(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target date/timestamp column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the day of the week for date/timestamp (0 = Monday, 1 = Tuesday, ..., 6 = Sunday). - Examples + Examples: -------- - >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) - >>> df.select(weekday('dt').alias('day')).show() + >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) + >>> df.select(weekday("dt").alias("day")).show() +---+ |day| +---+ @@ -5612,8 +5549,7 @@ def weekday(col: "ColumnOrName") -> Column: def zeroifnull(col: "ColumnOrName") -> Column: - """ - Returns zero if `col` is null, or `col` otherwise. + """Returns zero if `col` is null, or `col` otherwise. .. versionadded:: 4.0.0 @@ -5621,7 +5557,7 @@ def zeroifnull(col: "ColumnOrName") -> Column: ---------- col : :class:`~pyspark.sql.Column` or str - Examples + Examples: -------- >>> df = spark.createDataFrame([(None,), (1,)], ["a"]) >>> df.select(zeroifnull(df.a).alias("result")).show() @@ -5634,6 +5570,7 @@ def zeroifnull(col: "ColumnOrName") -> Column: """ return coalesce(col, lit(0)) + def _to_date_or_timestamp(col: "ColumnOrName", spark_datatype: _types.DataType, format: Optional[str] = None) -> Column: if format is not None: raise ContributionsAcceptedError( @@ -5663,21 +5600,21 @@ def to_date(col: "ColumnOrName", format: Optional[str] = None) -> Column: format: str, optional format to use to convert date values. - Returns + Returns: ------- :class:`~pyspark.sql.Column` date value as :class:`pyspark.sql.types.DateType` type. - Examples + Examples: -------- - >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - >>> df.select(to_date(df.t).alias('date')).collect() + >>> df = spark.createDataFrame([("1997-02-28 10:30:00",)], ["t"]) + >>> df.select(to_date(df.t).alias("date")).collect() [Row(date=datetime.date(1997, 2, 28))] - >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - >>> df.select(to_date(df.t, 'yyyy-MM-dd HH:mm:ss').alias('date')).collect() + >>> df = spark.createDataFrame([("1997-02-28 10:30:00",)], ["t"]) + >>> df.select(to_date(df.t, "yyyy-MM-dd HH:mm:ss").alias("date")).collect() [Row(date=datetime.date(1997, 2, 28))] - """ + """ # noqa: D205 return _to_date_or_timestamp(col, _types.DateType(), format) @@ -5701,21 +5638,21 @@ def to_timestamp(col: "ColumnOrName", format: Optional[str] = None) -> Column: format: str, optional format to use to convert timestamp values. - Returns + Returns: ------- :class:`~pyspark.sql.Column` timestamp value as :class:`pyspark.sql.types.TimestampType` type. - Examples + Examples: -------- - >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - >>> df.select(to_timestamp(df.t).alias('dt')).collect() + >>> df = spark.createDataFrame([("1997-02-28 10:30:00",)], ["t"]) + >>> df.select(to_timestamp(df.t).alias("dt")).collect() [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))] - >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - >>> df.select(to_timestamp(df.t, 'yyyy-MM-dd HH:mm:ss').alias('dt')).collect() + >>> df = spark.createDataFrame([("1997-02-28 10:30:00",)], ["t"]) + >>> df.select(to_timestamp(df.t, "yyyy-MM-dd HH:mm:ss").alias("dt")).collect() [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))] - """ + """ # noqa: D205 return _to_date_or_timestamp(col, _types.TimestampNTZType(), format) @@ -5723,8 +5660,7 @@ def to_timestamp_ltz( timestamp: "ColumnOrName", format: Optional["ColumnOrName"] = None, ) -> Column: - """ - Parses the `timestamp` with the `format` to a timestamp without time zone. + """Parses the `timestamp` with the `format` to a timestamp without time zone. Returns null with invalid input. .. versionadded:: 3.5.0 @@ -5736,18 +5672,18 @@ def to_timestamp_ltz( format : :class:`~pyspark.sql.Column` or str, optional format to use to convert type `TimestampType` timestamp values. - Examples + Examples: -------- >>> df = spark.createDataFrame([("2016-12-31",)], ["e"]) - >>> df.select(to_timestamp_ltz(df.e, lit("yyyy-MM-dd")).alias('r')).collect() + >>> df.select(to_timestamp_ltz(df.e, lit("yyyy-MM-dd")).alias("r")).collect() ... # doctest: +SKIP [Row(r=datetime.datetime(2016, 12, 31, 0, 0))] >>> df = spark.createDataFrame([("2016-12-31",)], ["e"]) - >>> df.select(to_timestamp_ltz(df.e).alias('r')).collect() + >>> df.select(to_timestamp_ltz(df.e).alias("r")).collect() ... # doctest: +SKIP [Row(r=datetime.datetime(2016, 12, 31, 0, 0))] - """ + """ # noqa: D205 return _to_date_or_timestamp(timestamp, _types.TimestampNTZType(), format) @@ -5755,8 +5691,7 @@ def to_timestamp_ntz( timestamp: "ColumnOrName", format: Optional["ColumnOrName"] = None, ) -> Column: - """ - Parses the `timestamp` with the `format` to a timestamp without time zone. + """Parses the `timestamp` with the `format` to a timestamp without time zone. Returns null with invalid input. .. versionadded:: 3.5.0 @@ -5768,24 +5703,23 @@ def to_timestamp_ntz( format : :class:`~pyspark.sql.Column` or str, optional format to use to convert type `TimestampNTZType` timestamp values. - Examples + Examples: -------- >>> df = spark.createDataFrame([("2016-04-08",)], ["e"]) - >>> df.select(to_timestamp_ntz(df.e, lit("yyyy-MM-dd")).alias('r')).collect() + >>> df.select(to_timestamp_ntz(df.e, lit("yyyy-MM-dd")).alias("r")).collect() ... # doctest: +SKIP [Row(r=datetime.datetime(2016, 4, 8, 0, 0))] >>> df = spark.createDataFrame([("2016-04-08",)], ["e"]) - >>> df.select(to_timestamp_ntz(df.e).alias('r')).collect() + >>> df.select(to_timestamp_ntz(df.e).alias("r")).collect() ... # doctest: +SKIP [Row(r=datetime.datetime(2016, 4, 8, 0, 0))] - """ + """ # noqa: D205 return _to_date_or_timestamp(timestamp, _types.TimestampNTZType(), format) def try_to_timestamp(col: "ColumnOrName", format: Optional["ColumnOrName"] = None) -> Column: - """ - Parses the `col` with the `format` to a timestamp. The function always + """Parses the `col` with the `format` to a timestamp. The function always returns null on an invalid input with/without ANSI SQL mode enabled. The result data type is consistent with the value of configuration `spark.sql.timestampType`. .. versionadded:: 3.5.0 @@ -5795,24 +5729,23 @@ def try_to_timestamp(col: "ColumnOrName", format: Optional["ColumnOrName"] = Non column values to convert. format: str, optional format to use to convert timestamp values. - Examples + + Examples: -------- - >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - >>> df.select(try_to_timestamp(df.t).alias('dt')).collect() + >>> df = spark.createDataFrame([("1997-02-28 10:30:00",)], ["t"]) + >>> df.select(try_to_timestamp(df.t).alias("dt")).collect() [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))] - >>> df.select(try_to_timestamp(df.t, lit('yyyy-MM-dd HH:mm:ss')).alias('dt')).collect() + >>> df.select(try_to_timestamp(df.t, lit("yyyy-MM-dd HH:mm:ss")).alias("dt")).collect() [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))] - """ + """ # noqa: D205, D415 if format is None: - format = lit(['%Y-%m-%d', '%Y-%m-%d %H:%M:%S']) + format = lit(["%Y-%m-%d", "%Y-%m-%d %H:%M:%S"]) return _invoke_function_over_columns("try_strptime", col, format) -def substr( - str: "ColumnOrName", pos: "ColumnOrName", len: Optional["ColumnOrName"] = None -) -> Column: - """ - Returns the substring of `str` that starts at `pos` and is of length `len`, + +def substr(str: "ColumnOrName", pos: "ColumnOrName", len: Optional["ColumnOrName"] = None) -> Column: + """Returns the substring of `str` that starts at `pos` and is of length `len`, or the slice of byte array that starts at `pos` and is of length `len`. .. versionadded:: 3.5.0 @@ -5826,11 +5759,18 @@ def substr( len : :class:`~pyspark.sql.Column` or str, optional A column of string, the substring of `str` is of length `len`. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.createDataFrame( - ... [("Spark SQL", 5, 1,)], ["a", "b", "c"] + ... [ + ... ( + ... "Spark SQL", + ... 5, + ... 1, + ... ) + ... ], + ... ["a", "b", "c"], ... ).select(sf.substr("a", "b", "c")).show() +---------------+ |substr(a, b, c)| @@ -5840,14 +5780,21 @@ def substr( >>> import pyspark.sql.functions as sf >>> spark.createDataFrame( - ... [("Spark SQL", 5, 1,)], ["a", "b", "c"] + ... [ + ... ( + ... "Spark SQL", + ... 5, + ... 1, + ... ) + ... ], + ... ["a", "b", "c"], ... ).select(sf.substr("a", "b")).show() +------------------------+ |substr(a, b, 2147483647)| +------------------------+ | k SQL| +------------------------+ - """ + """ # noqa: D205 if len is not None: return _invoke_function_over_columns("substring", str, pos, len) else: @@ -5855,18 +5802,21 @@ def substr( def _unix_diff(col: "ColumnOrName", part: str) -> Column: - return _invoke_function_over_columns("date_diff", lit(part), lit("1970-01-01 00:00:00+00:00").cast("timestamp"), col) + return _invoke_function_over_columns( + "date_diff", lit(part), lit("1970-01-01 00:00:00+00:00").cast("timestamp"), col + ) + def unix_date(col: "ColumnOrName") -> Column: """Returns the number of days since 1970-01-01. .. versionadded:: 3.5.0 - Examples + Examples: -------- >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") - >>> df = spark.createDataFrame([('1970-01-02',)], ['t']) - >>> df.select(unix_date(to_date(df.t)).alias('n')).collect() + >>> df = spark.createDataFrame([("1970-01-02",)], ["t"]) + >>> df.select(unix_date(to_date(df.t)).alias("n")).collect() [Row(n=1)] >>> spark.conf.unset("spark.sql.session.timeZone") """ @@ -5878,11 +5828,11 @@ def unix_micros(col: "ColumnOrName") -> Column: .. versionadded:: 3.5.0 - Examples + Examples: -------- >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") - >>> df = spark.createDataFrame([('2015-07-22 10:00:00',)], ['t']) - >>> df.select(unix_micros(to_timestamp(df.t)).alias('n')).collect() + >>> df = spark.createDataFrame([("2015-07-22 10:00:00",)], ["t"]) + >>> df.select(unix_micros(to_timestamp(df.t)).alias("n")).collect() [Row(n=1437584400000000)] >>> spark.conf.unset("spark.sql.session.timeZone") """ @@ -5895,14 +5845,14 @@ def unix_millis(col: "ColumnOrName") -> Column: .. versionadded:: 3.5.0 - Examples + Examples: -------- >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") - >>> df = spark.createDataFrame([('2015-07-22 10:00:00',)], ['t']) - >>> df.select(unix_millis(to_timestamp(df.t)).alias('n')).collect() + >>> df = spark.createDataFrame([("2015-07-22 10:00:00",)], ["t"]) + >>> df.select(unix_millis(to_timestamp(df.t)).alias("n")).collect() [Row(n=1437584400000)] >>> spark.conf.unset("spark.sql.session.timeZone") - """ + """ # noqa: D205 return _unix_diff(col, "milliseconds") @@ -5912,20 +5862,19 @@ def unix_seconds(col: "ColumnOrName") -> Column: .. versionadded:: 3.5.0 - Examples + Examples: -------- >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") - >>> df = spark.createDataFrame([('2015-07-22 10:00:00',)], ['t']) - >>> df.select(unix_seconds(to_timestamp(df.t)).alias('n')).collect() + >>> df = spark.createDataFrame([("2015-07-22 10:00:00",)], ["t"]) + >>> df.select(unix_seconds(to_timestamp(df.t)).alias("n")).collect() [Row(n=1437584400)] >>> spark.conf.unset("spark.sql.session.timeZone") - """ + """ # noqa: D205 return _unix_diff(col, "seconds") def arrays_overlap(a1: "ColumnOrName", a2: "ColumnOrName") -> Column: - """ - Collection function: returns true if the arrays contain any common non-null element; if not, + """Collection function: returns true if the arrays contain any common non-null element; if not, returns null if both the arrays are non-empty and any of them contains a null element; returns false otherwise. @@ -5934,17 +5883,17 @@ def arrays_overlap(a1: "ColumnOrName", a2: "ColumnOrName") -> Column: .. versionchanged:: 3.4.0 Supports Spark Connect. - Returns + Returns: ------- :class:`~pyspark.sql.Column` a column of Boolean type. - Examples + Examples: -------- - >>> df = spark.createDataFrame([(["a", "b"], ["b", "c"]), (["a"], ["b", "c"])], ['x', 'y']) + >>> df = spark.createDataFrame([(["a", "b"], ["b", "c"]), (["a"], ["b", "c"])], ["x", "y"]) >>> df.select(arrays_overlap(df.x, df.y).alias("overlap")).collect() [Row(overlap=True), Row(overlap=False)] - """ + """ # noqa: D205 a1 = _to_column_expr(a1) a2 = _to_column_expr(a2) @@ -5952,28 +5901,25 @@ def arrays_overlap(a1: "ColumnOrName", a2: "ColumnOrName") -> Column: a2_has_null = _list_contains_null(a2) return Column( - CaseExpression( - FunctionExpression("list_has_any", a1, a2), ConstantExpression(True) - ).otherwise( + CaseExpression(FunctionExpression("list_has_any", a1, a2), ConstantExpression(True)).otherwise( CaseExpression( - (FunctionExpression("len", a1) > 0) & (FunctionExpression("len", a2) > 0) & (a1_has_null | a2_has_null), ConstantExpression(None) - ).otherwise(ConstantExpression(False))) + (FunctionExpression("len", a1) > 0) & (FunctionExpression("len", a2) > 0) & (a1_has_null | a2_has_null), + ConstantExpression(None), + ).otherwise(ConstantExpression(False)) + ) ) def _list_contains_null(c: ColumnExpression) -> Expression: return FunctionExpression( "list_contains", - FunctionExpression( - "list_transform", c, LambdaExpression("x", ColumnExpression("x").isnull()) - ), + FunctionExpression("list_transform", c, LambdaExpression("x", ColumnExpression("x").isnull())), True, ) def arrays_zip(*cols: "ColumnOrName") -> Column: - """ - Collection function: Returns a merged array of structs in which the N-th struct contains all + """Collection function: Returns a merged array of structs in which the N-th struct contains all N-th values of input arrays. If one of the arrays is shorter than others then resulting struct type value will be a `null` for missing elements. @@ -5987,16 +5933,18 @@ def arrays_zip(*cols: "ColumnOrName") -> Column: cols : :class:`~pyspark.sql.Column` or str columns of arrays to be merged. - Returns + Returns: ------- :class:`~pyspark.sql.Column` merged array of entries. - Examples + Examples: -------- >>> from pyspark.sql.functions import arrays_zip - >>> df = spark.createDataFrame([([1, 2, 3], [2, 4, 6], [3, 6])], ['vals1', 'vals2', 'vals3']) - >>> df = df.select(arrays_zip(df.vals1, df.vals2, df.vals3).alias('zipped')) + >>> df = spark.createDataFrame( + ... [([1, 2, 3], [2, 4, 6], [3, 6])], ["vals1", "vals2", "vals3"] + ... ) + >>> df = df.select(arrays_zip(df.vals1, df.vals2, df.vals3).alias("zipped")) >>> df.show(truncate=False) +------------------------------------+ |zipped | @@ -6010,19 +5958,19 @@ def arrays_zip(*cols: "ColumnOrName") -> Column: | | |-- vals1: long (nullable = true) | | |-- vals2: long (nullable = true) | | |-- vals3: long (nullable = true) - """ + """ # noqa: D205 return _invoke_function_over_columns("list_zip", *cols) def substring(str: "ColumnOrName", pos: int, len: int) -> Column: - """ - Substring starts at `pos` and is of length `len` when str is String type or + """Substring starts at `pos` and is of length `len` when str is String type or returns the slice of byte array that starts at `pos` in byte and is of length `len` when str is Binary type. .. versionadded:: 1.5.0 .. versionchanged:: 3.4.0 Supports Spark Connect. - Notes + + Notes: ----- The position is not zero based, but 1 based index. Parameters @@ -6033,16 +5981,23 @@ def substring(str: "ColumnOrName", pos: int, len: int) -> Column: starting position in str. len : int length of chars. - Returns + + Returns: ------- :class:`~pyspark.sql.Column` substring of given value. - Examples + + Examples: -------- - >>> df = spark.createDataFrame([('abcd',)], ['s',]) - >>> df.select(substring(df.s, 1, 2).alias('s')).collect() + >>> df = spark.createDataFrame( + ... [("abcd",)], + ... [ + ... "s", + ... ], + ... ) + >>> df.select(substring(df.s, 1, 2).alias("s")).collect() [Row(s='ab')] - """ + """ # noqa: D205 return _invoke_function( "substring", _to_column_expr(str), @@ -6052,8 +6007,7 @@ def substring(str: "ColumnOrName", pos: int, len: int) -> Column: def contains(left: "ColumnOrName", right: "ColumnOrName") -> Column: - """ - Returns a boolean. The value is True if right is found inside left. + """Returns a boolean. The value is True if right is found inside left. Returns NULL if either input expression is NULL. Otherwise, returns False. Both left or right must be of STRING or BINARY type. .. versionadded:: 3.5.0 @@ -6063,12 +6017,21 @@ def contains(left: "ColumnOrName", right: "ColumnOrName") -> Column: The input column or strings to check, may be NULL. right : :class:`~pyspark.sql.Column` or str The input column or strings to find, may be NULL. - Examples + + Examples: -------- - >>> df = spark.createDataFrame([("Spark SQL", "Spark")], ['a', 'b']) - >>> df.select(contains(df.a, df.b).alias('r')).collect() + >>> df = spark.createDataFrame([("Spark SQL", "Spark")], ["a", "b"]) + >>> df.select(contains(df.a, df.b).alias("r")).collect() [Row(r=True)] - >>> df = spark.createDataFrame([("414243", "4243",)], ["c", "d"]) + >>> df = spark.createDataFrame( + ... [ + ... ( + ... "414243", + ... "4243", + ... ) + ... ], + ... ["c", "d"], + ... ) >>> df = df.select(to_binary("c").alias("c"), to_binary("d").alias("d")) >>> df.printSchema() root @@ -6080,13 +6043,12 @@ def contains(left: "ColumnOrName", right: "ColumnOrName") -> Column: +--------------+--------------+ | true| false| +--------------+--------------+ - """ + """ # noqa: D205, D415 return _invoke_function_over_columns("contains", left, right) def reverse(col: "ColumnOrName") -> Column: - """ - Collection function: returns a reversed string or an array with reverse order of elements. + """Collection function: returns a reversed string or an array with reverse order of elements. .. versionadded:: 1.5.0 .. versionchanged:: 3.4.0 Supports Spark Connect. @@ -6094,24 +6056,26 @@ def reverse(col: "ColumnOrName") -> Column: ---------- col : :class:`~pyspark.sql.Column` or str name of column or expression - Returns + + Returns: ------- :class:`~pyspark.sql.Column` array of elements in reverse order. - Examples + + Examples: -------- - >>> df = spark.createDataFrame([('Spark SQL',)], ['data']) - >>> df.select(reverse(df.data).alias('s')).collect() + >>> df = spark.createDataFrame([("Spark SQL",)], ["data"]) + >>> df.select(reverse(df.data).alias("s")).collect() [Row(s='LQS krapS')] - >>> df = spark.createDataFrame([([2, 1, 3],) ,([1],) ,([],)], ['data']) - >>> df.select(reverse(df.data).alias('r')).collect() + >>> df = spark.createDataFrame([([2, 1, 3],), ([1],), ([],)], ["data"]) + >>> df.select(reverse(df.data).alias("r")).collect() [Row(r=[3, 1, 2]), Row(r=[1]), Row(r=[])] - """ + """ # noqa: D205, D415 return _invoke_function("reverse", _to_column_expr(col)) + def concat(*cols: "ColumnOrName") -> Column: - """ - Concatenates multiple input columns together into a single column. + """Concatenates multiple input columns together into a single column. The function works with strings, numeric, binary and compatible array columns. .. versionadded:: 1.5.0 .. versionchanged:: 3.4.0 @@ -6120,34 +6084,38 @@ def concat(*cols: "ColumnOrName") -> Column: ---------- cols : :class:`~pyspark.sql.Column` or str target column or columns to work on. - Returns + + Returns: ------- :class:`~pyspark.sql.Column` concatenated values. Type of the `Column` depends on input columns' type. - See Also + + See Also: -------- :meth:`pyspark.sql.functions.array_join` : to concatenate string columns with delimiter - Examples + + Examples: -------- - >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd']) - >>> df = df.select(concat(df.s, df.d).alias('s')) + >>> df = spark.createDataFrame([("abcd", "123")], ["s", "d"]) + >>> df = df.select(concat(df.s, df.d).alias("s")) >>> df.collect() [Row(s='abcd123')] >>> df DataFrame[s: string] - >>> df = spark.createDataFrame([([1, 2], [3, 4], [5]), ([1, 2], None, [3])], ['a', 'b', 'c']) + >>> df = spark.createDataFrame( + ... [([1, 2], [3, 4], [5]), ([1, 2], None, [3])], ["a", "b", "c"] + ... ) >>> df = df.select(concat(df.a, df.b, df.c).alias("arr")) >>> df.collect() [Row(arr=[1, 2, 3, 4, 5]), Row(arr=None)] >>> df DataFrame[arr: array] - """ + """ # noqa: D205, D415 return _invoke_function_over_columns("concat", *cols) def instr(str: "ColumnOrName", substr: str) -> Column: - """ - Locate the position of the first occurrence of substr column in the given string. + """Locate the position of the first occurrence of substr column in the given string. Returns null if either of the arguments are null. .. versionadded:: 1.5.0 @@ -6155,7 +6123,7 @@ def instr(str: "ColumnOrName", substr: str) -> Column: .. versionchanged:: 3.4.0 Supports Spark Connect. - Notes + Notes: ----- The position is not zero based, but 1 based index. Returns 0 if substr could not be found in str. @@ -6167,21 +6135,27 @@ def instr(str: "ColumnOrName", substr: str) -> Column: substr : str substring to look for. - Returns + Returns: ------- :class:`~pyspark.sql.Column` location of the first occurrence of the substring as integer. - Examples + Examples: -------- - >>> df = spark.createDataFrame([('abcd',)], ['s',]) - >>> df.select(instr(df.s, 'b').alias('s')).collect() + >>> df = spark.createDataFrame( + ... [("abcd",)], + ... [ + ... "s", + ... ], + ... ) + >>> df.select(instr(df.s, "b").alias("s")).collect() [Row(s=2)] - """ + """ # noqa: D205 return _invoke_function("instr", _to_column_expr(str), ConstantExpression(substr)) + def expr(str: str) -> Column: - """Parses the expression string into the column that it represents + """Parses the expression string into the column that it represents. .. versionadded:: 1.5.0 @@ -6193,12 +6167,12 @@ def expr(str: str) -> Column: str : str expression defined in string. - Returns + Returns: ------- :class:`~pyspark.sql.Column` column representing the expression. - Examples + Examples: -------- >>> df = spark.createDataFrame([["Alice"], ["Bob"]], ["name"]) >>> df.select("name", expr("length(name)")).show() @@ -6211,11 +6185,11 @@ def expr(str: str) -> Column: """ return Column(SQLExpression(str)) + def broadcast(df: "DataFrame") -> "DataFrame": - """ - The broadcast function in Spark is used to optimize joins by broadcasting a smaller + """The broadcast function in Spark is used to optimize joins by broadcasting a smaller dataset to all the worker nodes. However, DuckDB operates on a single-node architecture . As a result, the function simply returns the input DataFrame without applying any modifications or optimizations, since broadcasting is not applicable in the DuckDB context. - """ + """ # noqa: D205 return df diff --git a/duckdb/experimental/spark/sql/group.py b/duckdb/experimental/spark/sql/group.py index e6e99beb..aa3e56d6 100644 --- a/duckdb/experimental/spark/sql/group.py +++ b/duckdb/experimental/spark/sql/group.py @@ -1,4 +1,4 @@ -# +# # noqa: D100 # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. @@ -15,26 +15,27 @@ # limitations under the License. # -from ..exception import ContributionsAcceptedError -from typing import Callable, TYPE_CHECKING, overload, Dict, Union, List +from typing import TYPE_CHECKING, Callable, Union, overload +from ..exception import ContributionsAcceptedError from .column import Column -from .session import SparkSession from .dataframe import DataFrame from .functions import _to_column_expr -from ._typing import ColumnOrName from .types import NumericType +# Only import symbols needed for type checking if something is type checking if TYPE_CHECKING: - from ._typing import LiteralType + from ._typing import ColumnOrName + from .session import SparkSession __all__ = ["GroupedData", "Grouping"] + def _api_internal(self: "GroupedData", name: str, *cols: str) -> DataFrame: expressions = ",".join(list(cols)) group_by = str(self._grouping) if self._grouping else "" projections = self._grouping.get_columns() - jdf = getattr(self._df.relation, "apply")( + jdf = self._df.relation.apply( function_name=name, # aggregate function function_aggr=expressions, # inputs to aggregate group_expr=group_by, # groups @@ -42,6 +43,7 @@ def _api_internal(self: "GroupedData", name: str, *cols: str) -> DataFrame: ) return DataFrame(jdf, self.session) + def df_varargs_api(f: Callable[..., DataFrame]) -> Callable[..., DataFrame]: def _api(self: "GroupedData", *cols: str) -> DataFrame: name = f.__name__ @@ -52,49 +54,49 @@ def _api(self: "GroupedData", *cols: str) -> DataFrame: return _api -class Grouping: - def __init__(self, *cols: "ColumnOrName", **kwargs): +class Grouping: # noqa: D101 + def __init__(self, *cols: "ColumnOrName", **kwargs) -> None: # noqa: D107 self._type = "" self._cols = [_to_column_expr(x) for x in cols] - if 'special' in kwargs: - special = kwargs['special'] + if "special" in kwargs: + special = kwargs["special"] accepted_special = ["cube", "rollup"] assert special in accepted_special self._type = special - def get_columns(self) -> str: + def get_columns(self) -> str: # noqa: D102 columns = ",".join([str(x) for x in self._cols]) return columns - def __str__(self): + def __str__(self) -> str: # noqa: D105 columns = self.get_columns() if self._type: - return self._type + '(' + columns + ')' + return self._type + "(" + columns + ")" return columns class GroupedData: - """ - A set of methods for aggregations on a :class:`DataFrame`, + """A set of methods for aggregations on a :class:`DataFrame`, created by :func:`DataFrame.groupBy`. - """ + """ # noqa: D205 - def __init__(self, grouping: Grouping, df: DataFrame): + def __init__(self, grouping: Grouping, df: DataFrame) -> None: # noqa: D107 self._grouping = grouping self._df = df self.session: SparkSession = df.session - def __repr__(self) -> str: + def __repr__(self) -> str: # noqa: D105 return str(self._df) def count(self) -> DataFrame: """Counts the number of records for each group. - Examples + Examples: -------- >>> df = spark.createDataFrame( - ... [(2, "Alice"), (3, "Alice"), (5, "Bob"), (10, "Bob")], ["age", "name"]) + ... [(2, "Alice"), (3, "Alice"), (5, "Bob"), (10, "Bob")], ["age", "name"] + ... ) >>> df.show() +---+-----+ |age| name| @@ -115,7 +117,7 @@ def count(self) -> DataFrame: | Bob| 2| +-----+-----+ """ - return _api_internal(self, "count").withColumnRenamed('count_star()', 'count') + return _api_internal(self, "count").withColumnRenamed("count_star()", "count") @df_varargs_api def mean(self, *cols: str) -> DataFrame: @@ -139,11 +141,12 @@ def avg(self, *cols: str) -> DataFrame: cols : str column names. Non-numeric columns are ignored. - Examples + Examples: -------- - >>> df = spark.createDataFrame([ - ... (2, "Alice", 80), (3, "Alice", 100), - ... (5, "Bob", 120), (10, "Bob", 140)], ["age", "name", "height"]) + >>> df = spark.createDataFrame( + ... [(2, "Alice", 80), (3, "Alice", 100), (5, "Bob", 120), (10, "Bob", 140)], + ... ["age", "name", "height"], + ... ) >>> df.show() +---+-----+------+ |age| name|height| @@ -156,7 +159,7 @@ def avg(self, *cols: str) -> DataFrame: Group-by name, and calculate the mean of the age in each group. - >>> df.groupBy("name").avg('age').sort("name").show() + >>> df.groupBy("name").avg("age").sort("name").show() +-----+--------+ | name|avg(age)| +-----+--------+ @@ -166,7 +169,7 @@ def avg(self, *cols: str) -> DataFrame: Calculate the mean of the age and height in all data. - >>> df.groupBy().avg('age', 'height').show() + >>> df.groupBy().avg("age", "height").show() +--------+-----------+ |avg(age)|avg(height)| +--------+-----------+ @@ -177,18 +180,19 @@ def avg(self, *cols: str) -> DataFrame: if len(columns) == 0: schema = self._df.schema # Take only the numeric types of the relation - columns: List[str] = [x.name for x in schema.fields if isinstance(x.dataType, NumericType)] + columns: list[str] = [x.name for x in schema.fields if isinstance(x.dataType, NumericType)] return _api_internal(self, "avg", *columns) @df_varargs_api def max(self, *cols: str) -> DataFrame: """Computes the max value for each numeric columns for each group. - Examples + Examples: -------- - >>> df = spark.createDataFrame([ - ... (2, "Alice", 80), (3, "Alice", 100), - ... (5, "Bob", 120), (10, "Bob", 140)], ["age", "name", "height"]) + >>> df = spark.createDataFrame( + ... [(2, "Alice", 80), (3, "Alice", 100), (5, "Bob", 120), (10, "Bob", 140)], + ... ["age", "name", "height"], + ... ) >>> df.show() +---+-----+------+ |age| name|height| @@ -228,11 +232,12 @@ def min(self, *cols: str) -> DataFrame: cols : str column names. Non-numeric columns are ignored. - Examples + Examples: -------- - >>> df = spark.createDataFrame([ - ... (2, "Alice", 80), (3, "Alice", 100), - ... (5, "Bob", 120), (10, "Bob", 140)], ["age", "name", "height"]) + >>> df = spark.createDataFrame( + ... [(2, "Alice", 80), (3, "Alice", 100), (5, "Bob", 120), (10, "Bob", 140)], + ... ["age", "name", "height"], + ... ) >>> df.show() +---+-----+------+ |age| name|height| @@ -272,11 +277,12 @@ def sum(self, *cols: str) -> DataFrame: cols : str column names. Non-numeric columns are ignored. - Examples + Examples: -------- - >>> df = spark.createDataFrame([ - ... (2, "Alice", 80), (3, "Alice", 100), - ... (5, "Bob", 120), (10, "Bob", 140)], ["age", "name", "height"]) + >>> df = spark.createDataFrame( + ... [(2, "Alice", 80), (3, "Alice", 100), (5, "Bob", 120), (10, "Bob", 140)], + ... ["age", "name", "height"], + ... ) >>> df.show() +---+-----+------+ |age| name|height| @@ -308,14 +314,12 @@ def sum(self, *cols: str) -> DataFrame: """ @overload - def agg(self, *exprs: Column) -> DataFrame: - ... + def agg(self, *exprs: Column) -> DataFrame: ... @overload - def agg(self, __exprs: Dict[str, str]) -> DataFrame: - ... + def agg(self, __exprs: dict[str, str]) -> DataFrame: ... # noqa: PYI063 - def agg(self, *exprs: Union[Column, Dict[str, str]]) -> DataFrame: + def agg(self, *exprs: Union[Column, dict[str, str]]) -> DataFrame: """Compute aggregates and returns the result as a :class:`DataFrame`. The available aggregate functions can be: @@ -347,17 +351,18 @@ def agg(self, *exprs: Union[Column, Dict[str, str]]) -> DataFrame: a dict mapping from column name (string) to aggregate functions (string), or a list of :class:`Column`. - Notes + Notes: ----- Built-in aggregation functions and group aggregate pandas UDFs cannot be mixed in a single call to this function. - Examples + Examples: -------- >>> from pyspark.sql import functions as F >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( - ... [(2, "Alice"), (3, "Alice"), (5, "Bob"), (10, "Bob")], ["age", "name"]) + ... [(2, "Alice"), (3, "Alice"), (5, "Bob"), (10, "Bob")], ["age", "name"] + ... ) >>> df.show() +---+-----+ |age| name| @@ -393,10 +398,9 @@ def agg(self, *exprs: Union[Column, Dict[str, str]]) -> DataFrame: Same as above but uses pandas UDF. - >>> @pandas_udf('int', PandasUDFType.GROUPED_AGG) # doctest: +SKIP + >>> @pandas_udf("int", PandasUDFType.GROUPED_AGG) # doctest: +SKIP ... def min_udf(v): ... return v.min() - ... >>> df.groupBy(df.name).agg(min_udf(df.age)).sort("name").show() # doctest: +SKIP +-----+------------+ | name|min_udf(age)| @@ -417,4 +421,4 @@ def agg(self, *exprs: Union[Column, Dict[str, str]]) -> DataFrame: rel = self._df.relation.select(*expressions, groups=group_by) return DataFrame(rel, self.session) - # TODO: add 'pivot' + # TODO: add 'pivot' # noqa: TD002, TD003 diff --git a/duckdb/experimental/spark/sql/readwriter.py b/duckdb/experimental/spark/sql/readwriter.py index 990201cf..b3d08561 100644 --- a/duckdb/experimental/spark/sql/readwriter.py +++ b/duckdb/experimental/spark/sql/readwriter.py @@ -1,11 +1,9 @@ -from typing import TYPE_CHECKING, List, Optional, Union, cast +from typing import TYPE_CHECKING, Optional, Union, cast # noqa: D100 +from ..errors import PySparkNotImplementedError, PySparkTypeError from ..exception import ContributionsAcceptedError from .types import StructType - -from ..errors import PySparkNotImplementedError, PySparkTypeError - PrimitiveType = Union[bool, float, int, str] OptionalPrimitiveType = Optional[PrimitiveType] @@ -14,19 +12,19 @@ from duckdb.experimental.spark.sql.session import SparkSession -class DataFrameWriter: - def __init__(self, dataframe: "DataFrame"): +class DataFrameWriter: # noqa: D101 + def __init__(self, dataframe: "DataFrame") -> None: # noqa: D107 self.dataframe = dataframe - def saveAsTable(self, table_name: str) -> None: + def saveAsTable(self, table_name: str) -> None: # noqa: D102 relation = self.dataframe.relation relation.create(table_name) - def parquet( + def parquet( # noqa: D102 self, path: str, mode: Optional[str] = None, - partitionBy: Union[str, List[str], None] = None, + partitionBy: Union[str, list[str], None] = None, compression: Optional[str] = None, ) -> None: relation = self.dataframe.relation @@ -37,7 +35,7 @@ def parquet( relation.write_parquet(path, compression=compression) - def csv( + def csv( # noqa: D102 self, path: str, mode: Optional[str] = None, @@ -57,7 +55,7 @@ def csv( encoding: Optional[str] = None, emptyValue: Optional[str] = None, lineSep: Optional[str] = None, - ): + ) -> None: if mode not in (None, "overwrite"): raise NotImplementedError if escapeQuotes: @@ -88,13 +86,13 @@ def csv( ) -class DataFrameReader: - def __init__(self, session: "SparkSession"): +class DataFrameReader: # noqa: D101 + def __init__(self, session: "SparkSession") -> None: # noqa: D107 self.session = session - def load( + def load( # noqa: D102 self, - path: Optional[Union[str, List[str]]] = None, + path: Optional[Union[str, list[str]]] = None, format: Optional[str] = None, schema: Optional[Union[StructType, str]] = None, **options: OptionalPrimitiveType, @@ -102,7 +100,7 @@ def load( from duckdb.experimental.spark.sql.dataframe import DataFrame if not isinstance(path, str): - raise ImportError + raise TypeError if options: raise ContributionsAcceptedError @@ -123,15 +121,15 @@ def load( if schema: if not isinstance(schema, StructType): raise ContributionsAcceptedError - schema = cast(StructType, schema) + schema = cast("StructType", schema) types, names = schema.extract_types_and_names() df = df._cast_types(types) df = df.toDF(names) raise NotImplementedError - def csv( + def csv( # noqa: D102 self, - path: Union[str, List[str]], + path: Union[str, list[str]], schema: Optional[Union[StructType, str]] = None, sep: Optional[str] = None, encoding: Optional[str] = None, @@ -225,7 +223,7 @@ def csv( dtype = None names = None if schema: - schema = cast(StructType, schema) + schema = cast("StructType", schema) dtype, names = schema.extract_types_and_names() rel = self.session.conn.read_csv( @@ -247,13 +245,15 @@ def csv( df = df.toDF(*names) return df - def parquet(self, *paths: str, **options: "OptionalPrimitiveType") -> "DataFrame": + def parquet(self, *paths: str, **options: "OptionalPrimitiveType") -> "DataFrame": # noqa: D102 input = list(paths) if len(input) != 1: - raise NotImplementedError("Only single paths are supported for now") + msg = "Only single paths are supported for now" + raise NotImplementedError(msg) option_amount = len(options.keys()) if option_amount != 0: - raise ContributionsAcceptedError("Options are not supported") + msg = "Options are not supported" + raise ContributionsAcceptedError(msg) path = input[0] rel = self.session.conn.read_parquet(path) from ..sql.dataframe import DataFrame @@ -263,7 +263,7 @@ def parquet(self, *paths: str, **options: "OptionalPrimitiveType") -> "DataFrame def json( self, - path: Union[str, List[str]], + path: Union[str, list[str]], schema: Optional[Union[StructType, str]] = None, primitivesAsString: Optional[Union[bool, str]] = None, prefersDecimal: Optional[Union[bool, str]] = None, @@ -289,8 +289,7 @@ def json( modifiedAfter: Optional[Union[bool, str]] = None, allowNonNumericNumbers: Optional[Union[bool, str]] = None, ) -> "DataFrame": - """ - Loads JSON files and returns the results as a :class:`DataFrame`. + """Loads JSON files and returns the results as a :class:`DataFrame`. `JSON Lines `_ (newline-delimited JSON) is supported by default. For JSON (one record per file), set the ``multiLine`` parameter to ``true``. @@ -321,16 +320,16 @@ def json( .. # noqa - Examples + Examples: -------- Write a DataFrame into a JSON file and read it back. >>> import tempfile >>> with tempfile.TemporaryDirectory() as d: ... # Write a DataFrame into a JSON file - ... spark.createDataFrame( - ... [{"age": 100, "name": "Hyukjin Kwon"}] - ... ).write.mode("overwrite").format("json").save(d) + ... spark.createDataFrame([{"age": 100, "name": "Hyukjin Kwon"}]).write.mode( + ... "overwrite" + ... ).format("json").save(d) ... ... # Read the JSON file as a DataFrame. ... spark.read.json(d).show() @@ -340,102 +339,89 @@ def json( |100|Hyukjin Kwon| +---+------------+ """ - if schema is not None: - raise ContributionsAcceptedError("The 'schema' option is not supported") + msg = "The 'schema' option is not supported" + raise ContributionsAcceptedError(msg) if primitivesAsString is not None: - raise ContributionsAcceptedError( - "The 'primitivesAsString' option is not supported" - ) + msg = "The 'primitivesAsString' option is not supported" + raise ContributionsAcceptedError(msg) if prefersDecimal is not None: - raise ContributionsAcceptedError( - "The 'prefersDecimal' option is not supported" - ) + msg = "The 'prefersDecimal' option is not supported" + raise ContributionsAcceptedError(msg) if allowComments is not None: - raise ContributionsAcceptedError( - "The 'allowComments' option is not supported" - ) + msg = "The 'allowComments' option is not supported" + raise ContributionsAcceptedError(msg) if allowUnquotedFieldNames is not None: - raise ContributionsAcceptedError( - "The 'allowUnquotedFieldNames' option is not supported" - ) + msg = "The 'allowUnquotedFieldNames' option is not supported" + raise ContributionsAcceptedError(msg) if allowSingleQuotes is not None: - raise ContributionsAcceptedError( - "The 'allowSingleQuotes' option is not supported" - ) + msg = "The 'allowSingleQuotes' option is not supported" + raise ContributionsAcceptedError(msg) if allowNumericLeadingZero is not None: - raise ContributionsAcceptedError( - "The 'allowNumericLeadingZero' option is not supported" - ) + msg = "The 'allowNumericLeadingZero' option is not supported" + raise ContributionsAcceptedError(msg) if allowBackslashEscapingAnyCharacter is not None: - raise ContributionsAcceptedError( - "The 'allowBackslashEscapingAnyCharacter' option is not supported" - ) + msg = "The 'allowBackslashEscapingAnyCharacter' option is not supported" + raise ContributionsAcceptedError(msg) if mode is not None: - raise ContributionsAcceptedError("The 'mode' option is not supported") + msg = "The 'mode' option is not supported" + raise ContributionsAcceptedError(msg) if columnNameOfCorruptRecord is not None: - raise ContributionsAcceptedError( - "The 'columnNameOfCorruptRecord' option is not supported" - ) + msg = "The 'columnNameOfCorruptRecord' option is not supported" + raise ContributionsAcceptedError(msg) if dateFormat is not None: - raise ContributionsAcceptedError("The 'dateFormat' option is not supported") + msg = "The 'dateFormat' option is not supported" + raise ContributionsAcceptedError(msg) if timestampFormat is not None: - raise ContributionsAcceptedError( - "The 'timestampFormat' option is not supported" - ) + msg = "The 'timestampFormat' option is not supported" + raise ContributionsAcceptedError(msg) if multiLine is not None: - raise ContributionsAcceptedError("The 'multiLine' option is not supported") + msg = "The 'multiLine' option is not supported" + raise ContributionsAcceptedError(msg) if allowUnquotedControlChars is not None: - raise ContributionsAcceptedError( - "The 'allowUnquotedControlChars' option is not supported" - ) + msg = "The 'allowUnquotedControlChars' option is not supported" + raise ContributionsAcceptedError(msg) if lineSep is not None: - raise ContributionsAcceptedError("The 'lineSep' option is not supported") + msg = "The 'lineSep' option is not supported" + raise ContributionsAcceptedError(msg) if samplingRatio is not None: - raise ContributionsAcceptedError( - "The 'samplingRatio' option is not supported" - ) + msg = "The 'samplingRatio' option is not supported" + raise ContributionsAcceptedError(msg) if dropFieldIfAllNull is not None: - raise ContributionsAcceptedError( - "The 'dropFieldIfAllNull' option is not supported" - ) + msg = "The 'dropFieldIfAllNull' option is not supported" + raise ContributionsAcceptedError(msg) if encoding is not None: - raise ContributionsAcceptedError("The 'encoding' option is not supported") + msg = "The 'encoding' option is not supported" + raise ContributionsAcceptedError(msg) if locale is not None: - raise ContributionsAcceptedError("The 'locale' option is not supported") + msg = "The 'locale' option is not supported" + raise ContributionsAcceptedError(msg) if pathGlobFilter is not None: - raise ContributionsAcceptedError( - "The 'pathGlobFilter' option is not supported" - ) + msg = "The 'pathGlobFilter' option is not supported" + raise ContributionsAcceptedError(msg) if recursiveFileLookup is not None: - raise ContributionsAcceptedError( - "The 'recursiveFileLookup' option is not supported" - ) + msg = "The 'recursiveFileLookup' option is not supported" + raise ContributionsAcceptedError(msg) if modifiedBefore is not None: - raise ContributionsAcceptedError( - "The 'modifiedBefore' option is not supported" - ) + msg = "The 'modifiedBefore' option is not supported" + raise ContributionsAcceptedError(msg) if modifiedAfter is not None: - raise ContributionsAcceptedError( - "The 'modifiedAfter' option is not supported" - ) + msg = "The 'modifiedAfter' option is not supported" + raise ContributionsAcceptedError(msg) if allowNonNumericNumbers is not None: - raise ContributionsAcceptedError( - "The 'allowNonNumericNumbers' option is not supported" - ) + msg = "The 'allowNonNumericNumbers' option is not supported" + raise ContributionsAcceptedError(msg) if isinstance(path, str): path = [path] - if isinstance(path, list): + if isinstance(path, list): if len(path) == 1: rel = self.session.conn.read_json(path[0]) from .dataframe import DataFrame df = DataFrame(rel, self.session) return df - raise PySparkNotImplementedError( - message="Only a single path is supported for now" - ) + raise PySparkNotImplementedError(message="Only a single path is supported for now") else: raise PySparkTypeError( error_class="NOT_STR_OR_LIST_OF_RDD", @@ -446,4 +432,4 @@ def json( ) -__all__ = ["DataFrameWriter", "DataFrameReader"] +__all__ = ["DataFrameReader", "DataFrameWriter"] diff --git a/duckdb/experimental/spark/sql/session.py b/duckdb/experimental/spark/sql/session.py index d3cfaa68..b05b9705 100644 --- a/duckdb/experimental/spark/sql/session.py +++ b/duckdb/experimental/spark/sql/session.py @@ -1,32 +1,31 @@ -from typing import Optional, List, Any, Union, Iterable, TYPE_CHECKING -import uuid +import uuid # noqa: D100 +from collections.abc import Iterable, Sized +from typing import TYPE_CHECKING, Any, NoReturn, Optional, Union + +import duckdb if TYPE_CHECKING: - from .catalog import Catalog from pandas.core.frame import DataFrame as PandasDataFrame -from ..exception import ContributionsAcceptedError -from .types import StructType, AtomicType, DataType + from .catalog import Catalog + + from ..conf import SparkConf -from .dataframe import DataFrame +from ..context import SparkContext +from ..errors import PySparkTypeError +from ..exception import ContributionsAcceptedError from .conf import RuntimeConfig +from .dataframe import DataFrame from .readwriter import DataFrameReader -from ..context import SparkContext -from .udf import UDFRegistration from .streaming import DataStreamReader -import duckdb - -from ..errors import ( - PySparkTypeError, - PySparkValueError -) - -from ..errors.error_classes import * +from .types import StructType +from .udf import UDFRegistration # In spark: # SparkSession holds a SparkContext # SparkContext gets created from SparkConf -# At this level the check is made to determine whether the instance already exists and just needs to be retrieved or it needs to be created +# At this level the check is made to determine whether the instance already exists and just needs +# to be retrieved or it needs to be created. # For us this is done inside of `duckdb.connect`, based on the passed in path + configuration # SparkContext can be compared to our Connection class, and SparkConf to our ClientContext class @@ -34,7 +33,7 @@ # data is a List of rows # every value in each row needs to be turned into a Value -def _combine_data_and_schema(data: Iterable[Any], schema: StructType): +def _combine_data_and_schema(data: Iterable[Any], schema: StructType) -> list[duckdb.Value]: from duckdb import Value new_data = [] @@ -44,8 +43,8 @@ def _combine_data_and_schema(data: Iterable[Any], schema: StructType): return new_data -class SparkSession: - def __init__(self, context: SparkContext): +class SparkSession: # noqa: D101 + def __init__(self, context: SparkContext) -> None: # noqa: D107 self.conn = context.connection self._context = context self._conf = RuntimeConfig(self.conn) @@ -53,15 +52,16 @@ def __init__(self, context: SparkContext): def _create_dataframe(self, data: Union[Iterable[Any], "PandasDataFrame"]) -> DataFrame: try: import pandas + has_pandas = True except ImportError: has_pandas = False if has_pandas and isinstance(data, pandas.DataFrame): - unique_name = f'pyspark_pandas_df_{uuid.uuid1()}' + unique_name = f"pyspark_pandas_df_{uuid.uuid1()}" self.conn.register(unique_name, data) return DataFrame(self.conn.sql(f'select * from "{unique_name}"'), self) - def verify_tuple_integrity(tuples): + def verify_tuple_integrity(tuples: list[tuple]) -> None: if len(tuples) <= 1: return expected_length = len(tuples[0]) @@ -73,9 +73,9 @@ def verify_tuple_integrity(tuples): error_class="LENGTH_SHOULD_BE_THE_SAME", message_parameters={ "arg1": f"data{i}", - "arg2": f"data{i+1}", + "arg2": f"data{i + 1}", "arg1_length": str(expected_length), - "arg2_length": str(actual_length) + "arg2_length": str(actual_length), }, ) @@ -83,16 +83,16 @@ def verify_tuple_integrity(tuples): data = list(data) verify_tuple_integrity(data) - def construct_query(tuples) -> str: - def construct_values_list(row, start_param_idx): + def construct_query(tuples: Iterable) -> str: + def construct_values_list(row: Sized, start_param_idx: int) -> str: parameter_count = len(row) - parameters = [f'${x+start_param_idx}' for x in range(parameter_count)] - parameters = '(' + ', '.join(parameters) + ')' + parameters = [f"${x + start_param_idx}" for x in range(parameter_count)] + parameters = "(" + ", ".join(parameters) + ")" return parameters row_size = len(tuples[0]) values_list = [construct_values_list(x, 1 + (i * row_size)) for i, x in enumerate(tuples)] - values_list = ', '.join(values_list) + values_list = ", ".join(values_list) query = f""" select * from (values {values_list}) @@ -101,7 +101,7 @@ def construct_values_list(row, start_param_idx): query = construct_query(data) - def construct_parameters(tuples): + def construct_parameters(tuples: Iterable) -> list[list]: parameters = [] for row in tuples: parameters.extend(list(row)) @@ -112,7 +112,9 @@ def construct_parameters(tuples): rel = self.conn.sql(query, params=parameters) return DataFrame(rel, self) - def _createDataFrameFromPandas(self, data: "PandasDataFrame", types, names) -> DataFrame: + def _createDataFrameFromPandas( + self, data: "PandasDataFrame", types: Union[list[str], None], names: Union[list[str], None] + ) -> DataFrame: df = self._create_dataframe(data) # Cast to types @@ -123,10 +125,10 @@ def _createDataFrameFromPandas(self, data: "PandasDataFrame", types, names) -> D df = df.toDF(*names) return df - def createDataFrame( + def createDataFrame( # noqa: D102 self, data: Union["PandasDataFrame", Iterable[Any]], - schema: Optional[Union[StructType, List[str]]] = None, + schema: Optional[Union[StructType, list[str]]] = None, samplingRatio: Optional[float] = None, verifySchema: bool = True, ) -> DataFrame: @@ -175,7 +177,7 @@ def createDataFrame( if is_empty: rel = df.relation # Add impossible where clause - rel = rel.filter('1=0') + rel = rel.filter("1=0") df = DataFrame(rel, self) # Cast to types @@ -186,10 +188,10 @@ def createDataFrame( df = df.toDF(*names) return df - def newSession(self) -> "SparkSession": + def newSession(self) -> "SparkSession": # noqa: D102 return SparkSession(self._context) - def range( + def range( # noqa: D102 self, start: int, end: Optional[int] = None, @@ -203,26 +205,26 @@ def range( end = start start = 0 - return DataFrame(self.conn.table_function("range", parameters=[start, end, step]),self) + return DataFrame(self.conn.table_function("range", parameters=[start, end, step]), self) - def sql(self, sqlQuery: str, **kwargs: Any) -> DataFrame: + def sql(self, sqlQuery: str, **kwargs: Any) -> DataFrame: # noqa: D102, ANN401 if kwargs: raise NotImplementedError relation = self.conn.sql(sqlQuery) return DataFrame(relation, self) - def stop(self) -> None: + def stop(self) -> None: # noqa: D102 self._context.stop() - def table(self, tableName: str) -> DataFrame: + def table(self, tableName: str) -> DataFrame: # noqa: D102 relation = self.conn.table(tableName) return DataFrame(relation, self) - def getActiveSession(self) -> "SparkSession": + def getActiveSession(self) -> "SparkSession": # noqa: D102 return self @property - def catalog(self) -> "Catalog": + def catalog(self) -> "Catalog": # noqa: D102 if not hasattr(self, "_catalog"): from duckdb.experimental.spark.sql.catalog import Catalog @@ -230,59 +232,62 @@ def catalog(self) -> "Catalog": return self._catalog @property - def conf(self) -> RuntimeConfig: + def conf(self) -> RuntimeConfig: # noqa: D102 return self._conf @property - def read(self) -> DataFrameReader: + def read(self) -> DataFrameReader: # noqa: D102 return DataFrameReader(self) @property - def readStream(self) -> DataStreamReader: + def readStream(self) -> DataStreamReader: # noqa: D102 return DataStreamReader(self) @property - def sparkContext(self) -> SparkContext: + def sparkContext(self) -> SparkContext: # noqa: D102 return self._context @property - def streams(self) -> Any: + def streams(self) -> NoReturn: # noqa: D102 raise ContributionsAcceptedError @property - def udf(self) -> UDFRegistration: + def udf(self) -> UDFRegistration: # noqa: D102 return UDFRegistration(self) @property - def version(self) -> str: - return '1.0.0' + def version(self) -> str: # noqa: D102 + return "1.0.0" - class Builder: - def __init__(self): + class Builder: # noqa: D106 + def __init__(self) -> None: # noqa: D107 pass - def master(self, name: str) -> "SparkSession.Builder": + def master(self, name: str) -> "SparkSession.Builder": # noqa: D102 # no-op return self - def appName(self, name: str) -> "SparkSession.Builder": + def appName(self, name: str) -> "SparkSession.Builder": # noqa: D102 # no-op return self - def remote(self, url: str) -> "SparkSession.Builder": + def remote(self, url: str) -> "SparkSession.Builder": # noqa: D102 # no-op return self - def getOrCreate(self) -> "SparkSession": + def getOrCreate(self) -> "SparkSession": # noqa: D102 context = SparkContext("__ignored__") return SparkSession(context) - def config( - self, key: Optional[str] = None, value: Optional[Any] = None, conf: Optional[SparkConf] = None + def config( # noqa: D102 + self, + key: Optional[str] = None, + value: Optional[Any] = None, # noqa: ANN401 + conf: Optional[SparkConf] = None, ) -> "SparkSession.Builder": return self - def enableHiveSupport(self) -> "SparkSession.Builder": + def enableHiveSupport(self) -> "SparkSession.Builder": # noqa: D102 # no-op return self diff --git a/duckdb/experimental/spark/sql/streaming.py b/duckdb/experimental/spark/sql/streaming.py index 5414344f..08b7cc30 100644 --- a/duckdb/experimental/spark/sql/streaming.py +++ b/duckdb/experimental/spark/sql/streaming.py @@ -1,4 +1,5 @@ -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Optional, Union # noqa: D100 + from .types import StructType if TYPE_CHECKING: @@ -9,28 +10,26 @@ OptionalPrimitiveType = Optional[PrimitiveType] -class DataStreamWriter: - def __init__(self, dataframe: "DataFrame"): +class DataStreamWriter: # noqa: D101 + def __init__(self, dataframe: "DataFrame") -> None: # noqa: D107 self.dataframe = dataframe - def toTable(self, table_name: str) -> None: + def toTable(self, table_name: str) -> None: # noqa: D102 # Should we register the dataframe or create a table from the contents? raise NotImplementedError -class DataStreamReader: - def __init__(self, session: "SparkSession"): +class DataStreamReader: # noqa: D101 + def __init__(self, session: "SparkSession") -> None: # noqa: D107 self.session = session - def load( + def load( # noqa: D102 self, path: Optional[str] = None, format: Optional[str] = None, schema: Union[StructType, str, None] = None, - **options: OptionalPrimitiveType + **options: OptionalPrimitiveType, ) -> "DataFrame": - from duckdb.experimental.spark.sql.dataframe import DataFrame - raise NotImplementedError diff --git a/duckdb/experimental/spark/sql/type_utils.py b/duckdb/experimental/spark/sql/type_utils.py index a17d0f53..90dac658 100644 --- a/duckdb/experimental/spark/sql/type_utils.py +++ b/duckdb/experimental/spark/sql/type_utils.py @@ -1,105 +1,107 @@ -from duckdb.typing import DuckDBPyType -from typing import List, Tuple, cast +from typing import cast # noqa: D100 + +from duckdb.sqltypes import DuckDBPyType + from .types import ( - DataType, - StringType, + ArrayType, BinaryType, BitstringType, - UUIDType, BooleanType, + ByteType, + DataType, DateType, - TimestampType, - TimestampNTZType, - TimeType, - TimeNTZType, - TimestampNanosecondNTZType, - TimestampMilisecondNTZType, - TimestampSecondNTZType, + DayTimeIntervalType, DecimalType, DoubleType, FloatType, - ByteType, - UnsignedByteType, - ShortType, - UnsignedShortType, + HugeIntegerType, IntegerType, - UnsignedIntegerType, LongType, - UnsignedLongType, - HugeIntegerType, - UnsignedHugeIntegerType, - DayTimeIntervalType, - ArrayType, MapType, + ShortType, + StringType, StructField, StructType, + TimeNTZType, + TimestampMilisecondNTZType, + TimestampNanosecondNTZType, + TimestampNTZType, + TimestampSecondNTZType, + TimestampType, + TimeType, + UnsignedByteType, + UnsignedHugeIntegerType, + UnsignedIntegerType, + UnsignedLongType, + UnsignedShortType, + UUIDType, ) _sqltype_to_spark_class = { - 'boolean': BooleanType, - 'utinyint': UnsignedByteType, - 'tinyint': ByteType, - 'usmallint': UnsignedShortType, - 'smallint': ShortType, - 'uinteger': UnsignedIntegerType, - 'integer': IntegerType, - 'ubigint': UnsignedLongType, - 'bigint': LongType, - 'hugeint': HugeIntegerType, - 'uhugeint': UnsignedHugeIntegerType, - 'varchar': StringType, - 'blob': BinaryType, - 'bit': BitstringType, - 'uuid': UUIDType, - 'date': DateType, - 'time': TimeNTZType, - 'time with time zone': TimeType, - 'timestamp': TimestampNTZType, - 'timestamp with time zone': TimestampType, - 'timestamp_ms': TimestampNanosecondNTZType, - 'timestamp_ns': TimestampMilisecondNTZType, - 'timestamp_s': TimestampSecondNTZType, - 'interval': DayTimeIntervalType, - 'list': ArrayType, - 'struct': StructType, - 'map': MapType, + "boolean": BooleanType, + "utinyint": UnsignedByteType, + "tinyint": ByteType, + "usmallint": UnsignedShortType, + "smallint": ShortType, + "uinteger": UnsignedIntegerType, + "integer": IntegerType, + "ubigint": UnsignedLongType, + "bigint": LongType, + "hugeint": HugeIntegerType, + "uhugeint": UnsignedHugeIntegerType, + "varchar": StringType, + "blob": BinaryType, + "bit": BitstringType, + "uuid": UUIDType, + "date": DateType, + "time": TimeNTZType, + "time with time zone": TimeType, + "timestamp": TimestampNTZType, + "timestamp with time zone": TimestampType, + "timestamp_ms": TimestampNanosecondNTZType, + "timestamp_ns": TimestampMilisecondNTZType, + "timestamp_s": TimestampSecondNTZType, + "interval": DayTimeIntervalType, + "list": ArrayType, + "struct": StructType, + "map": MapType, # union # enum # null (???) - 'float': FloatType, - 'double': DoubleType, - 'decimal': DecimalType, + "float": FloatType, + "double": DoubleType, + "decimal": DecimalType, } -def convert_nested_type(dtype: DuckDBPyType) -> DataType: +def convert_nested_type(dtype: DuckDBPyType) -> DataType: # noqa: D103 id = dtype.id - if id == 'list' or id == 'array': + if id == "list" or id == "array": children = dtype.children return ArrayType(convert_type(children[0][1])) - # TODO: add support for 'union' - if id == 'struct': - children: List[Tuple[str, DuckDBPyType]] = dtype.children + # TODO: add support for 'union' # noqa: TD002, TD003 + if id == "struct": + children: list[tuple[str, DuckDBPyType]] = dtype.children fields = [StructField(x[0], convert_type(x[1])) for x in children] return StructType(fields) - if id == 'map': + if id == "map": return MapType(convert_type(dtype.key), convert_type(dtype.value)) raise NotImplementedError -def convert_type(dtype: DuckDBPyType) -> DataType: +def convert_type(dtype: DuckDBPyType) -> DataType: # noqa: D103 id = dtype.id - if id in ['list', 'struct', 'map', 'array']: + if id in ["list", "struct", "map", "array"]: return convert_nested_type(dtype) - if id == 'decimal': - children: List[Tuple[str, DuckDBPyType]] = dtype.children - precision = cast(int, children[0][1]) - scale = cast(int, children[1][1]) + if id == "decimal": + children: list[tuple[str, DuckDBPyType]] = dtype.children + precision = cast("int", children[0][1]) + scale = cast("int", children[1][1]) return DecimalType(precision, scale) spark_type = _sqltype_to_spark_class[id] return spark_type() -def duckdb_to_spark_schema(names: List[str], types: List[DuckDBPyType]) -> StructType: +def duckdb_to_spark_schema(names: list[str], types: list[DuckDBPyType]) -> StructType: # noqa: D103 fields = [StructField(name, dtype) for name, dtype in zip(names, [convert_type(x) for x in types])] return StructType(fields) diff --git a/duckdb/experimental/spark/sql/types.py b/duckdb/experimental/spark/sql/types.py index 13cd8480..856885e9 100644 --- a/duckdb/experimental/spark/sql/types.py +++ b/duckdb/experimental/spark/sql/types.py @@ -1,28 +1,28 @@ -# This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'spark' folder. +# ruff: noqa: D100 +# This code is based on code from Apache Spark under the license found in the LICENSE +# file located in the 'spark' folder. +import calendar +import datetime +import math +import re +import time +from builtins import tuple +from collections.abc import Iterator, Mapping +from types import MappingProxyType from typing import ( - cast, - overload, - Dict, - Optional, - List, - Tuple, Any, - Union, - Type, - TypeVar, ClassVar, - Iterator, + NoReturn, + Optional, + TypeVar, + Union, + cast, + overload, ) -from builtins import tuple -import datetime -import calendar -import time -import math -import re import duckdb -from duckdb.typing import DuckDBPyType +from duckdb.sqltypes import DuckDBPyType from ..exception import ContributionsAcceptedError @@ -30,105 +30,100 @@ U = TypeVar("U") __all__ = [ - "DataType", - "NullType", - "StringType", + "ArrayType", "BinaryType", - "UUIDType", "BitstringType", "BooleanType", + "ByteType", + "DataType", "DateType", - "TimestampType", - "TimestampNTZType", - "TimestampNanosecondNTZType", - "TimestampMilisecondNTZType", - "TimestampSecondNTZType", - "TimeType", - "TimeNTZType", + "DayTimeIntervalType", "DecimalType", "DoubleType", "FloatType", - "ByteType", - "UnsignedByteType", - "ShortType", - "UnsignedShortType", + "HugeIntegerType", "IntegerType", - "UnsignedIntegerType", "LongType", - "UnsignedLongType", - "HugeIntegerType", - "UnsignedHugeIntegerType", - "DayTimeIntervalType", - "Row", - "ArrayType", "MapType", + "NullType", + "Row", + "ShortType", + "StringType", "StructField", "StructType", + "TimeNTZType", + "TimeType", + "TimestampMilisecondNTZType", + "TimestampNTZType", + "TimestampNanosecondNTZType", + "TimestampSecondNTZType", + "TimestampType", + "UUIDType", + "UnsignedByteType", + "UnsignedHugeIntegerType", + "UnsignedIntegerType", + "UnsignedLongType", + "UnsignedShortType", ] class DataType: """Base class for data types.""" - def __init__(self, duckdb_type): + def __init__(self, duckdb_type: DuckDBPyType) -> None: # noqa: D107 self.duckdb_type = duckdb_type - def __repr__(self) -> str: + def __repr__(self) -> str: # noqa: D105 return self.__class__.__name__ + "()" - def __hash__(self) -> int: + def __hash__(self) -> int: # noqa: D105 return hash(str(self)) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: # noqa: D105 return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - def __ne__(self, other: Any) -> bool: + def __ne__(self, other: object) -> bool: # noqa: D105 return not self.__eq__(other) @classmethod - def typeName(cls) -> str: + def typeName(cls) -> str: # noqa: D102 return cls.__name__[:-4].lower() - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return self.typeName() - def jsonValue(self) -> Union[str, Dict[str, Any]]: + def jsonValue(self) -> Union[str, dict[str, Any]]: # noqa: D102 raise ContributionsAcceptedError - def json(self) -> str: + def json(self) -> str: # noqa: D102 raise ContributionsAcceptedError def needConversion(self) -> bool: - """ - Does this type needs conversion between Python object and internal SQL object. + """Does this type needs conversion between Python object and internal SQL object. This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType. """ return False - def toInternal(self, obj: Any) -> Any: - """ - Converts a Python object into an internal SQL object. - """ + def toInternal(self, obj: Any) -> Any: # noqa: ANN401 + """Converts a Python object into an internal SQL object.""" return obj - def fromInternal(self, obj: Any) -> Any: - """ - Converts an internal SQL object into a native Python object. - """ + def fromInternal(self, obj: Any) -> Any: # noqa: ANN401 + """Converts an internal SQL object into a native Python object.""" return obj # This singleton pattern does not work with pickle, you will get # another object after pickle and unpickle class DataTypeSingleton(type): - """Metaclass for DataType""" + """Metaclass for DataType.""" - _instances: ClassVar[Dict[Type["DataTypeSingleton"], "DataTypeSingleton"]] = {} + _instances: ClassVar[dict[type["DataTypeSingleton"], "DataTypeSingleton"]] = {} - def __call__(cls: Type[T]) -> T: # type: ignore[override] + def __call__(cls: type[T]) -> T: # type: ignore[override] if cls not in cls._instances: # type: ignore[attr-defined] - cls._instances[cls] = super(DataTypeSingleton, cls).__call__() # type: ignore[misc, attr-defined] + cls._instances[cls] = super().__call__() # type: ignore[misc, attr-defined] return cls._instances[cls] # type: ignore[attr-defined] @@ -138,17 +133,18 @@ class NullType(DataType, metaclass=DataTypeSingleton): The data type representing None, used for the types that cannot be inferred. """ - def __init__(self): + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("NULL")) @classmethod - def typeName(cls) -> str: + def typeName(cls) -> str: # noqa: D102 return "void" class AtomicType(DataType): """An internal type used to represent everything that is not - null, UDTs, arrays, structs, and maps.""" + null, UDTs, arrays, structs, and maps. + """ # noqa: D205 class NumericType(AtomicType): @@ -166,54 +162,54 @@ class FractionalType(NumericType): class StringType(AtomicType, metaclass=DataTypeSingleton): """String data type.""" - def __init__(self): + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("VARCHAR")) class BitstringType(AtomicType, metaclass=DataTypeSingleton): """Bitstring data type.""" - def __init__(self): + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("BIT")) class UUIDType(AtomicType, metaclass=DataTypeSingleton): """UUID data type.""" - def __init__(self): + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("UUID")) class BinaryType(AtomicType, metaclass=DataTypeSingleton): """Binary (byte array) data type.""" - def __init__(self): + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("BLOB")) class BooleanType(AtomicType, metaclass=DataTypeSingleton): """Boolean data type.""" - def __init__(self): + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("BOOLEAN")) class DateType(AtomicType, metaclass=DataTypeSingleton): """Date (datetime.date) data type.""" - def __init__(self): + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("DATE")) EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal() - def needConversion(self) -> bool: + def needConversion(self) -> bool: # noqa: D102 return True - def toInternal(self, d: datetime.date) -> int: + def toInternal(self, d: datetime.date) -> int: # noqa: D102 if d is not None: return d.toordinal() - self.EPOCH_ORDINAL - def fromInternal(self, v: int) -> datetime.date: + def fromInternal(self, v: int) -> datetime.date: # noqa: D102 if v is not None: return datetime.date.fromordinal(v + self.EPOCH_ORDINAL) @@ -221,22 +217,22 @@ def fromInternal(self, v: int) -> datetime.date: class TimestampType(AtomicType, metaclass=DataTypeSingleton): """Timestamp (datetime.datetime) data type.""" - def __init__(self): + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("TIMESTAMPTZ")) @classmethod - def typeName(cls) -> str: + def typeName(cls) -> str: # noqa: D102 return "timestamptz" - def needConversion(self) -> bool: + def needConversion(self) -> bool: # noqa: D102 return True - def toInternal(self, dt: datetime.datetime) -> int: + def toInternal(self, dt: datetime.datetime) -> int: # noqa: D102 if dt is not None: seconds = calendar.timegm(dt.utctimetuple()) if dt.tzinfo else time.mktime(dt.timetuple()) return int(seconds) * 1000000 + dt.microsecond - def fromInternal(self, ts: int) -> datetime.datetime: + def fromInternal(self, ts: int) -> datetime.datetime: # noqa: D102 if ts is not None: # using int to avoid precision loss in float return datetime.datetime.fromtimestamp(ts // 1000000).replace(microsecond=ts % 1000000) @@ -245,22 +241,22 @@ def fromInternal(self, ts: int) -> datetime.datetime: class TimestampNTZType(AtomicType, metaclass=DataTypeSingleton): """Timestamp (datetime.datetime) data type without timezone information with microsecond precision.""" - def __init__(self): + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("TIMESTAMP")) - def needConversion(self) -> bool: + def needConversion(self) -> bool: # noqa: D102 return True @classmethod - def typeName(cls) -> str: + def typeName(cls) -> str: # noqa: D102 return "timestamp" - def toInternal(self, dt: datetime.datetime) -> int: + def toInternal(self, dt: datetime.datetime) -> int: # noqa: D102 if dt is not None: seconds = calendar.timegm(dt.timetuple()) return int(seconds) * 1000000 + dt.microsecond - def fromInternal(self, ts: int) -> datetime.datetime: + def fromInternal(self, ts: int) -> datetime.datetime: # noqa: D102 if ts is not None: # using int to avoid precision loss in float return datetime.datetime.utcfromtimestamp(ts // 1000000).replace(microsecond=ts % 1000000) @@ -269,60 +265,60 @@ def fromInternal(self, ts: int) -> datetime.datetime: class TimestampSecondNTZType(AtomicType, metaclass=DataTypeSingleton): """Timestamp (datetime.datetime) data type without timezone information with second precision.""" - def __init__(self): + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("TIMESTAMP_S")) - def needConversion(self) -> bool: + def needConversion(self) -> bool: # noqa: D102 return True @classmethod - def typeName(cls) -> str: + def typeName(cls) -> str: # noqa: D102 return "timestamp_s" - def toInternal(self, dt: datetime.datetime) -> int: + def toInternal(self, dt: datetime.datetime) -> int: # noqa: D102 raise ContributionsAcceptedError - def fromInternal(self, ts: int) -> datetime.datetime: + def fromInternal(self, ts: int) -> datetime.datetime: # noqa: D102 raise ContributionsAcceptedError class TimestampMilisecondNTZType(AtomicType, metaclass=DataTypeSingleton): """Timestamp (datetime.datetime) data type without timezone information with milisecond precision.""" - def __init__(self): + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("TIMESTAMP_MS")) - def needConversion(self) -> bool: + def needConversion(self) -> bool: # noqa: D102 return True @classmethod - def typeName(cls) -> str: + def typeName(cls) -> str: # noqa: D102 return "timestamp_ms" - def toInternal(self, dt: datetime.datetime) -> int: + def toInternal(self, dt: datetime.datetime) -> int: # noqa: D102 raise ContributionsAcceptedError - def fromInternal(self, ts: int) -> datetime.datetime: + def fromInternal(self, ts: int) -> datetime.datetime: # noqa: D102 raise ContributionsAcceptedError class TimestampNanosecondNTZType(AtomicType, metaclass=DataTypeSingleton): """Timestamp (datetime.datetime) data type without timezone information with nanosecond precision.""" - def __init__(self): + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("TIMESTAMP_NS")) - def needConversion(self) -> bool: + def needConversion(self) -> bool: # noqa: D102 return True @classmethod - def typeName(cls) -> str: + def typeName(cls) -> str: # noqa: D102 return "timestamp_ns" - def toInternal(self, dt: datetime.datetime) -> int: + def toInternal(self, dt: datetime.datetime) -> int: # noqa: D102 raise ContributionsAcceptedError - def fromInternal(self, ts: int) -> datetime.datetime: + def fromInternal(self, ts: int) -> datetime.datetime: # noqa: D102 raise ContributionsAcceptedError @@ -346,90 +342,90 @@ class DecimalType(FractionalType): the number of digits on right side of dot. (default: 0) """ - def __init__(self, precision: int = 10, scale: int = 0): + def __init__(self, precision: int = 10, scale: int = 0) -> None: # noqa: D107 super().__init__(duckdb.decimal_type(precision, scale)) self.precision = precision self.scale = scale self.hasPrecisionInfo = True # this is a public API - def simpleString(self) -> str: - return "decimal(%d,%d)" % (self.precision, self.scale) + def simpleString(self) -> str: # noqa: D102 + return f"decimal({int(self.precision):d},{int(self.scale):d})" - def __repr__(self) -> str: - return "DecimalType(%d,%d)" % (self.precision, self.scale) + def __repr__(self) -> str: # noqa: D105 + return f"DecimalType({int(self.precision):d},{int(self.scale):d})" class DoubleType(FractionalType, metaclass=DataTypeSingleton): """Double data type, representing double precision floats.""" - def __init__(self): + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("DOUBLE")) class FloatType(FractionalType, metaclass=DataTypeSingleton): """Float data type, representing single precision floats.""" - def __init__(self): + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("FLOAT")) class ByteType(IntegralType): """Byte data type, i.e. a signed integer in a single byte.""" - def __init__(self): + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("TINYINT")) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "tinyint" class UnsignedByteType(IntegralType): """Unsigned byte data type, i.e. a unsigned integer in a single byte.""" - def __init__(self): + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("UTINYINT")) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "utinyint" class ShortType(IntegralType): """Short data type, i.e. a signed 16-bit integer.""" - def __init__(self): + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("SMALLINT")) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "smallint" class UnsignedShortType(IntegralType): """Unsigned short data type, i.e. a unsigned 16-bit integer.""" - def __init__(self): + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("USMALLINT")) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "usmallint" class IntegerType(IntegralType): """Int data type, i.e. a signed 32-bit integer.""" - def __init__(self): + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("INTEGER")) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "integer" class UnsignedIntegerType(IntegralType): """Unsigned int data type, i.e. a unsigned 32-bit integer.""" - def __init__(self): + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("UINTEGER")) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "uinteger" @@ -440,10 +436,10 @@ class LongType(IntegralType): please use :class:`DecimalType`. """ - def __init__(self): + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("BIGINT")) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "bigint" @@ -454,24 +450,24 @@ class UnsignedLongType(IntegralType): please use :class:`HugeIntegerType`. """ - def __init__(self): + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("UBIGINT")) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "ubigint" class HugeIntegerType(IntegralType): """Huge integer data type, i.e. a signed 128-bit integer. - If the values are beyond the range of [-170141183460469231731687303715884105728, 170141183460469231731687303715884105727], - please use :class:`DecimalType`. + If the values are beyond the range of [-170141183460469231731687303715884105728, + 170141183460469231731687303715884105727], please use :class:`DecimalType`. """ - def __init__(self): + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("HUGEINT")) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "hugeint" @@ -482,30 +478,30 @@ class UnsignedHugeIntegerType(IntegralType): please use :class:`DecimalType`. """ - def __init__(self): + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("UHUGEINT")) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "uhugeint" class TimeType(IntegralType): """Time (datetime.time) data type.""" - def __init__(self): + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("TIMETZ")) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "timetz" class TimeNTZType(IntegralType): """Time (datetime.time) data type without timezone information.""" - def __init__(self): + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("TIME")) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "time" @@ -517,16 +513,18 @@ class DayTimeIntervalType(AtomicType): MINUTE = 2 SECOND = 3 - _fields = { - DAY: "day", - HOUR: "hour", - MINUTE: "minute", - SECOND: "second", - } + _fields: Mapping[str, int] = MappingProxyType( + { + DAY: "day", + HOUR: "hour", + MINUTE: "minute", + SECOND: "second", + } + ) - _inverted_fields = dict(zip(_fields.values(), _fields.keys())) + _inverted_fields: Mapping[int, str] = MappingProxyType(dict(zip(_fields.values(), _fields.keys()))) - def __init__(self, startField: Optional[int] = None, endField: Optional[int] = None): + def __init__(self, startField: Optional[int] = None, endField: Optional[int] = None) -> None: # noqa: D107 super().__init__(DuckDBPyType("INTERVAL")) if startField is None and endField is None: # Default matched to scala side. @@ -536,33 +534,34 @@ def __init__(self, startField: Optional[int] = None, endField: Optional[int] = N endField = startField fields = DayTimeIntervalType._fields - if startField not in fields.keys() or endField not in fields.keys(): - raise RuntimeError("interval %s to %s is invalid" % (startField, endField)) - self.startField = cast(int, startField) - self.endField = cast(int, endField) + if startField not in fields or endField not in fields: + msg = f"interval {startField} to {endField} is invalid" + raise RuntimeError(msg) + self.startField = cast("int", startField) + self.endField = cast("int", endField) def _str_repr(self) -> str: fields = DayTimeIntervalType._fields start_field_name = fields[self.startField] end_field_name = fields[self.endField] if start_field_name == end_field_name: - return "interval %s" % start_field_name + return f"interval {start_field_name}" else: - return "interval %s to %s" % (start_field_name, end_field_name) + return f"interval {start_field_name} to {end_field_name}" simpleString = _str_repr - def __repr__(self) -> str: - return "%s(%d, %d)" % (type(self).__name__, self.startField, self.endField) + def __repr__(self) -> str: # noqa: D105 + return f"{type(self).__name__}({int(self.startField):d}, {int(self.endField):d})" - def needConversion(self) -> bool: + def needConversion(self) -> bool: # noqa: D102 return True - def toInternal(self, dt: datetime.timedelta) -> Optional[int]: + def toInternal(self, dt: datetime.timedelta) -> Optional[int]: # noqa: D102 if dt is not None: return (math.floor(dt.total_seconds()) * 1000000) + dt.microseconds - def fromInternal(self, micros: int) -> Optional[datetime.timedelta]: + def fromInternal(self, micros: int) -> Optional[datetime.timedelta]: # noqa: D102 if micros is not None: return datetime.timedelta(microseconds=micros) @@ -577,7 +576,7 @@ class ArrayType(DataType): containsNull : bool, optional whether the array can contain null (None) values. - Examples + Examples: -------- >>> ArrayType(StringType()) == ArrayType(StringType(), True) True @@ -585,30 +584,27 @@ class ArrayType(DataType): False """ - def __init__(self, elementType: DataType, containsNull: bool = True): + def __init__(self, elementType: DataType, containsNull: bool = True) -> None: # noqa: D107 super().__init__(duckdb.list_type(elementType.duckdb_type)) - assert isinstance(elementType, DataType), "elementType %s should be an instance of %s" % ( - elementType, - DataType, - ) + assert isinstance(elementType, DataType), f"elementType {elementType} should be an instance of {DataType}" self.elementType = elementType self.containsNull = containsNull - def simpleString(self) -> str: - return "array<%s>" % self.elementType.simpleString() + def simpleString(self) -> str: # noqa: D102 + return f"array<{self.elementType.simpleString()}>" - def __repr__(self) -> str: - return "ArrayType(%s, %s)" % (self.elementType, str(self.containsNull)) + def __repr__(self) -> str: # noqa: D105 + return f"ArrayType({self.elementType}, {self.containsNull!s})" - def needConversion(self) -> bool: + def needConversion(self) -> bool: # noqa: D102 return self.elementType.needConversion() - def toInternal(self, obj: List[Optional[T]]) -> List[Optional[T]]: + def toInternal(self, obj: list[Optional[T]]) -> list[Optional[T]]: # noqa: D102 if not self.needConversion(): return obj return obj and [self.elementType.toInternal(v) for v in obj] - def fromInternal(self, obj: List[Optional[T]]) -> List[Optional[T]]: + def fromInternal(self, obj: list[Optional[T]]) -> list[Optional[T]]: # noqa: D102 if not self.needConversion(): return obj return obj and [self.elementType.fromInternal(v) for v in obj] @@ -626,59 +622,44 @@ class MapType(DataType): valueContainsNull : bool, optional indicates whether values can contain null (None) values. - Notes + Notes: ----- Keys in a map data type are not allowed to be null (None). - Examples + Examples: -------- - >>> (MapType(StringType(), IntegerType()) - ... == MapType(StringType(), IntegerType(), True)) + >>> (MapType(StringType(), IntegerType()) == MapType(StringType(), IntegerType(), True)) True - >>> (MapType(StringType(), IntegerType(), False) - ... == MapType(StringType(), FloatType())) + >>> (MapType(StringType(), IntegerType(), False) == MapType(StringType(), FloatType())) False """ - def __init__(self, keyType: DataType, valueType: DataType, valueContainsNull: bool = True): + def __init__(self, keyType: DataType, valueType: DataType, valueContainsNull: bool = True) -> None: # noqa: D107 super().__init__(duckdb.map_type(keyType.duckdb_type, valueType.duckdb_type)) - assert isinstance(keyType, DataType), "keyType %s should be an instance of %s" % ( - keyType, - DataType, - ) - assert isinstance(valueType, DataType), "valueType %s should be an instance of %s" % ( - valueType, - DataType, - ) + assert isinstance(keyType, DataType), f"keyType {keyType} should be an instance of {DataType}" + assert isinstance(valueType, DataType), f"valueType {valueType} should be an instance of {DataType}" self.keyType = keyType self.valueType = valueType self.valueContainsNull = valueContainsNull - def simpleString(self) -> str: - return "map<%s,%s>" % ( - self.keyType.simpleString(), - self.valueType.simpleString(), - ) + def simpleString(self) -> str: # noqa: D102 + return f"map<{self.keyType.simpleString()},{self.valueType.simpleString()}>" - def __repr__(self) -> str: - return "MapType(%s, %s, %s)" % ( - self.keyType, - self.valueType, - str(self.valueContainsNull), - ) + def __repr__(self) -> str: # noqa: D105 + return f"MapType({self.keyType}, {self.valueType}, {self.valueContainsNull!s})" - def needConversion(self) -> bool: + def needConversion(self) -> bool: # noqa: D102 return self.keyType.needConversion() or self.valueType.needConversion() - def toInternal(self, obj: Dict[T, Optional[U]]) -> Dict[T, Optional[U]]: + def toInternal(self, obj: dict[T, Optional[U]]) -> dict[T, Optional[U]]: # noqa: D102 if not self.needConversion(): return obj - return obj and dict((self.keyType.toInternal(k), self.valueType.toInternal(v)) for k, v in obj.items()) + return obj and {self.keyType.toInternal(k): self.valueType.toInternal(v) for k, v in obj.items()} - def fromInternal(self, obj: Dict[T, Optional[U]]) -> Dict[T, Optional[U]]: + def fromInternal(self, obj: dict[T, Optional[U]]) -> dict[T, Optional[U]]: # noqa: D102 if not self.needConversion(): return obj - return obj and dict((self.keyType.fromInternal(k), self.valueType.fromInternal(v)) for k, v in obj.items()) + return obj and {self.keyType.fromInternal(k): self.valueType.fromInternal(v) for k, v in obj.items()} class StructField(DataType): @@ -695,66 +676,58 @@ class StructField(DataType): metadata : dict, optional a dict from string to simple type that can be toInternald to JSON automatically - Examples + Examples: -------- - >>> (StructField("f1", StringType(), True) - ... == StructField("f1", StringType(), True)) + >>> (StructField("f1", StringType(), True) == StructField("f1", StringType(), True)) True - >>> (StructField("f1", StringType(), True) - ... == StructField("f2", StringType(), True)) + >>> (StructField("f1", StringType(), True) == StructField("f2", StringType(), True)) False """ - def __init__( + def __init__( # noqa: D107 self, name: str, dataType: DataType, nullable: bool = True, - metadata: Optional[Dict[str, Any]] = None, - ): + metadata: Optional[dict[str, Any]] = None, + ) -> None: super().__init__(dataType.duckdb_type) - assert isinstance(dataType, DataType), "dataType %s should be an instance of %s" % ( - dataType, - DataType, - ) - assert isinstance(name, str), "field name %s should be a string" % (name) + assert isinstance(dataType, DataType), f"dataType {dataType} should be an instance of {DataType}" + assert isinstance(name, str), f"field name {name} should be a string" self.name = name self.dataType = dataType self.nullable = nullable self.metadata = metadata or {} - def simpleString(self) -> str: - return "%s:%s" % (self.name, self.dataType.simpleString()) + def simpleString(self) -> str: # noqa: D102 + return f"{self.name}:{self.dataType.simpleString()}" - def __repr__(self) -> str: - return "StructField('%s', %s, %s)" % ( - self.name, - self.dataType, - str(self.nullable), - ) + def __repr__(self) -> str: # noqa: D105 + return f"StructField('{self.name}', {self.dataType}, {self.nullable!s})" - def needConversion(self) -> bool: + def needConversion(self) -> bool: # noqa: D102 return self.dataType.needConversion() - def toInternal(self, obj: T) -> T: + def toInternal(self, obj: T) -> T: # noqa: D102 return self.dataType.toInternal(obj) - def fromInternal(self, obj: T) -> T: + def fromInternal(self, obj: T) -> T: # noqa: D102 return self.dataType.fromInternal(obj) - def typeName(self) -> str: # type: ignore[override] - raise TypeError("StructField does not have typeName. " "Use typeName on its type explicitly instead.") + def typeName(self) -> str: # type: ignore[override] # noqa: D102 + msg = "StructField does not have typeName. Use typeName on its type explicitly instead." + raise TypeError(msg) class StructType(DataType): - """Struct type, consisting of a list of :class:`StructField`. + r"""Struct type, consisting of a list of :class:`StructField`. This is the data type representing a :class:`Row`. Iterating a :class:`StructType` will iterate over its :class:`StructField`\\s. A contained :class:`StructField` can be accessed by its name or position. - Examples + Examples: -------- >>> struct1 = StructType([StructField("f1", StringType(), True)]) >>> struct1["f1"] @@ -767,16 +740,17 @@ class StructType(DataType): >>> struct1 == struct2 True >>> struct1 = StructType([StructField("f1", StringType(), True)]) - >>> struct2 = StructType([StructField("f1", StringType(), True), - ... StructField("f2", IntegerType(), False)]) + >>> struct2 = StructType( + ... [StructField("f1", StringType(), True), StructField("f2", IntegerType(), False)] + ... ) >>> struct1 == struct2 False """ - def _update_internal_duckdb_type(self): + def _update_internal_duckdb_type(self) -> None: self.duckdb_type = duckdb.struct_type(dict(zip(self.names, [x.duckdb_type for x in self.fields]))) - def __init__(self, fields: Optional[List[StructField]] = None): + def __init__(self, fields: Optional[list[StructField]] = None) -> None: # noqa: D107 if not fields: self.fields = [] self.names = [] @@ -795,23 +769,20 @@ def add( field: str, data_type: Union[str, DataType], nullable: bool = True, - metadata: Optional[Dict[str, Any]] = None, - ) -> "StructType": - ... + metadata: Optional[dict[str, Any]] = None, + ) -> "StructType": ... @overload - def add(self, field: StructField) -> "StructType": - ... + def add(self, field: StructField) -> "StructType": ... def add( self, field: Union[str, StructField], data_type: Optional[Union[str, DataType]] = None, nullable: bool = True, - metadata: Optional[Dict[str, Any]] = None, + metadata: Optional[dict[str, Any]] = None, ) -> "StructType": - """ - Construct a :class:`StructType` by adding new elements to it, to define the schema. + r"""Construct a :class:`StructType` by adding new elements to it, to define the schema. The method accepts either: a) A single parameter which is a :class:`StructField` object. @@ -830,11 +801,11 @@ def add( metadata : dict, optional Any additional metadata (default None) - Returns + Returns: ------- :class:`StructType` - Examples + Examples: -------- >>> struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) >>> struct2 = StructType([StructField("f1", StringType(), True), \\ @@ -849,13 +820,14 @@ def add( >>> struct2 = StructType([StructField("f1", StringType(), True)]) >>> struct1 == struct2 True - """ + """ # noqa: D205, D415 if isinstance(field, StructField): self.fields.append(field) self.names.append(field.name) else: if isinstance(field, str) and data_type is None: - raise ValueError("Must specify DataType if passing name of struct_field to create.") + msg = "Must specify DataType if passing name of struct_field to create." + raise ValueError(msg) else: data_type_f = data_type self.fields.append(StructField(field, data_type_f, nullable, metadata)) @@ -867,7 +839,7 @@ def add( return self def __iter__(self) -> Iterator[StructField]: - """Iterate the fields""" + """Iterate the fields.""" return iter(self.fields) def __len__(self) -> int: @@ -880,27 +852,30 @@ def __getitem__(self, key: Union[str, int]) -> StructField: for field in self: if field.name == key: return field - raise KeyError("No StructField named {0}".format(key)) + msg = f"No StructField named {key}" + raise KeyError(msg) elif isinstance(key, int): try: return self.fields[key] except IndexError: - raise IndexError("StructType index out of range") + msg = "StructType index out of range" + raise IndexError(msg) # noqa: B904 elif isinstance(key, slice): return StructType(self.fields[key]) else: - raise TypeError("StructType keys should be strings, integers or slices") + msg = "StructType keys should be strings, integers or slices" + raise TypeError(msg) - def simpleString(self) -> str: - return "struct<%s>" % (",".join(f.simpleString() for f in self)) + def simpleString(self) -> str: # noqa: D102 + return "struct<{}>".format(",".join(f.simpleString() for f in self)) - def __repr__(self) -> str: - return "StructType([%s])" % ", ".join(str(field) for field in self) + def __repr__(self) -> str: # noqa: D105 + return "StructType([{}])".format(", ".join(str(field) for field in self)) - def __contains__(self, item: Any) -> bool: + def __contains__(self, item: str) -> bool: # noqa: D105 return item in self.names - def extract_types_and_names(self) -> Tuple[List[str], List[str]]: + def extract_types_and_names(self) -> tuple[list[str], list[str]]: # noqa: D102 names = [] types = [] for f in self.fields: @@ -908,11 +883,10 @@ def extract_types_and_names(self) -> Tuple[List[str], List[str]]: names.append(f.name) return (types, names) - def fieldNames(self) -> List[str]: - """ - Returns all field names in a list. + def fieldNames(self) -> list[str]: + """Returns all field names in a list. - Examples + Examples: -------- >>> struct = StructType([StructField("f1", StringType(), True)]) >>> struct.fieldNames() @@ -920,11 +894,11 @@ def fieldNames(self) -> List[str]: """ return list(self.names) - def needConversion(self) -> bool: + def needConversion(self) -> bool: # noqa: D102 # We need convert Row()/namedtuple into tuple() return True - def toInternal(self, obj: Tuple) -> Tuple: + def toInternal(self, obj: tuple) -> tuple: # noqa: D102 if obj is None: return @@ -944,7 +918,8 @@ def toInternal(self, obj: Tuple) -> Tuple: for n, f, c in zip(self.names, self.fields, self._needConversion) ) else: - raise ValueError("Unexpected tuple %r with StructType" % obj) + msg = f"Unexpected tuple {obj!r} with StructType" + raise ValueError(msg) else: if isinstance(obj, dict): return tuple(obj.get(n) for n in self.names) @@ -954,16 +929,17 @@ def toInternal(self, obj: Tuple) -> Tuple: d = obj.__dict__ return tuple(d.get(n) for n in self.names) else: - raise ValueError("Unexpected tuple %r with StructType" % obj) + msg = f"Unexpected tuple {obj!r} with StructType" + raise ValueError(msg) - def fromInternal(self, obj: Tuple) -> "Row": + def fromInternal(self, obj: tuple) -> "Row": # noqa: D102 if obj is None: return if isinstance(obj, Row): # it's already converted by pickler return obj - values: Union[Tuple, List] + values: Union[tuple, list] if self._needSerializeAnyField: # Only calling fromInternal function for fields that need conversion values = [f.fromInternal(v) if c else v for f, v, c in zip(self.fields, obj, self._needConversion)] @@ -973,7 +949,7 @@ def fromInternal(self, obj: Tuple) -> "Row": class UnionType(DataType): - def __init__(self): + def __init__(self) -> None: raise ContributionsAcceptedError @@ -983,7 +959,7 @@ class UserDefinedType(DataType): .. note:: WARN: Spark Internal Use Only """ - def __init__(self): + def __init__(self) -> None: raise ContributionsAcceptedError @classmethod @@ -992,24 +968,21 @@ def typeName(cls) -> str: @classmethod def sqlType(cls) -> DataType: - """ - Underlying SQL storage type for this UDT. - """ - raise NotImplementedError("UDT must implement sqlType().") + """Underlying SQL storage type for this UDT.""" + msg = "UDT must implement sqlType()." + raise NotImplementedError(msg) @classmethod def module(cls) -> str: - """ - The Python module of the UDT. - """ - raise NotImplementedError("UDT must implement module().") + """The Python module of the UDT.""" + msg = "UDT must implement module()." + raise NotImplementedError(msg) @classmethod def scalaUDT(cls) -> str: - """ - The class name of the paired Scala UDT (could be '', if there + """The class name of the paired Scala UDT (could be '', if there is no corresponding one). - """ + """ # noqa: D205 return "" def needConversion(self) -> bool: @@ -1017,42 +990,38 @@ def needConversion(self) -> bool: @classmethod def _cachedSqlType(cls) -> DataType: - """ - Cache the sqlType() into class, because it's heavily used in `toInternal`. - """ + """Cache the sqlType() into class, because it's heavily used in `toInternal`.""" if not hasattr(cls, "_cached_sql_type"): cls._cached_sql_type = cls.sqlType() # type: ignore[attr-defined] return cls._cached_sql_type # type: ignore[attr-defined] - def toInternal(self, obj: Any) -> Any: + def toInternal(self, obj: Any) -> Any: # noqa: ANN401 if obj is not None: return self._cachedSqlType().toInternal(self.serialize(obj)) - def fromInternal(self, obj: Any) -> Any: + def fromInternal(self, obj: Any) -> Any: # noqa: ANN401 v = self._cachedSqlType().fromInternal(obj) if v is not None: return self.deserialize(v) - def serialize(self, obj: Any) -> Any: - """ - Converts a user-type object into a SQL datum. - """ - raise NotImplementedError("UDT must implement toInternal().") + def serialize(self, obj: Any) -> NoReturn: # noqa: ANN401 + """Converts a user-type object into a SQL datum.""" + msg = "UDT must implement toInternal()." + raise NotImplementedError(msg) - def deserialize(self, datum: Any) -> Any: - """ - Converts a SQL datum into a user-type object. - """ - raise NotImplementedError("UDT must implement fromInternal().") + def deserialize(self, datum: Any) -> NoReturn: # noqa: ANN401 + """Converts a SQL datum into a user-type object.""" + msg = "UDT must implement fromInternal()." + raise NotImplementedError(msg) def simpleString(self) -> str: return "udt" - def __eq__(self, other: Any) -> bool: - return type(self) == type(other) + def __eq__(self, other: object) -> bool: + return type(self) is type(other) -_atomic_types: List[Type[DataType]] = [ +_atomic_types: list[type[DataType]] = [ StringType, BinaryType, BooleanType, @@ -1068,32 +1037,28 @@ def __eq__(self, other: Any) -> bool: TimestampNTZType, NullType, ] -_all_atomic_types: Dict[str, Type[DataType]] = dict((t.typeName(), t) for t in _atomic_types) +_all_atomic_types: dict[str, type[DataType]] = {t.typeName(): t for t in _atomic_types} -_complex_types: List[Type[Union[ArrayType, MapType, StructType]]] = [ +_complex_types: list[type[Union[ArrayType, MapType, StructType]]] = [ ArrayType, MapType, StructType, ] -_all_complex_types: Dict[str, Type[Union[ArrayType, MapType, StructType]]] = dict( - (v.typeName(), v) for v in _complex_types -) +_all_complex_types: dict[str, type[Union[ArrayType, MapType, StructType]]] = {v.typeName(): v for v in _complex_types} _FIXED_DECIMAL = re.compile(r"decimal\(\s*(\d+)\s*,\s*(-?\d+)\s*\)") _INTERVAL_DAYTIME = re.compile(r"interval (day|hour|minute|second)( to (day|hour|minute|second))?") -def _create_row(fields: Union["Row", List[str]], values: Union[Tuple[Any, ...], List[Any]]) -> "Row": +def _create_row(fields: Union["Row", list[str]], values: Union[tuple[Any, ...], list[Any]]) -> "Row": row = Row(*values) row.__fields__ = fields return row class Row(tuple): - - """ - A row in :class:`DataFrame`. + """A row in :class:`DataFrame`. The fields in it can be accessed: * like attributes (``row.key``) @@ -1110,18 +1075,18 @@ class Row(tuple): field names sorted alphabetically and will be ordered in the position as entered. - Examples + Examples: -------- >>> row = Row(name="Alice", age=11) >>> row Row(name='Alice', age=11) - >>> row['name'], row['age'] + >>> row["name"], row["age"] ('Alice', 11) >>> row.name, row.age ('Alice', 11) - >>> 'name' in row + >>> "name" in row True - >>> 'wrong_key' in row + >>> "wrong_key" in row False Row also can be used to create another Row like class, then it @@ -1130,9 +1095,9 @@ class Row(tuple): >>> Person = Row("name", "age") >>> Person - >>> 'name' in Person + >>> "name" in Person True - >>> 'wrong_key' in Person + >>> "wrong_key" in Person False >>> Person("Alice", 11) Row(name='Alice', age=11) @@ -1144,19 +1109,18 @@ class Row(tuple): >>> row2 = Row(name="Alice", age=11) >>> row1 == row2 True - """ + """ # noqa: D205, D415 @overload - def __new__(cls, *args: str) -> "Row": - ... + def __new__(cls, *args: str) -> "Row": ... @overload - def __new__(cls, **kwargs: Any) -> "Row": - ... + def __new__(cls, **kwargs: Any) -> "Row": ... # noqa: ANN401 - def __new__(cls, *args: Optional[str], **kwargs: Optional[Any]) -> "Row": + def __new__(cls, *args: Optional[str], **kwargs: Optional[Any]) -> "Row": # noqa: D102 if args and kwargs: - raise ValueError("Can not use both args " "and kwargs to create Row") + msg = "Can not use both args and kwargs to create Row" + raise ValueError(msg) if kwargs: # create row objects row = tuple.__new__(cls, list(kwargs.values())) @@ -1166,16 +1130,15 @@ def __new__(cls, *args: Optional[str], **kwargs: Optional[Any]) -> "Row": # create row class or objects return tuple.__new__(cls, args) - def asDict(self, recursive: bool = False) -> Dict[str, Any]: - """ - Return as a dict + def asDict(self, recursive: bool = False) -> dict[str, Any]: + """Return as a dict. Parameters ---------- recursive : bool, optional turns the nested Rows to dict (default: False). - Notes + Notes: ----- If a row contains duplicate field names, e.g., the rows of a join between two :class:`DataFrame` that both have the fields of same names, @@ -1183,28 +1146,29 @@ def asDict(self, recursive: bool = False) -> Dict[str, Any]: will also return one of the duplicate fields, however returned value might be different to ``asDict``. - Examples + Examples: -------- - >>> Row(name="Alice", age=11).asDict() == {'name': 'Alice', 'age': 11} + >>> Row(name="Alice", age=11).asDict() == {"name": "Alice", "age": 11} True - >>> row = Row(key=1, value=Row(name='a', age=2)) - >>> row.asDict() == {'key': 1, 'value': Row(name='a', age=2)} + >>> row = Row(key=1, value=Row(name="a", age=2)) + >>> row.asDict() == {"key": 1, "value": Row(name="a", age=2)} True - >>> row.asDict(True) == {'key': 1, 'value': {'name': 'a', 'age': 2}} + >>> row.asDict(True) == {"key": 1, "value": {"name": "a", "age": 2}} True """ if not hasattr(self, "__fields__"): - raise TypeError("Cannot convert a Row class into dict") + msg = "Cannot convert a Row class into dict" + raise TypeError(msg) if recursive: - def conv(obj: Any) -> Any: + def conv(obj: Union[Row, list, dict, object]) -> Union[list, dict, object]: if isinstance(obj, Row): return obj.asDict(True) elif isinstance(obj, list): return [conv(o) for o in obj] elif isinstance(obj, dict): - return dict((k, conv(v)) for k, v in obj.items()) + return {k: conv(v) for k, v in obj.items()} else: return obj @@ -1212,35 +1176,34 @@ def conv(obj: Any) -> Any: else: return dict(zip(self.__fields__, self)) - def __contains__(self, item: Any) -> bool: + def __contains__(self, item: Any) -> bool: # noqa: D105, ANN401 if hasattr(self, "__fields__"): return item in self.__fields__ else: - return super(Row, self).__contains__(item) + return super().__contains__(item) # let object acts like class - def __call__(self, *args: Any) -> "Row": - """create new Row object""" + def __call__(self, *args: Any) -> "Row": # noqa: ANN401 + """Create new Row object.""" if len(args) > len(self): - raise ValueError( - "Can not create Row with fields %s, expected %d values " "but got %s" % (self, len(self), args) - ) + msg = f"Can not create Row with fields {self}, expected {len(self):d} values but got {args}" + raise ValueError(msg) return _create_row(self, args) - def __getitem__(self, item: Any) -> Any: + def __getitem__(self, item: Any) -> Any: # noqa: D105, ANN401 if isinstance(item, (int, slice)): - return super(Row, self).__getitem__(item) + return super().__getitem__(item) try: # it will be slow when it has many fields, # but this will not be used in normal cases idx = self.__fields__.index(item) - return super(Row, self).__getitem__(idx) + return super().__getitem__(idx) except IndexError: - raise KeyError(item) + raise KeyError(item) # noqa: B904 except ValueError: - raise ValueError(item) + raise ValueError(item) # noqa: B904 - def __getattr__(self, item: str) -> Any: + def __getattr__(self, item: str) -> Any: # noqa: D105, ANN401 if item.startswith("__"): raise AttributeError(item) try: @@ -1249,18 +1212,19 @@ def __getattr__(self, item: str) -> Any: idx = self.__fields__.index(item) return self[idx] except IndexError: - raise AttributeError(item) + raise AttributeError(item) # noqa: B904 except ValueError: - raise AttributeError(item) + raise AttributeError(item) # noqa: B904 - def __setattr__(self, key: Any, value: Any) -> None: + def __setattr__(self, key: Any, value: Any) -> None: # noqa: D105, ANN401 if key != "__fields__": - raise RuntimeError("Row is read-only") + msg = "Row is read-only" + raise RuntimeError(msg) self.__dict__[key] = value def __reduce__( self, - ) -> Union[str, Tuple[Any, ...]]: + ) -> Union[str, tuple[Any, ...]]: """Returns a tuple so Python knows how to pickle Row.""" if hasattr(self, "__fields__"): return (_create_row, (self.__fields__, tuple(self))) @@ -1270,6 +1234,6 @@ def __reduce__( def __repr__(self) -> str: """Printable representation of Row used in Python REPL.""" if hasattr(self, "__fields__"): - return "Row(%s)" % ", ".join("%s=%r" % (k, v) for k, v in zip(self.__fields__, tuple(self))) + return "Row({})".format(", ".join(f"{k}={v!r}" for k, v in zip(self.__fields__, tuple(self)))) else: - return "" % ", ".join("%r" % field for field in self) + return "".format(", ".join(f"{field!r}" for field in self)) diff --git a/duckdb/experimental/spark/sql/udf.py b/duckdb/experimental/spark/sql/udf.py index 61d3bee9..7437ed6b 100644 --- a/duckdb/experimental/spark/sql/udf.py +++ b/duckdb/experimental/spark/sql/udf.py @@ -1,4 +1,4 @@ -# https://sparkbyexamples.com/pyspark/pyspark-udf-user-defined-function/ +# https://sparkbyexamples.com/pyspark/pyspark-udf-user-defined-function/ # noqa: D100 from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union from .types import DataType @@ -10,11 +10,11 @@ UserDefinedFunctionLike = TypeVar("UserDefinedFunctionLike") -class UDFRegistration: - def __init__(self, sparkSession: "SparkSession"): +class UDFRegistration: # noqa: D101 + def __init__(self, sparkSession: "SparkSession") -> None: # noqa: D107 self.sparkSession = sparkSession - def register( + def register( # noqa: D102 self, name: str, f: Union[Callable[..., Any], "UserDefinedFunctionLike"], @@ -22,7 +22,7 @@ def register( ) -> "UserDefinedFunctionLike": self.sparkSession.conn.create_function(name, f, return_type=returnType) - def registerJavaFunction( + def registerJavaFunction( # noqa: D102 self, name: str, javaClassName: str, @@ -30,7 +30,7 @@ def registerJavaFunction( ) -> None: raise NotImplementedError - def registerJavaUDAF(self, name: str, javaClassName: str) -> None: + def registerJavaUDAF(self, name: str, javaClassName: str) -> None: # noqa: D102 raise NotImplementedError diff --git a/duckdb/filesystem.py b/duckdb/filesystem.py index fbef757d..cc082efb 100644 --- a/duckdb/filesystem.py +++ b/duckdb/filesystem.py @@ -1,23 +1,33 @@ -from fsspec import filesystem, AbstractFileSystem -from fsspec.implementations.memory import MemoryFileSystem, MemoryFile -from .bytes_io_wrapper import BytesIOWrapper -from io import TextIOBase +"""In-memory filesystem to store ephemeral dependencies. + +Warning: Not for external use. May change at any moment. Likely to be made internal. +""" + +from __future__ import annotations -def is_file_like(obj): - # We only care that we can read from the file - return hasattr(obj, "read") and hasattr(obj, "seek") +import io +import typing + +from fsspec import AbstractFileSystem +from fsspec.implementations.memory import MemoryFile, MemoryFileSystem + +from .bytes_io_wrapper import BytesIOWrapper class ModifiedMemoryFileSystem(MemoryFileSystem): - protocol = ('DUCKDB_INTERNAL_OBJECTSTORE',) + """In-memory filesystem implementation that uses its own protocol.""" + + protocol = ("DUCKDB_INTERNAL_OBJECTSTORE",) # defer to the original implementation that doesn't hardcode the protocol - _strip_protocol = classmethod(AbstractFileSystem._strip_protocol.__func__) + _strip_protocol: typing.Callable[[str], str] = classmethod(AbstractFileSystem._strip_protocol.__func__) # type: ignore[assignment] - def add_file(self, object, path): - if not is_file_like(object): - raise ValueError("Can not read from a non file-like object") - path = self._strip_protocol(path) - if isinstance(object, TextIOBase): + def add_file(self, obj: io.IOBase | BytesIOWrapper | object, path: str) -> None: + """Add a file to the filesystem.""" + if not (hasattr(obj, "read") and hasattr(obj, "seek")): + msg = "Can not read from a non file-like object" + raise TypeError(msg) + if isinstance(obj, io.TextIOBase): # Wrap this so that we can return a bytes object from 'read' - object = BytesIOWrapper(object) - self.store[path] = MemoryFile(self, path, object.read()) + obj = BytesIOWrapper(obj) + path = self._strip_protocol(path) + self.store[path] = MemoryFile(self, path, obj.read()) diff --git a/duckdb/func/__init__.py b/duckdb/func/__init__.py new file mode 100644 index 00000000..5d73f490 --- /dev/null +++ b/duckdb/func/__init__.py @@ -0,0 +1,3 @@ +from _duckdb._func import ARROW, DEFAULT, NATIVE, SPECIAL, FunctionNullHandling, PythonUDFType # noqa: D104 + +__all__ = ["ARROW", "DEFAULT", "NATIVE", "SPECIAL", "FunctionNullHandling", "PythonUDFType"] diff --git a/duckdb/functional/__init__.py b/duckdb/functional/__init__.py index ac4a6495..5114629b 100644 --- a/duckdb/functional/__init__.py +++ b/duckdb/functional/__init__.py @@ -1,17 +1,13 @@ -from _duckdb.functional import ( - FunctionNullHandling, - PythonUDFType, - SPECIAL, - DEFAULT, - NATIVE, - ARROW -) +"""DuckDB function constants and types. DEPRECATED: please use `duckdb.func` instead.""" + +import warnings + +from duckdb.func import ARROW, DEFAULT, NATIVE, SPECIAL, FunctionNullHandling, PythonUDFType -__all__ = [ - "FunctionNullHandling", - "PythonUDFType", - "SPECIAL", - "DEFAULT", - "NATIVE", - "ARROW" -] +__all__ = ["ARROW", "DEFAULT", "NATIVE", "SPECIAL", "FunctionNullHandling", "PythonUDFType"] + +warnings.warn( + "`duckdb.functional` is deprecated and will be removed in a future version. Please use `duckdb.func` instead.", + DeprecationWarning, + stacklevel=2, +) diff --git a/duckdb/functional/__init__.pyi b/duckdb/functional/__init__.pyi deleted file mode 100644 index 33ea33fa..00000000 --- a/duckdb/functional/__init__.pyi +++ /dev/null @@ -1,31 +0,0 @@ -from typing import Dict - -SPECIAL: FunctionNullHandling -DEFAULT: FunctionNullHandling - -NATIVE: PythonUDFType -ARROW: PythonUDFType - -class FunctionNullHandling: - DEFAULT: FunctionNullHandling - SPECIAL: FunctionNullHandling - def __int__(self) -> int: ... - def __index__(self) -> int: ... - @property - def __members__(self) -> Dict[str, FunctionNullHandling]: ... - @property - def name(self) -> str: ... - @property - def value(self) -> int: ... - -class PythonUDFType: - NATIVE: PythonUDFType - ARROW: PythonUDFType - def __int__(self) -> int: ... - def __index__(self) -> int: ... - @property - def __members__(self) -> Dict[str, PythonUDFType]: ... - @property - def name(self) -> str: ... - @property - def value(self) -> int: ... diff --git a/duckdb/polars_io.py b/duckdb/polars_io.py index d8d4cfe9..2c075baf 100644 --- a/duckdb/polars_io.py +++ b/duckdb/polars_io.py @@ -1,20 +1,30 @@ -import duckdb -import polars as pl -from typing import Iterator, Optional +from __future__ import annotations # noqa: D100 -from polars.io.plugins import register_io_source -from duckdb import SQLExpression +import contextlib +import datetime import json +import typing from decimal import Decimal -import datetime -def _predicate_to_expression(predicate: pl.Expr) -> Optional[SQLExpression]: - """ - Convert a Polars predicate expression to a DuckDB-compatible SQL expression. - +import polars as pl +from polars.io.plugins import register_io_source + +import duckdb + +if typing.TYPE_CHECKING: + from collections.abc import Iterator + + import typing_extensions + +_ExpressionTree: typing_extensions.TypeAlias = typing.Dict[str, typing.Union[str, int, "_ExpressionTree", typing.Any]] # noqa: UP006 + + +def _predicate_to_expression(predicate: pl.Expr) -> duckdb.Expression | None: + """Convert a Polars predicate expression to a DuckDB-compatible SQL expression. + Parameters: predicate (pl.Expr): A Polars expression (e.g., col("foo") > 5) - + Returns: SQLExpression: A DuckDB SQL expression string equivalent. None: If conversion fails. @@ -25,20 +35,19 @@ def _predicate_to_expression(predicate: pl.Expr) -> Optional[SQLExpression]: """ # Serialize the Polars expression tree to JSON tree = json.loads(predicate.meta.serialize(format="json")) - + try: # Convert the tree to SQL sql_filter = _pl_tree_to_sql(tree) - return SQLExpression(sql_filter) - except: + return duckdb.SQLExpression(sql_filter) + except Exception: # If the conversion fails, we return None return None def _pl_operation_to_sql(op: str) -> str: - """ - Map Polars binary operation strings to SQL equivalents. - + """Map Polars binary operation strings to SQL equivalents. + Example: >>> _pl_operation_to_sql("Eq") '=' @@ -55,12 +64,11 @@ def _pl_operation_to_sql(op: str) -> str: "Or": "OR", }[op] except KeyError: - raise NotImplementedError(op) + raise NotImplementedError(op) # noqa: B904 def _escape_sql_identifier(identifier: str) -> str: - """ - Escape SQL identifiers by doubling any double quotes and wrapping in double quotes. + """Escape SQL identifiers by doubling any double quotes and wrapping in double quotes. Example: >>> _escape_sql_identifier('column"name') @@ -70,16 +78,15 @@ def _escape_sql_identifier(identifier: str) -> str: return f'"{escaped}"' -def _pl_tree_to_sql(tree: dict) -> str: - """ - Recursively convert a Polars expression tree (as JSON) to a SQL string. - +def _pl_tree_to_sql(tree: _ExpressionTree) -> str: + """Recursively convert a Polars expression tree (as JSON) to a SQL string. + Parameters: tree (dict): JSON-deserialized expression tree from Polars - + Returns: str: SQL expression string - + Example: Input tree: { @@ -92,36 +99,51 @@ def _pl_tree_to_sql(tree: dict) -> str: Output: "(foo > 5)" """ [node_type] = tree.keys() - subtree = tree[node_type] if node_type == "BinaryExpr": # Binary expressions: left OP right - return ( - "(" + - " ".join(( - _pl_tree_to_sql(subtree['left']), - _pl_operation_to_sql(subtree['op']), - _pl_tree_to_sql(subtree['right']) - )) + - ")" - ) + bin_expr_tree = tree[node_type] + assert isinstance(bin_expr_tree, dict), f"A {node_type} should be a dict but got {type(bin_expr_tree)}" + lhs, op, rhs = bin_expr_tree["left"], bin_expr_tree["op"], bin_expr_tree["right"] + assert isinstance(lhs, dict), f"LHS of a {node_type} should be a dict but got {type(lhs)}" + assert isinstance(op, str), f"The op of a {node_type} should be a str but got {type(op)}" + assert isinstance(rhs, dict), f"RHS of a {node_type} should be a dict but got {type(rhs)}" + return f"({_pl_tree_to_sql(lhs)} {_pl_operation_to_sql(op)} {_pl_tree_to_sql(rhs)})" if node_type == "Column": # A reference to a column name # Wrap in quotes to handle special characters - return _escape_sql_identifier(subtree) + col_name = tree[node_type] + assert isinstance(col_name, str), f"The col name of a {node_type} should be a str but got {type(col_name)}" + return _escape_sql_identifier(col_name) if node_type in ("Literal", "Dyn"): # Recursively process dynamic or literal values - return _pl_tree_to_sql(subtree) + val_tree = tree[node_type] + assert isinstance(val_tree, dict), f"A {node_type} should be a dict but got {type(val_tree)}" + return _pl_tree_to_sql(val_tree) if node_type == "Int": # Direct integer literals - return str(subtree) + int_literal = tree[node_type] + assert isinstance(int_literal, (int, str)), ( + f"The value of an Int should be an int or str but got {type(int_literal)}" + ) + return str(int_literal) if node_type == "Function": # Handle boolean functions like IsNull, IsNotNull - inputs = subtree["input"] - func_dict = subtree["function"] + func_tree = tree[node_type] + assert isinstance(func_tree, dict), f"A {node_type} should be a dict but got {type(func_tree)}" + inputs = func_tree["input"] + assert isinstance(inputs, list), f"A {node_type} should have a list of dicts as input but got {type(inputs)}" + input_tree = inputs[0] + assert isinstance(input_tree, dict), ( + f"A {node_type} should have a list of dicts as input but got {type(input_tree)}" + ) + func_dict = func_tree["function"] + assert isinstance(func_dict, dict), ( + f"A {node_type} should have a function dict as input but got {type(func_dict)}" + ) if "Boolean" in func_dict: func = func_dict["Boolean"] @@ -131,80 +153,107 @@ def _pl_tree_to_sql(tree: dict) -> str: return f"({arg_sql} IS NULL)" if func == "IsNotNull": return f"({arg_sql} IS NOT NULL)" - raise NotImplementedError(f"Boolean function not supported: {func}") + msg = f"Boolean function not supported: {func}" + raise NotImplementedError(msg) - raise NotImplementedError(f"Unsupported function type: {func_dict}") + msg = f"Unsupported function type: {func_dict}" + raise NotImplementedError(msg) if node_type == "Scalar": # Detect format: old style (dtype/value) or new style (direct type key) - if "dtype" in subtree and "value" in subtree: - dtype = str(subtree["dtype"]) - value = subtree["value"] + scalar_tree = tree[node_type] + assert isinstance(scalar_tree, dict), f"A {node_type} should be a dict but got {type(scalar_tree)}" + if "dtype" in scalar_tree and "value" in scalar_tree: + dtype = str(scalar_tree["dtype"]) + value = scalar_tree["value"] else: # New style: dtype is the single key in the dict - dtype = next(iter(subtree.keys())) - value = subtree + dtype = next(iter(scalar_tree.keys())) + value = scalar_tree + assert isinstance(dtype, str), f"A {node_type} should have a str dtype but got {type(dtype)}" + assert isinstance(value, dict), f"A {node_type} should have a dict value but got {type(value)}" # Decimal support if dtype.startswith("{'Decimal'") or dtype == "Decimal": - decimal_value = value['Decimal'] - decimal_value = Decimal(decimal_value[0]) / Decimal(10 ** decimal_value[1]) - return str(decimal_value) + decimal_value = value["Decimal"] + assert isinstance(decimal_value, list), ( + f"A {dtype} should be a two or three member list but got {type(decimal_value)}" + ) + assert 2 <= len(decimal_value) <= 3, ( + f"A {dtype} should be a two or three member list but got {len(decimal_value)} member list" + ) + return str(Decimal(decimal_value[0]) / Decimal(10 ** decimal_value[-1])) # Datetime with microseconds since epoch if dtype.startswith("{'Datetime'") or dtype == "Datetime": - micros = value['Datetime'][0] - dt_timestamp = datetime.datetime.fromtimestamp(micros / 1_000_000, tz=datetime.UTC) - return f"'{str(dt_timestamp)}'::TIMESTAMP" + micros = value["Datetime"] + assert isinstance(micros, list), f"A {dtype} should be a one member list but got {type(micros)}" + dt_timestamp = datetime.datetime.fromtimestamp(micros[0] / 1_000_000, tz=datetime.timezone.utc) + return f"'{dt_timestamp!s}'::TIMESTAMP" # Match simple numeric/boolean types - if dtype in ("Int8", "Int16", "Int32", "Int64", - "UInt8", "UInt16", "UInt32", "UInt64", - "Float32", "Float64", "Boolean"): + if dtype in ( + "Int8", + "Int16", + "Int32", + "Int64", + "UInt8", + "UInt16", + "UInt32", + "UInt64", + "Float32", + "Float64", + "Boolean", + ): return str(value[dtype]) # Time type if dtype == "Time": nanoseconds = value["Time"] + assert isinstance(nanoseconds, int), f"A {dtype} should be an int but got {type(nanoseconds)}" seconds = nanoseconds // 1_000_000_000 microseconds = (nanoseconds % 1_000_000_000) // 1_000 - dt_time = (datetime.datetime.min + datetime.timedelta( - seconds=seconds, microseconds=microseconds - )).time() + dt_time = (datetime.datetime.min + datetime.timedelta(seconds=seconds, microseconds=microseconds)).time() return f"'{dt_time}'::TIME" # Date type if dtype == "Date": days_since_epoch = value["Date"] + assert isinstance(days_since_epoch, (float, int)), ( + f"A {dtype} should be a number but got {type(days_since_epoch)}" + ) date = datetime.date(1970, 1, 1) + datetime.timedelta(days=days_since_epoch) return f"'{date}'::DATE" # Binary type if dtype == "Binary": - binary_data = bytes(value["Binary"]) - escaped = ''.join(f'\\x{b:02x}' for b in binary_data) + bin_value = value["Binary"] + assert isinstance(bin_value, list), f"A {dtype} should be a list but got {type(bin_value)}" + binary_data = bytes(bin_value) + escaped = "".join(f"\\x{b:02x}" for b in binary_data) return f"'{escaped}'::BLOB" # String type if dtype == "String" or dtype == "StringOwned": # Some new formats may store directly under StringOwned - string_val = value.get("StringOwned", value.get("String", None)) + string_val: object | None = value.get("StringOwned", value.get("String", None)) return f"'{string_val}'" + msg = f"Unsupported scalar type {dtype!s}, with value {value}" + raise NotImplementedError(msg) - raise NotImplementedError(f"Unsupported scalar type {str(dtype)}, with value {value}") + msg = f"Node type: {node_type} is not implemented. {tree[node_type]}" + raise NotImplementedError(msg) - raise NotImplementedError(f"Node type: {node_type} is not implemented. {subtree}") def duckdb_source(relation: duckdb.DuckDBPyRelation, schema: pl.schema.Schema) -> pl.LazyFrame: - """ - A polars IO plugin for DuckDB. - """ + """A polars IO plugin for DuckDB.""" + def source_generator( - with_columns: Optional[list[str]], - predicate: Optional[pl.Expr], - n_rows: Optional[int], - batch_size: Optional[int], + with_columns: list[str] | None, + predicate: pl.Expr | None, + n_rows: int | None, + batch_size: int | None, ) -> Iterator[pl.DataFrame]: duck_predicate = None relation_final = relation @@ -215,7 +264,8 @@ def source_generator( relation_final = relation_final.limit(n_rows) if predicate is not None: # We have a predicate, if possible, we push it down to DuckDB - duck_predicate = _predicate_to_expression(predicate) + with contextlib.suppress(AssertionError, KeyError): + duck_predicate = _predicate_to_expression(predicate) # Try to pushdown filter, if one exists if duck_predicate is not None: relation_final = relation_final.filter(duck_predicate) @@ -223,15 +273,12 @@ def source_generator( results = relation_final.fetch_arrow_reader() else: results = relation_final.fetch_arrow_reader(batch_size) - while True: - try: - record_batch = results.read_next_batch() - if predicate is not None and duck_predicate is None: - # We have a predicate, but did not manage to push it down, we fallback here - yield pl.from_arrow(record_batch).filter(predicate) - else: - yield pl.from_arrow(record_batch) - except StopIteration: - break + + for record_batch in iter(results.read_next_batch, None): + if predicate is not None and duck_predicate is None: + # We have a predicate, but did not manage to push it down, we fallback here + yield pl.from_arrow(record_batch).filter(predicate) # type: ignore[arg-type,misc,unused-ignore] + else: + yield pl.from_arrow(record_batch) # type: ignore[misc,unused-ignore] return register_io_source(source_generator, schema=schema) diff --git a/duckdb/value/__init__.pyi b/duckdb/py.typed similarity index 100% rename from duckdb/value/__init__.pyi rename to duckdb/py.typed diff --git a/duckdb/query_graph/__main__.py b/duckdb/query_graph/__main__.py index 26038a6f..d4851694 100644 --- a/duckdb/query_graph/__main__.py +++ b/duckdb/query_graph/__main__.py @@ -1,10 +1,9 @@ +import argparse # noqa: D100 import json -import os -import sys import re import webbrowser from functools import reduce -import argparse +from pathlib import Path qgraph_css = """ .styled-table { @@ -57,7 +56,7 @@ text-align: center; padding: 0px; border-radius: 1px; - + /* Positioning */ position: absolute; z-index: 1; @@ -65,7 +64,7 @@ left: 50%; transform: translateX(-50%); margin-bottom: 8px; - + /* Tooltip Arrow */ width: 400px; } @@ -76,124 +75,128 @@ """ -class NodeTiming: - - def __init__(self, phase: str, time: float) -> object: +class NodeTiming: # noqa: D101 + def __init__(self, phase: str, time: float) -> None: # noqa: D107 self.phase = phase self.time = time # percentage is determined later. self.percentage = 0 - def calculate_percentage(self, total_time: float) -> None: + def calculate_percentage(self, total_time: float) -> None: # noqa: D102 self.percentage = self.time / total_time - def combine_timing(l: object, r: object) -> object: - # TODO: can only add timings for same-phase nodes - total_time = l.time + r.time - return NodeTiming(l.phase, total_time) + def combine_timing(self, r: "NodeTiming") -> "NodeTiming": # noqa: D102 + # TODO: can only add timings for same-phase nodes # noqa: TD002, TD003 + total_time = self.time + r.time + return NodeTiming(self.phase, total_time) -class AllTimings: - - def __init__(self): +class AllTimings: # noqa: D101 + def __init__(self) -> None: # noqa: D107 self.phase_to_timings = {} - def add_node_timing(self, node_timing: NodeTiming): + def add_node_timing(self, node_timing: NodeTiming) -> None: # noqa: D102 if node_timing.phase in self.phase_to_timings: self.phase_to_timings[node_timing.phase].append(node_timing) - return - self.phase_to_timings[node_timing.phase] = [node_timing] + else: + self.phase_to_timings[node_timing.phase] = [node_timing] - def get_phase_timings(self, phase: str): + def get_phase_timings(self, phase: str) -> list[NodeTiming]: # noqa: D102 return self.phase_to_timings[phase] - def get_summary_phase_timings(self, phase: str): + def get_summary_phase_timings(self, phase: str) -> NodeTiming: # noqa: D102 return reduce(NodeTiming.combine_timing, self.phase_to_timings[phase]) - def get_phases(self): + def get_phases(self) -> list[NodeTiming]: # noqa: D102 phases = list(self.phase_to_timings.keys()) phases.sort(key=lambda x: (self.get_summary_phase_timings(x)).time) phases.reverse() return phases - def get_sum_of_all_timings(self): + def get_sum_of_all_timings(self) -> float: # noqa: D102 total_timing_sum = 0 - for phase in self.phase_to_timings.keys(): + for phase in self.phase_to_timings: total_timing_sum += self.get_summary_phase_timings(phase).time return total_timing_sum -def open_utf8(fpath: str, flags: str) -> object: - return open(fpath, flags, encoding="utf8") +def open_utf8(fpath: str, flags: str) -> object: # noqa: D103 + return Path(fpath).open(mode=flags, encoding="utf8") -def get_child_timings(top_node: object, query_timings: object) -> str: - node_timing = NodeTiming(top_node['operator_type'], float(top_node['operator_timing'])) +def get_child_timings(top_node: object, query_timings: object) -> str: # noqa: D103 + node_timing = NodeTiming(top_node["operator_type"], float(top_node["operator_timing"])) query_timings.add_node_timing(node_timing) - for child in top_node['children']: + for child in top_node["children"]: get_child_timings(child, query_timings) -def get_pink_shade_hex(fraction: float): +def get_pink_shade_hex(fraction: float) -> str: # noqa: D103 fraction = max(0, min(1, fraction)) - + # Define the RGB values for very light pink (almost white) and dark pink light_pink = (255, 250, 250) # Very light pink - dark_pink = (255, 20, 147) # Dark pink - + dark_pink = (255, 20, 147) # Dark pink + # Calculate the RGB values for the given fraction r = int(light_pink[0] + (dark_pink[0] - light_pink[0]) * fraction) g = int(light_pink[1] + (dark_pink[1] - light_pink[1]) * fraction) b = int(light_pink[2] + (dark_pink[2] - light_pink[2]) * fraction) - + # Return as hexadecimal color code return f"#{r:02x}{g:02x}{b:02x}" -def get_node_body(name: str, result: str, cpu_time: float, card: int, est: int, width: int, extra_info: str) -> str: - node_style = f"background-color: {get_pink_shade_hex(float(result)/cpu_time)};" - body = f"" - body += "
" +def get_node_body(name: str, result: str, cpu_time: float, card: int, est: int, width: int, extra_info: str) -> str: # noqa: D103 + node_style = f"background-color: {get_pink_shade_hex(float(result) / cpu_time)};" + + body = f'' + body += '
' new_name = "BRIDGE" if (name == "INVALID") else name.replace("_", " ") formatted_num = f"{float(result):.4f}" body += f"

{new_name}

time: {formatted_num} seconds

" - body += f" {extra_info} " - if (width > 0): + body += f' {extra_info} ' + if width > 0: body += f"

cardinality: {card}

" body += f"

estimate: {est}

" body += f"

width: {width} bytes

" - # TODO: Expand on timing. Usually available from a detailed profiling + # TODO: Expand on timing. Usually available from a detailed profiling # noqa: TD002, TD003 body += "
" body += "
" return body -def generate_tree_recursive(json_graph: object, cpu_time: float) -> str: +def generate_tree_recursive(json_graph: object, cpu_time: float) -> str: # noqa: D103 node_prefix_html = "
  • " node_suffix_html = "
  • " extra_info = "" estimate = 0 - for key in json_graph['extra_info']: - value = json_graph['extra_info'][key] - if (key == "Estimated Cardinality"): + for key in json_graph["extra_info"]: + value = json_graph["extra_info"][key] + if key == "Estimated Cardinality": estimate = int(value) else: extra_info += f"{key}: {value}
    " cardinality = json_graph["operator_cardinality"] - width = int(json_graph["result_set_size"]/max(1,cardinality)) + width = int(json_graph["result_set_size"] / max(1, cardinality)) # get rid of some typically long names extra_info = re.sub(r"__internal_\s*", "__", extra_info) extra_info = re.sub(r"compress_integral\s*", "compress", extra_info) - node_body = get_node_body(json_graph["operator_type"], - json_graph["operator_timing"], - cpu_time, cardinality, estimate, width, - re.sub(r",\s*", ", ", extra_info)) + node_body = get_node_body( + json_graph["operator_type"], + json_graph["operator_timing"], + cpu_time, + cardinality, + estimate, + width, + re.sub(r",\s*", ", ", extra_info), + ) children_html = "" - if len(json_graph['children']) >= 1: + if len(json_graph["children"]) >= 1: children_html += "
      " for child in json_graph["children"]: children_html += generate_tree_recursive(child, cpu_time) @@ -202,12 +205,12 @@ def generate_tree_recursive(json_graph: object, cpu_time: float) -> str: # For generating the table in the top left. -def generate_timing_html(graph_json: object, query_timings: object) -> object: +def generate_timing_html(graph_json: object, query_timings: object) -> object: # noqa: D103 json_graph = json.loads(graph_json) gather_timing_information(json_graph, query_timings) - total_time = float(json_graph.get('operator_timing') or json_graph.get('latency')) + total_time = float(json_graph.get("operator_timing") or json_graph.get("latency")) table_head = """ - +
      @@ -224,7 +227,7 @@ def generate_timing_html(graph_json: object, query_timings: object) -> object: all_phases = query_timings.get_phases() query_timings.add_node_timing(NodeTiming("TOTAL TIME", total_time)) query_timings.add_node_timing(NodeTiming("Execution Time", execution_time)) - all_phases = ["TOTAL TIME", "Execution Time"] + all_phases + all_phases = ["TOTAL TIME", "Execution Time", *all_phases] for phase in all_phases: summarized_phase = query_timings.get_summary_phase_timings(phase) summarized_phase.calculate_percentage(total_time) @@ -240,55 +243,48 @@ def generate_timing_html(graph_json: object, query_timings: object) -> object: return table_head + table_body -def generate_tree_html(graph_json: object) -> str: +def generate_tree_html(graph_json: object) -> str: # noqa: D103 json_graph = json.loads(graph_json) - cpu_time = float(json_graph['cpu_time']) - tree_prefix = "
      \n
        " + cpu_time = float(json_graph["cpu_time"]) + tree_prefix = '
        \n
          ' tree_suffix = "
        " # first level of json is general overview - # FIXME: make sure json output first level always has only 1 level - tree_body = generate_tree_recursive(json_graph['children'][0], cpu_time) + # TODO: make sure json output first level always has only 1 level # noqa: TD002, TD003 + tree_body = generate_tree_recursive(json_graph["children"][0], cpu_time) return tree_prefix + tree_body + tree_suffix -def generate_ipython(json_input: str) -> str: +def generate_ipython(json_input: str) -> str: # noqa: D103 from IPython.core.display import HTML - html_output = generate_html(json_input, False) + html_output = generate_html(json_input, False) # noqa: F821 - return HTML(("\n" - " ${CSS}\n" - " ${LIBRARIES}\n" - "
        \n" - " ${CHART_SCRIPT}\n" - " ").replace("${CSS}", html_output['css']).replace('${CHART_SCRIPT}', - html_output['chart_script']).replace( - '${LIBRARIES}', html_output['libraries'])) + return HTML( + ('\n ${CSS}\n ${LIBRARIES}\n
        \n ${CHART_SCRIPT}\n ') + .replace("${CSS}", html_output["css"]) + .replace("${CHART_SCRIPT}", html_output["chart_script"]) + .replace("${LIBRARIES}", html_output["libraries"]) + ) -def generate_style_html(graph_json: str, include_meta_info: bool) -> None: - treeflex_css = "\n" +def generate_style_html(graph_json: str, include_meta_info: bool) -> None: # noqa: D103, FBT001 + treeflex_css = '\n' css = "\n" - return { - 'treeflex_css': treeflex_css, - 'duckdb_css': css, - 'libraries': '', - 'chart_script': '' - } + return {"treeflex_css": treeflex_css, "duckdb_css": css, "libraries": "", "chart_script": ""} -def gather_timing_information(json: str, query_timings: object) -> None: +def gather_timing_information(json: str, query_timings: object) -> None: # noqa: D103 # add up all of the times # measure each time as a percentage of the total time. # then you can return a list of [phase, time, percentage] - get_child_timings(json['children'][0], query_timings) + get_child_timings(json["children"][0], query_timings) -def translate_json_to_html(input_file: str, output_file: str) -> None: +def translate_json_to_html(input_file: str, output_file: str) -> None: # noqa: D103 query_timings = AllTimings() - with open_utf8(input_file, 'r') as f: + with open_utf8(input_file, "r") as f: text = f.read() html_output = generate_style_html(text, True) @@ -317,23 +313,22 @@ def translate_json_to_html(input_file: str, output_file: str) -> None: """ - html = html.replace("${TREEFLEX_CSS}", html_output['treeflex_css']) - html = html.replace("${DUCKDB_CSS}", html_output['duckdb_css']) + html = html.replace("${TREEFLEX_CSS}", html_output["treeflex_css"]) + html = html.replace("${DUCKDB_CSS}", html_output["duckdb_css"]) html = html.replace("${TIMING_TABLE}", timing_table) - html = html.replace('${TREE}', tree_output) + html = html.replace("${TREE}", tree_output) f.write(html) -def main() -> None: - if sys.version_info[0] < 3: - print("Please use python3") - exit(1) +def main() -> None: # noqa: D103 parser = argparse.ArgumentParser( - prog='Query Graph Generator', - description='Given a json profile output, generate a html file showing the query graph and timings of operators') - parser.add_argument('profile_input', help='profile input in json') - parser.add_argument('--out', required=False, default=False) - parser.add_argument('--open', required=False, action='store_true', default=True) + prog="Query Graph Generator", + description="""Given a json profile output, generate a html file showing the query graph and + timings of operators""", + ) + parser.add_argument("profile_input", help="profile input in json") + parser.add_argument("--out", required=False, default=False) + parser.add_argument("--open", required=False, action="store_true", default=True) args = parser.parse_args() input = args.profile_input @@ -356,8 +351,8 @@ def main() -> None: translate_json_to_html(input, output) if open_output: - webbrowser.open('file://' + os.path.abspath(output), new=2) + webbrowser.open(f"file://{Path(output).resolve()}", new=2) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/duckdb/sqltypes/__init__.py b/duckdb/sqltypes/__init__.py new file mode 100644 index 00000000..38917ce3 --- /dev/null +++ b/duckdb/sqltypes/__init__.py @@ -0,0 +1,63 @@ +"""DuckDB's SQL types.""" + +from _duckdb._sqltypes import ( + BIGINT, + BIT, + BLOB, + BOOLEAN, + DATE, + DOUBLE, + FLOAT, + HUGEINT, + INTEGER, + INTERVAL, + SMALLINT, + SQLNULL, + TIME, + TIME_TZ, + TIMESTAMP, + TIMESTAMP_MS, + TIMESTAMP_NS, + TIMESTAMP_S, + TIMESTAMP_TZ, + TINYINT, + UBIGINT, + UHUGEINT, + UINTEGER, + USMALLINT, + UTINYINT, + UUID, + VARCHAR, + DuckDBPyType, +) + +__all__ = [ + "BIGINT", + "BIT", + "BLOB", + "BOOLEAN", + "DATE", + "DOUBLE", + "FLOAT", + "HUGEINT", + "INTEGER", + "INTERVAL", + "SMALLINT", + "SQLNULL", + "TIME", + "TIMESTAMP", + "TIMESTAMP_MS", + "TIMESTAMP_NS", + "TIMESTAMP_S", + "TIMESTAMP_TZ", + "TIME_TZ", + "TINYINT", + "UBIGINT", + "UHUGEINT", + "UINTEGER", + "USMALLINT", + "UTINYINT", + "UUID", + "VARCHAR", + "DuckDBPyType", +] diff --git a/duckdb/typing/__init__.py b/duckdb/typing/__init__.py index d0e95b50..4c29047b 100644 --- a/duckdb/typing/__init__.py +++ b/duckdb/typing/__init__.py @@ -1,5 +1,8 @@ -from _duckdb.typing import ( - DuckDBPyType, +"""DuckDB's SQL types. DEPRECATED. Please use `duckdb.sqltypes` instead.""" + +import warnings + +from duckdb.sqltypes import ( BIGINT, BIT, BLOB, @@ -8,29 +11,29 @@ DOUBLE, FLOAT, HUGEINT, - UHUGEINT, INTEGER, INTERVAL, SMALLINT, SQLNULL, TIME, + TIME_TZ, TIMESTAMP, TIMESTAMP_MS, TIMESTAMP_NS, TIMESTAMP_S, TIMESTAMP_TZ, - TIME_TZ, TINYINT, UBIGINT, + UHUGEINT, UINTEGER, USMALLINT, UTINYINT, UUID, - VARCHAR + VARCHAR, + DuckDBPyType, ) __all__ = [ - "DuckDBPyType", "BIGINT", "BIT", "BLOB", @@ -39,7 +42,6 @@ "DOUBLE", "FLOAT", "HUGEINT", - "UHUGEINT", "INTEGER", "INTERVAL", "SMALLINT", @@ -53,9 +55,17 @@ "TIME_TZ", "TINYINT", "UBIGINT", + "UHUGEINT", "UINTEGER", "USMALLINT", "UTINYINT", "UUID", - "VARCHAR" + "VARCHAR", + "DuckDBPyType", ] + +warnings.warn( + "`duckdb.typing` is deprecated and will be removed in a future version. Please use `duckdb.sqltypes` instead.", + DeprecationWarning, + stacklevel=2, +) diff --git a/duckdb/typing/__init__.pyi b/duckdb/typing/__init__.pyi deleted file mode 100644 index 69435c05..00000000 --- a/duckdb/typing/__init__.pyi +++ /dev/null @@ -1,36 +0,0 @@ -from duckdb import DuckDBPyConnection - -SQLNULL: DuckDBPyType -BOOLEAN: DuckDBPyType -TINYINT: DuckDBPyType -UTINYINT: DuckDBPyType -SMALLINT: DuckDBPyType -USMALLINT: DuckDBPyType -INTEGER: DuckDBPyType -UINTEGER: DuckDBPyType -BIGINT: DuckDBPyType -UBIGINT: DuckDBPyType -HUGEINT: DuckDBPyType -UHUGEINT: DuckDBPyType -UUID: DuckDBPyType -FLOAT: DuckDBPyType -DOUBLE: DuckDBPyType -DATE: DuckDBPyType -TIMESTAMP: DuckDBPyType -TIMESTAMP_MS: DuckDBPyType -TIMESTAMP_NS: DuckDBPyType -TIMESTAMP_S: DuckDBPyType -TIME: DuckDBPyType -TIME_TZ: DuckDBPyType -TIMESTAMP_TZ: DuckDBPyType -VARCHAR: DuckDBPyType -BLOB: DuckDBPyType -BIT: DuckDBPyType -INTERVAL: DuckDBPyType - -class DuckDBPyType: - def __init__(self, type_str: str, connection: DuckDBPyConnection = ...) -> None: ... - def __repr__(self) -> str: ... - def __eq__(self, other) -> bool: ... - def __getattr__(self, name: str): DuckDBPyType - def __getitem__(self, name: str): DuckDBPyType \ No newline at end of file diff --git a/duckdb/udf.py b/duckdb/udf.py index bbf05c7d..b15ba709 100644 --- a/duckdb/udf.py +++ b/duckdb/udf.py @@ -1,9 +1,15 @@ -def vectorized(func): - """ - Decorate a function with annotated function parameters, so DuckDB can infer that the function should be provided with pyarrow arrays and should expect pyarrow array(s) as output +# ruff: noqa: D100 +import typing + + +def vectorized(func: typing.Callable[..., typing.Any]) -> typing.Callable[..., typing.Any]: + """Decorate a function with annotated function parameters. + + This allows DuckDB to infer that the function should be provided with pyarrow arrays and should expect + pyarrow array(s) as output. """ - from inspect import signature import types + from inspect import signature new_func = types.FunctionType(func.__code__, func.__globals__, func.__name__, func.__defaults__, func.__closure__) # Construct the annotations: @@ -11,7 +17,6 @@ def vectorized(func): new_annotations = {} sig = signature(func) - sig.parameters for param in sig.parameters: new_annotations[param] = pa.lib.ChunkedArray diff --git a/duckdb/value/__init__.py b/duckdb/value/__init__.py index e69de29b..6e031999 100644 --- a/duckdb/value/__init__.py +++ b/duckdb/value/__init__.py @@ -0,0 +1 @@ +# noqa: D104 diff --git a/duckdb/value/constant/__init__.py b/duckdb/value/constant/__init__.py index da2004b9..530c6bdc 100644 --- a/duckdb/value/constant/__init__.py +++ b/duckdb/value/constant/__init__.py @@ -1,6 +1,7 @@ -from typing import Any, Dict -from duckdb.typing import DuckDBPyType -from duckdb.typing import ( +# ruff: noqa: D101, D104, D105, D107, ANN401 +from typing import Any + +from duckdb.sqltypes import ( BIGINT, BIT, BLOB, @@ -9,30 +10,31 @@ DOUBLE, FLOAT, HUGEINT, - UHUGEINT, INTEGER, INTERVAL, SMALLINT, SQLNULL, TIME, + TIME_TZ, TIMESTAMP, TIMESTAMP_MS, TIMESTAMP_NS, TIMESTAMP_S, TIMESTAMP_TZ, - TIME_TZ, TINYINT, UBIGINT, + UHUGEINT, UINTEGER, USMALLINT, UTINYINT, UUID, VARCHAR, + DuckDBPyType, ) class Value: - def __init__(self, object: Any, type: DuckDBPyType): + def __init__(self, object: Any, type: DuckDBPyType) -> None: self.object = object self.type = type @@ -44,12 +46,12 @@ def __repr__(self) -> str: class NullValue(Value): - def __init__(self): + def __init__(self) -> None: super().__init__(None, SQLNULL) class BooleanValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, BOOLEAN) @@ -57,22 +59,22 @@ def __init__(self, object: Any): class UnsignedBinaryValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, UTINYINT) class UnsignedShortValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, USMALLINT) class UnsignedIntegerValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, UINTEGER) class UnsignedLongValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, UBIGINT) @@ -80,32 +82,32 @@ def __init__(self, object: Any): class BinaryValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, TINYINT) class ShortValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, SMALLINT) class IntegerValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, INTEGER) class LongValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, BIGINT) class HugeIntegerValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, HUGEINT) class UnsignedHugeIntegerValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, UHUGEINT) @@ -113,17 +115,17 @@ def __init__(self, object: Any): class FloatValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, FLOAT) class DoubleValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, DOUBLE) class DecimalValue(Value): - def __init__(self, object: Any, width: int, scale: int): + def __init__(self, object: Any, width: int, scale: int) -> None: import duckdb decimal_type = duckdb.decimal_type(width, scale) @@ -134,22 +136,22 @@ def __init__(self, object: Any, width: int, scale: int): class StringValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, VARCHAR) class UUIDValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, UUID) class BitValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, BIT) class BlobValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, BLOB) @@ -157,52 +159,52 @@ def __init__(self, object: Any): class DateValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, DATE) class IntervalValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, INTERVAL) class TimestampValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, TIMESTAMP) class TimestampSecondValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, TIMESTAMP_S) class TimestampMilisecondValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, TIMESTAMP_MS) class TimestampNanosecondValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, TIMESTAMP_NS) class TimestampTimeZoneValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, TIMESTAMP_TZ) class TimeValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, TIME) class TimeTimeZoneValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, TIME_TZ) class ListValue(Value): - def __init__(self, object: Any, child_type: DuckDBPyType): + def __init__(self, object: Any, child_type: DuckDBPyType) -> None: import duckdb list_type = duckdb.list_type(child_type) @@ -210,7 +212,7 @@ def __init__(self, object: Any, child_type: DuckDBPyType): class StructValue(Value): - def __init__(self, object: Any, children: Dict[str, DuckDBPyType]): + def __init__(self, object: Any, children: dict[str, DuckDBPyType]) -> None: import duckdb struct_type = duckdb.struct_type(children) @@ -218,7 +220,7 @@ def __init__(self, object: Any, children: Dict[str, DuckDBPyType]): class MapValue(Value): - def __init__(self, object: Any, key_type: DuckDBPyType, value_type: DuckDBPyType): + def __init__(self, object: Any, key_type: DuckDBPyType, value_type: DuckDBPyType) -> None: import duckdb map_type = duckdb.map_type(key_type, value_type) @@ -226,43 +228,43 @@ def __init__(self, object: Any, key_type: DuckDBPyType, value_type: DuckDBPyType class UnionType(Value): - def __init__(self, object: Any, members: Dict[str, DuckDBPyType]): + def __init__(self, object: Any, members: dict[str, DuckDBPyType]) -> None: import duckdb union_type = duckdb.union_type(members) super().__init__(object, union_type) -# TODO: add EnumValue once `duckdb.enum_type` is added +# TODO: add EnumValue once `duckdb.enum_type` is added # noqa: TD002, TD003 __all__ = [ - "Value", - "NullValue", - "BooleanValue", - "UnsignedBinaryValue", - "UnsignedShortValue", - "UnsignedIntegerValue", - "UnsignedLongValue", "BinaryValue", - "ShortValue", - "IntegerValue", - "LongValue", - "HugeIntegerValue", - "UnsignedHugeIntegerValue", - "FloatValue", - "DoubleValue", - "DecimalValue", - "StringValue", - "UUIDValue", "BitValue", "BlobValue", + "BooleanValue", "DateValue", + "DecimalValue", + "DoubleValue", + "FloatValue", + "HugeIntegerValue", + "IntegerValue", "IntervalValue", - "TimestampValue", - "TimestampSecondValue", + "LongValue", + "NullValue", + "ShortValue", + "StringValue", + "TimeTimeZoneValue", + "TimeValue", "TimestampMilisecondValue", "TimestampNanosecondValue", + "TimestampSecondValue", "TimestampTimeZoneValue", - "TimeValue", - "TimeTimeZoneValue", + "TimestampValue", + "UUIDValue", + "UnsignedBinaryValue", + "UnsignedHugeIntegerValue", + "UnsignedIntegerValue", + "UnsignedLongValue", + "UnsignedShortValue", + "Value", ] diff --git a/duckdb/value/constant/__init__.pyi b/duckdb/value/constant/__init__.pyi deleted file mode 100644 index 8cea58cf..00000000 --- a/duckdb/value/constant/__init__.pyi +++ /dev/null @@ -1,115 +0,0 @@ -from duckdb.typing import DuckDBPyType -from typing import Any - -class NullValue(Value): - def __init__(self) -> None: ... - def __repr__(self) -> str: ... - -class BooleanValue(Value): - def __init__(self, object: Any) -> None: ... - def __repr__(self) -> str: ... - -class UnsignedBinaryValue(Value): - def __init__(self, object: Any) -> None: ... - def __repr__(self) -> str: ... - -class UnsignedShortValue(Value): - def __init__(self, object: Any) -> None: ... - def __repr__(self) -> str: ... - -class UnsignedIntegerValue(Value): - def __init__(self, object: Any) -> None: ... - def __repr__(self) -> str: ... - -class UnsignedLongValue(Value): - def __init__(self, object: Any) -> None: ... - def __repr__(self) -> str: ... - -class BinaryValue(Value): - def __init__(self, object: Any) -> None: ... - def __repr__(self) -> str: ... - -class ShortValue(Value): - def __init__(self, object: Any) -> None: ... - def __repr__(self) -> str: ... - -class IntegerValue(Value): - def __init__(self, object: Any) -> None: ... - def __repr__(self) -> str: ... - -class LongValue(Value): - def __init__(self, object: Any) -> None: ... - def __repr__(self) -> str: ... - -class HugeIntegerValue(Value): - def __init__(self, object: Any) -> None: ... - def __repr__(self) -> str: ... - -class FloatValue(Value): - def __init__(self, object: Any) -> None: ... - def __repr__(self) -> str: ... - -class DoubleValue(Value): - def __init__(self, object: Any) -> None: ... - def __repr__(self) -> str: ... - -class DecimalValue(Value): - def __init__(self, object: Any, width: int, scale: int) -> None: ... - def __repr__(self) -> str: ... - -class StringValue(Value): - def __init__(self, object: Any) -> None: ... - def __repr__(self) -> str: ... - -class UUIDValue(Value): - def __init__(self, object: Any) -> None: ... - def __repr__(self) -> str: ... - -class BitValue(Value): - def __init__(self, object: Any) -> None: ... - def __repr__(self) -> str: ... - -class BlobValue(Value): - def __init__(self, object: Any) -> None: ... - def __repr__(self) -> str: ... - -class DateValue(Value): - def __init__(self, object: Any) -> None: ... - def __repr__(self) -> str: ... - -class IntervalValue(Value): - def __init__(self, object: Any) -> None: ... - def __repr__(self) -> str: ... - -class TimestampValue(Value): - def __init__(self, object: Any) -> None: ... - def __repr__(self) -> str: ... - -class TimestampSecondValue(Value): - def __init__(self, object: Any) -> None: ... - def __repr__(self) -> str: ... - -class TimestampMilisecondValue(Value): - def __init__(self, object: Any) -> None: ... - def __repr__(self) -> str: ... - -class TimestampNanosecondValue(Value): - def __init__(self, object: Any) -> None: ... - def __repr__(self) -> str: ... - -class TimestampTimeZoneValue(Value): - def __init__(self, object: Any) -> None: ... - def __repr__(self) -> str: ... - -class TimeValue(Value): - def __init__(self, object: Any) -> None: ... - def __repr__(self) -> str: ... - -class TimeTimeZoneValue(Value): - def __init__(self, object: Any) -> None: ... - def __repr__(self) -> str: ... - - -class Value: - def __init__(self, object: Any, type: DuckDBPyType) -> None: ... - def __repr__(self) -> str: ... diff --git a/duckdb_packaging/_versioning.py b/duckdb_packaging/_versioning.py index ca8e7716..0a5eb66b 100644 --- a/duckdb_packaging/_versioning.py +++ b/duckdb_packaging/_versioning.py @@ -5,13 +5,15 @@ - Git tag creation and management - Version parsing and validation """ + import pathlib +import re import subprocess from typing import Optional -import re - -VERSION_RE = re.compile(r"^(?P[0-9]+)\.(?P[0-9]+)\.(?P[0-9]+)(?:rc(?P[0-9]+)|\.post(?P[0-9]+))?$") +VERSION_RE = re.compile( + r"^(?P[0-9]+)\.(?P[0-9]+)\.(?P[0-9]+)(?:rc(?P[0-9]+)|\.post(?P[0-9]+))?$" +) def parse_version(version: str) -> tuple[int, int, int, int, int]: @@ -28,7 +30,8 @@ def parse_version(version: str) -> tuple[int, int, int, int, int]: """ match = VERSION_RE.match(version) if not match: - raise ValueError(f"Invalid version format: {version} (expected X.Y.Z, X.Y.Z.rcM or X.Y.Z.postN)") + msg = f"Invalid version format: {version} (expected X.Y.Z, X.Y.Z.rcM or X.Y.Z.postN)" + raise ValueError(msg) major, minor, patch, rc, post = match.groups() return int(major), int(minor), int(patch), int(post or 0), int(rc or 0) @@ -49,7 +52,8 @@ def format_version(major: int, minor: int, patch: int, post: int = 0, rc: int = """ version = f"{major}.{minor}.{patch}" if post != 0 and rc != 0: - raise ValueError("post and rc are mutually exclusive") + msg = "post and rc are mutually exclusive" + raise ValueError(msg) if post != 0: version += f".post{post}" if rc != 0: @@ -67,12 +71,12 @@ def git_tag_to_pep440(git_tag: str) -> str: PEP440 version string (e.g., "1.3.1", "1.3.1.post1") """ # Remove 'v' prefix if present - version = git_tag[1:] if git_tag.startswith('v') else git_tag + version = git_tag[1:] if git_tag.startswith("v") else git_tag if "-post" in version: - assert 'rc' not in version + assert "rc" not in version version = version.replace("-post", ".post") - elif '-rc' in version: + elif "-rc" in version: version = version.replace("-rc", "rc") return version @@ -87,10 +91,10 @@ def pep440_to_git_tag(version: str) -> str: Returns: Git tag format (e.g., "v1.3.1-post1") """ - if '.post' in version: - assert 'rc' not in version + if ".post" in version: + assert "rc" not in version version = version.replace(".post", "-post") - elif 'rc' in version: + elif "rc" in version: version = version.replace("rc", "-rc") return f"v{version}" @@ -104,12 +108,7 @@ def get_current_version() -> Optional[str]: """ try: # Get the latest tag - result = subprocess.run( - ["git", "describe", "--tags", "--abbrev=0"], - capture_output=True, - text=True, - check=True - ) + result = subprocess.run(["git", "describe", "--tags", "--abbrev=0"], capture_output=True, text=True, check=True) tag = result.stdout.strip() return git_tag_to_pep440(tag) except subprocess.CalledProcessError: @@ -141,35 +140,39 @@ def create_git_tag(version: str, message: Optional[str] = None, repo_path: Optio def strip_post_from_version(version: str) -> str: - """ - Removing post-release suffixes from the given version. + """Removing post-release suffixes from the given version. DuckDB doesn't allow post-release versions, so .post* suffixes are stripped. """ return re.sub(r"[\.-]post[0-9]+", "", version) -def get_git_describe(repo_path: Optional[pathlib.Path] = None, since_major=False, since_minor=False) -> Optional[str]: +def get_git_describe( + repo_path: Optional[pathlib.Path] = None, + since_major: bool = False, # noqa: FBT001 + since_minor: bool = False, # noqa: FBT001 +) -> Optional[str]: """Get git describe output for version determination. Returns: Git describe output or None if no tags exist """ cwd = repo_path if repo_path is not None else None - pattern="v*.*.*" + pattern = "v*.*.*" if since_major: - pattern="v*.0.0" + pattern = "v*.0.0" elif since_minor: - pattern="v*.*.0" + pattern = "v*.*.0" try: result = subprocess.run( ["git", "describe", "--tags", "--long", "--match", pattern], capture_output=True, text=True, check=True, - cwd=cwd + cwd=cwd, ) result.check_returncode() return result.stdout.strip() - except FileNotFoundError: - raise RuntimeError("git executable can't be found") + except FileNotFoundError as e: + msg = "git executable can't be found" + raise RuntimeError(msg) from e diff --git a/duckdb_packaging/build_backend.py b/duckdb_packaging/build_backend.py index d96a4847..799a43c9 100644 --- a/duckdb_packaging/build_backend.py +++ b/duckdb_packaging/build_backend.py @@ -12,25 +12,29 @@ Also see https://peps.python.org/pep-0517/#in-tree-build-backends. """ -import sys -import os + import subprocess +import sys from pathlib import Path -from typing import Optional, Dict, List, Union +from typing import Optional, Union + from scikit_build_core.build import ( - build_wheel as skbuild_build_wheel, build_editable, - build_sdist as skbuild_build_sdist, - get_requires_for_build_wheel, - get_requires_for_build_sdist, get_requires_for_build_editable, - prepare_metadata_for_build_wheel, + get_requires_for_build_sdist, + get_requires_for_build_wheel, prepare_metadata_for_build_editable, + prepare_metadata_for_build_wheel, +) +from scikit_build_core.build import ( + build_sdist as skbuild_build_sdist, +) +from scikit_build_core.build import ( + build_wheel as skbuild_build_wheel, ) -from duckdb_packaging._versioning import create_git_tag, pep440_to_git_tag, get_git_describe, strip_post_from_version -from duckdb_packaging.setuptools_scm_version import forced_version_from_env, MAIN_BRANCH_VERSIONING - +from duckdb_packaging._versioning import get_git_describe, pep440_to_git_tag, strip_post_from_version +from duckdb_packaging.setuptools_scm_version import MAIN_BRANCH_VERSIONING, forced_version_from_env _DUCKDB_VERSION_FILENAME = "duckdb_version.txt" _LOGGING_FORMAT = "[duckdb_pytooling.build_backend] {}" @@ -39,14 +43,13 @@ _FORCED_PEP440_VERSION = forced_version_from_env() -def _log(msg: str, is_error: bool=False) -> None: +def _log(msg: str) -> None: """Log a message with build backend prefix. Args: msg: The message to log. - is_error: If True, log to stderr; otherwise log to stdout. """ - print(_LOGGING_FORMAT.format(msg), flush=True, file=sys.stderr if is_error else sys.stdout) + print(_LOGGING_FORMAT.format(msg), flush=True, file=sys.stderr) def _in_git_repository() -> bool: @@ -70,10 +73,11 @@ def _in_sdist() -> bool: def _duckdb_submodule_path() -> Path: """Verify that the duckdb submodule is checked out and usable and return its path.""" if not _in_git_repository(): - raise RuntimeError("Not in a git repository, no duckdb submodule present") + msg = "Not in a git repository, no duckdb submodule present" + raise RuntimeError(msg) # search the duckdb submodule gitmodules_path = Path(".gitmodules") - modules = dict() + modules = {} with gitmodules_path.open("r") as f: cur_module_path = None cur_module_reponame = None @@ -84,15 +88,16 @@ def _duckdb_submodule_path() -> Path: cur_module_reponame = None cur_module_path = None elif line.strip().startswith("path"): - cur_module_path = line.split('=')[-1].strip() + cur_module_path = line.split("=")[-1].strip() elif line.strip().startswith("url"): - basename = os.path.basename(line.split('=')[-1].strip()) + basename = Path(line.split("=")[-1].strip()).name cur_module_reponame = basename[:-4] if basename.endswith(".git") else basename if cur_module_reponame is not None and cur_module_path is not None: modules[cur_module_reponame] = cur_module_path if "duckdb" not in modules: - raise RuntimeError("DuckDB submodule missing") + msg = "DuckDB submodule missing" + raise RuntimeError(msg) duckdb_path = modules["duckdb"] # now check that the submodule is usable @@ -101,9 +106,11 @@ def _duckdb_submodule_path() -> Path: status = status.decode("ascii", "replace") for line in status.splitlines(): if line.startswith("-"): - raise RuntimeError(f"Duckdb submodule not initialized: {line}") + msg = f"Duckdb submodule not initialized: {line}" + raise RuntimeError(msg) if line.startswith("U"): - raise RuntimeError(f"Duckdb submodule has merge conflicts: {line}") + msg = f"Duckdb submodule has merge conflicts: {line}" + raise RuntimeError(msg) if line.startswith("+"): _log(f"WARNING: Duckdb submodule not clean: {line}") # all good @@ -115,7 +122,7 @@ def _version_file_path() -> Path: return package_dir / _DUCKDB_VERSION_FILENAME -def _write_duckdb_long_version(long_version: str)-> None: +def _write_duckdb_long_version(long_version: str) -> None: """Write the given version string to a file in the same directory as this module.""" _version_file_path().write_text(long_version, encoding="utf-8") @@ -125,9 +132,7 @@ def _read_duckdb_long_version() -> str: return _version_file_path().read_text(encoding="utf-8").strip() -def _skbuild_config_add( - key: str, value: Union[List, str], config_settings: Dict[str, Union[List[str],str]], fail_if_exists: bool=False -): +def _skbuild_config_add(key: str, value: Union[list, str], config_settings: dict[str, Union[list[str], str]]) -> None: """Add or modify a configuration setting for scikit-build-core. This function handles adding values to scikit-build-core configuration settings, @@ -137,10 +142,9 @@ def _skbuild_config_add( key: The configuration key to set (will be prefixed with 'skbuild.' if needed). value: The value to add (string or list). config_settings: The configuration dictionary to modify. - fail_if_exists: If True, raise an error if the key already exists. Raises: - RuntimeError: If fail_if_exists is True and key exists, or on type mismatches. + RuntimeError: If this would overwrite an existing value, or on type mismatches. AssertionError: If config_settings is None. Behavior Rules: @@ -163,22 +167,19 @@ def _skbuild_config_add( val_is_list = isinstance(value, list) if not key_exists: config_settings[store_key] = value - elif fail_if_exists: - raise RuntimeError(f"{key} already present in config and may not be overridden") elif key_exists_as_list and val_is_list: config_settings[store_key].extend(value) elif key_exists_as_list and val_is_str: config_settings[store_key].append(value) elif key_exists_as_str and val_is_str: - _log(f"WARNING: overriding existing value in {store_key}") - config_settings[store_key] = value + msg = f"{key} already present in config and may not be overridden" + raise RuntimeError(msg) else: - raise RuntimeError( - f"Type mismatch: cannot set {store_key} ({type(config_settings[store_key])}) to `{value}` ({type(value)})" - ) + msg = f"Type mismatch: cannot set {store_key} ({type(config_settings[store_key])}) to `{value}` ({type(value)})" + raise RuntimeError(msg) -def build_sdist(sdist_directory: str, config_settings: Optional[Dict[str, Union[List[str],str]]] = None) -> str: +def build_sdist(sdist_directory: str, config_settings: Optional[dict[str, Union[list[str], str]]] = None) -> str: """Build a source distribution using the DuckDB submodule. This function extracts the DuckDB version from either the git submodule and saves it @@ -196,7 +197,8 @@ def build_sdist(sdist_directory: str, config_settings: Optional[Dict[str, Union[ RuntimeError: If not in a git repository or DuckDB submodule issues. """ if not _in_git_repository(): - raise RuntimeError("Not in a git repository, can't create an sdist") + msg = "Not in a git repository, can't create an sdist" + raise RuntimeError(msg) submodule_path = _duckdb_submodule_path() if _FORCED_PEP440_VERSION is not None: duckdb_version = pep440_to_git_tag(strip_post_from_version(_FORCED_PEP440_VERSION)) @@ -207,9 +209,9 @@ def build_sdist(sdist_directory: str, config_settings: Optional[Dict[str, Union[ def build_wheel( - wheel_directory: str, - config_settings: Optional[Dict[str, Union[List[str],str]]] = None, - metadata_directory: Optional[str] = None, + wheel_directory: str, + config_settings: Optional[dict[str, Union[list[str], str]]] = None, + metadata_directory: Optional[str] = None, ) -> str: """Build a wheel from either git submodule or extracted sdist sources. @@ -232,7 +234,8 @@ def build_wheel( duckdb_version = None if not _in_git_repository(): if not _in_sdist(): - raise RuntimeError("Not in a git repository nor in an sdist, can't build a wheel") + msg = "Not in a git repository nor in an sdist, can't build a wheel" + raise RuntimeError(msg) _log("Building duckdb wheel from sdist. Reading duckdb version from file.") config_settings = config_settings or {} duckdb_version = _read_duckdb_long_version() @@ -241,22 +244,21 @@ def build_wheel( # We add the found version to the OVERRIDE_GIT_DESCRIBE cmake var if duckdb_version is not None: - _skbuild_config_add(_SKBUILD_CMAKE_OVERRIDE_GIT_DESCRIBE, duckdb_version, config_settings, fail_if_exists=True) + _skbuild_config_add(_SKBUILD_CMAKE_OVERRIDE_GIT_DESCRIBE, duckdb_version, config_settings) _log(f"{_SKBUILD_CMAKE_OVERRIDE_GIT_DESCRIBE} set to {duckdb_version}") else: _log("No explicit DuckDB submodule version provided. Letting CMake figure it out.") - return skbuild_build_wheel(wheel_directory, config_settings=config_settings, metadata_directory=metadata_directory) __all__ = [ - "build_wheel", - "build_sdist", "build_editable", - "get_requires_for_build_wheel", - "get_requires_for_build_sdist", + "build_sdist", + "build_wheel", "get_requires_for_build_editable", - "prepare_metadata_for_build_wheel", + "get_requires_for_build_sdist", + "get_requires_for_build_wheel", "prepare_metadata_for_build_editable", + "prepare_metadata_for_build_wheel", ] diff --git a/duckdb_packaging/pypi_cleanup.py b/duckdb_packaging/pypi_cleanup.py index 81d4c8e0..d67db691 100644 --- a/duckdb_packaging/pypi_cleanup.py +++ b/duckdb_packaging/pypi_cleanup.py @@ -1,5 +1,4 @@ -""" -!!HERE BE DRAGONS!! Use this script with care! +"""!!HERE BE DRAGONS!! Use this script with care! PyPI package cleanup tool. This script will: * Never remove a stable version (including a post release version) @@ -17,8 +16,10 @@ import sys import time from collections import defaultdict +from collections.abc import Generator +from enum import Enum from html.parser import HTMLParser -from typing import Optional, Set, Generator +from typing import Optional, Union from urllib.parse import urlparse import pyotp @@ -28,8 +29,8 @@ from requests.exceptions import RequestException from urllib3 import Retry -_PYPI_URL_PROD = 'https://pypi.org/' -_PYPI_URL_TEST = 'https://test.pypi.org/' +_PYPI_URL_PROD = "https://pypi.org/" +_PYPI_URL_TEST = "https://test.pypi.org/" _DEFAULT_MAX_NIGHTLIES = 2 _LOGIN_RETRY_ATTEMPTS = 3 _LOGIN_RETRY_DELAY = 5 @@ -50,88 +51,66 @@ def create_argument_parser() -> argparse.ArgumentParser: * Keep the configured amount of dev releases per version, and remove older dev releases """, epilog="Environment variables required (unless --dry-run): PYPI_CLEANUP_PASSWORD, PYPI_CLEANUP_OTP", - formatter_class=argparse.RawDescriptionHelpFormatter + formatter_class=argparse.RawDescriptionHelpFormatter, ) - parser.add_argument( - "--dry-run", - action="store_true", - help="Show what would be deleted but don't actually do it" - ) + parser.add_argument("--dry-run", action="store_true", help="Show what would be deleted but don't actually do it") host_group = parser.add_mutually_exclusive_group(required=True) - host_group.add_argument( - "--prod", - action="store_true", - help="Use production PyPI (pypi.org)" - ) - host_group.add_argument( - "--test", - action="store_true", - help="Use test PyPI (test.pypi.org)" - ) + host_group.add_argument("--prod", action="store_true", help="Use production PyPI (pypi.org)") + host_group.add_argument("--test", action="store_true", help="Use test PyPI (test.pypi.org)") parser.add_argument( - "-m", "--max-nightlies", + "-m", + "--max-nightlies", type=int, default=_DEFAULT_MAX_NIGHTLIES, - help=f"Max number of nightlies of unreleased versions (default={_DEFAULT_MAX_NIGHTLIES})" + help=f"Max number of nightlies of unreleased versions (default={_DEFAULT_MAX_NIGHTLIES})", ) - parser.add_argument( - "-u", "--username", - type=validate_username, - help="PyPI username (required unless --dry-run)" - ) + parser.add_argument("-u", "--username", type=validate_username, help="PyPI username (required unless --dry-run)") - parser.add_argument( - "-v", "--verbose", - action="store_true", - help="Enable verbose debug logging" - ) + parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose debug logging") return parser + class PyPICleanupError(Exception): """Base exception for PyPI cleanup operations.""" - pass class AuthenticationError(PyPICleanupError): """Raised when authentication fails.""" - pass class ValidationError(PyPICleanupError): """Raised when input validation fails.""" - pass -def setup_logging(verbose: bool = False) -> None: +def setup_logging(level: int = logging.INFO) -> None: """Configure logging with appropriate level and format.""" - level = logging.DEBUG if verbose else logging.INFO - logging.basicConfig( - level=level, - format='%(asctime)s - %(levelname)s - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' - ) + logging.basicConfig(level=level, format="%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S") def validate_username(value: str) -> str: """Validate and sanitize username input.""" if not value or not value.strip(): - raise argparse.ArgumentTypeError("Username cannot be empty") - + msg = "Username cannot be empty" + raise argparse.ArgumentTypeError(msg) + username = value.strip() if len(username) > 100: # Reasonable limit - raise argparse.ArgumentTypeError("Username too long (max 100 characters)") - + msg = "Username too long (max 100 characters)" + raise argparse.ArgumentTypeError(msg) + # Basic validation - PyPI usernames are alphanumeric with limited special chars - if not re.match(r'^[a-zA-Z0-9][a-zA-Z0-9._-]*[a-zA-Z0-9]$|^[a-zA-Z0-9]$', username): - raise argparse.ArgumentTypeError("Invalid username format") - + if not re.match(r"^[a-zA-Z0-9][a-zA-Z0-9._-]*[a-zA-Z0-9]$|^[a-zA-Z0-9]$", username): + msg = "Invalid username format" + raise argparse.ArgumentTypeError(msg) + return username + @contextlib.contextmanager def session_with_retries() -> Generator[Session, None, None]: """Create a requests session with retry strategy for ephemeral errors.""" @@ -143,7 +122,7 @@ def session_with_retries() -> Generator[Session, None, None]: connect=3, # try 3 times before giving up on connection errors read=3, # try 3 times before giving up on read errors status=3, # try 3 times before giving up on status errors (see forcelist below) - status_forcelist=[429] + [status for status in range(500, 512)], + status_forcelist=[429, *list(range(500, 512))], other=0, # whatever else may cause an error should break backoff_factor=0.1, # [0.0s, 0.2s, 0.4s] raise_on_redirect=True, # raise exception when redirect error retries are exhausted @@ -154,107 +133,115 @@ def session_with_retries() -> Generator[Session, None, None]: session.mount("https://", adapter) yield session -def load_credentials(dry_run: bool) -> tuple[Optional[str], Optional[str]]: + +def load_credentials() -> tuple[Optional[str], Optional[str]]: """Load credentials from environment variables.""" - if dry_run: - return None, None - - password = os.getenv('PYPI_CLEANUP_PASSWORD') - otp = os.getenv('PYPI_CLEANUP_OTP') - + password = os.getenv("PYPI_CLEANUP_PASSWORD") + otp = os.getenv("PYPI_CLEANUP_OTP") + if not password: - raise ValidationError("PYPI_CLEANUP_PASSWORD environment variable is required when not in dry-run mode") + msg = "PYPI_CLEANUP_PASSWORD environment variable is required when not in dry-run mode" + raise ValidationError(msg) if not otp: - raise ValidationError("PYPI_CLEANUP_OTP environment variable is required when not in dry-run mode") - + msg = "PYPI_CLEANUP_OTP environment variable is required when not in dry-run mode" + raise ValidationError(msg) + return password, otp def validate_arguments(args: argparse.Namespace) -> None: """Validate parsed arguments.""" if not args.dry_run and not args.username: - raise ValidationError("--username is required when not in dry-run mode") - + msg = "--username is required when not in dry-run mode" + raise ValidationError(msg) + if args.max_nightlies < 0: - raise ValidationError("--max-nightlies must be non-negative") + msg = "--max-nightlies must be non-negative" + raise ValidationError(msg) + class CsrfParser(HTMLParser): """HTML parser to extract CSRF tokens from PyPI forms. - + Based on pypi-cleanup package (https://github.com/arcivanov/pypi-cleanup/tree/master) """ - def __init__(self, target, contains_input=None): + + def __init__(self, target: str) -> None: # noqa: D107 super().__init__() self._target = target - self._contains_input = contains_input self.csrf = None # Result value from all forms on page - self._csrf = None # Temp value from current form self._in_form = False # Currently parsing a form with an action we're interested in - self._input_contained = False # Input field requested is contained in the current form - - def handle_starttag(self, tag, attrs): - if tag == "form": - attrs = dict(attrs) - action = attrs.get("action") # Might be None. - if action and (action == self._target or action.startswith(self._target)): - self._in_form = True - return - if self._in_form and tag == "input": - attrs = dict(attrs) - if attrs.get("name") == "csrf_token": - self._csrf = attrs["value"] + def handle_starttag(self, tag: str, attrs: list[tuple[str, Union[str, None]]]) -> None: # noqa: D102 + if not self.csrf: + if tag == "form": + attrs = dict(attrs) + action = attrs.get("action") # Might be None. + if action and (action == self._target or action.startswith(self._target)): + self._in_form = True + elif self._in_form and tag == "input": + attrs = dict(attrs) + if attrs.get("name") == "csrf_token" and not self.csrf: + self.csrf = attrs["value"] + + def handle_endtag(self, tag: str) -> None: # noqa: D102 + if tag == "form" and self._in_form: + self._in_form = False - if self._contains_input and attrs.get("name") == self._contains_input: - self._input_contained = True - return +class CleanMode(Enum): + """Supported clean-up modes.""" - def handle_endtag(self, tag): - if tag == "form": - self._in_form = False - # If we're in a right form that contains the requested input and csrf is not set - if (not self._contains_input or self._input_contained) and not self.csrf: - self.csrf = self._csrf - return + LIST_ONLY = 1 + DELETE = 2 class PyPICleanup: """Main class for performing PyPI package cleanup operations.""" - def __init__(self, index_url: str, do_delete: bool, max_dev_releases: int=_DEFAULT_MAX_NIGHTLIES, - username: Optional[str]=None, password: Optional[str]=None, otp: Optional[str]=None): + def __init__( # noqa: D107 + self, + index_url: str, + mode: CleanMode, + max_dev_releases: int = _DEFAULT_MAX_NIGHTLIES, + username: Optional[str] = None, + password: Optional[str] = None, + otp: Optional[str] = None, + ) -> None: parsed_url = urlparse(index_url) - self._index_url = parsed_url.geturl().rstrip('/') + self._index_url = parsed_url.geturl().rstrip("/") self._index_host = parsed_url.hostname - self._do_delete = do_delete + self._mode = mode self._max_dev_releases = max_dev_releases self._username = username self._password = password self._otp = otp - self._package = 'duckdb' + self._package = "duckdb" self._dev_version_pattern = re.compile(r"^(?P\d+\.\d+\.\d+)\.dev(?P\d+)$") self._rc_version_pattern = re.compile(r"^(?P\d+\.\d+\.\d+)\.rc\d+$") self._stable_version_pattern = re.compile(r"^\d+\.\d+\.\d+(\.post\d+)?$") def run(self) -> int: """Execute the cleanup process. - + Returns: int: Exit code (0 for success, non-zero for failure) """ - if self._do_delete: - logging.warning(f"NOT A DRILL: WILL DELETE PACKAGES") - else: + if self._mode == CleanMode.DELETE: + logging.warning("NOT A DRILL: WILL DELETE PACKAGES") + elif self._mode == CleanMode.LIST_ONLY: logging.info("Running in DRY RUN mode, nothing will be deleted") + else: + msg = "Unexpected mode" + raise RuntimeError(msg) logging.info(f"Max development releases to keep per unreleased version: {self._max_dev_releases}") try: with session_with_retries() as http_session: return self._execute_cleanup(http_session) - except PyPICleanupError as e: - logging.error(f"Cleanup failed: {e}") + except PyPICleanupError: + logging.exception("Cleanup failed") return 1 except Exception as e: logging.error(f"Unexpected error: {e}", exc_info=True) @@ -262,47 +249,48 @@ def run(self) -> int: def _execute_cleanup(self, http_session: Session) -> int: """Execute the main cleanup logic.""" - # Get released versions versions = self._fetch_released_versions(http_session) if not versions: logging.info(f"No releases found for {self._package}") return 0 - + # Determine versions to delete versions_to_delete = self._determine_versions_to_delete(versions) if not versions_to_delete: logging.info("No versions to delete (no stale rc's or dev releases)") return 0 - + logging.warning(f"Found {len(versions_to_delete)} versions to clean up:") for version in sorted(versions_to_delete): logging.warning(version) - - if not self._do_delete: + + if self._mode != CleanMode.DELETE: logging.info("Dry run complete - no packages were deleted") return 0 # Perform authentication and deletion self._authenticate(http_session) self._delete_versions(http_session, versions_to_delete) - + logging.info(f"Successfully cleaned up {len(versions_to_delete)} development versions") return 0 - - def _fetch_released_versions(self, http_session: Session) -> Set[str]: + + def _fetch_released_versions(self, http_session: Session) -> set[str]: """Fetch package release information from PyPI API.""" logging.debug(f"Fetching package information for '{self._package}'") - + try: req = http_session.get(f"{self._index_url}/pypi/{self._package}/json") req.raise_for_status() - data = req.json() - versions = {v for v, files in data["releases"].items() if len(files) > 0} - logging.debug(f"Found {len(versions)} releases with files") - return versions except RequestException as e: - raise PyPICleanupError(f"Failed to fetch package information for '{self._package}': {e}") from e + msg = f"Failed to fetch package information for '{self._package}': {e}" + raise PyPICleanupError(msg) from e + + data = req.json() + versions = {v for v, files in data["releases"].items() if len(files) > 0} + logging.debug(f"Found {len(versions)} releases with files") + return versions def _is_stable_release_version(self, version: str) -> bool: """Determine whether a version string denotes a stable release.""" @@ -320,17 +308,19 @@ def _parse_rc_version(self, version: str) -> str: """Parse a rc version string to determine the base version.""" match = self._rc_version_pattern.match(version) if not match: - raise PyPICleanupError(f"Invalid rc version '{version}'") + msg = f"Invalid rc version '{version}'" + raise PyPICleanupError(msg) return match.group("version") if match else None def _parse_dev_version(self, version: str) -> tuple[str, int]: """Parse a dev version string to determine the base version and dev version id.""" match = self._dev_version_pattern.match(version) if not match: - raise PyPICleanupError(f"Invalid dev version '{version}'") + msg = f"Invalid dev version '{version}'" + raise PyPICleanupError(msg) return match.group("version"), int(match.group("dev_id")) - def _determine_versions_to_delete(self, versions: Set[str]) -> Set[str]: + def _determine_versions_to_delete(self, versions: set[str]) -> set[str]: """Determine which package versions should be deleted.""" logging.debug("Analyzing versions to determine cleanup candidates") @@ -378,26 +368,29 @@ def _determine_versions_to_delete(self, versions: Set[str]) -> Set[str]: # Final safety checks if versions_to_delete == versions: - raise PyPICleanupError( + msg = ( f"Safety check failed: cleanup would delete ALL versions of '{self._package}'. " "This would make the package permanently inaccessible. Aborting." ) + raise PyPICleanupError(msg) if len(versions_to_delete.intersection(stable_versions)) > 0: - raise PyPICleanupError( + msg = ( f"Safety check failed: cleanup would delete one or more stable versions of '{self._package}'. " f"A regexp might be broken? (would delete {versions_to_delete.intersection(stable_versions)})" ) + raise PyPICleanupError(msg) unknown_versions = versions.difference(stable_versions).difference(rc_versions).difference(dev_versions) if unknown_versions: logging.warning(f"Found version string(s) in an unsupported format: {unknown_versions}") return versions_to_delete - + def _authenticate(self, http_session: Session) -> None: """Authenticate with PyPI.""" if not self._username or not self._password: - raise AuthenticationError("Username and password are required for authentication") - + msg = "Username and password are required for authentication" + raise AuthenticationError(msg) + logging.info(f"Authenticating user '{self._username}' with PyPI") try: @@ -408,12 +401,13 @@ def _authenticate(self, http_session: Session) -> None: if login_response.url.startswith(f"{self._index_url}/account/two-factor/"): logging.debug("Two-factor authentication required") self._handle_two_factor_auth(http_session, login_response) - + logging.info("Authentication successful") except RequestException as e: - raise AuthenticationError(f"Network error during authentication: {e}") from e - + msg = f"Network error during authentication: {e}" + raise AuthenticationError(msg) from e + def _get_csrf_token(self, http_session: Session, form_action: str) -> str: """Extract CSRF token from a form page.""" resp = http_session.get(f"{self._index_url}{form_action}") @@ -421,43 +415,41 @@ def _get_csrf_token(self, http_session: Session, form_action: str) -> str: parser = CsrfParser(form_action) parser.feed(resp.text) if not parser.csrf: - raise AuthenticationError(f"No CSRF token found in {form_action}") + msg = f"No CSRF token found in {form_action}" + raise AuthenticationError(msg) return parser.csrf - + def _perform_login(self, http_session: Session) -> requests.Response: """Perform the initial login with username/password.""" - # Get login form and CSRF token csrf_token = self._get_csrf_token(http_session, "/account/login/") - login_data = { - "csrf_token": csrf_token, - "username": self._username, - "password": self._password - } + login_data = {"csrf_token": csrf_token, "username": self._username, "password": self._password} response = http_session.post( f"{self._index_url}/account/login/", data=login_data, - headers={"referer": f"{self._index_url}/account/login/"} + headers={"referer": f"{self._index_url}/account/login/"}, ) response.raise_for_status() # Check if login failed (redirected back to login page) if response.url == f"{self._index_url}/account/login/": - raise AuthenticationError(f"Login failed for user '{self._username}' - check credentials") + msg = f"Login failed for user '{self._username}' - check credentials" + raise AuthenticationError(msg) return response - + def _handle_two_factor_auth(self, http_session: Session, response: requests.Response) -> None: """Handle two-factor authentication.""" if not self._otp: - raise AuthenticationError("Two-factor authentication required but no OTP secret provided") - + msg = "Two-factor authentication required but no OTP secret provided" + raise AuthenticationError(msg) + two_factor_url = response.url - form_action = two_factor_url[len(self._index_url):] + form_action = two_factor_url[len(self._index_url) :] csrf_token = self._get_csrf_token(http_session, form_action) - + # Try authentication with retries for attempt in range(_LOGIN_RETRY_ATTEMPTS): try: @@ -467,7 +459,7 @@ def _handle_two_factor_auth(self, http_session: Session, response: requests.Resp auth_response = http_session.post( two_factor_url, data={"csrf_token": csrf_token, "method": "totp", "totp_value": auth_code}, - headers={"referer": two_factor_url} + headers={"referer": two_factor_url}, ) auth_response.raise_for_status() @@ -479,46 +471,48 @@ def _handle_two_factor_auth(self, http_session: Session, response: requests.Resp if attempt < _LOGIN_RETRY_ATTEMPTS - 1: logging.debug(f"2FA code rejected, retrying in {_LOGIN_RETRY_DELAY} seconds...") time.sleep(_LOGIN_RETRY_DELAY) - + except RequestException as e: if attempt == _LOGIN_RETRY_ATTEMPTS - 1: - raise AuthenticationError(f"Network error during 2FA: {e}") from e + msg = f"Network error during 2FA: {e}" + raise AuthenticationError(msg) from e logging.debug(f"Network error during 2FA attempt {attempt + 1}, retrying...") time.sleep(_LOGIN_RETRY_DELAY) - - raise AuthenticationError("Two-factor authentication failed after all attempts") - - def _delete_versions(self, http_session: Session, versions_to_delete: Set[str]) -> None: + + msg = "Two-factor authentication failed after all attempts" + raise AuthenticationError(msg) + + def _delete_versions(self, http_session: Session, versions_to_delete: set[str]) -> None: """Delete the specified package versions.""" logging.info(f"Starting deletion of {len(versions_to_delete)} development versions") - - failed_deletions = list() + + failed_deletions = [] for version in sorted(versions_to_delete): try: self._delete_single_version(http_session, version) logging.info(f"Successfully deleted {self._package} version {version}") - except Exception as e: + except Exception: # Continue with other versions rather than failing completely - logging.error(f"Failed to delete version {version}: {e}") + logging.exception(f"Failed to delete version {version}") failed_deletions.append(version) - + if failed_deletions: - raise PyPICleanupError( - f"Failed to delete {len(failed_deletions)}/{len(versions_to_delete)} versions: {failed_deletions}" - ) - + msg = f"Failed to delete {len(failed_deletions)}/{len(versions_to_delete)} versions: {failed_deletions}" + raise PyPICleanupError(msg) + def _delete_single_version(self, http_session: Session, version: str) -> None: """Delete a single package version.""" # Safety check if not self._is_dev_version(version) or self._is_rc_version(version): - raise PyPICleanupError(f"Refusing to delete non-[dev|rc] version: {version}") - + msg = f"Refusing to delete non-[dev|rc] version: {version}" + raise PyPICleanupError(msg) + logging.debug(f"Deleting {self._package} version {version}") - + # Get deletion form and CSRF token form_action = f"/manage/project/{self._package}/release/{version}/" form_url = f"{self._index_url}{form_action}" - + csrf_token = self._get_csrf_token(http_session, form_action) # Submit deletion request @@ -528,7 +522,7 @@ def _delete_single_version(self, http_session: Session, version: str) -> None: "csrf_token": csrf_token, "confirm_delete_version": version, }, - headers={"referer": form_url} + headers={"referer": form_url}, ) delete_response.raise_for_status() @@ -537,34 +531,36 @@ def main() -> int: """Main entry point for the script.""" parser = create_argument_parser() args = parser.parse_args() - + # Setup logging - setup_logging(args.verbose) - + setup_logging((args.verbose and logging.DEBUG) or logging.INFO) + try: # Validate arguments validate_arguments(args) - - # Load credentials - password, otp = load_credentials(args.dry_run) - + + # Dry run vs delete + password, otp, mode = None, None, CleanMode.LIST_ONLY + if args.dry_run: + password, otp = load_credentials() + mode = CleanMode.DELETE + # Determine PyPI URL pypi_url = _PYPI_URL_PROD if args.prod else _PYPI_URL_TEST - + # Create and run cleanup - cleanup = PyPICleanup(pypi_url, not args.dry_run, args.max_nightlies, username=args.username, - password=password, otp=otp) - + cleanup = PyPICleanup(pypi_url, mode, args.max_nightlies, username=args.username, password=password, otp=otp) + return cleanup.run() - - except ValidationError as e: - logging.error(f"Configuration error: {e}") + + except ValidationError: + logging.exception("Configuration error") return 2 except KeyboardInterrupt: logging.info("Operation cancelled by user") return 130 - except Exception as e: - logging.error(f"Unexpected error: {e}", exc_info=args.verbose) + except Exception: + logging.exception("Unexpected error", exc_info=args.verbose) return 1 diff --git a/duckdb_packaging/setuptools_scm_version.py b/duckdb_packaging/setuptools_scm_version.py index 8381e1e2..630f2493 100644 --- a/duckdb_packaging/setuptools_scm_version.py +++ b/duckdb_packaging/setuptools_scm_version.py @@ -1,5 +1,4 @@ -""" -setuptools_scm integration for DuckDB Python versioning. +"""setuptools_scm integration for DuckDB Python versioning. This module provides the setuptools_scm version scheme and handles environment variable overrides to match the exact behavior of the original DuckDB Python package. @@ -7,10 +6,10 @@ import os import re -from typing import Any +from typing import Protocol # Import from our own versioning module to avoid duplication -from ._versioning import parse_version, format_version +from ._versioning import format_version, parse_version # MAIN_BRANCH_VERSIONING should be 'True' on main branch only MAIN_BRANCH_VERSIONING = True @@ -20,13 +19,19 @@ OVERRIDE_GIT_DESCRIBE_ENV_VAR = "OVERRIDE_GIT_DESCRIBE" -def _main_branch_versioning(): - from_env = os.getenv('MAIN_BRANCH_VERSIONING') +class _VersionObject(Protocol): + tag: object + distance: int + dirty: bool + + +def _main_branch_versioning() -> bool: + from_env = os.getenv("MAIN_BRANCH_VERSIONING") return from_env == "1" if from_env is not None else MAIN_BRANCH_VERSIONING -def version_scheme(version: Any) -> str: - """ - setuptools_scm version scheme that matches DuckDB's original behavior. + +def version_scheme(version: _VersionObject) -> str: + """setuptools_scm version scheme that matches DuckDB's original behavior. Args: version: setuptools_scm version object @@ -41,42 +46,45 @@ def version_scheme(version: Any) -> str: # Handle case where tag is None if version.tag is None: - raise ValueError("Need a valid version. Did you set a fallback_version in pyproject.toml?") + msg = "Need a valid version. Did you set a fallback_version in pyproject.toml?" + raise ValueError(msg) + distance = int(version.distance or 0) try: - return _bump_version(str(version.tag), version.distance, version.dirty) + if distance == 0 and not version.dirty: + return _tag_to_version(str(version.tag)) + return _bump_dev_version(str(version.tag), distance) except Exception as e: - raise RuntimeError(f"Failed to bump version: {e}") + msg = f"Failed to bump version: {e}" + raise RuntimeError(msg) from e -def _bump_version(base_version: str, distance: int, dirty: bool = False) -> str: - """Bump the version if needed.""" - # Validate the base version (this should never include anything else than X.Y.Z or X.Y.Z.[rc|post]N) - try: - major, minor, patch, post, rc = parse_version(base_version) - except ValueError as e: - raise ValueError(f"Incorrect version format: {base_version} (expected X.Y.Z or X.Y.Z.postN)") +def _tag_to_version(tag: str) -> str: + """Bump the version when we're on a tag.""" + major, minor, patch, post, rc = parse_version(tag) + return format_version(major, minor, patch, post=post, rc=rc) - # If we're exactly on a tag (distance = 0, dirty=False) - distance = int(distance or 0) - if distance == 0 and not dirty: - return format_version(major, minor, patch, post=post, rc=rc) - # Otherwise we're at a distance and / or dirty, and need to bump +def _bump_dev_version(base_version: str, distance: int) -> str: + """Bump the given version.""" + if distance == 0: + msg = "Dev distance is 0, cannot bump version." + raise ValueError(msg) + major, minor, patch, post, rc = parse_version(base_version) + if post != 0: # We're developing on top of a post-release - return f"{format_version(major, minor, patch, post=post+1)}.dev{distance}" + return f"{format_version(major, minor, patch, post=post + 1)}.dev{distance}" elif rc != 0: # We're developing on top of an rc - return f"{format_version(major, minor, patch, rc=rc+1)}.dev{distance}" + return f"{format_version(major, minor, patch, rc=rc + 1)}.dev{distance}" elif _main_branch_versioning(): - return f"{format_version(major, minor+1, 0)}.dev{distance}" - return f"{format_version(major, minor, patch+1)}.dev{distance}" + return f"{format_version(major, minor + 1, 0)}.dev{distance}" + return f"{format_version(major, minor, patch + 1)}.dev{distance}" -def forced_version_from_env(): - """ - Handle getting versions from environment variables. +def forced_version_from_env() -> str: + """Handle getting versions from environment variables. Only supports a single way of manually overriding the version through OVERRIDE_GIT_DESCRIBE. If SETUPTOOLS_SCM_PRETEND_VERSION* is set, it gets unset. @@ -112,25 +120,27 @@ def _git_describe_override_to_pep_440(override_value: str) -> str: match = describe_pattern.match(override_value) if not match: - raise ValueError(f"Invalid git describe override: {override_value}") + msg = f"Invalid git describe override: {override_value}" + raise ValueError(msg) version, distance, commit_hash = match.groups() # Convert version format to PEP440 format (v1.3.1-post1 -> 1.3.1.post1) - if '-post' in version: + if "-post" in version: version = version.replace("-post", ".post") - elif '-rc' in version: + elif "-rc" in version: version = version.replace("-rc", "rc") # Bump version and format according to PEP440 - pep440_version = _bump_version(version, int(distance or 0)) + distance = int(distance or 0) + pep440_version = _tag_to_version(str(version)) if distance == 0 else _bump_dev_version(str(version), distance) if commit_hash: pep440_version += f"+g{commit_hash.lower()}" return pep440_version -def _remove_unsupported_env_var(env_var): +def _remove_unsupported_env_var(env_var: str) -> None: """Remove an unsupported environment variable with a warning.""" print(f"[versioning] WARNING: We do not support {env_var}! Removing.") del os.environ[env_var] diff --git a/external/duckdb b/external/duckdb index 4bfb6e2f..b3c8acdc 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 4bfb6e2f8c74c7a02e25ca80bad68456270e545b +Subproject commit b3c8acdc0e4e43478671955a069590b3e7b76478 diff --git a/pyproject.toml b/pyproject.toml index bcbb24f6..1b81884e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ all = [ # users can install duckdb with 'duckdb[all]', which will install this l "numpy", # used in duckdb.experimental.spark and in duckdb.fetchnumpy() "pandas; python_version < '3.14'", # used for pandas dataframes all over the place "pyarrow; python_version < '3.14'", # used for pyarrow support - "adbc_driver_manager; python_version < '3.14'", # for the adbc driver (TODO: this should live under the duckdb package) + "adbc_driver_manager; python_version < '3.14'", # for the adbc driver ] ###################################################################################################### @@ -77,6 +77,8 @@ metadata.version.provider = "scikit_build_core.metadata.setuptools_scm" [tool.scikit-build.wheel] cmake = true packages.duckdb = "duckdb" +packages.adbc_driver_duckdb = "adbc_driver_duckdb" +packages._duckdb-stubs = "_duckdb-stubs" [tool.scikit-build.cmake.define] CORE_EXTENSIONS = "core_functions;json;parquet;icu;jemalloc" @@ -133,11 +135,12 @@ include = [ "CMakeLists.txt", "cmake/**", - # Source code + # Source code and stubs "src/**", "duckdb/**", "duckdb_packaging/**", "adbc_driver_duckdb/**", + "_duckdb-stubs/*.pyi", # Generated during sdist build, contains git describe string for duckdb "duckdb_packaging/duckdb_version.txt", @@ -218,12 +221,15 @@ torchvision = [ { index = "pytorch-cpu" } ] [dependency-groups] # used for development only, requires pip >=25.1.0 stubdeps = [ # dependencies used for typehints in the stubs + "pybind11-stubgen", + "mypy", "fsspec", "pandas; python_version < '3.14'", "polars; python_version < '3.14'", "pyarrow; python_version < '3.14'", ] test = [ # dependencies used for running tests + "adbc-driver-manager", "pytest", "pytest-reraise", "pytest-timeout", @@ -274,7 +280,8 @@ build = [ "setuptools_scm>=8.0", ] dev = [ # tooling like uv will install this automatically when syncing the environment - "ruff>=0.11.13", + "pre-commit", + "ruff>=0.13.0", {include-group = "stubdeps"}, {include-group = "build"}, {include-group = "test"}, @@ -307,19 +314,47 @@ filterwarnings = [ "ignore:is_datetime64tz_dtype is deprecated:DeprecationWarning", ] +[tool.mypy] +packages = ["duckdb", "_duckdb"] +strict = true +warn_unreachable = true +pretty = true +python_version = "3.9" +exclude = [ + "duckdb/experimental/", # not checking the pyspark API + "duckdb/query_graph/", # old and unmaintained (should probably remove) + "tests", "scripts", +] + +[[tool.mypy.overrides]] +module = [ + "fsspec.*", + "pandas", + "polars", + "pyarrow.*", + "torch", + "tensorflow", +] +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "duckdb.filesystem" +disallow_subclassing_any = false + [tool.coverage.run] branch = true source = ["duckdb"] [tool.ruff] -line-length = 88 +line-length = 120 indent-width = 4 target-version = "py39" fix = true -fixable = ["ALL"] -exclude = ['external/duckdb'] +exclude = ['external/duckdb', 'sqllogic'] [tool.ruff.lint] +fixable = ["ALL"] +exclude = ['*.pyi'] select = [ "ANN", # flake8-annotations "B", # flake8-bugbear @@ -328,10 +363,9 @@ select = [ "E", # pycodestyle "EM", # flake8-errmsg "F", # pyflakes - "FA", # flake8-future-annotations "FBT001", # flake8-boolean-trap "I", # isort - "ICN", # flake8-import-conventions + #"ICN", # flake8-import-conventions "INT", # flake8-gettext "PERF", # perflint "PIE", # flake8-pie @@ -342,15 +376,15 @@ select = [ "SIM", # flake8-simplify "TCH", # flake8-type-checking "TD", # flake8-todos - "TID", # flake8-tidy-imports + #"TID", # flake8-tidy-imports "TRY", # tryceratops "UP", # pyupgrade "W", # pycodestyle ] -ignore = [] - -[tool.ruff.lint.pycodestyle] -max-doc-length = 88 +ignore = [ + "C901", # ignore function complexity + "ANN002", "ANN003", # don't require type annotations for *args and **kwargs +] [tool.ruff.lint.pydocstyle] convention = "google" @@ -361,6 +395,35 @@ ban-relative-imports = "all" [tool.ruff.lint.flake8-type-checking] strict = true +[tool.ruff.lint.per-file-ignores] +"duckdb/experimental/spark/**.py" = [ + # Ignore boolean positional args in the Spark API + 'FBT001' +] +"duckdb_packaging/**.py" = [ + # ignore all performance-related rules for duckdb_packaging + 'PERF' +] +"tests/**.py" = [ + # No need for package, module, class, function, init etc docstrings in tests + 'D100', 'D101', 'D102', 'D103', 'D104', 'D105', 'D107', + # No need for type hinting in tests + 'ANN001', 'ANN201', 'ANN202' +] +"tests/fast/spark/**.py" = [ + "E402" +] +"tests/spark_namespace/**.py" = [ + # we need * imports for Spark + 'F403' +] +"scripts/**.py" = [ + # No need for package, module, class, function, init etc docstrings in scripts + 'D100', 'D101', 'D102', 'D103', 'D104', 'D105', 'D107', 'D205', + # No need for type hinting in scripts + 'ANN001', 'ANN201' +] + [tool.ruff.format] docstring-code-format = true docstring-code-line-length = 88 diff --git a/scripts/cache_data.json b/scripts/cache_data.json index 640052cd..3dd9a1f1 100644 --- a/scripts/cache_data.json +++ b/scripts/cache_data.json @@ -7,7 +7,19 @@ "pyarrow.dataset", "pyarrow.Table", "pyarrow.RecordBatchReader", - "pyarrow.ipc" + "pyarrow.ipc", + "pyarrow.scalar", + "pyarrow.date32", + "pyarrow.time64", + "pyarrow.timestamp", + "pyarrow.uint8", + "pyarrow.uint16", + "pyarrow.uint32", + "pyarrow.uint64", + "pyarrow.binary_view", + "pyarrow.decimal32", + "pyarrow.decimal64", + "pyarrow.decimal128" ] }, "pyarrow.dataset": { @@ -709,5 +721,77 @@ "name": "duckdb_source", "children": [], "required": false + }, + "pyarrow.scalar": { + "type": "attribute", + "full_path": "pyarrow.scalar", + "name": "scalar", + "children": [] + }, + "pyarrow.date32": { + "type": "attribute", + "full_path": "pyarrow.date32", + "name": "date32", + "children": [] + }, + "pyarrow.time64": { + "type": "attribute", + "full_path": "pyarrow.time64", + "name": "time64", + "children": [] + }, + "pyarrow.timestamp": { + "type": "attribute", + "full_path": "pyarrow.timestamp", + "name": "timestamp", + "children": [] + }, + "pyarrow.uint8": { + "type": "attribute", + "full_path": "pyarrow.uint8", + "name": "uint8", + "children": [] + }, + "pyarrow.uint16": { + "type": "attribute", + "full_path": "pyarrow.uint16", + "name": "uint16", + "children": [] + }, + "pyarrow.uint32": { + "type": "attribute", + "full_path": "pyarrow.uint32", + "name": "uint32", + "children": [] + }, + "pyarrow.uint64": { + "type": "attribute", + "full_path": "pyarrow.uint64", + "name": "uint64", + "children": [] + }, + "pyarrow.binary_view": { + "type": "attribute", + "full_path": "pyarrow.binary_view", + "name": "binary_view", + "children": [] + }, + "pyarrow.decimal32": { + "type": "attribute", + "full_path": "pyarrow.decimal32", + "name": "decimal32", + "children": [] + }, + "pyarrow.decimal64": { + "type": "attribute", + "full_path": "pyarrow.decimal64", + "name": "decimal64", + "children": [] + }, + "pyarrow.decimal128": { + "type": "attribute", + "full_path": "pyarrow.decimal128", + "name": "decimal128", + "children": [] } } \ No newline at end of file diff --git a/scripts/generate_connection_code.py b/scripts/generate_connection_code.py index 3737f83a..8e2bace9 100644 --- a/scripts/generate_connection_code.py +++ b/scripts/generate_connection_code.py @@ -3,7 +3,7 @@ import generate_connection_wrapper_methods import generate_connection_wrapper_stubs -if __name__ == '__main__': +if __name__ == "__main__": generate_connection_methods.generate() generate_connection_stubs.generate() generate_connection_wrapper_methods.generate() diff --git a/scripts/generate_connection_methods.py b/scripts/generate_connection_methods.py index c1f01e54..13cb7dce 100644 --- a/scripts/generate_connection_methods.py +++ b/scripts/generate_connection_methods.py @@ -1,10 +1,11 @@ -import os import json +import os +from pathlib import Path -os.chdir(os.path.dirname(__file__)) +os.chdir(Path(__file__).parent) -JSON_PATH = os.path.join("connection_methods.json") -PYCONNECTION_SOURCE = os.path.join("..", "src", "pyconnection.cpp") +JSON_PATH = "connection_methods.json" +PYCONNECTION_SOURCE = Path("..") / "src" / "duckdb_py" / "pyconnection.cpp" INITIALIZE_METHOD = ( "static void InitializeConnectionMethods(py::class_> &m) {" @@ -13,58 +14,57 @@ def is_py_kwargs(method): - return 'kwargs_as_dict' in method and method['kwargs_as_dict'] == True + return "kwargs_as_dict" in method and method["kwargs_as_dict"] def is_py_args(method): - if 'args' not in method: + if "args" not in method: return False - args = method['args'] + args = method["args"] if len(args) == 0: return False - if args[0]['name'] != '*args': - return False - return True + return args[0]["name"] == "*args" def generate(): # Read the PYCONNECTION_SOURCE file - with open(PYCONNECTION_SOURCE, 'r') as source_file: - source_code = source_file.readlines() + source_code = Path(PYCONNECTION_SOURCE).read_text().splitlines() start_index = -1 end_index = -1 for i, line in enumerate(source_code): if line.startswith(INITIALIZE_METHOD): if start_index != -1: - raise ValueError("Encountered the INITIALIZE_METHOD a second time, quitting!") + msg = "Encountered the INITIALIZE_METHOD a second time, quitting!" + raise ValueError(msg) start_index = i elif line.startswith(END_MARKER): if end_index != -1: - raise ValueError("Encountered the END_MARKER a second time, quitting!") + msg = "Encountered the END_MARKER a second time, quitting!" + raise ValueError(msg) end_index = i if start_index == -1 or end_index == -1: - raise ValueError("Couldn't find start or end marker in source file") + msg = "Couldn't find start or end marker in source file" + raise ValueError(msg) start_section = source_code[: start_index + 1] end_section = source_code[end_index:] # ---- Generate the definition code from the json ---- # Read the JSON file - with open(JSON_PATH, 'r') as json_file: - connection_methods = json.load(json_file) + connection_methods = json.loads(Path(JSON_PATH).read_text()) DEFAULT_ARGUMENT_MAP = { - 'True': 'true', - 'False': 'false', - 'None': 'py::none()', - 'PythonUDFType.NATIVE': 'PythonUDFType::NATIVE', - 'PythonExceptionHandling.DEFAULT': 'PythonExceptionHandling::FORWARD_ERROR', - 'FunctionNullHandling.DEFAULT': 'FunctionNullHandling::DEFAULT_NULL_HANDLING', + "True": "true", + "False": "false", + "None": "py::none()", + "PythonUDFType.NATIVE": "PythonUDFType::NATIVE", + "PythonExceptionHandling.DEFAULT": "PythonExceptionHandling::FORWARD_ERROR", + "FunctionNullHandling.DEFAULT": "FunctionNullHandling::DEFAULT_NULL_HANDLING", } - def map_default(val): + def map_default(val) -> str: if val in DEFAULT_ARGUMENT_MAP: return DEFAULT_ARGUMENT_MAP[val] return val @@ -72,61 +72,57 @@ def map_default(val): def create_arguments(arguments) -> list: result = [] for arg in arguments: - if arg['name'] == '*args': + if arg["name"] == "*args": break - argument = f"py::arg(\"{arg['name']}\")" - if 'allow_none' in arg: - value = str(arg['allow_none']).lower() + argument = f'py::arg("{arg["name"]}")' + if "allow_none" in arg: + value = str(arg["allow_none"]).lower() argument += f".none({value})" # Add the default argument if present - if 'default' in arg: - default = map_default(arg['default']) + if "default" in arg: + default = map_default(arg["default"]) argument += f" = {default}" result.append(argument) return result def create_definition(name, method) -> str: - definition = f"m.def(\"{name}\"" + definition = f'm.def("{name}"' definition += ", " - definition += f"""&DuckDBPyConnection::{method['function']}""" + definition += f"""&DuckDBPyConnection::{method["function"]}""" definition += ", " - definition += f"\"{method['docs']}\"" - if 'args' in method and not is_py_args(method): + definition += f'"{method["docs"]}"' + if "args" in method and not is_py_args(method): definition += ", " - arguments = create_arguments(method['args']) - definition += ', '.join(arguments) - if 'kwargs' in method: + arguments = create_arguments(method["args"]) + definition += ", ".join(arguments) + if "kwargs" in method: definition += ", " if is_py_kwargs(method): definition += "py::kw_only()" else: definition += "py::kw_only(), " - arguments = create_arguments(method['kwargs']) - definition += ', '.join(arguments) + arguments = create_arguments(method["kwargs"]) + definition += ", ".join(arguments) definition += ");" return definition body = [] for method in connection_methods: - if isinstance(method['name'], list): - names = method['name'] - else: - names = [method['name']] - for name in names: - body.append(create_definition(name, method)) + names = method["name"] if isinstance(method["name"], list) else [method["name"]] + body.extend(create_definition(name, method) for name in names) # ---- End of generation code ---- - with_newlines = ['\t' + x + '\n' for x in body] + with_newlines = ["\t" + x + "\n" for x in body] # Recreate the file content by concatenating all the pieces together new_content = start_section + with_newlines + end_section # Write out the modified PYCONNECTION_SOURCE file - with open(PYCONNECTION_SOURCE, 'w') as source_file: - source_file.write("".join(new_content)) + Path(PYCONNECTION_SOURCE).write_text("".join(new_content)) -if __name__ == '__main__': - raise ValueError("Please use 'generate_connection_code.py' instead of running the individual script(s)") +if __name__ == "__main__": + msg = "Please use 'generate_connection_code.py' instead of running the individual script(s)" + raise ValueError(msg) # generate() diff --git a/scripts/generate_connection_stubs.py b/scripts/generate_connection_stubs.py index fbb66c21..d542a047 100644 --- a/scripts/generate_connection_stubs.py +++ b/scripts/generate_connection_stubs.py @@ -1,10 +1,11 @@ -import os import json +import os +from pathlib import Path -os.chdir(os.path.dirname(__file__)) +os.chdir(Path(__file__).parent) -JSON_PATH = os.path.join("connection_methods.json") -DUCKDB_STUBS_FILE = os.path.join("..", "duckdb-stubs", "__init__.pyi") +JSON_PATH = "connection_methods.json" +DUCKDB_STUBS_FILE = Path("..") / "duckdb" / "__init__.pyi" START_MARKER = " # START OF CONNECTION METHODS" END_MARKER = " # END OF CONNECTION METHODS" @@ -12,31 +13,32 @@ def generate(): # Read the DUCKDB_STUBS_FILE file - with open(DUCKDB_STUBS_FILE, 'r') as source_file: - source_code = source_file.readlines() + source_code = Path(DUCKDB_STUBS_FILE).read_text().splitlines() start_index = -1 end_index = -1 for i, line in enumerate(source_code): if line.startswith(START_MARKER): if start_index != -1: - raise ValueError("Encountered the START_MARKER a second time, quitting!") + msg = "Encountered the START_MARKER a second time, quitting!" + raise ValueError(msg) start_index = i elif line.startswith(END_MARKER): if end_index != -1: - raise ValueError("Encountered the END_MARKER a second time, quitting!") + msg = "Encountered the END_MARKER a second time, quitting!" + raise ValueError(msg) end_index = i if start_index == -1 or end_index == -1: - raise ValueError("Couldn't find start or end marker in source file") + msg = "Couldn't find start or end marker in source file" + raise ValueError(msg) start_section = source_code[: start_index + 1] end_section = source_code[end_index:] # ---- Generate the definition code from the json ---- # Read the JSON file - with open(JSON_PATH, 'r') as json_file: - connection_methods = json.load(json_file) + connection_methods = json.loads(Path(JSON_PATH).read_text()) body = [] @@ -45,54 +47,53 @@ def create_arguments(arguments) -> list: for arg in arguments: argument = f"{arg['name']}: {arg['type']}" # Add the default argument if present - if 'default' in arg: - default = arg['default'] + if "default" in arg: + default = arg["default"] argument += f" = {default}" result.append(argument) return result - def create_definition(name, method, overloaded: bool) -> str: - if overloaded: - definition: str = "@overload\n" - else: - definition: str = "" - definition += f"def {name}(" - arguments = ['self'] - if 'args' in method: - arguments.extend(create_arguments(method['args'])) - if 'kwargs' in method: - if not any(x.startswith('*') for x in arguments): + def create_definition(name, method) -> str: + definition = f"def {name}(" + arguments = ["self"] + if "args" in method: + arguments.extend(create_arguments(method["args"])) + if "kwargs" in method: + if not any(x.startswith("*") for x in arguments): arguments.append("*") - arguments.extend(create_arguments(method['kwargs'])) + arguments.extend(create_arguments(method["kwargs"])) definition += ", ".join(arguments) definition += ")" definition += f" -> {method['return']}: ..." return definition + def create_overloaded_definition(name, method) -> str: + return f"@overload\n{create_definition(name, method)}" + # We have "duplicate" methods, which are overloaded. # We keep note of them to add the @overload decorator. - overloaded_methods: set[str] = {m for m in connection_methods if isinstance(m['name'], list)} + overloaded_methods: set[str] = {m for m in connection_methods if isinstance(m["name"], list)} for method in connection_methods: - if isinstance(method['name'], list): - names = method['name'] - else: - names = [method['name']] + names = method["name"] if isinstance(method["name"], list) else [method["name"]] for name in names: - body.append(create_definition(name, method, name in overloaded_methods)) + if name in overloaded_methods: + body.append(create_overloaded_definition(name, method)) + else: + body.append(create_definition(name, method)) # ---- End of generation code ---- - with_newlines = [' ' + x + '\n' for x in body] + with_newlines = [" " + x + "\n" for x in body] # Recreate the file content by concatenating all the pieces together new_content = start_section + with_newlines + end_section # Write out the modified DUCKDB_STUBS_FILE file - with open(DUCKDB_STUBS_FILE, 'w') as source_file: - source_file.write("".join(new_content)) + Path(DUCKDB_STUBS_FILE).write_text("".join(new_content)) -if __name__ == '__main__': - raise ValueError("Please use 'generate_connection_code.py' instead of running the individual script(s)") +if __name__ == "__main__": + msg = "Please use 'generate_connection_code.py' instead of running the individual script(s)" + raise ValueError(msg) # generate() diff --git a/scripts/generate_connection_wrapper_methods.py b/scripts/generate_connection_wrapper_methods.py index 7be7256c..c8478602 100644 --- a/scripts/generate_connection_wrapper_methods.py +++ b/scripts/generate_connection_wrapper_methods.py @@ -1,15 +1,14 @@ -import os -import sys import json +import os +from pathlib import Path # Requires `python3 -m pip install cxxheaderparser pcpp` -from get_cpp_methods import get_methods, FunctionParam, ConnectionMethod -from typing import List, Tuple +from get_cpp_methods import ConnectionMethod, ReturnType, get_methods -os.chdir(os.path.dirname(__file__)) +os.chdir(Path(__file__).parent) -JSON_PATH = os.path.join("connection_methods.json") -DUCKDB_PYTHON_SOURCE = os.path.join("..", "duckdb_python.cpp") +JSON_PATH = "connection_methods.json" +DUCKDB_PYTHON_SOURCE = Path("..") / "src" / "duckdb_py" / "duckdb_python.cpp" START_MARKER = "\t// START_OF_CONNECTION_METHODS" END_MARKER = "\t// END_OF_CONNECTION_METHODS" @@ -33,23 +32,22 @@ ]) """ -WRAPPER_JSON_PATH = os.path.join("connection_wrapper_methods.json") +WRAPPER_JSON_PATH = "connection_wrapper_methods.json" -DUCKDB_INIT_FILE = os.path.join("..", "duckdb", "__init__.py") +DUCKDB_INIT_FILE = Path("..") / "duckdb" / "__init__.py" INIT_PY_START = "# START OF CONNECTION WRAPPER" INIT_PY_END = "# END OF CONNECTION WRAPPER" # Read the JSON file -with open(WRAPPER_JSON_PATH, 'r') as json_file: - wrapper_methods = json.load(json_file) +wrapper_methods = json.loads(Path(WRAPPER_JSON_PATH).read_text()) # On DuckDBPyConnection these are read_only_properties, they're basically functions without requiring () to invoke # that's not possible on 'duckdb' so it becomes a function call with no arguments (i.e duckdb.description()) -READONLY_PROPERTY_NAMES = ['description', 'rowcount'] +READONLY_PROPERTY_NAMES = ["description", "rowcount"] # These methods are not directly DuckDBPyConnection methods, # they first call 'FromDF' and then call a method on the created DuckDBPyRelation -SPECIAL_METHOD_NAMES = [x['name'] for x in wrapper_methods if x['name'] not in READONLY_PROPERTY_NAMES] +SPECIAL_METHOD_NAMES = [x["name"] for x in wrapper_methods if x["name"] not in READONLY_PROPERTY_NAMES] RETRIEVE_CONN_FROM_DICT = """auto connection_arg = kwargs.contains("conn") ? kwargs["conn"] : py::none(); auto conn = py::cast>(connection_arg); @@ -57,35 +55,36 @@ def is_py_args(method): - if 'args' not in method: + if "args" not in method: return False - args = method['args'] + args = method["args"] if len(args) == 0: return False - if args[0]['name'] != '*args': - return False - return True + return args[0]["name"] == "*args" def is_py_kwargs(method): - return 'kwargs_as_dict' in method and method['kwargs_as_dict'] == True + return "kwargs_as_dict" in method and method["kwargs_as_dict"] -def remove_section(content, start_marker, end_marker) -> Tuple[List[str], List[str]]: +def remove_section(content, start_marker, end_marker) -> tuple[list[str], list[str]]: start_index = -1 end_index = -1 for i, line in enumerate(content): if line.startswith(start_marker): if start_index != -1: - raise ValueError("Encountered the START_MARKER a second time, quitting!") + msg = "Encountered the START_MARKER a second time, quitting!" + raise ValueError(msg) start_index = i elif line.startswith(end_marker): if end_index != -1: - raise ValueError("Encountered the END_MARKER a second time, quitting!") + msg = "Encountered the END_MARKER a second time, quitting!" + raise ValueError(msg) end_index = i if start_index == -1 or end_index == -1: - raise ValueError("Couldn't find start or end marker in source file") + msg = "Couldn't find start or end marker in source file" + raise ValueError(msg) start_section = content[: start_index + 1] end_section = content[end_index:] @@ -94,36 +93,33 @@ def remove_section(content, start_marker, end_marker) -> Tuple[List[str], List[s def generate(): # Read the DUCKDB_PYTHON_SOURCE file - with open(DUCKDB_PYTHON_SOURCE, 'r') as source_file: - source_code = source_file.readlines() + source_code = Path(DUCKDB_PYTHON_SOURCE).read_text().splitlines() start_section, end_section = remove_section(source_code, START_MARKER, END_MARKER) # Read the DUCKDB_INIT_FILE file - with open(DUCKDB_INIT_FILE, 'r') as source_file: - source_code = source_file.readlines() + source_code = Path(DUCKDB_INIT_FILE).read_text().splitlines() py_start, py_end = remove_section(source_code, INIT_PY_START, INIT_PY_END) # ---- Generate the definition code from the json ---- # Read the JSON file - with open(JSON_PATH, 'r') as json_file: - connection_methods = json.load(json_file) + connection_methods = json.loads(Path(JSON_PATH).read_text()) # Collect the definitions from the pyconnection.hpp header - cpp_connection_defs = get_methods('DuckDBPyConnection') - cpp_relation_defs = get_methods('DuckDBPyRelation') + cpp_connection_defs = get_methods("DuckDBPyConnection") + cpp_relation_defs = get_methods("DuckDBPyRelation") DEFAULT_ARGUMENT_MAP = { - 'True': 'true', - 'False': 'false', - 'None': 'py::none()', - 'PythonUDFType.NATIVE': 'PythonUDFType::NATIVE', - 'PythonExceptionHandling.DEFAULT': 'PythonExceptionHandling::FORWARD_ERROR', - 'FunctionNullHandling.DEFAULT': 'FunctionNullHandling::DEFAULT_NULL_HANDLING', + "True": "true", + "False": "false", + "None": "py::none()", + "PythonUDFType.NATIVE": "PythonUDFType::NATIVE", + "PythonExceptionHandling.DEFAULT": "PythonExceptionHandling::FORWARD_ERROR", + "FunctionNullHandling.DEFAULT": "FunctionNullHandling::DEFAULT_NULL_HANDLING", } - def map_default(val): + def map_default(val) -> str: if val in DEFAULT_ARGUMENT_MAP: return DEFAULT_ARGUMENT_MAP[val] return val @@ -131,16 +127,16 @@ def map_default(val): def create_arguments(arguments) -> list: result = [] for arg in arguments: - if arg['name'] == '*args': + if arg["name"] == "*args": # py::args() should not have a corresponding py::arg() continue - argument = f"py::arg(\"{arg['name']}\")" - if 'allow_none' in arg: - value = str(arg['allow_none']).lower() + argument = f'py::arg("{arg["name"]}")' + if "allow_none" in arg: + value = str(arg["allow_none"]).lower() argument += f".none({value})" # Add the default argument if present - if 'default' in arg: - default = map_default(arg['default']) + if "default" in arg: + default = map_default(arg["default"]) argument += f" = {default}" result.append(argument) return result @@ -148,11 +144,11 @@ def create_arguments(arguments) -> list: def get_lambda_definition(name, method, definition: ConnectionMethod) -> str: param_definitions = [] if name in SPECIAL_METHOD_NAMES: - param_definitions.append('const PandasDataFrame &df') + param_definitions.append("const PandasDataFrame &df") param_definitions.extend([x.proto for x in definition.params]) if not is_py_kwargs(method): - param_definitions.append('shared_ptr conn = nullptr') + param_definitions.append("shared_ptr conn = nullptr") param_definitions = ", ".join(param_definitions) param_names = [x.name for x in definition.params] @@ -160,73 +156,67 @@ def get_lambda_definition(name, method, definition: ConnectionMethod) -> str: function_name = definition.name if name in SPECIAL_METHOD_NAMES: - function_name = 'FromDF(df)->' + function_name + function_name = "FromDF(df)->" + function_name format_dict = { - 'param_definitions': param_definitions, - 'opt_retrieval': '', - 'opt_return': '' if definition.is_void else 'return ', - 'function_name': function_name, - 'parameter_names': param_names, + "param_definitions": param_definitions, + "opt_retrieval": "", + "opt_return": "" if definition.return_type == ReturnType.VOID else "return", + "function_name": function_name, + "parameter_names": param_names, } if is_py_kwargs(method): - format_dict['opt_retrieval'] += RETRIEVE_CONN_FROM_DICT + format_dict["opt_retrieval"] += RETRIEVE_CONN_FROM_DICT return LAMBDA_FORMAT.format_map(format_dict) def create_definition(name, method, lambda_def) -> str: - definition = f"m.def(\"{name}\"" + definition = f'm.def("{name}"' definition += ", " definition += lambda_def definition += ", " - definition += f"\"{method['docs']}\"" - if 'args' in method and not is_py_args(method): + definition += f'"{method["docs"]}"' + if "args" in method and not is_py_args(method): definition += ", " - arguments = create_arguments(method['args']) - definition += ', '.join(arguments) - if 'kwargs' in method: + arguments = create_arguments(method["args"]) + definition += ", ".join(arguments) + if "kwargs" in method: definition += ", " if is_py_kwargs(method): definition += "py::kw_only()" else: definition += "py::kw_only(), " - arguments = create_arguments(method['kwargs']) - definition += ', '.join(arguments) + arguments = create_arguments(method["kwargs"]) + definition += ", ".join(arguments) definition += ");" return definition body = [] all_names = [] for method in connection_methods: - if isinstance(method['name'], list): - names = method['name'] - else: - names = [method['name']] - if 'kwargs' not in method: - method['kwargs'] = [] - method['kwargs'].append({'name': 'connection', 'type': 'Optional[DuckDBPyConnection]', 'default': 'None'}) + names = method["name"] if isinstance(method["name"], list) else [method["name"]] + if "kwargs" not in method: + method["kwargs"] = [] + method["kwargs"].append({"name": "connection", "type": "Optional[DuckDBPyConnection]", "default": "None"}) for name in names: - function_name = method['function'] + function_name = method["function"] cpp_definition = cpp_connection_defs[function_name] lambda_def = get_lambda_definition(name, method, cpp_definition) body.append(create_definition(name, method, lambda_def)) all_names.append(name) for method in wrapper_methods: - if isinstance(method['name'], list): - names = method['name'] - else: - names = [method['name']] - if 'kwargs' not in method: - method['kwargs'] = [] - method['kwargs'].append({'name': 'connection', 'type': 'Optional[DuckDBPyConnection]', 'default': 'None'}) + names = method["name"] if isinstance(method["name"], list) else [method["name"]] + if "kwargs" not in method: + method["kwargs"] = [] + method["kwargs"].append({"name": "connection", "type": "Optional[DuckDBPyConnection]", "default": "None"}) for name in names: - function_name = method['function'] + function_name = method["function"] if name in SPECIAL_METHOD_NAMES: cpp_definition = cpp_relation_defs[function_name] - if 'args' not in method: - method['args'] = [] - method['args'].insert(0, {'name': 'df', 'type': 'DataFrame'}) + if "args" not in method: + method["args"] = [] + method["args"].insert(0, {"name": "df", "type": "DataFrame"}) else: cpp_definition = cpp_connection_defs[function_name] lambda_def = get_lambda_definition(name, method, cpp_definition) @@ -235,24 +225,22 @@ def create_definition(name, method, lambda_def) -> str: # ---- End of generation code ---- - with_newlines = ['\t' + x + '\n' for x in body] + with_newlines = ["\t" + x + "\n" for x in body] # Recreate the file content by concatenating all the pieces together new_content = start_section + with_newlines + end_section # Write out the modified DUCKDB_PYTHON_SOURCE file - with open(DUCKDB_PYTHON_SOURCE, 'w') as source_file: - source_file.write("".join(new_content)) + Path(DUCKDB_PYTHON_SOURCE).write_text("".join(new_content)) - item_list = '\n'.join([f'\t{name},' for name in all_names]) - str_item_list = '\n'.join([f"\t'{name}'," for name in all_names]) - imports = PY_INIT_FORMAT.format(item_list=item_list, str_item_list=str_item_list).split('\n') - imports = [x + '\n' for x in imports] + item_list = "\n".join([f"\t{name}," for name in all_names]) + str_item_list = "\n".join([f"\t'{name}'," for name in all_names]) + imports = PY_INIT_FORMAT.format(item_list=item_list, str_item_list=str_item_list).split("\n") + imports = [x + "\n" for x in imports] init_py_content = py_start + imports + py_end # Write out the modified DUCKDB_INIT_FILE file - with open(DUCKDB_INIT_FILE, 'w') as source_file: - source_file.write("".join(init_py_content)) + Path(DUCKDB_INIT_FILE).write_text("".join(init_py_content)) -if __name__ == '__main__': +if __name__ == "__main__": # raise ValueError("Please use 'generate_connection_code.py' instead of running the individual script(s)") generate() diff --git a/scripts/generate_connection_wrapper_stubs.py b/scripts/generate_connection_wrapper_stubs.py index 62c60a84..78c1768c 100644 --- a/scripts/generate_connection_wrapper_stubs.py +++ b/scripts/generate_connection_wrapper_stubs.py @@ -1,11 +1,12 @@ -import os import json +import os +from pathlib import Path -os.chdir(os.path.dirname(__file__)) +os.chdir(Path(__file__).parent) -JSON_PATH = os.path.join("connection_methods.json") -WRAPPER_JSON_PATH = os.path.join("connection_wrapper_methods.json") -DUCKDB_STUBS_FILE = os.path.join("..", "duckdb-stubs", "__init__.pyi") +JSON_PATH = "connection_methods.json" +WRAPPER_JSON_PATH = "connection_wrapper_methods.json" +DUCKDB_STUBS_FILE = Path("..") / "duckdb" / "__init__.pyi" START_MARKER = "# START OF CONNECTION WRAPPER" END_MARKER = "# END OF CONNECTION WRAPPER" @@ -13,23 +14,25 @@ def generate(): # Read the DUCKDB_STUBS_FILE file - with open(DUCKDB_STUBS_FILE, 'r') as source_file: - source_code = source_file.readlines() + source_code = Path(DUCKDB_STUBS_FILE).read_text().splitlines() start_index = -1 end_index = -1 for i, line in enumerate(source_code): if line.startswith(START_MARKER): if start_index != -1: - raise ValueError("Encountered the START_MARKER a second time, quitting!") + msg = "Encountered the START_MARKER a second time, quitting!" + raise ValueError(msg) start_index = i elif line.startswith(END_MARKER): if end_index != -1: - raise ValueError("Encountered the END_MARKER a second time, quitting!") + msg = "Encountered the END_MARKER a second time, quitting!" + raise ValueError(msg) end_index = i if start_index == -1 or end_index == -1: - raise ValueError("Couldn't find start or end marker in source file") + msg = "Couldn't find start or end marker in source file" + raise ValueError(msg) start_section = source_code[: start_index + 1] end_section = source_code[end_index:] @@ -37,86 +40,82 @@ def generate(): methods = [] - # Read the JSON file - with open(JSON_PATH, 'r') as json_file: - connection_methods = json.load(json_file) - - with open(WRAPPER_JSON_PATH, 'r') as json_file: - wrapper_methods = json.load(json_file) + # Read the JSON files + connection_methods = json.loads(Path(JSON_PATH).read_text()) + wrapper_methods = json.loads(Path(WRAPPER_JSON_PATH).read_text()) methods.extend(connection_methods) methods.extend(wrapper_methods) # On DuckDBPyConnection these are read_only_properties, they're basically functions without requiring () to invoke # that's not possible on 'duckdb' so it becomes a function call with no arguments (i.e duckdb.description()) - READONLY_PROPERTY_NAMES = ['description', 'rowcount'] + READONLY_PROPERTY_NAMES = ["description", "rowcount"] # These methods are not directly DuckDBPyConnection methods, # they first call 'from_df' and then call a method on the created DuckDBPyRelation - SPECIAL_METHOD_NAMES = [x['name'] for x in wrapper_methods if x['name'] not in READONLY_PROPERTY_NAMES] + SPECIAL_METHOD_NAMES = [x["name"] for x in wrapper_methods if x["name"] not in READONLY_PROPERTY_NAMES] def create_arguments(arguments) -> list: result = [] for arg in arguments: argument = f"{arg['name']}: {arg['type']}" # Add the default argument if present - if 'default' in arg: - default = arg['default'] + if "default" in arg: + default = arg["default"] argument += f" = {default}" result.append(argument) return result - def create_definition(name, method, overloaded: bool) -> str: - if overloaded: - definition: str = "@overload\n" - else: - definition: str = "" - definition += f"def {name}(" + def create_definition(name, method) -> str: + definition = f"def {name}(" arguments = [] if name in SPECIAL_METHOD_NAMES: - arguments.append('df: pandas.DataFrame') - if 'args' in method: - arguments.extend(create_arguments(method['args'])) - if 'kwargs' in method: - if not any(x.startswith('*') for x in arguments): + arguments.append("df: pandas.DataFrame") + if "args" in method: + arguments.extend(create_arguments(method["args"])) + if "kwargs" in method: + if not any(x.startswith("*") for x in arguments): arguments.append("*") - arguments.extend(create_arguments(method['kwargs'])) - definition += ', '.join(arguments) + arguments.extend(create_arguments(method["kwargs"])) + definition += ", ".join(arguments) definition += ")" definition += f" -> {method['return']}: ..." return definition + def create_overloaded_definition(name, method) -> str: + return f"@overload\n{create_definition(name, method)}" + # We have "duplicate" methods, which are overloaded. # We keep note of them to add the @overload decorator. - overloaded_methods: set[str] = {m for m in connection_methods if isinstance(m['name'], list)} + overloaded_methods: set[str] = {m for m in connection_methods if isinstance(m["name"], list)} body = [] for method in methods: - if isinstance(method['name'], list): - names = method['name'] - else: - names = [method['name']] + names = method["name"] if isinstance(method["name"], list) else [method["name"]] # Artificially add 'connection' keyword argument - if 'kwargs' not in method: - method['kwargs'] = [] - method['kwargs'].append({'name': 'connection', 'type': 'DuckDBPyConnection', 'default': '...'}) + if "kwargs" not in method: + method["kwargs"] = [] + method["kwargs"].append({"name": "connection", "type": "DuckDBPyConnection", "default": "..."}) for name in names: - body.append(create_definition(name, method, name in overloaded_methods)) + if name in overloaded_methods: + body.append(create_overloaded_definition(name, method)) + else: + body.append(create_definition(name, method)) # ---- End of generation code ---- - with_newlines = [x + '\n' for x in body] + with_newlines = [x + "\n" for x in body] # Recreate the file content by concatenating all the pieces together new_content = start_section + with_newlines + end_section # Write out the modified DUCKDB_STUBS_FILE file - with open(DUCKDB_STUBS_FILE, 'w') as source_file: - source_file.write("".join(new_content)) + Path(DUCKDB_STUBS_FILE).write_text("".join(new_content)) -if __name__ == '__main__': - raise ValueError("Please use 'generate_connection_code.py' instead of running the individual script(s)") +if __name__ == "__main__": + msg = "Please use 'generate_connection_code.py' instead of running the individual script(s)" + raise ValueError(msg) # generate() diff --git a/scripts/generate_import_cache_cpp.py b/scripts/generate_import_cache_cpp.py index f902c5a5..d10dde0c 100644 --- a/scripts/generate_import_cache_cpp.py +++ b/scripts/generate_import_cache_cpp.py @@ -1,112 +1,104 @@ -import os - -script_dir = os.path.dirname(__file__) -from typing import List, Dict import json +from pathlib import Path + +script_dir = Path(__file__).parent # Load existing JSON data from a file if it exists json_data = {} -json_cache_path = os.path.join(script_dir, "cache_data.json") +json_cache_path = Path(script_dir) / "cache_data.json" try: - with open(json_cache_path, "r") as file: - json_data = json.load(file) + json_data = json.loads(Path(json_cache_path).read_text()) except FileNotFoundError: print("Please first use 'generate_import_cache_json.py' first to generate json") # deal with leaf nodes?? Those are just PythonImportCacheItem def get_class_name(path: str) -> str: - parts: List[str] = path.replace('_', '').split('.') + parts: list[str] = path.replace("_", "").split(".") parts = [x.title() for x in parts] - return ''.join(parts) + 'CacheItem' + return "".join(parts) + "CacheItem" def get_filename(name: str) -> str: - return name.replace('_', '').lower() + '_module.hpp' + return name.replace("_", "").lower() + "_module.hpp" def get_variable_name(name: str) -> str: - if name in ['short', 'ushort']: - return name + '_' + if name in ["short", "ushort"]: + return name + "_" return name -def collect_items_of_module(module: dict, collection: Dict): +def collect_items_of_module(module: dict, collection: dict): global json_data - children = module['children'] - collection[module['full_path']] = module + children = module["children"] + collection[module["full_path"]] = module for child in children: collect_items_of_module(json_data[child], collection) class CacheItem: - def __init__(self, module: dict, items): - self.name = module['name'] + def __init__(self, module: dict, items) -> None: + self.name = module["name"] self.module = module self.items = items - self.class_name = get_class_name(module['full_path']) + self.class_name = get_class_name(module["full_path"]) def get_full_module_path(self): - if self.module['type'] != 'module': - return '' - full_path = self.module['full_path'] + if self.module["type"] != "module": + return "" + full_path = self.module["full_path"] return f""" public: \tstatic constexpr const char *Name = "{full_path}"; """ def get_optionally_required(self): - if 'required' not in self.module: - return '' + if "required" not in self.module: + return "" string = f""" protected: \tbool IsRequired() const override final {{ -\t\treturn {str(self.module['required']).lower()}; +\t\treturn {str(self.module["required"]).lower()}; \t}} """ return string def get_variables(self): variables = [] - for key in self.module['children']: + for key in self.module["children"]: item = self.items[key] - name = item['name'] + name = item["name"] var_name = get_variable_name(name) - if item['children'] == []: - class_name = 'PythonImportCacheItem' - else: - class_name = get_class_name(item['full_path']) - variables.append(f'\t{class_name} {var_name};') - return '\n'.join(variables) + class_name = "PythonImportCacheItem" if item["children"] == [] else get_class_name(item["full_path"]) + variables.append(f"\t{class_name} {var_name};") + return "\n".join(variables) def get_initializer(self): variables = [] - for key in self.module['children']: + for key in self.module["children"]: item = self.items[key] - name = item['name'] + name = item["name"] var_name = get_variable_name(name) - if item['children'] == []: + if item["children"] == []: initialization = f'{var_name}("{name}", this)' variables.append(initialization) else: - if item['type'] == 'module': - arguments = '' - else: - arguments = 'this' - initialization = f'{var_name}({arguments})' + arguments = "" if item["type"] == "module" else "this" + initialization = f"{var_name}({arguments})" variables.append(initialization) - if self.module['type'] != 'module': + if self.module["type"] != "module": constructor_params = f'"{self.name}"' - constructor_params += ', parent' + constructor_params += ", parent" else: - full_path = self.module['full_path'] + full_path = self.module["full_path"] constructor_params = f'"{full_path}"' - return f'PythonImportCacheItem({constructor_params}), ' + ', '.join(variables) + '{}' + return f"PythonImportCacheItem({constructor_params}), " + ", ".join(variables) + "{}" def get_constructor(self): - if self.module['type'] == 'module': - return f'{self.class_name}()' - return f'{self.class_name}(optional_ptr parent)' + if self.module["type"] == "module": + return f"{self.class_name}()" + return f"{self.class_name}(optional_ptr parent)" def to_string(self): return f""" @@ -122,29 +114,26 @@ def to_string(self): """ -def collect_classes(items: Dict) -> List: - output: List = [] +def collect_classes(items: dict) -> list: + output: list = [] for item in items.values(): - if item['children'] == []: + if item["children"] == []: continue output.append(CacheItem(item, items)) return output class ModuleFile: - def __init__(self, module: dict): + def __init__(self, module: dict) -> None: self.module = module - self.file_name = get_filename(module['name']) + self.file_name = get_filename(module["name"]) self.items = {} collect_items_of_module(module, self.items) self.classes = collect_classes(self.items) self.classes.reverse() def get_classes(self): - classes = [] - for item in self.classes: - classes.append(item.to_string()) - return ''.join(classes) + return "".join(item.to_string() for item in self.classes) def to_string(self): string = f""" @@ -174,27 +163,26 @@ def to_string(self): return string -files: List[ModuleFile] = [] -for name, value in json_data.items(): - if value['full_path'] != value['name']: +files: list[ModuleFile] = [] +for value in json_data.values(): + if value["full_path"] != value["name"]: continue files.append(ModuleFile(value)) for file in files: content = file.to_string() - path = f'src/include/duckdb_python/import_cache/modules/{file.file_name}' - import_cache_path = os.path.join(script_dir, '..', path) - with open(import_cache_path, "w") as f: - f.write(content) + path = f"src/duckdb_py/include/duckdb_python/import_cache/modules/{file.file_name}" + import_cache_path = Path(script_dir) / ".." / path + import_cache_path.write_text(content) -def get_root_modules(files: List[ModuleFile]): +def get_root_modules(files: list[ModuleFile]): modules = [] for file in files: - name = file.module['name'] + name = file.module["name"] class_name = get_class_name(name) - modules.append(f'\t{class_name} {name};') - return '\n'.join(modules) + modules.append(f"\t{class_name} {name};") + return "\n".join(modules) # Generate the python_import_cache.hpp file @@ -237,25 +225,21 @@ def get_root_modules(files: List[ModuleFile]): """ -import_cache_path = os.path.join(script_dir, '..', 'src/include/duckdb_python/import_cache/python_import_cache.hpp') -with open(import_cache_path, "w") as f: - f.write(import_cache_file) +import_cache_path = Path(script_dir) / ".." / "src/duckdb_py/include/duckdb_python/import_cache/python_import_cache.hpp" +import_cache_path.write_text(import_cache_file) -def get_module_file_path_includes(files: List[ModuleFile]): - includes = [] - for file in files: - includes.append(f'#include "duckdb_python/import_cache/modules/{file.file_name}"') - return '\n'.join(includes) +def get_module_file_path_includes(files: list[ModuleFile]): + template = '#include "duckdb_python/import_cache/modules/{}' + return "\n".join(template.format(f.file_name) for f in files) module_includes = get_module_file_path_includes(files) -modules_header = os.path.join( - script_dir, '..', 'src/include/duckdb_python/import_cache/python_import_cache_modules.hpp' +modules_header = ( + Path(script_dir) / ".." / ("src/duckdb_py/include/duckdb_python/import_cache/python_import_cache_modules.hpp") ) -with open(modules_header, "w") as f: - f.write(module_includes) +modules_header.write_text(module_includes) # Generate the python_import_cache_modules.hpp file # listing all the generated header files diff --git a/scripts/generate_import_cache_json.py b/scripts/generate_import_cache_json.py index 7a59e6b7..dd8c3d5c 100644 --- a/scripts/generate_import_cache_json.py +++ b/scripts/generate_import_cache_json.py @@ -1,19 +1,19 @@ -import os - -script_dir = os.path.dirname(__file__) -from typing import List, Dict, Union +import contextlib import json +from pathlib import Path +from typing import Union -lines: List[str] = [file for file in open(f'{script_dir}/imports.py').read().split('\n') if file != ''] +script_dir = Path(__file__).parent +lines = [line for line in (script_dir / "imports.py").read_text().splitlines() if line] class ImportCacheAttribute: - def __init__(self, full_path: str): - parts = full_path.split('.') + def __init__(self, full_path: str) -> None: + parts = full_path.split(".") self.type = "attribute" self.name = parts[-1] self.full_path = full_path - self.children: Dict[str, "ImportCacheAttribute"] = {} + self.children: dict[str, ImportCacheAttribute] = {} def has_item(self, item_name: str) -> bool: return item_name in self.children @@ -41,12 +41,12 @@ def populate_json(self, json_data: dict): class ImportCacheModule: - def __init__(self, full_path): - parts = full_path.split('.') + def __init__(self, full_path) -> None: + parts = full_path.split(".") self.type = "module" self.name = parts[-1] self.full_path = full_path - self.items: Dict[str, Union[ImportCacheAttribute, "ImportCacheModule"]] = {} + self.items: dict[str, Union[ImportCacheAttribute, ImportCacheModule]] = {} def add_item(self, item: Union[ImportCacheAttribute, "ImportCacheModule"]): assert self.full_path != item.full_path @@ -78,46 +78,47 @@ def root_module(self) -> bool: class ImportCacheGenerator: - def __init__(self): - self.modules: Dict[str, ImportCacheModule] = {} + def __init__(self) -> None: + self.modules: dict[str, ImportCacheModule] = {} def add_module(self, path: str): - assert path.startswith('import') + assert path.startswith("import") path = path[7:] module = ImportCacheModule(path) self.modules[module.full_path] = module # Add it to the parent module if present - parts = path.split('.') + parts = path.split(".") if len(parts) == 1: return # This works back from the furthest child module to the top level module child_module = module for i in range(1, len(parts)): - parent_path = '.'.join(parts[: len(parts) - i]) + parent_path = ".".join(parts[: len(parts) - i]) parent_module = self.add_or_get_module(parent_path) parent_module.add_item(child_module) child_module = parent_module def add_or_get_module(self, module_name: str) -> ImportCacheModule: if module_name not in self.modules: - self.add_module(f'import {module_name}') + self.add_module(f"import {module_name}") return self.get_module(module_name) def get_module(self, module_name: str) -> ImportCacheModule: if module_name not in self.modules: - raise ValueError("Import the module before registering its attributes!") + msg = "Import the module before registering its attributes!" + raise ValueError(msg) return self.modules[module_name] def get_item(self, item_name: str) -> Union[ImportCacheModule, ImportCacheAttribute]: - parts = item_name.split('.') + parts = item_name.split(".") if len(parts) == 1: return self.get_module(item_name) parent = self.get_module(parts[0]) for i in range(1, len(parts)): - child_path = '.'.join(parts[: i + 1]) + child_path = ".".join(parts[: i + 1]) if parent.has_item(child_path): parent = parent.get_item(child_path) else: @@ -127,8 +128,8 @@ def get_item(self, item_name: str) -> Union[ImportCacheModule, ImportCacheAttrib return parent def add_attribute(self, path: str): - assert not path.startswith('import') - parts = path.split('.') + assert not path.startswith("import") + parts = path.split(".") assert len(parts) >= 2 self.get_item(path) @@ -145,32 +146,28 @@ def to_json(self): generator = ImportCacheGenerator() for line in lines: - if line.startswith('#'): + if line.startswith("#"): continue - if line.startswith('import'): + if line.startswith("import"): generator.add_module(line) else: generator.add_attribute(line) # Load existing JSON data from a file if it exists existing_json_data = {} -json_cache_path = os.path.join(script_dir, "cache_data.json") -try: - with open(json_cache_path, "r") as file: - existing_json_data = json.load(file) -except FileNotFoundError: - pass +json_cache_path = Path(script_dir) / "cache_data.json" +with contextlib.suppress(FileNotFoundError): + existing_json_data = json.loads(json_cache_path.read_text()) def update_json(existing: dict, new: dict) -> dict: # Iterate over keys in the new dictionary. for key in new: new_value = new[key] - old_value = existing[key] if key in existing else None + old_value = existing.get(key) # If both values are dictionaries, update recursively. if isinstance(new_value, dict) and isinstance(old_value, dict): - print(key) updated = update_json(old_value, new_value) existing[key] = updated else: @@ -184,5 +181,4 @@ def update_json(existing: dict, new: dict) -> dict: json_data = update_json(existing_json_data, json_data) # Save the merged JSON data back to the file -with open(json_cache_path, "w") as file: - json.dump(json_data, file, indent=4) +json_cache_path.write_text(json.dumps(json_data, indent=4)) diff --git a/scripts/get_cpp_methods.py b/scripts/get_cpp_methods.py index e784d054..a86b609e 100644 --- a/scripts/get_cpp_methods.py +++ b/scripts/get_cpp_methods.py @@ -1,33 +1,39 @@ # Requires `python3 -m pip install cxxheaderparser pcpp` -import os +from enum import Enum +from pathlib import Path +from typing import Callable import cxxheaderparser.parser -import cxxheaderparser.visitor import cxxheaderparser.preprocessor -from typing import List, Dict +import cxxheaderparser.visitor -scripts_folder = os.path.dirname(os.path.abspath(__file__)) +scripts_folder = Path(__file__).parent class FunctionParam: - def __init__(self, name: str, proto: str): + def __init__(self, name: str, proto: str) -> None: self.proto = proto self.name = name +class ReturnType(Enum): + VOID = 0 + OTHER = 1 + + class ConnectionMethod: - def __init__(self, name: str, params: List[FunctionParam], is_void: bool): + def __init__(self, name: str, params: list[FunctionParam], return_type: ReturnType) -> None: self.name = name self.params = params - self.is_void = is_void + self.return_type = return_type class Visitor: - def __init__(self, class_name: str): + def __init__(self, class_name: str) -> None: self.methods_dict = {} self.class_name = class_name - def __getattr__(self, name): + def __getattr__(self, name) -> Callable[[...], bool]: return lambda *state: True def on_class_start(self, state): @@ -36,8 +42,9 @@ def on_class_start(self, state): def on_class_method(self, state, node): name = node.name.format() - return_type = node.return_type - is_void = return_type and return_type.format() == "void" + return_type = ReturnType.VOID + if node.return_type and node.return_type.format() == "void": + return_type = ReturnType.OTHER params = [ FunctionParam( x.name, @@ -46,24 +53,27 @@ def on_class_method(self, state, node): for x in node.parameters ] - self.methods_dict[name] = ConnectionMethod(name, params, is_void) + self.methods_dict[name] = ConnectionMethod(name, params, return_type) -def get_methods(class_name: str) -> Dict[str, ConnectionMethod]: +def get_methods(class_name: str) -> dict[str, ConnectionMethod]: CLASSES = { - "DuckDBPyConnection": os.path.join( - scripts_folder, - "..", - "src", - "include", - "duckdb_python", - "pyconnection", - "pyconnection.hpp", - ), - "DuckDBPyRelation": os.path.join(scripts_folder, "..", "src", "include", "duckdb_python", "pyrelation.hpp"), + "DuckDBPyConnection": Path(scripts_folder) + / ".." + / "src" + / "duckdb_py" + / "include" + / "duckdb_python" + / "pyconnection" + / "pyconnection.hpp", + "DuckDBPyRelation": Path(scripts_folder) + / ".." + / "src" + / "duckdb_py" + / "include" + / "duckdb_python" + / "pyrelation.hpp", } - # Create a dictionary to store method names and prototypes - methods_dict = {} path = CLASSES[class_name] diff --git a/scripts/imports.py b/scripts/imports.py index 6b035768..d7e38750 100644 --- a/scripts/imports.py +++ b/scripts/imports.py @@ -1,3 +1,4 @@ +# ruff: noqa: E402, B018 import pyarrow import pyarrow.dataset @@ -6,6 +7,22 @@ pyarrow.Table pyarrow.RecordBatchReader pyarrow.ipc.MessageReader +pyarrow.scalar +pyarrow.date32 +pyarrow.time64 +pyarrow.timestamp +pyarrow.timestamp +pyarrow.timestamp +pyarrow.timestamp +pyarrow.timestamp +pyarrow.uint8 +pyarrow.uint16 +pyarrow.uint32 +pyarrow.uint64 +pyarrow.binary_view +pyarrow.decimal32 +pyarrow.decimal64 +pyarrow.decimal128 import pandas diff --git a/sqllogic/conftest.py b/sqllogic/conftest.py index 73219e0d..48315109 100644 --- a/sqllogic/conftest.py +++ b/sqllogic/conftest.py @@ -1,16 +1,17 @@ import itertools import pathlib -import pytest import random import re import typing import warnings -import glob + +import pytest + from .skipped_tests import SKIPPED_TESTS SQLLOGIC_TEST_CASE_NAME = "test_sqllogic" SQLLOGIC_TEST_PARAMETER = "test_script_path" -DUCKDB_ROOT_DIR = (pathlib.Path(__file__).parent.parent / 'external' / 'duckdb').resolve() +DUCKDB_ROOT_DIR = (pathlib.Path(__file__).parent.parent / "external" / "duckdb").resolve() def pytest_addoption(parser: pytest.Parser): @@ -65,8 +66,8 @@ def pytest_keyboard_interrupt(excinfo: pytest.ExceptionInfo): # Ensure all tests are properly cleaned up on keyboard interrupt from .test_sqllogic import test_sqllogic - if hasattr(test_sqllogic, 'executor') and test_sqllogic.executor: - if test_sqllogic.executor.database and hasattr(test_sqllogic.executor.database, 'connection'): + if hasattr(test_sqllogic, "executor") and test_sqllogic.executor: + if test_sqllogic.executor.database and hasattr(test_sqllogic.executor.database, "connection"): test_sqllogic.executor.database.connection.interrupt() test_sqllogic.executor.cleanup() test_sqllogic.executor = None @@ -90,7 +91,7 @@ def get_test_id(path: pathlib.Path, root_dir: pathlib.Path, config: pytest.Confi return str(path.relative_to(root_dir.parent)) -def get_test_marks(path: pathlib.Path, root_dir: pathlib.Path, config: pytest.Config) -> typing.List[typing.Any]: +def get_test_marks(path: pathlib.Path, root_dir: pathlib.Path, config: pytest.Config) -> list[typing.Any]: # Tests are tagged with the their category (i.e., name of their parent directory) category = path.parent.name @@ -126,11 +127,9 @@ def create_parameters_from_paths(paths, root_dir: pathlib.Path, config: pytest.C def scan_for_test_scripts(root_dir: pathlib.Path, config: pytest.Config) -> typing.Iterator[typing.Any]: - """ - Scans for .test files in the given directory and its subdirectories. + """Scans for .test files in the given directory and its subdirectories. Returns an iterator of pytest parameters (argument, id and marks). - """ - + """ # noqa: D205 # TODO: Add tests from extensions test_script_extensions = [".test", ".test_slow", ".test_coverage"] it = itertools.chain.from_iterable(root_dir.rglob(f"*{ext}") for ext in test_script_extensions) @@ -142,7 +141,7 @@ def pytest_generate_tests(metafunc: pytest.Metafunc): if metafunc.definition.name != SQLLOGIC_TEST_CASE_NAME: return - test_dirs: typing.List[pathlib.Path] = metafunc.config.getoption("test_dirs") + test_dirs: list[pathlib.Path] = metafunc.config.getoption("test_dirs") test_glob: typing.Optional[pathlib.Path] = metafunc.config.getoption("path") parameters = [] @@ -165,14 +164,12 @@ def pytest_generate_tests(metafunc: pytest.Metafunc): metafunc.parametrize(SQLLOGIC_TEST_PARAMETER, parameters) -def determine_test_offsets(config: pytest.Config, num_tests: int) -> typing.Tuple[int, int]: - """ - If start_offset and end_offset are specified, then these are used. +def determine_test_offsets(config: pytest.Config, num_tests: int) -> tuple[int, int]: + """If start_offset and end_offset are specified, then these are used. start_offset defaults to 0. end_offset defaults to and is capped to the last test index. start_offset_percentage and end_offset_percentage are used to calculate the start and end offsets based on the total number of tests. This is done in a way that a test run to 25% and another test run starting at 25% do not overlap by excluding the 25th percent test. - """ - + """ # noqa: D205 start_offset = config.getoption("start_offset") end_offset = config.getoption("end_offset") start_offset_percentage = config.getoption("start_offset_percentage") @@ -182,16 +179,20 @@ def determine_test_offsets(config: pytest.Config, num_tests: int) -> typing.Tupl percentage_specified = start_offset_percentage is not None or end_offset_percentage is not None if index_specified and percentage_specified: - raise ValueError("You can only specify either start/end offsets or start/end offset percentages, not both") + msg = "You can only specify either start/end offsets or start/end offset percentages, not both" + raise ValueError(msg) if start_offset is not None and start_offset < 0: - raise ValueError("--start-offset must be a non-negative integer") + msg = "--start-offset must be a non-negative integer" + raise ValueError(msg) if start_offset_percentage is not None and (start_offset_percentage < 0 or start_offset_percentage > 100): - raise ValueError("--start-offset-percentage must be between 0 and 100") + msg = "--start-offset-percentage must be between 0 and 100" + raise ValueError(msg) if end_offset_percentage is not None and (end_offset_percentage < 0 or end_offset_percentage > 100): - raise ValueError("--end-offset-percentage must be between 0 and 100") + msg = "--end-offset-percentage must be between 0 and 100" + raise ValueError(msg) if start_offset is None: if start_offset_percentage is not None: @@ -200,9 +201,8 @@ def determine_test_offsets(config: pytest.Config, num_tests: int) -> typing.Tupl start_offset = 0 if end_offset is not None and end_offset < start_offset: - raise ValueError( - f"--end-offset ({end_offset}) must be greater than or equal to the start offset ({start_offset})" - ) + msg = f"--end-offset ({end_offset}) must be greater than or equal to the start offset ({start_offset})" + raise ValueError(msg) if end_offset is None: if end_offset_percentage is not None: @@ -271,9 +271,7 @@ def pytest_collection_modifyitems(session: pytest.Session, config: pytest.Config def pytest_runtest_setup(item: pytest.Item): - """ - Show the test index after the test name - """ + """Show the test index after the test name.""" def get_from_tuple_list(tuples, key): for t in tuples: diff --git a/sqllogic/skipped_tests.py b/sqllogic/skipped_tests.py index 39269c42..485ed9b9 100644 --- a/sqllogic/skipped_tests.py +++ b/sqllogic/skipped_tests.py @@ -1,42 +1,42 @@ SKIPPED_TESTS = set( [ - 'test/sql/timezone/disable_timestamptz_casts.test', # <-- ICU extension is always loaded - 'test/sql/copy/return_stats_truncate.test', # <-- handling was changed - 'test/sql/copy/return_stats.test', # <-- handling was changed - 'test/sql/copy/parquet/writer/skip_empty_write.test', # <-- handling was changed - 'test/sql/types/map/map_empty.test', - 'test/extension/wrong_function_type.test', # <-- JSON is always loaded - 'test/sql/insert/test_insert_invalid.test', # <-- doesn't parse properly - 'test/sql/cast/cast_error_location.test', # <-- python exception doesn't contain error location yet - 'test/sql/pragma/test_query_log.test', # <-- query_log gets filled with NULL when con.query(...) is used - 'test/sql/json/table/read_json_objects.test', # <-- Python client is always loaded with JSON available - 'test/sql/copy/csv/zstd_crash.test', # <-- Python client is always loaded with Parquet available - 'test/sql/error/extension_function_error.test', # <-- Python client is always loaded with TPCH available - 'test/optimizer/joins/tpcds_nofail.test', # <-- Python client is always loaded with TPCDS available - 'test/sql/settings/errors_as_json.test', # <-- errors_as_json not currently supported in Python - 'test/sql/parallelism/intraquery/depth_first_evaluation_union_and_join.test', # <-- Python client is always loaded with TPCDS available - 'test/sql/types/timestamp/test_timestamp_tz.test', # <-- Python client is always loaded wih ICU available - making the TIMESTAMPTZ::DATE cast pass - 'test/sql/parser/invisible_spaces.test', # <-- Parser is getting tripped up on the invisible spaces - 'test/sql/copy/csv/code_cov/csv_state_machine_invalid_utf.test', # <-- ConversionException is empty, see Python Mega Issue (duckdb-internal #1488) - 'test/sql/copy/csv/test_csv_timestamp_tz.test', # <-- ICU is always loaded - 'test/fuzzer/duckfuzz/duck_fuzz_column_binding_tests.test', # <-- ICU is always loaded - 'test/sql/pragma/test_custom_optimizer_profiling.test', # Because of logic related to enabling 'restart' statement capabilities, this will not measure the right statement - 'test/sql/pragma/test_custom_profiling_settings.test', # Because of logic related to enabling 'restart' statement capabilities, this will not measure the right statement - 'test/sql/copy/csv/test_copy.test', # JSON is always loaded - 'test/sql/copy/csv/test_timestamptz_12926.test', # ICU is always loaded - 'test/fuzzer/pedro/in_clause_optimization_error.test', # error message differs due to a different execution path - 'test/sql/order/test_limit_parameter.test', # error message differs due to a different execution path - 'test/sql/catalog/test_set_search_path.test', # current_query() is not the same - 'test/sql/catalog/table/create_table_parameters.test', # prepared statement error quirks - 'test/sql/pragma/profiling/test_custom_profiling_rows_scanned.test', # we perform additional queries that mess with the expected metrics - 'test/sql/pragma/profiling/test_custom_profiling_disable_metrics.test', # we perform additional queries that mess with the expected metrics - 'test/sql/pragma/profiling/test_custom_profiling_result_set_size.test', # we perform additional queries that mess with the expected metrics - 'test/sql/pragma/profiling/test_custom_profiling_result_set_size.test', # we perform additional queries that mess with the expected metrics - 'test/sql/cte/materialized/materialized_cte_modifiers.test', # problems connected to auto installing tpcds from remote - 'test/sql/tpcds/dsdgen_readonly.test', # problems connected to auto installing tpcds from remote - 'test/sql/tpcds/tpcds_sf0.test', # problems connected to auto installing tpcds from remote - 'test/sql/optimizer/plan/test_filter_pushdown_materialized_cte.test', # problems connected to auto installing tpcds from remote - 'test/sql/explain/test_explain_analyze.test', # unknown problem with changes in API - 'test/sql/pragma/profiling/test_profiling_all.test', # Because of logic related to enabling 'restart' statement capabilities, this will not measure the right statement + "test/sql/timezone/disable_timestamptz_casts.test", # <-- ICU extension is always loaded + "test/sql/copy/return_stats_truncate.test", # <-- handling was changed + "test/sql/copy/return_stats.test", # <-- handling was changed + "test/sql/copy/parquet/writer/skip_empty_write.test", # <-- handling was changed + "test/sql/types/map/map_empty.test", + "test/extension/wrong_function_type.test", # <-- JSON is always loaded + "test/sql/insert/test_insert_invalid.test", # <-- doesn't parse properly + "test/sql/cast/cast_error_location.test", # <-- python exception doesn't contain error location yet + "test/sql/pragma/test_query_log.test", # <-- query_log gets filled with NULL when con.query(...) is used + "test/sql/json/table/read_json_objects.test", # <-- Python client is always loaded with JSON available + "test/sql/copy/csv/zstd_crash.test", # <-- Python client is always loaded with Parquet available + "test/sql/error/extension_function_error.test", # <-- Python client is always loaded with TPCH available + "test/optimizer/joins/tpcds_nofail.test", # <-- Python client is always loaded with TPCDS available + "test/sql/settings/errors_as_json.test", # <-- errors_as_json not currently supported in Python + "test/sql/parallelism/intraquery/depth_first_evaluation_union_and_join.test", # <-- Python client is always loaded with TPCDS available + "test/sql/types/timestamp/test_timestamp_tz.test", # <-- Python client is always loaded wih ICU available - making the TIMESTAMPTZ::DATE cast pass + "test/sql/parser/invisible_spaces.test", # <-- Parser is getting tripped up on the invisible spaces + "test/sql/copy/csv/code_cov/csv_state_machine_invalid_utf.test", # <-- ConversionException is empty, see Python Mega Issue (duckdb-internal #1488) + "test/sql/copy/csv/test_csv_timestamp_tz.test", # <-- ICU is always loaded + "test/fuzzer/duckfuzz/duck_fuzz_column_binding_tests.test", # <-- ICU is always loaded + "test/sql/pragma/test_custom_optimizer_profiling.test", # Because of logic related to enabling 'restart' statement capabilities, this will not measure the right statement + "test/sql/pragma/test_custom_profiling_settings.test", # Because of logic related to enabling 'restart' statement capabilities, this will not measure the right statement + "test/sql/copy/csv/test_copy.test", # JSON is always loaded + "test/sql/copy/csv/test_timestamptz_12926.test", # ICU is always loaded + "test/fuzzer/pedro/in_clause_optimization_error.test", # error message differs due to a different execution path + "test/sql/order/test_limit_parameter.test", # error message differs due to a different execution path + "test/sql/catalog/test_set_search_path.test", # current_query() is not the same + "test/sql/catalog/table/create_table_parameters.test", # prepared statement error quirks + "test/sql/pragma/profiling/test_custom_profiling_rows_scanned.test", # we perform additional queries that mess with the expected metrics + "test/sql/pragma/profiling/test_custom_profiling_disable_metrics.test", # we perform additional queries that mess with the expected metrics + "test/sql/pragma/profiling/test_custom_profiling_result_set_size.test", # we perform additional queries that mess with the expected metrics + "test/sql/pragma/profiling/test_custom_profiling_result_set_size.test", # we perform additional queries that mess with the expected metrics + "test/sql/cte/materialized/materialized_cte_modifiers.test", # problems connected to auto installing tpcds from remote + "test/sql/tpcds/dsdgen_readonly.test", # problems connected to auto installing tpcds from remote + "test/sql/tpcds/tpcds_sf0.test", # problems connected to auto installing tpcds from remote + "test/sql/optimizer/plan/test_filter_pushdown_materialized_cte.test", # problems connected to auto installing tpcds from remote + "test/sql/explain/test_explain_analyze.test", # unknown problem with changes in API + "test/sql/pragma/profiling/test_profiling_all.test", # Because of logic related to enabling 'restart' statement capabilities, this will not measure the right statement ] ) diff --git a/sqllogic/test_sqllogic.py b/sqllogic/test_sqllogic.py index ee7426cd..35736015 100644 --- a/sqllogic/test_sqllogic.py +++ b/sqllogic/test_sqllogic.py @@ -1,31 +1,32 @@ import gc import os import pathlib -import pytest import signal import sys -from typing import Any, Generator, Optional +from collections.abc import Generator +from typing import Any, Optional + +import pytest -sys.path.append(str(pathlib.Path(__file__).parent.parent / 'external' / 'duckdb' / 'scripts')) +sys.path.append(str(pathlib.Path(__file__).parent.parent / "external" / "duckdb" / "scripts")) from sqllogictest import ( - SQLParserException, SQLLogicParser, SQLLogicTest, + SQLParserException, ) - from sqllogictest.result import ( - TestException, - SQLLogicRunner, - SQLLogicDatabase, - SQLLogicContext, ExecuteResult, + SQLLogicContext, + SQLLogicDatabase, + SQLLogicRunner, + TestException, ) def sigquit_handler(signum, frame): # Access the executor from the test_sqllogic function - if hasattr(test_sqllogic, 'executor') and test_sqllogic.executor: - if test_sqllogic.executor.database and hasattr(test_sqllogic.executor.database, 'connection'): + if hasattr(test_sqllogic, "executor") and test_sqllogic.executor: + if test_sqllogic.executor.database and hasattr(test_sqllogic.executor.database, "connection"): test_sqllogic.executor.database.connection.interrupt() test_sqllogic.executor.cleanup() test_sqllogic.executor = None @@ -39,7 +40,7 @@ def sigquit_handler(signum, frame): class SQLLogicTestExecutor(SQLLogicRunner): - def __init__(self, test_directory: str, build_directory: Optional[str] = None): + def __init__(self, test_directory: str, build_directory: Optional[str] = None) -> None: super().__init__(build_directory) self.test_directory = test_directory # TODO: get this from the `duckdb` package @@ -85,13 +86,13 @@ def execute_test(self, test: SQLLogicTest) -> ExecuteResult: self.original_sqlite_test = self.test.is_sqlite_test() # Top level keywords - keywords = {'__TEST_DIR__': self.get_test_directory(), '__WORKING_DIRECTORY__': os.getcwd()} + keywords = {"__TEST_DIR__": self.get_test_directory(), "__WORKING_DIRECTORY__": os.getcwd()} def update_value(_: SQLLogicContext) -> Generator[Any, Any, Any]: # Yield once to represent one iteration, do not touch the keywords yield None - self.database = SQLLogicDatabase(':memory:', None) + self.database = SQLLogicDatabase(":memory:", None) pool = self.database.connect() context = SQLLogicContext(pool, self, test.statements, keywords, update_value) pool.initialize_connection(context, pool.get_connection()) @@ -126,7 +127,7 @@ def update_value(_: SQLLogicContext) -> Generator[Any, Any, Any]: def cleanup(self): if self.database: - if hasattr(self.database, 'connection'): + if hasattr(self.database, "connection"): self.database.connection.interrupt() self.database.reset() self.database = None @@ -160,6 +161,6 @@ def test_sqllogic(test_script_path: pathlib.Path, pytestconfig: pytest.Config, t test_sqllogic.executor = None -if __name__ == '__main__': +if __name__ == "__main__": # Pass all arguments including the script name to pytest sys.exit(pytest.main(sys.argv)) diff --git a/src/duckdb_py/CMakeLists.txt b/src/duckdb_py/CMakeLists.txt index 2252ba29..3d06b062 100644 --- a/src/duckdb_py/CMakeLists.txt +++ b/src/duckdb_py/CMakeLists.txt @@ -12,22 +12,22 @@ add_subdirectory(common) add_subdirectory(pandas) add_subdirectory(arrow) -add_library(python_src OBJECT - dataframe.cpp - duckdb_python.cpp - importer.cpp - map.cpp - path_like.cpp - pyconnection.cpp - pyexpression.cpp - pyfilesystem.cpp - pyrelation.cpp - pyresult.cpp - pystatement.cpp - python_dependency.cpp - python_import_cache.cpp - python_replacement_scan.cpp - python_udf.cpp -) +add_library( + python_src OBJECT + dataframe.cpp + duckdb_python.cpp + importer.cpp + map.cpp + path_like.cpp + pyconnection.cpp + pyexpression.cpp + pyfilesystem.cpp + pyrelation.cpp + pyresult.cpp + pystatement.cpp + python_dependency.cpp + python_import_cache.cpp + python_replacement_scan.cpp + python_udf.cpp) target_link_libraries(python_src PRIVATE _duckdb_dependencies) diff --git a/src/duckdb_py/arrow/CMakeLists.txt b/src/duckdb_py/arrow/CMakeLists.txt index 29b188c6..9a9188b8 100644 --- a/src/duckdb_py/arrow/CMakeLists.txt +++ b/src/duckdb_py/arrow/CMakeLists.txt @@ -1,4 +1,5 @@ # this is used for clang-tidy checks -add_library(python_arrow OBJECT arrow_array_stream.cpp arrow_export_utils.cpp) +add_library(python_arrow OBJECT arrow_array_stream.cpp arrow_export_utils.cpp + pyarrow_filter_pushdown.cpp) target_link_libraries(python_arrow PRIVATE _duckdb_dependencies) diff --git a/src/duckdb_py/arrow/arrow_array_stream.cpp b/src/duckdb_py/arrow/arrow_array_stream.cpp index 533c31ed..f9cfd1bb 100644 --- a/src/duckdb_py/arrow/arrow_array_stream.cpp +++ b/src/duckdb_py/arrow/arrow_array_stream.cpp @@ -1,22 +1,15 @@ #include "duckdb_python/arrow/arrow_array_stream.hpp" +#include "duckdb_python/arrow/pyarrow_filter_pushdown.hpp" -#include "duckdb/common/types/value_map.hpp" -#include "duckdb/planner/filter/in_filter.hpp" -#include "duckdb/planner/filter/optional_filter.hpp" +#include "duckdb_python/pyconnection/pyconnection.hpp" +#include "duckdb_python/pyrelation.hpp" +#include "duckdb_python/pyresult.hpp" +#include "duckdb/function/table/arrow.hpp" #include "duckdb/common/assert.hpp" #include "duckdb/common/common.hpp" #include "duckdb/common/limits.hpp" #include "duckdb/main/client_config.hpp" -#include "duckdb/planner/filter/conjunction_filter.hpp" -#include "duckdb/planner/filter/constant_filter.hpp" -#include "duckdb/planner/filter/struct_filter.hpp" -#include "duckdb/planner/table_filter.hpp" - -#include "duckdb_python/pyconnection/pyconnection.hpp" -#include "duckdb_python/pyrelation.hpp" -#include "duckdb_python/pyresult.hpp" -#include "duckdb/function/table/arrow.hpp" namespace duckdb { @@ -56,8 +49,8 @@ py::object PythonTableArrowArrayStreamFactory::ProduceScanner(DBConfig &config, } if (has_filter) { - auto filter = TransformFilter(*filters, parameters.projected_columns.projection_map, filter_to_col, - client_properties, arrow_table); + auto filter = PyArrowFilterPushdown::TransformFilter(*filters, parameters.projected_columns.projection_map, + filter_to_col, client_properties, arrow_table); if (!filter.is(py::none())) { kwargs["filter"] = filter; } @@ -171,323 +164,4 @@ void PythonTableArrowArrayStreamFactory::GetSchema(uintptr_t factory_ptr, ArrowS GetSchemaInternal(arrow_obj_handle, schema); } -string ConvertTimestampUnit(ArrowDateTimeType unit) { - switch (unit) { - case ArrowDateTimeType::MICROSECONDS: - return "us"; - case ArrowDateTimeType::MILLISECONDS: - return "ms"; - case ArrowDateTimeType::NANOSECONDS: - return "ns"; - case ArrowDateTimeType::SECONDS: - return "s"; - default: - throw NotImplementedException("DatetimeType not recognized in ConvertTimestampUnit: %d", (int)unit); - } -} - -int64_t ConvertTimestampTZValue(int64_t base_value, ArrowDateTimeType datetime_type) { - auto input = timestamp_t(base_value); - if (!Timestamp::IsFinite(input)) { - return base_value; - } - - switch (datetime_type) { - case ArrowDateTimeType::MICROSECONDS: - return Timestamp::GetEpochMicroSeconds(input); - case ArrowDateTimeType::MILLISECONDS: - return Timestamp::GetEpochMs(input); - case ArrowDateTimeType::NANOSECONDS: - return Timestamp::GetEpochNanoSeconds(input); - case ArrowDateTimeType::SECONDS: - return Timestamp::GetEpochSeconds(input); - default: - throw NotImplementedException("DatetimeType not recognized in ConvertTimestampTZValue"); - } -} - -py::object GetScalar(Value &constant, const string &timezone_config, const ArrowType &type) { - py::object scalar = py::module_::import("pyarrow").attr("scalar"); - auto &import_cache = *DuckDBPyConnection::ImportCache(); - py::object dataset_scalar = import_cache.pyarrow.dataset().attr("scalar"); - py::object scalar_value; - switch (constant.type().id()) { - case LogicalTypeId::BOOLEAN: - return dataset_scalar(constant.GetValue()); - case LogicalTypeId::TINYINT: - return dataset_scalar(constant.GetValue()); - case LogicalTypeId::SMALLINT: - return dataset_scalar(constant.GetValue()); - case LogicalTypeId::INTEGER: - return dataset_scalar(constant.GetValue()); - case LogicalTypeId::BIGINT: - return dataset_scalar(constant.GetValue()); - case LogicalTypeId::DATE: { - py::object date_type = py::module_::import("pyarrow").attr("date32"); - return dataset_scalar(scalar(constant.GetValue(), date_type())); - } - case LogicalTypeId::TIME: { - py::object date_type = py::module_::import("pyarrow").attr("time64"); - return dataset_scalar(scalar(constant.GetValue(), date_type("us"))); - } - case LogicalTypeId::TIMESTAMP: { - py::object date_type = py::module_::import("pyarrow").attr("timestamp"); - return dataset_scalar(scalar(constant.GetValue(), date_type("us"))); - } - case LogicalTypeId::TIMESTAMP_MS: { - py::object date_type = py::module_::import("pyarrow").attr("timestamp"); - return dataset_scalar(scalar(constant.GetValue(), date_type("ms"))); - } - case LogicalTypeId::TIMESTAMP_NS: { - py::object date_type = py::module_::import("pyarrow").attr("timestamp"); - return dataset_scalar(scalar(constant.GetValue(), date_type("ns"))); - } - case LogicalTypeId::TIMESTAMP_SEC: { - py::object date_type = py::module_::import("pyarrow").attr("timestamp"); - return dataset_scalar(scalar(constant.GetValue(), date_type("s"))); - } - case LogicalTypeId::TIMESTAMP_TZ: { - auto &datetime_info = type.GetTypeInfo(); - auto base_value = constant.GetValue(); - auto arrow_datetime_type = datetime_info.GetDateTimeType(); - auto time_unit_string = ConvertTimestampUnit(arrow_datetime_type); - auto converted_value = ConvertTimestampTZValue(base_value, arrow_datetime_type); - py::object date_type = py::module_::import("pyarrow").attr("timestamp"); - return dataset_scalar(scalar(converted_value, date_type(time_unit_string, py::arg("tz") = timezone_config))); - } - case LogicalTypeId::UTINYINT: { - py::object integer_type = py::module_::import("pyarrow").attr("uint8"); - return dataset_scalar(scalar(constant.GetValue(), integer_type())); - } - case LogicalTypeId::USMALLINT: { - py::object integer_type = py::module_::import("pyarrow").attr("uint16"); - return dataset_scalar(scalar(constant.GetValue(), integer_type())); - } - case LogicalTypeId::UINTEGER: { - py::object integer_type = py::module_::import("pyarrow").attr("uint32"); - return dataset_scalar(scalar(constant.GetValue(), integer_type())); - } - case LogicalTypeId::UBIGINT: { - py::object integer_type = py::module_::import("pyarrow").attr("uint64"); - return dataset_scalar(scalar(constant.GetValue(), integer_type())); - } - case LogicalTypeId::FLOAT: - return dataset_scalar(constant.GetValue()); - case LogicalTypeId::DOUBLE: - return dataset_scalar(constant.GetValue()); - case LogicalTypeId::VARCHAR: - return dataset_scalar(constant.ToString()); - case LogicalTypeId::BLOB: { - if (type.GetTypeInfo().GetSizeType() == ArrowVariableSizeType::VIEW) { - py::object binary_view_type = py::module_::import("pyarrow").attr("binary_view"); - return dataset_scalar(scalar(py::bytes(constant.GetValueUnsafe()), binary_view_type())); - } - return dataset_scalar(py::bytes(constant.GetValueUnsafe())); - } - case LogicalTypeId::DECIMAL: { - py::object decimal_type; - auto &datetime_info = type.GetTypeInfo(); - auto bit_width = datetime_info.GetBitWidth(); - switch (bit_width) { - case DecimalBitWidth::DECIMAL_32: - decimal_type = py::module_::import("pyarrow").attr("decimal32"); - break; - case DecimalBitWidth::DECIMAL_64: - decimal_type = py::module_::import("pyarrow").attr("decimal64"); - break; - case DecimalBitWidth::DECIMAL_128: - decimal_type = py::module_::import("pyarrow").attr("decimal128"); - break; - default: - throw NotImplementedException("Unsupported precision for Arrow Decimal Type."); - } - - uint8_t width; - uint8_t scale; - constant.type().GetDecimalProperties(width, scale); - // pyarrow only allows 'decimal.Decimal' to be used to construct decimal scalars such as 0.05 - auto val = import_cache.decimal.Decimal()(constant.ToString()); - return dataset_scalar( - scalar(std::move(val), decimal_type(py::arg("precision") = width, py::arg("scale") = scale))); - } - default: - throw NotImplementedException("Unimplemented type \"%s\" for Arrow Filter Pushdown", - constant.type().ToString()); - } -} - -py::object TransformFilterRecursive(TableFilter &filter, vector column_ref, const string &timezone_config, - const ArrowType &type) { - auto &import_cache = *DuckDBPyConnection::ImportCache(); - py::object field = import_cache.pyarrow.dataset().attr("field"); - switch (filter.filter_type) { - case TableFilterType::CONSTANT_COMPARISON: { - auto &constant_filter = filter.Cast(); - auto constant_field = field(py::tuple(py::cast(column_ref))); - auto constant_value = GetScalar(constant_filter.constant, timezone_config, type); - - bool is_nan = false; - auto &constant = constant_filter.constant; - auto &constant_type = constant.type(); - if (constant_type.id() == LogicalTypeId::FLOAT) { - is_nan = Value::IsNan(constant.GetValue()); - } else if (constant_type.id() == LogicalTypeId::DOUBLE) { - is_nan = Value::IsNan(constant.GetValue()); - } - - // Special handling for NaN comparisons (to explicitly violate IEEE-754) - if (is_nan) { - switch (constant_filter.comparison_type) { - case ExpressionType::COMPARE_EQUAL: - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - return constant_field.attr("is_nan")(); - case ExpressionType::COMPARE_LESSTHAN: - case ExpressionType::COMPARE_NOTEQUAL: - return constant_field.attr("is_nan")().attr("__invert__")(); - case ExpressionType::COMPARE_GREATERTHAN: - // Nothing is greater than NaN - return import_cache.pyarrow.dataset().attr("scalar")(false); - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - // Everything is less than or equal to NaN - return import_cache.pyarrow.dataset().attr("scalar")(true); - default: - throw NotImplementedException("Unsupported comparison type (%s) for NaN values", - EnumUtil::ToString(constant_filter.comparison_type)); - } - } - - switch (constant_filter.comparison_type) { - case ExpressionType::COMPARE_EQUAL: - return constant_field.attr("__eq__")(constant_value); - case ExpressionType::COMPARE_LESSTHAN: - return constant_field.attr("__lt__")(constant_value); - case ExpressionType::COMPARE_GREATERTHAN: - return constant_field.attr("__gt__")(constant_value); - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - return constant_field.attr("__le__")(constant_value); - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - return constant_field.attr("__ge__")(constant_value); - case ExpressionType::COMPARE_NOTEQUAL: - return constant_field.attr("__ne__")(constant_value); - default: - throw NotImplementedException("Comparison Type %s can't be an Arrow Scan Pushdown Filter", - EnumUtil::ToString(constant_filter.comparison_type)); - } - } - //! We do not pushdown is null yet - case TableFilterType::IS_NULL: { - auto constant_field = field(py::tuple(py::cast(column_ref))); - return constant_field.attr("is_null")(); - } - case TableFilterType::IS_NOT_NULL: { - auto constant_field = field(py::tuple(py::cast(column_ref))); - return constant_field.attr("is_valid")(); - } - //! We do not pushdown or conjunctions yet - case TableFilterType::CONJUNCTION_OR: { - auto &or_filter = filter.Cast(); - py::object expression = py::none(); - for (idx_t i = 0; i < or_filter.child_filters.size(); i++) { - auto &child_filter = *or_filter.child_filters[i]; - py::object child_expression = TransformFilterRecursive(child_filter, column_ref, timezone_config, type); - if (child_expression.is(py::none())) { - continue; - } - if (expression.is(py::none())) { - expression = std::move(child_expression); - } else { - expression = expression.attr("__or__")(child_expression); - } - } - return expression; - } - case TableFilterType::CONJUNCTION_AND: { - auto &and_filter = filter.Cast(); - py::object expression = py::none(); - for (idx_t i = 0; i < and_filter.child_filters.size(); i++) { - auto &child_filter = *and_filter.child_filters[i]; - py::object child_expression = TransformFilterRecursive(child_filter, column_ref, timezone_config, type); - if (child_expression.is(py::none())) { - continue; - } - if (expression.is(py::none())) { - expression = std::move(child_expression); - } else { - expression = expression.attr("__and__")(child_expression); - } - } - return expression; - } - case TableFilterType::STRUCT_EXTRACT: { - auto &struct_filter = filter.Cast(); - auto &child_name = struct_filter.child_name; - auto &struct_type_info = type.GetTypeInfo(); - auto &struct_child_type = struct_type_info.GetChild(struct_filter.child_idx); - - column_ref.push_back(child_name); - auto child_expr = TransformFilterRecursive(*struct_filter.child_filter, std::move(column_ref), timezone_config, - struct_child_type); - return child_expr; - } - case TableFilterType::OPTIONAL_FILTER: { - auto &optional_filter = filter.Cast(); - if (!optional_filter.child_filter) { - return py::none(); - } - return TransformFilterRecursive(*optional_filter.child_filter, column_ref, timezone_config, type); - } - case TableFilterType::IN_FILTER: { - auto &in_filter = filter.Cast(); - ConjunctionOrFilter or_filter; - value_set_t unique_values; - for (const auto &value : in_filter.values) { - if (unique_values.find(value) == unique_values.end()) { - unique_values.insert(value); - } - } - for (const auto &value : unique_values) { - or_filter.child_filters.push_back(make_uniq(ExpressionType::COMPARE_EQUAL, value)); - } - return TransformFilterRecursive(or_filter, column_ref, timezone_config, type); - } - case TableFilterType::DYNAMIC_FILTER: { - //! Ignore dynamic filters for now, not necessary for correctness - return py::none(); - } - default: - throw NotImplementedException("Pushdown Filter Type %s is not currently supported in PyArrow Scans", - EnumUtil::ToString(filter.filter_type)); - } -} - -py::object PythonTableArrowArrayStreamFactory::TransformFilter(TableFilterSet &filter_collection, - std::unordered_map &columns, - unordered_map filter_to_col, - const ClientProperties &config, - const ArrowTableSchema &arrow_table) { - auto &filters_map = filter_collection.filters; - - py::object expression = py::none(); - for (auto &it : filters_map) { - auto column_idx = it.first; - auto &column_name = columns[column_idx]; - - vector column_ref; - column_ref.push_back(column_name); - - D_ASSERT(columns.find(column_idx) != columns.end()); - - auto &arrow_type = arrow_table.GetColumns().at(filter_to_col.at(column_idx)); - py::object child_expression = TransformFilterRecursive(*it.second, column_ref, config.time_zone, *arrow_type); - if (child_expression.is(py::none())) { - continue; - } else if (expression.is(py::none())) { - expression = std::move(child_expression); - } else { - expression = expression.attr("__and__")(child_expression); - } - } - return expression; -} - } // namespace duckdb diff --git a/src/duckdb_py/arrow/pyarrow_filter_pushdown.cpp b/src/duckdb_py/arrow/pyarrow_filter_pushdown.cpp new file mode 100644 index 00000000..66a6e3fa --- /dev/null +++ b/src/duckdb_py/arrow/pyarrow_filter_pushdown.cpp @@ -0,0 +1,336 @@ +#include "duckdb_python/arrow/pyarrow_filter_pushdown.hpp" + +#include "duckdb/common/types/value_map.hpp" +#include "duckdb/planner/filter/in_filter.hpp" +#include "duckdb/planner/filter/optional_filter.hpp" +#include "duckdb/planner/filter/conjunction_filter.hpp" +#include "duckdb/planner/filter/constant_filter.hpp" +#include "duckdb/planner/filter/struct_filter.hpp" +#include "duckdb/planner/table_filter.hpp" + +#include "duckdb_python/pyconnection/pyconnection.hpp" +#include "duckdb_python/pyrelation.hpp" +#include "duckdb_python/pyresult.hpp" +#include "duckdb/function/table/arrow.hpp" + +namespace duckdb { + +string ConvertTimestampUnit(ArrowDateTimeType unit) { + switch (unit) { + case ArrowDateTimeType::MICROSECONDS: + return "us"; + case ArrowDateTimeType::MILLISECONDS: + return "ms"; + case ArrowDateTimeType::NANOSECONDS: + return "ns"; + case ArrowDateTimeType::SECONDS: + return "s"; + default: + throw NotImplementedException("DatetimeType not recognized in ConvertTimestampUnit: %d", (int)unit); + } +} + +int64_t ConvertTimestampTZValue(int64_t base_value, ArrowDateTimeType datetime_type) { + auto input = timestamp_t(base_value); + if (!Timestamp::IsFinite(input)) { + return base_value; + } + + switch (datetime_type) { + case ArrowDateTimeType::MICROSECONDS: + return Timestamp::GetEpochMicroSeconds(input); + case ArrowDateTimeType::MILLISECONDS: + return Timestamp::GetEpochMs(input); + case ArrowDateTimeType::NANOSECONDS: + return Timestamp::GetEpochNanoSeconds(input); + case ArrowDateTimeType::SECONDS: + return Timestamp::GetEpochSeconds(input); + default: + throw NotImplementedException("DatetimeType not recognized in ConvertTimestampTZValue"); + } +} + +py::object GetScalar(Value &constant, const string &timezone_config, const ArrowType &type) { + auto &import_cache = *DuckDBPyConnection::ImportCache(); + auto scalar = import_cache.pyarrow.scalar(); + py::handle dataset_scalar = import_cache.pyarrow.dataset().attr("scalar"); + + switch (constant.type().id()) { + case LogicalTypeId::BOOLEAN: + return dataset_scalar(constant.GetValue()); + case LogicalTypeId::TINYINT: + return dataset_scalar(constant.GetValue()); + case LogicalTypeId::SMALLINT: + return dataset_scalar(constant.GetValue()); + case LogicalTypeId::INTEGER: + return dataset_scalar(constant.GetValue()); + case LogicalTypeId::BIGINT: + return dataset_scalar(constant.GetValue()); + case LogicalTypeId::DATE: { + py::handle date_type = import_cache.pyarrow.date32(); + return dataset_scalar(scalar(constant.GetValue(), date_type())); + } + case LogicalTypeId::TIME: { + py::handle date_type = import_cache.pyarrow.time64(); + return dataset_scalar(scalar(constant.GetValue(), date_type("us"))); + } + case LogicalTypeId::TIMESTAMP: { + py::handle date_type = import_cache.pyarrow.timestamp(); + return dataset_scalar(scalar(constant.GetValue(), date_type("us"))); + } + case LogicalTypeId::TIMESTAMP_MS: { + py::handle date_type = import_cache.pyarrow.timestamp(); + return dataset_scalar(scalar(constant.GetValue(), date_type("ms"))); + } + case LogicalTypeId::TIMESTAMP_NS: { + py::handle date_type = import_cache.pyarrow.timestamp(); + return dataset_scalar(scalar(constant.GetValue(), date_type("ns"))); + } + case LogicalTypeId::TIMESTAMP_SEC: { + py::handle date_type = import_cache.pyarrow.timestamp(); + return dataset_scalar(scalar(constant.GetValue(), date_type("s"))); + } + case LogicalTypeId::TIMESTAMP_TZ: { + auto &datetime_info = type.GetTypeInfo(); + auto base_value = constant.GetValue(); + auto arrow_datetime_type = datetime_info.GetDateTimeType(); + auto time_unit_string = ConvertTimestampUnit(arrow_datetime_type); + auto converted_value = ConvertTimestampTZValue(base_value, arrow_datetime_type); + py::handle date_type = import_cache.pyarrow.timestamp(); + return dataset_scalar(scalar(converted_value, date_type(time_unit_string, py::arg("tz") = timezone_config))); + } + case LogicalTypeId::UTINYINT: { + py::handle integer_type = import_cache.pyarrow.uint8(); + return dataset_scalar(scalar(constant.GetValue(), integer_type())); + } + case LogicalTypeId::USMALLINT: { + py::handle integer_type = import_cache.pyarrow.uint16(); + return dataset_scalar(scalar(constant.GetValue(), integer_type())); + } + case LogicalTypeId::UINTEGER: { + py::handle integer_type = import_cache.pyarrow.uint32(); + return dataset_scalar(scalar(constant.GetValue(), integer_type())); + } + case LogicalTypeId::UBIGINT: { + py::handle integer_type = import_cache.pyarrow.uint64(); + return dataset_scalar(scalar(constant.GetValue(), integer_type())); + } + case LogicalTypeId::FLOAT: + return dataset_scalar(constant.GetValue()); + case LogicalTypeId::DOUBLE: + return dataset_scalar(constant.GetValue()); + case LogicalTypeId::VARCHAR: + return dataset_scalar(constant.ToString()); + case LogicalTypeId::BLOB: { + if (type.GetTypeInfo().GetSizeType() == ArrowVariableSizeType::VIEW) { + py::handle binary_view_type = import_cache.pyarrow.binary_view(); + return dataset_scalar(scalar(py::bytes(constant.GetValueUnsafe()), binary_view_type())); + } + return dataset_scalar(py::bytes(constant.GetValueUnsafe())); + } + case LogicalTypeId::DECIMAL: { + py::handle decimal_type; + auto &datetime_info = type.GetTypeInfo(); + auto bit_width = datetime_info.GetBitWidth(); + switch (bit_width) { + case DecimalBitWidth::DECIMAL_32: + decimal_type = import_cache.pyarrow.decimal32(); + break; + case DecimalBitWidth::DECIMAL_64: + decimal_type = import_cache.pyarrow.decimal64(); + break; + case DecimalBitWidth::DECIMAL_128: + decimal_type = import_cache.pyarrow.decimal128(); + break; + default: + throw NotImplementedException("Unsupported precision for Arrow Decimal Type."); + } + + uint8_t width; + uint8_t scale; + constant.type().GetDecimalProperties(width, scale); + // pyarrow only allows 'decimal.Decimal' to be used to construct decimal scalars such as 0.05 + auto val = import_cache.decimal.Decimal()(constant.ToString()); + return dataset_scalar( + scalar(std::move(val), decimal_type(py::arg("precision") = width, py::arg("scale") = scale))); + } + default: + throw NotImplementedException("Unimplemented type \"%s\" for Arrow Filter Pushdown", + constant.type().ToString()); + } +} + +py::object TransformFilterRecursive(TableFilter &filter, vector column_ref, const string &timezone_config, + const ArrowType &type) { + auto &import_cache = *DuckDBPyConnection::ImportCache(); + py::object field = import_cache.pyarrow.dataset().attr("field"); + switch (filter.filter_type) { + case TableFilterType::CONSTANT_COMPARISON: { + auto &constant_filter = filter.Cast(); + auto constant_field = field(py::tuple(py::cast(column_ref))); + auto constant_value = GetScalar(constant_filter.constant, timezone_config, type); + + bool is_nan = false; + auto &constant = constant_filter.constant; + auto &constant_type = constant.type(); + if (constant_type.id() == LogicalTypeId::FLOAT) { + is_nan = Value::IsNan(constant.GetValue()); + } else if (constant_type.id() == LogicalTypeId::DOUBLE) { + is_nan = Value::IsNan(constant.GetValue()); + } + + // Special handling for NaN comparisons (to explicitly violate IEEE-754) + if (is_nan) { + switch (constant_filter.comparison_type) { + case ExpressionType::COMPARE_EQUAL: + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + return constant_field.attr("is_nan")(); + case ExpressionType::COMPARE_LESSTHAN: + case ExpressionType::COMPARE_NOTEQUAL: + return constant_field.attr("is_nan")().attr("__invert__")(); + case ExpressionType::COMPARE_GREATERTHAN: + // Nothing is greater than NaN + return import_cache.pyarrow.dataset().attr("scalar")(false); + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + // Everything is less than or equal to NaN + return import_cache.pyarrow.dataset().attr("scalar")(true); + default: + throw NotImplementedException("Unsupported comparison type (%s) for NaN values", + EnumUtil::ToString(constant_filter.comparison_type)); + } + } + + switch (constant_filter.comparison_type) { + case ExpressionType::COMPARE_EQUAL: + return constant_field.attr("__eq__")(constant_value); + case ExpressionType::COMPARE_LESSTHAN: + return constant_field.attr("__lt__")(constant_value); + case ExpressionType::COMPARE_GREATERTHAN: + return constant_field.attr("__gt__")(constant_value); + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + return constant_field.attr("__le__")(constant_value); + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + return constant_field.attr("__ge__")(constant_value); + case ExpressionType::COMPARE_NOTEQUAL: + return constant_field.attr("__ne__")(constant_value); + default: + throw NotImplementedException("Comparison Type %s can't be an Arrow Scan Pushdown Filter", + EnumUtil::ToString(constant_filter.comparison_type)); + } + } + //! We do not pushdown is null yet + case TableFilterType::IS_NULL: { + auto constant_field = field(py::tuple(py::cast(column_ref))); + return constant_field.attr("is_null")(); + } + case TableFilterType::IS_NOT_NULL: { + auto constant_field = field(py::tuple(py::cast(column_ref))); + return constant_field.attr("is_valid")(); + } + //! We do not pushdown or conjunctions yet + case TableFilterType::CONJUNCTION_OR: { + auto &or_filter = filter.Cast(); + py::object expression = py::none(); + for (idx_t i = 0; i < or_filter.child_filters.size(); i++) { + auto &child_filter = *or_filter.child_filters[i]; + py::object child_expression = TransformFilterRecursive(child_filter, column_ref, timezone_config, type); + if (child_expression.is(py::none())) { + continue; + } + if (expression.is(py::none())) { + expression = std::move(child_expression); + } else { + expression = expression.attr("__or__")(child_expression); + } + } + return expression; + } + case TableFilterType::CONJUNCTION_AND: { + auto &and_filter = filter.Cast(); + py::object expression = py::none(); + for (idx_t i = 0; i < and_filter.child_filters.size(); i++) { + auto &child_filter = *and_filter.child_filters[i]; + py::object child_expression = TransformFilterRecursive(child_filter, column_ref, timezone_config, type); + if (child_expression.is(py::none())) { + continue; + } + if (expression.is(py::none())) { + expression = std::move(child_expression); + } else { + expression = expression.attr("__and__")(child_expression); + } + } + return expression; + } + case TableFilterType::STRUCT_EXTRACT: { + auto &struct_filter = filter.Cast(); + auto &child_name = struct_filter.child_name; + auto &struct_type_info = type.GetTypeInfo(); + auto &struct_child_type = struct_type_info.GetChild(struct_filter.child_idx); + + column_ref.push_back(child_name); + auto child_expr = TransformFilterRecursive(*struct_filter.child_filter, std::move(column_ref), timezone_config, + struct_child_type); + return child_expr; + } + case TableFilterType::OPTIONAL_FILTER: { + auto &optional_filter = filter.Cast(); + if (!optional_filter.child_filter) { + return py::none(); + } + return TransformFilterRecursive(*optional_filter.child_filter, column_ref, timezone_config, type); + } + case TableFilterType::IN_FILTER: { + auto &in_filter = filter.Cast(); + ConjunctionOrFilter or_filter; + value_set_t unique_values; + for (const auto &value : in_filter.values) { + if (unique_values.find(value) == unique_values.end()) { + unique_values.insert(value); + } + } + for (const auto &value : unique_values) { + or_filter.child_filters.push_back(make_uniq(ExpressionType::COMPARE_EQUAL, value)); + } + return TransformFilterRecursive(or_filter, column_ref, timezone_config, type); + } + case TableFilterType::DYNAMIC_FILTER: { + //! Ignore dynamic filters for now, not necessary for correctness + return py::none(); + } + default: + throw NotImplementedException("Pushdown Filter Type %s is not currently supported in PyArrow Scans", + EnumUtil::ToString(filter.filter_type)); + } +} + +py::object PyArrowFilterPushdown::TransformFilter(TableFilterSet &filter_collection, + unordered_map &columns, + unordered_map filter_to_col, + const ClientProperties &config, const ArrowTableSchema &arrow_table) { + auto &filters_map = filter_collection.filters; + + py::object expression = py::none(); + for (auto &it : filters_map) { + auto column_idx = it.first; + auto &column_name = columns[column_idx]; + + vector column_ref; + column_ref.push_back(column_name); + + D_ASSERT(columns.find(column_idx) != columns.end()); + + auto &arrow_type = arrow_table.GetColumns().at(filter_to_col.at(column_idx)); + py::object child_expression = TransformFilterRecursive(*it.second, column_ref, config.time_zone, *arrow_type); + if (child_expression.is(py::none())) { + continue; + } else if (expression.is(py::none())) { + expression = std::move(child_expression); + } else { + expression = expression.attr("__and__")(child_expression); + } + } + return expression; +} + +} // namespace duckdb diff --git a/src/duckdb_py/common/exceptions.cpp b/src/duckdb_py/common/exceptions.cpp index 05ee73d7..51de2bdf 100644 --- a/src/duckdb_py/common/exceptions.cpp +++ b/src/duckdb_py/common/exceptions.cpp @@ -310,6 +310,12 @@ void PyThrowException(ErrorData &error, PyObject *http_exception) { } } +static void UnsetPythonException() { + if (PyErr_Occurred()) { + PyErr_Clear(); + } +} + /** * @see https://peps.python.org/pep-0249/#exceptions */ @@ -381,6 +387,7 @@ void RegisterExceptions(const py::module &m) { } } catch (const duckdb::Exception &ex) { duckdb::ErrorData error(ex); + UnsetPythonException(); PyThrowException(error, HTTP_EXCEPTION.ptr()); } catch (const py::builtin_exception &ex) { // These represent Python exceptions, we don't want to catch these @@ -391,6 +398,7 @@ void RegisterExceptions(const py::module &m) { // we need to pass non-DuckDB exceptions through as-is throw; } + UnsetPythonException(); PyThrowException(error, HTTP_EXCEPTION.ptr()); } }); diff --git a/src/duckdb_py/duckdb_python.cpp b/src/duckdb_py/duckdb_python.cpp index 939fa41a..1dd3ba17 100644 --- a/src/duckdb_py/duckdb_python.cpp +++ b/src/duckdb_py/duckdb_python.cpp @@ -20,6 +20,7 @@ #include "duckdb_python/pybind11/conversions/python_udf_type_enum.hpp" #include "duckdb_python/pybind11/conversions/python_csv_line_terminator_enum.hpp" #include "duckdb/common/enums/statement_type.hpp" +#include "duckdb/common/adbc/adbc-init.hpp" #include "duckdb.hpp" @@ -1007,7 +1008,35 @@ static void RegisterExpectedResultType(py::handle &m) { expected_return_type.export_values(); } +// ###################################################################### +// Symbol exports +// +// We want to limit the symbols we export to only the absolute minimum. +// This means we compile with -fvisibility=hidden to hide all symbols, +// and then explicitly export only the symbols we want. +// +// Right now we export two symbols only: +// - duckdb_adbc_init: the entrypoint for our ADBC driver +// - PyInit__duckdb: the entrypoint for the python extension +// +// All symbols that need exporting must be added to both the list below +// AND to CMakeLists.txt. +extern "C" { +PYBIND11_EXPORT void *_force_symbol_inclusion() { + static void *symbols[] = { + // Add functions to export here + (void *)&duckdb_adbc_init, + }; + return symbols; +} +}; + PYBIND11_MODULE(DUCKDB_PYTHON_LIB_NAME, m) { // NOLINT + // DO NOT REMOVE: the below forces that we include all symbols we want to export + volatile auto *keep_alive = _force_symbol_inclusion(); + (void)keep_alive; + // END + py::enum_(m, "ExplainType") .value("STANDARD", duckdb::ExplainType::EXPLAIN_STANDARD) .value("ANALYZE", duckdb::ExplainType::EXPLAIN_ANALYZE) diff --git a/src/duckdb_py/functional/functional.cpp b/src/duckdb_py/functional/functional.cpp index 6761a264..252634b1 100644 --- a/src/duckdb_py/functional/functional.cpp +++ b/src/duckdb_py/functional/functional.cpp @@ -3,8 +3,7 @@ namespace duckdb { void DuckDBPyFunctional::Initialize(py::module_ &parent) { - auto m = - parent.def_submodule("functional", "This module contains classes and methods related to functions and udf"); + auto m = parent.def_submodule("_func", "This module contains classes and methods related to functions and udf"); py::enum_(m, "PythonUDFType") .value("NATIVE", duckdb::PythonUDFType::NATIVE) diff --git a/src/duckdb_py/include/duckdb_python/arrow/arrow_array_stream.hpp b/src/duckdb_py/include/duckdb_python/arrow/arrow_array_stream.hpp index 7eb6d20b..a5895b4a 100644 --- a/src/duckdb_py/include/duckdb_python/arrow/arrow_array_stream.hpp +++ b/src/duckdb_py/include/duckdb_python/arrow/arrow_array_stream.hpp @@ -86,11 +86,6 @@ class PythonTableArrowArrayStreamFactory { DBConfig &config; private: - //! We transform a TableFilterSet to an Arrow Expression Object - static py::object TransformFilter(TableFilterSet &filters, std::unordered_map &columns, - unordered_map filter_to_col, - const ClientProperties &client_properties, const ArrowTableSchema &arrow_table); - static py::object ProduceScanner(DBConfig &config, py::object &arrow_scanner, py::handle &arrow_obj_handle, ArrowStreamParameters ¶meters, const ClientProperties &client_properties); }; diff --git a/src/duckdb_py/include/duckdb_python/arrow/pyarrow_filter_pushdown.hpp b/src/duckdb_py/include/duckdb_python/arrow/pyarrow_filter_pushdown.hpp new file mode 100644 index 00000000..4cc85a47 --- /dev/null +++ b/src/duckdb_py/include/duckdb_python/arrow/pyarrow_filter_pushdown.hpp @@ -0,0 +1,26 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb_python/arrow/pyarrow_filter_pushdown.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/arrow/arrow_wrapper.hpp" +#include "duckdb/function/table/arrow/arrow_duck_schema.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/planner/table_filter.hpp" +#include "duckdb/main/client_properties.hpp" +#include "duckdb_python/pybind11/pybind_wrapper.hpp" + +namespace duckdb { + +struct PyArrowFilterPushdown { + static py::object TransformFilter(TableFilterSet &filter_collection, unordered_map &columns, + unordered_map filter_to_col, const ClientProperties &config, + const ArrowTableSchema &arrow_table); +}; + +} // namespace duckdb diff --git a/src/duckdb_py/include/duckdb_python/import_cache/modules/pyarrow_module.hpp b/src/duckdb_py/include/duckdb_python/import_cache/modules/pyarrow_module.hpp index ccd8a16d..d3331565 100644 --- a/src/duckdb_py/include/duckdb_python/import_cache/modules/pyarrow_module.hpp +++ b/src/duckdb_py/include/duckdb_python/import_cache/modules/pyarrow_module.hpp @@ -56,7 +56,10 @@ struct PyarrowCacheItem : public PythonImportCacheItem { public: PyarrowCacheItem() : PythonImportCacheItem("pyarrow"), dataset(), Table("Table", this), - RecordBatchReader("RecordBatchReader", this), ipc(this) { + RecordBatchReader("RecordBatchReader", this), ipc(this), scalar("scalar", this), date32("date32", this), + time64("time64", this), timestamp("timestamp", this), uint8("uint8", this), uint16("uint16", this), + uint32("uint32", this), uint64("uint64", this), binary_view("binary_view", this), + decimal32("decimal32", this), decimal64("decimal64", this), decimal128("decimal128", this) { } ~PyarrowCacheItem() override { } @@ -65,6 +68,18 @@ struct PyarrowCacheItem : public PythonImportCacheItem { PythonImportCacheItem Table; PythonImportCacheItem RecordBatchReader; PyarrowIpcCacheItem ipc; + PythonImportCacheItem scalar; + PythonImportCacheItem date32; + PythonImportCacheItem time64; + PythonImportCacheItem timestamp; + PythonImportCacheItem uint8; + PythonImportCacheItem uint16; + PythonImportCacheItem uint32; + PythonImportCacheItem uint64; + PythonImportCacheItem binary_view; + PythonImportCacheItem decimal32; + PythonImportCacheItem decimal64; + PythonImportCacheItem decimal128; }; } // namespace duckdb diff --git a/src/duckdb_py/include/duckdb_python/pybind11/pybind_wrapper.hpp b/src/duckdb_py/include/duckdb_python/pybind11/pybind_wrapper.hpp index a492db9a..d51ddea2 100644 --- a/src/duckdb_py/include/duckdb_python/pybind11/pybind_wrapper.hpp +++ b/src/duckdb_py/include/duckdb_python/pybind11/pybind_wrapper.hpp @@ -90,7 +90,7 @@ bool try_cast(const handle &object, T &result) { } // namespace py template -void DefineMethod(std::vector aliases, T &mod, ARGS &&... args) { +void DefineMethod(std::vector aliases, T &mod, ARGS &&...args) { for (auto &alias : aliases) { mod.def(alias, args...); } diff --git a/src/duckdb_py/pyexpression/initialize.cpp b/src/duckdb_py/pyexpression/initialize.cpp index 2d2d6af9..11cf5dc3 100644 --- a/src/duckdb_py/pyexpression/initialize.cpp +++ b/src/duckdb_py/pyexpression/initialize.cpp @@ -28,8 +28,7 @@ void InitializeStaticMethods(py::module_ &m) { // Star Expression docs = ""; m.def("StarExpression", &DuckDBPyExpression::StarExpression, py::kw_only(), py::arg("exclude") = py::none(), docs); - m.def( - "StarExpression", []() { return DuckDBPyExpression::StarExpression(); }, docs); + m.def("StarExpression", []() { return DuckDBPyExpression::StarExpression(); }, docs); // Function Expression docs = ""; @@ -63,7 +62,8 @@ static void InitializeDunderMethods(py::class_> &m) { diff --git a/src/duckdb_py/pyrelation/initialize.cpp b/src/duckdb_py/pyrelation/initialize.cpp index 7992cc17..cd1f042c 100644 --- a/src/duckdb_py/pyrelation/initialize.cpp +++ b/src/duckdb_py/pyrelation/initialize.cpp @@ -61,8 +61,8 @@ static void InitializeConsumers(py::class_ &m) { py::arg("date_as_object") = false) .def("fetch_df_chunk", &DuckDBPyRelation::FetchDFChunk, "Execute and fetch a chunk of the rows", py::arg("vectors_per_chunk") = 1, py::kw_only(), py::arg("date_as_object") = false) - .def("arrow", &DuckDBPyRelation::ToRecordBatch, "Execute and return an Arrow Record Batch Reader that yields all rows", - py::arg("batch_size") = 1000000) + .def("arrow", &DuckDBPyRelation::ToRecordBatch, + "Execute and return an Arrow Record Batch Reader that yields all rows", py::arg("batch_size") = 1000000) .def("fetch_arrow_table", &DuckDBPyRelation::ToArrowTable, "Execute and fetch all rows as an Arrow Table", py::arg("batch_size") = 1000000) .def("to_arrow_table", &DuckDBPyRelation::ToArrowTable, "Execute and fetch all rows as an Arrow Table", @@ -80,16 +80,16 @@ static void InitializeConsumers(py::class_ &m) { py::arg("requested_schema") = py::none()); m.def("fetch_record_batch", &DuckDBPyRelation::ToRecordBatch, "Execute and return an Arrow Record Batch Reader that yields all rows", py::arg("rows_per_batch") = 1000000) - .def("fetch_arrow_reader", &DuckDBPyRelation::ToRecordBatch, + .def("fetch_arrow_reader", &DuckDBPyRelation::ToRecordBatch, "Execute and return an Arrow Record Batch Reader that yields all rows", py::arg("batch_size") = 1000000) - .def("record_batch", - [](pybind11::object &self, idx_t rows_per_batch) - { - PyErr_WarnEx(PyExc_DeprecationWarning, - "record_batch() is deprecated, use fetch_record_batch() instead.", - 0); - return self.attr("fetch_record_batch")(rows_per_batch); - }, py::arg("batch_size") = 1000000); + .def( + "record_batch", + [](pybind11::object &self, idx_t rows_per_batch) { + PyErr_WarnEx(PyExc_DeprecationWarning, + "record_batch() is deprecated, use fetch_record_batch() instead.", 0); + return self.attr("fetch_record_batch")(rows_per_batch); + }, + py::arg("batch_size") = 1000000); } static void InitializeAggregates(py::class_ &m) { diff --git a/src/duckdb_py/typing/pytype.cpp b/src/duckdb_py/typing/pytype.cpp index 009e3dab..eca92bed 100644 --- a/src/duckdb_py/typing/pytype.cpp +++ b/src/duckdb_py/typing/pytype.cpp @@ -326,8 +326,11 @@ void DuckDBPyType::Initialize(py::handle &m) { auto type_module = py::class_>(m, "DuckDBPyType", py::module_local()); type_module.def("__repr__", &DuckDBPyType::ToString, "Stringified representation of the type object"); - type_module.def("__eq__", &DuckDBPyType::Equals, "Compare two types for equality", py::arg("other"), py::is_operator()); - type_module.def("__eq__", &DuckDBPyType::EqualsString, "Compare two types for equality", py::arg("other"), py::is_operator()); + type_module.def("__eq__", &DuckDBPyType::Equals, "Compare two types for equality", py::arg("other"), + py::is_operator()); + type_module.def("__eq__", &DuckDBPyType::EqualsString, "Compare two types for equality", py::arg("other"), + py::is_operator()); + type_module.def("__hash__", [](const DuckDBPyType &type) { return py::hash(py::str(type.ToString())); }); type_module.def_property_readonly("id", &DuckDBPyType::GetId); type_module.def_property_readonly("children", &DuckDBPyType::Children); type_module.def(py::init<>([](const string &type_str, shared_ptr connection = nullptr) { @@ -347,7 +350,8 @@ void DuckDBPyType::Initialize(py::handle &m) { return make_shared_ptr(ltype); })); type_module.def("__getattr__", &DuckDBPyType::GetAttribute, "Get the child type by 'name'", py::arg("name")); - type_module.def("__getitem__", &DuckDBPyType::GetAttribute, "Get the child type by 'name'", py::arg("name"), py::is_operator()); + type_module.def("__getitem__", &DuckDBPyType::GetAttribute, "Get the child type by 'name'", py::arg("name"), + py::is_operator()); py::implicitly_convertible(); py::implicitly_convertible(); diff --git a/src/duckdb_py/typing/typing.cpp b/src/duckdb_py/typing/typing.cpp index c0e2675e..fe990de1 100644 --- a/src/duckdb_py/typing/typing.cpp +++ b/src/duckdb_py/typing/typing.cpp @@ -39,7 +39,7 @@ static void DefineBaseTypes(py::handle &m) { } void DuckDBPyTyping::Initialize(py::module_ &parent) { - auto m = parent.def_submodule("typing", "This module contains classes and methods related to typing"); + auto m = parent.def_submodule("_sqltypes", "This module contains classes and methods related to typing"); DuckDBPyType::Initialize(m); DefineBaseTypes(m); diff --git a/tests/conftest.py b/tests/conftest.py index 5e297aee..df64f86c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,13 @@ import os import sys -import pytest -import shutil -from os.path import abspath, join, dirname, normpath -import glob -import duckdb import warnings from importlib import import_module -import sys +from pathlib import Path +from typing import Any, Union + +import pytest + +import duckdb try: # need to ignore warnings that might be thrown deep inside pandas's import tree (from dateutil in this case) @@ -22,7 +22,7 @@ # Only install mock after we've failed to import pandas for conftest.py class MockPandas: - def __getattr__(self, name): + def __getattr__(self, name: str) -> object: pytest.skip("pandas not available", allow_module_level=True) sys.modules["pandas"] = MockPandas() @@ -56,19 +56,39 @@ def pytest_addoption(parser): @pytest.hookimpl(hookwrapper=True) def pytest_runtest_call(item): - """Convert pandas requirement exceptions to skips""" - + """Convert pandas requirement exceptions and missing pyarrow imports to skips.""" outcome = yield - # TODO: Remove skip when Pandas releases for 3.14. After, consider bumping to 3.15 - if sys.version_info[:2] == (3, 14): + # TODO: Remove skip when Pandas releases for 3.14. After, consider bumping to 3.15 # noqa: TD002,TD003 + if sys.version_info[:2] == (3, 14): try: outcome.get_result() - except duckdb.InvalidInputException as e: - if "'pandas' is required for this operation but it was not installed" in str(e): - pytest.skip("pandas not available - test requires pandas functionality") + except (duckdb.InvalidInputException, ImportError) as e: + if isinstance(e, ImportError) and e.name == "pyarrow": + pytest.skip(f"pyarrow not available - {item.name} requires pyarrow") + elif "'pandas' is required for this operation but it was not installed" in str(e): + pytest.skip(f"pandas not available - {item.name} requires pandas functionality") else: - raise e + raise + + +@pytest.hookimpl(hookwrapper=True) +def pytest_make_collect_report(collector): + """Wrap module collection to catch pyarrow import errors on Python 3.14. + + If we're on Python 3.14 and a test module raises ModuleNotFoundError + for 'pyarrow', mark the entire module as xfailed rather than failing collection. + """ + outcome = yield + result = outcome.get_result() + + if sys.version_info[:2] == (3, 14): + # Only handle failures from module collectors + if result.failed and collector.__class__.__name__ == "Module": + longrepr = str(result.longrepr) + if "ModuleNotFoundError: No module named 'pyarrow'" in longrepr: + result.outcome = "skipped" + result.longrepr = f"XFAIL: pyarrow not available {collector.name} ({longrepr.strip()})" def pytest_collection_modifyitems(config, items): @@ -92,19 +112,20 @@ def pytest_collection_modifyitems(config, items): item.add_marker(skip_listed) -@pytest.fixture(scope="function") +@pytest.fixture def duckdb_empty_cursor(request): - connection = duckdb.connect('') + connection = duckdb.connect("") cursor = connection.cursor() return cursor -def getTimeSeriesData(nper=None, freq: "Frequency" = "B"): - from pandas import DatetimeIndex, bdate_range, Series +def getTimeSeriesData(nper=None, freq: "Frequency" = "B"): # noqa: F821 + import string from datetime import datetime - from pandas._typing import Frequency + import numpy as np - import string + from pandas import DatetimeIndex, Series, bdate_range + from pandas._typing import Frequency _N = 30 _K = 4 @@ -128,14 +149,14 @@ def makeTimeSeries(nper=None, freq: Frequency = "B", name=None) -> Series: def pandas_2_or_higher(): from packaging.version import Version - return Version(import_pandas().__version__) >= Version('2.0.0') + return Version(import_pandas().__version__) >= Version("2.0.0") def pandas_supports_arrow_backend(): try: from pandas.compat import pa_version_under11p0 - if pa_version_under11p0 == True: + if pa_version_under11p0: return False except ImportError: return False @@ -152,12 +173,12 @@ def arrow_pandas_df(*args, **kwargs): class NumpyPandas: - def __init__(self): - self.backend = 'numpy_nullable' + def __init__(self) -> None: + self.backend = "numpy_nullable" self.DataFrame = numpy_pandas_df self.pandas = import_pandas() - def __getattr__(self, name: str): + def __getattr__(self, name: str) -> Any: # noqa: ANN401 return getattr(self.pandas, name) @@ -174,7 +195,7 @@ def convert_to_numpy(df): if ( pyarrow_dtypes_enabled and pyarrow_dtype is not None - and any([True for x in df.dtypes if isinstance(x, pyarrow_dtype)]) + and any(True for x in df.dtypes if isinstance(x, pyarrow_dtype)) ): return convert_arrow_to_numpy_backend(df) return df @@ -187,11 +208,11 @@ def convert_and_equal(df1, df2, **kwargs): class ArrowMockTesting: - def __init__(self): + def __init__(self) -> None: self.testing = import_pandas().testing self.assert_frame_equal = convert_and_equal - def __getattr__(self, name: str): + def __getattr__(self, name: str) -> Any: # noqa: ANN401 return getattr(self.testing, name) @@ -199,84 +220,79 @@ def __getattr__(self, name: str): # Assert equal does the opposite, turning all pyarrow backed dataframes into numpy backed ones # this is done because we don't produce pyarrow backed dataframes yet class ArrowPandas: - def __init__(self): + def __init__(self) -> None: self.pandas = import_pandas() if pandas_2_or_higher() and pyarrow_dtypes_enabled: - self.backend = 'pyarrow' + self.backend = "pyarrow" self.DataFrame = arrow_pandas_df else: # For backwards compatible reasons, just mock regular pandas - self.backend = 'numpy_nullable' + self.backend = "numpy_nullable" self.DataFrame = self.pandas.DataFrame self.testing = ArrowMockTesting() - def __getattr__(self, name: str): + def __getattr__(self, name: str) -> Any: # noqa: ANN401 return getattr(self.pandas, name) -@pytest.fixture(scope="function") +@pytest.fixture def require(): - def _require(extension_name, db_name=''): + def _require(extension_name, db_name="") -> Union[duckdb.DuckDBPyConnection, None]: # Paths to search for extensions - build = normpath(join(dirname(__file__), "../../../build/")) + build = Path(__file__).parent.parent / "build" extension = "extension/*/*.duckdb_extension" extension_search_patterns = [ - join(build, "release", extension), - join(build, "debug", extension), + build / "release" / extension, + build / "debug" / extension, ] # DUCKDB_PYTHON_TEST_EXTENSION_PATH can be used to add a path for the extension test to search for extensions - if 'DUCKDB_PYTHON_TEST_EXTENSION_PATH' in os.environ: - env_extension_path = os.getenv('DUCKDB_PYTHON_TEST_EXTENSION_PATH') - env_extension_path = env_extension_path.rstrip('/') - extension_search_patterns.append(env_extension_path + '/*/*.duckdb_extension') - extension_search_patterns.append(env_extension_path + '/*.duckdb_extension') + if "DUCKDB_PYTHON_TEST_EXTENSION_PATH" in os.environ: + env_extension_path = os.getenv("DUCKDB_PYTHON_TEST_EXTENSION_PATH") + env_extension_path = env_extension_path.rstrip("/") + extension_search_patterns.append(env_extension_path + "/*/*.duckdb_extension") + extension_search_patterns.append(env_extension_path + "/*.duckdb_extension") extension_paths_found = [] for pattern in extension_search_patterns: - extension_pattern_abs = abspath(pattern) - print(f"Searching path: {extension_pattern_abs}") - for path in glob.glob(extension_pattern_abs): - extension_paths_found.append(path) + extension_paths_found.extend(list(Path(pattern).resolve().glob("*"))) for path in extension_paths_found: print(path) if path.endswith(extension_name + ".duckdb_extension"): - conn = duckdb.connect(db_name, config={'allow_unsigned_extensions': 'true'}) + conn = duckdb.connect(db_name, config={"allow_unsigned_extensions": "true"}) conn.execute(f"LOAD '{path}'") return conn - pytest.skip(f'could not load {extension_name}') + pytest.skip(f"could not load {extension_name}") return _require # By making the scope 'function' we ensure that a new connection gets created for every function that uses the fixture -@pytest.fixture(scope='function') +@pytest.fixture def spark(): - from spark_namespace import USE_ACTUAL_SPARK - - if not hasattr(spark, 'session'): + if not hasattr(spark, "session"): # Cache the import from spark_namespace.sql import SparkSession as session spark.session = session - return spark.session.builder.appName('pyspark').getOrCreate() + return spark.session.builder.appName("pyspark").getOrCreate() -@pytest.fixture(scope='function') +@pytest.fixture def duckdb_cursor(): - connection = duckdb.connect('') + connection = duckdb.connect("") yield connection connection.close() -@pytest.fixture(scope='function') +@pytest.fixture def integers(duckdb_cursor): cursor = duckdb_cursor - cursor.execute('CREATE TABLE integers (i integer)') + cursor.execute("CREATE TABLE integers (i integer)") cursor.execute( """ INSERT INTO integers VALUES @@ -297,42 +313,10 @@ def integers(duckdb_cursor): cursor.execute("drop table integers") -@pytest.fixture(scope='function') +@pytest.fixture def timestamps(duckdb_cursor): cursor = duckdb_cursor - cursor.execute('CREATE TABLE timestamps (t timestamp)') + cursor.execute("CREATE TABLE timestamps (t timestamp)") cursor.execute("INSERT INTO timestamps VALUES ('1992-10-03 18:34:45'), ('2010-01-01 00:00:01'), (NULL)") yield cursor.execute("drop table timestamps") - - -@pytest.fixture(scope="function") -def duckdb_cursor_autocommit(request, tmp_path): - test_dbfarm = tmp_path.resolve().as_posix() - - def finalizer(): - duckdb.shutdown() - if tmp_path.is_dir(): - shutil.rmtree(test_dbfarm) - - request.addfinalizer(finalizer) - - connection = duckdb.connect(test_dbfarm) - connection.set_autocommit(True) - cursor = connection.cursor() - return (cursor, connection, test_dbfarm) - - -@pytest.fixture(scope="function") -def initialize_duckdb(request, tmp_path): - test_dbfarm = tmp_path.resolve().as_posix() - - def finalizer(): - duckdb.shutdown() - if tmp_path.is_dir(): - shutil.rmtree(test_dbfarm) - - request.addfinalizer(finalizer) - - duckdb.connect(test_dbfarm) - return test_dbfarm diff --git a/tests/coverage/test_pandas_categorical_coverage.py b/tests/coverage/test_pandas_categorical_coverage.py index e20afa72..7b0645e0 100644 --- a/tests/coverage/test_pandas_categorical_coverage.py +++ b/tests/coverage/test_pandas_categorical_coverage.py @@ -1,7 +1,7 @@ -import duckdb -import numpy import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import NumpyPandas + +import duckdb def check_result_list(res): @@ -15,17 +15,17 @@ def check_create_table(category, pandas): conn.execute("PRAGMA enable_verification") df_in = pandas.DataFrame( { - 'x': pandas.Categorical(category, ordered=True), - 'y': pandas.Categorical(category, ordered=True), - 'z': category, + "x": pandas.Categorical(category, ordered=True), + "y": pandas.Categorical(category, ordered=True), + "z": category, } ) - category.append('bla') + category.append("bla") - df_in_diff = pandas.DataFrame( + df_in_diff = pandas.DataFrame( # noqa: F841 { - 'k': pandas.Categorical(category, ordered=True), + "k": pandas.Categorical(category, ordered=True), } ) @@ -44,7 +44,7 @@ def check_create_table(category, pandas): conn.execute("INSERT INTO t1 VALUES ('2','2','2')") res = conn.execute("SELECT x FROM t1 where x = '1'").fetchall() - assert res == [('1',)] + assert res == [("1",)] res = conn.execute("SELECT t1.x FROM t1 inner join t2 on (t1.x = t2.x) order by t1.x").fetchall() assert res == conn.execute("SELECT x FROM t1 order by t1.x").fetchall() @@ -68,18 +68,14 @@ def check_create_table(category, pandas): conn.execute("DROP TABLE t1") -# TODO: extend tests with ArrowPandas -class TestCategory(object): - @pytest.mark.parametrize('pandas', [NumpyPandas()]) +# TODO: extend tests with ArrowPandas # noqa: TD002, TD003 +class TestCategory: + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_category_string_uint16(self, duckdb_cursor, pandas): - category = [] - for i in range(300): - category.append(str(i)) + category = [str(i) for i in range(300)] check_create_table(category, pandas) - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_category_string_uint32(self, duckdb_cursor, pandas): - category = [] - for i in range(70000): - category.append(str(i)) + category = [str(i) for i in range(70000)] check_create_table(category, pandas) diff --git a/tests/extensions/json/test_read_json.py b/tests/extensions/json/test_read_json.py index 48590175..f431906b 100644 --- a/tests/extensions/json/test_read_json.py +++ b/tests/extensions/json/test_read_json.py @@ -1,59 +1,56 @@ -import numpy -import datetime -import pandas +from io import StringIO + import pytest + import duckdb -import re -from io import StringIO def TestFile(name): - import os + from pathlib import Path - filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', name) - return filename + return str(Path(__file__).parent / "data" / name) -class TestReadJSON(object): +class TestReadJSON: def test_read_json_columns(self): - rel = duckdb.read_json(TestFile('example.json'), columns={'id': 'integer', 'name': 'varchar'}) + rel = duckdb.read_json(TestFile("example.json"), columns={"id": "integer", "name": "varchar"}) res = rel.fetchone() print(res) - assert res == (1, 'O Brother, Where Art Thou?') + assert res == (1, "O Brother, Where Art Thou?") def test_read_json_auto(self): - rel = duckdb.read_json(TestFile('example.json')) + rel = duckdb.read_json(TestFile("example.json")) res = rel.fetchone() print(res) - assert res == (1, 'O Brother, Where Art Thou?') + assert res == (1, "O Brother, Where Art Thou?") def test_read_json_maximum_depth(self): - rel = duckdb.read_json(TestFile('example.json'), maximum_depth=4) + rel = duckdb.read_json(TestFile("example.json"), maximum_depth=4) res = rel.fetchone() print(res) - assert res == (1, 'O Brother, Where Art Thou?') + assert res == (1, "O Brother, Where Art Thou?") def test_read_json_sample_size(self): - rel = duckdb.read_json(TestFile('example.json'), sample_size=2) + rel = duckdb.read_json(TestFile("example.json"), sample_size=2) res = rel.fetchone() print(res) - assert res == (1, 'O Brother, Where Art Thou?') + assert res == (1, "O Brother, Where Art Thou?") def test_read_json_format(self): # Wrong option - with pytest.raises(duckdb.BinderException, match="format must be one of .* not 'test'"): - rel = duckdb.read_json(TestFile('example.json'), format='test') + with pytest.raises(duckdb.BinderException, match=r"format must be one of .* not 'test'"): + rel = duckdb.read_json(TestFile("example.json"), format="test") - rel = duckdb.read_json(TestFile('example.json'), format='unstructured') + rel = duckdb.read_json(TestFile("example.json"), format="unstructured") res = rel.fetchone() print(res) assert res == ( [ - {'id': 1, 'name': 'O Brother, Where Art Thou?'}, - {'id': 2, 'name': 'Home for the Holidays'}, - {'id': 3, 'name': 'The Firm'}, - {'id': 4, 'name': 'Broadcast News'}, - {'id': 5, 'name': 'Raising Arizona'}, + {"id": 1, "name": "O Brother, Where Art Thou?"}, + {"id": 2, "name": "Home for the Holidays"}, + {"id": 3, "name": "The Firm"}, + {"id": 4, "name": "Broadcast News"}, + {"id": 5, "name": "Raising Arizona"}, ], ) @@ -63,13 +60,13 @@ def test_read_filelike(self, duckdb_cursor): duckdb_cursor.execute("set threads=1") string = StringIO("""{"id":1,"name":"O Brother, Where Art Thou?"}\n{"id":2,"name":"Home for the Holidays"}""") res = duckdb_cursor.read_json(string).fetchall() - assert res == [(1, 'O Brother, Where Art Thou?'), (2, 'Home for the Holidays')] + assert res == [(1, "O Brother, Where Art Thou?"), (2, "Home for the Holidays")] string1 = StringIO("""{"id":1,"name":"O Brother, Where Art Thou?"}""") string2 = StringIO("""{"id":2,"name":"Home for the Holidays"}""") res = duckdb_cursor.read_json([string1, string2], filename=True).fetchall() - assert res[0][1] == 'O Brother, Where Art Thou?' - assert res[1][1] == 'Home for the Holidays' + assert res[0][1] == "O Brother, Where Art Thou?" + assert res[1][1] == "Home for the Holidays" # filenames are different assert res[0][2] != res[1][2] @@ -77,51 +74,51 @@ def test_read_filelike(self, duckdb_cursor): def test_read_json_records(self): # Wrong option with pytest.raises(duckdb.BinderException, match="""read_json requires "records" to be one of"""): - rel = duckdb.read_json(TestFile('example.json'), records='none') + rel = duckdb.read_json(TestFile("example.json"), records="none") - rel = duckdb.read_json(TestFile('example.json'), records='true') + rel = duckdb.read_json(TestFile("example.json"), records="true") res = rel.fetchone() print(res) - assert res == (1, 'O Brother, Where Art Thou?') + assert res == (1, "O Brother, Where Art Thou?") @pytest.mark.parametrize( - 'option', + "option", [ - ('filename', True), - ('filename', 'test'), - ('date_format', '%m-%d-%Y'), - ('date_format', '%m-%d-%y'), - ('date_format', '%d-%m-%Y'), - ('date_format', '%d-%m-%y'), - ('date_format', '%Y-%m-%d'), - ('date_format', '%y-%m-%d'), - ('timestamp_format', '%H:%M:%S%y-%m-%d'), - ('compression', 'AUTO_DETECT'), - ('compression', 'UNCOMPRESSED'), - ('maximum_object_size', 5), - ('ignore_errors', False), - ('ignore_errors', True), - ('convert_strings_to_integers', False), - ('convert_strings_to_integers', True), - ('field_appearance_threshold', 0.534), - ('map_inference_threshold', 34234), - ('maximum_sample_files', 5), - ('hive_partitioning', True), - ('hive_partitioning', False), - ('union_by_name', True), - ('union_by_name', False), - ('hive_types_autocast', False), - ('hive_types_autocast', True), - ('hive_types', {'id': 'INTEGER', 'name': 'VARCHAR'}), + ("filename", True), + ("filename", "test"), + ("date_format", "%m-%d-%Y"), + ("date_format", "%m-%d-%y"), + ("date_format", "%d-%m-%Y"), + ("date_format", "%d-%m-%y"), + ("date_format", "%Y-%m-%d"), + ("date_format", "%y-%m-%d"), + ("timestamp_format", "%H:%M:%S%y-%m-%d"), + ("compression", "AUTO_DETECT"), + ("compression", "UNCOMPRESSED"), + ("maximum_object_size", 5), + ("ignore_errors", False), + ("ignore_errors", True), + ("convert_strings_to_integers", False), + ("convert_strings_to_integers", True), + ("field_appearance_threshold", 0.534), + ("map_inference_threshold", 34234), + ("maximum_sample_files", 5), + ("hive_partitioning", True), + ("hive_partitioning", False), + ("union_by_name", True), + ("union_by_name", False), + ("hive_types_autocast", False), + ("hive_types_autocast", True), + ("hive_types", {"id": "INTEGER", "name": "VARCHAR"}), ], ) def test_read_json_options(self, duckdb_cursor, option): - keyword_arguments = dict() + keyword_arguments = {} option_name, option_value = option keyword_arguments[option_name] = option_value - if option_name == 'hive_types': - with pytest.raises(duckdb.InvalidInputException, match=r'Unknown hive_type:'): - rel = duckdb_cursor.read_json(TestFile('example.json'), **keyword_arguments) + if option_name == "hive_types": + with pytest.raises(duckdb.InvalidInputException, match=r"Unknown hive_type:"): + rel = duckdb_cursor.read_json(TestFile("example.json"), **keyword_arguments) else: - rel = duckdb_cursor.read_json(TestFile('example.json'), **keyword_arguments) - res = rel.fetchall() + rel = duckdb_cursor.read_json(TestFile("example.json"), **keyword_arguments) + rel.fetchall() diff --git a/tests/extensions/test_extensions_loading.py b/tests/extensions/test_extensions_loading.py index 2b4eab0c..8fbbd974 100644 --- a/tests/extensions/test_extensions_loading.py +++ b/tests/extensions/test_extensions_loading.py @@ -1,10 +1,9 @@ import os import platform -import duckdb -from pytest import raises import pytest +import duckdb pytestmark = pytest.mark.skipif( platform.system() == "Emscripten", @@ -13,9 +12,9 @@ def test_extension_loading(require): - if not os.getenv('DUCKDB_PYTHON_TEST_EXTENSION_REQUIRED', False): + if not os.getenv("DUCKDB_PYTHON_TEST_EXTENSION_REQUIRED", False): return - extensions_list = ['json', 'excel', 'httpfs', 'tpch', 'tpcds', 'icu', 'fts'] + extensions_list = ["json", "excel", "httpfs", "tpch", "tpcds", "icu", "fts"] for extension in extensions_list: connection = require(extension) assert connection is not None @@ -25,17 +24,17 @@ def test_install_non_existent_extension(): conn = duckdb.connect() conn.execute("set custom_extension_repository = 'http://example.com'") - with raises(duckdb.IOException) as exc: - conn.install_extension('non-existent') + with pytest.raises(duckdb.IOException) as exc: + conn.install_extension("non-existent") - if not isinstance(exc, duckdb.HTTPException): - pytest.skip(reason='This test does not throw an HTTPException, only an IOException') - value = exc.value + if not isinstance(exc, duckdb.HTTPException): + pytest.skip(reason="This test does not throw an HTTPException, only an IOException") + value = exc.value - assert value.status_code == 404 - assert value.reason == 'Not Found' - assert 'Example Domain' in value.body - assert 'Content-Length' in value.headers + assert value.status_code == 404 + assert value.reason == "Not Found" + assert "Example Domain" in value.body + assert "Content-Length" in value.headers def test_install_misuse_errors(duckdb_cursor): @@ -43,17 +42,17 @@ def test_install_misuse_errors(duckdb_cursor): duckdb.InvalidInputException, match="Both 'repository' and 'repository_url' are set which is not allowed, please pick one or the other", ): - duckdb_cursor.install_extension('name', repository='hello', repository_url='hello.com') + duckdb_cursor.install_extension("name", repository="hello", repository_url="hello.com") with pytest.raises( duckdb.InvalidInputException, match="The provided 'repository' or 'repository_url' can not be empty!" ): - duckdb_cursor.install_extension('name', repository_url='') + duckdb_cursor.install_extension("name", repository_url="") with pytest.raises( duckdb.InvalidInputException, match="The provided 'repository' or 'repository_url' can not be empty!" ): - duckdb_cursor.install_extension('name', repository='') + duckdb_cursor.install_extension("name", repository="") with pytest.raises(duckdb.InvalidInputException, match="The provided 'version' can not be empty!"): - duckdb_cursor.install_extension('name', version='') + duckdb_cursor.install_extension("name", version="") diff --git a/tests/extensions/test_httpfs.py b/tests/extensions/test_httpfs.py index 6366e07f..26ce917c 100644 --- a/tests/extensions/test_httpfs.py +++ b/tests/extensions/test_httpfs.py @@ -1,82 +1,83 @@ -import duckdb +import datetime import os -from pytest import raises, mark + import pytest -from conftest import NumpyPandas, ArrowPandas -import datetime +from conftest import ArrowPandas, NumpyPandas + +import duckdb # We only run this test if this env var is set -# FIXME: we can add a custom command line argument to pytest to provide an extension directory +# TODO: we can add a custom command line argument to pytest to provide an extension directory # noqa: TD002, TD003 # We can use that instead of checking this environment variable inside of conftest.py's 'require' method -pytestmark = mark.skipif( - not os.getenv('DUCKDB_PYTHON_TEST_EXTENSION_REQUIRED', False), - reason='DUCKDB_PYTHON_TEST_EXTENSION_REQUIRED is not set', +pytestmark = pytest.mark.skipif( + not os.getenv("DUCKDB_PYTHON_TEST_EXTENSION_REQUIRED", False), + reason="DUCKDB_PYTHON_TEST_EXTENSION_REQUIRED is not set", ) -class TestHTTPFS(object): +class TestHTTPFS: def test_read_json_httpfs(self, require): - connection = require('httpfs') + connection = require("httpfs") try: - res = connection.read_json('https://jsonplaceholder.typicode.com/todos') + res = connection.read_json("https://jsonplaceholder.typicode.com/todos") assert len(res.types) == 4 except duckdb.Error as e: - if '403' in e: + if "403" in e: pytest.skip(reason="Test is flaky, sometimes returns 403") else: pytest.fail(str(e)) def test_s3fs(self, require): - connection = require('httpfs') + connection = require("httpfs") - rel = connection.read_csv(f"s3://duckdb-blobs/data/Star_Trek-Season_1.csv", header=True) + rel = connection.read_csv("s3://duckdb-blobs/data/Star_Trek-Season_1.csv", header=True) res = rel.fetchone() assert res == (1, 0, datetime.date(1965, 2, 28), 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 6, 0, 0, 0, 0) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_httpfs(self, require, pandas): - connection = require('httpfs') + connection = require("httpfs") try: - connection.execute( - "SELECT id, first_name, last_name FROM PARQUET_SCAN('https://raw.githubusercontent.com/duckdb/duckdb/main/data/parquet-testing/userdata1.parquet') LIMIT 3;" - ) + connection.execute(""" + SELECT id, first_name, last_name FROM PARQUET_SCAN( + 'https://raw.githubusercontent.com/duckdb/duckdb/main/data/parquet-testing/userdata1.parquet' + ) LIMIT 3; + """) except RuntimeError as e: # Test will ignore result if it fails due to networking issues while running the test. - if str(e).startswith("HTTP HEAD error"): - return - elif str(e).startswith("Unable to connect"): + if str(e).startswith("HTTP HEAD error") or str(e).startswith("Unable to connect"): return else: - raise e + raise result_df = connection.fetchdf() exp_result = pandas.DataFrame( { - 'id': pandas.Series([1, 2, 3], dtype="int32"), - 'first_name': ['Amanda', 'Albert', 'Evelyn'], - 'last_name': ['Jordan', 'Freeman', 'Morgan'], + "id": pandas.Series([1, 2, 3], dtype="int32"), + "first_name": ["Amanda", "Albert", "Evelyn"], + "last_name": ["Jordan", "Freeman", "Morgan"], } ) pandas.testing.assert_frame_equal(result_df, exp_result) def test_http_exception(self, require): - connection = require('httpfs') + connection = require("httpfs") # Read from a bogus HTTPS url, assert that it errors with a non-successful status code - with raises(duckdb.HTTPException) as exc: + with pytest.raises(duckdb.HTTPException) as exc: connection.execute("SELECT * FROM PARQUET_SCAN('https://example.com/userdata1.parquet')") value = exc.value assert value.status_code != 200 - assert value.body == '' - assert 'Content-Length' in value.headers + assert value.body == "" + assert "Content-Length" in value.headers def test_fsspec_priority(self, require): pytest.importorskip("fsspec") pytest.importorskip("gscfs") import fsspec - connection = require('httpfs') + connection = require("httpfs") gcs = fsspec.filesystem("gcs") connection.register_filesystem(gcs) diff --git a/tests/fast/adbc/test_adbc.py b/tests/fast/adbc/test_adbc.py index 663563cf..80920a99 100644 --- a/tests/fast/adbc/test_adbc.py +++ b/tests/fast/adbc/test_adbc.py @@ -1,29 +1,21 @@ -import duckdb -import pytest -import sys import datetime -import os -import numpy as np - -if sys.version_info < (3, 9): - pytest.skip( - "Python Version must be higher or equal to 3.9 to run this test", - allow_module_level=True, - ) +import sys +from pathlib import Path -adbc_driver_manager = pytest.importorskip("adbc_driver_manager.dbapi") -adbc_driver_manager_lib = pytest.importorskip("adbc_driver_manager._lib") +import adbc_driver_manager.dbapi +import numpy as np +import pyarrow +import pytest -pyarrow = pytest.importorskip("pyarrow") +import adbc_driver_duckdb.dbapi -# When testing local, if you build via BUILD_PYTHON=1 make, you need to manually set up the -# dylib duckdb path. -driver_path = duckdb.duckdb.__file__ +xfail = pytest.mark.xfail +driver_path = adbc_driver_duckdb.driver_path() @pytest.fixture def duck_conn(): - with adbc_driver_manager.connect(driver=driver_path, entrypoint="duckdb_adbc_init") as conn: + with adbc_driver_manager.dbapi.connect(driver=driver_path, entrypoint="duckdb_adbc_init") as conn: yield conn @@ -37,7 +29,7 @@ def example_table(): ) -@pytest.mark.xfail +@xfail(sys.platform == "win32", reason="adbc-driver-manager.adbc_get_info() returns an empty dict on windows") def test_connection_get_info(duck_conn): assert duck_conn.adbc_get_info() != {} @@ -47,9 +39,12 @@ def test_connection_get_table_types(duck_conn): with duck_conn.cursor() as cursor: # Test Default Schema cursor.execute("CREATE TABLE tableschema (ints BIGINT)") - assert duck_conn.adbc_get_table_types() == ['BASE TABLE'] + assert duck_conn.adbc_get_table_types() == ["BASE TABLE"] +@xfail( + sys.platform == "win32", reason="adbc-driver-manager.adbc_get_objects() returns an invalid schema dict on windows" +) def test_connection_get_objects(duck_conn): with duck_conn.cursor() as cursor: cursor.execute("CREATE TABLE getobjects (ints BIGINT PRIMARY KEY)") @@ -71,6 +66,9 @@ def test_connection_get_objects(duck_conn): assert depth_all.schema == depth_catalogs.schema +@xfail( + sys.platform == "win32", reason="adbc-driver-manager.adbc_get_objects() returns an invalid schema dict on windows" +) def test_connection_get_objects_filters(duck_conn): with duck_conn.cursor() as cursor: cursor.execute("CREATE TABLE getobjects (ints BIGINT PRIMARY KEY)") @@ -97,13 +95,13 @@ def test_connection_get_objects_filters(duck_conn): def test_commit(tmp_path): - db = os.path.join(tmp_path, "tmp.db") - if os.path.exists(db): - os.remove(db) + db = Path(tmp_path) / "tmp.db" + if db.exists(): + db.unlink() table = example_table() db_kwargs = {"path": f"{db}"} # Start connection with auto-commit off - with adbc_driver_manager.connect( + with adbc_driver_manager.dbapi.connect( driver=driver_path, entrypoint="duckdb_adbc_init", db_kwargs=db_kwargs, @@ -113,7 +111,7 @@ def test_commit(tmp_path): cur.adbc_ingest("ingest", table, "create") # Check Data is not there - with adbc_driver_manager.connect( + with adbc_driver_manager.dbapi.connect( driver=driver_path, entrypoint="duckdb_adbc_init", db_kwargs=db_kwargs, @@ -123,22 +121,24 @@ def test_commit(tmp_path): with conn.cursor() as cur: # This errors because the table does not exist with pytest.raises( - adbc_driver_manager_lib.InternalError, - match=r'Table with name ingest does not exist!', + adbc_driver_manager._lib.InternalError, + match=r"Table with name ingest does not exist!", ): cur.execute("SELECT count(*) from ingest") cur.adbc_ingest("ingest", table, "create") # This now works because we enabled autocommit - with adbc_driver_manager.connect( - driver=driver_path, - entrypoint="duckdb_adbc_init", - db_kwargs=db_kwargs, - ) as conn: - with conn.cursor() as cur: - cur.execute("SELECT count(*) from ingest") - assert cur.fetch_arrow_table().to_pydict() == {'count_star()': [4]} + with ( + adbc_driver_manager.dbapi.connect( + driver=driver_path, + entrypoint="duckdb_adbc_init", + db_kwargs=db_kwargs, + ) as conn, + conn.cursor() as cur, + ): + cur.execute("SELECT count(*) from ingest") + assert cur.fetch_arrow_table().to_pydict() == {"count_star()": [4]} def test_connection_get_table_schema(duck_conn): @@ -207,6 +207,7 @@ def test_statement_query(duck_conn): assert cursor.fetch_arrow_table().to_pylist() == [{"foo": 1}] +@xfail(sys.platform == "win32", reason="adbc-driver-manager returns an invalid table schema on windows") def test_insertion(duck_conn): table = example_table() reader = table.to_reader() @@ -224,7 +225,7 @@ def test_insertion(duck_conn): # Test Append with duck_conn.cursor() as cursor: with pytest.raises( - adbc_driver_manager_lib.InternalError, + adbc_driver_manager.InternalError, match=r'Table with name "ingest_table" already exists!', ): cursor.adbc_ingest("ingest_table", table, "create") @@ -233,9 +234,10 @@ def test_insertion(duck_conn): assert cursor.fetch_arrow_table().to_pydict() == {"count_star()": [8]} +@xfail(sys.platform == "win32", reason="adbc-driver-manager returns an invalid table schema on windows") def test_read(duck_conn): with duck_conn.cursor() as cursor: - filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "data", "category.csv") + filename = Path(__file__).parent / ".." / "data" / "category.csv" cursor.execute(f"SELECT * FROM '{filename}'") assert cursor.fetch_arrow_table().to_pydict() == { "CATEGORY_ID": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], @@ -297,46 +299,50 @@ def test_large_chunk(tmp_path): # Create the table table = pyarrow.table([col1, col2, col3], names=["ints", "floats", "strings"]) - db = os.path.join(tmp_path, "tmp.db") - if os.path.exists(db): - os.remove(db) + db = Path(tmp_path) / "tmp.db" + if db.exists(): + db.unlink() db_kwargs = {"path": f"{db}"} - with adbc_driver_manager.connect( - driver=driver_path, - entrypoint="duckdb_adbc_init", - db_kwargs=db_kwargs, - autocommit=True, - ) as conn: - with conn.cursor() as cur: - cur.adbc_ingest("ingest", table, "create") - cur.execute("SELECT count(*) from ingest") - assert cur.fetch_arrow_table().to_pydict() == {'count_star()': [30_000]} + with ( + adbc_driver_manager.dbapi.connect( + driver=driver_path, + entrypoint="duckdb_adbc_init", + db_kwargs=db_kwargs, + autocommit=True, + ) as conn, + conn.cursor() as cur, + ): + cur.adbc_ingest("ingest", table, "create") + cur.execute("SELECT count(*) from ingest") + assert cur.fetch_arrow_table().to_pydict() == {"count_star()": [30_000]} def test_dictionary_data(tmp_path): - data = ['apple', 'banana', 'apple', 'orange', 'banana', 'banana'] + data = ["apple", "banana", "apple", "orange", "banana", "banana"] dict_type = pyarrow.dictionary(index_type=pyarrow.int32(), value_type=pyarrow.string()) dict_array = pyarrow.array(data, type=dict_type) # Wrap in a table - table = pyarrow.table({'fruits': dict_array}) - db = os.path.join(tmp_path, "tmp.db") - if os.path.exists(db): - os.remove(db) + table = pyarrow.table({"fruits": dict_array}) + db = Path(tmp_path) / "tmp.db" + if db.exists(): + db.unlink() db_kwargs = {"path": f"{db}"} - with adbc_driver_manager.connect( - driver=driver_path, - entrypoint="duckdb_adbc_init", - db_kwargs=db_kwargs, - autocommit=True, - ) as conn: - with conn.cursor() as cur: - cur.adbc_ingest("ingest", table, "create") - cur.execute("from ingest") - assert cur.fetch_arrow_table().to_pydict() == { - 'fruits': ['apple', 'banana', 'apple', 'orange', 'banana', 'banana'] - } + with ( + adbc_driver_manager.dbapi.connect( + driver=driver_path, + entrypoint="duckdb_adbc_init", + db_kwargs=db_kwargs, + autocommit=True, + ) as conn, + conn.cursor() as cur, + ): + cur.adbc_ingest("ingest", table, "create") + cur.execute("from ingest") + assert cur.fetch_arrow_table().to_pydict() == { + "fruits": ["apple", "banana", "apple", "orange", "banana", "banana"] + } def test_ree_data(tmp_path): @@ -347,50 +353,52 @@ def test_ree_data(tmp_path): table = pyarrow.table({"fruits": ree_array}) - db = os.path.join(tmp_path, "tmp.db") - if os.path.exists(db): - os.remove(db) + db = Path(tmp_path) / "tmp.db" + if db.exists(): + db.unlink() db_kwargs = {"path": f"{db}"} - with adbc_driver_manager.connect( - driver=driver_path, - entrypoint="duckdb_adbc_init", - db_kwargs=db_kwargs, - autocommit=True, - ) as conn: - with conn.cursor() as cur: - cur.adbc_ingest("ingest", table, "create") - cur.execute("from ingest") - assert cur.fetch_arrow_table().to_pydict() == { - 'fruits': ['apple', 'apple', 'apple', 'banana', 'banana', 'orange'] - } + with ( + adbc_driver_manager.dbapi.connect( + driver=driver_path, + entrypoint="duckdb_adbc_init", + db_kwargs=db_kwargs, + autocommit=True, + ) as conn, + conn.cursor() as cur, + ): + cur.adbc_ingest("ingest", table, "create") + cur.execute("from ingest") + assert cur.fetch_arrow_table().to_pydict() == { + "fruits": ["apple", "apple", "apple", "banana", "banana", "orange"] + } def sorted_get_objects(catalogs): res = [] - for catalog in sorted(catalogs, key=lambda cat: cat['catalog_name']): + for catalog in sorted(catalogs, key=lambda cat: cat["catalog_name"]): new_catalog = { - "catalog_name": catalog['catalog_name'], + "catalog_name": catalog["catalog_name"], "catalog_db_schemas": [], } - for db_schema in sorted(catalog['catalog_db_schemas'] or [], key=lambda sch: sch['db_schema_name']): + for db_schema in sorted(catalog["catalog_db_schemas"] or [], key=lambda sch: sch["db_schema_name"]): new_db_schema = { - "db_schema_name": db_schema['db_schema_name'], + "db_schema_name": db_schema["db_schema_name"], "db_schema_tables": [], } - for table in sorted(db_schema['db_schema_tables'] or [], key=lambda tab: tab['table_name']): + for table in sorted(db_schema["db_schema_tables"] or [], key=lambda tab: tab["table_name"]): new_table = { - "table_name": table['table_name'], - "table_type": table['table_type'], + "table_name": table["table_name"], + "table_type": table["table_type"], "table_columns": [], "table_constraints": [], } - for column in sorted(table['table_columns'] or [], key=lambda col: col['ordinal_position']): + for column in sorted(table["table_columns"] or [], key=lambda col: col["ordinal_position"]): new_table["table_columns"].append(column) - for constraint in sorted(table['table_constraints'] or [], key=lambda con: con['constraint_name']): + for constraint in sorted(table["table_constraints"] or [], key=lambda con: con["constraint_name"]): new_table["table_constraints"].append(constraint) new_db_schema["db_schema_tables"].append(new_table) diff --git a/tests/fast/adbc/test_connection_get_info.py b/tests/fast/adbc/test_connection_get_info.py index 3744b7da..aa2b3d32 100644 --- a/tests/fast/adbc/test_connection_get_info.py +++ b/tests/fast/adbc/test_connection_get_info.py @@ -1,37 +1,19 @@ -import sys +import pyarrow as pa +import adbc_driver_duckdb.dbapi import duckdb -import pytest -pa = pytest.importorskip("pyarrow") -adbc_driver_manager = pytest.importorskip("adbc_driver_manager") -if sys.version_info < (3, 9): - pytest.skip( - "Python Version must be higher or equal to 3.9 to run this test", - allow_module_level=True, - ) - -try: - adbc_driver_duckdb = pytest.importorskip("adbc_driver_duckdb.dbapi") - con = adbc_driver_duckdb.connect() -except adbc_driver_manager.InternalError as e: - pytest.skip( - f"'duckdb_adbc_init' was not exported in this install, try running 'python3 setup.py install': {e}", - allow_module_level=True, - ) - - -class TestADBCConnectionGetInfo(object): +class TestADBCConnectionGetInfo: def test_connection_basic(self): - con = adbc_driver_duckdb.connect() + con = adbc_driver_duckdb.dbapi.connect() with con.cursor() as cursor: cursor.execute("select 42") res = cursor.fetchall() assert res == [(42,)] def test_connection_get_info_all(self): - con = adbc_driver_duckdb.connect() + con = adbc_driver_duckdb.dbapi.connect() adbc_con = con.adbc_connection res = adbc_con.get_info() reader = pa.RecordBatchReader._import_from_c(res.address) @@ -41,7 +23,7 @@ def test_connection_get_info_all(self): expected_result = pa.array( [ "duckdb", - "v" + duckdb.__version__, # don't hardcode this, as it will change every version + "v" + duckdb.__duckdb_version__, # don't hardcode this, as it will change every version "ADBC DuckDB Driver", "(unknown)", "(unknown)", @@ -55,7 +37,7 @@ def test_connection_get_info_all(self): assert string_values == expected_result def test_empty_result(self): - con = adbc_driver_duckdb.connect() + con = adbc_driver_duckdb.dbapi.connect() adbc_con = con.adbc_connection res = adbc_con.get_info([1337]) reader = pa.RecordBatchReader._import_from_c(res.address) @@ -66,7 +48,7 @@ def test_empty_result(self): assert values.num_chunks == 0 def test_unrecognized_codes(self): - con = adbc_driver_duckdb.connect() + con = adbc_driver_duckdb.dbapi.connect() adbc_con = con.adbc_connection res = adbc_con.get_info([0, 1000, 4, 2000]) reader = pa.RecordBatchReader._import_from_c(res.address) diff --git a/tests/fast/adbc/test_statement_bind.py b/tests/fast/adbc/test_statement_bind.py index 5e9d7d45..d35693ff 100644 --- a/tests/fast/adbc/test_statement_bind.py +++ b/tests/fast/adbc/test_statement_bind.py @@ -1,18 +1,12 @@ import sys +import adbc_driver_manager +import pyarrow as pa import pytest -if sys.version_info < (3, 9): - pytest.skip( - "Python Version must be higher or equal to 3.9 to run this test", - allow_module_level=True, - ) +import adbc_driver_duckdb.dbapi -pa = pytest.importorskip("pyarrow") -adbc_driver_manager = pytest.importorskip("adbc_driver_manager") - -adbc_driver_duckdb = pytest.importorskip("adbc_driver_duckdb.dbapi") -con = adbc_driver_duckdb.connect() +xfail = pytest.mark.xfail def _import(handle): @@ -21,17 +15,18 @@ def _import(handle): return pa.RecordBatchReader._import_from_c(handle.address) elif isinstance(handle, adbc_driver_manager.ArrowSchemaHandle): return pa.Schema._import_from_c(handle.address) - raise NotImplementedError(f"Importing {handle!r}") + msg = f"Importing {handle!r}" + raise NotImplementedError(msg) -def _bind(stmt, batch): +def _bind(stmt, batch) -> None: array = adbc_driver_manager.ArrowArrayHandle() schema = adbc_driver_manager.ArrowSchemaHandle() batch._export_to_c(array.address, schema.address) stmt.bind(array, schema) -class TestADBCStatementBind(object): +class TestADBCStatementBind: def test_bind_multiple_rows(self): data = pa.record_batch( [ @@ -40,7 +35,7 @@ def test_bind_multiple_rows(self): names=["ints"], ) - con = adbc_driver_duckdb.connect() + con = adbc_driver_duckdb.dbapi.connect() with con.cursor() as cursor: statement = cursor.adbc_statement statement.set_sql_query("select ? * 2 as i") @@ -50,7 +45,7 @@ def test_bind_multiple_rows(self): with pytest.raises( adbc_driver_manager.NotSupportedError, match="Binding multiple rows at once is not supported yet" ): - res, number_of_rows = statement.execute_query() + statement.execute_query() def test_bind_single_row(self): expected_result = pa.array([8], type=pa.int64()) @@ -62,7 +57,7 @@ def test_bind_single_row(self): names=["ints"], ) - con = adbc_driver_duckdb.connect() + con = adbc_driver_duckdb.dbapi.connect() with con.cursor() as cursor: statement = cursor.adbc_statement statement.set_sql_query("select ? * 2 as i") @@ -70,34 +65,35 @@ def test_bind_single_row(self): raw_schema = statement.get_parameter_schema() schema = _import(raw_schema) - assert schema.names == ['0'] + assert schema.names == ["0"] _bind(statement, data) res, _ = statement.execute_query() table = _import(res).read_all() - result = table['i'] + result = table["i"] assert result.num_chunks == 1 result_values = result.chunk(0) assert result_values == expected_result + @xfail(sys.platform == "win32", reason="adbc-driver-manager returns an invalid table schema on windows") def test_multiple_parameters(self): int_data = pa.array([5]) - varchar_data = pa.array(['not a short string']) + varchar_data = pa.array(["not a short string"]) bool_data = pa.array([True]) # Create the schema - schema = pa.schema([('a', pa.int64()), ('b', pa.string()), ('c', pa.bool_())]) + schema = pa.schema([("a", pa.int64()), ("b", pa.string()), ("c", pa.bool_())]) # Create the PyArrow table expected_res = pa.Table.from_arrays([int_data, varchar_data, bool_data], schema=schema) data = pa.record_batch( - [[5], ['not a short string'], [True]], + [[5], ["not a short string"], [True]], names=["ints", "strings", "bools"], ) - con = adbc_driver_duckdb.connect() + con = adbc_driver_duckdb.dbapi.connect() with con.cursor() as cursor: statement = cursor.adbc_statement statement.set_sql_query("select ? as a, ? as b, ? as c") @@ -105,7 +101,7 @@ def test_multiple_parameters(self): raw_schema = statement.get_parameter_schema() schema = _import(raw_schema) - assert schema.names == ['0', '1', '2'] + assert schema.names == ["0", "1", "2"] _bind(statement, data) res, _ = statement.execute_query() @@ -115,19 +111,19 @@ def test_multiple_parameters(self): def test_bind_composite_type(self): data_dict = { - 'field1': pa.array([10], type=pa.int64()), - 'field2': pa.array([3.14], type=pa.float64()), - 'field3': pa.array(['example with long string'], type=pa.string()), + "field1": pa.array([10], type=pa.int64()), + "field2": pa.array([3.14], type=pa.float64()), + "field3": pa.array(["example with long string"], type=pa.string()), } # Create the StructArray struct_array = pa.StructArray.from_arrays(arrays=data_dict.values(), names=data_dict.keys()) - schema = pa.schema([(name, array.type) for name, array in zip(['a'], [struct_array])]) + schema = pa.schema([(name, array.type) for name, array in zip(["a"], [struct_array])]) # Create the RecordBatch record_batch = pa.RecordBatch.from_arrays([struct_array], schema=schema) - con = adbc_driver_duckdb.connect() + con = adbc_driver_duckdb.dbapi.connect() with con.cursor() as cursor: statement = cursor.adbc_statement statement.set_sql_query("select ? as a") @@ -135,22 +131,22 @@ def test_bind_composite_type(self): raw_schema = statement.get_parameter_schema() schema = _import(raw_schema) - assert schema.names == ['0'] + assert schema.names == ["0"] _bind(statement, record_batch) res, _ = statement.execute_query() table = _import(res).read_all() - result = table['a'] + result = table["a"] result = result.chunk(0) assert result == struct_array def test_too_many_parameters(self): data = pa.record_batch( - [[12423], ['not a short string']], + [[12423], ["not a short string"]], names=["ints", "strings"], ) - con = adbc_driver_duckdb.connect() + con = adbc_driver_duckdb.dbapi.connect() with con.cursor() as cursor: statement = cursor.adbc_statement statement.set_sql_query("select ? as a") @@ -158,7 +154,7 @@ def test_too_many_parameters(self): raw_schema = statement.get_parameter_schema() schema = _import(raw_schema) - assert schema.names == ['0'] + assert schema.names == ["0"] array = adbc_driver_manager.ArrowArrayHandle() schema = adbc_driver_manager.ArrowSchemaHandle() @@ -170,15 +166,16 @@ def test_too_many_parameters(self): adbc_driver_manager.ProgrammingError, match="Input data has more column than prepared statement has parameters", ): - res, _ = statement.execute_query() + statement.execute_query() + @xfail(sys.platform == "win32", reason="adbc-driver-manager returns an invalid table schema on windows") def test_not_enough_parameters(self): data = pa.record_batch( - [['not a short string']], + [["not a short string"]], names=["strings"], ) - con = adbc_driver_duckdb.connect() + con = adbc_driver_duckdb.dbapi.connect() with con.cursor() as cursor: statement = cursor.adbc_statement statement.set_sql_query("select ? as a, ? as b") @@ -186,7 +183,7 @@ def test_not_enough_parameters(self): raw_schema = statement.get_parameter_schema() schema = _import(raw_schema) - assert schema.names == ['0', '1'] + assert schema.names == ["0", "1"] array = adbc_driver_manager.ArrowArrayHandle() schema = adbc_driver_manager.ArrowSchemaHandle() @@ -196,4 +193,4 @@ def test_not_enough_parameters(self): adbc_driver_manager.ProgrammingError, match="Values were not provided for the following prepared statement parameters: 2", ): - res, _ = statement.execute_query() + statement.execute_query() diff --git a/tests/fast/api/test_3324.py b/tests/fast/api/test_3324.py index e8f6085f..1e5889a5 100644 --- a/tests/fast/api/test_3324.py +++ b/tests/fast/api/test_3324.py @@ -1,30 +1,31 @@ import pytest + import duckdb -class Test3324(object): +class Test3324: def test_3324(self, duckdb_cursor): - create_output = duckdb_cursor.execute( + duckdb_cursor.execute( """ - create or replace table my_table as - select 'test1' as column1, 1 as column2, 'quack' as column3 + create or replace table my_table as + select 'test1' as column1, 1 as column2, 'quack' as column3 union all - select 'test2' as column1, 2 as column2, 'quacks' as column3 + select 'test2' as column1, 2 as column2, 'quacks' as column3 union all - select 'test3' as column1, 3 as column2, 'quacking' as column3 + select 'test3' as column1, 3 as column2, 'quacking' as column3 """ ).fetch_df() - prepare_output = duckdb_cursor.execute( + duckdb_cursor.execute( """ - prepare v1 as - select + prepare v1 as + select column1 , column2 - , column3 + , column3 from my_table - where + where column1 = $1""" ).fetch_df() with pytest.raises(duckdb.BinderException, match="Unexpected prepared parameter"): - duckdb_cursor.execute("""execute v1(?)""", ('test1',)).fetch_df() + duckdb_cursor.execute("""execute v1(?)""", ("test1",)).fetch_df() diff --git a/tests/fast/api/test_3654.py b/tests/fast/api/test_3654.py index e63f0cd1..a6b01dd5 100644 --- a/tests/fast/api/test_3654.py +++ b/tests/fast/api/test_3654.py @@ -1,21 +1,22 @@ -import duckdb import pytest +import duckdb + try: import pyarrow as pa can_run = True -except: +except Exception: can_run = False -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas -class Test3654(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) +class Test3654: + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_3654_pandas(self, duckdb_cursor, pandas): df1 = pandas.DataFrame( { - 'id': [1, 1, 2], + "id": [1, 1, 2], } ) con = duckdb.connect() @@ -24,14 +25,14 @@ def test_3654_pandas(self, duckdb_cursor, pandas): print(rel.execute().fetchall()) assert rel.execute().fetchall() == [(1,), (1,), (2,)] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_3654_arrow(self, duckdb_cursor, pandas): if not can_run: return df1 = pandas.DataFrame( { - 'id': [1, 1, 2], + "id": [1, 1, 2], } ) table = pa.Table.from_pandas(df1) diff --git a/tests/fast/api/test_3728.py b/tests/fast/api/test_3728.py index 2df3c156..bd770bf0 100644 --- a/tests/fast/api/test_3728.py +++ b/tests/fast/api/test_3728.py @@ -1,7 +1,7 @@ import duckdb -class Test3728(object): +class Test3728: def test_3728_describe_enum(self, duckdb_cursor): # Create an in-memory database, but the problem is also present in file-backed DBs cursor = duckdb.connect(":memory:") @@ -14,6 +14,6 @@ def test_3728_describe_enum(self, duckdb_cursor): # This fails with "RuntimeError: Not implemented Error: unsupported type: mood" assert cursor.table("person").execute().description == [ - ('name', 'VARCHAR', None, None, None, None, None), - ('current_mood', "ENUM('sad', 'ok', 'happy')", None, None, None, None, None), + ("name", "VARCHAR", None, None, None, None, None), + ("current_mood", "ENUM('sad', 'ok', 'happy')", None, None, None, None, None), ] diff --git a/tests/fast/api/test_6315.py b/tests/fast/api/test_6315.py index e8eaff59..3702831e 100644 --- a/tests/fast/api/test_6315.py +++ b/tests/fast/api/test_6315.py @@ -1,7 +1,7 @@ import duckdb -class Test6315(object): +class Test6315: def test_6315(self, duckdb_cursor): # segfault when accessing description after fetching rows c = duckdb.connect(":memory:") @@ -9,15 +9,15 @@ def test_6315(self, duckdb_cursor): rv.fetchall() desc = rv.description names = [x[0] for x in desc] - assert names == ['type', 'name', 'tbl_name', 'rootpage', 'sql'] + assert names == ["type", "name", "tbl_name", "rootpage", "sql"] # description of relation rel = c.sql("select * from sqlite_master where type = 'table'") desc = rel.description names = [x[0] for x in desc] - assert names == ['type', 'name', 'tbl_name', 'rootpage', 'sql'] + assert names == ["type", "name", "tbl_name", "rootpage", "sql"] rel.fetchall() desc = rel.description names = [x[0] for x in desc] - assert names == ['type', 'name', 'tbl_name', 'rootpage', 'sql'] + assert names == ["type", "name", "tbl_name", "rootpage", "sql"] diff --git a/tests/fast/api/test_attribute_getter.py b/tests/fast/api/test_attribute_getter.py index 958e8892..3566c5e4 100644 --- a/tests/fast/api/test_attribute_getter.py +++ b/tests/fast/api/test_attribute_getter.py @@ -1,53 +1,47 @@ -import duckdb -import tempfile -import os -import pandas as pd -import tempfile -import pandas._testing as tm -import datetime -import csv import pytest +import duckdb + -class TestGetAttribute(object): +class TestGetAttribute: def test_basic_getattr(self, duckdb_cursor): - rel = duckdb_cursor.sql('select i as a, (i + 5) % 10 as b, (i + 2) % 3 as c from range(100) tbl(i)') + rel = duckdb_cursor.sql("select i as a, (i + 5) % 10 as b, (i + 2) % 3 as c from range(100) tbl(i)") assert rel.a.fetchmany(5) == [(0,), (1,), (2,), (3,), (4,)] assert rel.b.fetchmany(5) == [(5,), (6,), (7,), (8,), (9,)] assert rel.c.fetchmany(5) == [(2,), (0,), (1,), (2,), (0,)] def test_basic_getitem(self, duckdb_cursor): - rel = duckdb_cursor.sql('select i as a, (i + 5) % 10 as b, (i + 2) % 3 as c from range(100) tbl(i)') - assert rel['a'].fetchmany(5) == [(0,), (1,), (2,), (3,), (4,)] - assert rel['b'].fetchmany(5) == [(5,), (6,), (7,), (8,), (9,)] - assert rel['c'].fetchmany(5) == [(2,), (0,), (1,), (2,), (0,)] + rel = duckdb_cursor.sql("select i as a, (i + 5) % 10 as b, (i + 2) % 3 as c from range(100) tbl(i)") + assert rel["a"].fetchmany(5) == [(0,), (1,), (2,), (3,), (4,)] + assert rel["b"].fetchmany(5) == [(5,), (6,), (7,), (8,), (9,)] + assert rel["c"].fetchmany(5) == [(2,), (0,), (1,), (2,), (0,)] def test_getitem_nonexistant(self, duckdb_cursor): - rel = duckdb_cursor.sql('select i as a, (i + 5) % 10 as b, (i + 2) % 3 as c from range(100) tbl(i)') + rel = duckdb_cursor.sql("select i as a, (i + 5) % 10 as b, (i + 2) % 3 as c from range(100) tbl(i)") with pytest.raises(AttributeError): - rel['d'] + rel["d"] def test_getattr_nonexistant(self, duckdb_cursor): - rel = duckdb_cursor.sql('select i as a, (i + 5) % 10 as b, (i + 2) % 3 as c from range(100) tbl(i)') + rel = duckdb_cursor.sql("select i as a, (i + 5) % 10 as b, (i + 2) % 3 as c from range(100) tbl(i)") with pytest.raises(AttributeError): - rel.d + rel.d # noqa: B018 def test_getattr_collision(self, duckdb_cursor): - rel = duckdb_cursor.sql('select i as df from range(100) tbl(i)') + rel = duckdb_cursor.sql("select i as df from range(100) tbl(i)") # 'df' also exists as a method on DuckDBPyRelation assert rel.df.__class__ != duckdb.DuckDBPyRelation def test_getitem_collision(self, duckdb_cursor): - rel = duckdb_cursor.sql('select i as df from range(100) tbl(i)') + rel = duckdb_cursor.sql("select i as df from range(100) tbl(i)") # this case is not an issue on __getitem__ - assert rel['df'].__class__ == duckdb.DuckDBPyRelation + assert rel["df"].__class__ == duckdb.DuckDBPyRelation def test_getitem_struct(self, duckdb_cursor): rel = duckdb_cursor.sql("select {'a':5, 'b':6} as a, 5 as b") - assert rel['a']['a'].fetchall()[0][0] == 5 - assert rel['a']['b'].fetchall()[0][0] == 6 + assert rel["a"]["a"].fetchall()[0][0] == 5 + assert rel["a"]["b"].fetchall()[0][0] == 6 def test_getattr_struct(self, duckdb_cursor): rel = duckdb_cursor.sql("select {'a':5, 'b':6} as a, 5 as b") @@ -56,7 +50,7 @@ def test_getattr_struct(self, duckdb_cursor): def test_getattr_spaces(self, duckdb_cursor): rel = duckdb_cursor.sql('select 42 as "hello world"') - assert rel['hello world'].fetchall()[0][0] == 42 + assert rel["hello world"].fetchall()[0][0] == 42 def test_getattr_doublequotes(self, duckdb_cursor): rel = duckdb_cursor.sql('select 1 as "tricky"", ""quotes", 2 as tricky, 3 as quotes') diff --git a/tests/fast/api/test_config.py b/tests/fast/api/test_config.py index 5db5f77b..aaec24c4 100644 --- a/tests/fast/api/test_config.py +++ b/tests/fast/api/test_config.py @@ -1,90 +1,90 @@ -# simple DB API testcase +# ruff: noqa: F841 +import os +import re -import duckdb -import numpy import pytest -import re -import os -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb -class TestDBConfig(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) +class TestDBConfig: + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_default_order(self, duckdb_cursor, pandas): - df = pandas.DataFrame({'a': [1, 2, 3]}) - con = duckdb.connect(':memory:', config={'default_order': 'desc'}) - result = con.execute('select * from df order by a').fetchall() + df = pandas.DataFrame({"a": [1, 2, 3]}) + con = duckdb.connect(":memory:", config={"default_order": "desc"}) + result = con.execute("select * from df order by a").fetchall() assert result == [(3,), (2,), (1,)] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_null_order(self, duckdb_cursor, pandas): - df = pandas.DataFrame({'a': [1, 2, 3, None]}) - con = duckdb.connect(':memory:', config={'default_null_order': 'nulls_last'}) - result = con.execute('select * from df order by a').fetchall() + df = pandas.DataFrame({"a": [1, 2, 3, None]}) + con = duckdb.connect(":memory:", config={"default_null_order": "nulls_last"}) + result = con.execute("select * from df order by a").fetchall() assert result == [(1,), (2,), (3,), (None,)] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_multiple_options(self, duckdb_cursor, pandas): - df = pandas.DataFrame({'a': [1, 2, 3, None]}) - con = duckdb.connect(':memory:', config={'default_null_order': 'nulls_last', 'default_order': 'desc'}) - result = con.execute('select * from df order by a').fetchall() + df = pandas.DataFrame({"a": [1, 2, 3, None]}) + con = duckdb.connect(":memory:", config={"default_null_order": "nulls_last", "default_order": "desc"}) + result = con.execute("select * from df order by a").fetchall() assert result == [(3,), (2,), (1,), (None,)] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_external_access(self, duckdb_cursor, pandas): - df = pandas.DataFrame({'a': [1, 2, 3]}) + df = pandas.DataFrame({"a": [1, 2, 3]}) # this works (replacement scan) - con_regular = duckdb.connect(':memory:', config={}) - con_regular.execute('select * from df') + con_regular = duckdb.connect(":memory:", config={}) + con_regular.execute("select * from df") # disable external access: this also disables pandas replacement scans - con = duckdb.connect(':memory:', config={'enable_external_access': False}) + con = duckdb.connect(":memory:", config={"enable_external_access": False}) # this should fail query_failed = False try: - con.execute('select * from df').fetchall() - except: + con.execute("select * from df").fetchall() + except Exception: query_failed = True - assert query_failed == True + assert query_failed def test_extension_setting(self): - repository = os.environ.get('LOCAL_EXTENSION_REPO') + repository = os.environ.get("LOCAL_EXTENSION_REPO") if not repository: return - con = duckdb.connect(config={"TimeZone": "UTC", 'autoinstall_extension_repository': repository}) - assert 'UTC' == con.sql("select current_setting('TimeZone')").fetchone()[0] + con = duckdb.connect(config={"TimeZone": "UTC", "autoinstall_extension_repository": repository}) + assert con.sql("select current_setting('TimeZone')").fetchone()[0] == "UTC" def test_unrecognized_option(self, duckdb_cursor): success = True try: - con_regular = duckdb.connect(':memory:', config={'thisoptionisprobablynotthere': '42'}) - except: + duckdb.connect(":memory:", config={"thisoptionisprobablynotthere": "42"}) + except Exception: success = False - assert success == False + assert not success def test_incorrect_parameter(self, duckdb_cursor): success = True try: - con_regular = duckdb.connect(':memory:', config={'default_null_order': '42'}) - except: + duckdb.connect(":memory:", config={"default_null_order": "42"}) + except Exception: success = False - assert success == False + assert not success def test_user_agent_default(self, duckdb_cursor): - con_regular = duckdb.connect(':memory:') + con_regular = duckdb.connect(":memory:") regex = re.compile("duckdb/.* python/.*") # Expands to: SELECT * FROM pragma_user_agent() assert regex.match(con_regular.sql("PRAGMA user_agent").fetchone()[0]) is not None custom_user_agent = con_regular.sql("SELECT current_setting('custom_user_agent')").fetchone() - assert custom_user_agent[0] == '' + assert custom_user_agent[0] == "" def test_user_agent_custom(self, duckdb_cursor): - con_regular = duckdb.connect(':memory:', config={'custom_user_agent': 'CUSTOM_STRING'}) + con_regular = duckdb.connect(":memory:", config={"custom_user_agent": "CUSTOM_STRING"}) regex = re.compile("duckdb/.* python/.* CUSTOM_STRING") assert regex.match(con_regular.sql("PRAGMA user_agent").fetchone()[0]) is not None custom_user_agent = con_regular.sql("SELECT current_setting('custom_user_agent')").fetchone() - assert custom_user_agent[0] == 'CUSTOM_STRING' + assert custom_user_agent[0] == "CUSTOM_STRING" def test_secret_manager_option(self, duckdb_cursor): - con = duckdb.connect(':memory:', config={'allow_persistent_secrets': False}) - result = con.execute('select count(*) from duckdb_secrets()').fetchall() + con = duckdb.connect(":memory:", config={"allow_persistent_secrets": False}) + result = con.execute("select count(*) from duckdb_secrets()").fetchall() assert result == [(0,)] diff --git a/tests/fast/api/test_connection_close.py b/tests/fast/api/test_connection_close.py index e7a47404..8ec24b63 100644 --- a/tests/fast/api/test_connection_close.py +++ b/tests/fast/api/test_connection_close.py @@ -1,10 +1,10 @@ # cursor description - -import duckdb import tempfile -import os + import pytest +import duckdb + def check_exception(f): had_exception = False @@ -15,11 +15,10 @@ def check_exception(f): assert had_exception -class TestConnectionClose(object): +class TestConnectionClose: def test_connection_close(self, duckdb_cursor): - fd, db = tempfile.mkstemp() - os.close(fd) - os.remove(db) + with tempfile.NamedTemporaryFile() as tmp: + db = tmp.name con = duckdb.connect(db) cursor = con.cursor() cursor.execute("create table a (i integer)") @@ -28,16 +27,13 @@ def test_connection_close(self, duckdb_cursor): check_exception(lambda: cursor.execute("select * from a")) def test_open_and_exit(self): - with pytest.raises(TypeError): - with duckdb.connect() as connection: - connection.execute("select 42") - # This exception does not get swallowed by __exit__ - raise TypeError() + with pytest.raises(TypeError), duckdb.connect(): + # This exception does not get swallowed by DuckDBPyConnection's __exit__ + raise TypeError() def test_reopen_connection(self, duckdb_cursor): - fd, db = tempfile.mkstemp() - os.close(fd) - os.remove(db) + with tempfile.NamedTemporaryFile() as tmp: + db = tmp.name con = duckdb.connect(db) cursor = con.cursor() cursor.execute("create table a (i integer)") @@ -54,7 +50,7 @@ def test_get_closed_default_conn(self, duckdb_cursor): duckdb.close() # 'duckdb.close()' closes this connection, because we explicitly set it as the default - with pytest.raises(duckdb.ConnectionException, match='Connection Error: Connection already closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection Error: Connection already closed"): con.sql("select 42").fetchall() default_con = duckdb.default_connection() @@ -65,11 +61,11 @@ def test_get_closed_default_conn(self, duckdb_cursor): duckdb.sql("select 42").fetchall() # Show that the 'default_con' is still closed - with pytest.raises(duckdb.ConnectionException, match='Connection Error: Connection already closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection Error: Connection already closed"): default_con.sql("select 42").fetchall() duckdb.close() # This also does not error because we silently receive a new connection - con2 = duckdb.connect(':default:') + con2 = duckdb.connect(":default:") con2.sql("select 42").fetchall() diff --git a/tests/fast/api/test_connection_interrupt.py b/tests/fast/api/test_connection_interrupt.py index 4efd68b5..4ea63176 100644 --- a/tests/fast/api/test_connection_interrupt.py +++ b/tests/fast/api/test_connection_interrupt.py @@ -2,11 +2,12 @@ import threading import time -import duckdb import pytest +import duckdb + -class TestConnectionInterrupt(object): +class TestConnectionInterrupt: @pytest.mark.xfail( condition=platform.system() == "Emscripten", reason="threads not allowed on Emscripten", @@ -14,7 +15,7 @@ class TestConnectionInterrupt(object): def test_connection_interrupt(self): conn = duckdb.connect() - def interrupt(): + def interrupt() -> None: # Wait for query to start running before interrupting time.sleep(0.1) conn.interrupt() diff --git a/tests/fast/api/test_cursor.py b/tests/fast/api/test_cursor.py index 9510fbd9..f0d7d332 100644 --- a/tests/fast/api/test_cursor.py +++ b/tests/fast/api/test_cursor.py @@ -1,13 +1,14 @@ # simple DB API testcase import pytest + import duckdb -class TestDBAPICursor(object): +class TestDBAPICursor: def test_cursor_basic(self): # Create a connection - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") # Then create a cursor on the connection cursor = con.cursor() # Use the cursor for queries @@ -15,14 +16,14 @@ def test_cursor_basic(self): assert res == [([1, 2, 3, None, 4],)] def test_cursor_preexisting(self): - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") con.execute("create table tbl as select i a, i+1 b, i+2 c from range(5) tbl(i)") cursor = con.cursor() res = cursor.execute("select * from tbl").fetchall() assert res == [(0, 1, 2), (1, 2, 3), (2, 3, 4), (3, 4, 5), (4, 5, 6)] def test_cursor_after_creation(self): - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") # First create the cursor cursor = con.cursor() # Then create table on the source connection @@ -31,7 +32,7 @@ def test_cursor_after_creation(self): assert res == [(0, 1, 2), (1, 2, 3), (2, 3, 4), (3, 4, 5), (4, 5, 6)] def test_cursor_mixed(self): - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") # First create the cursor cursor = con.cursor() # Then create table on the cursor @@ -43,7 +44,7 @@ def test_cursor_mixed(self): assert res == [(0, 1, 2), (1, 2, 3), (2, 3, 4), (3, 4, 5), (4, 5, 6)] def test_cursor_temp_schema_closed(self): - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") cursor = con.cursor() cursor.execute("create temp table tbl as select * from range(100)") other_cursor = con.cursor() @@ -51,10 +52,10 @@ def test_cursor_temp_schema_closed(self): cursor.close() with pytest.raises(duckdb.CatalogException): # This table does not exist in this cursor - res = other_cursor.execute("select * from tbl").fetchall() + other_cursor.execute("select * from tbl").fetchall() def test_cursor_temp_schema_open(self): - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") cursor = con.cursor() cursor.execute("create temp table tbl as select * from range(100)") other_cursor = con.cursor() @@ -62,10 +63,10 @@ def test_cursor_temp_schema_open(self): # cursor.close() with pytest.raises(duckdb.CatalogException): # This table does not exist in this cursor - res = other_cursor.execute("select * from tbl").fetchall() + other_cursor.execute("select * from tbl").fetchall() def test_cursor_temp_schema_both(self): - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") cursor1 = con.cursor() cursor2 = con.cursor() cursor3 = con.cursor() @@ -92,23 +93,23 @@ def test_cursor_timezone(self): # Because the 'timezone' setting was not explicitly set for the connection # the setting of the DBConfig is used instead res = con1.execute("SELECT make_timestamptz(2000,01,20,03,30,59)").fetchone() - assert str(res) == '(datetime.datetime(2000, 1, 20, 3, 30, 59, tzinfo=),)' + assert str(res) == "(datetime.datetime(2000, 1, 20, 3, 30, 59, tzinfo=),)" def test_cursor_closed(self): - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") con.close() with pytest.raises(duckdb.ConnectionException): - cursor = con.cursor() + con.cursor() def test_cursor_used_after_connection_closed(self): - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") cursor = con.cursor() con.close() with pytest.raises(duckdb.ConnectionException): cursor.execute("select [1,2,3,4]") def test_cursor_used_after_close(self): - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") cursor = con.cursor() cursor.close() with pytest.raises(duckdb.ConnectionException): diff --git a/tests/fast/api/test_dbapi00.py b/tests/fast/api/test_dbapi00.py index 815a81b9..425cb7e1 100644 --- a/tests/fast/api/test_dbapi00.py +++ b/tests/fast/api/test_dbapi00.py @@ -2,17 +2,16 @@ import numpy import pytest -import duckdb -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas def assert_result_equal(result): assert result == [(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,), (9,), (None,)], "Incorrect result returned" -class TestSimpleDBAPI(object): +class TestSimpleDBAPI: def test_regular_selection(self, duckdb_cursor, integers): - duckdb_cursor.execute('SELECT * FROM integers') + duckdb_cursor.execute("SELECT * FROM integers") result = duckdb_cursor.fetchall() assert_result_equal(result) @@ -20,9 +19,7 @@ def test_fetchmany_default(self, duckdb_cursor, integers): # Get truth-value truth_value = len(duckdb_cursor.execute("select * from integers").fetchall()) - duckdb_cursor.execute('Select * from integers') - # by default 'size' is 1 - arraysize = 1 + duckdb_cursor.execute("Select * from integers") list_of_results = [] while True: res = duckdb_cursor.fetchmany() @@ -40,7 +37,7 @@ def test_fetchmany_default(self, duckdb_cursor, integers): def test_fetchmany(self, duckdb_cursor, integers): # Get truth value truth_value = len(duckdb_cursor.execute("select * from integers").fetchall()) - duckdb_cursor.execute('select * from integers') + duckdb_cursor.execute("select * from integers") list_of_results = [] arraysize = 3 expected_iteration_count = 1 + (int)(truth_value / arraysize) + (1 if truth_value % arraysize else 0) @@ -63,8 +60,8 @@ def test_fetchmany(self, duckdb_cursor, integers): assert len(res) == 0 def test_fetchmany_too_many(self, duckdb_cursor, integers): - truth_value = len(duckdb_cursor.execute('select * from integers').fetchall()) - duckdb_cursor.execute('select * from integers') + truth_value = len(duckdb_cursor.execute("select * from integers").fetchall()) + duckdb_cursor.execute("select * from integers") res = duckdb_cursor.fetchmany(truth_value * 5) assert len(res) == truth_value assert_result_equal(res) @@ -74,48 +71,48 @@ def test_fetchmany_too_many(self, duckdb_cursor, integers): assert len(res) == 0 def test_numpy_selection(self, duckdb_cursor, integers, timestamps): - duckdb_cursor.execute('SELECT * FROM integers') + duckdb_cursor.execute("SELECT * FROM integers") result = duckdb_cursor.fetchnumpy() arr = numpy.ma.masked_array(numpy.arange(11)) arr.mask = [False] * 10 + [True] - numpy.testing.assert_array_equal(result['i'], arr, "Incorrect result returned") - duckdb_cursor.execute('SELECT * FROM timestamps') + numpy.testing.assert_array_equal(result["i"], arr, "Incorrect result returned") + duckdb_cursor.execute("SELECT * FROM timestamps") result = duckdb_cursor.fetchnumpy() - arr = numpy.array(['1992-10-03 18:34:45', '2010-01-01 00:00:01', None], dtype="datetime64[ms]") + arr = numpy.array(["1992-10-03 18:34:45", "2010-01-01 00:00:01", None], dtype="datetime64[ms]") arr = numpy.ma.masked_array(arr) arr.mask = [False, False, True] - numpy.testing.assert_array_equal(result['t'], arr, "Incorrect result returned") + numpy.testing.assert_array_equal(result["t"], arr, "Incorrect result returned") - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_pandas_selection(self, duckdb_cursor, pandas, integers, timestamps): import datetime from packaging.version import Version # I don't know when this exactly changed, but 2.0.3 does not support this, recent versions do - if Version(pandas.__version__) <= Version('2.0.3'): + if Version(pandas.__version__) <= Version("2.0.3"): pytest.skip("The resulting dtype is 'object' when given a Series with dtype Int32DType") - duckdb_cursor.execute('SELECT * FROM integers') + duckdb_cursor.execute("SELECT * FROM integers") result = duckdb_cursor.fetchdf() array = numpy.ma.masked_array(numpy.arange(11)) array.mask = [False] * 10 + [True] - arr = {'i': pandas.Series(array.data, dtype=pandas.Int32Dtype)} - arr['i'][array.mask] = pandas.NA + arr = {"i": pandas.Series(array.data, dtype=pandas.Int32Dtype)} + arr["i"][array.mask] = pandas.NA arr = pandas.DataFrame(arr) pandas.testing.assert_frame_equal(result, arr) - duckdb_cursor.execute('SELECT * FROM timestamps') + duckdb_cursor.execute("SELECT * FROM timestamps") result = duckdb_cursor.fetchdf() df = pandas.DataFrame( { - 't': pandas.Series( + "t": pandas.Series( data=[ datetime.datetime(year=1992, month=10, day=3, hour=18, minute=34, second=45), datetime.datetime(year=2010, month=1, day=1, hour=0, minute=0, second=1), None, ], - dtype='datetime64[us]', + dtype="datetime64[us]", ) } ) diff --git a/tests/fast/api/test_dbapi01.py b/tests/fast/api/test_dbapi01.py index dd0d2b4e..4d52fd64 100644 --- a/tests/fast/api/test_dbapi01.py +++ b/tests/fast/api/test_dbapi01.py @@ -1,13 +1,14 @@ # multiple result sets import numpy + import duckdb -class TestMultipleResultSets(object): +class TestMultipleResultSets: def test_regular_selection(self, duckdb_cursor, integers): - duckdb_cursor.execute('SELECT * FROM integers') - duckdb_cursor.execute('SELECT * FROM integers') + duckdb_cursor.execute("SELECT * FROM integers") + duckdb_cursor.execute("SELECT * FROM integers") result = duckdb_cursor.fetchall() assert result == [ (0,), @@ -24,18 +25,18 @@ def test_regular_selection(self, duckdb_cursor, integers): ], "Incorrect result returned" def test_numpy_selection(self, duckdb_cursor, integers): - duckdb_cursor.execute('SELECT * FROM integers') - duckdb_cursor.execute('SELECT * FROM integers') + duckdb_cursor.execute("SELECT * FROM integers") + duckdb_cursor.execute("SELECT * FROM integers") result = duckdb_cursor.fetchnumpy() expected = numpy.ma.masked_array(numpy.arange(11), mask=([False] * 10 + [True])) - numpy.testing.assert_array_equal(result['i'], expected) + numpy.testing.assert_array_equal(result["i"], expected) def test_numpy_materialized(self, duckdb_cursor, integers): - connection = duckdb.connect('') + connection = duckdb.connect("") cursor = connection.cursor() - cursor.execute('CREATE TABLE integers (i integer)') - cursor.execute('INSERT INTO integers VALUES (0),(1),(2),(3),(4),(5),(6),(7),(8),(9),(NULL)') + cursor.execute("CREATE TABLE integers (i integer)") + cursor.execute("INSERT INTO integers VALUES (0),(1),(2),(3),(4),(5),(6),(7),(8),(9),(NULL)") rel = connection.table("integers") res = rel.aggregate("sum(i)").execute().fetchnumpy() - assert res['sum(i)'][0] == 45 + assert res["sum(i)"][0] == 45 diff --git a/tests/fast/api/test_dbapi04.py b/tests/fast/api/test_dbapi04.py index b2c9173a..2c2259ce 100644 --- a/tests/fast/api/test_dbapi04.py +++ b/tests/fast/api/test_dbapi04.py @@ -1,9 +1,9 @@ # simple DB API testcase -class TestSimpleDBAPI(object): +class TestSimpleDBAPI: def test_regular_selection(self, duckdb_cursor, integers): - duckdb_cursor.execute('SELECT * FROM integers') + duckdb_cursor.execute("SELECT * FROM integers") result = duckdb_cursor.fetchall() assert result == [ (0,), diff --git a/tests/fast/api/test_dbapi05.py b/tests/fast/api/test_dbapi05.py index 0de217f2..6c6d4fa1 100644 --- a/tests/fast/api/test_dbapi05.py +++ b/tests/fast/api/test_dbapi05.py @@ -1,9 +1,9 @@ # simple DB API testcase -class TestSimpleDBAPI(object): +class TestSimpleDBAPI: def test_prepare(self, duckdb_cursor): - result = duckdb_cursor.execute('SELECT CAST(? AS INTEGER), CAST(? AS INTEGER)', ['42', '84']).fetchall() + result = duckdb_cursor.execute("SELECT CAST(? AS INTEGER), CAST(? AS INTEGER)", ["42", "84"]).fetchall() assert result == [ ( 42, @@ -15,26 +15,26 @@ def test_prepare(self, duckdb_cursor): # from python docs c.execute( - '''CREATE TABLE stocks - (date text, trans text, symbol text, qty real, price real)''' + """CREATE TABLE stocks + (date text, trans text, symbol text, qty real, price real)""" ) c.execute("INSERT INTO stocks VALUES ('2006-01-05','BUY','RHAT',100,35.14)") - t = ('RHAT',) - result = c.execute('SELECT COUNT(*) FROM stocks WHERE symbol=?', t).fetchone() + t = ("RHAT",) + result = c.execute("SELECT COUNT(*) FROM stocks WHERE symbol=?", t).fetchone() assert result == (1,) - t = ['RHAT'] - result = c.execute('SELECT COUNT(*) FROM stocks WHERE symbol=?', t).fetchone() + t = ["RHAT"] + result = c.execute("SELECT COUNT(*) FROM stocks WHERE symbol=?", t).fetchone() assert result == (1,) # Larger example that inserts many records at a time purchases = [ - ('2006-03-28', 'BUY', 'IBM', 1000, 45.00), - ('2006-04-05', 'BUY', 'MSFT', 1000, 72.00), - ('2006-04-06', 'SELL', 'IBM', 500, 53.00), + ("2006-03-28", "BUY", "IBM", 1000, 45.00), + ("2006-04-05", "BUY", "MSFT", 1000, 72.00), + ("2006-04-06", "SELL", "IBM", 500, 53.00), ] - c.executemany('INSERT INTO stocks VALUES (?,?,?,?,?)', purchases) + c.executemany("INSERT INTO stocks VALUES (?,?,?,?,?)", purchases) - result = c.execute('SELECT count(*) FROM stocks').fetchone() + result = c.execute("SELECT count(*) FROM stocks").fetchone() assert result == (4,) diff --git a/tests/fast/api/test_dbapi07.py b/tests/fast/api/test_dbapi07.py index 7792b8de..eab581e5 100644 --- a/tests/fast/api/test_dbapi07.py +++ b/tests/fast/api/test_dbapi07.py @@ -1,16 +1,17 @@ # timestamp ms precision -import numpy from datetime import datetime +import numpy + -class TestNumpyTimestampMilliseconds(object): +class TestNumpyTimestampMilliseconds: def test_numpy_timestamp(self, duckdb_cursor): res = duckdb_cursor.execute("SELECT TIMESTAMP '2019-11-26 21:11:42.501' as test_time").fetchnumpy() - assert res['test_time'] == numpy.datetime64('2019-11-26 21:11:42.501') + assert res["test_time"] == numpy.datetime64("2019-11-26 21:11:42.501") -class TestTimestampMilliseconds(object): +class TestTimestampMilliseconds: def test_numpy_timestamp(self, duckdb_cursor): res = duckdb_cursor.execute("SELECT TIMESTAMP '2019-11-26 21:11:42.501' as test_time").fetchone()[0] - assert res == datetime.strptime('2019-11-26 21:11:42.501', '%Y-%m-%d %H:%M:%S.%f') + assert res == datetime.strptime("2019-11-26 21:11:42.501", "%Y-%m-%d %H:%M:%S.%f") diff --git a/tests/fast/api/test_dbapi08.py b/tests/fast/api/test_dbapi08.py index a81acfd1..def4e925 100644 --- a/tests/fast/api/test_dbapi08.py +++ b/tests/fast/api/test_dbapi08.py @@ -1,12 +1,12 @@ # test fetchdf with various types -import numpy import pytest -import duckdb from conftest import NumpyPandas +import duckdb + -class TestType(object): - @pytest.mark.parametrize('pandas', [NumpyPandas()]) +class TestType: + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_fetchdf(self, pandas): con = duckdb.connect() con.execute("CREATE TABLE items(item VARCHAR)") @@ -14,7 +14,7 @@ def test_fetchdf(self, pandas): res = con.execute("SELECT item FROM items").fetchdf() assert isinstance(res, pandas.core.frame.DataFrame) - df = pandas.DataFrame({'item': ['jeans', '', None]}) + df = pandas.DataFrame({"item": ["jeans", "", None]}) print(res) print(df) diff --git a/tests/fast/api/test_dbapi09.py b/tests/fast/api/test_dbapi09.py index dde8ebff..8a31e10e 100644 --- a/tests/fast/api/test_dbapi09.py +++ b/tests/fast/api/test_dbapi09.py @@ -1,22 +1,23 @@ # date type -import numpy import datetime + +import numpy import pandas -class TestNumpyDate(object): +class TestNumpyDate: def test_fetchall_date(self, duckdb_cursor): res = duckdb_cursor.execute("SELECT DATE '2020-01-10' as test_date").fetchall() assert res == [(datetime.date(2020, 1, 10),)] def test_fetchnumpy_date(self, duckdb_cursor): res = duckdb_cursor.execute("SELECT DATE '2020-01-10' as test_date").fetchnumpy() - arr = numpy.array(['2020-01-10'], dtype="datetime64[s]") + arr = numpy.array(["2020-01-10"], dtype="datetime64[s]") arr = numpy.ma.masked_array(arr) - numpy.testing.assert_array_equal(res['test_date'], arr) + numpy.testing.assert_array_equal(res["test_date"], arr) def test_fetchdf_date(self, duckdb_cursor): res = duckdb_cursor.execute("SELECT DATE '2020-01-10' as test_date").fetchdf() - ser = pandas.Series(numpy.array(['2020-01-10'], dtype="datetime64[us]"), name="test_date") - pandas.testing.assert_series_equal(res['test_date'], ser) + ser = pandas.Series(numpy.array(["2020-01-10"], dtype="datetime64[us]"), name="test_date") + pandas.testing.assert_series_equal(res["test_date"], ser) diff --git a/tests/fast/api/test_dbapi10.py b/tests/fast/api/test_dbapi10.py index 1fbde602..6d60b27c 100644 --- a/tests/fast/api/test_dbapi10.py +++ b/tests/fast/api/test_dbapi10.py @@ -1,22 +1,29 @@ # cursor description -from datetime import datetime, date -from pytest import mark +from datetime import date, datetime + +import pytest + import duckdb -class TestCursorDescription(object): - @mark.parametrize( - "query,column_name,string_type,real_type", +class TestCursorDescription: + @pytest.mark.parametrize( + ("query", "column_name", "string_type", "real_type"), [ - ["SELECT * FROM integers", "i", "INTEGER", int], - ["SELECT * FROM timestamps", "t", "TIMESTAMP", datetime], - ["SELECT DATE '1992-09-20' AS date_col;", "date_col", "DATE", date], - ["SELECT '\\xAA'::BLOB AS blob_col;", "blob_col", "BLOB", bytes], - ["SELECT {'x': 1, 'y': 2, 'z': 3} AS struct_col", "struct_col", "STRUCT(x INTEGER, y INTEGER, z INTEGER)", dict], - ["SELECT [1, 2, 3] AS list_col", "list_col", "INTEGER[]", list], - ["SELECT 'Frank' AS str_col", "str_col", "VARCHAR", str], - ["SELECT [1, 2, 3]::JSON AS json_col", "json_col", "JSON", str], - ["SELECT union_value(tag := 1) AS union_col", "union_col", "UNION(tag INTEGER)", int], + ("SELECT * FROM integers", "i", "INTEGER", int), + ("SELECT * FROM timestamps", "t", "TIMESTAMP", datetime), + ("SELECT DATE '1992-09-20' AS date_col;", "date_col", "DATE", date), + ("SELECT '\\xAA'::BLOB AS blob_col;", "blob_col", "BLOB", bytes), + ( + "SELECT {'x': 1, 'y': 2, 'z': 3} AS struct_col", + "struct_col", + "STRUCT(x INTEGER, y INTEGER, z INTEGER)", + dict, + ), + ("SELECT [1, 2, 3] AS list_col", "list_col", "INTEGER[]", list), + ("SELECT 'Frank' AS str_col", "str_col", "VARCHAR", str), + ("SELECT [1, 2, 3]::JSON AS json_col", "json_col", "JSON", str), + ("SELECT union_value(tag := 1) AS union_col", "union_col", "UNION(tag INTEGER)", int), ], ) def test_description(self, query, column_name, string_type, real_type, duckdb_cursor, timestamps, integers): @@ -32,20 +39,20 @@ def test_description_comparisons(self): NUMBER = duckdb.NUMBER DATETIME = duckdb.DATETIME - assert(types[1] == STRING) - assert(STRING == types[1]) - assert(types[0] != STRING) - assert((types[1] != STRING) == False) - assert((STRING != types[1]) == False) + assert types[1] == STRING + assert STRING == types[1] # noqa: SIM300 + assert types[0] != STRING + assert types[1] == STRING + assert STRING == types[1] # noqa: SIM300 - assert(types[1] in [STRING]) - assert(types[1] in [STRING, NUMBER]) - assert(types[1] not in [NUMBER, DATETIME]) + assert types[1] in [STRING] + assert types[1] in [STRING, NUMBER] + assert types[1] not in [NUMBER, DATETIME] def test_none_description(self, duckdb_empty_cursor): assert duckdb_empty_cursor.description is None -class TestCursorRowcount(object): +class TestCursorRowcount: def test_rowcount(self, duckdb_cursor): assert duckdb_cursor.rowcount == -1 diff --git a/tests/fast/api/test_dbapi11.py b/tests/fast/api/test_dbapi11.py index 91237b9e..c5e9fe1c 100644 --- a/tests/fast/api/test_dbapi11.py +++ b/tests/fast/api/test_dbapi11.py @@ -1,24 +1,23 @@ # cursor description -import duckdb import tempfile -import os + +import duckdb def check_exception(f): had_exception = False try: f() - except: + except Exception: had_exception = True assert had_exception -class TestReadOnly(object): +class TestReadOnly: def test_readonly(self, duckdb_cursor): - fd, db = tempfile.mkstemp() - os.close(fd) - os.remove(db) + with tempfile.NamedTemporaryFile() as tmp: + db = tmp.name # this is forbidden check_exception(lambda: duckdb.connect(":memory:", True)) diff --git a/tests/fast/api/test_dbapi12.py b/tests/fast/api/test_dbapi12.py index 78881f5e..57881144 100644 --- a/tests/fast/api/test_dbapi12.py +++ b/tests/fast/api/test_dbapi12.py @@ -1,54 +1,53 @@ -import duckdb -import tempfile -import os import pandas as pd +import duckdb + -class TestRelationApi(object): +class TestRelationApi: def test_readonly(self, duckdb_cursor): test_df = pd.DataFrame.from_dict({"i": [1, 2, 3], "j": ["one", "two", "three"]}) - def test_rel(rel, duckdb_cursor): + def test_rel(rel, duckdb_cursor) -> None: res = ( - rel.filter('i < 3') - .order('j') - .project('i') - .union(rel.filter('i > 2').project('i')) - .join(rel.set_alias('a1'), 'i') - .project('CAST(i as BIGINT) i, j') - .order('i') + rel.filter("i < 3") + .order("j") + .project("i") + .union(rel.filter("i > 2").project("i")) + .join(rel.set_alias("a1"), "i") + .project("CAST(i as BIGINT) i, j") + .order("i") ) pd.testing.assert_frame_equal(res.to_df(), test_df) res3 = duckdb_cursor.from_df(res.to_df()).to_df() pd.testing.assert_frame_equal(res3, test_df) - df_sql = res.query('x', 'select CAST(i as BIGINT) i, j from x') + df_sql = res.query("x", "select CAST(i as BIGINT) i, j from x") pd.testing.assert_frame_equal(df_sql.df(), test_df) - res2 = res.aggregate('i, count(j) as cj', 'i').order('i') + res2 = res.aggregate("i, count(j) as cj", "i").order("i") cmp_df = pd.DataFrame.from_dict({"i": [1, 2, 3], "cj": [1, 1, 1]}) pd.testing.assert_frame_equal(res2.to_df(), cmp_df) - duckdb_cursor.execute('DROP TABLE IF EXISTS a2') - rel.create('a2') - rel_a2 = duckdb_cursor.table('a2').project('CAST(i as BIGINT) i, j').to_df() + duckdb_cursor.execute("DROP TABLE IF EXISTS a2") + rel.create("a2") + rel_a2 = duckdb_cursor.table("a2").project("CAST(i as BIGINT) i, j").to_df() pd.testing.assert_frame_equal(rel_a2, test_df) - duckdb_cursor.execute('DROP TABLE IF EXISTS a3') - duckdb_cursor.execute('CREATE TABLE a3 (i INTEGER, j STRING)') - rel.insert_into('a3') - rel_a3 = duckdb_cursor.table('a3').project('CAST(i as BIGINT) i, j').to_df() + duckdb_cursor.execute("DROP TABLE IF EXISTS a3") + duckdb_cursor.execute("CREATE TABLE a3 (i INTEGER, j STRING)") + rel.insert_into("a3") + rel_a3 = duckdb_cursor.table("a3").project("CAST(i as BIGINT) i, j").to_df() pd.testing.assert_frame_equal(rel_a3, test_df) - duckdb_cursor.execute('CREATE TABLE a (i INTEGER, j STRING)') + duckdb_cursor.execute("CREATE TABLE a (i INTEGER, j STRING)") duckdb_cursor.execute("INSERT INTO a VALUES (1, 'one'), (2, 'two'), (3, 'three')") - duckdb_cursor.execute('CREATE VIEW v AS SELECT * FROM a') + duckdb_cursor.execute("CREATE VIEW v AS SELECT * FROM a") - duckdb_cursor.execute('CREATE TEMPORARY TABLE at_ (i INTEGER)') - duckdb_cursor.execute('CREATE TEMPORARY VIEW vt AS SELECT * FROM at_') + duckdb_cursor.execute("CREATE TEMPORARY TABLE at_ (i INTEGER)") + duckdb_cursor.execute("CREATE TEMPORARY VIEW vt AS SELECT * FROM at_") - rel_a = duckdb_cursor.table('a') - rel_v = duckdb_cursor.view('v') + rel_a = duckdb_cursor.table("a") + rel_v = duckdb_cursor.view("v") # rel_at = duckdb_cursor.table('at') # rel_vt = duckdb_cursor.view('vt') @@ -59,8 +58,8 @@ def test_rel(rel, duckdb_cursor): test_rel(rel_df, duckdb_cursor) def test_fromquery(self, duckdb_cursor): - assert duckdb.from_query('select 42').fetchone()[0] == 42 - assert duckdb_cursor.query('select 43').fetchone()[0] == 43 + assert duckdb.from_query("select 42").fetchone()[0] == 42 + assert duckdb_cursor.query("select 43").fetchone()[0] == 43 # assert duckdb_cursor.from_query('select 44').execute().fetchone()[0] == 44 # assert duckdb_cursor.from_query('select 45').execute().fetchone()[0] == 45 diff --git a/tests/fast/api/test_dbapi13.py b/tests/fast/api/test_dbapi13.py index fb7fbaa8..c08cefb1 100644 --- a/tests/fast/api/test_dbapi13.py +++ b/tests/fast/api/test_dbapi13.py @@ -1,11 +1,12 @@ # time type -import numpy import datetime + +import numpy import pandas -class TestNumpyTime(object): +class TestNumpyTime: def test_fetchall_time(self, duckdb_cursor): res = duckdb_cursor.execute("SELECT TIME '13:06:40' as test_time").fetchall() assert res == [(datetime.time(13, 6, 40),)] @@ -14,9 +15,9 @@ def test_fetchnumpy_time(self, duckdb_cursor): res = duckdb_cursor.execute("SELECT TIME '13:06:40' as test_time").fetchnumpy() arr = numpy.array([datetime.time(13, 6, 40)], dtype="object") arr = numpy.ma.masked_array(arr) - numpy.testing.assert_array_equal(res['test_time'], arr) + numpy.testing.assert_array_equal(res["test_time"], arr) def test_fetchdf_time(self, duckdb_cursor): res = duckdb_cursor.execute("SELECT TIME '13:06:40' as test_time").fetchdf() ser = pandas.Series(numpy.array([datetime.time(13, 6, 40)], dtype="object"), name="test_time") - pandas.testing.assert_series_equal(res['test_time'], ser) + pandas.testing.assert_series_equal(res["test_time"], ser) diff --git a/tests/fast/api/test_dbapi_fetch.py b/tests/fast/api/test_dbapi_fetch.py index 6eda4b9d..97ff6fe6 100644 --- a/tests/fast/api/test_dbapi_fetch.py +++ b/tests/fast/api/test_dbapi_fetch.py @@ -1,28 +1,30 @@ -import duckdb -import pytest -from uuid import UUID import datetime from decimal import Decimal +from uuid import UUID + +import pytest + +import duckdb -class TestDBApiFetch(object): +class TestDBApiFetch: def test_multiple_fetch_one(self, duckdb_cursor): con = duckdb.connect() - c = con.execute('SELECT 42') + c = con.execute("SELECT 42") assert c.fetchone() == (42,) assert c.fetchone() is None assert c.fetchone() is None def test_multiple_fetch_all(self, duckdb_cursor): con = duckdb.connect() - c = con.execute('SELECT 42') + c = con.execute("SELECT 42") assert c.fetchall() == [(42,)] assert c.fetchall() == [] assert c.fetchall() == [] def test_multiple_fetch_many(self, duckdb_cursor): con = duckdb.connect() - c = con.execute('SELECT 42') + c = con.execute("SELECT 42") assert c.fetchmany(1000) == [(42,)] assert c.fetchmany(1000) == [] assert c.fetchmany(1000) == [] @@ -30,45 +32,45 @@ def test_multiple_fetch_many(self, duckdb_cursor): def test_multiple_fetch_df(self, duckdb_cursor): pd = pytest.importorskip("pandas") con = duckdb.connect() - c = con.execute('SELECT 42::BIGINT AS a') - pd.testing.assert_frame_equal(c.df(), pd.DataFrame.from_dict({'a': [42]})) + c = con.execute("SELECT 42::BIGINT AS a") + pd.testing.assert_frame_equal(c.df(), pd.DataFrame.from_dict({"a": [42]})) assert c.df() is None assert c.df() is None def test_multiple_fetch_arrow(self, duckdb_cursor): pd = pytest.importorskip("pandas") - arrow = pytest.importorskip("pyarrow") + pytest.importorskip("pyarrow") con = duckdb.connect() - c = con.execute('SELECT 42::BIGINT AS a') + c = con.execute("SELECT 42::BIGINT AS a") table = c.fetch_arrow_table() df = table.to_pandas() - pd.testing.assert_frame_equal(df, pd.DataFrame.from_dict({'a': [42]})) + pd.testing.assert_frame_equal(df, pd.DataFrame.from_dict({"a": [42]})) assert c.fetch_arrow_table() is None assert c.fetch_arrow_table() is None def test_multiple_close(self, duckdb_cursor): con = duckdb.connect() - c = con.execute('SELECT 42') + c = con.execute("SELECT 42") c.close() c.close() c.close() - with pytest.raises(duckdb.InvalidInputException, match='No open result set'): + with pytest.raises(duckdb.InvalidInputException, match="No open result set"): c.fetchall() def test_multiple_fetch_all_relation(self, duckdb_cursor): - res = duckdb_cursor.query('SELECT 42') + res = duckdb_cursor.query("SELECT 42") assert res.fetchall() == [(42,)] assert res.fetchall() == [(42,)] assert res.fetchall() == [(42,)] def test_multiple_fetch_many_relation(self, duckdb_cursor): - res = duckdb_cursor.query('SELECT 42') + res = duckdb_cursor.query("SELECT 42") assert res.fetchmany(10000) == [(42,)] assert res.fetchmany(10000) == [] assert res.fetchmany(10000) == [] def test_fetch_one_relation(self, duckdb_cursor): - res = duckdb_cursor.query('SELECT * FROM range(3)') + res = duckdb_cursor.query("SELECT * FROM range(3)") assert res.fetchone() == (0,) assert res.fetchone() == (1,) assert res.fetchone() == (2,) @@ -86,40 +88,40 @@ def test_fetch_one_relation(self, duckdb_cursor): assert res.fetchone() is None @pytest.mark.parametrize( - 'test_case', + "test_case", [ - (False, 'BOOLEAN', False), - (-128, 'TINYINT', -128), - (-32768, 'SMALLINT', -32768), - (-2147483648, 'INTEGER', -2147483648), - (-9223372036854775808, 'BIGINT', -9223372036854775808), - (-170141183460469231731687303715884105728, 'HUGEINT', -170141183460469231731687303715884105728), - (0, 'UTINYINT', 0), - (0, 'USMALLINT', 0), - (0, 'UINTEGER', 0), - (0, 'UBIGINT', 0), - (0, 'UHUGEINT', 0), - (1.3423423767089844, 'FLOAT', 1.3423424), - (1.3423424, 'DOUBLE', 1.3423424), - (Decimal('1.342342'), 'DECIMAL(10, 6)', 1.342342), - ('hello', "ENUM('world', 'hello')", 'hello'), - ('🦆🦆🦆🦆🦆🦆', 'VARCHAR', '🦆🦆🦆🦆🦆🦆'), - (b'thisisalongblob\x00withnullbytes', 'BLOB', 'thisisalongblob\\x00withnullbytes'), - ('0010001001011100010101011010111', 'BITSTRING', '0010001001011100010101011010111'), - ('290309-12-22 (BC) 00:00:00', 'TIMESTAMP', '290309-12-22 (BC) 00:00:00'), - ('290309-12-22 (BC) 00:00:00', 'TIMESTAMP_MS', '290309-12-22 (BC) 00:00:00'), - (datetime.datetime(1677, 9, 22, 0, 0), 'TIMESTAMP_NS', '1677-09-22 00:00:00'), - ('290309-12-22 (BC) 00:00:00', 'TIMESTAMP_S', '290309-12-22 (BC) 00:00:00'), - ('290309-12-22 (BC) 00:00:30+00', 'TIMESTAMPTZ', '290309-12-22 (BC) 00:17:30+00:17'), + (False, "BOOLEAN", False), + (-128, "TINYINT", -128), + (-32768, "SMALLINT", -32768), + (-2147483648, "INTEGER", -2147483648), + (-9223372036854775808, "BIGINT", -9223372036854775808), + (-170141183460469231731687303715884105728, "HUGEINT", -170141183460469231731687303715884105728), + (0, "UTINYINT", 0), + (0, "USMALLINT", 0), + (0, "UINTEGER", 0), + (0, "UBIGINT", 0), + (0, "UHUGEINT", 0), + (1.3423423767089844, "FLOAT", 1.3423424), + (1.3423424, "DOUBLE", 1.3423424), + (Decimal("1.342342"), "DECIMAL(10, 6)", 1.342342), + ("hello", "ENUM('world', 'hello')", "hello"), + ("🦆🦆🦆🦆🦆🦆", "VARCHAR", "🦆🦆🦆🦆🦆🦆"), + (b"thisisalongblob\x00withnullbytes", "BLOB", "thisisalongblob\\x00withnullbytes"), + ("0010001001011100010101011010111", "BITSTRING", "0010001001011100010101011010111"), + ("290309-12-22 (BC) 00:00:00", "TIMESTAMP", "290309-12-22 (BC) 00:00:00"), + ("290309-12-22 (BC) 00:00:00", "TIMESTAMP_MS", "290309-12-22 (BC) 00:00:00"), + (datetime.datetime(1677, 9, 22, 0, 0), "TIMESTAMP_NS", "1677-09-22 00:00:00"), + ("290309-12-22 (BC) 00:00:00", "TIMESTAMP_S", "290309-12-22 (BC) 00:00:00"), + ("290309-12-22 (BC) 00:00:30+00", "TIMESTAMPTZ", "290309-12-22 (BC) 00:17:30+00:17"), ( datetime.time(0, 0, tzinfo=datetime.timezone(datetime.timedelta(seconds=57599))), - 'TIMETZ', - '00:00:00+15:59:59', + "TIMETZ", + "00:00:00+15:59:59", ), - ('5877642-06-25 (BC)', 'DATE', '5877642-06-25 (BC)'), - (UUID('cd57dfbd-d65f-4e15-991e-2a92e74b9f79'), 'UUID', 'cd57dfbd-d65f-4e15-991e-2a92e74b9f79'), - (datetime.timedelta(days=90), 'INTERVAL', '3 months'), - ('🦆🦆🦆🦆🦆🦆', 'UNION(a int, b bool, c varchar)', '🦆🦆🦆🦆🦆🦆'), + ("5877642-06-25 (BC)", "DATE", "5877642-06-25 (BC)"), + (UUID("cd57dfbd-d65f-4e15-991e-2a92e74b9f79"), "UUID", "cd57dfbd-d65f-4e15-991e-2a92e74b9f79"), + (datetime.timedelta(days=90), "INTERVAL", "3 months"), + ("🦆🦆🦆🦆🦆🦆", "UNION(a int, b bool, c varchar)", "🦆🦆🦆🦆🦆🦆"), ], ) def test_fetch_dict_coverage(self, duckdb_cursor, test_case): @@ -138,7 +140,7 @@ def test_fetch_dict_coverage(self, duckdb_cursor, test_case): print(res[0].keys()) assert res[0][python_key] == -2147483648 - @pytest.mark.parametrize('test_case', ['VARCHAR[]']) + @pytest.mark.parametrize("test_case", ["VARCHAR[]"]) def test_fetch_dict_key_not_hashable(self, duckdb_cursor, test_case): key_type = test_case query = f""" @@ -153,4 +155,4 @@ def test_fetch_dict_key_not_hashable(self, duckdb_cursor, test_case): select a from map_cte; """ res = duckdb_cursor.sql(query).fetchone() - assert 'key' in res[0].keys() + assert "key" in res[0] diff --git a/tests/fast/api/test_duckdb_connection.py b/tests/fast/api/test_duckdb_connection.py index 4cb565c1..d197e639 100644 --- a/tests/fast/api/test_duckdb_connection.py +++ b/tests/fast/api/test_duckdb_connection.py @@ -1,7 +1,10 @@ +import re + +import pytest +from conftest import ArrowPandas, NumpyPandas + import duckdb import duckdb.typing -import pytest -from conftest import NumpyPandas, ArrowPandas pa = pytest.importorskip("pyarrow") @@ -9,9 +12,9 @@ def is_dunder_method(method_name: str) -> bool: if len(method_name) < 4: return False - if method_name.startswith('_pybind11'): + if method_name.startswith("_pybind11"): return True - return method_name[:2] == '__' and method_name[:-3:-1] == '__' + return method_name[:2] == "__" and method_name[:-3:-1] == "__" @pytest.fixture(scope="session") @@ -22,44 +25,44 @@ def tmp_database(tmp_path_factory): # This file contains tests for DuckDBPyConnection methods, # wrapped by the 'duckdb' module, to execute with the 'default_connection' -class TestDuckDBConnection(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) +class TestDuckDBConnection: + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_append(self, pandas): duckdb.execute("Create table integers (i integer)") df_in = pandas.DataFrame( { - 'numbers': [1, 2, 3, 4, 5], + "numbers": [1, 2, 3, 4, 5], } ) - duckdb.append('integers', df_in) - assert duckdb.execute('select count(*) from integers').fetchone()[0] == 5 + duckdb.append("integers", df_in) + assert duckdb.execute("select count(*) from integers").fetchone()[0] == 5 # cleanup duckdb.execute("drop table integers") def test_default_connection_from_connect(self): - duckdb.sql('create or replace table connect_default_connect (i integer)') - con = duckdb.connect(':default:') - con.sql('select i from connect_default_connect') - duckdb.sql('drop table connect_default_connect') + duckdb.sql("create or replace table connect_default_connect (i integer)") + con = duckdb.connect(":default:") + con.sql("select i from connect_default_connect") + duckdb.sql("drop table connect_default_connect") with pytest.raises(duckdb.Error): - con.sql('select i from connect_default_connect') + con.sql("select i from connect_default_connect") # not allowed with additional options with pytest.raises( - duckdb.InvalidInputException, match='Default connection fetching is only allowed without additional options' + duckdb.InvalidInputException, match="Default connection fetching is only allowed without additional options" ): - con = duckdb.connect(':default:', read_only=True) + con = duckdb.connect(":default:", read_only=True) def test_arrow(self): - pyarrow = pytest.importorskip("pyarrow") + pytest.importorskip("pyarrow") duckdb.execute("select [1,2,3]") - result = duckdb.fetch_arrow_table() + duckdb.fetch_arrow_table() def test_begin_commit(self): duckdb.begin() duckdb.execute("create table tbl as select 1") duckdb.commit() - res = duckdb.table("tbl") + duckdb.table("tbl") duckdb.execute("drop table tbl") def test_begin_rollback(self): @@ -68,7 +71,7 @@ def test_begin_rollback(self): duckdb.rollback() with pytest.raises(duckdb.CatalogException): # Table does not exist - res = duckdb.table("tbl") + duckdb.table("tbl") def test_cursor(self): duckdb.execute("create table tbl as select 3") @@ -83,13 +86,10 @@ def test_cursor(self): def test_cursor_lifetime(self): con = duckdb.connect() - def use_cursors(): - cursors = [] - for _ in range(10): - cursors.append(con.cursor()) + def use_cursors() -> None: + cursors = [con.cursor() for _ in range(10)] for cursor in cursors: - print("closing cursor") cursor.close() use_cursors() @@ -98,7 +98,7 @@ def use_cursors(): def test_df(self): ref = [([1, 2, 3],)] duckdb.execute("select [1,2,3]") - res_df = duckdb.fetch_df() + res_df = duckdb.fetch_df() # noqa: F841 res = duckdb.query("select * from res_df").fetchall() assert res == ref @@ -114,98 +114,99 @@ def test_readonly_properties(self): duckdb.execute("select 42") description = duckdb.description() rowcount = duckdb.rowcount() - assert description == [('42', 'INTEGER', None, None, None, None, None)] + assert description == [("42", "INTEGER", None, None, None, None, None)] assert rowcount == -1 def test_execute(self): - assert [([4, 2],)] == duckdb.execute("select [4,2]").fetchall() + assert duckdb.execute("select [4,2]").fetchall() == [([4, 2],)] def test_executemany(self): # executemany does not keep an open result set - # TODO: shouldn't we also have a version that executes a query multiple times with different parameters, returning all of the results? + # TODO: shouldn't we also have a version that executes a query multiple times with # noqa: TD002, TD003 + # different parameters, returning all of the results? duckdb.execute("create table tbl (i integer, j varchar)") - duckdb.executemany("insert into tbl VALUES (?, ?)", [(5, 'test'), (2, 'duck'), (42, 'quack')]) + duckdb.executemany("insert into tbl VALUES (?, ?)", [(5, "test"), (2, "duck"), (42, "quack")]) res = duckdb.table("tbl").fetchall() - assert res == [(5, 'test'), (2, 'duck'), (42, 'quack')] + assert res == [(5, "test"), (2, "duck"), (42, "quack")] duckdb.execute("drop table tbl") def test_pystatement(self): - with pytest.raises(duckdb.ParserException, match='seledct'): - statements = duckdb.extract_statements('seledct 42; select 21') + with pytest.raises(duckdb.ParserException, match="seledct"): + statements = duckdb.extract_statements("seledct 42; select 21") - statements = duckdb.extract_statements('select $1; select 21') + statements = duckdb.extract_statements("select $1; select 21") assert len(statements) == 2 - assert statements[0].query == 'select $1' + assert statements[0].query == "select $1" assert statements[0].type == duckdb.StatementType.SELECT - assert statements[0].named_parameters == set('1') + assert statements[0].named_parameters == set("1") assert statements[0].expected_result_type == [duckdb.ExpectedResultType.QUERY_RESULT] - assert statements[1].query == ' select 21' + assert statements[1].query == " select 21" assert statements[1].type == duckdb.StatementType.SELECT assert statements[1].named_parameters == set() with pytest.raises( duckdb.InvalidInputException, - match='Please provide either a DuckDBPyStatement or a string representing the query', + match="Please provide either a DuckDBPyStatement or a string representing the query", ): - rel = duckdb.query(statements) + duckdb.query(statements) with pytest.raises(duckdb.BinderException, match="This type of statement can't be prepared!"): - rel = duckdb.query(statements[0]) + duckdb.query(statements[0]) assert duckdb.query(statements[1]).fetchall() == [(21,)] assert duckdb.execute(statements[1]).fetchall() == [(21,)] with pytest.raises( duckdb.InvalidInputException, - match='Values were not provided for the following prepared statement parameters: 1', + match="Values were not provided for the following prepared statement parameters: 1", ): duckdb.execute(statements[0]) - assert duckdb.execute(statements[0], {'1': 42}).fetchall() == [(42,)] + assert duckdb.execute(statements[0], {"1": 42}).fetchall() == [(42,)] duckdb.execute("create table tbl(a integer)") - statements = duckdb.extract_statements('insert into tbl select $1') + statements = duckdb.extract_statements("insert into tbl select $1") assert statements[0].expected_result_type == [ duckdb.ExpectedResultType.CHANGED_ROWS, duckdb.ExpectedResultType.QUERY_RESULT, ] with pytest.raises( - duckdb.InvalidInputException, match='executemany requires a non-empty list of parameter sets to be provided' + duckdb.InvalidInputException, match="executemany requires a non-empty list of parameter sets to be provided" ): duckdb.executemany(statements[0]) duckdb.executemany(statements[0], [(21,), (22,), (23,)]) - assert duckdb.table('tbl').fetchall() == [(21,), (22,), (23,)] + assert duckdb.table("tbl").fetchall() == [(21,), (22,), (23,)] duckdb.execute("drop table tbl") def test_fetch_arrow_table(self): # Needed for 'fetch_arrow_table' - pyarrow = pytest.importorskip("pyarrow") + pytest.importorskip("pyarrow") duckdb.execute("Create Table test (a integer)") for i in range(1024): - for j in range(2): - duckdb.execute("Insert Into test values ('" + str(i) + "')") + duckdb.execute("Insert Into test values ('" + str(i) + "')") + duckdb.execute("Insert Into test values ('" + str(i) + "')") duckdb.execute("Insert Into test values ('5000')") duckdb.execute("Insert Into test values ('6000')") - sql = ''' + sql = """ SELECT a, COUNT(*) AS repetitions FROM test GROUP BY a - ''' + """ result_df = duckdb.execute(sql).df() arrow_table = duckdb.execute(sql).fetch_arrow_table() arrow_df = arrow_table.to_pandas() - assert result_df['repetitions'].sum() == arrow_df['repetitions'].sum() + assert result_df["repetitions"].sum() == arrow_df["repetitions"].sum() duckdb.execute("drop table test") def test_fetch_df(self): ref = [([1, 2, 3],)] duckdb.execute("select [1,2,3]") - res_df = duckdb.fetch_df() + res_df = duckdb.fetch_df() # noqa: F841 res = duckdb.query("select * from res_df").fetchall() assert res == ref @@ -213,16 +214,16 @@ def test_fetch_df_chunk(self): duckdb.execute("CREATE table t as select range a from range(3000);") query = duckdb.execute("SELECT a FROM t") cur_chunk = query.fetch_df_chunk() - assert cur_chunk['a'][0] == 0 + assert cur_chunk["a"][0] == 0 assert len(cur_chunk) == 2048 cur_chunk = query.fetch_df_chunk() - assert cur_chunk['a'][0] == 2048 + assert cur_chunk["a"][0] == 2048 assert len(cur_chunk) == 952 duckdb.execute("DROP TABLE t") def test_fetch_record_batch(self): # Needed for 'fetch_arrow_table' - pyarrow = pytest.importorskip("pyarrow") + pytest.importorskip("pyarrow") duckdb.execute("CREATE table t as select range a from range(3000);") duckdb.execute("SELECT a FROM t") @@ -231,68 +232,68 @@ def test_fetch_record_batch(self): assert len(chunk) == 3000 def test_fetchall(self): - assert [([1, 2, 3],)] == duckdb.execute("select [1,2,3]").fetchall() + assert duckdb.execute("select [1,2,3]").fetchall() == [([1, 2, 3],)] def test_fetchdf(self): ref = [([1, 2, 3],)] duckdb.execute("select [1,2,3]") - res_df = duckdb.fetchdf() + res_df = duckdb.fetchdf() # noqa: F841 res = duckdb.query("select * from res_df").fetchall() assert res == ref def test_fetchmany(self): - assert [(0,), (1,)] == duckdb.execute("select * from range(5)").fetchmany(2) + assert duckdb.execute("select * from range(5)").fetchmany(2) == [(0,), (1,)] def test_fetchnumpy(self): numpy = pytest.importorskip("numpy") duckdb.execute("SELECT BLOB 'hello'") results = duckdb.fetchall() - assert results[0][0] == b'hello' + assert results[0][0] == b"hello" duckdb.execute("SELECT BLOB 'hello' AS a") results = duckdb.fetchnumpy() - assert results['a'] == numpy.array([b'hello'], dtype=object) + assert results["a"] == numpy.array([b"hello"], dtype=object) def test_fetchone(self): - assert (0,) == duckdb.execute("select * from range(5)").fetchone() + assert duckdb.execute("select * from range(5)").fetchone() == (0,) def test_from_arrow(self): - assert None != duckdb.from_arrow + assert duckdb.from_arrow is not None def test_from_csv_auto(self): - assert None != duckdb.from_csv_auto + assert duckdb.from_csv_auto is not None def test_from_df(self): - assert None != duckdb.from_df + assert duckdb.from_df is not None def test_from_parquet(self): - assert None != duckdb.from_parquet + assert duckdb.from_parquet is not None def test_from_query(self): - assert None != duckdb.from_query + assert duckdb.from_query is not None def test_get_table_names(self): - assert None != duckdb.get_table_names + assert duckdb.get_table_names is not None def test_install_extension(self): - assert None != duckdb.install_extension + assert duckdb.install_extension is not None def test_load_extension(self): - assert None != duckdb.load_extension + assert duckdb.load_extension is not None def test_query(self): - assert [(3,)] == duckdb.query("select 3").fetchall() + assert duckdb.query("select 3").fetchall() == [(3,)] def test_register(self): - assert None != duckdb.register + assert duckdb.register is not None def test_register_relation(self): con = duckdb.connect() - rel = con.sql('select [5,4,3]') + rel = con.sql("select [5,4,3]") con.register("relation", rel) con.sql("create table tbl as select * from relation") - assert con.table('tbl').fetchall() == [([5, 4, 3],)] + assert con.table("tbl").fetchall() == [([5, 4, 3],)] def test_unregister_problematic_behavior(self, duckdb_cursor): # We have a VIEW called 'vw' in the Catalog @@ -302,65 +303,65 @@ def test_unregister_problematic_behavior(self, duckdb_cursor): # Create a registered object called 'vw' arrow_result = duckdb_cursor.execute("select 42").fetch_arrow_table() with pytest.raises(duckdb.CatalogException, match='View with name "vw" already exists'): - duckdb_cursor.register('vw', arrow_result) + duckdb_cursor.register("vw", arrow_result) # Temporary views take precedence over registered objects assert duckdb_cursor.execute("select * from vw").fetchone() == (0,) # Decide that we're done with this registered object.. - duckdb_cursor.unregister('vw') + duckdb_cursor.unregister("vw") # This should not have affected the existing view: assert duckdb_cursor.execute("select * from vw").fetchone() == (0,) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_relation_out_of_scope(self, pandas): def temporary_scope(): # Create a connection, we will return this con = duckdb.connect() # Create a dataframe - df = pandas.DataFrame({'a': [1, 2, 3]}) + df = pandas.DataFrame({"a": [1, 2, 3]}) # The dataframe has to be registered as well # making sure it does not go out of scope con.register("df", df) - rel = con.sql('select * from df') + rel = con.sql("select * from df") con.register("relation", rel) return con con = temporary_scope() - res = con.sql('select * from relation').fetchall() + res = con.sql("select * from relation").fetchall() print(res) def test_table(self): con = duckdb.connect() con.execute("create table tbl as select 1") - assert [(1,)] == con.table("tbl").fetchall() + assert con.table("tbl").fetchall() == [(1,)] def test_table_function(self): - assert None != duckdb.table_function + assert duckdb.table_function is not None def test_unregister(self): - assert None != duckdb.unregister + assert duckdb.unregister is not None def test_values(self): - assert None != duckdb.values + assert duckdb.values is not None def test_view(self): duckdb.execute("create view vw as select range(5)") - assert [([0, 1, 2, 3, 4],)] == duckdb.view("vw").fetchall() + assert duckdb.view("vw").fetchall() == [([0, 1, 2, 3, 4],)] duckdb.execute("drop view vw") def test_close(self): - assert None != duckdb.close + assert duckdb.close is not None def test_interrupt(self): - assert None != duckdb.interrupt + assert duckdb.interrupt is not None def test_wrap_shadowing(self): pd = NumpyPandas() import duckdb - df = pd.DataFrame({"a": [1, 2, 3]}) + df = pd.DataFrame({"a": [1, 2, 3]}) # noqa: F841 res = duckdb.sql("from df").fetchall() assert res == [(1,), (2,), (3,)] @@ -382,7 +383,8 @@ def test_connect_with_path(self, tmp_database): assert con.sql("select 42").fetchall() == [(42,)] with pytest.raises( - duckdb.InvalidInputException, match="Please provide either a str or a pathlib.Path, not " + duckdb.InvalidInputException, + match=re.escape("Please provide either a str or a pathlib.Path, not "), ): con = duckdb.connect(5) @@ -393,7 +395,7 @@ def test_set_pandas_analyze_sample_size(self): # Find the cached config con2 = duckdb.connect(":memory:named", config={"pandas_analyze_sample": 0}) - con2.execute(f"SET GLOBAL pandas_analyze_sample=2") + con2.execute("SET GLOBAL pandas_analyze_sample=2") # This change is reflected in 'con' because the instance was cached res = con.sql("select current_setting('pandas_analyze_sample')").fetchone() diff --git a/tests/fast/api/test_duckdb_execute.py b/tests/fast/api/test_duckdb_execute.py index fba01a0c..389659be 100644 --- a/tests/fast/api/test_duckdb_execute.py +++ b/tests/fast/api/test_duckdb_execute.py @@ -1,11 +1,12 @@ -import duckdb import pytest +import duckdb + -class TestDuckDBExecute(object): +class TestDuckDBExecute: def test_execute_basic(self, duckdb_cursor): - duckdb_cursor.execute('create table t as select 5') - res = duckdb_cursor.table('t').fetchall() + duckdb_cursor.execute("create table t as select 5") + res = duckdb_cursor.table("t").fetchall() assert res == [(5,)] def test_execute_many_basic(self, duckdb_cursor): @@ -19,11 +20,11 @@ def test_execute_many_basic(self, duckdb_cursor): """, (99,), ) - res = duckdb_cursor.table('t').fetchall() + res = duckdb_cursor.table("t").fetchall() assert res == [(99,)] @pytest.mark.parametrize( - 'rowcount', + "rowcount", [ 50, 2048, @@ -40,7 +41,7 @@ def generator(rowcount): yield min(2048, rowcount - count) count += 2048 - # FIXME: perhaps we want to test with different buffer sizes? + # TODO: perhaps we want to test with different buffer sizes? # noqa: TD002, TD003 # duckdb_cursor.execute("set streaming_buffer_size='1mb'") duckdb_cursor.execute(f"create table tbl as from range({rowcount})") duckdb_cursor.execute("select * from tbl") @@ -53,7 +54,7 @@ def test_execute_many_error(self, duckdb_cursor): # Prepared parameter used in a statement that is not the last with pytest.raises( - duckdb.NotImplementedException, match='Prepared parameters are only supported for the last statement' + duckdb.NotImplementedException, match="Prepared parameters are only supported for the last statement" ): duckdb_cursor.execute( """ @@ -67,17 +68,16 @@ def test_execute_many_generator(self, duckdb_cursor): to_insert = [[1], [2], [3]] def to_insert_from_generator(what): - for x in what: - yield x + yield from what gen = to_insert_from_generator(to_insert) duckdb_cursor.execute("CREATE TABLE unittest_generator (a INTEGER);") duckdb_cursor.executemany("INSERT into unittest_generator (a) VALUES (?)", gen) - assert duckdb_cursor.table('unittest_generator').fetchall() == [(1,), (2,), (3,)] + assert duckdb_cursor.table("unittest_generator").fetchall() == [(1,), (2,), (3,)] def test_execute_multiple_statements(self, duckdb_cursor): pd = pytest.importorskip("pandas") - df = pd.DataFrame({'a': [5, 6, 7, 8]}) + df = pd.DataFrame({"a": [5, 6, 7, 8]}) # noqa: F841 sql = """ select * from df; select * from VALUES (1),(2),(3),(4) t(a); diff --git a/tests/fast/api/test_duckdb_query.py b/tests/fast/api/test_duckdb_query.py index 43f36603..04531e49 100644 --- a/tests/fast/api/test_duckdb_query.py +++ b/tests/fast/api/test_duckdb_query.py @@ -1,44 +1,45 @@ -import duckdb import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb from duckdb import Value -class TestDuckDBQuery(object): +class TestDuckDBQuery: def test_duckdb_query(self, duckdb_cursor): # we can use duckdb_cursor.sql to run both DDL statements and select statements - duckdb_cursor.sql('create view v1 as select 42 i') - rel = duckdb_cursor.sql('select * from v1') + duckdb_cursor.sql("create view v1 as select 42 i") + rel = duckdb_cursor.sql("select * from v1") assert rel.fetchall()[0][0] == 42 # also multiple statements - duckdb_cursor.sql('create view v2 as select i*2 j from v1; create view v3 as select j * 2 from v2;') - rel = duckdb_cursor.sql('select * from v3') + duckdb_cursor.sql("create view v2 as select i*2 j from v1; create view v3 as select j * 2 from v2;") + rel = duckdb_cursor.sql("select * from v3") assert rel.fetchall()[0][0] == 168 # we can run multiple select statements - we get only the last result - res = duckdb_cursor.sql('select 42; select 84;').fetchall() + res = duckdb_cursor.sql("select 42; select 84;").fetchall() assert res == [(84,)] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_duckdb_from_query_multiple_statements(self, pandas): - tst_df = pandas.DataFrame({'a': [1, 23, 3, 5]}) + tst_df = pandas.DataFrame({"a": [1, 23, 3, 5]}) # noqa: F841 res = duckdb.sql( - ''' + """ select 42; select * from tst_df union all select * from tst_df; - ''' + """ ).fetchall() assert res == [(1,), (23,), (3,), (5,), (1,), (23,), (3,), (5,)] def test_duckdb_query_empty_result(self): con = duckdb.connect() # show tables on empty connection does not produce any tuples - res = con.query('show tables').fetchall() + res = con.query("show tables").fetchall() assert res == [] def test_parametrized_explain(self, duckdb_cursor): @@ -57,7 +58,7 @@ def test_parametrized_explain(self, duckdb_cursor): duckdb_cursor.execute(query, params) results = duckdb_cursor.fetchall() - assert 'EXPLAIN_ANALYZE' in results[0][1] + assert "EXPLAIN_ANALYZE" in results[0][1] def test_named_param(self): con = duckdb.connect() @@ -83,7 +84,7 @@ def test_named_param(self): from range(100) tbl(i) """, - {'param': 5, 'other_param': 10}, + {"param": 5, "other_param": 10}, ).fetchall() assert res == original_res @@ -95,14 +96,14 @@ def test_named_param_not_dict(self): duckdb.InvalidInputException, match="Values were not provided for the following prepared statement parameters: name1, name2, name3", ): - con.execute("select $name1, $name2, $name3", ['name1', 'name2', 'name3']) + con.execute("select $name1, $name2, $name3", ["name1", "name2", "name3"]) def test_named_param_basic(self): con = duckdb.connect() - res = con.execute("select $name1, $name2, $name3", {'name1': 5, 'name2': 3, 'name3': 'a'}).fetchall() + res = con.execute("select $name1, $name2, $name3", {"name1": 5, "name2": 3, "name3": "a"}).fetchall() assert res == [ - (5, 3, 'a'), + (5, 3, "a"), ] def test_named_param_not_exhaustive(self): @@ -110,9 +111,9 @@ def test_named_param_not_exhaustive(self): with pytest.raises( duckdb.InvalidInputException, - match="Invalid Input Error: Values were not provided for the following prepared statement parameters: name3", + match="Invalid Input Error: Values were not provided for the following prepared statement parameters: name3", # noqa: E501 ): - con.execute("select $name1, $name2, $name3", {'name1': 5, 'name2': 3}) + con.execute("select $name1, $name2, $name3", {"name1": 5, "name2": 3}) def test_named_param_excessive(self): con = duckdb.connect() @@ -121,7 +122,7 @@ def test_named_param_excessive(self): duckdb.InvalidInputException, match="Values were not provided for the following prepared statement parameters: name3", ): - con.execute("select $name1, $name2, $name3", {'name1': 5, 'name2': 3, 'not_a_named_param': 5}) + con.execute("select $name1, $name2, $name3", {"name1": 5, "name2": 3, "not_a_named_param": 5}) def test_named_param_not_named(self): con = duckdb.connect() @@ -130,7 +131,7 @@ def test_named_param_not_named(self): duckdb.InvalidInputException, match="Values were not provided for the following prepared statement parameters: 1, 2", ): - con.execute("select $1, $1, $2", {'name1': 5, 'name2': 3}) + con.execute("select $1, $1, $2", {"name1": 5, "name2": 3}) def test_named_param_mixed(self): con = duckdb.connect() @@ -138,13 +139,13 @@ def test_named_param_mixed(self): with pytest.raises( duckdb.NotImplementedException, match="Mixing named and positional parameters is not supported yet" ): - con.execute("select $name1, $1, $2", {'name1': 5, 'name2': 3}) + con.execute("select $name1, $1, $2", {"name1": 5, "name2": 3}) def test_named_param_strings_with_dollarsign(self): con = duckdb.connect() - res = con.execute("select '$name1', $name1, $name1, '$name1'", {'name1': 5}).fetchall() - assert res == [('$name1', 5, 5, '$name1')] + res = con.execute("select '$name1', $name1, $name1, '$name1'", {"name1": 5}).fetchall() + assert res == [("$name1", 5, 5, "$name1")] def test_named_param_case_insensivity(self): con = duckdb.connect() @@ -153,10 +154,10 @@ def test_named_param_case_insensivity(self): """ select $NaMe1, $NAME2, $name3 """, - {'name1': 5, 'nAmE2': 3, 'NAME3': 'a'}, + {"name1": 5, "nAmE2": 3, "NAME3": "a"}, ).fetchall() assert res == [ - (5, 3, 'a'), + (5, 3, "a"), ] def test_named_param_keyword(self): @@ -176,16 +177,16 @@ def test_conversion_from_tuple(self): assert result == [([21, 22, 42],)] # If wrapped in a Value, it can convert to a struct - result = con.execute("select $1", [Value(('a', 21, True), {'a': str, 'b': int, 'c': bool})]).fetchall() - assert result == [({'a': 'a', 'b': 21, 'c': True},)] + result = con.execute("select $1", [Value(("a", 21, True), {"a": str, "b": int, "c": bool})]).fetchall() + assert result == [({"a": "a", "b": 21, "c": True},)] # If the amount of items in the tuple and the children of the struct don't match # we throw an error with pytest.raises( duckdb.InvalidInputException, - match='Tried to create a STRUCT value from a tuple containing 3 elements, but the STRUCT consists of 2 children', + match="Tried to create a STRUCT value from a tuple containing 3 elements, but the STRUCT consists of 2 children", # noqa: E501 ): - result = con.execute("select $1", [Value(('a', 21, True), {'a': str, 'b': int})]).fetchall() + result = con.execute("select $1", [Value(("a", 21, True), {"a": str, "b": int})]).fetchall() # If we try to create anything other than a STRUCT or a LIST out of the tuple, we throw an error with pytest.raises(duckdb.InvalidInputException, match="Can't convert tuple to a Value of type VARCHAR"): @@ -194,12 +195,12 @@ def test_conversion_from_tuple(self): def test_column_name_behavior(self, duckdb_cursor): _ = pytest.importorskip("pandas") - expected_names = ['one', 'ONE_1'] + expected_names = ["one", "ONE_1"] df = duckdb_cursor.execute('select 1 as one, 2 as "ONE"').fetchdf() assert expected_names == list(df.columns) - duckdb_cursor.register('tbl', df) + duckdb_cursor.register("tbl", df) df = duckdb_cursor.execute("select * from tbl").fetchdf() assert expected_names == list(df.columns) diff --git a/tests/fast/api/test_explain.py b/tests/fast/api/test_explain.py index 73c198b9..61ea979c 100644 --- a/tests/fast/api/test_explain.py +++ b/tests/fast/api/test_explain.py @@ -1,43 +1,41 @@ import pytest + import duckdb -class TestExplain(object): +class TestExplain: def test_explain_basic(self, duckdb_cursor): - res = duckdb_cursor.sql('select 42').explain() + res = duckdb_cursor.sql("select 42").explain() assert isinstance(res, str) def test_explain_standard(self, duckdb_cursor): - res = duckdb_cursor.sql('select 42').explain('standard') - assert isinstance(res, str) - - res = duckdb_cursor.sql('select 42').explain('STANDARD') + res = duckdb_cursor.sql("select 42").explain("standard") assert isinstance(res, str) - res = duckdb_cursor.sql('select 42').explain(duckdb.STANDARD) + res = duckdb_cursor.sql("select 42").explain("STANDARD") assert isinstance(res, str) - res = duckdb_cursor.sql('select 42').explain(duckdb.ExplainType.STANDARD) + res = duckdb_cursor.sql("select 42").explain(duckdb.ExplainType.STANDARD) assert isinstance(res, str) - res = duckdb_cursor.sql('select 42').explain(0) + res = duckdb_cursor.sql("select 42").explain(0) assert isinstance(res, str) def test_explain_analyze(self, duckdb_cursor): - res = duckdb_cursor.sql('select 42').explain('analyze') + res = duckdb_cursor.sql("select 42").explain("analyze") assert isinstance(res, str) - res = duckdb_cursor.sql('select 42').explain('ANALYZE') + res = duckdb_cursor.sql("select 42").explain("ANALYZE") assert isinstance(res, str) - res = duckdb_cursor.sql('select 42').explain(duckdb.ExplainType.ANALYZE) + res = duckdb_cursor.sql("select 42").explain(duckdb.ExplainType.ANALYZE) assert isinstance(res, str) - res = duckdb_cursor.sql('select 42').explain(1) + res = duckdb_cursor.sql("select 42").explain(1) assert isinstance(res, str) def test_explain_df(self, duckdb_cursor): pd = pytest.importorskip("pandas") - df = pd.DataFrame({'a': [42]}) - res = duckdb_cursor.sql('select * from df').explain('ANALYZE') + df = pd.DataFrame({"a": [42]}) # noqa: F841 + res = duckdb_cursor.sql("select * from df").explain("ANALYZE") assert isinstance(res, str) diff --git a/tests/fast/api/test_fsspec.py b/tests/fast/api/test_fsspec.py index 0a289972..154f38bd 100644 --- a/tests/fast/api/test_fsspec.py +++ b/tests/fast/api/test_fsspec.py @@ -1,28 +1,27 @@ -import pytest -import duckdb -import io import datetime +import io + +import pytest fsspec = pytest.importorskip("fsspec") -class TestReadParquet(object): +class TestReadParquet: def test_fsspec_deadlock(self, duckdb_cursor, tmp_path): # Create test parquet data file_path = tmp_path / "data.parquet" - duckdb_cursor.sql("COPY (FROM range(50_000)) TO '{}' (FORMAT parquet)".format(str(file_path))) - with open(file_path, "rb") as f: - parquet_data = f.read() + duckdb_cursor.sql(f"COPY (FROM range(50_000)) TO '{file_path!s}' (FORMAT parquet)") + parquet_data = file_path.read_bytes() class TestFileSystem(fsspec.AbstractFileSystem): protocol = "deadlock" @property - def fsid(self): + def fsid(self) -> str: return "deadlock" def ls(self, path, detail=True, **kwargs): - vals = [k for k in self._data.keys() if k.startswith(path)] + vals = [k for k in self._data if k.startswith(path)] if detail: return [ { @@ -44,12 +43,12 @@ def modified(self, path): def _open(self, path, **kwargs): return io.BytesIO(self._data[path]) - def __init__(self): + def __init__(self) -> None: super().__init__() self._data = {"a": parquet_data, "b": parquet_data} fsspec.register_implementation("deadlock", TestFileSystem, clobber=True) - fs = fsspec.filesystem('deadlock') + fs = fsspec.filesystem("deadlock") duckdb_cursor.register_filesystem(fs) result = duckdb_cursor.read_parquet(file_globs=["deadlock://a", "deadlock://b"], union_by_name=True) diff --git a/tests/fast/api/test_insert_into.py b/tests/fast/api/test_insert_into.py index e6d4c6ba..1214203b 100644 --- a/tests/fast/api/test_insert_into.py +++ b/tests/fast/api/test_insert_into.py @@ -1,28 +1,29 @@ -import duckdb -from pandas import DataFrame import pytest +from pandas import DataFrame + +import duckdb -class TestInsertInto(object): +class TestInsertInto: def test_insert_into_schema(self, duckdb_cursor): # open connection con = duckdb.connect() - con.execute('CREATE SCHEMA s') - con.execute('CREATE TABLE s.t (id INTEGER PRIMARY KEY)') + con.execute("CREATE SCHEMA s") + con.execute("CREATE TABLE s.t (id INTEGER PRIMARY KEY)") # make relation - df = DataFrame([1], columns=['id']) + df = DataFrame([1], columns=["id"]) rel = con.from_df(df) - rel.insert_into('s.t') + rel.insert_into("s.t") assert con.execute("select * from s.t").fetchall() == [(1,)] # This should fail since this will go to default schema with pytest.raises(duckdb.CatalogException): - rel.insert_into('t') + rel.insert_into("t") # If we add t in the default schema it should work. - con.execute('CREATE TABLE t (id INTEGER PRIMARY KEY)') - rel.insert_into('t') + con.execute("CREATE TABLE t (id INTEGER PRIMARY KEY)") + rel.insert_into("t") assert con.execute("select * from t").fetchall() == [(1,)] diff --git a/tests/fast/api/test_join.py b/tests/fast/api/test_join.py index 7d7f45c2..be311ec0 100644 --- a/tests/fast/api/test_join.py +++ b/tests/fast/api/test_join.py @@ -1,14 +1,15 @@ -import duckdb import pytest +import duckdb + -class TestJoin(object): +class TestJoin: def test_alias_from_sql(self): con = duckdb.connect() - rel1 = con.sql("SELECT 1 AS col1, 2 AS col2") - rel2 = con.sql("SELECT 1 AS col1, 3 AS col3") + rel1 = con.sql("SELECT 1 AS col1, 2 AS col2") # noqa: F841 + rel2 = con.sql("SELECT 1 AS col1, 3 AS col3") # noqa: F841 - rel = con.sql('select * from rel1 JOIN rel2 USING (col1)') + rel = con.sql("select * from rel1 JOIN rel2 USING (col1)") rel.show() res = rel.fetchall() assert res == [(1, 2, 3)] @@ -19,27 +20,27 @@ def test_relational_join(self): rel1 = con.sql("SELECT 1 AS col1, 2 AS col2") rel2 = con.sql("SELECT 1 AS col1, 3 AS col3") - rel = rel1.join(rel2, 'col1') + rel = rel1.join(rel2, "col1") res = rel.fetchall() assert res == [(1, 2, 3)] def test_relational_join_alias_collision(self): con = duckdb.connect() - rel1 = con.sql("SELECT 1 AS col1, 2 AS col2").set_alias('a') - rel2 = con.sql("SELECT 1 AS col1, 3 AS col3").set_alias('a') + rel1 = con.sql("SELECT 1 AS col1, 2 AS col2").set_alias("a") + rel2 = con.sql("SELECT 1 AS col1, 3 AS col3").set_alias("a") - with pytest.raises(duckdb.InvalidInputException, match='Both relations have the same alias'): - rel = rel1.join(rel2, 'col1') + with pytest.raises(duckdb.InvalidInputException, match="Both relations have the same alias"): + rel1.join(rel2, "col1") def test_relational_join_with_condition(self): con = duckdb.connect() - rel1 = con.sql("SELECT 1 AS col1, 2 AS col2", alias='rel1') - rel2 = con.sql("SELECT 1 AS col1, 3 AS col3", alias='rel2') + rel1 = con.sql("SELECT 1 AS col1, 2 AS col2", alias="rel1") + rel2 = con.sql("SELECT 1 AS col1, 3 AS col3", alias="rel2") # This makes a USING clause, which is kind of unexpected behavior - rel = rel1.join(rel2, 'rel1.col1 = rel2.col1') + rel = rel1.join(rel2, "rel1.col1 = rel2.col1") rel.show() res = rel.fetchall() assert res == [(1, 2, 1, 3)] @@ -49,8 +50,8 @@ def test_deduplicated_bindings(self, duckdb_cursor): duckdb_cursor.execute("create table old as select * from (values ('42', 1), ('21', 2)) t(a, b)") duckdb_cursor.execute("create table old_1 as select * from (values ('42', 3), ('21', 4)) t(a, b)") - old = duckdb_cursor.table('old') - old_1 = duckdb_cursor.table('old_1') + old = duckdb_cursor.table("old") + old_1 = duckdb_cursor.table("old_1") join_one = old.join(old_1, "old.a == old_1.a") join_two = old.join(old_1, "old.a == old_1.a") diff --git a/tests/fast/api/test_native_tz.py b/tests/fast/api/test_native_tz.py index 6098ca08..66b06565 100644 --- a/tests/fast/api/test_native_tz.py +++ b/tests/fast/api/test_native_tz.py @@ -1,35 +1,40 @@ -import duckdb import datetime -import pytz -import os +from pathlib import Path + import pytest +from packaging.version import Version + +import duckdb pd = pytest.importorskip("pandas") pa = pytest.importorskip("pyarrow") -from packaging.version import Version -filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'data', 'tz.parquet') +filename = str(Path(__file__).parent / ".." / "data" / "tz.parquet") -class TestNativeTimeZone(object): +class TestNativeTimeZone: def test_native_python_timestamp_timezone(self, duckdb_cursor): duckdb_cursor.execute("SET timezone='America/Los_Angeles';") res = duckdb_cursor.execute(f"select TimeRecStart as tz from '{filename}'").fetchone() - assert res[0].hour == 14 and res[0].minute == 52 - assert res[0].tzinfo.zone == 'America/Los_Angeles' + assert res[0].hour == 14 + assert res[0].minute == 52 + assert res[0].tzinfo.zone == "America/Los_Angeles" res = duckdb_cursor.execute(f"select TimeRecStart as tz from '{filename}'").fetchall()[0] - assert res[0].hour == 14 and res[0].minute == 52 - assert res[0].tzinfo.zone == 'America/Los_Angeles' + assert res[0].hour == 14 + assert res[0].minute == 52 + assert res[0].tzinfo.zone == "America/Los_Angeles" res = duckdb_cursor.execute(f"select TimeRecStart as tz from '{filename}'").fetchmany(1)[0] - assert res[0].hour == 14 and res[0].minute == 52 - assert res[0].tzinfo.zone == 'America/Los_Angeles' + assert res[0].hour == 14 + assert res[0].minute == 52 + assert res[0].tzinfo.zone == "America/Los_Angeles" duckdb_cursor.execute("SET timezone='UTC';") res = duckdb_cursor.execute(f"select TimeRecStart as tz from '{filename}'").fetchone() - assert res[0].hour == 21 and res[0].minute == 52 - assert res[0].tzinfo.zone == 'UTC' + assert res[0].hour == 21 + assert res[0].minute == 52 + assert res[0].tzinfo.zone == "UTC" def test_native_python_time_timezone(self, duckdb_cursor): res = duckdb_cursor.execute(f"select TimeRecStart::TIMETZ as tz from '{filename}'").fetchone() @@ -41,33 +46,37 @@ def test_native_python_time_timezone(self, duckdb_cursor): def test_pandas_timestamp_timezone(self, duckdb_cursor): res = duckdb_cursor.execute("SET timezone='America/Los_Angeles';") res = duckdb_cursor.execute(f"select TimeRecStart as tz from '{filename}'").df() - assert res.dtypes["tz"].tz.zone == 'America/Los_Angeles' - assert res['tz'][0].hour == 14 and res['tz'][0].minute == 52 + assert res.dtypes["tz"].tz.zone == "America/Los_Angeles" + assert res["tz"][0].hour == 14 + assert res["tz"][0].minute == 52 duckdb_cursor.execute("SET timezone='UTC';") res = duckdb_cursor.execute(f"select TimeRecStart as tz from '{filename}'").df() - assert res['tz'][0].hour == 21 and res['tz'][0].minute == 52 + assert res["tz"][0].hour == 21 + assert res["tz"][0].minute == 52 def test_pandas_timestamp_time(self, duckdb_cursor): with pytest.raises( - duckdb.NotImplementedException, match="Not implemented Error: Unsupported type \"TIME WITH TIME ZONE\"" + duckdb.NotImplementedException, match='Not implemented Error: Unsupported type "TIME WITH TIME ZONE"' ): duckdb_cursor.execute(f"select TimeRecStart::TIMETZ as tz from '{filename}'").df() @pytest.mark.skipif( - Version(pa.__version__) < Version('15.0.0'), reason="pyarrow 14.0.2 'to_pandas' causes a DeprecationWarning" + Version(pa.__version__) < Version("15.0.0"), reason="pyarrow 14.0.2 'to_pandas' causes a DeprecationWarning" ) def test_arrow_timestamp_timezone(self, duckdb_cursor): res = duckdb_cursor.execute("SET timezone='America/Los_Angeles';") table = duckdb_cursor.execute(f"select TimeRecStart as tz from '{filename}'").fetch_arrow_table() res = table.to_pandas() - assert res.dtypes["tz"].tz.zone == 'America/Los_Angeles' - assert res['tz'][0].hour == 14 and res['tz'][0].minute == 52 + assert res.dtypes["tz"].tz.zone == "America/Los_Angeles" + assert res["tz"][0].hour == 14 + assert res["tz"][0].minute == 52 duckdb_cursor.execute("SET timezone='UTC';") res = duckdb_cursor.execute(f"select TimeRecStart as tz from '{filename}'").fetch_arrow_table().to_pandas() - assert res.dtypes["tz"].tz.zone == 'UTC' - assert res['tz'][0].hour == 21 and res['tz'][0].minute == 52 + assert res.dtypes["tz"].tz.zone == "UTC" + assert res["tz"][0].hour == 21 + assert res["tz"][0].minute == 52 def test_arrow_timestamp_time(self, duckdb_cursor): duckdb_cursor.execute("SET timezone='America/Los_Angeles';") @@ -81,8 +90,10 @@ def test_arrow_timestamp_time(self, duckdb_cursor): .fetch_arrow_table() .to_pandas() ) - assert res1['tz'][0].hour == 14 and res1['tz'][0].minute == 52 - assert res2['tz'][0].hour == res2['tz'][0].hour and res2['tz'][0].minute == res1['tz'][0].minute + assert res1["tz"][0].hour == 14 + assert res1["tz"][0].minute == 52 + assert res2["tz"][0].hour == res2["tz"][0].hour + assert res2["tz"][0].minute == res1["tz"][0].minute duckdb_cursor.execute("SET timezone='UTC';") res1 = ( @@ -95,5 +106,7 @@ def test_arrow_timestamp_time(self, duckdb_cursor): .fetch_arrow_table() .to_pandas() ) - assert res1['tz'][0].hour == 21 and res1['tz'][0].minute == 52 - assert res2['tz'][0].hour == res2['tz'][0].hour and res2['tz'][0].minute == res1['tz'][0].minute + assert res1["tz"][0].hour == 21 + assert res1["tz"][0].minute == 52 + assert res2["tz"][0].hour == res2["tz"][0].hour + assert res2["tz"][0].minute == res1["tz"][0].minute diff --git a/tests/fast/api/test_query_interrupt.py b/tests/fast/api/test_query_interrupt.py index 6334e475..4a5a02e5 100644 --- a/tests/fast/api/test_query_interrupt.py +++ b/tests/fast/api/test_query_interrupt.py @@ -1,10 +1,11 @@ -import duckdb +import _thread as thread +import platform +import threading import time + import pytest -import platform -import threading -import _thread as thread +import duckdb def send_keyboard_interrupt(): @@ -14,7 +15,7 @@ def send_keyboard_interrupt(): thread.interrupt_main() -class TestQueryInterruption(object): +class TestQueryInterruption: @pytest.mark.xfail( condition=platform.system() == "Emscripten", reason="Emscripten builds cannot use threads", @@ -25,11 +26,11 @@ def test_query_interruption(self): # Start the thread thread.start() try: - res = con.execute('select count(*) from range(100000000000)').fetchall() + con.execute("select count(*) from range(100000000000)").fetchall() except RuntimeError: # If this is not reached, we could not cancel the query before it completed # indicating that the query interruption functionality is broken assert True except KeyboardInterrupt: - pytest.fail() + pytest.fail("Interrupted by user") thread.join() diff --git a/tests/fast/api/test_query_progress.py b/tests/fast/api/test_query_progress.py index f885e36d..c57a88c3 100644 --- a/tests/fast/api/test_query_progress.py +++ b/tests/fast/api/test_query_progress.py @@ -1,12 +1,14 @@ +import contextlib import platform import threading import time -import duckdb import pytest +import duckdb + -class TestQueryProgress(object): +class TestQueryProgress: @pytest.mark.xfail( condition=platform.system() == "Emscripten", reason="threads not allowed on Emscripten", @@ -17,13 +19,10 @@ def test_query_progress(self, reraise): conn.sql("set progress_bar_time=0") conn.sql("create table t as (select range as n from range(10000000))") - def thread_target(): + def thread_target() -> None: # run a very slow query which hopefully isn't too memory intensive. - with reraise: - try: - conn.execute("select max(sha1(n::varchar)) from t").fetchall() - except duckdb.InterruptException: - pass + with reraise, contextlib.suppress(duckdb.InterruptException): + conn.execute("select max(sha1(n::varchar)) from t").fetchall() thread = threading.Thread(target=thread_target) thread.start() @@ -33,7 +32,7 @@ def thread_target(): # query never progresses. This will also fail if the query is too # quick as it will be back at -1 as soon as the query is finished. - for _ in range(0, 500): + for _ in range(500): assert thread.is_alive(), "query finished too quick" if (qp1 := conn.query_progress()) > 0: break @@ -42,7 +41,7 @@ def thread_target(): pytest.fail("query start timeout") # keep monitoring and wait for the progress to increase - for _ in range(0, 500): + for _ in range(500): assert thread.is_alive(), "query finished too quick" if (qp2 := conn.query_progress()) > qp1: break diff --git a/tests/fast/api/test_read_csv.py b/tests/fast/api/test_read_csv.py index 1a297109..e7862e9b 100644 --- a/tests/fast/api/test_read_csv.py +++ b/tests/fast/api/test_read_csv.py @@ -1,18 +1,19 @@ -from multiprocessing.sharedctypes import Value import datetime -import pytest import platform +import sys +from io import BytesIO, StringIO +from pathlib import Path +from typing import NoReturn + +import pytest + import duckdb -from io import StringIO, BytesIO from duckdb import CSVLineTerminator -import sys def TestFile(name): - import os - - filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'data', name) - return filename + filename = Path(__file__).parent / ".." / "data" / name + return str(filename) @pytest.fixture @@ -33,264 +34,263 @@ def create_temp_csv(tmp_path): return file1_path, file2_path -class TestReadCSV(object): +class TestReadCSV: def test_using_connection_wrapper(self): - rel = duckdb.read_csv(TestFile('category.csv')) + rel = duckdb.read_csv(TestFile("category.csv")) res = rel.fetchone() print(res) - assert res == (1, 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + assert res == (1, "Action", datetime.datetime(2006, 2, 15, 4, 46, 27)) def test_using_connection_wrapper_with_keyword(self): - rel = duckdb.read_csv(TestFile('category.csv'), dtype={'category_id': 'string'}) + rel = duckdb.read_csv(TestFile("category.csv"), dtype={"category_id": "string"}) res = rel.fetchone() print(res) - assert res == ('1', 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + assert res == ("1", "Action", datetime.datetime(2006, 2, 15, 4, 46, 27)) def test_no_options(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv')) + rel = duckdb_cursor.read_csv(TestFile("category.csv")) res = rel.fetchone() print(res) - assert res == (1, 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + assert res == (1, "Action", datetime.datetime(2006, 2, 15, 4, 46, 27)) def test_dtype(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), dtype={'category_id': 'string'}) + rel = duckdb_cursor.read_csv(TestFile("category.csv"), dtype={"category_id": "string"}) res = rel.fetchone() print(res) - assert res == ('1', 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + assert res == ("1", "Action", datetime.datetime(2006, 2, 15, 4, 46, 27)) def test_dtype_as_list(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), dtype=['string']) + rel = duckdb_cursor.read_csv(TestFile("category.csv"), dtype=["string"]) res = rel.fetchone() print(res) - assert res == ('1', 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + assert res == ("1", "Action", datetime.datetime(2006, 2, 15, 4, 46, 27)) - rel = duckdb_cursor.read_csv(TestFile('category.csv'), dtype=['double']) + rel = duckdb_cursor.read_csv(TestFile("category.csv"), dtype=["double"]) res = rel.fetchone() print(res) - assert res == (1.0, 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + assert res == (1.0, "Action", datetime.datetime(2006, 2, 15, 4, 46, 27)) def test_sep(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), sep=" ") + rel = duckdb_cursor.read_csv(TestFile("category.csv"), sep=" ") res = rel.fetchone() print(res) - assert res == ('1|Action|2006-02-15', datetime.time(4, 46, 27)) + assert res == ("1|Action|2006-02-15", datetime.time(4, 46, 27)) def test_delimiter(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), delimiter=" ") + rel = duckdb_cursor.read_csv(TestFile("category.csv"), delimiter=" ") res = rel.fetchone() print(res) - assert res == ('1|Action|2006-02-15', datetime.time(4, 46, 27)) + assert res == ("1|Action|2006-02-15", datetime.time(4, 46, 27)) def test_delimiter_and_sep(self, duckdb_cursor): with pytest.raises(duckdb.InvalidInputException, match="read_csv takes either 'delimiter' or 'sep', not both"): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), delimiter=" ", sep=" ") + duckdb_cursor.read_csv(TestFile("category.csv"), delimiter=" ", sep=" ") def test_header_true(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv')) + rel = duckdb_cursor.read_csv(TestFile("category.csv")) res = rel.fetchone() print(res) - assert res == (1, 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + assert res == (1, "Action", datetime.datetime(2006, 2, 15, 4, 46, 27)) - @pytest.mark.skip(reason="Issue #6011 needs to be fixed first, header=False doesn't work correctly") def test_header_false(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), header=False) + duckdb_cursor.read_csv(TestFile("category.csv"), header=False) def test_na_values(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), na_values='Action') + rel = duckdb_cursor.read_csv(TestFile("category.csv"), na_values="Action") res = rel.fetchone() print(res) assert res == (1, None, datetime.datetime(2006, 2, 15, 4, 46, 27)) def test_na_values_list(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), na_values=['Action', 'Animation']) + rel = duckdb_cursor.read_csv(TestFile("category.csv"), na_values=["Action", "Animation"]) res = rel.fetchone() assert res == (1, None, datetime.datetime(2006, 2, 15, 4, 46, 27)) res = rel.fetchone() assert res == (2, None, datetime.datetime(2006, 2, 15, 4, 46, 27)) def test_skiprows(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), skiprows=1) + rel = duckdb_cursor.read_csv(TestFile("category.csv"), skiprows=1) res = rel.fetchone() print(res) - assert res == (1, 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + assert res == (1, "Action", datetime.datetime(2006, 2, 15, 4, 46, 27)) # We want to detect this at bind time def test_compression_wrong(self, duckdb_cursor): with pytest.raises(duckdb.Error, match="Input is not a GZIP stream"): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), compression='gzip') + duckdb_cursor.read_csv(TestFile("category.csv"), compression="gzip") def test_quotechar(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('unquote_without_delimiter.csv'), quotechar="", header=False) + rel = duckdb_cursor.read_csv(TestFile("unquote_without_delimiter.csv"), quotechar="", header=False) res = rel.fetchone() print(res) assert res == ('"AAA"BB',) def test_quote(self, duckdb_cursor): with pytest.raises( - duckdb.Error, match="The methods read_csv and read_csv_auto do not have the \"quote\" argument." + duckdb.Error, match='The methods read_csv and read_csv_auto do not have the "quote" argument' ): - rel = duckdb_cursor.read_csv(TestFile('unquote_without_delimiter.csv'), quote="", header=False) + duckdb_cursor.read_csv(TestFile("unquote_without_delimiter.csv"), quote="", header=False) def test_escapechar(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('quote_escape.csv'), escapechar=";", header=False) + rel = duckdb_cursor.read_csv(TestFile("quote_escape.csv"), escapechar=";", header=False) res = rel.limit(1, 1).fetchone() print(res) - assert res == ('345', 'TEST6', '"text""2""text"') + assert res == ("345", "TEST6", '"text""2""text"') def test_encoding_wrong(self, duckdb_cursor): with pytest.raises( duckdb.BinderException, match="Copy is only supported for UTF-8 encoded files, ENCODING 'UTF-8'" ): - rel = duckdb_cursor.read_csv(TestFile('quote_escape.csv'), encoding=";") + duckdb_cursor.read_csv(TestFile("quote_escape.csv"), encoding=";") def test_encoding_correct(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('quote_escape.csv'), encoding="UTF-8") + rel = duckdb_cursor.read_csv(TestFile("quote_escape.csv"), encoding="UTF-8") res = rel.limit(1, 1).fetchone() print(res) - assert res == (345, 'TEST6', 'text"2"text') + assert res == (345, "TEST6", 'text"2"text') def test_date_format_as_datetime(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('datetime.csv')) + rel = duckdb_cursor.read_csv(TestFile("datetime.csv")) res = rel.fetchone() print(res) assert res == ( 123, - 'TEST2', + "TEST2", datetime.time(12, 12, 12), datetime.date(2000, 1, 1), datetime.datetime(2000, 1, 1, 12, 12), ) def test_date_format_as_date(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('datetime.csv'), date_format='%Y-%m-%d') + rel = duckdb_cursor.read_csv(TestFile("datetime.csv"), date_format="%Y-%m-%d") res = rel.fetchone() print(res) assert res == ( 123, - 'TEST2', + "TEST2", datetime.time(12, 12, 12), datetime.date(2000, 1, 1), datetime.datetime(2000, 1, 1, 12, 12), ) def test_timestamp_format(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('datetime.csv'), timestamp_format='%Y-%m-%d %H:%M:%S') + rel = duckdb_cursor.read_csv(TestFile("datetime.csv"), timestamp_format="%Y-%m-%d %H:%M:%S") res = rel.fetchone() assert res == ( 123, - 'TEST2', + "TEST2", datetime.time(12, 12, 12), datetime.date(2000, 1, 1), datetime.datetime(2000, 1, 1, 12, 12), ) def test_sample_size_correct(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('problematic.csv'), sample_size=-1) + rel = duckdb_cursor.read_csv(TestFile("problematic.csv"), sample_size=-1) res = rel.fetchone() print(res) - assert res == ('1', '1', '1') + assert res == ("1", "1", "1") def test_all_varchar(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), all_varchar=True) + rel = duckdb_cursor.read_csv(TestFile("category.csv"), all_varchar=True) res = rel.fetchone() print(res) - assert res == ('1', 'Action', '2006-02-15 04:46:27') + assert res == ("1", "Action", "2006-02-15 04:46:27") def test_null_padding(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('nullpadding.csv'), null_padding=False, header=False) + rel = duckdb_cursor.read_csv(TestFile("nullpadding.csv"), null_padding=False, header=False) res = rel.fetchall() assert res == [ - ('# this file has a bunch of gunk at the top',), - ('one,two,three,four',), - ('1,a,alice',), - ('2,b,bob',), + ("# this file has a bunch of gunk at the top",), + ("one,two,three,four",), + ("1,a,alice",), + ("2,b,bob",), ] - rel = duckdb_cursor.read_csv(TestFile('nullpadding.csv'), null_padding=True, header=False) + rel = duckdb_cursor.read_csv(TestFile("nullpadding.csv"), null_padding=True, header=False) res = rel.fetchall() assert res == [ - ('# this file has a bunch of gunk at the top', None, None, None), - ('one', 'two', 'three', 'four'), - ('1', 'a', 'alice', None), - ('2', 'b', 'bob', None), + ("# this file has a bunch of gunk at the top", None, None, None), + ("one", "two", "three", "four"), + ("1", "a", "alice", None), + ("2", "b", "bob", None), ] - rel = duckdb.read_csv(TestFile('nullpadding.csv'), null_padding=False, header=False) + rel = duckdb.read_csv(TestFile("nullpadding.csv"), null_padding=False, header=False) res = rel.fetchall() assert res == [ - ('# this file has a bunch of gunk at the top',), - ('one,two,three,four',), - ('1,a,alice',), - ('2,b,bob',), + ("# this file has a bunch of gunk at the top",), + ("one,two,three,four",), + ("1,a,alice",), + ("2,b,bob",), ] - rel = duckdb.read_csv(TestFile('nullpadding.csv'), null_padding=True, header=False) + rel = duckdb.read_csv(TestFile("nullpadding.csv"), null_padding=True, header=False) res = rel.fetchall() assert res == [ - ('# this file has a bunch of gunk at the top', None, None, None), - ('one', 'two', 'three', 'four'), - ('1', 'a', 'alice', None), - ('2', 'b', 'bob', None), + ("# this file has a bunch of gunk at the top", None, None, None), + ("one", "two", "three", "four"), + ("1", "a", "alice", None), + ("2", "b", "bob", None), ] - rel = duckdb_cursor.from_csv_auto(TestFile('nullpadding.csv'), null_padding=False, header=False) + rel = duckdb_cursor.from_csv_auto(TestFile("nullpadding.csv"), null_padding=False, header=False) res = rel.fetchall() assert res == [ - ('# this file has a bunch of gunk at the top',), - ('one,two,three,four',), - ('1,a,alice',), - ('2,b,bob',), + ("# this file has a bunch of gunk at the top",), + ("one,two,three,four",), + ("1,a,alice",), + ("2,b,bob",), ] - rel = duckdb_cursor.from_csv_auto(TestFile('nullpadding.csv'), null_padding=True, header=False) + rel = duckdb_cursor.from_csv_auto(TestFile("nullpadding.csv"), null_padding=True, header=False) res = rel.fetchall() assert res == [ - ('# this file has a bunch of gunk at the top', None, None, None), - ('one', 'two', 'three', 'four'), - ('1', 'a', 'alice', None), - ('2', 'b', 'bob', None), + ("# this file has a bunch of gunk at the top", None, None, None), + ("one", "two", "three", "four"), + ("1", "a", "alice", None), + ("2", "b", "bob", None), ] def test_normalize_names(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), normalize_names=False) + rel = duckdb_cursor.read_csv(TestFile("category.csv"), normalize_names=False) df = rel.df() column_names = list(df.columns.values) # The names are not normalized, so they are capitalized - assert 'CATEGORY_ID' in column_names + assert "CATEGORY_ID" in column_names - rel = duckdb_cursor.read_csv(TestFile('category.csv'), normalize_names=True) + rel = duckdb_cursor.read_csv(TestFile("category.csv"), normalize_names=True) df = rel.df() column_names = list(df.columns.values) # The capitalized names are normalized to lowercase instead - assert 'CATEGORY_ID' not in column_names + assert "CATEGORY_ID" not in column_names def test_filename(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), filename=False) + rel = duckdb_cursor.read_csv(TestFile("category.csv"), filename=False) df = rel.df() column_names = list(df.columns.values) # The filename is not included in the returned columns - assert 'filename' not in column_names + assert "filename" not in column_names - rel = duckdb_cursor.read_csv(TestFile('category.csv'), filename=True) + rel = duckdb_cursor.read_csv(TestFile("category.csv"), filename=True) df = rel.df() column_names = list(df.columns.values) # The filename is included in the returned columns - assert 'filename' in column_names + assert "filename" in column_names def test_read_pathlib_path(self, duckdb_cursor): pathlib = pytest.importorskip("pathlib") - path = pathlib.Path(TestFile('category.csv')) + path = pathlib.Path(TestFile("category.csv")) rel = duckdb_cursor.read_csv(path) res = rel.fetchone() print(res) - assert res == (1, 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + assert res == (1, "Action", datetime.datetime(2006, 2, 15, 4, 46, 27)) def test_read_filelike(self, duckdb_cursor): pytest.importorskip("fsspec") string = StringIO("c1,c2,c3\na,b,c") res = duckdb_cursor.read_csv(string).fetchall() - assert res == [('a', 'b', 'c')] + assert res == [("a", "b", "c")] def test_read_filelike_rel_out_of_scope(self, duckdb_cursor): _ = pytest.importorskip("fsspec") @@ -321,47 +321,47 @@ def test_filelike_bytesio(self, duckdb_cursor): _ = pytest.importorskip("fsspec") string = BytesIO(b"c1,c2,c3\na,b,c") res = duckdb_cursor.read_csv(string).fetchall() - assert res == [('a', 'b', 'c')] + assert res == [("a", "b", "c")] def test_filelike_exception(self, duckdb_cursor): _ = pytest.importorskip("fsspec") class ReadError: - def __init__(self): + def __init__(self) -> None: pass - def read(self, amount=-1): + def read(self, amount=-1) -> NoReturn: raise ValueError(amount) - def seek(self, loc): + def seek(self, loc) -> int: return 0 class SeekError: - def __init__(self): + def __init__(self) -> None: pass - def read(self, amount=-1): - return b'test' + def read(self, amount=-1) -> bytes: + return b"test" - def seek(self, loc): + def seek(self, loc) -> NoReturn: raise ValueError(loc) # The MemoryFileSystem reads the content into another object, so this fails instantly obj = ReadError() - with pytest.raises(ValueError): - res = duckdb_cursor.read_csv(obj).fetchall() + with pytest.raises(ValueError, match="-1"): + duckdb_cursor.read_csv(obj).fetchall() - # For that same reason, this will not error, because the data is retrieved with 'read' and then SeekError is never used again + # For that same reason, this will not error, because the data is retrieved with 'read' and then + # SeekError is never used again obj = SeekError() - res = duckdb_cursor.read_csv(obj).fetchall() + duckdb_cursor.read_csv(obj).fetchall() def test_filelike_custom(self, duckdb_cursor): _ = pytest.importorskip("fsspec") class CustomIO: - def __init__(self): + def __init__(self) -> None: self.loc = 0 - pass def seek(self, loc): self.loc = loc @@ -377,19 +377,19 @@ def read(self, amount=-1): obj = CustomIO() res = duckdb_cursor.read_csv(obj).fetchall() - assert res == [('a', 'b', 'c')] + assert res == [("a", "b", "c")] def test_filelike_non_readable(self, duckdb_cursor): _ = pytest.importorskip("fsspec") obj = 5 - with pytest.raises(ValueError, match="Can not read from a non file-like object"): - res = duckdb_cursor.read_csv(obj).fetchall() + with pytest.raises(TypeError, match="Can not read from a non file-like object"): + duckdb_cursor.read_csv(obj).fetchall() def test_filelike_none(self, duckdb_cursor): _ = pytest.importorskip("fsspec") obj = None - with pytest.raises(ValueError, match="Can not read from a non file-like object"): - res = duckdb_cursor.read_csv(obj).fetchall() + with pytest.raises(TypeError, match="Can not read from a non file-like object"): + duckdb_cursor.read_csv(obj).fetchall() @pytest.mark.skip(reason="depends on garbage collector behaviour, and sporadically breaks in CI") def test_internal_object_filesystem_cleanup(self, duckdb_cursor): @@ -398,21 +398,21 @@ def test_internal_object_filesystem_cleanup(self, duckdb_cursor): class CountedObject(StringIO): instance_count = 0 - def __init__(self, str): + def __init__(self, str) -> None: CountedObject.instance_count += 1 super().__init__(str) - def __del__(self): + def __del__(self) -> None: CountedObject.instance_count -= 1 - def scoped_objects(duckdb_cursor): + def scoped_objects(duckdb_cursor) -> None: obj = CountedObject("a,b,c") rel1 = duckdb_cursor.read_csv(obj) assert rel1.fetchall() == [ ( - 'a', - 'b', - 'c', + "a", + "b", + "c", ) ] assert CountedObject.instance_count == 1 @@ -421,9 +421,9 @@ def scoped_objects(duckdb_cursor): rel2 = duckdb_cursor.read_csv(obj) assert rel2.fetchall() == [ ( - 'a', - 'b', - 'c', + "a", + "b", + "c", ) ] assert CountedObject.instance_count == 2 @@ -432,9 +432,9 @@ def scoped_objects(duckdb_cursor): rel3 = duckdb_cursor.read_csv(obj) assert rel3.fetchall() == [ ( - 'a', - 'b', - 'c', + "a", + "b", + "c", ) ] assert CountedObject.instance_count == 3 @@ -444,28 +444,26 @@ def scoped_objects(duckdb_cursor): assert CountedObject.instance_count == 0 def test_read_csv_glob(self, tmp_path, create_temp_csv): - file1_path, file2_path = create_temp_csv - # Use the temporary file paths to read CSV files con = duckdb.connect() - rel = con.read_csv(f'{tmp_path}/file*.csv') + rel = con.read_csv(f"{tmp_path}/file*.csv") # noqa: F841 res = con.sql("select * from rel order by all").fetchall() assert res == [(1,), (2,), (3,), (4,), (5,), (6,)] @pytest.mark.xfail(condition=platform.system() == "Emscripten", reason="time zones not working") def test_read_csv_combined(self, duckdb_cursor): - CSV_FILE = TestFile('stress_test.csv') + CSV_FILE = TestFile("stress_test.csv") COLUMNS = { - 'result': 'VARCHAR', - 'table': 'BIGINT', - '_time': 'TIMESTAMPTZ', - '_measurement': 'VARCHAR', - 'bench_test': 'VARCHAR', - 'flight_id': 'VARCHAR', - 'flight_status': 'VARCHAR', - 'log_level': 'VARCHAR', - 'sys_uuid': 'VARCHAR', - 'message': 'VARCHAR', + "result": "VARCHAR", + "table": "BIGINT", + "_time": "TIMESTAMPTZ", + "_measurement": "VARCHAR", + "bench_test": "VARCHAR", + "flight_id": "VARCHAR", + "flight_status": "VARCHAR", + "log_level": "VARCHAR", + "sys_uuid": "VARCHAR", + "message": "VARCHAR", } rel = duckdb.read_csv(CSV_FILE, skiprows=1, delimiter=",", quotechar='"', escapechar="\\", dtype=COLUMNS) @@ -483,64 +481,63 @@ def test_read_csv_combined(self, duckdb_cursor): def test_read_csv_names(self, tmp_path): file = tmp_path / "file.csv" - file.write_text('one,two,three,four\n1,2,3,4\n1,2,3,4\n1,2,3,4') + file.write_text("one,two,three,four\n1,2,3,4\n1,2,3,4\n1,2,3,4") con = duckdb.connect() - rel = con.read_csv(str(file), names=['a', 'b', 'c']) - assert rel.columns == ['a', 'b', 'c', 'four'] + rel = con.read_csv(str(file), names=["a", "b", "c"]) + assert rel.columns == ["a", "b", "c", "four"] with pytest.raises(duckdb.InvalidInputException, match="read_csv only accepts 'names' as a list of strings"): - rel = con.read_csv(file, names=True) + con.read_csv(file, names=True) with pytest.raises(duckdb.InvalidInputException, match="not possible to detect the CSV Header"): - rel = con.read_csv(file, names=['a', 'b', 'c', 'd', 'e']) + con.read_csv(file, names=["a", "b", "c", "d", "e"]) # Duplicates are not okay with pytest.raises(duckdb.BinderException, match="names must have unique values"): - rel = con.read_csv(file, names=['a', 'b', 'a', 'b']) - assert rel.columns == ['a', 'b', 'a', 'b'] + con.read_csv(file, names=["a", "b", "a", "b"]) def test_read_csv_names_mixed_with_dtypes(self, tmp_path): file = tmp_path / "file.csv" - file.write_text('one,two,three,four\n1,2,3,4\n1,2,3,4\n1,2,3,4') + file.write_text("one,two,three,four\n1,2,3,4\n1,2,3,4\n1,2,3,4") con = duckdb.connect() rel = con.read_csv( file, - names=['a', 'b', 'c'], + names=["a", "b", "c"], dtype={ - 'a': int, - 'b': bool, - 'c': str, + "a": int, + "b": bool, + "c": str, }, ) - assert rel.columns == ['a', 'b', 'c', 'four'] - assert rel.types == ['BIGINT', 'BOOLEAN', 'VARCHAR', 'BIGINT'] + assert rel.columns == ["a", "b", "c", "four"] + assert rel.types == ["BIGINT", "BOOLEAN", "VARCHAR", "BIGINT"] # dtypes and names dont match - # FIXME: seems the order columns are named in this error is non-deterministic + # TODO: seems the order columns are named in this error is non-deterministic # noqa: TD002, TD003 # so for now I'm excluding the list of columns from the expected error expected_error = """do not exist in the CSV File""" with pytest.raises(duckdb.BinderException, match=expected_error): rel = con.read_csv( file, - names=['a', 'b', 'c'], + names=["a", "b", "c"], dtype={ - 'd': int, - 'e': bool, - 'f': str, + "d": int, + "e": bool, + "f": str, }, ) def test_read_csv_multi_file(self, tmp_path): file1 = tmp_path / "file1.csv" - file1.write_text('one,two,three,four\n1,2,3,4\n1,2,3,4\n1,2,3,4') + file1.write_text("one,two,three,four\n1,2,3,4\n1,2,3,4\n1,2,3,4") file2 = tmp_path / "file2.csv" - file2.write_text('one,two,three,four\n5,6,7,8\n5,6,7,8\n5,6,7,8') + file2.write_text("one,two,three,four\n5,6,7,8\n5,6,7,8\n5,6,7,8") file3 = tmp_path / "file3.csv" - file3.write_text('one,two,three,four\n9,10,11,12\n9,10,11,12\n9,10,11,12') + file3.write_text("one,two,three,four\n9,10,11,12\n9,10,11,12\n9,10,11,12") con = duckdb.connect() files = [str(file1), str(file2), str(file3)] @@ -562,146 +559,145 @@ def test_read_csv_empty_list(self): con = duckdb.connect() files = [] with pytest.raises( - duckdb.InvalidInputException, match='Please provide a non-empty list of paths or file-like objects' + duckdb.InvalidInputException, match="Please provide a non-empty list of paths or file-like objects" ): - rel = con.read_csv(files) - res = rel.fetchall() + con.read_csv(files) def test_read_auto_detect(self, tmp_path): file1 = tmp_path / "file1.csv" - file1.write_text('one|two|three|four\n1|2|3|4') + file1.write_text("one|two|three|four\n1|2|3|4") con = duckdb.connect() - rel = con.read_csv(str(file1), columns={'a': 'VARCHAR'}, auto_detect=False, header=False) - assert rel.fetchall() == [('one|two|three|four',), ('1|2|3|4',)] + rel = con.read_csv(str(file1), columns={"a": "VARCHAR"}, auto_detect=False, header=False) + assert rel.fetchall() == [("one|two|three|four",), ("1|2|3|4",)] def test_read_csv_list_invalid_path(self, tmp_path): con = duckdb.connect() file1 = tmp_path / "file1.csv" - file1.write_text('one,two,three,four\n1,2,3,4\n1,2,3,4\n1,2,3,4') + file1.write_text("one,two,three,four\n1,2,3,4\n1,2,3,4\n1,2,3,4") file3 = tmp_path / "file3.csv" - file3.write_text('one,two,three,four\n9,10,11,12\n9,10,11,12\n9,10,11,12') + file3.write_text("one,two,three,four\n9,10,11,12\n9,10,11,12\n9,10,11,12") - files = [str(file1), 'not_valid_path', str(file3)] + files = [str(file1), "not_valid_path", str(file3)] with pytest.raises(duckdb.IOException, match='No files found that match the pattern "not_valid_path"'): - rel = con.read_csv(files) - res = rel.fetchall() + con.read_csv(files) @pytest.mark.parametrize( - 'options', + "options", [ - {'lineterminator': '\\n'}, - {'lineterminator': 'LINE_FEED'}, - {'lineterminator': CSVLineTerminator.LINE_FEED}, - {'columns': {'id': 'INTEGER', 'name': 'INTEGER', 'c': 'integer', 'd': 'INTEGER'}}, - {'auto_type_candidates': ['INTEGER', 'INTEGER']}, - {'max_line_size': 10000}, - {'ignore_errors': True}, - {'ignore_errors': False}, - {'store_rejects': True}, - {'store_rejects': False}, - {'rejects_table': 'my_rejects_table'}, - {'rejects_scan': 'my_rejects_scan'}, - {'rejects_table': 'my_rejects_table', 'rejects_limit': 50}, - {'force_not_null': ['one', 'two']}, - {'buffer_size': 2097153}, - {'decimal': '.'}, - {'allow_quoted_nulls': True}, - {'allow_quoted_nulls': False}, - {'filename': True}, - {'filename': 'test'}, - {'hive_partitioning': True}, - {'hive_partitioning': False}, - {'union_by_name': True}, - {'union_by_name': False}, - {'hive_types_autocast': False}, - {'hive_types_autocast': True}, - {'hive_types': {'one': 'INTEGER', 'two': 'VARCHAR'}}, + {"lineterminator": "\\n"}, + {"lineterminator": "LINE_FEED"}, + {"lineterminator": CSVLineTerminator.LINE_FEED}, + {"columns": {"id": "INTEGER", "name": "INTEGER", "c": "integer", "d": "INTEGER"}}, + {"auto_type_candidates": ["INTEGER", "INTEGER"]}, + {"max_line_size": 10000}, + {"ignore_errors": True}, + {"ignore_errors": False}, + {"store_rejects": True}, + {"store_rejects": False}, + {"rejects_table": "my_rejects_table"}, + {"rejects_scan": "my_rejects_scan"}, + {"rejects_table": "my_rejects_table", "rejects_limit": 50}, + {"force_not_null": ["one", "two"]}, + {"buffer_size": 2097153}, + {"decimal": "."}, + {"allow_quoted_nulls": True}, + {"allow_quoted_nulls": False}, + {"filename": True}, + {"filename": "test"}, + {"hive_partitioning": True}, + {"hive_partitioning": False}, + {"union_by_name": True}, + {"union_by_name": False}, + {"hive_types_autocast": False}, + {"hive_types_autocast": True}, + {"hive_types": {"one": "INTEGER", "two": "VARCHAR"}}, ], ) @pytest.mark.skipif(sys.platform.startswith("win"), reason="Skipping on Windows because of lineterminator option") def test_read_csv_options(self, duckdb_cursor, options, tmp_path): file = tmp_path / "file.csv" - file.write_text('one,two,three,four\n1,2,3,4\n1,2,3,4\n1,2,3,4') + file.write_text("one,two,three,four\n1,2,3,4\n1,2,3,4\n1,2,3,4") print(options) - if 'hive_types' in options: - with pytest.raises(duckdb.InvalidInputException, match=r'Unknown hive_type:'): + if "hive_types" in options: + with pytest.raises(duckdb.InvalidInputException, match=r"Unknown hive_type:"): rel = duckdb_cursor.read_csv(file, **options) else: rel = duckdb_cursor.read_csv(file, **options) - res = rel.fetchall() + rel.fetchall() def test_read_comment(self, tmp_path): file1 = tmp_path / "file1.csv" - file1.write_text('one|two|three|four\n1|2|3|4#|5|6\n#bla\n1|2|3|4\n') + file1.write_text("one|two|three|four\n1|2|3|4#|5|6\n#bla\n1|2|3|4\n") con = duckdb.connect() - rel = con.read_csv(str(file1), columns={'a': 'VARCHAR'}, auto_detect=False, header=False, comment='#') - assert rel.fetchall() == [('one|two|three|four',), ('1|2|3|4',), ('1|2|3|4',)] + rel = con.read_csv(str(file1), columns={"a": "VARCHAR"}, auto_detect=False, header=False, comment="#") + assert rel.fetchall() == [("one|two|three|four",), ("1|2|3|4",), ("1|2|3|4",)] def test_read_enum(self, tmp_path): file1 = tmp_path / "file1.csv" - file1.write_text('feelings\nhappy\nsad\nangry\nhappy\n') + file1.write_text("feelings\nhappy\nsad\nangry\nhappy\n") con = duckdb.connect() con.execute("CREATE TYPE mood AS ENUM ('happy', 'sad', 'angry')") - rel = con.read_csv(str(file1), dtype=['mood']) - assert rel.fetchall() == [('happy',), ('sad',), ('angry',), ('happy',)] + rel = con.read_csv(str(file1), dtype=["mood"]) + assert rel.fetchall() == [("happy",), ("sad",), ("angry",), ("happy",)] - rel = con.read_csv(str(file1), dtype={'feelings': 'mood'}) - assert rel.fetchall() == [('happy',), ('sad',), ('angry',), ('happy',)] + rel = con.read_csv(str(file1), dtype={"feelings": "mood"}) + assert rel.fetchall() == [("happy",), ("sad",), ("angry",), ("happy",)] - rel = con.read_csv(str(file1), columns={'feelings': 'mood'}) - assert rel.fetchall() == [('happy',), ('sad',), ('angry',), ('happy',)] + rel = con.read_csv(str(file1), columns={"feelings": "mood"}) + assert rel.fetchall() == [("happy",), ("sad",), ("angry",), ("happy",)] with pytest.raises(duckdb.CatalogException, match="Type with name mood_2 does not exist!"): - rel = con.read_csv(str(file1), columns={'feelings': 'mood_2'}) + rel = con.read_csv(str(file1), columns={"feelings": "mood_2"}) with pytest.raises(duckdb.CatalogException, match="Type with name mood_2 does not exist!"): - rel = con.read_csv(str(file1), dtype={'feelings': 'mood_2'}) + rel = con.read_csv(str(file1), dtype={"feelings": "mood_2"}) with pytest.raises(duckdb.CatalogException, match="Type with name mood_2 does not exist!"): - rel = con.read_csv(str(file1), dtype=['mood_2']) + rel = con.read_csv(str(file1), dtype=["mood_2"]) def test_strict_mode(self, tmp_path): file1 = tmp_path / "file1.csv" - file1.write_text('one|two|three|four\n1|2|3|4\n1|2|3|4|5\n1|2|3|4\n') + file1.write_text("one|two|three|four\n1|2|3|4\n1|2|3|4|5\n1|2|3|4\n") con = duckdb.connect() + rel = con.read_csv( + str(file1), + header=True, + delimiter="|", + columns={"a": "INTEGER", "b": "INTEGER", "c": "INTEGER", "d": "INTEGER"}, + auto_detect=False, + ) with pytest.raises(duckdb.InvalidInputException, match="CSV Error on Line"): - rel = con.read_csv( - str(file1), - header=True, - delimiter='|', - columns={'a': 'INTEGER', 'b': 'INTEGER', 'c': 'INTEGER', 'd': 'INTEGER'}, - auto_detect=False, - ) rel.fetchall() + rel = con.read_csv( str(file1), header=True, - delimiter='|', + delimiter="|", strict_mode=False, - columns={'a': 'INTEGER', 'b': 'INTEGER', 'c': 'INTEGER', 'd': 'INTEGER'}, + columns={"a": "INTEGER", "b": "INTEGER", "c": "INTEGER", "d": "INTEGER"}, auto_detect=False, ) assert rel.fetchall() == [(1, 2, 3, 4), (1, 2, 3, 4), (1, 2, 3, 4)] def test_union_by_name(self, tmp_path): file1 = tmp_path / "file1.csv" - file1.write_text('one|two|three|four\n1|2|3|4') + file1.write_text("one|two|three|four\n1|2|3|4") file1 = tmp_path / "file2.csv" - file1.write_text('two|three|four|five\n2|3|4|5') + file1.write_text("two|three|four|five\n2|3|4|5") con = duckdb.connect() file_path = tmp_path / "file*.csv" rel = con.read_csv(file_path, union_by_name=True) - assert rel.columns == ['one', 'two', 'three', 'four', 'five'] + assert rel.columns == ["one", "two", "three", "four", "five"] assert rel.fetchall() == [(1, 2, 3, 4, None), (None, 2, 3, 4, 5)] def test_thousands_separator(self, tmp_path): @@ -709,32 +705,32 @@ def test_thousands_separator(self, tmp_path): file.write_text('money\n"10,000.23"\n"1,000,000,000.01"') con = duckdb.connect() - rel = con.read_csv(file, thousands=',') + rel = con.read_csv(file, thousands=",") assert rel.fetchall() == [(10000.23,), (1000000000.01,)] with pytest.raises( duckdb.BinderException, match="Unsupported parameter for THOUSANDS: should be max one character" ): - con.read_csv(file, thousands=',,,') + con.read_csv(file, thousands=",,,") def test_skip_comment_option(self, tmp_path): file1 = tmp_path / "file1.csv" - file1.write_text('skip this line\n# comment\nx,y,z\n1,2,3\n4,5,6') + file1.write_text("skip this line\n# comment\nx,y,z\n1,2,3\n4,5,6") con = duckdb.connect() - rel = con.read_csv(file1, comment='#', skiprows=1, all_varchar=True) - assert rel.columns == ['x', 'y', 'z'] - assert rel.fetchall() == [('1', '2', '3'), ('4', '5', '6')] + rel = con.read_csv(file1, comment="#", skiprows=1, all_varchar=True) + assert rel.columns == ["x", "y", "z"] + assert rel.fetchall() == [("1", "2", "3"), ("4", "5", "6")] def test_files_to_sniff_option(self, tmp_path): file1 = tmp_path / "file1.csv" - file1.write_text('bar,baz\n2025-05-12,baz') + file1.write_text("bar,baz\n2025-05-12,baz") file2 = tmp_path / "file2.csv" - file2.write_text('bar,baz\nbar,baz') + file2.write_text("bar,baz\nbar,baz") file_path = tmp_path / "file*.csv" con = duckdb.connect() + rel = con.read_csv(file_path, files_to_sniff=1) with pytest.raises(duckdb.ConversionException, match="Conversion Error"): - rel = con.read_csv(file_path, files_to_sniff=1) rel.fetchall() rel = con.read_csv(file_path, files_to_sniff=-1) - assert rel.fetchall() == [('2025-05-12', 'baz'), ('bar', 'baz')] + assert rel.fetchall() == [("2025-05-12", "baz"), ("bar", "baz")] diff --git a/tests/fast/api/test_relation_to_view.py b/tests/fast/api/test_relation_to_view.py index f4a43d54..14f4cb4d 100644 --- a/tests/fast/api/test_relation_to_view.py +++ b/tests/fast/api/test_relation_to_view.py @@ -1,30 +1,31 @@ import pytest + import duckdb -class TestRelationToView(object): +class TestRelationToView: def test_values_to_view(self, duckdb_cursor): - rel = duckdb_cursor.values(['test', 'this is a long string']) + rel = duckdb_cursor.values(["test", "this is a long string"]) res = rel.fetchall() - assert res == [('test', 'this is a long string')] + assert res == [("test", "this is a long string")] - rel.to_view('vw1') + rel.to_view("vw1") - view = duckdb_cursor.table('vw1') + view = duckdb_cursor.table("vw1") res = view.fetchall() - assert res == [('test', 'this is a long string')] + assert res == [("test", "this is a long string")] def test_relation_to_view(self, duckdb_cursor): rel = duckdb_cursor.sql("select 'test', 'this is a long string'") res = rel.fetchall() - assert res == [('test', 'this is a long string')] + assert res == [("test", "this is a long string")] - rel.to_view('vw1') + rel.to_view("vw1") - view = duckdb_cursor.table('vw1') + view = duckdb_cursor.table("vw1") res = view.fetchall() - assert res == [('test', 'this is a long string')] + assert res == [("test", "this is a long string")] def test_registered_relation(self, duckdb_cursor): rel = duckdb_cursor.sql("select 'test', 'this is a long string'") @@ -33,12 +34,12 @@ def test_registered_relation(self, duckdb_cursor): # Register on a different connection is not allowed with pytest.raises( duckdb.InvalidInputException, - match='was created by another Connection and can therefore not be used by this Connection', + match="was created by another Connection and can therefore not be used by this Connection", ): - con.register('cross_connection', rel) + con.register("cross_connection", rel) # Register on the same connection just creates a view - duckdb_cursor.register('same_connection', rel) - view = duckdb_cursor.table('same_connection') + duckdb_cursor.register("same_connection", rel) + view = duckdb_cursor.table("same_connection") res = view.fetchall() - assert res == [('test', 'this is a long string')] + assert res == [("test", "this is a long string")] diff --git a/tests/fast/api/test_streaming_result.py b/tests/fast/api/test_streaming_result.py index e51f62e4..4003f20f 100644 --- a/tests/fast/api/test_streaming_result.py +++ b/tests/fast/api/test_streaming_result.py @@ -1,11 +1,12 @@ import pytest + import duckdb -class TestStreamingResult(object): +class TestStreamingResult: def test_fetch_one(self, duckdb_cursor): # fetch one - res = duckdb_cursor.sql('SELECT * FROM range(100000)') + res = duckdb_cursor.sql("SELECT * FROM range(100000)") result = [] while len(result) < 5000: tpl = res.fetchone() @@ -17,14 +18,11 @@ def test_fetch_one(self, duckdb_cursor): "SELECT CASE WHEN i < 10000 THEN i ELSE concat('hello', i::VARCHAR)::INT END FROM range(100000) t(i)" ) with pytest.raises(duckdb.ConversionException): - while True: - tpl = res.fetchone() - if tpl is None: - break + res.fetchone() def test_fetch_many(self, duckdb_cursor): # fetch many - res = duckdb_cursor.sql('SELECT * FROM range(100000)') + res = duckdb_cursor.sql("SELECT * FROM range(100000)") result = [] while len(result) < 5000: tpl = res.fetchmany(10) @@ -36,20 +34,17 @@ def test_fetch_many(self, duckdb_cursor): "SELECT CASE WHEN i < 10000 THEN i ELSE concat('hello', i::VARCHAR)::INT END FROM range(100000) t(i)" ) with pytest.raises(duckdb.ConversionException): - while True: - tpl = res.fetchmany(10) - if tpl is None: - break + res.fetchmany(10) def test_record_batch_reader(self, duckdb_cursor): pytest.importorskip("pyarrow") pytest.importorskip("pyarrow.dataset") # record batch reader - res = duckdb_cursor.sql('SELECT * FROM range(100000) t(i)') + res = duckdb_cursor.sql("SELECT * FROM range(100000) t(i)") reader = res.fetch_arrow_reader(batch_size=16_384) result = [] for batch in reader: - result += batch.to_pydict()['i'] + result += batch.to_pydict()["i"] assert result == list(range(100000)) # record batch reader with error @@ -60,9 +55,9 @@ def test_record_batch_reader(self, duckdb_cursor): reader = res.fetch_arrow_reader(batch_size=16_384) def test_9801(self, duckdb_cursor): - duckdb_cursor.execute('CREATE TABLE test(id INTEGER , name VARCHAR NOT NULL);') + duckdb_cursor.execute("CREATE TABLE test(id INTEGER , name VARCHAR NOT NULL);") - words = ['aaaaaaaaaaaaaaaaaaaaaaa', 'bbbb', 'ccccccccc', 'ííííííííí'] + words = ["aaaaaaaaaaaaaaaaaaaaaaa", "bbbb", "ccccccccc", "ííííííííí"] lines = [(i, words[i % 4]) for i in range(1000)] duckdb_cursor.executemany("INSERT INTO TEST (id, name) VALUES (?, ?)", lines) diff --git a/tests/fast/api/test_to_csv.py b/tests/fast/api/test_to_csv.py index e48ae1b8..97f13d8b 100644 --- a/tests/fast/api/test_to_csv.py +++ b/tests/fast/api/test_to_csv.py @@ -1,18 +1,19 @@ -import duckdb -import tempfile -import os -import pandas._testing as tm -import datetime import csv +import datetime +import os +import tempfile + import pytest -from conftest import NumpyPandas, ArrowPandas, getTimeSeriesData +from conftest import ArrowPandas, NumpyPandas, getTimeSeriesData + +import duckdb -class TestToCSV(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) +class TestToCSV: + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_basic_to_csv(self, pandas): - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': [5, 3, 23, 2], 'b': [45, 234, 234, 2]}) + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 + df = pandas.DataFrame({"a": [5, 3, 23, 2], "b": [45, 234, 234, 2]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name) @@ -20,21 +21,21 @@ def test_basic_to_csv(self, pandas): csv_rel = duckdb.read_csv(temp_file_name) assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_sep(self, pandas): - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': [5, 3, 23, 2], 'b': [45, 234, 234, 2]}) + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 + df = pandas.DataFrame({"a": [5, 3, 23, 2], "b": [45, 234, 234, 2]}) rel = duckdb.from_df(df) - rel.to_csv(temp_file_name, sep=',') + rel.to_csv(temp_file_name, sep=",") - csv_rel = duckdb.read_csv(temp_file_name, sep=',') + csv_rel = duckdb.read_csv(temp_file_name, sep=",") assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_na_rep(self, pandas): - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': [5, None, 23, 2], 'b': [45, 234, 234, 2]}) + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 + df = pandas.DataFrame({"a": [5, None, 23, 2], "b": [45, 234, 234, 2]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, na_rep="test") @@ -42,10 +43,10 @@ def test_to_csv_na_rep(self, pandas): csv_rel = duckdb.read_csv(temp_file_name, na_values="test") assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_header(self, pandas): - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': [5, None, 23, 2], 'b': [45, 234, 234, 2]}) + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 + df = pandas.DataFrame({"a": [5, None, 23, 2], "b": [45, 234, 234, 2]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name) @@ -53,20 +54,20 @@ def test_to_csv_header(self, pandas): csv_rel = duckdb.read_csv(temp_file_name) assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_quotechar(self, pandas): - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': ["\'a,b,c\'", None, "hello", "bye"], 'b': [45, 234, 234, 2]}) + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 + df = pandas.DataFrame({"a": ["'a,b,c'", None, "hello", "bye"], "b": [45, 234, 234, 2]}) rel = duckdb.from_df(df) - rel.to_csv(temp_file_name, quotechar='\'', sep=',') + rel.to_csv(temp_file_name, quotechar="'", sep=",") - csv_rel = duckdb.read_csv(temp_file_name, sep=',', quotechar='\'') + csv_rel = duckdb.read_csv(temp_file_name, sep=",", quotechar="'") assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_escapechar(self, pandas): - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 df = pandas.DataFrame( { "c_bool": [True, False], @@ -76,13 +77,13 @@ def test_to_csv_escapechar(self, pandas): } ) rel = duckdb.from_df(df) - rel.to_csv(temp_file_name, quotechar='"', escapechar='!') - csv_rel = duckdb.read_csv(temp_file_name, quotechar='"', escapechar='!') + rel.to_csv(temp_file_name, quotechar='"', escapechar="!") + csv_rel = duckdb.read_csv(temp_file_name, quotechar='"', escapechar="!") assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_date_format(self, pandas): - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 df = pandas.DataFrame(getTimeSeriesData()) dt_index = df.index df = pandas.DataFrame({"A": dt_index, "B": dt_index.shift(1)}, index=dt_index) @@ -93,82 +94,82 @@ def test_to_csv_date_format(self, pandas): assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_timestamp_format(self, pandas): - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 data = [datetime.time(hour=23, minute=1, second=34, microsecond=234345)] - df = pandas.DataFrame({'0': pandas.Series(data=data, dtype='object')}) + df = pandas.DataFrame({"0": pandas.Series(data=data, dtype="object")}) rel = duckdb.from_df(df) - rel.to_csv(temp_file_name, timestamp_format='%m/%d/%Y') + rel.to_csv(temp_file_name, timestamp_format="%m/%d/%Y") - csv_rel = duckdb.read_csv(temp_file_name, timestamp_format='%m/%d/%Y') + csv_rel = duckdb.read_csv(temp_file_name, timestamp_format="%m/%d/%Y") assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_quoting_off(self, pandas): - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': ['string1', 'string2', 'string3']}) + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 + df = pandas.DataFrame({"a": ["string1", "string2", "string3"]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, quoting=None) csv_rel = duckdb.read_csv(temp_file_name) assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_quoting_on(self, pandas): - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': ['string1', 'string2', 'string3']}) + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 + df = pandas.DataFrame({"a": ["string1", "string2", "string3"]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, quoting="force") csv_rel = duckdb.read_csv(temp_file_name) assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_quoting_quote_all(self, pandas): - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': ['string1', 'string2', 'string3']}) + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 + df = pandas.DataFrame({"a": ["string1", "string2", "string3"]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, quoting=csv.QUOTE_ALL) csv_rel = duckdb.read_csv(temp_file_name) assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_encoding_incorrect(self, pandas): - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': ['string1', 'string2', 'string3']}) + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 + df = pandas.DataFrame({"a": ["string1", "string2", "string3"]}) rel = duckdb.from_df(df) with pytest.raises( duckdb.InvalidInputException, match="Invalid Input Error: The only supported encoding option is 'UTF8" ): rel.to_csv(temp_file_name, encoding="nope") - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_encoding_correct(self, pandas): - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': ['string1', 'string2', 'string3']}) + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 + df = pandas.DataFrame({"a": ["string1", "string2", "string3"]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, encoding="UTF-8") csv_rel = duckdb.read_csv(temp_file_name) assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_compression_gzip(self, pandas): - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': ['string1', 'string2', 'string3']}) + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 + df = pandas.DataFrame({"a": ["string1", "string2", "string3"]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, compression="gzip") csv_rel = duckdb.read_csv(temp_file_name, compression="gzip") assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_partition(self, pandas): - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 df = pandas.DataFrame( { - "c_category": ['a', 'a', 'b', 'b'], + "c_category": ["a", "a", "b", "b"], "c_bool": [True, False, True, True], "c_float": [1.0, 3.2, 3.0, 4.0], "c_int": [42, None, 123, 321], @@ -178,23 +179,23 @@ def test_to_csv_partition(self, pandas): rel = duckdb.from_df(df) rel.to_csv(temp_file_name, header=True, partition_by=["c_category"]) csv_rel = duckdb.sql( - f'''FROM read_csv_auto('{temp_file_name}/*/*.csv', hive_partitioning=TRUE, header=TRUE);''' + f"""FROM read_csv_auto('{temp_file_name}/*/*.csv', hive_partitioning=TRUE, header=TRUE);""" ) expected = [ - (True, 1.0, 42.0, 'a', 'a'), - (False, 3.2, None, 'b,c', 'a'), - (True, 3.0, 123.0, 'e', 'b'), - (True, 4.0, 321.0, 'f', 'b'), + (True, 1.0, 42.0, "a", "a"), + (False, 3.2, None, "b,c", "a"), + (True, 3.0, 123.0, "e", "b"), + (True, 4.0, 321.0, "f", "b"), ] assert csv_rel.execute().fetchall() == expected - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_partition_with_columns_written(self, pandas): - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 df = pandas.DataFrame( { - "c_category": ['a', 'a', 'b', 'b'], + "c_category": ["a", "a", "b", "b"], "c_bool": [True, False, True, True], "c_float": [1.0, 3.2, 3.0, 4.0], "c_int": [42, None, 123, 321], @@ -205,17 +206,17 @@ def test_to_csv_partition_with_columns_written(self, pandas): res = duckdb.sql("FROM rel order by all") rel.to_csv(temp_file_name, header=True, partition_by=["c_category"], write_partition_columns=True) csv_rel = duckdb.sql( - f'''FROM read_csv_auto('{temp_file_name}/*/*.csv', hive_partitioning=TRUE, header=TRUE) order by all;''' + f"""FROM read_csv_auto('{temp_file_name}/*/*.csv', hive_partitioning=TRUE, header=TRUE) order by all;""" ) assert res.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_overwrite(self, pandas): - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 df = pandas.DataFrame( { - "c_category_1": ['a', 'a', 'b', 'b'], - "c_category_2": ['c', 'c', 'd', 'd'], + "c_category_1": ["a", "a", "b", "b"], + "c_category_2": ["c", "c", "d", "d"], "c_bool": [True, False, True, True], "c_float": [1.0, 3.2, 3.0, 4.0], "c_int": [42, None, 123, 321], @@ -226,24 +227,24 @@ def test_to_csv_overwrite(self, pandas): rel.to_csv(temp_file_name, header=True, partition_by=["c_category_1"]) # csv to be overwritten rel.to_csv(temp_file_name, header=True, partition_by=["c_category_1"], overwrite=True) csv_rel = duckdb.sql( - f'''FROM read_csv_auto('{temp_file_name}/*/*.csv', hive_partitioning=TRUE, header=TRUE);''' + f"""FROM read_csv_auto('{temp_file_name}/*/*.csv', hive_partitioning=TRUE, header=TRUE);""" ) # When partition columns are read from directory names, column order become different from original expected = [ - ('c', True, 1.0, 42.0, 'a', 'a'), - ('c', False, 3.2, None, 'b,c', 'a'), - ('d', True, 3.0, 123.0, 'e', 'b'), - ('d', True, 4.0, 321.0, 'f', 'b'), + ("c", True, 1.0, 42.0, "a", "a"), + ("c", False, 3.2, None, "b,c", "a"), + ("d", True, 3.0, 123.0, "e", "b"), + ("d", True, 4.0, 321.0, "f", "b"), ] assert csv_rel.execute().fetchall() == expected - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_overwrite_with_columns_written(self, pandas): - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 df = pandas.DataFrame( { - "c_category_1": ['a', 'a', 'b', 'b'], - "c_category_2": ['c', 'c', 'd', 'd'], + "c_category_1": ["a", "a", "b", "b"], + "c_category_2": ["c", "c", "d", "d"], "c_bool": [True, False, True, True], "c_float": [1.0, 3.2, 3.0, 4.0], "c_int": [42, None, 123, 321], @@ -258,18 +259,18 @@ def test_to_csv_overwrite_with_columns_written(self, pandas): temp_file_name, header=True, partition_by=["c_category_1"], overwrite=True, write_partition_columns=True ) csv_rel = duckdb.sql( - f'''FROM read_csv_auto('{temp_file_name}/*/*.csv', hive_partitioning=TRUE, header=TRUE) order by all;''' + f"""FROM read_csv_auto('{temp_file_name}/*/*.csv', hive_partitioning=TRUE, header=TRUE) order by all;""" ) res = duckdb.sql("FROM rel order by all") assert res.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_overwrite_not_enabled(self, pandas): - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 df = pandas.DataFrame( { - "c_category_1": ['a', 'a', 'b', 'b'], - "c_category_2": ['c', 'c', 'd', 'd'], + "c_category_1": ["a", "a", "b", "b"], + "c_category_2": ["c", "c", "d", "d"], "c_bool": [True, False, True, True], "c_float": [1.0, 3.2, 3.0, 4.0], "c_int": [42, None, 123, 321], @@ -281,14 +282,14 @@ def test_to_csv_overwrite_not_enabled(self, pandas): with pytest.raises(duckdb.IOException, match="OVERWRITE"): rel.to_csv(temp_file_name, header=True, partition_by=["c_category_1"]) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_per_thread_output(self, pandas): - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 num_threads = duckdb.sql("select current_setting('threads')").fetchone()[0] - print('num_threads:', num_threads) + print("num_threads:", num_threads) df = pandas.DataFrame( { - "c_category": ['a', 'a', 'b', 'b'], + "c_category": ["a", "a", "b", "b"], "c_bool": [True, False, True, True], "c_float": [1.0, 3.2, 3.0, 4.0], "c_int": [42, None, 123, 321], @@ -297,16 +298,16 @@ def test_to_csv_per_thread_output(self, pandas): ) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, header=True, per_thread_output=True) - csv_rel = duckdb.read_csv(f'{temp_file_name}/*.csv', header=True) + csv_rel = duckdb.read_csv(f"{temp_file_name}/*.csv", header=True) assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_use_tmp_file(self, pandas): - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 df = pandas.DataFrame( { - "c_category_1": ['a', 'a', 'b', 'b'], - "c_category_2": ['c', 'c', 'd', 'd'], + "c_category_1": ["a", "a", "b", "b"], + "c_category_2": ["c", "c", "d", "d"], "c_bool": [True, False, True, True], "c_float": [1.0, 3.2, 3.0, 4.0], "c_int": [42, None, 123, 321], diff --git a/tests/fast/api/test_to_parquet.py b/tests/fast/api/test_to_parquet.py index d778aba3..8d8162b0 100644 --- a/tests/fast/api/test_to_parquet.py +++ b/tests/fast/api/test_to_parquet.py @@ -1,19 +1,17 @@ -import duckdb -import tempfile import os import tempfile -import pandas._testing as tm -import datetime -import csv + import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb -class TestToParquet(object): +class TestToParquet: @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) def test_basic_to_parquet(self, pd): - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pd.DataFrame({'a': [5, 3, 23, 2], 'b': [45, 234, 234, 2]}) + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 + df = pd.DataFrame({"a": [5, 3, 23, 2], "b": [45, 234, 234, 2]}) rel = duckdb.from_df(df) rel.to_parquet(temp_file_name) @@ -23,46 +21,41 @@ def test_basic_to_parquet(self, pd): @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) def test_compression_gzip(self, pd): - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pd.DataFrame({'a': ['string1', 'string2', 'string3']}) + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 + df = pd.DataFrame({"a": ["string1", "string2", "string3"]}) rel = duckdb.from_df(df) rel.to_parquet(temp_file_name, compression="gzip") csv_rel = duckdb.read_parquet(temp_file_name, compression="gzip") assert rel.execute().fetchall() == csv_rel.execute().fetchall() def test_field_ids_auto(self): - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - rel = duckdb.sql('''SELECT {i: 128} AS my_struct''') - rel.to_parquet(temp_file_name, field_ids='auto') + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 + rel = duckdb.sql("""SELECT {i: 128} AS my_struct""") + rel.to_parquet(temp_file_name, field_ids="auto") parquet_rel = duckdb.read_parquet(temp_file_name) assert rel.execute().fetchall() == parquet_rel.execute().fetchall() def test_field_ids(self): - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - rel = duckdb.sql('''SELECT 1 as i, {j: 128} AS my_struct''') - rel.to_parquet(temp_file_name, field_ids=dict(i=42, my_struct={'__duckdb_field_id': 43, 'j': 44})) + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 + rel = duckdb.sql("""SELECT 1 as i, {j: 128} AS my_struct""") + rel.to_parquet(temp_file_name, field_ids={"i": 42, "my_struct": {"__duckdb_field_id": 43, "j": 44}}) parquet_rel = duckdb.read_parquet(temp_file_name) assert rel.execute().fetchall() == parquet_rel.execute().fetchall() - assert ( - [('duckdb_schema', None), ('i', 42), ('my_struct', 43), ('j', 44)] - == duckdb.sql( - f''' + assert duckdb.sql( + f""" select name,field_id from parquet_schema('{temp_file_name}') - ''' - ) - .execute() - .fetchall() - ) + """ + ).execute().fetchall() == [("duckdb_schema", None), ("i", 42), ("my_struct", 43), ("j", 44)] @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) - @pytest.mark.parametrize('row_group_size_bytes', [122880 * 1024, '2MB']) + @pytest.mark.parametrize("row_group_size_bytes", [122880 * 1024, "2MB"]) def test_row_group_size_bytes(self, pd, row_group_size_bytes): con = duckdb.connect() con.execute("SET preserve_insertion_order=false;") - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pd.DataFrame({'a': ['string1', 'string2', 'string3']}) + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 + df = pd.DataFrame({"a": ["string1", "string2", "string3"]}) rel = con.from_df(df) rel.to_parquet(temp_file_name, row_group_size_bytes=row_group_size_bytes) parquet_rel = con.read_parquet(temp_file_name) @@ -70,22 +63,22 @@ def test_row_group_size_bytes(self, pd, row_group_size_bytes): @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) def test_row_group_size(self, pd): - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pd.DataFrame({'a': ['string1', 'string2', 'string3']}) + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 + df = pd.DataFrame({"a": ["string1", "string2", "string3"]}) rel = duckdb.from_df(df) rel.to_parquet(temp_file_name, row_group_size=122880) parquet_rel = duckdb.read_parquet(temp_file_name) assert rel.execute().fetchall() == parquet_rel.execute().fetchall() @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) - @pytest.mark.parametrize('write_columns', [None, True, False]) + @pytest.mark.parametrize("write_columns", [None, True, False]) def test_partition(self, pd, write_columns): - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 df = pd.DataFrame( { "name": ["rei", "shinji", "asuka", "kaworu"], "float": [321.0, 123.0, 23.0, 340.0], - "category": ['a', 'a', 'b', 'c'], + "category": ["a", "a", "b", "c"], } ) rel = duckdb.from_df(df) @@ -95,14 +88,14 @@ def test_partition(self, pd, write_columns): assert result.execute().fetchall() == expected @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) - @pytest.mark.parametrize('write_columns', [None, True, False]) + @pytest.mark.parametrize("write_columns", [None, True, False]) def test_overwrite(self, pd, write_columns): - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 df = pd.DataFrame( { "name": ["rei", "shinji", "asuka", "kaworu"], "float": [321.0, 123.0, 23.0, 340.0], - "category": ['a', 'a', 'b', 'c'], + "category": ["a", "a", "b", "c"], } ) rel = duckdb.from_df(df) @@ -115,12 +108,12 @@ def test_overwrite(self, pd, write_columns): @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) def test_use_tmp_file(self, pd): - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 df = pd.DataFrame( { "name": ["rei", "shinji", "asuka", "kaworu"], "float": [321.0, 123.0, 23.0, 340.0], - "category": ['a', 'a', 'b', 'c'], + "category": ["a", "a", "b", "c"], } ) rel = duckdb.from_df(df) @@ -131,49 +124,49 @@ def test_use_tmp_file(self, pd): @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) def test_per_thread_output(self, pd): - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 num_threads = duckdb.sql("select current_setting('threads')").fetchone()[0] - print('threads:', num_threads) + print("threads:", num_threads) df = pd.DataFrame( { "name": ["rei", "shinji", "asuka", "kaworu"], "float": [321.0, 123.0, 23.0, 340.0], - "category": ['a', 'a', 'b', 'c'], + "category": ["a", "a", "b", "c"], } ) rel = duckdb.from_df(df) rel.to_parquet(temp_file_name, per_thread_output=True) - result = duckdb.read_parquet(f'{temp_file_name}/*.parquet') + result = duckdb.read_parquet(f"{temp_file_name}/*.parquet") assert rel.execute().fetchall() == result.execute().fetchall() @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) def test_append(self, pd): - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 df = pd.DataFrame( { "name": ["rei", "shinji", "asuka", "kaworu"], "float": [321.0, 123.0, 23.0, 340.0], - "category": ['a', 'a', 'b', 'c'], + "category": ["a", "a", "b", "c"], } ) rel = duckdb.from_df(df) - rel.to_parquet(temp_file_name, partition_by=['category']) + rel.to_parquet(temp_file_name, partition_by=["category"]) df_to_append = pd.DataFrame( { "name": ["random"], "float": [420], - "category": ['a'], + "category": ["a"], } ) rel_to_append = duckdb.from_df(df_to_append) - rel_to_append.to_parquet(temp_file_name, partition_by=['category'], append=True) + rel_to_append.to_parquet(temp_file_name, partition_by=["category"], append=True) result = duckdb.sql(f"FROM read_parquet('{temp_file_name}/*/*.parquet', hive_partitioning=TRUE) ORDER BY name") result.show() expected = [ - ('asuka', 23.0, 'b'), - ('kaworu', 340.0, 'c'), - ('random', 420.0, 'a'), - ('rei', 321.0, 'a'), - ('shinji', 123.0, 'a'), + ("asuka", 23.0, "b"), + ("kaworu", 340.0, "c"), + ("random", 420.0, "a"), + ("rei", 321.0, "a"), + ("shinji", 123.0, "a"), ] assert result.execute().fetchall() == expected diff --git a/tests/fast/api/test_with_propagating_exceptions.py b/tests/fast/api/test_with_propagating_exceptions.py index e9cfb3c0..6f4719fb 100644 --- a/tests/fast/api/test_with_propagating_exceptions.py +++ b/tests/fast/api/test_with_propagating_exceptions.py @@ -1,18 +1,14 @@ import pytest + import duckdb -class TestWithPropagatingExceptions(object): +class TestWithPropagatingExceptions: def test_with(self): # Should propagate exception raised in the 'with duckdb.connect() ..' - with pytest.raises(duckdb.ParserException, match="syntax error at or near *"): - with duckdb.connect() as con: - print('before') - con.execute('invalid') - print('after') + with pytest.raises(duckdb.ParserException, match=r"syntax error at or near *"), duckdb.connect() as con: + con.execute("invalid") # Does not raise an exception with duckdb.connect() as con: - print('before') - con.execute('select 1') - print('after') + con.execute("select 1") diff --git a/tests/fast/arrow/parquet_write_roundtrip.py b/tests/fast/arrow/parquet_write_roundtrip.py index 093040c0..29d95e64 100644 --- a/tests/fast/arrow/parquet_write_roundtrip.py +++ b/tests/fast/arrow/parquet_write_roundtrip.py @@ -1,29 +1,31 @@ -import duckdb -import pytest +import datetime import tempfile + import numpy import pandas -import datetime +import pytest + +import duckdb pa = pytest.importorskip("pyarrow") def parquet_types_test(type_list): - temp = tempfile.NamedTemporaryFile() - temp_name = temp.name + with tempfile.NamedTemporaryFile() as tmp: + temp_name = tmp.name for type_pair in type_list: value_list = type_pair[0] numpy_type = type_pair[1] sql_type = type_pair[2] add_cast = len(type_pair) > 3 and type_pair[3] add_sql_cast = len(type_pair) > 4 and type_pair[4] - df = pandas.DataFrame.from_dict({'val': numpy.array(value_list, dtype=numpy_type)}) + df = pandas.DataFrame.from_dict({"val": numpy.array(value_list, dtype=numpy_type)}) duckdb_cursor = duckdb.connect() duckdb_cursor.execute(f"CREATE TABLE tmp AS SELECT val::{sql_type} val FROM df") duckdb_cursor.execute(f"COPY tmp TO '{temp_name}' (FORMAT PARQUET)") read_df = pandas.read_parquet(temp_name) if add_cast: - read_df['val'] = read_df['val'].astype(numpy_type) + read_df["val"] = read_df["val"].astype(numpy_type) assert df.equals(read_df) read_from_duckdb = duckdb_cursor.execute(f"SELECT * FROM parquet_scan('{temp_name}')").df() @@ -37,19 +39,19 @@ def parquet_types_test(type_list): assert read_df.equals(read_from_arrow) -class TestParquetRoundtrip(object): +class TestParquetRoundtrip: def test_roundtrip_numeric(self, duckdb_cursor): type_list = [ - ([-(2**7), 0, 2**7 - 1], numpy.int8, 'TINYINT'), - ([-(2**15), 0, 2**15 - 1], numpy.int16, 'SMALLINT'), - ([-(2**31), 0, 2**31 - 1], numpy.int32, 'INTEGER'), - ([-(2**63), 0, 2**63 - 1], numpy.int64, 'BIGINT'), - ([0, 42, 2**8 - 1], numpy.uint8, 'UTINYINT'), - ([0, 42, 2**16 - 1], numpy.uint16, 'USMALLINT'), - ([0, 42, 2**32 - 1], numpy.uint32, 'UINTEGER', False, True), - ([0, 42, 2**64 - 1], numpy.uint64, 'UBIGINT'), - ([0, 0.5, -0.5], numpy.float32, 'REAL'), - ([0, 0.5, -0.5], numpy.float64, 'DOUBLE'), + ([-(2**7), 0, 2**7 - 1], numpy.int8, "TINYINT"), + ([-(2**15), 0, 2**15 - 1], numpy.int16, "SMALLINT"), + ([-(2**31), 0, 2**31 - 1], numpy.int32, "INTEGER"), + ([-(2**63), 0, 2**63 - 1], numpy.int64, "BIGINT"), + ([0, 42, 2**8 - 1], numpy.uint8, "UTINYINT"), + ([0, 42, 2**16 - 1], numpy.uint16, "USMALLINT"), + ([0, 42, 2**32 - 1], numpy.uint32, "UINTEGER", False, True), + ([0, 42, 2**64 - 1], numpy.uint64, "UBIGINT"), + ([0, 0.5, -0.5], numpy.float32, "REAL"), + ([0, 0.5, -0.5], numpy.float64, "DOUBLE"), ] parquet_types_test(type_list) @@ -61,15 +63,15 @@ def test_roundtrip_timestamp(self, duckdb_cursor): datetime.datetime(1992, 7, 9, 7, 5, 33), ] type_list = [ - (date_time_list, 'datetime64[ns]', 'TIMESTAMP_NS'), - (date_time_list, 'datetime64[us]', 'TIMESTAMP'), - (date_time_list, 'datetime64[ms]', 'TIMESTAMP_MS'), - (date_time_list, 'datetime64[s]', 'TIMESTAMP_S'), - (date_time_list, 'datetime64[D]', 'DATE', True), + (date_time_list, "datetime64[ns]", "TIMESTAMP_NS"), + (date_time_list, "datetime64[us]", "TIMESTAMP"), + (date_time_list, "datetime64[ms]", "TIMESTAMP_MS"), + (date_time_list, "datetime64[s]", "TIMESTAMP_S"), + (date_time_list, "datetime64[D]", "DATE", True), ] parquet_types_test(type_list) def test_roundtrip_varchar(self, duckdb_cursor): - varchar_list = ['hello', 'this is a very long string', 'hello', None] - type_list = [(varchar_list, object, 'VARCHAR')] + varchar_list = ["hello", "this is a very long string", "hello", None] + type_list = [(varchar_list, object, "VARCHAR")] parquet_types_test(type_list) diff --git a/tests/fast/arrow/test_10795.py b/tests/fast/arrow/test_10795.py index 043bf4ff..5dc88402 100644 --- a/tests/fast/arrow/test_10795.py +++ b/tests/fast/arrow/test_10795.py @@ -1,12 +1,13 @@ -import duckdb import pytest -pyarrow = pytest.importorskip('pyarrow') +import duckdb + +pyarrow = pytest.importorskip("pyarrow") -@pytest.mark.parametrize('arrow_large_buffer_size', [True, False]) +@pytest.mark.parametrize("arrow_large_buffer_size", [True, False]) def test_10795(arrow_large_buffer_size): conn = duckdb.connect() conn.sql(f"set arrow_large_buffer_size={arrow_large_buffer_size}") arrow = conn.sql("select map(['non-inlined string', 'test', 'duckdb'], [42, 1337, 123]) as map").to_arrow_table() - assert arrow.to_pydict() == {'map': [[('non-inlined string', 42), ('test', 1337), ('duckdb', 123)]]} + assert arrow.to_pydict() == {"map": [[("non-inlined string", 42), ("test", 1337), ("duckdb", 123)]]} diff --git a/tests/fast/arrow/test_12384.py b/tests/fast/arrow/test_12384.py index af9c8ed2..933428f0 100644 --- a/tests/fast/arrow/test_12384.py +++ b/tests/fast/arrow/test_12384.py @@ -1,20 +1,24 @@ -import duckdb +from pathlib import Path + import pytest -import os -pa = pytest.importorskip('pyarrow') +import duckdb + +pa = pytest.importorskip("pyarrow") def test_10795(): - arrow_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'arrow_table') - with pa.memory_map(arrow_filename, 'r') as source: + arrow_filename = Path(__file__).parent / "data" / "arrow_table" + with pa.memory_map(str(arrow_filename), "r") as source: reader = pa.ipc.RecordBatchFileReader(source) taxi_fhvhv_arrow = reader.read_all() - con = duckdb.connect(database=':memory:') + con = duckdb.connect(database=":memory:") con.execute("SET TimeZone='UTC';") - con.register('taxi_fhvhv', taxi_fhvhv_arrow) - res = con.execute( - "SELECT PULocationID, pickup_datetime FROM taxi_fhvhv WHERE pickup_datetime >= '2023-01-01T00:00:00-05:00' AND PULocationID = 244" - ).fetchall() + con.register("taxi_fhvhv", taxi_fhvhv_arrow) + res = con.execute(""" + SELECT PULocationID, pickup_datetime + FROM taxi_fhvhv + WHERE pickup_datetime >= '2023-01-01T00:00:00-05:00' AND PULocationID = 244 + """).fetchall() assert len(res) == 3685 diff --git a/tests/fast/arrow/test_14344.py b/tests/fast/arrow/test_14344.py index 522228c0..8bb4ba9b 100644 --- a/tests/fast/arrow/test_14344.py +++ b/tests/fast/arrow/test_14344.py @@ -1,18 +1,18 @@ -import duckdb +import hashlib + import pytest pa = pytest.importorskip("pyarrow") -import hashlib def test_14344(duckdb_cursor): - my_table = pa.Table.from_pydict({"foo": pa.array([hashlib.sha256("foo".encode()).digest()], type=pa.binary())}) - my_table2 = pa.Table.from_pydict( - {"foo": pa.array([hashlib.sha256("foo".encode()).digest()], type=pa.binary()), "a": ["123"]} + my_table = pa.Table.from_pydict({"foo": pa.array([hashlib.sha256(b"foo").digest()], type=pa.binary())}) # noqa: F841 + my_table2 = pa.Table.from_pydict( # noqa: F841 + {"foo": pa.array([hashlib.sha256(b"foo").digest()], type=pa.binary()), "a": ["123"]} ) res = duckdb_cursor.sql( - f""" + """ SELECT my_table2.* EXCLUDE (foo) FROM @@ -22,4 +22,4 @@ def test_14344(duckdb_cursor): USING (foo) """ ).fetchall() - assert res == [('123',)] + assert res == [("123",)] diff --git a/tests/fast/arrow/test_2426.py b/tests/fast/arrow/test_2426.py index cdef8da7..f43284d3 100644 --- a/tests/fast/arrow/test_2426.py +++ b/tests/fast/arrow/test_2426.py @@ -1,15 +1,12 @@ import duckdb -import os try: - import pyarrow as pa - can_run = True -except: +except Exception: can_run = False -class Test2426(object): +class Test2426: def test_2426(self, duckdb_cursor): if not can_run: return @@ -18,19 +15,19 @@ def test_2426(self, duckdb_cursor): con.execute("Create Table test (a integer)") for i in range(1024): - for j in range(2): + for _j in range(2): con.execute("Insert Into test values ('" + str(i) + "')") con.execute("Insert Into test values ('5000')") con.execute("Insert Into test values ('6000')") - sql = ''' + sql = """ SELECT a, COUNT(*) AS repetitions FROM test GROUP BY a - ''' + """ result_df = con.execute(sql).df() arrow_table = con.execute(sql).fetch_arrow_table() arrow_df = arrow_table.to_pandas() - assert result_df['repetitions'].sum() == arrow_df['repetitions'].sum() + assert result_df["repetitions"].sum() == arrow_df["repetitions"].sum() diff --git a/tests/fast/arrow/test_5547.py b/tests/fast/arrow/test_5547.py index b27b29b2..32beec29 100644 --- a/tests/fast/arrow/test_5547.py +++ b/tests/fast/arrow/test_5547.py @@ -1,9 +1,10 @@ -import duckdb import pandas as pd -from pandas.testing import assert_frame_equal import pytest +from pandas.testing import assert_frame_equal + +import duckdb -pa = pytest.importorskip('pyarrow') +pa = pytest.importorskip("pyarrow") def test_5547(): @@ -12,12 +13,12 @@ def test_5547(): tbl = pa.Table.from_pandas( pd.DataFrame.from_records( [ - dict( - id=i, - nested=dict( - a=i, - ), - ) + { + "id": i, + "nested": { + "a": i, + }, + } for i in range(num_rows) ] ) diff --git a/tests/fast/arrow/test_6584.py b/tests/fast/arrow/test_6584.py index 9a6241f9..feadc6d7 100644 --- a/tests/fast/arrow/test_6584.py +++ b/tests/fast/arrow/test_6584.py @@ -1,8 +1,10 @@ from concurrent.futures import ThreadPoolExecutor -import duckdb + import pytest -pyarrow = pytest.importorskip('pyarrow') +import duckdb + +pyarrow = pytest.importorskip("pyarrow") def f(cur, i, data): diff --git a/tests/fast/arrow/test_6796.py b/tests/fast/arrow/test_6796.py index 6690f22c..bf557038 100644 --- a/tests/fast/arrow/test_6796.py +++ b/tests/fast/arrow/test_6796.py @@ -1,11 +1,12 @@ -import duckdb import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb -pyarrow = pytest.importorskip('pyarrow') +pyarrow = pytest.importorskip("pyarrow") -@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) +@pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_6796(pandas): conn = duckdb.connect() input_df = pandas.DataFrame({"foo": ["bar"]}) @@ -19,9 +20,9 @@ def test_6796(pandas): # fetching directly into Pandas works res_df = conn.execute(query).fetch_df() - res_arrow = conn.execute(query).fetch_arrow_table() + res_arrow = conn.execute(query).fetch_arrow_table() # noqa: F841 - df_arrow_table = pyarrow.Table.from_pandas(res_df) + df_arrow_table = pyarrow.Table.from_pandas(res_df) # noqa: F841 result_1 = conn.execute("select * from df_arrow_table order by all").fetchall() diff --git a/tests/fast/arrow/test_7652.py b/tests/fast/arrow/test_7652.py index afe3b738..516d7d1f 100644 --- a/tests/fast/arrow/test_7652.py +++ b/tests/fast/arrow/test_7652.py @@ -1,23 +1,20 @@ -import duckdb -import os -import pytest import tempfile +import pytest + pa = pytest.importorskip("pyarrow", minversion="11") pq = pytest.importorskip("pyarrow.parquet", minversion="11") -class Test7652(object): +class Test7652: def test_7652(self, duckdb_cursor): - temp_file_name = tempfile.NamedTemporaryFile(suffix='.parquet').name + with tempfile.NamedTemporaryFile(suffix=".parquet") as tmp: + temp_file_name = tmp.name # Generate a list of values that aren't uniform in changes. generated_list = [1, 0, 2] - print("Generated values:", generated_list) - print(f"Min value: {min(generated_list)} max value: {max(generated_list)}") - # Convert list of values to a PyArrow table with a single column. - fake_table = pa.Table.from_arrays([pa.array(generated_list, pa.int64())], names=['n0']) + fake_table = pa.Table.from_arrays([pa.array(generated_list, pa.int64())], names=["n0"]) # Write that column with DELTA_BINARY_PACKED encoding with pq.ParquetWriter( @@ -36,7 +33,7 @@ def test_7652(self, duckdb_cursor): # Attempt to perform the same thing with duckdb. print("Retrieving from duckdb") - duckdb_result = list(map(lambda v: v[0], duckdb_cursor.sql(f"select * from '{temp_file_name}'").fetchall())) + duckdb_result = [v[0] for v in duckdb_cursor.sql(f"select * from '{temp_file_name}'").fetchall()] print("DuckDB result:", duckdb_result) assert min(duckdb_result) == min(generated_list) diff --git a/tests/fast/arrow/test_7699.py b/tests/fast/arrow/test_7699.py index c8c234ef..ba2f4af3 100644 --- a/tests/fast/arrow/test_7699.py +++ b/tests/fast/arrow/test_7699.py @@ -1,13 +1,13 @@ -import duckdb -import pytest import string +import pytest + pa = pytest.importorskip("pyarrow") pq = pytest.importorskip("pyarrow.parquet") pl = pytest.importorskip("polars") -class Test7699(object): +class Test7699: def test_7699(self, duckdb_cursor): pl_tbl = pl.DataFrame( { @@ -22,4 +22,4 @@ def test_7699(self, duckdb_cursor): rel = duckdb_cursor.sql("select * from df1234") res = rel.fetchall() - assert res == [('K',), ('L',), ('K',), ('L',), ('M',)] + assert res == [("K",), ("L",), ("K",), ("L",), ("M",)] diff --git a/tests/fast/arrow/test_8522.py b/tests/fast/arrow/test_8522.py index 84aa125c..53a8fdfb 100644 --- a/tests/fast/arrow/test_8522.py +++ b/tests/fast/arrow/test_8522.py @@ -1,22 +1,22 @@ -import duckdb -import pytest -import string import datetime as dt +import pytest + pa = pytest.importorskip("pyarrow") # Reconstruct filters when pushing down into arrow scan # arrow supports timestamp_tz with different units than US, we only support US -# so we have to convert ConstantValues back to their native unit when pushing the filter expression containing them down to pyarrow -class Test8522(object): +# so we have to convert ConstantValues back to their native unit when pushing the filter +# expression containing them down to pyarrow +class Test8522: def test_8522(self, duckdb_cursor): - t_us = pa.Table.from_arrays( + t_us = pa.Table.from_arrays( # noqa: F841 arrays=[pa.array([dt.datetime(2022, 1, 1)])], schema=pa.schema([pa.field("time", pa.timestamp("us", tz="UTC"))]), ) - t_ms = pa.Table.from_arrays( + t_ms = pa.Table.from_arrays( # noqa: F841 arrays=[pa.array([dt.datetime(2022, 1, 1)])], schema=pa.schema([pa.field("time", pa.timestamp("ms", tz="UTC"))]), ) diff --git a/tests/fast/arrow/test_9443.py b/tests/fast/arrow/test_9443.py index 7de04bde..fe5a2ce1 100644 --- a/tests/fast/arrow/test_9443.py +++ b/tests/fast/arrow/test_9443.py @@ -1,14 +1,13 @@ -import duckdb +from datetime import time +from pathlib import PurePosixPath + import pytest pq = pytest.importorskip("pyarrow.parquet") pa = pytest.importorskip("pyarrow") -from datetime import time -from pathlib import PurePosixPath - -class Test9443(object): +class Test9443: def test_9443(self, tmp_path, duckdb_cursor): arrow_table = pa.Table.from_pylist( [ diff --git a/tests/fast/arrow/test_arrow_batch_index.py b/tests/fast/arrow/test_arrow_batch_index.py index dadf6f89..094360ea 100644 --- a/tests/fast/arrow/test_arrow_batch_index.py +++ b/tests/fast/arrow/test_arrow_batch_index.py @@ -1,21 +1,20 @@ -import duckdb import pytest -import pandas as pd + import duckdb pa = pytest.importorskip("pyarrow") -class TestArrowBatchIndex(object): +class TestArrowBatchIndex: def test_arrow_batch_index(self, duckdb_cursor): con = duckdb.connect() - df = con.execute('SELECT * FROM range(10000000) t(i)').df() - arrow_tbl = pa.Table.from_pandas(df) + df = con.execute("SELECT * FROM range(10000000) t(i)").df() + arrow_tbl = pa.Table.from_pandas(df) # noqa: F841 - con.execute('CREATE TABLE tbl AS SELECT * FROM arrow_tbl') + con.execute("CREATE TABLE tbl AS SELECT * FROM arrow_tbl") - result = con.execute('SELECT * FROM tbl LIMIT 5').fetchall() + result = con.execute("SELECT * FROM tbl LIMIT 5").fetchall() assert [x[0] for x in result] == [0, 1, 2, 3, 4] - result = con.execute('SELECT * FROM tbl LIMIT 5 OFFSET 777778').fetchall() + result = con.execute("SELECT * FROM tbl LIMIT 5 OFFSET 777778").fetchall() assert [x[0] for x in result] == [777778, 777779, 777780, 777781, 777782] diff --git a/tests/fast/arrow/test_arrow_binary_view.py b/tests/fast/arrow/test_arrow_binary_view.py index 7d9d0afc..4e161ac3 100644 --- a/tests/fast/arrow/test_arrow_binary_view.py +++ b/tests/fast/arrow/test_arrow_binary_view.py @@ -1,14 +1,15 @@ -import duckdb import pytest +import duckdb + pa = pytest.importorskip("pyarrow") -class TestArrowBinaryView(object): +class TestArrowBinaryView: def test_arrow_binary_view(self, duckdb_cursor): con = duckdb.connect() tab = pa.table({"x": pa.array([b"abc", b"thisisaverybigbinaryyaymorethanfifteen", None], pa.binary_view())}) - assert con.execute("FROM tab").fetchall() == [(b'abc',), (b'thisisaverybigbinaryyaymorethanfifteen',), (None,)] + assert con.execute("FROM tab").fetchall() == [(b"abc",), (b"thisisaverybigbinaryyaymorethanfifteen",), (None,)] # By default we won't export a view assert not con.execute("FROM tab").fetch_arrow_table().equals(tab) # We do the binary view from 1.4 onwards @@ -16,5 +17,5 @@ def test_arrow_binary_view(self, duckdb_cursor): assert con.execute("FROM tab").fetch_arrow_table().equals(tab) assert con.execute("FROM tab where x = 'thisisaverybigbinaryyaymorethanfifteen'").fetchall() == [ - (b'thisisaverybigbinaryyaymorethanfifteen',) + (b"thisisaverybigbinaryyaymorethanfifteen",) ] diff --git a/tests/fast/arrow/test_arrow_case_sensitive.py b/tests/fast/arrow/test_arrow_case_sensitive.py index 6106cc75..11bca339 100644 --- a/tests/fast/arrow/test_arrow_case_sensitive.py +++ b/tests/fast/arrow/test_arrow_case_sensitive.py @@ -1,24 +1,23 @@ -import duckdb import pytest pa = pytest.importorskip("pyarrow") -class TestArrowCaseSensitive(object): +class TestArrowCaseSensitive: def test_arrow_case_sensitive(self, duckdb_cursor): data = (pa.array([1], type=pa.int32()), pa.array([1000], type=pa.int32())) - arrow_table = pa.Table.from_arrays([data[0], data[1]], ['A1', 'a1']) + arrow_table = pa.Table.from_arrays([data[0], data[1]], ["A1", "a1"]) - duckdb_cursor.register('arrow_tbl', arrow_table) - assert duckdb_cursor.table("arrow_tbl").columns == ['A1', 'a1_1'] + duckdb_cursor.register("arrow_tbl", arrow_table) + assert duckdb_cursor.table("arrow_tbl").columns == ["A1", "a1_1"] assert duckdb_cursor.execute("select A1 from arrow_tbl;").fetchall() == [(1,)] assert duckdb_cursor.execute("select a1_1 from arrow_tbl;").fetchall() == [(1000,)] - assert arrow_table.column_names == ['A1', 'a1'] + assert arrow_table.column_names == ["A1", "a1"] def test_arrow_case_sensitive_repeated(self, duckdb_cursor): data = (pa.array([1], type=pa.int32()), pa.array([1000], type=pa.int32())) - arrow_table = pa.Table.from_arrays([data[0], data[1], data[1]], ['A1', 'a1_1', 'a1']) + arrow_table = pa.Table.from_arrays([data[0], data[1], data[1]], ["A1", "a1_1", "a1"]) - duckdb_cursor.register('arrow_tbl', arrow_table) - assert duckdb_cursor.table("arrow_tbl").columns == ['A1', 'a1_1', 'a1_2'] - assert arrow_table.column_names == ['A1', 'a1_1', 'a1'] + duckdb_cursor.register("arrow_tbl", arrow_table) + assert duckdb_cursor.table("arrow_tbl").columns == ["A1", "a1_1", "a1_2"] + assert arrow_table.column_names == ["A1", "a1_1", "a1"] diff --git a/tests/fast/arrow/test_arrow_decimal256.py b/tests/fast/arrow/test_arrow_decimal256.py index 0ab84d3a..d687ec8a 100644 --- a/tests/fast/arrow/test_arrow_decimal256.py +++ b/tests/fast/arrow/test_arrow_decimal256.py @@ -1,14 +1,16 @@ -import duckdb -import pytest from decimal import Decimal +import pytest + +import duckdb + pa = pytest.importorskip("pyarrow") -class TestArrowDecimal256(object): +class TestArrowDecimal256: def test_decimal_256_throws(self, duckdb_cursor): with duckdb.connect() as conn: - pa_decimal256 = pa.Table.from_pylist( + pa_decimal256 = pa.Table.from_pylist( # noqa: F841 [{"data": Decimal("100.00")} for _ in range(4)], pa.schema([("data", pa.decimal256(12, 4))]), ) diff --git a/tests/fast/arrow/test_arrow_decimal_32_64.py b/tests/fast/arrow/test_arrow_decimal_32_64.py index 4a960454..301d890f 100644 --- a/tests/fast/arrow/test_arrow_decimal_32_64.py +++ b/tests/fast/arrow/test_arrow_decimal_32_64.py @@ -1,14 +1,16 @@ -import duckdb -import pytest from decimal import Decimal +import pytest + +import duckdb + pa = pytest.importorskip("pyarrow") -class TestArrowDecimalTypes(object): +class TestArrowDecimalTypes: def test_decimal_32(self, duckdb_cursor): duckdb_cursor = duckdb.connect() - duckdb_cursor.execute('SET arrow_output_version = 1.5') + duckdb_cursor.execute("SET arrow_output_version = 1.5") decimal_32 = pa.Table.from_pylist( [ {"data": Decimal("100.20")}, @@ -20,10 +22,10 @@ def test_decimal_32(self, duckdb_cursor): ) # Test scan assert duckdb_cursor.execute("FROM decimal_32").fetchall() == [ - (Decimal('100.20'),), - (Decimal('110.21'),), - (Decimal('31.20'),), - (Decimal('500.20'),), + (Decimal("100.20"),), + (Decimal("110.21"),), + (Decimal("31.20"),), + (Decimal("500.20"),), ] # Test filter pushdown assert duckdb_cursor.execute("SELECT COUNT(*) FROM decimal_32 where data > 100 and data < 200 ").fetchall() == [ @@ -37,7 +39,7 @@ def test_decimal_32(self, duckdb_cursor): def test_decimal_64(self, duckdb_cursor): duckdb_cursor = duckdb.connect() - duckdb_cursor.execute('SET arrow_output_version = 1.5') + duckdb_cursor.execute("SET arrow_output_version = 1.5") decimal_64 = pa.Table.from_pylist( [ {"data": Decimal("1000.231")}, @@ -50,10 +52,10 @@ def test_decimal_64(self, duckdb_cursor): # Test scan assert duckdb_cursor.execute("FROM decimal_64").fetchall() == [ - (Decimal('1000.231'),), - (Decimal('1100.231'),), - (Decimal('999999999999.231'),), - (Decimal('500.200'),), + (Decimal("1000.231"),), + (Decimal("1100.231"),), + (Decimal("999999999999.231"),), + (Decimal("500.200"),), ] # Test Filter pushdown diff --git a/tests/fast/arrow/test_arrow_extensions.py b/tests/fast/arrow/test_arrow_extensions.py index 9180fa90..f79c32c4 100644 --- a/tests/fast/arrow/test_arrow_extensions.py +++ b/tests/fast/arrow/test_arrow_extensions.py @@ -1,15 +1,17 @@ -import duckdb -import pytest -import uuid +# ruff: noqa: F841 +import datetime import json +import uuid from uuid import UUID -import datetime -pa = pytest.importorskip('pyarrow', '18.0.0') +import pytest +import duckdb -class TestCanonicalExtensionTypes(object): +pa = pytest.importorskip("pyarrow", "18.0.0") + +class TestCanonicalExtensionTypes: def test_uuid(self): duckdb_cursor = duckdb.connect() duckdb_cursor.execute("SET arrow_lossless_conversion = true") @@ -17,9 +19,9 @@ def test_uuid(self): storage_array = pa.array([uuid.uuid4().bytes for _ in range(4)], pa.binary(16)) storage_array = pa.uuid().wrap_array(storage_array) - arrow_table = pa.Table.from_arrays([storage_array], names=['uuid_col']) + arrow_table = pa.Table.from_arrays([storage_array], names=["uuid_col"]) - duck_arrow = duckdb_cursor.execute('FROM arrow_table').fetch_arrow_table() + duck_arrow = duckdb_cursor.execute("FROM arrow_table").fetch_arrow_table() assert duck_arrow.equals(arrow_table) @@ -30,14 +32,14 @@ def test_uuid_from_duck(self): arrow_table = duckdb_cursor.execute("select uuid from test_all_types()").fetch_arrow_table() assert arrow_table.to_pylist() == [ - {'uuid': UUID('00000000-0000-0000-0000-000000000000')}, - {'uuid': UUID('ffffffff-ffff-ffff-ffff-ffffffffffff')}, - {'uuid': None}, + {"uuid": UUID("00000000-0000-0000-0000-000000000000")}, + {"uuid": UUID("ffffffff-ffff-ffff-ffff-ffffffffffff")}, + {"uuid": None}, ] assert duckdb_cursor.execute("FROM arrow_table").fetchall() == [ - (UUID('00000000-0000-0000-0000-000000000000'),), - (UUID('ffffffff-ffff-ffff-ffff-ffffffffffff'),), + (UUID("00000000-0000-0000-0000-000000000000"),), + (UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"),), (None,), ] @@ -45,8 +47,8 @@ def test_uuid_from_duck(self): "select '00000000-0000-0000-0000-000000000100'::UUID as uuid" ).fetch_arrow_table() - assert arrow_table.to_pylist() == [{'uuid': UUID('00000000-0000-0000-0000-000000000100')}] - assert duckdb_cursor.execute("FROM arrow_table").fetchall() == [(UUID('00000000-0000-0000-0000-000000000100'),)] + assert arrow_table.to_pylist() == [{"uuid": UUID("00000000-0000-0000-0000-000000000100")}] + assert duckdb_cursor.execute("FROM arrow_table").fetchall() == [(UUID("00000000-0000-0000-0000-000000000100"),)] def test_json(self, duckdb_cursor): data = {"name": "Pedro", "age": 28, "car": "VW Fox"} @@ -56,10 +58,10 @@ def test_json(self, duckdb_cursor): storage_array = pa.array([json_string], pa.string()) - arrow_table = pa.Table.from_arrays([storage_array], names=['json_col']) + arrow_table = pa.Table.from_arrays([storage_array], names=["json_col"]) duckdb_cursor.execute("SET arrow_lossless_conversion = true") - duck_arrow = duckdb_cursor.execute('FROM arrow_table').fetch_arrow_table() + duck_arrow = duckdb_cursor.execute("FROM arrow_table").fetch_arrow_table() assert duck_arrow.equals(arrow_table) @@ -70,8 +72,8 @@ def test_uuid_no_def(self): res_arrow = duckdb_cursor.execute("select uuid from test_all_types()").fetch_arrow_table() res_duck = duckdb_cursor.execute("from res_arrow").fetchall() assert res_duck == [ - (UUID('00000000-0000-0000-0000-000000000000'),), - (UUID('ffffffff-ffff-ffff-ffff-ffffffffffff'),), + (UUID("00000000-0000-0000-0000-000000000000"),), + (UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"),), (None,), ] @@ -79,15 +81,15 @@ def test_uuid_no_def_lossless(self): duckdb_cursor = duckdb.connect() res_arrow = duckdb_cursor.execute("select uuid from test_all_types()").fetch_arrow_table() assert res_arrow.to_pylist() == [ - {'uuid': '00000000-0000-0000-0000-000000000000'}, - {'uuid': 'ffffffff-ffff-ffff-ffff-ffffffffffff'}, - {'uuid': None}, + {"uuid": "00000000-0000-0000-0000-000000000000"}, + {"uuid": "ffffffff-ffff-ffff-ffff-ffffffffffff"}, + {"uuid": None}, ] res_duck = duckdb_cursor.execute("from res_arrow").fetchall() assert res_duck == [ - ('00000000-0000-0000-0000-000000000000',), - ('ffffffff-ffff-ffff-ffff-ffffffffffff',), + ("00000000-0000-0000-0000-000000000000",), + ("ffffffff-ffff-ffff-ffff-ffffffffffff",), (None,), ] @@ -98,8 +100,8 @@ def test_uuid_no_def_stream(self): res_arrow = duckdb_cursor.execute("select uuid from test_all_types()").fetch_record_batch() res_duck = duckdb.execute("from res_arrow").fetchall() assert res_duck == [ - (UUID('00000000-0000-0000-0000-000000000000'),), - (UUID('ffffffff-ffff-ffff-ffff-ffffffffffff'),), + (UUID("00000000-0000-0000-0000-000000000000"),), + (UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"),), (None,), ] @@ -109,62 +111,62 @@ def test_function(x): return x con = duckdb.connect() - con.create_function('test', test_function, ['UUID'], 'UUID', type='arrow') + con.create_function("test", test_function, ["UUID"], "UUID", type="arrow") - rel = con.sql("select ? as x", params=[uuid.UUID('ffffffff-ffff-ffff-ffff-ffffffffffff')]) + rel = con.sql("select ? as x", params=[uuid.UUID("ffffffff-ffff-ffff-ffff-ffffffffffff")]) rel.project("test(x) from t").fetchall() def test_unimplemented_extension(self, duckdb_cursor): class MyType(pa.ExtensionType): - def __init__(self): + def __init__(self) -> None: pa.ExtensionType.__init__(self, pa.binary(5), "pedro.binary") - def __arrow_ext_serialize__(self): - return b'' + def __arrow_ext_serialize__(self) -> bytes: + return b"" @classmethod - def __arrow_ext_deserialize__(cls, storage_type, serialized): - return UuidTypeWrong() + def __arrow_ext_deserialize__(cls, storage_type, serialized) -> object: + return None - storage_array = pa.array(['pedro'], pa.binary(5)) + storage_array = pa.array(["pedro"], pa.binary(5)) my_type = MyType() storage_array = my_type.wrap_array(storage_array) age_array = pa.array([29], pa.int32()) - arrow_table = pa.Table.from_arrays([storage_array, age_array], names=['pedro_pedro_pedro', 'age']) + arrow_table = pa.Table.from_arrays([storage_array, age_array], names=["pedro_pedro_pedro", "age"]) - duck_arrow = duckdb_cursor.execute('FROM arrow_table').fetch_arrow_table() - assert duckdb_cursor.execute('FROM duck_arrow').fetchall() == [(b'pedro', 29)] + duck_arrow = duckdb_cursor.execute("FROM arrow_table").fetch_arrow_table() + assert duckdb_cursor.execute("FROM duck_arrow").fetchall() == [(b"pedro", 29)] def test_hugeint(self): con = duckdb.connect() con.execute("SET arrow_lossless_conversion = true") - storage_array = pa.array([b'\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff'], pa.binary(16)) + storage_array = pa.array([b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"], pa.binary(16)) hugeint_type = pa.opaque(pa.binary(16), "hugeint", "DuckDB") storage_array = hugeint_type.wrap_array(storage_array) - arrow_table = pa.Table.from_arrays([storage_array], names=['numbers']) + arrow_table = pa.Table.from_arrays([storage_array], names=["numbers"]) - assert con.execute('FROM arrow_table').fetchall() == [(-1,)] + assert con.execute("FROM arrow_table").fetchall() == [(-1,)] - assert con.execute('FROM arrow_table').fetch_arrow_table().equals(arrow_table) + assert con.execute("FROM arrow_table").fetch_arrow_table().equals(arrow_table) con.execute("SET arrow_lossless_conversion = false") - assert not con.execute('FROM arrow_table').fetch_arrow_table().equals(arrow_table) + assert not con.execute("FROM arrow_table").fetch_arrow_table().equals(arrow_table) def test_uhugeint(self, duckdb_cursor): - storage_array = pa.array([b'\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff'], pa.binary(16)) + storage_array = pa.array([b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"], pa.binary(16)) uhugeint_type = pa.opaque(pa.binary(16), "uhugeint", "DuckDB") storage_array = uhugeint_type.wrap_array(storage_array) - arrow_table = pa.Table.from_arrays([storage_array], names=['numbers']) + arrow_table = pa.Table.from_arrays([storage_array], names=["numbers"]) - assert duckdb_cursor.execute('FROM arrow_table').fetchall() == [(340282366920938463463374607431768211455,)] + assert duckdb_cursor.execute("FROM arrow_table").fetchall() == [(340282366920938463463374607431768211455,)] def test_bit(self): con = duckdb.connect() @@ -176,18 +178,18 @@ def test_bit(self): res_bit = con.execute("SELECT '0101011'::BIT str FROM range(5) tbl(i)").fetch_arrow_table() assert con.execute("FROM res_blob").fetchall() == [ - (b'\x01\xab',), - (b'\x01\xab',), - (b'\x01\xab',), - (b'\x01\xab',), - (b'\x01\xab',), + (b"\x01\xab",), + (b"\x01\xab",), + (b"\x01\xab",), + (b"\x01\xab",), + (b"\x01\xab",), ] assert con.execute("FROM res_bit").fetchall() == [ - ('0101011',), - ('0101011',), - ('0101011',), - ('0101011',), - ('0101011',), + ("0101011",), + ("0101011",), + ("0101011",), + ("0101011",), + ("0101011",), ] def test_timetz(self): @@ -207,14 +209,14 @@ def test_timetz(self): def test_bignum(self): con = duckdb.connect() res_bignum = con.execute( - "SELECT '179769313486231570814527423731704356798070567525844996598917476803157260780028538760589558632766878171540458953514382464234321326889464182768467546703537516986049910576551282076245490090389328944075868508455133942304583236903222948165808559332123348274797826204144723168738177180919299881250404026184124858368'::bignum a FROM range(1) tbl(i)" + "SELECT '179769313486231570814527423731704356798070567525844996598917476803157260780028538760589558632766878171540458953514382464234321326889464182768467546703537516986049910576551282076245490090389328944075868508455133942304583236903222948165808559332123348274797826204144723168738177180919299881250404026184124858368'::bignum a FROM range(1) tbl(i)" # noqa: E501 ).fetch_arrow_table() - assert res_bignum.column("a").type.type_name == 'bignum' + assert res_bignum.column("a").type.type_name == "bignum" assert res_bignum.column("a").type.vendor_name == "DuckDB" assert con.execute("FROM res_bignum").fetchall() == [ ( - '179769313486231570814527423731704356798070567525844996598917476803157260780028538760589558632766878171540458953514382464234321326889464182768467546703537516986049910576551282076245490090389328944075868508455133942304583236903222948165808559332123348274797826204144723168738177180919299881250404026184124858368', + "179769313486231570814527423731704356798070567525844996598917476803157260780028538760589558632766878171540458953514382464234321326889464182768467546703537516986049910576551282076245490090389328944075868508455133942304583236903222948165808559332123348274797826204144723168738177180919299881250404026184124858368", ) ] @@ -235,9 +237,9 @@ def test_extension_dictionary(self, duckdb_cursor): indices = pa.array([0, 1, 0, 1, 2, 1, 0, 2]) dictionary = pa.array( [ - b'\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff', - b'\x01\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff', - b'\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff', + b"\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff", + b"\x01\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff", + b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff", ], pa.binary(16), ) @@ -245,7 +247,7 @@ def test_extension_dictionary(self, duckdb_cursor): dictionary = uhugeint_type.wrap_array(dictionary) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array], ['a']) + arrow_table = pa.Table.from_arrays([dict_array], ["a"]) rel = duckdb_cursor.from_arrow(arrow_table) assert rel.execute().fetchall() == [ (340282366920938463463374607431768211200,), @@ -263,13 +265,13 @@ def test_boolean(self): con.execute("SET arrow_lossless_conversion = true") storage_array = pa.array([-1, 0, 1, 2, None], pa.int8()) bool8_array = pa.ExtensionArray.from_storage(pa.bool8(), storage_array) - arrow_table = pa.Table.from_arrays([bool8_array], names=['bool8']) - assert con.execute('FROM arrow_table').fetchall() == [(True,), (False,), (True,), (True,), (None,)] - result_table = con.execute('FROM arrow_table').fetch_arrow_table() + arrow_table = pa.Table.from_arrays([bool8_array], names=["bool8"]) + assert con.execute("FROM arrow_table").fetchall() == [(True,), (False,), (True,), (True,), (None,)] + result_table = con.execute("FROM arrow_table").fetch_arrow_table() res_storage_array = pa.array([1, 0, 1, 1, None], pa.int8()) res_bool8_array = pa.ExtensionArray.from_storage(pa.bool8(), res_storage_array) - res_arrow_table = pa.Table.from_arrays([res_bool8_array], names=['bool8']) + res_arrow_table = pa.Table.from_arrays([res_bool8_array], names=["bool8"]) assert result_table.equals(res_arrow_table) @@ -279,7 +281,7 @@ def test_accept_malformed_complex_json(self, duckdb_cursor): pa.binary(), metadata={ "ARROW:extension:name": "foofyfoo", - "ARROW:extension:metadata": 'this is not valid json', + "ARROW:extension:metadata": "this is not valid json", }, ) schema = pa.schema([field]) @@ -296,7 +298,7 @@ def test_accept_malformed_complex_json(self, duckdb_cursor): pa.binary(), metadata={ "ARROW:extension:name": "arrow.opaque", - "ARROW:extension:metadata": 'this is not valid json', + "ARROW:extension:metadata": "this is not valid json", }, ) schema = pa.schema([field]) @@ -337,9 +339,9 @@ def test_accept_malformed_complex_json(self, duckdb_cursor): schema=schema, ) assert duckdb_cursor.sql("""DESCRIBE FROM bignum_table;""").fetchone() == ( - 'bignum_value', - 'BIGNUM', - 'YES', + "bignum_value", + "BIGNUM", + "YES", None, None, None, diff --git a/tests/fast/arrow/test_arrow_fetch.py b/tests/fast/arrow/test_arrow_fetch.py index 04a34595..0547020f 100644 --- a/tests/fast/arrow/test_arrow_fetch.py +++ b/tests/fast/arrow/test_arrow_fetch.py @@ -1,11 +1,8 @@ import duckdb -import pytest try: - import pyarrow as pa - can_run = True -except: +except Exception: can_run = False @@ -18,7 +15,7 @@ def check_equal(duckdb_conn): assert arrow_result == true_result -class TestArrowFetch(object): +class TestArrowFetch: def test_empty_table(self, duckdb_cursor): if not can_run: return @@ -83,8 +80,8 @@ def test_to_arrow_chunk_size(self, duckdb_cursor): duckdb_cursor = duckdb.connect() duckdb_cursor.execute("CREATE table t as select range a from range(3000);") - relation = duckdb_cursor.table('t') + relation = duckdb_cursor.table("t") arrow_tbl = relation.fetch_arrow_table() - assert arrow_tbl['a'].num_chunks == 1 + assert arrow_tbl["a"].num_chunks == 1 arrow_tbl = relation.fetch_arrow_table(2048) - assert arrow_tbl['a'].num_chunks == 2 + assert arrow_tbl["a"].num_chunks == 2 diff --git a/tests/fast/arrow/test_arrow_fetch_recordbatch.py b/tests/fast/arrow/test_arrow_fetch_recordbatch.py index 24d7c2c7..0070430b 100644 --- a/tests/fast/arrow/test_arrow_fetch_recordbatch.py +++ b/tests/fast/arrow/test_arrow_fetch_recordbatch.py @@ -1,10 +1,11 @@ -import duckdb import pytest -pa = pytest.importorskip('pyarrow') +import duckdb + +pa = pytest.importorskip("pyarrow") -class TestArrowFetchRecordBatch(object): +class TestArrowFetchRecordBatch: # Test with basic numeric conversion (integers, floats, and others fall this code-path) def test_record_batch_next_batch_numeric(self, duckdb_cursor): duckdb_cursor = duckdb.connect() @@ -12,7 +13,7 @@ def test_record_batch_next_batch_numeric(self, duckdb_cursor): duckdb_cursor.execute("CREATE table t as select range a from range(3000);") query = duckdb_cursor.execute("SELECT a FROM t") record_batch_reader = query.fetch_record_batch(1024) - assert record_batch_reader.schema.names == ['a'] + assert record_batch_reader.schema.names == ["a"] chunk = record_batch_reader.read_next_batch() assert len(chunk) == 1024 chunk = record_batch_reader.read_next_batch() @@ -38,7 +39,7 @@ def test_record_batch_next_batch_bool(self, duckdb_cursor): ) query = duckdb_cursor.execute("SELECT a FROM t") record_batch_reader = query.fetch_record_batch(1024) - assert record_batch_reader.schema.names == ['a'] + assert record_batch_reader.schema.names == ["a"] chunk = record_batch_reader.read_next_batch() assert len(chunk) == 1024 chunk = record_batch_reader.read_next_batch() @@ -63,7 +64,7 @@ def test_record_batch_next_batch_varchar(self, duckdb_cursor): duckdb_cursor.execute("CREATE table t as select range::varchar a from range(3000);") query = duckdb_cursor.execute("SELECT a FROM t") record_batch_reader = query.fetch_record_batch(1024) - assert record_batch_reader.schema.names == ['a'] + assert record_batch_reader.schema.names == ["a"] chunk = record_batch_reader.read_next_batch() assert len(chunk) == 1024 chunk = record_batch_reader.read_next_batch() @@ -90,7 +91,7 @@ def test_record_batch_next_batch_struct(self, duckdb_cursor): ) query = duckdb_cursor.execute("SELECT a FROM t") record_batch_reader = query.fetch_record_batch(1024) - assert record_batch_reader.schema.names == ['a'] + assert record_batch_reader.schema.names == ["a"] chunk = record_batch_reader.read_next_batch() assert len(chunk) == 1024 chunk = record_batch_reader.read_next_batch() @@ -115,7 +116,7 @@ def test_record_batch_next_batch_list(self, duckdb_cursor): duckdb_cursor.execute("CREATE table t as select [i,i+1] as a from range(3000) as tbl(i);") query = duckdb_cursor.execute("SELECT a FROM t") record_batch_reader = query.fetch_record_batch(1024) - assert record_batch_reader.schema.names == ['a'] + assert record_batch_reader.schema.names == ["a"] chunk = record_batch_reader.read_next_batch() assert len(chunk) == 1024 chunk = record_batch_reader.read_next_batch() @@ -141,7 +142,7 @@ def test_record_batch_next_batch_map(self, duckdb_cursor): duckdb_cursor.execute("CREATE table t as select map([i], [i+1]) as a from range(3000) as tbl(i);") query = duckdb_cursor.execute("SELECT a FROM t") record_batch_reader = query.fetch_record_batch(1024) - assert record_batch_reader.schema.names == ['a'] + assert record_batch_reader.schema.names == ["a"] chunk = record_batch_reader.read_next_batch() assert len(chunk) == 1024 chunk = record_batch_reader.read_next_batch() @@ -169,7 +170,7 @@ def test_record_batch_next_batch_with_null(self, duckdb_cursor): ) query = duckdb_cursor.execute("SELECT a FROM t") record_batch_reader = query.fetch_record_batch(1024) - assert record_batch_reader.schema.names == ['a'] + assert record_batch_reader.schema.names == ["a"] chunk = record_batch_reader.read_next_batch() assert len(chunk) == 1024 chunk = record_batch_reader.read_next_batch() @@ -224,15 +225,15 @@ def test_record_batch_next_batch_multiple_vectors_per_chunk_error(self, duckdb_c duckdb_cursor = duckdb.connect() duckdb_cursor.execute("CREATE table t as select range a from range(5000);") query = duckdb_cursor.execute("SELECT a FROM t") - with pytest.raises(RuntimeError, match='Approximate Batch Size of Record Batch MUST be higher than 0'): - record_batch_reader = query.fetch_record_batch(0) - with pytest.raises(TypeError, match='incompatible function arguments'): - record_batch_reader = query.fetch_record_batch(-1) + with pytest.raises(RuntimeError, match="Approximate Batch Size of Record Batch MUST be higher than 0"): + query.fetch_record_batch(0) + with pytest.raises(TypeError, match="incompatible function arguments"): + query.fetch_record_batch(-1) def test_record_batch_reader_from_relation(self, duckdb_cursor): duckdb_cursor = duckdb.connect() duckdb_cursor.execute("CREATE table t as select range a from range(3000);") - relation = duckdb_cursor.table('t') + relation = duckdb_cursor.table("t") record_batch_reader = relation.record_batch() chunk = record_batch_reader.read_next_batch() assert len(chunk) == 3000 @@ -249,11 +250,9 @@ def test_record_coverage(self, duckdb_cursor): def test_record_batch_query_error(self): duckdb_cursor = duckdb.connect() duckdb_cursor.execute("CREATE table t as select 'foo' as a;") - with pytest.raises(duckdb.ConversionException, match='Conversion Error'): + with pytest.raises(duckdb.ConversionException, match="Conversion Error"): # 'execute' materializes the result, causing the error directly - query = duckdb_cursor.execute("SELECT cast(a as double) FROM t") - record_batch_reader = query.fetch_record_batch(1024) - record_batch_reader.read_next_batch() + duckdb_cursor.execute("SELECT cast(a as double) FROM t") def test_many_list_batches(self): conn = duckdb.connect() @@ -281,8 +280,8 @@ def test_many_chunk_sizes(self): query = duckdb_cursor.execute("SELECT a FROM t") record_batch_reader = query.fetch_record_batch(i) num_loops = int(object_size / i) - for j in range(num_loops): - assert record_batch_reader.schema.names == ['a'] + for _j in range(num_loops): + assert record_batch_reader.schema.names == ["a"] chunk = record_batch_reader.read_next_batch() assert len(chunk) == i remainder = object_size % i diff --git a/tests/fast/arrow/test_arrow_fixed_binary.py b/tests/fast/arrow/test_arrow_fixed_binary.py index aa0047a8..ccbc4a17 100644 --- a/tests/fast/arrow/test_arrow_fixed_binary.py +++ b/tests/fast/arrow/test_arrow_fixed_binary.py @@ -3,19 +3,19 @@ pa = pytest.importorskip("pyarrow") -class TestArrowFixedBinary(object): +class TestArrowFixedBinary: def test_arrow_fixed_binary(self, duckdb_cursor): ids = [ None, - b'\x66\x4d\xf4\xae\xb1\x5c\xb0\x4a\xdd\x5d\x1d\x54', - b'\x66\x4d\xf4\xf0\xa3\xfc\xec\x5b\x26\x81\x4e\x1d', + b"\x66\x4d\xf4\xae\xb1\x5c\xb0\x4a\xdd\x5d\x1d\x54", + b"\x66\x4d\xf4\xf0\xa3\xfc\xec\x5b\x26\x81\x4e\x1d", ] id_array = pa.array(ids, type=pa.binary(12)) - arrow_table = pa.Table.from_arrays([id_array], names=["id"]) + arrow_table = pa.Table.from_arrays([id_array], names=["id"]) # noqa: F841 res = duckdb_cursor.sql( """ SELECT lower(hex(id)) as id FROM arrow_table """ ).fetchall() - assert res == [(None,), ('664df4aeb15cb04add5d1d54',), ('664df4f0a3fcec5b26814e1d',)] + assert res == [(None,), ("664df4aeb15cb04add5d1d54",), ("664df4f0a3fcec5b26814e1d",)] diff --git a/tests/fast/arrow/test_arrow_ipc.py b/tests/fast/arrow/test_arrow_ipc.py index 1d71eaa4..df5181e7 100644 --- a/tests/fast/arrow/test_arrow_ipc.py +++ b/tests/fast/arrow/test_arrow_ipc.py @@ -1,17 +1,18 @@ import pytest + import duckdb -pa = pytest.importorskip('pyarrow') +pa = pytest.importorskip("pyarrow") -ipc = pytest.importorskip('pyarrow.ipc') +ipc = pytest.importorskip("pyarrow.ipc") def get_record_batch(): - data = [pa.array([1, 2, 3, 4]), pa.array(['foo', 'bar', 'baz', None]), pa.array([True, None, False, True])] - return pa.record_batch(data, names=['f0', 'f1', 'f2']) + data = [pa.array([1, 2, 3, 4]), pa.array(["foo", "bar", "baz", None]), pa.array([True, None, False, True])] + return pa.record_batch(data, names=["f0", "f1", "f2"]) -class TestArrowIPCExtension(object): +class TestArrowIPCExtension: # Only thing we can test in core is that it suggests the # instalation and loading of the extension def test_single_buffer(self, duckdb_cursor): @@ -24,11 +25,10 @@ def test_single_buffer(self, duckdb_cursor): buffer = sink.getvalue() - buffers = [] with pa.BufferReader(buffer) as buf_reader: # Use pyarrow.BufferReader stream = ipc.MessageReader.open_stream(buf_reader) # This fails with pytest.raises( - duckdb.Error, match="The nanoarrow community extension is needed to read the Arrow IPC protocol." + duckdb.Error, match="The nanoarrow community extension is needed to read the Arrow IPC protocol" ): - result = duckdb_cursor.from_arrow(stream).fetchall() + duckdb_cursor.from_arrow(stream).fetchall() diff --git a/tests/fast/arrow/test_arrow_list.py b/tests/fast/arrow/test_arrow_list.py index e2449fd3..b460f7e5 100644 --- a/tests/fast/arrow/test_arrow_list.py +++ b/tests/fast/arrow/test_arrow_list.py @@ -1,4 +1,3 @@ -import duckdb import numpy as np import pytest @@ -21,17 +20,13 @@ def create_and_register_arrow_table(column_list, duckdb_cursor): def create_and_register_comparison_result(column_list, duckdb_cursor): - columns = ",".join([f'{name} {dtype}' for (name, dtype, _) in column_list]) + columns = ",".join([f"{name} {dtype}" for (name, dtype, _) in column_list]) column_amount = len(column_list) assert column_amount row_amount = len(column_list[0][2]) - inserted_values = [] - for row in range(row_amount): - for col in range(column_amount): - inserted_values.append(column_list[col][2][row]) - inserted_values = tuple(inserted_values) + inserted_values = tuple(column_list[col][2][row] for row in range(row_amount) for col in range(column_amount)) - column_format = ",".join(['?' for _ in range(column_amount)]) + column_format = ",".join(["?" for _ in range(column_amount)]) row_format = ",".join([f"({column_format})" for _ in range(row_amount)]) query = f"""CREATE TABLE test ({columns}); INSERT INTO test VALUES {row_format}; @@ -41,13 +36,13 @@ def create_and_register_comparison_result(column_list, duckdb_cursor): class ListGenerationResult: - def __init__(self, list, list_view): + def __init__(self, list, list_view) -> None: self.list = list self.list_view = list_view def generate_list(child_size) -> ListGenerationResult: - input = [i for i in range(child_size)] + input = list(range(child_size)) offsets = [] sizes = [] lists = [] @@ -58,7 +53,7 @@ def generate_list(child_size) -> ListGenerationResult: if count >= child_size: break size = SIZE_MAP[i % len(SIZE_MAP)] - if size == None: + if size is None: mask.append(True) size = 0 else: @@ -73,7 +68,7 @@ def generate_list(child_size) -> ListGenerationResult: # Create a regular ListArray list_arr = pa.ListArray.from_arrays(offsets=offsets, values=input, mask=pa.array(mask, type=pa.bool_())) - if not hasattr(pa, 'ListViewArray'): + if not hasattr(pa, "ListViewArray"): return ListGenerationResult(list_arr, None) lists = list(reversed(lists)) @@ -91,24 +86,24 @@ def generate_list(child_size) -> ListGenerationResult: return ListGenerationResult(list_arr, list_view_arr) -class TestArrowListType(object): +class TestArrowListType: def test_regular_list(self, duckdb_cursor): n = 5 # Amount of lists generated_size = 3 # Size of each list list_size = -1 # Argument passed to `pa._list()` - data = [np.random.random((generated_size)) for _ in range(n)] + data = [np.random.random(generated_size) for _ in range(n)] list_type = pa.list_(pa.float32(), list_size=list_size) create_and_register_arrow_table( [ - ('a', list_type, data), + ("a", list_type, data), ], duckdb_cursor, ) create_and_register_comparison_result( [ - ('a', 'FLOAT[]', data), + ("a", "FLOAT[]", data), ], duckdb_cursor, ) @@ -120,31 +115,31 @@ def test_fixedsize_list(self, duckdb_cursor): generated_size = 3 # Size of each list list_size = 3 # Argument passed to `pa._list()` - data = [np.random.random((generated_size)) for _ in range(n)] + data = [np.random.random(generated_size) for _ in range(n)] list_type = pa.list_(pa.float32(), list_size=list_size) create_and_register_arrow_table( [ - ('a', list_type, data), + ("a", list_type, data), ], duckdb_cursor, ) create_and_register_comparison_result( [ - ('a', f'FLOAT[{list_size}]', data), + ("a", f"FLOAT[{list_size}]", data), ], duckdb_cursor, ) check_equal(duckdb_cursor) - @pytest.mark.skipif(not hasattr(pa, 'ListViewArray'), reason='The pyarrow version does not support ListViewArrays') - @pytest.mark.parametrize('child_size', [100000]) + @pytest.mark.skipif(not hasattr(pa, "ListViewArray"), reason="The pyarrow version does not support ListViewArrays") + @pytest.mark.parametrize("child_size", [100000]) def test_list_view(self, duckdb_cursor, child_size): res = generate_list(child_size) - list_tbl = pa.Table.from_arrays([res.list], ['x']) - list_view_tbl = pa.Table.from_arrays([res.list_view], ['x']) + list_tbl = pa.Table.from_arrays([res.list], ["x"]) # noqa: F841 + list_view_tbl = pa.Table.from_arrays([res.list_view], ["x"]) # noqa: F841 assert res.list_view.to_pylist() == res.list.to_pylist() original = duckdb_cursor.query("select * from list_tbl").fetchall() diff --git a/tests/fast/arrow/test_arrow_offsets.py b/tests/fast/arrow/test_arrow_offsets.py index 6bc94530..2b28f416 100644 --- a/tests/fast/arrow/test_arrow_offsets.py +++ b/tests/fast/arrow/test_arrow_offsets.py @@ -1,8 +1,8 @@ -import duckdb -import pytest -from pytest import mark +# ruff: noqa: F841 import datetime import decimal + +import pytest import pytz pa = pytest.importorskip("pyarrow") @@ -62,28 +62,26 @@ def decimal_value(value, precision, scale): val = str(value) actual_width = precision - scale if len(val) > actual_width: - return decimal.Decimal('9' * actual_width) + return decimal.Decimal("9" * actual_width) return decimal.Decimal(val) def expected_result(col1_null, col2_null, expected): col1 = None if col1_null else expected - if col1_null or col2_null: - col2 = None - else: - col2 = expected + col2 = None if col1_null or col2_null else expected return [(col1, col2)] -null_test_parameters = lambda: mark.parametrize( - ['col1_null', 'col2_null'], [(False, True), (True, False), (True, True), (False, False)] -) +def null_test_parameters(): + return pytest.mark.parametrize( + ("col1_null", "col2_null"), [(False, True), (True, False), (True, True), (False, False)] + ) -class TestArrowOffsets(object): +class TestArrowOffsets: @null_test_parameters() def test_struct_of_strings(self, duckdb_cursor, col1_null, col2_null): - col1 = [str(i) for i in range(0, MAGIC_ARRAY_SIZE)] + col1 = [str(i) for i in range(MAGIC_ARRAY_SIZE)] if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -100,14 +98,14 @@ def test_struct_of_strings(self, duckdb_cursor, col1_null, col2_null): SELECT col1, col2.a - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() - assert res == expected_result(col1_null, col2_null, '131072') + assert res == expected_result(col1_null, col2_null, "131072") @null_test_parameters() def test_struct_of_bools(self, duckdb_cursor, col1_null, col2_null): - tuples = [False for i in range(0, MAGIC_ARRAY_SIZE)] + tuples = [False for i in range(MAGIC_ARRAY_SIZE)] tuples[-1] = True col1 = tuples @@ -126,13 +124,13 @@ def test_struct_of_bools(self, duckdb_cursor, col1_null, col2_null): SELECT col1, col2.a - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() assert res == expected_result(col1_null, col2_null, True) @pytest.mark.parametrize( - ["constructor", "expected"], + ("constructor", "expected"), [ (pa_date32(), datetime.date(2328, 11, 12)), (pa_date64(), datetime.date(1970, 1, 1)), @@ -140,7 +138,7 @@ def test_struct_of_bools(self, duckdb_cursor, col1_null, col2_null): ) @null_test_parameters() def test_struct_of_dates(self, duckdb_cursor, constructor, expected, col1_null, col2_null): - tuples = [i for i in range(0, MAGIC_ARRAY_SIZE)] + tuples = list(range(MAGIC_ARRAY_SIZE)) col1 = tuples if col1_null: @@ -158,7 +156,7 @@ def test_struct_of_dates(self, duckdb_cursor, constructor, expected, col1_null, SELECT col1, col2.a - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() assert res == expected_result(col1_null, col2_null, expected) @@ -167,8 +165,8 @@ def test_struct_of_dates(self, duckdb_cursor, constructor, expected, col1_null, def test_struct_of_enum(self, duckdb_cursor, col1_null, col2_null): enum_type = pa.dictionary(pa.int64(), pa.utf8()) - tuples = ['red' for i in range(MAGIC_ARRAY_SIZE)] - tuples[-1] = 'green' + tuples = ["red" for i in range(MAGIC_ARRAY_SIZE)] + tuples[-1] = "green" if col1_null: tuples[-1] = None @@ -177,7 +175,7 @@ def test_struct_of_enum(self, duckdb_cursor, col1_null, col2_null): struct_tuples[-1] = None arrow_table = pa.Table.from_pydict( - {'col1': pa.array(tuples, enum_type), 'col2': pa.array(struct_tuples, pa.struct({"a": enum_type}))}, + {"col1": pa.array(tuples, enum_type), "col2": pa.array(struct_tuples, pa.struct({"a": enum_type}))}, schema=pa.schema([("col1", enum_type), ("col2", pa.struct({"a": enum_type}))]), ) res = duckdb_cursor.sql( @@ -185,14 +183,14 @@ def test_struct_of_enum(self, duckdb_cursor, col1_null, col2_null): SELECT col1, col2.a - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() - assert res == expected_result(col1_null, col2_null, 'green') + assert res == expected_result(col1_null, col2_null, "green") @null_test_parameters() def test_struct_of_blobs(self, duckdb_cursor, col1_null, col2_null): - col1 = [str(i) for i in range(0, MAGIC_ARRAY_SIZE)] + col1 = [str(i) for i in range(MAGIC_ARRAY_SIZE)] if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -209,28 +207,28 @@ def test_struct_of_blobs(self, duckdb_cursor, col1_null, col2_null): SELECT col1, col2.a - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() - assert res == expected_result(col1_null, col2_null, b'131072') + assert res == expected_result(col1_null, col2_null, b"131072") @null_test_parameters() @pytest.mark.parametrize( - ["constructor", "unit", "expected"], + ("constructor", "unit", "expected"), [ - (pa_time32(), 'ms', datetime.time(0, 2, 11, 72000)), - (pa_time32(), 's', datetime.time(23, 59, 59)), - (pa_time64(), 'ns', datetime.time(0, 0, 0, 131)), - (pa_time64(), 'us', datetime.time(0, 0, 0, 131072)), + (pa_time32(), "ms", datetime.time(0, 2, 11, 72000)), + (pa_time32(), "s", datetime.time(23, 59, 59)), + (pa_time64(), "ns", datetime.time(0, 0, 0, 131)), + (pa_time64(), "us", datetime.time(0, 0, 0, 131072)), ], ) def test_struct_of_time(self, duckdb_cursor, constructor, unit, expected, col1_null, col2_null): size = MAGIC_ARRAY_SIZE - if unit == 's': - # FIXME: We limit the size because we don't support time values > 24 hours + if unit == "s": + # TODO: We limit the size because we don't support time values > 24 hours # noqa: TD002, TD003 size = 86400 # The amount of seconds in a day - col1 = [i for i in range(0, size)] + col1 = list(range(size)) if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -247,7 +245,7 @@ def test_struct_of_time(self, duckdb_cursor, constructor, unit, expected, col1_n SELECT col1, col2.a - FROM arrow_table offset {size-1} + FROM arrow_table offset {size - 1} """ ).fetchall() assert res == expected_result(col1_null, col2_null, expected) @@ -255,7 +253,7 @@ def test_struct_of_time(self, duckdb_cursor, constructor, unit, expected, col1_n @null_test_parameters() # NOTE: there is sadly no way to create a 'interval[months]' (tiM) type from pyarrow @pytest.mark.parametrize( - ["constructor", "expected", "converter"], + ("constructor", "expected", "converter"), [ (pa_month_day_nano_interval(), datetime.timedelta(days=3932160), month_interval), (pa_month_day_nano_interval(), datetime.timedelta(days=131072), day_interval), @@ -265,7 +263,7 @@ def test_struct_of_time(self, duckdb_cursor, constructor, unit, expected, col1_n def test_struct_of_interval(self, duckdb_cursor, constructor, expected, converter, col1_null, col2_null): size = MAGIC_ARRAY_SIZE - col1 = [converter(i) for i in range(0, size)] + col1 = [converter(i) for i in range(size)] if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -282,25 +280,25 @@ def test_struct_of_interval(self, duckdb_cursor, constructor, expected, converte SELECT col1, col2.a - FROM arrow_table offset {size-1} + FROM arrow_table offset {size - 1} """ ).fetchall() assert res == expected_result(col1_null, col2_null, expected) @null_test_parameters() @pytest.mark.parametrize( - ["constructor", "unit", "expected"], + ("constructor", "unit", "expected"), [ - (pa_duration(), 'ms', datetime.timedelta(seconds=131, microseconds=72000)), - (pa_duration(), 's', datetime.timedelta(days=1, seconds=44672)), - (pa_duration(), 'ns', datetime.timedelta(microseconds=131)), - (pa_duration(), 'us', datetime.timedelta(microseconds=131072)), + (pa_duration(), "ms", datetime.timedelta(seconds=131, microseconds=72000)), + (pa_duration(), "s", datetime.timedelta(days=1, seconds=44672)), + (pa_duration(), "ns", datetime.timedelta(microseconds=131)), + (pa_duration(), "us", datetime.timedelta(microseconds=131072)), ], ) def test_struct_of_duration(self, duckdb_cursor, constructor, unit, expected, col1_null, col2_null): size = MAGIC_ARRAY_SIZE - col1 = [i for i in range(0, size)] + col1 = list(range(size)) if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -317,26 +315,26 @@ def test_struct_of_duration(self, duckdb_cursor, constructor, unit, expected, co SELECT col1, col2.a - FROM arrow_table offset {size-1} + FROM arrow_table offset {size - 1} """ ).fetchall() assert res == expected_result(col1_null, col2_null, expected) @null_test_parameters() @pytest.mark.parametrize( - ["constructor", "unit", "expected"], + ("constructor", "unit", "expected"), [ - (pa_timestamp(), 'ms', datetime.datetime(1970, 1, 1, 0, 2, 11, 72000, tzinfo=pytz.utc)), - (pa_timestamp(), 's', datetime.datetime(1970, 1, 2, 12, 24, 32, 0, tzinfo=pytz.utc)), - (pa_timestamp(), 'ns', datetime.datetime(1970, 1, 1, 0, 0, 0, 131, tzinfo=pytz.utc)), - (pa_timestamp(), 'us', datetime.datetime(1970, 1, 1, 0, 0, 0, 131072, tzinfo=pytz.utc)), + (pa_timestamp(), "ms", datetime.datetime(1970, 1, 1, 0, 2, 11, 72000, tzinfo=pytz.utc)), + (pa_timestamp(), "s", datetime.datetime(1970, 1, 2, 12, 24, 32, 0, tzinfo=pytz.utc)), + (pa_timestamp(), "ns", datetime.datetime(1970, 1, 1, 0, 0, 0, 131, tzinfo=pytz.utc)), + (pa_timestamp(), "us", datetime.datetime(1970, 1, 1, 0, 0, 0, 131072, tzinfo=pytz.utc)), ], ) def test_struct_of_timestamp_tz(self, duckdb_cursor, constructor, unit, expected, col1_null, col2_null): size = MAGIC_ARRAY_SIZE duckdb_cursor.execute("set timezone='UTC'") - col1 = [i for i in range(0, size)] + col1 = list(range(size)) if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -346,7 +344,7 @@ def test_struct_of_timestamp_tz(self, duckdb_cursor, constructor, unit, expected arrow_table = pa.Table.from_pydict( {"col1": col1, "col2": col2}, schema=pa.schema( - [("col1", constructor(unit, 'UTC')), ("col2", pa.struct({"a": constructor(unit, 'UTC')}))] + [("col1", constructor(unit, "UTC")), ("col2", pa.struct({"a": constructor(unit, "UTC")}))] ), ) @@ -355,14 +353,14 @@ def test_struct_of_timestamp_tz(self, duckdb_cursor, constructor, unit, expected SELECT col1, col2.a - FROM arrow_table offset {size-1} + FROM arrow_table offset {size - 1} """ ).fetchall() assert res == expected_result(col1_null, col2_null, expected) @null_test_parameters() def test_struct_of_large_blobs(self, duckdb_cursor, col1_null, col2_null): - col1 = [str(i) for i in range(0, MAGIC_ARRAY_SIZE)] + col1 = [str(i) for i in range(MAGIC_ARRAY_SIZE)] if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -379,28 +377,28 @@ def test_struct_of_large_blobs(self, duckdb_cursor, col1_null, col2_null): SELECT col1, col2.a - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() - assert res == expected_result(col1_null, col2_null, b'131072') + assert res == expected_result(col1_null, col2_null, b"131072") @null_test_parameters() @pytest.mark.parametrize( - ["precision_scale", "expected"], + ("precision_scale", "expected"), [ - ((38, 37), decimal.Decimal('9.0000000000000000000000000000000000000')), - ((38, 24), decimal.Decimal('131072.000000000000000000000000')), - ((18, 14), decimal.Decimal('9999.00000000000000')), - ((18, 5), decimal.Decimal('131072.00000')), - ((9, 7), decimal.Decimal('99.0000000')), - ((9, 3), decimal.Decimal('131072.000')), - ((4, 2), decimal.Decimal('99.00')), - ((4, 0), decimal.Decimal('9999')), + ((38, 37), decimal.Decimal("9.0000000000000000000000000000000000000")), + ((38, 24), decimal.Decimal("131072.000000000000000000000000")), + ((18, 14), decimal.Decimal("9999.00000000000000")), + ((18, 5), decimal.Decimal("131072.00000")), + ((9, 7), decimal.Decimal("99.0000000")), + ((9, 3), decimal.Decimal("131072.000")), + ((4, 2), decimal.Decimal("99.00")), + ((4, 0), decimal.Decimal("9999")), ], ) def test_struct_of_decimal(self, duckdb_cursor, precision_scale, expected, col1_null, col2_null): precision, scale = precision_scale - col1 = [decimal_value(i, precision, scale) for i in range(0, MAGIC_ARRAY_SIZE)] + col1 = [decimal_value(i, precision, scale) for i in range(MAGIC_ARRAY_SIZE)] if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -420,14 +418,14 @@ def test_struct_of_decimal(self, duckdb_cursor, precision_scale, expected, col1_ SELECT col1, col2.a - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() assert res == expected_result(col1_null, col2_null, expected) @null_test_parameters() def test_struct_of_small_list(self, duckdb_cursor, col1_null, col2_null): - col1 = [str(i) for i in range(0, MAGIC_ARRAY_SIZE)] + col1 = [str(i) for i in range(MAGIC_ARRAY_SIZE)] if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -443,21 +441,21 @@ def test_struct_of_small_list(self, duckdb_cursor, col1_null, col2_null): SELECT col1, col2.a - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() - res1 = None if col1_null else '131072' + res1 = None if col1_null else "131072" if col2_null: res2 = None elif col1_null: res2 = [None, None, None] else: - res2 = ['131072', '131072', '131072'] + res2 = ["131072", "131072", "131072"] assert res == [(res1, res2)] @null_test_parameters() def test_struct_of_fixed_size_list(self, duckdb_cursor, col1_null, col2_null): - col1 = [str(i) for i in range(0, MAGIC_ARRAY_SIZE)] + col1 = [str(i) for i in range(MAGIC_ARRAY_SIZE)] if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -473,21 +471,21 @@ def test_struct_of_fixed_size_list(self, duckdb_cursor, col1_null, col2_null): SELECT col1, col2.a - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() - res1 = None if col1_null else '131072' + res1 = None if col1_null else "131072" if col2_null: res2 = None elif col1_null: res2 = (None, None, None) else: - res2 = ('131072', '131072', '131072') + res2 = ("131072", "131072", "131072") assert res == [(res1, res2)] @null_test_parameters() def test_struct_of_fixed_size_blob(self, duckdb_cursor, col1_null, col2_null): - col1 = [str(i) for i in range(0, MAGIC_ARRAY_SIZE)] + col1 = [str(i) for i in range(MAGIC_ARRAY_SIZE)] if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -504,21 +502,21 @@ def test_struct_of_fixed_size_blob(self, duckdb_cursor, col1_null, col2_null): SELECT col1, col2.a - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() - res1 = None if col1_null else b'131072' + res1 = None if col1_null else b"131072" if col2_null: res2 = None elif col1_null: res2 = (None, None, None) else: - res2 = (b'131072', b'131073', b'131074') + res2 = (b"131072", b"131073", b"131074") assert res == [(res1, res2)] @null_test_parameters() def test_struct_of_list_of_blobs(self, duckdb_cursor, col1_null, col2_null): - col1 = [str(i) for i in range(0, MAGIC_ARRAY_SIZE)] + col1 = [str(i) for i in range(MAGIC_ARRAY_SIZE)] if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -535,21 +533,21 @@ def test_struct_of_list_of_blobs(self, duckdb_cursor, col1_null, col2_null): SELECT col1, col2.a - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() - res1 = None if col1_null else b'131072' + res1 = None if col1_null else b"131072" if col2_null: res2 = None elif col1_null: res2 = [None, None, None] else: - res2 = [b'131072', b'131073', b'131074'] + res2 = [b"131072", b"131073", b"131074"] assert res == [(res1, res2)] @null_test_parameters() def test_struct_of_list_of_list(self, duckdb_cursor, col1_null, col2_null): - col1 = [i for i in range(0, MAGIC_ARRAY_SIZE)] + col1 = list(range(MAGIC_ARRAY_SIZE)) if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -566,7 +564,7 @@ def test_struct_of_list_of_list(self, duckdb_cursor, col1_null, col2_null): SELECT col1, col2.a - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() res1 = None if col1_null else 131072 @@ -578,10 +576,10 @@ def test_struct_of_list_of_list(self, duckdb_cursor, col1_null, col2_null): res2 = [[131072, 131072, 131072], [], None, [131072]] assert res == [(res1, res2)] - @pytest.mark.parametrize('col1_null', [True, False]) + @pytest.mark.parametrize("col1_null", [True, False]) def test_list_of_struct(self, duckdb_cursor, col1_null): # One single tuple containing a very big list - tuples = [{"a": i} for i in range(0, MAGIC_ARRAY_SIZE)] + tuples = [{"a": i} for i in range(MAGIC_ARRAY_SIZE)] if col1_null: tuples[-1] = None tuples = [tuples] @@ -590,7 +588,7 @@ def test_list_of_struct(self, duckdb_cursor, col1_null): schema=pa.schema([("col1", pa.list_(pa.struct({"a": pa.int32()})))]), ) res = duckdb_cursor.sql( - f""" + """ SELECT col1 FROM arrow_table @@ -598,20 +596,20 @@ def test_list_of_struct(self, duckdb_cursor, col1_null): ).fetchall() res = res[0][0] for i, x in enumerate(res[:-1]): - assert x.__class__ == dict - assert x['a'] == i + assert x.__class__ is dict + assert x["a"] == i if col1_null: - assert res[-1] == None + assert res[-1] is None else: - assert res[-1]['a'] == len(res) - 1 + assert res[-1]["a"] == len(res) - 1 - @pytest.mark.parametrize(['outer_null', 'inner_null'], [(True, False), (False, True)]) + @pytest.mark.parametrize(("outer_null", "inner_null"), [(True, False), (False, True)]) def test_list_of_list_of_struct(self, duckdb_cursor, outer_null, inner_null): tuples = [[[{"a": str(i), "b": None, "c": [i]}]] for i in range(MAGIC_ARRAY_SIZE)] if outer_null: tuples[-1] = None else: - inner = [[{"a": 'aaaaaaaaaaaaaaa', "b": 'test', "c": [1, 2, 3]}] for _ in range(MAGIC_ARRAY_SIZE)] + inner = [[{"a": "aaaaaaaaaaaaaaa", "b": "test", "c": [1, 2, 3]}] for _ in range(MAGIC_ARRAY_SIZE)] if inner_null: inner[-1] = None tuples[-1] = inner @@ -635,18 +633,18 @@ def test_list_of_list_of_struct(self, duckdb_cursor, outer_null, inner_null): f""" SELECT col1 - FROM arrow_table OFFSET {MAGIC_ARRAY_SIZE-1} + FROM arrow_table OFFSET {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() if outer_null: assert res == [(None,)] else: if inner_null: - assert res[-1][-1][-1] == None + assert res[-1][-1][-1] is None else: assert res[-1][-1][-1] == 131072 - @pytest.mark.parametrize('col1_null', [True, False]) + @pytest.mark.parametrize("col1_null", [True, False]) def test_struct_of_list(self, duckdb_cursor, col1_null): # All elements are of size 1 tuples = [{"a": [str(i)]} for i in range(MAGIC_ARRAY_SIZE)] @@ -664,13 +662,13 @@ def test_struct_of_list(self, duckdb_cursor, col1_null): f""" SELECT col1 - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchone() if col1_null: - assert res[0] == None + assert res[0] is None else: - assert res[0]['a'][-1] == '131072' + assert res[0]["a"][-1] == "131072" def test_bools_with_offset(self, duckdb_cursor): bools = [False, False, False, False, True, False, False, False, False, False] diff --git a/tests/fast/arrow/test_arrow_pycapsule.py b/tests/fast/arrow/test_arrow_pycapsule.py index c293344d..47f1542b 100644 --- a/tests/fast/arrow/test_arrow_pycapsule.py +++ b/tests/fast/arrow/test_arrow_pycapsule.py @@ -1,6 +1,6 @@ -import duckdb import pytest -import os + +import duckdb pl = pytest.importorskip("polars") @@ -8,24 +8,24 @@ def polars_supports_capsule(): from packaging.version import Version - return Version(pl.__version__) >= Version('1.4.1') + return Version(pl.__version__) >= Version("1.4.1") @pytest.mark.skipif( - not polars_supports_capsule(), reason='Polars version does not support the Arrow PyCapsule interface' + not polars_supports_capsule(), reason="Polars version does not support the Arrow PyCapsule interface" ) -class TestArrowPyCapsule(object): +class TestArrowPyCapsule: def test_polars_pycapsule_scan(self, duckdb_cursor): class MyObject: - def __init__(self, obj): + def __init__(self, obj) -> None: self.obj = obj self.count = 0 - def __arrow_c_stream__(self, requested_schema=None): + def __arrow_c_stream__(self, requested_schema=None) -> object: self.count += 1 return self.obj.__arrow_c_stream__(requested_schema=requested_schema) - df = pl.DataFrame({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8]}) + df = pl.DataFrame({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]}) obj = MyObject(df) # Call the __arrow_c_stream__ from within DuckDB @@ -40,7 +40,7 @@ def __arrow_c_stream__(self, requested_schema=None): assert obj.count == 2 # Ensure __arrow_c_stream__ accepts a requested_schema argument as noop - capsule = obj.__arrow_c_stream__(requested_schema="foo") + capsule = obj.__arrow_c_stream__(requested_schema="foo") # noqa: F841 res = duckdb_cursor.sql("select * from capsule") assert res.fetchall() == [(1, 5), (2, 6), (3, 7), (4, 8)] assert obj.count == 3 @@ -53,7 +53,7 @@ def create_capsule(): capsule = rel.__arrow_c_stream__() return capsule - capsule = create_capsule() + capsule = create_capsule() # noqa: F841 rel2 = duckdb_cursor.sql("select * from capsule") assert rel2.fetchall() == [(i, i + 1, -i) for i in range(100)] @@ -61,9 +61,9 @@ def test_automatic_reexecution(self, duckdb_cursor): other_con = duckdb_cursor.cursor() rel = duckdb_cursor.sql("select i, i+1, -i from range(100) t(i)") - capsule_one = rel.__arrow_c_stream__() + capsule_one = rel.__arrow_c_stream__() # noqa: F841 res1 = other_con.sql("select * from capsule_one").fetchall() - capsule_two = rel.__arrow_c_stream__() + capsule_two = rel.__arrow_c_stream__() # noqa: F841 res2 = other_con.sql("select * from capsule_two").fetchall() assert len(res1) == 100 assert res1 == res2 @@ -71,17 +71,17 @@ def test_automatic_reexecution(self, duckdb_cursor): def test_consumer_interface_roundtrip(self, duckdb_cursor): def create_table(): class MyTable: - def __init__(self, rel, conn): + def __init__(self, rel, conn) -> None: self.rel = rel self.conn = conn - def __arrow_c_stream__(self, requested_schema=None): + def __arrow_c_stream__(self, requested_schema=None) -> object: return self.rel.__arrow_c_stream__(requested_schema=requested_schema) conn = duckdb.connect() rel = conn.sql("select i, i+1, -i from range(100) t(i)") return MyTable(rel, conn) - tbl = create_table() + tbl = create_table() # noqa: F841 rel2 = duckdb_cursor.sql("select * from tbl") assert rel2.fetchall() == [(i, i + 1, -i) for i in range(100)] diff --git a/tests/fast/arrow/test_arrow_recordbatchreader.py b/tests/fast/arrow/test_arrow_recordbatchreader.py index 0f8a701d..0b7852b3 100644 --- a/tests/fast/arrow/test_arrow_recordbatchreader.py +++ b/tests/fast/arrow/test_arrow_recordbatchreader.py @@ -1,20 +1,21 @@ -import duckdb -import os +from pathlib import Path + import pytest +import duckdb + pyarrow = pytest.importorskip("pyarrow") pyarrow.parquet = pytest.importorskip("pyarrow.parquet") pyarrow.dataset = pytest.importorskip("pyarrow.dataset") np = pytest.importorskip("numpy") -class TestArrowRecordBatchReader(object): +class TestArrowRecordBatchReader: def test_parallel_reader(self, duckdb_cursor): - duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = str(Path(__file__).parent / "data" / "userdata1.parquet") userdata_parquet_dataset = pyarrow.dataset.dataset( [ @@ -25,25 +26,22 @@ def test_parallel_reader(self, duckdb_cursor): format="parquet", ) - batches = [r for r in userdata_parquet_dataset.to_batches()] + batches = list(userdata_parquet_dataset.to_batches()) reader = pyarrow.dataset.Scanner.from_batches(batches, schema=userdata_parquet_dataset.schema).to_reader() rel = duckdb_conn.from_arrow(reader) assert ( - rel.filter("first_name=\'Jose\' and salary > 134708.82").aggregate('count(*)').execute().fetchone()[0] == 12 + rel.filter("first_name='Jose' and salary > 134708.82").aggregate("count(*)").execute().fetchone()[0] == 12 ) # The reader is already consumed so this should be 0 - assert ( - rel.filter("first_name=\'Jose\' and salary > 134708.82").aggregate('count(*)').execute().fetchone()[0] == 0 - ) + assert rel.filter("first_name='Jose' and salary > 134708.82").aggregate("count(*)").execute().fetchone()[0] == 0 def test_parallel_reader_replacement_scans(self, duckdb_cursor): - duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = str(Path(__file__).parent / "data" / "userdata1.parquet") userdata_parquet_dataset = pyarrow.dataset.dataset( [ @@ -54,28 +52,27 @@ def test_parallel_reader_replacement_scans(self, duckdb_cursor): format="parquet", ) - batches = [r for r in userdata_parquet_dataset.to_batches()] - reader = pyarrow.dataset.Scanner.from_batches(batches, schema=userdata_parquet_dataset.schema).to_reader() + batches = list(userdata_parquet_dataset.to_batches()) + reader = pyarrow.dataset.Scanner.from_batches(batches, schema=userdata_parquet_dataset.schema).to_reader() # noqa: F841 assert ( duckdb_conn.execute( - "select count(*) r1 from reader where first_name=\'Jose\' and salary > 134708.82" + "select count(*) r1 from reader where first_name='Jose' and salary > 134708.82" ).fetchone()[0] == 12 ) assert ( duckdb_conn.execute( - "select count(*) r2 from reader where first_name=\'Jose\' and salary > 134708.82" + "select count(*) r2 from reader where first_name='Jose' and salary > 134708.82" ).fetchone()[0] == 0 ) def test_parallel_reader_register(self, duckdb_cursor): - duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = str(Path(__file__).parent / "data" / "userdata1.parquet") userdata_parquet_dataset = pyarrow.dataset.dataset( [ @@ -86,27 +83,22 @@ def test_parallel_reader_register(self, duckdb_cursor): format="parquet", ) - batches = [r for r in userdata_parquet_dataset.to_batches()] + batches = list(userdata_parquet_dataset.to_batches()) reader = pyarrow.dataset.Scanner.from_batches(batches, schema=userdata_parquet_dataset.schema).to_reader() duckdb_conn.register("bla", reader) assert ( - duckdb_conn.execute("select count(*) from bla where first_name=\'Jose\' and salary > 134708.82").fetchone()[ - 0 - ] + duckdb_conn.execute("select count(*) from bla where first_name='Jose' and salary > 134708.82").fetchone()[0] == 12 ) assert ( - duckdb_conn.execute("select count(*) from bla where first_name=\'Jose\' and salary > 134708.82").fetchone()[ - 0 - ] + duckdb_conn.execute("select count(*) from bla where first_name='Jose' and salary > 134708.82").fetchone()[0] == 0 ) def test_parallel_reader_default_conn(self, duckdb_cursor): - - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = str(Path(__file__).parent / "data" / "userdata1.parquet") userdata_parquet_dataset = pyarrow.dataset.dataset( [ @@ -117,15 +109,13 @@ def test_parallel_reader_default_conn(self, duckdb_cursor): format="parquet", ) - batches = [r for r in userdata_parquet_dataset.to_batches()] + batches = list(userdata_parquet_dataset.to_batches()) reader = pyarrow.dataset.Scanner.from_batches(batches, schema=userdata_parquet_dataset.schema).to_reader() rel = duckdb.from_arrow(reader) assert ( - rel.filter("first_name=\'Jose\' and salary > 134708.82").aggregate('count(*)').execute().fetchone()[0] == 12 + rel.filter("first_name='Jose' and salary > 134708.82").aggregate("count(*)").execute().fetchone()[0] == 12 ) # The reader is already consumed so this should be 0 - assert ( - rel.filter("first_name=\'Jose\' and salary > 134708.82").aggregate('count(*)').execute().fetchone()[0] == 0 - ) + assert rel.filter("first_name='Jose' and salary > 134708.82").aggregate("count(*)").execute().fetchone()[0] == 0 diff --git a/tests/fast/arrow/test_arrow_replacement_scan.py b/tests/fast/arrow/test_arrow_replacement_scan.py index a02bac10..8c372a22 100644 --- a/tests/fast/arrow/test_arrow_replacement_scan.py +++ b/tests/fast/arrow/test_arrow_replacement_scan.py @@ -1,32 +1,32 @@ -import duckdb +from pathlib import Path + import pytest -import os -import pandas as pd + +import duckdb pa = pytest.importorskip("pyarrow") pq = pytest.importorskip("pyarrow.parquet") ds = pytest.importorskip("pyarrow.dataset") -class TestArrowReplacementScan(object): +class TestArrowReplacementScan: def test_arrow_table_replacement_scan(self, duckdb_cursor): - - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = str(Path(__file__).parent / "data" / "userdata1.parquet") userdata_parquet_table = pq.read_table(parquet_filename) - df = userdata_parquet_table.to_pandas() + df = userdata_parquet_table.to_pandas() # noqa: F841 con = duckdb.connect() - for i in range(5): + for _i in range(5): assert con.execute("select count(*) from userdata_parquet_table").fetchone() == (1000,) assert con.execute("select count(*) from df").fetchone() == (1000,) @pytest.mark.skipif( - not hasattr(pa.Table, '__arrow_c_stream__'), - reason='This version of pyarrow does not support the Arrow Capsule Interface', + not hasattr(pa.Table, "__arrow_c_stream__"), + reason="This version of pyarrow does not support the Arrow Capsule Interface", ) def test_arrow_pycapsule_replacement_scan(self, duckdb_cursor): - tbl = pa.Table.from_pydict({'a': [1, 2, 3, 4, 5, 6, 7, 8, 9]}) + tbl = pa.Table.from_pydict({"a": [1, 2, 3, 4, 5, 6, 7, 8, 9]}) capsule = tbl.__arrow_c_stream__() rel = duckdb_cursor.sql("select * from capsule") @@ -36,38 +36,37 @@ def test_arrow_pycapsule_replacement_scan(self, duckdb_cursor): rel = duckdb_cursor.sql("select * from capsule where a > 3 and a < 5") assert rel.fetchall() == [(4,)] - tbl = pa.Table.from_pydict({'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [7, 8, 9], 'd': [10, 11, 12]}) - capsule = tbl.__arrow_c_stream__() + tbl = pa.Table.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9], "d": [10, 11, 12]}) + capsule = tbl.__arrow_c_stream__() # noqa: F841 rel = duckdb_cursor.sql("select b, d from capsule") assert rel.fetchall() == [(i, i + 6) for i in range(4, 7)] - with pytest.raises(duckdb.InvalidInputException, match='The ArrowArrayStream was already released'): - rel = duckdb_cursor.sql("select b, d from capsule") + with pytest.raises(duckdb.InvalidInputException, match="The ArrowArrayStream was already released"): + duckdb_cursor.sql("select b, d from capsule") schema_obj = tbl.schema - schema_capsule = schema_obj.__arrow_c_schema__() + schema_capsule = schema_obj.__arrow_c_schema__() # noqa: F841 with pytest.raises( duckdb.InvalidInputException, match="""Expected a 'arrow_array_stream' PyCapsule, got: arrow_schema""" ): - rel = duckdb_cursor.sql("select b, d from schema_capsule") + duckdb_cursor.sql("select b, d from schema_capsule") def test_arrow_table_replacement_scan_view(self, duckdb_cursor): - - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = str(Path(__file__).parent / "data" / "userdata1.parquet") userdata_parquet_table = pq.read_table(parquet_filename) con = duckdb.connect() con.execute("create view x as select * from userdata_parquet_table") del userdata_parquet_table - with pytest.raises(duckdb.CatalogException, match='Table with name userdata_parquet_table does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name userdata_parquet_table does not exist"): assert con.execute("select count(*) from x").fetchone() def test_arrow_dataset_replacement_scan(self, duckdb_cursor): - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') - userdata_parquet_table = pq.read_table(parquet_filename) - userdata_parquet_dataset = ds.dataset(parquet_filename) + parquet_filename = str(Path(__file__).parent / "data" / "userdata1.parquet") + pq.read_table(parquet_filename) + userdata_parquet_dataset = ds.dataset(parquet_filename) # noqa: F841 con = duckdb.connect() assert con.execute("select count(*) from userdata_parquet_dataset").fetchone() == (1000,) diff --git a/tests/fast/arrow/test_arrow_run_end_encoding.py b/tests/fast/arrow/test_arrow_run_end_encoding.py index 6315d1b7..e04f9ea0 100644 --- a/tests/fast/arrow/test_arrow_run_end_encoding.py +++ b/tests/fast/arrow/test_arrow_run_end_encoding.py @@ -1,9 +1,6 @@ -import duckdb import pytest -import pandas as pd -import duckdb -pa = pytest.importorskip("pyarrow", '21.0.0', reason="Needs pyarrow >= 21") +pa = pytest.importorskip("pyarrow", "21.0.0", reason="Needs pyarrow >= 21") pc = pytest.importorskip("pyarrow.compute") @@ -25,14 +22,14 @@ def create_list(offsets, values): def list_constructors(): result = [] result.append(create_list) - if hasattr(pa, 'ListViewArray'): + if hasattr(pa, "ListViewArray"): result.append(create_list_view) return result -class TestArrowREE(object): +class TestArrowREE: @pytest.mark.parametrize( - 'query', + "query", [ """ select @@ -46,57 +43,55 @@ class TestArrowREE(object): """, ], ) - @pytest.mark.parametrize('run_length', [4, 1, 10, 1000, 2048, 3000]) - @pytest.mark.parametrize('size', [100, 10000]) + @pytest.mark.parametrize("run_length", [4, 1, 10, 1000, 2048, 3000]) + @pytest.mark.parametrize("size", [100, 10000]) @pytest.mark.parametrize( - 'value_type', - ['UTINYINT', 'USMALLINT', 'UINTEGER', 'UBIGINT', 'TINYINT', 'SMALLINT', 'INTEGER', 'BIGINT', 'HUGEINT'], + "value_type", + ["UTINYINT", "USMALLINT", "UINTEGER", "UBIGINT", "TINYINT", "SMALLINT", "INTEGER", "BIGINT", "HUGEINT"], ) def test_arrow_run_end_encoding_numerics(self, duckdb_cursor, query, run_length, size, value_type): - if value_type == 'UTINYINT': - if size > 255: - size = 255 - if value_type == 'TINYINT': - if size > 127: - size = 127 + if value_type == "UTINYINT" and size > 255: + size = 255 + if value_type == "TINYINT" and size > 127: + size = 127 query = query.format(run_length, value_type, size) rel = duckdb_cursor.sql(query) - array = rel.fetch_arrow_table()['ree'] + array = rel.fetch_arrow_table()["ree"] expected = rel.fetchall() encoded_array = pc.run_end_encode(array) schema = pa.schema([("ree", encoded_array.type)]) - tbl = pa.Table.from_arrays([encoded_array], schema=schema) + tbl = pa.Table.from_arrays([encoded_array], schema=schema) # noqa: F841 res = duckdb_cursor.sql("select * from tbl").fetchall() assert res == expected @pytest.mark.parametrize( - ['dbtype', 'val1', 'val2'], + ("dbtype", "val1", "val2"), [ - ('TINYINT', '(-128)', '127'), - ('SMALLINT', '(-32768)', '32767'), - ('INTEGER', '(-2147483648)', '2147483647'), - ('BIGINT', '(-9223372036854775808)', '9223372036854775807'), - ('UTINYINT', '0', '255'), - ('USMALLINT', '0', '65535'), - ('UINTEGER', '0', '4294967295'), - ('UBIGINT', '0', '18446744073709551615'), - ('BOOL', 'true', 'false'), - ('VARCHAR', "'test'", "'this is a long string'"), - ('BLOB', "'\\xE0\\x9F\\x98\\x84'", "'\\xF0\\x9F\\xA6\\x86'"), - ('DATE', "'1992-03-27'", "'2204-11-01'"), - ('TIME', "'01:02:03'", "'23:41:35'"), - ('TIMESTAMP_S', "'1992-03-22 01:02:03'", "'2022-11-07 08:43:04.123456'"), - ('TIMESTAMP', "'1992-03-22 01:02:03'", "'2022-11-07 08:43:04.123456'"), - ('TIMESTAMP_MS', "'1992-03-22 01:02:03'", "'2022-11-07 08:43:04.123456'"), - ('TIMESTAMP_NS', "'1992-03-22 01:02:03'", "'2022-11-07 08:43:04.123456'"), - ('DECIMAL(4,2)', "'12.23'", "'99.99'"), - ('DECIMAL(7,6)', "'1.234234'", "'0.000001'"), - ('DECIMAL(14,7)', "'134523.234234'", "'999999.000001'"), - ('DECIMAL(28,1)', "'12345678910111234123456789.1'", "'999999999999999999999999999.9'"), - ('UUID', "'10acd298-15d7-417c-8b59-eabb5a2bacab'", "'eeccb8c5-9943-b2bb-bb5e-222f4e14b687'"), - ('BIT', "'01010101010000'", "'01010100010101010101010101111111111'"), + ("TINYINT", "(-128)", "127"), + ("SMALLINT", "(-32768)", "32767"), + ("INTEGER", "(-2147483648)", "2147483647"), + ("BIGINT", "(-9223372036854775808)", "9223372036854775807"), + ("UTINYINT", "0", "255"), + ("USMALLINT", "0", "65535"), + ("UINTEGER", "0", "4294967295"), + ("UBIGINT", "0", "18446744073709551615"), + ("BOOL", "true", "false"), + ("VARCHAR", "'test'", "'this is a long string'"), + ("BLOB", "'\\xE0\\x9F\\x98\\x84'", "'\\xF0\\x9F\\xA6\\x86'"), + ("DATE", "'1992-03-27'", "'2204-11-01'"), + ("TIME", "'01:02:03'", "'23:41:35'"), + ("TIMESTAMP_S", "'1992-03-22 01:02:03'", "'2022-11-07 08:43:04.123456'"), + ("TIMESTAMP", "'1992-03-22 01:02:03'", "'2022-11-07 08:43:04.123456'"), + ("TIMESTAMP_MS", "'1992-03-22 01:02:03'", "'2022-11-07 08:43:04.123456'"), + ("TIMESTAMP_NS", "'1992-03-22 01:02:03'", "'2022-11-07 08:43:04.123456'"), + ("DECIMAL(4,2)", "'12.23'", "'99.99'"), + ("DECIMAL(7,6)", "'1.234234'", "'0.000001'"), + ("DECIMAL(14,7)", "'134523.234234'", "'999999.000001'"), + ("DECIMAL(28,1)", "'12345678910111234123456789.1'", "'999999999999999999999999999.9'"), + ("UUID", "'10acd298-15d7-417c-8b59-eabb5a2bacab'", "'eeccb8c5-9943-b2bb-bb5e-222f4e14b687'"), + ("BIT", "'01010101010000'", "'01010100010101010101010101111111111'"), ], ) @pytest.mark.parametrize( @@ -107,7 +102,7 @@ def test_arrow_run_end_encoding_numerics(self, duckdb_cursor, query, run_length, ], ) def test_arrow_run_end_encoding(self, duckdb_cursor, dbtype, val1, val2, filter): - if dbtype in ['BIT', 'UUID']: + if dbtype in ["BIT", "UUID"]: pytest.skip("BIT and UUID are currently broken (FIXME)") projection = "a, b, ree" query = """ @@ -130,49 +125,49 @@ def test_arrow_run_end_encoding(self, duckdb_cursor, dbtype, val1, val2, filter) duckdb_cursor.execute(query) rel = duckdb_cursor.query("select * from ree_tbl") - expected = duckdb_cursor.query("select {} from ree_tbl where {}".format(projection, filter)).fetchall() + expected = duckdb_cursor.query(f"select {projection} from ree_tbl where {filter}").fetchall() # Create an Arrow Table from the table arrow_conversion = rel.fetch_arrow_table() arrays = { - 'ree': arrow_conversion['ree'], - 'a': arrow_conversion['a'], - 'b': arrow_conversion['b'], + "ree": arrow_conversion["ree"], + "a": arrow_conversion["a"], + "b": arrow_conversion["b"], } encoded_arrays = { - 'ree': pc.run_end_encode(arrays['ree']), - 'a': pc.run_end_encode(arrays['a']), - 'b': pc.run_end_encode(arrays['b']), + "ree": pc.run_end_encode(arrays["ree"]), + "a": pc.run_end_encode(arrays["a"]), + "b": pc.run_end_encode(arrays["b"]), } schema = pa.schema( [ - ("ree", encoded_arrays['ree'].type), - ("a", encoded_arrays['a'].type), - ("b", encoded_arrays['b'].type), + ("ree", encoded_arrays["ree"].type), + ("a", encoded_arrays["a"].type), + ("b", encoded_arrays["b"].type), ] ) - tbl = pa.Table.from_arrays([encoded_arrays['ree'], encoded_arrays['a'], encoded_arrays['b']], schema=schema) + tbl = pa.Table.from_arrays([encoded_arrays["ree"], encoded_arrays["a"], encoded_arrays["b"]], schema=schema) # noqa: F841 # Scan the Arrow Table and verify that the results are the same - res = duckdb_cursor.sql("select {} from tbl where {}".format(projection, filter)).fetchall() + res = duckdb_cursor.sql(f"select {projection} from tbl where {filter}").fetchall() assert res == expected def test_arrow_ree_empty_table(self, duckdb_cursor): duckdb_cursor.query("create table tbl (ree integer)") - rel = duckdb_cursor.table('tbl') - array = rel.fetch_arrow_table()['ree'] + rel = duckdb_cursor.table("tbl") + array = rel.fetch_arrow_table()["ree"] expected = rel.fetchall() encoded_array = pc.run_end_encode(array) schema = pa.schema([("ree", encoded_array.type)]) - pa_res = pa.Table.from_arrays([encoded_array], schema=schema) + pa_res = pa.Table.from_arrays([encoded_array], schema=schema) # noqa: F841 res = duckdb_cursor.sql("select * from pa_res").fetchall() assert res == expected - @pytest.mark.parametrize('projection', ['*', 'a, c, b', 'ree, a, b, c', 'c, b, a, ree', 'c', 'b, ree, c, a']) + @pytest.mark.parametrize("projection", ["*", "a, c, b", "ree, a, b, c", "c, b, a, ree", "c", "b, ree, c, a"]) def test_arrow_ree_projections(self, duckdb_cursor, projection): # Create the schema duckdb_cursor.query( @@ -199,58 +194,54 @@ def test_arrow_ree_projections(self, duckdb_cursor, projection): ) # Fetch the result as an Arrow Table - result = duckdb_cursor.table('tbl').fetch_arrow_table() + result = duckdb_cursor.table("tbl").fetch_arrow_table() # Turn 'ree' into a run-end-encoded array and reconstruct a table from it arrays = { - 'ree': pc.run_end_encode(result['ree']), - 'a': result['a'], - 'b': result['b'], - 'c': result['c'], + "ree": pc.run_end_encode(result["ree"]), + "a": result["a"], + "b": result["b"], + "c": result["c"], } schema = pa.schema( [ - ("ree", arrays['ree'].type), - ("a", arrays['a'].type), - ("b", arrays['b'].type), - ("c", arrays['c'].type), + ("ree", arrays["ree"].type), + ("a", arrays["a"].type), + ("b", arrays["b"].type), + ("c", arrays["c"].type), ] ) - arrow_tbl = pa.Table.from_arrays([arrays['ree'], arrays['a'], arrays['b'], arrays['c']], schema=schema) + arrow_tbl = pa.Table.from_arrays([arrays["ree"], arrays["a"], arrays["b"], arrays["c"]], schema=schema) # Verify that the array is run end encoded - ar_type = arrow_tbl['ree'].type - assert pa.types.is_run_end_encoded(ar_type) == True + ar_type = arrow_tbl["ree"].type + assert pa.types.is_run_end_encoded(ar_type) # Scan the arrow table, making projections that don't cover the entire table # This should be pushed down into arrow to only provide us with the necessary columns - res = duckdb_cursor.query( - """ - select {} from arrow_tbl - """.format( - projection - ) + res = duckdb_cursor.query( # noqa: F841 + f""" + select {projection} from arrow_tbl + """ ).fetch_arrow_table() # Verify correctness by fetching from the original table and the constructed result - expected = duckdb_cursor.query("select {} from tbl".format(projection)).fetchall() - actual = duckdb_cursor.query("select {} from res".format(projection)).fetchall() + expected = duckdb_cursor.query(f"select {projection} from tbl").fetchall() + actual = duckdb_cursor.query(f"select {projection} from res").fetchall() assert expected == actual - @pytest.mark.parametrize('create_list', list_constructors()) + @pytest.mark.parametrize("create_list", list_constructors()) def test_arrow_ree_list(self, duckdb_cursor, create_list): size = 1000 duckdb_cursor.query( - """ + f""" create table tbl as select i // 4 as ree, - FROM range({}) t(i) - """.format( - size - ) + FROM range({size}) t(i) + """ ) # Populate the table with data @@ -281,7 +272,7 @@ def test_arrow_ree_list(self, duckdb_cursor, create_list): structured = pa.chunked_array(structured_chunks) - arrow_tbl = pa.Table.from_arrays([structured], names=['ree']) + arrow_tbl = pa.Table.from_arrays([structured], names=["ree"]) result = duckdb_cursor.query("select * from arrow_tbl").fetch_arrow_table() assert arrow_tbl.to_pylist() == result.to_pylist() @@ -314,11 +305,11 @@ def test_arrow_ree_struct(self, duckdb_cursor): iterables = [x.iterchunks() for x in columns] zipped = zip(*iterables) - structured_chunks = [pa.StructArray.from_arrays([y for y in x], names=names) for x in zipped] + structured_chunks = [pa.StructArray.from_arrays(list(x), names=names) for x in zipped] structured = pa.chunked_array(structured_chunks) - arrow_tbl = pa.Table.from_arrays([structured], names=['ree']) - result = duckdb_cursor.query("select * from arrow_tbl").fetch_arrow_table() + arrow_tbl = pa.Table.from_arrays([structured], names=["ree"]) # noqa: F841 + result = duckdb_cursor.query("select * from arrow_tbl").fetch_arrow_table() # noqa: F841 expected = duckdb_cursor.query("select {'ree': ree, 'a': a, 'b': b, 'c': c} as s from tbl").fetchall() actual = duckdb_cursor.query("select * from result").fetchall() @@ -329,17 +320,15 @@ def test_arrow_ree_union(self, duckdb_cursor): size = 1000 duckdb_cursor.query( - """ + f""" create table tbl as select i // 4 as ree, i as a, i % 2 == 0 as b, i::VARCHAR as c - FROM range({}) t(i) - """.format( - size - ) + FROM range({size}) t(i) + """ ) # Populate the table with data @@ -368,8 +357,8 @@ def test_arrow_ree_union(self, duckdb_cursor): structured_chunks.append(new_array) structured = pa.chunked_array(structured_chunks) - arrow_tbl = pa.Table.from_arrays([structured], names=['ree']) - result = duckdb_cursor.query("select * from arrow_tbl").fetch_arrow_table() + arrow_tbl = pa.Table.from_arrays([structured], names=["ree"]) # noqa: F841 + result = duckdb_cursor.query("select * from arrow_tbl").fetch_arrow_table() # noqa: F841 # Recreate the same result set expected = [] @@ -389,15 +378,13 @@ def test_arrow_ree_map(self, duckdb_cursor): size = 1000 duckdb_cursor.query( - """ + f""" create table tbl as select i // 4 as ree, i as a, - FROM range({}) t(i) - """.format( - size - ) + FROM range({size}) t(i) + """ ) # Populate the table with data @@ -412,7 +399,6 @@ def test_arrow_ree_map(self, duckdb_cursor): columns[0] = pc.run_end_encode(columns[0]) # Create a (chunked) MapArray from the chunked arrays (columns) of the ArrowTable - names = unstructured.column_names iterables = [x.iterchunks() for x in columns] zipped = zip(*iterables) @@ -431,7 +417,7 @@ def test_arrow_ree_map(self, duckdb_cursor): structured_chunks.append(new_array) structured = pa.chunked_array(structured_chunks) - arrow_tbl = pa.Table.from_arrays([structured], names=['ree']) + arrow_tbl = pa.Table.from_arrays([structured], names=["ree"]) result = duckdb_cursor.query("select * from arrow_tbl").fetch_arrow_table() # Verify that the resulting scan is the same as the input @@ -441,14 +427,12 @@ def test_arrow_ree_dictionary(self, duckdb_cursor): size = 1000 duckdb_cursor.query( - """ + f""" create table tbl as select i // 4 as ree, - FROM range({}) t(i) - """.format( - size - ) + FROM range({size}) t(i) + """ ) # Populate the table with data @@ -467,13 +451,13 @@ def test_arrow_ree_dictionary(self, duckdb_cursor): for chunk in columns[0].iterchunks(): ree = chunk chunk_length = len(ree) - offsets = [i for i in reversed(range(chunk_length))] + offsets = list(reversed(range(chunk_length))) new_array = pa.DictionaryArray.from_arrays(indices=offsets, dictionary=ree) structured_chunks.append(new_array) structured = pa.chunked_array(structured_chunks) - arrow_tbl = pa.Table.from_arrays([structured], names=['ree']) + arrow_tbl = pa.Table.from_arrays([structured], names=["ree"]) result = duckdb_cursor.query("select * from arrow_tbl").fetch_arrow_table() # Verify that the resulting scan is the same as the input diff --git a/tests/fast/arrow/test_arrow_scanner.py b/tests/fast/arrow/test_arrow_scanner.py index 6d74ddb5..918acf65 100644 --- a/tests/fast/arrow/test_arrow_scanner.py +++ b/tests/fast/arrow/test_arrow_scanner.py @@ -1,20 +1,20 @@ +from pathlib import Path + import duckdb -import os try: import pyarrow - import pyarrow.parquet + import pyarrow.compute as pc import pyarrow.dataset + import pyarrow.parquet from pyarrow.dataset import Scanner - import pyarrow.compute as pc - import numpy as np can_run = True -except: +except Exception: can_run = False -class TestArrowScanner(object): +class TestArrowScanner: def test_parallel_scanner(self, duckdb_cursor): if not can_run: return @@ -22,7 +22,7 @@ def test_parallel_scanner(self, duckdb_cursor): duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = str(Path(__file__).parent / "data" / "userdata1.parquet") arrow_dataset = pyarrow.dataset.dataset( [ @@ -33,13 +33,13 @@ def test_parallel_scanner(self, duckdb_cursor): format="parquet", ) - scanner_filter = (pc.field("first_name") == pc.scalar('Jose')) & (pc.field("salary") > pc.scalar(134708.82)) + scanner_filter = (pc.field("first_name") == pc.scalar("Jose")) & (pc.field("salary") > pc.scalar(134708.82)) arrow_scanner = Scanner.from_dataset(arrow_dataset, filter=scanner_filter) rel = duckdb_conn.from_arrow(arrow_scanner) - assert rel.aggregate('count(*)').execute().fetchone()[0] == 12 + assert rel.aggregate("count(*)").execute().fetchone()[0] == 12 def test_parallel_scanner_replacement_scans(self, duckdb_cursor): if not can_run: @@ -48,7 +48,7 @@ def test_parallel_scanner_replacement_scans(self, duckdb_cursor): duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = str(Path(__file__).parent / "data" / "userdata1.parquet") arrow_dataset = pyarrow.dataset.dataset( [ @@ -59,9 +59,9 @@ def test_parallel_scanner_replacement_scans(self, duckdb_cursor): format="parquet", ) - scanner_filter = (pc.field("first_name") == pc.scalar('Jose')) & (pc.field("salary") > pc.scalar(134708.82)) + scanner_filter = (pc.field("first_name") == pc.scalar("Jose")) & (pc.field("salary") > pc.scalar(134708.82)) - arrow_scanner = Scanner.from_dataset(arrow_dataset, filter=scanner_filter) + arrow_scanner = Scanner.from_dataset(arrow_dataset, filter=scanner_filter) # noqa: F841 assert duckdb_conn.execute("select count(*) from arrow_scanner").fetchone()[0] == 12 @@ -72,7 +72,7 @@ def test_parallel_scanner_register(self, duckdb_cursor): duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = str(Path(__file__).parent / "data" / "userdata1.parquet") arrow_dataset = pyarrow.dataset.dataset( [ @@ -83,7 +83,7 @@ def test_parallel_scanner_register(self, duckdb_cursor): format="parquet", ) - scanner_filter = (pc.field("first_name") == pc.scalar('Jose')) & (pc.field("salary") > pc.scalar(134708.82)) + scanner_filter = (pc.field("first_name") == pc.scalar("Jose")) & (pc.field("salary") > pc.scalar(134708.82)) arrow_scanner = Scanner.from_dataset(arrow_dataset, filter=scanner_filter) @@ -95,7 +95,7 @@ def test_parallel_scanner_default_conn(self, duckdb_cursor): if not can_run: return - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = str(Path(__file__).parent / "data" / "userdata1.parquet") arrow_dataset = pyarrow.dataset.dataset( [ @@ -106,10 +106,10 @@ def test_parallel_scanner_default_conn(self, duckdb_cursor): format="parquet", ) - scanner_filter = (pc.field("first_name") == pc.scalar('Jose')) & (pc.field("salary") > pc.scalar(134708.82)) + scanner_filter = (pc.field("first_name") == pc.scalar("Jose")) & (pc.field("salary") > pc.scalar(134708.82)) arrow_scanner = Scanner.from_dataset(arrow_dataset, filter=scanner_filter) rel = duckdb.from_arrow(arrow_scanner) - assert rel.aggregate('count(*)').execute().fetchone()[0] == 12 + assert rel.aggregate("count(*)").execute().fetchone()[0] == 12 diff --git a/tests/fast/arrow/test_arrow_string_view.py b/tests/fast/arrow/test_arrow_string_view.py index fc4bbd40..9ed9bece 100644 --- a/tests/fast/arrow/test_arrow_string_view.py +++ b/tests/fast/arrow/test_arrow_string_view.py @@ -1,11 +1,11 @@ -import duckdb import pytest -from packaging import version -pa = pytest.importorskip('pyarrow') +import duckdb + +pa = pytest.importorskip("pyarrow") pytestmark = pytest.mark.skipif( - not hasattr(pa, 'string_view'), reason="This version of PyArrow does not support StringViews" + not hasattr(pa, "string_view"), reason="This version of PyArrow does not support StringViews" ) @@ -20,14 +20,14 @@ def RoundTripStringView(query, array): # Generate an arrow table # Create a field for the array with a specific data type - field = pa.field('str_val', pa.string_view()) + field = pa.field("str_val", pa.string_view()) # Create a schema for the table using the field schema = pa.schema([field]) # Create a table using the schema and the array - gt_table = pa.Table.from_arrays([array], schema=schema) - arrow_table = con.execute("select * from gt_table").fetch_arrow_table() + gt_table = pa.Table.from_arrays([array], schema=schema) # noqa: F841 + arrow_table = con.execute("select * from gt_table").fetch_arrow_table() # noqa: F841 assert arrow_tbl[0].combine_chunks().tolist() == array.tolist() @@ -43,7 +43,7 @@ def RoundTripDuckDBInternal(query): assert res[i] == from_arrow_res[i] -class TestArrowStringView(object): +class TestArrowStringView: # Test Small Inlined String View def test_inlined_string_view(self): RoundTripStringView( @@ -77,7 +77,7 @@ def test_not_inlined_string_view(self): # Test Small Not-Inlined Strings with Null def test_not_inlined_string_view_with_null(self): RoundTripStringView( - "SELECT 'Imaverybigstringmuchbiggerthanfourbytes'||i::varchar str FROM range(5) tbl(i) UNION SELECT NULL order by str", + "SELECT 'Imaverybigstringmuchbiggerthanfourbytes'||i::varchar str FROM range(5) tbl(i) UNION SELECT NULL order by str", # noqa: E501 pa.array( [ "Imaverybigstringmuchbiggerthanfourbytes0", @@ -92,7 +92,7 @@ def test_not_inlined_string_view_with_null(self): ) # Test Mix of Inlined and Not-Inlined Strings with Null - def test_not_inlined_string_view(self): + def test_not_inlined_string_view_2(self): RoundTripStringView( "SELECT '8bytestr'||(i*10^i)::varchar str FROM range(5) tbl(i) UNION SELECT NULL order by str", pa.array( @@ -103,26 +103,26 @@ def test_not_inlined_string_view(self): # Test Over-Vector Size def test_large_string_view_inlined(self): - RoundTripDuckDBInternal('''select * from (SELECT i::varchar str FROM range(10000) tbl(i)) order by str''') + RoundTripDuckDBInternal("""select * from (SELECT i::varchar str FROM range(10000) tbl(i)) order by str""") def test_large_string_view_inlined_with_null(self): RoundTripDuckDBInternal( - '''select * from (SELECT i::varchar str FROM range(10000) tbl(i) UNION select null) order by str''' + """select * from (SELECT i::varchar str FROM range(10000) tbl(i) UNION select null) order by str""" ) def test_large_string_view_not_inlined(self): RoundTripDuckDBInternal( - '''select * from (SELECT 'Imaverybigstringmuchbiggerthanfourbytes'||i::varchar str FROM range(10000) tbl(i) UNION select null) order by str''' + """select * from (SELECT 'Imaverybigstringmuchbiggerthanfourbytes'||i::varchar str FROM range(10000) tbl(i) UNION select null) order by str""" # noqa: E501 ) def test_large_string_view_not_inlined_with_null(self): RoundTripDuckDBInternal( - '''select * from (SELECT 'Imaverybigstringmuchbiggerthanfourbytes'||i::varchar str FROM range(10000) tbl(i) UNION select null) order by str''' + """select * from (SELECT 'Imaverybigstringmuchbiggerthanfourbytes'||i::varchar str FROM range(10000) tbl(i) UNION select null) order by str""" # noqa: E501 ) def test_large_string_view_mixed_with_null(self): RoundTripDuckDBInternal( - '''select * from (SELECT i::varchar str FROM range(10000) tbl(i) UNION SELECT 'Imaverybigstringmuchbiggerthanfourbytes'||i::varchar str FROM range(10000) tbl(i) UNION select null) order by str''' + """select * from (SELECT i::varchar str FROM range(10000) tbl(i) UNION SELECT 'Imaverybigstringmuchbiggerthanfourbytes'||i::varchar str FROM range(10000) tbl(i) UNION select null) order by str""" # noqa: E501 ) def test_multiple_data_buffers(self): @@ -143,7 +143,7 @@ def test_large_string_polars(self): # pl = pytest.importorskip('polars') # con = duckdb.connect() # con.execute("SET produce_arrow_string_view=True") - # query = '''select * from (SELECT i::varchar str FROM range(10000) tbl(i) UNION SELECT 'Imaverybigstringmuchbiggerthanfourbytes'||i::varchar str FROM range(10000) tbl(i) UNION select null) order by str''' + # query = '''select * from (SELECT i::varchar str FROM range(10000) tbl(i) UNION SELECT 'Imaverybigstringmuchbiggerthanfourbytes'||i::varchar str FROM range(10000) tbl(i) UNION select null) order by str''' # noqa: E501 # polars_df = con.execute(query).pl() # result = con.execute(query).fetchall() # con.register('polars_df', polars_df) diff --git a/tests/fast/arrow/test_arrow_types.py b/tests/fast/arrow/test_arrow_types.py index 97f747ef..be03009c 100644 --- a/tests/fast/arrow/test_arrow_types.py +++ b/tests/fast/arrow/test_arrow_types.py @@ -1,11 +1,12 @@ -import duckdb import pytest +import duckdb + pa = pytest.importorskip("pyarrow") ds = pytest.importorskip("pyarrow.dataset") -class TestArrowTypes(object): +class TestArrowTypes: def test_null_type(self, duckdb_cursor): schema = pa.schema([("data", pa.null())]) inputs = [pa.array([None, None, None], type=pa.null())] @@ -17,17 +18,17 @@ def test_null_type(self, duckdb_cursor): inputs = [pa.array([None, None, None], type=pa.null())] arrow_table = pa.Table.from_arrays(inputs, schema=schema) - assert rel['data'] == arrow_table['data'] + assert rel["data"] == arrow_table["data"] def test_invalid_struct(self, duckdb_cursor): empty_struct_type = pa.struct([]) # Create an empty array with the defined struct type empty_array = pa.array([], type=empty_struct_type) - arrow_table = pa.Table.from_arrays([empty_array], schema=pa.schema([("data", empty_struct_type)])) + arrow_table = pa.Table.from_arrays([empty_array], schema=pa.schema([("data", empty_struct_type)])) # noqa: F841 with pytest.raises( duckdb.InvalidInputException, - match='Attempted to convert a STRUCT with no fields to DuckDB which is not supported', + match="Attempted to convert a STRUCT with no fields to DuckDB which is not supported", ): duckdb_cursor.sql("select * from arrow_table").fetchall() @@ -39,9 +40,6 @@ def test_invalid_union(self, duckdb_cursor): arrow_table = pa.Table.from_arrays([sparse_union_array], schema=pa.schema([("data", sparse_union_array.type)])) with pytest.raises( duckdb.InvalidInputException, - match='Attempted to convert a UNION with no fields to DuckDB which is not supported', + match="Attempted to convert a UNION with no fields to DuckDB which is not supported", ): - duckdb_cursor.register('invalid_union', arrow_table) - - res = duckdb_cursor.sql("select * from invalid_union").fetchall() - print(res) + duckdb_cursor.register("invalid_union", arrow_table) diff --git a/tests/fast/arrow/test_arrow_union.py b/tests/fast/arrow/test_arrow_union.py index 1d853a1b..784a5433 100644 --- a/tests/fast/arrow/test_arrow_union.py +++ b/tests/fast/arrow/test_arrow_union.py @@ -1,29 +1,26 @@ -from pytest import importorskip +import pytest -importorskip('pyarrow') - -import duckdb -from pyarrow import scalar, string, large_string, list_, int32, types +pyarrow = pytest.importorskip("pyarrow") def test_nested(duckdb_cursor): - res = run(duckdb_cursor, 'select 42::UNION(name VARCHAR, attr UNION(age INT, veteran BOOL)) as res') - assert types.is_union(res.type) - assert res.value.value == scalar(42, type=int32()) + res = run(duckdb_cursor, "select 42::UNION(name VARCHAR, attr UNION(age INT, veteran BOOL)) as res") + assert pyarrow.types.is_union(res.type) + assert res.value.value == pyarrow.scalar(42, type=pyarrow.int32()) def test_union_contains_nested_data(duckdb_cursor): - _ = importorskip("pyarrow", minversion="11") + _ = pytest.importorskip("pyarrow", minversion="11") res = run(duckdb_cursor, "select ['hello']::UNION(first_name VARCHAR, middle_names VARCHAR[]) as res") - assert types.is_union(res.type) - assert res.value == scalar(['hello'], type=list_(string())) + assert pyarrow.types.is_union(res.type) + assert res.value == pyarrow.scalar(["hello"], type=pyarrow.list_(pyarrow.string())) def test_unions_inside_lists_structs_maps(duckdb_cursor): res = run(duckdb_cursor, "select [union_value(name := 'Frank')] as res") - assert types.is_list(res.type) - assert types.is_union(res.type.value_type) - assert res[0].value == scalar('Frank', type=string()) + assert pyarrow.types.is_list(res.type) + assert pyarrow.types.is_union(res.type.value_type) + assert res[0].value == pyarrow.scalar("Frank", type=pyarrow.string()) def test_unions_with_struct(duckdb_cursor): @@ -38,13 +35,13 @@ def test_unions_with_struct(duckdb_cursor): """ ) - rel = duckdb_cursor.table('tbl') - arrow = rel.fetch_arrow_table() + rel = duckdb_cursor.table("tbl") + arrow = rel.fetch_arrow_table() # noqa: F841 duckdb_cursor.execute("create table other as select * from arrow") - rel2 = duckdb_cursor.table('other') + rel2 = duckdb_cursor.table("other") res = rel2.fetchall() - assert res == [({'a': 42, 'b': True},)] + assert res == [({"a": 42, "b": True},)] def run(conn, query): diff --git a/tests/fast/arrow/test_arrow_version_format.py b/tests/fast/arrow/test_arrow_version_format.py index ff8699eb..d2864b15 100644 --- a/tests/fast/arrow/test_arrow_version_format.py +++ b/tests/fast/arrow/test_arrow_version_format.py @@ -1,15 +1,17 @@ -import duckdb -import pytest from decimal import Decimal +import pytest + +import duckdb + pa = pytest.importorskip("pyarrow") -class TestArrowDecimalTypes(object): +class TestArrowDecimalTypes: def test_decimal_v1_5(self, duckdb_cursor): duckdb_cursor = duckdb.connect() - duckdb_cursor.execute(f"SET arrow_output_version = 1.5") - decimal_32 = pa.Table.from_pylist( + duckdb_cursor.execute("SET arrow_output_version = 1.5") + decimal_32 = pa.Table.from_pylist( # noqa: F841 [ {"data": Decimal("100.20")}, {"data": Decimal("110.21")}, @@ -19,9 +21,10 @@ def test_decimal_v1_5(self, duckdb_cursor): pa.schema([("data", pa.decimal32(5, 2))]), ) col_type = duckdb_cursor.execute("FROM decimal_32").fetch_arrow_table().schema.field("data").type - assert col_type.bit_width == 32 and pa.types.is_decimal(col_type) + assert col_type.bit_width == 32 + assert pa.types.is_decimal(col_type) - decimal_64 = pa.Table.from_pylist( + decimal_64 = pa.Table.from_pylist( # noqa: F841 [ {"data": Decimal("1000.231")}, {"data": Decimal("1100.231")}, @@ -31,31 +34,34 @@ def test_decimal_v1_5(self, duckdb_cursor): pa.schema([("data", pa.decimal64(16, 3))]), ) col_type = duckdb_cursor.execute("FROM decimal_64").fetch_arrow_table().schema.field("data").type - assert col_type.bit_width == 64 and pa.types.is_decimal(col_type) - for version in ['1.0', '1.1', '1.2', '1.3', '1.4']: + assert col_type.bit_width == 64 + assert pa.types.is_decimal(col_type) + for version in ["1.0", "1.1", "1.2", "1.3", "1.4"]: duckdb_cursor.execute(f"SET arrow_output_version = {version}") result = duckdb_cursor.execute("FROM decimal_32").fetch_arrow_table() col_type = result.schema.field("data").type - assert col_type.bit_width == 128 and pa.types.is_decimal(col_type) + assert col_type.bit_width == 128 + assert pa.types.is_decimal(col_type) assert result.to_pydict() == { - 'data': [Decimal('100.20'), Decimal('110.21'), Decimal('31.20'), Decimal('500.20')] + "data": [Decimal("100.20"), Decimal("110.21"), Decimal("31.20"), Decimal("500.20")] } result = duckdb_cursor.execute("FROM decimal_64").fetch_arrow_table() col_type = result.schema.field("data").type - assert col_type.bit_width == 128 and pa.types.is_decimal(col_type) + assert col_type.bit_width == 128 + assert pa.types.is_decimal(col_type) assert result.to_pydict() == { - 'data': [Decimal('1000.231'), Decimal('1100.231'), Decimal('999999999999.231'), Decimal('500.200')] + "data": [Decimal("1000.231"), Decimal("1100.231"), Decimal("999999999999.231"), Decimal("500.200")] } def test_invalide_opt(self, duckdb_cursor): duckdb_cursor = duckdb.connect() with pytest.raises(duckdb.NotImplementedException, match="unrecognized"): - duckdb_cursor.execute(f"SET arrow_output_version = 999.9") + duckdb_cursor.execute("SET arrow_output_version = 999.9") def test_view_v1_4(self, duckdb_cursor): duckdb_cursor = duckdb.connect() - duckdb_cursor.execute(f"SET arrow_output_version = 1.5") + duckdb_cursor.execute("SET arrow_output_version = 1.5") duckdb_cursor.execute("SET produce_arrow_string_view=True") duckdb_cursor.execute("SET arrow_output_list_view=True") col_type = duckdb_cursor.execute("SELECT 'string' as data ").fetch_arrow_table().schema.field("data").type @@ -63,14 +69,14 @@ def test_view_v1_4(self, duckdb_cursor): col_type = duckdb_cursor.execute("SELECT ['string'] as data ").fetch_arrow_table().schema.field("data").type assert pa.types.is_list_view(col_type) - for version in ['1.0', '1.1', '1.2', '1.3']: + for version in ["1.0", "1.1", "1.2", "1.3"]: duckdb_cursor.execute(f"SET arrow_output_version = {version}") col_type = duckdb_cursor.execute("SELECT 'string' as data ").fetch_arrow_table().schema.field("data").type assert not pa.types.is_string_view(col_type) col_type = duckdb_cursor.execute("SELECT ['string'] as data ").fetch_arrow_table().schema.field("data").type assert not pa.types.is_list_view(col_type) - for version in ['1.4', '1.5']: + for version in ["1.4", "1.5"]: duckdb_cursor.execute(f"SET arrow_output_version = {version}") col_type = duckdb_cursor.execute("SELECT 'string' as data ").fetch_arrow_table().schema.field("data").type assert pa.types.is_string_view(col_type) @@ -80,7 +86,7 @@ def test_view_v1_4(self, duckdb_cursor): duckdb_cursor.execute("SET produce_arrow_string_view=False") duckdb_cursor.execute("SET arrow_output_list_view=False") - for version in ['1.4', '1.5']: + for version in ["1.4", "1.5"]: duckdb_cursor.execute(f"SET arrow_output_version = {version}") col_type = duckdb_cursor.execute("SELECT 'string' as data ").fetch_arrow_table().schema.field("data").type assert not pa.types.is_string_view(col_type) diff --git a/tests/fast/arrow/test_binary_type.py b/tests/fast/arrow/test_binary_type.py index 489d4caf..d549c82b 100644 --- a/tests/fast/arrow/test_binary_type.py +++ b/tests/fast/arrow/test_binary_type.py @@ -1,13 +1,10 @@ import duckdb -import os try: import pyarrow as pa - from pyarrow import parquet as pq - import numpy as np can_run = True -except: +except Exception: can_run = False @@ -17,7 +14,7 @@ def create_binary_table(type): return pa.Table.from_arrays(inputs, schema=schema) -class TestArrowBinary(object): +class TestArrowBinary: def test_binary_types(self, duckdb_cursor): if not can_run: return diff --git a/tests/fast/arrow/test_buffer_size_option.py b/tests/fast/arrow/test_buffer_size_option.py index 46047e21..28e3bf58 100644 --- a/tests/fast/arrow/test_buffer_size_option.py +++ b/tests/fast/arrow/test_buffer_size_option.py @@ -1,11 +1,12 @@ -import duckdb import pytest +import duckdb +from duckdb.typing import VARCHAR + pa = pytest.importorskip("pyarrow") -from duckdb.typing import * -class TestArrowBufferSize(object): +class TestArrowBufferSize: def test_arrow_buffer_size(self): con = duckdb.connect() @@ -34,7 +35,7 @@ def just_return(x): return x con = duckdb.connect() - con.create_function('just_return', just_return, [VARCHAR], VARCHAR, type='arrow') + con.create_function("just_return", just_return, [VARCHAR], VARCHAR, type="arrow") res = con.query("select just_return('bla')").fetch_arrow_table() diff --git a/tests/fast/arrow/test_dataset.py b/tests/fast/arrow/test_dataset.py index 2f3d7a53..36e29110 100644 --- a/tests/fast/arrow/test_dataset.py +++ b/tests/fast/arrow/test_dataset.py @@ -1,20 +1,22 @@ -import duckdb -import os +from pathlib import Path + import pytest +import duckdb + pyarrow = pytest.importorskip("pyarrow") np = pytest.importorskip("numpy") pyarrow.parquet = pytest.importorskip("pyarrow.parquet") pyarrow.dataset = pytest.importorskip("pyarrow.dataset") -class TestArrowDataset(object): +class TestArrowDataset: def test_parallel_dataset(self, duckdb_cursor): duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") duckdb_conn.execute("PRAGMA verify_parallelism") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = str(Path(__file__).parent / "data" / "userdata1.parquet") userdata_parquet_dataset = pyarrow.dataset.dataset( [ @@ -28,7 +30,7 @@ def test_parallel_dataset(self, duckdb_cursor): rel = duckdb_conn.from_arrow(userdata_parquet_dataset) assert ( - rel.filter("first_name=\'Jose\' and salary > 134708.82").aggregate('count(*)').execute().fetchone()[0] == 12 + rel.filter("first_name='Jose' and salary > 134708.82").aggregate("count(*)").execute().fetchone()[0] == 12 ) def test_parallel_dataset_register(self, duckdb_cursor): @@ -36,7 +38,7 @@ def test_parallel_dataset_register(self, duckdb_cursor): duckdb_conn.execute("PRAGMA threads=4") duckdb_conn.execute("PRAGMA verify_parallelism") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = str(Path(__file__).parent / "data" / "userdata1.parquet") userdata_parquet_dataset = pyarrow.dataset.dataset( [ @@ -47,7 +49,7 @@ def test_parallel_dataset_register(self, duckdb_cursor): format="parquet", ) - rel = duckdb_conn.register("dataset", userdata_parquet_dataset) + duckdb_conn.register("dataset", userdata_parquet_dataset) assert ( duckdb_conn.execute( @@ -61,7 +63,7 @@ def test_parallel_dataset_roundtrip(self, duckdb_cursor): duckdb_conn.execute("PRAGMA threads=4") duckdb_conn.execute("PRAGMA verify_parallelism") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = str(Path(__file__).parent / "data" / "userdata1.parquet") userdata_parquet_dataset = pyarrow.dataset.dataset( [ @@ -72,16 +74,16 @@ def test_parallel_dataset_roundtrip(self, duckdb_cursor): format="parquet", ) - rel = duckdb_conn.register("dataset", userdata_parquet_dataset) + duckdb_conn.register("dataset", userdata_parquet_dataset) query = duckdb_conn.execute("SELECT * FROM dataset order by id") record_batch_reader = query.fetch_record_batch(2048) - arrow_table = record_batch_reader.read_all() + arrow_table = record_batch_reader.read_all() # noqa: F841 # reorder since order of rows isn't deterministic - df = userdata_parquet_dataset.to_table().to_pandas().sort_values('id').reset_index(drop=True) + df = userdata_parquet_dataset.to_table().to_pandas().sort_values("id").reset_index(drop=True) # turn it into an arrow table - arrow_table_2 = pyarrow.Table.from_pandas(df) + arrow_table_2 = pyarrow.Table.from_pandas(df) # noqa: F841 result_1 = duckdb_conn.execute("select * from arrow_table order by all").fetchall() result_2 = duckdb_conn.execute("select * from arrow_table_2 order by all").fetchall() @@ -90,11 +92,11 @@ def test_parallel_dataset_roundtrip(self, duckdb_cursor): def test_ducktyping(self, duckdb_cursor): duckdb_conn = duckdb.connect() - dataset = CustomDataset() + dataset = CustomDataset() # noqa: F841 query = duckdb_conn.execute("SELECT b FROM dataset WHERE a < 5") record_batch_reader = query.fetch_record_batch(2048) arrow_table = record_batch_reader.read_all() - assert arrow_table.equals(CustomDataset.DATA[:5].select(['b'])) + assert arrow_table.equals(CustomDataset.DATA[:5].select(["b"])) class CustomDataset(pyarrow.dataset.Dataset): @@ -102,7 +104,7 @@ class CustomDataset(pyarrow.dataset.Dataset): SCHEMA = pyarrow.schema([pyarrow.field("a", pyarrow.int64(), True), pyarrow.field("b", pyarrow.float64(), True)]) DATA = pyarrow.Table.from_arrays([pyarrow.array(range(100)), pyarrow.array(np.arange(100) * 1.0)], schema=SCHEMA) - def __init__(self): + def __init__(self) -> None: pass def scanner(self, **kwargs): @@ -114,7 +116,7 @@ def schema(self): class CustomScanner(pyarrow.dataset.Scanner): - def __init__(self, filter=None, columns=None, **kwargs): + def __init__(self, filter=None, columns=None, **kwargs) -> None: self.filter = filter self.columns = columns self.kwargs = kwargs diff --git a/tests/fast/arrow/test_date.py b/tests/fast/arrow/test_date.py index 316fc689..20cf9f0f 100644 --- a/tests/fast/arrow/test_date.py +++ b/tests/fast/arrow/test_date.py @@ -1,47 +1,43 @@ import duckdb -import os -import datetime -import pytest try: import pyarrow as pa - import pandas as pd can_run = True -except: +except Exception: can_run = False -class TestArrowDate(object): +class TestArrowDate: def test_date_types(self, duckdb_cursor): if not can_run: return data = (pa.array([1000 * 60 * 60 * 24], type=pa.date64()), pa.array([1], type=pa.date32())) - arrow_table = pa.Table.from_arrays([data[0], data[1]], ['a', 'b']) + arrow_table = pa.Table.from_arrays([data[0], data[1]], ["a", "b"]) rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() - assert rel['a'] == arrow_table['b'] - assert rel['b'] == arrow_table['b'] + assert rel["a"] == arrow_table["b"] + assert rel["b"] == arrow_table["b"] def test_date_null(self, duckdb_cursor): if not can_run: return data = (pa.array([None], type=pa.date64()), pa.array([None], type=pa.date32())) - arrow_table = pa.Table.from_arrays([data[0], data[1]], ['a', 'b']) + arrow_table = pa.Table.from_arrays([data[0], data[1]], ["a", "b"]) rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() - assert rel['a'] == arrow_table['b'] - assert rel['b'] == arrow_table['b'] + assert rel["a"] == arrow_table["b"] + assert rel["b"] == arrow_table["b"] def test_max_date(self, duckdb_cursor): if not can_run: return data = (pa.array([2147483647], type=pa.date32()), pa.array([2147483647], type=pa.date32())) - result = pa.Table.from_arrays([data[0], data[1]], ['a', 'b']) + result = pa.Table.from_arrays([data[0], data[1]], ["a", "b"]) data = ( pa.array([2147483647 * (1000 * 60 * 60 * 24)], type=pa.date64()), pa.array([2147483647], type=pa.date32()), ) - arrow_table = pa.Table.from_arrays([data[0], data[1]], ['a', 'b']) + arrow_table = pa.Table.from_arrays([data[0], data[1]], ["a", "b"]) rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() - assert rel['a'] == result['a'] - assert rel['b'] == result['b'] + assert rel["a"] == result["a"] + assert rel["b"] == result["b"] diff --git a/tests/fast/arrow/test_dictionary_arrow.py b/tests/fast/arrow/test_dictionary_arrow.py index 823d6b05..32c348a3 100644 --- a/tests/fast/arrow/test_dictionary_arrow.py +++ b/tests/fast/arrow/test_dictionary_arrow.py @@ -1,4 +1,4 @@ -import duckdb +import datetime import pytest @@ -7,17 +7,16 @@ ds = pytest.importorskip("pyarrow.dataset") np = pytest.importorskip("numpy") pd = pytest.importorskip("pandas") -import datetime Timestamp = pd.Timestamp -class TestArrowDictionary(object): +class TestArrowDictionary: def test_dictionary(self, duckdb_cursor): indices = pa.array([0, 1, 0, 1, 2, 1, 0, 2]) dictionary = pa.array([10, 100, None]) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array], ['a']) + arrow_table = pa.Table.from_arrays([dict_array], ["a"]) rel = duckdb_cursor.from_arrow(arrow_table) assert rel.execute().fetchall() == [(10,), (100,), (10,), (100,), (None,), (100,), (10,), (None,)] @@ -27,14 +26,14 @@ def test_dictionary(self, duckdb_cursor): indices = pa.array(indices_list) dictionary = pa.array([10, 100, None, 999999]) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array], ['a']) + arrow_table = pa.Table.from_arrays([dict_array], ["a"]) rel = duckdb_cursor.from_arrow(arrow_table) result = [(10,), (100,), (10,), (100,), (None,), (100,), (10,), (None,), (999999,)] * 10000 assert rel.execute().fetchall() == result # Table with dictionary and normal array - arrow_table = pa.Table.from_arrays([dict_array, pa.array(indices_list)], ['a', 'b']) + arrow_table = pa.Table.from_arrays([dict_array, pa.array(indices_list)], ["a", "b"]) rel = duckdb_cursor.from_arrow(arrow_table) result = [(10, 0), (100, 1), (10, 0), (100, 1), (None, 2), (100, 1), (10, 0), (None, 2), (999999, 3)] * 10000 assert rel.execute().fetchall() == result @@ -43,7 +42,7 @@ def test_dictionary_null_index(self, duckdb_cursor): indices = pa.array([None, 1, 0, 1, 2, 1, 0, 2]) dictionary = pa.array([10, 100, None]) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array], ['a']) + arrow_table = pa.Table.from_arrays([dict_array], ["a"]) rel = duckdb_cursor.from_arrow(arrow_table) assert rel.execute().fetchall() == [(None,), (100,), (10,), (100,), (None,), (100,), (10,), (None,)] @@ -51,7 +50,7 @@ def test_dictionary_null_index(self, duckdb_cursor): indices = pa.array([None, 1, None, 1, 2, 1, 0]) dictionary = pa.array([10, 100, 100]) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array], ['a']) + arrow_table = pa.Table.from_arrays([dict_array], ["a"]) rel = duckdb_cursor.from_arrow(arrow_table) print(rel.execute().fetchall()) assert rel.execute().fetchall() == [(None,), (100,), (None,), (100,), (100,), (100,), (10,)] @@ -61,19 +60,19 @@ def test_dictionary_null_index(self, duckdb_cursor): indices = pa.array(indices_list * 1000) dictionary = pa.array([10, 100, 100]) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array], ['a']) + arrow_table = pa.Table.from_arrays([dict_array], ["a"]) rel = duckdb_cursor.from_arrow(arrow_table) result = [(None,), (100,), (None,), (100,), (100,), (100,), (10,)] * 1000 assert rel.execute().fetchall() == result # Table with dictionary and normal array - arrow_table = pa.Table.from_arrays([dict_array, indices], ['a', 'b']) + arrow_table = pa.Table.from_arrays([dict_array, indices], ["a", "b"]) rel = duckdb_cursor.from_arrow(arrow_table) result = [(None, None), (100, 1), (None, None), (100, 1), (100, 2), (100, 1), (10, 0)] * 1000 assert rel.execute().fetchall() == result @pytest.mark.parametrize( - 'element', + "element", [ # list """ @@ -110,7 +109,7 @@ def test_dictionary_null_index(self, duckdb_cursor): ], ) @pytest.mark.parametrize( - 'count', + "count", [ 1, 10, @@ -123,14 +122,14 @@ def test_dictionary_null_index(self, duckdb_cursor): 5000, ], ) - @pytest.mark.parametrize('query', ["select {} as a from range({})", "select [{} for x in range({})] as a"]) + @pytest.mark.parametrize("query", ["select {} as a from range({})", "select [{} for x in range({})] as a"]) def test_dictionary_roundtrip(self, query, element, duckdb_cursor, count): query = query.format(element, count) original_rel = duckdb_cursor.sql(query) expected = original_rel.fetchall() - arrow_res = original_rel.fetch_arrow_table() + arrow_res = original_rel.fetch_arrow_table() # noqa: F841 - roundtrip_rel = duckdb_cursor.sql('select * from arrow_res') + roundtrip_rel = duckdb_cursor.sql("select * from arrow_res") actual = roundtrip_rel.fetchall() assert expected == actual assert original_rel.columns == roundtrip_rel.columns @@ -142,14 +141,14 @@ def test_dictionary_batches(self, duckdb_cursor): indices = pa.array(indices_list * 10000) dictionary = pa.array([10, 100, 100]) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array], ['a']) + arrow_table = pa.Table.from_arrays([dict_array], ["a"]) batch_arrow_table = pa.Table.from_batches(arrow_table.to_batches(10)) rel = duckdb_cursor.from_arrow(batch_arrow_table) result = [(None,), (100,), (None,), (100,), (100,), (100,), (10,)] * 10000 assert rel.execute().fetchall() == result # Table with dictionary and normal array - arrow_table = pa.Table.from_arrays([dict_array, indices], ['a', 'b']) + arrow_table = pa.Table.from_arrays([dict_array, indices], ["a", "b"]) batch_arrow_table = pa.Table.from_batches(arrow_table.to_batches(10)) rel = duckdb_cursor.from_arrow(batch_arrow_table) result = [(None, None), (100, 1), (None, None), (100, 1), (100, 2), (100, 1), (10, 0)] * 10000 @@ -157,14 +156,14 @@ def test_dictionary_batches(self, duckdb_cursor): def test_dictionary_lifetime(self, duckdb_cursor): tables = [] - expected = '' + expected = "" for i in range(100): if i % 3 == 0: - input = 'ABCD' * 17000 + input = "ABCD" * 17000 elif i % 3 == 1: - input = 'FOOO' * 17000 + input = "FOOO" * 17000 else: - input = 'BARR' * 17000 + input = "BARR" * 17000 expected += input array = pa.array( input, @@ -173,7 +172,7 @@ def test_dictionary_lifetime(self, duckdb_cursor): tables.append(pa.table([array], names=["x"])) # All of the tables with different dictionaries are getting merged into one dataset # This is testing that our cache is being evicted correctly - x = ds.dataset(tables) + x = ds.dataset(tables) # noqa: F841 res = duckdb_cursor.sql("select * from x").fetchall() expected = [(x,) for x in expected] assert res == expected @@ -186,14 +185,14 @@ def test_dictionary_batches_parallel(self, duckdb_cursor): indices = pa.array(indices_list * 10000) dictionary = pa.array([10, 100, 100]) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array], ['a']) + arrow_table = pa.Table.from_arrays([dict_array], ["a"]) batch_arrow_table = pa.Table.from_batches(arrow_table.to_batches(10)) rel = duckdb_cursor.from_arrow(batch_arrow_table) result = [(None,), (100,), (None,), (100,), (100,), (100,), (10,)] * 10000 assert rel.execute().fetchall() == result # Table with dictionary and normal array - arrow_table = pa.Table.from_arrays([dict_array, indices], ['a', 'b']) + arrow_table = pa.Table.from_arrays([dict_array, indices], ["a", "b"]) batch_arrow_table = pa.Table.from_batches(arrow_table.to_batches(10)) rel = duckdb_cursor.from_arrow(batch_arrow_table) result = [(None, None), (100, 1), (None, None), (100, 1), (100, 2), (100, 1), (10, 0)] * 10000 @@ -214,7 +213,7 @@ def test_dictionary_index_types(self, duckdb_cursor): for index_type in index_types: dict_array = pa.DictionaryArray.from_arrays(index_type, dictionary) - arrow_table = pa.Table.from_arrays([dict_array], ['a']) + arrow_table = pa.Table.from_arrays([dict_array], ["a"]) rel = duckdb_cursor.from_arrow(arrow_table) result = [(None,), (100,), (None,), (100,), (100,), (100,), (10,)] * 10000 assert rel.execute().fetchall() == result @@ -222,17 +221,17 @@ def test_dictionary_index_types(self, duckdb_cursor): def test_dictionary_strings(self, duckdb_cursor): indices_list = [None, 0, 1, 2, 3, 4, None] indices = pa.array(indices_list * 1000) - dictionary = pa.array(['Matt Daaaaaaaaamon', 'Alec Baldwin', 'Sean Penn', 'Tim Robbins', 'Samuel L. Jackson']) + dictionary = pa.array(["Matt Daaaaaaaaamon", "Alec Baldwin", "Sean Penn", "Tim Robbins", "Samuel L. Jackson"]) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array], ['a']) + arrow_table = pa.Table.from_arrays([dict_array], ["a"]) rel = duckdb_cursor.from_arrow(arrow_table) result = [ (None,), - ('Matt Daaaaaaaaamon',), - ('Alec Baldwin',), - ('Sean Penn',), - ('Tim Robbins',), - ('Samuel L. Jackson',), + ("Matt Daaaaaaaaamon",), + ("Alec Baldwin",), + ("Sean Penn",), + ("Tim Robbins",), + ("Samuel L. Jackson",), (None,), ] * 1000 assert rel.execute().fetchall() == result @@ -249,7 +248,7 @@ def test_dictionary_timestamps(self, duckdb_cursor): ] ) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array], ['a']) + arrow_table = pa.Table.from_arrays([dict_array], ["a"]) rel = duckdb_cursor.from_arrow(arrow_table) print(rel.execute().fetchall()) expected = [ diff --git a/tests/fast/arrow/test_filter_pushdown.py b/tests/fast/arrow/test_filter_pushdown.py index dffa9631..c3f71b65 100644 --- a/tests/fast/arrow/test_filter_pushdown.py +++ b/tests/fast/arrow/test_filter_pushdown.py @@ -1,12 +1,12 @@ -from re import S -import duckdb -import os +# ruff: noqa: F841 +import sys + import pytest -import tempfile from conftest import pandas_supports_arrow_backend -import sys from packaging.version import Version +import duckdb + pa = pytest.importorskip("pyarrow") pq = pytest.importorskip("pyarrow.parquet") ds = pytest.importorskip("pyarrow.dataset") @@ -17,7 +17,7 @@ def create_pyarrow_pandas(rel): if not pandas_supports_arrow_backend(): pytest.skip(reason="Pandas version doesn't support 'pyarrow' backend") - return rel.df().convert_dtypes(dtype_backend='pyarrow') + return rel.df().convert_dtypes(dtype_backend="pyarrow") def create_pyarrow_table(rel): @@ -34,7 +34,7 @@ def test_decimal_filter_pushdown(duckdb_cursor): np = pytest.importorskip("numpy") np.random.seed(10) - df = pl.DataFrame({'x': pl.Series(np.random.uniform(-10, 10, 1000)).cast(pl.Decimal(precision=18, scale=4))}) + df = pl.DataFrame({"x": pl.Series(np.random.uniform(-10, 10, 1000)).cast(pl.Decimal(precision=18, scale=4))}) query = """ SELECT @@ -178,35 +178,34 @@ def string_check_or_pushdown(connection, tbl_name, create_table): assert not match -class TestArrowFilterPushdown(object): - +class TestArrowFilterPushdown: @pytest.mark.parametrize( - 'data_type', + "data_type", [ - 'TINYINT', - 'SMALLINT', - 'INTEGER', - 'BIGINT', - 'UTINYINT', - 'USMALLINT', - 'UINTEGER', - 'UBIGINT', - 'FLOAT', - 'DOUBLE', - 'HUGEINT', - 'DECIMAL(4,1)', - 'DECIMAL(9,1)', - 'DECIMAL(18,4)', - 'DECIMAL(30,12)', + "TINYINT", + "SMALLINT", + "INTEGER", + "BIGINT", + "UTINYINT", + "USMALLINT", + "UINTEGER", + "UBIGINT", + "FLOAT", + "DOUBLE", + "HUGEINT", + "DECIMAL(4,1)", + "DECIMAL(9,1)", + "DECIMAL(18,4)", + "DECIMAL(30,12)", ], ) - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_filter_pushdown_numeric(self, data_type, duckdb_cursor, create_table): tbl_name = "tbl" numeric_operators(duckdb_cursor, data_type, tbl_name, create_table) numeric_check_or_pushdown(duckdb_cursor, tbl_name, create_table) - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_filter_pushdown_varchar(self, duckdb_cursor, create_table): duckdb_cursor.execute( """ @@ -259,7 +258,7 @@ def test_filter_pushdown_varchar(self, duckdb_cursor, create_table): # More complex tests for OR pushed down on string string_check_or_pushdown(duckdb_cursor, "test_varchar", create_table) - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_filter_pushdown_bool(self, duckdb_cursor, create_table): duckdb_cursor.execute( """ @@ -294,7 +293,7 @@ def test_filter_pushdown_bool(self, duckdb_cursor, create_table): # Try Or assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a = True or b = True").fetchone()[0] == 3 - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_filter_pushdown_time(self, duckdb_cursor, create_table): duckdb_cursor.execute( """ @@ -352,7 +351,7 @@ def test_filter_pushdown_time(self, duckdb_cursor, create_table): == 2 ) - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_filter_pushdown_timestamp(self, duckdb_cursor, create_table): duckdb_cursor.execute( """ @@ -410,7 +409,7 @@ def test_filter_pushdown_timestamp(self, duckdb_cursor, create_table): ) assert ( duckdb_cursor.execute( - "SELECT count(*) from arrow_table where a ='2020-03-01 10:00:01' and b = '2010-01-01 10:00:01' and c = '2020-03-01 10:00:01'" + "SELECT count(*) from arrow_table where a ='2020-03-01 10:00:01' and b = '2010-01-01 10:00:01' and c = '2020-03-01 10:00:01'" # noqa: E501 ).fetchone()[0] == 1 ) @@ -422,7 +421,7 @@ def test_filter_pushdown_timestamp(self, duckdb_cursor, create_table): == 2 ) - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_filter_pushdown_timestamp_TZ(self, duckdb_cursor, create_table): duckdb_cursor.execute( """ @@ -482,7 +481,7 @@ def test_filter_pushdown_timestamp_TZ(self, duckdb_cursor, create_table): ) assert ( duckdb_cursor.execute( - "SELECT count(*) from arrow_table where a = '2020-03-01 10:00:01' and b = '2010-01-01 10:00:01' and c = '2020-03-01 10:00:01'" + "SELECT count(*) from arrow_table where a = '2020-03-01 10:00:01' and b = '2010-01-01 10:00:01' and c = '2020-03-01 10:00:01'" # noqa: E501 ).fetchone()[0] == 1 ) @@ -494,18 +493,18 @@ def test_filter_pushdown_timestamp_TZ(self, duckdb_cursor, create_table): == 2 ) - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) @pytest.mark.parametrize( - ['data_type', 'value'], + ("data_type", "value"), [ - ['TINYINT', 127], - ['SMALLINT', 32767], - ['INTEGER', 2147483647], - ['BIGINT', 9223372036854775807], - ['UTINYINT', 255], - ['USMALLINT', 65535], - ['UINTEGER', 4294967295], - ['UBIGINT', 18446744073709551615], + ("TINYINT", 127), + ("SMALLINT", 32767), + ("INTEGER", 2147483647), + ("BIGINT", 9223372036854775807), + ("UTINYINT", 255), + ("USMALLINT", 65535), + ("UINTEGER", 4294967295), + ("UBIGINT", 18446744073709551615), ], ) def test_filter_pushdown_integers(self, duckdb_cursor, data_type, value, create_table): @@ -514,9 +513,9 @@ def test_filter_pushdown_integers(self, duckdb_cursor, data_type, value, create_ CREATE TABLE tbl as select {value}::{data_type} as i """ ) - expected = duckdb_cursor.table('tbl').fetchall() + expected = duckdb_cursor.table("tbl").fetchall() filter = "i > 0" - rel = duckdb_cursor.table('tbl') + rel = duckdb_cursor.table("tbl") arrow_table = create_table(rel) actual = duckdb_cursor.sql(f"select * from arrow_table where {filter}").fetchall() assert expected == actual @@ -529,11 +528,10 @@ def test_filter_pushdown_integers(self, duckdb_cursor, data_type, value, create_ assert expected == actual @pytest.mark.skipif( - Version(pa.__version__) < Version('15.0.0'), reason="pyarrow 14.0.2 'to_pandas' causes a DeprecationWarning" + Version(pa.__version__) < Version("15.0.0"), reason="pyarrow 14.0.2 'to_pandas' causes a DeprecationWarning" ) def test_9371(self, duckdb_cursor, tmp_path): import datetime - import pathlib # connect to an in-memory database duckdb_cursor.execute("SET TimeZone='UTC';") @@ -546,7 +544,7 @@ def test_9371(self, duckdb_cursor, tmp_path): # Example data dt = datetime.datetime(2023, 8, 29, 1, tzinfo=datetime.timezone.utc) - my_arrow_table = pa.Table.from_pydict({'ts': [dt, dt, dt], 'value': [1, 2, 3]}) + my_arrow_table = pa.Table.from_pydict({"ts": [dt, dt, dt], "value": [1, 2, 3]}) df = my_arrow_table.to_pandas() df = df.set_index("ts") # SET INDEX! (It all works correctly when the index is not set) df.to_parquet(str(file_path)) @@ -557,7 +555,7 @@ def test_9371(self, duckdb_cursor, tmp_path): expected = [(1, dt), (2, dt), (3, dt)] assert output == expected - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_filter_pushdown_date(self, duckdb_cursor, create_table): duckdb_cursor.execute( """ @@ -617,15 +615,15 @@ def test_filter_pushdown_date(self, duckdb_cursor, create_table): == 2 ) - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_filter_pushdown_blob(self, duckdb_cursor, create_table): import pandas df = pandas.DataFrame( { - 'a': [bytes([1]), bytes([2]), bytes([3]), None], - 'b': [bytes([1]), bytes([2]), bytes([3]), None], - 'c': [bytes([1]), bytes([2]), bytes([3]), None], + "a": [bytes([1]), bytes([2]), bytes([3]), None], + "b": [bytes([1]), bytes([2]), bytes([3]), None], + "c": [bytes([1]), bytes([2]), bytes([3]), None], } ) rel = duckdb.from_df(df) @@ -660,7 +658,7 @@ def test_filter_pushdown_blob(self, duckdb_cursor, create_table): duckdb_cursor.execute("SELECT count(*) from arrow_table where a = '\x01' or b = '\x02'").fetchone()[0] == 2 ) - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table, create_pyarrow_dataset]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table, create_pyarrow_dataset]) def test_filter_pushdown_no_projection(self, duckdb_cursor, create_table): duckdb_cursor.execute( """ @@ -685,7 +683,7 @@ def test_filter_pushdown_no_projection(self, duckdb_cursor, create_table): assert duckdb_cursor.execute("SELECT * FROM arrow_table VALUES where a = 1").fetchall() == [(1, 1, 1)] - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_filter_pushdown_2145(self, duckdb_cursor, tmp_path, create_table): import pandas @@ -697,12 +695,12 @@ def test_filter_pushdown_2145(self, duckdb_cursor, tmp_path, create_table): df2 = pandas.DataFrame(np.random.randn(date2.shape[0], 5), columns=list("ABCDE")) df2["date"] = date2 - data1 = tmp_path / 'data1.parquet' - data2 = tmp_path / 'data2.parquet' + data1 = tmp_path / "data1.parquet" + data2 = tmp_path / "data2.parquet" duckdb_cursor.execute(f"copy (select * from df1) to '{data1.as_posix()}'") duckdb_cursor.execute(f"copy (select * from df2) to '{data2.as_posix()}'") - glob_pattern = tmp_path / 'data*.parquet' + glob_pattern = tmp_path / "data*.parquet" table = duckdb_cursor.read_parquet(glob_pattern.as_posix()).fetch_arrow_table() output_df = duckdb.arrow(table).filter("date > '2019-01-01'").df() @@ -710,7 +708,7 @@ def test_filter_pushdown_2145(self, duckdb_cursor, tmp_path, create_table): pandas.testing.assert_frame_equal(expected_df, output_df) # https://github.com/duckdb/duckdb/pull/4817/files#r1339973721 - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_filter_column_removal(self, duckdb_cursor, create_table): duckdb_cursor.execute( """ @@ -738,7 +736,7 @@ def test_filter_column_removal(self, duckdb_cursor, create_table): assert not match @pytest.mark.skipif(sys.version_info < (3, 9), reason="Requires python 3.9") - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_struct_filter_pushdown(self, duckdb_cursor, create_table): duckdb_cursor.execute( """ @@ -748,10 +746,10 @@ def test_struct_filter_pushdown(self, duckdb_cursor, create_table): duckdb_cursor.execute( """ INSERT INTO test_structs VALUES - ({'a': 1, 'b': true}), - ({'a': 2, 'b': false}), + ({'a': 1, 'b': true}), + ({'a': 2, 'b': false}), (NULL), - ({'a': 3, 'b': true}), + ({'a': 3, 'b': true}), ({'a': NULL, 'b': NULL}); """ ) @@ -768,7 +766,7 @@ def test_struct_filter_pushdown(self, duckdb_cursor, create_table): ).fetchall() input = query_res[0][1] - if 'PANDAS_SCAN' in input: + if "PANDAS_SCAN" in input: pytest.skip(reason="This version of pandas does not produce an Arrow object") match = re.search(r".*ARROW_SCAN.*Filters:.*s\.a<2.*", input, flags=re.DOTALL) assert match @@ -778,7 +776,7 @@ def test_struct_filter_pushdown(self, duckdb_cursor, create_table): query_res = duckdb_cursor.execute( """ - EXPLAIN SELECT * FROM arrow_table WHERE s.a < 3 AND s.b = true + EXPLAIN SELECT * FROM arrow_table WHERE s.a < 3 AND s.b = true """ ).fetchall() @@ -809,7 +807,7 @@ def test_struct_filter_pushdown(self, duckdb_cursor, create_table): assert not match @pytest.mark.skipif(sys.version_info < (3, 9), reason="Requires python 3.9") - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_nested_struct_filter_pushdown(self, duckdb_cursor, create_table): duckdb_cursor.execute( """ @@ -838,20 +836,20 @@ def test_nested_struct_filter_pushdown(self, duckdb_cursor, create_table): ).fetchall() input = query_res[0][1] - if 'PANDAS_SCAN' in input: + if "PANDAS_SCAN" in input: pytest.skip(reason="This version of pandas does not produce an Arrow object") match = re.search(r".*ARROW_SCAN.*Filters:.*s\.a\.b<2.*", input, flags=re.DOTALL) assert match # Check that the filter is applied correctly assert duckdb_cursor.execute("SELECT * FROM arrow_table WHERE s.a.b < 2").fetchone()[0] == { - 'a': {'b': 1, 'c': False}, - 'd': {'e': 2, 'f': 'foo'}, + "a": {"b": 1, "c": False}, + "d": {"e": 2, "f": "foo"}, } query_res = duckdb_cursor.execute( """ - EXPLAIN SELECT * FROM arrow_table WHERE s.a.c=true AND s.d.e=5 + EXPLAIN SELECT * FROM arrow_table WHERE s.a.c=true AND s.d.e=5 """ ).fetchall() @@ -866,8 +864,8 @@ def test_nested_struct_filter_pushdown(self, duckdb_cursor, create_table): # Check that the filter is applied correctly assert duckdb_cursor.execute("SELECT COUNT(*) FROM arrow_table WHERE s.a.c=true AND s.d.e=5").fetchone()[0] == 1 assert duckdb_cursor.execute("SELECT * FROM arrow_table WHERE s.a.c=true AND s.d.e=5").fetchone()[0] == { - 'a': {'b': None, 'c': True}, - 'd': {'e': 5, 'f': 'qux'}, + "a": {"b": None, "c": True}, + "d": {"e": 5, "f": "qux"}, } query_res = duckdb_cursor.execute( @@ -887,8 +885,8 @@ def test_nested_struct_filter_pushdown(self, duckdb_cursor, create_table): # Check that the filter is applied correctly assert duckdb_cursor.execute("SELECT * FROM arrow_table WHERE s.d.f = 'bar'").fetchone()[0] == { - 'a': {'b': 3, 'c': True}, - 'd': {'e': 4, 'f': 'bar'}, + "a": {"b": 3, "c": True}, + "d": {"e": 4, "f": "bar"}, } def test_filter_pushdown_not_supported(self): @@ -899,32 +897,32 @@ def test_filter_pushdown_not_supported(self): arrow_tbl = con.execute("FROM T").fetch_arrow_table() # No projection just unsupported filter - assert con.execute("from arrow_tbl where c == 3").fetchall() == [(3, '3', 3, 3)] + assert con.execute("from arrow_tbl where c == 3").fetchall() == [(3, "3", 3, 3)] # No projection unsupported + supported filter - assert con.execute("from arrow_tbl where c < 4 and a > 2").fetchall() == [(3, '3', 3, 3)] + assert con.execute("from arrow_tbl where c < 4 and a > 2").fetchall() == [(3, "3", 3, 3)] # No projection supported + unsupported + supported filter - assert con.execute("from arrow_tbl where a > 2 and c < 4 and b == '3' ").fetchall() == [(3, '3', 3, 3)] + assert con.execute("from arrow_tbl where a > 2 and c < 4 and b == '3' ").fetchall() == [(3, "3", 3, 3)] assert con.execute("from arrow_tbl where a > 2 and c < 4 and b == '0' ").fetchall() == [] # Projection with unsupported filter column + unsupported + supported filter - assert con.execute("select c, b from arrow_tbl where c < 4 and b == '3' and a > 2 ").fetchall() == [(3, '3')] - assert con.execute("select c, b from arrow_tbl where a > 2 and c < 4 and b == '3'").fetchall() == [(3, '3')] + assert con.execute("select c, b from arrow_tbl where c < 4 and b == '3' and a > 2 ").fetchall() == [(3, "3")] + assert con.execute("select c, b from arrow_tbl where a > 2 and c < 4 and b == '3'").fetchall() == [(3, "3")] # Projection without unsupported filter column + unsupported + supported filter - assert con.execute("select a, b from arrow_tbl where a > 2 and c < 4 and b == '3' ").fetchall() == [(3, '3')] + assert con.execute("select a, b from arrow_tbl where a > 2 and c < 4 and b == '3' ").fetchall() == [(3, "3")] # Lets also experiment with multiple unpush-able filters con.execute( - "CREATE TABLE T_2 as SELECT i::integer a, i::varchar b, i::uhugeint c, i::integer d , i::uhugeint e, i::smallint f, i::uhugeint g FROM range(50) tbl(i)" + "CREATE TABLE T_2 as SELECT i::integer a, i::varchar b, i::uhugeint c, i::integer d , i::uhugeint e, i::smallint f, i::uhugeint g FROM range(50) tbl(i)" # noqa: E501 ) arrow_tbl = con.execute("FROM T_2").fetch_arrow_table() assert con.execute( "select a, b from arrow_tbl where a > 2 and c < 40 and b == '28' and g > 15 and e < 30" - ).fetchall() == [(28, '28')] + ).fetchall() == [(28, "28")] def test_join_filter_pushdown(self, duckdb_cursor): duckdb_conn = duckdb.connect() @@ -951,18 +949,18 @@ def test_in_filter_pushdown(self, duckdb_cursor): def test_pushdown_of_optional_filter(self, duckdb_cursor): cardinality_table = pa.Table.from_pydict( { - 'column_name': [ - 'id', - 'product_code', - 'price', - 'quantity', - 'category', - 'is_available', - 'rating', - 'discount', - 'color', + "column_name": [ + "id", + "product_code", + "price", + "quantity", + "category", + "is_available", + "rating", + "discount", + "color", ], - 'cardinality': [100, 100, 100, 45, 5, 3, 6, 39, 5], + "cardinality": [100, 100, 100, 45, 5, 3, 6, 39, 5], } ) @@ -976,18 +974,19 @@ def test_pushdown_of_optional_filter(self, duckdb_cursor): ) res = result.fetchall() assert res == [ - ('is_available', 3), - ('category', 5), - ('color', 5), - ('rating', 6), - ('discount', 39), - ('quantity', 45), - ('id', 100), - ('product_code', 100), - ('price', 100), + ("is_available", 3), + ("category", 5), + ("color", 5), + ("rating", 6), + ("discount", 39), + ("quantity", 45), + ("id", 100), + ("product_code", 100), + ("price", 100), ] - # DuckDB intentionally violates IEEE-754 when it comes to NaNs, ensuring a total ordering where NaN is the greatest value + # DuckDB intentionally violates IEEE-754 when it comes to NaNs, ensuring a total ordering where NaN is the + # greatest value def test_nan_filter_pushdown(self, duckdb_cursor): duckdb_cursor.execute( """ @@ -1001,12 +1000,12 @@ def test_nan_filter_pushdown(self, duckdb_cursor): """ ) - def assert_equal_results(con, arrow_table, query): - duckdb_res = con.sql(query.format(table='test')).fetchall() - arrow_res = con.sql(query.format(table='arrow_table')).fetchall() + def assert_equal_results(con, arrow_table, query) -> None: + duckdb_res = con.sql(query.format(table="test")).fetchall() + arrow_res = con.sql(query.format(table="arrow_table")).fetchall() assert len(duckdb_res) == len(arrow_res) - arrow_table = duckdb_cursor.table('test').fetch_arrow_table() + arrow_table = duckdb_cursor.table("test").fetch_arrow_table() assert_equal_results(duckdb_cursor, arrow_table, "select * from {table} where a > 'NaN'::FLOAT") assert_equal_results(duckdb_cursor, arrow_table, "select * from {table} where a >= 'NaN'::FLOAT") assert_equal_results(duckdb_cursor, arrow_table, "select * from {table} where a < 'NaN'::FLOAT") diff --git a/tests/fast/arrow/test_integration.py b/tests/fast/arrow/test_integration.py index d9006758..1ec3a603 100644 --- a/tests/fast/arrow/test_integration.py +++ b/tests/fast/arrow/test_integration.py @@ -1,19 +1,21 @@ -import duckdb -import os import datetime +from pathlib import Path + import pytest +import duckdb + pa = pytest.importorskip("pyarrow") pq = pytest.importorskip("pyarrow.parquet") np = pytest.importorskip("numpy") -class TestArrowIntegration(object): +class TestArrowIntegration: def test_parquet_roundtrip(self, duckdb_cursor): - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') - cols = 'id, first_name, last_name, email, gender, ip_address, cc, country, birthdate, salary, title, comments' + parquet_filename = str(Path(__file__).parent / "data" / "userdata1.parquet") + cols = "id, first_name, last_name, email, gender, ip_address, cc, country, birthdate, salary, title, comments" - # TODO timestamp + # TODO timestamp # noqa: TD002, TD003, TD004 userdata_parquet_table = pq.read_table(parquet_filename) userdata_parquet_table.validate(full=True) @@ -35,8 +37,8 @@ def test_parquet_roundtrip(self, duckdb_cursor): assert rel_from_arrow.equals(rel_from_duckdb, check_metadata=True) def test_unsigned_roundtrip(self, duckdb_cursor): - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'unsigned.parquet') - cols = 'a, b, c, d' + parquet_filename = str(Path(__file__).parent / "data" / "unsigned.parquet") + cols = "a, b, c, d" unsigned_parquet_table = pq.read_table(parquet_filename) unsigned_parquet_table.validate(full=True) @@ -49,7 +51,7 @@ def test_unsigned_roundtrip(self, duckdb_cursor): assert rel_from_arrow.equals(rel_from_duckdb, check_metadata=True) duckdb_cursor.execute( - "select NULL c_null, (c % 4 = 0)::bool c_bool, (c%128)::tinyint c_tinyint, c::smallint*1000::INT c_smallint, c::integer*100000 c_integer, c::bigint*1000000000000 c_bigint, c::float c_float, c::double c_double, 'c_' || c::string c_string from (select case when range % 2 == 0 then range else null end as c from range(-10000, 10000)) sq" + "select NULL c_null, (c % 4 = 0)::bool c_bool, (c%128)::tinyint c_tinyint, c::smallint*1000::INT c_smallint, c::integer*100000 c_integer, c::bigint*1000000000000 c_bigint, c::float c_float, c::double c_double, 'c_' || c::string c_string from (select case when range % 2 == 0 then range else null end as c from range(-10000, 10000)) sq" # noqa: E501 ) arrow_result = duckdb_cursor.fetch_arrow_table() arrow_result.validate(full=True) @@ -82,16 +84,16 @@ def test_decimals_roundtrip(self, duckdb_cursor): "SELECT typeof(a), typeof(b), typeof(c),typeof(d) from testarrow" ).fetchone() - assert arrow_result[0] == 'DECIMAL(4,2)' - assert arrow_result[1] == 'DECIMAL(9,2)' - assert arrow_result[2] == 'DECIMAL(18,2)' - assert arrow_result[3] == 'DECIMAL(30,2)' + assert arrow_result[0] == "DECIMAL(4,2)" + assert arrow_result[1] == "DECIMAL(9,2)" + assert arrow_result[2] == "DECIMAL(18,2)" + assert arrow_result[3] == "DECIMAL(30,2)" # Lets also test big number comming from arrow land data = pa.array(np.array([9999999999999999999999999999999999]), type=pa.decimal128(38, 0)) - arrow_tbl = pa.Table.from_arrays([data], ['a']) + arrow_tbl = pa.Table.from_arrays([data], ["a"]) duckdb_cursor.from_arrow(arrow_tbl).create("bigdecimal") - result = duckdb_cursor.execute('select * from bigdecimal') + result = duckdb_cursor.execute("select * from bigdecimal") assert result.fetchone()[0] == 9999999999999999999999999999999999 def test_intervals_roundtrip(self, duckdb_cursor): @@ -110,9 +112,9 @@ def test_intervals_roundtrip(self, duckdb_cursor): arr = [expected_value] data = pa.array(arr, pa.month_day_nano_interval()) - arrow_tbl = pa.Table.from_arrays([data], ['a']) + arrow_tbl = pa.Table.from_arrays([data], ["a"]) duckdb_cursor.from_arrow(arrow_tbl).create("intervaltbl") - duck_arrow_tbl = duckdb_cursor.table("intervaltbl").fetch_arrow_table()['a'] + duck_arrow_tbl = duckdb_cursor.table("intervaltbl").fetch_arrow_table()["a"] assert duck_arrow_tbl[0].value == expected_value @@ -120,7 +122,7 @@ def test_intervals_roundtrip(self, duckdb_cursor): duckdb_cursor.execute("CREATE TABLE test (a INTERVAL)") duckdb_cursor.execute("INSERT INTO test VALUES (INTERVAL 1 YEAR + INTERVAL 1 DAY + INTERVAL 1 SECOND)") expected_value = pa.MonthDayNano([12, 1, 1000000000]) - duck_tbl_arrow = duckdb_cursor.table("test").fetch_arrow_table()['a'] + duck_tbl_arrow = duckdb_cursor.table("test").fetch_arrow_table()["a"] assert duck_tbl_arrow[0].value.months == expected_value.months assert duck_tbl_arrow[0].value.days == expected_value.days assert duck_tbl_arrow[0].value.nanoseconds == expected_value.nanoseconds @@ -140,11 +142,11 @@ def test_null_intervals_roundtrip(self, duckdb_cursor): ) arr = [None, expected_value] data = pa.array(arr, pa.month_day_nano_interval()) - arrow_tbl = pa.Table.from_arrays([data], ['a']) + arrow_tbl = pa.Table.from_arrays([data], ["a"]) duckdb_cursor.from_arrow(arrow_tbl).create("intervalnulltbl") - duckdb_tbl_arrow = duckdb_cursor.table("intervalnulltbl").fetch_arrow_table()['a'] + duckdb_tbl_arrow = duckdb_cursor.table("intervalnulltbl").fetch_arrow_table()["a"] - assert duckdb_tbl_arrow[0].value == None + assert duckdb_tbl_arrow[0].value is None assert duckdb_tbl_arrow[1].value == expected_value def test_nested_interval_roundtrip(self, duckdb_cursor): @@ -154,51 +156,51 @@ def test_nested_interval_roundtrip(self, duckdb_cursor): second_value = pa.MonthDayNano([90, 12, 0]) dictionary = pa.array([first_value, second_value, None]) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array], ['a']) + arrow_table = pa.Table.from_arrays([dict_array], ["a"]) duckdb_cursor.from_arrow(arrow_table).create("dictionarytbl") - duckdb_tbl_arrow = duckdb_cursor.table("dictionarytbl").fetch_arrow_table()['a'] + duckdb_tbl_arrow = duckdb_cursor.table("dictionarytbl").fetch_arrow_table()["a"] assert duckdb_tbl_arrow[0].value == first_value assert duckdb_tbl_arrow[1].value == second_value assert duckdb_tbl_arrow[2].value == first_value assert duckdb_tbl_arrow[3].value == second_value - assert duckdb_tbl_arrow[4].value == None + assert duckdb_tbl_arrow[4].value is None assert duckdb_tbl_arrow[5].value == second_value assert duckdb_tbl_arrow[6].value == first_value - assert duckdb_tbl_arrow[7].value == None + assert duckdb_tbl_arrow[7].value is None # List query = duckdb_cursor.sql( "SELECT a from (select list_value(INTERVAL 3 MONTHS, INTERVAL 5 DAYS, INTERVAL 10 SECONDS, NULL) as a) as t" - ).fetch_arrow_table()['a'] + ).fetch_arrow_table()["a"] assert query[0][0].value == pa.MonthDayNano([3, 0, 0]) assert query[0][1].value == pa.MonthDayNano([0, 5, 0]) assert query[0][2].value == pa.MonthDayNano([0, 0, 10000000000]) - assert query[0][3].value == None + assert query[0][3].value is None # Struct - query = "SELECT a from (SELECT STRUCT_PACK(a := INTERVAL 1 MONTHS, b := INTERVAL 10 DAYS, c:= INTERVAL 20 SECONDS) as a) as t" + query = "SELECT a from (SELECT STRUCT_PACK(a := INTERVAL 1 MONTHS, b := INTERVAL 10 DAYS, c:= INTERVAL 20 SECONDS) as a) as t" # noqa: E501 true_answer = duckdb_cursor.sql(query).fetchall() from_arrow = duckdb_cursor.from_arrow(duckdb_cursor.sql(query).fetch_arrow_table()).fetchall() - assert true_answer[0][0]['a'] == from_arrow[0][0]['a'] - assert true_answer[0][0]['b'] == from_arrow[0][0]['b'] - assert true_answer[0][0]['c'] == from_arrow[0][0]['c'] + assert true_answer[0][0]["a"] == from_arrow[0][0]["a"] + assert true_answer[0][0]["b"] == from_arrow[0][0]["b"] + assert true_answer[0][0]["c"] == from_arrow[0][0]["c"] def test_min_max_interval_roundtrip(self, duckdb_cursor): interval_min_value = pa.MonthDayNano([0, 0, 0]) interval_max_value = pa.MonthDayNano([2147483647, 2147483647, 9223372036854775000]) data = pa.array([interval_min_value, interval_max_value], pa.month_day_nano_interval()) - arrow_tbl = pa.Table.from_arrays([data], ['a']) + arrow_tbl = pa.Table.from_arrays([data], ["a"]) duckdb_cursor.from_arrow(arrow_tbl).create("intervalminmaxtbl") - duck_arrow_tbl = duckdb_cursor.table("intervalminmaxtbl").fetch_arrow_table()['a'] + duck_arrow_tbl = duckdb_cursor.table("intervalminmaxtbl").fetch_arrow_table()["a"] assert duck_arrow_tbl[0].value == pa.MonthDayNano([0, 0, 0]) assert duck_arrow_tbl[1].value == pa.MonthDayNano([2147483647, 2147483647, 9223372036854775000]) def test_duplicate_column_names(self, duckdb_cursor): pd = pytest.importorskip("pandas") - df_a = pd.DataFrame({'join_key': [1, 2, 3], 'col_a': ['a', 'b', 'c']}) - df_b = pd.DataFrame({'join_key': [1, 3, 4], 'col_a': ['x', 'y', 'z']}) + df_a = pd.DataFrame({"join_key": [1, 2, 3], "col_a": ["a", "b", "c"]}) # noqa: F841 + df_b = pd.DataFrame({"join_key": [1, 3, 4], "col_a": ["x", "y", "z"]}) # noqa: F841 res = duckdb_cursor.execute( """ @@ -210,15 +212,15 @@ def test_duplicate_column_names(self, duckdb_cursor): table1.join_key = table2.join_key """ ).fetch_arrow_table() - assert res.schema.names == ['join_key', 'col_a', 'join_key', 'col_a'] + assert res.schema.names == ["join_key", "col_a", "join_key", "col_a"] def test_strings_roundtrip(self, duckdb_cursor): duckdb_cursor.execute("CREATE TABLE test (a varchar)") # Test Small, Null and Very Big String - for i in range(0, 1000): + for _i in range(1000): duckdb_cursor.execute( - "INSERT INTO test VALUES ('Matt Damon'),(NULL), ('Jeffffreeeey Jeeeeef Baaaaaaazos'), ('X-Content-Type-Options')" + "INSERT INTO test VALUES ('Matt Damon'),(NULL), ('Jeffffreeeey Jeeeeef Baaaaaaazos'), ('X-Content-Type-Options')" # noqa: E501 ) true_result = duckdb_cursor.execute("SELECT * from test").fetchall() diff --git a/tests/fast/arrow/test_interval.py b/tests/fast/arrow/test_interval.py index a548818f..5426f39d 100644 --- a/tests/fast/arrow/test_interval.py +++ b/tests/fast/arrow/test_interval.py @@ -1,61 +1,59 @@ -import duckdb -import os -import datetime import pytest +import duckdb + try: import pyarrow as pa - import pandas as pd can_run = True -except: +except Exception: can_run = False -class TestArrowInterval(object): +class TestArrowInterval: def test_duration_types(self, duckdb_cursor): if not can_run: return expected_arrow = pa.Table.from_arrays( - [pa.array([pa.MonthDayNano([0, 0, 1000000000])], type=pa.month_day_nano_interval())], ['a'] + [pa.array([pa.MonthDayNano([0, 0, 1000000000])], type=pa.month_day_nano_interval())], ["a"] ) data = ( - pa.array([1000000000], type=pa.duration('ns')), - pa.array([1000000], type=pa.duration('us')), - pa.array([1000], pa.duration('ms')), - pa.array([1], pa.duration('s')), + pa.array([1000000000], type=pa.duration("ns")), + pa.array([1000000], type=pa.duration("us")), + pa.array([1000], pa.duration("ms")), + pa.array([1], pa.duration("s")), ) - arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ['a', 'b', 'c', 'd']) + arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ["a", "b", "c", "d"]) rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() - assert rel['a'] == expected_arrow['a'] - assert rel['b'] == expected_arrow['a'] - assert rel['c'] == expected_arrow['a'] - assert rel['d'] == expected_arrow['a'] + assert rel["a"] == expected_arrow["a"] + assert rel["b"] == expected_arrow["a"] + assert rel["c"] == expected_arrow["a"] + assert rel["d"] == expected_arrow["a"] def test_duration_null(self, duckdb_cursor): if not can_run: return - expected_arrow = pa.Table.from_arrays([pa.array([None], type=pa.month_day_nano_interval())], ['a']) + expected_arrow = pa.Table.from_arrays([pa.array([None], type=pa.month_day_nano_interval())], ["a"]) data = ( - pa.array([None], type=pa.duration('ns')), - pa.array([None], type=pa.duration('us')), - pa.array([None], pa.duration('ms')), - pa.array([None], pa.duration('s')), + pa.array([None], type=pa.duration("ns")), + pa.array([None], type=pa.duration("us")), + pa.array([None], pa.duration("ms")), + pa.array([None], pa.duration("s")), ) - arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ['a', 'b', 'c', 'd']) + arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ["a", "b", "c", "d"]) rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() - assert rel['a'] == expected_arrow['a'] - assert rel['b'] == expected_arrow['a'] - assert rel['c'] == expected_arrow['a'] - assert rel['d'] == expected_arrow['a'] + assert rel["a"] == expected_arrow["a"] + assert rel["b"] == expected_arrow["a"] + assert rel["c"] == expected_arrow["a"] + assert rel["d"] == expected_arrow["a"] def test_duration_overflow(self, duckdb_cursor): if not can_run: return # Only seconds can overflow - data = pa.array([9223372036854775807], pa.duration('s')) - arrow_table = pa.Table.from_arrays([data], ['a']) + data = pa.array([9223372036854775807], pa.duration("s")) + arrow_table = pa.Table.from_arrays([data], ["a"]) - with pytest.raises(duckdb.ConversionException, match='Could not convert Interval to Microsecond'): - arrow_from_duck = duckdb.from_arrow(arrow_table).fetch_arrow_table() + with pytest.raises(duckdb.ConversionException, match="Could not convert Interval to Microsecond"): + duckdb.from_arrow(arrow_table).fetch_arrow_table() diff --git a/tests/fast/arrow/test_large_offsets.py b/tests/fast/arrow/test_large_offsets.py index 1bcdd1b7..45b078b8 100644 --- a/tests/fast/arrow/test_large_offsets.py +++ b/tests/fast/arrow/test_large_offsets.py @@ -1,9 +1,6 @@ -from re import S -import duckdb -import os import pytest -import tempfile -from conftest import pandas_supports_arrow_backend + +import duckdb pa = pytest.importorskip("pyarrow") pq = pytest.importorskip("pyarrow.parquet") @@ -11,18 +8,19 @@ np = pytest.importorskip("numpy") -class TestArrowLargeOffsets(object): +class TestArrowLargeOffsets: @pytest.mark.skip(reason="CI does not have enough memory to validate this") def test_large_lists(self, duckdb_cursor): ary = pa.array([np.arange(start=0, stop=3000, dtype=np.uint8) for i in range(1_000_000)]) - tbl = pa.Table.from_pydict(dict(col=ary)) + tbl = pa.Table.from_pydict({"col": ary}) # noqa: F841 with pytest.raises( duckdb.InvalidInputException, - match='Arrow Appender: The maximum combined list offset for regular list buffers is 2147483647 but the offset of 2147481000 exceeds this.', + match="Arrow Appender: The maximum combined list offset for regular list buffers is 2147483647 but " + "the offset of 2147481000 exceeds this", ): - res = duckdb_cursor.sql("SELECT col FROM tbl").fetch_arrow_table() + duckdb_cursor.sql("SELECT col FROM tbl").fetch_arrow_table() - tbl2 = pa.Table.from_pydict(dict(col=ary.cast(pa.large_list(pa.uint8())))) + tbl2 = pa.Table.from_pydict({"col": ary.cast(pa.large_list(pa.uint8()))}) # noqa: F841 duckdb_cursor.sql("set arrow_large_buffer_size = true") res2 = duckdb_cursor.sql("SELECT col FROM tbl2").fetch_arrow_table() res2.validate() @@ -30,13 +28,14 @@ def test_large_lists(self, duckdb_cursor): @pytest.mark.skip(reason="CI does not have enough memory to validate this") def test_large_maps(self, duckdb_cursor): ary = pa.array([np.arange(start=3000 * j, stop=3000 * (j + 1), dtype=np.uint64) for j in range(1_000_000)]) - tbl = pa.Table.from_pydict(dict(col=ary)) + tbl = pa.Table.from_pydict({"col": ary}) # noqa: F841 with pytest.raises( duckdb.InvalidInputException, - match='Arrow Appender: The maximum combined list offset for regular list buffers is 2147483647 but the offset of 2147481000 exceeds this.', + match="Arrow Appender: The maximum combined list offset for regular list buffers is 2147483647 but the " + "offset of 2147481000 exceeds this", ): - arrow_map = duckdb_cursor.sql("select map(col, col) from tbl").fetch_arrow_table() + duckdb_cursor.sql("select map(col, col) from tbl").fetch_arrow_table() duckdb_cursor.sql("set arrow_large_buffer_size = true") arrow_map_large = duckdb_cursor.sql("select map(col, col) from tbl").fetch_arrow_table() diff --git a/tests/fast/arrow/test_large_string.py b/tests/fast/arrow/test_large_string.py index 4836048d..e56e5854 100644 --- a/tests/fast/arrow/test_large_string.py +++ b/tests/fast/arrow/test_large_string.py @@ -1,17 +1,14 @@ import duckdb -import os try: import pyarrow as pa - from pyarrow import parquet as pq - import numpy as np can_run = True -except: +except Exception: can_run = False -class TestArrowLargeString(object): +class TestArrowLargeString: def test_large_string_type(self, duckdb_cursor): if not can_run: return @@ -22,4 +19,4 @@ def test_large_string_type(self, duckdb_cursor): rel = duckdb.from_arrow(arrow_table) res = rel.execute().fetchall() - assert res == [('foo',), ('baaaar',), ('b',)] + assert res == [("foo",), ("baaaar",), ("b",)] diff --git a/tests/fast/arrow/test_multiple_reads.py b/tests/fast/arrow/test_multiple_reads.py index 935a8a9c..caa03467 100644 --- a/tests/fast/arrow/test_multiple_reads.py +++ b/tests/fast/arrow/test_multiple_reads.py @@ -1,22 +1,21 @@ +from pathlib import Path + import duckdb -import os try: import pyarrow import pyarrow.parquet can_run = True -except: +except Exception: can_run = False -class TestArrowReads(object): +class TestArrowReads: def test_multiple_queries_same_relation(self, duckdb_cursor): if not can_run: return - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') - cols = 'id, first_name, last_name, email, gender, ip_address, cc, country, birthdate, salary, title, comments' - + parquet_filename = str(Path(__file__).parent / "data" / "userdata1.parquet") userdata_parquet_table = pyarrow.parquet.read_table(parquet_filename) userdata_parquet_table.validate(full=True) rel = duckdb.from_arrow(userdata_parquet_table) diff --git a/tests/fast/arrow/test_nested_arrow.py b/tests/fast/arrow/test_nested_arrow.py index 693a5155..10fbfae0 100644 --- a/tests/fast/arrow/test_nested_arrow.py +++ b/tests/fast/arrow/test_nested_arrow.py @@ -1,7 +1,7 @@ -import duckdb - import pytest +import duckdb + pa = pytest.importorskip("pyarrow") pq = pytest.importorskip("pyarrow.parquet") np = pytest.importorskip("numpy") @@ -16,23 +16,23 @@ def compare_results(duckdb_cursor, query): def arrow_to_pandas(duckdb_cursor, query): - return duckdb_cursor.query(query).fetch_arrow_table().to_pandas()['a'].values.tolist() + return duckdb_cursor.query(query).fetch_arrow_table().to_pandas()["a"].values.tolist() def get_use_list_view_options(): result = [] result.append(False) - if hasattr(pa, 'ListViewArray'): + if hasattr(pa, "ListViewArray"): result.append(True) return result -class TestArrowNested(object): +class TestArrowNested: def test_lists_basic(self, duckdb_cursor): # Test Constant List query = ( duckdb_cursor.query("SELECT a from (select list_value(3,5,10) as a) as t") - .fetch_arrow_table()['a'] + .fetch_arrow_table()["a"] .to_numpy() ) assert query[0][0] == 3 @@ -40,32 +40,32 @@ def test_lists_basic(self, duckdb_cursor): assert query[0][2] == 10 # Empty List - query = duckdb_cursor.query("SELECT a from (select list_value() as a) as t").fetch_arrow_table()['a'].to_numpy() + query = duckdb_cursor.query("SELECT a from (select list_value() as a) as t").fetch_arrow_table()["a"].to_numpy() assert len(query[0]) == 0 # Test Constant List With Null query = ( duckdb_cursor.query("SELECT a from (select list_value(3,NULL) as a) as t") - .fetch_arrow_table()['a'] + .fetch_arrow_table()["a"] .to_numpy() ) assert query[0][0] == 3 assert np.isnan(query[0][1]) - @pytest.mark.parametrize('use_list_view', get_use_list_view_options()) + @pytest.mark.parametrize("use_list_view", get_use_list_view_options()) def test_list_types(self, duckdb_cursor, use_list_view): duckdb_cursor.execute(f"pragma arrow_output_list_view={use_list_view};") # Large Lists data = pa.array([[1], None, [2]], type=pa.large_list(pa.int64())) - arrow_table = pa.Table.from_arrays([data], ['a']) + arrow_table = pa.Table.from_arrays([data], ["a"]) rel = duckdb_cursor.from_arrow(arrow_table) res = rel.execute().fetchall() assert res == [([1],), (None,), ([2],)] # Fixed Size Lists data = pa.array([[1], None, [2]], type=pa.list_(pa.int64(), 1)) - arrow_table = pa.Table.from_arrays([data], ['a']) + arrow_table = pa.Table.from_arrays([data], ["a"]) rel = duckdb_cursor.from_arrow(arrow_table) res = rel.execute().fetchall() assert res == [((1,),), (None,), ((2,),)] @@ -76,27 +76,27 @@ def test_list_types(self, duckdb_cursor, use_list_view): pa.array([[1], None, [2]], type=pa.large_list(pa.int64())), pa.array([[1, 2, 3], None, [2, 1]], type=pa.list_(pa.int64())), ] - arrow_table = pa.Table.from_arrays([data[0], data[1], data[2]], ['a', 'b', 'c']) + arrow_table = pa.Table.from_arrays([data[0], data[1], data[2]], ["a", "b", "c"]) rel = duckdb_cursor.from_arrow(arrow_table) - res = rel.project('a').execute().fetchall() + res = rel.project("a").execute().fetchall() assert res == [((1,),), (None,), ((2,),)] - res = rel.project('b').execute().fetchall() + res = rel.project("b").execute().fetchall() assert res == [([1],), (None,), ([2],)] - res = rel.project('c').execute().fetchall() + res = rel.project("c").execute().fetchall() assert res == [([1, 2, 3],), (None,), ([2, 1],)] # Struct Holding different List Types - struct = [pa.StructArray.from_arrays(data, ['fixed', 'large', 'normal'])] - arrow_table = pa.Table.from_arrays(struct, ['a']) + struct = [pa.StructArray.from_arrays(data, ["fixed", "large", "normal"])] + arrow_table = pa.Table.from_arrays(struct, ["a"]) rel = duckdb_cursor.from_arrow(arrow_table) res = rel.execute().fetchall() assert res == [ - ({'fixed': (1,), 'large': [1], 'normal': [1, 2, 3]},), - ({'fixed': None, 'large': None, 'normal': None},), - ({'fixed': (2,), 'large': [2], 'normal': [2, 1]},), + ({"fixed": (1,), "large": [1], "normal": [1, 2, 3]},), + ({"fixed": None, "large": None, "normal": None},), + ({"fixed": (2,), "large": [2], "normal": [2, 1]},), ] - @pytest.mark.parametrize('use_list_view', get_use_list_view_options()) + @pytest.mark.parametrize("use_list_view", get_use_list_view_options()) @pytest.mark.skip(reason="FIXME: this fails on CI") def test_lists_roundtrip(self, duckdb_cursor, use_list_view): duckdb_cursor.execute(f"pragma arrow_output_list_view={use_list_view};") @@ -127,21 +127,21 @@ def test_lists_roundtrip(self, duckdb_cursor, use_list_view): # LIST[LIST[LIST[LIST[LIST[INTEGER]]]]]] compare_results( duckdb_cursor, - "SELECT list (lllle order by lllle) llllle from (SELECT list (llle order by llle) lllle from (SELECT list(lle order by lle) llle from (SELECT LIST(le order by le) lle FROM (SELECT LIST(i order by i) le from range(100) tbl(i) group by i%10) as t) as t1) as t2) as t3", + "SELECT list (lllle order by lllle) llllle from (SELECT list (llle order by llle) lllle from (SELECT list(lle order by lle) llle from (SELECT LIST(le order by le) lle FROM (SELECT LIST(i order by i) le from range(100) tbl(i) group by i%10) as t) as t1) as t2) as t3", # noqa: E501 ) compare_results( duckdb_cursor, - '''SELECT grp,lst,cs FROM (select grp, lst, case when grp>1 then lst else list_value(null) end as cs - from (SELECT a%4 as grp, list(a order by a) as lst FROM range(7) tbl(a) group by grp) as lst_tbl) as T order by all;''', + """SELECT grp,lst,cs FROM (select grp, lst, case when grp>1 then lst else list_value(null) end as cs + from (SELECT a%4 as grp, list(a order by a) as lst FROM range(7) tbl(a) group by grp) as lst_tbl) as T order by all;""", # noqa: E501 ) # Tests for converting multiple lists to/from Arrow with NULL values and/or strings compare_results( duckdb_cursor, - "SELECT list(st order by st) from (select i, case when i%10 then NULL else i::VARCHAR end as st from range(1000) tbl(i)) as t group by i%5 order by all", + "SELECT list(st order by st) from (select i, case when i%10 then NULL else i::VARCHAR end as st from range(1000) tbl(i)) as t group by i%5 order by all", # noqa: E501 ) - @pytest.mark.parametrize('use_list_view', get_use_list_view_options()) + @pytest.mark.parametrize("use_list_view", get_use_list_view_options()) def test_struct_roundtrip(self, duckdb_cursor, use_list_view): duckdb_cursor.execute(f"pragma arrow_output_list_view={use_list_view};") @@ -156,7 +156,7 @@ def test_struct_roundtrip(self, duckdb_cursor, use_list_view): "SELECT a from (SELECT STRUCT_PACK(a := LIST_VALUE(1,2,3), b := i) as a FROM range(10000) tbl(i)) as t", ) - @pytest.mark.parametrize('use_list_view', get_use_list_view_options()) + @pytest.mark.parametrize("use_list_view", get_use_list_view_options()) def test_map_roundtrip(self, duckdb_cursor, use_list_view): duckdb_cursor.execute(f"pragma arrow_output_list_view={use_list_view};") @@ -171,7 +171,7 @@ def test_map_roundtrip(self, duckdb_cursor, use_list_view): compare_results(duckdb_cursor, "SELECT a from (select MAP(LIST_VALUE(),LIST_VALUE()) as a) as t") compare_results( duckdb_cursor, - "SELECT a from (select MAP(LIST_VALUE('Jon Lajoie', 'Backstreet Boys', 'Tenacious D'),LIST_VALUE(10,9,10)) as a) as t", + "SELECT a from (select MAP(LIST_VALUE('Jon Lajoie', 'Backstreet Boys', 'Tenacious D'),LIST_VALUE(10,9,10)) as a) as t", # noqa: E501 ) compare_results( duckdb_cursor, @@ -182,30 +182,30 @@ def test_map_roundtrip(self, duckdb_cursor, use_list_view): ) compare_results( duckdb_cursor, - "SELECT m from (select MAP(lsta,lstb) as m from (SELECT list(i) as lsta, list(i) as lstb from range(10000) tbl(i) group by i%5 order by all) as lst_tbl) as T", + "SELECT m from (select MAP(lsta,lstb) as m from (SELECT list(i) as lsta, list(i) as lstb from range(10000) tbl(i) group by i%5 order by all) as lst_tbl) as T", # noqa: E501 ) - @pytest.mark.parametrize('use_list_view', get_use_list_view_options()) + @pytest.mark.parametrize("use_list_view", get_use_list_view_options()) def test_map_arrow_to_duckdb(self, duckdb_cursor, use_list_view): duckdb_cursor.execute(f"pragma arrow_output_list_view={use_list_view};") map_type = pa.map_(pa.int32(), pa.int32()) values = [[(3, 12), (3, 21)], [(5, 42)]] - arrow_table = pa.table({'detail': pa.array(values, map_type)}) + arrow_table = pa.table({"detail": pa.array(values, map_type)}) with pytest.raises( duckdb.InvalidInputException, match="Arrow map contains duplicate key, which isn't supported by DuckDB map type", ): - rel = duckdb_cursor.from_arrow(arrow_table).fetchall() + duckdb_cursor.from_arrow(arrow_table).fetchall() def test_null_map_arrow_to_duckdb(self, duckdb_cursor): map_type = pa.map_(pa.int32(), pa.int32()) values = [None, [(5, 42)]] - arrow_table = pa.table({'detail': pa.array(values, map_type)}) + arrow_table = pa.table({"detail": pa.array(values, map_type)}) # noqa: F841 res = duckdb_cursor.sql("select * from arrow_table").fetchall() assert res == [(None,), ({5: 42},)] - @pytest.mark.parametrize('use_list_view', get_use_list_view_options()) + @pytest.mark.parametrize("use_list_view", get_use_list_view_options()) def test_map_arrow_to_pandas(self, duckdb_cursor, use_list_view): duckdb_cursor.execute(f"pragma arrow_output_list_view={use_list_view};") assert arrow_to_pandas( @@ -214,17 +214,17 @@ def test_map_arrow_to_pandas(self, duckdb_cursor, use_list_view): assert arrow_to_pandas(duckdb_cursor, "SELECT a from (select MAP(LIST_VALUE(),LIST_VALUE()) as a) as t") == [[]] assert arrow_to_pandas( duckdb_cursor, - "SELECT a from (select MAP(LIST_VALUE('Jon Lajoie', 'Backstreet Boys', 'Tenacious D'),LIST_VALUE(10,9,10)) as a) as t", - ) == [[('Jon Lajoie', 10), ('Backstreet Boys', 9), ('Tenacious D', 10)]] + "SELECT a from (select MAP(LIST_VALUE('Jon Lajoie', 'Backstreet Boys', 'Tenacious D'),LIST_VALUE(10,9,10)) as a) as t", # noqa: E501 + ) == [[("Jon Lajoie", 10), ("Backstreet Boys", 9), ("Tenacious D", 10)]] assert arrow_to_pandas( duckdb_cursor, "SELECT a from (select MAP(list_value(1), list_value(2)) from range(5) tbl(i)) tbl(a)" ) == [[(1, 2)], [(1, 2)], [(1, 2)], [(1, 2)], [(1, 2)]] assert arrow_to_pandas( duckdb_cursor, "SELECT MAP(LIST_VALUE({'i':1,'j':2},{'i':3,'j':4}),LIST_VALUE({'i':1,'j':2},{'i':3,'j':4})) as a", - ) == [[({'i': 1, 'j': 2}, {'i': 1, 'j': 2}), ({'i': 3, 'j': 4}, {'i': 3, 'j': 4})]] + ) == [[({"i": 1, "j": 2}, {"i": 1, "j": 2}), ({"i": 3, "j": 4}, {"i": 3, "j": 4})]] - @pytest.mark.parametrize('use_list_view', get_use_list_view_options()) + @pytest.mark.parametrize("use_list_view", get_use_list_view_options()) def test_frankstein_nested(self, duckdb_cursor, use_list_view): duckdb_cursor.execute(f"pragma arrow_output_list_view={use_list_view};") @@ -237,13 +237,13 @@ def test_frankstein_nested(self, duckdb_cursor, use_list_view): # Maps embedded in a struct compare_results( duckdb_cursor, - "SELECT {'i':mp,'j':mp2} FROM (SELECT MAP(LIST_VALUE(1, 2, 3, 4),LIST_VALUE(10, 9, 8, 7)) as mp, MAP(LIST_VALUE(1, 2, 3, 5),LIST_VALUE(10, 9, 8, 7)) as mp2) as t", + "SELECT {'i':mp,'j':mp2} FROM (SELECT MAP(LIST_VALUE(1, 2, 3, 4),LIST_VALUE(10, 9, 8, 7)) as mp, MAP(LIST_VALUE(1, 2, 3, 5),LIST_VALUE(10, 9, 8, 7)) as mp2) as t", # noqa: E501 ) # List of maps compare_results( duckdb_cursor, - "SELECT [mp,mp2] FROM (SELECT MAP(LIST_VALUE(1, 2, 3, 4),LIST_VALUE(10, 9, 8, 7)) as mp, MAP(LIST_VALUE(1, 2, 3, 5),LIST_VALUE(10, 9, 8, 7)) as mp2) as t", + "SELECT [mp,mp2] FROM (SELECT MAP(LIST_VALUE(1, 2, 3, 4),LIST_VALUE(10, 9, 8, 7)) as mp, MAP(LIST_VALUE(1, 2, 3, 5),LIST_VALUE(10, 9, 8, 7)) as mp2) as t", # noqa: E501 ) # Map with list as key and/or value @@ -263,5 +263,5 @@ def test_frankstein_nested(self, duckdb_cursor, use_list_view): # MAP that is NULL entirely compare_results( duckdb_cursor, - "SELECT * FROM (VALUES (MAP(LIST_VALUE(1,2),LIST_VALUE(3,4))),(NULL), (MAP(LIST_VALUE(1,2),LIST_VALUE(3,4))), (NULL)) as a", + "SELECT * FROM (VALUES (MAP(LIST_VALUE(1,2),LIST_VALUE(3,4))),(NULL), (MAP(LIST_VALUE(1,2),LIST_VALUE(3,4))), (NULL)) as a", # noqa: E501 ) diff --git a/tests/fast/arrow/test_parallel.py b/tests/fast/arrow/test_parallel.py index 2609d1ae..817da26f 100644 --- a/tests/fast/arrow/test_parallel.py +++ b/tests/fast/arrow/test_parallel.py @@ -1,17 +1,18 @@ +from pathlib import Path + import duckdb -import os try: + import numpy as np import pyarrow import pyarrow.parquet - import numpy as np can_run = True -except: +except Exception: can_run = False -class TestArrowParallel(object): +class TestArrowParallel: def test_parallel_run(self, duckdb_cursor): if not can_run: return @@ -19,7 +20,7 @@ def test_parallel_run(self, duckdb_cursor): duckdb_conn.execute("PRAGMA threads=4") duckdb_conn.execute("PRAGMA verify_parallelism") data = pyarrow.array(np.random.randint(800, size=1000000), type=pyarrow.int32()) - tbl = pyarrow.Table.from_batches(pyarrow.Table.from_arrays([data], ['a']).to_batches(10000)) + tbl = pyarrow.Table.from_batches(pyarrow.Table.from_arrays([data], ["a"]).to_batches(10000)) rel = duckdb_conn.from_arrow(tbl) # Also test multiple reads assert rel.aggregate("(count(a))::INT").execute().fetchone()[0] == 1000000 @@ -32,17 +33,15 @@ def test_parallel_types_and_different_batches(self, duckdb_cursor): duckdb_conn.execute("PRAGMA threads=4") duckdb_conn.execute("PRAGMA verify_parallelism") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') - cols = 'id, first_name, last_name, email, gender, ip_address, cc, country, birthdate, salary, title, comments' - + parquet_filename = str(Path(__file__).parent / "data" / "userdata1.parquet") userdata_parquet_table = pyarrow.parquet.read_table(parquet_filename) for i in [7, 51, 99, 100, 101, 500, 1000, 2000]: data = pyarrow.array(np.arange(3, 7), type=pyarrow.int32()) - tbl = pyarrow.Table.from_arrays([data], ['a']) - rel_id = duckdb_conn.from_arrow(tbl) + tbl = pyarrow.Table.from_arrays([data], ["a"]) + duckdb_conn.from_arrow(tbl) userdata_parquet_table2 = pyarrow.Table.from_batches(userdata_parquet_table.to_batches(i)) rel = duckdb_conn.from_arrow(userdata_parquet_table2) - result = rel.filter("first_name=\'Jose\' and salary > 134708.82").aggregate('count(*)') + result = rel.filter("first_name='Jose' and salary > 134708.82").aggregate("count(*)") assert result.execute().fetchone()[0] == 4 def test_parallel_fewer_batches_than_threads(self, duckdb_cursor): @@ -53,7 +52,7 @@ def test_parallel_fewer_batches_than_threads(self, duckdb_cursor): duckdb_conn.execute("PRAGMA verify_parallelism") data = pyarrow.array(np.random.randint(800, size=1000), type=pyarrow.int32()) - tbl = pyarrow.Table.from_batches(pyarrow.Table.from_arrays([data], ['a']).to_batches(2)) + tbl = pyarrow.Table.from_batches(pyarrow.Table.from_arrays([data], ["a"]).to_batches(2)) rel = duckdb_conn.from_arrow(tbl) # Also test multiple reads assert rel.aggregate("(count(a))::INT").execute().fetchone()[0] == 1000 diff --git a/tests/fast/arrow/test_polars.py b/tests/fast/arrow/test_polars.py index 87e2f726..d5621701 100644 --- a/tests/fast/arrow/test_polars.py +++ b/tests/fast/arrow/test_polars.py @@ -1,13 +1,15 @@ -import duckdb -import pytest -import sys import datetime +import json + +import pytest + +import duckdb pl = pytest.importorskip("polars") arrow = pytest.importorskip("pyarrow") pl_testing = pytest.importorskip("polars.testing") -from duckdb.polars_io import _predicate_to_expression +from duckdb.polars_io import _pl_tree_to_sql, _predicate_to_expression # noqa: E402 def valid_filter(filter): @@ -20,7 +22,7 @@ def invalid_filter(filter): assert sql_expression is None -class TestPolars(object): +class TestPolars: def test_polars(self, duckdb_cursor): df = pl.DataFrame( { @@ -31,21 +33,21 @@ def test_polars(self, duckdb_cursor): } ) # scan plus return a polars dataframe - polars_result = duckdb_cursor.sql('SELECT * FROM df').pl() + polars_result = duckdb_cursor.sql("SELECT * FROM df").pl() pl_testing.assert_frame_equal(df, polars_result) # now do the same for a lazy dataframe - lazy_df = df.lazy() - lazy_result = duckdb_cursor.sql('SELECT * FROM lazy_df').pl() + lazy_df = df.lazy() # noqa: F841 + lazy_result = duckdb_cursor.sql("SELECT * FROM lazy_df").pl() pl_testing.assert_frame_equal(df, lazy_result) con = duckdb.connect() - con_result = con.execute('SELECT * FROM df').pl() + con_result = con.execute("SELECT * FROM df").pl() pl_testing.assert_frame_equal(df, con_result) def test_execute_polars(self, duckdb_cursor): res1 = duckdb_cursor.execute("SELECT 1 AS a, 2 AS a").pl() - assert res1.columns == ['a', 'a_1'] + assert res1.columns == ["a", "a_1"] def test_register_polars(self, duckdb_cursor): con = duckdb.connect() @@ -58,21 +60,21 @@ def test_register_polars(self, duckdb_cursor): } ) # scan plus return a polars dataframe - con.register('polars_df', df) - polars_result = con.execute('select * from polars_df').pl() + con.register("polars_df", df) + polars_result = con.execute("select * from polars_df").pl() pl_testing.assert_frame_equal(df, polars_result) - con.unregister('polars_df') - with pytest.raises(duckdb.CatalogException, match='Table with name polars_df does not exist'): + con.unregister("polars_df") + with pytest.raises(duckdb.CatalogException, match="Table with name polars_df does not exist"): con.execute("SELECT * FROM polars_df;").pl() - con.register('polars_df', df.lazy()) - polars_result = con.execute('select * from polars_df').pl() + con.register("polars_df", df.lazy()) + polars_result = con.execute("select * from polars_df").pl() pl_testing.assert_frame_equal(df, polars_result) def test_empty_polars_dataframe(self, duckdb_cursor): - polars_empty_df = pl.DataFrame() + polars_empty_df = pl.DataFrame() # noqa: F841 with pytest.raises( - duckdb.InvalidInputException, match='Provided table/dataframe must have at least one column' + duckdb.InvalidInputException, match="Provided table/dataframe must have at least one column" ): duckdb_cursor.sql("from polars_empty_df") @@ -82,7 +84,7 @@ def test_polars_from_json(self, duckdb_cursor): duckdb_cursor.sql("set arrow_lossless_conversion=false") string = StringIO("""{"entry":[{"content":{"ManagedSystem":{"test":null}}}]}""") res = duckdb_cursor.read_json(string).pl() - assert str(res['entry'][0][0]) == "{'content': {'ManagedSystem': {'test': None}}}" + assert str(res["entry"][0][0]) == "{'content': {'ManagedSystem': {'test': None}}}" @pytest.mark.skipif( not hasattr(pl.exceptions, "PanicException"), reason="Polars has no PanicException in this version" @@ -92,14 +94,14 @@ def test_polars_from_json_error(self, duckdb_cursor): duckdb_cursor.sql("set arrow_lossless_conversion=true") string = StringIO("""{"entry":[{"content":{"ManagedSystem":{"test":null}}}]}""") - res = duckdb_cursor.read_json(string).pl() - assert duckdb_cursor.execute("FROM res").fetchall() == [([{'content': {'ManagedSystem': {'test': None}}}],)] + with pytest.raises(pl.exceptions.PanicException, match=r"Arrow datatype Extension\(.*\) not supported"): + duckdb_cursor.read_json(string).pl() - def test_polars_from_json_error(self, duckdb_cursor): + def test_polars_from_json_error_2(self, duckdb_cursor): conn = duckdb.connect() - my_table = conn.query("select 'x' my_str").pl() + my_table = conn.query("select 'x' my_str").pl() # noqa: F841 my_res = duckdb.query("select my_str from my_table where my_str != 'y'") - assert my_res.fetchall() == [('x',)] + assert my_res.fetchall() == [("x",)] def test_polars_lazy_from_conn(self, duckdb_cursor): duckdb_conn = duckdb.connect() @@ -107,7 +109,7 @@ def test_polars_lazy_from_conn(self, duckdb_cursor): result = duckdb_conn.execute("SELECT 42 as bla") lazy_df = result.pl(lazy=True) - assert lazy_df.collect().to_dicts() == [{'bla': 42}] + assert lazy_df.collect().to_dicts() == [{"bla": 42}] def test_polars_lazy(self, duckdb_cursor): con = duckdb.connect() @@ -118,43 +120,43 @@ def test_polars_lazy(self, duckdb_cursor): assert isinstance(lazy_df, pl.LazyFrame) assert lazy_df.collect().to_dicts() == [ - {'a': 'Pedro', 'b': 32}, - {'a': 'Mark', 'b': 31}, - {'a': 'Thijs', 'b': 29}, + {"a": "Pedro", "b": 32}, + {"a": "Mark", "b": 31}, + {"a": "Thijs", "b": 29}, ] - assert lazy_df.select('a').collect().to_dicts() == [{'a': 'Pedro'}, {'a': 'Mark'}, {'a': 'Thijs'}] - assert lazy_df.limit(1).collect().to_dicts() == [{'a': 'Pedro', 'b': 32}] + assert lazy_df.select("a").collect().to_dicts() == [{"a": "Pedro"}, {"a": "Mark"}, {"a": "Thijs"}] + assert lazy_df.limit(1).collect().to_dicts() == [{"a": "Pedro", "b": 32}] assert lazy_df.filter(pl.col("b") < 32).collect().to_dicts() == [ - {'a': 'Mark', 'b': 31}, - {'a': 'Thijs', 'b': 29}, + {"a": "Mark", "b": 31}, + {"a": "Thijs", "b": 29}, ] - assert lazy_df.filter(pl.col("b") < 32).select('a').collect().to_dicts() == [{'a': 'Mark'}, {'a': 'Thijs'}] + assert lazy_df.filter(pl.col("b") < 32).select("a").collect().to_dicts() == [{"a": "Mark"}, {"a": "Thijs"}] def test_polars_column_with_tricky_name(self, duckdb_cursor): # Test that a polars DataFrame with a column name that is non standard still works - df_colon = pl.DataFrame({"x:y": [1, 2]}) + df_colon = pl.DataFrame({"x:y": [1, 2]}) # noqa: F841 lf = duckdb_cursor.sql("from df_colon").pl(lazy=True) result = lf.select(pl.all()).collect() assert result.to_dicts() == [{"x:y": 1}, {"x:y": 2}] result = lf.select(pl.all()).filter(pl.col("x:y") == 1).collect() assert result.to_dicts() == [{"x:y": 1}] - df_space = pl.DataFrame({"x y": [1, 2]}) + df_space = pl.DataFrame({"x y": [1, 2]}) # noqa: F841 lf = duckdb_cursor.sql("from df_space").pl(lazy=True) result = lf.select(pl.all()).collect() assert result.to_dicts() == [{"x y": 1}, {"x y": 2}] result = lf.select(pl.all()).filter(pl.col("x y") == 1).collect() assert result.to_dicts() == [{"x y": 1}] - df_dot = pl.DataFrame({"x.y": [1, 2]}) + df_dot = pl.DataFrame({"x.y": [1, 2]}) # noqa: F841 lf = duckdb_cursor.sql("from df_dot").pl(lazy=True) result = lf.select(pl.all()).collect() assert result.to_dicts() == [{"x.y": 1}, {"x.y": 2}] result = lf.select(pl.all()).filter(pl.col("x.y") == 1).collect() assert result.to_dicts() == [{"x.y": 1}] - df_quote = pl.DataFrame({'"xy"': [1, 2]}) + df_quote = pl.DataFrame({'"xy"': [1, 2]}) # noqa: F841 lf = duckdb_cursor.sql("from df_quote").pl(lazy=True) result = lf.select(pl.all()).collect() assert result.to_dicts() == [{'"xy"': 1}, {'"xy"': 2}] @@ -162,23 +164,23 @@ def test_polars_column_with_tricky_name(self, duckdb_cursor): assert result.to_dicts() == [{'"xy"': 1}] @pytest.mark.parametrize( - 'data_type', + "data_type", [ - 'TINYINT', - 'SMALLINT', - 'INTEGER', - 'BIGINT', - 'UTINYINT', - 'USMALLINT', - 'UINTEGER', - 'UBIGINT', - 'FLOAT', - 'DOUBLE', - 'HUGEINT', - 'DECIMAL(4,1)', - 'DECIMAL(9,1)', - 'DECIMAL(18,4)', - 'DECIMAL(30,12)', + "TINYINT", + "SMALLINT", + "INTEGER", + "BIGINT", + "UTINYINT", + "USMALLINT", + "UINTEGER", + "UBIGINT", + "FLOAT", + "DOUBLE", + "HUGEINT", + "DECIMAL(4,1)", + "DECIMAL(9,1)", + "DECIMAL(18,4)", + "DECIMAL(30,12)", ], ) def test_polars_lazy_pushdown_numeric(self, data_type, duckdb_cursor): @@ -272,7 +274,7 @@ def test_polars_lazy_pushdown_bool(self, duckdb_cursor): lazy_df = duck_tbl.pl(lazy=True) # == True - assert lazy_df.filter(pl.col("a") == True).select(pl.len()).collect().item() == 2 + assert lazy_df.filter(pl.col("a")).select(pl.len()).collect().item() == 2 # IS NULL assert lazy_df.filter(pl.col("a").is_null()).select(pl.len()).collect().item() == 1 @@ -281,17 +283,17 @@ def test_polars_lazy_pushdown_bool(self, duckdb_cursor): assert lazy_df.filter(pl.col("a").is_not_null()).select(pl.len()).collect().item() == 3 # AND - assert lazy_df.filter((pl.col("a") == True) & (pl.col("b") == True)).select(pl.len()).collect().item() == 1 + assert lazy_df.filter((pl.col("a")) & (pl.col("b"))).select(pl.len()).collect().item() == 1 # OR - assert lazy_df.filter((pl.col("a") == True) | (pl.col("b") == True)).select(pl.len()).collect().item() == 3 + assert lazy_df.filter((pl.col("a")) | (pl.col("b"))).select(pl.len()).collect().item() == 3 # Validate Filters - valid_filter(pl.col("a") == True) + valid_filter(pl.col("a")) valid_filter(pl.col("a").is_null()) valid_filter(pl.col("a").is_not_null()) - valid_filter((pl.col("a") == True) & (pl.col("b") == True)) - valid_filter((pl.col("a") == True) | (pl.col("b") == True)) + valid_filter((pl.col("a")) & (pl.col("b"))) + valid_filter((pl.col("a")) | (pl.col("b"))) def test_polars_lazy_pushdown_time(self, duckdb_cursor): duckdb_cursor.execute( @@ -388,8 +390,8 @@ def test_polars_lazy_pushdown_timestamp(self, duckdb_cursor): ts_2010 = datetime.datetime(2010, 1, 1, 10, 0, 1) ts_2020 = datetime.datetime(2020, 3, 1, 10, 0, 1) - # These will require a cast, which we currently do not support, hence the filter won't be pushed down, but the results - # Should still be correct, and we check we can't really pushdown the filter yet. + # These will require a cast, which we currently do not support, hence the filter won't be pushed down, but + # the results should still be correct, and we check we can't really pushdown the filter yet. # == assert lazy_df.filter(pl.col("a") == ts_2008).select(pl.len()).collect().item() == 1 @@ -524,9 +526,9 @@ def test_polars_lazy_pushdown_blob(self, duckdb_cursor): df = pandas.DataFrame( { - 'a': [bytes([1]), bytes([2]), bytes([3]), None], - 'b': [bytes([1]), bytes([2]), bytes([3]), None], - 'c': [bytes([1]), bytes([2]), bytes([3]), None], + "a": [bytes([1]), bytes([2]), bytes([3]), None], + "b": [bytes([1]), bytes([2]), bytes([3]), None], + "c": [bytes([1]), bytes([2]), bytes([3]), None], } ) duck_tbl = duckdb.from_df(df) @@ -604,3 +606,50 @@ def test_polars_lazy_many_batches(self, duckdb_cursor): correct = duckdb_cursor.execute("FROM t").fetchall() assert res == correct + + def test_invalid_expr_json(self): + bad_key_expr = """ + { + "BinaryExpr": { + "left": { "Column": "foo" }, + "middle": "Gt", + "right": { "Literal": { "Int": 5 } } + } + } + """ + with pytest.raises(KeyError, match="'op'"): + _pl_tree_to_sql(json.loads(bad_key_expr)) + + bad_type_expr = """ + { + "BinaryExpr": { + "left": { "Column": [ "foo" ] }, + "op": "Gt", + "right": { "Literal": { "Int": 5 } } + } + } + """ + with pytest.raises(AssertionError, match="The col name of a Column should be a str but got"): + _pl_tree_to_sql(json.loads(bad_type_expr)) + + def test_decimal_scale(self): + scalar_decimal_no_scale = """ + { "Scalar": { + "Decimal": [ + 1, + 0 + ] + } } + """ + assert _pl_tree_to_sql(json.loads(scalar_decimal_no_scale)) == "1" + + scalar_decimal_scale = """ + { "Scalar": { + "Decimal": [ + 1, + 38, + 0 + ] + } } + """ + assert _pl_tree_to_sql(json.loads(scalar_decimal_scale)) == "1" diff --git a/tests/fast/arrow/test_progress.py b/tests/fast/arrow/test_progress.py index c20ebe51..e11a3c41 100644 --- a/tests/fast/arrow/test_progress.py +++ b/tests/fast/arrow/test_progress.py @@ -1,14 +1,15 @@ -import duckdb import os + import pytest +import duckdb + pyarrow_parquet = pytest.importorskip("pyarrow.parquet") -import sys -class TestProgressBarArrow(object): +class TestProgressBarArrow: def test_progress_arrow(self): - if os.name == 'nt': + if os.name == "nt": return np = pytest.importorskip("numpy") pyarrow = pytest.importorskip("pyarrow") @@ -18,9 +19,9 @@ def test_progress_arrow(self): duckdb_conn.execute("PRAGMA progress_bar_time=1") duckdb_conn.execute("PRAGMA disable_print_progress_bar") - tbl = pyarrow.Table.from_arrays([data], ['a']) + tbl = pyarrow.Table.from_arrays([data], ["a"]) rel = duckdb_conn.from_arrow(tbl) - result = rel.aggregate('sum(a)') + result = rel.aggregate("sum(a)") assert result.execute().fetchone()[0] == 49999995000000 # Multiple Threads duckdb_conn.execute("PRAGMA threads=4") @@ -28,9 +29,9 @@ def test_progress_arrow(self): assert result.execute().fetchone()[0] == 49999995000000 # More than one batch - tbl = pyarrow.Table.from_batches(pyarrow.Table.from_arrays([data], ['a']).to_batches(100)) + tbl = pyarrow.Table.from_batches(pyarrow.Table.from_arrays([data], ["a"]).to_batches(100)) rel = duckdb_conn.from_arrow(tbl) - result = rel.aggregate('sum(a)') + result = rel.aggregate("sum(a)") assert result.execute().fetchone()[0] == 49999995000000 # Single Thread @@ -40,7 +41,7 @@ def test_progress_arrow(self): assert py_res == 49999995000000 def test_progress_arrow_empty(self): - if os.name == 'nt': + if os.name == "nt": return np = pytest.importorskip("numpy") pyarrow = pytest.importorskip("pyarrow") @@ -50,7 +51,7 @@ def test_progress_arrow_empty(self): duckdb_conn.execute("PRAGMA progress_bar_time=1") duckdb_conn.execute("PRAGMA disable_print_progress_bar") - tbl = pyarrow.Table.from_arrays([data], ['a']) + tbl = pyarrow.Table.from_arrays([data], ["a"]) rel = duckdb_conn.from_arrow(tbl) - result = rel.aggregate('sum(a)') - assert result.execute().fetchone()[0] == None + result = rel.aggregate("sum(a)") + assert result.execute().fetchone()[0] is None diff --git a/tests/fast/arrow/test_projection_pushdown.py b/tests/fast/arrow/test_projection_pushdown.py index 802259e1..fbd258e0 100644 --- a/tests/fast/arrow/test_projection_pushdown.py +++ b/tests/fast/arrow/test_projection_pushdown.py @@ -1,11 +1,9 @@ -import duckdb -import os import pytest -class TestArrowProjectionPushdown(object): +class TestArrowProjectionPushdown: def test_projection_pushdown_no_filter(self, duckdb_cursor): - pa = pytest.importorskip("pyarrow") + pytest.importorskip("pyarrow") ds = pytest.importorskip("pyarrow.dataset") duckdb_cursor.execute( @@ -27,8 +25,8 @@ def test_projection_pushdown_no_filter(self, duckdb_cursor): assert duckdb_cursor.execute("SELECT sum(c) FROM arrow_table").fetchall() == [(333,)] # RecordBatch does not use projection pushdown, test that this also still works - record_batch = arrow_table.to_batches()[0] + record_batch = arrow_table.to_batches()[0] # noqa: F841 assert duckdb_cursor.execute("SELECT sum(c) FROM record_batch").fetchall() == [(333,)] - arrow_dataset = ds.dataset(arrow_table) + arrow_dataset = ds.dataset(arrow_table) # noqa: F841 assert duckdb_cursor.execute("SELECT sum(c) FROM arrow_dataset").fetchall() == [(333,)] diff --git a/tests/fast/arrow/test_time.py b/tests/fast/arrow/test_time.py index 726b0f6a..b9bc5a21 100644 --- a/tests/fast/arrow/test_time.py +++ b/tests/fast/arrow/test_time.py @@ -1,77 +1,73 @@ import duckdb -import os -import datetime -import pytest try: import pyarrow as pa - import pandas as pd can_run = True -except: +except Exception: can_run = False -class TestArrowTime(object): +class TestArrowTime: def test_time_types(self, duckdb_cursor): if not can_run: return data = ( - pa.array([1], type=pa.time32('s')), - pa.array([1000], type=pa.time32('ms')), - pa.array([1000000], pa.time64('us')), - pa.array([1000000000], pa.time64('ns')), + pa.array([1], type=pa.time32("s")), + pa.array([1000], type=pa.time32("ms")), + pa.array([1000000], pa.time64("us")), + pa.array([1000000000], pa.time64("ns")), ) - arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ['a', 'b', 'c', 'd']) + arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ["a", "b", "c", "d"]) rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() - assert rel['a'] == arrow_table['c'] - assert rel['b'] == arrow_table['c'] - assert rel['c'] == arrow_table['c'] - assert rel['d'] == arrow_table['c'] + assert rel["a"] == arrow_table["c"] + assert rel["b"] == arrow_table["c"] + assert rel["c"] == arrow_table["c"] + assert rel["d"] == arrow_table["c"] def test_time_null(self, duckdb_cursor): if not can_run: return data = ( - pa.array([None], type=pa.time32('s')), - pa.array([None], type=pa.time32('ms')), - pa.array([None], pa.time64('us')), - pa.array([None], pa.time64('ns')), + pa.array([None], type=pa.time32("s")), + pa.array([None], type=pa.time32("ms")), + pa.array([None], pa.time64("us")), + pa.array([None], pa.time64("ns")), ) - arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ['a', 'b', 'c', 'd']) + arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ["a", "b", "c", "d"]) rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() - assert rel['a'] == arrow_table['c'] - assert rel['b'] == arrow_table['c'] - assert rel['c'] == arrow_table['c'] - assert rel['d'] == arrow_table['c'] + assert rel["a"] == arrow_table["c"] + assert rel["b"] == arrow_table["c"] + assert rel["c"] == arrow_table["c"] + assert rel["d"] == arrow_table["c"] def test_max_times(self, duckdb_cursor): if not can_run: return - data = pa.array([2147483647000000], type=pa.time64('us')) - result = pa.Table.from_arrays([data], ['a']) + data = pa.array([2147483647000000], type=pa.time64("us")) + result = pa.Table.from_arrays([data], ["a"]) # Max Sec - data = pa.array([2147483647], type=pa.time32('s')) - arrow_table = pa.Table.from_arrays([data], ['a']) + data = pa.array([2147483647], type=pa.time32("s")) + arrow_table = pa.Table.from_arrays([data], ["a"]) rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() - assert rel['a'] == result['a'] + assert rel["a"] == result["a"] # Max MSec - data = pa.array([2147483647000], type=pa.time64('us')) - result = pa.Table.from_arrays([data], ['a']) - data = pa.array([2147483647], type=pa.time32('ms')) - arrow_table = pa.Table.from_arrays([data], ['a']) + data = pa.array([2147483647000], type=pa.time64("us")) + result = pa.Table.from_arrays([data], ["a"]) + data = pa.array([2147483647], type=pa.time32("ms")) + arrow_table = pa.Table.from_arrays([data], ["a"]) rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() - assert rel['a'] == result['a'] + assert rel["a"] == result["a"] # Max NSec - data = pa.array([9223372036854774], type=pa.time64('us')) - result = pa.Table.from_arrays([data], ['a']) - data = pa.array([9223372036854774000], type=pa.time64('ns')) - arrow_table = pa.Table.from_arrays([data], ['a']) + data = pa.array([9223372036854774], type=pa.time64("us")) + result = pa.Table.from_arrays([data], ["a"]) + data = pa.array([9223372036854774000], type=pa.time64("ns")) + arrow_table = pa.Table.from_arrays([data], ["a"]) rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() - print(rel['a']) - print(result['a']) - assert rel['a'] == result['a'] + print(rel["a"]) + print(result["a"]) + assert rel["a"] == result["a"] diff --git a/tests/fast/arrow/test_timestamp_timezone.py b/tests/fast/arrow/test_timestamp_timezone.py index 4fdadf49..7e338626 100644 --- a/tests/fast/arrow/test_timestamp_timezone.py +++ b/tests/fast/arrow/test_timestamp_timezone.py @@ -1,9 +1,11 @@ -import duckdb -import pytest import datetime + +import pytest import pytz -pa = pytest.importorskip('pyarrow') +import duckdb + +pa = pytest.importorskip("pyarrow") def generate_table(current_time, precision, timezone): @@ -13,30 +15,30 @@ def generate_table(current_time, precision, timezone): return pa.Table.from_arrays(inputs, schema=schema) -timezones = ['UTC', 'BET', 'CET', 'Asia/Kathmandu'] +timezones = ["UTC", "BET", "CET", "Asia/Kathmandu"] -class TestArrowTimestampsTimezone(object): +class TestArrowTimestampsTimezone: def test_timestamp_timezone(self, duckdb_cursor): - precisions = ['us', 's', 'ns', 'ms'] + precisions = ["us", "s", "ns", "ms"] current_time = datetime.datetime(2017, 11, 28, 23, 55, 59, tzinfo=pytz.UTC) con = duckdb.connect() con.execute("SET TimeZone = 'UTC'") for precision in precisions: - arrow_table = generate_table(current_time, precision, 'UTC') + arrow_table = generate_table(current_time, precision, "UTC") res_utc = con.from_arrow(arrow_table).execute().fetchall() assert res_utc[0][0] == current_time def test_timestamp_timezone_overflow(self, duckdb_cursor): - precisions = ['s', 'ms'] + precisions = ["s", "ms"] current_time = 9223372036854775807 for precision in precisions: - with pytest.raises(duckdb.ConversionException, match='Could not convert'): - arrow_table = generate_table(current_time, precision, 'UTC') - res_utc = duckdb.from_arrow(arrow_table).execute().fetchall() + arrow_table = generate_table(current_time, precision, "UTC") + with pytest.raises(duckdb.ConversionException, match="Could not convert"): + duckdb.from_arrow(arrow_table).execute().fetchall() def test_timestamp_tz_to_arrow(self, duckdb_cursor): - precisions = ['us', 's', 'ns', 'ms'] + precisions = ["us", "s", "ns", "ms"] current_time = datetime.datetime(2017, 11, 28, 23, 55, 59) con = duckdb.connect() for precision in precisions: @@ -44,16 +46,16 @@ def test_timestamp_tz_to_arrow(self, duckdb_cursor): con.execute("SET TimeZone = '" + timezone + "'") arrow_table = generate_table(current_time, precision, timezone) res = con.from_arrow(arrow_table).fetch_arrow_table() - assert res[0].type == pa.timestamp('us', tz=timezone) - assert res == generate_table(current_time, 'us', timezone) + assert res[0].type == pa.timestamp("us", tz=timezone) + assert res == generate_table(current_time, "us", timezone) def test_timestamp_tz_with_null(self, duckdb_cursor): con = duckdb.connect() con.execute("create table t (i timestamptz)") con.execute("insert into t values (NULL),('2021-11-15 02:30:00'::timestamptz)") - rel = con.table('t') + rel = con.table("t") arrow_tbl = rel.fetch_arrow_table() - con.register('t2', arrow_tbl) + con.register("t2", arrow_tbl) assert con.execute("select * from t").fetchall() == con.execute("select * from t2").fetchall() @@ -61,8 +63,8 @@ def test_timestamp_stream(self, duckdb_cursor): con = duckdb.connect() con.execute("create table t (i timestamptz)") con.execute("insert into t values (NULL),('2021-11-15 02:30:00'::timestamptz)") - rel = con.table('t') + rel = con.table("t") arrow_tbl = rel.record_batch().read_all() - con.register('t2', arrow_tbl) + con.register("t2", arrow_tbl) assert con.execute("select * from t").fetchall() == con.execute("select * from t2").fetchall() diff --git a/tests/fast/arrow/test_timestamps.py b/tests/fast/arrow/test_timestamps.py index c2529c83..b00b7982 100644 --- a/tests/fast/arrow/test_timestamps.py +++ b/tests/fast/arrow/test_timestamps.py @@ -1,77 +1,75 @@ -import duckdb -import os import datetime -import pytest + +import duckdb try: import pyarrow as pa - import pandas as pd can_run = True -except: +except Exception: can_run = False -class TestArrowTimestamps(object): +class TestArrowTimestamps: def test_timestamp_types(self, duckdb_cursor): if not can_run: return data = ( - pa.array([datetime.datetime.now()], type=pa.timestamp('ns')), - pa.array([datetime.datetime.now()], type=pa.timestamp('us')), - pa.array([datetime.datetime.now()], pa.timestamp('ms')), - pa.array([datetime.datetime.now()], pa.timestamp('s')), + pa.array([datetime.datetime.now()], type=pa.timestamp("ns")), + pa.array([datetime.datetime.now()], type=pa.timestamp("us")), + pa.array([datetime.datetime.now()], pa.timestamp("ms")), + pa.array([datetime.datetime.now()], pa.timestamp("s")), ) - arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ['a', 'b', 'c', 'd']) + arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ["a", "b", "c", "d"]) rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() - assert rel['a'] == arrow_table['a'] - assert rel['b'] == arrow_table['b'] - assert rel['c'] == arrow_table['c'] - assert rel['d'] == arrow_table['d'] + assert rel["a"] == arrow_table["a"] + assert rel["b"] == arrow_table["b"] + assert rel["c"] == arrow_table["c"] + assert rel["d"] == arrow_table["d"] def test_timestamp_nulls(self, duckdb_cursor): if not can_run: return data = ( - pa.array([None], type=pa.timestamp('ns')), - pa.array([None], type=pa.timestamp('us')), - pa.array([None], pa.timestamp('ms')), - pa.array([None], pa.timestamp('s')), + pa.array([None], type=pa.timestamp("ns")), + pa.array([None], type=pa.timestamp("us")), + pa.array([None], pa.timestamp("ms")), + pa.array([None], pa.timestamp("s")), ) - arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ['a', 'b', 'c', 'd']) + arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ["a", "b", "c", "d"]) rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() - assert rel['a'] == arrow_table['a'] - assert rel['b'] == arrow_table['b'] - assert rel['c'] == arrow_table['c'] - assert rel['d'] == arrow_table['d'] + assert rel["a"] == arrow_table["a"] + assert rel["b"] == arrow_table["b"] + assert rel["c"] == arrow_table["c"] + assert rel["d"] == arrow_table["d"] def test_timestamp_overflow(self, duckdb_cursor): if not can_run: return data = ( - pa.array([9223372036854775807], pa.timestamp('s')), - pa.array([9223372036854775807], pa.timestamp('ms')), - pa.array([9223372036854775807], pa.timestamp('us')), + pa.array([9223372036854775807], pa.timestamp("s")), + pa.array([9223372036854775807], pa.timestamp("ms")), + pa.array([9223372036854775807], pa.timestamp("us")), ) - arrow_table = pa.Table.from_arrays([data[0], data[1], data[2]], ['a', 'b', 'c']) + arrow_table = pa.Table.from_arrays([data[0], data[1], data[2]], ["a", "b", "c"]) arrow_from_duck = duckdb.from_arrow(arrow_table).fetch_arrow_table() - assert arrow_from_duck['a'] == arrow_table['a'] - assert arrow_from_duck['b'] == arrow_table['b'] - assert arrow_from_duck['c'] == arrow_table['c'] + assert arrow_from_duck["a"] == arrow_table["a"] + assert arrow_from_duck["b"] == arrow_table["b"] + assert arrow_from_duck["c"] == arrow_table["c"] expected = (datetime.datetime(9999, 12, 31, 23, 59, 59, 999999),) duck_rel = duckdb.from_arrow(arrow_table) - res = duck_rel.project('a::TIMESTAMP_US') + res = duck_rel.project("a::TIMESTAMP_US") result = res.fetchone() assert result == expected duck_rel = duckdb.from_arrow(arrow_table) - res = duck_rel.project('b::TIMESTAMP_US') + res = duck_rel.project("b::TIMESTAMP_US") result = res.fetchone() assert result == expected duck_rel = duckdb.from_arrow(arrow_table) - res = duck_rel.project('c::TIMESTAMP_NS') + res = duck_rel.project("c::TIMESTAMP_NS") result = res.fetchone() assert result == expected diff --git a/tests/fast/arrow/test_tpch.py b/tests/fast/arrow/test_tpch.py index ff4a0445..cb6024cf 100644 --- a/tests/fast/arrow/test_tpch.py +++ b/tests/fast/arrow/test_tpch.py @@ -1,13 +1,13 @@ import pytest + import duckdb try: import pyarrow import pyarrow.parquet - import numpy as np can_run = True -except: +except Exception: can_run = False @@ -24,7 +24,7 @@ def check_result(result, answers): db_result = result.fetchone() cq_results = q_res.split("|") # The end of the rows, continue - if cq_results == [''] and str(db_result) == 'None' or str(db_result[0]) == 'None': + if (cq_results == [""] and str(db_result) == "None") or str(db_result[0]) == "None": continue ans_result = [munge(cell) for cell in cq_results] db_result = [munge(cell) for cell in db_result] @@ -34,12 +34,12 @@ def check_result(result, answers): @pytest.mark.skip(reason="Test needs to be adapted to missing TPCH extension") -class TestTPCHArrow(object): +class TestTPCHArrow: def test_tpch_arrow(self, duckdb_cursor): if not can_run: return - tpch_tables = ['part', 'partsupp', 'supplier', 'customer', 'lineitem', 'orders', 'nation', 'region'] + tpch_tables = ["part", "partsupp", "supplier", "customer", "lineitem", "orders", "nation", "region"] arrow_tables = [] duckdb_conn = duckdb.connect() @@ -69,7 +69,7 @@ def test_tpch_arrow_01(self, duckdb_cursor): if not can_run: return - tpch_tables = ['part', 'partsupp', 'supplier', 'customer', 'lineitem', 'orders', 'nation', 'region'] + tpch_tables = ["part", "partsupp", "supplier", "customer", "lineitem", "orders", "nation", "region"] arrow_tables = [] duckdb_conn = duckdb.connect() @@ -97,7 +97,7 @@ def test_tpch_arrow_batch(self, duckdb_cursor): if not can_run: return - tpch_tables = ['part', 'partsupp', 'supplier', 'customer', 'lineitem', 'orders', 'nation', 'region'] + tpch_tables = ["part", "partsupp", "supplier", "customer", "lineitem", "orders", "nation", "region"] arrow_tables = [] duckdb_conn = duckdb.connect() diff --git a/tests/fast/arrow/test_unregister.py b/tests/fast/arrow/test_unregister.py index c63ef0d6..0aceaea1 100644 --- a/tests/fast/arrow/test_unregister.py +++ b/tests/fast/arrow/test_unregister.py @@ -1,47 +1,36 @@ -import pytest -import tempfile import gc -import duckdb -import os +import tempfile +from pathlib import Path -try: - import pyarrow - import pyarrow.parquet +import pytest + +import duckdb - can_run = True -except: - can_run = False +pyarrow = pytest.importorskip("pyarrow") +pytest.importorskip("pyarrow.parquet") -class TestArrowUnregister(object): +class TestArrowUnregister: def test_arrow_unregister1(self, duckdb_cursor): - if not can_run: - return - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') - cols = 'id, first_name, last_name, email, gender, ip_address, cc, country, birthdate, salary, title, comments' - + parquet_filename = str(Path(__file__).parent / "data" / "userdata1.parquet") arrow_table_obj = pyarrow.parquet.read_table(parquet_filename) connection = duckdb.connect(":memory:") connection.register("arrow_table", arrow_table_obj) - arrow_table_2 = connection.execute("SELECT * FROM arrow_table;").fetch_arrow_table() + connection.execute("SELECT * FROM arrow_table;").fetch_arrow_table() connection.unregister("arrow_table") - with pytest.raises(duckdb.CatalogException, match='Table with name arrow_table does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name arrow_table does not exist"): connection.execute("SELECT * FROM arrow_table;").fetch_arrow_table() - with pytest.raises(duckdb.CatalogException, match='View with name arrow_table does not exist'): + with pytest.raises(duckdb.CatalogException, match="View with name arrow_table does not exist"): connection.execute("DROP VIEW arrow_table;") connection.execute("DROP VIEW IF EXISTS arrow_table;") def test_arrow_unregister2(self, duckdb_cursor): - if not can_run: - return - fd, db = tempfile.mkstemp() - os.close(fd) - os.remove(db) + with tempfile.NamedTemporaryFile() as tmp: + db = tmp.name connection = duckdb.connect(db) - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') - cols = 'id, first_name, last_name, email, gender, ip_address, cc, country, birthdate, salary, title, comments' + parquet_filename = str(Path(__file__).parent / "data" / "userdata1.parquet") arrow_table_obj = pyarrow.parquet.read_table(parquet_filename) connection.register("arrow_table", arrow_table_obj) connection.unregister("arrow_table") # Attempting to unregister. @@ -49,7 +38,7 @@ def test_arrow_unregister2(self, duckdb_cursor): # Reconnecting while Arrow Table still in mem. connection = duckdb.connect(db) assert len(connection.execute("PRAGMA show_tables;").fetchall()) == 0 - with pytest.raises(duckdb.CatalogException, match='Table with name arrow_table does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name arrow_table does not exist"): connection.execute("SELECT * FROM arrow_table;").fetch_arrow_table() connection.close() del arrow_table_obj @@ -57,6 +46,6 @@ def test_arrow_unregister2(self, duckdb_cursor): # Reconnecting after Arrow Table is freed. connection = duckdb.connect(db) assert len(connection.execute("PRAGMA show_tables;").fetchall()) == 0 - with pytest.raises(duckdb.CatalogException, match='Table with name arrow_table does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name arrow_table does not exist"): connection.execute("SELECT * FROM arrow_table;").fetch_arrow_table() connection.close() diff --git a/tests/fast/arrow/test_view.py b/tests/fast/arrow/test_view.py index 54acb336..769e9532 100644 --- a/tests/fast/arrow/test_view.py +++ b/tests/fast/arrow/test_view.py @@ -1,16 +1,16 @@ -import duckdb -import os +from pathlib import Path + import pytest pa = pytest.importorskip("pyarrow") pq = pytest.importorskip("pyarrow.parquet") -class TestArrowView(object): +class TestArrowView: def test_arrow_view(self, duckdb_cursor): - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = str(Path(__file__).parent / "data" / "userdata1.parquet") userdata_parquet_table = pa.parquet.read_table(parquet_filename) userdata_parquet_table.validate(full=True) - duckdb_cursor.from_arrow(userdata_parquet_table).create_view('arrow_view') - assert duckdb_cursor.execute("PRAGMA show_tables").fetchone() == ('arrow_view',) + duckdb_cursor.from_arrow(userdata_parquet_table).create_view("arrow_view") + assert duckdb_cursor.execute("PRAGMA show_tables").fetchone() == ("arrow_view",) assert duckdb_cursor.execute("select avg(salary)::INT from arrow_view").fetchone()[0] == 149005 diff --git a/tests/fast/numpy/test_numpy_new_path.py b/tests/fast/numpy/test_numpy_new_path.py index 3735ff6e..66a11f12 100644 --- a/tests/fast/numpy/test_numpy_new_path.py +++ b/tests/fast/numpy/test_numpy_new_path.py @@ -1,15 +1,16 @@ """The support for scaning over numpy arrays reuses many codes for pandas. Therefore, we only test the new codes and exec paths. -""" +""" # noqa: D205 -import sys -import numpy as np -import duckdb from datetime import timedelta + +import numpy as np import pytest +import duckdb + -class TestScanNumpy(object): +class TestScanNumpy: def test_scan_numpy(self, duckdb_cursor): z = np.array([1, 2, 3]) res = duckdb_cursor.sql("select * from z").fetchall() @@ -29,11 +30,11 @@ def test_scan_numpy(self, duckdb_cursor): z = np.array(["zzz", "xxx"]) res = duckdb_cursor.sql("select * from z").fetchall() - assert res == [('zzz',), ('xxx',)] + assert res == [("zzz",), ("xxx",)] z = [np.array(["zzz", "xxx"]), np.array([1, 2])] res = duckdb_cursor.sql("select * from z").fetchall() - assert res == [('zzz', 1), ('xxx', 2)] + assert res == [("zzz", 1), ("xxx", 2)] # test ndarray with dtype = object (python dict) z = [] @@ -42,9 +43,9 @@ def test_scan_numpy(self, duckdb_cursor): z = np.array(z) res = duckdb_cursor.sql("select * from z").fetchall() assert res == [ - ({'3': 0},), - ({'2': 1},), - ({'1': 2},), + ({"3": 0},), + ({"2": 1},), + ({"1": 2},), ] # test timedelta @@ -75,12 +76,12 @@ def test_scan_numpy(self, duckdb_cursor): # dict of mixed types z = {"z": np.array([1, 2, 3]), "x": np.array(["z", "x", "c"])} res = duckdb_cursor.sql("select * from z").fetchall() - assert res == [(1, 'z'), (2, 'x'), (3, 'c')] + assert res == [(1, "z"), (2, "x"), (3, "c")] # list of mixed types z = [np.array([1, 2, 3]), np.array(["z", "x", "c"])] res = duckdb_cursor.sql("select * from z").fetchall() - assert res == [(1, 'z'), (2, 'x'), (3, 'c')] + assert res == [(1, "z"), (2, "x"), (3, "c")] # currently unsupported formats, will throw duckdb.InvalidInputException diff --git a/tests/fast/pandas/test_2304.py b/tests/fast/pandas/test_2304.py index 6fc355e5..c60b1b4a 100644 --- a/tests/fast/pandas/test_2304.py +++ b/tests/fast/pandas/test_2304.py @@ -1,41 +1,43 @@ -import duckdb import numpy as np import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb -class TestPandasMergeSameName(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) +class TestPandasMergeSameName: + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_2304(self, duckdb_cursor, pandas): df1 = pandas.DataFrame( { - 'id_1': [1, 1, 1, 2, 2], - 'agedate': np.array(['2010-01-01', '2010-02-01', '2010-03-01', '2020-02-01', '2020-03-01']).astype( - 'datetime64[D]' + "id_1": [1, 1, 1, 2, 2], + "agedate": np.array(["2010-01-01", "2010-02-01", "2010-03-01", "2020-02-01", "2020-03-01"]).astype( + "datetime64[D]" ), - 'age': [1, 2, 3, 1, 2], - 'v': [1.1, 1.2, 1.3, 2.1, 2.2], + "age": [1, 2, 3, 1, 2], + "v": [1.1, 1.2, 1.3, 2.1, 2.2], } ) df2 = pandas.DataFrame( { - 'id_1': [1, 1, 2], - 'agedate': np.array(['2010-01-01', '2010-02-01', '2020-03-01']).astype('datetime64[D]'), - 'v2': [11.1, 11.2, 21.2], + "id_1": [1, 1, 2], + "agedate": np.array(["2010-01-01", "2010-02-01", "2020-03-01"]).astype("datetime64[D]"), + "v2": [11.1, 11.2, 21.2], } ) con = duckdb.connect() - con.register('df1', df1) - con.register('df2', df2) + con.register("df1", df1) + con.register("df2", df2) query = """SELECT * from df1 LEFT OUTER JOIN df2 - ON (df1.id_1=df2.id_1 and df1.agedate=df2.agedate) order by df1.id_1, df1.agedate, df1.age, df1.v, df2.id_1,df2.agedate,df2.v2""" + ON (df1.id_1=df2.id_1 and df1.agedate=df2.agedate) + order by df1.id_1, df1.agedate, df1.age, df1.v, df2.id_1,df2.agedate,df2.v2""" result_df = con.execute(query).fetchdf() expected_result = con.execute(query).fetchall() - con.register('result_df', result_df) + con.register("result_df", result_df) rel = con.sql( """ select * from result_df order by @@ -52,32 +54,32 @@ def test_2304(self, duckdb_cursor, pandas): assert result == expected_result - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_pd_names(self, duckdb_cursor, pandas): df1 = pandas.DataFrame( { - 'id': [1, 1, 2], - 'id_1': [1, 1, 2], - 'id_3': [1, 1, 2], + "id": [1, 1, 2], + "id_1": [1, 1, 2], + "id_3": [1, 1, 2], } ) - df2 = pandas.DataFrame({'id': [1, 1, 2], 'id_1': [1, 1, 2], 'id_2': [1, 1, 1]}) + df2 = pandas.DataFrame({"id": [1, 1, 2], "id_1": [1, 1, 2], "id_2": [1, 1, 1]}) exp_result = pandas.DataFrame( { - 'id': [1, 1, 2, 1, 1], - 'id_1': [1, 1, 2, 1, 1], - 'id_3': [1, 1, 2, 1, 1], - 'id_2': [1, 1, 2, 1, 1], - 'id_1_1': [1, 1, 2, 1, 1], - 'id_2_1': [1, 1, 1, 1, 1], + "id": [1, 1, 2, 1, 1], + "id_1": [1, 1, 2, 1, 1], + "id_3": [1, 1, 2, 1, 1], + "id_2": [1, 1, 2, 1, 1], + "id_1_1": [1, 1, 2, 1, 1], + "id_2_1": [1, 1, 1, 1, 1], } ) con = duckdb.connect() - con.register('df1', df1) - con.register('df2', df2) + con.register("df1", df1) + con.register("df2", df2) query = """SELECT * from df1 LEFT OUTER JOIN df2 ON (df1.id_1=df2.id_1)""" @@ -85,30 +87,30 @@ def test_pd_names(self, duckdb_cursor, pandas): result_df = con.execute(query).fetchdf() pandas.testing.assert_frame_equal(exp_result, result_df) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_repeat_name(self, duckdb_cursor, pandas): df1 = pandas.DataFrame( { - 'id': [1], - 'id_1': [1], - 'id_2': [1], + "id": [1], + "id_1": [1], + "id_2": [1], } ) - df2 = pandas.DataFrame({'id': [1]}) + df2 = pandas.DataFrame({"id": [1]}) exp_result = pandas.DataFrame( { - 'id': [1], - 'id_1': [1], - 'id_2': [1], - 'id_3': [1], + "id": [1], + "id_1": [1], + "id_2": [1], + "id_3": [1], } ) con = duckdb.connect() - con.register('df1', df1) - con.register('df2', df2) + con.register("df1", df1) + con.register("df2", df2) result_df = con.execute( """ diff --git a/tests/fast/pandas/test_append_df.py b/tests/fast/pandas/test_append_df.py index 18805a5a..d93cfa2d 100644 --- a/tests/fast/pandas/test_append_df.py +++ b/tests/fast/pandas/test_append_df.py @@ -1,38 +1,39 @@ -import duckdb import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb -class TestAppendDF(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) +class TestAppendDF: + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_df_to_table_append(self, duckdb_cursor, pandas): conn = duckdb.connect() conn.execute("Create table integers (i integer)") df_in = pandas.DataFrame( { - 'numbers': [1, 2, 3, 4, 5], + "numbers": [1, 2, 3, 4, 5], } ) - conn.append('integers', df_in) - assert conn.execute('select count(*) from integers').fetchone()[0] == 5 + conn.append("integers", df_in) + assert conn.execute("select count(*) from integers").fetchone()[0] == 5 - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_append_by_name(self, pandas): con = duckdb.connect() con.execute("create table tbl (a integer, b bool, c varchar)") - df_in = pandas.DataFrame({'c': ['duck', 'db'], 'b': [False, True], 'a': [4, 2]}) + df_in = pandas.DataFrame({"c": ["duck", "db"], "b": [False, True], "a": [4, 2]}) # By default we append by position, causing the following exception: with pytest.raises( duckdb.ConversionException, match="Conversion Error: Could not convert string 'duck' to INT32" ): - con.append('tbl', df_in) + con.append("tbl", df_in) # When we use 'by_name' we instead append by name - con.append('tbl', df_in, by_name=True) - res = con.table('tbl').fetchall() - assert res == [(4, False, 'duck'), (2, True, 'db')] + con.append("tbl", df_in, by_name=True) + res = con.table("tbl").fetchall() + assert res == [(4, False, "duck"), (2, True, "db")] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_append_by_name_quoted(self, pandas): con = duckdb.connect() con.execute( @@ -41,32 +42,32 @@ def test_append_by_name_quoted(self, pandas): """ ) df_in = pandas.DataFrame({"needs to be quoted": [1, 2, 3]}) - con.append('tbl', df_in, by_name=True) - res = con.table('tbl').fetchall() + con.append("tbl", df_in, by_name=True) + res = con.table("tbl").fetchall() assert res == [(1, None), (2, None), (3, None)] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_append_by_name_no_exact_match(self, pandas): con = duckdb.connect() con.execute("create table tbl (a integer, b bool)") - df_in = pandas.DataFrame({'c': ['a', 'b'], 'b': [True, False], 'a': [42, 1337]}) + df_in = pandas.DataFrame({"c": ["a", "b"], "b": [True, False], "a": [42, 1337]}) # Too many columns raises an error, because the columns cant be found in the targeted table with pytest.raises(duckdb.BinderException, match='Table "tbl" does not have a column with name "c"'): - con.append('tbl', df_in, by_name=True) + con.append("tbl", df_in, by_name=True) - df_in = pandas.DataFrame({'b': [False, False, False]}) + df_in = pandas.DataFrame({"b": [False, False, False]}) # Not matching all columns is not a problem, as they will be filled with NULL instead - con.append('tbl', df_in, by_name=True) - res = con.table('tbl').fetchall() + con.append("tbl", df_in, by_name=True) + res = con.table("tbl").fetchall() # 'a' got filled by NULL automatically because it wasn't inserted into assert res == [(None, False), (None, False), (None, False)] # Empty the table con.execute("create or replace table tbl (a integer, b bool)") - df_in = pandas.DataFrame({'a': [1, 2, 3]}) - con.append('tbl', df_in, by_name=True) - res = con.table('tbl').fetchall() + df_in = pandas.DataFrame({"a": [1, 2, 3]}) + con.append("tbl", df_in, by_name=True) + res = con.table("tbl").fetchall() # Also works for missing columns *after* the supplied ones assert res == [(1, None), (2, None), (3, None)] diff --git a/tests/fast/pandas/test_bug2281.py b/tests/fast/pandas/test_bug2281.py index 703baf4b..a85517ba 100644 --- a/tests/fast/pandas/test_bug2281.py +++ b/tests/fast/pandas/test_bug2281.py @@ -1,18 +1,15 @@ -import duckdb -import os -import datetime -import pytest -import pandas as pd import io +import pandas as pd + -class TestPandasStringNull(object): +class TestPandasStringNull: def test_pandas_string_null(self, duckdb_cursor): - csv = u'''what,is_control,is_test + csv = """what,is_control,is_test ,0,0 -foo,1,0''' +foo,1,0""" df = pd.read_csv(io.StringIO(csv)) duckdb_cursor.register("c", df) - duckdb_cursor.execute('select what, count(*) from c group by what') - df_result = duckdb_cursor.fetchdf() + duckdb_cursor.execute("select what, count(*) from c group by what") + duckdb_cursor.fetchdf() assert True # Should not crash ^^ diff --git a/tests/fast/pandas/test_bug5922.py b/tests/fast/pandas/test_bug5922.py index af9be167..b75ddf1b 100644 --- a/tests/fast/pandas/test_bug5922.py +++ b/tests/fast/pandas/test_bug5922.py @@ -1,16 +1,17 @@ -import duckdb import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb -class TestPandasAcceptFloat16(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) +class TestPandasAcceptFloat16: + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_pandas_accept_float16(self, duckdb_cursor, pandas): - df = pandas.DataFrame({'col': [1, 2, 3]}) - df16 = df.astype({'col': 'float16'}) + df = pandas.DataFrame({"col": [1, 2, 3]}) + df16 = df.astype({"col": "float16"}) # noqa: F841 con = duckdb.connect() - con.execute('CREATE TABLE tbl AS SELECT * FROM df16') - con.execute('select * from tbl') + con.execute("CREATE TABLE tbl AS SELECT * FROM df16") + con.execute("select * from tbl") df_result = con.fetchdf() - df32 = df.astype({'col': 'float32'}) - assert (df32['col'] == df_result['col']).all() + df32 = df.astype({"col": "float32"}) + assert (df32["col"] == df_result["col"]).all() diff --git a/tests/fast/pandas/test_copy_on_write.py b/tests/fast/pandas/test_copy_on_write.py index dc484f1b..176c2133 100644 --- a/tests/fast/pandas/test_copy_on_write.py +++ b/tests/fast/pandas/test_copy_on_write.py @@ -1,9 +1,11 @@ -import duckdb +import datetime + import pytest +import duckdb + # https://pandas.pydata.org/docs/dev/user_guide/copy_on_write.html -pandas = pytest.importorskip('pandas', '1.5', reason='copy_on_write does not exist in earlier versions') -import datetime +pandas = pytest.importorskip("pandas", "1.5", reason="copy_on_write does not exist in earlier versions") # Make sure the variable get's properly reset even in case of error @@ -21,11 +23,11 @@ def convert_to_result(col): return [(x,) for x in col] -class TestCopyOnWrite(object): +class TestCopyOnWrite: @pytest.mark.parametrize( - 'col', + "col", [ - ['a', 'b', 'this is a long string'], + ["a", "b", "this is a long string"], [1.2334, None, 234.12], [123234, -213123, 2324234], [datetime.date(1990, 12, 7), None, datetime.date(1940, 1, 13)], @@ -33,14 +35,14 @@ class TestCopyOnWrite(object): ], ) def test_copy_on_write(self, col): - assert pandas.options.mode.copy_on_write == True + assert pandas.options.mode.copy_on_write con = duckdb.connect() - df_in = pandas.DataFrame( + df_in = pandas.DataFrame( # noqa: F841 { - 'numbers': col, + "numbers": col, } ) - rel = con.sql('select * from df_in') + rel = con.sql("select * from df_in") res = rel.fetchall() print(res) expected = convert_to_result(col) diff --git a/tests/fast/pandas/test_create_table_from_pandas.py b/tests/fast/pandas/test_create_table_from_pandas.py index 69234dc7..436fd0c8 100644 --- a/tests/fast/pandas/test_create_table_from_pandas.py +++ b/tests/fast/pandas/test_create_table_from_pandas.py @@ -1,13 +1,12 @@ import pytest +from conftest import ArrowPandas, NumpyPandas + import duckdb -import numpy as np -import sys -from conftest import NumpyPandas, ArrowPandas def assert_create(internal_data, expected_result, data_type, pandas): conn = duckdb.connect() - df_in = pandas.DataFrame(data=internal_data, dtype=data_type) + df_in = pandas.DataFrame(data=internal_data, dtype=data_type) # noqa: F841 conn.execute("CREATE TABLE t AS SELECT * FROM df_in") @@ -25,13 +24,11 @@ def assert_create_register(internal_data, expected_result, data_type, pandas): assert result == expected_result -class TestCreateTableFromPandas(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) +class TestCreateTableFromPandas: + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_integer_create_table(self, duckdb_cursor, pandas): - if sys.version_info.major < 3: - return - # FIXME: This should work with other data types e.g., int8... - data_types = ['Int8', 'Int16', 'Int32', 'Int64'] + # TODO: This should work with other data types e.g., int8... # noqa: TD002, TD003 + data_types = ["Int8", "Int16", "Int32", "Int64"] internal_data = [1, 2, 3, 4] expected_result = [(1,), (2,), (3,), (4,)] for data_type in data_types: @@ -39,4 +36,4 @@ def test_integer_create_table(self, duckdb_cursor, pandas): assert_create_register(internal_data, expected_result, data_type, pandas) assert_create(internal_data, expected_result, data_type, pandas) - # FIXME: Also test other data types + # TODO: Also test other data types # noqa: TD002, TD003 diff --git a/tests/fast/pandas/test_date_as_datetime.py b/tests/fast/pandas/test_date_as_datetime.py index 038f24a8..484674ea 100644 --- a/tests/fast/pandas/test_date_as_datetime.py +++ b/tests/fast/pandas/test_date_as_datetime.py @@ -1,13 +1,14 @@ +import datetime + import pandas as pd + import duckdb -import datetime -import pytest def run_checks(df): - assert type(df['d'][0]) is datetime.date - assert df['d'][0] == datetime.date(1992, 7, 30) - assert pd.isnull(df['d'][1]) + assert type(df["d"][0]) is datetime.date + assert df["d"][0] == datetime.date(1992, 7, 30) + assert pd.isnull(df["d"][1]) def test_date_as_datetime(): @@ -22,7 +23,7 @@ def test_date_as_datetime(): run_checks(con.execute("Select * from t").fetch_df(date_as_object=True)) # Relation Methods - rel = con.table('t') + rel = con.table("t") run_checks(rel.df(date_as_object=True)) run_checks(rel.to_df(date_as_object=True)) diff --git a/tests/fast/pandas/test_datetime_time.py b/tests/fast/pandas/test_datetime_time.py index cda96e6b..0b2642b0 100644 --- a/tests/fast/pandas/test_datetime_time.py +++ b/tests/fast/pandas/test_datetime_time.py @@ -1,31 +1,33 @@ -import duckdb +from datetime import datetime, time, timezone + import numpy as np import pytest -from conftest import NumpyPandas, ArrowPandas -from datetime import datetime, timezone, time, timedelta +from conftest import ArrowPandas, NumpyPandas + +import duckdb _ = pytest.importorskip("pandas", minversion="2.0.0") -class TestDateTimeTime(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) +class TestDateTimeTime: + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_time_high(self, duckdb_cursor, pandas): duckdb_time = duckdb_cursor.sql("SELECT make_time(23, 1, 34.234345) AS '0'").df() data = [time(hour=23, minute=1, second=34, microsecond=234345)] - df_in = pandas.DataFrame({'0': pandas.Series(data=data, dtype='object')}) + df_in = pandas.DataFrame({"0": pandas.Series(data=data, dtype="object")}) df_out = duckdb.query_df(df_in, "df", "select * from df").df() pandas.testing.assert_frame_equal(df_out, duckdb_time) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_time_low(self, duckdb_cursor, pandas): duckdb_time = duckdb_cursor.sql("SELECT make_time(00, 01, 1.000) AS '0'").df() data = [time(hour=0, minute=1, second=1)] - df_in = pandas.DataFrame({'0': pandas.Series(data=data, dtype='object')}) + df_in = pandas.DataFrame({"0": pandas.Series(data=data, dtype="object")}) df_out = duckdb.query_df(df_in, "df", "select * from df").df() pandas.testing.assert_frame_equal(df_out, duckdb_time) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) - @pytest.mark.parametrize('input', ['2263-02-28', '9999-01-01']) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("input", ["2263-02-28", "9999-01-01"]) def test_pandas_datetime_big(self, pandas, input): duckdb_con = duckdb.connect() @@ -33,8 +35,8 @@ def test_pandas_datetime_big(self, pandas, input): duckdb_con.execute(f"INSERT INTO TEST VALUES ('{input}')") res = duckdb_con.execute("select * from test").df() - date_value = np.array([f'{input}'], dtype='datetime64[us]') - df = pandas.DataFrame({'date': date_value}) + date_value = np.array([f"{input}"], dtype="datetime64[us]") + df = pandas.DataFrame({"date": date_value}) pandas.testing.assert_frame_equal(res, df) def test_timezone_datetime(self): @@ -45,6 +47,6 @@ def test_timezone_datetime(self): original = dt stringified = str(dt) - original_res = con.execute('select ?::TIMESTAMPTZ', [original]).fetchone() - stringified_res = con.execute('select ?::TIMESTAMPTZ', [stringified]).fetchone() + original_res = con.execute("select ?::TIMESTAMPTZ", [original]).fetchone() + stringified_res = con.execute("select ?::TIMESTAMPTZ", [stringified]).fetchone() assert original_res == stringified_res diff --git a/tests/fast/pandas/test_datetime_timestamp.py b/tests/fast/pandas/test_datetime_timestamp.py index e3b26501..c6d4e3a9 100644 --- a/tests/fast/pandas/test_datetime_timestamp.py +++ b/tests/fast/pandas/test_datetime_timestamp.py @@ -1,29 +1,28 @@ -import duckdb import datetime -import numpy as np + import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas from packaging.version import Version pd = pytest.importorskip("pandas") -class TestDateTimeTimeStamp(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) +class TestDateTimeTimeStamp: + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_timestamp_high(self, pandas, duckdb_cursor): duckdb_time = duckdb_cursor.sql("SELECT '2260-01-01 23:59:00'::TIMESTAMP AS '0'").df() - df_in = pandas.DataFrame( + df_in = pandas.DataFrame( # noqa: F841 { 0: pandas.Series( data=[datetime.datetime(year=2260, month=1, day=1, hour=23, minute=59)], - dtype='datetime64[us]', + dtype="datetime64[us]", ) } ) df_out = duckdb_cursor.sql("select * from df_in").df() pandas.testing.assert_frame_equal(df_out, duckdb_time) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_timestamp_low(self, pandas, duckdb_cursor): duckdb_time = duckdb_cursor.sql( """ @@ -32,27 +31,27 @@ def test_timestamp_low(self, pandas, duckdb_cursor): ).df() df_in = pandas.DataFrame( { - '0': pandas.Series( + "0": pandas.Series( data=[ pandas.Timestamp( datetime.datetime(year=1680, month=1, day=1, hour=23, minute=59, microsecond=234243), - unit='us', + unit="us", ) ], - dtype='datetime64[us]', + dtype="datetime64[us]", ) } ) - print('original:', duckdb_time['0'].dtype) - print('df_in:', df_in['0'].dtype) + print("original:", duckdb_time["0"].dtype) + print("df_in:", df_in["0"].dtype) df_out = duckdb_cursor.sql("select * from df_in").df() - print('df_out:', df_out['0'].dtype) + print("df_out:", df_out["0"].dtype) pandas.testing.assert_frame_equal(df_out, duckdb_time) @pytest.mark.skipif( - Version(pd.__version__) < Version('2.0.2'), reason="pandas < 2.0.2 does not properly convert timezones" + Version(pd.__version__) < Version("2.0.2"), reason="pandas < 2.0.2 does not properly convert timezones" ) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_timestamp_timezone_regular(self, pandas, duckdb_cursor): duckdb_time = duckdb_cursor.sql( """ @@ -62,10 +61,10 @@ def test_timestamp_timezone_regular(self, pandas, duckdb_cursor): offset = datetime.timedelta(hours=-2) timezone = datetime.timezone(offset) - df_in = pandas.DataFrame( + df_in = pandas.DataFrame( # noqa: F841 { 0: pandas.Series( - data=[datetime.datetime(year=2022, month=1, day=1, hour=15, tzinfo=timezone)], dtype='object' + data=[datetime.datetime(year=2022, month=1, day=1, hour=15, tzinfo=timezone)], dtype="object" ) } ) @@ -75,9 +74,9 @@ def test_timestamp_timezone_regular(self, pandas, duckdb_cursor): pandas.testing.assert_frame_equal(df_out, duckdb_time) @pytest.mark.skipif( - Version(pd.__version__) < Version('2.0.2'), reason="pandas < 2.0.2 does not properly convert timezones" + Version(pd.__version__) < Version("2.0.2"), reason="pandas < 2.0.2 does not properly convert timezones" ) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_timestamp_timezone_negative_extreme(self, pandas, duckdb_cursor): duckdb_time = duckdb_cursor.sql( """ @@ -88,10 +87,10 @@ def test_timestamp_timezone_negative_extreme(self, pandas, duckdb_cursor): offset = datetime.timedelta(hours=-19) timezone = datetime.timezone(offset) - df_in = pandas.DataFrame( + df_in = pandas.DataFrame( # noqa: F841 { 0: pandas.Series( - data=[datetime.datetime(year=2021, month=12, day=31, hour=22, tzinfo=timezone)], dtype='object' + data=[datetime.datetime(year=2021, month=12, day=31, hour=22, tzinfo=timezone)], dtype="object" ) } ) @@ -99,9 +98,9 @@ def test_timestamp_timezone_negative_extreme(self, pandas, duckdb_cursor): pandas.testing.assert_frame_equal(df_out, duckdb_time) @pytest.mark.skipif( - Version(pd.__version__) < Version('2.0.2'), reason="pandas < 2.0.2 does not properly convert timezones" + Version(pd.__version__) < Version("2.0.2"), reason="pandas < 2.0.2 does not properly convert timezones" ) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_timestamp_timezone_positive_extreme(self, pandas, duckdb_cursor): duckdb_time = duckdb_cursor.sql( """ @@ -112,10 +111,10 @@ def test_timestamp_timezone_positive_extreme(self, pandas, duckdb_cursor): offset = datetime.timedelta(hours=14) timezone = datetime.timezone(offset) - df_in = pandas.DataFrame( + df_in = pandas.DataFrame( # noqa: F841 { 0: pandas.Series( - data=[datetime.datetime(year=2021, month=12, day=31, hour=23, tzinfo=timezone)], dtype='object' + data=[datetime.datetime(year=2021, month=12, day=31, hour=23, tzinfo=timezone)], dtype="object" ) } ) @@ -123,16 +122,16 @@ def test_timestamp_timezone_positive_extreme(self, pandas, duckdb_cursor): pandas.testing.assert_frame_equal(df_out, duckdb_time) @pytest.mark.skipif( - Version(pd.__version__) < Version('2.0.2'), reason="pandas < 2.0.2 does not properly convert timezones" + Version(pd.__version__) < Version("2.0.2"), reason="pandas < 2.0.2 does not properly convert timezones" ) - @pytest.mark.parametrize('unit', ['ms', 'ns', 's']) + @pytest.mark.parametrize("unit", ["ms", "ns", "s"]) def test_timestamp_timezone_coverage(self, unit, duckdb_cursor): pd = pytest.importorskip("pandas") - ts_df = pd.DataFrame( - {'ts': pd.Series(data=[pd.Timestamp(datetime.datetime(1990, 12, 21))], dtype=f'datetime64[{unit}]')} + ts_df = pd.DataFrame( # noqa: F841 + {"ts": pd.Series(data=[pd.Timestamp(datetime.datetime(1990, 12, 21))], dtype=f"datetime64[{unit}]")} ) - usecond_df = pd.DataFrame( - {'ts': pd.Series(data=[pd.Timestamp(datetime.datetime(1990, 12, 21))], dtype='datetime64[us]')} + usecond_df = pd.DataFrame( # noqa: F841 + {"ts": pd.Series(data=[pd.Timestamp(datetime.datetime(1990, 12, 21))], dtype="datetime64[us]")} ) query = """ @@ -142,12 +141,12 @@ def test_timestamp_timezone_coverage(self, unit, duckdb_cursor): """ duckdb_cursor.sql("set TimeZone = 'UTC'") - utc_usecond = duckdb_cursor.sql(query.format('usecond_df')).df() - utc_other = duckdb_cursor.sql(query.format('ts_df')).df() + utc_usecond = duckdb_cursor.sql(query.format("usecond_df")).df() + utc_other = duckdb_cursor.sql(query.format("ts_df")).df() duckdb_cursor.sql("set TimeZone = 'America/Los_Angeles'") - us_usecond = duckdb_cursor.sql(query.format('usecond_df')).df() - us_other = duckdb_cursor.sql(query.format('ts_df')).df() + us_usecond = duckdb_cursor.sql(query.format("usecond_df")).df() + us_other = duckdb_cursor.sql(query.format("ts_df")).df() pd.testing.assert_frame_equal(utc_usecond, utc_other) pd.testing.assert_frame_equal(us_usecond, us_other) diff --git a/tests/fast/pandas/test_df_analyze.py b/tests/fast/pandas/test_df_analyze.py index 114f8e3f..96cd426d 100644 --- a/tests/fast/pandas/test_df_analyze.py +++ b/tests/fast/pandas/test_df_analyze.py @@ -1,16 +1,16 @@ -import duckdb -import datetime import numpy as np import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb def create_generic_dataframe(data, pandas): - return pandas.DataFrame({'col0': pandas.Series(data=data, dtype='object')}) + return pandas.DataFrame({"col0": pandas.Series(data=data, dtype="object")}) -class TestResolveObjectColumns(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) +class TestResolveObjectColumns: + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_sample_low_correct(self, duckdb_cursor, pandas): print(pandas.backend) duckdb_conn = duckdb.connect() @@ -21,7 +21,7 @@ def test_sample_low_correct(self, duckdb_cursor, pandas): duckdb_df = duckdb_conn.query("select * FROM (VALUES (1000008), (6), (9), (4), (1), (6)) as '0'").df() pandas.testing.assert_frame_equal(duckdb_df, roundtripped_df, check_dtype=False) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_sample_low_incorrect_detected(self, duckdb_cursor, pandas): duckdb_conn = duckdb.connect() duckdb_conn.execute("SET pandas_analyze_sample=2") @@ -31,9 +31,9 @@ def test_sample_low_incorrect_detected(self, duckdb_cursor, pandas): df = create_generic_dataframe(data, pandas) roundtripped_df = duckdb.query_df(df, "x", "select * from x", connection=duckdb_conn).df() # Sample high enough to detect mismatch in types, fallback to VARCHAR - assert roundtripped_df['col0'].dtype == np.dtype('object') + assert roundtripped_df["col0"].dtype == np.dtype("object") - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_sample_zero(self, duckdb_cursor, pandas): duckdb_conn = duckdb.connect() # Disable dataframe analyze @@ -42,12 +42,12 @@ def test_sample_zero(self, duckdb_cursor, pandas): df = create_generic_dataframe(data, pandas) roundtripped_df = duckdb.query_df(df, "x", "select * from x", connection=duckdb_conn).df() # Always converts to VARCHAR - if pandas.backend == 'pyarrow': - assert roundtripped_df['col0'].dtype == np.dtype('int64') + if pandas.backend == "pyarrow": + assert roundtripped_df["col0"].dtype == np.dtype("int64") else: - assert roundtripped_df['col0'].dtype == np.dtype('object') + assert roundtripped_df["col0"].dtype == np.dtype("object") - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_sample_low_incorrect_undetected(self, duckdb_cursor, pandas): duckdb_conn = duckdb.connect() duckdb_conn.execute("SET pandas_analyze_sample=1") @@ -55,7 +55,7 @@ def test_sample_low_incorrect_undetected(self, duckdb_cursor, pandas): df = create_generic_dataframe(data, pandas) # Sample size is too low to detect the mismatch, exception is raised when trying to convert with pytest.raises(duckdb.InvalidInputException, match="Failed to cast value: Unimplemented type for cast"): - roundtripped_df = duckdb.query_df(df, "x", "select * from x", connection=duckdb_conn).df() + duckdb.query_df(df, "x", "select * from x", connection=duckdb_conn).df() def test_reset_analyze_sample_setting(self, duckdb_cursor): duckdb_cursor.execute("SET pandas_analyze_sample=5") @@ -65,10 +65,10 @@ def test_reset_analyze_sample_setting(self, duckdb_cursor): res = duckdb_cursor.execute("select current_setting('pandas_analyze_sample')").fetchall() assert res == [(1000,)] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_10750(self, duckdb_cursor, pandas): max_row_number = 2000 - data = {'id': [i for i in range(max_row_number + 1)], 'content': [None for _ in range(max_row_number + 1)]} + data = {"id": list(range(max_row_number + 1)), "content": [None for _ in range(max_row_number + 1)]} pdf = pandas.DataFrame(data=data) duckdb_cursor.register("content", pdf) diff --git a/tests/fast/pandas/test_df_object_resolution.py b/tests/fast/pandas/test_df_object_resolution.py index ed89f324..58ae0c94 100644 --- a/tests/fast/pandas/test_df_object_resolution.py +++ b/tests/fast/pandas/test_df_object_resolution.py @@ -1,19 +1,22 @@ -import duckdb +# ruff: noqa: F841 import datetime -import numpy as np -import platform -import pytest import decimal import math -from decimal import Decimal +import platform import re -from conftest import NumpyPandas, ArrowPandas +from decimal import Decimal + +import numpy as np +import pytest +from conftest import ArrowPandas, NumpyPandas + +import duckdb standard_vector_size = duckdb.__standard_vector_size__ def create_generic_dataframe(data, pandas): - return pandas.DataFrame({'0': pandas.Series(data=data, dtype='object')}) + return pandas.DataFrame({"0": pandas.Series(data=data, dtype="object")}) def create_repeated_nulls(size): @@ -25,15 +28,15 @@ def create_repeated_nulls(size): def create_trailing_non_null(size): data = [None for _ in range(size - 1)] - data.append('this is a long string') + data.append("this is a long string") return data class IntString: - def __init__(self, value: int): + def __init__(self, value: int) -> None: self.value = value - def __str__(self): + def __str__(self) -> str: return str(self.value) @@ -43,16 +46,16 @@ def ConvertStringToDecimal(data: list, pandas): for i in range(len(data)): if isinstance(data[i], str): data[i] = decimal.Decimal(data[i]) - data = pandas.Series(data=data, dtype='object') + data = pandas.Series(data=data, dtype="object") return data class ObjectPair: - def __init__(self, obj1, obj2): + def __init__(self, obj1, obj2) -> None: self.first = obj1 self.second = obj2 - def __repr__(self): + def __repr__(self) -> str: return str([self.first, self.second]) @@ -61,13 +64,13 @@ def construct_list(pair): def construct_struct(pair): - return [{'v1': pair.first}, {'v1': pair.second}] + return [{"v1": pair.first}, {"v1": pair.second}] def construct_map(pair): return [ - {'key': ['v1', 'v2'], "value": [pair.first, pair.first]}, - {'key': ['v1', 'v2'], "value": [pair.second, pair.second]}, + {"key": ["v1", "v2"], "value": [pair.first, pair.first]}, + {"key": ["v1", "v2"], "value": [pair.second, pair.second]}, ] @@ -81,159 +84,160 @@ def check_struct_upgrade(expected_type: str, creation_method, pair: ObjectPair, assert expected_type == rel.types[0] -class TestResolveObjectColumns(object): - # TODO: add support for ArrowPandas - @pytest.mark.parametrize('pandas', [NumpyPandas()]) +class TestResolveObjectColumns: + # TODO: add support for ArrowPandas # noqa: TD002, TD003 + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_integers(self, pandas, duckdb_cursor): data = [5, 0, 3] df_in = create_generic_dataframe(data, pandas) - # These are float64 because pandas would force these to be float64 even if we set them to int8, int16, int32, int64 respectively - df_expected_res = pandas.DataFrame({'0': pandas.Series(data=data, dtype='int32')}) + # These are float64 because pandas would force these to be float64 even if we set them to int8, int16, + # int32, int64 respectively + df_expected_res = pandas.DataFrame({"0": pandas.Series(data=data, dtype="int32")}) df_out = duckdb_cursor.sql("SELECT * FROM df_in").df() print(df_out) pandas.testing.assert_frame_equal(df_expected_res, df_out) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_struct_correct(self, pandas, duckdb_cursor): - data = [{'a': 1, 'b': 3, 'c': 3, 'd': 7}] - df = pandas.DataFrame({'0': pandas.Series(data=data)}) + data = [{"a": 1, "b": 3, "c": 3, "d": 7}] + df = pandas.DataFrame({"0": pandas.Series(data=data)}) duckdb_col = duckdb_cursor.sql("SELECT {a: 1, b: 3, c: 3, d: 7} as '0'").df() converted_col = duckdb_cursor.sql("SELECT * FROM df").df() pandas.testing.assert_frame_equal(duckdb_col, converted_col) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_map_fallback_different_keys(self, pandas, duckdb_cursor): x = pandas.DataFrame( [ - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'e': 7}], #'e' instead of 'd' as key - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "e": 7}], #'e' instead of 'd' as key + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], ] ) converted_df = duckdb_cursor.sql("SELECT * FROM x").df() y = pandas.DataFrame( [ - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'e'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "e"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], ] ) equal_df = duckdb_cursor.sql("SELECT * FROM y").df() pandas.testing.assert_frame_equal(converted_df, equal_df) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_map_fallback_incorrect_amount_of_keys(self, pandas, duckdb_cursor): x = pandas.DataFrame( [ - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3}], # incorrect amount of keys - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3}], # incorrect amount of keys + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], ] ) converted_df = duckdb_cursor.sql("SELECT * FROM x").df() y = pandas.DataFrame( [ - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c'], 'value': [1, 3, 3]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c"], "value": [1, 3, 3]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], ] ) equal_df = duckdb_cursor.sql("SELECT * FROM y").df() pandas.testing.assert_frame_equal(converted_df, equal_df) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_struct_value_upgrade(self, pandas, duckdb_cursor): x = pandas.DataFrame( [ - [{'a': 1, 'b': 3, 'c': 3, 'd': 'string'}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], + [{"a": 1, "b": 3, "c": 3, "d": "string"}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], ] ) y = pandas.DataFrame( [ - [{'a': 1, 'b': 3, 'c': 3, 'd': 'string'}], - [{'a': 1, 'b': 3, 'c': 3, 'd': '7'}], - [{'a': 1, 'b': 3, 'c': 3, 'd': '7'}], - [{'a': 1, 'b': 3, 'c': 3, 'd': '7'}], - [{'a': 1, 'b': 3, 'c': 3, 'd': '7'}], + [{"a": 1, "b": 3, "c": 3, "d": "string"}], + [{"a": 1, "b": 3, "c": 3, "d": "7"}], + [{"a": 1, "b": 3, "c": 3, "d": "7"}], + [{"a": 1, "b": 3, "c": 3, "d": "7"}], + [{"a": 1, "b": 3, "c": 3, "d": "7"}], ] ) converted_df = duckdb_cursor.sql("SELECT * FROM x").df() equal_df = duckdb_cursor.sql("SELECT * FROM y").df() pandas.testing.assert_frame_equal(converted_df, equal_df) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_struct_null(self, pandas, duckdb_cursor): x = pandas.DataFrame( [ [None], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], ] ) y = pandas.DataFrame( [ [None], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], ] ) converted_df = duckdb_cursor.sql("SELECT * FROM x").df() equal_df = duckdb_cursor.sql("SELECT * FROM y").df() pandas.testing.assert_frame_equal(converted_df, equal_df) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_map_fallback_value_upgrade(self, pandas, duckdb_cursor): x = pandas.DataFrame( [ - [{'a': 1, 'b': 3, 'c': 3, 'd': 'test'}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], + [{"a": 1, "b": 3, "c": 3, "d": "test"}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], ] ) y = pandas.DataFrame( [ - [{'a': '1', 'b': '3', 'c': '3', 'd': 'test'}], - [{'a': '1', 'b': '3', 'c': '3', 'd': '7'}], - [{'a': '1', 'b': '3', 'c': '3'}], - [{'a': '1', 'b': '3', 'c': '3', 'd': '7'}], - [{'a': '1', 'b': '3', 'c': '3', 'd': '7'}], + [{"a": "1", "b": "3", "c": "3", "d": "test"}], + [{"a": "1", "b": "3", "c": "3", "d": "7"}], + [{"a": "1", "b": "3", "c": "3"}], + [{"a": "1", "b": "3", "c": "3", "d": "7"}], + [{"a": "1", "b": "3", "c": "3", "d": "7"}], ] ) converted_df = duckdb_cursor.sql("SELECT * FROM x").df() equal_df = duckdb_cursor.sql("SELECT * FROM y").df() pandas.testing.assert_frame_equal(converted_df, equal_df) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_map_correct(self, pandas, duckdb_cursor): x = pandas.DataFrame( [ - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], ] ) - x.rename(columns={0: 'a'}, inplace=True) + x.rename(columns={0: "a"}, inplace=True) converted_col = duckdb_cursor.sql("select * from x as 'a'").df() duckdb_cursor.sql( """ @@ -253,10 +257,10 @@ def test_map_correct(self, pandas, duckdb_cursor): print(converted_col.columns) pandas.testing.assert_frame_equal(converted_col, duckdb_col) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) - @pytest.mark.parametrize('sample_size', [1, 10]) - @pytest.mark.parametrize('fill', [1000, 10000]) - @pytest.mark.parametrize('get_data', [create_repeated_nulls, create_trailing_non_null]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("sample_size", [1, 10]) + @pytest.mark.parametrize("fill", [1000, 10000]) + @pytest.mark.parametrize("get_data", [create_repeated_nulls, create_trailing_non_null]) def test_analyzing_nulls(self, pandas, duckdb_cursor, fill, sample_size, get_data): data = get_data(fill) df1 = pandas.DataFrame(data={"col1": data}) @@ -265,9 +269,9 @@ def test_analyzing_nulls(self, pandas, duckdb_cursor, fill, sample_size, get_dat pandas.testing.assert_frame_equal(df1, df) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_nested_map(self, pandas, duckdb_cursor): - df = pandas.DataFrame(data={'col1': [{'a': {'b': {'x': 'A', 'y': 'B'}}}, {'c': {'b': {'x': 'A'}}}]}) + df = pandas.DataFrame(data={"col1": [{"a": {"b": {"x": "A", "y": "B"}}}, {"c": {"b": {"x": "A"}}}]}) rel = duckdb_cursor.sql("select * from df") expected_rel = duckdb_cursor.sql( @@ -283,18 +287,18 @@ def test_nested_map(self, pandas, duckdb_cursor): expected_res = str(expected_rel) assert res == expected_res - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_map_value_upgrade(self, pandas, duckdb_cursor): x = pandas.DataFrame( [ - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 'test']}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, "test"]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], ] ) - x.rename(columns={0: 'a'}, inplace=True) + x.rename(columns={0: "a"}, inplace=True) converted_col = duckdb_cursor.sql("select * from x").df() duckdb_cursor.sql( """ @@ -319,69 +323,66 @@ def test_map_value_upgrade(self, pandas, duckdb_cursor): print(converted_col.columns) pandas.testing.assert_frame_equal(converted_col, duckdb_col) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_map_duplicate(self, pandas, duckdb_cursor): - x = pandas.DataFrame([[{'key': ['a', 'a', 'b'], 'value': [4, 0, 4]}]]) - with pytest.raises(duckdb.InvalidInputException, match="Map keys must be unique."): + x = pandas.DataFrame([[{"key": ["a", "a", "b"], "value": [4, 0, 4]}]]) + with pytest.raises(duckdb.InvalidInputException, match="Map keys must be unique"): duckdb_cursor.sql("select * from x").show() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_map_nullkey(self, pandas, duckdb_cursor): - x = pandas.DataFrame([[{'key': [None, 'a', 'b'], 'value': [4, 0, 4]}]]) - with pytest.raises(duckdb.InvalidInputException, match="Map keys can not be NULL."): + x = pandas.DataFrame([[{"key": [None, "a", "b"], "value": [4, 0, 4]}]]) + with pytest.raises(duckdb.InvalidInputException, match="Map keys can not be NULL"): converted_col = duckdb_cursor.sql("select * from x").df() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_map_nullkeylist(self, pandas, duckdb_cursor): - x = pandas.DataFrame([[{'key': None, 'value': None}]]) + x = pandas.DataFrame([[{"key": None, "value": None}]]) converted_col = duckdb_cursor.sql("select * from x").df() duckdb_col = duckdb_cursor.sql("SELECT MAP(NULL, NULL) as '0'").df() pandas.testing.assert_frame_equal(duckdb_col, converted_col) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_map_fallback_nullkey(self, pandas, duckdb_cursor): - x = pandas.DataFrame([[{'a': 4, None: 0, 'c': 4}], [{'a': 4, None: 0, 'd': 4}]]) - with pytest.raises(duckdb.InvalidInputException, match="Map keys can not be NULL."): + x = pandas.DataFrame([[{"a": 4, None: 0, "c": 4}], [{"a": 4, None: 0, "d": 4}]]) + with pytest.raises(duckdb.InvalidInputException, match="Map keys can not be NULL"): converted_col = duckdb_cursor.sql("select * from x").df() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_map_fallback_nullkey_coverage(self, pandas, duckdb_cursor): x = pandas.DataFrame( [ - [{'key': None, 'value': None}], - [{'key': None, None: 5}], + [{"key": None, "value": None}], + [{"key": None, None: 5}], ] ) - with pytest.raises(duckdb.InvalidInputException, match="Map keys can not be NULL."): + with pytest.raises(duckdb.InvalidInputException, match="Map keys can not be NULL"): converted_col = duckdb_cursor.sql("select * from x").df() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_structs_in_nested_types(self, pandas, duckdb_cursor): # This test is testing a bug that occurred when type upgrades occurred inside nested types # STRUCT(key1 varchar) + STRUCT(key1 varchar, key2 varchar) turns into MAP # But when inside a nested structure, this upgrade did not happen properly pairs = { - 'v1': ObjectPair({'key1': 21}, {'key1': 21, 'key2': 42}), - 'v2': ObjectPair({'key1': 21}, {'key2': 21}), - 'v3': ObjectPair({'key1': 21, 'key2': 42}, {'key1': 21}), - 'v4': ObjectPair({}, {'key1': 21}), + "v1": ObjectPair({"key1": 21}, {"key1": 21, "key2": 42}), + "v2": ObjectPair({"key1": 21}, {"key2": 21}), + "v3": ObjectPair({"key1": 21, "key2": 42}, {"key1": 21}), + "v4": ObjectPair({}, {"key1": 21}), } - for _, pair in pairs.items(): - check_struct_upgrade('MAP(VARCHAR, INTEGER)[]', construct_list, pair, pandas, duckdb_cursor) + for pair in pairs.values(): + check_struct_upgrade("MAP(VARCHAR, INTEGER)[]", construct_list, pair, pandas, duckdb_cursor) for key, pair in pairs.items(): - if key == 'v4': - expected_type = 'MAP(VARCHAR, MAP(VARCHAR, INTEGER))' - else: - expected_type = 'STRUCT(v1 MAP(VARCHAR, INTEGER))' + expected_type = "MAP(VARCHAR, MAP(VARCHAR, INTEGER))" if key == "v4" else "STRUCT(v1 MAP(VARCHAR, INTEGER))" check_struct_upgrade(expected_type, construct_struct, pair, pandas, duckdb_cursor) - for key, pair in pairs.items(): - check_struct_upgrade('MAP(VARCHAR, MAP(VARCHAR, INTEGER))', construct_map, pair, pandas, duckdb_cursor) + for pair in pairs.values(): + check_struct_upgrade("MAP(VARCHAR, MAP(VARCHAR, INTEGER))", construct_map, pair, pandas, duckdb_cursor) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_structs_of_different_sizes(self, pandas, duckdb_cursor): # This list has both a STRUCT(v1) and a STRUCT(v1, v2) member # Those can't be combined @@ -404,9 +405,9 @@ def test_structs_of_different_sizes(self, pandas, duckdb_cursor): ) res = duckdb_cursor.query("select typeof(col) from df").fetchall() # So we fall back to converting them as VARCHAR instead - assert res == [('MAP(VARCHAR, VARCHAR)[]',), ('MAP(VARCHAR, VARCHAR)[]',)] + assert res == [("MAP(VARCHAR, VARCHAR)[]",), ("MAP(VARCHAR, VARCHAR)[]",)] - malformed_struct = duckdb.Value({"v1": 1, "v2": 2}, duckdb.struct_type({'v1': int})) + malformed_struct = duckdb.Value({"v1": 1, "v2": 2}, duckdb.struct_type({"v1": int})) with pytest.raises( duckdb.InvalidInputException, match=re.escape( @@ -414,9 +415,8 @@ def test_structs_of_different_sizes(self, pandas, duckdb_cursor): ), ): res = duckdb_cursor.execute("select $1", [malformed_struct]) - print(res) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_struct_key_conversion(self, pandas, duckdb_cursor): x = pandas.DataFrame( [ @@ -428,48 +428,48 @@ def test_struct_key_conversion(self, pandas, duckdb_cursor): duckdb_cursor.sql("drop view if exists tbl") pandas.testing.assert_frame_equal(duckdb_col, converted_col) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_list_correct(self, pandas, duckdb_cursor): - x = pandas.DataFrame([{'0': [[5], [34], [-245]]}]) + x = pandas.DataFrame([{"0": [[5], [34], [-245]]}]) duckdb_col = duckdb_cursor.sql("select [[5], [34], [-245]] as '0'").df() converted_col = duckdb_cursor.sql("select * from x").df() duckdb_cursor.sql("drop view if exists tbl") pandas.testing.assert_frame_equal(duckdb_col, converted_col) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_list_contains_null(self, pandas, duckdb_cursor): - x = pandas.DataFrame([{'0': [[5], None, [-245]]}]) + x = pandas.DataFrame([{"0": [[5], None, [-245]]}]) duckdb_col = duckdb_cursor.sql("select [[5], NULL, [-245]] as '0'").df() converted_col = duckdb_cursor.sql("select * from x").df() duckdb_cursor.sql("drop view if exists tbl") pandas.testing.assert_frame_equal(duckdb_col, converted_col) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_list_starts_with_null(self, pandas, duckdb_cursor): - x = pandas.DataFrame([{'0': [None, [5], [-245]]}]) + x = pandas.DataFrame([{"0": [None, [5], [-245]]}]) duckdb_col = duckdb_cursor.sql("select [NULL, [5], [-245]] as '0'").df() converted_col = duckdb_cursor.sql("select * from x").df() duckdb_cursor.sql("drop view if exists tbl") pandas.testing.assert_frame_equal(duckdb_col, converted_col) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_list_value_upgrade(self, pandas, duckdb_cursor): - x = pandas.DataFrame([{'0': [['5'], [34], [-245]]}]) + x = pandas.DataFrame([{"0": [["5"], [34], [-245]]}]) duckdb_rel = duckdb_cursor.sql("select [['5'], ['34'], ['-245']] as '0'") duckdb_col = duckdb_rel.df() converted_col = duckdb_cursor.sql("select * from x").df() pandas.testing.assert_frame_equal(duckdb_col, converted_col) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_list_column_value_upgrade(self, pandas, duckdb_cursor): x = pandas.DataFrame( [ [[1, 25, 300]], [[500, 345, 30]], - [[50, 'a', 67]], + [[50, "a", 67]], ] ) - x.rename(columns={0: 'a'}, inplace=True) + x.rename(columns={0: "a"}, inplace=True) converted_col = duckdb_cursor.sql("select * from x").df() duckdb_cursor.sql( """ @@ -498,29 +498,29 @@ def test_list_column_value_upgrade(self, pandas, duckdb_cursor): print(converted_col.columns) pandas.testing.assert_frame_equal(converted_col, duckdb_col) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_ubigint_object_conversion(self, pandas, duckdb_cursor): # UBIGINT + TINYINT would result in HUGEINT, but conversion to HUGEINT is not supported yet from pandas->duckdb # So this instead becomes a DOUBLE data = [18446744073709551615, 0] - x = pandas.DataFrame({'0': pandas.Series(data=data, dtype='object')}) + x = pandas.DataFrame({"0": pandas.Series(data=data, dtype="object")}) converted_col = duckdb_cursor.sql("select * from x").df() - if pandas.backend == 'numpy_nullable': - float64 = np.dtype('float64') - assert isinstance(converted_col['0'].dtype, float64.__class__) == True + if pandas.backend == "numpy_nullable": + float64 = np.dtype("float64") + assert isinstance(converted_col["0"].dtype, float64.__class__) else: - uint64 = np.dtype('uint64') - assert isinstance(converted_col['0'].dtype, uint64.__class__) == True + uint64 = np.dtype("uint64") + assert isinstance(converted_col["0"].dtype, uint64.__class__) - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_double_object_conversion(self, pandas, duckdb_cursor): data = [18446744073709551616, 0] - x = pandas.DataFrame({'0': pandas.Series(data=data, dtype='object')}) + x = pandas.DataFrame({"0": pandas.Series(data=data, dtype="object")}) converted_col = duckdb_cursor.sql("select * from x").df() - double_dtype = np.dtype('float64') - assert isinstance(converted_col['0'].dtype, double_dtype.__class__) == True + double_dtype = np.dtype("float64") + assert isinstance(converted_col["0"].dtype, double_dtype.__class__) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) @pytest.mark.xfail( condition=platform.system() == "Emscripten", reason="older numpy raises a warning when running with Pyodide", @@ -551,51 +551,51 @@ def test_numpy_object_with_stride(self, pandas, duckdb_cursor): (9, 18, 0), ] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_numpy_stringliterals(self, pandas, duckdb_cursor): df = pandas.DataFrame({"x": list(map(np.str_, range(3)))}) res = duckdb_cursor.execute("select * from df").fetchall() - assert res == [('0',), ('1',), ('2',)] + assert res == [("0",), ("1",), ("2",)] - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_integer_conversion_fail(self, pandas, duckdb_cursor): data = [2**10000, 0] - x = pandas.DataFrame({'0': pandas.Series(data=data, dtype='object')}) + x = pandas.DataFrame({"0": pandas.Series(data=data, dtype="object")}) converted_col = duckdb_cursor.sql("select * from x").df() - print(converted_col['0']) - double_dtype = np.dtype('object') - assert isinstance(converted_col['0'].dtype, double_dtype.__class__) == True + print(converted_col["0"]) + double_dtype = np.dtype("object") + assert isinstance(converted_col["0"].dtype, double_dtype.__class__) # Most of the time numpy.datetime64 is just a wrapper around a datetime.datetime object # But to support arbitrary precision, it can fall back to using an `int` internally - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) # Which we don't support yet + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) # Which we don't support yet def test_numpy_datetime(self, pandas, duckdb_cursor): numpy = pytest.importorskip("numpy") data = [] - data += [numpy.datetime64('2022-12-10T21:38:24.578696')] * standard_vector_size - data += [numpy.datetime64('2022-02-21T06:59:23.324812')] * standard_vector_size - data += [numpy.datetime64('1974-06-05T13:12:01.000000')] * standard_vector_size - data += [numpy.datetime64('2049-01-13T00:24:31.999999')] * standard_vector_size - x = pandas.DataFrame({'dates': pandas.Series(data=data, dtype='object')}) + data += [numpy.datetime64("2022-12-10T21:38:24.578696")] * standard_vector_size + data += [numpy.datetime64("2022-02-21T06:59:23.324812")] * standard_vector_size + data += [numpy.datetime64("1974-06-05T13:12:01.000000")] * standard_vector_size + data += [numpy.datetime64("2049-01-13T00:24:31.999999")] * standard_vector_size + x = pandas.DataFrame({"dates": pandas.Series(data=data, dtype="object")}) res = duckdb_cursor.sql("select distinct * from x").df() - assert len(res['dates'].__array__()) == 4 + assert len(res["dates"].__array__()) == 4 - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_numpy_datetime_int_internally(self, pandas, duckdb_cursor): numpy = pytest.importorskip("numpy") - data = [numpy.datetime64('2022-12-10T21:38:24.0000000000001')] - x = pandas.DataFrame({'dates': pandas.Series(data=data, dtype='object')}) + data = [numpy.datetime64("2022-12-10T21:38:24.0000000000001")] + x = pandas.DataFrame({"dates": pandas.Series(data=data, dtype="object")}) with pytest.raises( duckdb.ConversionException, match=re.escape("Conversion Error: Unimplemented type for cast (BIGINT -> TIMESTAMP)"), ): rel = duckdb.query_df(x, "x", "create table dates as select dates::TIMESTAMP WITHOUT TIME ZONE from x") - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_fallthrough_object_conversion(self, pandas, duckdb_cursor): x = pandas.DataFrame( [ @@ -605,10 +605,10 @@ def test_fallthrough_object_conversion(self, pandas, duckdb_cursor): ] ) duckdb_col = duckdb_cursor.sql("select * from x").df() - df_expected_res = pandas.DataFrame({'0': pandas.Series(['4', '2', '0'])}) + df_expected_res = pandas.DataFrame({"0": pandas.Series(["4", "2", "0"])}) pandas.testing.assert_frame_equal(duckdb_col, df_expected_res) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_numeric_decimal(self, pandas, duckdb_cursor): # DuckDB uses DECIMAL where possible, so all the 'float' types here are actually DECIMAL reference_query = """ @@ -623,15 +623,16 @@ def test_numeric_decimal(self, pandas, duckdb_cursor): ) tbl(a, b, c); """ duckdb_cursor.execute(reference_query) - # Because of this we need to wrap these native floats as DECIMAL for this test, to avoid these decimals being "upgraded" to DOUBLE + # Because of this we need to wrap these native floats as DECIMAL for this test, to avoid these decimals being + # "upgraded" to DOUBLE x = pandas.DataFrame( { - '0': ConvertStringToDecimal([5, '12.0', '-123.0', '-234234.0', None, '1.234'], pandas), - '1': ConvertStringToDecimal( - [5002340, 13, '-12.0000000005', '7453324234.0', None, '-324234234'], pandas + "0": ConvertStringToDecimal([5, "12.0", "-123.0", "-234234.0", None, "1.234"], pandas), + "1": ConvertStringToDecimal( + [5002340, 13, "-12.0000000005", "7453324234.0", None, "-324234234"], pandas ), - '2': ConvertStringToDecimal( - ['-234234234234.0', '324234234.00000005', -128, 345345, '1E5', '1324234359'], pandas + "2": ConvertStringToDecimal( + ["-234234234234.0", "324234234.00000005", -128, 345345, "1E5", "1324234359"], pandas ), } ) @@ -640,10 +641,10 @@ def test_numeric_decimal(self, pandas, duckdb_cursor): assert conversion == reference - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_numeric_decimal_coverage(self, pandas, duckdb_cursor): x = pandas.DataFrame( - {'0': [Decimal("nan"), Decimal("+nan"), Decimal("-nan"), Decimal("inf"), Decimal("+inf"), Decimal("-inf")]} + {"0": [Decimal("nan"), Decimal("+nan"), Decimal("-nan"), Decimal("inf"), Decimal("+inf"), Decimal("-inf")]} ) conversion = duckdb_cursor.sql("select * from x").fetchall() print(conversion[0][0].__class__) @@ -655,12 +656,12 @@ def test_numeric_decimal_coverage(self, pandas, duckdb_cursor): assert math.isinf(conversion[3][0]) assert math.isinf(conversion[4][0]) assert math.isinf(conversion[5][0]) - assert str(conversion) == '[(nan,), (nan,), (nan,), (inf,), (inf,), (inf,)]' + assert str(conversion) == "[(nan,), (nan,), (nan,), (inf,), (inf,), (inf,)]" # Test that the column 'offset' is actually used when converting, @pytest.mark.parametrize( - 'pandas', [NumpyPandas(), ArrowPandas()] + "pandas", [NumpyPandas(), ArrowPandas()] ) # and that the same 2048 (STANDARD_VECTOR_SIZE) values are not being scanned over and over again def test_multiple_chunks(self, pandas, duckdb_cursor): data = [] @@ -668,13 +669,13 @@ def test_multiple_chunks(self, pandas, duckdb_cursor): data += [datetime.date(2022, 9, 14) for x in range(standard_vector_size)] data += [datetime.date(2022, 9, 15) for x in range(standard_vector_size)] data += [datetime.date(2022, 9, 16) for x in range(standard_vector_size)] - x = pandas.DataFrame({'dates': pandas.Series(data=data, dtype='object')}) + x = pandas.DataFrame({"dates": pandas.Series(data=data, dtype="object")}) res = duckdb_cursor.sql("select distinct * from x").df() - assert len(res['dates'].__array__()) == 4 + assert len(res["dates"].__array__()) == 4 - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_multiple_chunks_aggregate(self, pandas, duckdb_cursor): - duckdb_cursor.execute(f"SET GLOBAL pandas_analyze_sample=4096") + duckdb_cursor.execute("SET GLOBAL pandas_analyze_sample=4096") duckdb_cursor.execute( "create table dates as select '2022-09-14'::DATE + INTERVAL (i::INTEGER) DAY as i from range(4096) tbl(i);" ) @@ -683,8 +684,8 @@ def test_multiple_chunks_aggregate(self, pandas, duckdb_cursor): date_df = res.copy() # Convert the dataframe to datetime - date_df['i'] = pandas.to_datetime(res['i']).dt.date - assert str(date_df['i'].dtype) == 'object' + date_df["i"] = pandas.to_datetime(res["i"]).dt.date + assert str(date_df["i"].dtype) == "object" expected_res = [ ( @@ -707,10 +708,10 @@ def test_multiple_chunks_aggregate(self, pandas, duckdb_cursor): assert expected_res == actual_res # Now interleave nulls into the dataframe - duckdb_cursor.execute('drop table dates') - for i in range(0, len(res['i']), 2): - res.loc[i, 'i'] = None - duckdb_cursor.execute('create table dates as select * from res') + duckdb_cursor.execute("drop table dates") + for i in range(0, len(res["i"]), 2): + res.loc[i, "i"] = None + duckdb_cursor.execute("create table dates as select * from res") expected_res = [ ( @@ -721,8 +722,8 @@ def test_multiple_chunks_aggregate(self, pandas, duckdb_cursor): ] # Convert the dataframe to datetime date_df = res.copy() - date_df['i'] = pandas.to_datetime(res['i']).dt.date - assert str(date_df['i'].dtype) == 'object' + date_df["i"] = pandas.to_datetime(res["i"]).dt.date + assert str(date_df["i"].dtype) == "object" actual_res = duckdb_cursor.sql( """ @@ -736,47 +737,47 @@ def test_multiple_chunks_aggregate(self, pandas, duckdb_cursor): assert expected_res == actual_res - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_mixed_object_types(self, pandas, duckdb_cursor): x = pandas.DataFrame( { - 'nested': pandas.Series( - data=[{'a': 1, 'b': 2}, [5, 4, 3], {'key': [1, 2, 3], 'value': ['a', 'b', 'c']}], dtype='object' + "nested": pandas.Series( + data=[{"a": 1, "b": 2}, [5, 4, 3], {"key": [1, 2, 3], "value": ["a", "b", "c"]}], dtype="object" ), } ) res = duckdb_cursor.sql("select * from x").df() - assert res['nested'].dtype == np.dtype('object') + assert res["nested"].dtype == np.dtype("object") - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_struct_deeply_nested_in_struct(self, pandas, duckdb_cursor): x = pandas.DataFrame( [ { # STRUCT(b STRUCT(x VARCHAR, y VARCHAR)) - 'a': {'b': {'x': 'A', 'y': 'B'}} + "a": {"b": {"x": "A", "y": "B"}} }, { # STRUCT(b STRUCT(x VARCHAR)) - 'a': {'b': {'x': 'A'}} + "a": {"b": {"x": "A"}} }, ] ) # The dataframe has incompatible struct schemas in the nested child # This gets upgraded to STRUCT(b MAP(VARCHAR, VARCHAR)) res = duckdb_cursor.sql("select * from x").fetchall() - assert res == [({'b': {'x': 'A', 'y': 'B'}},), ({'b': {'x': 'A'}},)] + assert res == [({"b": {"x": "A", "y": "B"}},), ({"b": {"x": "A"}},)] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_struct_deeply_nested_in_list(self, pandas, duckdb_cursor): x = pandas.DataFrame( { - 'a': [ + "a": [ [ # STRUCT(x VARCHAR, y VARCHAR)[] - {'x': 'A', 'y': 'B'}, + {"x": "A", "y": "B"}, # STRUCT(x VARCHAR)[] - {'x': 'A'}, + {"x": "A"}, ] ] } @@ -784,16 +785,16 @@ def test_struct_deeply_nested_in_list(self, pandas, duckdb_cursor): # The dataframe has incompatible struct schemas in the nested child # This gets upgraded to STRUCT(b MAP(VARCHAR, VARCHAR)) res = duckdb_cursor.sql("select * from x").fetchall() - assert res == [([{'x': 'A', 'y': 'B'}, {'x': 'A'}],)] + assert res == [([{"x": "A", "y": "B"}, {"x": "A"}],)] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_analyze_sample_too_small(self, pandas, duckdb_cursor): data = [1 for _ in range(9)] + [[1, 2, 3]] + [1 for _ in range(9991)] - x = pandas.DataFrame({'a': pandas.Series(data=data)}) + x = pandas.DataFrame({"a": pandas.Series(data=data)}) with pytest.raises(duckdb.InvalidInputException, match="Failed to cast value: Unimplemented type for cast"): res = duckdb_cursor.sql("select * from x").df() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_numeric_decimal_zero_fractional(self, pandas, duckdb_cursor): decimals = pandas.DataFrame( data={ @@ -826,7 +827,7 @@ def test_numeric_decimal_zero_fractional(self, pandas, duckdb_cursor): assert conversion == reference - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_numeric_decimal_incompatible(self, pandas, duckdb_cursor): reference_query = """ CREATE TABLE tbl AS SELECT * FROM ( @@ -842,10 +843,10 @@ def test_numeric_decimal_incompatible(self, pandas, duckdb_cursor): duckdb_cursor.execute(reference_query) x = pandas.DataFrame( { - '0': ConvertStringToDecimal(['5', '12.0', '-123.0', '-234234.0', None, '1.234'], pandas), - '1': ConvertStringToDecimal([5002340, 13, '-12.0000000005', 7453324234, None, '-324234234'], pandas), - '2': ConvertStringToDecimal( - [-234234234234, '324234234.00000005', -128, 345345, 0, '1324234359'], pandas + "0": ConvertStringToDecimal(["5", "12.0", "-123.0", "-234234.0", None, "1.234"], pandas), + "1": ConvertStringToDecimal([5002340, 13, "-12.0000000005", 7453324234, None, "-324234234"], pandas), + "2": ConvertStringToDecimal( + [-234234234234, "324234234.00000005", -128, 345345, 0, "1324234359"], pandas ), } ) @@ -857,7 +858,7 @@ def test_numeric_decimal_incompatible(self, pandas, duckdb_cursor): print(conversion) @pytest.mark.parametrize( - 'pandas', [NumpyPandas(), ArrowPandas()] + "pandas", [NumpyPandas(), ArrowPandas()] ) # result: [('1E-28',), ('10000000000000000000000000.0',)] def test_numeric_decimal_combined(self, pandas, duckdb_cursor): decimals = pandas.DataFrame( @@ -878,7 +879,7 @@ def test_numeric_decimal_combined(self, pandas, duckdb_cursor): print(conversion) # result: [('1234.0',), ('123456789.0',), ('1234567890123456789.0',), ('0.1234567890123456789',)] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_numeric_decimal_varying_sizes(self, pandas, duckdb_cursor): decimals = pandas.DataFrame( data={ @@ -906,7 +907,7 @@ def test_numeric_decimal_varying_sizes(self, pandas, duckdb_cursor): print(reference) print(conversion) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_numeric_decimal_fallback_to_double(self, pandas, duckdb_cursor): # The widths of these decimal values are bigger than the max supported width for DECIMAL data = [ @@ -927,7 +928,7 @@ def test_numeric_decimal_fallback_to_double(self, pandas, duckdb_cursor): assert conversion == reference assert isinstance(conversion[0][0], float) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_numeric_decimal_double_mixed(self, pandas, duckdb_cursor): data = [ Decimal("1.234"), @@ -959,7 +960,7 @@ def test_numeric_decimal_double_mixed(self, pandas, duckdb_cursor): assert conversion == reference assert isinstance(conversion[0][0], float) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_numeric_decimal_out_of_range(self, pandas, duckdb_cursor): data = [Decimal("1.234567890123456789012345678901234567"), Decimal("123456789012345678901234567890123456.0")] decimals = pandas.DataFrame(data={"0": data}) diff --git a/tests/fast/pandas/test_df_recursive_nested.py b/tests/fast/pandas/test_df_recursive_nested.py index b8de512a..871132ae 100644 --- a/tests/fast/pandas/test_df_recursive_nested.py +++ b/tests/fast/pandas/test_df_recursive_nested.py @@ -1,9 +1,7 @@ -import duckdb -import datetime -import numpy as np import pytest -import copy -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb from duckdb import Value NULL = None @@ -12,8 +10,7 @@ def check_equal(conn, df, reference_query, data): duckdb_conn = duckdb.connect() duckdb_conn.execute(reference_query, parameters=[data]) - res = duckdb_conn.query('SELECT * FROM tbl').fetchall() - df_res = duckdb_conn.query('SELECT * FROM tbl').df() + res = duckdb_conn.query("SELECT * FROM tbl").fetchall() out = conn.sql("SELECT * FROM df").fetchall() assert res == out @@ -23,40 +20,40 @@ def create_reference_query(): return query -class TestDFRecursiveNested(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) +class TestDFRecursiveNested: + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_list_of_structs(self, duckdb_cursor, pandas): - data = [[{'a': 5}, NULL, {'a': NULL}], NULL, [{'a': 5}, NULL, {'a': NULL}]] + data = [[{"a": 5}, NULL, {"a": NULL}], NULL, [{"a": 5}, NULL, {"a": NULL}]] reference_query = create_reference_query() - df = pandas.DataFrame([{'a': data}]) - check_equal(duckdb_cursor, df, reference_query, Value(data, 'STRUCT(a INTEGER)[]')) + df = pandas.DataFrame([{"a": data}]) + check_equal(duckdb_cursor, df, reference_query, Value(data, "STRUCT(a INTEGER)[]")) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_list_of_map(self, duckdb_cursor, pandas): # LIST(MAP(VARCHAR, VARCHAR)) - data = [[{5: NULL}, NULL, {}], NULL, [NULL, {3: NULL, 2: 'a', 4: NULL}, {'a': 1, 'b': 2, 'c': 3}]] + data = [[{5: NULL}, NULL, {}], NULL, [NULL, {3: NULL, 2: "a", 4: NULL}, {"a": 1, "b": 2, "c": 3}]] reference_query = create_reference_query() print(reference_query) - df = pandas.DataFrame([{'a': data}]) - check_equal(duckdb_cursor, df, reference_query, Value(data, 'MAP(VARCHAR, VARCHAR)[][]')) + df = pandas.DataFrame([{"a": data}]) + check_equal(duckdb_cursor, df, reference_query, Value(data, "MAP(VARCHAR, VARCHAR)[][]")) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_recursive_list(self, duckdb_cursor, pandas): # LIST(LIST(LIST(LIST(INTEGER)))) data = [[[[3, NULL, 5], NULL], NULL, [[5, -20, NULL]]], NULL, [[[NULL]], [[]], NULL]] reference_query = create_reference_query() - df = pandas.DataFrame([{'a': data}]) - check_equal(duckdb_cursor, df, reference_query, Value(data, 'INTEGER[][][][]')) + df = pandas.DataFrame([{"a": data}]) + check_equal(duckdb_cursor, df, reference_query, Value(data, "INTEGER[][][][]")) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_recursive_struct(self, duckdb_cursor, pandas): # STRUCT(STRUCT(STRUCT(LIST))) data = { - 'A': {'a': {'1': [1, 2, 3]}, 'b': NULL, 'c': {'1': NULL}}, - 'B': {'a': {'1': [1, NULL, 3]}, 'b': NULL, 'c': {'1': NULL}}, + "A": {"a": {"1": [1, 2, 3]}, "b": NULL, "c": {"1": NULL}}, + "B": {"a": {"1": [1, NULL, 3]}, "b": NULL, "c": {"1": NULL}}, } reference_query = create_reference_query() - df = pandas.DataFrame([{'a': data}]) + df = pandas.DataFrame([{"a": data}]) check_equal( duckdb_cursor, df, @@ -92,7 +89,7 @@ def test_recursive_struct(self, duckdb_cursor, pandas): ), ) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_recursive_map(self, duckdb_cursor, pandas): # MAP( # MAP( @@ -102,42 +99,42 @@ def test_recursive_map(self, duckdb_cursor, pandas): # INTEGER # ) data = { - 'key': [ - {'key': [5, 6, 7], 'value': [{'key': [8], 'value': [NULL]}, NULL, {'key': [9], 'value': ['a']}]}, - {'key': [], 'value': []}, + "key": [ + {"key": [5, 6, 7], "value": [{"key": [8], "value": [NULL]}, NULL, {"key": [9], "value": ["a"]}]}, + {"key": [], "value": []}, ], - 'value': [1, 2], + "value": [1, 2], } reference_query = create_reference_query() - df = pandas.DataFrame([{'a': data}]) + df = pandas.DataFrame([{"a": data}]) check_equal( - duckdb_cursor, df, reference_query, Value(data, 'MAP(MAP(INTEGER, MAP(INTEGER, VARCHAR)), INTEGER)') + duckdb_cursor, df, reference_query, Value(data, "MAP(MAP(INTEGER, MAP(INTEGER, VARCHAR)), INTEGER)") ) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_recursive_stresstest(self, duckdb_cursor, pandas): data = [ { - 'a': { - 'key': [ + "a": { + "key": [ # key 1 - {'1': [5, 4, 3], '2': [8, 7, 6], '3': [1, 2, 3]}, + {"1": [5, 4, 3], "2": [8, 7, 6], "3": [1, 2, 3]}, # key 2 - {'1': [], '2': NULL, '3': [NULL, 0, NULL]}, + {"1": [], "2": NULL, "3": [NULL, 0, NULL]}, ], - 'value': [ + "value": [ # value 1 - [{'A': 'abc', 'B': 'def', 'C': NULL}], + [{"A": "abc", "B": "def", "C": NULL}], # value 2 [NULL], ], }, - 'b': NULL, - 'c': {'key': [], 'value': []}, + "b": NULL, + "c": {"key": [], "value": []}, } ] reference_query = create_reference_query() - df = pandas.DataFrame([{'a': data}]) + df = pandas.DataFrame([{"a": data}]) duckdb_type = """ STRUCT( a MAP( diff --git a/tests/fast/pandas/test_fetch_df_chunk.py b/tests/fast/pandas/test_fetch_df_chunk.py index 1973a729..4fba64ea 100644 --- a/tests/fast/pandas/test_fetch_df_chunk.py +++ b/tests/fast/pandas/test_fetch_df_chunk.py @@ -1,10 +1,11 @@ import pytest + import duckdb VECTOR_SIZE = duckdb.__standard_vector_size__ -class TestType(object): +class TestType: def test_fetch_df_chunk(self): size = 3000 con = duckdb.connect() @@ -13,16 +14,16 @@ def test_fetch_df_chunk(self): # Fetch the first chunk cur_chunk = query.fetch_df_chunk() - assert cur_chunk['a'][0] == 0 + assert cur_chunk["a"][0] == 0 assert len(cur_chunk) == VECTOR_SIZE # Fetch the second chunk, can't be entirely filled cur_chunk = query.fetch_df_chunk() - assert cur_chunk['a'][0] == VECTOR_SIZE + assert cur_chunk["a"][0] == VECTOR_SIZE expected = size - VECTOR_SIZE assert len(cur_chunk) == expected - @pytest.mark.parametrize('size', [3000, 10000, 100000, VECTOR_SIZE - 1, VECTOR_SIZE + 1, VECTOR_SIZE]) + @pytest.mark.parametrize("size", [3000, 10000, 100000, VECTOR_SIZE - 1, VECTOR_SIZE + 1, VECTOR_SIZE]) def test_monahan(self, size): con = duckdb.connect() con.execute(f"CREATE table t as select range a from range({size});") @@ -52,12 +53,12 @@ def test_fetch_df_chunk_parameter(self): # Return 2 vectors cur_chunk = query.fetch_df_chunk(2) - assert cur_chunk['a'][0] == 0 + assert cur_chunk["a"][0] == 0 assert len(cur_chunk) == VECTOR_SIZE * 2 # Return Default 1 vector cur_chunk = query.fetch_df_chunk() - assert cur_chunk['a'][0] == VECTOR_SIZE * 2 + assert cur_chunk["a"][0] == VECTOR_SIZE * 2 assert len(cur_chunk) == VECTOR_SIZE # Return 0 vectors @@ -69,7 +70,7 @@ def test_fetch_df_chunk_parameter(self): # Return more vectors than we have remaining cur_chunk = query.fetch_df_chunk(3) - assert cur_chunk['a'][0] == fetched + assert cur_chunk["a"][0] == fetched assert len(cur_chunk) == expected # These shouldn't throw errors (Just emmit empty chunks) @@ -88,5 +89,5 @@ def test_fetch_df_chunk_negative_parameter(self): query = con.execute("SELECT a FROM t") # Return -1 vector should not work - with pytest.raises(TypeError, match='incompatible function arguments'): - cur_chunk = query.fetch_df_chunk(-1) + with pytest.raises(TypeError, match="incompatible function arguments"): + query.fetch_df_chunk(-1) diff --git a/tests/fast/pandas/test_fetch_nested.py b/tests/fast/pandas/test_fetch_nested.py index 5727429f..3bf46c10 100644 --- a/tests/fast/pandas/test_fetch_nested.py +++ b/tests/fast/pandas/test_fetch_nested.py @@ -1,19 +1,19 @@ +import numpy as np import pytest + import duckdb -import sys pd = pytest.importorskip("pandas") -import numpy as np def compare_results(con, query, expected): expected = pd.DataFrame.from_dict(expected) unsorted_res = con.query(query).df() - print(unsorted_res, unsorted_res['a'][0].__class__) + print(unsorted_res, unsorted_res["a"][0].__class__) df_duck = con.query("select * from unsorted_res order by all").df() - print(df_duck, df_duck['a'][0].__class__) - print(expected, expected['a'][0].__class__) + print(df_duck, df_duck["a"][0].__class__) + print(expected, expected["a"][0].__class__) pd.testing.assert_frame_equal(df_duck, expected) @@ -55,7 +55,7 @@ def list_test_cases(): }), ("SELECT a from (SELECT LIST(i) as a FROM range(10000) tbl(i)) as t", { 'a': [ - list(range(0, 10000)) + list(range(10000)) ] }), ("SELECT LIST(i) as a FROM range(5) tbl(i) group by i%2 order by all", { @@ -146,13 +146,13 @@ def list_test_cases(): return test_cases -class TestFetchNested(object): - @pytest.mark.parametrize('query, expected', list_test_cases()) +class TestFetchNested: + @pytest.mark.parametrize(("query", "expected"), list_test_cases()) def test_fetch_df_list(self, duckdb_cursor, query, expected): compare_results(duckdb_cursor, query, expected) # fmt: off - @pytest.mark.parametrize('query, expected', [ + @pytest.mark.parametrize(("query", "expected"), [ ("SELECT a from (SELECT STRUCT_PACK(a := 42, b := 43) as a) as t", { 'a': [ {'a': 42, 'b': 43} @@ -192,22 +192,11 @@ def test_fetch_df_list(self, duckdb_cursor, query, expected): ] }), ]) - # fmt: on def test_struct_df(self, duckdb_cursor, query, expected): compare_results(duckdb_cursor, query, expected) # fmt: off - @pytest.mark.parametrize('query, expected, expected_error', [ - ("SELECT a from (select MAP(LIST_VALUE(1, 2, 3, 4),LIST_VALUE(10, 9, 8, 7)) as a) as t", { - 'a': [ - { - '1':10, - '2':9, - '3':8, - '4':7 - } - ] - }, ""), + @pytest.mark.parametrize(("query", "expected", "expected_error"), [ ("SELECT a from (select MAP(LIST_VALUE(1, 2, 3, 4),LIST_VALUE(10, 9, 8, 7)) as a) as t", { 'a': [ { @@ -242,7 +231,7 @@ def test_struct_df(self, duckdb_cursor, query, expected): } ] }, ""), - ("SELECT m as a from (select MAP(lsta,lstb) as m from (SELECT list(i) as lsta, list(i) as lstb from range(10) tbl(i) group by i%5 order by all) as lst_tbl) as T", { + ("SELECT m as a from (select MAP(lsta,lstb) as m from (SELECT list(i) as lsta, list(i) as lstb from range(10) tbl(i) group by i%5 order by all) as lst_tbl) as T", { # noqa: E501 'a': [ { '0':0, @@ -278,7 +267,7 @@ def test_struct_df(self, duckdb_cursor, query, expected): } ] }, "Map keys must be unique"), - ("SELECT a from (select MAP(LIST_VALUE('Jon Lajoie', 'Backstreet Boys', 'Tenacious D','Jon Lajoie' ),LIST_VALUE(10,9,10,11)) as a) as t", { + ("SELECT a from (select MAP(LIST_VALUE('Jon Lajoie', 'Backstreet Boys', 'Tenacious D','Jon Lajoie' ),LIST_VALUE(10,9,10,11)) as a) as t", { # noqa: E501 'a': [ { 'key': ['Jon Lajoie', 'Backstreet Boys', 'Tenacious D', 'Jon Lajoie'], @@ -286,7 +275,7 @@ def test_struct_df(self, duckdb_cursor, query, expected): } ] }, "Map keys must be unique"), - ("SELECT a from (select MAP(LIST_VALUE('Jon Lajoie', NULL, 'Tenacious D',NULL,NULL ),LIST_VALUE(10,9,10,11,13)) as a) as t", { + ("SELECT a from (select MAP(LIST_VALUE('Jon Lajoie', NULL, 'Tenacious D',NULL,NULL ),LIST_VALUE(10,9,10,11,13)) as a) as t", { # noqa: E501 'a': [ { 'key': ['Jon Lajoie', None, 'Tenacious D', None, None], @@ -302,7 +291,7 @@ def test_struct_df(self, duckdb_cursor, query, expected): } ] }, "Map keys can not be NULL"), - ("SELECT a from (select MAP(LIST_VALUE(NULL, NULL, NULL,NULL,NULL ),LIST_VALUE(NULL, NULL, NULL,NULL,NULL )) as a) as t", { + ("SELECT a from (select MAP(LIST_VALUE(NULL, NULL, NULL,NULL,NULL ),LIST_VALUE(NULL, NULL, NULL,NULL,NULL )) as a) as t", { # noqa: E501 'a': [ { 'key': [None, None, None, None, None], @@ -311,7 +300,6 @@ def test_struct_df(self, duckdb_cursor, query, expected): ] }, "Map keys can not be NULL"), ]) - # fmt: on def test_map_df(self, duckdb_cursor, query, expected, expected_error): if not expected_error: compare_results(duckdb_cursor, query, expected) @@ -320,7 +308,7 @@ def test_map_df(self, duckdb_cursor, query, expected, expected_error): compare_results(duckdb_cursor, query, expected) # fmt: off - @pytest.mark.parametrize('query, expected', [ + @pytest.mark.parametrize(("query", "expected"), [ (""" SELECT [ {'i':1,'j':2}, @@ -359,7 +347,7 @@ def test_map_df(self, duckdb_cursor, query, expected, expected_error): }), (""" SELECT {'i':mp,'j':mp2} as a FROM (SELECT MAP(LIST_VALUE(1, 2, 3, 4),LIST_VALUE(10, 9, 8, 7)) as mp, MAP(LIST_VALUE(1, 2, 3, 5),LIST_VALUE(10, 9, 8, 7)) as mp2) as t - """, { + """, { # noqa: E501 'a': [ { 'i': { @@ -379,7 +367,7 @@ def test_map_df(self, duckdb_cursor, query, expected, expected_error): }), (""" SELECT [mp,mp2] as a FROM (SELECT MAP(LIST_VALUE(1, 2, 3, 4),LIST_VALUE(10, 9, 8, 7)) as mp, MAP(LIST_VALUE(1, 2, 3, 5),LIST_VALUE(10, 9, 8, 7)) as mp2) as t - """, { + """, { # noqa: E501 'a': [ [ { @@ -459,6 +447,5 @@ def test_map_df(self, duckdb_cursor, query, expected, expected_error): ] }), ]) - # fmt: on def test_nested_mix(self, duckdb_cursor, query, expected): compare_results(duckdb_cursor, query, expected) diff --git a/tests/fast/pandas/test_implicit_pandas_scan.py b/tests/fast/pandas/test_implicit_pandas_scan.py index e6f0b9f4..76f2c200 100644 --- a/tests/fast/pandas/test_implicit_pandas_scan.py +++ b/tests/fast/pandas/test_implicit_pandas_scan.py @@ -1,42 +1,43 @@ # simple DB API testcase -import duckdb import pandas as pd import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas from packaging.version import Version +import duckdb + numpy_nullable_df = pd.DataFrame([{"COL1": "val1", "CoL2": 1.05}, {"COL1": "val4", "CoL2": 17}]) try: from pandas.compat import pa_version_under7p0 pyarrow_dtypes_enabled = not pa_version_under7p0 -except: +except Exception: pyarrow_dtypes_enabled = False -if Version(pd.__version__) >= Version('2.0.0') and pyarrow_dtypes_enabled: +if Version(pd.__version__) >= Version("2.0.0") and pyarrow_dtypes_enabled: pyarrow_df = numpy_nullable_df.convert_dtypes(dtype_backend="pyarrow") else: # dtype_backend is not supported in pandas < 2.0.0 pyarrow_df = numpy_nullable_df -class TestImplicitPandasScan(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) +class TestImplicitPandasScan: + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_local_pandas_scan(self, duckdb_cursor, pandas): con = duckdb.connect() - df = pandas.DataFrame([{"COL1": "val1", "CoL2": 1.05}, {"COL1": "val3", "CoL2": 17}]) - r1 = con.execute('select * from df').fetchdf() + df = pandas.DataFrame([{"COL1": "val1", "CoL2": 1.05}, {"COL1": "val3", "CoL2": 17}]) # noqa: F841 + r1 = con.execute("select * from df").fetchdf() assert r1["COL1"][0] == "val1" assert r1["COL1"][1] == "val3" assert r1["CoL2"][0] == 1.05 assert r1["CoL2"][1] == 17 - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_global_pandas_scan(self, duckdb_cursor, pandas): con = duckdb.connect() - r1 = con.execute(f'select * from {pandas.backend}_df').fetchdf() + r1 = con.execute(f"select * from {pandas.backend}_df").fetchdf() assert r1["COL1"][0] == "val1" assert r1["COL1"][1] == "val4" assert r1["CoL2"][0] == 1.05 diff --git a/tests/fast/pandas/test_import_cache.py b/tests/fast/pandas/test_import_cache.py index 32eab7b0..eb1c8fb8 100644 --- a/tests/fast/pandas/test_import_cache.py +++ b/tests/fast/pandas/test_import_cache.py @@ -1,28 +1,29 @@ -from conftest import NumpyPandas, ArrowPandas -import duckdb import pytest +from conftest import ArrowPandas, NumpyPandas + +import duckdb -@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) +@pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_import_cache_explicit_dtype(pandas): - df = pandas.DataFrame( + df = pandas.DataFrame( # noqa: F841 { - 'id': [1, 2, 3], - 'value': pandas.Series(['123.123', pandas.NaT, pandas.NA], dtype=pandas.StringDtype(storage='python')), + "id": [1, 2, 3], + "value": pandas.Series(["123.123", pandas.NaT, pandas.NA], dtype=pandas.StringDtype(storage="python")), } ) con = duckdb.connect() result_df = con.query("select id, value from df").df() - assert result_df['value'][1] is None - assert result_df['value'][2] is None + assert result_df["value"][1] is None + assert result_df["value"][2] is None -@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) +@pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_import_cache_implicit_dtype(pandas): - df = pandas.DataFrame({'id': [1, 2, 3], 'value': pandas.Series(['123.123', pandas.NaT, pandas.NA])}) + df = pandas.DataFrame({"id": [1, 2, 3], "value": pandas.Series(["123.123", pandas.NaT, pandas.NA])}) # noqa: F841 con = duckdb.connect() result_df = con.query("select id, value from df").df() - assert result_df['value'][1] is None - assert result_df['value'][2] is None + assert result_df["value"][1] is None + assert result_df["value"][2] is None diff --git a/tests/fast/pandas/test_issue_1767.py b/tests/fast/pandas/test_issue_1767.py index e37f19e1..48d3e852 100644 --- a/tests/fast/pandas/test_issue_1767.py +++ b/tests/fast/pandas/test_issue_1767.py @@ -1,15 +1,14 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- -import duckdb -import numpy import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb # Join from pandas not matching identical strings #1767 -class TestIssue1767(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) +class TestIssue1767: + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_unicode_join_pandas(self, duckdb_cursor, pandas): A = pandas.DataFrame({"key": ["a", "п"]}) B = pandas.DataFrame({"key": ["a", "п"]}) @@ -18,6 +17,6 @@ def test_unicode_join_pandas(self, duckdb_cursor, pandas): q = arrow.query("""SELECT key FROM "A" FULL JOIN "B" USING ("key") ORDER BY key""") result = q.df() - d = {'key': ["a", "п"]} + d = {"key": ["a", "п"]} df = pandas.DataFrame(data=d) pandas.testing.assert_frame_equal(result, df) diff --git a/tests/fast/pandas/test_limit.py b/tests/fast/pandas/test_limit.py index 4a03c24f..51c4a382 100644 --- a/tests/fast/pandas/test_limit.py +++ b/tests/fast/pandas/test_limit.py @@ -1,25 +1,26 @@ -import duckdb import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb -class TestLimitPandas(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) +class TestLimitPandas: + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_limit_df(self, duckdb_cursor, pandas): df_in = pandas.DataFrame( { - 'numbers': [1, 2, 3, 4, 5], + "numbers": [1, 2, 3, 4, 5], } ) limit_df = duckdb.limit(df_in, 2) assert len(limit_df.execute().fetchall()) == 2 - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_aggregate_df(self, duckdb_cursor, pandas): df_in = pandas.DataFrame( { - 'numbers': [1, 2, 2, 2], + "numbers": [1, 2, 2, 2], } ) - aggregate_df = duckdb.aggregate(df_in, 'count(numbers)', 'numbers').order('all') + aggregate_df = duckdb.aggregate(df_in, "count(numbers)", "numbers").order("all") assert aggregate_df.execute().fetchall() == [(1,), (3,)] diff --git a/tests/fast/pandas/test_pandas_arrow.py b/tests/fast/pandas/test_pandas_arrow.py index 8729362d..0cb1f00d 100644 --- a/tests/fast/pandas/test_pandas_arrow.py +++ b/tests/fast/pandas/test_pandas_arrow.py @@ -1,19 +1,21 @@ -import duckdb -import pytest import datetime +import numpy as np +import pytest from conftest import pandas_supports_arrow_backend -pd = pytest.importorskip("pandas", '2.0.0') -import numpy as np -from pandas.api.types import is_integer_dtype +import duckdb + +pd = pytest.importorskip("pandas", "2.0.0") + +from pandas.api.types import is_integer_dtype # noqa: E402 @pytest.mark.skipif(not pandas_supports_arrow_backend(), reason="pandas does not support the 'pyarrow' backend") -class TestPandasArrow(object): +class TestPandasArrow: def test_pandas_arrow(self, duckdb_cursor): pd = pytest.importorskip("pandas") - df = pd.DataFrame({'a': pd.Series([5, 4, 3])}).convert_dtypes() + df = pd.DataFrame({"a": pd.Series([5, 4, 3])}).convert_dtypes() # noqa: F841 con = duckdb.connect() res = con.sql("select * from df").fetchall() assert res == [(5,), (4,), (3,)] @@ -21,8 +23,8 @@ def test_pandas_arrow(self, duckdb_cursor): def test_mixed_columns(self): df = pd.DataFrame( { - 'strings': pd.Series(['abc', 'DuckDB', 'quack', 'quack']), - 'timestamps': pd.Series( + "strings": pd.Series(["abc", "DuckDB", "quack", "quack"]), + "timestamps": pd.Series( [ datetime.datetime(1990, 10, 21), datetime.datetime(2023, 1, 11), @@ -30,23 +32,23 @@ def test_mixed_columns(self): datetime.datetime(1990, 10, 21), ] ), - 'objects': pd.Series([[5, 4, 3], 'test', None, {'a': 42}]), - 'integers': np.ndarray((4,), buffer=np.array([1, 2, 3, 4, 5]), offset=np.int_().itemsize, dtype=int), + "objects": pd.Series([[5, 4, 3], "test", None, {"a": 42}]), + "integers": np.ndarray((4,), buffer=np.array([1, 2, 3, 4, 5]), offset=np.int_().itemsize, dtype=int), } ) - pyarrow_df = df.convert_dtypes(dtype_backend='pyarrow') + pyarrow_df = df.convert_dtypes(dtype_backend="pyarrow") # noqa: F841 con = duckdb.connect() with pytest.raises( - duckdb.InvalidInputException, match='The dataframe could not be converted to a pyarrow.lib.Table' + duckdb.InvalidInputException, match=r"The dataframe could not be converted to a pyarrow\.lib\.Table" ): - res = con.sql('select * from pyarrow_df').fetchall() + res = con.sql("select * from pyarrow_df").fetchall() # noqa: F841 numpy_df = pd.DataFrame( - {'a': np.ndarray((2,), buffer=np.array([1, 2, 3]), offset=np.int_().itemsize, dtype=int)} - ).convert_dtypes(dtype_backend='numpy_nullable') + {"a": np.ndarray((2,), buffer=np.array([1, 2, 3]), offset=np.int_().itemsize, dtype=int)} + ).convert_dtypes(dtype_backend="numpy_nullable") arrow_df = pd.DataFrame( { - 'a': pd.Series( + "a": pd.Series( [ datetime.datetime(1990, 10, 21), datetime.datetime(2023, 1, 11), @@ -55,45 +57,45 @@ def test_mixed_columns(self): ] ) } - ).convert_dtypes(dtype_backend='pyarrow') - python_df = pd.DataFrame({'a': pd.Series(['test', [5, 4, 3], {'a': 42}])}).convert_dtypes() + ).convert_dtypes(dtype_backend="pyarrow") + python_df = pd.DataFrame({"a": pd.Series(["test", [5, 4, 3], {"a": 42}])}).convert_dtypes() - df = pd.concat([numpy_df['a'], arrow_df['a'], python_df['a']], axis=1, keys=['numpy', 'arrow', 'python']) - assert is_integer_dtype(df.dtypes['numpy']) - assert isinstance(df.dtypes['arrow'], pd.ArrowDtype) - assert isinstance(df.dtypes['python'], np.dtype('O').__class__) + df = pd.concat([numpy_df["a"], arrow_df["a"], python_df["a"]], axis=1, keys=["numpy", "arrow", "python"]) + assert is_integer_dtype(df.dtypes["numpy"]) + assert isinstance(df.dtypes["arrow"], pd.ArrowDtype) + assert isinstance(df.dtypes["python"], np.dtype("O").__class__) with pytest.raises( - duckdb.InvalidInputException, match='The dataframe could not be converted to a pyarrow.lib.Table' + duckdb.InvalidInputException, match=r"The dataframe could not be converted to a pyarrow\.lib\.Table" ): - res = con.sql('select * from df').fetchall() + con.sql("select * from df").fetchall() def test_empty_df(self): df = pd.DataFrame( { - 'string': pd.Series(data=[], dtype='string'), - 'object': pd.Series(data=[], dtype='object'), - 'Int64': pd.Series(data=[], dtype='Int64'), - 'Float64': pd.Series(data=[], dtype='Float64'), - 'bool': pd.Series(data=[], dtype='bool'), - 'datetime64[ns]': pd.Series(data=[], dtype='datetime64[ns]'), - 'datetime64[ms]': pd.Series(data=[], dtype='datetime64[ms]'), - 'datetime64[us]': pd.Series(data=[], dtype='datetime64[us]'), - 'datetime64[s]': pd.Series(data=[], dtype='datetime64[s]'), - 'category': pd.Series(data=[], dtype='category'), - 'timedelta64[ns]': pd.Series(data=[], dtype='timedelta64[ns]'), + "string": pd.Series(data=[], dtype="string"), + "object": pd.Series(data=[], dtype="object"), + "Int64": pd.Series(data=[], dtype="Int64"), + "Float64": pd.Series(data=[], dtype="Float64"), + "bool": pd.Series(data=[], dtype="bool"), + "datetime64[ns]": pd.Series(data=[], dtype="datetime64[ns]"), + "datetime64[ms]": pd.Series(data=[], dtype="datetime64[ms]"), + "datetime64[us]": pd.Series(data=[], dtype="datetime64[us]"), + "datetime64[s]": pd.Series(data=[], dtype="datetime64[s]"), + "category": pd.Series(data=[], dtype="category"), + "timedelta64[ns]": pd.Series(data=[], dtype="timedelta64[ns]"), } ) - pyarrow_df = df.convert_dtypes(dtype_backend='pyarrow') + pyarrow_df = df.convert_dtypes(dtype_backend="pyarrow") # noqa: F841 con = duckdb.connect() - res = con.sql('select * from pyarrow_df').fetchall() + res = con.sql("select * from pyarrow_df").fetchall() assert res == [] def test_completely_null_df(self): df = pd.DataFrame( { - 'a': pd.Series( + "a": pd.Series( data=[ None, np.nan, @@ -102,35 +104,35 @@ def test_completely_null_df(self): ) } ) - pyarrow_df = df.convert_dtypes(dtype_backend='pyarrow') + pyarrow_df = df.convert_dtypes(dtype_backend="pyarrow") # noqa: F841 con = duckdb.connect() - res = con.sql('select * from pyarrow_df').fetchall() + res = con.sql("select * from pyarrow_df").fetchall() assert res == [(None,), (None,), (None,)] def test_mixed_nulls(self): df = pd.DataFrame( { - 'float': pd.Series(data=[4.123123, None, 7.23456], dtype='Float64'), - 'int64': pd.Series(data=[-234234124, 709329413, pd.NA], dtype='Int64'), - 'bool': pd.Series(data=[np.nan, True, False], dtype='boolean'), - 'string': pd.Series(data=['NULL', None, 'quack']), - 'list[str]': pd.Series(data=[['Huey', 'Dewey', 'Louie'], [None, pd.NA, np.nan, 'DuckDB'], None]), - 'datetime64': pd.Series( + "float": pd.Series(data=[4.123123, None, 7.23456], dtype="Float64"), + "int64": pd.Series(data=[-234234124, 709329413, pd.NA], dtype="Int64"), + "bool": pd.Series(data=[np.nan, True, False], dtype="boolean"), + "string": pd.Series(data=["NULL", None, "quack"]), + "list[str]": pd.Series(data=[["Huey", "Dewey", "Louie"], [None, pd.NA, np.nan, "DuckDB"], None]), + "datetime64": pd.Series( data=[datetime.datetime(2011, 8, 16, 22, 7, 8), None, datetime.datetime(2010, 4, 26, 18, 14, 14)] ), - 'date': pd.Series(data=[datetime.date(2008, 5, 28), datetime.date(2013, 7, 14), None]), + "date": pd.Series(data=[datetime.date(2008, 5, 28), datetime.date(2013, 7, 14), None]), } ) - pyarrow_df = df.convert_dtypes(dtype_backend='pyarrow') + pyarrow_df = df.convert_dtypes(dtype_backend="pyarrow") # noqa: F841 con = duckdb.connect() - res = con.sql('select * from pyarrow_df').fetchone() + res = con.sql("select * from pyarrow_df").fetchone() assert res == ( 4.123123, -234234124, None, - 'NULL', - ['Huey', 'Dewey', 'Louie'], + "NULL", + ["Huey", "Dewey", "Louie"], datetime.datetime(2011, 8, 16, 22, 7, 8), datetime.date(2008, 5, 28), ) diff --git a/tests/fast/pandas/test_pandas_category.py b/tests/fast/pandas/test_pandas_category.py index e86a97d9..39db1bb8 100644 --- a/tests/fast/pandas/test_pandas_category.py +++ b/tests/fast/pandas/test_pandas_category.py @@ -1,13 +1,14 @@ -import duckdb -import pandas as pd import numpy +import pandas as pd import pytest +import duckdb + def check_category_equal(category): df_in = pd.DataFrame( { - 'x': pd.Categorical(category, ordered=True), + "x": pd.Categorical(category, ordered=True), } ) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() @@ -23,7 +24,7 @@ def check_create_table(category): conn = duckdb.connect() conn.execute("PRAGMA enable_verification") - df_in = pd.DataFrame({'x': pd.Categorical(category, ordered=True), 'y': pd.Categorical(category, ordered=True)}) + df_in = pd.DataFrame({"x": pd.Categorical(category, ordered=True), "y": pd.Categorical(category, ordered=True)}) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() assert df_in.equals(df_out) @@ -39,7 +40,7 @@ def check_create_table(category): conn.execute("INSERT INTO t1 VALUES ('2','2')") res = conn.execute("SELECT x FROM t1 where x = '1'").fetchall() - assert res == [('1',)] + assert res == [("1",)] res = conn.execute("SELECT t1.x FROM t1 inner join t2 on (t1.x = t2.x)").fetchall() assert res == conn.execute("SELECT x FROM t1").fetchall() @@ -54,29 +55,29 @@ def check_create_table(category): conn.execute("DROP TABLE t1") -class TestCategory(object): +class TestCategory: def test_category_simple(self, duckdb_cursor): - df_in = pd.DataFrame({'float': [1.0, 2.0, 1.0], 'int': pd.Series([1, 2, 1], dtype="category")}) + df_in = pd.DataFrame({"float": [1.0, 2.0, 1.0], "int": pd.Series([1, 2, 1], dtype="category")}) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() print(duckdb.query_df(df_in, "data", "SELECT * FROM data").fetchall()) - print(df_out['int']) - assert numpy.all(df_out['float'] == numpy.array([1.0, 2.0, 1.0])) - assert numpy.all(df_out['int'] == numpy.array([1, 2, 1])) + print(df_out["int"]) + assert numpy.all(df_out["float"] == numpy.array([1.0, 2.0, 1.0])) + assert numpy.all(df_out["int"] == numpy.array([1, 2, 1])) def test_category_nulls(self, duckdb_cursor): - df_in = pd.DataFrame({'int': pd.Series([1, 2, None], dtype="category")}) + df_in = pd.DataFrame({"int": pd.Series([1, 2, None], dtype="category")}) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() print(duckdb.query_df(df_in, "data", "SELECT * FROM data").fetchall()) - assert df_out['int'][0] == 1 - assert df_out['int'][1] == 2 - assert pd.isna(df_out['int'][2]) + assert df_out["int"][0] == 1 + assert df_out["int"][1] == 2 + assert pd.isna(df_out["int"][2]) def test_category_string(self, duckdb_cursor): - check_category_equal(['foo', 'bla', 'zoo', 'foo', 'foo', 'bla']) + check_category_equal(["foo", "bla", "zoo", "foo", "foo", "bla"]) def test_category_string_null(self, duckdb_cursor): - check_category_equal(['foo', 'bla', None, 'zoo', 'foo', 'foo', None, 'bla']) + check_category_equal(["foo", "bla", None, "zoo", "foo", "foo", None, "bla"]) def test_category_string_null_bug_4747(self, duckdb_cursor): check_category_equal([str(i) for i in range(160)] + [None]) @@ -84,51 +85,49 @@ def test_category_string_null_bug_4747(self, duckdb_cursor): def test_categorical_fetchall(self, duckdb_cursor): df_in = pd.DataFrame( { - 'x': pd.Categorical(['foo', 'bla', None, 'zoo', 'foo', 'foo', None, 'bla'], ordered=True), + "x": pd.Categorical(["foo", "bla", None, "zoo", "foo", "foo", None, "bla"], ordered=True), } ) assert duckdb.query_df(df_in, "data", "SELECT * FROM data").fetchall() == [ - ('foo',), - ('bla',), + ("foo",), + ("bla",), (None,), - ('zoo',), - ('foo',), - ('foo',), + ("zoo",), + ("foo",), + ("foo",), (None,), - ('bla',), + ("bla",), ] def test_category_string_uint8(self, duckdb_cursor): - category = [] - for i in range(10): - category.append(str(i)) + category = [str(i) for i in range(10)] check_create_table(category) def test_empty_categorical(self, duckdb_cursor): - empty_categoric_df = pd.DataFrame({'category': pd.Series(dtype='category')}) + empty_categoric_df = pd.DataFrame({"category": pd.Series(dtype="category")}) # noqa: F841 duckdb_cursor.execute("CREATE TABLE test AS SELECT * FROM empty_categoric_df") - res = duckdb_cursor.table('test').fetchall() + res = duckdb_cursor.table("test").fetchall() assert res == [] with pytest.raises(duckdb.ConversionException, match="Could not convert string 'test' to UINT8"): duckdb_cursor.execute("insert into test VALUES('test')") duckdb_cursor.execute("insert into test VALUES(NULL)") - res = duckdb_cursor.table('test').fetchall() + res = duckdb_cursor.table("test").fetchall() assert res == [(None,)] def test_category_fetch_df_chunk(self, duckdb_cursor): con = duckdb.connect() - categories = ['foo', 'bla', None, 'zoo', 'foo', 'foo', None, 'bla'] + categories = ["foo", "bla", None, "zoo", "foo", "foo", None, "bla"] result = categories * 256 categories = result * 2 df_result = pd.DataFrame( { - 'x': pd.Categorical(result, ordered=True), + "x": pd.Categorical(result, ordered=True), } ) df_in = pd.DataFrame( { - 'x': pd.Categorical(categories, ordered=True), + "x": pd.Categorical(categories, ordered=True), } ) con.register("data", df_in) @@ -146,8 +145,8 @@ def test_category_fetch_df_chunk(self, duckdb_cursor): def test_category_mix(self, duckdb_cursor): df_in = pd.DataFrame( { - 'float': [1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 0.0], - 'x': pd.Categorical(['foo', 'bla', None, 'zoo', 'foo', 'foo', None, 'bla'], ordered=True), + "float": [1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 0.0], + "x": pd.Categorical(["foo", "bla", None, "zoo", "foo", "foo", None, "bla"], ordered=True), } ) diff --git a/tests/fast/pandas/test_pandas_df_none.py b/tests/fast/pandas/test_pandas_df_none.py index 50e1553c..5f61b3bb 100644 --- a/tests/fast/pandas/test_pandas_df_none.py +++ b/tests/fast/pandas/test_pandas_df_none.py @@ -1,12 +1,8 @@ -import pandas as pd -import pytest import duckdb -import sys -import gc -class TestPandasDFNone(object): +class TestPandasDFNone: # This used to decrease the ref count of None def test_none_deref(self): con = duckdb.connect() - df = con.sql("select NULL::VARCHAR as a from range(1000000)").df() + df = con.sql("select NULL::VARCHAR as a from range(1000000)").df() # noqa: F841 diff --git a/tests/fast/pandas/test_pandas_enum.py b/tests/fast/pandas/test_pandas_enum.py index 9dc13a64..17b2e3c2 100644 --- a/tests/fast/pandas/test_pandas_enum.py +++ b/tests/fast/pandas/test_pandas_enum.py @@ -1,9 +1,10 @@ import pandas as pd import pytest + import duckdb -class TestPandasEnum(object): +class TestPandasEnum: def test_3480(self, duckdb_cursor): duckdb_cursor.execute( """ @@ -14,8 +15,8 @@ def test_3480(self, duckdb_cursor): ); """ ) - df = duckdb_cursor.query(f"SELECT * FROM tab LIMIT 0;").to_df() - assert df["cat"].cat.categories.equals(pd.Index(['marie', 'duchess', 'toulouse'])) + df = duckdb_cursor.query("SELECT * FROM tab LIMIT 0;").to_df() + assert df["cat"].cat.categories.equals(pd.Index(["marie", "duchess", "toulouse"])) duckdb_cursor.execute("DROP TABLE tab") duckdb_cursor.execute("DROP TYPE cat") @@ -32,16 +33,17 @@ def test_3479(self, duckdb_cursor): df = pd.DataFrame( { - "cat2": pd.Series(['duchess', 'toulouse', 'marie', None, "berlioz", "o_malley"], dtype="category"), + "cat2": pd.Series(["duchess", "toulouse", "marie", None, "berlioz", "o_malley"], dtype="category"), "amt": [1, 2, 3, 4, 5, 6], } ) - duckdb_cursor.register('df', df) + duckdb_cursor.register("df", df) with pytest.raises( duckdb.ConversionException, - match='Type UINT8 with value 0 can\'t be cast because the value is out of range for the destination type UINT8', + match="Type UINT8 with value 0 can't be cast because the value is out of range for the destination " + "type UINT8", ): - duckdb_cursor.execute(f"INSERT INTO tab SELECT * FROM df;") + duckdb_cursor.execute("INSERT INTO tab SELECT * FROM df;") assert duckdb_cursor.execute("select * from tab").fetchall() == [] duckdb_cursor.execute("DROP TABLE tab") diff --git a/tests/fast/pandas/test_pandas_limit.py b/tests/fast/pandas/test_pandas_limit.py index 506d5dd5..4c765e96 100644 --- a/tests/fast/pandas/test_pandas_limit.py +++ b/tests/fast/pandas/test_pandas_limit.py @@ -1,14 +1,12 @@ import duckdb -import pandas as pd -import pytest -class TestPandasLimit(object): +class TestPandasLimit: def test_pandas_limit(self, duckdb_cursor): con = duckdb.connect() - df = con.execute('select * from range(10000000) tbl(i)').df() + df = con.execute("select * from range(10000000) tbl(i)").df() # noqa: F841 - con.execute('SET threads=8') + con.execute("SET threads=8") - limit_df = con.execute('SELECT * FROM df WHERE i=334 OR i>9967864 LIMIT 5').df() - assert list(limit_df['i']) == [334, 9967865, 9967866, 9967867, 9967868] + limit_df = con.execute("SELECT * FROM df WHERE i=334 OR i>9967864 LIMIT 5").df() + assert list(limit_df["i"]) == [334, 9967865, 9967866, 9967867, 9967868] diff --git a/tests/fast/pandas/test_pandas_na.py b/tests/fast/pandas/test_pandas_na.py index f165d180..6462c298 100644 --- a/tests/fast/pandas/test_pandas_na.py +++ b/tests/fast/pandas/test_pandas_na.py @@ -1,52 +1,53 @@ +import platform + import numpy as np -import datetime -import duckdb import pytest -import platform -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb def assert_nullness(items, null_indices): for i in range(len(items)): if i in null_indices: - assert items[i] == None + assert items[i] is None else: - assert items[i] != None + assert items[i] is not None @pytest.mark.skipif(platform.system() == "Emscripten", reason="Pandas interaction is broken in Pyodide 3.11") -class TestPandasNA(object): - @pytest.mark.parametrize('rows', [100, duckdb.__standard_vector_size__, 5000, 1000000]) - @pytest.mark.parametrize('pd', [NumpyPandas(), ArrowPandas()]) +class TestPandasNA: + @pytest.mark.parametrize("rows", [100, duckdb.__standard_vector_size__, 5000, 1000000]) + @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) def test_pandas_string_null(self, duckdb_cursor, rows, pd): df: pd.DataFrame = pd.DataFrame(index=np.arange(rows)) df["string_column"] = pd.Series(dtype="string") e_df_rel = duckdb_cursor.from_df(df) - assert e_df_rel.types == ['VARCHAR'] + assert e_df_rel.types == ["VARCHAR"] roundtrip = e_df_rel.df() - assert roundtrip['string_column'].dtype == 'object' - expected = pd.DataFrame({'string_column': [None for _ in range(rows)]}) + assert roundtrip["string_column"].dtype == "object" + expected = pd.DataFrame({"string_column": [None for _ in range(rows)]}) pd.testing.assert_frame_equal(expected, roundtrip) def test_pandas_na(self, duckdb_cursor): - pd = pytest.importorskip('pandas', minversion='1.0.0', reason='Support for pandas.NA has not been added yet') + pd = pytest.importorskip("pandas", minversion="1.0.0", reason="Support for pandas.NA has not been added yet") # DataFrame containing a single pd.NA df = pd.DataFrame(pd.Series([pd.NA])) res = duckdb_cursor.execute("select * from df").fetchall() - assert res[0][0] == None + assert res[0][0] is None # DataFrame containing multiple values, with a pd.NA mixed in null_index = 3 - df = pd.DataFrame(pd.Series([3, 1, 2, pd.NA, 8, 6])) + df = pd.DataFrame(pd.Series([3, 1, 2, pd.NA, 8, 6])) # noqa: F841 res = duckdb_cursor.execute("select * from df").fetchall() - items = [x[0] for x in [y for y in res]] + items = [x[0] for x in list(res)] assert_nullness(items, [null_index]) # Test if pd.NA behaves the same as np.nan once converted nan_df = pd.DataFrame( { - 'a': [ + "a": [ 1.123, 5.23234, np.nan, @@ -60,7 +61,7 @@ def test_pandas_na(self, duckdb_cursor): ) na_df = pd.DataFrame( { - 'a': [ + "a": [ 1.123, 5.23234, pd.NA, @@ -72,16 +73,16 @@ def test_pandas_na(self, duckdb_cursor): ] } ) - assert str(nan_df['a'].dtype) == 'float64' - assert str(na_df['a'].dtype) == 'object' # pd.NA values turn the column into 'object' + assert str(nan_df["a"].dtype) == "float64" + assert str(na_df["a"].dtype) == "object" # pd.NA values turn the column into 'object' nan_result = duckdb_cursor.execute("select * from nan_df").df() na_result = duckdb_cursor.execute("select * from na_df").df() pd.testing.assert_frame_equal(nan_result, na_result) # Mixed with stringified pd.NA values - na_string_df = pd.DataFrame({'a': [str(pd.NA), str(pd.NA), pd.NA, str(pd.NA), pd.NA, pd.NA, pd.NA, str(pd.NA)]}) + na_string_df = pd.DataFrame({"a": [str(pd.NA), str(pd.NA), pd.NA, str(pd.NA), pd.NA, pd.NA, pd.NA, str(pd.NA)]}) # noqa: F841 null_indices = [2, 4, 5, 6] res = duckdb_cursor.execute("select * from na_string_df").fetchall() - items = [x[0] for x in [y for y in res]] + items = [x[0] for x in list(res)] assert_nullness(items, null_indices) diff --git a/tests/fast/pandas/test_pandas_object.py b/tests/fast/pandas/test_pandas_object.py index c00fcbc2..4c1de99f 100644 --- a/tests/fast/pandas/test_pandas_object.py +++ b/tests/fast/pandas/test_pandas_object.py @@ -1,36 +1,37 @@ -import pandas as pd -import duckdb import datetime + import numpy as np -import random +import pandas as pd + +import duckdb -class TestPandasObject(object): +class TestPandasObject: def test_object_lotof_nulls(self): # Test mostly null column data = [None] + [1] + [None] * 10000 # Last element is 1, others are None - pandas_df = pd.DataFrame(data, columns=['c'], dtype=object) + pandas_df = pd.DataFrame(data, columns=["c"], dtype=object) # noqa: F841 con = duckdb.connect() - assert con.execute('FROM pandas_df where c is not null').fetchall() == [(1.0,)] + assert con.execute("FROM pandas_df where c is not null").fetchall() == [(1.0,)] # Test all nulls, should return varchar data = [None] * 10000 # Last element is 1, others are None - pandas_df_2 = pd.DataFrame(data, columns=['c'], dtype=object) - assert con.execute('FROM pandas_df_2 limit 1').fetchall() == [(None,)] - assert con.execute('select typeof(c) FROM pandas_df_2 limit 1').fetchall() == [('"NULL"',)] + pandas_df_2 = pd.DataFrame(data, columns=["c"], dtype=object) # noqa: F841 + assert con.execute("FROM pandas_df_2 limit 1").fetchall() == [(None,)] + assert con.execute("select typeof(c) FROM pandas_df_2 limit 1").fetchall() == [('"NULL"',)] def test_object_to_string(self, duckdb_cursor): - con = duckdb.connect(database=':memory:', read_only=False) - x = pd.DataFrame([[1, 'a', 2], [1, None, 2], [1, 1.1, 2], [1, 1.1, 2], [1, 1.1, 2]]) + con = duckdb.connect(database=":memory:", read_only=False) + x = pd.DataFrame([[1, "a", 2], [1, None, 2], [1, 1.1, 2], [1, 1.1, 2], [1, 1.1, 2]]) x = x.iloc[1:].copy() # middle col now entirely native float items - con.register('view2', x) - df = con.execute('select * from view2').fetchall() + con.register("view2", x) + df = con.execute("select * from view2").fetchall() assert df == [(1, None, 2), (1, 1.1, 2), (1, 1.1, 2), (1, 1.1, 2)] def test_tuple_to_list(self, duckdb_cursor): - tuple_df = pd.DataFrame.from_dict( - dict( - nums=[ + tuple_df = pd.DataFrame.from_dict( # noqa: F841 + { + "nums": [ ( 1, 2, @@ -42,22 +43,22 @@ def test_tuple_to_list(self, duckdb_cursor): 6, ), ] - ) + } ) duckdb_cursor.execute("CREATE TABLE test as SELECT * FROM tuple_df") - res = duckdb_cursor.table('test').fetchall() + res = duckdb_cursor.table("test").fetchall() assert res == [([1, 2, 3],), ([4, 5, 6],)] def test_2273(self, duckdb_cursor): - df_in = pd.DataFrame([[datetime.date(1992, 7, 30)]]) + df_in = pd.DataFrame([[datetime.date(1992, 7, 30)]]) # noqa: F841 assert duckdb_cursor.query("Select * from df_in").fetchall() == [(datetime.date(1992, 7, 30),)] def test_object_to_string_with_stride(self, duckdb_cursor): data = np.array([["a", "b", "c"], [1, 2, 3], [1, 2, 3], [11, 22, 33]]) df = pd.DataFrame(data=data[1:,], columns=data[0]) duckdb_cursor.register("object_with_strides", df) - res = duckdb_cursor.sql('select * from object_with_strides').fetchall() - assert res == [('1', '2', '3'), ('1', '2', '3'), ('11', '22', '33')] + res = duckdb_cursor.sql("select * from object_with_strides").fetchall() + assert res == [("1", "2", "3"), ("1", "2", "3"), ("11", "22", "33")] def test_2499(self, duckdb_cursor): df = pd.DataFrame( @@ -65,11 +66,11 @@ def test_2499(self, duckdb_cursor): [ np.array( [ - {'a': 0.881040697801939}, - {'a': 0.9922600577751953}, - {'a': 0.1589674833259317}, - {'a': 0.8928451262745073}, - {'a': 0.07022897889168278}, + {"a": 0.881040697801939}, + {"a": 0.9922600577751953}, + {"a": 0.1589674833259317}, + {"a": 0.8928451262745073}, + {"a": 0.07022897889168278}, ], dtype=object, ) @@ -77,11 +78,11 @@ def test_2499(self, duckdb_cursor): [ np.array( [ - {'a': 0.8759413504156746}, - {'a': 0.055784331256246156}, - {'a': 0.8605151517439655}, - {'a': 0.40807139339337695}, - {'a': 0.8429048322459952}, + {"a": 0.8759413504156746}, + {"a": 0.055784331256246156}, + {"a": 0.8605151517439655}, + {"a": 0.40807139339337695}, + {"a": 0.8429048322459952}, ], dtype=object, ) @@ -89,19 +90,19 @@ def test_2499(self, duckdb_cursor): [ np.array( [ - {'a': 0.9697093934032401}, - {'a': 0.9529257667149468}, - {'a': 0.21398182248591713}, - {'a': 0.6328512122275955}, - {'a': 0.5146953214092728}, + {"a": 0.9697093934032401}, + {"a": 0.9529257667149468}, + {"a": 0.21398182248591713}, + {"a": 0.6328512122275955}, + {"a": 0.5146953214092728}, ], dtype=object, ) ], ], - columns=['col'], + columns=["col"], ) - con = duckdb.connect(database=':memory:', read_only=False) - con.register('df', df) - assert con.execute('select count(*) from df').fetchone() == (3,) + con = duckdb.connect(database=":memory:", read_only=False) + con.register("df", df) + assert con.execute("select count(*) from df").fetchone() == (3,) diff --git a/tests/fast/pandas/test_pandas_string.py b/tests/fast/pandas/test_pandas_string.py index 494823ad..d1302f89 100644 --- a/tests/fast/pandas/test_pandas_string.py +++ b/tests/fast/pandas/test_pandas_string.py @@ -1,27 +1,28 @@ -import duckdb -import pandas as pd import numpy +import pandas as pd + +import duckdb -class TestPandasString(object): +class TestPandasString: def test_pandas_string(self, duckdb_cursor): - strings = numpy.array(['foo', 'bar', 'baz']) + strings = numpy.array(["foo", "bar", "baz"]) # https://pandas.pydata.org/pandas-docs/stable/user_guide/text.html df_in = pd.DataFrame( { - 'object': pd.Series(strings, dtype='object'), + "object": pd.Series(strings, dtype="object"), } ) # Only available in pandas 1.0.0 - if hasattr(pd, 'StringDtype'): - df_in['string'] = pd.Series(strings, dtype=pd.StringDtype()) + if hasattr(pd, "StringDtype"): + df_in["string"] = pd.Series(strings, dtype=pd.StringDtype()) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() - assert numpy.all(df_out['object'] == strings) - if hasattr(pd, 'StringDtype'): - assert numpy.all(df_out['string'] == strings) + assert numpy.all(df_out["object"] == strings) + if hasattr(pd, "StringDtype"): + assert numpy.all(df_out["string"] == strings) def test_bug_2467(self, duckdb_cursor): N = 1_000_000 @@ -31,15 +32,12 @@ def test_bug_2467(self, duckdb_cursor): con = duckdb.connect() con.register("df", df) con.execute( - f""" + """ CREATE TABLE t1 AS SELECT * FROM df """ ) - assert ( - con.execute( - f""" + assert con.execute( + """ SELECT count(*) from t1 """ - ).fetchall() - == [(3000000,)] - ) + ).fetchall() == [(3000000,)] diff --git a/tests/fast/pandas/test_pandas_timestamp.py b/tests/fast/pandas/test_pandas_timestamp.py index 8e17db21..6311f3ba 100644 --- a/tests/fast/pandas/test_pandas_timestamp.py +++ b/tests/fast/pandas/test_pandas_timestamp.py @@ -1,36 +1,36 @@ -import duckdb +from datetime import datetime + import pandas import pytest - -from datetime import datetime -from pytz import timezone from conftest import pandas_2_or_higher +import duckdb + -@pytest.mark.parametrize('timezone', ['UTC', 'CET', 'Asia/Kathmandu']) +@pytest.mark.parametrize("timezone", ["UTC", "CET", "Asia/Kathmandu"]) @pytest.mark.skipif(not pandas_2_or_higher(), reason="Pandas <2.0.0 does not support timezones in the metadata string") def test_run_pandas_with_tz(timezone): con = duckdb.connect() con.execute(f"SET TimeZone = '{timezone}'") df = pandas.DataFrame( { - 'timestamp': pandas.Series( - data=[pandas.Timestamp(year=2022, month=1, day=1, hour=10, minute=15, tz=timezone, unit='us')], - dtype=f'datetime64[us, {timezone}]', + "timestamp": pandas.Series( + data=[pandas.Timestamp(year=2022, month=1, day=1, hour=10, minute=15, tz=timezone, unit="us")], + dtype=f"datetime64[us, {timezone}]", ) } ) duck_df = con.from_df(df).df() - assert duck_df['timestamp'][0] == df['timestamp'][0] + assert duck_df["timestamp"][0] == df["timestamp"][0] def test_timestamp_conversion(duckdb_cursor): - tzinfo = pandas.Timestamp('2024-01-01 00:00:00+0100', tz='Europe/Copenhagen').tzinfo - ts_df = pandas.DataFrame( + tzinfo = pandas.Timestamp("2024-01-01 00:00:00+0100", tz="Europe/Copenhagen").tzinfo + ts_df = pandas.DataFrame( # noqa: F841 { "ts": [ - pandas.Timestamp('2024-01-01 00:00:00+0100', tz=tzinfo), - pandas.Timestamp('2024-01-02 00:00:00+0100', tz=tzinfo), + pandas.Timestamp("2024-01-01 00:00:00+0100", tz=tzinfo), + pandas.Timestamp("2024-01-02 00:00:00+0100", tz=tzinfo), ] } ) diff --git a/tests/fast/pandas/test_pandas_types.py b/tests/fast/pandas/test_pandas_types.py index aeb33ea4..7510cb28 100644 --- a/tests/fast/pandas/test_pandas_types.py +++ b/tests/fast/pandas/test_pandas_types.py @@ -1,17 +1,19 @@ -import duckdb -import pytest -import pandas as pd -import numpy import string -from packaging import version import warnings from contextlib import suppress +import numpy +import pandas as pd +import pytest +from packaging import version + +import duckdb + def round_trip(data, pandas_type): df_in = pd.DataFrame( { - 'object': pd.Series(data, dtype=pandas_type), + "object": pd.Series(data, dtype=pandas_type), } ) @@ -21,9 +23,9 @@ def round_trip(data, pandas_type): assert df_out.equals(df_in) -class TestNumpyNullableTypes(object): +class TestNumpyNullableTypes: def test_pandas_numeric(self): - base_df = pd.DataFrame({'a': range(10)}) + base_df = pd.DataFrame({"a": range(10)}) data_types = [ "uint8", @@ -46,7 +48,7 @@ def test_pandas_numeric(self): "float64", ] - if version.parse(pd.__version__) >= version.parse('1.2.0'): + if version.parse(pd.__version__) >= version.parse("1.2.0"): # These DTypes where added in 1.2.0 data_types.extend(["Float32", "Float64"]) # Generate a dataframe with all the types, in the form of: @@ -57,25 +59,25 @@ def test_pandas_numeric(self): for letter, dtype in zip(string.ascii_lowercase, data_types): data[letter] = base_df.a.astype(dtype) - df = pd.DataFrame.from_dict(data) + df = pd.DataFrame.from_dict(data) # noqa: F841 conn = duckdb.connect() - out_df = conn.execute('select * from df').df() + out_df = conn.execute("select * from df").df() # Verify that the types in the out_df are correct - # FIXME: we don't support outputting pandas specific types (i.e UInt64) + # TODO: we don't support outputting pandas specific types (i.e UInt64) # noqa: TD002, TD003 for letter, item in zip(string.ascii_lowercase, data_types): column_name = letter assert str(out_df[column_name].dtype) == item.lower() def test_pandas_unsigned(self, duckdb_cursor): - unsigned_types = ['uint8', 'uint16', 'uint32', 'uint64'] + unsigned_types = ["uint8", "uint16", "uint32", "uint64"] data = numpy.array([0, 1, 2, 3]) for u_type in unsigned_types: round_trip(data, u_type) def test_pandas_bool(self, duckdb_cursor): data = numpy.array([True, False, False, True]) - round_trip(data, 'bool') + round_trip(data, "bool") def test_pandas_masked_float64(self, duckdb_cursor, tmp_path): pa = pytest.importorskip("pyarrow") @@ -92,7 +94,7 @@ def test_pandas_masked_float64(self, duckdb_cursor, tmp_path): pq.write_table(pa.Table.from_pandas(testdf), parquet_path) # Read the Parquet file back into a DataFrame - testdf2 = pd.read_parquet(parquet_path) + testdf2 = pd.read_parquet(parquet_path) # noqa: F841 # Use duckdb_cursor to query the parquet data result = duckdb_cursor.execute("SELECT MIN(value) FROM testdf2").fetchall() @@ -102,107 +104,107 @@ def test_pandas_boolean(self, duckdb_cursor): data = numpy.array([True, None, pd.NA, numpy.nan, True]) df_in = pd.DataFrame( { - 'object': pd.Series(data, dtype='boolean'), + "object": pd.Series(data, dtype="boolean"), } ) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() - assert df_out['object'][0] == df_in['object'][0] - assert pd.isna(df_out['object'][1]) - assert pd.isna(df_out['object'][2]) - assert pd.isna(df_out['object'][3]) - assert df_out['object'][4] == df_in['object'][4] + assert df_out["object"][0] == df_in["object"][0] + assert pd.isna(df_out["object"][1]) + assert pd.isna(df_out["object"][2]) + assert pd.isna(df_out["object"][3]) + assert df_out["object"][4] == df_in["object"][4] def test_pandas_float32(self, duckdb_cursor): data = numpy.array([0.1, 0.32, 0.78, numpy.nan]) df_in = pd.DataFrame( { - 'object': pd.Series(data, dtype='float32'), + "object": pd.Series(data, dtype="float32"), } ) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() - assert df_out['object'][0] == df_in['object'][0] - assert df_out['object'][1] == df_in['object'][1] - assert df_out['object'][2] == df_in['object'][2] - assert pd.isna(df_out['object'][3]) + assert df_out["object"][0] == df_in["object"][0] + assert df_out["object"][1] == df_in["object"][1] + assert df_out["object"][2] == df_in["object"][2] + assert pd.isna(df_out["object"][3]) def test_pandas_float64(self): - data = numpy.array([0.233, numpy.nan, 3456.2341231, float('-inf'), -23424.45345, float('+inf'), 0.0000000001]) + data = numpy.array([0.233, numpy.nan, 3456.2341231, float("-inf"), -23424.45345, float("+inf"), 0.0000000001]) df_in = pd.DataFrame( { - 'object': pd.Series(data, dtype='float64'), + "object": pd.Series(data, dtype="float64"), } ) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() for i in range(len(data)): - if pd.isna(df_out['object'][i]): + if pd.isna(df_out["object"][i]): assert i == 1 continue - assert df_out['object'][i] == df_in['object'][i] + assert df_out["object"][i] == df_in["object"][i] def test_pandas_interval(self, duckdb_cursor): - if pd.__version__ != '1.2.4': + if pd.__version__ != "1.2.4": return data = numpy.array([2069211000000000, numpy.datetime64("NaT")]) df_in = pd.DataFrame( { - 'object': pd.Series(data, dtype='timedelta64[ns]'), + "object": pd.Series(data, dtype="timedelta64[ns]"), } ) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() - assert df_out['object'][0] == df_in['object'][0] - assert pd.isnull(df_out['object'][1]) + assert df_out["object"][0] == df_in["object"][0] + assert pd.isnull(df_out["object"][1]) def test_pandas_encoded_utf8(self, duckdb_cursor): - data = u'\u00c3' # Unicode data - data = [data.encode('utf8')] + data = "\u00c3" # Unicode data + data = [data.encode("utf8")] expected_result = data[0] - df_in = pd.DataFrame({'object': pd.Series(data, dtype='object')}) + df_in = pd.DataFrame({"object": pd.Series(data, dtype="object")}) result = duckdb.query_df(df_in, "data", "SELECT * FROM data").fetchone()[0] assert result == expected_result @pytest.mark.parametrize( - 'dtype', + "dtype", [ - 'bool', - 'utinyint', - 'usmallint', - 'uinteger', - 'ubigint', - 'tinyint', - 'smallint', - 'integer', - 'bigint', - 'float', - 'double', + "bool", + "utinyint", + "usmallint", + "uinteger", + "ubigint", + "tinyint", + "smallint", + "integer", + "bigint", + "float", + "double", ], ) def test_producing_nullable_dtypes(self, duckdb_cursor, dtype): class Input: - def __init__(self, value, expected_dtype): + def __init__(self, value, expected_dtype) -> None: self.value = value self.expected_dtype = expected_dtype inputs = { - 'bool': Input('true', 'BooleanDtype'), - 'utinyint': Input('255', 'UInt8Dtype'), - 'usmallint': Input('65535', 'UInt16Dtype'), - 'uinteger': Input('4294967295', 'UInt32Dtype'), - 'ubigint': Input('18446744073709551615', 'UInt64Dtype'), - 'tinyint': Input('-128', 'Int8Dtype'), - 'smallint': Input('-32768', 'Int16Dtype'), - 'integer': Input('-2147483648', 'Int32Dtype'), - 'bigint': Input('-9223372036854775808', 'Int64Dtype'), - 'float': Input('268043421344044473239570760152672894976.0000000000', 'float32'), - 'double': Input( - '14303088389124869511075243108389716684037132417196499782261853698893384831666205572097390431189931733040903060865714975797777061496396865611606109149583360363636503436181348332896211726552694379264498632046075093077887837955077425420408952536212326792778411457460885268567735875437456412217418386401944141824.0000000000', - 'float64', + "bool": Input("true", "BooleanDtype"), + "utinyint": Input("255", "UInt8Dtype"), + "usmallint": Input("65535", "UInt16Dtype"), + "uinteger": Input("4294967295", "UInt32Dtype"), + "ubigint": Input("18446744073709551615", "UInt64Dtype"), + "tinyint": Input("-128", "Int8Dtype"), + "smallint": Input("-32768", "Int16Dtype"), + "integer": Input("-2147483648", "Int32Dtype"), + "bigint": Input("-9223372036854775808", "Int64Dtype"), + "float": Input("268043421344044473239570760152672894976.0000000000", "float32"), + "double": Input( + "14303088389124869511075243108389716684037132417196499782261853698893384831666205572097390431189931733040903060865714975797777061496396865611606109149583360363636503436181348332896211726552694379264498632046075093077887837955077425420408952536212326792778411457460885268567735875437456412217418386401944141824.0000000000", + "float64", ), } @@ -222,7 +224,7 @@ def __init__(self, value, expected_dtype): rel = duckdb_cursor.sql(query) # Pandas <= 2.2.3 does not convert without throwing a warning - warnings.simplefilter(action='ignore', category=RuntimeWarning) + warnings.simplefilter(action="ignore", category=RuntimeWarning) with suppress(TypeError): df = rel.df() warnings.resetwarnings() @@ -231,4 +233,4 @@ def __init__(self, value, expected_dtype): expected_dtype = getattr(pd, input.expected_dtype) else: expected_dtype = numpy.dtype(input.expected_dtype) - assert isinstance(df['a'].dtype, expected_dtype) + assert isinstance(df["a"].dtype, expected_dtype) diff --git a/tests/fast/pandas/test_pandas_unregister.py b/tests/fast/pandas/test_pandas_unregister.py index 794e5910..ab83eb42 100644 --- a/tests/fast/pandas/test_pandas_unregister.py +++ b/tests/fast/pandas/test_pandas_unregister.py @@ -1,32 +1,31 @@ -import duckdb -import pytest -import tempfile -import os import gc +import tempfile + import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb -class TestPandasUnregister(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) +class TestPandasUnregister: + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_pandas_unregister1(self, duckdb_cursor, pandas): df = pandas.DataFrame([[1, 2, 3], [4, 5, 6]]) connection = duckdb.connect(":memory:") connection.register("dataframe", df) - df2 = connection.execute("SELECT * FROM dataframe;").fetchdf() + df2 = connection.execute("SELECT * FROM dataframe;").fetchdf() # noqa: F841 connection.unregister("dataframe") - with pytest.raises(duckdb.CatalogException, match='Table with name dataframe does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name dataframe does not exist"): connection.execute("SELECT * FROM dataframe;").fetchdf() - with pytest.raises(duckdb.CatalogException, match='View with name dataframe does not exist'): + with pytest.raises(duckdb.CatalogException, match="View with name dataframe does not exist"): connection.execute("DROP VIEW dataframe;") connection.execute("DROP VIEW IF EXISTS dataframe;") - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_pandas_unregister2(self, duckdb_cursor, pandas): - fd, db = tempfile.mkstemp() - os.close(fd) - os.remove(db) + with tempfile.NamedTemporaryFile() as tmp: + db = tmp.name connection = duckdb.connect(db) df = pandas.DataFrame([[1, 2, 3], [4, 5, 6]]) @@ -39,7 +38,7 @@ def test_pandas_unregister2(self, duckdb_cursor, pandas): connection = duckdb.connect(db) assert len(connection.execute("PRAGMA show_tables;").fetchall()) == 0 - with pytest.raises(duckdb.CatalogException, match='Table with name dataframe does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name dataframe does not exist"): connection.execute("SELECT * FROM dataframe;").fetchdf() connection.close() @@ -50,6 +49,6 @@ def test_pandas_unregister2(self, duckdb_cursor, pandas): # Reconnecting after DataFrame freed. connection = duckdb.connect(db) assert len(connection.execute("PRAGMA show_tables;").fetchall()) == 0 - with pytest.raises(duckdb.CatalogException, match='Table with name dataframe does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name dataframe does not exist"): connection.execute("SELECT * FROM dataframe;").fetchdf() connection.close() diff --git a/tests/fast/pandas/test_pandas_update.py b/tests/fast/pandas/test_pandas_update.py index 663d6da2..bc1740d9 100644 --- a/tests/fast/pandas/test_pandas_update.py +++ b/tests/fast/pandas/test_pandas_update.py @@ -1,13 +1,14 @@ -import duckdb import pandas as pd +import duckdb + -class TestPandasUpdateList(object): +class TestPandasUpdateList: def test_pandas_update_list(self, duckdb_cursor): - duckdb_cursor = duckdb.connect(':memory:') - duckdb_cursor.execute('create table t (l int[])') - duckdb_cursor.execute('insert into t values ([1, 2]), ([3,4])') - duckdb_cursor.execute('update t set l = [5, 6]') - expected = pd.DataFrame({'l': [[5, 6], [5, 6]]}) - res = duckdb_cursor.execute('select * from t').fetchdf() + duckdb_cursor = duckdb.connect(":memory:") + duckdb_cursor.execute("create table t (l int[])") + duckdb_cursor.execute("insert into t values ([1, 2]), ([3,4])") + duckdb_cursor.execute("update t set l = [5, 6]") + expected = pd.DataFrame({"l": [[5, 6], [5, 6]]}) + res = duckdb_cursor.execute("select * from t").fetchdf() pd.testing.assert_frame_equal(expected, res) diff --git a/tests/fast/pandas/test_parallel_pandas_scan.py b/tests/fast/pandas/test_parallel_pandas_scan.py index a9fd99b9..9ac7b738 100644 --- a/tests/fast/pandas/test_parallel_pandas_scan.py +++ b/tests/fast/pandas/test_parallel_pandas_scan.py @@ -1,14 +1,15 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- -import duckdb -import numpy import datetime + +import numpy import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb def run_parallel_queries(main_table, left_join_table, expected_df, pandas, iteration_count=5): - for i in range(0, iteration_count): + for _i in range(iteration_count): output_df = None sql = """ select @@ -24,8 +25,8 @@ def run_parallel_queries(main_table, left_join_table, expected_df, pandas, itera try: duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") - duckdb_conn.register('main_table', main_table) - duckdb_conn.register('left_join_table', left_join_table) + duckdb_conn.register("main_table", main_table) + duckdb_conn.register("left_join_table", left_join_table) output_df = duckdb_conn.execute(sql).fetchdf() pandas.testing.assert_frame_equal(expected_df, output_df) print(output_df) @@ -35,70 +36,70 @@ def run_parallel_queries(main_table, left_join_table, expected_df, pandas, itera duckdb_conn.close() -class TestParallelPandasScan(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) +class TestParallelPandasScan: + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_parallel_numeric_scan(self, duckdb_cursor, pandas): main_table = pandas.DataFrame([{"join_column": 3}]) left_join_table = pandas.DataFrame([{"join_column": 3, "other_column": 4}]) run_parallel_queries(main_table, left_join_table, left_join_table, pandas) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_parallel_ascii_text(self, duckdb_cursor, pandas): main_table = pandas.DataFrame([{"join_column": "text"}]) left_join_table = pandas.DataFrame([{"join_column": "text", "other_column": "more text"}]) run_parallel_queries(main_table, left_join_table, left_join_table, pandas) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_parallel_unicode_text(self, duckdb_cursor, pandas): - main_table = pandas.DataFrame([{"join_column": u"mühleisen"}]) - left_join_table = pandas.DataFrame([{"join_column": u"mühleisen", "other_column": u"höhöhö"}]) + main_table = pandas.DataFrame([{"join_column": "mühleisen"}]) + left_join_table = pandas.DataFrame([{"join_column": "mühleisen", "other_column": "höhöhö"}]) run_parallel_queries(main_table, left_join_table, left_join_table, pandas) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_parallel_complex_unicode_text(self, duckdb_cursor, pandas): - main_table = pandas.DataFrame([{"join_column": u"鴨"}]) - left_join_table = pandas.DataFrame([{"join_column": u"鴨", "other_column": u"數據庫"}]) + main_table = pandas.DataFrame([{"join_column": "鴨"}]) + left_join_table = pandas.DataFrame([{"join_column": "鴨", "other_column": "數據庫"}]) run_parallel_queries(main_table, left_join_table, left_join_table, pandas) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_parallel_emojis(self, duckdb_cursor, pandas): - main_table = pandas.DataFrame([{"join_column": u"🤦🏼‍♂️ L🤦🏼‍♂️R 🤦🏼‍♂️"}]) - left_join_table = pandas.DataFrame([{"join_column": u"🤦🏼‍♂️ L🤦🏼‍♂️R 🤦🏼‍♂️", "other_column": u"🦆🍞🦆"}]) + main_table = pandas.DataFrame([{"join_column": "🤦🏼‍♂️ L🤦🏼‍♂️R 🤦🏼‍♂️"}]) + left_join_table = pandas.DataFrame([{"join_column": "🤦🏼‍♂️ L🤦🏼‍♂️R 🤦🏼‍♂️", "other_column": "🦆🍞🦆"}]) run_parallel_queries(main_table, left_join_table, left_join_table, pandas) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_parallel_numeric_object(self, duckdb_cursor, pandas): - main_table = pandas.DataFrame({'join_column': pandas.Series([3], dtype="Int8")}) + main_table = pandas.DataFrame({"join_column": pandas.Series([3], dtype="Int8")}) left_join_table = pandas.DataFrame( - {'join_column': pandas.Series([3], dtype="Int8"), 'other_column': pandas.Series([4], dtype="Int8")} + {"join_column": pandas.Series([3], dtype="Int8"), "other_column": pandas.Series([4], dtype="Int8")} ) expected_df = pandas.DataFrame( {"join_column": numpy.array([3], dtype=numpy.int8), "other_column": numpy.array([4], dtype=numpy.int8)} ) run_parallel_queries(main_table, left_join_table, expected_df, pandas) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_parallel_timestamp(self, duckdb_cursor, pandas): - main_table = pandas.DataFrame({'join_column': [pandas.Timestamp('20180310T11:17:54Z')]}) + main_table = pandas.DataFrame({"join_column": [pandas.Timestamp("20180310T11:17:54Z")]}) left_join_table = pandas.DataFrame( { - 'join_column': [pandas.Timestamp('20180310T11:17:54Z')], - 'other_column': [pandas.Timestamp('20190310T11:17:54Z')], + "join_column": [pandas.Timestamp("20180310T11:17:54Z")], + "other_column": [pandas.Timestamp("20190310T11:17:54Z")], } ) expected_df = pandas.DataFrame( { - "join_column": numpy.array([datetime.datetime(2018, 3, 10, 11, 17, 54)], dtype='datetime64[ns]'), - "other_column": numpy.array([datetime.datetime(2019, 3, 10, 11, 17, 54)], dtype='datetime64[ns]'), + "join_column": numpy.array([datetime.datetime(2018, 3, 10, 11, 17, 54)], dtype="datetime64[ns]"), + "other_column": numpy.array([datetime.datetime(2019, 3, 10, 11, 17, 54)], dtype="datetime64[ns]"), } ) run_parallel_queries(main_table, left_join_table, expected_df, pandas) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_parallel_empty(self, duckdb_cursor, pandas): - df_empty = pandas.DataFrame({'A': []}) + df_empty = pandas.DataFrame({"A": []}) duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") duckdb_conn.execute("PRAGMA verify_parallelism") - duckdb_conn.register('main_table', df_empty) - assert duckdb_conn.execute('select * from main_table').fetchall() == [] + duckdb_conn.register("main_table", df_empty) + assert duckdb_conn.execute("select * from main_table").fetchall() == [] diff --git a/tests/fast/pandas/test_partitioned_pandas_scan.py b/tests/fast/pandas/test_partitioned_pandas_scan.py index 32c5352f..c1ab7b34 100644 --- a/tests/fast/pandas/test_partitioned_pandas_scan.py +++ b/tests/fast/pandas/test_partitioned_pandas_scan.py @@ -1,16 +1,15 @@ -import duckdb -import pandas as pd import numpy -import datetime -import time +import pandas as pd + +import duckdb -class TestPartitionedPandasScan(object): +class TestPartitionedPandasScan: def test_parallel_pandas(self, duckdb_cursor): con = duckdb.connect() - df = pd.DataFrame({'i': numpy.arange(10000000)}) + df = pd.DataFrame({"i": numpy.arange(10000000)}) - con.register('df', df) + con.register("df", df) seq_results = con.execute("SELECT SUM(i) FROM df").fetchall() diff --git a/tests/fast/pandas/test_progress_bar.py b/tests/fast/pandas/test_progress_bar.py index 241cedd6..78764624 100644 --- a/tests/fast/pandas/test_progress_bar.py +++ b/tests/fast/pandas/test_progress_bar.py @@ -1,17 +1,16 @@ -import duckdb -import pandas as pd import numpy -import datetime -import time +import pandas as pd + +import duckdb -class TestProgressBarPandas(object): +class TestProgressBarPandas: def test_progress_pandas_single(self, duckdb_cursor): con = duckdb.connect() - df = pd.DataFrame({'i': numpy.arange(10000000)}) + df = pd.DataFrame({"i": numpy.arange(10000000)}) - con.register('df', df) - con.register('df_2', df) + con.register("df", df) + con.register("df_2", df) con.execute("PRAGMA progress_bar_time=1") con.execute("PRAGMA disable_print_progress_bar") result = con.execute("SELECT SUM(df.i) FROM df inner join df_2 on (df.i = df_2.i)").fetchall() @@ -19,10 +18,10 @@ def test_progress_pandas_single(self, duckdb_cursor): def test_progress_pandas_parallel(self, duckdb_cursor): con = duckdb.connect() - df = pd.DataFrame({'i': numpy.arange(10000000)}) + df = pd.DataFrame({"i": numpy.arange(10000000)}) - con.register('df', df) - con.register('df_2', df) + con.register("df", df) + con.register("df_2", df) con.execute("PRAGMA progress_bar_time=1") con.execute("PRAGMA disable_print_progress_bar") con.execute("PRAGMA threads=4") @@ -31,9 +30,9 @@ def test_progress_pandas_parallel(self, duckdb_cursor): def test_progress_pandas_empty(self, duckdb_cursor): con = duckdb.connect() - df = pd.DataFrame({'i': []}) - con.register('df', df) + df = pd.DataFrame({"i": []}) + con.register("df", df) con.execute("PRAGMA progress_bar_time=1") con.execute("PRAGMA disable_print_progress_bar") result = con.execute("SELECT SUM(df.i) from df").fetchall() - assert result[0][0] == None + assert result[0][0] is None diff --git a/tests/fast/pandas/test_pyarrow_projection_pushdown.py b/tests/fast/pandas/test_pyarrow_projection_pushdown.py index e693e75c..87f49f04 100644 --- a/tests/fast/pandas/test_pyarrow_projection_pushdown.py +++ b/tests/fast/pandas/test_pyarrow_projection_pushdown.py @@ -1,21 +1,20 @@ -import duckdb -import os import pytest - from conftest import pandas_supports_arrow_backend +import duckdb + pa = pytest.importorskip("pyarrow") ds = pytest.importorskip("pyarrow.dataset") -_ = pytest.importorskip("pandas", '2.0.0') +_ = pytest.importorskip("pandas", "2.0.0") @pytest.mark.skipif(not pandas_supports_arrow_backend(), reason="pandas does not support the 'pyarrow' backend") -class TestArrowDFProjectionPushdown(object): +class TestArrowDFProjectionPushdown: def test_projection_pushdown_no_filter(self, duckdb_cursor): duckdb_conn = duckdb.connect() duckdb_conn.execute("CREATE TABLE test (a INTEGER, b INTEGER, c INTEGER)") duckdb_conn.execute("INSERT INTO test VALUES (1,1,1),(10,10,10),(100,10,100),(NULL,NULL,NULL)") duck_tbl = duckdb_conn.table("test") - arrow_table = duck_tbl.df().convert_dtypes(dtype_backend='pyarrow') + arrow_table = duck_tbl.df().convert_dtypes(dtype_backend="pyarrow") duckdb_conn.register("testarrowtable", arrow_table) assert duckdb_conn.execute("SELECT sum(a) FROM testarrowtable").fetchall() == [(111,)] diff --git a/tests/fast/pandas/test_same_name.py b/tests/fast/pandas/test_same_name.py index f48eb7eb..ff499ddf 100644 --- a/tests/fast/pandas/test_same_name.py +++ b/tests/fast/pandas/test_same_name.py @@ -1,80 +1,78 @@ -import pytest -import duckdb import pandas as pd -class TestMultipleColumnsSameName(object): +class TestMultipleColumnsSameName: def test_multiple_columns_with_same_name(self, duckdb_cursor): - df = pd.DataFrame({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8], 'd': [9, 10, 11, 12]}) + df = pd.DataFrame({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8], "d": [9, 10, 11, 12]}) df = df.rename(columns={df.columns[1]: "a"}) - duckdb_cursor.register('df_view', df) + duckdb_cursor.register("df_view", df) - assert duckdb_cursor.table("df_view").columns == ['a', 'a_1', 'd'] + assert duckdb_cursor.table("df_view").columns == ["a", "a_1", "d"] assert duckdb_cursor.execute("select a_1 from df_view;").fetchall() == [(5,), (6,), (7,), (8,)] assert duckdb_cursor.execute("select a from df_view;").fetchall() == [(1,), (2,), (3,), (4,)] # Verify we are not changing original dataframe - assert all(df.columns == ['a', 'a', 'd']), df.columns + assert all(df.columns == ["a", "a", "d"]), df.columns def test_multiple_columns_with_same_name_relation(self, duckdb_cursor): - df = pd.DataFrame({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8], 'd': [9, 10, 11, 12]}) + df = pd.DataFrame({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8], "d": [9, 10, 11, 12]}) df = df.rename(columns={df.columns[1]: "a"}) rel = duckdb_cursor.from_df(df) assert rel.query("df_view", "DESCRIBE df_view;").fetchall() == [ - ('a', 'BIGINT', 'YES', None, None, None), - ('a_1', 'BIGINT', 'YES', None, None, None), - ('d', 'BIGINT', 'YES', None, None, None), + ("a", "BIGINT", "YES", None, None, None), + ("a_1", "BIGINT", "YES", None, None, None), + ("d", "BIGINT", "YES", None, None, None), ] assert rel.query("df_view", "select a_1 from df_view;").fetchall() == [(5,), (6,), (7,), (8,)] assert rel.query("df_view", "select a from df_view;").fetchall() == [(1,), (2,), (3,), (4,)] # Verify we are not changing original dataframe - assert all(df.columns == ['a', 'a', 'd']), df.columns + assert all(df.columns == ["a", "a", "d"]), df.columns def test_multiple_columns_with_same_name_replacement_scans(self, duckdb_cursor): - df = pd.DataFrame({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8], 'd': [9, 10, 11, 12]}) + df = pd.DataFrame({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8], "d": [9, 10, 11, 12]}) df = df.rename(columns={df.columns[1]: "a"}) assert duckdb_cursor.execute("select a_1 from df;").fetchall() == [(5,), (6,), (7,), (8,)] assert duckdb_cursor.execute("select a from df;").fetchall() == [(1,), (2,), (3,), (4,)] # Verify we are not changing original dataframe - assert all(df.columns == ['a', 'a', 'd']), df.columns + assert all(df.columns == ["a", "a", "d"]), df.columns def test_3669(self, duckdb_cursor): - df = pd.DataFrame([(1, 5, 9), (2, 6, 10), (3, 7, 11), (4, 8, 12)], columns=['a_1', 'a', 'a']) - duckdb_cursor.register('df_view', df) - assert duckdb_cursor.table("df_view").columns == ['a_1', 'a', 'a_2'] + df = pd.DataFrame([(1, 5, 9), (2, 6, 10), (3, 7, 11), (4, 8, 12)], columns=["a_1", "a", "a"]) + duckdb_cursor.register("df_view", df) + assert duckdb_cursor.table("df_view").columns == ["a_1", "a", "a_2"] assert duckdb_cursor.execute("select a_1 from df_view;").fetchall() == [(1,), (2,), (3,), (4,)] assert duckdb_cursor.execute("select a from df_view;").fetchall() == [(5,), (6,), (7,), (8,)] # Verify we are not changing original dataframe - assert all(df.columns == ['a_1', 'a', 'a']), df.columns + assert all(df.columns == ["a_1", "a", "a"]), df.columns def test_minimally_rename(self, duckdb_cursor): df = pd.DataFrame( - [(1, 5, 9, 13), (2, 6, 10, 14), (3, 7, 11, 15), (4, 8, 12, 16)], columns=['a_1', 'a', 'a', 'a_2'] + [(1, 5, 9, 13), (2, 6, 10, 14), (3, 7, 11, 15), (4, 8, 12, 16)], columns=["a_1", "a", "a", "a_2"] ) - duckdb_cursor.register('df_view', df) + duckdb_cursor.register("df_view", df) rel = duckdb_cursor.table("df_view") res = rel.columns - assert res == ['a_1', 'a', 'a_2', 'a_2_1'] + assert res == ["a_1", "a", "a_2", "a_2_1"] assert duckdb_cursor.execute("select a_1 from df_view;").fetchall() == [(1,), (2,), (3,), (4,)] assert duckdb_cursor.execute("select a from df_view;").fetchall() == [(5,), (6,), (7,), (8,)] assert duckdb_cursor.execute("select a_2 from df_view;").fetchall() == [(9,), (10,), (11,), (12,)] assert duckdb_cursor.execute("select a_2_1 from df_view;").fetchall() == [(13,), (14,), (15,), (16,)] # Verify we are not changing original dataframe - assert all(df.columns == ['a_1', 'a', 'a', 'a_2']), df.columns + assert all(df.columns == ["a_1", "a", "a", "a_2"]), df.columns def test_multiple_columns_with_same_name_2(self, duckdb_cursor): - df = pd.DataFrame({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8], 'a_1': [9, 10, 11, 12]}) + df = pd.DataFrame({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8], "a_1": [9, 10, 11, 12]}) df = df.rename(columns={df.columns[1]: "a_1"}) - duckdb_cursor.register('df_view', df) - assert duckdb_cursor.table("df_view").columns == ['a', 'a_1', 'a_1_1'] + duckdb_cursor.register("df_view", df) + assert duckdb_cursor.table("df_view").columns == ["a", "a_1", "a_1_1"] assert duckdb_cursor.execute("select a_1 from df_view;").fetchall() == [(5,), (6,), (7,), (8,)] assert duckdb_cursor.execute("select a from df_view;").fetchall() == [(1,), (2,), (3,), (4,)] assert duckdb_cursor.execute("select a_1_1 from df_view;").fetchall() == [(9,), (10,), (11,), (12,)] def test_case_insensitive(self, duckdb_cursor): - df = pd.DataFrame({'A_1': [1, 2, 3, 4], 'a_1': [9, 10, 11, 12]}) - duckdb_cursor.register('df_view', df) - assert duckdb_cursor.table("df_view").columns == ['A_1', 'a_1_1'] + df = pd.DataFrame({"A_1": [1, 2, 3, 4], "a_1": [9, 10, 11, 12]}) + duckdb_cursor.register("df_view", df) + assert duckdb_cursor.table("df_view").columns == ["A_1", "a_1_1"] assert duckdb_cursor.execute("select a_1 from df_view;").fetchall() == [(1,), (2,), (3,), (4,)] assert duckdb_cursor.execute("select a_1_1 from df_view;").fetchall() == [(9,), (10,), (11,), (12,)] diff --git a/tests/fast/pandas/test_stride.py b/tests/fast/pandas/test_stride.py index 5efe8d56..cbe23cfd 100644 --- a/tests/fast/pandas/test_stride.py +++ b/tests/fast/pandas/test_stride.py @@ -1,34 +1,36 @@ +import datetime + +import numpy as np import pandas as pd + import duckdb -import numpy as np -import datetime -class TestPandasStride(object): +class TestPandasStride: def test_stride(self, duckdb_cursor): expected_df = pd.DataFrame(np.arange(20).reshape(5, 4), columns=["a", "b", "c", "d"]) con = duckdb.connect() - con.register('df_view', expected_df) + con.register("df_view", expected_df) output_df = con.execute("SELECT * FROM df_view;").fetchdf() pd.testing.assert_frame_equal(expected_df, output_df) def test_stride_fp32(self, duckdb_cursor): - expected_df = pd.DataFrame(np.arange(20, dtype='float32').reshape(5, 4), columns=["a", "b", "c", "d"]) + expected_df = pd.DataFrame(np.arange(20, dtype="float32").reshape(5, 4), columns=["a", "b", "c", "d"]) con = duckdb.connect() - con.register('df_view', expected_df) + con.register("df_view", expected_df) output_df = con.execute("SELECT * FROM df_view;").fetchdf() for col in output_df.columns: - assert str(output_df[col].dtype) == 'float32' + assert str(output_df[col].dtype) == "float32" pd.testing.assert_frame_equal(expected_df, output_df) def test_stride_datetime(self, duckdb_cursor): - df = pd.DataFrame({'date': pd.Series(pd.date_range("2024-01-01", freq="D", periods=100))}) + df = pd.DataFrame({"date": pd.Series(pd.date_range("2024-01-01", freq="D", periods=100))}) df = df.loc[::23,] roundtrip = duckdb_cursor.sql("select * from df").df() expected = pd.DataFrame( { - 'date': [ + "date": [ datetime.datetime(2024, 1, 1), datetime.datetime(2024, 1, 24), datetime.datetime(2024, 2, 16), @@ -40,13 +42,13 @@ def test_stride_datetime(self, duckdb_cursor): pd.testing.assert_frame_equal(roundtrip, expected) def test_stride_timedelta(self, duckdb_cursor): - df = pd.DataFrame({'date': [datetime.timedelta(days=i) for i in range(100)]}) + df = pd.DataFrame({"date": [datetime.timedelta(days=i) for i in range(100)]}) df = df.loc[::23,] roundtrip = duckdb_cursor.sql("select * from df").df() expected = pd.DataFrame( { - 'date': [ + "date": [ datetime.timedelta(days=0), datetime.timedelta(days=23), datetime.timedelta(days=46), @@ -58,10 +60,10 @@ def test_stride_timedelta(self, duckdb_cursor): pd.testing.assert_frame_equal(roundtrip, expected) def test_stride_fp64(self, duckdb_cursor): - expected_df = pd.DataFrame(np.arange(20, dtype='float64').reshape(5, 4), columns=["a", "b", "c", "d"]) + expected_df = pd.DataFrame(np.arange(20, dtype="float64").reshape(5, 4), columns=["a", "b", "c", "d"]) con = duckdb.connect() - con.register('df_view', expected_df) + con.register("df_view", expected_df) output_df = con.execute("SELECT * FROM df_view;").fetchdf() for col in output_df.columns: - assert str(output_df[col].dtype) == 'float64' + assert str(output_df[col].dtype) == "float64" pd.testing.assert_frame_equal(expected_df, output_df) diff --git a/tests/fast/pandas/test_timedelta.py b/tests/fast/pandas/test_timedelta.py index 5c6aa4b9..7c41c593 100644 --- a/tests/fast/pandas/test_timedelta.py +++ b/tests/fast/pandas/test_timedelta.py @@ -1,17 +1,19 @@ +import datetime import platform + import pandas as pd -import duckdb -import datetime import pytest +import duckdb + -class TestTimedelta(object): +class TestTimedelta: def test_timedelta_positive(self, duckdb_cursor): duckdb_interval = duckdb_cursor.query( "SELECT '2290-01-01 23:59:00'::TIMESTAMP - '2000-01-01 23:59:00'::TIMESTAMP AS '0'" ).df() data = [datetime.timedelta(microseconds=9151574400000000)] - df_in = pd.DataFrame({0: pd.Series(data=data, dtype='object')}) + df_in = pd.DataFrame({0: pd.Series(data=data, dtype="object")}) df_out = duckdb.query_df(df_in, "df", "select * from df", connection=duckdb_cursor).df() pd.testing.assert_frame_equal(df_out, duckdb_interval) @@ -20,7 +22,7 @@ def test_timedelta_basic(self, duckdb_cursor): "SELECT '2290-08-30 23:53:40'::TIMESTAMP - '2000-02-01 01:56:00'::TIMESTAMP AS '0'" ).df() data = [datetime.timedelta(microseconds=9169797460000000)] - df_in = pd.DataFrame({0: pd.Series(data=data, dtype='object')}) + df_in = pd.DataFrame({0: pd.Series(data=data, dtype="object")}) df_out = duckdb.query_df(df_in, "df", "select * from df", connection=duckdb_cursor).df() pd.testing.assert_frame_equal(df_out, duckdb_interval) @@ -29,24 +31,24 @@ def test_timedelta_negative(self, duckdb_cursor): "SELECT '2000-01-01 23:59:00'::TIMESTAMP - '2290-01-01 23:59:00'::TIMESTAMP AS '0'" ).df() data = [datetime.timedelta(microseconds=-9151574400000000)] - df_in = pd.DataFrame({0: pd.Series(data=data, dtype='object')}) + df_in = pd.DataFrame({0: pd.Series(data=data, dtype="object")}) df_out = duckdb.query_df(df_in, "df", "select * from df", connection=duckdb_cursor).df() pd.testing.assert_frame_equal(df_out, duckdb_interval) - @pytest.mark.parametrize('days', [1, 9999]) - @pytest.mark.parametrize('seconds', [0, 60]) + @pytest.mark.parametrize("days", [1, 9999]) + @pytest.mark.parametrize("seconds", [0, 60]) @pytest.mark.parametrize( - 'microseconds', + "microseconds", [ 0, 232493, 999_999, ], ) - @pytest.mark.parametrize('milliseconds', [0, 999]) - @pytest.mark.parametrize('minutes', [0, 60]) - @pytest.mark.parametrize('hours', [0, 24]) - @pytest.mark.parametrize('weeks', [0, 51]) + @pytest.mark.parametrize("milliseconds", [0, 999]) + @pytest.mark.parametrize("minutes", [0, 60]) + @pytest.mark.parametrize("hours", [0, 24]) + @pytest.mark.parametrize("weeks", [0, 51]) @pytest.mark.skipif(platform.system() == "Emscripten", reason="Bind parameters are broken when running on Pyodide") def test_timedelta_coverage(self, duckdb_cursor, days, seconds, microseconds, milliseconds, minutes, hours, weeks): def create_duck_interval(days, seconds, microseconds, milliseconds, minutes, hours, weeks) -> str: @@ -77,8 +79,8 @@ def create_python_interval( equality = "select {value} = $1, {value}, $1" equality = equality.format(value=duck_interval) res, a, b = duckdb_cursor.execute(equality, [val]).fetchone() - if res != True: - # FIXME: in some cases intervals that are identical don't compare equal. + if not res: + # TODO: in some cases intervals that are identical don't compare equal. # noqa: TD002, TD003 assert a == b else: - assert res == True + assert res diff --git a/tests/fast/pandas/test_timestamp.py b/tests/fast/pandas/test_timestamp.py index 0a580025..81651634 100644 --- a/tests/fast/pandas/test_timestamp.py +++ b/tests/fast/pandas/test_timestamp.py @@ -1,54 +1,56 @@ -import duckdb import datetime import os -import pytest -import pandas as pd import platform + +import pandas as pd +import pytest from conftest import pandas_2_or_higher +import duckdb + -class TestPandasTimestamps(object): - @pytest.mark.parametrize('unit', ['s', 'ms', 'us', 'ns']) +class TestPandasTimestamps: + @pytest.mark.parametrize("unit", ["s", "ms", "us", "ns"]) def test_timestamp_types_roundtrip(self, unit): d = { - 'time': pd.Series( + "time": pd.Series( [pd.Timestamp(datetime.datetime(2020, 6, 12, 14, 43, 24, 394587), unit=unit)], - dtype=f'datetime64[{unit}]', + dtype=f"datetime64[{unit}]", ) } df = pd.DataFrame(data=d) df_from_duck = duckdb.from_df(df).df() assert df_from_duck.equals(df) - @pytest.mark.parametrize('unit', ['s', 'ms', 'us', 'ns']) + @pytest.mark.parametrize("unit", ["s", "ms", "us", "ns"]) def test_timestamp_timezone_roundtrip(self, unit): if pandas_2_or_higher(): - dtype = pd.core.dtypes.dtypes.DatetimeTZDtype(unit=unit, tz='UTC') - expected_dtype = pd.core.dtypes.dtypes.DatetimeTZDtype(unit='us', tz='UTC') + dtype = pd.core.dtypes.dtypes.DatetimeTZDtype(unit=unit, tz="UTC") + expected_dtype = pd.core.dtypes.dtypes.DatetimeTZDtype(unit="us", tz="UTC") else: # Older versions of pandas only support 'ns' as timezone unit - expected_dtype = pd.core.dtypes.dtypes.DatetimeTZDtype(unit='ns', tz='UTC') - dtype = pd.core.dtypes.dtypes.DatetimeTZDtype(unit='ns', tz='UTC') + expected_dtype = pd.core.dtypes.dtypes.DatetimeTZDtype(unit="ns", tz="UTC") + dtype = pd.core.dtypes.dtypes.DatetimeTZDtype(unit="ns", tz="UTC") conn = duckdb.connect() conn.execute("SET TimeZone =UTC") d = { - 'time': pd.Series( - [pd.Timestamp(datetime.datetime(2020, 6, 12, 14, 43, 24, 394587), unit=unit, tz='UTC')], + "time": pd.Series( + [pd.Timestamp(datetime.datetime(2020, 6, 12, 14, 43, 24, 394587), unit=unit, tz="UTC")], dtype=dtype, ) } df = pd.DataFrame(data=d) - # Our timezone aware type is in US (microseconds), when we scan a timestamp column that isn't US and has timezone info, - # we convert the time unit to US + # Our timezone aware type is in US (microseconds), when we scan a timestamp column that isn't US and has + # timezone info, we convert the time unit to US expected = pd.DataFrame(data=d, dtype=expected_dtype) df_from_duck = conn.from_df(df).df() assert df_from_duck.equals(expected) - @pytest.mark.parametrize('unit', ['s', 'ms', 'us', 'ns']) + @pytest.mark.parametrize("unit", ["s", "ms", "us", "ns"]) def test_timestamp_nulls(self, unit): - d = {'time': pd.Series([pd.Timestamp(None, unit=unit)], dtype=f'datetime64[{unit}]')} + d = {"time": pd.Series([pd.Timestamp(None, unit=unit)], dtype=f"datetime64[{unit}]")} df = pd.DataFrame(data=d) df_from_duck = duckdb.from_df(df).df() assert df_from_duck.equals(df) @@ -56,10 +58,10 @@ def test_timestamp_nulls(self, unit): def test_timestamp_timedelta(self): df = pd.DataFrame( { - 'a': [pd.Timedelta(1, unit='s')], - 'b': [pd.Timedelta(None, unit='s')], - 'c': [pd.Timedelta(1, unit='us')], - 'd': [pd.Timedelta(1, unit='ms')], + "a": [pd.Timedelta(1, unit="s")], + "b": [pd.Timedelta(None, unit="s")], + "c": [pd.Timedelta(1, unit="us")], + "d": [pd.Timedelta(1, unit="ms")], } ) df_from_duck = duckdb.from_df(df).df() @@ -78,4 +80,4 @@ def test_timestamp_timezone(self, duckdb_cursor): """ ) res = rel.df() - assert res['dateTime'][0] == res['dateTime_1'][0] + assert res["dateTime"][0] == res["dateTime_1"][0] diff --git a/tests/fast/relational_api/test_groupings.py b/tests/fast/relational_api/test_groupings.py index fc81deba..250df7ad 100644 --- a/tests/fast/relational_api/test_groupings.py +++ b/tests/fast/relational_api/test_groupings.py @@ -1,6 +1,7 @@ -import duckdb import pytest +import duckdb + @pytest.fixture def con(): @@ -17,12 +18,12 @@ def con(): ) AS tbl(a, b, c)) """ ) - yield conn + return conn -class TestGroupings(object): +class TestGroupings: def test_basic_grouping(self, con): - rel = con.table('tbl').sum("a", "b") + rel = con.table("tbl").sum("a", "b") res = rel.fetchall() assert res == [(7,), (2,), (5,)] @@ -31,7 +32,7 @@ def test_basic_grouping(self, con): assert res == res2 def test_cubed(self, con): - rel = con.table('tbl').sum("a", "CUBE (b)").order("ALL") + rel = con.table("tbl").sum("a", "CUBE (b)").order("ALL") res = rel.fetchall() assert res == [(2,), (5,), (7,), (14,)] @@ -40,7 +41,7 @@ def test_cubed(self, con): assert res == res2 def test_rollup(self, con): - rel = con.table('tbl').sum("a", "ROLLUP (b, c)").order("ALL") + rel = con.table("tbl").sum("a", "ROLLUP (b, c)").order("ALL") res = rel.fetchall() assert res == [(1,), (1,), (2,), (2,), (2,), (3,), (5,), (5,), (7,), (14,)] diff --git a/tests/fast/relational_api/test_joins.py b/tests/fast/relational_api/test_joins.py index 8eb365d5..726fdac8 100644 --- a/tests/fast/relational_api/test_joins.py +++ b/tests/fast/relational_api/test_joins.py @@ -1,5 +1,6 @@ -import duckdb import pytest + +import duckdb from duckdb import ColumnExpression @@ -26,62 +27,62 @@ def con(): ) AS t(a, b)) """ ) - yield conn + return conn -class TestRAPIJoins(object): +class TestRAPIJoins: def test_outer_join(self, con): - a = con.table('tbl_a') - b = con.table('tbl_b') - expr = ColumnExpression('tbl_a.b') == ColumnExpression('tbl_b.a') - rel = a.join(b, expr, 'outer') + a = con.table("tbl_a") + b = con.table("tbl_b") + expr = ColumnExpression("tbl_a.b") == ColumnExpression("tbl_b.a") + rel = a.join(b, expr, "outer") res = rel.fetchall() assert res == [(1, 1, 1, 4), (2, 1, 1, 4), (3, 2, None, None), (None, None, 3, 5)] def test_inner_join(self, con): - a = con.table('tbl_a') - b = con.table('tbl_b') - expr = ColumnExpression('tbl_a.b') == ColumnExpression('tbl_b.a') - rel = a.join(b, expr, 'inner') + a = con.table("tbl_a") + b = con.table("tbl_b") + expr = ColumnExpression("tbl_a.b") == ColumnExpression("tbl_b.a") + rel = a.join(b, expr, "inner") res = rel.fetchall() assert res == [(1, 1, 1, 4), (2, 1, 1, 4)] def test_anti_join(self, con): - a = con.table('tbl_a') - b = con.table('tbl_b') - expr = ColumnExpression('tbl_a.b') == ColumnExpression('tbl_b.a') - rel = a.join(b, expr, 'anti') + a = con.table("tbl_a") + b = con.table("tbl_b") + expr = ColumnExpression("tbl_a.b") == ColumnExpression("tbl_b.a") + rel = a.join(b, expr, "anti") res = rel.fetchall() # Only output the row(s) from A where the condition is false assert res == [(3, 2)] def test_left_join(self, con): - a = con.table('tbl_a') - b = con.table('tbl_b') - expr = ColumnExpression('tbl_a.b') == ColumnExpression('tbl_b.a') - rel = a.join(b, expr, 'left') + a = con.table("tbl_a") + b = con.table("tbl_b") + expr = ColumnExpression("tbl_a.b") == ColumnExpression("tbl_b.a") + rel = a.join(b, expr, "left") res = rel.fetchall() assert res == [(1, 1, 1, 4), (2, 1, 1, 4), (3, 2, None, None)] def test_right_join(self, con): - a = con.table('tbl_a') - b = con.table('tbl_b') - expr = ColumnExpression('tbl_a.b') == ColumnExpression('tbl_b.a') - rel = a.join(b, expr, 'right') + a = con.table("tbl_a") + b = con.table("tbl_b") + expr = ColumnExpression("tbl_a.b") == ColumnExpression("tbl_b.a") + rel = a.join(b, expr, "right") res = rel.fetchall() assert res == [(1, 1, 1, 4), (2, 1, 1, 4), (None, None, 3, 5)] def test_semi_join(self, con): - a = con.table('tbl_a') - b = con.table('tbl_b') - expr = ColumnExpression('tbl_a.b') == ColumnExpression('tbl_b.a') - rel = a.join(b, expr, 'semi') + a = con.table("tbl_a") + b = con.table("tbl_b") + expr = ColumnExpression("tbl_a.b") == ColumnExpression("tbl_b.a") + rel = a.join(b, expr, "semi") res = rel.fetchall() assert res == [(1, 1), (2, 1)] def test_cross_join(self, con): - a = con.table('tbl_a') - b = con.table('tbl_b') + a = con.table("tbl_a") + b = con.table("tbl_b") rel = a.cross(b) res = rel.fetchall() assert res == [(1, 1, 1, 4), (2, 1, 1, 4), (3, 2, 1, 4), (1, 1, 3, 5), (2, 1, 3, 5), (3, 2, 3, 5)] diff --git a/tests/fast/relational_api/test_pivot.py b/tests/fast/relational_api/test_pivot.py index d78df656..7052568e 100644 --- a/tests/fast/relational_api/test_pivot.py +++ b/tests/fast/relational_api/test_pivot.py @@ -1,13 +1,11 @@ -import duckdb -import pytest -import os import tempfile +from pathlib import Path -class TestPivot(object): +class TestPivot: def test_pivot_issue_14600(self, duckdb_cursor): duckdb_cursor.sql( - "create table input_data as select unnest(['u','v','w']) as a, unnest(['x','y','z']) as b, unnest([1,2,3]) as c;" + "create table input_data as select unnest(['u','v','w']) as a, unnest(['x','y','z']) as b, unnest([1,2,3]) as c;" # noqa: E501 ) pivot_1 = duckdb_cursor.query("pivot input_data on a using max(c) group by b;") pivot_2 = duckdb_cursor.query("pivot input_data on b using max(c) group by a;") @@ -20,11 +18,10 @@ def test_pivot_issue_14600(self, duckdb_cursor): def test_pivot_issue_14601(self, duckdb_cursor): duckdb_cursor.sql( - "create table input_data as select unnest(['u','v','w']) as a, unnest(['x','y','z']) as b, unnest([1,2,3]) as c;" + "create table input_data as select unnest(['u','v','w']) as a, unnest(['x','y','z']) as b, unnest([1,2,3]) as c;" # noqa: E501 ) pivot_1 = duckdb_cursor.query("pivot input_data on a using max(c) group by b;") pivot_1.create("pivot_1") export_dir = tempfile.mkdtemp() duckdb_cursor.query(f"EXPORT DATABASE '{export_dir}'") - with open(os.path.join(export_dir, "schema.sql"), "r") as f: - assert 'CREATE TYPE' not in f.read() + assert "CREATE TYPE" not in (Path(export_dir) / "schema.sql").read_text() diff --git a/tests/fast/relational_api/test_rapi_aggregations.py b/tests/fast/relational_api/test_rapi_aggregations.py index 29202759..ffb7e303 100644 --- a/tests/fast/relational_api/test_rapi_aggregations.py +++ b/tests/fast/relational_api/test_rapi_aggregations.py @@ -1,7 +1,7 @@ -import duckdb -from decimal import Decimal import pytest +import duckdb + @pytest.fixture(autouse=True) def setup_and_teardown_of_table(duckdb_cursor): @@ -23,45 +23,45 @@ def setup_and_teardown_of_table(duckdb_cursor): duckdb_cursor.execute("drop table agg") -@pytest.fixture() +@pytest.fixture def table(duckdb_cursor): return duckdb_cursor.table("agg") -class TestRAPIAggregations(object): +class TestRAPIAggregations: # General aggregate functions def test_any_value(self, table): result = table.order("id, t").any_value("v").execute().fetchall() expected = [(1,)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = ( table.order("id, t").any_value("v", groups="id", projected_columns="id").order("id").execute().fetchall() ) expected = [(1, 1), (2, 11), (3, 5)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_arg_max(self, table): result = table.arg_max("t", "v").execute().fetchall() expected = [(-1,)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = table.arg_max("t", "v", groups="id", projected_columns="id").order("id").execute().fetchall() expected = [(1, 3), (2, -1), (3, -2)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_arg_min(self, table): result = table.arg_min("t", "v").execute().fetchall() expected = [(0,)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = table.arg_min("t", "v", groups="id", projected_columns="id").order("id").execute().fetchall() expected = [(1, 2), (2, 4), (3, 0)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_avg(self, table): result = table.avg("v").execute().fetchall() @@ -78,41 +78,41 @@ def test_bit_and(self, table): result = table.bit_and("v").execute().fetchall() expected = [(0,)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = table.bit_and("v", groups="id", projected_columns="id").order("id").execute().fetchall() expected = [(1, 0), (2, 10), (3, 5)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_bit_or(self, table): result = table.bit_or("v").execute().fetchall() expected = [(-1,)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = table.bit_or("v", groups="id", projected_columns="id").order("id").execute().fetchall() expected = [(1, 3), (2, 11), (3, -1)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_bit_xor(self, table): result = table.bit_xor("v").execute().fetchall() expected = [(-7,)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = table.bit_xor("v", groups="id", projected_columns="id").order("id").execute().fetchall() expected = [(1, 2), (2, 1), (3, -6)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_bitstring_agg(self, table): result = table.bitstring_agg("v").execute().fetchall() expected = [("1011001000011",)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = table.bitstring_agg("v", groups="id", projected_columns="id").order("id").execute().fetchall() expected = [(1, "0011000000000"), (2, "0000000000011"), (3, "1000001000000")] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) with pytest.raises(duckdb.InvalidInputException): table.bitstring_agg("v", min="1") with pytest.raises(duckdb.InvalidTypeException): @@ -122,216 +122,214 @@ def test_bool_and(self, table): result = table.bool_and("v::BOOL").execute().fetchall() expected = [(True,)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = table.bool_and("t::BOOL", groups="id", projected_columns="id").order("id").execute().fetchall() expected = [(1, True), (2, True), (3, False)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_bool_or(self, table): result = table.bool_or("v::BOOL").execute().fetchall() expected = [(True,)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = table.bool_or("v::BOOL", groups="id", projected_columns="id").order("id").execute().fetchall() expected = [(1, True), (2, True), (3, True)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_count(self, table): result = table.count("*").execute().fetchall() expected = [(8,)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = table.count("*", groups="id", projected_columns="id").order("id").execute().fetchall() expected = [(1, 3), (2, 2), (3, 3)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_value_counts(self, table): result = table.value_counts("v").execute().fetchall() expected = [(None, 0), (-1, 1), (1, 2), (2, 1), (5, 1), (10, 1), (11, 1)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = table.value_counts("v", groups="v").order("v").execute().fetchall() expected = [(-1, 1), (1, 2), (2, 1), (5, 1), (10, 1), (11, 1), (None, 0)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_favg(self, table): result = [round(r[0], 2) for r in table.favg("f").execute().fetchall()] expected = [5.12] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = [ (r[0], round(r[1], 2)) for r in table.favg("f", groups="id", projected_columns="id").order("id").execute().fetchall() ] expected = [(1, 0.25), (2, 5.24), (3, 9.92)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_first(self, table): result = table.first("v").execute().fetchall() expected = [(1,)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = table.first("v", "id", "id").order("id").execute().fetchall() expected = [(1, 1), (2, 10), (3, -1)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_last(self, table): result = table.last("v").execute().fetchall() expected = [(None,)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = table.last("v", "id", "id").order("id").execute().fetchall() expected = [(1, 2), (2, 11), (3, None)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_fsum(self, table): result = [round(r[0], 2) for r in table.fsum("f").execute().fetchall()] expected = [40.99] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = [ (r[0], round(r[1], 2)) for r in table.fsum("f", groups="id", projected_columns="id").order("id").execute().fetchall() ] expected = [(1, 0.75), (2, 10.49), (3, 29.75)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_geomean(self, table): result = [round(r[0], 2) for r in table.geomean("f").execute().fetchall()] expected = [0.67] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = [ (r[0], round(r[1], 2)) for r in table.geomean("f", groups="id", projected_columns="id").order("id").execute().fetchall() ] expected = [(1, 0.05), (2, 0.65), (3, 9.52)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_histogram(self, table): result = table.histogram("v").execute().fetchall() expected = [({-1: 1, 1: 2, 2: 1, 5: 1, 10: 1, 11: 1},)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = table.histogram("v", groups="id", projected_columns="id").order("id").execute().fetchall() expected = [(1, {1: 2, 2: 1}), (2, {10: 1, 11: 1}), (3, {-1: 1, 5: 1})] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_list(self, table): result = table.list("v").execute().fetchall() expected = [([1, 1, 2, 10, 11, -1, 5, None],)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = table.list("v", groups="id order by t asc", projected_columns="id").order("id").execute().fetchall() expected = [(1, [1, 1, 2]), (2, [10, 11]), (3, [-1, 5, None])] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_max(self, table): result = table.max("v").execute().fetchall() expected = [(11,)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = table.max("v", groups="id", projected_columns="id").order("id").execute().fetchall() expected = [(1, 2), (2, 11), (3, 5)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_min(self, table): result = table.min("v").execute().fetchall() expected = [(-1,)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = table.min("v", groups="id", projected_columns="id").order("id").execute().fetchall() expected = [(1, 1), (2, 10), (3, -1)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_product(self, table): result = table.product("v").execute().fetchall() expected = [(-1100,)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = table.product("v", groups="id", projected_columns="id").order("id").execute().fetchall() expected = [(1, 2), (2, 110), (3, -5)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_string_agg(self, table): result = table.string_agg("s", sep="/").execute().fetchall() - expected = [('h/e/l/l/o/,/wor/ld',)] + expected = [("h/e/l/l/o/,/wor/ld",)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = ( table.string_agg("s", sep="/", groups="id order by t asc", projected_columns="id") .order("id") .execute() .fetchall() ) - expected = [(1, 'h/e/l'), (2, 'l/o'), (3, ',/wor/ld')] + expected = [(1, "h/e/l"), (2, "l/o"), (3, ",/wor/ld")] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_sum(self, table): result = table.sum("v").execute().fetchall() expected = [(29,)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = table.sum("v", groups="id", projected_columns="id").execute().fetchall() expected = [(1, 4), (2, 21), (3, 4)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) - # TODO: Approximate aggregate functions + # TODO: Approximate aggregate functions # noqa: TD002, TD003 - # TODO: Statistical aggregate functions + # TODO: Statistical aggregate functions # noqa: TD002, TD003 def test_median(self, table): result = table.median("v").execute().fetchall() expected = [(2.0,)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = table.median("v", groups="id", projected_columns="id").order("id").execute().fetchall() expected = [(1, 1.0), (2, 10.5), (3, 2.0)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_mode(self, table): result = table.mode("v").execute().fetchall() expected = [(1,)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = table.mode("v", groups="id", projected_columns="id").order("id").execute().fetchall() expected = [(1, 1), (2, 10), (3, -1)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_quantile_cont(self, table): result = table.quantile_cont("v").execute().fetchall() expected = [(2.0,)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) - result = [ - list(map(lambda x: round(x, 2), r[0])) for r in table.quantile_cont("v", q=[0.1, 0.5]).execute().fetchall() - ] + assert all(r == e for r, e in zip(result, expected)) + result = [[round(x, 2) for x in r[0]] for r in table.quantile_cont("v", q=[0.1, 0.5]).execute().fetchall()] expected = [[0.2, 2.0]] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = table.quantile_cont("v", groups="id", projected_columns="id").order("id").execute().fetchall() expected = [(1, 1.0), (2, 10.5), (3, 2.0)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = [ - (r[0], list(map(lambda x: round(x, 2), r[1]))) + (r[0], [round(x, 2) for x in r[1]]) for r in table.quantile_cont("v", q=[0.2, 0.5], groups="id", projected_columns="id") .order("id") .execute() @@ -339,82 +337,82 @@ def test_quantile_cont(self, table): ] expected = [(1, [1.0, 1.0]), (2, [10.2, 10.5]), (3, [0.2, 2.0])] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) @pytest.mark.parametrize("f", ["quantile_disc", "quantile"]) def test_quantile_disc(self, table, f): result = getattr(table, f)("v").execute().fetchall() expected = [(2,)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = getattr(table, f)("v", q=[0.2, 0.5]).execute().fetchall() expected = [([1, 2],)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = getattr(table, f)("v", groups="id", projected_columns="id").order("id").execute().fetchall() expected = [(1, 1), (2, 10), (3, -1)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = ( getattr(table, f)("v", q=[0.2, 0.8], groups="id", projected_columns="id").order("id").execute().fetchall() ) expected = [(1, [1, 2]), (2, [10, 11]), (3, [-1, 5])] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_std_pop(self, table): result = [round(r[0], 2) for r in table.stddev_pop("v").execute().fetchall()] expected = [4.36] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = [ (r[0], round(r[1], 2)) for r in table.stddev_pop("v", groups="id", projected_columns="id").order("id").execute().fetchall() ] expected = [(1, 0.47), (2, 0.5), (3, 3.0)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) @pytest.mark.parametrize("f", ["stddev_samp", "stddev", "std"]) def test_std_samp(self, table, f): result = [round(r[0], 2) for r in getattr(table, f)("v").execute().fetchall()] expected = [4.71] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = [ (r[0], round(r[1], 2)) for r in getattr(table, f)("v", groups="id", projected_columns="id").order("id").execute().fetchall() ] expected = [(1, 0.58), (2, 0.71), (3, 4.24)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_var_pop(self, table): result = [round(r[0], 2) for r in table.var_pop("v").execute().fetchall()] expected = [18.98] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = [ (r[0], round(r[1], 2)) for r in table.var_pop("v", groups="id", projected_columns="id").order("id").execute().fetchall() ] expected = [(1, 0.22), (2, 0.25), (3, 9.0)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) @pytest.mark.parametrize("f", ["var_samp", "variance", "var"]) def test_var_samp(self, table, f): result = [round(r[0], 2) for r in getattr(table, f)("v").execute().fetchall()] expected = [22.14] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = [ (r[0], round(r[1], 2)) for r in getattr(table, f)("v", groups="id", projected_columns="id").order("id").execute().fetchall() ] expected = [(1, 0.33), (2, 0.5), (3, 18.0)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_describe(self, table): assert table.describe().fetchall() is not None diff --git a/tests/fast/relational_api/test_rapi_close.py b/tests/fast/relational_api/test_rapi_close.py index 270c58f5..969e2792 100644 --- a/tests/fast/relational_api/test_rapi_close.py +++ b/tests/fast/relational_api/test_rapi_close.py @@ -1,9 +1,10 @@ -import duckdb import pytest +import duckdb + # A closed connection should invalidate all relation's methods -class TestRAPICloseConnRel(object): +class TestRAPICloseConnRel: def test_close_conn_rel(self, duckdb_cursor): con = duckdb.connect() con.execute("CREATE TABLE items(item VARCHAR, value DECIMAL(10,2), count INTEGER)") @@ -11,153 +12,153 @@ def test_close_conn_rel(self, duckdb_cursor): rel = con.table("items") con.close() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): print(rel) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): len(rel) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.aggregate("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.any_value("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.apply("", "") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.arg_max("", "") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.arg_min("", "") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.fetch_arrow_table() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.avg("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.bit_and("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.bit_or("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.bit_xor("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.bitstring_agg("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.bool_and("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.bool_or("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.count("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.create("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.create_view("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.cume_dist("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.dense_rank("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.describe() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.df() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.distinct() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.execute() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.favg("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.fetchall() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.fetchnumpy() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.fetchone() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.filter("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.first("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.first_value("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.fsum("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.geomean("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.histogram("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.insert("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.insert_into("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.lag("", "") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.last("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.last_value("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.lead("", "") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): print(rel.limit(1)) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.list("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.map(lambda df: df['col0'].add(42).to_frame()) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): + rel.map(lambda df: df["col0"].add(42).to_frame()) + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.max("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.mean("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.median("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.min("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.mode("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.n_tile("", 1) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.nth_value("", "", 1) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.order("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.percent_rank("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.product("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.project("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.quantile("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.quantile_cont("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.quantile_disc("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.query("", "") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.rank("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.rank_dense("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.row_number("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.std("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.stddev("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.stddev_pop("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.stddev_samp("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.string_agg("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.sum("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.to_arrow_table() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.to_df() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.var("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.var_pop("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.var_samp("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.variance("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.write_csv("") con = duckdb.connect() @@ -166,14 +167,14 @@ def test_close_conn_rel(self, duckdb_cursor): valid_rel = con.table("items") # Test these bad boys when left relation is valid - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): valid_rel.union(rel) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): valid_rel.except_(rel) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): valid_rel.intersect(rel) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - valid_rel.join(rel.set_alias('rel'), "rel.items = valid_rel.items") + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): + valid_rel.join(rel.set_alias("rel"), "rel.items = valid_rel.items") def test_del_conn(self, duckdb_cursor): con = duckdb.connect() @@ -181,5 +182,5 @@ def test_del_conn(self, duckdb_cursor): con.execute("INSERT INTO items VALUES ('jeans', 20.0, 1), ('hammer', 42.2, 2)") rel = con.table("items") del con - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): print(rel) diff --git a/tests/fast/relational_api/test_rapi_description.py b/tests/fast/relational_api/test_rapi_description.py index 01c8a460..33eb2f7e 100644 --- a/tests/fast/relational_api/test_rapi_description.py +++ b/tests/fast/relational_api/test_rapi_description.py @@ -1,34 +1,35 @@ -import duckdb import pytest +import duckdb + -class TestRAPIDescription(object): +class TestRAPIDescription: def test_rapi_description(self, duckdb_cursor): - res = duckdb_cursor.query('select 42::INT AS a, 84::BIGINT AS b') + res = duckdb_cursor.query("select 42::INT AS a, 84::BIGINT AS b") desc = res.description names = [x[0] for x in desc] types = [x[1] for x in desc] - assert names == ['a', 'b'] - assert types == ['INTEGER', 'BIGINT'] - assert (all([x == duckdb.NUMBER for x in types])) + assert names == ["a", "b"] + assert types == ["INTEGER", "BIGINT"] + assert all(x == duckdb.NUMBER for x in types) def test_rapi_describe(self, duckdb_cursor): np = pytest.importorskip("numpy") - pd = pytest.importorskip("pandas") - res = duckdb_cursor.query('select 42::INT AS a, 84::BIGINT AS b') + pytest.importorskip("pandas") + res = duckdb_cursor.query("select 42::INT AS a, 84::BIGINT AS b") duck_describe = res.describe().df() - np.testing.assert_array_equal(duck_describe['aggr'], ['count', 'mean', 'stddev', 'min', 'max', 'median']) - np.testing.assert_array_equal(duck_describe['a'], [1, 42, float('nan'), 42, 42, 42]) - np.testing.assert_array_equal(duck_describe['b'], [1, 84, float('nan'), 84, 84, 84]) + np.testing.assert_array_equal(duck_describe["aggr"], ["count", "mean", "stddev", "min", "max", "median"]) + np.testing.assert_array_equal(duck_describe["a"], [1, 42, float("nan"), 42, 42, 42]) + np.testing.assert_array_equal(duck_describe["b"], [1, 84, float("nan"), 84, 84, 84]) # now with more values res = duckdb_cursor.query( - 'select CASE WHEN i%2=0 THEN i ELSE NULL END AS i, i * 10 AS j, (i * 23 // 27)::DOUBLE AS k FROM range(10000) t(i)' + "select CASE WHEN i%2=0 THEN i ELSE NULL END AS i, i * 10 AS j, (i * 23 // 27)::DOUBLE AS k FROM range(10000) t(i)" # noqa: E501 ) duck_describe = res.describe().df() - np.testing.assert_allclose(duck_describe['i'], [5000.0, 4999.0, 2887.0400066504103, 0.0, 9998.0, 4999.0]) - np.testing.assert_allclose(duck_describe['j'], [10000.0, 49995.0, 28868.956799071675, 0.0, 99990.0, 49995.0]) - np.testing.assert_allclose(duck_describe['k'], [10000.0, 4258.3518, 2459.207430770227, 0.0, 8517.0, 4258.5]) + np.testing.assert_allclose(duck_describe["i"], [5000.0, 4999.0, 2887.0400066504103, 0.0, 9998.0, 4999.0]) + np.testing.assert_allclose(duck_describe["j"], [10000.0, 49995.0, 28868.956799071675, 0.0, 99990.0, 49995.0]) + np.testing.assert_allclose(duck_describe["k"], [10000.0, 4258.3518, 2459.207430770227, 0.0, 8517.0, 4258.5]) # describe data with other (non-numeric) types res = duckdb_cursor.query("select 'hello world' AS a, [1, 2, 3] AS b") @@ -38,8 +39,8 @@ def test_rapi_describe(self, duckdb_cursor): # describe mixed table res = duckdb_cursor.query("select 42::INT AS a, 84::BIGINT AS b, 'hello world' AS c") duck_describe = res.describe().df() - np.testing.assert_array_equal(duck_describe['a'], [1, 42, float('nan'), 42, 42, 42]) - np.testing.assert_array_equal(duck_describe['b'], [1, 84, float('nan'), 84, 84, 84]) + np.testing.assert_array_equal(duck_describe["a"], [1, 42, float("nan"), 42, 42, 42]) + np.testing.assert_array_equal(duck_describe["b"], [1, 84, float("nan"), 84, 84, 84]) # timestamps res = duckdb_cursor.query("select timestamp '1992-01-01', date '2000-01-01'") diff --git a/tests/fast/relational_api/test_rapi_functions.py b/tests/fast/relational_api/test_rapi_functions.py index 92de4c2c..143aa8df 100644 --- a/tests/fast/relational_api/test_rapi_functions.py +++ b/tests/fast/relational_api/test_rapi_functions.py @@ -1,12 +1,12 @@ import duckdb -class TestRAPIFunctions(object): +class TestRAPIFunctions: def test_rapi_str_print(self, duckdb_cursor): - res = duckdb_cursor.query('select 42::INT AS a, 84::BIGINT AS b') + res = duckdb_cursor.query("select 42::INT AS a, 84::BIGINT AS b") assert str(res) is not None res.show() def test_rapi_relation_sql_query(self): - res = duckdb.table_function('range', [10]) + res = duckdb.table_function("range", [10]) assert res.sql_query() == 'SELECT * FROM "range"(10)' diff --git a/tests/fast/relational_api/test_rapi_query.py b/tests/fast/relational_api/test_rapi_query.py index 92f87776..25f8c323 100644 --- a/tests/fast/relational_api/test_rapi_query.py +++ b/tests/fast/relational_api/test_rapi_query.py @@ -1,21 +1,23 @@ -import duckdb -import pytest import platform import sys +import pytest + +import duckdb + -@pytest.fixture() +@pytest.fixture def tbl_table(): con = duckdb.default_connection() con.execute("drop table if exists tbl CASCADE") con.execute("create table tbl (i integer)") yield - con.execute('drop table tbl CASCADE') + con.execute("drop table tbl CASCADE") -@pytest.fixture() +@pytest.fixture def scoped_default(duckdb_cursor): - default = duckdb.connect(':default:') + default = duckdb.connect(":default:") duckdb.set_default_connection(duckdb_cursor) # Overwrite the default connection yield @@ -23,11 +25,11 @@ def scoped_default(duckdb_cursor): duckdb.set_default_connection(default) -class TestRAPIQuery(object): - @pytest.mark.parametrize('steps', [1, 2, 3, 4]) +class TestRAPIQuery: + @pytest.mark.parametrize("steps", [1, 2, 3, 4]) def test_query_chain(self, steps): con = duckdb.default_connection() - amount = int(1000000) + amount = 1000000 rel = None for _ in range(steps): rel = con.query(f"select i from range({amount}::BIGINT) tbl(i)") @@ -36,7 +38,7 @@ def test_query_chain(self, steps): result = rel.execute() assert len(result.fetchall()) == amount - @pytest.mark.parametrize('input', [[5, 4, 3], [], [1000]]) + @pytest.mark.parametrize("input", [[5, 4, 3], [], [1000]]) def test_query_table(self, tbl_table, input): con = duckdb.default_connection() rel = con.table("tbl") @@ -45,7 +47,7 @@ def test_query_table(self, tbl_table, input): # Querying a table relation rel = rel.query("x", "select * from x") result = rel.execute() - assert result.fetchall() == [tuple([x]) for x in input] + assert result.fetchall() == [(x,) for x in input] def test_query_table_basic(self, tbl_table): con = duckdb.default_connection() @@ -98,80 +100,80 @@ def test_query_table_unrelated(self, tbl_table): def test_query_non_select_result(self, duckdb_cursor): with pytest.raises(duckdb.ParserException, match="syntax error"): - duckdb_cursor.query('selec 42') + duckdb_cursor.query("selec 42") - res = duckdb_cursor.query('explain select 42').fetchall() + res = duckdb_cursor.query("explain select 42").fetchall() assert len(res) > 0 - res = duckdb_cursor.query('describe select 42::INT AS column_name').fetchall() - assert res[0][0] == 'column_name' + res = duckdb_cursor.query("describe select 42::INT AS column_name").fetchall() + assert res[0][0] == "column_name" - res = duckdb_cursor.query('create or replace table tbl_non_select_result(i integer)') + res = duckdb_cursor.query("create or replace table tbl_non_select_result(i integer)") assert res is None - res = duckdb_cursor.query('insert into tbl_non_select_result values (42)') + res = duckdb_cursor.query("insert into tbl_non_select_result values (42)") assert res is None - res = duckdb_cursor.query('insert into tbl_non_select_result values (84) returning *').fetchall() + res = duckdb_cursor.query("insert into tbl_non_select_result values (84) returning *").fetchall() assert res == [(84,)] - res = duckdb_cursor.query('select * from tbl_non_select_result').fetchall() + res = duckdb_cursor.query("select * from tbl_non_select_result").fetchall() assert res == [(42,), (84,)] - res = duckdb_cursor.query('insert into tbl_non_select_result select * from range(10000) returning *').fetchall() + res = duckdb_cursor.query("insert into tbl_non_select_result select * from range(10000) returning *").fetchall() assert len(res) == 10000 - res = duckdb_cursor.query('show tables').fetchall() + res = duckdb_cursor.query("show tables").fetchall() assert len(res) > 0 - res = duckdb_cursor.query('drop table tbl_non_select_result') + res = duckdb_cursor.query("drop table tbl_non_select_result") assert res is None def test_replacement_scan_recursion(self, duckdb_cursor): depth_limit = 1000 - if sys.platform.startswith('win') or platform.system() == "Emscripten": + if sys.platform.startswith("win") or platform.system() == "Emscripten": # With the default we reach a stack overflow in the CI for windows # and also outside of it for Pyodide depth_limit = 250 duckdb_cursor.execute(f"SET max_expression_depth TO {depth_limit}") - rel = duckdb_cursor.sql('select 42 a, 21 b') - rel = duckdb_cursor.sql('select a+a a, b+b b from rel') - other_rel = duckdb_cursor.sql('select a from rel') + rel = duckdb_cursor.sql("select 42 a, 21 b") + rel = duckdb_cursor.sql("select a+a a, b+b b from rel") # noqa: F841 + other_rel = duckdb_cursor.sql("select a from rel") res = other_rel.fetchall() assert res == [(84,)] def test_set_default_connection(self, scoped_default): duckdb.sql("create table t as select 42") - assert duckdb.table('t').fetchall() == [(42,)] - con = duckdb.connect(':default:') + assert duckdb.table("t").fetchall() == [(42,)] + con = duckdb.connect(":default:") # Uses the same db as the module - assert con.table('t').fetchall() == [(42,)] + assert con.table("t").fetchall() == [(42,)] con2 = duckdb.connect() con2.sql("create table t as select 21") - assert con2.table('t').fetchall() == [(21,)] + assert con2.table("t").fetchall() == [(21,)] # Change the db used by the module duckdb.set_default_connection(con2) - with pytest.raises(duckdb.CatalogException, match='Table with name d does not exist'): - con2.table('d').fetchall() + with pytest.raises(duckdb.CatalogException, match="Table with name d does not exist"): + con2.table("d").fetchall() - assert duckdb.table('t').fetchall() == [(21,)] + assert duckdb.table("t").fetchall() == [(21,)] duckdb.sql("create table d as select [1,2,3]") - assert duckdb.table('d').fetchall() == [([1, 2, 3],)] - assert con2.table('d').fetchall() == [([1, 2, 3],)] + assert duckdb.table("d").fetchall() == [([1, 2, 3],)] + assert con2.table("d").fetchall() == [([1, 2, 3],)] def test_set_default_connection_error(self, scoped_default): - with pytest.raises(TypeError, match='Invoked with: None'): + with pytest.raises(TypeError, match="Invoked with: None"): # set_default_connection does not allow None duckdb.set_default_connection(None) - with pytest.raises(TypeError, match='Invoked with: 5'): + with pytest.raises(TypeError, match="Invoked with: 5"): duckdb.set_default_connection(5) assert duckdb.sql("select 42").fetchall() == [(42,)] diff --git a/tests/fast/relational_api/test_rapi_windows.py b/tests/fast/relational_api/test_rapi_windows.py index 7c13debc..28d533b7 100644 --- a/tests/fast/relational_api/test_rapi_windows.py +++ b/tests/fast/relational_api/test_rapi_windows.py @@ -1,6 +1,7 @@ -import duckdb import pytest +import duckdb + @pytest.fixture(autouse=True) def setup_and_teardown_of_table(duckdb_cursor): @@ -15,14 +16,14 @@ def setup_and_teardown_of_table(duckdb_cursor): (2, 11, -1, 10.45, 'o'), (3, -1, 0, 13.32, ','), (3, 5, -2, 9.87, 'wor'), - (3, null, 10, 6.56, 'ld'); + (3, null, 10, 6.56, 'ld'); """ ) yield duckdb_cursor.execute("drop table win") -@pytest.fixture() +@pytest.fixture def table(duckdb_cursor): return duckdb_cursor.table("win") @@ -33,7 +34,7 @@ def test_row_number(self, table): result = table.row_number("over ()").execute().fetchall() expected = list(range(1, 9)) assert len(result) == len(expected) - assert all([r[0] == e for r, e in zip(result, expected)]) + assert all(r[0] == e for r, e in zip(result, expected)) result = table.row_number("over (partition by id order by t asc)", "id, v, t").order("id").execute().fetchall() expected = [ (1, 1, 1, 1), @@ -46,34 +47,34 @@ def test_row_number(self, table): (3, None, 10, 3), ] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_rank(self, table): result = table.rank("over ()").execute().fetchall() expected = [1] * 8 assert len(result) == len(expected) - assert all([r[0] == e for r, e in zip(result, expected)]) + assert all(r[0] == e for r, e in zip(result, expected)) result = table.rank("over (partition by id order by v asc)", "id, v").order("id").execute().fetchall() expected = [(1, 1, 1), (1, 1, 1), (1, 2, 3), (2, 10, 1), (2, 11, 2), (3, -1, 1), (3, 5, 2), (3, None, 3)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) @pytest.mark.parametrize("f", ["dense_rank", "rank_dense"]) def test_dense_rank(self, table, f): result = getattr(table, f)("over ()").execute().fetchall() expected = [1] * 8 assert len(result) == len(expected) - assert all([r[0] == e for r, e in zip(result, expected)]) + assert all(r[0] == e for r, e in zip(result, expected)) result = getattr(table, f)("over (partition by id order by v asc)", "id, v").order("id").execute().fetchall() expected = [(1, 1, 1), (1, 1, 1), (1, 2, 2), (2, 10, 1), (2, 11, 2), (3, -1, 1), (3, 5, 2), (3, None, 3)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_percent_rank(self, table): result = table.percent_rank("over ()").execute().fetchall() expected = [0.0] * 8 assert len(result) == len(expected) - assert all([r[0] == e for r, e in zip(result, expected)]) + assert all(r[0] == e for r, e in zip(result, expected)) result = table.percent_rank("over (partition by id order by v asc)", "id, v").order("id").execute().fetchall() expected = [ (1, 1, 0.0), @@ -86,13 +87,13 @@ def test_percent_rank(self, table): (3, None, 1.0), ] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_cume_dist(self, table): result = table.cume_dist("over ()").execute().fetchall() expected = [1.0] * 8 assert len(result) == len(expected) - assert all([r[0] == e for r, e in zip(result, expected)]) + assert all(r[0] == e for r, e in zip(result, expected)) result = table.cume_dist("over (partition by id order by v asc)", "id, v").order("id").execute().fetchall() expected = [ (1, 1, 2 / 3), @@ -105,13 +106,13 @@ def test_cume_dist(self, table): (3, None, 1.0), ] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_ntile(self, table): result = table.n_tile("over (order by v)", 3, "v").execute().fetchall() expected = [(-1, 1), (1, 1), (1, 1), (2, 2), (5, 2), (10, 2), (11, 3), (None, 3)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_lag(self, table): result = ( @@ -131,7 +132,7 @@ def test_lag(self, table): (3, None, 10, -1), ] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = ( table.lag("v", "over (partition by id order by t asc)", default_value="-1", projected_columns="id, v, t") .order("id") @@ -149,7 +150,7 @@ def test_lag(self, table): (3, None, 10, -1), ] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = ( table.lag("v", "over (partition by id order by t asc)", offset=2, projected_columns="id, v, t") .order("id") @@ -167,7 +168,7 @@ def test_lag(self, table): (3, None, 10, 5), ] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_lead(self, table): result = ( @@ -187,7 +188,7 @@ def test_lead(self, table): (3, None, 10, None), ] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = ( table.lead("v", "over (partition by id order by t asc)", default_value="-1", projected_columns="id, v, t") .order("id") @@ -205,7 +206,7 @@ def test_lead(self, table): (3, None, 10, -1), ] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = ( table.lead("v", "over (partition by id order by t asc)", offset=2, projected_columns="id, v, t") .order("id") @@ -223,7 +224,7 @@ def test_lead(self, table): (3, None, 10, None), ] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_first_value(self, table): result = ( @@ -243,7 +244,7 @@ def test_first_value(self, table): (3, None, 10, 5), ] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_last_value(self, table): result = ( @@ -267,7 +268,7 @@ def test_last_value(self, table): (3, None, 10, None), ] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_nth_value(self, table): result = ( @@ -287,7 +288,7 @@ def test_nth_value(self, table): (3, None, 10, -1), ] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = ( table.nth_value("v", "over (partition by id order by t asc)", offset=4, projected_columns="id, v, t") .order("id") @@ -305,7 +306,7 @@ def test_nth_value(self, table): (3, None, 10, None), ] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) # agg functions within win def test_any_value(self, table): @@ -317,7 +318,7 @@ def test_any_value(self, table): ) expected = [(1, 1), (1, 1), (1, 1), (2, 11), (2, 11), (3, 5), (3, 5), (3, 5)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_arg_max(self, table): result = ( @@ -328,7 +329,7 @@ def test_arg_max(self, table): ) expected = [(1, 3), (1, 3), (1, 3), (2, -1), (2, -1), (3, -2), (3, -2), (3, -2)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_arg_min(self, table): result = ( @@ -339,7 +340,7 @@ def test_arg_min(self, table): ) expected = [(1, 2), (1, 2), (1, 2), (2, 4), (2, 4), (3, 0), (3, 0), (3, 0)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_avg(self, table): result = [ @@ -347,7 +348,8 @@ def test_avg(self, table): for r in ( table.avg( "v", - window_spec="over (partition by id order by t asc rows between unbounded preceding and current row)", + window_spec="over (partition by id order by t asc rows between unbounded preceding and " + "current row)", projected_columns="id", ) .order("id") @@ -357,7 +359,7 @@ def test_avg(self, table): ] expected = [(1, 1.0), (1, 1.0), (1, 1.33), (2, 11.0), (2, 10.5), (3, 5.0), (3, 2.0), (3, 2.0)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_bit_and(self, table): result = ( @@ -372,7 +374,7 @@ def test_bit_and(self, table): ) expected = [(1, 1), (1, 1), (1, 0), (2, 11), (2, 10), (3, 5), (3, 5), (3, 5)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_bit_or(self, table): result = ( @@ -387,7 +389,7 @@ def test_bit_or(self, table): ) expected = [(1, 1), (1, 1), (1, 3), (2, 11), (2, 11), (3, 5), (3, -1), (3, -1)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_bit_xor(self, table): result = ( @@ -402,20 +404,15 @@ def test_bit_xor(self, table): ) expected = [(1, 1), (1, 0), (1, 2), (2, 11), (2, 1), (3, 5), (3, -6), (3, -6)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_bitstring_agg(self, table): with pytest.raises(duckdb.BinderException, match="Could not retrieve required statistics"): - result = ( - table.bitstring_agg( - "v", - window_spec="over (partition by id order by t asc rows between unbounded preceding and current row)", - projected_columns="id", - ) - .order("id") - .execute() - .fetchall() - ) + table.bitstring_agg( + "v", + window_spec="over (partition by id order by t asc rows between unbounded preceding and current row)", + projected_columns="id", + ).order("id").execute().fetchall() result = ( table.bitstring_agg( "v", @@ -429,17 +426,17 @@ def test_bitstring_agg(self, table): .fetchall() ) expected = [ - (1, '0010000000000'), - (1, '0010000000000'), - (1, '0011000000000'), - (2, '0000000000001'), - (2, '0000000000011'), - (3, '0000001000000'), - (3, '1000001000000'), - (3, '1000001000000'), + (1, "0010000000000"), + (1, "0010000000000"), + (1, "0011000000000"), + (2, "0000000000001"), + (2, "0000000000011"), + (3, "0000001000000"), + (3, "1000001000000"), + (3, "1000001000000"), ] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_bool_and(self, table): result = ( @@ -450,7 +447,7 @@ def test_bool_and(self, table): ) expected = [(1, True), (1, True), (1, True), (2, True), (2, True), (3, False), (3, False), (3, False)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_bool_or(self, table): result = ( @@ -461,7 +458,7 @@ def test_bool_or(self, table): ) expected = [(1, True), (1, True), (1, True), (2, True), (2, True), (3, True), (3, True), (3, True)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_count(self, table): result = ( @@ -476,7 +473,7 @@ def test_count(self, table): ) expected = [(1, 1), (1, 2), (1, 3), (2, 1), (2, 2), (3, 1), (3, 2), (3, 3)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_favg(self, table): result = [ @@ -492,7 +489,7 @@ def test_favg(self, table): ] expected = [(1, 0.21), (1, 0.38), (1, 0.25), (2, 10.45), (2, 5.24), (3, 9.87), (3, 11.59), (3, 9.92)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_fsum(self, table): result = [ @@ -508,7 +505,7 @@ def test_fsum(self, table): ] expected = [(1, 0.21), (1, 0.75), (1, 0.75), (2, 10.45), (2, 10.49), (3, 9.87), (3, 23.19), (3, 29.75)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) @pytest.mark.skip(reason="geomean is not supported from a windowing context") def test_geomean(self, table): @@ -536,7 +533,7 @@ def test_histogram(self, table): (3, {-1: 1, 5: 1}), ] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_list(self, table): result = ( @@ -560,7 +557,7 @@ def test_list(self, table): (3, [5, -1, None]), ] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_max(self, table): result = ( @@ -575,7 +572,7 @@ def test_max(self, table): ) expected = [(1, 1), (1, 1), (1, 2), (2, 11), (2, 11), (3, 5), (3, 5), (3, 5)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_min(self, table): result = ( @@ -590,7 +587,7 @@ def test_min(self, table): ) expected = [(1, 1), (1, 1), (1, 1), (2, 11), (2, 10), (3, 5), (3, -1), (3, -1)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_product(self, table): result = ( @@ -605,7 +602,7 @@ def test_product(self, table): ) expected = [(1, 1), (1, 1), (1, 2), (2, 11), (2, 110), (3, 5), (3, -5), (3, -5)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_string_agg(self, table): result = ( @@ -619,9 +616,9 @@ def test_string_agg(self, table): .execute() .fetchall() ) - expected = [(1, 'e'), (1, 'e/h'), (1, 'e/h/l'), (2, 'o'), (2, 'o/l'), (3, 'wor'), (3, 'wor/,'), (3, 'wor/,/ld')] + expected = [(1, "e"), (1, "e/h"), (1, "e/h/l"), (2, "o"), (2, "o/l"), (3, "wor"), (3, "wor/,"), (3, "wor/,/ld")] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_sum(self, table): result = ( @@ -636,7 +633,7 @@ def test_sum(self, table): ) expected = [(1, 1), (1, 2), (1, 4), (2, 11), (2, 21), (3, 5), (3, 4), (3, 4)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_median(self, table): result = ( @@ -651,7 +648,7 @@ def test_median(self, table): ) expected = [(1, 1.0), (1, 1.0), (1, 1.0), (2, 11.0), (2, 10.5), (3, 5.0), (3, 2.0), (3, 2.0)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_mode(self, table): result = ( @@ -666,7 +663,7 @@ def test_mode(self, table): ) expected = [(1, 2), (1, 2), (1, 1), (2, 10), (2, 10), (3, None), (3, -1), (3, -1)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_quantile_cont(self, table): result = ( @@ -681,9 +678,9 @@ def test_quantile_cont(self, table): ) expected = [(1, 2.0), (1, 1.5), (1, 1.0), (2, 10.0), (2, 10.5), (3, None), (3, -1.0), (3, 2.0)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = [ - (r[0], list(map(lambda x: round(x, 2), r[1])) if r[1] is not None else None) + (r[0], [round(x, 2) for x in r[1]] if r[1] is not None else None) for r in table.quantile_cont( "v", q=[0.2, 0.5], @@ -705,7 +702,7 @@ def test_quantile_cont(self, table): (3, [0.2, 2.0]), ] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) @pytest.mark.parametrize("f", ["quantile_disc", "quantile"]) def test_quantile_disc(self, table, f): @@ -721,7 +718,7 @@ def test_quantile_disc(self, table, f): ) expected = [(1, 2), (1, 1), (1, 1), (2, 10), (2, 10), (3, None), (3, -1), (3, -1)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) result = ( getattr(table, f)( "v", @@ -744,7 +741,7 @@ def test_quantile_disc(self, table, f): (3, [-1, 5]), ] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_stddev_pop(self, table): result = [ @@ -760,7 +757,7 @@ def test_stddev_pop(self, table): ] expected = [(1, 0.0), (1, 0.5), (1, 0.47), (2, 0.0), (2, 0.5), (3, None), (3, 0.0), (3, 3.0)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) @pytest.mark.parametrize("f", ["stddev_samp", "stddev", "std"]) def test_stddev_samp(self, table, f): @@ -777,7 +774,7 @@ def test_stddev_samp(self, table, f): ] expected = [(1, None), (1, 0.71), (1, 0.58), (2, None), (2, 0.71), (3, None), (3, None), (3, 4.24)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) def test_var_pop(self, table): result = [ @@ -793,7 +790,7 @@ def test_var_pop(self, table): ] expected = [(1, 0.0), (1, 0.25), (1, 0.22), (2, 0.0), (2, 0.25), (3, None), (3, 0.0), (3, 9.0)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) @pytest.mark.parametrize("f", ["var_samp", "variance", "var"]) def test_var_samp(self, table, f): @@ -810,4 +807,4 @@ def test_var_samp(self, table, f): ] expected = [(1, None), (1, 0.5), (1, 0.33), (2, None), (2, 0.5), (3, None), (3, None), (3, 18.0)] assert len(result) == len(expected) - assert all([r == e for r, e in zip(result, expected)]) + assert all(r == e for r, e in zip(result, expected)) diff --git a/tests/fast/relational_api/test_table_function.py b/tests/fast/relational_api/test_table_function.py index 4f4a1016..4dcf5f93 100644 --- a/tests/fast/relational_api/test_table_function.py +++ b/tests/fast/relational_api/test_table_function.py @@ -1,17 +1,19 @@ -import duckdb +from pathlib import Path + import pytest -import os -script_path = os.path.dirname(__file__) +import duckdb + +script_path = Path(__file__).parent -class TestTableFunction(object): +class TestTableFunction: def test_table_function(self, duckdb_cursor): - path = os.path.join(script_path, '..', 'data/integers.csv') - rel = duckdb_cursor.table_function('read_csv', [path]) + path = str(script_path / ".." / "data/integers.csv") + rel = duckdb_cursor.table_function("read_csv", [path]) res = rel.fetchall() assert res == [(1, 10, 0), (2, 50, 30)] # Provide only a string as argument, should error, needs a list with pytest.raises(duckdb.InvalidInputException, match=r"'params' has to be a list of parameters"): - rel = duckdb_cursor.table_function('read_csv', path) + rel = duckdb_cursor.table_function("read_csv", path) diff --git a/tests/fast/spark/test_replace_column_value.py b/tests/fast/spark/test_replace_column_value.py index 33940616..17a2254e 100644 --- a/tests/fast/spark/test_replace_column_value.py +++ b/tests/fast/spark/test_replace_column_value.py @@ -4,7 +4,7 @@ from spark_namespace.sql.types import Row -class TestReplaceValue(object): +class TestReplaceValue: # https://sparkbyexamples.com/pyspark/pyspark-replace-column-values/?expand_article=1 def test_replace_value(self, spark): address = [(1, "14851 Jeffrey Rd", "DE"), (2, "43421 Margarita St", "NY"), (3, "13111 Siemon Ave", "CA")] @@ -13,7 +13,7 @@ def test_replace_value(self, spark): # Replace part of string with another string from spark_namespace.sql.functions import regexp_replace - df2 = df.withColumn('address', regexp_replace('address', 'Rd', 'Road')) + df2 = df.withColumn("address", regexp_replace("address", "Rd", "Road")) # Replace string column value conditionally from spark_namespace.sql.functions import when @@ -21,24 +21,24 @@ def test_replace_value(self, spark): res = df2.collect() print(res) df2 = df.withColumn( - 'address', - when(df.address.endswith('Rd'), regexp_replace(df.address, 'Rd', 'Road')) - .when(df.address.endswith('St'), regexp_replace(df.address, 'St', 'Street')) - .when(df.address.endswith('Ave'), regexp_replace(df.address, 'Ave', 'Avenue')) + "address", + when(df.address.endswith("Rd"), regexp_replace(df.address, "Rd", "Road")) + .when(df.address.endswith("St"), regexp_replace(df.address, "St", "Street")) + .when(df.address.endswith("Ave"), regexp_replace(df.address, "Ave", "Avenue")) .otherwise(df.address), ) res = df2.collect() print(res) expected = [ - Row(id=1, address='14851 Jeffrey Road', state='DE'), - Row(id=2, address='43421 Margarita Street', state='NY'), - Row(id=3, address='13111 Siemon Avenue', state='CA'), + Row(id=1, address="14851 Jeffrey Road", state="DE"), + Row(id=2, address="43421 Margarita Street", state="NY"), + Row(id=3, address="13111 Siemon Avenue", state="CA"), ] print(expected) assert res == expected # Replace all substrings of the specified string value that match regexp with rep. - df3 = spark.createDataFrame([('100-200',)], ['str']) - res = df3.select(regexp_replace('str', r'(\d+)', '--').alias('d')).collect() - expected = [Row(d='-----')] + df3 = spark.createDataFrame([("100-200",)], ["str"]) + res = df3.select(regexp_replace("str", r"(\d+)", "--").alias("d")).collect() + expected = [Row(d="-----")] print(expected) assert res == expected diff --git a/tests/fast/spark/test_replace_empty_value.py b/tests/fast/spark/test_replace_empty_value.py index 71a9f25f..0a078167 100644 --- a/tests/fast/spark/test_replace_empty_value.py +++ b/tests/fast/spark/test_replace_empty_value.py @@ -2,42 +2,41 @@ _ = pytest.importorskip("duckdb.experimental.spark") -from spark_namespace import USE_ACTUAL_SPARK from spark_namespace.sql.types import Row # https://sparkbyexamples.com/pyspark/pyspark-replace-empty-value-with-none-on-dataframe-2/?expand_article=1 -class TestReplaceEmpty(object): +class TestReplaceEmpty: def test_replace_empty(self, spark): # Create the dataframe data = [("", "CA"), ("Julia", ""), ("Robert", ""), ("", "NJ")] df = spark.createDataFrame(data, ["name", "state"]) - res = df.select('name').collect() - assert res == [Row(name=''), Row(name='Julia'), Row(name='Robert'), Row(name='')] - res = df.select('state').collect() - assert res == [Row(state='CA'), Row(state=''), Row(state=''), Row(state='NJ')] + res = df.select("name").collect() + assert res == [Row(name=""), Row(name="Julia"), Row(name="Robert"), Row(name="")] + res = df.select("state").collect() + assert res == [Row(state="CA"), Row(state=""), Row(state=""), Row(state="NJ")] # Replace name # CASE WHEN "name" == '' THEN NULL ELSE "name" END from spark_namespace.sql.functions import col, when df2 = df.withColumn("name", when(col("name") == "", None).otherwise(col("name"))) - assert df2.columns == ['name', 'state'] - res = df2.select('name').collect() - assert res == [Row(name=None), Row(name='Julia'), Row(name='Robert'), Row(name=None)] + assert df2.columns == ["name", "state"] + res = df2.select("name").collect() + assert res == [Row(name=None), Row(name="Julia"), Row(name="Robert"), Row(name=None)] # Replace state + name from spark_namespace.sql.functions import col, when df2 = df.select([when(col(c) == "", None).otherwise(col(c)).alias(c) for c in df.columns]) - assert df2.columns == ['name', 'state'] - key_f = lambda x: x.name or x.state + assert df2.columns == ["name", "state"] + res = df2.sort("name", "state").collect() expected_res = [ - Row(name=None, state='CA'), - Row(name=None, state='NJ'), - Row(name='Julia', state=None), - Row(name='Robert', state=None), + Row(name=None, state="CA"), + Row(name=None, state="NJ"), + Row(name="Julia", state=None), + Row(name="Robert", state=None), ] assert res == expected_res @@ -46,15 +45,17 @@ def test_replace_empty(self, spark): from spark_namespace.sql.functions import col, when replaceCols = ["state"] - df2 = df.select([when(col(c) == "", None).otherwise(col(c)).alias(c) for c in replaceCols]).sort(col('state')) - assert df2.columns == ['state'] + df2 = df.select([when(col(c) == "", None).otherwise(col(c)).alias(c) for c in replaceCols]).sort(col("state")) + assert df2.columns == ["state"] + + def key_f(x): + return x.state or "" - key_f = lambda x: x.state or "" res = df2.collect() assert sorted(res, key=key_f) == sorted( [ - Row(state='CA'), - Row(state='NJ'), + Row(state="CA"), + Row(state="NJ"), Row(state=None), Row(state=None), ], diff --git a/tests/fast/spark/test_spark_arrow_table.py b/tests/fast/spark/test_spark_arrow_table.py index 57c81599..fc773562 100644 --- a/tests/fast/spark/test_spark_arrow_table.py +++ b/tests/fast/spark/test_spark_arrow_table.py @@ -2,8 +2,6 @@ _ = pytest.importorskip("duckdb.experimental.spark") pa = pytest.importorskip("pyarrow") -from spark_namespace import USE_ACTUAL_SPARK - from spark_namespace import USE_ACTUAL_SPARK from spark_namespace.sql.dataframe import DataFrame diff --git a/tests/fast/spark/test_spark_catalog.py b/tests/fast/spark/test_spark_catalog.py index 7f523abd..8a07a0a7 100644 --- a/tests/fast/spark/test_spark_catalog.py +++ b/tests/fast/spark/test_spark_catalog.py @@ -3,19 +3,19 @@ _ = pytest.importorskip("duckdb.experimental.spark") from spark_namespace import USE_ACTUAL_SPARK -from spark_namespace.sql.catalog import Table, Database, Column +from spark_namespace.sql.catalog import Column, Database, Table -class TestSparkCatalog(object): +class TestSparkCatalog: def test_list_databases(self, spark): dbs = spark.catalog.listDatabases() if USE_ACTUAL_SPARK: assert all(isinstance(db, Database) for db in dbs) else: assert dbs == [ - Database(name='memory', description=None, locationUri=''), - Database(name='system', description=None, locationUri=''), - Database(name='temp', description=None, locationUri=''), + Database(name="memory", description=None, locationUri=""), + Database(name="system", description=None, locationUri=""), + Database(name="temp", description=None, locationUri=""), ] def test_list_tables(self, spark): @@ -26,31 +26,31 @@ def test_list_tables(self, spark): if not USE_ACTUAL_SPARK: # Skip this if we're using actual Spark because we can't create tables # with our setup. - spark.sql('create table tbl(a varchar)') + spark.sql("create table tbl(a varchar)") tbls = spark.catalog.listTables() assert tbls == [ Table( - name='tbl', - database='memory', - description='CREATE TABLE tbl(a VARCHAR);', - tableType='', + name="tbl", + database="memory", + description="CREATE TABLE tbl(a VARCHAR);", + tableType="", isTemporary=False, ) ] @pytest.mark.skipif(USE_ACTUAL_SPARK, reason="We can't create tables with our Spark test setup") def test_list_columns(self, spark): - spark.sql('create table tbl(a varchar, b bool)') - columns = spark.catalog.listColumns('tbl') + spark.sql("create table tbl(a varchar, b bool)") + columns = spark.catalog.listColumns("tbl") assert columns == [ - Column(name='a', description=None, dataType='VARCHAR', nullable=True, isPartition=False, isBucket=False), - Column(name='b', description=None, dataType='BOOLEAN', nullable=True, isPartition=False, isBucket=False), + Column(name="a", description=None, dataType="VARCHAR", nullable=True, isPartition=False, isBucket=False), + Column(name="b", description=None, dataType="BOOLEAN", nullable=True, isPartition=False, isBucket=False), ] - # FIXME: should this error instead? - non_existant_columns = spark.catalog.listColumns('none_existant') + # TODO: should this error instead? # noqa: TD002, TD003 + non_existant_columns = spark.catalog.listColumns("none_existant") assert non_existant_columns == [] - spark.sql('create view vw as select * from tbl') - view_columns = spark.catalog.listColumns('vw') + spark.sql("create view vw as select * from tbl") + view_columns = spark.catalog.listColumns("vw") assert view_columns == columns diff --git a/tests/fast/spark/test_spark_column.py b/tests/fast/spark/test_spark_column.py index e56ba9ee..b2656643 100644 --- a/tests/fast/spark/test_spark_column.py +++ b/tests/fast/spark/test_spark_column.py @@ -2,42 +2,41 @@ _ = pytest.importorskip("duckdb.experimental.spark") +import re + from spark_namespace import USE_ACTUAL_SPARK -from spark_namespace.sql.column import Column -from spark_namespace.sql.functions import struct, array, col -from spark_namespace.sql.types import Row from spark_namespace.errors import PySparkTypeError - -import duckdb -import re +from spark_namespace.sql.functions import array, col, struct +from spark_namespace.sql.types import Row -class TestSparkColumn(object): +class TestSparkColumn: def test_struct_column(self, spark): df = spark.createDataFrame([Row(a=1, b=2, c=3, d=4)]) - # FIXME: column names should be set explicitly using the Row, rather than letting duckdb assign defaults (col0, col1, etc..) + # TODO: column names should be set explicitly using the Row, rather than letting duckdb # noqa: TD002, TD003 + # assign defaults(col0, col1, etc..) if USE_ACTUAL_SPARK: - df = df.withColumn('struct', struct(df.a, df.b)) + df = df.withColumn("struct", struct(df.a, df.b)) else: - df = df.withColumn('struct', struct(df.col0, df.col1)) - assert 'struct' in df - new_col = df.schema['struct'] + df = df.withColumn("struct", struct(df.col0, df.col1)) + assert "struct" in df + new_col = df.schema["struct"] if USE_ACTUAL_SPARK: - assert 'a' in df.schema['struct'].dataType.fieldNames() - assert 'b' in df.schema['struct'].dataType.fieldNames() + assert "a" in df.schema["struct"].dataType.fieldNames() + assert "b" in df.schema["struct"].dataType.fieldNames() else: - assert 'col0' in new_col.dataType - assert 'col1' in new_col.dataType + assert "col0" in new_col.dataType + assert "col1" in new_col.dataType with pytest.raises( PySparkTypeError, match=re.escape("[NOT_COLUMN] Argument `col` should be a Column, got str.") ): - df = df.withColumn('struct', 'yes') + df = df.withColumn("struct", "yes") def test_array_column(self, spark): - df = spark.createDataFrame([Row(a=1, b=2, c=3, d=4)], ['a', 'b', 'c', 'd']) + df = spark.createDataFrame([Row(a=1, b=2, c=3, d=4)], ["a", "b", "c", "d"]) df2 = df.select( array(df["a"], df["b"]).alias("array"), diff --git a/tests/fast/spark/test_spark_dataframe.py b/tests/fast/spark/test_spark_dataframe.py index d88b03eb..95a6b3a8 100644 --- a/tests/fast/spark/test_spark_dataframe.py +++ b/tests/fast/spark/test_spark_dataframe.py @@ -2,43 +2,40 @@ _ = pytest.importorskip("duckdb.experimental.spark") + from spark_namespace import USE_ACTUAL_SPARK +from spark_namespace.errors import PySparkTypeError, PySparkValueError +from spark_namespace.sql.column import Column +from spark_namespace.sql.functions import col, struct, when from spark_namespace.sql.types import ( - LongType, - StructType, + ArrayType, BooleanType, - StructField, - StringType, IntegerType, LongType, - Row, - ArrayType, MapType, + Row, + StringType, + StructField, + StructType, ) -from spark_namespace.sql.functions import col, struct, when -from spark_namespace.sql.column import Column -import duckdb -import re - -from spark_namespace.errors import PySparkValueError, PySparkTypeError def assert_column_objects_equal(col1: Column, col2: Column): - assert type(col1) == type(col2) + assert type(col1) is type(col2) if not USE_ACTUAL_SPARK: assert col1.expr == col2.expr -class TestDataFrame(object): +class TestDataFrame: def test_dataframe_from_list_of_tuples(self, spark): # Valid address = [(1, "14851 Jeffrey Rd", "DE"), (2, "43421 Margarita St", "NY"), (3, "13111 Siemon Ave", "CA")] df = spark.createDataFrame(address, ["id", "address", "state"]) res = df.collect() assert res == [ - Row(id=1, address='14851 Jeffrey Rd', state='DE'), - Row(id=2, address='43421 Margarita St', state='NY'), - Row(id=3, address='13111 Siemon Ave', state='CA'), + Row(id=1, address="14851 Jeffrey Rd", state="DE"), + Row(id=2, address="43421 Margarita St", state="NY"), + Row(id=3, address="13111 Siemon Ave", state="CA"), ] # Tuples of different sizes @@ -48,23 +45,22 @@ def test_dataframe_from_list_of_tuples(self, spark): from py4j.protocol import Py4JJavaError with pytest.raises(Py4JJavaError): - df = spark.createDataFrame(address, ["id", "address", "state"]) - df.collect() + spark.createDataFrame(address, ["id", "address", "state"]) else: with pytest.raises(PySparkTypeError, match="LENGTH_SHOULD_BE_THE_SAME"): - df = spark.createDataFrame(address, ["id", "address", "state"]) + spark.createDataFrame(address, ["id", "address", "state"]) # Dataframe instead of list with pytest.raises(PySparkTypeError, match="SHOULD_NOT_DATAFRAME"): - df = spark.createDataFrame(df, ["id", "address", "state"]) + spark.createDataFrame(df, ["id", "address", "state"]) # Not a list with pytest.raises(TypeError, match="not iterable"): - df = spark.createDataFrame(5, ["id", "address", "test"]) + spark.createDataFrame(5, ["id", "address", "test"]) # Empty list if not USE_ACTUAL_SPARK: - # FIXME: Spark raises PySparkValueError [CANNOT_INFER_EMPTY_SCHEMA] + # TODO: Spark raises PySparkValueError [CANNOT_INFER_EMPTY_SCHEMA] # noqa: TD002, TD003 df = spark.createDataFrame([], ["id", "address", "test"]) res = df.collect() assert res == [] @@ -73,7 +69,10 @@ def test_dataframe_from_list_of_tuples(self, spark): address = [(1, "14851 Jeffrey Rd", "DE"), (2, "43421 Margarita St", "NY"), (3, "13111 Siemon Ave", "DE")] df = spark.createDataFrame(address, ["id", "address", "id"]) res = df.collect() - exptected_res_str = "[Row(id=1, address='14851 Jeffrey Rd', id='DE'), Row(id=2, address='43421 Margarita St', id='NY'), Row(id=3, address='13111 Siemon Ave', id='DE')]" + exptected_res_str = ( + "[Row(id=1, address='14851 Jeffrey Rd', id='DE'), Row(id=2, address='43421 " + "Margarita St', id='NY'), Row(id=3, address='13111 Siemon Ave', id='DE')]" + ) if USE_ACTUAL_SPARK: # Spark uses string for both ID columns. DuckDB correctly infers the types. exptected_res_str = ( @@ -83,50 +82,50 @@ def test_dataframe_from_list_of_tuples(self, spark): # Not enough column names if not USE_ACTUAL_SPARK: - # FIXME: Spark does not raise this error + # TODO: Spark does not raise this error # noqa: TD002, TD003 with pytest.raises(PySparkValueError, match="number of columns in the DataFrame don't match"): - df = spark.createDataFrame(address, ["id", "address"]) + spark.createDataFrame(address, ["id", "address"]) # Empty column names list # Columns are filled in with default names - # TODO: check against Spark behavior + # TODO: check against Spark behavior # noqa: TD002, TD003 df = spark.createDataFrame(address, []) res = df.collect() assert res == [ - Row(col0=1, col1='14851 Jeffrey Rd', col2='DE'), - Row(col0=2, col1='43421 Margarita St', col2='NY'), - Row(col0=3, col1='13111 Siemon Ave', col2='DE'), + Row(col0=1, col1="14851 Jeffrey Rd", col2="DE"), + Row(col0=2, col1="43421 Margarita St", col2="NY"), + Row(col0=3, col1="13111 Siemon Ave", col2="DE"), ] # Too many column names if not USE_ACTUAL_SPARK: # In Spark, this leads to an IndexError with pytest.raises(PySparkValueError, match="number of columns in the DataFrame don't match"): - df = spark.createDataFrame(address, ["id", "address", "one", "two", "three"]) + spark.createDataFrame(address, ["id", "address", "one", "two", "three"]) # Column names is not a list (but is iterable) if not USE_ACTUAL_SPARK: # These things do not work in Spark or throw different errors - df = spark.createDataFrame(address, {'a': 5, 'b': 6, 'c': 42}) + df = spark.createDataFrame(address, {"a": 5, "b": 6, "c": 42}) res = df.collect() assert res == [ - Row(a=1, b='14851 Jeffrey Rd', c='DE'), - Row(a=2, b='43421 Margarita St', c='NY'), - Row(a=3, b='13111 Siemon Ave', c='DE'), + Row(a=1, b="14851 Jeffrey Rd", c="DE"), + Row(a=2, b="43421 Margarita St", c="NY"), + Row(a=3, b="13111 Siemon Ave", c="DE"), ] # Column names is not a list (string, becomes a single column name) with pytest.raises(PySparkValueError, match="number of columns in the DataFrame don't match"): - df = spark.createDataFrame(address, 'a') + spark.createDataFrame(address, "a") with pytest.raises(TypeError, match="must be an iterable, not int"): - df = spark.createDataFrame(address, 5) + spark.createDataFrame(address, 5) def test_dataframe(self, spark): # Create DataFrame df = spark.createDataFrame([("Scala", 25000), ("Spark", 35000), ("PHP", 21000)]) res = df.collect() - assert res == [Row(col0='Scala', col1=25000), Row(col0='Spark', col1=35000), Row(col0='PHP', col1=21000)] + assert res == [Row(col0="Scala", col1=25000), Row(col0="Spark", col1=35000), Row(col0="PHP", col1=21000)] @pytest.mark.skipif(USE_ACTUAL_SPARK, reason="We can't create tables with our Spark test setup") def test_writing_to_table(self, spark): @@ -136,65 +135,65 @@ def test_writing_to_table(self, spark): create table sample_table("_1" bool, "_2" integer) """ ) - spark.sql('insert into sample_table VALUES (True, 42)') + spark.sql("insert into sample_table VALUES (True, 42)") spark.table("sample_table").write.saveAsTable("sample_hive_table") df3 = spark.sql("SELECT _1,_2 FROM sample_hive_table") res = df3.collect() assert res == [Row(_1=True, _2=42)] schema = df3.schema - assert schema == StructType([StructField('_1', BooleanType(), True), StructField('_2', IntegerType(), True)]) + assert schema == StructType([StructField("_1", BooleanType(), True), StructField("_2", IntegerType(), True)]) def test_dataframe_collect(self, spark): - df = spark.createDataFrame([(42,), (21,)]).toDF('a') + df = spark.createDataFrame([(42,), (21,)]).toDF("a") res = df.collect() - assert str(res) == '[Row(a=42), Row(a=21)]' + assert str(res) == "[Row(a=42), Row(a=21)]" def test_dataframe_from_rows(self, spark): columns = ["language", "users_count"] data = [("Java", "20000"), ("Python", "100000"), ("Scala", "3000")] - rowData = map(lambda x: Row(*x), data) + rowData = (Row(*x) for x in data) df = spark.createDataFrame(rowData, columns) res = df.collect() assert res == [ - Row(language='Java', users_count='20000'), - Row(language='Python', users_count='100000'), - Row(language='Scala', users_count='3000'), + Row(language="Java", users_count="20000"), + Row(language="Python", users_count="100000"), + Row(language="Scala", users_count="3000"), ] def test_empty_df(self, spark): schema = StructType( [ - StructField('firstname', StringType(), True), - StructField('middlename', StringType(), True), - StructField('lastname', StringType(), True), + StructField("firstname", StringType(), True), + StructField("middlename", StringType(), True), + StructField("lastname", StringType(), True), ] ) df = spark.createDataFrame([], schema=schema) res = df.collect() - # TODO: assert that the types and column names are correct + # TODO: assert that the types and column names are correct # noqa: TD002, TD003 assert res == [] def test_df_from_pandas(self, spark): import pandas as pd - df = spark.createDataFrame(pd.DataFrame({'a': [42, 21], 'b': [True, False]})) + df = spark.createDataFrame(pd.DataFrame({"a": [42, 21], "b": [True, False]})) res = df.collect() assert res == [Row(a=42, b=True), Row(a=21, b=False)] def test_df_from_struct_type(self, spark): - schema = StructType([StructField('a', LongType()), StructField('b', BooleanType())]) + schema = StructType([StructField("a", LongType()), StructField("b", BooleanType())]) df = spark.createDataFrame([(42, True), (21, False)], schema) res = df.collect() assert res == [Row(a=42, b=True), Row(a=21, b=False)] def test_df_from_name_list(self, spark): - df = spark.createDataFrame([(42, True), (21, False)], ['a', 'b']) + df = spark.createDataFrame([(42, True), (21, False)], ["a", "b"]) res = df.collect() assert res == [Row(a=42, b=True), Row(a=21, b=False)] def test_df_creation_coverage(self, spark): - from spark_namespace.sql.types import StructType, StructField, StringType, IntegerType + from spark_namespace.sql.types import IntegerType, StringType, StructField, StructType data2 = [ ("James", "", "Smith", "36636", "M", 3000), @@ -218,11 +217,11 @@ def test_df_creation_coverage(self, spark): df = spark.createDataFrame(data=data2, schema=schema) res = df.collect() assert res == [ - Row(firstname='James', middlename='', lastname='Smith', id='36636', gender='M', salary=3000), - Row(firstname='Michael', middlename='Rose', lastname='', id='40288', gender='M', salary=4000), - Row(firstname='Robert', middlename='', lastname='Williams', id='42114', gender='M', salary=4000), - Row(firstname='Maria', middlename='Anne', lastname='Jones', id='39192', gender='F', salary=4000), - Row(firstname='Jen', middlename='Mary', lastname='Brown', id='', gender='F', salary=-1), + Row(firstname="James", middlename="", lastname="Smith", id="36636", gender="M", salary=3000), + Row(firstname="Michael", middlename="Rose", lastname="", id="40288", gender="M", salary=4000), + Row(firstname="Robert", middlename="", lastname="Williams", id="42114", gender="M", salary=4000), + Row(firstname="Maria", middlename="Anne", lastname="Jones", id="39192", gender="F", salary=4000), + Row(firstname="Jen", middlename="Mary", lastname="Brown", id="", gender="F", salary=-1), ] def test_df_nested_struct(self, spark): @@ -236,18 +235,18 @@ def test_df_nested_struct(self, spark): structureSchema = StructType( [ StructField( - 'name', + "name", StructType( [ - StructField('firstname', StringType(), True), - StructField('middlename', StringType(), True), - StructField('lastname', StringType(), True), + StructField("firstname", StringType(), True), + StructField("middlename", StringType(), True), + StructField("lastname", StringType(), True), ] ), ), - StructField('id', StringType(), True), - StructField('gender', StringType(), True), - StructField('salary', IntegerType(), True), + StructField("id", StringType(), True), + StructField("gender", StringType(), True), + StructField("salary", IntegerType(), True), ] ) @@ -255,24 +254,24 @@ def test_df_nested_struct(self, spark): res = df2.collect() expected_res = [ Row( - name={'firstname': 'James', 'middlename': '', 'lastname': 'Smith'}, id='36636', gender='M', salary=3100 + name={"firstname": "James", "middlename": "", "lastname": "Smith"}, id="36636", gender="M", salary=3100 ), Row( - name={'firstname': 'Michael', 'middlename': 'Rose', 'lastname': ''}, id='40288', gender='M', salary=4300 + name={"firstname": "Michael", "middlename": "Rose", "lastname": ""}, id="40288", gender="M", salary=4300 ), Row( - name={'firstname': 'Robert', 'middlename': '', 'lastname': 'Williams'}, - id='42114', - gender='M', + name={"firstname": "Robert", "middlename": "", "lastname": "Williams"}, + id="42114", + gender="M", salary=1400, ), Row( - name={'firstname': 'Maria', 'middlename': 'Anne', 'lastname': 'Jones'}, - id='39192', - gender='F', + name={"firstname": "Maria", "middlename": "Anne", "lastname": "Jones"}, + id="39192", + gender="F", salary=5500, ), - Row(name={'firstname': 'Jen', 'middlename': 'Mary', 'lastname': 'Brown'}, id='', gender='F', salary=-1), + Row(name={"firstname": "Jen", "middlename": "Mary", "lastname": "Brown"}, id="", gender="F", salary=-1), ] if USE_ACTUAL_SPARK: expected_res = [Row(name=Row(**r.name), id=r.id, gender=r.gender, salary=r.salary) for r in expected_res] @@ -281,24 +280,24 @@ def test_df_nested_struct(self, spark): assert schema == StructType( [ StructField( - 'name', + "name", StructType( [ - StructField('firstname', StringType(), True), - StructField('middlename', StringType(), True), - StructField('lastname', StringType(), True), + StructField("firstname", StringType(), True), + StructField("middlename", StringType(), True), + StructField("lastname", StringType(), True), ] ), True, ), - StructField('id', StringType(), True), - StructField('gender', StringType(), True), - StructField('salary', IntegerType(), True), + StructField("id", StringType(), True), + StructField("gender", StringType(), True), + StructField("salary", IntegerType(), True), ] ) def test_df_columns(self, spark): - from spark_namespace.sql.functions import col, struct, when + from spark_namespace.sql.functions import col structureData = [ (("James", "", "Smith"), "36636", "M", 3100), @@ -310,18 +309,18 @@ def test_df_columns(self, spark): structureSchema = StructType( [ StructField( - 'name', + "name", StructType( [ - StructField('firstname', StringType(), True), - StructField('middlename', StringType(), True), - StructField('lastname', StringType(), True), + StructField("firstname", StringType(), True), + StructField("middlename", StringType(), True), + StructField("lastname", StringType(), True), ] ), ), - StructField('id', StringType(), True), - StructField('gender', StringType(), True), - StructField('salary', IntegerType(), True), + StructField("id", StringType(), True), + StructField("gender", StringType(), True), + StructField("salary", IntegerType(), True), ] ) @@ -339,25 +338,24 @@ def test_df_columns(self, spark): ), ).drop("id", "gender", "salary") - assert 'OtherInfo' in updatedDF.columns + assert "OtherInfo" in updatedDF.columns def test_array_and_map_type(self, spark): - """Array & Map""" - - arrayStructureSchema = StructType( + """Array & Map.""" + StructType( [ StructField( - 'name', + "name", StructType( [ - StructField('firstname', StringType(), True), - StructField('middlename', StringType(), True), - StructField('lastname', StringType(), True), + StructField("firstname", StringType(), True), + StructField("middlename", StringType(), True), + StructField("lastname", StringType(), True), ] ), ), - StructField('hobbies', ArrayType(StringType()), True), - StructField('properties', MapType(StringType(), StringType()), True), + StructField("hobbies", ArrayType(StringType()), True), + StructField("properties", MapType(StringType(), StringType()), True), ] ) diff --git a/tests/fast/spark/test_spark_dataframe_sort.py b/tests/fast/spark/test_spark_dataframe_sort.py index 20558197..edf77917 100644 --- a/tests/fast/spark/test_spark_dataframe_sort.py +++ b/tests/fast/spark/test_spark_dataframe_sort.py @@ -3,14 +3,14 @@ _ = pytest.importorskip("duckdb.experimental.spark") import spark_namespace.errors -from spark_namespace.sql.types import Row -from spark_namespace.sql.functions import desc, asc -from spark_namespace.errors import PySparkTypeError, PySparkValueError from spark_namespace import USE_ACTUAL_SPARK +from spark_namespace.errors import PySparkTypeError, PySparkValueError +from spark_namespace.sql.functions import asc, desc +from spark_namespace.sql.types import Row -class TestDataFrameSort(object): - data = [(56, "Carol"), (20, "Alice"), (3, "Dave"), (3, "Anna"), (1, "Ben")] +class TestDataFrameSort: + data = ((56, "Carol"), (20, "Alice"), (3, "Dave"), (3, "Anna"), (1, "Ben")) def test_sort_ascending(self, spark): df = spark.createDataFrame(self.data, ["age", "name"]) @@ -84,18 +84,18 @@ def test_sort_invalid_column(self, spark): df = spark.createDataFrame(self.data, ["age", "name"]) with pytest.raises(PySparkTypeError): - df = df.sort(dict(a=1)) + df = df.sort({"a": 1}) def test_sort_with_desc(self, spark): df = spark.createDataFrame(self.data, ["age", "name"]) df = df.sort(desc("name")) res = df.collect() assert res == [ - Row(age=3, name='Dave'), - Row(age=56, name='Carol'), - Row(age=1, name='Ben'), - Row(age=3, name='Anna'), - Row(age=20, name='Alice'), + Row(age=3, name="Dave"), + Row(age=56, name="Carol"), + Row(age=1, name="Ben"), + Row(age=3, name="Anna"), + Row(age=20, name="Alice"), ] def test_sort_with_asc(self, spark): @@ -103,9 +103,9 @@ def test_sort_with_asc(self, spark): df = df.sort(asc("name")) res = df.collect() assert res == [ - Row(age=20, name='Alice'), - Row(age=3, name='Anna'), - Row(age=1, name='Ben'), - Row(age=56, name='Carol'), - Row(age=3, name='Dave'), + Row(age=20, name="Alice"), + Row(age=3, name="Anna"), + Row(age=1, name="Ben"), + Row(age=56, name="Carol"), + Row(age=3, name="Dave"), ] diff --git a/tests/fast/spark/test_spark_drop_duplicates.py b/tests/fast/spark/test_spark_drop_duplicates.py index 6dc7f573..cd658c77 100644 --- a/tests/fast/spark/test_spark_drop_duplicates.py +++ b/tests/fast/spark/test_spark_drop_duplicates.py @@ -1,6 +1,4 @@ import pytest - - from spark_namespace.sql.types import ( Row, ) @@ -8,7 +6,7 @@ _ = pytest.importorskip("duckdb.experimental.spark") -class TestDataFrameDropDuplicates(object): +class TestDataFrameDropDuplicates: @pytest.mark.parametrize("method", ["dropDuplicates", "drop_duplicates"]) def test_spark_drop_duplicates(self, method, spark): # Prepare Data @@ -34,15 +32,15 @@ def test_spark_drop_duplicates(self, method, spark): res = distinctDF.collect() # James | Sales had a duplicate, has been removed expected = [ - Row(employee_name='James', department='Sales', salary=3000), - Row(employee_name='Jeff', department='Marketing', salary=3000), - Row(employee_name='Jen', department='Finance', salary=3900), - Row(employee_name='Kumar', department='Marketing', salary=2000), - Row(employee_name='Maria', department='Finance', salary=3000), - Row(employee_name='Michael', department='Sales', salary=4600), - Row(employee_name='Robert', department='Sales', salary=4100), - Row(employee_name='Saif', department='Sales', salary=4100), - Row(employee_name='Scott', department='Finance', salary=3300), + Row(employee_name="James", department="Sales", salary=3000), + Row(employee_name="Jeff", department="Marketing", salary=3000), + Row(employee_name="Jen", department="Finance", salary=3900), + Row(employee_name="Kumar", department="Marketing", salary=2000), + Row(employee_name="Maria", department="Finance", salary=3000), + Row(employee_name="Michael", department="Sales", salary=4600), + Row(employee_name="Robert", department="Sales", salary=4100), + Row(employee_name="Saif", department="Sales", salary=4100), + Row(employee_name="Scott", department="Finance", salary=3300), ] assert res == expected @@ -52,14 +50,14 @@ def test_spark_drop_duplicates(self, method, spark): assert res2 == res expected_subset = [ - Row(department='Finance', salary=3000), - Row(department='Finance', salary=3300), - Row(department='Finance', salary=3900), - Row(department='Marketing', salary=2000), - Row(department='Marketing', salary=3000), - Row(epartment='Sales', salary=3000), - Row(department='Sales', salary=4100), - Row(department='Sales', salary=4600), + Row(department="Finance", salary=3000), + Row(department="Finance", salary=3300), + Row(department="Finance", salary=3900), + Row(department="Marketing", salary=2000), + Row(department="Marketing", salary=3000), + Row(epartment="Sales", salary=3000), + Row(department="Sales", salary=4100), + Row(department="Sales", salary=4600), ] dropDisDF = getattr(df, method)(["department", "salary"]).sort("department", "salary") diff --git a/tests/fast/spark/test_spark_except.py b/tests/fast/spark/test_spark_except.py index 434ac613..dd6c802d 100644 --- a/tests/fast/spark/test_spark_except.py +++ b/tests/fast/spark/test_spark_except.py @@ -1,10 +1,8 @@ -import platform import pytest _ = pytest.importorskip("duckdb.experimental.spark") from duckdb.experimental.spark.sql.types import Row -from duckdb.experimental.spark.sql.functions import col @pytest.fixture @@ -19,7 +17,6 @@ def df2(spark): class TestDataFrameIntersect: def test_exceptAll(self, spark, df, df2): - df3 = df.exceptAll(df2).sort(*df.columns) res = df3.collect() diff --git a/tests/fast/spark/test_spark_filter.py b/tests/fast/spark/test_spark_filter.py index fb6f0b1a..8d412a2c 100644 --- a/tests/fast/spark/test_spark_filter.py +++ b/tests/fast/spark/test_spark_filter.py @@ -2,26 +2,20 @@ _ = pytest.importorskip("duckdb.experimental.spark") + from spark_namespace import USE_ACTUAL_SPARK +from spark_namespace.errors import PySparkTypeError +from spark_namespace.sql.functions import array_contains, col from spark_namespace.sql.types import ( - LongType, - StructType, - BooleanType, - StructField, - StringType, - IntegerType, - LongType, - Row, ArrayType, - MapType, + Row, + StringType, + StructField, + StructType, ) -from spark_namespace.sql.functions import col, struct, when, lit, array_contains -from spark_namespace.errors import PySparkTypeError -import duckdb -import re -class TestDataFrameFilter(object): +class TestDataFrameFilter: def test_dataframe_filter(self, spark): data = [ (("James", "", "Smith"), ["Java", "Scala", "C++"], "OH", "M"), @@ -35,18 +29,18 @@ def test_dataframe_filter(self, spark): schema = StructType( [ StructField( - 'name', + "name", StructType( [ - StructField('firstname', StringType(), True), - StructField('middlename', StringType(), True), - StructField('lastname', StringType(), True), + StructField("firstname", StringType(), True), + StructField("middlename", StringType(), True), + StructField("lastname", StringType(), True), ] ), ), - StructField('languages', ArrayType(StringType()), True), - StructField('state', StringType(), True), - StructField('gender', StringType(), True), + StructField("languages", ArrayType(StringType()), True), + StructField("state", StringType(), True), + StructField("gender", StringType(), True), ] ) @@ -57,53 +51,55 @@ def test_dataframe_filter(self, spark): # Using equals condition df2 = df.filter(df.state == "OH") res = df2.collect() - assert res[0].state == 'OH' + assert res[0].state == "OH" # not equals condition df2 = df.filter(df.state != "OH") df2 = df.filter(~(df.state == "OH")) res = df2.collect() for item in res: - assert item.state == 'NY' or item.state == 'CA' + assert item.state == "NY" or item.state == "CA" df2 = df.filter(col("state") == "OH") res = df2.collect() - assert res[0].state == 'OH' + assert res[0].state == "OH" df2 = df.filter("gender == 'M'") res = df2.collect() - assert res[0].gender == 'M' + assert res[0].gender == "M" df2 = df.filter("gender != 'M'") res = df2.collect() - assert res[0].gender == 'F' + assert res[0].gender == "F" df2 = df.filter("gender <> 'M'") res = df2.collect() - assert res[0].gender == 'F' + assert res[0].gender == "F" # Filter multiple condition df2 = df.filter((df.state == "OH") & (df.gender == "M")) res = df2.collect() assert len(res) == 2 for item in res: - assert item.gender == 'M' and item.state == 'OH' + assert item.gender == "M" + assert item.state == "OH" # Filter IS IN List values li = ["OH", "NY"] df2 = df.filter(df.state.isin(li)) res = df2.collect() for item in res: - assert item.state == 'OH' or item.state == 'NY' + assert item.state == "OH" or item.state == "NY" # Filter NOT IS IN List values # These show all records with NY (NY is not part of the list) df2 = df.filter(~df.state.isin(li)) res = df2.collect() for item in res: - assert item.state != 'OH' and item.state != 'NY' + assert item.state != "OH" + assert item.state != "NY" - df2 = df.filter(df.state.isin(li) == False) + df2 = df.filter(df.state.isin(li) == False) # noqa: E712 res2 = df2.collect() assert res2 == res @@ -111,19 +107,19 @@ def test_dataframe_filter(self, spark): df2 = df.filter(df.state.startswith("N")) res = df2.collect() for item in res: - assert item.state == 'NY' + assert item.state == "NY" # using endswith df2 = df.filter(df.state.endswith("H")) res = df2.collect() for item in res: - assert item.state == 'OH' + assert item.state == "OH" # contains df2 = df.filter(df.state.contains("H")) res = df2.collect() for item in res: - assert item.state == 'OH' + assert item.state == "OH" data2 = [(2, "Michael Rose"), (3, "Robert Williams"), (4, "Rames Rose"), (5, "Rames rose")] df2 = spark.createDataFrame(data=data2, schema=["id", "name"]) @@ -131,56 +127,56 @@ def test_dataframe_filter(self, spark): # like - SQL LIKE pattern df3 = df2.filter(df2.name.like("%rose%")) res = df3.collect() - assert res == [Row(id=5, name='Rames rose')] + assert res == [Row(id=5, name="Rames rose")] # rlike - SQL RLIKE pattern (LIKE with Regex) # This check case insensitive df3 = df2.filter(df2.name.rlike("(?i)^*rose$")) res = df3.collect() - assert res == [Row(id=2, name='Michael Rose'), Row(id=4, name='Rames Rose'), Row(id=5, name='Rames rose')] + assert res == [Row(id=2, name="Michael Rose"), Row(id=4, name="Rames Rose"), Row(id=5, name="Rames rose")] df2 = df.filter(array_contains(df.languages, "Java")) res = df2.collect() - james_name = {'firstname': 'James', 'middlename': '', 'lastname': 'Smith'} - anna_name = {'firstname': 'Anna', 'middlename': 'Rose', 'lastname': ''} + james_name = {"firstname": "James", "middlename": "", "lastname": "Smith"} + anna_name = {"firstname": "Anna", "middlename": "Rose", "lastname": ""} if USE_ACTUAL_SPARK: james_name = Row(**james_name) anna_name = Row(**anna_name) assert res == [ Row( name=james_name, - languages=['Java', 'Scala', 'C++'], - state='OH', - gender='M', + languages=["Java", "Scala", "C++"], + state="OH", + gender="M", ), Row( name=anna_name, - languages=['Spark', 'Java', 'C++'], - state='CA', - gender='F', + languages=["Spark", "Java", "C++"], + state="CA", + gender="F", ), ] df2 = df.filter(df.name.lastname == "Williams") res = df2.collect() - julia_name = {'firstname': 'Julia', 'middlename': '', 'lastname': 'Williams'} - mike_name = {'firstname': 'Mike', 'middlename': 'Mary', 'lastname': 'Williams'} + julia_name = {"firstname": "Julia", "middlename": "", "lastname": "Williams"} + mike_name = {"firstname": "Mike", "middlename": "Mary", "lastname": "Williams"} if USE_ACTUAL_SPARK: julia_name = Row(**julia_name) mike_name = Row(**mike_name) assert res == [ Row( name=julia_name, - languages=['CSharp', 'VB'], - state='OH', - gender='F', + languages=["CSharp", "VB"], + state="OH", + gender="F", ), Row( name=mike_name, - languages=['Python', 'VB'], - state='OH', - gender='M', + languages=["Python", "VB"], + state="OH", + gender="M", ), ] @@ -188,4 +184,4 @@ def test_invalid_condition_type(self, spark): df = spark.createDataFrame([(1, "A")], ["A", "B"]) with pytest.raises(PySparkTypeError): - df = df.filter(dict(a=1)) + df = df.filter({"a": 1}) diff --git a/tests/fast/spark/test_spark_function_concat_ws.py b/tests/fast/spark/test_spark_function_concat_ws.py index 82f19cd1..b4268d0f 100644 --- a/tests/fast/spark/test_spark_function_concat_ws.py +++ b/tests/fast/spark/test_spark_function_concat_ws.py @@ -1,11 +1,11 @@ import pytest _ = pytest.importorskip("duckdb.experimental.spark") +from spark_namespace.sql.functions import col, concat_ws from spark_namespace.sql.types import Row -from spark_namespace.sql.functions import concat_ws, col -class TestReplaceEmpty(object): +class TestReplaceEmpty: def test_replace_empty(self, spark): data = [ ("firstRowFirstColumn", "firstRowSecondColumn"), diff --git a/tests/fast/spark/test_spark_functions_array.py b/tests/fast/spark/test_spark_functions_array.py index f83e0ef2..9ee2ffc2 100644 --- a/tests/fast/spark/test_spark_functions_array.py +++ b/tests/fast/spark/test_spark_functions_array.py @@ -1,10 +1,11 @@ -import pytest import platform +import pytest + _ = pytest.importorskip("duckdb.experimental.spark") +from spark_namespace import USE_ACTUAL_SPARK from spark_namespace.sql import functions as sf from spark_namespace.sql.types import Row -from spark_namespace import USE_ACTUAL_SPARK pytestmark = pytest.mark.skipif( platform.system() == "Emscripten", @@ -75,7 +76,7 @@ def test_array_min(self, spark): ] def test_get(self, spark): - df = spark.createDataFrame([(["a", "b", "c"], 1)], ['data', 'index']) + df = spark.createDataFrame([(["a", "b", "c"], 1)], ["data", "index"]) res = df.select(sf.get(df.data, 1).alias("r")).collect() assert res == [Row(r="b")] @@ -87,25 +88,25 @@ def test_get(self, spark): assert res == [Row(r=None)] res = df.select(sf.get(df.data, "index").alias("r")).collect() - assert res == [Row(r='b')] + assert res == [Row(r="b")] res = df.select(sf.get(df.data, sf.col("index") - 1).alias("r")).collect() - assert res == [Row(r='a')] + assert res == [Row(r="a")] def test_flatten(self, spark): - df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],), ([None, [4, 5]],)], ['data']) + df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],), ([None, [4, 5]],)], ["data"]) res = df.select(sf.flatten(df.data).alias("r")).collect() assert res == [Row(r=[1, 2, 3, 4, 5, 6]), Row(r=None)] def test_array_compact(self, spark): - df = spark.createDataFrame([([1, None, 2, 3],), ([4, 5, None, 4],)], ['data']) + df = spark.createDataFrame([([1, None, 2, 3],), ([4, 5, None, 4],)], ["data"]) - res = df.select(sf.array_compact(df.data).alias("v")).collect() + df.select(sf.array_compact(df.data).alias("v")).collect() assert [Row(v=[1, 2, 3]), Row(v=[4, 5, 4])] def test_array_remove(self, spark): - df = spark.createDataFrame([([1, 2, 3, 1, 1],), ([],)], ['data']) + df = spark.createDataFrame([([1, 2, 3, 1, 1],), ([],)], ["data"]) res = df.select(sf.array_remove(df.data, 1).alias("v")).collect() assert res == [Row(v=[2, 3]), Row(v=[])] @@ -126,102 +127,102 @@ def test_array_append(self, spark): df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2="c")], ["c1", "c2"]) res = df.select(sf.array_append(df.c1, df.c2).alias("r")).collect() - assert res == [Row(r=['b', 'a', 'c', 'c'])] + assert res == [Row(r=["b", "a", "c", "c"])] - res = df.select(sf.array_append(df.c1, 'x')).collect() - assert res == [Row(r=['b', 'a', 'c', 'x'])] + res = df.select(sf.array_append(df.c1, "x")).collect() + assert res == [Row(r=["b", "a", "c", "x"])] def test_array_insert(self, spark): df = spark.createDataFrame( - [(['a', 'b', 'c'], 2, 'd'), (['a', 'b', 'c', 'e'], 2, 'd'), (['c', 'b', 'a'], -2, 'd')], - ['data', 'pos', 'val'], + [(["a", "b", "c"], 2, "d"), (["a", "b", "c", "e"], 2, "d"), (["c", "b", "a"], -2, "d")], + ["data", "pos", "val"], ) - res = df.select(sf.array_insert(df.data, df.pos.cast('integer'), df.val).alias('data')).collect() + res = df.select(sf.array_insert(df.data, df.pos.cast("integer"), df.val).alias("data")).collect() assert res == [ - Row(data=['a', 'd', 'b', 'c']), - Row(data=['a', 'd', 'b', 'c', 'e']), - Row(data=['c', 'b', 'd', 'a']), + Row(data=["a", "d", "b", "c"]), + Row(data=["a", "d", "b", "c", "e"]), + Row(data=["c", "b", "d", "a"]), ] - res = df.select(sf.array_insert(df.data, 5, 'hello').alias('data')).collect() + res = df.select(sf.array_insert(df.data, 5, "hello").alias("data")).collect() assert res == [ - Row(data=['a', 'b', 'c', None, 'hello']), - Row(data=['a', 'b', 'c', 'e', 'hello']), - Row(data=['c', 'b', 'a', None, 'hello']), + Row(data=["a", "b", "c", None, "hello"]), + Row(data=["a", "b", "c", "e", "hello"]), + Row(data=["c", "b", "a", None, "hello"]), ] - res = df.select(sf.array_insert(df.data, -5, 'hello').alias('data')).collect() + res = df.select(sf.array_insert(df.data, -5, "hello").alias("data")).collect() assert res == [ - Row(data=['hello', None, 'a', 'b', 'c']), - Row(data=['hello', 'a', 'b', 'c', 'e']), - Row(data=['hello', None, 'c', 'b', 'a']), + Row(data=["hello", None, "a", "b", "c"]), + Row(data=["hello", "a", "b", "c", "e"]), + Row(data=["hello", None, "c", "b", "a"]), ] def test_slice(self, spark): - df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ['x']) + df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ["x"]) res = df.select(sf.slice(df.x, 2, 2).alias("sliced")).collect() assert res == [Row(sliced=[2, 3]), Row(sliced=[5])] def test_sort_array(self, spark): - df = spark.createDataFrame([([2, 1, None, 3],), ([1],), ([],)], ['data']) + df = spark.createDataFrame([([2, 1, None, 3],), ([1],), ([],)], ["data"]) - res = df.select(sf.sort_array(df.data).alias('r')).collect() + res = df.select(sf.sort_array(df.data).alias("r")).collect() assert res == [Row(r=[None, 1, 2, 3]), Row(r=[1]), Row(r=[])] - res = df.select(sf.sort_array(df.data, asc=False).alias('r')).collect() + res = df.select(sf.sort_array(df.data, asc=False).alias("r")).collect() assert res == [Row(r=[3, 2, 1, None]), Row(r=[1]), Row(r=[])] @pytest.mark.parametrize(("null_replacement", "expected_joined_2"), [(None, "a"), ("replaced", "a,replaced")]) def test_array_join(self, spark, null_replacement, expected_joined_2): - df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], ['data']) + df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], ["data"]) res = df.select(sf.array_join(df.data, ",", null_replacement=null_replacement).alias("joined")).collect() - assert res == [Row(joined='a,b,c'), Row(joined=expected_joined_2)] + assert res == [Row(joined="a,b,c"), Row(joined=expected_joined_2)] def test_array_position(self, spark): - df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ['data']) + df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ["data"]) res = df.select(sf.array_position(df.data, "a").alias("pos")).collect() assert res == [Row(pos=3), Row(pos=0)] def test_array_preprend(self, spark): - df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) + df = spark.createDataFrame([([2, 3, 4],), ([],)], ["data"]) res = df.select(sf.array_prepend(df.data, 1).alias("pre")).collect() assert res == [Row(pre=[1, 2, 3, 4]), Row(pre=[1])] def test_array_repeat(self, spark): - df = spark.createDataFrame([('ab',)], ['data']) + df = spark.createDataFrame([("ab",)], ["data"]) - res = df.select(sf.array_repeat(df.data, 3).alias('r')).collect() - assert res == [Row(r=['ab', 'ab', 'ab'])] + res = df.select(sf.array_repeat(df.data, 3).alias("r")).collect() + assert res == [Row(r=["ab", "ab", "ab"])] def test_array_size(self, spark): - df = spark.createDataFrame([([2, 1, 3],), (None,)], ['data']) + df = spark.createDataFrame([([2, 1, 3],), (None,)], ["data"]) - res = df.select(sf.array_size(df.data).alias('r')).collect() + res = df.select(sf.array_size(df.data).alias("r")).collect() assert res == [Row(r=3), Row(r=None)] def test_array_sort(self, spark): - df = spark.createDataFrame([([2, 1, None, 3],), ([1],), ([],)], ['data']) + df = spark.createDataFrame([([2, 1, None, 3],), ([1],), ([],)], ["data"]) - res = df.select(sf.array_sort(df.data).alias('r')).collect() + res = df.select(sf.array_sort(df.data).alias("r")).collect() assert res == [Row(r=[1, 2, 3, None]), Row(r=[1]), Row(r=[])] def test_arrays_overlap(self, spark): df = spark.createDataFrame( - [(["a", "b"], ["b", "c"]), (["a"], ["b", "c"]), ([None, "c"], ["a"]), ([None, "c"], [None])], ['x', 'y'] + [(["a", "b"], ["b", "c"]), (["a"], ["b", "c"]), ([None, "c"], ["a"]), ([None, "c"], [None])], ["x", "y"] ) res = df.select(sf.arrays_overlap(df.x, df.y).alias("overlap")).collect() assert res == [Row(overlap=True), Row(overlap=False), Row(overlap=None), Row(overlap=None)] def test_arrays_zip(self, spark): - df = spark.createDataFrame([([1, 2, 3], [2, 4, 6], [3, 6])], ['vals1', 'vals2', 'vals3']) + df = spark.createDataFrame([([1, 2, 3], [2, 4, 6], [3, 6])], ["vals1", "vals2", "vals3"]) - res = df.select(sf.arrays_zip(df.vals1, df.vals2, df.vals3).alias('zipped')).collect() - # FIXME: The structure of the results should be the same + res = df.select(sf.arrays_zip(df.vals1, df.vals2, df.vals3).alias("zipped")).collect() + # TODO: The structure of the results should be the same # noqa: TD002, TD003 if USE_ACTUAL_SPARK: assert res == [ Row( diff --git a/tests/fast/spark/test_spark_functions_base64.py b/tests/fast/spark/test_spark_functions_base64.py index 734a5275..44e4a7cd 100644 --- a/tests/fast/spark/test_spark_functions_base64.py +++ b/tests/fast/spark/test_spark_functions_base64.py @@ -5,7 +5,7 @@ from spark_namespace.sql import functions as F -class TestSparkFunctionsBase64(object): +class TestSparkFunctionsBase64: def test_base64(self, spark): data = [ ("quack",), @@ -40,4 +40,4 @@ def test_unbase64(self, spark): .select("decoded_value") .collect() ) - assert res[0].decoded_value == b'quack' + assert res[0].decoded_value == b"quack" diff --git a/tests/fast/spark/test_spark_functions_date.py b/tests/fast/spark/test_spark_functions_date.py index 2a51d9b8..914d33f6 100644 --- a/tests/fast/spark/test_spark_functions_date.py +++ b/tests/fast/spark/test_spark_functions_date.py @@ -1,4 +1,5 @@ import warnings + import pytest _ = pytest.importorskip("duckdb.experimental.spark") @@ -6,11 +7,11 @@ from spark_namespace import USE_ACTUAL_SPARK from spark_namespace.sql import functions as F -from spark_namespace.sql.types import Row from spark_namespace.sql.functions import col +from spark_namespace.sql.types import Row -class TestsSparkFunctionsDate(object): +class TestsSparkFunctionsDate: def test_date_trunc(self, spark): df = spark.createDataFrame( [(datetime(2019, 1, 23, 14, 34, 9, 87539),)], @@ -145,43 +146,43 @@ def test_second(self, spark): assert result[0].second_num == 45 def test_unix_date(self, spark): - df = spark.createDataFrame([('1970-01-02',)], ['t']) - res = df.select(F.unix_date(df.t.cast("date")).alias('n')).collect() + df = spark.createDataFrame([("1970-01-02",)], ["t"]) + res = df.select(F.unix_date(df.t.cast("date")).alias("n")).collect() assert res == [Row(n=1)] def test_unix_micros(self, spark): - df = spark.createDataFrame([('2015-07-22 10:00:00+00:00',)], ['t']) - res = df.select(F.unix_micros(df.t.cast("timestamp")).alias('n')).collect() + df = spark.createDataFrame([("2015-07-22 10:00:00+00:00",)], ["t"]) + res = df.select(F.unix_micros(df.t.cast("timestamp")).alias("n")).collect() assert res == [Row(n=1437559200000000)] def test_unix_millis(self, spark): - df = spark.createDataFrame([('2015-07-22 10:00:00+00:00',)], ['t']) - res = df.select(F.unix_millis(df.t.cast("timestamp")).alias('n')).collect() + df = spark.createDataFrame([("2015-07-22 10:00:00+00:00",)], ["t"]) + res = df.select(F.unix_millis(df.t.cast("timestamp")).alias("n")).collect() assert res == [Row(n=1437559200000)] def test_unix_seconds(self, spark): - df = spark.createDataFrame([('2015-07-22 10:00:00+00:00',)], ['t']) - res = df.select(F.unix_seconds(df.t.cast("timestamp")).alias('n')).collect() + df = spark.createDataFrame([("2015-07-22 10:00:00+00:00",)], ["t"]) + res = df.select(F.unix_seconds(df.t.cast("timestamp")).alias("n")).collect() assert res == [Row(n=1437559200)] def test_weekday(self, spark): - df = spark.createDataFrame([('2015-04-08',)], ['dt']) - res = df.select(F.weekday(df.dt.cast("date")).alias('day')).collect() + df = spark.createDataFrame([("2015-04-08",)], ["dt"]) + res = df.select(F.weekday(df.dt.cast("date")).alias("day")).collect() assert res == [Row(day=2)] def test_to_date(self, spark): - df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - res = df.select(F.to_date(df.t).alias('date')).collect() + df = spark.createDataFrame([("1997-02-28 10:30:00",)], ["t"]) + res = df.select(F.to_date(df.t).alias("date")).collect() assert res == [Row(date=date(1997, 2, 28))] def test_to_timestamp(self, spark): - df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - res = df.select(F.to_timestamp(df.t).alias('dt')).collect() + df = spark.createDataFrame([("1997-02-28 10:30:00",)], ["t"]) + res = df.select(F.to_timestamp(df.t).alias("dt")).collect() assert res == [Row(dt=datetime(1997, 2, 28, 10, 30))] def test_to_timestamp_ltz(self, spark): df = spark.createDataFrame([("2016-12-31",)], ["e"]) - res = df.select(F.to_timestamp_ltz(df.e).alias('r')).collect() + res = df.select(F.to_timestamp_ltz(df.e).alias("r")).collect() assert res == [Row(r=datetime(2016, 12, 31, 0, 0))] @@ -194,15 +195,15 @@ def test_to_timestamp_ntz(self, spark): if USE_ACTUAL_SPARK: with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=DeprecationWarning) - res = df.select(F.to_timestamp_ntz(df.e).alias('r')).collect() + res = df.select(F.to_timestamp_ntz(df.e).alias("r")).collect() else: - res = df.select(F.to_timestamp_ntz(df.e).alias('r')).collect() + res = df.select(F.to_timestamp_ntz(df.e).alias("r")).collect() assert res == [Row(r=datetime(2016, 4, 8, 0, 0))] def test_last_day(self, spark): - df = spark.createDataFrame([('1997-02-10',)], ['d']) + df = spark.createDataFrame([("1997-02-10",)], ["d"]) - res = df.select(F.last_day(df.d.cast("date")).alias('date')).collect() + res = df.select(F.last_day(df.d.cast("date")).alias("date")).collect() assert res == [Row(date=date(1997, 2, 28))] def test_add_months(self, spark): @@ -219,12 +220,12 @@ def test_add_months(self, spark): assert result[0].with_col == date(2024, 7, 12) def test_date_diff(self, spark): - df = spark.createDataFrame([('2015-04-08', '2015-05-10')], ["d1", "d2"]) + df = spark.createDataFrame([("2015-04-08", "2015-05-10")], ["d1", "d2"]) - result_data = df.select(F.date_diff(col("d2").cast('DATE'), col("d1").cast('DATE')).alias("diff")).collect() + result_data = df.select(F.date_diff(col("d2").cast("DATE"), col("d1").cast("DATE")).alias("diff")).collect() assert result_data[0]["diff"] == -32 - result_data = df.select(F.date_diff(col("d1").cast('DATE'), col("d2").cast('DATE')).alias("diff")).collect() + result_data = df.select(F.date_diff(col("d1").cast("DATE"), col("d2").cast("DATE")).alias("diff")).collect() assert result_data[0]["diff"] == 32 def test_try_to_timestamp(self, spark): @@ -239,4 +240,4 @@ def test_try_to_timestamp_with_format(self, spark): res = df.select(F.try_to_timestamp(df.t, format=F.lit("%Y-%m-%d %H:%M:%S")).alias("dt")).collect() assert res[0].dt == datetime(1997, 2, 28, 10, 30) assert res[1].dt is None - assert res[2].dt is None \ No newline at end of file + assert res[2].dt is None diff --git a/tests/fast/spark/test_spark_functions_expr.py b/tests/fast/spark/test_spark_functions_expr.py index 7cc47735..f14dbcce 100644 --- a/tests/fast/spark/test_spark_functions_expr.py +++ b/tests/fast/spark/test_spark_functions_expr.py @@ -5,7 +5,7 @@ _ = pytest.importorskip("duckdb.experimental.spark") -class TestSparkFunctionsExpr(object): +class TestSparkFunctionsExpr: def test_expr(self, spark): df = spark.createDataFrame([["Alice"], ["Bob"]], ["name"]) res = df.select("name", F.expr("length(name)").alias("str_len")).collect() diff --git a/tests/fast/spark/test_spark_functions_hash.py b/tests/fast/spark/test_spark_functions_hash.py index 7b14f29e..d1890990 100644 --- a/tests/fast/spark/test_spark_functions_hash.py +++ b/tests/fast/spark/test_spark_functions_hash.py @@ -4,7 +4,7 @@ from spark_namespace.sql import functions as F -class TestSparkFunctionsHash(object): +class TestSparkFunctionsHash: def test_md5(self, spark): data = [ ("quack",), diff --git a/tests/fast/spark/test_spark_functions_hex.py b/tests/fast/spark/test_spark_functions_hex.py index e5cbf12f..54caaf28 100644 --- a/tests/fast/spark/test_spark_functions_hex.py +++ b/tests/fast/spark/test_spark_functions_hex.py @@ -1,11 +1,10 @@ import pytest -import sys _ = pytest.importorskip("duckdb.experimental.spark") from spark_namespace.sql import functions as F -class TestSparkFunctionsHex(object): +class TestSparkFunctionsHex: def test_hex_string_col(self, spark): data = [ ("quack",), @@ -20,7 +19,7 @@ def test_hex_string_col(self, spark): def test_hex_binary_col(self, spark): data = [ - (b'quack',), + (b"quack",), ] res = ( spark.createDataFrame(data, ["firstColumn"]) @@ -32,7 +31,7 @@ def test_hex_binary_col(self, spark): def test_hex_integer_col(self, spark): data = [ - (int(42),), + (42,), ] res = ( spark.createDataFrame(data, ["firstColumn"]) @@ -65,4 +64,4 @@ def test_unhex(self, spark): .select("unhex_value") .collect() ) - assert res[0].unhex_value == b'quack' + assert res[0].unhex_value == b"quack" diff --git a/tests/fast/spark/test_spark_functions_miscellaneous.py b/tests/fast/spark/test_spark_functions_miscellaneous.py index 87b6b776..f6af47fe 100644 --- a/tests/fast/spark/test_spark_functions_miscellaneous.py +++ b/tests/fast/spark/test_spark_functions_miscellaneous.py @@ -30,38 +30,38 @@ def test_call_function(self, spark): ] def test_octet_length(self, spark): - df = spark.createDataFrame([('cat',)], ['c1']) - res = df.select(F.octet_length('c1').alias("o")).collect() + df = spark.createDataFrame([("cat",)], ["c1"]) + res = df.select(F.octet_length("c1").alias("o")).collect() assert res == [Row(o=3)] def test_positive(self, spark): - df = spark.createDataFrame([(-1,), (0,), (1,)], ['v']) + df = spark.createDataFrame([(-1,), (0,), (1,)], ["v"]) res = df.select(F.positive("v").alias("p")).collect() assert res == [Row(p=-1), Row(p=0), Row(p=1)] def test_sequence(self, spark): - df1 = spark.createDataFrame([(-2, 2)], ('C1', 'C2')) - res = df1.select(F.sequence('C1', 'C2').alias('r')).collect() + df1 = spark.createDataFrame([(-2, 2)], ("C1", "C2")) + res = df1.select(F.sequence("C1", "C2").alias("r")).collect() assert res == [Row(r=[-2, -1, 0, 1, 2])] - df2 = spark.createDataFrame([(4, -4, -2)], ('C1', 'C2', 'C3')) - res = df2.select(F.sequence('C1', 'C2', 'C3').alias('r')).collect() + df2 = spark.createDataFrame([(4, -4, -2)], ("C1", "C2", "C3")) + res = df2.select(F.sequence("C1", "C2", "C3").alias("r")).collect() assert res == [Row(r=[4, 2, 0, -2, -4])] def test_like(self, spark): - df = spark.createDataFrame([("Spark", "_park")], ['a', 'b']) - res = df.select(F.like(df.a, df.b).alias('r')).collect() + df = spark.createDataFrame([("Spark", "_park")], ["a", "b"]) + res = df.select(F.like(df.a, df.b).alias("r")).collect() assert res == [Row(r=True)] - df = spark.createDataFrame([("%SystemDrive%/Users/John", "/%SystemDrive/%//Users%")], ['a', 'b']) - res = df.select(F.like(df.a, df.b, F.lit('/')).alias('r')).collect() + df = spark.createDataFrame([("%SystemDrive%/Users/John", "/%SystemDrive/%//Users%")], ["a", "b"]) + res = df.select(F.like(df.a, df.b, F.lit("/")).alias("r")).collect() assert res == [Row(r=True)] def test_ilike(self, spark): - df = spark.createDataFrame([("Spark", "spark")], ['a', 'b']) - res = df.select(F.ilike(df.a, df.b).alias('r')).collect() + df = spark.createDataFrame([("Spark", "spark")], ["a", "b"]) + res = df.select(F.ilike(df.a, df.b).alias("r")).collect() assert res == [Row(r=True)] - df = spark.createDataFrame([("%SystemDrive%/Users/John", "/%SystemDrive/%//Users%")], ['a', 'b']) - res = df.select(F.ilike(df.a, df.b, F.lit('/')).alias('r')).collect() + df = spark.createDataFrame([("%SystemDrive%/Users/John", "/%SystemDrive/%//Users%")], ["a", "b"]) + res = df.select(F.ilike(df.a, df.b, F.lit("/")).alias("r")).collect() assert res == [Row(r=True)] diff --git a/tests/fast/spark/test_spark_functions_null.py b/tests/fast/spark/test_spark_functions_null.py index 3f5ee31b..2bcfd94a 100644 --- a/tests/fast/spark/test_spark_functions_null.py +++ b/tests/fast/spark/test_spark_functions_null.py @@ -7,7 +7,7 @@ from spark_namespace.sql.types import Row -class TestsSparkFunctionsNull(object): +class TestsSparkFunctionsNull: def test_coalesce(self, spark): data = [ (None, 2), @@ -62,7 +62,7 @@ def test_nvl2(self, spark): ], ["a", "b", "c"], ) - res = df.select(F.nvl2(df.a, df.b, df.c).alias('r')).collect() + res = df.select(F.nvl2(df.a, df.b, df.c).alias("r")).collect() assert res == [Row(r=6), Row(r=9)] def test_ifnull(self, spark): @@ -92,7 +92,7 @@ def test_nullif(self, spark): ], ["a", "b"], ) - res = df.select(F.nullif(df.a, df.b).alias('r')).collect() + res = df.select(F.nullif(df.a, df.b).alias("r")).collect() assert res == [Row(r=None), Row(r=1)] def test_isnull(self, spark): @@ -116,4 +116,4 @@ def test_isnotnull(self, spark): def test_equal_null(self, spark): df = spark.createDataFrame([(1, 1), (None, 2), (None, None)], ("a", "b")) res = df.select(F.equal_null("a", F.col("b")).alias("r")).collect() - assert res == [Row(r=True), Row(r=False), Row(r=True)] \ No newline at end of file + assert res == [Row(r=True), Row(r=False), Row(r=True)] diff --git a/tests/fast/spark/test_spark_functions_numeric.py b/tests/fast/spark/test_spark_functions_numeric.py index 9c4bafb9..8378aafa 100644 --- a/tests/fast/spark/test_spark_functions_numeric.py +++ b/tests/fast/spark/test_spark_functions_numeric.py @@ -3,13 +3,14 @@ _ = pytest.importorskip("duckdb.experimental.spark") import math + import numpy as np from spark_namespace import USE_ACTUAL_SPARK from spark_namespace.sql import functions as sf from spark_namespace.sql.types import Row -class TestSparkFunctionsNumeric(object): +class TestSparkFunctionsNumeric: def test_greatest(self, spark): data = [ (1, 2), @@ -131,8 +132,8 @@ def test_exp(self, spark): df = spark.createDataFrame(data, ["firstColumn"]) df = df.withColumn("exp_value", sf.exp(sf.col("firstColumn"))) res = df.select("exp_value").collect() - round(res[0].exp_value, 2) == 2 - res[1].exp_value == 1 + assert round(res[0].exp_value, 2) == 2 + assert res[1].exp_value == 1 def test_factorial(self, spark): data = [ @@ -168,8 +169,8 @@ def test_ln(self, spark): df = spark.createDataFrame(data, ["firstColumn"]) df = df.withColumn("ln_value", sf.ln(sf.col("firstColumn"))) res = df.select("ln_value").collect() - round(res[0].ln_value, 2) == 1 - res[1].ln_value == 0 + assert round(res[0].ln_value, 2) == 1 + assert res[1].ln_value == 0 def test_degrees(self, spark): data = [ @@ -179,8 +180,8 @@ def test_degrees(self, spark): df = spark.createDataFrame(data, ["firstColumn"]) df = df.withColumn("degrees_value", sf.degrees(sf.col("firstColumn"))) res = df.select("degrees_value").collect() - round(res[0].degrees_value, 2) == 180 - res[1].degrees_value == 0 + assert round(res[0].degrees_value, 2) == 180 + assert res[1].degrees_value == 0 def test_radians(self, spark): data = [ @@ -190,8 +191,8 @@ def test_radians(self, spark): df = spark.createDataFrame(data, ["firstColumn"]) df = df.withColumn("radians_value", sf.radians(sf.col("firstColumn"))) res = df.select("radians_value").collect() - round(res[0].radians_value, 2) == 3.14 - res[1].radians_value == 0 + assert round(res[0].radians_value, 2) == 3.14 + assert res[1].radians_value == 0 def test_atan(self, spark): data = [ @@ -201,8 +202,8 @@ def test_atan(self, spark): df = spark.createDataFrame(data, ["firstColumn"]) df = df.withColumn("atan_value", sf.atan(sf.col("firstColumn"))) res = df.select("atan_value").collect() - round(res[0].atan_value, 2) == 0.79 - res[1].atan_value == 0 + assert round(res[0].atan_value, 2) == 0.79 + assert res[1].atan_value == 0 def test_atan2(self, spark): data = [ @@ -214,20 +215,20 @@ def test_atan2(self, spark): # Both columns df2 = df.withColumn("atan2_value", sf.atan2(sf.col("firstColumn"), "secondColumn")) res = df2.select("atan2_value").collect() - round(res[0].atan2_value, 2) == 0.79 - res[1].atan2_value == 0 + assert round(res[0].atan2_value, 2) == 0.79 + assert res[1].atan2_value == 0 # Both literals df2 = df.withColumn("atan2_value_lit", sf.atan2(1, 1)) res = df2.select("atan2_value_lit").collect() - round(res[0].atan2_value_lit, 2) == 0.79 - round(res[1].atan2_value_lit, 2) == 0.79 + assert round(res[0].atan2_value_lit, 2) == 0.79 + assert round(res[1].atan2_value_lit, 2) == 0.79 # One literal, one column df2 = df.withColumn("atan2_value_lit_col", sf.atan2(1.0, sf.col("secondColumn"))) res = df2.select("atan2_value_lit_col").collect() - round(res[0].atan2_value_lit_col, 2) == 0.79 - res[1].atan2_value_lit_col == 0 + assert round(res[0].atan2_value_lit_col, 2) == 0.79 + assert round(res[1].atan2_value_lit_col, 2) == 1.57 def test_tan(self, spark): data = [ @@ -237,8 +238,8 @@ def test_tan(self, spark): df = spark.createDataFrame(data, ["firstColumn"]) df = df.withColumn("tan_value", sf.tan(sf.col("firstColumn"))) res = df.select("tan_value").collect() - res[0].tan_value == 0 - round(res[1].tan_value, 2) == 1.56 + assert res[0].tan_value == 0 + assert round(res[1].tan_value, 2) == 1.56 def test_round(self, spark): data = [ @@ -290,7 +291,7 @@ def test_asin(self, spark): if USE_ACTUAL_SPARK: assert np.isnan(res[1].asin_value) else: - # FIXME: DuckDB should return NaN here. Reason is that + # TODO: DuckDB should return NaN here. Reason is that # noqa: TD002, TD003 # ConstantExpression(float("nan")) gives NULL and not NaN assert res[1].asin_value is None @@ -301,7 +302,7 @@ def test_corr(self, spark): # Have to use a groupby to test this as agg is not yet implemented without df = spark.createDataFrame(zip(a, b, ["group1"] * N), ["a", "b", "g"]) - res = df.groupBy("g").agg(sf.corr("a", "b").alias('c')).collect() + res = df.groupBy("g").agg(sf.corr("a", "b").alias("c")).collect() assert pytest.approx(res[0].c) == 1 def test_cot(self, spark): @@ -330,13 +331,15 @@ def test_pow(self, spark): def test_random(self, spark): df = spark.range(0, 2, 1) - res = df.withColumn('rand', sf.rand()).collect() + res = df.withColumn("rand", sf.rand()).collect() assert isinstance(res[0].rand, float) - assert res[0].rand >= 0 and res[0].rand < 1 + assert res[0].rand >= 0 + assert res[0].rand < 1 assert isinstance(res[1].rand, float) - assert res[1].rand >= 0 and res[1].rand < 1 + assert res[1].rand >= 0 + assert res[1].rand < 1 @pytest.mark.parametrize("sign_func", [sf.sign, sf.signum]) def test_sign(self, spark, sign_func): @@ -355,4 +358,4 @@ def test_negative(self, spark): res = df.collect() assert res[0].value == 0 assert res[1].value == -2 - assert res[2].value == -3 \ No newline at end of file + assert res[2].value == -3 diff --git a/tests/fast/spark/test_spark_functions_string.py b/tests/fast/spark/test_spark_functions_string.py index e90cca11..ba2a540d 100644 --- a/tests/fast/spark/test_spark_functions_string.py +++ b/tests/fast/spark/test_spark_functions_string.py @@ -7,7 +7,7 @@ from spark_namespace.sql.types import Row -class TestSparkFunctionsString(object): +class TestSparkFunctionsString: def test_length(self, spark): data = [ ("firstRowFirstColumn",), @@ -152,47 +152,47 @@ def test_btrim(self, spark): "SL", ) ], - ['a', 'b'], + ["a", "b"], ) - res = df.select(F.btrim(df.a, df.b).alias('r')).collect() - assert res == [Row(r='parkSQ')] + res = df.select(F.btrim(df.a, df.b).alias("r")).collect() + assert res == [Row(r="parkSQ")] - df = spark.createDataFrame([(" SparkSQL ",)], ['a']) - res = df.select(F.btrim(df.a).alias('r')).collect() - assert res == [Row(r='SparkSQL')] + df = spark.createDataFrame([(" SparkSQL ",)], ["a"]) + res = df.select(F.btrim(df.a).alias("r")).collect() + assert res == [Row(r="SparkSQL")] def test_char(self, spark): df = spark.createDataFrame( [(65,), (65 + 256,), (66 + 256,)], [ - 'a', + "a", ], ) - res = df.select(F.char(df.a).alias('ch')).collect() - assert res == [Row(ch='A'), Row(ch='A'), Row(ch='B')] + res = df.select(F.char(df.a).alias("ch")).collect() + assert res == [Row(ch="A"), Row(ch="A"), Row(ch="B")] def test_encode(self, spark): - df = spark.createDataFrame([('abcd',)], ['c']) + df = spark.createDataFrame([("abcd",)], ["c"]) res = df.select(F.encode("c", "UTF-8").alias("encoded")).collect() - # FIXME: Should return the same type + # TODO: Should return the same type # noqa: TD002, TD003 if USE_ACTUAL_SPARK: - assert res == [Row(encoded=bytearray(b'abcd'))] + assert res == [Row(encoded=bytearray(b"abcd"))] else: - assert res == [Row(encoded=b'abcd')] + assert res == [Row(encoded=b"abcd")] def test_split(self, spark): df = spark.createDataFrame( - [('oneAtwoBthreeC',)], + [("oneAtwoBthreeC",)], [ - 's', + "s", ], ) - res = df.select(F.split(df.s, '[ABC]').alias('s')).collect() - assert res == [Row(s=['one', 'two', 'three', ''])] + res = df.select(F.split(df.s, "[ABC]").alias("s")).collect() + assert res == [Row(s=["one", "two", "three", ""])] def test_split_part(self, spark): df = spark.createDataFrame( @@ -206,8 +206,8 @@ def test_split_part(self, spark): ["a", "b", "c"], ) - res = df.select(F.split_part(df.a, df.b, df.c).alias('r')).collect() - assert res == [Row(r='13')] + res = df.select(F.split_part(df.a, df.b, df.c).alias("r")).collect() + assert res == [Row(r="13")] # If any input is null, should return null df = spark.createDataFrame( @@ -225,8 +225,8 @@ def test_split_part(self, spark): ], ["a", "b", "c"], ) - res = df.select(F.split_part(df.a, df.b, df.c).alias('r')).collect() - assert res == [Row(r=None), Row(r='11')] + res = df.select(F.split_part(df.a, df.b, df.c).alias("r")).collect() + assert res == [Row(r=None), Row(r="11")] # If partNum is out of range, should return an empty string df = spark.createDataFrame( @@ -239,8 +239,8 @@ def test_split_part(self, spark): ], ["a", "b", "c"], ) - res = df.select(F.split_part(df.a, df.b, df.c).alias('r')).collect() - assert res == [Row(r='')] + res = df.select(F.split_part(df.a, df.b, df.c).alias("r")).collect() + assert res == [Row(r="")] # If partNum is negative, parts are counted backwards df = spark.createDataFrame( @@ -253,8 +253,8 @@ def test_split_part(self, spark): ], ["a", "b", "c"], ) - res = df.select(F.split_part(df.a, df.b, df.c).alias('r')).collect() - assert res == [Row(r='13')] + res = df.select(F.split_part(df.a, df.b, df.c).alias("r")).collect() + assert res == [Row(r="13")] # If the delimiter is an empty string, the return should be empty df = spark.createDataFrame( @@ -267,8 +267,8 @@ def test_split_part(self, spark): ], ["a", "b", "c"], ) - res = df.select(F.split_part(df.a, df.b, df.c).alias('r')).collect() - assert res == [Row(r='')] + res = df.select(F.split_part(df.a, df.b, df.c).alias("r")).collect() + assert res == [Row(r="")] def test_substr(self, spark): df = spark.createDataFrame( @@ -282,7 +282,7 @@ def test_substr(self, spark): ["a", "b", "c"], ) res = df.select(F.substr("a", "b", "c").alias("s")).collect() - assert res == [Row(s='k')] + assert res == [Row(s="k")] df = spark.createDataFrame( [ @@ -295,21 +295,21 @@ def test_substr(self, spark): ["a", "b", "c"], ) res = df.select(F.substr("a", "b").alias("s")).collect() - assert res == [Row(s='k SQL')] + assert res == [Row(s="k SQL")] def test_find_in_set(self, spark): string_array = "abc,b,ab,c,def" - df = spark.createDataFrame([("ab", string_array), ("b,c", string_array), ("z", string_array)], ['a', 'b']) + df = spark.createDataFrame([("ab", string_array), ("b,c", string_array), ("z", string_array)], ["a", "b"]) - res = df.select(F.find_in_set(df.a, df.b).alias('r')).collect() + res = df.select(F.find_in_set(df.a, df.b).alias("r")).collect() assert res == [Row(r=3), Row(r=0), Row(r=0)] def test_initcap(self, spark): - df = spark.createDataFrame([('ab cd',)], ['a']) + df = spark.createDataFrame([("ab cd",)], ["a"]) - res = df.select(F.initcap("a").alias('v')).collect() - assert res == [Row(v='Ab Cd')] + res = df.select(F.initcap("a").alias("v")).collect() + assert res == [Row(v="Ab Cd")] def test_left(self, spark): df = spark.createDataFrame( @@ -327,11 +327,11 @@ def test_left(self, spark): -3, ), ], - ['a', 'b'], + ["a", "b"], ) - res = df.select(F.left(df.a, df.b).alias('r')).collect() - assert res == [Row(r='Spa'), Row(r=''), Row(r='')] + res = df.select(F.left(df.a, df.b).alias("r")).collect() + assert res == [Row(r="Spa"), Row(r=""), Row(r="")] def test_right(self, spark): df = spark.createDataFrame( @@ -349,39 +349,39 @@ def test_right(self, spark): -3, ), ], - ['a', 'b'], + ["a", "b"], ) - res = df.select(F.right(df.a, df.b).alias('r')).collect() - assert res == [Row(r='SQL'), Row(r=''), Row(r='')] + res = df.select(F.right(df.a, df.b).alias("r")).collect() + assert res == [Row(r="SQL"), Row(r=""), Row(r="")] def test_levenshtein(self, spark): - df = spark.createDataFrame([("kitten", "sitting"), ("saturdays", "sunday")], ['a', 'b']) + df = spark.createDataFrame([("kitten", "sitting"), ("saturdays", "sunday")], ["a", "b"]) - res = df.select(F.levenshtein(df.a, df.b).alias('r'), F.levenshtein(df.a, df.b, 3).alias('r_th')).collect() + res = df.select(F.levenshtein(df.a, df.b).alias("r"), F.levenshtein(df.a, df.b, 3).alias("r_th")).collect() assert res == [Row(r=3, r_th=3), Row(r=4, r_th=-1)] def test_lpad(self, spark): df = spark.createDataFrame( - [('abcd',)], + [("abcd",)], [ - 's', + "s", ], ) - res = df.select(F.lpad(df.s, 6, '#').alias('s')).collect() - assert res == [Row(s='##abcd')] + res = df.select(F.lpad(df.s, 6, "#").alias("s")).collect() + assert res == [Row(s="##abcd")] def test_rpad(self, spark): df = spark.createDataFrame( - [('abcd',)], + [("abcd",)], [ - 's', + "s", ], ) - res = df.select(F.rpad(df.s, 6, '#').alias('s')).collect() - assert res == [Row(s='abcd##')] + res = df.select(F.rpad(df.s, 6, "#").alias("s")).collect() + assert res == [Row(s="abcd##")] def test_printf(self, spark): df = spark.createDataFrame( @@ -395,79 +395,79 @@ def test_printf(self, spark): ["a", "b", "c"], ) res = df.select(F.printf("a", "b", "c").alias("r")).collect() - assert res == [Row(r='aa123cc')] + assert res == [Row(r="aa123cc")] @pytest.mark.parametrize("regexp_func", [F.regexp, F.regexp_like]) def test_regexp_and_regexp_like(self, spark, regexp_func): df = spark.createDataFrame([("1a 2b 14m", r"(\d+)")], ["str", "regexp"]) - res = df.select(regexp_func('str', F.lit(r'(\d+)')).alias("m")).collect() + res = df.select(regexp_func("str", F.lit(r"(\d+)")).alias("m")).collect() assert res[0].m is True df = spark.createDataFrame([("1a 2b 14m", r"(\d+)")], ["str", "regexp"]) - res = df.select(regexp_func('str', F.lit(r'\d{2}b')).alias("m")).collect() + res = df.select(regexp_func("str", F.lit(r"\d{2}b")).alias("m")).collect() assert res[0].m is False df = spark.createDataFrame([("1a 2b 14m", r"(\d+)")], ["str", "regexp"]) - res = df.select(regexp_func('str', F.col("regexp")).alias("m")).collect() + res = df.select(regexp_func("str", F.col("regexp")).alias("m")).collect() assert res[0].m is True def test_regexp_count(self, spark): df = spark.createDataFrame([("1a 2b 14m", r"\d+")], ["str", "regexp"]) - res = df.select(F.regexp_count('str', F.lit(r'\d+')).alias('d')).collect() + res = df.select(F.regexp_count("str", F.lit(r"\d+")).alias("d")).collect() assert res == [Row(d=3)] - res = df.select(F.regexp_count('str', F.lit(r'mmm')).alias('d')).collect() + res = df.select(F.regexp_count("str", F.lit(r"mmm")).alias("d")).collect() assert res == [Row(d=0)] - res = df.select(F.regexp_count("str", F.col("regexp")).alias('d')).collect() + res = df.select(F.regexp_count("str", F.col("regexp")).alias("d")).collect() assert res == [Row(d=3)] def test_regexp_extract(self, spark): - df = spark.createDataFrame([('100-200',)], ['str']) - res = df.select(F.regexp_extract('str', r'(\d+)-(\d+)', 1).alias('d')).collect() - assert res == [Row(d='100')] + df = spark.createDataFrame([("100-200",)], ["str"]) + res = df.select(F.regexp_extract("str", r"(\d+)-(\d+)", 1).alias("d")).collect() + assert res == [Row(d="100")] - df = spark.createDataFrame([('foo',)], ['str']) - res = df.select(F.regexp_extract('str', r'(\d+)', 1).alias('d')).collect() - assert res == [Row(d='')] + df = spark.createDataFrame([("foo",)], ["str"]) + res = df.select(F.regexp_extract("str", r"(\d+)", 1).alias("d")).collect() + assert res == [Row(d="")] - df = spark.createDataFrame([('aaaac',)], ['str']) - res = df.select(F.regexp_extract('str', '(a+)(b)?(c)', 2).alias('d')).collect() - assert res == [Row(d='')] + df = spark.createDataFrame([("aaaac",)], ["str"]) + res = df.select(F.regexp_extract("str", "(a+)(b)?(c)", 2).alias("d")).collect() + assert res == [Row(d="")] def test_regexp_extract_all(self, spark): df = spark.createDataFrame([("100-200, 300-400", r"(\d+)-(\d+)")], ["str", "regexp"]) - res = df.select(F.regexp_extract_all('str', F.lit(r'(\d+)-(\d+)')).alias('d')).collect() - assert res == [Row(d=['100', '300'])] + res = df.select(F.regexp_extract_all("str", F.lit(r"(\d+)-(\d+)")).alias("d")).collect() + assert res == [Row(d=["100", "300"])] - res = df.select(F.regexp_extract_all('str', F.lit(r'(\d+)-(\d+)'), 1).alias('d')).collect() - assert res == [Row(d=['100', '300'])] + res = df.select(F.regexp_extract_all("str", F.lit(r"(\d+)-(\d+)"), 1).alias("d")).collect() + assert res == [Row(d=["100", "300"])] - res = df.select(F.regexp_extract_all('str', F.lit(r'(\d+)-(\d+)'), 2).alias('d')).collect() - assert res == [Row(d=['200', '400'])] + res = df.select(F.regexp_extract_all("str", F.lit(r"(\d+)-(\d+)"), 2).alias("d")).collect() + assert res == [Row(d=["200", "400"])] - res = df.select(F.regexp_extract_all('str', F.col("regexp")).alias('d')).collect() - assert res == [Row(d=['100', '300'])] + res = df.select(F.regexp_extract_all("str", F.col("regexp")).alias("d")).collect() + assert res == [Row(d=["100", "300"])] def test_regexp_substr(self, spark): df = spark.createDataFrame([("1a 2b 14m", r"\d+")], ["str", "regexp"]) - res = df.select(F.regexp_substr('str', F.lit(r'\d+')).alias('d')).collect() - assert res == [Row(d='1')] + res = df.select(F.regexp_substr("str", F.lit(r"\d+")).alias("d")).collect() + assert res == [Row(d="1")] - res = df.select(F.regexp_substr('str', F.lit(r'mmm')).alias('d')).collect() + res = df.select(F.regexp_substr("str", F.lit(r"mmm")).alias("d")).collect() assert res == [Row(d=None)] - res = df.select(F.regexp_substr("str", F.col("regexp")).alias('d')).collect() - assert res == [Row(d='1')] + res = df.select(F.regexp_substr("str", F.col("regexp")).alias("d")).collect() + assert res == [Row(d="1")] def test_repeat(self, spark): df = spark.createDataFrame( - [('ab',)], + [("ab",)], [ - 's', + "s", ], ) - res = df.select(F.repeat(df.s, 3).alias('s')).collect() - assert res == [Row(s='ababab')] + res = df.select(F.repeat(df.s, 3).alias("s")).collect() + assert res == [Row(s="ababab")] def test_reverse(self, spark): data = [ diff --git a/tests/fast/spark/test_spark_group_by.py b/tests/fast/spark/test_spark_group_by.py index 8b66901f..fafd747d 100644 --- a/tests/fast/spark/test_spark_group_by.py +++ b/tests/fast/spark/test_spark_group_by.py @@ -3,47 +3,35 @@ _ = pytest.importorskip("duckdb.experimental.spark") from spark_namespace import USE_ACTUAL_SPARK -from spark_namespace.sql.types import ( - LongType, - StructType, - BooleanType, - StructField, - StringType, - IntegerType, - LongType, - Row, - ArrayType, - MapType, -) -from spark_namespace.sql.functions import col, struct, when, lit, array_contains from spark_namespace.sql.functions import ( - sum, + any_value, + approx_count_distinct, avg, + col, + covar_pop, + covar_samp, + first, + last, max, - min, - stddev_samp, - stddev, + median, + mode, + product, + skewness, std, + stddev, stddev_pop, + stddev_samp, + sum, var_pop, var_samp, variance, - mean, - mode, - median, - product, - count, - skewness, - any_value, - approx_count_distinct, - covar_pop, - covar_samp, - first, - last, +) +from spark_namespace.sql.types import ( + Row, ) -class TestDataFrameGroupBy(object): +class TestDataFrameGroupBy: def test_group_by(self, spark): simpleData = [ ("James", "Sales", "NY", 90000, 34, 10000), @@ -62,40 +50,49 @@ def test_group_by(self, spark): df2 = df.groupBy("department").sum("salary").sort("department") res = df2.collect() - expected = "[Row(department='Finance', sum(salary)=351000), Row(department='Marketing', sum(salary)=171000), Row(department='Sales', sum(salary)=257000)]" + expected = ( + "[Row(department='Finance', sum(salary)=351000), Row(department='Marketing', sum(salary)=171000), " + "Row(department='Sales', sum(salary)=257000)]" + ) assert str(res) == expected df2 = df.groupBy("department").count().sort("department") res = df2.collect() assert ( - str(res) - == "[Row(department='Finance', count=4), Row(department='Marketing', count=2), Row(department='Sales', count=3)]" + str(res) == "[Row(department='Finance', count=4), Row(department='Marketing', count=2), " + "Row(department='Sales', count=3)]" ) df2 = df.groupBy("department").min("salary").sort("department") res = df2.collect() assert ( str(res) - == "[Row(department='Finance', min(salary)=79000), Row(department='Marketing', min(salary)=80000), Row(department='Sales', min(salary)=81000)]" + == "[Row(department='Finance', min(salary)=79000), Row(department='Marketing', min(salary)=80000), " + "Row(department='Sales', min(salary)=81000)]" ) df2 = df.groupBy("department").max("salary").sort("department") res = df2.collect() assert ( str(res) - == "[Row(department='Finance', max(salary)=99000), Row(department='Marketing', max(salary)=91000), Row(department='Sales', max(salary)=90000)]" + == "[Row(department='Finance', max(salary)=99000), Row(department='Marketing', max(salary)=91000), " + "Row(department='Sales', max(salary)=90000)]" ) df2 = df.groupBy("department").avg("salary").sort("department") res = df2.collect() assert ( str(res) - == "[Row(department='Finance', avg(salary)=87750.0), Row(department='Marketing', avg(salary)=85500.0), Row(department='Sales', avg(salary)=85666.66666666667)]" + == "[Row(department='Finance', avg(salary)=87750.0), Row(department='Marketing', avg(salary)=85500.0), " + "Row(department='Sales', avg(salary)=85666.66666666667)]" ) df2 = df.groupBy("department").mean("salary").sort("department") res = df2.collect() - expected_res_str = "[Row(department='Finance', mean(salary)=87750.0), Row(department='Marketing', mean(salary)=85500.0), Row(department='Sales', mean(salary)=85666.66666666667)]" + expected_res_str = ( + "[Row(department='Finance', mean(salary)=87750.0), Row(department='Marketing', " + "mean(salary)=85500.0), Row(department='Sales', mean(salary)=85666.66666666667)]" + ) if USE_ACTUAL_SPARK: expected_res_str = expected_res_str.replace("mean(", "avg(") assert str(res) == expected_res_str @@ -103,8 +100,12 @@ def test_group_by(self, spark): df2 = df.groupBy("department", "state").sum("salary", "bonus").sort("department", "state") res = df2.collect() assert ( - str(res) - == "[Row(department='Finance', state='CA', sum(salary)=189000, sum(bonus)=47000), Row(department='Finance', state='NY', sum(salary)=162000, sum(bonus)=34000), Row(department='Marketing', state='CA', sum(salary)=80000, sum(bonus)=18000), Row(department='Marketing', state='NY', sum(salary)=91000, sum(bonus)=21000), Row(department='Sales', state='CA', sum(salary)=81000, sum(bonus)=23000), Row(department='Sales', state='NY', sum(salary)=176000, sum(bonus)=30000)]" + str(res) == "[Row(department='Finance', state='CA', sum(salary)=189000, sum(bonus)=47000), " + "Row(department='Finance', state='NY', sum(salary)=162000, sum(bonus)=34000), " + "Row(department='Marketing', state='CA', sum(salary)=80000, sum(bonus)=18000), " + "Row(department='Marketing', state='NY', sum(salary)=91000, sum(bonus)=21000), " + "Row(department='Sales', state='CA', sum(salary)=81000, sum(bonus)=23000), " + "Row(department='Sales', state='NY', sum(salary)=176000, sum(bonus)=30000)]" ) df2 = ( @@ -122,7 +123,11 @@ def test_group_by(self, spark): res = df2.collect() assert ( str(res) - == "[Row(department='Finance', sum_salary=351000, avg_salary=87750.0, sum_bonus=81000, max_bonus=24000, any_state='CA', distinct_state=2), Row(department='Marketing', sum_salary=171000, avg_salary=85500.0, sum_bonus=39000, max_bonus=21000, any_state='CA', distinct_state=2), Row(department='Sales', sum_salary=257000, avg_salary=85666.66666666667, sum_bonus=53000, max_bonus=23000, any_state='NY', distinct_state=2)]" + == "[Row(department='Finance', sum_salary=351000, avg_salary=87750.0, sum_bonus=81000, max_bonus=24000, " + "any_state='CA', distinct_state=2), Row(department='Marketing', sum_salary=171000, avg_salary=85500.0, " + "sum_bonus=39000, max_bonus=21000, any_state='CA', distinct_state=2), Row(department='Sales', " + "sum_salary=257000, avg_salary=85666.66666666667, sum_bonus=53000, max_bonus=23000, any_state='NY', " + "distinct_state=2)]" ) df2 = ( @@ -141,7 +146,9 @@ def test_group_by(self, spark): print(str(res)) assert ( str(res) - == "[Row(department='Finance', sum_salary=351000, avg_salary=87750.0, sum_bonus=81000, max_bonus=24000, any_state='CA'), Row(department='Sales', sum_salary=257000, avg_salary=85666.66666666667, sum_bonus=53000, max_bonus=23000, any_state='NY')]" + == "[Row(department='Finance', sum_salary=351000, avg_salary=87750.0, sum_bonus=81000, max_bonus=24000, " + "any_state='CA'), Row(department='Sales', sum_salary=257000, avg_salary=85666.66666666667, " + "sum_bonus=53000, max_bonus=23000, any_state='NY')]" ) df = spark.createDataFrame( @@ -170,12 +177,12 @@ def test_group_by_empty(self, spark): res = df.groupBy(["name", "age"]).count().sort("name").collect() assert ( - str(res) - == "[Row(name='1', age=2, count=1), Row(name='2', age=2, count=1), Row(name='3', age=2, count=1), Row(name='4', age=5, count=1)]" + str(res) == "[Row(name='1', age=2, count=1), Row(name='2', age=2, count=1), Row(name='3', age=2, count=1), " + "Row(name='4', age=5, count=1)]" ) res = df.groupBy("name").count().columns - assert res == ['name', 'count'] + assert res == ["name", "count"] def test_group_by_first_and_last(self, spark): df = spark.createDataFrame([("Alice", 2), ("Bob", 5), ("Alice", None)], ("name", "age")) @@ -188,7 +195,7 @@ def test_group_by_first_and_last(self, spark): .collect() ) - assert res == [Row(name='Alice', first_age=None, last_age=2), Row(name='Bob', first_age=5, last_age=5)] + assert res == [Row(name="Alice", first_age=None, last_age=2), Row(name="Bob", first_age=5, last_age=5)] def test_standard_deviations(self, spark): df = spark.createDataFrame( @@ -265,7 +272,7 @@ def test_group_by_mean(self, spark): res = df.groupBy("course").agg(median("earnings").alias("m")).collect() - assert sorted(res, key=lambda x: x.course) == [Row(course='Java', m=22000), Row(course='dotNET', m=10000)] + assert sorted(res, key=lambda x: x.course) == [Row(course="Java", m=22000), Row(course="dotNET", m=10000)] def test_group_by_mode(self, spark): df = spark.createDataFrame( @@ -282,17 +289,17 @@ def test_group_by_mode(self, spark): res = df.groupby("course").agg(mode("year").alias("mode")).collect() - assert sorted(res, key=lambda x: x.course) == [Row(course='Java', mode=2012), Row(course='dotNET', mode=2012)] + assert sorted(res, key=lambda x: x.course) == [Row(course="Java", mode=2012), Row(course="dotNET", mode=2012)] def test_group_by_product(self, spark): - df = spark.range(1, 10).toDF('x').withColumn('mod3', col('x') % 3) - res = df.groupBy('mod3').agg(product('x').alias('product')).orderBy("mod3").collect() + df = spark.range(1, 10).toDF("x").withColumn("mod3", col("x") % 3) + res = df.groupBy("mod3").agg(product("x").alias("product")).orderBy("mod3").collect() assert res == [Row(mod3=0, product=162), Row(mod3=1, product=28), Row(mod3=2, product=80)] def test_group_by_skewness(self, spark): df = spark.createDataFrame([[1, "A"], [1, "A"], [2, "A"]], ["c", "group"]) res = df.groupBy("group").agg(skewness(df.c).alias("v")).collect() - # FIXME: Why is this different? + # TODO: Why is this different? # noqa: TD002, TD003 if USE_ACTUAL_SPARK: assert pytest.approx(res[0].v) == 0.7071067811865475 else: diff --git a/tests/fast/spark/test_spark_intersect.py b/tests/fast/spark/test_spark_intersect.py index 7fd97d40..8ec67dd0 100644 --- a/tests/fast/spark/test_spark_intersect.py +++ b/tests/fast/spark/test_spark_intersect.py @@ -1,10 +1,8 @@ -import platform import pytest _ = pytest.importorskip("duckdb.experimental.spark") from duckdb.experimental.spark.sql.types import Row -from duckdb.experimental.spark.sql.functions import col @pytest.fixture @@ -19,7 +17,6 @@ def df2(spark): class TestDataFrameIntersect: def test_intersect(self, spark, df, df2): - df3 = df.intersect(df2).sort(df.C1) res = df3.collect() @@ -29,7 +26,6 @@ def test_intersect(self, spark, df, df2): ] def test_intersect_all(self, spark, df, df2): - df3 = df.intersectAll(df2).sort(df.C1) res = df3.collect() diff --git a/tests/fast/spark/test_spark_join.py b/tests/fast/spark/test_spark_join.py index c7ef9878..5ca8ca63 100644 --- a/tests/fast/spark/test_spark_join.py +++ b/tests/fast/spark/test_spark_join.py @@ -2,20 +2,10 @@ _ = pytest.importorskip("duckdb.experimental.spark") +from spark_namespace.sql.functions import col from spark_namespace.sql.types import ( - LongType, - StructType, - BooleanType, - StructField, - StringType, - IntegerType, - LongType, Row, - ArrayType, - MapType, ) -from spark_namespace.sql.functions import col, struct, when, lit, array_contains -from spark_namespace.sql.functions import sum, avg, max, min, mean, count @pytest.fixture @@ -30,7 +20,7 @@ def dataframe_a(spark): ] empColumns = ["emp_id", "name", "superior_emp_id", "year_joined", "emp_dept_id", "gender", "salary"] dataframe = spark.createDataFrame(data=emp, schema=empColumns) - yield dataframe + return dataframe @pytest.fixture @@ -38,10 +28,10 @@ def dataframe_b(spark): dept = [("Finance", 10), ("Marketing", 20), ("Sales", 30), ("IT", 40)] deptColumns = ["dept_name", "dept_id"] dataframe = spark.createDataFrame(data=dept, schema=deptColumns) - yield dataframe + return dataframe -class TestDataFrameJoin(object): +class TestDataFrameJoin: def test_inner_join(self, dataframe_a, dataframe_b): df = dataframe_a.join(dataframe_b, dataframe_a.emp_dept_id == dataframe_b.dept_id, "inner") df = df.sort(*df.columns) @@ -49,63 +39,63 @@ def test_inner_join(self, dataframe_a, dataframe_b): expected = [ Row( emp_id=1, - name='Smith', + name="Smith", superior_emp_id=-1, - year_joined='2018', - emp_dept_id='10', - gender='M', + year_joined="2018", + emp_dept_id="10", + gender="M", salary=3000, - dept_name='Finance', + dept_name="Finance", dept_id=10, ), Row( emp_id=2, - name='Rose', + name="Rose", superior_emp_id=1, - year_joined='2010', - emp_dept_id='20', - gender='M', + year_joined="2010", + emp_dept_id="20", + gender="M", salary=4000, - dept_name='Marketing', + dept_name="Marketing", dept_id=20, ), Row( emp_id=3, - name='Williams', + name="Williams", superior_emp_id=1, - year_joined='2010', - emp_dept_id='10', - gender='M', + year_joined="2010", + emp_dept_id="10", + gender="M", salary=1000, - dept_name='Finance', + dept_name="Finance", dept_id=10, ), Row( emp_id=4, - name='Jones', + name="Jones", superior_emp_id=2, - year_joined='2005', - emp_dept_id='10', - gender='F', + year_joined="2005", + emp_dept_id="10", + gender="F", salary=2000, - dept_name='Finance', + dept_name="Finance", dept_id=10, ), Row( emp_id=5, - name='Brown', + name="Brown", superior_emp_id=2, - year_joined='2010', - emp_dept_id='40', - gender='', + year_joined="2010", + emp_dept_id="40", + gender="", salary=-1, - dept_name='IT', + dept_name="IT", dept_id=40, ), ] assert sorted(res) == sorted(expected) - @pytest.mark.parametrize('how', ['outer', 'fullouter', 'full', 'full_outer']) + @pytest.mark.parametrize("how", ["outer", "fullouter", "full", "full_outer"]) def test_outer_join(self, dataframe_a, dataframe_b, how): df = dataframe_a.join(dataframe_b, dataframe_a.emp_dept_id == dataframe_b.dept_id, how) df = df.sort(*df.columns) @@ -114,66 +104,66 @@ def test_outer_join(self, dataframe_a, dataframe_b, how): [ Row( emp_id=1, - name='Smith', + name="Smith", superior_emp_id=-1, - year_joined='2018', - emp_dept_id='10', - gender='M', + year_joined="2018", + emp_dept_id="10", + gender="M", salary=3000, - dept_name='Finance', + dept_name="Finance", dept_id=10, ), Row( emp_id=2, - name='Rose', + name="Rose", superior_emp_id=1, - year_joined='2010', - emp_dept_id='20', - gender='M', + year_joined="2010", + emp_dept_id="20", + gender="M", salary=4000, - dept_name='Marketing', + dept_name="Marketing", dept_id=20, ), Row( emp_id=3, - name='Williams', + name="Williams", superior_emp_id=1, - year_joined='2010', - emp_dept_id='10', - gender='M', + year_joined="2010", + emp_dept_id="10", + gender="M", salary=1000, - dept_name='Finance', + dept_name="Finance", dept_id=10, ), Row( emp_id=4, - name='Jones', + name="Jones", superior_emp_id=2, - year_joined='2005', - emp_dept_id='10', - gender='F', + year_joined="2005", + emp_dept_id="10", + gender="F", salary=2000, - dept_name='Finance', + dept_name="Finance", dept_id=10, ), Row( emp_id=5, - name='Brown', + name="Brown", superior_emp_id=2, - year_joined='2010', - emp_dept_id='40', - gender='', + year_joined="2010", + emp_dept_id="40", + gender="", salary=-1, - dept_name='IT', + dept_name="IT", dept_id=40, ), Row( emp_id=6, - name='Brown', + name="Brown", superior_emp_id=2, - year_joined='2010', - emp_dept_id='50', - gender='', + year_joined="2010", + emp_dept_id="50", + gender="", salary=-1, dept_name=None, dept_id=None, @@ -186,14 +176,14 @@ def test_outer_join(self, dataframe_a, dataframe_b, how): emp_dept_id=None, gender=None, salary=None, - dept_name='Sales', + dept_name="Sales", dept_id=30, ), ], key=lambda x: x.emp_id or 0, ) - @pytest.mark.parametrize('how', ['right', 'rightouter', 'right_outer']) + @pytest.mark.parametrize("how", ["right", "rightouter", "right_outer"]) def test_right_join(self, dataframe_a, dataframe_b, how): df = dataframe_a.join(dataframe_b, dataframe_a.emp_dept_id == dataframe_b.dept_id, how) df = df.sort(*df.columns) @@ -202,57 +192,57 @@ def test_right_join(self, dataframe_a, dataframe_b, how): [ Row( emp_id=1, - name='Smith', + name="Smith", superior_emp_id=-1, - year_joined='2018', - emp_dept_id='10', - gender='M', + year_joined="2018", + emp_dept_id="10", + gender="M", salary=3000, - dept_name='Finance', + dept_name="Finance", dept_id=10, ), Row( emp_id=2, - name='Rose', + name="Rose", superior_emp_id=1, - year_joined='2010', - emp_dept_id='20', - gender='M', + year_joined="2010", + emp_dept_id="20", + gender="M", salary=4000, - dept_name='Marketing', + dept_name="Marketing", dept_id=20, ), Row( emp_id=3, - name='Williams', + name="Williams", superior_emp_id=1, - year_joined='2010', - emp_dept_id='10', - gender='M', + year_joined="2010", + emp_dept_id="10", + gender="M", salary=1000, - dept_name='Finance', + dept_name="Finance", dept_id=10, ), Row( emp_id=4, - name='Jones', + name="Jones", superior_emp_id=2, - year_joined='2005', - emp_dept_id='10', - gender='F', + year_joined="2005", + emp_dept_id="10", + gender="F", salary=2000, - dept_name='Finance', + dept_name="Finance", dept_id=10, ), Row( emp_id=5, - name='Brown', + name="Brown", superior_emp_id=2, - year_joined='2010', - emp_dept_id='40', - gender='', + year_joined="2010", + emp_dept_id="40", + gender="", salary=-1, - dept_name='IT', + dept_name="IT", dept_id=40, ), Row( @@ -263,14 +253,14 @@ def test_right_join(self, dataframe_a, dataframe_b, how): emp_dept_id=None, gender=None, salary=None, - dept_name='Sales', + dept_name="Sales", dept_id=30, ), ], key=lambda x: x.emp_id or 0, ) - @pytest.mark.parametrize('how', ['semi', 'leftsemi', 'left_semi']) + @pytest.mark.parametrize("how", ["semi", "leftsemi", "left_semi"]) def test_semi_join(self, dataframe_a, dataframe_b, how): df = dataframe_a.join(dataframe_b, dataframe_a.emp_dept_id == dataframe_b.dept_id, how) df = df.sort(*df.columns) @@ -279,59 +269,59 @@ def test_semi_join(self, dataframe_a, dataframe_b, how): [ Row( emp_id=1, - name='Smith', + name="Smith", superior_emp_id=-1, - year_joined='2018', - emp_dept_id='10', - gender='M', + year_joined="2018", + emp_dept_id="10", + gender="M", salary=3000, ), Row( emp_id=2, - name='Rose', + name="Rose", superior_emp_id=1, - year_joined='2010', - emp_dept_id='20', - gender='M', + year_joined="2010", + emp_dept_id="20", + gender="M", salary=4000, ), Row( emp_id=3, - name='Williams', + name="Williams", superior_emp_id=1, - year_joined='2010', - emp_dept_id='10', - gender='M', + year_joined="2010", + emp_dept_id="10", + gender="M", salary=1000, ), Row( emp_id=4, - name='Jones', + name="Jones", superior_emp_id=2, - year_joined='2005', - emp_dept_id='10', - gender='F', + year_joined="2005", + emp_dept_id="10", + gender="F", salary=2000, ), Row( emp_id=5, - name='Brown', + name="Brown", superior_emp_id=2, - year_joined='2010', - emp_dept_id='40', - gender='', + year_joined="2010", + emp_dept_id="40", + gender="", salary=-1, ), ] ) - @pytest.mark.parametrize('how', ['anti', 'leftanti', 'left_anti']) + @pytest.mark.parametrize("how", ["anti", "leftanti", "left_anti"]) def test_anti_join(self, dataframe_a, dataframe_b, how): df = dataframe_a.join(dataframe_b, dataframe_a.emp_dept_id == dataframe_b.dept_id, how) df = df.sort(*df.columns) res = df.collect() assert res == [ - Row(emp_id=6, name='Brown', superior_emp_id=2, year_joined='2010', emp_dept_id='50', gender='', salary=-1) + Row(emp_id=6, name="Brown", superior_emp_id=2, year_joined="2010", emp_dept_id="50", gender="", salary=-1) ] def test_self_join(self, dataframe_a): @@ -351,11 +341,11 @@ def test_self_join(self, dataframe_a): res = df.collect() assert sorted(res, key=lambda x: x.emp_id) == sorted( [ - Row(emp_id=2, name='Rose', superior_emp_id=1, superior_emp_name='Smith'), - Row(emp_id=3, name='Williams', superior_emp_id=1, superior_emp_name='Smith'), - Row(emp_id=4, name='Jones', superior_emp_id=2, superior_emp_name='Rose'), - Row(emp_id=5, name='Brown', superior_emp_id=2, superior_emp_name='Rose'), - Row(emp_id=6, name='Brown', superior_emp_id=2, superior_emp_name='Rose'), + Row(emp_id=2, name="Rose", superior_emp_id=1, superior_emp_name="Smith"), + Row(emp_id=3, name="Williams", superior_emp_id=1, superior_emp_name="Smith"), + Row(emp_id=4, name="Jones", superior_emp_id=2, superior_emp_name="Rose"), + Row(emp_id=5, name="Brown", superior_emp_id=2, superior_emp_name="Rose"), + Row(emp_id=6, name="Brown", superior_emp_id=2, superior_emp_name="Rose"), ], key=lambda x: x.emp_id, ) @@ -382,33 +372,38 @@ def test_cross_join(self, spark): ) def test_join_with_using_clause(self, spark, dataframe_a): - dataframe_a = dataframe_a.select('name', 'year_joined') + dataframe_a = dataframe_a.select("name", "year_joined") - df = dataframe_a.alias('df1') - df2 = dataframe_a.alias('df2') - res = df.join(df2, ['name', 'year_joined']).sort('name', 'year_joined') + df = dataframe_a.alias("df1") + df2 = dataframe_a.alias("df2") + res = df.join(df2, ["name", "year_joined"]).sort("name", "year_joined") res = res.collect() assert res == [ - Row(name='Brown', year_joined='2010'), - Row(name='Brown', year_joined='2010'), - Row(name='Brown', year_joined='2010'), - Row(name='Brown', year_joined='2010'), - Row(name='Jones', year_joined='2005'), - Row(name='Rose', year_joined='2010'), - Row(name='Smith', year_joined='2018'), - Row(name='Williams', year_joined='2010'), + Row(name="Brown", year_joined="2010"), + Row(name="Brown", year_joined="2010"), + Row(name="Brown", year_joined="2010"), + Row(name="Brown", year_joined="2010"), + Row(name="Jones", year_joined="2005"), + Row(name="Rose", year_joined="2010"), + Row(name="Smith", year_joined="2018"), + Row(name="Williams", year_joined="2010"), ] def test_join_with_common_column(self, spark, dataframe_a): - dataframe_a = dataframe_a.select('name', 'year_joined') + dataframe_a = dataframe_a.select("name", "year_joined") - df = dataframe_a.alias('df1') - df2 = dataframe_a.alias('df2') - res = df.join(df2, df.name == df2.name).sort('df1.name') + df = dataframe_a.alias("df1") + df2 = dataframe_a.alias("df2") + res = df.join(df2, df.name == df2.name).sort("df1.name") res = res.collect() assert ( - str(res) - == "[Row(name='Brown', year_joined='2010', name='Brown', year_joined='2010'), Row(name='Brown', year_joined='2010', name='Brown', year_joined='2010'), Row(name='Brown', year_joined='2010', name='Brown', year_joined='2010'), Row(name='Brown', year_joined='2010', name='Brown', year_joined='2010'), Row(name='Jones', year_joined='2005', name='Jones', year_joined='2005'), Row(name='Rose', year_joined='2010', name='Rose', year_joined='2010'), Row(name='Smith', year_joined='2018', name='Smith', year_joined='2018'), Row(name='Williams', year_joined='2010', name='Williams', year_joined='2010')]" + str(res) == "[Row(name='Brown', year_joined='2010', name='Brown', year_joined='2010'), Row(name='Brown', " + "year_joined='2010', name='Brown', year_joined='2010'), Row(name='Brown', year_joined='2010', " + "name='Brown', year_joined='2010'), Row(name='Brown', year_joined='2010', name='Brown', " + "year_joined='2010'), Row(name='Jones', year_joined='2005', name='Jones', year_joined='2005'), " + "Row(name='Rose', year_joined='2010', name='Rose', year_joined='2010'), Row(name='Smith', " + "year_joined='2018', name='Smith', year_joined='2018'), Row(name='Williams', year_joined='2010', " + "name='Williams', year_joined='2010')]" ) @pytest.mark.xfail(condition=True, reason="Selecting from a duplicate binding causes an error") diff --git a/tests/fast/spark/test_spark_limit.py b/tests/fast/spark/test_spark_limit.py index c00496a0..eb88fc6a 100644 --- a/tests/fast/spark/test_spark_limit.py +++ b/tests/fast/spark/test_spark_limit.py @@ -7,7 +7,7 @@ ) -class TestDataFrameLimit(object): +class TestDataFrameLimit: def test_dataframe_limit(self, spark): df = spark.sql("select * from range(100000)") df2 = df.limit(10) diff --git a/tests/fast/spark/test_spark_order_by.py b/tests/fast/spark/test_spark_order_by.py index 92aa4d3a..030db4b8 100644 --- a/tests/fast/spark/test_spark_order_by.py +++ b/tests/fast/spark/test_spark_order_by.py @@ -2,24 +2,13 @@ _ = pytest.importorskip("duckdb.experimental.spark") +from spark_namespace.sql.functions import col from spark_namespace.sql.types import ( - LongType, - StructType, - BooleanType, - StructField, - StringType, - IntegerType, - LongType, Row, - ArrayType, - MapType, ) -from spark_namespace.sql.functions import col, struct, when, lit, array_contains -import duckdb -import re -class TestDataFrameOrderBy(object): +class TestDataFrameOrderBy: def test_order_by(self, spark): simpleData = [ ("James", "Sales", "NY", 90000, 34, 10000), @@ -38,15 +27,15 @@ def test_order_by(self, spark): df2 = df.sort("department", "state") res1 = df2.collect() assert res1 == [ - Row(employee_name='Maria', department='Finance', state='CA', salary=90000, age=24, bonus=23000), - Row(employee_name='Raman', department='Finance', state='CA', salary=99000, age=40, bonus=24000), - Row(employee_name='Scott', department='Finance', state='NY', salary=83000, age=36, bonus=19000), - Row(employee_name='Jen', department='Finance', state='NY', salary=79000, age=53, bonus=15000), - Row(employee_name='Jeff', department='Marketing', state='CA', salary=80000, age=25, bonus=18000), - Row(employee_name='Kumar', department='Marketing', state='NY', salary=91000, age=50, bonus=21000), - Row(employee_name='Robert', department='Sales', state='CA', salary=81000, age=30, bonus=23000), - Row(employee_name='James', department='Sales', state='NY', salary=90000, age=34, bonus=10000), - Row(employee_name='Michael', department='Sales', state='NY', salary=86000, age=56, bonus=20000), + Row(employee_name="Maria", department="Finance", state="CA", salary=90000, age=24, bonus=23000), + Row(employee_name="Raman", department="Finance", state="CA", salary=99000, age=40, bonus=24000), + Row(employee_name="Scott", department="Finance", state="NY", salary=83000, age=36, bonus=19000), + Row(employee_name="Jen", department="Finance", state="NY", salary=79000, age=53, bonus=15000), + Row(employee_name="Jeff", department="Marketing", state="CA", salary=80000, age=25, bonus=18000), + Row(employee_name="Kumar", department="Marketing", state="NY", salary=91000, age=50, bonus=21000), + Row(employee_name="Robert", department="Sales", state="CA", salary=81000, age=30, bonus=23000), + Row(employee_name="James", department="Sales", state="NY", salary=90000, age=34, bonus=10000), + Row(employee_name="Michael", department="Sales", state="NY", salary=86000, age=56, bonus=20000), ] df2 = df.sort(col("department"), col("state")) @@ -60,15 +49,15 @@ def test_order_by(self, spark): df2 = df.sort(df.department.asc(), df.state.desc()) res1 = df2.collect() assert res1 == [ - Row(employee_name='Scott', department='Finance', state='NY', salary=83000, age=36, bonus=19000), - Row(employee_name='Jen', department='Finance', state='NY', salary=79000, age=53, bonus=15000), - Row(employee_name='Maria', department='Finance', state='CA', salary=90000, age=24, bonus=23000), - Row(employee_name='Raman', department='Finance', state='CA', salary=99000, age=40, bonus=24000), - Row(employee_name='Kumar', department='Marketing', state='NY', salary=91000, age=50, bonus=21000), - Row(employee_name='Jeff', department='Marketing', state='CA', salary=80000, age=25, bonus=18000), - Row(employee_name='James', department='Sales', state='NY', salary=90000, age=34, bonus=10000), - Row(employee_name='Michael', department='Sales', state='NY', salary=86000, age=56, bonus=20000), - Row(employee_name='Robert', department='Sales', state='CA', salary=81000, age=30, bonus=23000), + Row(employee_name="Scott", department="Finance", state="NY", salary=83000, age=36, bonus=19000), + Row(employee_name="Jen", department="Finance", state="NY", salary=79000, age=53, bonus=15000), + Row(employee_name="Maria", department="Finance", state="CA", salary=90000, age=24, bonus=23000), + Row(employee_name="Raman", department="Finance", state="CA", salary=99000, age=40, bonus=24000), + Row(employee_name="Kumar", department="Marketing", state="NY", salary=91000, age=50, bonus=21000), + Row(employee_name="Jeff", department="Marketing", state="CA", salary=80000, age=25, bonus=18000), + Row(employee_name="James", department="Sales", state="NY", salary=90000, age=34, bonus=10000), + Row(employee_name="Michael", department="Sales", state="NY", salary=86000, age=56, bonus=20000), + Row(employee_name="Robert", department="Sales", state="CA", salary=81000, age=30, bonus=23000), ] df2 = df.sort(col("department").asc(), col("state").desc()) @@ -94,15 +83,15 @@ def test_order_by(self, spark): ) res = df2.collect() assert res == [ - Row(employee_name='Maria', department='Finance', state='CA', salary=90000, age=24, bonus=23000), - Row(employee_name='Raman', department='Finance', state='CA', salary=99000, age=40, bonus=24000), - Row(employee_name='Scott', department='Finance', state='NY', salary=83000, age=36, bonus=19000), - Row(employee_name='Jen', department='Finance', state='NY', salary=79000, age=53, bonus=15000), - Row(employee_name='Jeff', department='Marketing', state='CA', salary=80000, age=25, bonus=18000), - Row(employee_name='Kumar', department='Marketing', state='NY', salary=91000, age=50, bonus=21000), - Row(employee_name='James', department='Sales', state='NY', salary=90000, age=34, bonus=10000), - Row(employee_name='Michael', department='Sales', state='NY', salary=86000, age=56, bonus=20000), - Row(employee_name='Robert', department='Sales', state='CA', salary=81000, age=30, bonus=23000), + Row(employee_name="Maria", department="Finance", state="CA", salary=90000, age=24, bonus=23000), + Row(employee_name="Raman", department="Finance", state="CA", salary=99000, age=40, bonus=24000), + Row(employee_name="Scott", department="Finance", state="NY", salary=83000, age=36, bonus=19000), + Row(employee_name="Jen", department="Finance", state="NY", salary=79000, age=53, bonus=15000), + Row(employee_name="Jeff", department="Marketing", state="CA", salary=80000, age=25, bonus=18000), + Row(employee_name="Kumar", department="Marketing", state="NY", salary=91000, age=50, bonus=21000), + Row(employee_name="James", department="Sales", state="NY", salary=90000, age=34, bonus=10000), + Row(employee_name="Michael", department="Sales", state="NY", salary=86000, age=56, bonus=20000), + Row(employee_name="Robert", department="Sales", state="CA", salary=81000, age=30, bonus=23000), ] def test_null_ordering(self, spark): @@ -130,56 +119,56 @@ def test_null_ordering(self, spark): res = df.orderBy("value1", "value2").collect() assert res == [ - Row(value1=None, value2='A'), - Row(value1=2, value2='A'), + Row(value1=None, value2="A"), + Row(value1=2, value2="A"), Row(value1=3, value2=None), - Row(value1=3, value2='A'), + Row(value1=3, value2="A"), ] res = df.orderBy("value1", "value2", ascending=True).collect() assert res == [ - Row(value1=None, value2='A'), - Row(value1=2, value2='A'), + Row(value1=None, value2="A"), + Row(value1=2, value2="A"), Row(value1=3, value2=None), - Row(value1=3, value2='A'), + Row(value1=3, value2="A"), ] res = df.orderBy("value1", "value2", ascending=False).collect() assert res == [ - Row(value1=3, value2='A'), + Row(value1=3, value2="A"), Row(value1=3, value2=None), - Row(value1=2, value2='A'), - Row(value1=None, value2='A'), + Row(value1=2, value2="A"), + Row(value1=None, value2="A"), ] res = df.orderBy(df.value1, df.value2).collect() assert res == [ - Row(value1=None, value2='A'), - Row(value1=2, value2='A'), + Row(value1=None, value2="A"), + Row(value1=2, value2="A"), Row(value1=3, value2=None), - Row(value1=3, value2='A'), + Row(value1=3, value2="A"), ] res = df.orderBy(df.value1.asc(), df.value2.asc()).collect() assert res == [ - Row(value1=None, value2='A'), - Row(value1=2, value2='A'), + Row(value1=None, value2="A"), + Row(value1=2, value2="A"), Row(value1=3, value2=None), - Row(value1=3, value2='A'), + Row(value1=3, value2="A"), ] res = df.orderBy(df.value1.desc(), df.value2.desc()).collect() assert res == [ - Row(value1=3, value2='A'), + Row(value1=3, value2="A"), Row(value1=3, value2=None), - Row(value1=2, value2='A'), - Row(value1=None, value2='A'), + Row(value1=2, value2="A"), + Row(value1=None, value2="A"), ] res = df.orderBy(df.value1, df.value2, ascending=[True, False]).collect() assert res == [ - Row(value1=None, value2='A'), - Row(value1=2, value2='A'), + Row(value1=None, value2="A"), + Row(value1=2, value2="A"), Row(value1=3, value2="A"), Row(value1=3, value2=None), ] diff --git a/tests/fast/spark/test_spark_pandas_dataframe.py b/tests/fast/spark/test_spark_pandas_dataframe.py index dcec77a8..4e468d29 100644 --- a/tests/fast/spark/test_spark_pandas_dataframe.py +++ b/tests/fast/spark/test_spark_pandas_dataframe.py @@ -3,42 +3,34 @@ _ = pytest.importorskip("duckdb.experimental.spark") pd = pytest.importorskip("pandas") +from pandas.testing import assert_frame_equal from spark_namespace.sql.types import ( - LongType, - StructType, - BooleanType, - StructField, - StringType, IntegerType, - LongType, Row, - ArrayType, - MapType, + StringType, + StructField, + StructType, ) -from spark_namespace.sql.functions import col, struct, when -import duckdb -import re -from pandas.testing import assert_frame_equal @pytest.fixture def pandasDF(spark): - data = [['Scott', 50], ['Jeff', 45], ['Thomas', 54], ['Ann', 34]] + data = [["Scott", 50], ["Jeff", 45], ["Thomas", 54], ["Ann", 34]] # Create the pandas DataFrame - df = pd.DataFrame(data, columns=['Name', 'Age']) - yield df + df = pd.DataFrame(data, columns=["Name", "Age"]) + return df -class TestPandasDataFrame(object): +class TestPandasDataFrame: def test_pd_conversion_basic(self, spark, pandasDF): sparkDF = spark.createDataFrame(pandasDF) res = sparkDF.collect() sparkDF.show() expected = [ - Row(Name='Scott', Age=50), - Row(Name='Jeff', Age=45), - Row(Name='Thomas', Age=54), - Row(Name='Ann', Age=34), + Row(Name="Scott", Age=50), + Row(Name="Jeff", Age=45), + Row(Name="Thomas", Age=54), + Row(Name="Ann", Age=34), ] assert res == expected @@ -47,7 +39,10 @@ def test_pd_conversion_schema(self, spark, pandasDF): sparkDF = spark.createDataFrame(pandasDF, schema=mySchema) sparkDF.show() res = sparkDF.collect() - expected = "[Row(First Name='Scott', Age=50), Row(First Name='Jeff', Age=45), Row(First Name='Thomas', Age=54), Row(First Name='Ann', Age=34)]" + expected = ( + "[Row(First Name='Scott', Age=50), Row(First Name='Jeff', Age=45), " + "Row(First Name='Thomas', Age=54), Row(First Name='Ann', Age=34)]" + ) assert str(res) == expected def test_spark_to_pandas_dataframe(self, spark, pandasDF): diff --git a/tests/fast/spark/test_spark_readcsv.py b/tests/fast/spark/test_spark_readcsv.py index 8e6c0515..27cee47e 100644 --- a/tests/fast/spark/test_spark_readcsv.py +++ b/tests/fast/spark/test_spark_readcsv.py @@ -2,26 +2,15 @@ _ = pytest.importorskip("duckdb.experimental.spark") -from spark_namespace.sql.types import Row from spark_namespace import USE_ACTUAL_SPARK -import textwrap +from spark_namespace.sql.types import Row -class TestSparkReadCSV(object): +class TestSparkReadCSV: def test_read_csv(self, spark, tmp_path): - file_path = tmp_path / 'basic.csv' - with open(file_path, 'w+') as f: - f.write( - textwrap.dedent( - """ - 1,2 - 3,4 - 5,6 - """ - ) - ) - file_path = file_path.as_posix() - df = spark.read.csv(file_path) + file_path = tmp_path / "basic.csv" + file_path.write_text("1,2\n3,4\n5,6\n") + df = spark.read.csv(file_path.as_posix()) res = df.collect() expected_res = sorted([Row(column0=1, column1=2), Row(column0=3, column1=4), Row(column0=5, column1=6)]) diff --git a/tests/fast/spark/test_spark_readjson.py b/tests/fast/spark/test_spark_readjson.py index a6ad05f0..aa8d8ec5 100644 --- a/tests/fast/spark/test_spark_readjson.py +++ b/tests/fast/spark/test_spark_readjson.py @@ -2,16 +2,15 @@ _ = pytest.importorskip("duckdb.experimental.spark") + from spark_namespace.sql.types import Row -import textwrap -import duckdb -class TestSparkReadJson(object): +class TestSparkReadJson: def test_read_json(self, duckdb_cursor, spark, tmp_path): - file_path = tmp_path / 'basic.parquet' + file_path = tmp_path / "basic.parquet" file_path = file_path.as_posix() duckdb_cursor.execute(f"COPY (select 42 a, true b, 'this is a long string' c) to '{file_path}' (FORMAT JSON)") df = spark.read.json(file_path) res = df.collect() - assert res == [Row(a=42, b=True, c='this is a long string')] + assert res == [Row(a=42, b=True, c="this is a long string")] diff --git a/tests/fast/spark/test_spark_readparquet.py b/tests/fast/spark/test_spark_readparquet.py index a08ab16d..2f182650 100644 --- a/tests/fast/spark/test_spark_readparquet.py +++ b/tests/fast/spark/test_spark_readparquet.py @@ -2,18 +2,17 @@ _ = pytest.importorskip("duckdb.experimental.spark") + from spark_namespace.sql.types import Row -import textwrap -import duckdb -class TestSparkReadParquet(object): +class TestSparkReadParquet: def test_read_parquet(self, duckdb_cursor, spark, tmp_path): - file_path = tmp_path / 'basic.parquet' + file_path = tmp_path / "basic.parquet" file_path = file_path.as_posix() duckdb_cursor.execute( f"COPY (select 42 a, true b, 'this is a long string' c) to '{file_path}' (FORMAT PARQUET)" ) df = spark.read.parquet(file_path) res = df.collect() - assert res == [Row(a=42, b=True, c='this is a long string')] + assert res == [Row(a=42, b=True, c="this is a long string")] diff --git a/tests/fast/spark/test_spark_runtime_config.py b/tests/fast/spark/test_spark_runtime_config.py index 5e93ed63..fc6749b5 100644 --- a/tests/fast/spark/test_spark_runtime_config.py +++ b/tests/fast/spark/test_spark_runtime_config.py @@ -5,10 +5,10 @@ from spark_namespace import USE_ACTUAL_SPARK -class TestSparkRuntimeConfig(object): +class TestSparkRuntimeConfig: def test_spark_runtime_config(self, spark): # This fetches the internal runtime config from the session - spark.conf + spark.conf # noqa: B018 @pytest.mark.skipif( USE_ACTUAL_SPARK, reason="Getting an error with our local PySpark setup. Unclear why but not a priority." @@ -22,4 +22,4 @@ def test_spark_runtime_config_set(self, spark): def test_spark_runtime_config_get(self, spark): # Get a Spark Config with pytest.raises(KeyError): - partitions = spark.conf.get("spark.sql.shuffle.partitions") + spark.conf.get("spark.sql.shuffle.partitions") diff --git a/tests/fast/spark/test_spark_session.py b/tests/fast/spark/test_spark_session.py index 7c338898..d36f4f08 100644 --- a/tests/fast/spark/test_spark_session.py +++ b/tests/fast/spark/test_spark_session.py @@ -1,85 +1,79 @@ import pytest +from spark_namespace import USE_ACTUAL_SPARK +from spark_namespace.sql.types import Row + from duckdb.experimental.spark.exception import ( ContributionsAcceptedError, ) -from spark_namespace.sql.types import Row -from spark_namespace import USE_ACTUAL_SPARK _ = pytest.importorskip("duckdb.experimental.spark") from spark_namespace.sql import SparkSession -class TestSparkSession(object): +class TestSparkSession: def test_spark_session_default(self): - session = SparkSession.builder.getOrCreate() + SparkSession.builder.getOrCreate() def test_spark_session(self): - session = SparkSession.builder.master("local[1]").appName('SparkByExamples.com').getOrCreate() + SparkSession.builder.master("local[1]").appName("SparkByExamples.com").getOrCreate() def test_new_session(self, spark: SparkSession): - session = spark.newSession() + spark.newSession() - @pytest.mark.skip(reason='not tested yet') + @pytest.mark.skip(reason="not tested yet") def test_retrieve_same_session(self): - spark = SparkSession.builder.master('test').appName('test2').getOrCreate() + spark1 = SparkSession.builder.master("test").appName("test2").getOrCreate() spark2 = SparkSession.builder.getOrCreate() # Same connection should be returned - assert spark == spark2 + assert spark1 == spark2 def test_config(self): # Usage of config() - spark = ( - SparkSession.builder.master("local[1]") - .appName("SparkByExamples.com") - .config("spark.some.config.option", "config-value") - .getOrCreate() - ) + SparkSession.builder.master("local[1]").appName("SparkByExamples.com").config( + "spark.some.config.option", "config-value" + ).getOrCreate() @pytest.mark.skip(reason="enableHiveSupport is not implemented yet") def test_hive_support(self): # Enabling Hive to use in Spark - spark = ( - SparkSession.builder.master("local[1]") - .appName("SparkByExamples.com") - .config("spark.sql.warehouse.dir", "/spark-warehouse") - .enableHiveSupport() - .getOrCreate() - ) + SparkSession.builder.master("local[1]").appName("SparkByExamples.com").config( + "spark.sql.warehouse.dir", "/spark-warehouse" + ).enableHiveSupport().getOrCreate() @pytest.mark.skipif(USE_ACTUAL_SPARK, reason="Different version numbers") def test_version(self, spark): version = spark.version - assert version == '1.0.0' + assert version == "1.0.0" def test_get_active_session(self, spark): - active_session = spark.getActiveSession() + spark.getActiveSession() def test_read(self, spark): - reader = spark.read + reader = spark.read # noqa: F841 def test_write(self, spark): - df = spark.sql('select 42') - writer = df.write + df = spark.sql("select 42") + writer = df.write # noqa: F841 def test_read_stream(self, spark): - reader = spark.readStream + reader = spark.readStream # noqa: F841 def test_spark_context(self, spark): - context = spark.sparkContext + context = spark.sparkContext # noqa: F841 def test_sql(self, spark): - df = spark.sql('select 42') + spark.sql("select 42") def test_stop_context(self, spark): - context = spark.sparkContext + context = spark.sparkContext # noqa: F841 spark.stop() @pytest.mark.skipif( USE_ACTUAL_SPARK, reason="Can't create table with the local PySpark setup in the CI/CD pipeline" ) def test_table(self, spark): - spark.sql('create table tbl(a varchar(10))') - df = spark.table('tbl') + spark.sql("create table tbl(a varchar(10))") + spark.table("tbl") def test_range(self, spark): res_1 = spark.range(3).collect() @@ -96,4 +90,4 @@ def test_range(self, spark): spark.range(0, 10, 2, 2) def test_udf(self, spark): - udf_registration = spark.udf + udf_registration = spark.udf # noqa: F841 diff --git a/tests/fast/spark/test_spark_to_csv.py b/tests/fast/spark/test_spark_to_csv.py index 5048e579..10e0028c 100644 --- a/tests/fast/spark/test_spark_to_csv.py +++ b/tests/fast/spark/test_spark_to_csv.py @@ -1,12 +1,13 @@ -import pytest -import tempfile - +import csv +import datetime import os -_ = pytest.importorskip("duckdb.experimental.spark") - +import pytest +from conftest import ArrowPandas, NumpyPandas, getTimeSeriesData from spark_namespace import USE_ACTUAL_SPARK +from duckdb import InvalidInputException, read_csv + if USE_ACTUAL_SPARK: pytest.skip( "Skipping these tests as right now," @@ -15,12 +16,7 @@ allow_module_level=True, ) -from duckdb import connect, InvalidInputException, read_csv -from conftest import NumpyPandas, ArrowPandas, getTimeSeriesData -from spark_namespace import USE_ACTUAL_SPARK -import pandas._testing as tm -import datetime -import csv +pytest.importorskip("duckdb.experimental.spark") @pytest.fixture @@ -34,26 +30,26 @@ def df(spark): ) columns = ["CourseName", "fee", "discount"] dataframe = spark.createDataFrame(data=simpleData, schema=columns) - yield dataframe + return dataframe @pytest.fixture(params=[NumpyPandas(), ArrowPandas()]) def pandas_df_ints(request, spark): pandas = request.param - dataframe = pandas.DataFrame({'a': [5, 3, 23, 2], 'b': [45, 234, 234, 2]}) - yield dataframe + dataframe = pandas.DataFrame({"a": [5, 3, 23, 2], "b": [45, 234, 234, 2]}) + return dataframe @pytest.fixture(params=[NumpyPandas(), ArrowPandas()]) def pandas_df_strings(request, spark): pandas = request.param - dataframe = pandas.DataFrame({'a': ['string1', 'string2', 'string3']}) - yield dataframe + dataframe = pandas.DataFrame({"a": ["string1", "string2", "string3"]}) + return dataframe -class TestSparkToCSV(object): +class TestSparkToCSV: def test_basic_to_csv(self, pandas_df_ints, spark, tmp_path): - temp_file_name = os.path.join(tmp_path, "temp_file.csv") + temp_file_name = os.path.join(tmp_path, "temp_file.csv") # noqa: PTH118 df = spark.createDataFrame(pandas_df_ints) @@ -64,19 +60,19 @@ def test_basic_to_csv(self, pandas_df_ints, spark, tmp_path): assert df.collect() == csv_rel.collect() def test_to_csv_sep(self, pandas_df_ints, spark, tmp_path): - temp_file_name = os.path.join(tmp_path, "temp_file.csv") + temp_file_name = os.path.join(tmp_path, "temp_file.csv") # noqa: PTH118 df = spark.createDataFrame(pandas_df_ints) - df.write.csv(temp_file_name, sep=',') + df.write.csv(temp_file_name, sep=",") - csv_rel = spark.read.csv(temp_file_name, sep=',') + csv_rel = spark.read.csv(temp_file_name, sep=",") assert df.collect() == csv_rel.collect() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_na_rep(self, pandas, spark, tmp_path): - temp_file_name = os.path.join(tmp_path, "temp_file.csv") - pandas_df = pandas.DataFrame({'a': [5, None, 23, 2], 'b': [45, 234, 234, 2]}) + temp_file_name = os.path.join(tmp_path, "temp_file.csv") # noqa: PTH118 + pandas_df = pandas.DataFrame({"a": [5, None, 23, 2], "b": [45, 234, 234, 2]}) df = spark.createDataFrame(pandas_df) @@ -85,10 +81,10 @@ def test_to_csv_na_rep(self, pandas, spark, tmp_path): csv_rel = spark.read.csv(temp_file_name, nullValue="test") assert df.collect() == csv_rel.collect() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_header(self, pandas, spark, tmp_path): - temp_file_name = os.path.join(tmp_path, "temp_file.csv") - pandas_df = pandas.DataFrame({'a': [5, None, 23, 2], 'b': [45, 234, 234, 2]}) + temp_file_name = os.path.join(tmp_path, "temp_file.csv") # noqa: PTH118 + pandas_df = pandas.DataFrame({"a": [5, None, 23, 2], "b": [45, 234, 234, 2]}) df = spark.createDataFrame(pandas_df) @@ -97,22 +93,22 @@ def test_to_csv_header(self, pandas, spark, tmp_path): csv_rel = spark.read.csv(temp_file_name) assert df.collect() == csv_rel.collect() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_quotechar(self, pandas, spark, tmp_path): - temp_file_name = os.path.join(tmp_path, "temp_file.csv") + temp_file_name = os.path.join(tmp_path, "temp_file.csv") # noqa: PTH118 - pandas_df = pandas.DataFrame({'a': ["\'a,b,c\'", None, "hello", "bye"], 'b': [45, 234, 234, 2]}) + pandas_df = pandas.DataFrame({"a": ["'a,b,c'", None, "hello", "bye"], "b": [45, 234, 234, 2]}) df = spark.createDataFrame(pandas_df) - df.write.csv(temp_file_name, quote='\'', sep=',') + df.write.csv(temp_file_name, quote="'", sep=",") - csv_rel = spark.read.csv(temp_file_name, sep=',', quote='\'') + csv_rel = spark.read.csv(temp_file_name, sep=",", quote="'") assert df.collect() == csv_rel.collect() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_escapechar(self, pandas, spark, tmp_path): - temp_file_name = os.path.join(tmp_path, "temp_file.csv") + temp_file_name = os.path.join(tmp_path, "temp_file.csv") # noqa: PTH118 pandas_df = pandas.DataFrame( { "c_bool": [True, False], @@ -124,13 +120,13 @@ def test_to_csv_escapechar(self, pandas, spark, tmp_path): df = spark.createDataFrame(pandas_df) - df.write.csv(temp_file_name, quote='"', escape='!') - csv_rel = spark.read.csv(temp_file_name, quote='"', escape='!') + df.write.csv(temp_file_name, quote='"', escape="!") + csv_rel = spark.read.csv(temp_file_name, quote='"', escape="!") assert df.collect() == csv_rel.collect() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_date_format(self, pandas, spark, tmp_path): - temp_file_name = os.path.join(tmp_path, "temp_file.csv") + temp_file_name = os.path.join(tmp_path, "temp_file.csv") # noqa: PTH118 pandas_df = pandas.DataFrame(getTimeSeriesData()) dt_index = pandas_df.index pandas_df = pandas.DataFrame({"A": dt_index, "B": dt_index.shift(1)}, index=dt_index) @@ -143,22 +139,22 @@ def test_to_csv_date_format(self, pandas, spark, tmp_path): assert df.collect() == csv_rel.collect() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_timestamp_format(self, pandas, spark, tmp_path): - temp_file_name = os.path.join(tmp_path, "temp_file.csv") + temp_file_name = os.path.join(tmp_path, "temp_file.csv") # noqa: PTH118 data = [datetime.time(hour=23, minute=1, second=34, microsecond=234345)] - pandas_df = pandas.DataFrame({'0': pandas.Series(data=data, dtype='object')}) + pandas_df = pandas.DataFrame({"0": pandas.Series(data=data, dtype="object")}) df = spark.createDataFrame(pandas_df) - df.write.csv(temp_file_name, timestampFormat='%m/%d/%Y') + df.write.csv(temp_file_name, timestampFormat="%m/%d/%Y") - csv_rel = spark.read.csv(temp_file_name, timestampFormat='%m/%d/%Y') + csv_rel = spark.read.csv(temp_file_name, timestampFormat="%m/%d/%Y") assert df.collect() == csv_rel.collect() def test_to_csv_quoting_off(self, pandas_df_strings, spark, tmp_path): - temp_file_name = os.path.join(tmp_path, "temp_file.csv") + temp_file_name = os.path.join(tmp_path, "temp_file.csv") # noqa: PTH118 df = spark.createDataFrame(pandas_df_strings) df.write.csv(temp_file_name, quoteAll=None) @@ -166,7 +162,7 @@ def test_to_csv_quoting_off(self, pandas_df_strings, spark, tmp_path): assert df.collect() == csv_rel.collect() def test_to_csv_quoting_on(self, pandas_df_strings, spark, tmp_path): - temp_file_name = os.path.join(tmp_path, "temp_file.csv") + temp_file_name = os.path.join(tmp_path, "temp_file.csv") # noqa: PTH118 df = spark.createDataFrame(pandas_df_strings) df.write.csv(temp_file_name, quoteAll="force") @@ -174,7 +170,7 @@ def test_to_csv_quoting_on(self, pandas_df_strings, spark, tmp_path): assert df.collect() == csv_rel.collect() def test_to_csv_quoting_quote_all(self, pandas_df_strings, spark, tmp_path): - temp_file_name = os.path.join(tmp_path, "temp_file.csv") + temp_file_name = os.path.join(tmp_path, "temp_file.csv") # noqa: PTH118 df = spark.createDataFrame(pandas_df_strings) df.write.csv(temp_file_name, quoteAll=csv.QUOTE_ALL) @@ -182,7 +178,7 @@ def test_to_csv_quoting_quote_all(self, pandas_df_strings, spark, tmp_path): assert df.collect() == csv_rel.collect() def test_to_csv_encoding_incorrect(self, pandas_df_strings, spark, tmp_path): - temp_file_name = os.path.join(tmp_path, "temp_file.csv") + temp_file_name = os.path.join(tmp_path, "temp_file.csv") # noqa: PTH118 df = spark.createDataFrame(pandas_df_strings) with pytest.raises( InvalidInputException, match="Invalid Input Error: The only supported encoding option is 'UTF8" @@ -190,7 +186,7 @@ def test_to_csv_encoding_incorrect(self, pandas_df_strings, spark, tmp_path): df.write.csv(temp_file_name, encoding="nope") def test_to_csv_encoding_correct(self, pandas_df_strings, spark, tmp_path): - temp_file_name = os.path.join(tmp_path, "temp_file.csv") + temp_file_name = os.path.join(tmp_path, "temp_file.csv") # noqa: PTH118 df = spark.createDataFrame(pandas_df_strings) df.write.csv(temp_file_name, encoding="UTF-8") csv_rel = spark.read.csv(temp_file_name) @@ -198,10 +194,11 @@ def test_to_csv_encoding_correct(self, pandas_df_strings, spark, tmp_path): @pytest.mark.skipif( USE_ACTUAL_SPARK, - reason="This test uses DuckDB to read the CSV. However, this does not work if Spark created it as Spark creates a folder instead of a single file.", + reason="This test uses DuckDB to read the CSV. However, this does not work if Spark created it as " + "Spark creates a folder instead of a single file.", ) def test_compression_gzip(self, pandas_df_strings, spark, tmp_path): - temp_file_name = os.path.join(tmp_path, "temp_file.csv") + temp_file_name = os.path.join(tmp_path, "temp_file.csv") # noqa: PTH118 df = spark.createDataFrame(pandas_df_strings) df.write.csv(temp_file_name, compression="gzip") diff --git a/tests/fast/spark/test_spark_to_parquet.py b/tests/fast/spark/test_spark_to_parquet.py index 68a10f65..d120bec6 100644 --- a/tests/fast/spark/test_spark_to_parquet.py +++ b/tests/fast/spark/test_spark_to_parquet.py @@ -1,8 +1,7 @@ -import pytest -import tempfile - import os +import pytest + _ = pytest.importorskip("duckdb.experimental.spark") @@ -17,12 +16,12 @@ def df(spark): ) columns = ["CourseName", "fee", "discount"] dataframe = spark.createDataFrame(data=simpleData, schema=columns) - yield dataframe + return dataframe -class TestSparkToParquet(object): +class TestSparkToParquet: def test_basic_to_parquet(self, df, spark, tmp_path): - temp_file_name = os.path.join(tmp_path, "temp_file.parquet") + temp_file_name = os.path.join(tmp_path, "temp_file.parquet") # noqa: PTH118 df.write.parquet(temp_file_name) @@ -31,7 +30,7 @@ def test_basic_to_parquet(self, df, spark, tmp_path): assert sorted(df.collect()) == sorted(csv_rel.collect()) def test_compressed_to_parquet(self, df, spark, tmp_path): - temp_file_name = os.path.join(tmp_path, "temp_file.parquet") + temp_file_name = os.path.join(tmp_path, "temp_file.parquet") # noqa: PTH118 df.write.parquet(temp_file_name, compression="ZSTD") diff --git a/tests/fast/spark/test_spark_transform.py b/tests/fast/spark/test_spark_transform.py index 83e219a5..bf1c7b01 100644 --- a/tests/fast/spark/test_spark_transform.py +++ b/tests/fast/spark/test_spark_transform.py @@ -3,19 +3,8 @@ _ = pytest.importorskip("duckdb.experimental.spark") from spark_namespace.sql.types import ( - LongType, - StructType, - BooleanType, - StructField, - StringType, - IntegerType, - LongType, Row, - ArrayType, - MapType, ) -from spark_namespace.sql.functions import col, struct, when, lit, array_contains -from spark_namespace.sql.functions import sum, avg, max, min, mean, count @pytest.fixture @@ -26,7 +15,7 @@ def array_df(spark): ("Robert,,Williams", ["CSharp", "VB"], ["Spark", "Python"]), ] dataframe = spark.createDataFrame(data=data, schema=["Name", "Languages1", "Languages2"]) - yield dataframe + return dataframe @pytest.fixture @@ -40,10 +29,10 @@ def df(spark): ) columns = ["CourseName", "fee", "discount"] dataframe = spark.createDataFrame(data=simpleData, schema=columns) - yield dataframe + return dataframe -class TestDataFrameUnion(object): +class TestDataFrameUnion: def test_transform(self, spark, df): # Custom transformation 1 from spark_namespace.sql.functions import upper @@ -62,16 +51,16 @@ def apply_discount(df): df2 = df.transform(to_upper_str_columns).transform(reduce_price, 1000).transform(apply_discount) res = df2.collect() assert res == [ - Row(CourseName='JAVA', fee=4000, discount=5, new_fee=3000, discounted_fee=2850.0), - Row(CourseName='PYTHON', fee=4600, discount=10, new_fee=3600, discounted_fee=3240.0), - Row(CourseName='SCALA', fee=4100, discount=15, new_fee=3100, discounted_fee=2635.0), - Row(CourseName='SCALA', fee=4500, discount=15, new_fee=3500, discounted_fee=2975.0), - Row(CourseName='PHP', fee=3000, discount=20, new_fee=2000, discounted_fee=1600.0), + Row(CourseName="JAVA", fee=4000, discount=5, new_fee=3000, discounted_fee=2850.0), + Row(CourseName="PYTHON", fee=4600, discount=10, new_fee=3600, discounted_fee=3240.0), + Row(CourseName="SCALA", fee=4100, discount=15, new_fee=3100, discounted_fee=2635.0), + Row(CourseName="SCALA", fee=4500, discount=15, new_fee=3500, discounted_fee=2975.0), + Row(CourseName="PHP", fee=3000, discount=20, new_fee=2000, discounted_fee=1600.0), ] # https://sparkbyexamples.com/pyspark/pyspark-transform-function/ - @pytest.mark.skip(reason='LambdaExpressions are currently under development, waiting til that is finished') + @pytest.mark.skip(reason="LambdaExpressions are currently under development, waiting til that is finished") def test_transform_function(self, spark, array_df): - from spark_namespace.sql.functions import upper, transform + from spark_namespace.sql.functions import transform, upper df.select(transform("Languages1", lambda x: upper(x)).alias("languages1")).show() diff --git a/tests/fast/spark/test_spark_types.py b/tests/fast/spark/test_spark_types.py index fb6e6102..7e72aad6 100644 --- a/tests/fast/spark/test_spark_types.py +++ b/tests/fast/spark/test_spark_types.py @@ -9,43 +9,42 @@ "Skipping these tests as they use test_all_types() which is specific to DuckDB", allow_module_level=True ) -from spark_namespace.sql.types import Row from spark_namespace.sql.types import ( - StringType, + ArrayType, BinaryType, BitstringType, - UUIDType, BooleanType, + ByteType, DateType, - TimestampType, - TimestampNTZType, - TimeType, - TimeNTZType, - TimestampNanosecondNTZType, - TimestampMilisecondNTZType, - TimestampSecondNTZType, + DayTimeIntervalType, DecimalType, DoubleType, FloatType, - ByteType, - UnsignedByteType, - ShortType, - UnsignedShortType, + HugeIntegerType, IntegerType, - UnsignedIntegerType, LongType, - UnsignedLongType, - HugeIntegerType, - UnsignedHugeIntegerType, - DayTimeIntervalType, - ArrayType, MapType, + ShortType, + StringType, StructField, StructType, + TimeNTZType, + TimestampMilisecondNTZType, + TimestampNanosecondNTZType, + TimestampNTZType, + TimestampSecondNTZType, + TimestampType, + TimeType, + UnsignedByteType, + UnsignedHugeIntegerType, + UnsignedIntegerType, + UnsignedLongType, + UnsignedShortType, + UUIDType, ) -class TestTypes(object): +class TestTypes: def test_all_types_schema(self, spark): # Create DataFrame df = spark.sql( @@ -58,10 +57,10 @@ def test_all_types_schema(self, spark): fixed_int_array, fixed_varchar_array, fixed_nested_int_array, - fixed_nested_varchar_array, - fixed_struct_array, - struct_of_fixed_array, - fixed_array_of_int_list, + fixed_nested_varchar_array, + fixed_struct_array, + struct_of_fixed_array, + fixed_array_of_int_list, list_of_fixed_int_array, bignum ) from test_all_types() @@ -70,65 +69,65 @@ def test_all_types_schema(self, spark): schema = df.schema assert schema == StructType( [ - StructField('bool', BooleanType(), True), - StructField('tinyint', ByteType(), True), - StructField('smallint', ShortType(), True), - StructField('int', IntegerType(), True), - StructField('bigint', LongType(), True), - StructField('hugeint', HugeIntegerType(), True), - StructField('uhugeint', UnsignedHugeIntegerType(), True), - StructField('utinyint', UnsignedByteType(), True), - StructField('usmallint', UnsignedShortType(), True), - StructField('uint', UnsignedIntegerType(), True), - StructField('ubigint', UnsignedLongType(), True), - StructField('date', DateType(), True), - StructField('time', TimeNTZType(), True), - StructField('timestamp', TimestampNTZType(), True), - StructField('timestamp_s', TimestampSecondNTZType(), True), - StructField('timestamp_ms', TimestampNanosecondNTZType(), True), - StructField('timestamp_ns', TimestampMilisecondNTZType(), True), - StructField('time_tz', TimeType(), True), - StructField('timestamp_tz', TimestampType(), True), - StructField('float', FloatType(), True), - StructField('double', DoubleType(), True), - StructField('dec_4_1', DecimalType(4, 1), True), - StructField('dec_9_4', DecimalType(9, 4), True), - StructField('dec_18_6', DecimalType(18, 6), True), - StructField('dec38_10', DecimalType(38, 10), True), - StructField('uuid', UUIDType(), True), - StructField('interval', DayTimeIntervalType(0, 3), True), - StructField('varchar', StringType(), True), - StructField('blob', BinaryType(), True), - StructField('bit', BitstringType(), True), - StructField('int_array', ArrayType(IntegerType(), True), True), - StructField('double_array', ArrayType(DoubleType(), True), True), - StructField('date_array', ArrayType(DateType(), True), True), - StructField('timestamp_array', ArrayType(TimestampNTZType(), True), True), - StructField('timestamptz_array', ArrayType(TimestampType(), True), True), - StructField('varchar_array', ArrayType(StringType(), True), True), - StructField('nested_int_array', ArrayType(ArrayType(IntegerType(), True), True), True), + StructField("bool", BooleanType(), True), + StructField("tinyint", ByteType(), True), + StructField("smallint", ShortType(), True), + StructField("int", IntegerType(), True), + StructField("bigint", LongType(), True), + StructField("hugeint", HugeIntegerType(), True), + StructField("uhugeint", UnsignedHugeIntegerType(), True), + StructField("utinyint", UnsignedByteType(), True), + StructField("usmallint", UnsignedShortType(), True), + StructField("uint", UnsignedIntegerType(), True), + StructField("ubigint", UnsignedLongType(), True), + StructField("date", DateType(), True), + StructField("time", TimeNTZType(), True), + StructField("timestamp", TimestampNTZType(), True), + StructField("timestamp_s", TimestampSecondNTZType(), True), + StructField("timestamp_ms", TimestampNanosecondNTZType(), True), + StructField("timestamp_ns", TimestampMilisecondNTZType(), True), + StructField("time_tz", TimeType(), True), + StructField("timestamp_tz", TimestampType(), True), + StructField("float", FloatType(), True), + StructField("double", DoubleType(), True), + StructField("dec_4_1", DecimalType(4, 1), True), + StructField("dec_9_4", DecimalType(9, 4), True), + StructField("dec_18_6", DecimalType(18, 6), True), + StructField("dec38_10", DecimalType(38, 10), True), + StructField("uuid", UUIDType(), True), + StructField("interval", DayTimeIntervalType(0, 3), True), + StructField("varchar", StringType(), True), + StructField("blob", BinaryType(), True), + StructField("bit", BitstringType(), True), + StructField("int_array", ArrayType(IntegerType(), True), True), + StructField("double_array", ArrayType(DoubleType(), True), True), + StructField("date_array", ArrayType(DateType(), True), True), + StructField("timestamp_array", ArrayType(TimestampNTZType(), True), True), + StructField("timestamptz_array", ArrayType(TimestampType(), True), True), + StructField("varchar_array", ArrayType(StringType(), True), True), + StructField("nested_int_array", ArrayType(ArrayType(IntegerType(), True), True), True), StructField( - 'struct', - StructType([StructField('a', IntegerType(), True), StructField('b', StringType(), True)]), + "struct", + StructType([StructField("a", IntegerType(), True), StructField("b", StringType(), True)]), True, ), StructField( - 'struct_of_arrays', + "struct_of_arrays", StructType( [ - StructField('a', ArrayType(IntegerType(), True), True), - StructField('b', ArrayType(StringType(), True), True), + StructField("a", ArrayType(IntegerType(), True), True), + StructField("b", ArrayType(StringType(), True), True), ] ), True, ), StructField( - 'array_of_structs', + "array_of_structs", ArrayType( - StructType([StructField('a', IntegerType(), True), StructField('b', StringType(), True)]), True + StructType([StructField("a", IntegerType(), True), StructField("b", StringType(), True)]), True ), True, ), - StructField('map', MapType(StringType(), StringType(), True), True), + StructField("map", MapType(StringType(), StringType(), True), True), ] ) diff --git a/tests/fast/spark/test_spark_udf.py b/tests/fast/spark/test_spark_udf.py index 3b5a5d36..cee0f256 100644 --- a/tests/fast/spark/test_spark_udf.py +++ b/tests/fast/spark/test_spark_udf.py @@ -3,9 +3,8 @@ _ = pytest.importorskip("duckdb.experimental.spark") -class TestSparkUDF(object): +class TestSparkUDF: def test_udf_register(self, spark): - def to_upper_fn(s: str) -> str: return s.upper() diff --git a/tests/fast/spark/test_spark_union.py b/tests/fast/spark/test_spark_union.py index ea889e05..588c7ecd 100644 --- a/tests/fast/spark/test_spark_union.py +++ b/tests/fast/spark/test_spark_union.py @@ -1,10 +1,11 @@ import platform + import pytest _ = pytest.importorskip("duckdb.experimental.spark") -from spark_namespace.sql.types import Row from spark_namespace.sql.functions import col +from spark_namespace.sql.types import Row @pytest.fixture @@ -18,7 +19,7 @@ def df(spark): columns = ["employee_name", "department", "state", "salary", "age", "bonus"] dataframe = spark.createDataFrame(data=simpleData, schema=columns) - yield dataframe + return dataframe @pytest.fixture @@ -32,23 +33,23 @@ def df2(spark): ] columns2 = ["employee_name", "department", "state", "salary", "age", "bonus"] dataframe = spark.createDataFrame(data=simpleData2, schema=columns2) - yield dataframe + return dataframe -class TestDataFrameUnion(object): +class TestDataFrameUnion: def test_merge_with_union(self, df, df2): unionDF = df.union(df2) res = unionDF.collect() assert res == [ - Row(employee_name='James', department='Sales', state='NY', salary=90000, age=34, bonus=10000), - Row(employee_name='Michael', department='Sales', state='NY', salary=86000, age=56, bonus=20000), - Row(employee_name='Robert', department='Sales', state='CA', salary=81000, age=30, bonus=23000), - Row(employee_name='Maria', department='Finance', state='CA', salary=90000, age=24, bonus=23000), - Row(employee_name='James', department='Sales', state='NY', salary=90000, age=34, bonus=10000), - Row(employee_name='Maria', department='Finance', state='CA', salary=90000, age=24, bonus=23000), - Row(employee_name='Jen', department='Finance', state='NY', salary=79000, age=53, bonus=15000), - Row(employee_name='Jeff', department='Marketing', state='CA', salary=80000, age=25, bonus=18000), - Row(employee_name='Kumar', department='Marketing', state='NY', salary=91000, age=50, bonus=21000), + Row(employee_name="James", department="Sales", state="NY", salary=90000, age=34, bonus=10000), + Row(employee_name="Michael", department="Sales", state="NY", salary=86000, age=56, bonus=20000), + Row(employee_name="Robert", department="Sales", state="CA", salary=81000, age=30, bonus=23000), + Row(employee_name="Maria", department="Finance", state="CA", salary=90000, age=24, bonus=23000), + Row(employee_name="James", department="Sales", state="NY", salary=90000, age=34, bonus=10000), + Row(employee_name="Maria", department="Finance", state="CA", salary=90000, age=24, bonus=23000), + Row(employee_name="Jen", department="Finance", state="NY", salary=79000, age=53, bonus=15000), + Row(employee_name="Jeff", department="Marketing", state="CA", salary=80000, age=25, bonus=18000), + Row(employee_name="Kumar", department="Marketing", state="NY", salary=91000, age=50, bonus=21000), ] unionDF = df.unionAll(df2) res2 = unionDF.collect() @@ -60,11 +61,11 @@ def test_merge_without_duplicates(self, df, df2): disDF = df.union(df2).distinct().sort(col("employee_name")) res = disDF.collect() assert res == [ - Row(employee_name='James', department='Sales', state='NY', salary=90000, age=34, bonus=10000), - Row(employee_name='Jeff', department='Marketing', state='CA', salary=80000, age=25, bonus=18000), - Row(employee_name='Jen', department='Finance', state='NY', salary=79000, age=53, bonus=15000), - Row(employee_name='Kumar', department='Marketing', state='NY', salary=91000, age=50, bonus=21000), - Row(employee_name='Maria', department='Finance', state='CA', salary=90000, age=24, bonus=23000), - Row(employee_name='Michael', department='Sales', state='NY', salary=86000, age=56, bonus=20000), - Row(employee_name='Robert', department='Sales', state='CA', salary=81000, age=30, bonus=23000), + Row(employee_name="James", department="Sales", state="NY", salary=90000, age=34, bonus=10000), + Row(employee_name="Jeff", department="Marketing", state="CA", salary=80000, age=25, bonus=18000), + Row(employee_name="Jen", department="Finance", state="NY", salary=79000, age=53, bonus=15000), + Row(employee_name="Kumar", department="Marketing", state="NY", salary=91000, age=50, bonus=21000), + Row(employee_name="Maria", department="Finance", state="CA", salary=90000, age=24, bonus=23000), + Row(employee_name="Michael", department="Sales", state="NY", salary=86000, age=56, bonus=20000), + Row(employee_name="Robert", department="Sales", state="CA", salary=81000, age=30, bonus=23000), ] diff --git a/tests/fast/spark/test_spark_union_by_name.py b/tests/fast/spark/test_spark_union_by_name.py index 08f3c62b..bec539a2 100644 --- a/tests/fast/spark/test_spark_union_by_name.py +++ b/tests/fast/spark/test_spark_union_by_name.py @@ -4,48 +4,37 @@ from spark_namespace.sql.types import ( - LongType, - StructType, - BooleanType, - StructField, - StringType, - IntegerType, - LongType, Row, - ArrayType, - MapType, ) -from spark_namespace.sql.functions import col, struct, when, lit, array_contains -from spark_namespace.sql.functions import sum, avg, max, min, mean, count @pytest.fixture def df1(spark): data = [("James", 34), ("Michael", 56), ("Robert", 30), ("Maria", 24)] dataframe = spark.createDataFrame(data=data, schema=["name", "id"]) - yield dataframe + return dataframe @pytest.fixture def df2(spark): data2 = [(34, "James"), (45, "Maria"), (45, "Jen"), (34, "Jeff")] dataframe = spark.createDataFrame(data=data2, schema=["id", "name"]) - yield dataframe + return dataframe -class TestDataFrameUnion(object): +class TestDataFrameUnion: def test_union_by_name(self, df1, df2): rel = df1.unionByName(df2) res = rel.collect() expected = [ - Row(name='James', id=34), - Row(name='Michael', id=56), - Row(name='Robert', id=30), - Row(name='Maria', id=24), - Row(name='James', id=34), - Row(name='Maria', id=45), - Row(name='Jen', id=45), - Row(name='Jeff', id=34), + Row(name="James", id=34), + Row(name="Michael", id=56), + Row(name="Robert", id=30), + Row(name="Maria", id=24), + Row(name="James", id=34), + Row(name="Maria", id=45), + Row(name="Jen", id=45), + Row(name="Jeff", id=34), ] assert res == expected @@ -53,13 +42,13 @@ def test_union_by_name_allow_missing_cols(self, df1, df2): rel = df1.unionByName(df2.drop("id"), allowMissingColumns=True) res = rel.collect() expected = [ - Row(name='James', id=34), - Row(name='Michael', id=56), - Row(name='Robert', id=30), - Row(name='Maria', id=24), - Row(name='James', id=None), - Row(name='Maria', id=None), - Row(name='Jen', id=None), - Row(name='Jeff', id=None), + Row(name="James", id=34), + Row(name="Michael", id=56), + Row(name="Robert", id=30), + Row(name="Maria", id=24), + Row(name="James", id=None), + Row(name="Maria", id=None), + Row(name="Jen", id=None), + Row(name="Jeff", id=None), ] assert res == expected diff --git a/tests/fast/spark/test_spark_with_column.py b/tests/fast/spark/test_spark_with_column.py index 80da34c3..4ea62fe1 100644 --- a/tests/fast/spark/test_spark_with_column.py +++ b/tests/fast/spark/test_spark_with_column.py @@ -2,41 +2,27 @@ _ = pytest.importorskip("duckdb.experimental.spark") -from spark_namespace.sql.types import ( - LongType, - StructType, - BooleanType, - StructField, - StringType, - IntegerType, - LongType, - Row, - ArrayType, - MapType, -) -from spark_namespace.sql.functions import col, struct, when, lit from spark_namespace import USE_ACTUAL_SPARK -import duckdb -import re +from spark_namespace.sql.functions import col, lit -class TestWithColumn(object): +class TestWithColumn: def test_with_column(self, spark): data = [ - ('James', '', 'Smith', '1991-04-01', 'M', 3000), - ('Michael', 'Rose', '', '2000-05-19', 'M', 4000), - ('Robert', '', 'Williams', '1978-09-05', 'M', 4000), - ('Maria', 'Anne', 'Jones', '1967-12-01', 'F', 4000), - ('Jen', 'Mary', 'Brown', '1980-02-17', 'F', -1), + ("James", "", "Smith", "1991-04-01", "M", 3000), + ("Michael", "Rose", "", "2000-05-19", "M", 4000), + ("Robert", "", "Williams", "1978-09-05", "M", 4000), + ("Maria", "Anne", "Jones", "1967-12-01", "F", 4000), + ("Jen", "Mary", "Brown", "1980-02-17", "F", -1), ] columns = ["firstname", "middlename", "lastname", "dob", "gender", "salary"] df = spark.createDataFrame(data=data, schema=columns) - assert df.schema['salary'].dataType.typeName() == ('long' if USE_ACTUAL_SPARK else 'integer') + assert df.schema["salary"].dataType.typeName() == ("long" if USE_ACTUAL_SPARK else "integer") # The type of 'salary' has been cast to Bigint new_df = df.withColumn("salary", col("salary").cast("BIGINT")) - assert new_df.schema['salary'].dataType.typeName() == 'long' + assert new_df.schema["salary"].dataType.typeName() == "long" # Replace the 'salary' column with '(salary * 100)' df2 = df.withColumn("salary", col("salary") * 100) @@ -50,16 +36,16 @@ def test_with_column(self, spark): df2 = df.withColumn("Country", lit("USA")) res = df2.collect() - assert res[0].Country == 'USA' + assert res[0].Country == "USA" df2 = df.withColumn("Country", lit("USA")).withColumn("anotherColumn", lit("anotherValue")) res = df2.collect() - assert res[0].Country == 'USA' - assert res[1].anotherColumn == 'anotherValue' + assert res[0].Country == "USA" + assert res[1].anotherColumn == "anotherValue" df2 = df.withColumnRenamed("gender", "sex") - assert 'gender' not in df2.columns - assert 'sex' in df2.columns + assert "gender" not in df2.columns + assert "sex" in df2.columns df2 = df.drop("salary") - assert 'salary' not in df2.columns + assert "salary" not in df2.columns diff --git a/tests/fast/spark/test_spark_with_column_renamed.py b/tests/fast/spark/test_spark_with_column_renamed.py index 168ff23a..73dd5606 100644 --- a/tests/fast/spark/test_spark_with_column_renamed.py +++ b/tests/fast/spark/test_spark_with_column_renamed.py @@ -2,69 +2,61 @@ _ = pytest.importorskip("duckdb.experimental.spark") + +from spark_namespace.sql.functions import col from spark_namespace.sql.types import ( - LongType, - StructType, - BooleanType, - StructField, - StringType, IntegerType, - LongType, - Row, - ArrayType, - MapType, + StringType, + StructField, + StructType, ) -from spark_namespace.sql.functions import col, struct, when, lit -import duckdb -import re -class TestWithColumnRenamed(object): +class TestWithColumnRenamed: def test_with_column_renamed(self, spark): dataDF = [ - (('James', '', 'Smith'), '1991-04-01', 'M', 3000), - (('Michael', 'Rose', ''), '2000-05-19', 'M', 4000), - (('Robert', '', 'Williams'), '1978-09-05', 'M', 4000), - (('Maria', 'Anne', 'Jones'), '1967-12-01', 'F', 4000), - (('Jen', 'Mary', 'Brown'), '1980-02-17', 'F', -1), + (("James", "", "Smith"), "1991-04-01", "M", 3000), + (("Michael", "Rose", ""), "2000-05-19", "M", 4000), + (("Robert", "", "Williams"), "1978-09-05", "M", 4000), + (("Maria", "Anne", "Jones"), "1967-12-01", "F", 4000), + (("Jen", "Mary", "Brown"), "1980-02-17", "F", -1), ] - from spark_namespace.sql.types import StructType, StructField, StringType, IntegerType schema = StructType( [ StructField( - 'name', + "name", StructType( [ - StructField('firstname', StringType(), True), - StructField('middlename', StringType(), True), - StructField('lastname', StringType(), True), + StructField("firstname", StringType(), True), + StructField("middlename", StringType(), True), + StructField("lastname", StringType(), True), ] ), ), - StructField('dob', StringType(), True), - StructField('gender', StringType(), True), - StructField('salary', IntegerType(), True), + StructField("dob", StringType(), True), + StructField("gender", StringType(), True), + StructField("salary", IntegerType(), True), ] ) df = spark.createDataFrame(data=dataDF, schema=schema) df2 = df.withColumnRenamed("dob", "DateOfBirth").withColumnRenamed("salary", "salary_amount") - assert 'dob' not in df2.columns - assert 'salary' not in df2.columns - assert 'DateOfBirth' in df2.columns - assert 'salary_amount' in df2.columns + assert "dob" not in df2.columns + assert "salary" not in df2.columns + assert "DateOfBirth" in df2.columns + assert "salary_amount" in df2.columns - schema2 = StructType( + StructType( [ StructField( - 'full name', + "full name", StructType( [ - StructField('fname', StringType(), True), - StructField('mname', StringType(), True), - StructField('lname', StringType(), True), + StructField("fname", StringType(), True), + StructField("mname", StringType(), True), + StructField("lname", StringType(), True), ] ), ), @@ -72,9 +64,9 @@ def test_with_column_renamed(self, spark): ) df2 = df.withColumnRenamed("name", "full name") - assert 'name' not in df2.columns - assert 'full name' in df2.columns - assert 'firstname' in df2.schema['full name'].dataType.fieldNames() + assert "name" not in df2.columns + assert "full name" in df2.columns + assert "firstname" in df2.schema["full name"].dataType.fieldNames() df2 = df.select( col("name").alias("full name"), @@ -82,9 +74,9 @@ def test_with_column_renamed(self, spark): col("gender"), col("salary"), ) - assert 'name' not in df2.columns - assert 'full name' in df2.columns - assert 'firstname' in df2.schema['full name'].dataType.fieldNames() + assert "name" not in df2.columns + assert "full name" in df2.columns + assert "firstname" in df2.schema["full name"].dataType.fieldNames() df2 = df.select( col("name.firstname").alias("fname"), @@ -94,5 +86,5 @@ def test_with_column_renamed(self, spark): col("gender"), col("salary"), ) - assert 'firstname' not in df2.columns - assert 'fname' in df2.columns + assert "firstname" not in df2.columns + assert "fname" in df2.columns diff --git a/tests/fast/spark/test_spark_with_columns.py b/tests/fast/spark/test_spark_with_columns.py index 6e1bedea..244d40a3 100644 --- a/tests/fast/spark/test_spark_with_columns.py +++ b/tests/fast/spark/test_spark_with_columns.py @@ -3,27 +3,27 @@ _ = pytest.importorskip("duckdb.experimental.spark") -from spark_namespace.sql.functions import col, lit from spark_namespace import USE_ACTUAL_SPARK +from spark_namespace.sql.functions import col, lit class TestWithColumns: def test_with_columns(self, spark): data = [ - ('James', '', 'Smith', '1991-04-01', 'M', 3000), - ('Michael', 'Rose', '', '2000-05-19', 'M', 4000), - ('Robert', '', 'Williams', '1978-09-05', 'M', 4000), - ('Maria', 'Anne', 'Jones', '1967-12-01', 'F', 4000), - ('Jen', 'Mary', 'Brown', '1980-02-17', 'F', -1), + ("James", "", "Smith", "1991-04-01", "M", 3000), + ("Michael", "Rose", "", "2000-05-19", "M", 4000), + ("Robert", "", "Williams", "1978-09-05", "M", 4000), + ("Maria", "Anne", "Jones", "1967-12-01", "F", 4000), + ("Jen", "Mary", "Brown", "1980-02-17", "F", -1), ] columns = ["firstname", "middlename", "lastname", "dob", "gender", "salary"] df = spark.createDataFrame(data=data, schema=columns) - assert df.schema['salary'].dataType.typeName() == ('long' if USE_ACTUAL_SPARK else 'integer') + assert df.schema["salary"].dataType.typeName() == ("long" if USE_ACTUAL_SPARK else "integer") # The type of 'salary' has been cast to Bigint new_df = df.withColumns({"salary": col("salary").cast("BIGINT")}) - assert new_df.schema['salary'].dataType.typeName() == 'long' + assert new_df.schema["salary"].dataType.typeName() == "long" # Replace the 'salary' column with '(salary * 100)' and add a new column # from an existing column @@ -34,12 +34,12 @@ def test_with_columns(self, spark): df2 = df.withColumns({"Country": lit("USA")}) res = df2.collect() - assert res[0].Country == 'USA' + assert res[0].Country == "USA" df2 = df.withColumns({"Country": lit("USA")}).withColumns({"anotherColumn": lit("anotherValue")}) res = df2.collect() - assert res[0].Country == 'USA' - assert res[1].anotherColumn == 'anotherValue' + assert res[0].Country == "USA" + assert res[1].anotherColumn == "anotherValue" df2 = df.drop("salary") - assert 'salary' not in df2.columns + assert "salary" not in df2.columns diff --git a/tests/fast/spark/test_spark_with_columns_renamed.py b/tests/fast/spark/test_spark_with_columns_renamed.py index 99c4ce63..8c24062b 100644 --- a/tests/fast/spark/test_spark_with_columns_renamed.py +++ b/tests/fast/spark/test_spark_with_columns_renamed.py @@ -1,4 +1,5 @@ import re + import pytest _ = pytest.importorskip("duckdb.experimental.spark") @@ -6,47 +7,47 @@ from spark_namespace import USE_ACTUAL_SPARK -class TestWithColumnsRenamed(object): +class TestWithColumnsRenamed: def test_with_columns_renamed(self, spark): dataDF = [ - (('James', '', 'Smith'), '1991-04-01', 'M', 3000), - (('Michael', 'Rose', ''), '2000-05-19', 'M', 4000), - (('Robert', '', 'Williams'), '1978-09-05', 'M', 4000), - (('Maria', 'Anne', 'Jones'), '1967-12-01', 'F', 4000), - (('Jen', 'Mary', 'Brown'), '1980-02-17', 'F', -1), + (("James", "", "Smith"), "1991-04-01", "M", 3000), + (("Michael", "Rose", ""), "2000-05-19", "M", 4000), + (("Robert", "", "Williams"), "1978-09-05", "M", 4000), + (("Maria", "Anne", "Jones"), "1967-12-01", "F", 4000), + (("Jen", "Mary", "Brown"), "1980-02-17", "F", -1), ] - from spark_namespace.sql.types import StructType, StructField, StringType, IntegerType + from spark_namespace.sql.types import IntegerType, StringType, StructField, StructType schema = StructType( [ StructField( - 'name', + "name", StructType( [ - StructField('firstname', StringType(), True), - StructField('middlename', StringType(), True), - StructField('lastname', StringType(), True), + StructField("firstname", StringType(), True), + StructField("middlename", StringType(), True), + StructField("lastname", StringType(), True), ] ), ), - StructField('dob', StringType(), True), - StructField('gender', StringType(), True), - StructField('salary', IntegerType(), True), + StructField("dob", StringType(), True), + StructField("gender", StringType(), True), + StructField("salary", IntegerType(), True), ] ) df = spark.createDataFrame(data=dataDF, schema=schema) df2 = df.withColumnsRenamed({"dob": "DateOfBirth", "salary": "salary_amount"}) - assert 'dob' not in df2.columns - assert 'salary' not in df2.columns - assert 'DateOfBirth' in df2.columns - assert 'salary_amount' in df2.columns + assert "dob" not in df2.columns + assert "salary" not in df2.columns + assert "DateOfBirth" in df2.columns + assert "salary_amount" in df2.columns df2 = df.withColumnsRenamed({"name": "full name"}) - assert 'name' not in df2.columns - assert 'full name' in df2.columns - assert 'firstname' in df2.schema['full name'].dataType.fieldNames() + assert "name" not in df2.columns + assert "full name" in df2.columns + assert "firstname" in df2.schema["full name"].dataType.fieldNames() # PySpark does not raise an error. This is a convenience we provide in DuckDB. if not USE_ACTUAL_SPARK: diff --git a/tests/fast/sqlite/test_types.py b/tests/fast/sqlite/test_types.py index d4be447a..b06228fc 100644 --- a/tests/fast/sqlite/test_types.py +++ b/tests/fast/sqlite/test_types.py @@ -27,8 +27,8 @@ import datetime import decimal import unittest + import duckdb -import pytest class DuckDBTypeTests(unittest.TestCase): @@ -42,76 +42,76 @@ def tearDown(self): self.con.close() def test_CheckString(self): - self.cur.execute("insert into test(s) values (?)", (u"Österreich",)) + self.cur.execute("insert into test(s) values (?)", ("Österreich",)) self.cur.execute("select s from test") row = self.cur.fetchone() - self.assertEqual(row[0], u"Österreich") + assert row[0] == "Österreich" def test_CheckSmallInt(self): self.cur.execute("insert into test(i) values (?)", (42,)) self.cur.execute("select i from test") row = self.cur.fetchone() - self.assertEqual(row[0], 42) + assert row[0] == 42 def test_CheckLargeInt(self): num = 2**40 self.cur.execute("insert into test(i) values (?)", (num,)) self.cur.execute("select i from test") row = self.cur.fetchone() - self.assertEqual(row[0], num) + assert row[0] == num def test_CheckFloat(self): val = 3.14 self.cur.execute("insert into test(f) values (?)", (val,)) self.cur.execute("select f from test") row = self.cur.fetchone() - self.assertEqual(row[0], val) + assert row[0] == val def test_CheckDecimalTooBig(self): val = 17.29 self.cur.execute("insert into test(f) values (?)", (decimal.Decimal(val),)) self.cur.execute("select f from test") row = self.cur.fetchone() - self.assertEqual(row[0], val) + assert row[0] == val def test_CheckDecimal(self): - val = '17.29' + val = "17.29" val = decimal.Decimal(val) self.cur.execute("insert into test(f) values (?)", (val,)) self.cur.execute("select f from test") row = self.cur.fetchone() - self.assertEqual(row[0], self.cur.execute("select 17.29::DOUBLE").fetchone()[0]) + assert row[0] == self.cur.execute("select 17.29::DOUBLE").fetchone()[0] def test_CheckDecimalWithExponent(self): - val = '1E5' + val = "1E5" val = decimal.Decimal(val) self.cur.execute("insert into test(f) values (?)", (val,)) self.cur.execute("select f from test") row = self.cur.fetchone() - self.assertEqual(row[0], self.cur.execute("select 1.00000::DOUBLE").fetchone()[0]) + assert row[0] == self.cur.execute("select 1.00000::DOUBLE").fetchone()[0] def test_CheckNaN(self): import math - val = decimal.Decimal('nan') + val = decimal.Decimal("nan") self.cur.execute("insert into test(f) values (?)", (val,)) self.cur.execute("select f from test") row = self.cur.fetchone() - self.assertEqual(math.isnan(row[0]), True) + assert math.isnan(row[0]) def test_CheckInf(self): - val = decimal.Decimal('inf') + val = decimal.Decimal("inf") self.cur.execute("insert into test(f) values (?)", (val,)) self.cur.execute("select f from test") row = self.cur.fetchone() - self.assertEqual(row[0], val) + assert row[0] == val def test_CheckBytesBlob(self): val = b"Guglhupf" self.cur.execute("insert into test(b) values (?)", (val,)) self.cur.execute("select b from test") row = self.cur.fetchone() - self.assertEqual(row[0], val) + assert row[0] == val def test_CheckMemoryviewBlob(self): sample = b"Guglhupf" @@ -119,27 +119,27 @@ def test_CheckMemoryviewBlob(self): self.cur.execute("insert into test(b) values (?)", (val,)) self.cur.execute("select b from test") row = self.cur.fetchone() - self.assertEqual(row[0], sample) + assert row[0] == sample def test_CheckMemoryviewFromhexBlob(self): - sample = bytes.fromhex('00FF0F2E3D4C5B6A798800FF00') + sample = bytes.fromhex("00FF0F2E3D4C5B6A798800FF00") val = memoryview(sample) self.cur.execute("insert into test(b) values (?)", (val,)) self.cur.execute("select b from test") row = self.cur.fetchone() - self.assertEqual(row[0], sample) + assert row[0] == sample def test_CheckNoneBlob(self): val = None self.cur.execute("insert into test(b) values (?)", (val,)) self.cur.execute("select b from test") row = self.cur.fetchone() - self.assertEqual(row[0], val) + assert row[0] == val def test_CheckUnicodeExecute(self): - self.cur.execute(u"select 'Österreich'") + self.cur.execute("select 'Österreich'") row = self.cur.fetchone() - self.assertEqual(row[0], u"Österreich") + assert row[0] == "Österreich" class CommonTableExpressionTests(unittest.TestCase): @@ -154,24 +154,24 @@ def tearDown(self): def test_CheckCursorDescriptionCTESimple(self): self.cur.execute("with one as (select 1) select * from one") - self.assertIsNotNone(self.cur.description) - self.assertEqual(self.cur.description[0][0], "1") + assert self.cur.description is not None + assert self.cur.description[0][0] == "1" def test_CheckCursorDescriptionCTESMultipleColumns(self): self.cur.execute("insert into test values(1)") self.cur.execute("insert into test values(2)") self.cur.execute("with testCTE as (select * from test) select * from testCTE") - self.assertIsNotNone(self.cur.description) - self.assertEqual(self.cur.description[0][0], "x") + assert self.cur.description is not None + assert self.cur.description[0][0] == "x" def test_CheckCursorDescriptionCTE(self): self.cur.execute("insert into test values (1)") self.cur.execute("with bar as (select * from test) select * from test where x = 1") - self.assertIsNotNone(self.cur.description) - self.assertEqual(self.cur.description[0][0], "x") + assert self.cur.description is not None + assert self.cur.description[0][0] == "x" self.cur.execute("with bar as (select * from test) select * from test where x = 2") - self.assertIsNotNone(self.cur.description) - self.assertEqual(self.cur.description[0][0], "x") + assert self.cur.description is not None + assert self.cur.description[0][0] == "x" class DateTimeTests(unittest.TestCase): @@ -189,51 +189,51 @@ def test_CheckDate(self): self.cur.execute("insert into test(d) values (?)", (d,)) self.cur.execute("select d from test") d2 = self.cur.fetchone()[0] - self.assertEqual(d, d2) + assert d == d2 def test_CheckTime(self): t = datetime.time(7, 15, 0) self.cur.execute("insert into test(t) values (?)", (t,)) self.cur.execute("select t from test") t2 = self.cur.fetchone()[0] - self.assertEqual(t, t2) + assert t == t2 def test_CheckTimestamp(self): ts = datetime.datetime(2004, 2, 14, 7, 15, 0) self.cur.execute("insert into test(ts) values (?)", (ts,)) self.cur.execute("select ts from test") ts2 = self.cur.fetchone()[0] - self.assertEqual(ts, ts2) + assert ts == ts2 def test_CheckSqlTimestamp(self): - now = datetime.datetime.now(datetime.UTC) if hasattr(datetime, 'UTC') else datetime.datetime.utcnow() + now = datetime.datetime.now(datetime.UTC) if hasattr(datetime, "UTC") else datetime.datetime.utcnow() self.cur.execute("insert into test(ts) values (current_timestamp)") self.cur.execute("select ts from test") ts = self.cur.fetchone()[0] - self.assertEqual(type(ts), datetime.datetime) - self.assertEqual(ts.year, now.year) + assert type(ts) is datetime.datetime + assert ts.year == now.year def test_CheckDateTimeSubSeconds(self): ts = datetime.datetime(2004, 2, 14, 7, 15, 0, 500000) self.cur.execute("insert into test(ts) values (?)", (ts,)) self.cur.execute("select ts from test") ts2 = self.cur.fetchone()[0] - self.assertEqual(ts, ts2) + assert ts == ts2 def test_CheckTimeSubSeconds(self): t = datetime.time(7, 15, 0, 500000) self.cur.execute("insert into test(t) values (?)", (t,)) self.cur.execute("select t from test") t2 = self.cur.fetchone()[0] - self.assertEqual(t, t2) + assert t == t2 def test_CheckDateTimeSubSecondsFloatingPoint(self): ts = datetime.datetime(2004, 2, 14, 7, 15, 0, 510241) self.cur.execute("insert into test(ts) values (?)", (ts,)) self.cur.execute("select ts from test") ts2 = self.cur.fetchone()[0] - self.assertEqual(ts.year, ts2.year) - self.assertEqual(ts2.microsecond, 510241) + assert ts.year == ts2.year + assert ts2.microsecond == 510241 class ListTests(unittest.TestCase): @@ -249,44 +249,24 @@ def tearDown(self): def test_CheckEmptyList(self): val = [] self.cur.execute("insert into test values (?, ?)", (val, val)) - self.assertEqual( - self.cur.execute("select * from test").fetchall(), - [(val, val)], - ) + assert self.cur.execute("select * from test").fetchall() == [(val, val)] def test_CheckSingleList(self): val = [1, 2, 3] self.cur.execute("insert into test(single) values (?)", (val,)) - self.assertEqual( - self.cur.execute("select * from test").fetchall(), - [(val, None)], - ) + assert self.cur.execute("select * from test").fetchall() == [(val, None)] def test_CheckNestedList(self): val = [[1], [2], [3, 4]] self.cur.execute("insert into test(nested) values (?)", (val,)) - self.assertEqual( - self.cur.execute("select * from test").fetchall(), - [ - ( - None, - val, - ) - ], - ) + assert self.cur.execute("select * from test").fetchall() == [(None, val)] def test_CheckNone(self): val = None self.cur.execute("insert into test values (?, ?)", (val, val)) - self.assertEqual( - self.cur.execute("select * from test").fetchall(), - [(val, val)], - ) + assert self.cur.execute("select * from test").fetchall() == [(val, val)] def test_CheckEmbeddedNone(self): val = [None] self.cur.execute("insert into test values (?, ?)", (val, val)) - self.assertEqual( - self.cur.execute("select * from test").fetchall(), - [(val, val)], - ) + assert self.cur.execute("select * from test").fetchall() == [(val, val)] diff --git a/tests/fast/test_alex_multithread.py b/tests/fast/test_alex_multithread.py index 92768ec0..243779be 100644 --- a/tests/fast/test_alex_multithread.py +++ b/tests/fast/test_alex_multithread.py @@ -1,8 +1,9 @@ import platform -import duckdb from threading import Thread, current_thread + import pytest +import duckdb pytestmark = pytest.mark.xfail( condition=platform.system() == "Emscripten", @@ -30,18 +31,19 @@ def insert_from_same_connection(duckdb_cursor): duckdb_cursor.execute("""INSERT INTO my_inserts VALUES (?)""", (thread_name,)) -class TestPythonMultithreading(object): +class TestPythonMultithreading: def test_multiple_cursors(self, duckdb_cursor): duckdb_con = duckdb.connect() # In Memory DuckDB duckdb_con.execute("""CREATE OR REPLACE TABLE my_inserts (thread_name varchar)""") thread_count = 3 - threads = [] # Kick off multiple threads (in the same process) # Pass in the same connection as an argument, and an object to store the results - for i in range(thread_count): - threads.append(Thread(target=insert_from_cursor, args=(duckdb_con,), name='my_thread_' + str(i))) + threads = [ + Thread(target=insert_from_cursor, args=(duckdb_con,), name="my_thread_" + str(i)) + for i in range(thread_count) + ] for thread in threads: thread.start() @@ -50,9 +52,9 @@ def test_multiple_cursors(self, duckdb_cursor): thread.join() assert duckdb_con.execute("""SELECT * FROM my_inserts order by thread_name""").fetchall() == [ - ('my_thread_0',), - ('my_thread_1',), - ('my_thread_2',), + ("my_thread_0",), + ("my_thread_1",), + ("my_thread_2",), ] def test_same_connection(self, duckdb_cursor): @@ -67,7 +69,7 @@ def test_same_connection(self, duckdb_cursor): # Pass in the same connection as an argument, and an object to store the results for i in range(thread_count): cursors.append(duckdb_con.cursor()) - threads.append(Thread(target=insert_from_same_connection, args=(cursors[i],), name='my_thread_' + str(i))) + threads.append(Thread(target=insert_from_same_connection, args=(cursors[i],), name="my_thread_" + str(i))) for thread in threads: thread.start() @@ -76,9 +78,9 @@ def test_same_connection(self, duckdb_cursor): thread.join() assert duckdb_con.execute("""SELECT * FROM my_inserts order by thread_name""").fetchall() == [ - ('my_thread_0',), - ('my_thread_1',), - ('my_thread_2',), + ("my_thread_0",), + ("my_thread_1",), + ("my_thread_2",), ] def test_multiple_cursors_persisted(self, tmp_database): @@ -86,12 +88,14 @@ def test_multiple_cursors_persisted(self, tmp_database): duckdb_con.execute("""CREATE OR REPLACE TABLE my_inserts (thread_name varchar)""") thread_count = 3 - threads = [] # Kick off multiple threads (in the same process) # Pass in the same connection as an argument, and an object to store the results - for i in range(thread_count): - threads.append(Thread(target=insert_from_cursor, args=(duckdb_con,), name='my_thread_' + str(i))) + threads = [ + Thread(target=insert_from_cursor, args=(duckdb_con,), name="my_thread_" + str(i)) + for i in range(thread_count) + ] + for thread in threads: thread.start() @@ -99,9 +103,9 @@ def test_multiple_cursors_persisted(self, tmp_database): thread.join() assert duckdb_con.execute("""SELECT * FROM my_inserts order by thread_name""").fetchall() == [ - ('my_thread_0',), - ('my_thread_1',), - ('my_thread_2',), + ("my_thread_0",), + ("my_thread_1",), + ("my_thread_2",), ] duckdb_con.close() @@ -110,12 +114,13 @@ def test_same_connection_persisted(self, tmp_database): duckdb_con.execute("""CREATE OR REPLACE TABLE my_inserts (thread_name varchar)""") thread_count = 3 - threads = [] # Kick off multiple threads (in the same process) # Pass in the same connection as an argument, and an object to store the results - for i in range(thread_count): - threads.append(Thread(target=insert_from_same_connection, args=(duckdb_con,), name='my_thread_' + str(i))) + threads = [ + Thread(target=insert_from_same_connection, args=(duckdb_con,), name="my_thread_" + str(i)) + for i in range(thread_count) + ] for thread in threads: thread.start() @@ -123,8 +128,8 @@ def test_same_connection_persisted(self, tmp_database): thread.join() assert duckdb_con.execute("""SELECT * FROM my_inserts order by thread_name""").fetchall() == [ - ('my_thread_0',), - ('my_thread_1',), - ('my_thread_2',), + ("my_thread_0",), + ("my_thread_1",), + ("my_thread_2",), ] duckdb_con.close() diff --git a/tests/fast/test_all_types.py b/tests/fast/test_all_types.py index 2128f9f1..77074fdc 100644 --- a/tests/fast/test_all_types.py +++ b/tests/fast/test_all_types.py @@ -1,18 +1,20 @@ -import duckdb -import pandas as pd -import numpy as np import datetime import math +import warnings +from contextlib import suppress from decimal import Decimal from uuid import UUID -import pytz + +import numpy as np +import pandas as pd import pytest -import warnings -from contextlib import suppress +import pytz + +import duckdb def replace_with_ndarray(obj): - if hasattr(obj, '__getitem__'): + if hasattr(obj, "__getitem__"): if isinstance(obj, dict): for key, value in obj.items(): obj[key] = replace_with_ndarray(value) @@ -25,22 +27,17 @@ def replace_with_ndarray(obj): # we need to write our own equality function that considers nan==nan for testing purposes def recursive_equality(o1, o2): - import math - - if type(o1) != type(o2): + if type(o1) is not type(o2): return False - if type(o1) == float and math.isnan(o1) and math.isnan(o2): + if type(o1) == float and math.isnan(o1) and math.isnan(o2): # noqa: E721 return True if o1 is np.ma.masked and o2 is np.ma.masked: return True try: if len(o1) != len(o2): return False - for i in range(len(o1)): - if not recursive_equality(o1[i], o2[i]): - return False - return True - except: + return all(recursive_equality(o1[i], o2[i]) for i in range(len(o1))) + except Exception: return o1 == o2 @@ -114,70 +111,70 @@ def recursive_equality(o1, o2): ] -class TestAllTypes(object): - @pytest.mark.parametrize('cur_type', all_types) +class TestAllTypes: + @pytest.mark.parametrize("cur_type", all_types) def test_fetchall(self, cur_type): conn = duckdb.connect() conn.execute("SET TimeZone =UTC") # We replace these values since the extreme ranges are not supported in native-python. replacement_values = { - 'timestamp': "'1990-01-01 00:00:00'::TIMESTAMP", - 'timestamp_s': "'1990-01-01 00:00:00'::TIMESTAMP_S", - 'timestamp_ns': "'1990-01-01 00:00:00'::TIMESTAMP_NS", - 'timestamp_ms': "'1990-01-01 00:00:00'::TIMESTAMP_MS", - 'timestamp_tz': "'1990-01-01 00:00:00Z'::TIMESTAMPTZ", - 'date': "'1990-01-01'::DATE", - 'date_array': "[], ['1970-01-01'::DATE, NULL, '0001-01-01'::DATE, '9999-12-31'::DATE,], [NULL::DATE,]", - 'timestamp_array': "[], ['1970-01-01'::TIMESTAMP, NULL, '0001-01-01'::TIMESTAMP, '9999-12-31 23:59:59.999999'::TIMESTAMP,], [NULL::TIMESTAMP,]", - 'timestamptz_array': "[], ['1970-01-01 00:00:00Z'::TIMESTAMPTZ, NULL, '0001-01-01 00:00:00Z'::TIMESTAMPTZ, '9999-12-31 23:59:59.999999Z'::TIMESTAMPTZ,], [NULL::TIMESTAMPTZ,]", + "timestamp": "'1990-01-01 00:00:00'::TIMESTAMP", + "timestamp_s": "'1990-01-01 00:00:00'::TIMESTAMP_S", + "timestamp_ns": "'1990-01-01 00:00:00'::TIMESTAMP_NS", + "timestamp_ms": "'1990-01-01 00:00:00'::TIMESTAMP_MS", + "timestamp_tz": "'1990-01-01 00:00:00Z'::TIMESTAMPTZ", + "date": "'1990-01-01'::DATE", + "date_array": "[], ['1970-01-01'::DATE, NULL, '0001-01-01'::DATE, '9999-12-31'::DATE,], [NULL::DATE,]", + "timestamp_array": "[], ['1970-01-01'::TIMESTAMP, NULL, '0001-01-01'::TIMESTAMP, '9999-12-31 23:59:59.999999'::TIMESTAMP,], [NULL::TIMESTAMP,]", # noqa: E501 + "timestamptz_array": "[], ['1970-01-01 00:00:00Z'::TIMESTAMPTZ, NULL, '0001-01-01 00:00:00Z'::TIMESTAMPTZ, '9999-12-31 23:59:59.999999Z'::TIMESTAMPTZ,], [NULL::TIMESTAMPTZ,]", # noqa: E501 } adjusted_values = { - 'time': """CASE WHEN "time" = '24:00:00'::TIME THEN '23:59:59.999999'::TIME ELSE "time" END AS "time" """, - 'time_tz': """CASE WHEN time_tz = '24:00:00-1559'::TIMETZ THEN '23:59:59.999999-1559'::TIMETZ ELSE time_tz END AS "time_tz" """, + "time": """CASE WHEN "time" = '24:00:00'::TIME THEN '23:59:59.999999'::TIME ELSE "time" END AS "time" """, + "time_tz": """CASE WHEN time_tz = '24:00:00-1559'::TIMETZ THEN '23:59:59.999999-1559'::TIMETZ ELSE time_tz END AS "time_tz" """, # noqa: E501 } min_datetime = datetime.datetime.min min_datetime_with_utc = min_datetime.replace(tzinfo=pytz.UTC) max_datetime = datetime.datetime.max max_datetime_with_utc = max_datetime.replace(tzinfo=pytz.UTC) correct_answer_map = { - 'bool': [(False,), (True,), (None,)], - 'tinyint': [(-128,), (127,), (None,)], - 'smallint': [(-32768,), (32767,), (None,)], - 'int': [(-2147483648,), (2147483647,), (None,)], - 'bigint': [(-9223372036854775808,), (9223372036854775807,), (None,)], - 'hugeint': [ + "bool": [(False,), (True,), (None,)], + "tinyint": [(-128,), (127,), (None,)], + "smallint": [(-32768,), (32767,), (None,)], + "int": [(-2147483648,), (2147483647,), (None,)], + "bigint": [(-9223372036854775808,), (9223372036854775807,), (None,)], + "hugeint": [ (-170141183460469231731687303715884105728,), (170141183460469231731687303715884105727,), (None,), ], - 'utinyint': [(0,), (255,), (None,)], - 'usmallint': [(0,), (65535,), (None,)], - 'uint': [(0,), (4294967295,), (None,)], - 'ubigint': [(0,), (18446744073709551615,), (None,)], - 'time': [(datetime.time(0, 0),), (datetime.time(23, 59, 59, 999999),), (None,)], - 'float': [(-3.4028234663852886e38,), (3.4028234663852886e38,), (None,)], - 'double': [(-1.7976931348623157e308,), (1.7976931348623157e308,), (None,)], - 'dec_4_1': [(Decimal('-999.9'),), (Decimal('999.9'),), (None,)], - 'dec_9_4': [(Decimal('-99999.9999'),), (Decimal('99999.9999'),), (None,)], - 'dec_18_6': [(Decimal('-999999999999.999999'),), (Decimal('999999999999.999999'),), (None,)], - 'dec38_10': [ - (Decimal('-9999999999999999999999999999.9999999999'),), - (Decimal('9999999999999999999999999999.9999999999'),), + "utinyint": [(0,), (255,), (None,)], + "usmallint": [(0,), (65535,), (None,)], + "uint": [(0,), (4294967295,), (None,)], + "ubigint": [(0,), (18446744073709551615,), (None,)], + "time": [(datetime.time(0, 0),), (datetime.time(23, 59, 59, 999999),), (None,)], + "float": [(-3.4028234663852886e38,), (3.4028234663852886e38,), (None,)], + "double": [(-1.7976931348623157e308,), (1.7976931348623157e308,), (None,)], + "dec_4_1": [(Decimal("-999.9"),), (Decimal("999.9"),), (None,)], + "dec_9_4": [(Decimal("-99999.9999"),), (Decimal("99999.9999"),), (None,)], + "dec_18_6": [(Decimal("-999999999999.999999"),), (Decimal("999999999999.999999"),), (None,)], + "dec38_10": [ + (Decimal("-9999999999999999999999999999.9999999999"),), + (Decimal("9999999999999999999999999999.9999999999"),), (None,), ], - 'uuid': [ - (UUID('00000000-0000-0000-0000-000000000000'),), - (UUID('ffffffff-ffff-ffff-ffff-ffffffffffff'),), + "uuid": [ + (UUID("00000000-0000-0000-0000-000000000000"),), + (UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"),), (None,), ], - 'varchar': [('🦆🦆🦆🦆🦆🦆',), ('goo\0se',), (None,)], - 'json': [('🦆🦆🦆🦆🦆🦆',), ('goose',), (None,)], - 'blob': [(b'thisisalongblob\x00withnullbytes',), (b'\x00\x00\x00a',), (None,)], - 'bit': [('0010001001011100010101011010111',), ('10101',), (None,)], - 'small_enum': [('DUCK_DUCK_ENUM',), ('GOOSE',), (None,)], - 'medium_enum': [('enum_0',), ('enum_299',), (None,)], - 'large_enum': [('enum_0',), ('enum_69999',), (None,)], - 'date_array': [ + "varchar": [("🦆🦆🦆🦆🦆🦆",), ("goo\0se",), (None,)], + "json": [("🦆🦆🦆🦆🦆🦆",), ("goose",), (None,)], + "blob": [(b"thisisalongblob\x00withnullbytes",), (b"\x00\x00\x00a",), (None,)], + "bit": [("0010001001011100010101011010111",), ("10101",), (None,)], + "small_enum": [("DUCK_DUCK_ENUM",), ("GOOSE",), (None,)], + "medium_enum": [("enum_0",), ("enum_299",), (None,)], + "large_enum": [("enum_0",), ("enum_69999",), (None,)], + "date_array": [ ( [], [datetime.date(1970, 1, 1), None, datetime.date.min, datetime.date.max], @@ -186,7 +183,7 @@ def test_fetchall(self, cur_type): ], ) ], - 'timestamp_array': [ + "timestamp_array": [ ( [], [datetime.datetime(1970, 1, 1), None, datetime.datetime.min, datetime.datetime.max], @@ -195,7 +192,7 @@ def test_fetchall(self, cur_type): ], ), ], - 'timestamptz_array': [ + "timestamptz_array": [ ( [], [ @@ -209,67 +206,67 @@ def test_fetchall(self, cur_type): ], ), ], - 'int_array': [([],), ([42, 999, None, None, -42],), (None,)], - 'varchar_array': [([],), (['🦆🦆🦆🦆🦆🦆', 'goose', None, ''],), (None,)], - 'double_array': [([],), ([42.0, float('nan'), float('inf'), float('-inf'), None, -42.0],), (None,)], - 'nested_int_array': [ + "int_array": [([],), ([42, 999, None, None, -42],), (None,)], + "varchar_array": [([],), (["🦆🦆🦆🦆🦆🦆", "goose", None, ""],), (None,)], + "double_array": [([],), ([42.0, float("nan"), float("inf"), float("-inf"), None, -42.0],), (None,)], + "nested_int_array": [ ([],), ([[], [42, 999, None, None, -42], None, [], [42, 999, None, None, -42]],), (None,), ], - 'struct': [({'a': None, 'b': None},), ({'a': 42, 'b': '🦆🦆🦆🦆🦆🦆'},), (None,)], - 'struct_of_arrays': [ - ({'a': None, 'b': None},), - ({'a': [42, 999, None, None, -42], 'b': ['🦆🦆🦆🦆🦆🦆', 'goose', None, '']},), + "struct": [({"a": None, "b": None},), ({"a": 42, "b": "🦆🦆🦆🦆🦆🦆"},), (None,)], + "struct_of_arrays": [ + ({"a": None, "b": None},), + ({"a": [42, 999, None, None, -42], "b": ["🦆🦆🦆🦆🦆🦆", "goose", None, ""]},), (None,), ], - 'array_of_structs': [([],), ([{'a': None, 'b': None}, {'a': 42, 'b': '🦆🦆🦆🦆🦆🦆'}, None],), (None,)], - 'map': [ + "array_of_structs": [([],), ([{"a": None, "b": None}, {"a": 42, "b": "🦆🦆🦆🦆🦆🦆"}, None],), (None,)], + "map": [ ({},), - ({'key1': '🦆🦆🦆🦆🦆🦆', 'key2': 'goose'},), + ({"key1": "🦆🦆🦆🦆🦆🦆", "key2": "goose"},), (None,), ], - 'time_tz': [(datetime.time(0, 0),), (datetime.time(23, 59, 59, 999999),), (None,)], - 'interval': [ + "time_tz": [(datetime.time(0, 0),), (datetime.time(23, 59, 59, 999999),), (None,)], + "interval": [ (datetime.timedelta(0),), (datetime.timedelta(days=30969, seconds=999, microseconds=999999),), (None,), ], - 'timestamp': [(datetime.datetime(1990, 1, 1, 0, 0),)], - 'date': [(datetime.date(1990, 1, 1),)], - 'timestamp_s': [(datetime.datetime(1990, 1, 1, 0, 0),)], - 'timestamp_ns': [(datetime.datetime(1990, 1, 1, 0, 0),)], - 'timestamp_ms': [(datetime.datetime(1990, 1, 1, 0, 0),)], - 'timestamp_tz': [(datetime.datetime(1990, 1, 1, 0, 0, tzinfo=pytz.UTC),)], - 'union': [('Frank',), (5,), (None,)], - 'fixed_int_array': [((None, 2, 3),), ((4, 5, 6),), (None,)], - 'fixed_varchar_array': [(('a', None, 'c'),), (('d', 'e', 'f'),), (None,)], - 'fixed_nested_int_array': [ + "timestamp": [(datetime.datetime(1990, 1, 1, 0, 0),)], + "date": [(datetime.date(1990, 1, 1),)], + "timestamp_s": [(datetime.datetime(1990, 1, 1, 0, 0),)], + "timestamp_ns": [(datetime.datetime(1990, 1, 1, 0, 0),)], + "timestamp_ms": [(datetime.datetime(1990, 1, 1, 0, 0),)], + "timestamp_tz": [(datetime.datetime(1990, 1, 1, 0, 0, tzinfo=pytz.UTC),)], + "union": [("Frank",), (5,), (None,)], + "fixed_int_array": [((None, 2, 3),), ((4, 5, 6),), (None,)], + "fixed_varchar_array": [(("a", None, "c"),), (("d", "e", "f"),), (None,)], + "fixed_nested_int_array": [ (((None, 2, 3), None, (None, 2, 3)),), (((4, 5, 6), (None, 2, 3), (4, 5, 6)),), (None,), ], - 'fixed_nested_varchar_array': [ - ((('a', None, 'c'), None, ('a', None, 'c')),), - ((('d', 'e', 'f'), ('a', None, 'c'), ('d', 'e', 'f')),), + "fixed_nested_varchar_array": [ + ((("a", None, "c"), None, ("a", None, "c")),), + ((("d", "e", "f"), ("a", None, "c"), ("d", "e", "f")),), (None,), ], - 'fixed_struct_array': [ - (({'a': None, 'b': None}, {'a': 42, 'b': '🦆🦆🦆🦆🦆🦆'}, {'a': None, 'b': None}),), - (({'a': 42, 'b': '🦆🦆🦆🦆🦆🦆'}, {'a': None, 'b': None}, {'a': 42, 'b': '🦆🦆🦆🦆🦆🦆'}),), + "fixed_struct_array": [ + (({"a": None, "b": None}, {"a": 42, "b": "🦆🦆🦆🦆🦆🦆"}, {"a": None, "b": None}),), + (({"a": 42, "b": "🦆🦆🦆🦆🦆🦆"}, {"a": None, "b": None}, {"a": 42, "b": "🦆🦆🦆🦆🦆🦆"}),), (None,), ], - 'struct_of_fixed_array': [ - ({'a': (None, 2, 3), 'b': ('a', None, 'c')},), - ({'a': (4, 5, 6), 'b': ('d', 'e', 'f')},), + "struct_of_fixed_array": [ + ({"a": (None, 2, 3), "b": ("a", None, "c")},), + ({"a": (4, 5, 6), "b": ("d", "e", "f")},), (None,), ], - 'fixed_array_of_int_list': [ + "fixed_array_of_int_list": [ (([], [42, 999, None, None, -42], []),), (([42, 999, None, None, -42], [], [42, 999, None, None, -42]),), (None,), ], - 'list_of_fixed_int_array': [ + "list_of_fixed_int_array": [ ([(None, 2, 3), (4, 5, 6), (None, 2, 3)],), ([(4, 5, 6), (None, 2, 3), (4, 5, 6)],), (None,), @@ -278,14 +275,14 @@ def test_fetchall(self, cur_type): if cur_type in replacement_values: result = conn.execute("select " + replacement_values[cur_type]).fetchall() elif cur_type in adjusted_values: - result = conn.execute(f'select {adjusted_values[cur_type]} from test_all_types()').fetchall() + result = conn.execute(f"select {adjusted_values[cur_type]} from test_all_types()").fetchall() else: result = conn.execute(f'select "{cur_type}" from test_all_types()').fetchall() correct_result = correct_answer_map[cur_type] assert recursive_equality(result, correct_result) def test_bytearray_with_nulls(self): - con = duckdb.connect(database=':memory:') + con = duckdb.connect(database=":memory:") con.execute("CREATE TABLE test (content BLOB)") want = bytearray([1, 2, 0, 3, 4]) con.execute("INSERT INTO test VALUES (?)", [want]) @@ -295,90 +292,90 @@ def test_bytearray_with_nulls(self): # Don't truncate the array on the nullbyte assert want == bytearray(got) - @pytest.mark.parametrize('cur_type', all_types) + @pytest.mark.parametrize("cur_type", all_types) def test_fetchnumpy(self, cur_type): conn = duckdb.connect() correct_answer_map = { - 'bool': np.ma.array( + "bool": np.ma.array( [False, True, False], mask=[0, 0, 1], ), - 'tinyint': np.ma.array( + "tinyint": np.ma.array( [-128, 127, -1], mask=[0, 0, 1], dtype=np.int8, ), - 'smallint': np.ma.array( + "smallint": np.ma.array( [-32768, 32767, -1], mask=[0, 0, 1], dtype=np.int16, ), - 'int': np.ma.array( + "int": np.ma.array( [-2147483648, 2147483647, -1], mask=[0, 0, 1], dtype=np.int32, ), - 'bigint': np.ma.array( + "bigint": np.ma.array( [-9223372036854775808, 9223372036854775807, -1], mask=[0, 0, 1], dtype=np.int64, ), - 'utinyint': np.ma.array( + "utinyint": np.ma.array( [0, 255, 42], mask=[0, 0, 1], dtype=np.uint8, ), - 'usmallint': np.ma.array( + "usmallint": np.ma.array( [0, 65535, 42], mask=[0, 0, 1], dtype=np.uint16, ), - 'uint': np.ma.array( + "uint": np.ma.array( [0, 4294967295, 42], mask=[0, 0, 1], dtype=np.uint32, ), - 'ubigint': np.ma.array( + "ubigint": np.ma.array( [0, 18446744073709551615, 42], mask=[0, 0, 1], dtype=np.uint64, ), - 'float': np.ma.array( + "float": np.ma.array( [-3.4028234663852886e38, 3.4028234663852886e38, 42.0], mask=[0, 0, 1], dtype=np.float32, ), - 'double': np.ma.array( + "double": np.ma.array( [-1.7976931348623157e308, 1.7976931348623157e308, 42.0], mask=[0, 0, 1], dtype=np.float64, ), - 'uuid': np.ma.array( + "uuid": np.ma.array( [ - UUID('00000000-0000-0000-0000-000000000000'), - UUID('ffffffff-ffff-ffff-ffff-ffffffffffff'), - UUID('00000000-0000-0000-0000-000000000042'), + UUID("00000000-0000-0000-0000-000000000000"), + UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"), + UUID("00000000-0000-0000-0000-000000000042"), ], mask=[0, 0, 1], dtype=object, ), - 'varchar': np.ma.array( - ['🦆🦆🦆🦆🦆🦆', 'goo\0se', "42"], + "varchar": np.ma.array( + ["🦆🦆🦆🦆🦆🦆", "goo\0se", "42"], mask=[0, 0, 1], dtype=object, ), - 'json': np.ma.array( - ['🦆🦆🦆🦆🦆🦆', 'goose', "42"], + "json": np.ma.array( + ["🦆🦆🦆🦆🦆🦆", "goose", "42"], mask=[0, 0, 1], dtype=object, ), - 'blob': np.ma.array( - [b'thisisalongblob\x00withnullbytes', b'\x00\x00\x00a', b"42"], + "blob": np.ma.array( + [b"thisisalongblob\x00withnullbytes", b"\x00\x00\x00a", b"42"], mask=[0, 0, 1], dtype=object, ), - 'interval': np.ma.array( + "interval": np.ma.array( [ np.timedelta64(0), np.timedelta64(2675722599999999000), @@ -388,7 +385,7 @@ def test_fetchnumpy(self, cur_type): ), # For timestamp_ns, the lowest value is out-of-range for numpy, # such that the conversion yields "Not a Time" - 'timestamp_ns': np.ma.array( + "timestamp_ns": np.ma.array( [ np.datetime64("NaT"), np.datetime64(9223372036854775806, "ns"), @@ -397,21 +394,21 @@ def test_fetchnumpy(self, cur_type): mask=[0, 0, 1], ), # Enums don't have a numpy equivalent and yield pandas Categorical. - 'small_enum': pd.Categorical( - ['DUCK_DUCK_ENUM', 'GOOSE', np.nan], + "small_enum": pd.Categorical( + ["DUCK_DUCK_ENUM", "GOOSE", np.nan], ordered=True, ), - 'medium_enum': pd.Categorical( - ['enum_0', 'enum_299', np.nan], + "medium_enum": pd.Categorical( + ["enum_0", "enum_299", np.nan], ordered=True, ), - 'large_enum': pd.Categorical( - ['enum_0', 'enum_69999', np.nan], + "large_enum": pd.Categorical( + ["enum_0", "enum_69999", np.nan], ordered=True, ), # The following types don't have a numpy equivalent and yield # object arrays: - 'int_array': np.ma.array( + "int_array": np.ma.array( [ [], [42, 999, None, None, -42], @@ -420,25 +417,25 @@ def test_fetchnumpy(self, cur_type): mask=[0, 0, 1], dtype=object, ), - 'varchar_array': np.ma.array( + "varchar_array": np.ma.array( [ [], - ['🦆🦆🦆🦆🦆🦆', 'goose', None, ''], + ["🦆🦆🦆🦆🦆🦆", "goose", None, ""], None, ], mask=[0, 0, 1], dtype=object, ), - 'double_array': np.ma.array( + "double_array": np.ma.array( [ [], - [42.0, float('nan'), float('inf'), float('-inf'), None, -42.0], + [42.0, float("nan"), float("inf"), float("-inf"), None, -42.0], None, ], mask=[0, 0, 1], dtype=object, ), - 'nested_int_array': np.ma.array( + "nested_int_array": np.ma.array( [ [], [[], [42, 999, None, None, -42], None, [], [42, 999, None, None, -42]], @@ -447,53 +444,53 @@ def test_fetchnumpy(self, cur_type): mask=[0, 0, 1], dtype=object, ), - 'struct': np.ma.array( + "struct": np.ma.array( [ - {'a': None, 'b': None}, - {'a': 42, 'b': '🦆🦆🦆🦆🦆🦆'}, + {"a": None, "b": None}, + {"a": 42, "b": "🦆🦆🦆🦆🦆🦆"}, None, ], mask=[0, 0, 1], dtype=object, ), - 'struct_of_arrays': np.ma.array( + "struct_of_arrays": np.ma.array( [ - {'a': None, 'b': None}, - {'a': [42, 999, None, None, -42], 'b': ['🦆🦆🦆🦆🦆🦆', 'goose', None, '']}, + {"a": None, "b": None}, + {"a": [42, 999, None, None, -42], "b": ["🦆🦆🦆🦆🦆🦆", "goose", None, ""]}, None, ], mask=[0, 0, 1], dtype=object, ), - 'array_of_structs': np.ma.array( + "array_of_structs": np.ma.array( [ [], - [{'a': None, 'b': None}, {'a': 42, 'b': '🦆🦆🦆🦆🦆🦆'}, None], + [{"a": None, "b": None}, {"a": 42, "b": "🦆🦆🦆🦆🦆🦆"}, None], None, ], mask=[0, 0, 1], dtype=object, ), - 'map': np.ma.array( + "map": np.ma.array( [ {}, - {'key1': '🦆🦆🦆🦆🦆🦆', 'key2': 'goose'}, + {"key1": "🦆🦆🦆🦆🦆🦆", "key2": "goose"}, None, ], mask=[0, 0, 1], dtype=object, ), - 'time': np.ma.array( - ['00:00:00', '24:00:00', None], + "time": np.ma.array( + ["00:00:00", "24:00:00", None], mask=[0, 0, 1], dtype=object, ), - 'time_tz': np.ma.array( - ['00:00:00', '23:59:59.999999', None], + "time_tz": np.ma.array( + ["00:00:00", "23:59:59.999999", None], mask=[0, 0, 1], dtype=object, ), - 'union': np.ma.array(['Frank', 5, None], mask=[0, 0, 1], dtype=object), + "union": np.ma.array(["Frank", 5, None], mask=[0, 0, 1], dtype=object), } correct_answer_map = replace_with_ndarray(correct_answer_map) @@ -535,19 +532,19 @@ def test_fetchnumpy(self, cur_type): assert np.all(result.mask == correct_answer.mask) np.testing.assert_equal(result, correct_answer) - @pytest.mark.parametrize('cur_type', all_types) + @pytest.mark.parametrize("cur_type", all_types) def test_arrow(self, cur_type): try: - import pyarrow as pa - except: + pass + except Exception: return # We skip those since the extreme ranges are not supported in arrow. - replacement_values = {'interval': "INTERVAL '2 years'"} + replacement_values = {"interval": "INTERVAL '2 years'"} # We do not round trip enum types - enum_types = {'small_enum', 'medium_enum', 'large_enum', 'double_array'} + enum_types = {"small_enum", "medium_enum", "large_enum", "double_array"} # uhugeint currently not supported by arrow - skip_types = {'uhugeint'} + skip_types = {"uhugeint"} if cur_type in skip_types: return @@ -565,35 +562,35 @@ def test_arrow(self, cur_type): round_trip_arrow_table = conn.execute("select * from arrow_table").fetch_arrow_table() assert arrow_table.equals(round_trip_arrow_table, check_metadata=True) - @pytest.mark.parametrize('cur_type', all_types) + @pytest.mark.parametrize("cur_type", all_types) def test_pandas(self, cur_type): # We skip those since the extreme ranges are not supported in python. replacement_values = { - 'timestamp': "'1990-01-01 00:00:00'::TIMESTAMP", - 'timestamp_s': "'1990-01-01 00:00:00'::TIMESTAMP_S", - 'timestamp_ns': "'1990-01-01 00:00:00'::TIMESTAMP_NS", - 'timestamp_ms': "'1990-01-01 00:00:00'::TIMESTAMP_MS", - 'timestamp_tz': "'1990-01-01 00:00:00Z'::TIMESTAMPTZ", - 'date': "'1990-01-01'::DATE", - 'date_array': "[], ['1970-01-01'::DATE, NULL, '0001-01-01'::DATE, '9999-12-31'::DATE,], [NULL::DATE,]", - 'timestamp_array': "[], ['1970-01-01'::TIMESTAMP, NULL, '0001-01-01'::TIMESTAMP, '9999-12-31 23:59:59.999999'::TIMESTAMP,], [NULL::TIMESTAMP,]", - 'timestamptz_array': "[], ['1970-01-01 00:00:00Z'::TIMESTAMPTZ, NULL, '0001-01-01 00:00:00Z'::TIMESTAMPTZ, '9999-12-31 23:59:59.999999Z'::TIMESTAMPTZ,], [NULL::TIMESTAMPTZ,]", + "timestamp": "'1990-01-01 00:00:00'::TIMESTAMP", + "timestamp_s": "'1990-01-01 00:00:00'::TIMESTAMP_S", + "timestamp_ns": "'1990-01-01 00:00:00'::TIMESTAMP_NS", + "timestamp_ms": "'1990-01-01 00:00:00'::TIMESTAMP_MS", + "timestamp_tz": "'1990-01-01 00:00:00Z'::TIMESTAMPTZ", + "date": "'1990-01-01'::DATE", + "date_array": "[], ['1970-01-01'::DATE, NULL, '0001-01-01'::DATE, '9999-12-31'::DATE,], [NULL::DATE,]", + "timestamp_array": "[], ['1970-01-01'::TIMESTAMP, NULL, '0001-01-01'::TIMESTAMP, '9999-12-31 23:59:59.999999'::TIMESTAMP,], [NULL::TIMESTAMP,]", # noqa: E501 + "timestamptz_array": "[], ['1970-01-01 00:00:00Z'::TIMESTAMPTZ, NULL, '0001-01-01 00:00:00Z'::TIMESTAMPTZ, '9999-12-31 23:59:59.999999Z'::TIMESTAMPTZ,], [NULL::TIMESTAMPTZ,]", # noqa: E501 } adjusted_values = { - 'time': """CASE WHEN "time" = '24:00:00'::TIME THEN '23:59:59.999999'::TIME ELSE "time" END AS "time" """, + "time": """CASE WHEN "time" = '24:00:00'::TIME THEN '23:59:59.999999'::TIME ELSE "time" END AS "time" """, } conn = duckdb.connect() # Pandas <= 2.2.3 does not convert without throwing a warning conn.execute("SET timezone = UTC") - warnings.simplefilter(action='ignore', category=RuntimeWarning) + warnings.simplefilter(action="ignore", category=RuntimeWarning) with suppress(TypeError): if cur_type in replacement_values: dataframe = conn.execute("select " + replacement_values[cur_type]).df() elif cur_type in adjusted_values: - dataframe = conn.execute(f'select {adjusted_values[cur_type]} from test_all_types()').df() + dataframe = conn.execute(f"select {adjusted_values[cur_type]} from test_all_types()").df() else: - dataframe = conn.execute(f'select "{cur_type}" from test_all_types()').df() + dataframe = conn.execute(f'select "{cur_type}" from test_all_types()').df() # noqa: F841 print(cur_type) round_trip_dataframe = conn.execute("select * from dataframe").df() result_dataframe = conn.execute("select * from dataframe").fetchall() diff --git a/tests/fast/test_ambiguous_prepare.py b/tests/fast/test_ambiguous_prepare.py index 998367ec..464dd79f 100644 --- a/tests/fast/test_ambiguous_prepare.py +++ b/tests/fast/test_ambiguous_prepare.py @@ -1,12 +1,10 @@ import duckdb -import pandas as pd -import pytest -class TestAmbiguousPrepare(object): +class TestAmbiguousPrepare: def test_bool(self, duckdb_cursor): conn = duckdb.connect() res = conn.execute("select ?, ?, ?", (True, 42, [1, 2, 3])).fetchall() - assert res[0][0] == True + assert res[0][0] assert res[0][1] == 42 assert res[0][2] == [1, 2, 3] diff --git a/tests/fast/test_case_alias.py b/tests/fast/test_case_alias.py index 4fcbd49c..d1afb4d8 100644 --- a/tests/fast/test_case_alias.py +++ b/tests/fast/test_case_alias.py @@ -1,41 +1,35 @@ -import pandas -import numpy as np -import datetime -import duckdb import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb -class TestCaseAlias(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) +class TestCaseAlias: + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_case_alias(self, duckdb_cursor, pandas): - import numpy as np - import datetime - import duckdb - - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") df = pandas.DataFrame([{"COL1": "val1", "CoL2": 1.05}, {"COL1": "val3", "CoL2": 17}]) - r1 = con.from_df(df).query('df', 'select * from df').df() + r1 = con.from_df(df).query("df", "select * from df").df() assert r1["COL1"][0] == "val1" assert r1["COL1"][1] == "val3" assert r1["CoL2"][0] == 1.05 assert r1["CoL2"][1] == 17 - r2 = con.from_df(df).query('df', 'select COL1, COL2 from df').df() + r2 = con.from_df(df).query("df", "select COL1, COL2 from df").df() assert r2["COL1"][0] == "val1" assert r2["COL1"][1] == "val3" assert r2["CoL2"][0] == 1.05 assert r2["CoL2"][1] == 17 - r3 = con.from_df(df).query('df', 'select COL1, COL2 from df ORDER BY COL1').df() + r3 = con.from_df(df).query("df", "select COL1, COL2 from df ORDER BY COL1").df() assert r3["COL1"][0] == "val1" assert r3["COL1"][1] == "val3" assert r3["CoL2"][0] == 1.05 assert r3["CoL2"][1] == 17 - r4 = con.from_df(df).query('df', 'select COL1, COL2 from df GROUP BY COL1, COL2 ORDER BY COL1').df() + r4 = con.from_df(df).query("df", "select COL1, COL2 from df GROUP BY COL1, COL2 ORDER BY COL1").df() assert r4["COL1"][0] == "val1" assert r4["COL1"][1] == "val3" assert r4["CoL2"][0] == 1.05 diff --git a/tests/fast/test_context_manager.py b/tests/fast/test_context_manager.py index 2ac451d1..b6a9ebb2 100644 --- a/tests/fast/test_context_manager.py +++ b/tests/fast/test_context_manager.py @@ -1,7 +1,7 @@ import duckdb -class TestContextManager(object): +class TestContextManager: def test_context_manager(self, duckdb_cursor): - with duckdb.connect(database=':memory:', read_only=False) as con: + with duckdb.connect(database=":memory:", read_only=False) as con: assert con.execute("select 1").fetchall() == [(1,)] diff --git a/tests/fast/test_duckdb_api.py b/tests/fast/test_duckdb_api.py index f5dcfb60..d779a368 100644 --- a/tests/fast/test_duckdb_api.py +++ b/tests/fast/test_duckdb_api.py @@ -1,8 +1,9 @@ -import duckdb import sys +import duckdb + def test_duckdb_api(): res = duckdb.execute("SELECT name, value FROM duckdb_settings() WHERE name == 'duckdb_api'") formatted_python_version = f"{sys.version_info.major}.{sys.version_info.minor}" - assert res.fetchall() == [('duckdb_api', f'python/{formatted_python_version}')] + assert res.fetchall() == [("duckdb_api", f"python/{formatted_python_version}")] diff --git a/tests/fast/test_expression.py b/tests/fast/test_expression.py index 289d88a9..5e61b455 100644 --- a/tests/fast/test_expression.py +++ b/tests/fast/test_expression.py @@ -1,19 +1,20 @@ +import datetime import platform -import duckdb + import pytest -from duckdb.typing import INTEGER, VARCHAR, TIMESTAMP + +import duckdb from duckdb import ( - Expression, - ConstantExpression, + CaseExpression, + CoalesceOperator, ColumnExpression, + ConstantExpression, + FunctionExpression, LambdaExpression, - CoalesceOperator, StarExpression, - FunctionExpression, - CaseExpression, ) -from duckdb.value.constant import Value, IntegerValue -import datetime +from duckdb.typing import INTEGER, TIMESTAMP, VARCHAR +from duckdb.value.constant import IntegerValue, Value pytestmark = pytest.mark.skipif( platform.system() == "Emscripten", @@ -21,7 +22,7 @@ ) -@pytest.fixture(scope='function') +@pytest.fixture def filter_rel(): con = duckdb.connect() rel = con.sql( @@ -36,9 +37,10 @@ def filter_rel(): """ ) yield rel + con.close() -class TestExpression(object): +class TestExpression: def test_constant_expression(self): con = duckdb.connect() @@ -59,7 +61,7 @@ def test_constant_expression(self): res = rel.fetchall() assert res == [(5,)] - @pytest.mark.skipif(platform.system() == 'Windows', reason="There is some weird interaction in Windows CI") + @pytest.mark.skipif(platform.system() == "Windows", reason="There is some weird interaction in Windows CI") def test_column_expression(self): con = duckdb.connect() @@ -71,12 +73,12 @@ def test_column_expression(self): 3 as c """ ) - column = ColumnExpression('a') + column = ColumnExpression("a") rel2 = rel.select(column) res = rel2.fetchall() assert res == [(1,)] - column = ColumnExpression('d') + column = ColumnExpression("d") with pytest.raises(duckdb.BinderException, match='Referenced column "d" not found'): rel2 = rel.select(column) @@ -89,9 +91,9 @@ def test_coalesce_operator(self): """ ) - rel2 = rel.select(CoalesceOperator(ConstantExpression(None), ConstantExpression('hello').cast(int))) + rel2 = rel.select(CoalesceOperator(ConstantExpression(None), ConstantExpression("hello").cast(int))) res = rel2.explain() - assert 'COALESCE' in res + assert "COALESCE" in res with pytest.raises(duckdb.ConversionException, match="Could not convert string 'hello' to INT64"): rel2.fetchall() @@ -103,8 +105,8 @@ def test_coalesce_operator(self): """ ) - with pytest.raises(duckdb.InvalidInputException, match='Please provide at least one argument'): - rel3 = rel.select(CoalesceOperator()) + with pytest.raises(duckdb.InvalidInputException, match="Please provide at least one argument"): + rel.select(CoalesceOperator()) rel4 = rel.select(CoalesceOperator(ConstantExpression(None))) assert rel4.fetchone() == (None,) @@ -112,7 +114,7 @@ def test_coalesce_operator(self): rel5 = rel.select(CoalesceOperator(ConstantExpression(42))) assert rel5.fetchone() == (42,) - exprtest = con.table('exprtest') + exprtest = con.table("exprtest") rel6 = exprtest.select(CoalesceOperator(ColumnExpression("a"))) res = rel6.fetchall() assert res == [(42,), (43,), (None,), (45,)] @@ -193,17 +195,17 @@ def test_column_expression_explain(self): """ ) rel = rel.select( - ConstantExpression("a").alias('c0'), - ConstantExpression(42).alias('c1'), - ConstantExpression(None).alias('c2'), + ConstantExpression("a").alias("c0"), + ConstantExpression(42).alias("c1"), + ConstantExpression(None).alias("c2"), ) res = rel.explain() - assert 'c0' in res - assert 'c1' in res + assert "c0" in res + assert "c1" in res # 'c2' is not in the explain result because it shows NULL instead - assert 'NULL' in res + assert "NULL" in res res = rel.fetchall() - assert res == [('a', 42, None)] + assert res == [("a", 42, None)] def test_column_expression_table(self): con = duckdb.connect() @@ -219,10 +221,10 @@ def test_column_expression_table(self): """ ) - rel = con.table('tbl') - rel2 = rel.select('c0', 'c1', 'c2') + rel = con.table("tbl") + rel2 = rel.select("c0", "c1", "c2") res = rel2.fetchall() - assert res == [('a', 'b', 'c'), ('d', 'e', 'f'), ('g', 'h', 'i')] + assert res == [("a", "b", "c"), ("d", "e", "f"), ("g", "h", "i")] def test_column_expression_view(self): con = duckdb.connect() @@ -241,18 +243,18 @@ def test_column_expression_view(self): CREATE VIEW v1 as select c0 as c3, c2 as c4 from tbl; """ ) - rel = con.view('v1') - rel2 = rel.select('c3', 'c4') + rel = con.view("v1") + rel2 = rel.select("c3", "c4") res = rel2.fetchall() - assert res == [('a', 'c'), ('d', 'f'), ('g', 'i')] + assert res == [("a", "c"), ("d", "f"), ("g", "i")] def test_column_expression_replacement_scan(self): con = duckdb.connect() pd = pytest.importorskip("pandas") - df = pd.DataFrame({'a': [42, 43, 0], 'b': [True, False, True], 'c': [23.123, 623.213, 0.30234]}) + df = pd.DataFrame({"a": [42, 43, 0], "b": [True, False, True], "c": [23.123, 623.213, 0.30234]}) # noqa: F841 rel = con.sql("select * from df") - rel2 = rel.select('a', 'b') + rel2 = rel.select("a", "b") res = rel2.fetchall() assert res == [(42, True), (43, False), (0, True)] @@ -271,7 +273,7 @@ def test_add_operator(self): ) constant = ConstantExpression(val) - col = ColumnExpression('b') + col = ColumnExpression("b") expr = col + constant rel = rel.select(expr, expr) @@ -288,7 +290,7 @@ def test_binary_function_expression(self): 5 as b """ ) - function = FunctionExpression("-", ColumnExpression('b'), ColumnExpression('a')) + function = FunctionExpression("-", ColumnExpression("b"), ColumnExpression("a")) rel2 = rel.select(function) res = rel2.fetchall() assert res == [(4,)] @@ -301,7 +303,7 @@ def test_negate_expression(self): select 5 as a """ ) - col = ColumnExpression('a') + col = ColumnExpression("a") col = -col rel = rel.select(col) res = rel.fetchall() @@ -317,8 +319,8 @@ def test_subtract_expression(self): 1 as b """ ) - col1 = ColumnExpression('a') - col2 = ColumnExpression('b') + col1 = ColumnExpression("a") + col2 = ColumnExpression("b") expr = col1 - col2 rel2 = rel.select(expr) res = rel2.fetchall() @@ -337,8 +339,8 @@ def test_multiply_expression(self): 2 as b """ ) - col1 = ColumnExpression('a') - col2 = ColumnExpression('b') + col1 = ColumnExpression("a") + col2 = ColumnExpression("b") expr = col1 * col2 rel = rel.select(expr) res = rel.fetchall() @@ -354,8 +356,8 @@ def test_division_expression(self): 2 as b """ ) - col1 = ColumnExpression('a') - col2 = ColumnExpression('b') + col1 = ColumnExpression("a") + col2 = ColumnExpression("b") expr = col1 / col2 rel2 = rel.select(expr) res = rel2.fetchall() @@ -376,8 +378,8 @@ def test_modulus_expression(self): 2 as b """ ) - col1 = ColumnExpression('a') - col2 = ColumnExpression('b') + col1 = ColumnExpression("a") + col2 = ColumnExpression("b") expr = col1 % col2 rel2 = rel.select(expr) res = rel2.fetchall() @@ -393,8 +395,8 @@ def test_power_expression(self): 2 as b """ ) - col1 = ColumnExpression('a') - col2 = ColumnExpression('b') + col1 = ColumnExpression("a") + col2 = ColumnExpression("b") expr = col1**col2 rel2 = rel.select(expr) res = rel2.fetchall() @@ -411,9 +413,9 @@ def test_between_expression(self): 3 as c """ ) - a = ColumnExpression('a') - b = ColumnExpression('b') - c = ColumnExpression('c') + a = ColumnExpression("a") + b = ColumnExpression("b") + c = ColumnExpression("c") # 5 BETWEEN 2 AND 3 -> false assert rel.select(a.between(b, c)).fetchall() == [(False,)] @@ -437,32 +439,32 @@ def test_collate_expression(self): """ ) - col1 = ColumnExpression('c0') - col2 = ColumnExpression('c1') + col1 = ColumnExpression("c0") + col2 = ColumnExpression("c1") - lower_a = ConstantExpression('a') - upper_a = ConstantExpression('A') + lower_a = ConstantExpression("a") + upper_a = ConstantExpression("A") # SELECT c0 LIKE 'a' == True - assert rel.select(FunctionExpression('~~', col1, lower_a)).fetchall() == [(True,)] + assert rel.select(FunctionExpression("~~", col1, lower_a)).fetchall() == [(True,)] # SELECT c0 LIKE 'A' == False - assert rel.select(FunctionExpression('~~', col1, upper_a)).fetchall() == [(False,)] + assert rel.select(FunctionExpression("~~", col1, upper_a)).fetchall() == [(False,)] # SELECT c0 LIKE 'A' COLLATE NOCASE == True - assert rel.select(FunctionExpression('~~', col1, upper_a.collate('NOCASE'))).fetchall() == [(True,)] + assert rel.select(FunctionExpression("~~", col1, upper_a.collate("NOCASE"))).fetchall() == [(True,)] # SELECT c1 LIKE 'a' == False - assert rel.select(FunctionExpression('~~', col2, lower_a)).fetchall() == [(False,)] + assert rel.select(FunctionExpression("~~", col2, lower_a)).fetchall() == [(False,)] # SELECT c1 LIKE 'a' COLLATE NOCASE == True - assert rel.select(FunctionExpression('~~', col2, lower_a.collate('NOCASE'))).fetchall() == [(True,)] + assert rel.select(FunctionExpression("~~", col2, lower_a.collate("NOCASE"))).fetchall() == [(True,)] - with pytest.raises(duckdb.BinderException, match='collations are only supported for type varchar'): - rel.select(FunctionExpression('~~', col2, lower_a).collate('NOCASE')) + with pytest.raises(duckdb.BinderException, match="collations are only supported for type varchar"): + rel.select(FunctionExpression("~~", col2, lower_a).collate("NOCASE")) - with pytest.raises(duckdb.CatalogException, match='Collation with name non-existant does not exist'): - rel.select(FunctionExpression('~~', col2, lower_a.collate('non-existant'))) + with pytest.raises(duckdb.CatalogException, match="Collation with name non-existant does not exist"): + rel.select(FunctionExpression("~~", col2, lower_a.collate("non-existant"))) def test_equality_expression(self): con = duckdb.connect() @@ -475,9 +477,9 @@ def test_equality_expression(self): 5 as c """ ) - col1 = ColumnExpression('a') - col2 = ColumnExpression('b') - col3 = ColumnExpression('c') + col1 = ColumnExpression("a") + col2 = ColumnExpression("b") + col3 = ColumnExpression("c") expr1 = col1 == col2 expr2 = col1 == col3 rel2 = rel.select(expr1, expr2) @@ -497,22 +499,22 @@ def test_lambda_expression(self): # Use a tuple of strings as 'lhs' func = FunctionExpression( "list_reduce", - ColumnExpression('a'), - LambdaExpression(('x', 'y'), ColumnExpression('x') + ColumnExpression('y')), + ColumnExpression("a"), + LambdaExpression(("x", "y"), ColumnExpression("x") + ColumnExpression("y")), ) rel2 = rel.select(func) res = rel2.fetchall() assert res == [(6,)] # Use only a string name as 'lhs' - func = FunctionExpression("list_apply", ColumnExpression('a'), LambdaExpression('x', ColumnExpression('x') + 3)) + func = FunctionExpression("list_apply", ColumnExpression("a"), LambdaExpression("x", ColumnExpression("x") + 3)) rel2 = rel.select(func) res = rel2.fetchall() assert res == [([4, 5, 6],)] # 'row' is not a lambda function, so it doesn't accept a lambda expression - func = FunctionExpression("row", ColumnExpression('a'), LambdaExpression('x', ColumnExpression('x') + 3)) - with pytest.raises(duckdb.BinderException, match='This scalar function does not support lambdas'): + func = FunctionExpression("row", ColumnExpression("a"), LambdaExpression("x", ColumnExpression("x") + 3)) + with pytest.raises(duckdb.BinderException, match="This scalar function does not support lambdas"): rel2 = rel.select(func) # lhs has to be a tuple of strings or a single string @@ -520,11 +522,11 @@ def test_lambda_expression(self): ValueError, match="Please provide 'lhs' as either a tuple containing strings, or a single string" ): func = FunctionExpression( - "list_filter", ColumnExpression('a'), LambdaExpression(42, ColumnExpression('x') + 3) + "list_filter", ColumnExpression("a"), LambdaExpression(42, ColumnExpression("x") + 3) ) func = FunctionExpression( - "list_filter", ColumnExpression('a'), LambdaExpression('x', ColumnExpression('y') != 3) + "list_filter", ColumnExpression("a"), LambdaExpression("x", ColumnExpression("y") != 3) ) with pytest.raises(duckdb.BinderException, match='Referenced column "y" not found in FROM clause'): rel2 = rel.select(func) @@ -540,9 +542,9 @@ def test_inequality_expression(self): 5 as c """ ) - col1 = ColumnExpression('a') - col2 = ColumnExpression('b') - col3 = ColumnExpression('c') + col1 = ColumnExpression("a") + col2 = ColumnExpression("b") + col3 = ColumnExpression("c") expr1 = col1 != col2 expr2 = col1 != col3 rel2 = rel.select(expr1, expr2) @@ -561,10 +563,10 @@ def test_comparison_expressions(self): 3 as d """ ) - col1 = ColumnExpression('a') - col2 = ColumnExpression('b') - col3 = ColumnExpression('c') - col4 = ColumnExpression('d') + col1 = ColumnExpression("a") + col2 = ColumnExpression("b") + col3 = ColumnExpression("c") + col4 = ColumnExpression("d") # Greater than expr1 = col1 > col2 @@ -606,11 +608,11 @@ def test_expression_alias(self): select 1 as a """ ) - col = ColumnExpression('a') - col = col.alias('b') + col = ColumnExpression("a") + col = col.alias("b") rel2 = rel.select(col) - assert rel2.columns == ['b'] + assert rel2.columns == ["b"] def test_star_expression(self): con = duckdb.connect() @@ -628,7 +630,7 @@ def test_star_expression(self): assert res == [(1, 2)] # With exclude list - star = StarExpression(exclude=['a']) + star = StarExpression(exclude=["a"]) rel2 = rel.select(star) res = rel2.fetchall() assert res == [(2,)] @@ -644,13 +646,13 @@ def test_struct_expression(self): """ ) - col1 = ColumnExpression('a') - col2 = ColumnExpression('b') - expr = FunctionExpression('struct_pack', col1, col2).alias('struct') + col1 = ColumnExpression("a") + col2 = ColumnExpression("b") + expr = FunctionExpression("struct_pack", col1, col2).alias("struct") rel = rel.select(expr) res = rel.fetchall() - assert res == [({'a': 1, 'b': 2},)] + assert res == [({"a": 1, "b": 2},)] def test_function_expression_udf(self): con = duckdb.connect() @@ -658,7 +660,7 @@ def test_function_expression_udf(self): def my_simple_func(a: int, b: int, c: int) -> int: return a + b + c - con.create_function('my_func', my_simple_func) + con.create_function("my_func", my_simple_func) rel = con.sql( """ @@ -668,10 +670,10 @@ def my_simple_func(a: int, b: int, c: int) -> int: 3 as c """ ) - col1 = ColumnExpression('a') - col2 = ColumnExpression('b') - col3 = ColumnExpression('c') - expr = FunctionExpression('my_func', col1, col2, col3) + col1 = ColumnExpression("a") + col2 = ColumnExpression("b") + col3 = ColumnExpression("c") + expr = FunctionExpression("my_func", col1, col2, col3) rel2 = rel.select(expr) res = rel2.fetchall() assert res == [(6,)] @@ -688,10 +690,10 @@ def test_function_expression_basic(self): ) tbl(text, start, "end") """ ) - expr = FunctionExpression('array_slice', "start", "text", "end") + expr = FunctionExpression("array_slice", "start", "text", "end") rel2 = rel.select(expr) res = rel2.fetchall() - assert res == [('tes',), ('his is',), ('di',)] + assert res == [("tes",), ("his is",), ("di",)] def test_column_expression_function_coverage(self): con = duckdb.connect() @@ -707,11 +709,11 @@ def test_column_expression_function_coverage(self): """ ) - rel = con.table('tbl') - expr = FunctionExpression('||', FunctionExpression('||', 'c0', 'c1'), 'c2') + rel = con.table("tbl") + expr = FunctionExpression("||", FunctionExpression("||", "c0", "c1"), "c2") rel2 = rel.select(expr) res = rel2.fetchall() - assert res == [('abc',), ('def',), ('ghi',)] + assert res == [("abc",), ("def",), ("ghi",)] def test_function_expression_aggregate(self): con = duckdb.connect() @@ -725,11 +727,11 @@ def test_function_expression_aggregate(self): ) tbl(text) """ ) - expr = FunctionExpression('first', 'text') + expr = FunctionExpression("first", "text") with pytest.raises( - duckdb.BinderException, match='Binder Error: Aggregates cannot be present in a Project relation!' + duckdb.BinderException, match="Binder Error: Aggregates cannot be present in a Project relation!" ): - rel2 = rel.select(expr) + rel.select(expr) def test_case_expression(self): con = duckdb.connect() @@ -743,9 +745,9 @@ def test_case_expression(self): """ ) - col1 = ColumnExpression('a') - col2 = ColumnExpression('b') - col3 = ColumnExpression('c') + col1 = ColumnExpression("a") + col2 = ColumnExpression("b") + col3 = ColumnExpression("c") const1 = ConstantExpression(IntegerValue(1)) # CASE WHEN col1 > 1 THEN 5 ELSE NULL @@ -796,17 +798,17 @@ def test_implicit_constant_conversion(self): def test_numeric_overflow(self): con = duckdb.connect() - rel = con.sql('select 3000::SHORT salary') + rel = con.sql("select 3000::SHORT salary") + expr = ColumnExpression("salary") * 100 + rel2 = rel.select(expr) with pytest.raises(duckdb.OutOfRangeException, match="Overflow in multiplication of INT16"): - expr = ColumnExpression("salary") * 100 - rel2 = rel.select(expr) - res = rel2.fetchall() + rel2.fetchall() + val = duckdb.Value(100, duckdb.typing.TINYINT) + expr2 = ColumnExpression("salary") * val + rel3 = rel.select(expr2) with pytest.raises(duckdb.OutOfRangeException, match="Overflow in multiplication of INT16"): - val = duckdb.Value(100, duckdb.typing.TINYINT) - expr = ColumnExpression("salary") * val - rel2 = rel.select(expr) - res = rel2.fetchall() + rel3.fetchall() def test_struct_column_expression(self): con = duckdb.connect() @@ -823,7 +825,7 @@ def test_filter_equality(self, filter_rel): rel2 = filter_rel.filter(expr) res = rel2.fetchall() assert len(res) == 2 - assert res == [(1, 'a'), (1, 'b')] + assert res == [(1, "a"), (1, "b")] def test_filter_not(self, filter_rel): expr = ColumnExpression("a") == 1 @@ -832,18 +834,18 @@ def test_filter_not(self, filter_rel): rel2 = filter_rel.filter(expr) res = rel2.fetchall() assert len(res) == 3 - assert res == [(2, 'b'), (3, 'c'), (4, 'a')] + assert res == [(2, "b"), (3, "c"), (4, "a")] def test_filter_and(self, filter_rel): expr = ColumnExpression("a") == 1 expr = ~expr # AND operator - expr = expr & ('b' != ConstantExpression('b')) + expr = expr & (ConstantExpression("b") != "b") rel2 = filter_rel.filter(expr) res = rel2.fetchall() assert len(res) == 2 - assert res == [(3, 'c'), (4, 'a')] + assert res == [(3, "c"), (4, "a")] def test_filter_or(self, filter_rel): # OR operator @@ -851,7 +853,7 @@ def test_filter_or(self, filter_rel): rel2 = filter_rel.filter(expr) res = rel2.fetchall() assert len(res) == 3 - assert res == [(1, 'a'), (1, 'b'), (4, 'a')] + assert res == [(1, "a"), (1, "b"), (4, "a")] def test_filter_mixed(self, filter_rel): # Mixed @@ -861,7 +863,7 @@ def test_filter_mixed(self, filter_rel): rel2 = filter_rel.filter(expr) res = rel2.fetchall() assert len(res) == 2 - assert res == [(1, 'a'), (4, 'a')] + assert res == [(1, "a"), (4, "a")] def test_empty_in(self, filter_rel): expr = ColumnExpression("a") @@ -884,7 +886,7 @@ def test_filter_in(self, filter_rel): rel2 = filter_rel.filter(expr) res = rel2.fetchall() assert len(res) == 3 - assert res == [(1, 'a'), (2, 'b'), (1, 'b')] + assert res == [(1, "a"), (2, "b"), (1, "b")] def test_filter_not_in(self, filter_rel): expr = ColumnExpression("a") @@ -894,7 +896,7 @@ def test_filter_not_in(self, filter_rel): rel2 = filter_rel.filter(expr) res = rel2.fetchall() assert len(res) == 2 - assert res == [(3, 'c'), (4, 'a')] + assert res == [(3, "c"), (4, "a")] # NOT IN expression expr = ColumnExpression("a") @@ -902,7 +904,7 @@ def test_filter_not_in(self, filter_rel): rel2 = filter_rel.filter(expr) res = rel2.fetchall() assert len(res) == 2 - assert res == [(3, 'c'), (4, 'a')] + assert res == [(3, "c"), (4, "a")] def test_null(self): con = duckdb.connect() @@ -924,7 +926,7 @@ def test_null(self): assert res == [(False,), (False,), (True,), (False,), (False,)] res2 = rel.filter(b.isnotnull()).fetchall() - assert res2 == [(1, 'a'), (2, 'b'), (4, 'c'), (5, 'a')] + assert res2 == [(1, "a"), (2, "b"), (4, "c"), (5, "a")] def test_sort(self): con = duckdb.connect() @@ -956,12 +958,12 @@ def test_sort(self): # Nulls first rel2 = rel.sort(b.desc().nulls_first()) res = rel2.b.fetchall() - assert res == [(None,), ('c',), ('b',), ('a',), ('a',)] + assert res == [(None,), ("c",), ("b",), ("a",), ("a",)] # Nulls last rel2 = rel.sort(b.desc().nulls_last()) res = rel2.b.fetchall() - assert res == [('c',), ('b',), ('a',), ('a',), (None,)] + assert res == [("c",), ("b",), ("a",), ("a",), (None,)] def test_aggregate(self): con = duckdb.connect() @@ -981,19 +983,18 @@ def test_aggregate_error(self): res = rel.aggregate([5]).execute().fetchone()[0] assert res == 5 + class MyClass: + def __init__(self) -> None: + pass + # Providing something that can not be converted into an expression is an error: with pytest.raises( - duckdb.InvalidInputException, match='Invalid Input Error: Please provide arguments of type Expression!' + duckdb.InvalidInputException, match="Invalid Input Error: Please provide arguments of type Expression!" ): - - class MyClass: - def __init__(self): - pass - - res = rel.aggregate([MyClass()]).fetchone()[0] + rel.aggregate([MyClass()]).fetchone()[0] with pytest.raises( duckdb.InvalidInputException, match="Please provide either a string or list of Expression objects, not ", ): - res = rel.aggregate(5).execute().fetchone() + rel.aggregate(5).execute().fetchone() diff --git a/tests/fast/test_filesystem.py b/tests/fast/test_filesystem.py index eaa86398..f9f08266 100644 --- a/tests/fast/test_filesystem.py +++ b/tests/fast/test_filesystem.py @@ -1,26 +1,22 @@ import logging import sys -from pathlib import Path +from pathlib import Path, PurePosixPath from shutil import copyfileobj -from typing import Callable, List -from os.path import exists -from pathlib import PurePosixPath +from typing import Callable + +import pytest import duckdb from duckdb import DuckDBPyConnection, InvalidInputException -from pytest import raises, importorskip, fixture, MonkeyPatch, mark -importorskip('fsspec', '2022.11.0') -from fsspec import filesystem, AbstractFileSystem -from fsspec.implementations.memory import MemoryFileSystem -from fsspec.implementations.local import LocalFileOpener, LocalFileSystem +fsspec = pytest.importorskip("fsspec", "2022.11.0") -FILENAME = 'integers.csv' +FILENAME = "integers.csv" logging.basicConfig(level=logging.DEBUG) -def intercept(monkeypatch: MonkeyPatch, obj: object, name: str) -> List[str]: +def intercept(monkeypatch: pytest.MonkeyPatch, obj: object, name: str) -> list[str]: error_occurred = [] orig = getattr(obj, name) @@ -29,25 +25,25 @@ def ceptor(*args, **kwargs): return orig(*args, **kwargs) except Exception as e: error_occurred.append(e) - raise e + raise monkeypatch.setattr(obj, name, ceptor) return error_occurred -@fixture() +@pytest.fixture def duckdb_cursor(): with duckdb.connect() as conn: yield conn -@fixture() +@pytest.fixture def memory(): - fs = filesystem('memory', skip_instance_cache=True) + fs = fsspec.filesystem("memory", skip_instance_cache=True) # ensure each instance is independent (to work around a weird quirk in fsspec) fs.store = {} - fs.pseudo_dirs = [''] + fs.pseudo_dirs = [""] # copy csv into memory filesystem add_file(fs) @@ -55,41 +51,41 @@ def memory(): def add_file(fs, filename=FILENAME): - with (Path(__file__).parent / 'data' / filename).open('rb') as source, fs.open(filename, 'wb') as dest: + with (Path(__file__).parent / "data" / filename).open("rb") as source, fs.open(filename, "wb") as dest: copyfileobj(source, dest) class TestPythonFilesystem: def test_unregister_non_existent_filesystem(self, duckdb_cursor: DuckDBPyConnection): - with raises(InvalidInputException): - duckdb_cursor.unregister_filesystem('fake') + with pytest.raises(InvalidInputException): + duckdb_cursor.unregister_filesystem("fake") - def test_memory_filesystem(self, duckdb_cursor: DuckDBPyConnection, memory: AbstractFileSystem): + def test_memory_filesystem(self, duckdb_cursor: DuckDBPyConnection, memory: fsspec.AbstractFileSystem): duckdb_cursor.register_filesystem(memory) - assert memory.protocol == 'memory' + assert memory.protocol == "memory" duckdb_cursor.execute(f"select * from 'memory://{FILENAME}'") assert duckdb_cursor.fetchall() == [(1, 10, 0), (2, 50, 30)] - duckdb_cursor.unregister_filesystem('memory') + duckdb_cursor.unregister_filesystem("memory") def test_reject_abstract_filesystem(self, duckdb_cursor: DuckDBPyConnection): - with raises(InvalidInputException): - duckdb_cursor.register_filesystem(AbstractFileSystem()) + with pytest.raises(InvalidInputException): + duckdb_cursor.register_filesystem(fsspec.AbstractFileSystem()) def test_unregister_builtin(self, require: Callable[[str], DuckDBPyConnection]): - duckdb_cursor = require('httpfs') - assert duckdb_cursor.filesystem_is_registered('S3FileSystem') == True - duckdb_cursor.unregister_filesystem('S3FileSystem') - assert duckdb_cursor.filesystem_is_registered('S3FileSystem') == False + duckdb_cursor = require("httpfs") + assert duckdb_cursor.filesystem_is_registered("S3FileSystem") + duckdb_cursor.unregister_filesystem("S3FileSystem") + assert not duckdb_cursor.filesystem_is_registered("S3FileSystem") def test_multiple_protocol_filesystems(self, duckdb_cursor: DuckDBPyConnection): - class ExtendedMemoryFileSystem(MemoryFileSystem): - protocol = ('file', 'local') + class ExtendedMemoryFileSystem(fsspec.implementations.memory.MemoryFileSystem): + protocol = ("file", "local") # defer to the original implementation that doesn't hardcode the protocol - _strip_protocol = classmethod(AbstractFileSystem._strip_protocol.__func__) + _strip_protocol = classmethod(fsspec.AbstractFileSystem._strip_protocol.__func__) memory = ExtendedMemoryFileSystem(skip_instance_cache=True) add_file(memory) @@ -99,56 +95,56 @@ class ExtendedMemoryFileSystem(MemoryFileSystem): assert duckdb_cursor.fetchall() == [(1, 10, 0), (2, 50, 30)] - def test_write(self, duckdb_cursor: DuckDBPyConnection, memory: AbstractFileSystem): + def test_write(self, duckdb_cursor: DuckDBPyConnection, memory: fsspec.AbstractFileSystem): duckdb_cursor.register_filesystem(memory) duckdb_cursor.execute("copy (select 1) to 'memory://01.csv' (FORMAT CSV, HEADER 0)") - assert memory.open('01.csv').read() == b'1\n' + assert memory.open("01.csv").read() == b"1\n" - def test_null_bytes(self, duckdb_cursor: DuckDBPyConnection, memory: AbstractFileSystem): - with memory.open('test.csv', 'wb') as fh: - fh.write(b'hello\n\0world\0') + def test_null_bytes(self, duckdb_cursor: DuckDBPyConnection, memory: fsspec.AbstractFileSystem): + with memory.open("test.csv", "wb") as fh: + fh.write(b"hello\n\0world\0") duckdb_cursor.register_filesystem(memory) - duckdb_cursor.execute('select * from read_csv("memory://test.csv", header = 0, quote = \'"\', escape = \'"\')') + duckdb_cursor.execute("select * from read_csv(\"memory://test.csv\", header = 0, quote = '\"', escape = '\"')") - assert duckdb_cursor.fetchall() == [('hello',), ('\0world\0',)] + assert duckdb_cursor.fetchall() == [("hello",), ("\0world\0",)] - def test_read_parquet(self, duckdb_cursor: DuckDBPyConnection, memory: AbstractFileSystem): - filename = 'binary_string.parquet' + def test_read_parquet(self, duckdb_cursor: DuckDBPyConnection, memory: fsspec.AbstractFileSystem): + filename = "binary_string.parquet" add_file(memory, filename) duckdb_cursor.register_filesystem(memory) duckdb_cursor.execute(f"select * from read_parquet('memory://{filename}')") - assert duckdb_cursor.fetchall() == [(b'foo',), (b'bar',), (b'baz',)] + assert duckdb_cursor.fetchall() == [(b"foo",), (b"bar",), (b"baz",)] - def test_write_parquet(self, duckdb_cursor: DuckDBPyConnection, memory: AbstractFileSystem): + def test_write_parquet(self, duckdb_cursor: DuckDBPyConnection, memory: fsspec.AbstractFileSystem): duckdb_cursor.register_filesystem(memory) - filename = 'output.parquet' + filename = "output.parquet" - duckdb_cursor.execute(f'''COPY (SELECT 1) TO 'memory://{filename}' (FORMAT PARQUET);''') + duckdb_cursor.execute(f"""COPY (SELECT 1) TO 'memory://{filename}' (FORMAT PARQUET);""") - assert memory.open(filename).read().startswith(b'PAR1') + assert memory.open(filename).read().startswith(b"PAR1") - def test_when_fsspec_not_installed(self, duckdb_cursor: DuckDBPyConnection, monkeypatch: MonkeyPatch): - monkeypatch.setitem(sys.modules, 'fsspec', None) + def test_when_fsspec_not_installed(self, duckdb_cursor: DuckDBPyConnection, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setitem(sys.modules, "fsspec", None) - with raises(ModuleNotFoundError): + with pytest.raises(ModuleNotFoundError): duckdb_cursor.register_filesystem(None) - @mark.skipif(sys.version_info < (3, 8), reason="ArrowFSWrapper requires python 3.8 or higher") + @pytest.mark.skipif(sys.version_info < (3, 8), reason="ArrowFSWrapper requires python 3.8 or higher") def test_arrow_fs_wrapper(self, tmp_path: Path, duckdb_cursor: DuckDBPyConnection): - fs = importorskip('pyarrow.fs') + fs = pytest.importorskip("pyarrow.fs") from fsspec.implementations.arrow import ArrowFSWrapper local = fs.LocalFileSystem() local_fsspec = ArrowFSWrapper(local, skip_instance_cache=True) # posix calls here required as ArrowFSWrapper only supports url-like paths (not Windows paths) filename = str(PurePosixPath(tmp_path.as_posix()) / "test.csv") - with local_fsspec.open(filename, mode='w') as f: + with local_fsspec.open(filename, mode="w") as f: f.write("a,b,c\n") f.write("1,2,3\n") f.write("4,5,6\n") @@ -158,114 +154,116 @@ def test_arrow_fs_wrapper(self, tmp_path: Path, duckdb_cursor: DuckDBPyConnectio assert duckdb_cursor.fetchall() == [(1, 2, 3), (4, 5, 6)] - def test_database_attach(self, tmp_path: Path, monkeypatch: MonkeyPatch): - db_path = str(tmp_path / 'hello.db') + def test_database_attach(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + db_path = tmp_path / "hello.db" # setup a database to attach later - with duckdb.connect(db_path) as conn: + with duckdb.connect(str(db_path)) as conn: conn.execute( - ''' + """ CREATE TABLE t (id int); INSERT INTO t VALUES (0) - ''' + """ ) - assert exists(db_path) + assert db_path.exists() with duckdb.connect() as conn: - fs = filesystem('file', skip_instance_cache=True) - write_errors = intercept(monkeypatch, LocalFileOpener, 'write') + fs = fsspec.filesystem("file", skip_instance_cache=True) + write_errors = intercept(monkeypatch, fsspec.implementations.local.LocalFileOpener, "write") conn.register_filesystem(fs) db_path_posix = str(PurePosixPath(tmp_path.as_posix()) / "hello.db") conn.execute(f"ATTACH 'file://{db_path_posix}'") - conn.execute('INSERT INTO hello.t VALUES (1)') + conn.execute("INSERT INTO hello.t VALUES (1)") - conn.execute('FROM hello.t') + conn.execute("FROM hello.t") assert conn.fetchall() == [(0,), (1,)] # duckdb sometimes seems to swallow write errors, so we use this to ensure that # isn't happening assert not write_errors - def test_copy_partition(self, duckdb_cursor: DuckDBPyConnection, memory: AbstractFileSystem): + def test_copy_partition(self, duckdb_cursor: DuckDBPyConnection, memory: fsspec.AbstractFileSystem): duckdb_cursor.register_filesystem(memory) duckdb_cursor.execute("copy (select 1 as a, 2 as b) to 'memory://root' (partition_by (a), HEADER 0)") - assert memory.open('/root/a=1/data_0.csv').read() == b'2\n' + assert memory.open("/root/a=1/data_0.csv").read() == b"2\n" - def test_copy_partition_with_columns_written(self, duckdb_cursor: DuckDBPyConnection, memory: AbstractFileSystem): + def test_copy_partition_with_columns_written( + self, duckdb_cursor: DuckDBPyConnection, memory: fsspec.AbstractFileSystem + ): duckdb_cursor.register_filesystem(memory) duckdb_cursor.execute( "copy (select 1 as a) to 'memory://root' (partition_by (a), HEADER 0, WRITE_PARTITION_COLUMNS)" ) - assert memory.open('/root/a=1/data_0.csv').read() == b'1\n' + assert memory.open("/root/a=1/data_0.csv").read() == b"1\n" - def test_read_hive_partition(self, duckdb_cursor: DuckDBPyConnection, memory: AbstractFileSystem): + def test_read_hive_partition(self, duckdb_cursor: DuckDBPyConnection, memory: fsspec.AbstractFileSystem): duckdb_cursor.register_filesystem(memory) duckdb_cursor.execute( "copy (select 2 as a, 3 as b, 4 as c) to 'memory://partition' (partition_by (a), HEADER 0)" ) - path = 'memory:///partition/*/*.csv' + path = "memory:///partition/*/*.csv" query = "SELECT * FROM read_csv_auto('" + path + "'" # hive partitioning - duckdb_cursor.execute(query + ', HIVE_PARTITIONING=1' + ');') + duckdb_cursor.execute(query + ", HIVE_PARTITIONING=1" + ");") assert duckdb_cursor.fetchall() == [(3, 4, 2)] # hive partitioning: auto detection - duckdb_cursor.execute(query + ');') + duckdb_cursor.execute(query + ");") assert duckdb_cursor.fetchall() == [(3, 4, 2)] # hive partitioning: cast to int - duckdb_cursor.execute(query + ', HIVE_PARTITIONING=1' + ', HIVE_TYPES_AUTOCAST=1' + ');') + duckdb_cursor.execute(query + ", HIVE_PARTITIONING=1" + ", HIVE_TYPES_AUTOCAST=1" + ");") assert duckdb_cursor.fetchall() == [(3, 4, 2)] # hive partitioning: no cast to int - duckdb_cursor.execute(query + ', HIVE_PARTITIONING=1' + ', HIVE_TYPES_AUTOCAST=0' + ');') - assert duckdb_cursor.fetchall() == [(3, 4, '2')] + duckdb_cursor.execute(query + ", HIVE_PARTITIONING=1" + ", HIVE_TYPES_AUTOCAST=0" + ");") + assert duckdb_cursor.fetchall() == [(3, 4, "2")] def test_read_hive_partition_with_columns_written( - self, duckdb_cursor: DuckDBPyConnection, memory: AbstractFileSystem + self, duckdb_cursor: DuckDBPyConnection, memory: fsspec.AbstractFileSystem ): duckdb_cursor.register_filesystem(memory) duckdb_cursor.execute( "copy (select 2 as a) to 'memory://partition' (partition_by (a), HEADER 0, WRITE_PARTITION_COLUMNS)" ) - path = 'memory:///partition/*/*.csv' + path = "memory:///partition/*/*.csv" query = "SELECT * FROM read_csv_auto('" + path + "'" # hive partitioning - duckdb_cursor.execute(query + ', HIVE_PARTITIONING=1' + ');') + duckdb_cursor.execute(query + ", HIVE_PARTITIONING=1" + ");") assert duckdb_cursor.fetchall() == [(2, 2)] # hive partitioning: auto detection - duckdb_cursor.execute(query + ');') + duckdb_cursor.execute(query + ");") assert duckdb_cursor.fetchall() == [(2, 2)] # hive partitioning: cast to int - duckdb_cursor.execute(query + ', HIVE_PARTITIONING=1' + ', HIVE_TYPES_AUTOCAST=1' + ');') + duckdb_cursor.execute(query + ", HIVE_PARTITIONING=1" + ", HIVE_TYPES_AUTOCAST=1" + ");") assert duckdb_cursor.fetchall() == [(2, 2)] # hive partitioning: no cast to int - duckdb_cursor.execute(query + ', HIVE_PARTITIONING=1' + ', HIVE_TYPES_AUTOCAST=0' + ');') - assert duckdb_cursor.fetchall() == [(2, '2')] + duckdb_cursor.execute(query + ", HIVE_PARTITIONING=1" + ", HIVE_TYPES_AUTOCAST=0" + ");") + assert duckdb_cursor.fetchall() == [(2, "2")] def test_parallel_union_by_name(self, tmp_path): - pa = importorskip('pyarrow') - pq = importorskip('pyarrow.parquet') - fsspec = importorskip('fsspec') + pa = pytest.importorskip("pyarrow") + pq = pytest.importorskip("pyarrow.parquet") + pytest.importorskip("fsspec") table1 = pa.Table.from_pylist( [ - {'time': 1719568210134107692, 'col1': 1}, + {"time": 1719568210134107692, "col1": 1}, ] ) table1_path = tmp_path / "table1.parquet" @@ -273,16 +271,16 @@ def test_parallel_union_by_name(self, tmp_path): table2 = pa.Table.from_pylist( [ - {'time': 1719568210134107692, 'col1': 1}, + {"time": 1719568210134107692, "col1": 1}, ] ) table2_path = tmp_path / "table2.parquet" pq.write_table(table2, table2_path) c = duckdb.connect() - c.register_filesystem(LocalFileSystem()) + c.register_filesystem(fsspec.implementations.local.LocalFileSystem()) - q = f"SELECT * FROM read_parquet('file://{tmp_path}/table*.parquet', union_by_name = TRUE) ORDER BY time DESC LIMIT 1" + q = f"SELECT * FROM read_parquet('file://{tmp_path}/table*.parquet', union_by_name = TRUE) ORDER BY time DESC LIMIT 1" # noqa: E501 res = c.sql(q).fetchall() assert res == [(1719568210134107692, 1)] diff --git a/tests/fast/test_get_table_names.py b/tests/fast/test_get_table_names.py index c11b8a65..161abed2 100644 --- a/tests/fast/test_get_table_names.py +++ b/tests/fast/test_get_table_names.py @@ -1,29 +1,30 @@ -import duckdb import pytest +import duckdb + -class TestGetTableNames(object): +class TestGetTableNames: def test_table_success(self, duckdb_cursor): conn = duckdb.connect() table_names = conn.get_table_names("SELECT * FROM my_table1, my_table2, my_table3") - assert table_names == {'my_table2', 'my_table3', 'my_table1'} + assert table_names == {"my_table2", "my_table3", "my_table1"} def test_table_fail(self, duckdb_cursor): conn = duckdb.connect() conn.close() with pytest.raises(duckdb.ConnectionException, match="Connection already closed"): - table_names = conn.get_table_names("SELECT * FROM my_table1, my_table2, my_table3") + conn.get_table_names("SELECT * FROM my_table1, my_table2, my_table3") def test_qualified_parameter_basic(self): conn = duckdb.connect() # Default (qualified=False) table_names = conn.get_table_names("SELECT * FROM test_table") - assert table_names == {'test_table'} + assert table_names == {"test_table"} # Explicit qualified=False table_names = conn.get_table_names("SELECT * FROM test_table", qualified=False) - assert table_names == {'test_table'} + assert table_names == {"test_table"} def test_qualified_parameter_schemas(self): conn = duckdb.connect() @@ -31,11 +32,11 @@ def test_qualified_parameter_schemas(self): # Default (qualified=False) query = "SELECT * FROM test_schema.schema_table, main_table" table_names = conn.get_table_names(query) - assert table_names == {'schema_table', 'main_table'} + assert table_names == {"schema_table", "main_table"} # Test with qualified names table_names = conn.get_table_names(query, qualified=True) - assert table_names == {'test_schema.schema_table', 'main_table'} + assert table_names == {"test_schema.schema_table", "main_table"} def test_qualified_parameter_catalogs(self): conn = duckdb.connect() @@ -45,11 +46,11 @@ def test_qualified_parameter_catalogs(self): # Default (qualified=False) table_names = conn.get_table_names(query) - assert table_names == {'catalog_table', 'regular_table'} + assert table_names == {"catalog_table", "regular_table"} # With qualified=True table_names = conn.get_table_names(query, qualified=True) - assert table_names == {'catalog1.test_schema.catalog_table', 'regular_table'} + assert table_names == {"catalog1.test_schema.catalog_table", "regular_table"} def test_qualified_parameter_quoted_identifiers(self): conn = duckdb.connect() @@ -59,7 +60,7 @@ def test_qualified_parameter_quoted_identifiers(self): # Default (qualified=False) table_names = conn.get_table_names(query) - assert table_names == {'Table.With.Dots', 'Table With Spaces'} + assert table_names == {"Table.With.Dots", "Table With Spaces"} # With qualified=True table_names = conn.get_table_names(query, qualified=True) @@ -67,45 +68,45 @@ def test_qualified_parameter_quoted_identifiers(self): def test_expanded_views(self): conn = duckdb.connect() - conn.execute('CREATE TABLE my_table(i INT)') - conn.execute('CREATE VIEW v1 AS SELECT * FROM my_table') + conn.execute("CREATE TABLE my_table(i INT)") + conn.execute("CREATE VIEW v1 AS SELECT * FROM my_table") # Test that v1 expands to my_table - query = 'SELECT col_a FROM v1' + query = "SELECT col_a FROM v1" # Default (qualified=False) table_names = conn.get_table_names(query) - assert table_names == {'my_table'} + assert table_names == {"my_table"} # With qualified=True table_names = conn.get_table_names(query, qualified=True) - assert table_names == {'my_table'} + assert table_names == {"my_table"} def test_expanded_views_with_schema(self): conn = duckdb.connect() - conn.execute('CREATE SCHEMA my_schema') - conn.execute('CREATE TABLE my_schema.my_table(i INT)') - conn.execute('CREATE VIEW v1 AS SELECT * FROM my_schema.my_table') + conn.execute("CREATE SCHEMA my_schema") + conn.execute("CREATE TABLE my_schema.my_table(i INT)") + conn.execute("CREATE VIEW v1 AS SELECT * FROM my_schema.my_table") # Test that v1 expands to my_table - query = 'SELECT col_a FROM v1' + query = "SELECT col_a FROM v1" # Default (qualified=False) table_names = conn.get_table_names(query) - assert table_names == {'my_table'} + assert table_names == {"my_table"} # With qualified=True table_names = conn.get_table_names(query, qualified=True) - assert table_names == {'my_schema.my_table'} + assert table_names == {"my_schema.my_table"} def test_select_function(self): conn = duckdb.connect() - query = 'SELECT EXTRACT(second FROM i) FROM timestamps;' + query = "SELECT EXTRACT(second FROM i) FROM timestamps;" # Default (qualified=False) table_names = conn.get_table_names(query) - assert table_names == {'timestamps'} + assert table_names == {"timestamps"} # With qualified=True table_names = conn.get_table_names(query, qualified=True) - assert table_names == {'timestamps'} + assert table_names == {"timestamps"} diff --git a/tests/fast/test_import_export.py b/tests/fast/test_import_export.py index 2fce1636..28748d6a 100644 --- a/tests/fast/test_import_export.py +++ b/tests/fast/test_import_export.py @@ -1,10 +1,10 @@ -import duckdb -import pytest -from os import path import shutil -import os from pathlib import Path +import pytest + +import duckdb + def export_database(export_location): # Create the db @@ -30,11 +30,15 @@ def import_database(import_location): def move_database(export_location, import_location): - assert path.exists(export_location) - assert path.exists(import_location) + export_dir = Path(export_location) + import_dir = Path(import_location) + assert export_dir.exists() + assert export_dir.is_dir() + assert import_dir.exists() + assert import_dir.is_dir() - for file in ['schema.sql', 'load.sql', 'tbl.csv']: - shutil.move(path.join(export_location, file), import_location) + for file in ["schema.sql", "load.sql", "tbl.csv"]: + shutil.move(export_dir / file, import_dir) def export_move_and_import(export_path, import_path): @@ -56,25 +60,24 @@ def export_and_import_empty_db(db_path, _): class TestDuckDBImportExport: - @pytest.mark.parametrize('routine', [export_move_and_import, export_and_import_empty_db]) + @pytest.mark.parametrize("routine", [export_move_and_import, export_and_import_empty_db]) def test_import_and_export(self, routine, tmp_path_factory): export_path = str(tmp_path_factory.mktemp("export_dbs", numbered=True)) import_path = str(tmp_path_factory.mktemp("import_dbs", numbered=True)) routine(export_path, import_path) def test_import_empty_db(self, tmp_path_factory): - import_path = str(tmp_path_factory.mktemp("empty_db", numbered=True)) + import_path = Path(tmp_path_factory.mktemp("empty_db", numbered=True)) # Create an empty db folder structure - Path(Path(import_path) / 'load.sql').touch() - Path(Path(import_path) / 'schema.sql').touch() + (import_path / "load.sql").touch() + (import_path / "schema.sql").touch() con = duckdb.connect() con.execute(f"import database '{import_path}'") # Put a single comment into the 'schema.sql' file - with open(Path(import_path) / 'schema.sql', 'w') as f: - f.write('--\n') + (import_path / "schema.sql").write_text("--\n") con.close() con = duckdb.connect() diff --git a/tests/fast/test_insert.py b/tests/fast/test_insert.py index 1465b68a..a61efd2e 100644 --- a/tests/fast/test_insert.py +++ b/tests/fast/test_insert.py @@ -1,12 +1,11 @@ -import duckdb -import tempfile -import os import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb -class TestInsert(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) +class TestInsert: + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_insert(self, pandas): test_df = pandas.DataFrame({"i": [1, 2, 3], "j": ["one", "two", "three"]}) # connect to an in-memory temporary database @@ -15,19 +14,19 @@ def test_insert(self, pandas): cursor = conn.cursor() conn.execute("CREATE TABLE test (i INTEGER, j STRING)") rel = conn.table("test") - rel.insert([1, 'one']) - rel.insert([2, 'two']) - rel.insert([3, 'three']) - rel_a3 = cursor.table('test').project('CAST(i as BIGINT)i, j').to_df() + rel.insert([1, "one"]) + rel.insert([2, "two"]) + rel.insert([3, "three"]) + rel_a3 = cursor.table("test").project("CAST(i as BIGINT)i, j").to_df() pandas.testing.assert_frame_equal(rel_a3, test_df) def test_insert_with_schema(self, duckdb_cursor): duckdb_cursor.sql("create schema not_main") duckdb_cursor.sql("create table not_main.tbl as select * from range(10)") - res = duckdb_cursor.table('not_main.tbl').fetchall() + res = duckdb_cursor.table("not_main.tbl").fetchall() assert len(res) == 10 - # FIXME: This is not currently supported - with pytest.raises(duckdb.CatalogException, match='Table with name tbl does not exist'): - duckdb_cursor.table('not_main.tbl').insert([42, 21, 1337]) + # TODO: This is not currently supported # noqa: TD002, TD003 + with pytest.raises(duckdb.CatalogException, match="Table with name tbl does not exist"): + duckdb_cursor.table("not_main.tbl").insert([42, 21, 1337]) diff --git a/tests/fast/test_json_logging.py b/tests/fast/test_json_logging.py index a7f305f3..3e1f184e 100644 --- a/tests/fast/test_json_logging.py +++ b/tests/fast/test_json_logging.py @@ -1,19 +1,20 @@ import json -import duckdb import pytest +import duckdb + def _parse_json_func(error_prefix: str): - """Helper to check that the error message is indeed parsable json""" + """Helper to check that the error message is indeed parsable json.""" - def parse_func(exception): + def parse_func(exception) -> bool: msg = exception.args[0] assert msg.startswith(error_prefix) json_str = msg.split(error_prefix, 1)[1] try: json.loads(json_str) - except: + except Exception: return False return True diff --git a/tests/fast/test_many_con_same_file.py b/tests/fast/test_many_con_same_file.py index 6b7362a6..705fbf9c 100644 --- a/tests/fast/test_many_con_same_file.py +++ b/tests/fast/test_many_con_same_file.py @@ -1,7 +1,10 @@ -import duckdb -import os +import contextlib +from pathlib import Path + import pytest +import duckdb + def get_tables(con): tbls = con.execute("SHOW TABLES").fetchall() @@ -11,10 +14,8 @@ def get_tables(con): def test_multiple_writes(): - try: - os.remove("test.db") - except: - pass + with contextlib.suppress(Exception): + Path("test.db").unlink() con1 = duckdb.connect("test.db") con2 = duckdb.connect("test.db") con1.execute("CREATE TABLE foo1 as SELECT 1 as a, 2 as b") @@ -23,15 +24,13 @@ def test_multiple_writes(): con1.close() con3 = duckdb.connect("test.db") tbls = get_tables(con3) - assert tbls == ['bar1', 'foo1'] + assert tbls == ["bar1", "foo1"] del con1 del con2 del con3 - try: - os.remove("test.db") - except: - pass + with contextlib.suppress(Exception): + Path("test.db").unlink() def test_multiple_writes_memory(): @@ -41,9 +40,9 @@ def test_multiple_writes_memory(): con2.execute("CREATE TABLE bar1 as SELECT 2 as a, 3 as b") con3 = duckdb.connect(":memory:") tbls = get_tables(con1) - assert tbls == ['foo1'] + assert tbls == ["foo1"] tbls = get_tables(con2) - assert tbls == ['bar1'] + assert tbls == ["bar1"] tbls = get_tables(con3) assert tbls == [] del con1 @@ -58,7 +57,7 @@ def test_multiple_writes_named_memory(): con2.execute("CREATE TABLE bar1 as SELECT 2 as a, 3 as b") con3 = duckdb.connect(":memory:1") tbls = get_tables(con3) - assert tbls == ['bar1', 'foo1'] + assert tbls == ["bar1", "foo1"] del con1 del con2 del con3 @@ -70,17 +69,17 @@ def test_diff_config(): duckdb.ConnectionException, match="Can't open a connection to same database file with a different configuration than existing connections", ): - con2 = duckdb.connect("test.db", True) + duckdb.connect("test.db", True) con1.close() del con1 def test_diff_config_extended(): - con1 = duckdb.connect("test.db", config={'null_order': 'NULLS FIRST'}) + con1 = duckdb.connect("test.db", config={"null_order": "NULLS FIRST"}) with pytest.raises( duckdb.ConnectionException, match="Can't open a connection to same database file with a different configuration than existing connections", ): - con2 = duckdb.connect("test.db") + duckdb.connect("test.db") con1.close() del con1 diff --git a/tests/fast/test_map.py b/tests/fast/test_map.py index 4dbd1a36..336b2775 100644 --- a/tests/fast/test_map.py +++ b/tests/fast/test_map.py @@ -1,44 +1,45 @@ -import duckdb -import numpy -import pytest -from datetime import date, timedelta import re -from conftest import NumpyPandas, ArrowPandas +from datetime import date, timedelta +from typing import NoReturn + +import pytest +from conftest import ArrowPandas, NumpyPandas + +import duckdb # column count differs from bind def evil1(df): if len(df) == 0: - return df['col0'].to_frame() + return df["col0"].to_frame() else: return df -class TestMap(object): - @pytest.mark.parametrize('pandas', [NumpyPandas()]) +class TestMap: + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_evil_map(self, duckdb_cursor, pandas): testrel = duckdb.values([1, 2]) - with pytest.raises(duckdb.InvalidInputException, match='Expected 1 columns from UDF, got 2'): - rel = testrel.map(evil1, schema={'i': str}) - df = rel.df() - print(df) + rel = testrel.map(evil1, schema={"i": str}) + with pytest.raises(duckdb.InvalidInputException, match="Expected 1 columns from UDF, got 2"): + rel.df() - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_map(self, duckdb_cursor, pandas): testrel = duckdb.values([1, 2]) conn = duckdb_cursor - conn.execute('CREATE TABLE t (a integer)') - empty_rel = conn.table('t') + conn.execute("CREATE TABLE t (a integer)") + empty_rel = conn.table("t") - newdf1 = testrel.map(lambda df: df['col0'].add(42).to_frame()) - newdf2 = testrel.map(lambda df: df['col0'].astype('string').to_frame()) - newdf3 = testrel.map(lambda df: df) + testrel.map(lambda df: df["col0"].add(42).to_frame()) + testrel.map(lambda df: df["col0"].astype("string").to_frame()) + testrel.map(lambda df: df) # column type differs from bind def evil2(df): result = df.copy(deep=True) if len(result) == 0: - result['col0'] = result['col0'].astype('double') + result["col0"] = result["col0"].astype("double") return result # column name differs from bind @@ -48,36 +49,36 @@ def evil3(df): return df # does not return a df - def evil4(df): + def evil4(df) -> int: return 42 # straight up throws exception - def evil5(df): + def evil5(df) -> NoReturn: raise TypeError def return_dataframe(df): - return pandas.DataFrame({'A': [1]}) + return pandas.DataFrame({"A": [1]}) def return_big_dataframe(df): - return pandas.DataFrame({'A': [1] * 5000}) + return pandas.DataFrame({"A": [1] * 5000}) - def return_none(df): + def return_none(df) -> None: return None def return_empty_df(df): return pandas.DataFrame() - with pytest.raises(duckdb.InvalidInputException, match='Expected 1 columns from UDF, got 2'): + with pytest.raises(duckdb.InvalidInputException, match="Expected 1 columns from UDF, got 2"): print(testrel.map(evil1).df()) - with pytest.raises(duckdb.InvalidInputException, match='UDF column type mismatch'): + with pytest.raises(duckdb.InvalidInputException, match="UDF column type mismatch"): print(testrel.map(evil2).df()) - with pytest.raises(duckdb.InvalidInputException, match='UDF column name mismatch'): + with pytest.raises(duckdb.InvalidInputException, match="UDF column name mismatch"): print(testrel.map(evil3).df()) with pytest.raises( - duckdb.InvalidInputException, match="Expected the UDF to return an object of type 'pandas.DataFrame'" + duckdb.InvalidInputException, match=r"Expected the UDF to return an object of type 'pandas\.DataFrame'" ): print(testrel.map(evil4).df()) @@ -92,19 +93,19 @@ def return_empty_df(df): with pytest.raises(TypeError): print(testrel.map().df()) - testrel.map(return_dataframe).df().equals(pandas.DataFrame({'A': [1]})) + testrel.map(return_dataframe).df().equals(pandas.DataFrame({"A": [1]})) with pytest.raises( - duckdb.InvalidInputException, match='UDF returned more than 2048 rows, which is not allowed.' + duckdb.InvalidInputException, match="UDF returned more than 2048 rows, which is not allowed" ): testrel.map(return_big_dataframe).df() - empty_rel.map(return_dataframe).df().equals(pandas.DataFrame({'A': []})) + empty_rel.map(return_dataframe).df().equals(pandas.DataFrame({"A": []})) - with pytest.raises(duckdb.InvalidInputException, match='No return value from Python function'): + with pytest.raises(duckdb.InvalidInputException, match="No return value from Python function"): testrel.map(return_none).df() - with pytest.raises(duckdb.InvalidInputException, match='Need a DataFrame with at least one column'): + with pytest.raises(duckdb.InvalidInputException, match="Need a DataFrame with at least one column"): testrel.map(return_empty_df).df() def test_map_with_object_column(self, duckdb_cursor): @@ -115,21 +116,21 @@ def return_with_no_modification(df): # when a dataframe with 'object' column is returned, we use the content to infer the type # when the dataframe is empty, this results in NULL, which is not desirable # in this case we assume the returned type should be the same as the input type - duckdb_cursor.values([b'1234']).map(return_with_no_modification).fetchall() + duckdb_cursor.values([b"1234"]).map(return_with_no_modification).fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_isse_3237(self, duckdb_cursor, pandas): def process(rel): def mapper(x): - dates = x['date'].to_numpy("datetime64[us]") - days = x['days_to_add'].to_numpy("int") + dates = x["date"].to_numpy("datetime64[us]") + days = x["days_to_add"].to_numpy("int") x["result1"] = pandas.Series( [pandas.to_datetime(y[0]).date() + timedelta(days=y[1].item()) for y in zip(dates, days)], - dtype='datetime64[us]', + dtype="datetime64[us]", ) x["result2"] = pandas.Series( [pandas.to_datetime(y[0]).date() + timedelta(days=-y[1].item()) for y in zip(dates, days)], - dtype='datetime64[us]', + dtype="datetime64[us]", ) return x @@ -140,22 +141,22 @@ def mapper(x): return rel df = pandas.DataFrame( - {'date': pandas.Series([date(2000, 1, 1), date(2000, 1, 2)], dtype="datetime64[us]"), 'days_to_add': [1, 2]} + {"date": pandas.Series([date(2000, 1, 1), date(2000, 1, 2)], dtype="datetime64[us]"), "days_to_add": [1, 2]} ) rel = duckdb.from_df(df) rel = process(rel) x = rel.fetchdf() - assert x['days_to_add'].to_numpy()[0] == 1 + assert x["days_to_add"].to_numpy()[0] == 1 def test_explicit_schema(self): def cast_to_string(df): - df['i'] = df['i'].astype(str) + df["i"] = df["i"].astype(str) return df con = duckdb.connect() - rel = con.sql('select i from range (10) tbl(i)') + rel = con.sql("select i from range (10) tbl(i)") assert rel.types[0] == duckdb.NUMBER - mapped_rel = rel.map(cast_to_string, schema={'i': str}) + mapped_rel = rel.map(cast_to_string, schema={"i": str}) assert mapped_rel.types[0] == duckdb.STRING def test_explicit_schema_returntype_mismatch(self): @@ -163,76 +164,77 @@ def does_nothing(df): return df con = duckdb.connect() - rel = con.sql('select i from range(10) tbl(i)') + rel = con.sql("select i from range(10) tbl(i)") # expects the mapper to return a string column - rel = rel.map(does_nothing, schema={'i': str}) + rel = rel.map(does_nothing, schema={"i": str}) with pytest.raises( duckdb.InvalidInputException, match=re.escape("UDF column type mismatch, expected [VARCHAR], got [BIGINT]") ): rel.fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_explicit_schema_name_mismatch(self, pandas): def renames_column(df): - return pandas.DataFrame({'a': df['i']}) + return pandas.DataFrame({"a": df["i"]}) con = duckdb.connect() - rel = con.sql('select i from range(10) tbl(i)') - rel = rel.map(renames_column, schema={'i': int}) - with pytest.raises(duckdb.InvalidInputException, match=re.escape('UDF column name mismatch')): + rel = con.sql("select i from range(10) tbl(i)") + rel = rel.map(renames_column, schema={"i": int}) + with pytest.raises(duckdb.InvalidInputException, match=re.escape("UDF column name mismatch")): rel.fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_explicit_schema_error(self, pandas): def no_op(df): return df con = duckdb.connect() - rel = con.sql('select 42') + rel = con.sql("select 42") with pytest.raises( duckdb.InvalidInputException, match=re.escape("Invalid Input Error: 'schema' should be given as a Dict[str, DuckDBType]"), ): rel.map(no_op, schema=[int]) - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_returns_non_dataframe(self, pandas): def returns_series(df): - return df.loc[:, 'i'] + return df.loc[:, "i"] con = duckdb.connect() - rel = con.sql('select i, i as j from range(10) tbl(i)') + rel = con.sql("select i, i as j from range(10) tbl(i)") with pytest.raises( duckdb.InvalidInputException, match=re.escape( - "Expected the UDF to return an object of type 'pandas.DataFrame', found '' instead" + "Expected the UDF to return an object of type 'pandas.DataFrame', found " + "'' instead" ), ): rel = rel.map(returns_series) - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_explicit_schema_columncount_mismatch(self, pandas): def returns_subset(df): - return pandas.DataFrame({'i': df.loc[:, 'i']}) + return pandas.DataFrame({"i": df.loc[:, "i"]}) con = duckdb.connect() - rel = con.sql('select i, i as j from range(10) tbl(i)') - rel = rel.map(returns_subset, schema={'i': int, 'j': int}) + rel = con.sql("select i, i as j from range(10) tbl(i)") + rel = rel.map(returns_subset, schema={"i": int, "j": int}) with pytest.raises( - duckdb.InvalidInputException, match='Invalid Input Error: Expected 2 columns from UDF, got 1' + duckdb.InvalidInputException, match="Invalid Input Error: Expected 2 columns from UDF, got 1" ): rel.fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_pyarrow_df(self, pandas): # PyArrow backed dataframes only exist on pandas >= 2.0.0 - _ = pytest.importorskip("pandas", "2.0.0") + pytest.importorskip("pandas", "2.0.0") def basic_function(df): # Create a pyarrow backed dataframe - df = pandas.DataFrame({'a': [5, 3, 2, 1, 2]}).convert_dtypes(dtype_backend='pyarrow') + df = pandas.DataFrame({"a": [5, 3, 2, 1, 2]}).convert_dtypes(dtype_backend="pyarrow") return df con = duckdb.connect() with pytest.raises(duckdb.InvalidInputException): - rel = con.sql('select 42').map(basic_function) + con.sql("select 42").map(basic_function) diff --git a/tests/fast/test_metatransaction.py b/tests/fast/test_metatransaction.py index 158bb6a9..35d7c239 100644 --- a/tests/fast/test_metatransaction.py +++ b/tests/fast/test_metatransaction.py @@ -7,10 +7,10 @@ NUMBER_OF_COLUMNS = 1 -class TestMetaTransaction(object): +class TestMetaTransaction: def test_fetchmany(self, duckdb_cursor): duckdb_cursor.execute("CREATE SEQUENCE id_seq") - column_names = ',\n'.join([f'column_{i} FLOAT' for i in range(1, NUMBER_OF_COLUMNS + 1)]) + column_names = ",\n".join([f"column_{i} FLOAT" for i in range(1, NUMBER_OF_COLUMNS + 1)]) create_table_query = f""" CREATE TABLE my_table ( id INTEGER DEFAULT nextval('id_seq'), @@ -23,7 +23,7 @@ def test_fetchmany(self, duckdb_cursor): for i in range(20): # Then insert a large amount of tuples, triggering a parallel execution data = np.random.rand(NUMBER_OF_ROWS, NUMBER_OF_COLUMNS) - columns = [f'Column_{i+1}' for i in range(NUMBER_OF_COLUMNS)] + columns = [f"Column_{i + 1}" for i in range(NUMBER_OF_COLUMNS)] df = pd.DataFrame(data, columns=columns) df_columns = ", ".join(df.columns) # This gets executed in parallel, causing NextValFunction to be called in parallel diff --git a/tests/fast/test_multi_statement.py b/tests/fast/test_multi_statement.py index db82eaf3..76ac0b4b 100644 --- a/tests/fast/test_multi_statement.py +++ b/tests/fast/test_multi_statement.py @@ -1,42 +1,40 @@ -import duckdb -import os +import contextlib import shutil +from pathlib import Path + +import duckdb -class TestMultiStatement(object): +class TestMultiStatement: def test_multi_statement(self, duckdb_cursor): - import duckdb - - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") # test empty statement - con.execute('') + con.execute("") # run multiple statements in one call to execute con.execute( - ''' + """ CREATE TABLE integers(i integer); insert into integers select * from range(10); select * from integers; - ''' + """ ) results = [x[0] for x in con.fetchall()] assert results == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] # test export/import - export_location = os.path.join(os.getcwd(), 'duckdb_pytest_dir_export') - try: + export_location = Path.cwd() / "duckdb_pytest_dir_export" + with contextlib.suppress(Exception): shutil.rmtree(export_location) - except: - pass - con.execute('CREATE TABLE integers2(i INTEGER)') - con.execute('INSERT INTO integers2 VALUES (1), (5), (7), (1928)') - con.execute("EXPORT DATABASE '%s'" % (export_location,)) + con.execute("CREATE TABLE integers2(i INTEGER)") + con.execute("INSERT INTO integers2 VALUES (1), (5), (7), (1928)") + con.execute(f"EXPORT DATABASE '{export_location}'") # reset connection - con = duckdb.connect(':memory:') - con.execute("IMPORT DATABASE '%s'" % (export_location,)) - integers = [x[0] for x in con.execute('SELECT * FROM integers').fetchall()] - integers2 = [x[0] for x in con.execute('SELECT * FROM integers2').fetchall()] + con = duckdb.connect(":memory:") + con.execute(f"IMPORT DATABASE '{export_location}'") + integers = [x[0] for x in con.execute("SELECT * FROM integers").fetchall()] + integers2 = [x[0] for x in con.execute("SELECT * FROM integers2").fetchall()] assert integers == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] assert integers2 == [1, 5, 7, 1928] shutil.rmtree(export_location) diff --git a/tests/fast/test_multithread.py b/tests/fast/test_multithread.py index 1ffdfc25..dfefb918 100644 --- a/tests/fast/test_multithread.py +++ b/tests/fast/test_multithread.py @@ -1,13 +1,13 @@ import platform -import duckdb -import pytest -import threading import queue as Queue +import threading +from pathlib import Path + import numpy as np -from conftest import NumpyPandas, ArrowPandas -import os -from typing import List +import pytest +from conftest import ArrowPandas, NumpyPandas +import duckdb pytestmark = pytest.mark.xfail( condition=platform.system() == "Emscripten", @@ -16,16 +16,16 @@ def connect_duck(duckdb_conn): - out = duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)').fetchall() + out = duckdb_conn.execute("select i from (values (42), (84), (NULL), (128)) tbl(i)").fetchall() assert out == [(42,), (84,), (None,), (128,)] -def everything_succeeded(results: List[bool]): - return all([result == True for result in results]) +def everything_succeeded(results: list[bool]): + return all(result for result in results) class DuckDBThreaded: - def __init__(self, duckdb_insert_thread_count, thread_function, pandas): + def __init__(self, duckdb_insert_thread_count, thread_function, pandas) -> None: self.duckdb_insert_thread_count = duckdb_insert_thread_count self.threads = [] self.thread_function = thread_function @@ -36,22 +36,22 @@ def multithread_test(self, result_verification=everything_succeeded): queue = Queue.Queue() # Create all threads - for i in range(0, self.duckdb_insert_thread_count): + for i in range(self.duckdb_insert_thread_count): self.threads.append( threading.Thread( - target=self.thread_function, args=(duckdb_conn, queue, self.pandas), name='duckdb_thread_' + str(i) + target=self.thread_function, args=(duckdb_conn, queue, self.pandas), name="duckdb_thread_" + str(i) ) ) # Record for every thread if they succeeded or not thread_results = [] - for i in range(0, len(self.threads)): + for i in range(len(self.threads)): self.threads[i].start() thread_result: bool = queue.get(timeout=60) thread_results.append(thread_result) # Finish all threads - for i in range(0, len(self.threads)): + for i in range(len(self.threads)): self.threads[i].join() # Assert that the results are what we expected @@ -60,9 +60,9 @@ def multithread_test(self, result_verification=everything_succeeded): def execute_query_same_connection(duckdb_conn, queue, pandas): try: - out = duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)') + duckdb_conn.execute("select i from (values (42), (84), (NULL), (128)) tbl(i)") queue.put(False) - except: + except Exception: queue.put(True) @@ -70,9 +70,9 @@ def execute_query(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() try: - duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)') + duckdb_conn.execute("select i from (values (42), (84), (NULL), (128)) tbl(i)") queue.put(True) - except: + except Exception: queue.put(False) @@ -80,9 +80,9 @@ def insert_runtime_error(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() try: - duckdb_conn.execute('insert into T values (42), (84), (NULL), (128)') + duckdb_conn.execute("insert into T values (42), (84), (NULL), (128)") queue.put(False) - except: + except Exception: queue.put(True) @@ -104,9 +104,9 @@ def execute_many_query(duckdb_conn, queue, pandas): ) # Larger example that inserts many records at a time purchases = [ - ('2006-03-28', 'BUY', 'IBM', 1000, 45.00), - ('2006-04-05', 'BUY', 'MSFT', 1000, 72.00), - ('2006-04-06', 'SELL', 'IBM', 500, 53.00), + ("2006-03-28", "BUY", "IBM", 1000, 45.00), + ("2006-04-05", "BUY", "MSFT", 1000, 72.00), + ("2006-04-06", "SELL", "IBM", 500, 53.00), ] duckdb_conn.executemany( """ @@ -115,7 +115,7 @@ def execute_many_query(duckdb_conn, queue, pandas): purchases, ) queue.put(True) - except: + except Exception: queue.put(False) @@ -123,9 +123,9 @@ def fetchone_query(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() try: - duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)').fetchone() + duckdb_conn.execute("select i from (values (42), (84), (NULL), (128)) tbl(i)").fetchone() queue.put(True) - except: + except Exception: queue.put(False) @@ -133,9 +133,9 @@ def fetchall_query(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() try: - duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)').fetchall() + duckdb_conn.execute("select i from (values (42), (84), (NULL), (128)) tbl(i)").fetchall() queue.put(True) - except: + except Exception: queue.put(False) @@ -145,7 +145,7 @@ def conn_close(duckdb_conn, queue, pandas): try: duckdb_conn.close() queue.put(True) - except: + except Exception: queue.put(False) @@ -153,9 +153,9 @@ def fetchnp_query(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() try: - duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)').fetchnumpy() + duckdb_conn.execute("select i from (values (42), (84), (NULL), (128)) tbl(i)").fetchnumpy() queue.put(True) - except: + except Exception: queue.put(False) @@ -163,9 +163,9 @@ def fetchdf_query(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() try: - duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)').fetchdf() + duckdb_conn.execute("select i from (values (42), (84), (NULL), (128)) tbl(i)").fetchdf() queue.put(True) - except: + except Exception: queue.put(False) @@ -173,9 +173,9 @@ def fetchdf_chunk_query(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() try: - duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)').fetch_df_chunk() + duckdb_conn.execute("select i from (values (42), (84), (NULL), (128)) tbl(i)").fetch_df_chunk() queue.put(True) - except: + except Exception: queue.put(False) @@ -183,9 +183,9 @@ def fetch_arrow_query(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() try: - duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)').fetch_arrow_table() + duckdb_conn.execute("select i from (values (42), (84), (NULL), (128)) tbl(i)").fetch_arrow_table() queue.put(True) - except: + except Exception: queue.put(False) @@ -193,9 +193,9 @@ def fetch_record_batch_query(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() try: - duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)').fetch_record_batch() + duckdb_conn.execute("select i from (values (42), (84), (NULL), (128)) tbl(i)").fetch_record_batch() queue.put(True) - except: + except Exception: queue.put(False) @@ -205,12 +205,12 @@ def transaction_query(duckdb_conn, queue, pandas): duckdb_conn.execute("CREATE TABLE T ( i INTEGER)") try: duckdb_conn.begin() - duckdb_conn.execute('insert into T values (42), (84), (NULL), (128)') + duckdb_conn.execute("insert into T values (42), (84), (NULL), (128)") duckdb_conn.rollback() - duckdb_conn.execute('insert into T values (42), (84), (NULL), (128)') + duckdb_conn.execute("insert into T values (42), (84), (NULL), (128)") duckdb_conn.commit() queue.put(True) - except: + except Exception: queue.put(False) @@ -218,47 +218,47 @@ def df_append(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() duckdb_conn.execute("CREATE TABLE T ( i INTEGER)") - df = pandas.DataFrame(np.random.randint(0, 100, size=15), columns=['A']) + df = pandas.DataFrame(np.random.randint(0, 100, size=15), columns=["A"]) try: - duckdb_conn.append('T', df) + duckdb_conn.append("T", df) queue.put(True) - except: + except Exception: queue.put(False) def df_register(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() - df = pandas.DataFrame(np.random.randint(0, 100, size=15), columns=['A']) + df = pandas.DataFrame(np.random.randint(0, 100, size=15), columns=["A"]) try: - duckdb_conn.register('T', df) + duckdb_conn.register("T", df) queue.put(True) - except: + except Exception: queue.put(False) def df_unregister(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() - df = pandas.DataFrame(np.random.randint(0, 100, size=15), columns=['A']) + df = pandas.DataFrame(np.random.randint(0, 100, size=15), columns=["A"]) try: - duckdb_conn.register('T', df) - duckdb_conn.unregister('T') + duckdb_conn.register("T", df) + duckdb_conn.unregister("T") queue.put(True) - except: + except Exception: queue.put(False) def arrow_register_unregister(duckdb_conn, queue, pandas): # Get a new connection - pa = pytest.importorskip('pyarrow') + pa = pytest.importorskip("pyarrow") duckdb_conn = duckdb.connect() - arrow_tbl = pa.Table.from_pydict({'my_column': pa.array([1, 2, 3, 4, 5], type=pa.int64())}) + arrow_tbl = pa.Table.from_pydict({"my_column": pa.array([1, 2, 3, 4, 5], type=pa.int64())}) try: - duckdb_conn.register('T', arrow_tbl) - duckdb_conn.unregister('T') + duckdb_conn.register("T", arrow_tbl) + duckdb_conn.unregister("T") queue.put(True) - except: + except Exception: queue.put(False) @@ -267,9 +267,9 @@ def table(duckdb_conn, queue, pandas): duckdb_conn = duckdb.connect() duckdb_conn.execute("CREATE TABLE T ( i INTEGER)") try: - out = duckdb_conn.table('T') + duckdb_conn.table("T") queue.put(True) - except: + except Exception: queue.put(False) @@ -279,9 +279,9 @@ def view(duckdb_conn, queue, pandas): duckdb_conn.execute("CREATE TABLE T ( i INTEGER)") duckdb_conn.execute("CREATE VIEW V as (SELECT * FROM T)") try: - out = duckdb_conn.values([5, 'five']) + duckdb_conn.values([5, "five"]) queue.put(True) - except: + except Exception: queue.put(False) @@ -289,9 +289,9 @@ def values(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() try: - out = duckdb_conn.values([5, 'five']) + duckdb_conn.values([5, "five"]) queue.put(True) - except: + except Exception: queue.put(False) @@ -299,68 +299,68 @@ def from_query(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() try: - out = duckdb_conn.from_query("select i from (values (42), (84), (NULL), (128)) tbl(i)") + duckdb_conn.from_query("select i from (values (42), (84), (NULL), (128)) tbl(i)") queue.put(True) - except: + except Exception: queue.put(False) def from_df(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() - df = pandas.DataFrame(['bla', 'blabla'] * 10, columns=['A']) + df = pandas.DataFrame(["bla", "blabla"] * 10, columns=["A"]) # noqa: F841 try: - out = duckdb_conn.execute("select * from df").fetchall() + duckdb_conn.execute("select * from df").fetchall() queue.put(True) - except: + except Exception: queue.put(False) def from_arrow(duckdb_conn, queue, pandas): # Get a new connection - pa = pytest.importorskip('pyarrow') + pa = pytest.importorskip("pyarrow") duckdb_conn = duckdb.connect() - arrow_tbl = pa.Table.from_pydict({'my_column': pa.array([1, 2, 3, 4, 5], type=pa.int64())}) + arrow_tbl = pa.Table.from_pydict({"my_column": pa.array([1, 2, 3, 4, 5], type=pa.int64())}) try: - out = duckdb_conn.from_arrow(arrow_tbl) + duckdb_conn.from_arrow(arrow_tbl) queue.put(True) - except: + except Exception: queue.put(False) def from_csv_auto(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() - filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'integers.csv') + filename = str(Path(__file__).parent / "data" / "integers.csv") try: - out = duckdb_conn.from_csv_auto(filename) + duckdb_conn.from_csv_auto(filename) queue.put(True) - except: + except Exception: queue.put(False) def from_parquet(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() - filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'binary_string.parquet') + filename = str(Path(__file__).parent / "data" / "binary_string.parquet") try: - out = duckdb_conn.from_parquet(filename) + duckdb_conn.from_parquet(filename) queue.put(True) - except: + except Exception: queue.put(False) -def description(duckdb_conn, queue, pandas): +def description(_, queue, __): # Get a new connection duckdb_conn = duckdb.connect() - duckdb_conn.execute('CREATE TABLE test (i bool, j TIME, k VARCHAR)') + duckdb_conn.execute("CREATE TABLE test (i bool, j TIME, k VARCHAR)") duckdb_conn.execute("INSERT INTO test VALUES (TRUE, '01:01:01', 'bla' )") rel = duckdb_conn.table("test") rel.execute() try: - rel.description + rel.description # noqa: B018 queue.put(True) - except: + except Exception: queue.put(False) @@ -368,145 +368,143 @@ def cursor(duckdb_conn, queue, pandas): # Get a new connection cx = duckdb_conn.cursor() try: - cx.execute('CREATE TABLE test (i bool, j TIME, k VARCHAR)') + cx.execute("CREATE TABLE test (i bool, j TIME, k VARCHAR)") queue.put(False) - except: + except Exception: queue.put(True) -class TestDuckMultithread(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) +class TestDuckMultithread: + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_execute(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, execute_query, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_execute_many(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, execute_many_query, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_fetchone(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, fetchone_query, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_fetchall(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, fetchall_query, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_close(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, conn_close, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_fetchnp(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, fetchnp_query, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_fetchdf(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, fetchdf_query, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_fetchdfchunk(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, fetchdf_chunk_query, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_fetcharrow(self, duckdb_cursor, pandas): - pytest.importorskip('pyarrow') + pytest.importorskip("pyarrow") duck_threads = DuckDBThreaded(10, fetch_arrow_query, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_fetch_record_batch(self, duckdb_cursor, pandas): - pytest.importorskip('pyarrow') + pytest.importorskip("pyarrow") duck_threads = DuckDBThreaded(10, fetch_record_batch_query, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_transaction(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, transaction_query, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_df_append(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, df_append, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_df_register(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, df_register, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_df_unregister(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, df_unregister, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_arrow_register_unregister(self, duckdb_cursor, pandas): - pytest.importorskip('pyarrow') + pytest.importorskip("pyarrow") duck_threads = DuckDBThreaded(10, arrow_register_unregister, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_table(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, table, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_view(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, view, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_values(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, values, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_from_query(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, from_query, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_from_DF(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, from_df, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_from_arrow(self, duckdb_cursor, pandas): - pytest.importorskip('pyarrow') + pytest.importorskip("pyarrow") duck_threads = DuckDBThreaded(10, from_arrow, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_from_csv_auto(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, from_csv_auto, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_from_parquet(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, from_parquet, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_description(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, description, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_cursor(self, duckdb_cursor, pandas): - def only_some_succeed(results: List[bool]): - if not any([result == True for result in results]): - return False - if all([result == True for result in results]): + def only_some_succeed(results: list[bool]) -> bool: + if not any(result for result in results): return False - return True + return not all(result for result in results) duck_threads = DuckDBThreaded(10, cursor, pandas) duck_threads.multithread_test(only_some_succeed) diff --git a/tests/fast/test_non_default_conn.py b/tests/fast/test_non_default_conn.py index bc9fa5f0..97b67fe8 100644 --- a/tests/fast/test_non_default_conn.py +++ b/tests/fast/test_non_default_conn.py @@ -1,11 +1,13 @@ -import pandas as pd -import numpy as np -import duckdb +import importlib import os import tempfile +import pandas as pd + +import duckdb + -class TestNonDefaultConn(object): +class TestNonDefaultConn: def test_values(self, duckdb_cursor): duckdb_cursor.execute("create table t (a integer)") duckdb.values([1], connection=duckdb_cursor).insert_into("t") @@ -18,40 +20,38 @@ def test_query(self, duckdb_cursor): assert duckdb_cursor.from_query("select count(*) from t").execute().fetchall()[0] == (1,) def test_from_csv(self, duckdb_cursor): - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 duckdb_cursor.execute("create table t (a integer)") duckdb_cursor.execute("insert into t values (1)") test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) test_df.to_csv(temp_file_name, index=False) rel = duckdb_cursor.from_csv_auto(temp_file_name) - assert rel.query('t_2', 'select count(*) from t inner join t_2 on (a = i)').fetchall()[0] == (1,) + assert rel.query("t_2", "select count(*) from t inner join t_2 on (a = i)").fetchall()[0] == (1,) def test_from_parquet(self, duckdb_cursor): - try: - import pyarrow as pa - except ImportError: + if not importlib.util.find_spec("pyarrow"): return - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 duckdb_cursor.execute("create table t (a integer)") duckdb_cursor.execute("insert into t values (1)") test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) test_df.to_parquet(temp_file_name, index=False) rel = duckdb_cursor.from_parquet(temp_file_name) - assert rel.query('t_2', 'select count(*) from t inner join t_2 on (a = i)').fetchall()[0] == (1,) + assert rel.query("t_2", "select count(*) from t inner join t_2 on (a = i)").fetchall()[0] == (1,) def test_from_df(self, duckdb_cursor): duckdb_cursor.execute("create table t (a integer)") duckdb_cursor.execute("insert into t values (1)") test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) rel = duckdb.df(test_df, connection=duckdb_cursor) - assert rel.query('t_2', 'select count(*) from t inner join t_2 on (a = i)').fetchall()[0] == (1,) + assert rel.query("t_2", "select count(*) from t inner join t_2 on (a = i)").fetchall()[0] == (1,) rel = duckdb_cursor.from_df(test_df) - assert rel.query('t_2', 'select count(*) from t inner join t_2 on (a = i)').fetchall()[0] == (1,) + assert rel.query("t_2", "select count(*) from t inner join t_2 on (a = i)").fetchall()[0] == (1,) def test_from_arrow(self, duckdb_cursor): try: import pyarrow as pa - except: + except Exception: return duckdb_cursor.execute("create table t (a integer)") @@ -59,55 +59,55 @@ def test_from_arrow(self, duckdb_cursor): test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) test_arrow = pa.Table.from_pandas(test_df) rel = duckdb_cursor.from_arrow(test_arrow) - assert rel.query('t_2', 'select count(*) from t inner join t_2 on (a = i)').fetchall()[0] == (1,) + assert rel.query("t_2", "select count(*) from t inner join t_2 on (a = i)").fetchall()[0] == (1,) rel = duckdb.arrow(test_arrow, connection=duckdb_cursor) - assert rel.query('t_2', 'select count(*) from t inner join t_2 on (a = i)').fetchall()[0] == (1,) + assert rel.query("t_2", "select count(*) from t inner join t_2 on (a = i)").fetchall()[0] == (1,) def test_filter_df(self, duckdb_cursor): duckdb_cursor.execute("create table t (a integer)") duckdb_cursor.execute("insert into t values (1), (4)") test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) rel = duckdb.filter(test_df, "i < 2", connection=duckdb_cursor) - assert rel.query('t_2', 'select count(*) from t inner join t_2 on (a = i)').fetchall()[0] == (1,) + assert rel.query("t_2", "select count(*) from t inner join t_2 on (a = i)").fetchall()[0] == (1,) def test_project_df(self, duckdb_cursor): duckdb_cursor.execute("create table t (a integer)") duckdb_cursor.execute("insert into t values (1), (4)") test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": [1, 2, 3, 4]}) rel = duckdb.project(test_df, "i", connection=duckdb_cursor) - assert rel.query('t_2', 'select * from t inner join t_2 on (a = i)').fetchall()[0] == (1, 1) + assert rel.query("t_2", "select * from t inner join t_2 on (a = i)").fetchall()[0] == (1, 1) def test_agg_df(self, duckdb_cursor): duckdb_cursor.execute("create table t (a integer)") duckdb_cursor.execute("insert into t values (1), (4)") test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": [1, 2, 3, 4]}) rel = duckdb.aggregate(test_df, "count(*) as i", connection=duckdb_cursor) - assert rel.query('t_2', 'select * from t inner join t_2 on (a = i)').fetchall()[0] == (4, 4) + assert rel.query("t_2", "select * from t inner join t_2 on (a = i)").fetchall()[0] == (4, 4) def test_distinct_df(self, duckdb_cursor): duckdb_cursor.execute("create table t (a integer)") duckdb_cursor.execute("insert into t values (1)") test_df = pd.DataFrame.from_dict({"i": [1, 1, 2, 3, 4]}) rel = duckdb.distinct(test_df, connection=duckdb_cursor) - assert rel.query('t_2', 'select * from t inner join t_2 on (a = i)').fetchall()[0] == (1, 1) + assert rel.query("t_2", "select * from t inner join t_2 on (a = i)").fetchall()[0] == (1, 1) def test_limit_df(self, duckdb_cursor): duckdb_cursor.execute("create table t (a integer)") duckdb_cursor.execute("insert into t values (1),(4)") test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) rel = duckdb.limit(test_df, 1, connection=duckdb_cursor) - assert rel.query('t_2', 'select * from t inner join t_2 on (a = i)').fetchall()[0] == (1, 1) + assert rel.query("t_2", "select * from t inner join t_2 on (a = i)").fetchall()[0] == (1, 1) def test_query_df(self, duckdb_cursor): duckdb_cursor.execute("create table t (a integer)") duckdb_cursor.execute("insert into t values (1),(4)") test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) - rel = duckdb.query_df(test_df, 't_2', 'select * from t inner join t_2 on (a = i)', connection=duckdb_cursor) + rel = duckdb.query_df(test_df, "t_2", "select * from t inner join t_2 on (a = i)", connection=duckdb_cursor) assert rel.fetchall()[0] == (1, 1) def test_query_order(self, duckdb_cursor): duckdb_cursor.execute("create table t (a integer)") duckdb_cursor.execute("insert into t values (1),(4)") test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) - rel = duckdb.order(test_df, 'i', connection=duckdb_cursor) - assert rel.query('t_2', 'select * from t inner join t_2 on (a = i)').fetchall()[0] == (1, 1) + rel = duckdb.order(test_df, "i", connection=duckdb_cursor) + assert rel.query("t_2", "select * from t inner join t_2 on (a = i)").fetchall()[0] == (1, 1) diff --git a/tests/fast/test_parameter_list.py b/tests/fast/test_parameter_list.py index 032b1b9c..22413999 100644 --- a/tests/fast/test_parameter_list.py +++ b/tests/fast/test_parameter_list.py @@ -1,9 +1,10 @@ -import duckdb import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb -class TestParameterList(object): +class TestParameterList: def test_bool(self, duckdb_cursor): conn = duckdb.connect() conn.execute("create table bool_table (a bool)") @@ -11,23 +12,23 @@ def test_bool(self, duckdb_cursor): res = conn.execute("select count(*) from bool_table where a =?", [True]) assert res.fetchone()[0] == 1 - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_exception(self, duckdb_cursor, pandas): conn = duckdb.connect() df_in = pandas.DataFrame( { - 'numbers': [1, 2, 3, 4, 5], + "numbers": [1, 2, 3, 4, 5], } ) conn.execute("create table bool_table (a bool)") conn.execute("insert into bool_table values (TRUE)") - with pytest.raises(duckdb.NotImplementedException, match='Unable to transform'): - res = conn.execute("select count(*) from bool_table where a =?", [df_in]) + with pytest.raises(duckdb.NotImplementedException, match="Unable to transform"): + conn.execute("select count(*) from bool_table where a =?", [df_in]) def test_explicit_nan_param(self): con = duckdb.default_connection() - res = con.execute('select isnan(cast(? as double))', (float("nan"),)) - assert res.fetchone()[0] == True + res = con.execute("select isnan(cast(? as double))", (float("nan"),)) + assert res.fetchone()[0] def test_string_parameter(self, duckdb_cursor): conn = duckdb.connect() diff --git a/tests/fast/test_parquet.py b/tests/fast/test_parquet.py index 51d8d276..3f6b1889 100644 --- a/tests/fast/test_parquet.py +++ b/tests/fast/test_parquet.py @@ -1,54 +1,54 @@ -import duckdb +from pathlib import Path + import pytest -import os -import tempfile -import pandas as pd + +import duckdb VARCHAR = duckdb.typing.VARCHAR BIGINT = duckdb.typing.BIGINT -filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'binary_string.parquet') +filename = str(Path(__file__).parent / "data" / "binary_string.parquet") @pytest.fixture(scope="session") def tmp_parquets(tmp_path_factory): - tmp_dir = tmp_path_factory.mktemp('parquets', numbered=True) - tmp_parquets = [str(tmp_dir / ('tmp' + str(i) + '.parquet')) for i in range(1, 4)] + tmp_dir = tmp_path_factory.mktemp("parquets", numbered=True) + tmp_parquets = [str(tmp_dir / ("tmp" + str(i) + ".parquet")) for i in range(1, 4)] return tmp_parquets -class TestParquet(object): +class TestParquet: def test_scan_binary(self, duckdb_cursor): conn = duckdb.connect() res = conn.execute("SELECT typeof(#1) FROM parquet_scan('" + filename + "') limit 1").fetchall() - assert res[0] == ('BLOB',) + assert res[0] == ("BLOB",) res = conn.execute("SELECT * FROM parquet_scan('" + filename + "')").fetchall() - assert res[0] == (b'foo',) + assert res[0] == (b"foo",) def test_from_parquet_binary(self, duckdb_cursor): rel = duckdb.from_parquet(filename) - assert rel.types == ['BLOB'] + assert rel.types == ["BLOB"] res = rel.execute().fetchall() - assert res[0] == (b'foo',) + assert res[0] == (b"foo",) def test_scan_binary_as_string(self, duckdb_cursor): conn = duckdb.connect() res = conn.execute( "SELECT typeof(#1) FROM parquet_scan('" + filename + "',binary_as_string=True) limit 1" ).fetchall() - assert res[0] == ('VARCHAR',) + assert res[0] == ("VARCHAR",) res = conn.execute("SELECT * FROM parquet_scan('" + filename + "',binary_as_string=True)").fetchall() - assert res[0] == ('foo',) + assert res[0] == ("foo",) def test_from_parquet_binary_as_string(self, duckdb_cursor): rel = duckdb.from_parquet(filename, True) assert rel.types == [VARCHAR] res = rel.execute().fetchall() - assert res[0] == ('foo',) + assert res[0] == ("foo",) def test_from_parquet_file_row_number(self, duckdb_cursor): rel = duckdb.from_parquet(filename, binary_as_string=True, file_row_number=True) @@ -56,7 +56,7 @@ def test_from_parquet_file_row_number(self, duckdb_cursor): res = rel.execute().fetchall() assert res[0] == ( - 'foo', + "foo", 0, ) @@ -66,7 +66,7 @@ def test_from_parquet_filename(self, duckdb_cursor): res = rel.execute().fetchall() assert res[0] == ( - 'foo', + "foo", filename, ) @@ -75,7 +75,7 @@ def test_from_parquet_list_binary_as_string(self, duckdb_cursor): assert rel.types == [VARCHAR] res = rel.execute().fetchall() - assert res[0] == ('foo',) + assert res[0] == ("foo",) def test_from_parquet_list_file_row_number(self, duckdb_cursor): rel = duckdb.from_parquet([filename], binary_as_string=True, file_row_number=True) @@ -83,7 +83,7 @@ def test_from_parquet_list_file_row_number(self, duckdb_cursor): res = rel.execute().fetchall() assert res[0] == ( - 'foo', + "foo", 0, ) @@ -93,41 +93,41 @@ def test_from_parquet_list_filename(self, duckdb_cursor): res = rel.execute().fetchall() assert res[0] == ( - 'foo', + "foo", filename, ) def test_parquet_binary_as_string_pragma(self, duckdb_cursor): conn = duckdb.connect() res = conn.execute("SELECT typeof(#1) FROM parquet_scan('" + filename + "') limit 1").fetchall() - assert res[0] == ('BLOB',) + assert res[0] == ("BLOB",) res = conn.execute("SELECT * FROM parquet_scan('" + filename + "')").fetchall() - assert res[0] == (b'foo',) + assert res[0] == (b"foo",) conn.execute("PRAGMA binary_as_string=1") res = conn.execute("SELECT typeof(#1) FROM parquet_scan('" + filename + "') limit 1").fetchall() - assert res[0] == ('VARCHAR',) + assert res[0] == ("VARCHAR",) res = conn.execute("SELECT * FROM parquet_scan('" + filename + "')").fetchall() - assert res[0] == ('foo',) + assert res[0] == ("foo",) res = conn.execute( "SELECT typeof(#1) FROM parquet_scan('" + filename + "',binary_as_string=False) limit 1" ).fetchall() - assert res[0] == ('BLOB',) + assert res[0] == ("BLOB",) res = conn.execute("SELECT * FROM parquet_scan('" + filename + "',binary_as_string=False)").fetchall() - assert res[0] == (b'foo',) + assert res[0] == (b"foo",) conn.execute("PRAGMA binary_as_string=0") res = conn.execute("SELECT typeof(#1) FROM parquet_scan('" + filename + "') limit 1").fetchall() - assert res[0] == ('BLOB',) + assert res[0] == ("BLOB",) res = conn.execute("SELECT * FROM parquet_scan('" + filename + "')").fetchall() - assert res[0] == (b'foo',) + assert res[0] == (b"foo",) def test_from_parquet_binary_as_string_default_conn(self, duckdb_cursor): duckdb.execute("PRAGMA binary_as_string=1") @@ -136,7 +136,7 @@ def test_from_parquet_binary_as_string_default_conn(self, duckdb_cursor): assert rel.types == [VARCHAR] res = rel.execute().fetchall() - assert res[0] == ('foo',) + assert res[0] == ("foo",) def test_from_parquet_union_by_name(self, tmp_parquets): conn = duckdb.connect() @@ -159,7 +159,7 @@ def test_from_parquet_union_by_name(self, tmp_parquets): + "' (format 'parquet');" ) - rel = duckdb.from_parquet(tmp_parquets, union_by_name=True).order('a') + rel = duckdb.from_parquet(tmp_parquets, union_by_name=True).order("a") assert rel.execute().fetchall() == [ ( 1, diff --git a/tests/fast/test_pypi_cleanup.py b/tests/fast/test_pypi_cleanup.py index 6e1460e2..3b32547e 100644 --- a/tests/fast/test_pypi_cleanup.py +++ b/tests/fast/test_pypi_cleanup.py @@ -1,10 +1,10 @@ #!/usr/bin/env python3 -""" -Unit tests for pypi_cleanup.py +"""Unit tests for pypi_cleanup.py. Run with: python -m pytest test_pypi_cleanup.py -v """ +import logging import os from unittest.mock import Mock, patch @@ -14,52 +14,63 @@ duckdb_packaging = pytest.importorskip("duckdb_packaging") -from duckdb_packaging.pypi_cleanup import ( - PyPICleanup, CsrfParser, PyPICleanupError, AuthenticationError, ValidationError, - setup_logging, validate_username, create_argument_parser, session_with_retries, - load_credentials, validate_arguments, main +from duckdb_packaging.pypi_cleanup import ( # noqa: E402 + AuthenticationError, + CleanMode, + CsrfParser, + PyPICleanup, + PyPICleanupError, + ValidationError, + create_argument_parser, + load_credentials, + main, + session_with_retries, + setup_logging, + validate_arguments, + validate_username, ) + class TestValidation: """Test input validation functions.""" - + def test_validate_username_valid(self): """Test valid usernames.""" assert validate_username("user123") == "user123" assert validate_username(" user.name ") == "user.name" assert validate_username("test-user_name") == "test-user_name" assert validate_username("a") == "a" - + def test_validate_username_invalid(self): """Test invalid usernames.""" from argparse import ArgumentTypeError - + with pytest.raises(ArgumentTypeError, match="cannot be empty"): validate_username("") - + with pytest.raises(ArgumentTypeError, match="cannot be empty"): validate_username(" ") - + with pytest.raises(ArgumentTypeError, match="too long"): validate_username("a" * 101) - + with pytest.raises(ArgumentTypeError, match="Invalid username format"): validate_username("-invalid") - + with pytest.raises(ArgumentTypeError, match="Invalid username format"): validate_username("invalid-") - + def test_validate_arguments_dry_run(self): """Test argument validation for dry run mode.""" args = Mock(dry_run=True, username=None, max_nightlies=2) validate_arguments(args) # Should not raise - + def test_validate_arguments_live_mode_no_username(self): """Test argument validation for live mode without username.""" args = Mock(dry_run=False, username=None, max_nightlies=2) with pytest.raises(ValidationError, match="username is required"): validate_arguments(args) - + def test_validate_arguments_negative_nightlies(self): """Test argument validation with negative max nightlies.""" args = Mock(dry_run=True, username="test", max_nightlies=-1) @@ -69,31 +80,25 @@ def test_validate_arguments_negative_nightlies(self): class TestCredentials: """Test credential loading.""" - - def test_load_credentials_dry_run(self): - """Test credential loading in dry run mode.""" - password, otp = load_credentials(dry_run=True) - assert password is None - assert otp is None - - @patch.dict(os.environ, {'PYPI_CLEANUP_PASSWORD': 'test_pass', 'PYPI_CLEANUP_OTP': 'test_otp'}) + + @patch.dict(os.environ, {"PYPI_CLEANUP_PASSWORD": "test_pass", "PYPI_CLEANUP_OTP": "test_otp"}) def test_load_credentials_live_mode_success(self): """Test successful credential loading in live mode.""" - password, otp = load_credentials(dry_run=False) - assert password == 'test_pass' - assert otp == 'test_otp' - + password, otp = load_credentials() + assert password == "test_pass" + assert otp == "test_otp" + @patch.dict(os.environ, {}, clear=True) def test_load_credentials_missing_password(self): """Test credential loading with missing password.""" with pytest.raises(ValidationError, match="PYPI_CLEANUP_PASSWORD"): - load_credentials(dry_run=False) - - @patch.dict(os.environ, {'PYPI_CLEANUP_PASSWORD': 'test_pass'}) + load_credentials() + + @patch.dict(os.environ, {"PYPI_CLEANUP_PASSWORD": "test_pass"}) def test_load_credentials_missing_otp(self): """Test credential loading with missing OTP.""" with pytest.raises(ValidationError, match="PYPI_CLEANUP_OTP"): - load_credentials(dry_run=False) + load_credentials() class TestUtilities: @@ -105,56 +110,56 @@ def test_create_session_with_retries(self): assert isinstance(session, requests.Session) # Verify retry adapter is mounted adapter = session.get_adapter("https://example.com") - assert hasattr(adapter, 'max_retries') - retries = getattr(adapter, 'max_retries') + assert hasattr(adapter, "max_retries") + retries = adapter.max_retries assert isinstance(retries, Retry) - @patch('duckdb_packaging.pypi_cleanup.logging.basicConfig') + @patch("duckdb_packaging.pypi_cleanup.logging.basicConfig") def test_setup_logging_normal(self, mock_basicConfig): """Test logging setup in normal mode.""" - setup_logging(verbose=False) + setup_logging() mock_basicConfig.assert_called_once() call_args = mock_basicConfig.call_args[1] - assert call_args['level'] == 20 # INFO level + assert call_args["level"] == 20 # INFO level - @patch('duckdb_packaging.pypi_cleanup.logging.basicConfig') + @patch("duckdb_packaging.pypi_cleanup.logging.basicConfig") def test_setup_logging_verbose(self, mock_basicConfig): """Test logging setup in verbose mode.""" - setup_logging(verbose=True) + setup_logging(level=logging.DEBUG) mock_basicConfig.assert_called_once() call_args = mock_basicConfig.call_args[1] - assert call_args['level'] == 10 # DEBUG level + assert call_args["level"] == 10 # DEBUG level class TestCsrfParser: """Test CSRF token parser.""" - + def test_csrf_parser_simple_form(self): """Test parsing CSRF token from simple form.""" - html = ''' + html = """
        - ''' + """ parser = CsrfParser("/test") parser.feed(html) assert parser.csrf == "abc123" - + def test_csrf_parser_multiple_forms(self): """Test parsing CSRF token when multiple forms exist.""" - html = ''' + html = """
        - ''' + """ parser = CsrfParser("/test") parser.feed(html) assert parser.csrf == "correct" - + def test_csrf_parser_no_token(self): """Test parser when no CSRF token is found.""" html = '
        ' @@ -165,36 +170,72 @@ def test_csrf_parser_no_token(self): class TestPyPICleanup: """Test the main PyPICleanup class.""" + @pytest.fixture def cleanup_dryrun_max_2(self) -> PyPICleanup: - return PyPICleanup("https://test.pypi.org/", False, 2) + return PyPICleanup("https://test.pypi.org/", CleanMode.LIST_ONLY, 2) @pytest.fixture def cleanup_dryrun_max_0(self) -> PyPICleanup: - return PyPICleanup("https://test.pypi.org/", False, 0) + return PyPICleanup("https://test.pypi.org/", CleanMode.LIST_ONLY, 0) @pytest.fixture def cleanup_max_2(self) -> PyPICleanup: - return PyPICleanup("https://test.pypi.org/", True, 2, - username="", password="", otp="") + return PyPICleanup( + "https://test.pypi.org/", CleanMode.DELETE, 2, username="", password="", otp="" + ) def test_determine_versions_to_delete_max_2(self, cleanup_dryrun_max_2): start_state = { "0.1.0", - "1.0.0.dev1", "1.0.0.dev2", "1.0.0.rc1", "1.0.0", - "1.0.1.dev3", "1.0.1.dev5", "1.0.1.dev8", "1.0.1.dev13", "1.0.1.dev21", "1.0.1", - "1.1.0.dev34", "1.1.0.dev54", "1.1.0.dev88", "1.1.0", "1.1.0.post1", - "1.1.1.dev142", "1.1.1.dev230", "1.1.1.dev372", - "2.0.0.dev602", "2.0.0.rc1", "2.0.0.rc2", "2.0.0.rc3", "2.0.0.rc4", "2.0.0", - "2.0.1.dev974", "2.0.1.rc1", "2.0.1.rc2", "2.0.1.rc3", + "1.0.0.dev1", + "1.0.0.dev2", + "1.0.0.rc1", + "1.0.0", + "1.0.1.dev3", + "1.0.1.dev5", + "1.0.1.dev8", + "1.0.1.dev13", + "1.0.1.dev21", + "1.0.1", + "1.1.0.dev34", + "1.1.0.dev54", + "1.1.0.dev88", + "1.1.0", + "1.1.0.post1", + "1.1.1.dev142", + "1.1.1.dev230", + "1.1.1.dev372", + "2.0.0.dev602", + "2.0.0.rc1", + "2.0.0.rc2", + "2.0.0.rc3", + "2.0.0.rc4", + "2.0.0", + "2.0.1.dev974", + "2.0.1.rc1", + "2.0.1.rc2", + "2.0.1.rc3", } expected_deletions = { - "1.0.0.dev1", "1.0.0.dev2", "1.0.0.rc1", - "1.0.1.dev3", "1.0.1.dev5", "1.0.1.dev8", "1.0.1.dev13", "1.0.1.dev21", - "1.1.0.dev34", "1.1.0.dev54", "1.1.0.dev88", + "1.0.0.dev1", + "1.0.0.dev2", + "1.0.0.rc1", + "1.0.1.dev3", + "1.0.1.dev5", + "1.0.1.dev8", + "1.0.1.dev13", + "1.0.1.dev21", + "1.1.0.dev34", + "1.1.0.dev54", + "1.1.0.dev88", "1.1.1.dev142", - "2.0.0.dev602", "2.0.0.rc1", "2.0.0.rc2", "2.0.0.rc3", "2.0.0.rc4", - "2.0.1.dev974" + "2.0.0.dev602", + "2.0.0.rc1", + "2.0.0.rc2", + "2.0.0.rc3", + "2.0.0.rc4", + "2.0.1.dev974", } versions_to_delete = cleanup_dryrun_max_2._determine_versions_to_delete(start_state) assert versions_to_delete == expected_deletions @@ -202,35 +243,82 @@ def test_determine_versions_to_delete_max_2(self, cleanup_dryrun_max_2): def test_determine_versions_to_delete_max_0(self, cleanup_dryrun_max_0): start_state = { "0.1.0", - "1.0.0.dev1", "1.0.0.dev2", "1.0.0.rc1", "1.0.0", - "1.0.1.dev3", "1.0.1.dev5", "1.0.1.dev8", "1.0.1.dev13", "1.0.1.dev21", "1.0.1", - "1.1.0.dev34", "1.1.0.dev54", "1.1.0.dev88", "1.1.0", "1.1.0.post1", - "1.1.1.dev142", "1.1.1.dev230", "1.1.1.dev372", - "2.0.0.dev602", "2.0.0.rc1", "2.0.0.rc2", "2.0.0.rc3", "2.0.0.rc4", "2.0.0", - "2.0.1.dev974", "2.0.1.rc1", "2.0.1.rc2", "2.0.1.rc3", + "1.0.0.dev1", + "1.0.0.dev2", + "1.0.0.rc1", + "1.0.0", + "1.0.1.dev3", + "1.0.1.dev5", + "1.0.1.dev8", + "1.0.1.dev13", + "1.0.1.dev21", + "1.0.1", + "1.1.0.dev34", + "1.1.0.dev54", + "1.1.0.dev88", + "1.1.0", + "1.1.0.post1", + "1.1.1.dev142", + "1.1.1.dev230", + "1.1.1.dev372", + "2.0.0.dev602", + "2.0.0.rc1", + "2.0.0.rc2", + "2.0.0.rc3", + "2.0.0.rc4", + "2.0.0", + "2.0.1.dev974", + "2.0.1.rc1", + "2.0.1.rc2", + "2.0.1.rc3", } expected_deletions = { - "1.0.0.dev1", "1.0.0.dev2", "1.0.0.rc1", - "1.0.1.dev3", "1.0.1.dev5", "1.0.1.dev8", "1.0.1.dev13", "1.0.1.dev21", - "1.1.0.dev34", "1.1.0.dev54", "1.1.0.dev88", - "1.1.1.dev142", "1.1.1.dev230", "1.1.1.dev372", - "2.0.0.dev602", "2.0.0.rc1", "2.0.0.rc2", "2.0.0.rc3", "2.0.0.rc4", - "2.0.1.dev974" + "1.0.0.dev1", + "1.0.0.dev2", + "1.0.0.rc1", + "1.0.1.dev3", + "1.0.1.dev5", + "1.0.1.dev8", + "1.0.1.dev13", + "1.0.1.dev21", + "1.1.0.dev34", + "1.1.0.dev54", + "1.1.0.dev88", + "1.1.1.dev142", + "1.1.1.dev230", + "1.1.1.dev372", + "2.0.0.dev602", + "2.0.0.rc1", + "2.0.0.rc2", + "2.0.0.rc3", + "2.0.0.rc4", + "2.0.1.dev974", } versions_to_delete = cleanup_dryrun_max_0._determine_versions_to_delete(start_state) assert versions_to_delete == expected_deletions def test_determine_versions_to_delete_only_devs_max_2(self, cleanup_dryrun_max_2): start_state = { - "1.0.0.dev1", "1.0.0.dev2", - "1.0.1.dev3", "1.0.1.dev5", "1.0.1.dev8", "1.0.1.dev13", "1.0.1.dev21", - "1.1.0.dev34", "1.1.0.dev54", "1.1.0.dev88", - "1.1.1.dev142", "1.1.1.dev230", "1.1.1.dev372", + "1.0.0.dev1", + "1.0.0.dev2", + "1.0.1.dev3", + "1.0.1.dev5", + "1.0.1.dev8", + "1.0.1.dev13", + "1.0.1.dev21", + "1.1.0.dev34", + "1.1.0.dev54", + "1.1.0.dev88", + "1.1.1.dev142", + "1.1.1.dev230", + "1.1.1.dev372", "2.0.0.dev602", "2.0.1.dev974", } expected_deletions = { - "1.0.1.dev3", "1.0.1.dev5", "1.0.1.dev8", + "1.0.1.dev3", + "1.0.1.dev5", + "1.0.1.dev8", "1.1.0.dev34", "1.1.1.dev142", } @@ -239,19 +327,28 @@ def test_determine_versions_to_delete_only_devs_max_2(self, cleanup_dryrun_max_2 def test_determine_versions_to_delete_only_devs_max_0_fails(self, cleanup_dryrun_max_0): start_state = { - "1.0.0.dev1", "1.0.0.dev2", - "1.0.1.dev3", "1.0.1.dev5", "1.0.1.dev8", "1.0.1.dev13", "1.0.1.dev21", - "1.1.0.dev34", "1.1.0.dev54", "1.1.0.dev88", - "1.1.1.dev142", "1.1.1.dev230", "1.1.1.dev372", + "1.0.0.dev1", + "1.0.0.dev2", + "1.0.1.dev3", + "1.0.1.dev5", + "1.0.1.dev8", + "1.0.1.dev13", + "1.0.1.dev21", + "1.1.0.dev34", + "1.1.0.dev54", + "1.1.0.dev88", + "1.1.1.dev142", + "1.1.1.dev230", + "1.1.1.dev372", "2.0.0.dev602", "2.0.1.dev974", } with pytest.raises(PyPICleanupError, match="Safety check failed"): cleanup_dryrun_max_0._determine_versions_to_delete(start_state) - @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._delete_versions') - @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._fetch_released_versions') - @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._determine_versions_to_delete') + @patch("duckdb_packaging.pypi_cleanup.PyPICleanup._delete_versions") + @patch("duckdb_packaging.pypi_cleanup.PyPICleanup._fetch_released_versions") + @patch("duckdb_packaging.pypi_cleanup.PyPICleanup._determine_versions_to_delete") def test_execute_cleanup_dry_run(self, mock_determine, mock_fetch, mock_delete, cleanup_dryrun_max_2): mock_fetch.return_value = {"1.0.0.dev1"} mock_determine.return_value = {"1.0.0.dev1"} @@ -264,14 +361,14 @@ def test_execute_cleanup_dry_run(self, mock_determine, mock_fetch, mock_delete, mock_determine.assert_called_once() mock_delete.assert_not_called() - @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._fetch_released_versions') + @patch("duckdb_packaging.pypi_cleanup.PyPICleanup._fetch_released_versions") def test_execute_cleanup_no_releases(self, mock_fetch, cleanup_dryrun_max_2): mock_fetch.return_value = {} with session_with_retries() as session: result = cleanup_dryrun_max_2._execute_cleanup(session) assert result == 0 - @patch('requests.Session.get') + @patch("requests.Session.get") def test_fetch_released_versions_success(self, mock_get, cleanup_dryrun_max_2): """Test successful package release fetching.""" mock_response = Mock() @@ -288,19 +385,21 @@ def test_fetch_released_versions_success(self, mock_get, cleanup_dryrun_max_2): assert releases == {"1.0.0", "1.0.0.dev1"} - @patch('requests.Session.get') + @patch("requests.Session.get") def test_fetch_released_versions_not_found(self, mock_get, cleanup_dryrun_max_2): """Test package release fetching when package not found.""" mock_response = Mock() mock_response.raise_for_status.side_effect = requests.HTTPError("404") mock_get.return_value = mock_response - with pytest.raises(PyPICleanupError, match="Failed to fetch package information"): - with session_with_retries() as session: - cleanup_dryrun_max_2._fetch_released_versions(session) + with ( + pytest.raises(PyPICleanupError, match="Failed to fetch package information"), + session_with_retries() as session, + ): + cleanup_dryrun_max_2._fetch_released_versions(session) - @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._get_csrf_token') - @patch('requests.Session.post') + @patch("duckdb_packaging.pypi_cleanup.PyPICleanup._get_csrf_token") + @patch("requests.Session.post") def test_authenticate_success(self, mock_post, mock_csrf, cleanup_max_2): """Test successful authentication.""" mock_csrf.return_value = "csrf123" @@ -313,11 +412,11 @@ def test_authenticate_success(self, mock_post, mock_csrf, cleanup_max_2): mock_csrf.assert_called_once_with(session, "/account/login/") mock_post.assert_called_once() - assert mock_post.call_args.args[0].endswith('/account/login/') + assert mock_post.call_args.args[0].endswith("/account/login/") - @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._get_csrf_token') - @patch('requests.Session.post') - @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._handle_two_factor_auth') + @patch("duckdb_packaging.pypi_cleanup.PyPICleanup._get_csrf_token") + @patch("requests.Session.post") + @patch("duckdb_packaging.pypi_cleanup.PyPICleanup._handle_two_factor_auth") def test_authenticate_with_2fa(self, mock_2fa, mock_post, mock_csrf, cleanup_max_2): mock_csrf.return_value = "csrf123" mock_response = Mock() @@ -332,7 +431,7 @@ def test_authenticate_missing_credentials(self, cleanup_dryrun_max_2): with pytest.raises(AuthenticationError, match="Username and password are required"): cleanup_dryrun_max_2._authenticate(None) - @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._delete_single_version') + @patch("duckdb_packaging.pypi_cleanup.PyPICleanup._delete_single_version") def test_delete_versions_success(self, mock_delete, cleanup_max_2): """Test successful version deletion.""" versions = {"1.0.0.dev1", "1.0.0.dev2"} @@ -343,7 +442,7 @@ def test_delete_versions_success(self, mock_delete, cleanup_max_2): assert mock_delete.call_count == 2 - @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._delete_single_version') + @patch("duckdb_packaging.pypi_cleanup.PyPICleanup._delete_single_version") def test_delete_versions_partial_failure(self, mock_delete, cleanup_max_2): """Test version deletion with partial failures.""" versions = {"1.0.0.dev1", "1.0.0.dev2"} @@ -360,75 +459,75 @@ def test_delete_single_version_safety_check(self, cleanup_max_2): class TestArgumentParser: """Test command line argument parsing.""" - + def test_argument_parser_creation(self): """Test argument parser creation.""" parser = create_argument_parser() assert parser.prog is not None - + def test_parse_args_prod_dry_run(self): """Test parsing arguments for production dry run.""" parser = create_argument_parser() - args = parser.parse_args(['--prod', '--dry-run']) - + args = parser.parse_args(["--prod", "--dry-run"]) + assert args.prod is True assert args.test is False assert args.dry_run is True assert args.max_nightlies == 2 assert args.verbose is False - + def test_parse_args_test_with_username(self): """Test parsing arguments for test with username.""" parser = create_argument_parser() - args = parser.parse_args(['--test', '-u', 'testuser', '--verbose']) - + args = parser.parse_args(["--test", "-u", "testuser", "--verbose"]) + assert args.test is True assert args.prod is False - assert args.username == 'testuser' + assert args.username == "testuser" assert args.verbose is True - + def test_parse_args_missing_host(self): """Test parsing arguments with missing host selection.""" parser = create_argument_parser() - + with pytest.raises(SystemExit): - parser.parse_args(['--dry-run']) # Missing --prod or --test + parser.parse_args(["--dry-run"]) # Missing --prod or --test class TestMainFunction: """Test the main function.""" - - @patch('duckdb_packaging.pypi_cleanup.setup_logging') - @patch('duckdb_packaging.pypi_cleanup.PyPICleanup') - @patch.dict(os.environ, {'PYPI_CLEANUP_PASSWORD': 'test', 'PYPI_CLEANUP_OTP': 'test'}) + + @patch("duckdb_packaging.pypi_cleanup.setup_logging") + @patch("duckdb_packaging.pypi_cleanup.PyPICleanup") + @patch.dict(os.environ, {"PYPI_CLEANUP_PASSWORD": "test", "PYPI_CLEANUP_OTP": "test"}) def test_main_success(self, mock_cleanup_class, mock_setup_logging): """Test successful main function execution.""" mock_cleanup = Mock() mock_cleanup.run.return_value = 0 mock_cleanup_class.return_value = mock_cleanup - - with patch('sys.argv', ['pypi_cleanup.py', '--test', '-u', 'testuser']): + + with patch("sys.argv", ["pypi_cleanup.py", "--test", "-u", "testuser"]): result = main() - + assert result == 0 mock_setup_logging.assert_called_once() mock_cleanup.run.assert_called_once() - - @patch('duckdb_packaging.pypi_cleanup.setup_logging') + + @patch("duckdb_packaging.pypi_cleanup.setup_logging") def test_main_validation_error(self, mock_setup_logging): """Test main function with validation error.""" - with patch('sys.argv', ['pypi_cleanup.py', '--test']): # Missing username for live mode + with patch("sys.argv", ["pypi_cleanup.py", "--test"]): # Missing username for live mode result = main() - + assert result == 2 # Validation error exit code - - @patch('duckdb_packaging.pypi_cleanup.setup_logging') - @patch('duckdb_packaging.pypi_cleanup.validate_arguments') + + @patch("duckdb_packaging.pypi_cleanup.setup_logging") + @patch("duckdb_packaging.pypi_cleanup.validate_arguments") def test_main_keyboard_interrupt(self, mock_validate, mock_setup_logging): """Test main function with keyboard interrupt.""" mock_validate.side_effect = KeyboardInterrupt() - - with patch('sys.argv', ['pypi_cleanup.py', '--test', '--dry-run']): + + with patch("sys.argv", ["pypi_cleanup.py", "--test", "--dry-run"]): result = main() - + assert result == 130 # Keyboard interrupt exit code diff --git a/tests/fast/test_pytorch.py b/tests/fast/test_pytorch.py index 365585cc..c0b9392d 100644 --- a/tests/fast/test_pytorch.py +++ b/tests/fast/test_pytorch.py @@ -1,8 +1,8 @@ -import duckdb import pytest +import duckdb -torch = pytest.importorskip('torch') +torch = pytest.importorskip("torch") @pytest.mark.skip(reason="some issues with Numpy, to be reverted") @@ -15,16 +15,16 @@ def test_pytorch(): # Test from connection duck_torch = con.execute("select * from t").torch() duck_numpy = con.sql("select * from t").fetchnumpy() - torch.equal(duck_torch['a'], torch.tensor(duck_numpy['a'])) - torch.equal(duck_torch['b'], torch.tensor(duck_numpy['b'])) + torch.equal(duck_torch["a"], torch.tensor(duck_numpy["a"])) + torch.equal(duck_torch["b"], torch.tensor(duck_numpy["b"])) # Test from relation duck_torch = con.sql("select * from t").torch() - torch.equal(duck_torch['a'], torch.tensor(duck_numpy['a'])) - torch.equal(duck_torch['b'], torch.tensor(duck_numpy['b'])) + torch.equal(duck_torch["a"], torch.tensor(duck_numpy["a"])) + torch.equal(duck_torch["b"], torch.tensor(duck_numpy["b"])) # Test all Numeric Types - numeric_types = ['TINYINT', 'SMALLINT', 'BIGINT', 'HUGEINT', 'FLOAT', 'DOUBLE', 'DECIMAL(4,1)', 'UTINYINT'] + numeric_types = ["TINYINT", "SMALLINT", "BIGINT", "HUGEINT", "FLOAT", "DOUBLE", "DECIMAL(4,1)", "UTINYINT"] for supported_type in numeric_types: con = duckdb.connect() @@ -32,8 +32,8 @@ def test_pytorch(): con.execute("insert into t values (1,2), (3,4)") duck_torch = con.sql("select * from t").torch() duck_numpy = con.sql("select * from t").fetchnumpy() - torch.equal(duck_torch['a'], torch.tensor(duck_numpy['a'])) - torch.equal(duck_torch['b'], torch.tensor(duck_numpy['b'])) + torch.equal(duck_torch["a"], torch.tensor(duck_numpy["a"])) + torch.equal(duck_torch["b"], torch.tensor(duck_numpy["b"])) # Comment out test that might fail or not depending on pytorch versions # with pytest.raises(TypeError, match="can't convert"): diff --git a/tests/fast/test_relation.py b/tests/fast/test_relation.py index 8e68c149..f1b4b1fd 100644 --- a/tests/fast/test_relation.py +++ b/tests/fast/test_relation.py @@ -1,16 +1,18 @@ -import duckdb -import numpy as np +# ruff: noqa: F841 +import datetime +import gc +import os import platform import tempfile -import os + +import numpy as np import pandas as pd import pytest from conftest import ArrowPandas, NumpyPandas -import datetime -import gc -from duckdb import ColumnExpression -from duckdb.typing import BIGINT, VARCHAR, TINYINT, BOOLEAN +import duckdb +from duckdb import ColumnExpression +from duckdb.typing import BIGINT, BOOLEAN, TINYINT, VARCHAR @pytest.fixture(scope="session") @@ -25,11 +27,11 @@ def get_relation(conn): return conn.from_df(test_df) -class TestRelation(object): +class TestRelation: def test_csv_auto(self): conn = duckdb.connect() df_rel = get_relation(conn) - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]}) test_df.to_csv(temp_file_name, index=False) @@ -37,10 +39,10 @@ def test_csv_auto(self): csv_rel = duckdb.from_csv_auto(temp_file_name) assert df_rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_relation_view(self, duckdb_cursor, pandas): - def create_view(duckdb_cursor): - df_in = pandas.DataFrame({'numbers': [1, 2, 3, 4, 5]}) + def create_view(duckdb_cursor) -> None: + df_in = pandas.DataFrame({"numbers": [1, 2, 3, 4, 5]}) rel = duckdb_cursor.query("select * from df_in") rel.to_view("my_view") @@ -49,9 +51,10 @@ def create_view(duckdb_cursor): # The df_in object is no longer reachable rel1 = duckdb_cursor.query("select * from df_in") # But it **is** reachable through our 'my_view' VIEW - # Because a Relation was created that references the df_in, the 'df_in' TableRef was injected with an ExternalDependency on the dataframe object - # We then created a VIEW from that Relation, which in turn copied this 'df_in' TableRef into the ViewCatalogEntry - # Because of this, the df_in object will stay alive for as long as our 'my_view' entry exists. + # Because a Relation was created that references the df_in, the 'df_in' TableRef was injected with an + # ExternalDependency on the dataframe object. We then created a VIEW from that Relation, which in turn copied + # this 'df_in' TableRef into the ViewCatalogEntry. Because of this, the df_in object will stay alive for as + # long as our 'my_view' entry exists. rel2 = duckdb_cursor.query("select * from my_view") res = rel2.fetchall() assert res == [(1,), (2,), (3,), (4,), (5,)] @@ -59,23 +62,23 @@ def create_view(duckdb_cursor): def test_filter_operator(self): conn = duckdb.connect() rel = get_relation(conn) - assert rel.filter('i > 1').execute().fetchall() == [(2, 'two'), (3, 'three'), (4, 'four')] + assert rel.filter("i > 1").execute().fetchall() == [(2, "two"), (3, "three"), (4, "four")] def test_projection_operator_single(self): conn = duckdb.connect() rel = get_relation(conn) - assert rel.project('i').execute().fetchall() == [(1,), (2,), (3,), (4,)] + assert rel.project("i").execute().fetchall() == [(1,), (2,), (3,), (4,)] def test_projection_operator_double(self): conn = duckdb.connect() rel = get_relation(conn) - assert rel.order('j').execute().fetchall() == [(4, 'four'), (1, 'one'), (3, 'three'), (2, 'two')] + assert rel.order("j").execute().fetchall() == [(4, "four"), (1, "one"), (3, "three"), (2, "two")] def test_limit_operator(self): conn = duckdb.connect() rel = get_relation(conn) - assert rel.limit(2).execute().fetchall() == [(1, 'one'), (2, 'two')] - assert rel.limit(2, offset=1).execute().fetchall() == [(2, 'two'), (3, 'three')] + assert rel.limit(2).execute().fetchall() == [(1, "one"), (2, "two")] + assert rel.limit(2, offset=1).execute().fetchall() == [(2, "two"), (3, "three")] def test_intersect_operator(self): conn = duckdb.connect() @@ -86,23 +89,23 @@ def test_intersect_operator(self): rel = conn.from_df(test_df) rel_2 = conn.from_df(test_df_2) - assert rel.intersect(rel_2).order('i').execute().fetchall() == [(3,), (4,)] + assert rel.intersect(rel_2).order("i").execute().fetchall() == [(3,), (4,)] def test_aggregate_operator(self): conn = duckdb.connect() rel = get_relation(conn) assert rel.aggregate("sum(i)").execute().fetchall() == [(10,)] - assert rel.aggregate("j, sum(i)").order('#2').execute().fetchall() == [ - ('one', 1), - ('two', 2), - ('three', 3), - ('four', 4), + assert rel.aggregate("j, sum(i)").order("#2").execute().fetchall() == [ + ("one", 1), + ("two", 2), + ("three", 3), + ("four", 4), ] def test_relation_fetch_df_chunk(self, duckdb_cursor): duckdb_cursor.execute(f"create table tbl as select * from range({duckdb.__standard_vector_size__ * 3})") - rel = duckdb_cursor.table('tbl') + rel = duckdb_cursor.table("tbl") # default arguments df1 = rel.fetch_df_chunk() assert len(df1) == duckdb.__standard_vector_size__ @@ -111,43 +114,43 @@ def test_relation_fetch_df_chunk(self, duckdb_cursor): assert len(df2) == duckdb.__standard_vector_size__ * 2 duckdb_cursor.execute( - f"create table dates as select (DATE '2021/02/21' + INTERVAL (i) DAYS)::DATE a from range({duckdb.__standard_vector_size__ * 4}) t(i)" + f"create table dates as select (DATE '2021/02/21' + INTERVAL (i) DAYS)::DATE a from range({duckdb.__standard_vector_size__ * 4}) t(i)" # noqa: E501 ) - rel = duckdb_cursor.table('dates') + rel = duckdb_cursor.table("dates") # default arguments df1 = rel.fetch_df_chunk() assert len(df1) == duckdb.__standard_vector_size__ - assert df1['a'][0].__class__ == pd.Timestamp + assert df1["a"][0].__class__ == pd.Timestamp # date as object df1 = rel.fetch_df_chunk(date_as_object=True) assert len(df1) == duckdb.__standard_vector_size__ - assert df1['a'][0].__class__ == datetime.date + assert df1["a"][0].__class__ == datetime.date # vectors and date as object df1 = rel.fetch_df_chunk(2, date_as_object=True) assert len(df1) == duckdb.__standard_vector_size__ * 2 - assert df1['a'][0].__class__ == datetime.date + assert df1["a"][0].__class__ == datetime.date def test_distinct_operator(self): conn = duckdb.connect() rel = get_relation(conn) - assert rel.distinct().order('all').execute().fetchall() == [(1, 'one'), (2, 'two'), (3, 'three'), (4, 'four')] + assert rel.distinct().order("all").execute().fetchall() == [(1, "one"), (2, "two"), (3, "three"), (4, "four")] def test_union_operator(self): conn = duckdb.connect() rel = get_relation(conn) print(rel.union(rel).execute().fetchall()) assert rel.union(rel).execute().fetchall() == [ - (1, 'one'), - (2, 'two'), - (3, 'three'), - (4, 'four'), - (1, 'one'), - (2, 'two'), - (3, 'three'), - (4, 'four'), + (1, "one"), + (2, "two"), + (3, "three"), + (4, "four"), + (1, "one"), + (2, "two"), + (3, "three"), + (4, "four"), ] def test_join_operator(self): @@ -156,11 +159,11 @@ def test_join_operator(self): test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]}) rel = conn.from_df(test_df) rel2 = conn.from_df(test_df) - assert rel.join(rel2, 'i').execute().fetchall() == [ - (1, 'one', 'one'), - (2, 'two', 'two'), - (3, 'three', 'three'), - (4, 'four', 'four'), + assert rel.join(rel2, "i").execute().fetchall() == [ + (1, "one", "one"), + (2, "two", "two"), + (3, "three", "three"), + (4, "four", "four"), ] def test_except_operator(self): @@ -176,10 +179,10 @@ def test_create_operator(self): rel = conn.from_df(test_df) rel.create("test_df") assert conn.query("select * from test_df").execute().fetchall() == [ - (1, 'one'), - (2, 'two'), - (3, 'three'), - (4, 'four'), + (1, "one"), + (2, "two"), + (3, "three"), + (4, "four"), ] def test_create_view_operator(self): @@ -188,31 +191,31 @@ def test_create_view_operator(self): rel = conn.from_df(test_df) rel.create_view("test_df") assert conn.query("select * from test_df").execute().fetchall() == [ - (1, 'one'), - (2, 'two'), - (3, 'three'), - (4, 'four'), + (1, "one"), + (2, "two"), + (3, "three"), + (4, "four"), ] def test_update_relation(self, duckdb_cursor): duckdb_cursor.sql("create table tbl (a varchar default 'test', b int)") - duckdb_cursor.table('tbl').insert(['hello', 21]) - duckdb_cursor.table('tbl').insert(['hello', 42]) + duckdb_cursor.table("tbl").insert(["hello", 21]) + duckdb_cursor.table("tbl").insert(["hello", 42]) # UPDATE tbl SET a = DEFAULT where b = 42 - duckdb_cursor.table('tbl').update( - {'a': duckdb.DefaultExpression()}, condition=duckdb.ColumnExpression('b') == 42 + duckdb_cursor.table("tbl").update( + {"a": duckdb.DefaultExpression()}, condition=duckdb.ColumnExpression("b") == 42 ) - assert duckdb_cursor.table('tbl').fetchall() == [('hello', 21), ('test', 42)] + assert duckdb_cursor.table("tbl").fetchall() == [("hello", 21), ("test", 42)] - rel = duckdb_cursor.table('tbl') - with pytest.raises(duckdb.InvalidInputException, match='Please provide at least one set expression'): + rel = duckdb_cursor.table("tbl") + with pytest.raises(duckdb.InvalidInputException, match="Please provide at least one set expression"): rel.update({}) with pytest.raises( - duckdb.InvalidInputException, match='Please provide the column name as the key of the dictionary' + duckdb.InvalidInputException, match="Please provide the column name as the key of the dictionary" ): rel.update({1: 21}) - with pytest.raises(duckdb.BinderException, match='Referenced update column c not found in table!'): - rel.update({'c': 21}) + with pytest.raises(duckdb.BinderException, match="Referenced update column c not found in table!"): + rel.update({"c": 21}) with pytest.raises( duckdb.InvalidInputException, match="Please provide 'set' as a dictionary of column name to Expression" ): @@ -221,11 +224,11 @@ def test_update_relation(self, duckdb_cursor): duckdb.InvalidInputException, match="Please provide an object of type Expression as the value, not ", ): - rel.update({'a': {21}}) + rel.update({"a": {21}}) def test_value_relation(self, duckdb_cursor): # Needs at least one input - with pytest.raises(duckdb.InvalidInputException, match='Could not create a ValueRelation without any inputs'): + with pytest.raises(duckdb.InvalidInputException, match="Could not create a ValueRelation without any inputs"): duckdb_cursor.values() # From a list of (python) values @@ -233,28 +236,28 @@ def test_value_relation(self, duckdb_cursor): assert rel.fetchall() == [(1, 2, 3)] # From an Expression - rel = duckdb_cursor.values(duckdb.ConstantExpression('test')) - assert rel.fetchall() == [('test',)] + rel = duckdb_cursor.values(duckdb.ConstantExpression("test")) + assert rel.fetchall() == [("test",)] # From multiple Expressions rel = duckdb_cursor.values( - duckdb.ConstantExpression('1'), duckdb.ConstantExpression('2'), duckdb.ConstantExpression('3') + duckdb.ConstantExpression("1"), duckdb.ConstantExpression("2"), duckdb.ConstantExpression("3") ) - assert rel.fetchall() == [('1', '2', '3')] + assert rel.fetchall() == [("1", "2", "3")] # From Expressions mixed with random values - with pytest.raises(duckdb.InvalidInputException, match='Please provide arguments of type Expression!'): + with pytest.raises(duckdb.InvalidInputException, match="Please provide arguments of type Expression!"): rel = duckdb_cursor.values( - duckdb.ConstantExpression('1'), - {'test'}, - duckdb.ConstantExpression('3'), + duckdb.ConstantExpression("1"), + {"test"}, + duckdb.ConstantExpression("3"), ) # From Expressions mixed with values that *can* be autocast to Expression rel = duckdb_cursor.values( - duckdb.ConstantExpression('1'), + duckdb.ConstantExpression("1"), 2, - duckdb.ConstantExpression('3'), + duckdb.ConstantExpression("3"), ) const = duckdb.ConstantExpression @@ -264,21 +267,21 @@ def test_value_relation(self, duckdb_cursor): # From mismatching tuples of Expressions with pytest.raises( - duckdb.InvalidInputException, match='Mismatch between length of tuples in input, expected 3 but found 2' + duckdb.InvalidInputException, match="Mismatch between length of tuples in input, expected 3 but found 2" ): rel = duckdb_cursor.values((const(1), const(2), const(3)), (const(5), const(4))) # From an empty tuple - with pytest.raises(duckdb.InvalidInputException, match='Please provide a non-empty tuple'): + with pytest.raises(duckdb.InvalidInputException, match="Please provide a non-empty tuple"): rel = duckdb_cursor.values(()) # Mixing tuples with Expressions - with pytest.raises(duckdb.InvalidInputException, match='Expected objects of type tuple'): + with pytest.raises(duckdb.InvalidInputException, match="Expected objects of type tuple"): rel = duckdb_cursor.values((const(1), const(2), const(3)), const(4)) # Using Expressions that can't be resolved: with pytest.raises(duckdb.BinderException, match='Referenced column "a" not found in FROM clause!'): - duckdb_cursor.values(duckdb.ColumnExpression('a')) + duckdb_cursor.values(duckdb.ColumnExpression("a")) def test_insert_into_operator(self): conn = duckdb.connect() @@ -290,23 +293,23 @@ def test_insert_into_operator(self): rel.insert_into("test_table3") # Inserting elements into table_3 - print(conn.values([5, 'five']).insert_into("test_table3")) + print(conn.values([5, "five"]).insert_into("test_table3")) rel_3 = conn.table("test_table3") - rel_3.insert([6, 'six']) + rel_3.insert([6, "six"]) assert rel_3.execute().fetchall() == [ - (1, 'one'), - (2, 'two'), - (3, 'three'), - (4, 'four'), - (5, 'five'), - (6, 'six'), + (1, "one"), + (2, "two"), + (3, "three"), + (4, "four"), + (5, "five"), + (6, "six"), ] def test_write_csv_operator(self): conn = duckdb.connect() df_rel = get_relation(conn) - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 df_rel.write_csv(temp_file_name) csv_rel = duckdb.from_csv_auto(temp_file_name) @@ -316,8 +319,8 @@ def test_table_update_with_schema(self, duckdb_cursor): duckdb_cursor.sql("create schema not_main;") duckdb_cursor.sql("create table not_main.tbl as select * from range(10) t(a)") - duckdb_cursor.table('not_main.tbl').update({'a': 21}, condition=ColumnExpression('a') == 5) - res = duckdb_cursor.table('not_main.tbl').fetchall() + duckdb_cursor.table("not_main.tbl").update({"a": 21}, condition=ColumnExpression("a") == 5) + res = duckdb_cursor.table("not_main.tbl").fetchall() assert res == [(0,), (1,), (2,), (3,), (4,), (21,), (6,), (7,), (8,), (9,)] def test_table_update_with_catalog(self, duckdb_cursor): @@ -325,8 +328,8 @@ def test_table_update_with_catalog(self, duckdb_cursor): duckdb_cursor.sql("create schema pg.not_main;") duckdb_cursor.sql("create table pg.not_main.tbl as select * from range(10) t(a)") - duckdb_cursor.table('pg.not_main.tbl').update({'a': 21}, condition=ColumnExpression('a') == 5) - res = duckdb_cursor.table('pg.not_main.tbl').fetchall() + duckdb_cursor.table("pg.not_main.tbl").update({"a": 21}, condition=ColumnExpression("a") == 5) + res = duckdb_cursor.table("pg.not_main.tbl").fetchall() assert res == [(0,), (1,), (2,), (3,), (4,), (21,), (6,), (7,), (8,), (9,)] def test_get_attr_operator(self): @@ -335,50 +338,50 @@ def test_get_attr_operator(self): rel = conn.table("test") assert rel.alias == "test" assert rel.type == "TABLE_RELATION" - assert rel.columns == ['i'] - assert rel.types == ['INTEGER'] + assert rel.columns == ["i"] + assert rel.types == ["INTEGER"] def test_query_fail(self): conn = duckdb.connect() conn.execute("CREATE TABLE test (i INTEGER)") rel = conn.table("test") - with pytest.raises(TypeError, match='incompatible function arguments'): + with pytest.raises(TypeError, match="incompatible function arguments"): rel.query("select j from test") def test_execute_fail(self): conn = duckdb.connect() conn.execute("CREATE TABLE test (i INTEGER)") rel = conn.table("test") - with pytest.raises(TypeError, match='incompatible function arguments'): + with pytest.raises(TypeError, match="incompatible function arguments"): rel.execute("select j from test") def test_df_proj(self): test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]}) - rel = duckdb.project(test_df, 'i') + rel = duckdb.project(test_df, "i") assert rel.execute().fetchall() == [(1,), (2,), (3,), (4,)] def test_relation_lifetime(self, duckdb_cursor): def create_relation(con): - df = pd.DataFrame({'a': [1, 2, 3]}) + df = pd.DataFrame({"a": [1, 2, 3]}) return con.sql("select * from df") assert create_relation(duckdb_cursor).fetchall() == [(1,), (2,), (3,)] def create_simple_join(con): - df1 = pd.DataFrame({'a': ['a', 'b', 'c'], 'b': [1, 2, 3]}) - df2 = pd.DataFrame({'a': ['a', 'b', 'c'], 'b': [4, 5, 6]}) + df1 = pd.DataFrame({"a": ["a", "b", "c"], "b": [1, 2, 3]}) + df2 = pd.DataFrame({"a": ["a", "b", "c"], "b": [4, 5, 6]}) return con.sql("select * from df1 JOIN df2 USING (a, a)") - assert create_simple_join(duckdb_cursor).fetchall() == [('a', 1, 4), ('b', 2, 5), ('c', 3, 6)] + assert create_simple_join(duckdb_cursor).fetchall() == [("a", 1, 4), ("b", 2, 5), ("c", 3, 6)] def create_complex_join(con): - df1 = pd.DataFrame({'a': [1], '1': [1]}) - df2 = pd.DataFrame({'a': [1], '2': [2]}) - df3 = pd.DataFrame({'a': [1], '3': [3]}) - df4 = pd.DataFrame({'a': [1], '4': [4]}) - df5 = pd.DataFrame({'a': [1], '5': [5]}) - df6 = pd.DataFrame({'a': [1], '6': [6]}) + df1 = pd.DataFrame({"a": [1], "1": [1]}) + df2 = pd.DataFrame({"a": [1], "2": [2]}) + df3 = pd.DataFrame({"a": [1], "3": [3]}) + df4 = pd.DataFrame({"a": [1], "4": [4]}) + df5 = pd.DataFrame({"a": [1], "5": [5]}) + df6 = pd.DataFrame({"a": [1], "6": [6]}) query = "select * from df1" for i in range(5): query += f" JOIN df{i + 2} USING (a, a)" @@ -407,7 +410,7 @@ def test_project_on_types(self): assert projection.columns == ["c2", "c4"] # select bigint, tinyint and a type that isn't there - projection = rel.select_types([BIGINT, "tinyint", con.struct_type({'a': VARCHAR, 'b': TINYINT})]) + projection = rel.select_types([BIGINT, "tinyint", con.struct_type({"a": VARCHAR, "b": TINYINT})]) assert projection.columns == ["c0", "c1"] ## select with empty projection list, not possible @@ -420,30 +423,30 @@ def test_project_on_types(self): def test_df_alias(self): test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]}) - rel = duckdb.alias(test_df, 'dfzinho') + rel = duckdb.alias(test_df, "dfzinho") assert rel.alias == "dfzinho" def test_df_filter(self): test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]}) - rel = duckdb.filter(test_df, 'i > 1') - assert rel.execute().fetchall() == [(2, 'two'), (3, 'three'), (4, 'four')] + rel = duckdb.filter(test_df, "i > 1") + assert rel.execute().fetchall() == [(2, "two"), (3, "three"), (4, "four")] def test_df_order_by(self): test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]}) - rel = duckdb.order(test_df, 'j') - assert rel.execute().fetchall() == [(4, 'four'), (1, 'one'), (3, 'three'), (2, 'two')] + rel = duckdb.order(test_df, "j") + assert rel.execute().fetchall() == [(4, "four"), (1, "one"), (3, "three"), (2, "two")] def test_df_distinct(self): test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]}) - rel = duckdb.distinct(test_df).order('i') - assert rel.execute().fetchall() == [(1, 'one'), (2, 'two'), (3, 'three'), (4, 'four')] + rel = duckdb.distinct(test_df).order("i") + assert rel.execute().fetchall() == [(1, "one"), (2, "two"), (3, "three"), (4, "four")] def test_df_write_csv(self): test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]}) - temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 duckdb.write_csv(test_df, temp_file_name) csv_rel = duckdb.from_csv_auto(temp_file_name) - assert csv_rel.execute().fetchall() == [(1, 'one'), (2, 'two'), (3, 'three'), (4, 'four')] + assert csv_rel.execute().fetchall() == [(1, "one"), (2, "two"), (3, "three"), (4, "four")] def test_join_types(self): test_df1 = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) @@ -452,9 +455,9 @@ def test_join_types(self): rel1 = con.from_df(test_df1) rel2 = con.from_df(test_df2) - assert rel1.join(rel2, 'i=j', 'inner').aggregate('count()').fetchone()[0] == 2 + assert rel1.join(rel2, "i=j", "inner").aggregate("count()").fetchone()[0] == 2 - assert rel1.join(rel2, 'i=j', 'left').aggregate('count()').fetchone()[0] == 4 + assert rel1.join(rel2, "i=j", "left").aggregate("count()").fetchone()[0] == 4 def test_fetchnumpy(self): start, stop = -1000, 2000 @@ -482,21 +485,21 @@ def test_fetchnumpy(self): assert len(res["a"]) == size assert np.all(res["a"] == np.arange(start, start + size)) - with pytest.raises(duckdb.ConversionException, match="Conversion Error.*out of range.*"): + with pytest.raises(duckdb.ConversionException, match=r"Conversion Error.*out of range.*"): # invalid conversion of negative integer to UINTEGER rel.project("CAST(a as UINTEGER)").fetchnumpy() def test_close(self): - def counter(): + def counter() -> int: counter.count += 1 return 42 counter.count = 0 conn = duckdb.connect() - conn.create_function('my_counter', counter, [], BIGINT) + conn.create_function("my_counter", counter, [], BIGINT) # Create a relation - rel = conn.sql('select my_counter()') + rel = conn.sql("select my_counter()") # Execute the relation once rel.fetchall() assert counter.count == 1 @@ -508,20 +511,20 @@ def counter(): assert counter.count == 2 # Verify that the query is run at least once if it's closed before it was executed. - rel = conn.sql('select my_counter()') + rel = conn.sql("select my_counter()") rel.close() assert counter.count == 3 def test_relation_print(self): con = duckdb.connect() con.execute("Create table t1 as select * from range(1000000)") - rel1 = con.table('t1') + rel1 = con.table("t1") text1 = str(rel1) - assert '? rows' in text1 - assert '>9999 rows' in text1 + assert "? rows" in text1 + assert ">9999 rows" in text1 @pytest.mark.parametrize( - 'num_rows', + "num_rows", [ 1024, 2048, @@ -531,7 +534,8 @@ def test_relation_print(self): 10000000, marks=pytest.mark.skipif( condition=platform.system() == "Emscripten", - reason="Emscripten/Pyodide builds run out of memory at this scale, and error might not thrown reliably", + reason="Emscripten/Pyodide builds run out of memory at this scale, and error might not " + "thrown reliably", ), ), ], @@ -541,7 +545,7 @@ def test_materialized_relation(self, duckdb_cursor, num_rows): query = f"call repeat_row(42, 'test', 'this is a long string', true, num_rows={num_rows})" rel = duckdb_cursor.sql(query) res = rel.fetchone() - assert res != None + assert res is not None res = rel.fetchmany(num_rows) assert len(res) == num_rows - 1 @@ -551,11 +555,11 @@ def test_materialized_relation(self, duckdb_cursor, num_rows): res = rel.fetchmany(5) assert len(res) == 0 res = rel.fetchone() - assert res == None + assert res is None rel.execute() res = rel.fetchone() - assert res != None + assert res is not None res = rel.fetchall() assert len(res) == num_rows - 1 @@ -563,7 +567,7 @@ def test_materialized_relation(self, duckdb_cursor, num_rows): assert len(res) == num_rows rel = duckdb_cursor.sql(query) - projection = rel.select('column0') + projection = rel.select("column0") assert projection.fetchall() == [(42,) for _ in range(num_rows)] filtered = rel.filter("column1 != 'test'") @@ -575,71 +579,71 @@ def test_materialized_relation(self, duckdb_cursor, num_rows): ): rel.insert([1, 2, 3, 4]) - query_rel = rel.query('x', "select 42 from x where column0 != 42") + query_rel = rel.query("x", "select 42 from x where column0 != 42") assert query_rel.fetchall() == [] distinct_rel = rel.distinct() - assert distinct_rel.fetchall() == [(42, 'test', 'this is a long string', True)] + assert distinct_rel.fetchall() == [(42, "test", "this is a long string", True)] limited_rel = rel.limit(50) assert len(limited_rel.fetchall()) == 50 # Using parameters also results in a MaterializedRelation materialized_one = duckdb_cursor.sql("select * from range(?)", params=[10]).project( - ColumnExpression('range').cast(str).alias('range') + ColumnExpression("range").cast(str).alias("range") ) materialized_two = duckdb_cursor.sql("call repeat('a', 5)") - joined_rel = materialized_one.join(materialized_two, 'range != a') + joined_rel = materialized_one.join(materialized_two, "range != a") res = joined_rel.fetchall() assert len(res) == 50 relation = duckdb_cursor.sql("select a from materialized_two") - assert relation.fetchone() == ('a',) + assert relation.fetchone() == ("a",) described = materialized_one.describe() res = described.fetchall() - assert res == [('count', '10'), ('mean', None), ('stddev', None), ('min', '0'), ('max', '9'), ('median', None)] + assert res == [("count", "10"), ("mean", None), ("stddev", None), ("min", "0"), ("max", "9"), ("median", None)] unioned_rel = materialized_one.union(materialized_two) res = unioned_rel.fetchall() assert res == [ - ('0',), - ('1',), - ('2',), - ('3',), - ('4',), - ('5',), - ('6',), - ('7',), - ('8',), - ('9',), - ('a',), - ('a',), - ('a',), - ('a',), - ('a',), + ("0",), + ("1",), + ("2",), + ("3",), + ("4",), + ("5",), + ("6",), + ("7",), + ("8",), + ("9",), + ("a",), + ("a",), + ("a",), + ("a",), + ("a",), ] except_rel = unioned_rel.except_(materialized_one) res = except_rel.fetchall() - assert res == [tuple('a') for _ in range(5)] + assert res == [tuple("a") for _ in range(5)] - intersect_rel = unioned_rel.intersect(materialized_one).order('range') + intersect_rel = unioned_rel.intersect(materialized_one).order("range") res = intersect_rel.fetchall() - assert res == [('0',), ('1',), ('2',), ('3',), ('4',), ('5',), ('6',), ('7',), ('8',), ('9',)] + assert res == [("0",), ("1",), ("2",), ("3",), ("4",), ("5",), ("6",), ("7",), ("8",), ("9",)] def test_materialized_relation_view(self, duckdb_cursor): - def create_view(duckdb_cursor): + def create_view(duckdb_cursor) -> None: duckdb_cursor.sql( """ create table tbl(a varchar); insert into tbl values ('test') returning * """ - ).to_view('vw') + ).to_view("vw") create_view(duckdb_cursor) res = duckdb_cursor.sql("select * from vw").fetchone() - assert res == ('test',) + assert res == ("test",) def test_materialized_relation_view2(self, duckdb_cursor): # This creates a MaterializedRelation @@ -651,21 +655,22 @@ def test_materialized_relation_view2(self, duckdb_cursor): # Create a VIEW that contains a ColumnDataRef rel.create_view("test", True) # Override the existing relation, the original MaterializedRelation has now gone out of scope - # The VIEW still works because the CDC that is being referenced is kept alive through the MaterializedDependency item + # The VIEW still works because the CDC that is being referenced is kept alive through the + # MaterializedDependency item rel = duckdb_cursor.sql("select * from test") res = rel.fetchall() - assert res == [([2], ['Alice'])] + assert res == [([2], ["Alice"])] def test_serialized_materialized_relation(self, tmp_database): con = duckdb.connect(tmp_database) - def create_view(con, view_name: str): + def create_view(con, view_name: str) -> None: rel = con.sql("select 'this is not a small string ' || range::varchar from range(?)", params=[10]) rel.to_view(view_name) - expected = [(f'this is not a small string {i}',) for i in range(10)] + expected = [(f"this is not a small string {i}",) for i in range(10)] - create_view(con, 'vw') + create_view(con, "vw") res = con.sql("select * from vw").fetchall() assert res == expected diff --git a/tests/fast/test_relation_dependency_leak.py b/tests/fast/test_relation_dependency_leak.py index ca505704..659e1c28 100644 --- a/tests/fast/test_relation_dependency_leak.py +++ b/tests/fast/test_relation_dependency_leak.py @@ -1,5 +1,6 @@ -import numpy as np import os + +import numpy as np import pytest try: @@ -8,8 +9,7 @@ can_run = True except ImportError: can_run = False -from conftest import NumpyPandas, ArrowPandas - +from conftest import ArrowPandas, NumpyPandas psutil = pytest.importorskip("psutil") @@ -31,43 +31,43 @@ def from_df(pandas, duckdb_cursor): def from_arrow(pandas, duckdb_cursor): data = pa.array(np.random.rand(1_000_000), type=pa.float32()) - arrow_table = pa.Table.from_arrays([data], ['a']) + arrow_table = pa.Table.from_arrays([data], ["a"]) duckdb_cursor.from_arrow(arrow_table) def arrow_replacement(pandas, duckdb_cursor): data = pa.array(np.random.rand(1_000_000), type=pa.float32()) - arrow_table = pa.Table.from_arrays([data], ['a']) + arrow_table = pa.Table.from_arrays([data], ["a"]) # noqa: F841 duckdb_cursor.query("select sum(a) from arrow_table").fetchall() def pandas_replacement(pandas, duckdb_cursor): - df = pandas.DataFrame({"x": np.random.rand(1_000_000)}) + df = pandas.DataFrame({"x": np.random.rand(1_000_000)}) # noqa: F841 duckdb_cursor.query("select sum(x) from df").fetchall() -class TestRelationDependencyMemoryLeak(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) +class TestRelationDependencyMemoryLeak: + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_from_arrow_leak(self, pandas, duckdb_cursor): if not can_run: return check_memory(from_arrow, pandas, duckdb_cursor) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_from_df_leak(self, pandas, duckdb_cursor): check_memory(from_df, pandas, duckdb_cursor) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_arrow_replacement_scan_leak(self, pandas, duckdb_cursor): if not can_run: return check_memory(arrow_replacement, pandas, duckdb_cursor) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_pandas_replacement_scan_leak(self, pandas, duckdb_cursor): check_memory(pandas_replacement, pandas, duckdb_cursor) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_relation_view_leak(self, pandas, duckdb_cursor): rel = from_df(pandas, duckdb_cursor) rel.create_view("bla") diff --git a/tests/fast/test_replacement_scan.py b/tests/fast/test_replacement_scan.py index 233439df..0e27016a 100644 --- a/tests/fast/test_replacement_scan.py +++ b/tests/fast/test_replacement_scan.py @@ -1,20 +1,23 @@ -import duckdb -import os +# ruff: noqa: F841 +from pathlib import Path + import pytest +import duckdb + pa = pytest.importorskip("pyarrow") pl = pytest.importorskip("polars") pd = pytest.importorskip("pandas") def using_table(con, to_scan, object_name): - local_scope = {'con': con, object_name: to_scan, 'object_name': object_name} - exec(f"result = con.table(object_name)", globals(), local_scope) + local_scope = {"con": con, object_name: to_scan, "object_name": object_name} + exec("result = con.table(object_name)", globals(), local_scope) return local_scope["result"] def using_sql(con, to_scan, object_name): - local_scope = {'con': con, object_name: to_scan, 'object_name': object_name} + local_scope = {"con": con, object_name: to_scan, "object_name": object_name} exec(f"result = con.sql('select * from \"{object_name}\"')", globals(), local_scope) return local_scope["result"] @@ -60,40 +63,40 @@ def fetch_relation(rel): def from_pandas(): - df = pd.DataFrame({'a': [1, 2, 3]}) + df = pd.DataFrame({"a": [1, 2, 3]}) return df def from_arrow(): - schema = pa.schema([('field_1', pa.int64())]) + schema = pa.schema([("field_1", pa.int64())]) df = pa.RecordBatchReader.from_batches(schema, [pa.RecordBatch.from_arrays([pa.array([1, 2, 3])], schema=schema)]) return df def create_relation(conn, query: str) -> duckdb.DuckDBPyRelation: - df = pd.DataFrame({'a': [1, 2, 3]}) + df = pd.DataFrame({"a": [1, 2, 3]}) return conn.sql(query) -class TestReplacementScan(object): +class TestReplacementScan: def test_csv_replacement(self): con = duckdb.connect() - filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'integers.csv') - res = con.execute("select count(*) from '%s'" % (filename)) + filename = str(Path(__file__).parent / "data" / "integers.csv") + res = con.execute(f"select count(*) from '{filename}'") assert res.fetchone()[0] == 2 def test_parquet_replacement(self): con = duckdb.connect() - filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'binary_string.parquet') - res = con.execute("select count(*) from '%s'" % (filename)) + filename = str(Path(__file__).parent / "data" / "binary_string.parquet") + res = con.execute(f"select count(*) from '{filename}'") assert res.fetchone()[0] == 3 - @pytest.mark.parametrize('get_relation', [using_table, using_sql]) + @pytest.mark.parametrize("get_relation", [using_table, using_sql]) @pytest.mark.parametrize( - 'fetch_method', + "fetch_method", [fetch_polars, fetch_df, fetch_arrow, fetch_arrow_table, fetch_arrow_record_batch, fetch_relation], ) - @pytest.mark.parametrize('object_name', ['tbl', 'table', 'select', 'update']) + @pytest.mark.parametrize("object_name", ["tbl", "table", "select", "update"]) def test_table_replacement_scans(self, duckdb_cursor, get_relation, fetch_method, object_name): base_rel = duckdb_cursor.values([1, 2, 3]) to_scan = fetch_method(base_rel) @@ -105,29 +108,29 @@ def test_table_replacement_scans(self, duckdb_cursor, get_relation, fetch_method def test_scan_global(self, duckdb_cursor): duckdb_cursor.execute("set python_enable_replacements=false") - with pytest.raises(duckdb.CatalogException, match='Table with name global_polars_df does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name global_polars_df does not exist"): # We set the depth to look for global variables to 0 so it's never found duckdb_cursor.sql("select * from global_polars_df") duckdb_cursor.execute("set python_enable_replacements=true") # Now the depth is 1, which is enough to locate the variable rel = duckdb_cursor.sql("select * from global_polars_df") res = rel.fetchone() - assert res == (1, 'banana', 5, 'beetle') + assert res == (1, "banana", 5, "beetle") def test_scan_local(self, duckdb_cursor): - df = pd.DataFrame({'a': [1, 2, 3]}) + df = pd.DataFrame({"a": [1, 2, 3]}) - def inner_func(duckdb_cursor): + def inner_func(duckdb_cursor) -> None: duckdb_cursor.execute("set python_enable_replacements=false") - with pytest.raises(duckdb.CatalogException, match='Table with name df does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name df does not exist"): # We set the depth to look for local variables to 0 so it's never found duckdb_cursor.sql("select * from df") duckdb_cursor.execute("set python_enable_replacements=true") - with pytest.raises(duckdb.CatalogException, match='Table with name df does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name df does not exist"): # Here it's still not found, because it's not visible to this frame duckdb_cursor.sql("select * from df") - df = pd.DataFrame({'a': [4, 5, 6]}) + df = pd.DataFrame({"a": [4, 5, 6]}) duckdb_cursor.execute("set python_enable_replacements=true") # We can find the newly defined 'df' with depth 1 rel = duckdb_cursor.sql("select * from df") @@ -137,12 +140,13 @@ def inner_func(duckdb_cursor): inner_func(duckdb_cursor) def test_scan_local_unlimited(self, duckdb_cursor): - df = pd.DataFrame({'a': [1, 2, 3]}) + df = pd.DataFrame({"a": [1, 2, 3]}) - def inner_func(duckdb_cursor): + def inner_func(duckdb_cursor) -> None: duckdb_cursor.execute("set python_enable_replacements=true") - with pytest.raises(duckdb.CatalogException, match='Table with name df does not exist'): - # We set the depth to look for local variables to 1 so it's still not found because it wasn't defined in this function + with pytest.raises(duckdb.CatalogException, match="Table with name df does not exist"): + # We set the depth to look for local variables to 1 so it's still not found because it wasn't defined + # in this function duckdb_cursor.sql("select * from df") duckdb_cursor.execute("set python_scan_all_frames=true") # Now we can find 'df' because we also scan the previous frame(s) @@ -155,37 +159,37 @@ def inner_func(duckdb_cursor): def test_replacement_scan_relapi(self): con = duckdb.connect() - pyrel1 = con.query('from (values (42), (84), (120)) t(i)') + pyrel1 = con.query("from (values (42), (84), (120)) t(i)") assert isinstance(pyrel1, duckdb.DuckDBPyRelation) assert pyrel1.fetchall() == [(42,), (84,), (120,)] - pyrel2 = con.query('from pyrel1 limit 2') + pyrel2 = con.query("from pyrel1 limit 2") assert isinstance(pyrel2, duckdb.DuckDBPyRelation) assert pyrel2.fetchall() == [(42,), (84,)] - pyrel3 = con.query('select i + 100 from pyrel2') - assert type(pyrel3) == duckdb.DuckDBPyRelation + pyrel3 = con.query("select i + 100 from pyrel2") + assert type(pyrel3) is duckdb.DuckDBPyRelation assert pyrel3.fetchall() == [(142,), (184,)] def test_replacement_scan_not_found(self): con = duckdb.connect() con.execute("set python_scan_all_frames=true") - with pytest.raises(duckdb.CatalogException, match='Table with name non_existant does not exist'): - res = con.sql("select * from non_existant").fetchall() + with pytest.raises(duckdb.CatalogException, match="Table with name non_existant does not exist"): + con.sql("select * from non_existant").fetchall() def test_replacement_scan_alias(self): con = duckdb.connect() - pyrel1 = con.query('from (values (1, 2)) t(i, j)') - pyrel2 = con.query('from (values (1, 10)) t(i, k)') - pyrel3 = con.query('from pyrel1 join pyrel2 using(i)') - assert type(pyrel3) == duckdb.DuckDBPyRelation + pyrel1 = con.query("from (values (1, 2)) t(i, j)") + pyrel2 = con.query("from (values (1, 10)) t(i, k)") + pyrel3 = con.query("from pyrel1 join pyrel2 using(i)") + assert type(pyrel3) is duckdb.DuckDBPyRelation assert pyrel3.fetchall() == [(1, 2, 10)] def test_replacement_scan_pandas_alias(self): con = duckdb.connect() - df1 = con.query('from (values (1, 2)) t(i, j)').df() - df2 = con.query('from (values (1, 10)) t(i, k)').df() - df3 = con.query('from df1 join df2 using(i)') + df1 = con.query("from (values (1, 2)) t(i, j)").df() + df2 = con.query("from (values (1, 10)) t(i, k)").df() + df3 = con.query("from df1 join df2 using(i)") assert df3.fetchall() == [(1, 2, 10)] def test_replacement_scan_after_creation(self, duckdb_cursor): @@ -194,14 +198,15 @@ def test_replacement_scan_after_creation(self, duckdb_cursor): rel = duckdb_cursor.sql("select * from df") duckdb_cursor.execute("drop table df") - df = pd.DataFrame({'b': [1, 2, 3]}) + df = pd.DataFrame({"b": [1, 2, 3]}) res = rel.fetchall() - # FIXME: this should error instead, the 'df' table we relied on has been removed and replaced with a replacement scan + # TODO: this should error instead, the 'df' table we relied on has been removed # noqa: TD002, TD003 + # and replaced with a replacement scan assert res == [(1,), (2,), (3,)] def test_replacement_scan_caching(self, duckdb_cursor): def return_rel(conn): - df = pd.DataFrame({'a': [1, 2, 3]}) + df = pd.DataFrame({"a": [1, 2, 3]}) rel = conn.sql("select * from df") return rel @@ -220,7 +225,7 @@ def test_replacement_scan_fail(self): con.execute("select count(*) from random_object").fetchone() @pytest.mark.parametrize( - 'df_create', + "df_create", [ from_pandas, from_arrow, @@ -258,7 +263,7 @@ def test_cte(self, duckdb_cursor, df_create): if df_create == from_arrow: # Because the RecordBatchReader is destructive, it's empty after the first scan # But we reference it multiple times, so the subsequent reads have no data to read - # FIXME: this should probably throw an error... + # TODO: this should probably throw an error... # noqa: TD002, TD003 assert len(res) >= 0 else: assert res == [([1, 2, 3],), ([1, 2, 3],), ([1, 2, 3],)] @@ -333,7 +338,7 @@ def test_same_name_cte(self, duckdb_cursor): def test_use_with_view(self, duckdb_cursor): rel = create_relation(duckdb_cursor, "select * from df") - rel.create_view('v1') + rel.create_view("v1") del rel rel = duckdb_cursor.sql("select * from v1") @@ -341,14 +346,15 @@ def test_use_with_view(self, duckdb_cursor): assert res == [(1,), (2,), (3,)] duckdb_cursor.execute("drop view v1") - def create_view_in_func(con): + def create_view_in_func(con) -> None: df = pd.DataFrame({"a": [1, 2, 3]}) - con.execute('CREATE VIEW v1 AS SELECT * FROM df') + con.execute("CREATE VIEW v1 AS SELECT * FROM df") create_view_in_func(duckdb_cursor) - # FIXME: this should be fixed in the future, likely by unifying the behavior of .sql and .execute - with pytest.raises(duckdb.CatalogException, match='Table with name df does not exist'): + # TODO: this should be fixed in the future, likely by unifying the behavior of # noqa: TD002, TD003 + # .sql and .execute + with pytest.raises(duckdb.CatalogException, match="Table with name df does not exist"): rel = duckdb_cursor.sql("select * from v1") def test_recursive_cte(self, duckdb_cursor): @@ -410,7 +416,7 @@ def test_multiple_replacements(self, duckdb_cursor): """ rel = duckdb_cursor.sql(query) res = rel.fetchall() - assert res == [(2, 'Bob', None), (3, 'Charlie', None), (4, 'David', 1.0), (5, 'Eve', 1.0)] + assert res == [(2, "Bob", None), (3, "Charlie", None), (4, "David", 1.0), (5, "Eve", 1.0)] def test_cte_at_different_levels(self, duckdb_cursor): query = """ @@ -460,19 +466,17 @@ def test_replacement_disabled(self): ## disable external access con.execute("set enable_external_access=false") - with pytest.raises(duckdb.CatalogException, match='Table with name df does not exist!'): - rel = create_relation(con, "select * from df") - res = rel.fetchall() + with pytest.raises(duckdb.CatalogException, match="Table with name df does not exist!"): + create_relation(con, "select * from df") with pytest.raises( - duckdb.InvalidInputException, match='Cannot change enable_external_access setting while database is running' + duckdb.InvalidInputException, match="Cannot change enable_external_access setting while database is running" ): con.execute("set enable_external_access=true") # Create connection with external access disabled - con = duckdb.connect(config={'enable_external_access': False}) - with pytest.raises(duckdb.CatalogException, match='Table with name df does not exist!'): - rel = create_relation(con, "select * from df") - res = rel.fetchall() + con = duckdb.connect(config={"enable_external_access": False}) + with pytest.raises(duckdb.CatalogException, match="Table with name df does not exist!"): + create_relation(con, "select * from df") # Create regular connection, disable inbetween creation and execution con = duckdb.connect() @@ -487,23 +491,23 @@ def test_replacement_disabled(self): assert res == [(1,), (2,), (3,)] def test_replacement_of_cross_connection_relation(self): - con1 = duckdb.connect(':memory:') - con2 = duckdb.connect(':memory:') - con1.query('create table integers(i int)') - con2.query('create table integers(v varchar)') - con1.query('insert into integers values (42)') - con2.query('insert into integers values (\'xxx\')') - rel1 = con1.query('select * from integers') + con1 = duckdb.connect(":memory:") + con2 = duckdb.connect(":memory:") + con1.query("create table integers(i int)") + con2.query("create table integers(v varchar)") + con1.query("insert into integers values (42)") + con2.query("insert into integers values ('xxx')") + rel1 = con1.query("select * from integers") with pytest.raises( duckdb.InvalidInputException, - match=r'The object was created by another Connection and can therefore not be used by this Connection.', + match=r"The object was created by another Connection and can therefore not be used by this Connection.", ): - con2.query('from rel1') + con2.query("from rel1") del con1 with pytest.raises( duckdb.InvalidInputException, - match=r'The object was created by another Connection and can therefore not be used by this Connection.', + match=r"The object was created by another Connection and can therefore not be used by this Connection.", ): - con2.query('from rel1') + con2.query("from rel1") diff --git a/tests/fast/test_result.py b/tests/fast/test_result.py index af68e268..4210a437 100644 --- a/tests/fast/test_result.py +++ b/tests/fast/test_result.py @@ -1,49 +1,51 @@ -import duckdb -import pytest import datetime +import pytest + +import duckdb + -class TestPythonResult(object): +class TestPythonResult: def test_result_closed(self, duckdb_cursor): - connection = duckdb.connect('') + connection = duckdb.connect("") cursor = connection.cursor() - cursor.execute('CREATE TABLE integers (i integer)') - cursor.execute('INSERT INTO integers VALUES (0),(1),(2),(3),(4),(5),(6),(7),(8),(9),(NULL)') + cursor.execute("CREATE TABLE integers (i integer)") + cursor.execute("INSERT INTO integers VALUES (0),(1),(2),(3),(4),(5),(6),(7),(8),(9),(NULL)") rel = connection.table("integers") res = rel.aggregate("sum(i)").execute() res.close() - with pytest.raises(duckdb.InvalidInputException, match='result closed'): + with pytest.raises(duckdb.InvalidInputException, match="result closed"): res.fetchone() - with pytest.raises(duckdb.InvalidInputException, match='result closed'): + with pytest.raises(duckdb.InvalidInputException, match="result closed"): res.fetchall() - with pytest.raises(duckdb.InvalidInputException, match='result closed'): + with pytest.raises(duckdb.InvalidInputException, match="result closed"): res.fetchnumpy() - with pytest.raises(duckdb.InvalidInputException, match='There is no query result'): + with pytest.raises(duckdb.InvalidInputException, match="There is no query result"): res.fetch_arrow_table() - with pytest.raises(duckdb.InvalidInputException, match='There is no query result'): + with pytest.raises(duckdb.InvalidInputException, match="There is no query result"): res.fetch_arrow_reader(1) def test_result_describe_types(self, duckdb_cursor): - connection = duckdb.connect('') + connection = duckdb.connect("") cursor = connection.cursor() - cursor.execute('CREATE TABLE test (i bool, j TIME, k VARCHAR)') + cursor.execute("CREATE TABLE test (i bool, j TIME, k VARCHAR)") cursor.execute("INSERT INTO test VALUES (TRUE, '01:01:01', 'bla' )") rel = connection.table("test") res = rel.execute() assert res.description == [ - ('i', 'BOOLEAN', None, None, None, None, None), - ('j', 'TIME', None, None, None, None, None), - ('k', 'VARCHAR', None, None, None, None, None), + ("i", "BOOLEAN", None, None, None, None, None), + ("j", "TIME", None, None, None, None, None), + ("k", "VARCHAR", None, None, None, None, None), ] def test_result_timestamps(self, duckdb_cursor): - connection = duckdb.connect('') + connection = duckdb.connect("") cursor = connection.cursor() cursor.execute( - 'CREATE TABLE IF NOT EXISTS timestamps (sec TIMESTAMP_S, milli TIMESTAMP_MS,micro TIMESTAMP_US, nano TIMESTAMP_NS );' + "CREATE TABLE IF NOT EXISTS timestamps (sec TIMESTAMP_S, milli TIMESTAMP_MS,micro TIMESTAMP_US, nano TIMESTAMP_NS );" # noqa: E501 ) cursor.execute( - "INSERT INTO timestamps VALUES ('2008-01-01 00:00:11','2008-01-01 00:00:01.794','2008-01-01 00:00:01.98926','2008-01-01 00:00:01.899268321' )" + "INSERT INTO timestamps VALUES ('2008-01-01 00:00:11','2008-01-01 00:00:01.794','2008-01-01 00:00:01.98926','2008-01-01 00:00:01.899268321' )" # noqa: E501 ) rel = connection.table("timestamps") @@ -59,12 +61,12 @@ def test_result_timestamps(self, duckdb_cursor): def test_result_interval(self): connection = duckdb.connect() cursor = connection.cursor() - cursor.execute('CREATE TABLE IF NOT EXISTS intervals (ivals INTERVAL)') + cursor.execute("CREATE TABLE IF NOT EXISTS intervals (ivals INTERVAL)") cursor.execute("INSERT INTO intervals VALUES ('1 day'), ('2 second'), ('1 microsecond')") rel = connection.table("intervals") res = rel.execute() - assert res.description == [('ivals', 'INTERVAL', None, None, None, None, None)] + assert res.description == [("ivals", "INTERVAL", None, None, None, None, None)] assert res.fetchall() == [ (datetime.timedelta(days=1.0),), (datetime.timedelta(seconds=2.0),), @@ -74,4 +76,4 @@ def test_result_interval(self): def test_description_uuid(self): connection = duckdb.connect() connection.execute("select uuid();") - connection.description + connection.description # noqa: B018 diff --git a/tests/fast/test_runtime_error.py b/tests/fast/test_runtime_error.py index 29e81d1e..9f1975a0 100644 --- a/tests/fast/test_runtime_error.py +++ b/tests/fast/test_runtime_error.py @@ -1,12 +1,18 @@ -import duckdb import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb + + +def closed(): + return pytest.raises(duckdb.ConnectionException, match="Connection already closed") + -closed = lambda: pytest.raises(duckdb.ConnectionException, match='Connection already closed') -no_result_set = lambda: pytest.raises(duckdb.InvalidInputException, match='No open result set') +def no_result_set(): + return pytest.raises(duckdb.InvalidInputException, match="No open result set") -class TestRuntimeError(object): +class TestRuntimeError: def test_fetch_error(self): con = duckdb.connect() con.execute("create table tbl as select 'hello' i") @@ -20,7 +26,7 @@ def test_df_error(self): con.execute("select i::int from tbl").df() def test_arrow_error(self): - pytest.importorskip('pyarrow') + pytest.importorskip("pyarrow") con = duckdb.connect() con.execute("create table tbl as select 'hello' i") @@ -34,83 +40,83 @@ def test_register_error(self): con.register(py_obj, "v") def test_arrow_fetch_table_error(self): - pytest.importorskip('pyarrow') + pytest.importorskip("pyarrow") con = duckdb.connect() arrow_object = con.execute("select 1").fetch_arrow_table() arrow_relation = con.from_arrow(arrow_object) res = arrow_relation.execute() res.close() - with pytest.raises(duckdb.InvalidInputException, match='There is no query result'): + with pytest.raises(duckdb.InvalidInputException, match="There is no query result"): res.fetch_arrow_table() def test_arrow_record_batch_reader_error(self): - pytest.importorskip('pyarrow') + pytest.importorskip("pyarrow") con = duckdb.connect() arrow_object = con.execute("select 1").fetch_arrow_table() arrow_relation = con.from_arrow(arrow_object) res = arrow_relation.execute() res.close() - with pytest.raises(duckdb.ProgrammingError, match='There is no query result'): + with pytest.raises(duckdb.ProgrammingError, match="There is no query result"): res.fetch_arrow_reader(1) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_relation_cache_fetchall(self, pandas): conn = duckdb.connect() df_in = pandas.DataFrame( { - 'numbers': [1, 2, 3, 4, 5], + "numbers": [1, 2, 3, 4, 5], } ) conn.execute("create view x as select * from df_in") rel = conn.query("select * from x") del df_in - with pytest.raises(duckdb.ProgrammingError, match='Table with name df_in does not exist'): + with pytest.raises(duckdb.ProgrammingError, match="Table with name df_in does not exist"): # Even when we preserve ExternalDependency objects correctly, this is not supported # Relations only save dependencies for their immediate TableRefs, # so the dependency of 'x' on 'df_in' is not registered in 'rel' rel.fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_relation_cache_execute(self, pandas): conn = duckdb.connect() df_in = pandas.DataFrame( { - 'numbers': [1, 2, 3, 4, 5], + "numbers": [1, 2, 3, 4, 5], } ) conn.execute("create view x as select * from df_in") rel = conn.query("select * from x") del df_in - with pytest.raises(duckdb.ProgrammingError, match='Table with name df_in does not exist'): + with pytest.raises(duckdb.ProgrammingError, match="Table with name df_in does not exist"): rel.execute() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_relation_query_error(self, pandas): conn = duckdb.connect() df_in = pandas.DataFrame( { - 'numbers': [1, 2, 3, 4, 5], + "numbers": [1, 2, 3, 4, 5], } ) conn.execute("create view x as select * from df_in") rel = conn.query("select * from x") del df_in - with pytest.raises(duckdb.CatalogException, match='Table with name df_in does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name df_in does not exist"): rel.query("bla", "select * from bla") - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_conn_broken_statement_error(self, pandas): conn = duckdb.connect() df_in = pandas.DataFrame( { - 'numbers': [1, 2, 3, 4, 5], + "numbers": [1, 2, 3, 4, 5], } ) conn.execute("create view x as select * from df_in") del df_in - with pytest.raises(duckdb.CatalogException, match='Table with name df_in does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name df_in does not exist"): conn.execute("select 1; select * from x; select 3;") def test_conn_prepared_statement_error(self): @@ -118,17 +124,17 @@ def test_conn_prepared_statement_error(self): conn.execute("create table integers (a integer, b integer)") with pytest.raises( duckdb.InvalidInputException, - match='Values were not provided for the following prepared statement parameters: 2', + match="Values were not provided for the following prepared statement parameters: 2", ): conn.execute("select * from integers where a =? and b=?", [1]) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_closed_conn_exceptions(self, pandas): conn = duckdb.connect() conn.close() df_in = pandas.DataFrame( { - 'numbers': [1, 2, 3, 4, 5], + "numbers": [1, 2, 3, 4, 5], } ) diff --git a/tests/fast/test_sql_expression.py b/tests/fast/test_sql_expression.py index 771be84d..f3cf41ca 100644 --- a/tests/fast/test_sql_expression.py +++ b/tests/fast/test_sql_expression.py @@ -1,5 +1,6 @@ -import duckdb import pytest + +import duckdb from duckdb import ( ColumnExpression, ConstantExpression, @@ -7,9 +8,8 @@ ) -class TestSQLExpression(object): +class TestSQLExpression: def test_sql_expression_basic(self, duckdb_cursor): - # Test simple constant expressions expr = SQLExpression("42") rel = duckdb_cursor.sql("SELECT 1").select(expr) @@ -17,7 +17,7 @@ def test_sql_expression_basic(self, duckdb_cursor): expr = SQLExpression("'hello'") rel = duckdb_cursor.sql("SELECT 1").select(expr) - assert rel.fetchall() == [('hello',)] + assert rel.fetchall() == [("hello",)] expr = SQLExpression("NULL") rel = duckdb_cursor.sql("SELECT 1").select(expr) @@ -43,14 +43,13 @@ def test_sql_expression_basic(self, duckdb_cursor): # Test function calls expr = SQLExpression("UPPER('test')") rel = duckdb_cursor.sql("SELECT 1").select(expr) - assert rel.fetchall() == [('TEST',)] + assert rel.fetchall() == [("TEST",)] expr = SQLExpression("CONCAT('hello', ' ', 'world')") rel = duckdb_cursor.sql("SELECT 1").select(expr) - assert rel.fetchall() == [('hello world',)] + assert rel.fetchall() == [("hello world",)] def test_sql_expression_with_columns(self, duckdb_cursor): - # Create a test table duckdb_cursor.execute( """ @@ -75,12 +74,12 @@ def test_sql_expression_with_columns(self, duckdb_cursor): expr = SQLExpression("UPPER(b)") rel2 = rel.select(expr) - assert rel2.fetchall() == [('ONE',), ('TWO',), ('THREE',)] + assert rel2.fetchall() == [("ONE",), ("TWO",), ("THREE",)] # Test complex expressions expr = SQLExpression("CASE WHEN a > 1 THEN b ELSE 'default' END") rel2 = rel.select(expr) - assert rel2.fetchall() == [('default',), ('two',), ('three',)] + assert rel2.fetchall() == [("default",), ("two",), ("three",)] # Test combining with other expression types expr1 = SQLExpression("a + 5") @@ -122,8 +121,8 @@ def test_sql_expression_alias(self, duckdb_cursor): rel = duckdb_cursor.table("test_alias") expr = SQLExpression("a + 10").alias("a_plus_10") rel2 = rel.select(expr, "b") - assert rel2.fetchall() == [(11, 'one'), (12, 'two')] - assert rel2.columns == ['a_plus_10', 'b'] + assert rel2.fetchall() == [(11, "one"), (12, "two")] + assert rel2.columns == ["a_plus_10", "b"] def test_sql_expression_in_filter(self, duckdb_cursor): duckdb_cursor.execute( @@ -142,18 +141,18 @@ def test_sql_expression_in_filter(self, duckdb_cursor): # Test filter with SQL expression expr = SQLExpression("a > 2") rel2 = rel.filter(expr) - assert rel2.fetchall() == [(3, 'three'), (4, 'four')] + assert rel2.fetchall() == [(3, "three"), (4, "four")] # Test complex filter expr = SQLExpression("a % 2 = 0 AND b LIKE '%o%'") rel2 = rel.filter(expr) - assert rel2.fetchall() == [(2, 'two'), (4, 'four')] + assert rel2.fetchall() == [(2, "two"), (4, "four")] # Test combining with other expression types expr1 = SQLExpression("a > 1") expr2 = ColumnExpression("b") == ConstantExpression("four") rel2 = rel.filter(expr1 & expr2) - assert rel2.fetchall() == [(4, 'four')] + assert rel2.fetchall() == [(4, "four")] def test_sql_expression_in_aggregates(self, duckdb_cursor): duckdb_cursor.execute( @@ -176,14 +175,14 @@ def test_sql_expression_in_aggregates(self, duckdb_cursor): # Test aggregation with group by expr = SQLExpression("SUM(c)") - rel2 = rel.aggregate([expr, "b"]).sort('b') + rel2 = rel.aggregate([expr, "b"]).sort("b") result = rel2.fetchall() - assert result == [(30, 'group1'), (70, 'group2')] + assert result == [(30, "group1"), (70, "group2")] # Test multiple aggregations expr1 = SQLExpression("SUM(a)").alias("sum_a") expr2 = SQLExpression("AVG(c)").alias("avg_c") - rel2 = rel.aggregate([expr1, expr2], "b").sort('sum_a', 'avg_c') + rel2 = rel.aggregate([expr1, expr2], "b").sort("sum_a", "avg_c") result = rel2.fetchall() result.sort() assert result == [(3, 15.0), (7, 35.0)] diff --git a/tests/fast/test_string_annotation.py b/tests/fast/test_string_annotation.py index c5500c66..a5ea4cfd 100644 --- a/tests/fast/test_string_annotation.py +++ b/tests/fast/test_string_annotation.py @@ -1,11 +1,13 @@ -import duckdb -import pytest import sys -from typing import Union + +# we need typing.Union in our import cache +from typing import Union # noqa: F401 + +import pytest def make_annotated_function(type: str): - def test_base(): + def test_base() -> None: return None import types @@ -14,31 +16,27 @@ def test_base(): test_base.__code__, test_base.__globals__, test_base.__name__, test_base.__defaults__, test_base.__closure__ ) # Add the 'type' string as return_annotation - test_function.__annotations__ = {'return': type} + test_function.__annotations__ = {"return": type} return test_function def python_version_lower_than_3_10(): - import sys - - if sys.version_info[0] < 3: - return True if sys.version_info[1] < 10: return True return False -class TestStringAnnotation(object): +class TestStringAnnotation: @pytest.mark.skipif( python_version_lower_than_3_10(), reason="inspect.signature(eval_str=True) only supported since 3.10 and higher" ) @pytest.mark.parametrize( - ['input', 'expected'], + ("input", "expected"), [ - ('str', 'VARCHAR'), - ('list[str]', 'VARCHAR[]'), - ('dict[str, str]', 'MAP(VARCHAR, VARCHAR)'), - ('dict[Union[str, bool], str]', 'MAP(UNION(u1 VARCHAR, u2 BOOLEAN), VARCHAR)'), + ("str", "VARCHAR"), + ("list[str]", "VARCHAR[]"), + ("dict[str, str]", "MAP(VARCHAR, VARCHAR)"), + ("dict[Union[str, bool], str]", "MAP(UNION(u1 VARCHAR, u2 BOOLEAN), VARCHAR)"), ], ) def test_string_annotations(self, duckdb_cursor, input, expected): @@ -46,7 +44,7 @@ def test_string_annotations(self, duckdb_cursor, input, expected): func = make_annotated_function(input) sig = signature(func) - assert sig.return_annotation.__class__ == str + assert sig.return_annotation.__class__ is str duckdb_cursor.create_function("foo", func) rel = duckdb_cursor.sql("select foo()") diff --git a/tests/fast/test_tf.py b/tests/fast/test_tf.py index b65acec6..ceec2ee0 100644 --- a/tests/fast/test_tf.py +++ b/tests/fast/test_tf.py @@ -1,8 +1,8 @@ -import duckdb import pytest +import duckdb -tf = pytest.importorskip('tensorflow') +tf = pytest.importorskip("tensorflow") def test_tf(): @@ -14,16 +14,16 @@ def test_tf(): # Test from connection duck_tf = con.execute("select * from t").tf() duck_numpy = con.sql("select * from t").fetchnumpy() - tf.math.equal(duck_tf['a'], tf.convert_to_tensor(duck_numpy['a'])) - tf.math.equal(duck_tf['b'], tf.convert_to_tensor(duck_numpy['b'])) + tf.math.equal(duck_tf["a"], tf.convert_to_tensor(duck_numpy["a"])) + tf.math.equal(duck_tf["b"], tf.convert_to_tensor(duck_numpy["b"])) # Test from relation duck_tf = con.sql("select * from t").tf() - tf.math.equal(duck_tf['a'], tf.convert_to_tensor(duck_numpy['a'])) - tf.math.equal(duck_tf['b'], tf.convert_to_tensor(duck_numpy['b'])) + tf.math.equal(duck_tf["a"], tf.convert_to_tensor(duck_numpy["a"])) + tf.math.equal(duck_tf["b"], tf.convert_to_tensor(duck_numpy["b"])) # Test all Numeric Types - numeric_types = ['TINYINT', 'SMALLINT', 'BIGINT', 'HUGEINT', 'FLOAT', 'DOUBLE', 'DECIMAL(4,1)', 'UTINYINT'] + numeric_types = ["TINYINT", "SMALLINT", "BIGINT", "HUGEINT", "FLOAT", "DOUBLE", "DECIMAL(4,1)", "UTINYINT"] for supported_type in numeric_types: con = duckdb.connect() @@ -31,5 +31,5 @@ def test_tf(): con.execute("insert into t values (1,2), (3,4)") duck_tf = con.sql("select * from t").tf() duck_numpy = con.sql("select * from t").fetchnumpy() - tf.math.equal(duck_tf['a'], tf.convert_to_tensor(duck_numpy['a'])) - tf.math.equal(duck_tf['b'], tf.convert_to_tensor(duck_numpy['b'])) + tf.math.equal(duck_tf["a"], tf.convert_to_tensor(duck_numpy["a"])) + tf.math.equal(duck_tf["b"], tf.convert_to_tensor(duck_numpy["b"])) diff --git a/tests/fast/test_transaction.py b/tests/fast/test_transaction.py index 54deaf82..0dfabafa 100644 --- a/tests/fast/test_transaction.py +++ b/tests/fast/test_transaction.py @@ -1,20 +1,19 @@ import duckdb -import pandas as pd -class TestConnectionTransaction(object): +class TestConnectionTransaction: def test_transaction(self, duckdb_cursor): con = duckdb.connect() - con.execute('create table t (i integer)') - con.execute('insert into t values (1)') + con.execute("create table t (i integer)") + con.execute("insert into t values (1)") con.begin() - con.execute('insert into t values (1)') - assert con.execute('select count (*) from t').fetchone()[0] == 2 + con.execute("insert into t values (1)") + assert con.execute("select count (*) from t").fetchone()[0] == 2 con.rollback() - assert con.execute('select count (*) from t').fetchone()[0] == 1 + assert con.execute("select count (*) from t").fetchone()[0] == 1 con.begin() - con.execute('insert into t values (1)') - assert con.execute('select count (*) from t').fetchone()[0] == 2 + con.execute("insert into t values (1)") + assert con.execute("select count (*) from t").fetchone()[0] == 2 con.commit() - assert con.execute('select count (*) from t').fetchone()[0] == 2 + assert con.execute("select count (*) from t").fetchone()[0] == 2 diff --git a/tests/fast/test_type.py b/tests/fast/test_type.py index 6f648179..0eb96716 100644 --- a/tests/fast/test_type.py +++ b/tests/fast/test_type.py @@ -1,166 +1,163 @@ -import duckdb -import os -import pandas as pd -import pytest -from typing import Union, Optional import sys +from typing import Optional, Union + +import pytest +import duckdb +import duckdb.typing from duckdb.typing import ( - SQLNULL, - BOOLEAN, - TINYINT, - UTINYINT, - SMALLINT, - USMALLINT, - INTEGER, - UINTEGER, BIGINT, - UBIGINT, - HUGEINT, - UHUGEINT, - UUID, - FLOAT, - DOUBLE, + BIT, + BLOB, + BOOLEAN, DATE, + DOUBLE, + FLOAT, + HUGEINT, + INTEGER, + INTERVAL, + SMALLINT, + SQLNULL, + TIME, + TIME_TZ, TIMESTAMP, TIMESTAMP_MS, TIMESTAMP_NS, TIMESTAMP_S, - DuckDBPyType, - TIME, - TIME_TZ, TIMESTAMP_TZ, + TINYINT, + UBIGINT, + UHUGEINT, + UINTEGER, + USMALLINT, + UTINYINT, + UUID, VARCHAR, - BLOB, - BIT, - INTERVAL, + DuckDBPyType, ) -import duckdb.typing -class TestType(object): +class TestType: def test_sqltype(self): - assert str(duckdb.sqltype('struct(a VARCHAR, b BIGINT)')) == 'STRUCT(a VARCHAR, b BIGINT)' - # todo: add tests with invalid type_str + assert str(duckdb.sqltype("struct(a VARCHAR, b BIGINT)")) == "STRUCT(a VARCHAR, b BIGINT)" + # TODO: add tests with invalid type_str # noqa: TD002, TD003 def test_primitive_types(self): assert str(SQLNULL) == '"NULL"' - assert str(BOOLEAN) == 'BOOLEAN' - assert str(TINYINT) == 'TINYINT' - assert str(UTINYINT) == 'UTINYINT' - assert str(SMALLINT) == 'SMALLINT' - assert str(USMALLINT) == 'USMALLINT' - assert str(INTEGER) == 'INTEGER' - assert str(UINTEGER) == 'UINTEGER' - assert str(BIGINT) == 'BIGINT' - assert str(UBIGINT) == 'UBIGINT' - assert str(HUGEINT) == 'HUGEINT' - assert str(UHUGEINT) == 'UHUGEINT' - assert str(UUID) == 'UUID' - assert str(FLOAT) == 'FLOAT' - assert str(DOUBLE) == 'DOUBLE' - assert str(DATE) == 'DATE' - assert str(TIMESTAMP) == 'TIMESTAMP' - assert str(TIMESTAMP_MS) == 'TIMESTAMP_MS' - assert str(TIMESTAMP_NS) == 'TIMESTAMP_NS' - assert str(TIMESTAMP_S) == 'TIMESTAMP_S' - assert str(TIME) == 'TIME' - assert str(TIME_TZ) == 'TIME WITH TIME ZONE' - assert str(TIMESTAMP_TZ) == 'TIMESTAMP WITH TIME ZONE' - assert str(VARCHAR) == 'VARCHAR' - assert str(BLOB) == 'BLOB' - assert str(BIT) == 'BIT' - assert str(INTERVAL) == 'INTERVAL' + assert str(BOOLEAN) == "BOOLEAN" + assert str(TINYINT) == "TINYINT" + assert str(UTINYINT) == "UTINYINT" + assert str(SMALLINT) == "SMALLINT" + assert str(USMALLINT) == "USMALLINT" + assert str(INTEGER) == "INTEGER" + assert str(UINTEGER) == "UINTEGER" + assert str(BIGINT) == "BIGINT" + assert str(UBIGINT) == "UBIGINT" + assert str(HUGEINT) == "HUGEINT" + assert str(UHUGEINT) == "UHUGEINT" + assert str(UUID) == "UUID" + assert str(FLOAT) == "FLOAT" + assert str(DOUBLE) == "DOUBLE" + assert str(DATE) == "DATE" + assert str(TIMESTAMP) == "TIMESTAMP" + assert str(TIMESTAMP_MS) == "TIMESTAMP_MS" + assert str(TIMESTAMP_NS) == "TIMESTAMP_NS" + assert str(TIMESTAMP_S) == "TIMESTAMP_S" + assert str(TIME) == "TIME" + assert str(TIME_TZ) == "TIME WITH TIME ZONE" + assert str(TIMESTAMP_TZ) == "TIMESTAMP WITH TIME ZONE" + assert str(VARCHAR) == "VARCHAR" + assert str(BLOB) == "BLOB" + assert str(BIT) == "BIT" + assert str(INTERVAL) == "INTERVAL" def test_list_type(self): type = duckdb.list_type(BIGINT) - assert str(type) == 'BIGINT[]' + assert str(type) == "BIGINT[]" def test_array_type(self): type = duckdb.array_type(BIGINT, 3) - assert str(type) == 'BIGINT[3]' + assert str(type) == "BIGINT[3]" def test_struct_type(self): - type = duckdb.struct_type({'a': BIGINT, 'b': BOOLEAN}) - assert str(type) == 'STRUCT(a BIGINT, b BOOLEAN)' + type = duckdb.struct_type({"a": BIGINT, "b": BOOLEAN}) + assert str(type) == "STRUCT(a BIGINT, b BOOLEAN)" - # FIXME: create an unnamed struct when fields are provided as a list + # TODO: create an unnamed struct when fields are provided as a list # noqa: TD002, TD003 type = duckdb.struct_type([BIGINT, BOOLEAN]) - assert str(type) == 'STRUCT(v1 BIGINT, v2 BOOLEAN)' + assert str(type) == "STRUCT(v1 BIGINT, v2 BOOLEAN)" def test_incomplete_struct_type(self): with pytest.raises( - duckdb.InvalidInputException, match='Could not convert empty dictionary to a duckdb STRUCT type' + duckdb.InvalidInputException, match="Could not convert empty dictionary to a duckdb STRUCT type" ): - type = duckdb.typing.DuckDBPyType(dict()) + duckdb.typing.DuckDBPyType({}) def test_map_type(self): type = duckdb.map_type(duckdb.sqltype("BIGINT"), duckdb.sqltype("DECIMAL(10, 2)")) - assert str(type) == 'MAP(BIGINT, DECIMAL(10,2))' + assert str(type) == "MAP(BIGINT, DECIMAL(10,2))" def test_decimal_type(self): type = duckdb.decimal_type(5, 3) - assert str(type) == 'DECIMAL(5,3)' + assert str(type) == "DECIMAL(5,3)" def test_string_type(self): type = duckdb.string_type() - assert str(type) == 'VARCHAR' + assert str(type) == "VARCHAR" def test_string_type_collation(self): - type = duckdb.string_type('NOCASE') + type = duckdb.string_type("NOCASE") # collation does not show up in the string representation.. - assert str(type) == 'VARCHAR' + assert str(type) == "VARCHAR" def test_union_type(self): type = duckdb.union_type([BIGINT, VARCHAR, TINYINT]) - assert str(type) == 'UNION(v1 BIGINT, v2 VARCHAR, v3 TINYINT)' - - type = duckdb.union_type({'a': BIGINT, 'b': VARCHAR, 'c': TINYINT}) - assert str(type) == 'UNION(a BIGINT, b VARCHAR, c TINYINT)' + assert str(type) == "UNION(v1 BIGINT, v2 VARCHAR, v3 TINYINT)" - import sys + type = duckdb.union_type({"a": BIGINT, "b": VARCHAR, "c": TINYINT}) + assert str(type) == "UNION(a BIGINT, b VARCHAR, c TINYINT)" @pytest.mark.skipif(sys.version_info < (3, 9), reason="requires >= python3.9") def test_implicit_convert_from_builtin_type(self): type = duckdb.list_type(list[str]) assert str(type.child) == "VARCHAR[]" - mapping = {str: 'VARCHAR', int: 'BIGINT', bytes: 'BLOB', bytearray: 'BLOB', bool: 'BOOLEAN', float: 'DOUBLE'} + mapping = {str: "VARCHAR", int: "BIGINT", bytes: "BLOB", bytearray: "BLOB", bool: "BOOLEAN", float: "DOUBLE"} for duckdb_type, expected in mapping.items(): res = duckdb.list_type(duckdb_type) assert str(res.child) == expected - res = duckdb.list_type({'a': str, 'b': int}) - assert str(res.child) == 'STRUCT(a VARCHAR, b BIGINT)' + res = duckdb.list_type({"a": str, "b": int}) + assert str(res.child) == "STRUCT(a VARCHAR, b BIGINT)" res = duckdb.list_type(dict[str, int]) - assert str(res.child) == 'MAP(VARCHAR, BIGINT)' + assert str(res.child) == "MAP(VARCHAR, BIGINT)" res = duckdb.list_type(list[str]) - assert str(res.child) == 'VARCHAR[]' + assert str(res.child) == "VARCHAR[]" res = duckdb.list_type(list[dict[str, dict[list[str], str]]]) - assert str(res.child) == 'MAP(VARCHAR, MAP(VARCHAR[], VARCHAR))[]' + assert str(res.child) == "MAP(VARCHAR, MAP(VARCHAR[], VARCHAR))[]" res = duckdb.list_type(list[Union[str, int]]) - assert str(res.child) == 'UNION(u1 VARCHAR, u2 BIGINT)[]' + assert str(res.child) == "UNION(u1 VARCHAR, u2 BIGINT)[]" def test_implicit_convert_from_numpy(self, duckdb_cursor): np = pytest.importorskip("numpy") type_mapping = { - 'bool': 'BOOLEAN', - 'int8': 'TINYINT', - 'uint8': 'UTINYINT', - 'int16': 'SMALLINT', - 'uint16': 'USMALLINT', - 'int32': 'INTEGER', - 'uint32': 'UINTEGER', - 'int64': 'BIGINT', - 'uint64': 'UBIGINT', - 'float16': 'FLOAT', - 'float32': 'FLOAT', - 'float64': 'DOUBLE', + "bool": "BOOLEAN", + "int8": "TINYINT", + "uint8": "UTINYINT", + "int16": "SMALLINT", + "uint16": "USMALLINT", + "int32": "INTEGER", + "uint32": "UINTEGER", + "int64": "BIGINT", + "uint64": "UBIGINT", + "float16": "FLOAT", + "float32": "FLOAT", + "float64": "DOUBLE", } builtins = [] @@ -189,58 +186,72 @@ def test_implicit_convert_from_numpy(self, duckdb_cursor): def test_attribute_accessor(self): type = duckdb.row_type([BIGINT, duckdb.list_type(duckdb.map_type(BLOB, BIT))]) - assert hasattr(type, 'a') == False - assert hasattr(type, 'v1') == True + assert not hasattr(type, "a") + assert hasattr(type, "v1") - field_one = type['v1'] - assert str(field_one) == 'BIGINT' + field_one = type["v1"] + assert str(field_one) == "BIGINT" field_one = type.v1 - assert str(field_one) == 'BIGINT' + assert str(field_one) == "BIGINT" - field_two = type['v2'] - assert str(field_two) == 'MAP(BLOB, BIT)[]' + field_two = type["v2"] + assert str(field_two) == "MAP(BLOB, BIT)[]" child_type = type.v2.child - assert str(child_type) == 'MAP(BLOB, BIT)' + assert str(child_type) == "MAP(BLOB, BIT)" def test_json_type(self): - json_type = duckdb.type('JSON') + json_type = duckdb.type("JSON") val = duckdb.Value('{"duck": 42}', json_type) res = duckdb.execute("select typeof($1)", [val]).fetchone() - assert res == ('JSON',) + assert res == ("JSON",) def test_struct_from_dict(self): - res = duckdb.list_type({'a': VARCHAR, 'b': VARCHAR}) - assert res == 'STRUCT(a VARCHAR, b VARCHAR)[]' + res = duckdb.list_type({"a": VARCHAR, "b": VARCHAR}) + assert res == "STRUCT(a VARCHAR, b VARCHAR)[]" + + def test_hash_method(self): + type1 = duckdb.list_type({"a": VARCHAR, "b": VARCHAR}) + type2 = duckdb.list_type({"b": VARCHAR, "a": VARCHAR}) + type3 = VARCHAR + + type_set = set() + type_set.add(type1) + type_set.add(type2) + type_set.add(type3) + + type_set.add(type1) + expected = ["STRUCT(a VARCHAR, b VARCHAR)[]", "STRUCT(b VARCHAR, a VARCHAR)[]", "VARCHAR"] + assert sorted([str(x) for x in list(type_set)]) == expected # NOTE: we can support this, but I don't think going through hoops for an outdated version of python is worth it @pytest.mark.skipif(sys.version_info < (3, 9), reason="python3.7 does not store Optional[..] in a recognized way") def test_optional(self): type = duckdb.typing.DuckDBPyType(Optional[str]) - assert type == 'VARCHAR' + assert type == "VARCHAR" type = duckdb.typing.DuckDBPyType(Optional[Union[int, bool]]) - assert type == 'UNION(u1 BIGINT, u2 BOOLEAN)' + assert type == "UNION(u1 BIGINT, u2 BOOLEAN)" type = duckdb.typing.DuckDBPyType(Optional[list[int]]) - assert type == 'BIGINT[]' + assert type == "BIGINT[]" type = duckdb.typing.DuckDBPyType(Optional[dict[int, str]]) - assert type == 'MAP(BIGINT, VARCHAR)' + assert type == "MAP(BIGINT, VARCHAR)" type = duckdb.typing.DuckDBPyType(Optional[dict[Optional[int], Optional[str]]]) - assert type == 'MAP(BIGINT, VARCHAR)' + assert type == "MAP(BIGINT, VARCHAR)" type = duckdb.typing.DuckDBPyType(Optional[dict[Optional[int], Optional[str]]]) - assert type == 'MAP(BIGINT, VARCHAR)' + assert type == "MAP(BIGINT, VARCHAR)" type = duckdb.typing.DuckDBPyType(Optional[Union[Optional[str], Optional[bool]]]) - assert type == 'UNION(u1 VARCHAR, u2 BOOLEAN)' + assert type == "UNION(u1 VARCHAR, u2 BOOLEAN)" type = duckdb.typing.DuckDBPyType(Union[str, None]) - assert type == 'VARCHAR' + assert type == "VARCHAR" @pytest.mark.skipif(sys.version_info < (3, 10), reason="'str | None' syntax requires Python 3.10 or higher") def test_optional_310(self): type = duckdb.typing.DuckDBPyType(str | None) - assert type == 'VARCHAR' + assert type == "VARCHAR" def test_children_attribute(self): - assert DuckDBPyType('INTEGER[]').children == [('child', DuckDBPyType('INTEGER'))] - assert DuckDBPyType('INTEGER[2]').children == [('child', DuckDBPyType('INTEGER')), ('size', 2)] - assert DuckDBPyType('INTEGER[2][3]').children == [('child', DuckDBPyType('INTEGER[2]')), ('size', 3)] - assert DuckDBPyType("ENUM('a', 'b', 'c')").children == [('values', ['a', 'b', 'c'])] + assert DuckDBPyType("INTEGER[]").children == [("child", DuckDBPyType("INTEGER"))] + assert DuckDBPyType("INTEGER[2]").children == [("child", DuckDBPyType("INTEGER")), ("size", 2)] + assert DuckDBPyType("INTEGER[2][3]").children == [("child", DuckDBPyType("INTEGER[2]")), ("size", 3)] + assert DuckDBPyType("ENUM('a', 'b', 'c')").children == [("values", ["a", "b", "c"])] diff --git a/tests/fast/test_type_explicit.py b/tests/fast/test_type_explicit.py index 23dcddc3..3b9fe334 100644 --- a/tests/fast/test_type_explicit.py +++ b/tests/fast/test_type_explicit.py @@ -1,20 +1,19 @@ import duckdb -class TestMap(object): - +class TestMap: def test_array_list_tuple_ambiguity(self): con = duckdb.connect() - res = con.sql("SELECT $arg", params={'arg': (1, 2)}).fetchall()[0][0] + res = con.sql("SELECT $arg", params={"arg": (1, 2)}).fetchall()[0][0] assert res == [1, 2] # By using an explicit duckdb.Value with an array type, we should convert the input as an array # and get an array (tuple) back typ = duckdb.array_type(duckdb.typing.BIGINT, 2) val = duckdb.Value((1, 2), typ) - res = con.sql("SELECT $arg", params={'arg': val}).fetchall()[0][0] + res = con.sql("SELECT $arg", params={"arg": val}).fetchall()[0][0] assert res == (1, 2) val = duckdb.Value([3, 4], typ) - res = con.sql("SELECT $arg", params={'arg': val}).fetchall()[0][0] + res = con.sql("SELECT $arg", params={"arg": val}).fetchall()[0][0] assert res == (3, 4) diff --git a/tests/fast/test_unicode.py b/tests/fast/test_unicode.py index b697f84a..f1ed8501 100644 --- a/tests/fast/test_unicode.py +++ b/tests/fast/test_unicode.py @@ -1,13 +1,13 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- -import duckdb import pandas as pd +import duckdb + -class TestUnicode(object): +class TestUnicode: def test_unicode_pandas_scan(self, duckdb_cursor): - con = duckdb.connect(database=':memory:', read_only=False) - test_df = pd.DataFrame.from_dict({"i": [1, 2, 3], "j": ["a", "c", u"ë"]}) - con.register('test_df_view', test_df) - con.execute('SELECT i, j, LENGTH(j) FROM test_df_view').fetchall() + con = duckdb.connect(database=":memory:", read_only=False) + test_df = pd.DataFrame.from_dict({"i": [1, 2, 3], "j": ["a", "c", "ë"]}) + con.register("test_df_view", test_df) + con.execute("SELECT i, j, LENGTH(j) FROM test_df_view").fetchall() diff --git a/tests/fast/test_union.py b/tests/fast/test_union.py index 912caff9..2cd096d7 100644 --- a/tests/fast/test_union.py +++ b/tests/fast/test_union.py @@ -1,8 +1,7 @@ import duckdb -import pandas as pd -class TestUnion(object): +class TestUnion: def test_union_by_all(self): connection = duckdb.connect() @@ -44,8 +43,8 @@ def test_union_by_all(self): (13, 14, 15, 16, 17), ] - df_1 = connection.execute("FROM tbl1").df() - df_2 = connection.execute("FROM tbl2").df() + df_1 = connection.execute("FROM tbl1").df() # noqa: F841 + df_2 = connection.execute("FROM tbl2").df() # noqa: F841 query = """ select diff --git a/tests/fast/test_value.py b/tests/fast/test_value.py index 4f74516c..58aa7a4d 100644 --- a/tests/fast/test_value.py +++ b/tests/fast/test_value.py @@ -1,77 +1,70 @@ -import duckdb -from pytest import raises -from duckdb import NotImplementedException, InvalidInputException -from duckdb.value.constant import ( - Value, - NullValue, - BooleanValue, - UnsignedBinaryValue, - UnsignedShortValue, - UnsignedIntegerValue, - UnsignedLongValue, - BinaryValue, - ShortValue, - IntegerValue, - LongValue, - HugeIntegerValue, - UnsignedHugeIntegerValue, - FloatValue, - DoubleValue, - DecimalValue, - StringValue, - UUIDValue, - BitValue, - BlobValue, - DateValue, - IntervalValue, - TimestampValue, - TimestampSecondValue, - TimestampMilisecondValue, - TimestampNanosecondValue, - TimestampTimeZoneValue, - TimeValue, - TimeTimeZoneValue, -) -import uuid import datetime -import pytest import decimal +import uuid + +import pytest +import duckdb +from duckdb import InvalidInputException, NotImplementedException from duckdb.typing import ( - SQLNULL, + BIGINT, + BIT, + BLOB, BOOLEAN, - TINYINT, - UTINYINT, - SMALLINT, - USMALLINT, + DATE, + DOUBLE, + FLOAT, + HUGEINT, INTEGER, - UINTEGER, - BIGINT, + INTERVAL, + SMALLINT, + SQLNULL, + TIME, + TIMESTAMP, + TINYINT, UBIGINT, - HUGEINT, UHUGEINT, + UINTEGER, + USMALLINT, + UTINYINT, UUID, - FLOAT, - DOUBLE, - DATE, - TIMESTAMP, - TIMESTAMP_MS, - TIMESTAMP_NS, - TIMESTAMP_S, - TIME, - TIME_TZ, - TIMESTAMP_TZ, VARCHAR, - BLOB, - BIT, - INTERVAL, +) +from duckdb.value.constant import ( + BinaryValue, + BitValue, + BlobValue, + BooleanValue, + DateValue, + DecimalValue, + DoubleValue, + FloatValue, + HugeIntegerValue, + IntegerValue, + IntervalValue, + LongValue, + NullValue, + ShortValue, + StringValue, + TimestampMilisecondValue, + TimestampNanosecondValue, + TimestampSecondValue, + TimestampValue, + TimeValue, + UnsignedBinaryValue, + UnsignedHugeIntegerValue, + UnsignedIntegerValue, + UnsignedLongValue, + UnsignedShortValue, + UUIDValue, + Value, ) -class TestValue(object): +class TestValue: # This excludes timezone aware values, as those are a pain to test @pytest.mark.parametrize( - 'item', + "item", [ (BOOLEAN, BooleanValue(True), True), (UTINYINT, UnsignedBinaryValue(129), 129), @@ -88,17 +81,17 @@ class TestValue(object): (DOUBLE, DoubleValue(0.23234234234), 0.23234234234), ( duckdb.decimal_type(12, 8), - DecimalValue(decimal.Decimal('1234.12345678'), 12, 8), - decimal.Decimal('1234.12345678'), + DecimalValue(decimal.Decimal("1234.12345678"), 12, 8), + decimal.Decimal("1234.12345678"), ), - (VARCHAR, StringValue('this is a long string'), 'this is a long string'), + (VARCHAR, StringValue("this is a long string"), "this is a long string"), ( UUID, - UUIDValue(uuid.UUID('ffffffff-ffff-ffff-ffff-ffffffffffff')), - uuid.UUID('ffffffff-ffff-ffff-ffff-ffffffffffff'), + UUIDValue(uuid.UUID("ffffffff-ffff-ffff-ffff-ffffffffffff")), + uuid.UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"), ), - (BIT, BitValue(b'010101010101'), '010101010101'), - (BLOB, BlobValue(b'\x00\x00\x00a'), b'\x00\x00\x00a'), + (BIT, BitValue(b"010101010101"), "010101010101"), + (BLOB, BlobValue(b"\x00\x00\x00a"), b"\x00\x00\x00a"), (DATE, DateValue(datetime.date(2000, 5, 4)), datetime.date(2000, 5, 4)), (INTERVAL, IntervalValue(datetime.timedelta(days=5)), datetime.timedelta(days=5)), ( @@ -116,10 +109,10 @@ def test_value_helpers(self, item): expected_value = item[2] con = duckdb.connect() - observed_type = con.execute('select typeof(a) from (select $1) tbl(a)', [value_object]).fetchall()[0][0] + observed_type = con.execute("select typeof(a) from (select $1) tbl(a)", [value_object]).fetchall()[0][0] assert observed_type == str(expected_type) - con.execute('select $1', [value_object]) + con.execute("select $1", [value_object]) result = con.fetchone() result = result[0] assert result == expected_value @@ -129,10 +122,10 @@ def test_float_to_decimal_prevention(self): con = duckdb.connect() with pytest.raises(duckdb.ConversionException, match="Can't losslessly convert"): - con.execute('select $1', [value]).fetchall() + con.execute("select $1", [value]).fetchall() @pytest.mark.parametrize( - 'value', + "value", [ TimestampSecondValue(datetime.datetime(1970, 3, 21, 12, 36, 43)), TimestampMilisecondValue(datetime.datetime(1970, 3, 21, 12, 36, 43)), @@ -142,12 +135,12 @@ def test_float_to_decimal_prevention(self): def test_timestamp_sec_not_supported(self, value): con = duckdb.connect() with pytest.raises( - duckdb.NotImplementedException, match="Conversion from 'datetime' to type .* is not implemented yet" + duckdb.NotImplementedException, match=r"Conversion from 'datetime' to type .* is not implemented yet" ): - con.execute('select $1', [value]).fetchall() + con.execute("select $1", [value]).fetchall() @pytest.mark.parametrize( - 'target_type,test_value,expected_conversion_success', + ("target_type", "test_value", "expected_conversion_success"), [ (TINYINT, 0, True), (TINYINT, 255, False), @@ -187,11 +180,12 @@ def test_numeric_values(self, target_type, test_value, expected_conversion_succe value = Value(test_value, target_type) con = duckdb.connect() - work = lambda: con.execute('select typeof(a) from (select $1) tbl(a)', [value]).fetchall() + def work(): + return con.execute("select typeof(a) from (select $1) tbl(a)", [value]).fetchall() if expected_conversion_success: res = work() assert str(target_type) == res[0][0] else: - with raises((NotImplementedException, InvalidInputException)): + with pytest.raises((NotImplementedException, InvalidInputException)): work() diff --git a/tests/fast/test_version.py b/tests/fast/test_version.py index cdeb42b0..81f72855 100644 --- a/tests/fast/test_version.py +++ b/tests/fast/test_version.py @@ -1,6 +1,7 @@ -import duckdb import sys +import duckdb + def test_version(): assert duckdb.__version__ != "0.0.0" diff --git a/tests/fast/test_versioning.py b/tests/fast/test_versioning.py index 7a3c7a68..d21052c8 100644 --- a/tests/fast/test_versioning.py +++ b/tests/fast/test_versioning.py @@ -1,24 +1,28 @@ -""" -Tests for duckdb_pytooling versioning functionality. -""" +"""Tests for duckdb_pytooling versioning functionality.""" + import os +import subprocess import unittest +from unittest.mock import MagicMock, patch import pytest -import subprocess -from unittest.mock import patch, MagicMock duckdb_packaging = pytest.importorskip("duckdb_packaging") -from duckdb_packaging._versioning import ( - parse_version, +from duckdb_packaging._versioning import ( # noqa: E402 format_version, - git_tag_to_pep440, - pep440_to_git_tag, get_current_version, get_git_describe, + git_tag_to_pep440, + parse_version, + pep440_to_git_tag, +) +from duckdb_packaging.setuptools_scm_version import ( # noqa: E402 + _bump_dev_version, + _tag_to_version, + forced_version_from_env, + version_scheme, ) -from duckdb_packaging.setuptools_scm_version import _bump_version, version_scheme, forced_version_from_env class TestVersionParsing(unittest.TestCase): @@ -106,29 +110,29 @@ class TestSetupToolsScmIntegration(unittest.TestCase): def test_bump_version_exact_tag(self): """Test bump_version with exact tag (distance=0, dirty=False).""" - assert _bump_version("1.2.3", 0, False) == "1.2.3" - assert _bump_version("1.2.3.post1", 0, False) == "1.2.3.post1" + assert _tag_to_version("1.2.3") == "1.2.3" + assert _tag_to_version("1.2.3.post1") == "1.2.3.post1" - @patch.dict('os.environ', {'MAIN_BRANCH_VERSIONING': '1'}) + @patch.dict("os.environ", {"MAIN_BRANCH_VERSIONING": "1"}) def test_bump_version_with_distance(self): """Test bump_version with distance from tag.""" - assert _bump_version("1.2.3", 5, False) == "1.3.0.dev5" - + assert _bump_dev_version("1.2.3", 5) == "1.3.0.dev5" + # Post-release development - assert _bump_version("1.2.3.post1", 3, False) == "1.2.3.post2.dev3" + assert _bump_dev_version("1.2.3.post1", 3) == "1.2.3.post2.dev3" - @patch.dict('os.environ', {'MAIN_BRANCH_VERSIONING': '0'}) + @patch.dict("os.environ", {"MAIN_BRANCH_VERSIONING": "0"}) def test_bump_version_release_branch(self): """Test bump_version on bugfix branch.""" - assert _bump_version("1.2.3", 5, False) == "1.2.4.dev5" + assert _bump_dev_version("1.2.3", 5) == "1.2.4.dev5" - @patch.dict('os.environ', {'MAIN_BRANCH_VERSIONING': '1'}) + @patch.dict("os.environ", {"MAIN_BRANCH_VERSIONING": "1"}) def test_bump_version_dirty(self): """Test bump_version with dirty working directory.""" - assert _bump_version("1.2.3", 0, True) == "1.3.0.dev0" - assert _bump_version("1.2.3.post1", 0, True) == "1.2.3.post2.dev0" + with pytest.raises(ValueError, match="Dev distance is 0, cannot bump version"): + _bump_dev_version("1.2.3", 0) - @patch.dict('os.environ', {'MAIN_BRANCH_VERSIONING': '1'}) + @patch.dict("os.environ", {"MAIN_BRANCH_VERSIONING": "1"}) def test_version_scheme_function(self): """Test the version_scheme function that setuptools_scm calls.""" # Mock setuptools_scm version object @@ -136,61 +140,60 @@ def test_version_scheme_function(self): mock_version.tag = "1.2.3" mock_version.distance = 5 mock_version.dirty = False - + result = version_scheme(mock_version) assert result == "1.3.0.dev5" def test_bump_version_invalid_format(self): """Test bump_version with invalid version format.""" - with pytest.raises(ValueError, match="Incorrect version format"): - _bump_version("invalid", 0, False) + with pytest.raises(ValueError, match="Invalid version format"): + _tag_to_version("invalid") + with pytest.raises(ValueError, match="Invalid version format"): + _bump_dev_version("invalid", 1) class TestGitOperations(unittest.TestCase): """Test git-related operations (mocked).""" - @patch('subprocess.run') + @patch("subprocess.run") def test_get_current_version_success(self, mock_run): """Test successful current version retrieval.""" mock_run.return_value.stdout = "v1.2.3\n" mock_run.return_value.check = True - + result = get_current_version() assert result == "1.2.3" mock_run.assert_called_once_with( - ["git", "describe", "--tags", "--abbrev=0"], - capture_output=True, - text=True, - check=True + ["git", "describe", "--tags", "--abbrev=0"], capture_output=True, text=True, check=True ) - @patch('subprocess.run') + @patch("subprocess.run") def test_get_current_version_with_post_release(self, mock_run): """Test current version retrieval with post-release tag.""" mock_run.return_value.stdout = "v1.2.3-post1\n" mock_run.return_value.check = True - + result = get_current_version() assert result == "1.2.3.post1" - @patch('subprocess.run') + @patch("subprocess.run") def test_get_current_version_no_tags(self, mock_run): """Test current version retrieval when no tags exist.""" mock_run.side_effect = subprocess.CalledProcessError(1, "git") - + result = get_current_version() assert result is None - @patch('subprocess.run') + @patch("subprocess.run") def test_get_git_describe_success(self, mock_run): """Test successful git describe.""" mock_run.return_value.stdout = "v1.2.3-5-g1234567\n" mock_run.return_value.check = True - + result = get_git_describe() assert result == "v1.2.3-5-g1234567" - @patch('subprocess.run') + @patch("subprocess.run") def test_get_git_describe_no_tags(self, mock_run): """Test git describe when no tags exist.""" mock_run.side_effect = subprocess.CalledProcessError(1, "git") @@ -202,21 +205,21 @@ def test_get_git_describe_no_tags(self, mock_run): class TestEnvironmentVariableHandling(unittest.TestCase): """Test environment variable handling in setuptools_scm integration.""" - @patch.dict('os.environ', {'OVERRIDE_GIT_DESCRIBE': 'v1.2.3-5-g1234567'}) + @patch.dict("os.environ", {"OVERRIDE_GIT_DESCRIBE": "v1.2.3-5-g1234567"}) def test_override_git_describe_basic(self): """Test OVERRIDE_GIT_DESCRIBE with basic format.""" forced_version_from_env() # Check that the environment variable was processed - assert 'SETUPTOOLS_SCM_PRETEND_VERSION_FOR_DUCKDB' in os.environ + assert "SETUPTOOLS_SCM_PRETEND_VERSION_FOR_DUCKDB" in os.environ - @patch.dict('os.environ', {'OVERRIDE_GIT_DESCRIBE': 'v1.2.3-post1-3-g1234567'}) + @patch.dict("os.environ", {"OVERRIDE_GIT_DESCRIBE": "v1.2.3-post1-3-g1234567"}) def test_override_git_describe_post_release(self): """Test OVERRIDE_GIT_DESCRIBE with post-release format.""" forced_version_from_env() # Check that post-release was converted correctly - assert 'SETUPTOOLS_SCM_PRETEND_VERSION_FOR_DUCKDB' in os.environ + assert "SETUPTOOLS_SCM_PRETEND_VERSION_FOR_DUCKDB" in os.environ - @patch.dict('os.environ', {'OVERRIDE_GIT_DESCRIBE': 'invalid-format'}) + @patch.dict("os.environ", {"OVERRIDE_GIT_DESCRIBE": "invalid-format"}) def test_override_git_describe_invalid(self): """Test OVERRIDE_GIT_DESCRIBE with invalid format.""" with pytest.raises(ValueError, match="Invalid git describe override"): diff --git a/tests/fast/test_windows_abs_path.py b/tests/fast/test_windows_abs_path.py index bc9f05ec..c3a606bd 100644 --- a/tests/fast/test_windows_abs_path.py +++ b/tests/fast/test_windows_abs_path.py @@ -1,31 +1,32 @@ -import duckdb -import pytest -import os import shutil +import sys +from pathlib import Path + +import pytest + +import duckdb -class TestWindowsAbsPath(object): - def test_windows_path_accent(self): - if os.name != 'nt': - return - current_directory = os.getcwd() - test_dir = os.path.join(current_directory, 'tést') - if os.path.isdir(test_dir): +@pytest.mark.skipif(not sys.platform.startswith("win"), reason="Tests only run on Windows") +class TestWindowsAbsPath: + def test_windows_path_accent(self, monkeypatch): + test_dir = Path.cwd() / "tést" + if test_dir.exists(): shutil.rmtree(test_dir) - os.mkdir(test_dir) + test_dir.mkdir() - dbname = 'test.db' - dbpath = os.path.join(test_dir, dbname) - con = duckdb.connect(dbpath) + dbname = "test.db" + dbpath = test_dir / dbname + con = duckdb.connect(str(dbpath)) con.execute("CREATE OR REPLACE TABLE int AS SELECT * FROM range(10) t(i)") res = con.execute("SELECT COUNT(*) FROM int").fetchall() assert res[0][0] == 10 del res del con - os.chdir('tést') - dbpath = os.path.join('..', dbpath) - con = duckdb.connect(dbpath) + monkeypatch.chdir("tést") + rel_dbpath = Path("..") / dbpath + con = duckdb.connect(str(rel_dbpath)) res = con.execute("SELECT COUNT(*) FROM int").fetchall() assert res[0][0] == 10 del res @@ -37,33 +38,18 @@ def test_windows_path_accent(self): del res del con - os.chdir('..') - def test_windows_abs_path(self): - if os.name != 'nt': - return - current_directory = os.getcwd() - dbpath = os.path.join(current_directory, 'test.db') - con = duckdb.connect(dbpath) - con.execute("CREATE OR REPLACE TABLE int AS SELECT * FROM range(10) t(i)") - res = con.execute("SELECT COUNT(*) FROM int").fetchall() - assert res[0][0] == 10 - del res - del con - - assert dbpath[1] == ':' - # remove the drive separator and reconnect - dbpath = dbpath[2:] - con = duckdb.connect(dbpath) - res = con.execute("SELECT COUNT(*) FROM int").fetchall() - assert res[0][0] == 10 - del res - del con - - # forward slashes work as well - dbpath = dbpath.replace('\\', '/') - con = duckdb.connect(dbpath) - res = con.execute("SELECT COUNT(*) FROM int").fetchall() - assert res[0][0] == 10 - del res - del con + # setup paths to test with + dbpath = Path.cwd() / "test.db" + abspath = str(dbpath.resolve()) + assert abspath[1] == ":" + no_drive_path = abspath[2:] + fwd_slash_path = no_drive_path.replace("\\", "/") + + for testpath in (abspath, no_drive_path, fwd_slash_path): + con = duckdb.connect(testpath) + con.execute("CREATE OR REPLACE TABLE int AS SELECT * FROM range(10) t(i)") + res = con.execute("SELECT COUNT(*) FROM int").fetchall() + assert res[0][0] == 10 + del res + del con diff --git a/tests/fast/types/test_blob.py b/tests/fast/types/test_blob.py index 162859d2..74f7f0b8 100644 --- a/tests/fast/types/test_blob.py +++ b/tests/fast/types/test_blob.py @@ -1,13 +1,12 @@ -import duckdb import numpy -class TestBlob(object): +class TestBlob: def test_blob(self, duckdb_cursor): duckdb_cursor.execute("SELECT BLOB 'hello'") results = duckdb_cursor.fetchall() - assert results[0][0] == b'hello' + assert results[0][0] == b"hello" duckdb_cursor.execute("SELECT BLOB 'hello' AS a") results = duckdb_cursor.fetchnumpy() - assert results['a'] == numpy.array([b'hello'], dtype=object) + assert results["a"] == numpy.array([b"hello"], dtype=object) diff --git a/tests/fast/types/test_boolean.py b/tests/fast/types/test_boolean.py index 8e8d2147..b97415dd 100644 --- a/tests/fast/types/test_boolean.py +++ b/tests/fast/types/test_boolean.py @@ -1,9 +1,5 @@ -import duckdb -import numpy - - -class TestBoolean(object): +class TestBoolean: def test_bool(self, duckdb_cursor): duckdb_cursor.execute("SELECT TRUE") results = duckdb_cursor.fetchall() - assert results[0][0] == True + assert results[0][0] diff --git a/tests/fast/types/test_datetime_date.py b/tests/fast/types/test_datetime_date.py index 9efb6bd1..d1c3d30b 100644 --- a/tests/fast/types/test_datetime_date.py +++ b/tests/fast/types/test_datetime_date.py @@ -1,8 +1,9 @@ -import duckdb import datetime +import duckdb + -class TestDateTimeDate(object): +class TestDateTimeDate: def test_date_infinity(self): con = duckdb.connect() # Positive infinity diff --git a/tests/fast/types/test_datetime_datetime.py b/tests/fast/types/test_datetime_datetime.py index 08a9953d..c486f9c9 100644 --- a/tests/fast/types/test_datetime_datetime.py +++ b/tests/fast/types/test_datetime_datetime.py @@ -1,32 +1,34 @@ -import duckdb import datetime + import pytest +import duckdb + def create_query(positive, type): - inf = 'infinity' if positive else '-infinity' + inf = "infinity" if positive else "-infinity" return f""" select '{inf}'::{type} """ -class TestDateTimeDateTime(object): - @pytest.mark.parametrize('positive', [True, False]) +class TestDateTimeDateTime: + @pytest.mark.parametrize("positive", [True, False]) @pytest.mark.parametrize( - 'type', + "type", [ - 'TIMESTAMP', - 'TIMESTAMP_S', - 'TIMESTAMP_MS', - 'TIMESTAMP_NS', - 'TIMESTAMPTZ', - 'TIMESTAMP_US', + "TIMESTAMP", + "TIMESTAMP_S", + "TIMESTAMP_MS", + "TIMESTAMP_NS", + "TIMESTAMPTZ", + "TIMESTAMP_US", ], ) def test_timestamp_infinity(self, positive, type): con = duckdb.connect() - if type in ['TIMESTAMP_S', 'TIMESTAMP_MS', 'TIMESTAMP_NS']: + if type in ["TIMESTAMP_S", "TIMESTAMP_MS", "TIMESTAMP_NS"]: # Infinity (both positive and negative) is not supported for non-usecond timetamps return diff --git a/tests/fast/types/test_decimal.py b/tests/fast/types/test_decimal.py index 30cb13e7..a5013dcd 100644 --- a/tests/fast/types/test_decimal.py +++ b/tests/fast/types/test_decimal.py @@ -1,26 +1,26 @@ +from decimal import Decimal + import numpy -import pandas -from decimal import * -class TestDecimal(object): +class TestDecimal: def test_decimal(self, duckdb_cursor): duckdb_cursor.execute( - 'SELECT 1.2::DECIMAL(4,1), 100.3::DECIMAL(9,1), 320938.4298::DECIMAL(18,4), 49082094824.904820482094::DECIMAL(30,12), NULL::DECIMAL' + "SELECT 1.2::DECIMAL(4,1), 100.3::DECIMAL(9,1), 320938.4298::DECIMAL(18,4), 49082094824.904820482094::DECIMAL(30,12), NULL::DECIMAL" # noqa: E501 ) result = duckdb_cursor.fetchall() assert result == [ - (Decimal('1.2'), Decimal('100.3'), Decimal('320938.4298'), Decimal('49082094824.904820482094'), None) + (Decimal("1.2"), Decimal("100.3"), Decimal("320938.4298"), Decimal("49082094824.904820482094"), None) ] def test_decimal_numpy(self, duckdb_cursor): duckdb_cursor.execute( - 'SELECT 1.2::DECIMAL(4,1) AS a, 100.3::DECIMAL(9,1) AS b, 320938.4298::DECIMAL(18,4) AS c, 49082094824.904820482094::DECIMAL(30,12) AS d' + "SELECT 1.2::DECIMAL(4,1) AS a, 100.3::DECIMAL(9,1) AS b, 320938.4298::DECIMAL(18,4) AS c, 49082094824.904820482094::DECIMAL(30,12) AS d" # noqa: E501 ) result = duckdb_cursor.fetchnumpy() assert result == { - 'a': numpy.array([1.2]), - 'b': numpy.array([100.3]), - 'c': numpy.array([320938.4298]), - 'd': numpy.array([49082094824.904820482094]), + "a": numpy.array([1.2]), + "b": numpy.array([100.3]), + "c": numpy.array([320938.4298]), + "d": numpy.array([49082094824.904820482094]), } diff --git a/tests/fast/types/test_hugeint.py b/tests/fast/types/test_hugeint.py index f0254380..aa8c900d 100644 --- a/tests/fast/types/test_hugeint.py +++ b/tests/fast/types/test_hugeint.py @@ -1,14 +1,13 @@ import numpy -import pandas -class TestHugeint(object): +class TestHugeint: def test_hugeint(self, duckdb_cursor): - duckdb_cursor.execute('SELECT 437894723897234238947043214') + duckdb_cursor.execute("SELECT 437894723897234238947043214") result = duckdb_cursor.fetchall() assert result == [(437894723897234238947043214,)] def test_hugeint_numpy(self, duckdb_cursor): - duckdb_cursor.execute('SELECT 1::HUGEINT AS i') + duckdb_cursor.execute("SELECT 1::HUGEINT AS i") result = duckdb_cursor.fetchnumpy() - assert result == {'i': numpy.array([1.0])} + assert result == {"i": numpy.array([1.0])} diff --git a/tests/fast/types/test_nan.py b/tests/fast/types/test_nan.py index b714ae6c..0d9e6122 100644 --- a/tests/fast/types/test_nan.py +++ b/tests/fast/types/test_nan.py @@ -1,12 +1,14 @@ -import numpy as np import datetime -import duckdb + +import numpy as np import pytest +import duckdb + pandas = pytest.importorskip("pandas") -class TestPandasNaN(object): +class TestPandasNaN: def test_pandas_nan(self, duckdb_cursor): # create a DataFrame with some basic values df = pandas.DataFrame([{"col1": "val1", "col2": 1.05}, {"col1": "val3", "col2": np.nan}]) @@ -15,34 +17,34 @@ def test_pandas_nan(self, duckdb_cursor): # now create a new column with the current time # (FIXME: we replace the microseconds with 0 for now, because we only support millisecond resolution) current_time = datetime.datetime.now().replace(microsecond=0) - df['datetest'] = current_time + df["datetest"] = current_time # introduce a NaT (Not a Time value) - df.loc[0, 'datetest'] = pandas.NaT + df.loc[0, "datetest"] = pandas.NaT # now pass the DF through duckdb: - conn = duckdb.connect(':memory:') - conn.register('testing_null_values', df) + conn = duckdb.connect(":memory:") + conn.register("testing_null_values", df) # scan the DF and fetch the results normally - results = conn.execute('select * from testing_null_values').fetchall() - assert results[0][0] == 'val1' + results = conn.execute("select * from testing_null_values").fetchall() + assert results[0][0] == "val1" assert results[0][1] == 1.05 - assert results[0][2] == None - assert results[0][3] == None - assert results[1][0] == 'val3' - assert results[1][1] == None - assert results[1][2] == 'val3' + assert results[0][2] is None + assert results[0][3] is None + assert results[1][0] == "val3" + assert results[1][1] is None + assert results[1][2] == "val3" assert results[1][3] == current_time # now fetch the results as numpy: - result_np = conn.execute('select * from testing_null_values').fetchnumpy() - assert result_np['col1'][0] == df['col1'][0] - assert result_np['col1'][1] == df['col1'][1] - assert result_np['col2'][0] == df['col2'][0] - - assert result_np['col2'].mask[1] - assert result_np['newcol1'].mask[0] - assert result_np['newcol1'][1] == df['newcol1'][1] - - result_df = conn.execute('select * from testing_null_values').fetchdf() - assert pandas.isnull(result_df['datetest'][0]) - assert result_df['datetest'][1] == df['datetest'][1] + result_np = conn.execute("select * from testing_null_values").fetchnumpy() + assert result_np["col1"][0] == df["col1"][0] + assert result_np["col1"][1] == df["col1"][1] + assert result_np["col2"][0] == df["col2"][0] + + assert result_np["col2"].mask[1] + assert result_np["newcol1"].mask[0] + assert result_np["newcol1"][1] == df["newcol1"][1] + + result_df = conn.execute("select * from testing_null_values").fetchdf() + assert pandas.isnull(result_df["datetest"][0]) + assert result_df["datetest"][1] == df["datetest"][1] diff --git a/tests/fast/types/test_nested.py b/tests/fast/types/test_nested.py index e005b3f3..824b2825 100644 --- a/tests/fast/types/test_nested.py +++ b/tests/fast/types/test_nested.py @@ -1,7 +1,4 @@ -import duckdb - - -class TestNested(object): +class TestNested: def test_lists(self, duckdb_cursor): result = duckdb_cursor.execute("SELECT LIST_VALUE(1, 2, 3, 4) ").fetchall() assert result == [([1, 2, 3, 4],)] @@ -23,24 +20,24 @@ def test_nested_lists(self, duckdb_cursor): def test_struct(self, duckdb_cursor): result = duckdb_cursor.execute("SELECT STRUCT_PACK(a := 42, b := 43)").fetchall() - assert result == [({'a': 42, 'b': 43},)] + assert result == [({"a": 42, "b": 43},)] result = duckdb_cursor.execute("SELECT STRUCT_PACK(a := 42, b := NULL)").fetchall() - assert result == [({'a': 42, 'b': None},)] + assert result == [({"a": 42, "b": None},)] def test_unnamed_struct(self, duckdb_cursor): result = duckdb_cursor.execute("SELECT row('aa','bb') AS x").fetchall() - assert result == [(('aa', 'bb'),)] + assert result == [(("aa", "bb"),)] result = duckdb_cursor.execute("SELECT row('aa',NULL) AS x").fetchall() - assert result == [(('aa', None),)] + assert result == [(("aa", None),)] def test_nested_struct(self, duckdb_cursor): result = duckdb_cursor.execute("SELECT STRUCT_PACK(a := 42, b := LIST_VALUE(10, 9, 8, 7))").fetchall() - assert result == [({'a': 42, 'b': [10, 9, 8, 7]},)] + assert result == [({"a": 42, "b": [10, 9, 8, 7]},)] result = duckdb_cursor.execute("SELECT STRUCT_PACK(a := 42, b := LIST_VALUE(10, 9, 8, NULL))").fetchall() - assert result == [({'a': 42, 'b': [10, 9, 8, None]},)] + assert result == [({"a": 42, "b": [10, 9, 8, None]},)] def test_map(self, duckdb_cursor): result = duckdb_cursor.execute("select MAP(LIST_VALUE(1, 2, 3, 4),LIST_VALUE(10, 9, 8, 7))").fetchall() diff --git a/tests/fast/types/test_null.py b/tests/fast/types/test_null.py index fa4105b6..27f287c8 100644 --- a/tests/fast/types/test_null.py +++ b/tests/fast/types/test_null.py @@ -1,7 +1,4 @@ -import traceback - - -class TestNull(object): +class TestNull: def test_fetchone_null(self, duckdb_cursor): duckdb_cursor.execute("CREATE TABLE atable (Value int)") duckdb_cursor.execute("INSERT INTO atable VALUES (1)") diff --git a/tests/fast/types/test_numeric.py b/tests/fast/types/test_numeric.py index f25b72b1..6540735d 100644 --- a/tests/fast/types/test_numeric.py +++ b/tests/fast/types/test_numeric.py @@ -1,14 +1,10 @@ -import duckdb -import numpy - - def check_result(duckdb_cursor, value, type): duckdb_cursor.execute("SELECT " + str(value) + "::" + type) results = duckdb_cursor.fetchall() assert results[0][0] == value -class TestNumeric(object): +class TestNumeric: def test_numeric_results(self, duckdb_cursor): check_result(duckdb_cursor, 1, "TINYINT") check_result(duckdb_cursor, 1, "SMALLINT") diff --git a/tests/fast/types/test_numpy.py b/tests/fast/types/test_numpy.py index 42ae33a0..b5fe6b3c 100644 --- a/tests/fast/types/test_numpy.py +++ b/tests/fast/types/test_numpy.py @@ -1,17 +1,18 @@ -import duckdb -import numpy as np import datetime -import pytest + +import numpy as np + +import duckdb -class TestNumpyDatetime64(object): +class TestNumpyDatetime64: def test_numpy_datetime64(self, duckdb_cursor): duckdb_con = duckdb.connect() duckdb_con.execute("create table tbl(col TIMESTAMP)") duckdb_con.execute( "insert into tbl VALUES (CAST(? AS TIMESTAMP WITHOUT TIME ZONE))", - parameters=[np.datetime64('2022-02-08T06:01:38.761310')], + parameters=[np.datetime64("2022-02-08T06:01:38.761310")], ) assert [(datetime.datetime(2022, 2, 8, 6, 1, 38, 761310),)] == duckdb_con.execute( "select * from tbl" @@ -24,11 +25,11 @@ def test_numpy_datetime_big(self): duckdb_con.execute("INSERT INTO TEST VALUES ('2263-02-28')") res1 = duckdb_con.execute("select * from test").fetchnumpy() - date_value = {'date': np.array(['2263-02-28'], dtype='datetime64[us]')} + date_value = {"date": np.array(["2263-02-28"], dtype="datetime64[us]")} assert res1 == date_value def test_numpy_enum_conversion(self, duckdb_cursor): - arr = np.array(['a', 'b', 'c']) + arr = np.array(["a", "b", "c"]) rel = duckdb_cursor.sql("select * from arr") - res = rel.fetchnumpy()['column0'] + res = rel.fetchnumpy()["column0"] np.testing.assert_equal(res, arr) diff --git a/tests/fast/types/test_object_int.py b/tests/fast/types/test_object_int.py index ce153d49..f0665535 100644 --- a/tests/fast/types/test_object_int.py +++ b/tests/fast/types/test_object_int.py @@ -1,30 +1,31 @@ -import numpy as np -import datetime -import duckdb -import pytest import warnings from contextlib import suppress +import numpy as np +import pytest + +import duckdb + -class TestPandasObjectInteger(object): +class TestPandasObjectInteger: # Signed Masked Integer types def test_object_integer(self, duckdb_cursor): pd = pytest.importorskip("pandas") df_in = pd.DataFrame( { - 'int8': pd.Series([None, 1, -1], dtype="Int8"), - 'int16': pd.Series([None, 1, -1], dtype="Int16"), - 'int32': pd.Series([None, 1, -1], dtype="Int32"), - 'int64': pd.Series([None, 1, -1], dtype="Int64"), + "int8": pd.Series([None, 1, -1], dtype="Int8"), + "int16": pd.Series([None, 1, -1], dtype="Int16"), + "int32": pd.Series([None, 1, -1], dtype="Int32"), + "int64": pd.Series([None, 1, -1], dtype="Int64"), } ) - warnings.simplefilter(action='ignore', category=RuntimeWarning) + warnings.simplefilter(action="ignore", category=RuntimeWarning) df_expected_res = pd.DataFrame( { - 'int8': pd.Series(np.ma.masked_array([0, 1, -1], mask=[True, False, False]), dtype='Int8'), - 'int16': pd.Series(np.ma.masked_array([0, 1, -1], mask=[True, False, False]), dtype='Int16'), - 'int32': pd.Series(np.ma.masked_array([0, 1, -1], mask=[True, False, False]), dtype='Int32'), - 'int64': pd.Series(np.ma.masked_array([0, 1, -1], mask=[True, False, False]), dtype='Int64'), + "int8": pd.Series(np.ma.masked_array([0, 1, -1], mask=[True, False, False]), dtype="Int8"), + "int16": pd.Series(np.ma.masked_array([0, 1, -1], mask=[True, False, False]), dtype="Int16"), + "int32": pd.Series(np.ma.masked_array([0, 1, -1], mask=[True, False, False]), dtype="Int32"), + "int64": pd.Series(np.ma.masked_array([0, 1, -1], mask=[True, False, False]), dtype="Int64"), } ) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() @@ -37,22 +38,22 @@ def test_object_uinteger(self, duckdb_cursor): with suppress(TypeError): df_in = pd.DataFrame( { - 'uint8': pd.Series([None, 1, 255], dtype="UInt8"), - 'uint16': pd.Series([None, 1, 65535], dtype="UInt16"), - 'uint32': pd.Series([None, 1, 4294967295], dtype="UInt32"), - 'uint64': pd.Series([None, 1, 18446744073709551615], dtype="UInt64"), + "uint8": pd.Series([None, 1, 255], dtype="UInt8"), + "uint16": pd.Series([None, 1, 65535], dtype="UInt16"), + "uint32": pd.Series([None, 1, 4294967295], dtype="UInt32"), + "uint64": pd.Series([None, 1, 18446744073709551615], dtype="UInt64"), } ) - warnings.simplefilter(action='ignore', category=RuntimeWarning) + warnings.simplefilter(action="ignore", category=RuntimeWarning) df_expected_res = pd.DataFrame( { - 'uint8': pd.Series(np.ma.masked_array([0, 1, 255], mask=[True, False, False]), dtype='UInt8'), - 'uint16': pd.Series(np.ma.masked_array([0, 1, 65535], mask=[True, False, False]), dtype='UInt16'), - 'uint32': pd.Series( - np.ma.masked_array([0, 1, 4294967295], mask=[True, False, False]), dtype='UInt32' + "uint8": pd.Series(np.ma.masked_array([0, 1, 255], mask=[True, False, False]), dtype="UInt8"), + "uint16": pd.Series(np.ma.masked_array([0, 1, 65535], mask=[True, False, False]), dtype="UInt16"), + "uint32": pd.Series( + np.ma.masked_array([0, 1, 4294967295], mask=[True, False, False]), dtype="UInt32" ), - 'uint64': pd.Series( - np.ma.masked_array([0, 1, 18446744073709551615], mask=[True, False, False]), dtype='UInt64' + "uint64": pd.Series( + np.ma.masked_array([0, 1, 18446744073709551615], mask=[True, False, False]), dtype="UInt64" ), } ) @@ -63,20 +64,20 @@ def test_object_uinteger(self, duckdb_cursor): # Unsigned Masked float/double types def test_object_float(self, duckdb_cursor): # Require pandas 1.2.0 >= for this, because Float32|Float64 was not added before this version - pd = pytest.importorskip("pandas", '1.2.0') + pd = pytest.importorskip("pandas", "1.2.0") df_in = pd.DataFrame( { - 'float32': pd.Series([None, 1, 4294967295], dtype="Float32"), - 'float64': pd.Series([None, 1, 18446744073709551615], dtype="Float64"), + "float32": pd.Series([None, 1, 4294967295], dtype="Float32"), + "float64": pd.Series([None, 1, 18446744073709551615], dtype="Float64"), } ) df_expected_res = pd.DataFrame( { - 'float32': pd.Series( - np.ma.masked_array([0, 1, 4294967295], mask=[True, False, False]), dtype='float32' + "float32": pd.Series( + np.ma.masked_array([0, 1, 4294967295], mask=[True, False, False]), dtype="float32" ), - 'float64': pd.Series( - np.ma.masked_array([0, 1, 18446744073709551615], mask=[True, False, False]), dtype='float64' + "float64": pd.Series( + np.ma.masked_array([0, 1, 18446744073709551615], mask=[True, False, False]), dtype="float64" ), } ) diff --git a/tests/fast/types/test_time_tz.py b/tests/fast/types/test_time_tz.py index 66475df8..8173e110 100644 --- a/tests/fast/types/test_time_tz.py +++ b/tests/fast/types/test_time_tz.py @@ -1,17 +1,16 @@ -import numpy as np +import datetime from datetime import time, timezone -import duckdb + import pytest -import datetime pandas = pytest.importorskip("pandas") -class TestTimeTz(object): +class TestTimeTz: def test_time_tz(self, duckdb_cursor): - df = pandas.DataFrame({"col1": [time(1, 2, 3, tzinfo=timezone.utc)]}) + df = pandas.DataFrame({"col1": [time(1, 2, 3, tzinfo=timezone.utc)]}) # noqa: F841 - sql = f'SELECT * FROM df' + sql = "SELECT * FROM df" duckdb_cursor.execute(sql) diff --git a/tests/fast/types/test_unsigned.py b/tests/fast/types/test_unsigned.py index 6ac50727..5639d33b 100644 --- a/tests/fast/types/test_unsigned.py +++ b/tests/fast/types/test_unsigned.py @@ -1,7 +1,7 @@ -class TestUnsigned(object): +class TestUnsigned: def test_unsigned(self, duckdb_cursor): - duckdb_cursor.execute('create table unsigned (a utinyint, b usmallint, c uinteger, d ubigint)') - duckdb_cursor.execute('insert into unsigned values (1,1,1,1), (null,null,null,null)') - duckdb_cursor.execute('select * from unsigned order by a nulls first') + duckdb_cursor.execute("create table unsigned (a utinyint, b usmallint, c uinteger, d ubigint)") + duckdb_cursor.execute("insert into unsigned values (1,1,1,1), (null,null,null,null)") + duckdb_cursor.execute("select * from unsigned order by a nulls first") result = duckdb_cursor.fetchall() assert result == [(None, None, None, None), (1, 1, 1, 1)] diff --git a/tests/fast/udf/test_null_filtering.py b/tests/fast/udf/test_null_filtering.py index 208a9246..e5c0d546 100644 --- a/tests/fast/udf/test_null_filtering.py +++ b/tests/fast/udf/test_null_filtering.py @@ -1,17 +1,33 @@ -import duckdb +import datetime +import uuid +from typing import Any, NamedTuple + import pytest -pd = pytest.importorskip("pandas") -pa = pytest.importorskip('pyarrow', '18.0.0') -from typing import Union -import pyarrow.compute as pc -import uuid -import datetime -import numpy as np -import cmath -from typing import NamedTuple, Any, List +import duckdb +from duckdb.typing import ( + BIGINT, + BLOB, + BOOLEAN, + DATE, + DOUBLE, + FLOAT, + INTEGER, + INTERVAL, + SMALLINT, + TIME, + TIMESTAMP, + TINYINT, + UBIGINT, + UINTEGER, + USMALLINT, + UTINYINT, + UUID, + VARCHAR, +) -from duckdb.typing import * +pd = pytest.importorskip("pandas") +pa = pytest.importorskip("pyarrow", "18.0.0") class Candidate(NamedTuple): @@ -22,28 +38,28 @@ class Candidate(NamedTuple): def layout(index: int): return [ - ['x', 'x', 'y'], - ['x', None, 'y'], - [None, 'y', None], - ['x', None, None], - [None, None, 'y'], + ["x", "x", "y"], + ["x", None, "y"], + [None, "y", None], + ["x", None, None], + [None, None, "y"], [None, None, None], ][index] def get_table_data(): - def add_variations(data, index: int): + def add_variations(data, index: int) -> None: data.extend( [ { - 'a': layout(index), - 'b': layout(0), - 'c': layout(0), + "a": layout(index), + "b": layout(0), + "c": layout(0), }, { - 'a': layout(0), - 'b': layout(0), - 'c': layout(index), + "a": layout(0), + "b": layout(0), + "c": layout(index), }, ] ) @@ -83,9 +99,9 @@ def get_types(): 2147483647, ), Candidate(UBIGINT, 18446744073709551615, 9223372036854776000), - Candidate(VARCHAR, 'long_string_test', 'smallstring'), + Candidate(VARCHAR, "long_string_test", "smallstring"), Candidate( - UUID, uuid.UUID('ffffffff-ffff-ffff-ffff-ffffffffffff'), uuid.UUID('ffffffff-ffff-ffff-ffff-000000000000') + UUID, uuid.UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"), uuid.UUID("ffffffff-ffff-ffff-ffff-000000000000") ), Candidate( FLOAT, @@ -106,8 +122,8 @@ def get_types(): ), Candidate( BLOB, - b'\xf6\x96\xb0\x85', - b'\x85\xb0\x96\xf6', + b"\xf6\x96\xb0\x85", + b"\x85\xb0\x96\xf6", ), Candidate( INTERVAL, @@ -120,24 +136,24 @@ def get_types(): False, ), Candidate( - duckdb.struct_type(['BIGINT[]', 'VARCHAR[]']), - {'v1': [1, 2, 3], 'v2': ['a', 'non-inlined string', 'duckdb']}, - {'v1': [5, 4, 3, 2, 1], 'v2': ['non-inlined-string', 'a', 'b', 'c', 'duckdb']}, + duckdb.struct_type(["BIGINT[]", "VARCHAR[]"]), + {"v1": [1, 2, 3], "v2": ["a", "non-inlined string", "duckdb"]}, + {"v1": [5, 4, 3, 2, 1], "v2": ["non-inlined-string", "a", "b", "c", "duckdb"]}, ), - Candidate(duckdb.list_type('VARCHAR'), ['the', 'duck', 'non-inlined string'], ['non-inlined-string', 'test']), + Candidate(duckdb.list_type("VARCHAR"), ["the", "duck", "non-inlined string"], ["non-inlined-string", "test"]), ] def construct_query(tuples) -> str: def construct_values_list(row, start_param_idx): parameter_count = len(row) - parameters = [f'${x+start_param_idx}' for x in range(parameter_count)] - parameters = '(' + ', '.join(parameters) + ')' + parameters = [f"${x + start_param_idx}" for x in range(parameter_count)] + parameters = "(" + ", ".join(parameters) + ")" return parameters row_size = len(tuples[0]) values_list = [construct_values_list(x, 1 + (i * row_size)) for i, x in enumerate(tuples)] - values_list = ', '.join(values_list) + values_list = ", ".join(values_list) query = f""" select * from (values {values_list}) @@ -148,25 +164,25 @@ def construct_values_list(row, start_param_idx): def construct_parameters(tuples, dbtype): parameters = [] for row in tuples: - parameters.extend(list([duckdb.Value(x, dbtype) for x in row])) + parameters.extend([duckdb.Value(x, dbtype) for x in row]) return parameters -class TestUDFNullFiltering(object): +class TestUDFNullFiltering: @pytest.mark.parametrize( - 'table_data', + "table_data", get_table_data(), ) @pytest.mark.parametrize( - 'test_type', + "test_type", get_types(), ) - @pytest.mark.parametrize('udf_type', ['arrow', 'native']) + @pytest.mark.parametrize("udf_type", ["arrow", "native"]) def test_null_filtering(self, duckdb_cursor, table_data: dict, test_type: Candidate, udf_type): - null_count = sum([1 for x in list(zip(*table_data.values())) if any([y == None for y in x])]) + null_count = sum([1 for x in list(zip(*table_data.values())) if any(y is None for y in x)]) row_count = len(table_data) table_data = { - key: [None if not x else test_type.variant_one if x == 'x' else test_type.variant_two for x in value] + key: [None if not x else test_type.variant_one if x == "x" else test_type.variant_two for x in value] for key, value in table_data.items() } @@ -174,26 +190,26 @@ def test_null_filtering(self, duckdb_cursor, table_data: dict, test_type: Candid query = construct_query(tuples) parameters = construct_parameters(tuples, test_type.type) rel = duckdb_cursor.sql(query + " t(a, b, c)", params=parameters) - rel.to_table('tbl') + rel.to_table("tbl") rel.show() def my_func(*args): - if udf_type == 'arrow': + if udf_type == "arrow": my_func.count += len(args[0]) else: my_func.count += 1 return args[0] def create_parameters(table_data, dbtype): - return ", ".join(f'{key}::{dbtype}' for key in list(table_data.keys())) + return ", ".join(f"{key}::{dbtype}" for key in list(table_data.keys())) my_func.count = 0 - duckdb_cursor.create_function('test', my_func, None, test_type.type, type=udf_type) + duckdb_cursor.create_function("test", my_func, None, test_type.type, type=udf_type) query = f"select test({create_parameters(table_data, test_type.type)}) from tbl" result = duckdb_cursor.sql(query).fetchall() expected_output = [ - (t[0],) if not any(x == None for x in t) else (None,) for t in list(zip(*table_data.values())) + (t[0],) if not any(x is None for x in t) else (None,) for t in list(zip(*table_data.values())) ] assert result == expected_output assert len(result) == row_count @@ -201,24 +217,24 @@ def create_parameters(table_data, dbtype): assert my_func.count == row_count - null_count @pytest.mark.parametrize( - 'table_data', + "table_data", [ [1, 2, 3, 4], [1, 2, None, 4], ], ) def test_nulls_from_default_null_handling_native(self, duckdb_cursor, table_data): - def returns_null(x): + def returns_null(x) -> None: return None - df = pd.DataFrame({'a': table_data}) + df = pd.DataFrame({"a": table_data}) # noqa: F841 duckdb_cursor.execute("create table tbl as select * from df") - duckdb_cursor.create_function('test', returns_null, [str], int, type='native') - with pytest.raises(duckdb.InvalidInputException, match='The UDF is not expected to return NULL values'): - result = duckdb_cursor.sql("select test(a::VARCHAR) from tbl").fetchall() + duckdb_cursor.create_function("test", returns_null, [str], int, type="native") + with pytest.raises(duckdb.InvalidInputException, match="The UDF is not expected to return NULL values"): + duckdb_cursor.sql("select test(a::VARCHAR) from tbl").fetchall() @pytest.mark.parametrize( - 'table_data', + "table_data", [ [1, 2, 3, 4], [1, 2, None, 4], @@ -226,12 +242,11 @@ def returns_null(x): ) def test_nulls_from_default_null_handling_arrow(self, duckdb_cursor, table_data): def returns_null(x): - l = x.to_pylist() - return pa.array([None for _ in l], type=pa.int64()) + lst = x.to_pylist() + return pa.array([None for _ in lst], type=pa.int64()) - df = pd.DataFrame({'a': table_data}) + df = pd.DataFrame({"a": table_data}) # noqa: F841 duckdb_cursor.execute("create table tbl as select * from df") - duckdb_cursor.create_function('test', returns_null, [str], int, type='arrow') - with pytest.raises(duckdb.InvalidInputException, match='The UDF is not expected to return NULL values'): - result = duckdb_cursor.sql("select test(a::VARCHAR) from tbl").fetchall() - print(result) + duckdb_cursor.create_function("test", returns_null, [str], int, type="arrow") + with pytest.raises(duckdb.InvalidInputException, match="The UDF is not expected to return NULL values"): + duckdb_cursor.sql("select test(a::VARCHAR) from tbl").fetchall() diff --git a/tests/fast/udf/test_remove_function.py b/tests/fast/udf/test_remove_function.py index 15dd6b2b..7ced339a 100644 --- a/tests/fast/udf/test_remove_function.py +++ b/tests/fast/udf/test_remove_function.py @@ -1,79 +1,72 @@ -import duckdb -import os import pytest +import duckdb +from duckdb.typing import BIGINT, VARCHAR + pd = pytest.importorskip("pandas") pa = pytest.importorskip("pyarrow") -from typing import Union -import pyarrow.compute as pc -import uuid -import datetime -import numpy as np -import cmath - -from duckdb.typing import * -class TestRemoveFunction(object): +class TestRemoveFunction: def test_not_created(self): con = duckdb.connect() with pytest.raises( duckdb.InvalidInputException, - match="No function by the name of 'not_a_registered_function' was found in the list of registered functions", + match="No function by the name of 'not_a_registered_function' was found in the list of " + "registered functions", ): - con.remove_function('not_a_registered_function') + con.remove_function("not_a_registered_function") def test_double_remove(self): def func(x: int) -> int: return x con = duckdb.connect() - con.create_function('func', func) - con.sql('select func(42)') - con.remove_function('func') + con.create_function("func", func) + con.sql("select func(42)") + con.remove_function("func") with pytest.raises( duckdb.InvalidInputException, match="No function by the name of 'func' was found in the list of registered functions", ): - con.remove_function('func') + con.remove_function("func") - with pytest.raises(duckdb.CatalogException, match='Scalar Function with name func does not exist!'): - con.sql('select func(42)') + with pytest.raises(duckdb.CatalogException, match="Scalar Function with name func does not exist!"): + con.sql("select func(42)") def test_use_after_remove(self): def func(x: int) -> int: return x con = duckdb.connect() - con.create_function('func', func) - rel = con.sql('select func(42)') - con.remove_function('func') + con.create_function("func", func) + rel = con.sql("select func(42)") + con.remove_function("func") """ Error: Catalog Error: Scalar Function with name func does not exist! """ - with pytest.raises( - duckdb.CatalogException, match='Scalar Function with name func does not exist!' - ): - res = rel.fetchall() + with pytest.raises(duckdb.CatalogException, match="Scalar Function with name func does not exist!"): + rel.fetchall() def test_use_after_remove_and_recreation(self): def func(x: str) -> str: return x con = duckdb.connect() - con.create_function('func', func) + con.create_function("func", func) + + with pytest.raises(duckdb.BinderException, match="No function matches the given name"): + con.sql("select func(42)") - with pytest.raises(duckdb.BinderException, match='No function matches the given name'): - rel1 = con.sql('select func(42)') rel2 = con.sql("select func('test'::VARCHAR)") - con.remove_function('func') + con.remove_function("func") def also_func(x: int) -> int: return x - con.create_function('func', also_func) - with pytest.raises(duckdb.BinderException, match='No function matches the given name'): - res = rel2.fetchall() + con.create_function("func", also_func) + with pytest.raises(duckdb.BinderException, match="No function matches the given name"): + rel2.fetchall() def test_overwrite_name(self): def func(x): @@ -81,7 +74,7 @@ def func(x): con = duckdb.connect() # create first version of the function - con.create_function('func', func, [BIGINT], BIGINT) + con.create_function("func", func, [BIGINT], BIGINT) # create relation that uses the function rel1 = con.sql("select func('3')") @@ -91,19 +84,20 @@ def other_func(x): with pytest.raises( duckdb.NotImplementedException, - match="A function by the name of 'func' is already created, creating multiple functions with the same name is not supported yet, please remove it first", + match="A function by the name of 'func' is already created, creating multiple functions with the " + "same name is not supported yet, please remove it first", ): - con.create_function('func', other_func, [VARCHAR], VARCHAR) + con.create_function("func", other_func, [VARCHAR], VARCHAR) - con.remove_function('func') + con.remove_function("func") with pytest.raises( - duckdb.CatalogException, match='Catalog Error: Scalar Function with name func does not exist!' + duckdb.CatalogException, match="Catalog Error: Scalar Function with name func does not exist!" ): # Attempted to execute the relation using the 'func' function, but it was deleted rel1.fetchall() - con.create_function('func', other_func, [VARCHAR], VARCHAR) + con.create_function("func", other_func, [VARCHAR], VARCHAR) # create relation that uses the new version rel2 = con.sql("select func('test')") @@ -111,5 +105,5 @@ def other_func(x): res1 = rel1.fetchall() res2 = rel2.fetchall() # This has been converted to string, because the previous version of the function no longer exists - assert res1 == [('3',)] - assert res2 == [('test',)] + assert res1 == [("3",)] + assert res2 == [("test",)] diff --git a/tests/fast/udf/test_scalar.py b/tests/fast/udf/test_scalar.py index 61648c20..40e0d4de 100644 --- a/tests/fast/udf/test_scalar.py +++ b/tests/fast/udf/test_scalar.py @@ -1,17 +1,37 @@ -import duckdb -import os -import pytest - -pd = pytest.importorskip("pandas") -pa = pytest.importorskip('pyarrow', '18.0.0') -from typing import Union -import pyarrow.compute as pc -import uuid +import cmath import datetime +import uuid +from typing import Any, NoReturn + import numpy as np -import cmath +import pytest -from duckdb.typing import * +import duckdb +from duckdb.typing import ( + BIGINT, + BLOB, + BOOLEAN, + DATE, + DOUBLE, + FLOAT, + HUGEINT, + INTEGER, + INTERVAL, + SMALLINT, + TIME, + TIMESTAMP, + TINYINT, + UBIGINT, + UHUGEINT, + UINTEGER, + USMALLINT, + UTINYINT, + UUID, + VARCHAR, +) + +pd = pytest.importorskip("pandas") +pa = pytest.importorskip("pyarrow", "18.0.0") def make_annotated_function(type): @@ -25,14 +45,14 @@ def test_base(x): test_base.__code__, test_base.__globals__, test_base.__name__, test_base.__defaults__, test_base.__closure__ ) # Add annotations for the return type and 'x' - test_function.__annotations__ = {'return': type, 'x': type} + test_function.__annotations__ = {"return": type, "x": type} return test_function -class TestScalarUDF(object): - @pytest.mark.parametrize('function_type', ['native', 'arrow']) +class TestScalarUDF: + @pytest.mark.parametrize("function_type", ["native", "arrow"]) @pytest.mark.parametrize( - 'test_type', + "test_type", [ (TINYINT, -42), (SMALLINT, -512), @@ -43,21 +63,21 @@ class TestScalarUDF(object): (UINTEGER, 4294967295), (UBIGINT, 18446744073709551615), (HUGEINT, 18446744073709551616), - (VARCHAR, 'long_string_test'), - (UUID, uuid.UUID('ffffffff-ffff-ffff-ffff-ffffffffffff')), + (VARCHAR, "long_string_test"), + (UUID, uuid.UUID("ffffffff-ffff-ffff-ffff-ffffffffffff")), (FLOAT, 0.12246409803628922), (DOUBLE, 123142.12312416293784721232344), (DATE, datetime.date(2005, 3, 11)), (TIMESTAMP, datetime.datetime(2009, 2, 13, 11, 5, 53)), (TIME, datetime.time(14, 1, 12)), - (BLOB, b'\xf6\x96\xb0\x85'), + (BLOB, b"\xf6\x96\xb0\x85"), (INTERVAL, datetime.timedelta(days=30969, seconds=999, microseconds=999999)), (BOOLEAN, True), ( - duckdb.struct_type(['BIGINT[]', 'VARCHAR[]']), - {'v1': [1, 2, 3], 'v2': ['a', 'non-inlined string', 'duckdb']}, + duckdb.struct_type(["BIGINT[]", "VARCHAR[]"]), + {"v1": [1, 2, 3], "v2": ["a", "non-inlined string", "duckdb"]}, ), - (duckdb.list_type('VARCHAR'), ['the', 'duck', 'non-inlined string']), + (duckdb.list_type("VARCHAR"), ["the", "duck", "non-inlined string"]), ], ) def test_type_coverage(self, test_type, function_type): @@ -67,18 +87,18 @@ def test_type_coverage(self, test_type, function_type): test_function = make_annotated_function(type) con = duckdb.connect() - con.create_function('test', test_function, type=function_type) + con.create_function("test", test_function, type=function_type) # Single value - res = con.execute(f"select test(?::{str(type)})", [value]).fetchall() + res = con.execute(f"select test(?::{type!s})", [value]).fetchall() assert res[0][0] == value # NULLs - res = con.execute(f"select res from (select ?, test(NULL::{str(type)}) as res)", [value]).fetchall() - assert res[0][0] == None + res = con.execute(f"select res from (select ?, test(NULL::{type!s}) as res)", [value]).fetchall() + assert res[0][0] is None # Multiple chunks size = duckdb.__standard_vector_size__ * 3 - res = con.execute(f"select test(x) from repeat(?::{str(type)}, {size}) as tbl(x)", [value]).fetchall() + res = con.execute(f"select test(x) from repeat(?::{type!s}, {size}) as tbl(x)", [value]).fetchall() assert len(res) == size # Mixed NULL/NON-NULL @@ -88,7 +108,7 @@ def test_type_coverage(self, test_type, function_type): f""" select test( case when (x > 0.5) then - ?::{str(type)} + ?::{type!s} else NULL end @@ -102,7 +122,7 @@ def test_type_coverage(self, test_type, function_type): f""" select case when (x > 0.5) then - ?::{str(type)} + ?::{type!s} else NULL end @@ -113,73 +133,74 @@ def test_type_coverage(self, test_type, function_type): assert expected == actual # Using 'relation.project' - con.execute(f"create table tbl as select ?::{str(type)} as x", [value]) - table_rel = con.table('tbl') - res = table_rel.project('test(x)').fetchall() + con.execute(f"create table tbl as select ?::{type!s} as x", [value]) + table_rel = con.table("tbl") + res = table_rel.project("test(x)").fetchall() assert res[0][0] == value - @pytest.mark.parametrize('udf_type', ['arrow', 'native']) + @pytest.mark.parametrize("udf_type", ["arrow", "native"]) def test_map_coverage(self, udf_type): def no_op(x): return x con = duckdb.connect() - map_type = con.map_type('VARCHAR', 'BIGINT') - con.create_function('test_map', no_op, [map_type], map_type, type=udf_type) + map_type = con.map_type("VARCHAR", "BIGINT") + con.create_function("test_map", no_op, [map_type], map_type, type=udf_type) rel = con.sql("select test_map(map(['non-inlined string', 'test', 'duckdb'], [42, 1337, 123]))") res = rel.fetchall() - assert res == [({'non-inlined string': 42, 'test': 1337, 'duckdb': 123},)] + assert res == [({"non-inlined string": 42, "test": 1337, "duckdb": 123},)] - @pytest.mark.parametrize('udf_type', ['arrow', 'native']) + @pytest.mark.parametrize("udf_type", ["arrow", "native"]) def test_exceptions(self, udf_type): - def raises_exception(x): - raise AttributeError("error") + def raises_exception(x) -> NoReturn: + msg = "error" + raise AttributeError(msg) con = duckdb.connect() - con.create_function('raises', raises_exception, [BIGINT], BIGINT, type=udf_type) + con.create_function("raises", raises_exception, [BIGINT], BIGINT, type=udf_type) with pytest.raises( duckdb.InvalidInputException, - match=' Python exception occurred while executing the UDF: AttributeError: error', + match=" Python exception occurred while executing the UDF: AttributeError: error", ): - res = con.sql('select raises(3)').fetchall() + res = con.sql("select raises(3)").fetchall() - con.remove_function('raises') + con.remove_function("raises") con.create_function( - 'raises', raises_exception, [BIGINT], BIGINT, exception_handling='return_null', type=udf_type + "raises", raises_exception, [BIGINT], BIGINT, exception_handling="return_null", type=udf_type ) - res = con.sql('select raises(3) from range(5)').fetchall() + res = con.sql("select raises(3) from range(5)").fetchall() assert res == [(None,), (None,), (None,), (None,), (None,)] def test_non_callable(self): con = duckdb.connect() with pytest.raises(TypeError): - con.create_function('func', 5, [BIGINT], BIGINT, type='arrow') + con.create_function("func", 5, [BIGINT], BIGINT, type="arrow") class MyCallable: - def __init__(self): + def __init__(self) -> None: pass - def __call__(self, x): + def __call__(self, x: Any) -> Any: # noqa: ANN401 return x my_callable = MyCallable() - con.create_function('func', my_callable, [BIGINT], BIGINT, type='arrow') - res = con.sql('select func(5)').fetchall() + con.create_function("func", my_callable, [BIGINT], BIGINT, type="arrow") + res = con.sql("select func(5)").fetchall() assert res == [(5,)] # pyarrow does not support creating an array filled with pd.NA values - @pytest.mark.parametrize('udf_type', ['native']) - @pytest.mark.parametrize('duckdb_type', [FLOAT, DOUBLE]) + @pytest.mark.parametrize("udf_type", ["native"]) + @pytest.mark.parametrize("duckdb_type", [FLOAT, DOUBLE]) def test_pd_nan(self, duckdb_type, udf_type): def return_pd_nan(): - if udf_type == 'native': + if udf_type == "native": return pd.NA con = duckdb.connect() - con.create_function('return_pd_nan', return_pd_nan, None, duckdb_type, null_handling='SPECIAL', type=udf_type) + con.create_function("return_pd_nan", return_pd_nan, None, duckdb_type, null_handling="SPECIAL", type=udf_type) - res = con.sql('select return_pd_nan()').fetchall() - assert res[0][0] == None + res = con.sql("select return_pd_nan()").fetchall() + assert res[0][0] is None def test_side_effects(self): def count() -> int: @@ -190,21 +211,21 @@ def count() -> int: count.counter = 0 con = duckdb.connect() - con.create_function('my_counter', count, side_effects=False) - res = con.sql('select my_counter() from range(10)').fetchall() + con.create_function("my_counter", count, side_effects=False) + res = con.sql("select my_counter() from range(10)").fetchall() assert res == [(0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,)] count.counter = 0 - con.remove_function('my_counter') - con.create_function('my_counter', count, side_effects=True) - res = con.sql('select my_counter() from range(10)').fetchall() + con.remove_function("my_counter") + con.create_function("my_counter", count, side_effects=True) + res = con.sql("select my_counter() from range(10)").fetchall() assert res == [(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,), (9,)] - @pytest.mark.parametrize('udf_type', ['arrow', 'native']) - @pytest.mark.parametrize('duckdb_type', [FLOAT, DOUBLE]) + @pytest.mark.parametrize("udf_type", ["arrow", "native"]) + @pytest.mark.parametrize("duckdb_type", [FLOAT, DOUBLE]) def test_np_nan(self, duckdb_type, udf_type): def return_np_nan(): - if udf_type == 'native': + if udf_type == "native": return np.nan else: import pyarrow as pa @@ -212,18 +233,16 @@ def return_np_nan(): return pa.chunked_array([[np.nan]], type=pa.float64()) con = duckdb.connect() - con.create_function('return_np_nan', return_np_nan, None, duckdb_type, null_handling='SPECIAL', type=udf_type) + con.create_function("return_np_nan", return_np_nan, None, duckdb_type, null_handling="SPECIAL", type=udf_type) - res = con.sql('select return_np_nan()').fetchall() + res = con.sql("select return_np_nan()").fetchall() assert pd.isnull(res[0][0]) - @pytest.mark.parametrize('udf_type', ['arrow', 'native']) - @pytest.mark.parametrize('duckdb_type', [FLOAT, DOUBLE]) + @pytest.mark.parametrize("udf_type", ["arrow", "native"]) + @pytest.mark.parametrize("duckdb_type", [FLOAT, DOUBLE]) def test_math_nan(self, duckdb_type, udf_type): def return_math_nan(): - import cmath - - if udf_type == 'native': + if udf_type == "native": return cmath.nan else: import pyarrow as pa @@ -232,15 +251,15 @@ def return_math_nan(): con = duckdb.connect() con.create_function( - 'return_math_nan', return_math_nan, None, duckdb_type, null_handling='SPECIAL', type=udf_type + "return_math_nan", return_math_nan, None, duckdb_type, null_handling="SPECIAL", type=udf_type ) - res = con.sql('select return_math_nan()').fetchall() + res = con.sql("select return_math_nan()").fetchall() assert pd.isnull(res[0][0]) - @pytest.mark.parametrize('udf_type', ['arrow', 'native']) + @pytest.mark.parametrize("udf_type", ["arrow", "native"]) @pytest.mark.parametrize( - 'data_type', + "data_type", [ TINYINT, SMALLINT, @@ -262,13 +281,13 @@ def return_math_nan(): BLOB, INTERVAL, BOOLEAN, - duckdb.struct_type(['BIGINT[]', 'VARCHAR[]']), - duckdb.list_type('VARCHAR'), + duckdb.struct_type(["BIGINT[]", "VARCHAR[]"]), + duckdb.list_type("VARCHAR"), ], ) def test_return_null(self, data_type, udf_type): def return_null(): - if udf_type == 'native': + if udf_type == "native": return None else: import pyarrow as pa @@ -276,23 +295,23 @@ def return_null(): return pa.nulls(1) con = duckdb.connect() - con.create_function('return_null', return_null, None, data_type, null_handling='special', type=udf_type) - rel = con.sql('select return_null() as x') + con.create_function("return_null", return_null, None, data_type, null_handling="special", type=udf_type) + rel = con.sql("select return_null() as x") assert rel.types[0] == data_type - assert rel.fetchall()[0][0] == None + assert rel.fetchall()[0][0] is None def test_udf_transaction_interaction(self): def func(x: int) -> int: return x con = duckdb.connect() - rel = con.sql('select 42') + rel = con.sql("select 42") # Using fetchone keeps the result open, with a transaction rel.fetchone() - con.create_function('func', func) + con.create_function("func", func) rel.fetchall() - res = con.sql('select func(5)').fetchall() + res = con.sql("select func(5)").fetchall() assert res == [(5,)] diff --git a/tests/fast/udf/test_scalar_arrow.py b/tests/fast/udf/test_scalar_arrow.py index 5773c474..46f932a1 100644 --- a/tests/fast/udf/test_scalar_arrow.py +++ b/tests/fast/udf/test_scalar_arrow.py @@ -1,49 +1,44 @@ -import duckdb -import os import pytest +import duckdb +from duckdb.typing import BIGINT, INTEGER, VARCHAR + pd = pytest.importorskip("pandas") pa = pytest.importorskip("pyarrow") -from typing import Union -import pyarrow.compute as pc -import uuid -import datetime -from duckdb.typing import * - -class TestPyArrowUDF(object): +class TestPyArrowUDF: def test_basic_use(self): def plus_one(x): - table = pa.lib.Table.from_arrays([x], names=['c0']) + pa.lib.Table.from_arrays([x], names=["c0"]) import pandas as pd df = pd.DataFrame(x.to_pandas()) - df['c0'] = df['c0'] + 1 + df["c0"] = df["c0"] + 1 return pa.lib.Table.from_pandas(df) con = duckdb.connect() - con.create_function('plus_one', plus_one, [BIGINT], BIGINT, type='arrow') - assert [(6,)] == con.sql('select plus_one(5)').fetchall() + con.create_function("plus_one", plus_one, [BIGINT], BIGINT, type="arrow") + assert con.sql("select plus_one(5)").fetchall() == [(6,)] - range_table = con.table_function('range', [5000]) - res = con.sql('select plus_one(i) from range_table tbl(i)').fetchall() + range_table = con.table_function("range", [5000]) # noqa: F841 + res = con.sql("select plus_one(i) from range_table tbl(i)").fetchall() assert len(res) == 5000 vector_size = duckdb.__standard_vector_size__ - res = con.sql(f'select i, plus_one(i) from test_vector_types(NULL::BIGINT, false) t(i), range({vector_size})') + res = con.sql(f"select i, plus_one(i) from test_vector_types(NULL::BIGINT, false) t(i), range({vector_size})") assert len(res) == (vector_size * 11) # NOTE: This only works up to duckdb.__standard_vector_size__, # because we process up to STANDARD_VECTOR_SIZE tuples at a time def test_sort_table(self): def sort_table(x): - table = pa.lib.Table.from_arrays([x], names=['c0']) + table = pa.lib.Table.from_arrays([x], names=["c0"]) sorted_table = table.sort_by([("c0", "ascending")]) return sorted_table con = duckdb.connect() - con.create_function('sort_table', sort_table, [BIGINT], BIGINT, type='arrow') + con.create_function("sort_table", sort_table, [BIGINT], BIGINT, type="arrow") res = con.sql("select 100-i as original, sort_table(original) from range(100) tbl(i)").fetchall() assert res[0] == (100, 1) @@ -51,13 +46,14 @@ def test_varargs(self): def variable_args(*args): # We return a chunked array here, but internally we convert this into a Table if len(args) == 0: - raise ValueError("Expected at least one argument") + msg = "Expected at least one argument" + raise ValueError(msg) for item in args: return item con = duckdb.connect() # This function takes any number of arguments, returning the first column - con.create_function('varargs', variable_args, None, BIGINT, type='arrow') + con.create_function("varargs", variable_args, None, BIGINT, type="arrow") res = con.sql("""select varargs(5, '3', '2', 1, 0.12345)""").fetchall() assert res == [(5,)] @@ -70,7 +66,7 @@ def takes_string(col): con = duckdb.connect() # The return type of the function is set to BIGINT, but it takes a VARCHAR - con.create_function('pyarrow_string_to_num', takes_string, [VARCHAR], BIGINT, type='arrow') + con.create_function("pyarrow_string_to_num", takes_string, [VARCHAR], BIGINT, type="arrow") # Successful conversion res = con.sql("""select pyarrow_string_to_num('5')""").fetchall() @@ -84,51 +80,49 @@ def returns_two_columns(col): import pandas as pd # Return a pyarrow table consisting of two columns - return pa.lib.Table.from_pandas(pd.DataFrame({'a': [5, 4, 3], 'b': ['test', 'quack', 'duckdb']})) + return pa.lib.Table.from_pandas(pd.DataFrame({"a": [5, 4, 3], "b": ["test", "quack", "duckdb"]})) con = duckdb.connect() # Scalar functions only return a single value per tuple - con.create_function('two_columns', returns_two_columns, [BIGINT], BIGINT, type='arrow') + con.create_function("two_columns", returns_two_columns, [BIGINT], BIGINT, type="arrow") with pytest.raises( duckdb.InvalidInputException, - match='The returned table from a pyarrow scalar udf should only contain one column, found 2', + match="The returned table from a pyarrow scalar udf should only contain one column, found 2", ): - res = con.sql("""select two_columns(5)""").fetchall() + con.sql("""select two_columns(5)""").fetchall() def test_return_none(self): - def returns_none(col): + def returns_none(col) -> None: return None con = duckdb.connect() - con.create_function('will_crash', returns_none, [BIGINT], BIGINT, type='arrow') + con.create_function("will_crash", returns_none, [BIGINT], BIGINT, type="arrow") with pytest.raises(duckdb.Error, match="""Could not convert the result into an Arrow Table"""): - res = con.sql("""select will_crash(5)""").fetchall() + con.sql("""select will_crash(5)""").fetchall() def test_empty_result(self): def return_empty(col): # Always returns an empty table - return pa.lib.Table.from_arrays([[]], names=['c0']) + return pa.lib.Table.from_arrays([[]], names=["c0"]) con = duckdb.connect() - con.create_function('empty_result', return_empty, [BIGINT], BIGINT, type='arrow') - with pytest.raises(duckdb.InvalidInputException, match='Returned pyarrow table should have 1 tuples, found 0'): - res = con.sql("""select empty_result(5)""").fetchall() + con.create_function("empty_result", return_empty, [BIGINT], BIGINT, type="arrow") + with pytest.raises(duckdb.InvalidInputException, match="Returned pyarrow table should have 1 tuples, found 0"): + con.sql("""select empty_result(5)""").fetchall() def test_excessive_result(self): def return_too_many(col): # Always returns a table consisting of 5 tuples - return pa.lib.Table.from_arrays([[5, 4, 3, 2, 1]], names=['c0']) + return pa.lib.Table.from_arrays([[5, 4, 3, 2, 1]], names=["c0"]) con = duckdb.connect() - con.create_function('too_many_tuples', return_too_many, [BIGINT], BIGINT, type='arrow') - with pytest.raises(duckdb.InvalidInputException, match='Returned pyarrow table should have 1 tuples, found 5'): - res = con.sql("""select too_many_tuples(5)""").fetchall() + con.create_function("too_many_tuples", return_too_many, [BIGINT], BIGINT, type="arrow") + with pytest.raises(duckdb.InvalidInputException, match="Returned pyarrow table should have 1 tuples, found 5"): + con.sql("""select too_many_tuples(5)""").fetchall() def test_arrow_side_effects(self, duckdb_cursor): - import random as r - def random_arrow(x): - if not hasattr(random_arrow, 'data'): + if not hasattr(random_arrow, "data"): random_arrow.data = 0 input = x.to_pylist() @@ -158,17 +152,17 @@ def return_struct(col): ).fetch_arrow_table() con = duckdb.connect() - struct_type = con.struct_type({'a': BIGINT, 'b': VARCHAR, 'c': con.list_type(BIGINT)}) - con.create_function('return_struct', return_struct, [BIGINT], struct_type, type='arrow') + struct_type = con.struct_type({"a": BIGINT, "b": VARCHAR, "c": con.list_type(BIGINT)}) + con.create_function("return_struct", return_struct, [BIGINT], struct_type, type="arrow") res = con.sql("""select return_struct(5)""").fetchall() - assert res == [({'a': 5, 'b': 'test', 'c': [5, 3, 2]},)] + assert res == [({"a": 5, "b": "test", "c": [5, 3, 2]},)] def test_multiple_chunks(self): def return_unmodified(col): return col con = duckdb.connect() - con.create_function('unmodified', return_unmodified, [BIGINT], BIGINT, type='arrow') + con.create_function("unmodified", return_unmodified, [BIGINT], BIGINT, type="arrow") res = con.sql( """ select unmodified(i) from range(5000) tbl(i) @@ -176,19 +170,19 @@ def return_unmodified(col): ).fetchall() assert len(res) == 5000 - assert res == con.sql('select * from range(5000)').fetchall() + assert res == con.sql("select * from range(5000)").fetchall() def test_inferred(self): def func(x: int) -> int: import pandas as pd - df = pd.DataFrame({'c0': x}) - df['c0'] = df['c0'] ** 2 + df = pd.DataFrame({"c0": x}) + df["c0"] = df["c0"] ** 2 return pa.lib.Table.from_pandas(df) con = duckdb.connect() - con.create_function('inferred', func, type='arrow') - res = con.sql('select inferred(42)').fetchall() + con.create_function("inferred", func, type="arrow") + res = con.sql("select inferred(42)").fetchall() assert res == [(1764,)] def test_nulls(self): @@ -196,27 +190,27 @@ def return_five(x): import pandas as pd length = len(x) - return pa.lib.Table.from_pandas(pd.DataFrame({'a': [5 for _ in range(length)]})) + return pa.lib.Table.from_pandas(pd.DataFrame({"a": [5 for _ in range(length)]})) con = duckdb.connect() - con.create_function('return_five', return_five, [BIGINT], BIGINT, null_handling='special', type='arrow') - res = con.sql('select return_five(NULL) from range(10)').fetchall() + con.create_function("return_five", return_five, [BIGINT], BIGINT, null_handling="special", type="arrow") + res = con.sql("select return_five(NULL) from range(10)").fetchall() # without 'special' null handling these would all be NULL assert res == [(5,), (5,), (5,), (5,), (5,), (5,), (5,), (5,), (5,), (5,)] con = duckdb.connect() - con.create_function('return_five', return_five, [BIGINT], BIGINT, null_handling='default', type='arrow') - res = con.sql('select return_five(NULL) from range(10)').fetchall() + con.create_function("return_five", return_five, [BIGINT], BIGINT, null_handling="default", type="arrow") + res = con.sql("select return_five(NULL) from range(10)").fetchall() # Because we didn't specify 'special' null handling, these are all NULL assert res == [(None,), (None,), (None,), (None,), (None,), (None,), (None,), (None,), (None,), (None,)] def test_struct_with_non_inlined_string(self, duckdb_cursor): def func(data): - return pa.array([{'x': 1, 'y': 'this is not an inlined string'}] * data.length()) + return pa.array([{"x": 1, "y": "this is not an inlined string"}] * data.length()) duckdb_cursor.create_function( name="func", function=func, return_type="STRUCT(x integer, y varchar)", type="arrow", side_effects=False ) res = duckdb_cursor.sql("select func(1).y").fetchone() - assert res == ('this is not an inlined string',) + assert res == ("this is not an inlined string",) diff --git a/tests/fast/udf/test_scalar_native.py b/tests/fast/udf/test_scalar_native.py index df58f6a4..64ea5b5b 100644 --- a/tests/fast/udf/test_scalar_native.py +++ b/tests/fast/udf/test_scalar_native.py @@ -1,36 +1,46 @@ -import duckdb -import os -import pandas as pd import pytest -from duckdb.typing import * - - -class TestNativeUDF(object): +import duckdb +from duckdb.typing import ( + BIGINT, + HUGEINT, + INTEGER, + SMALLINT, + TINYINT, + UBIGINT, + UHUGEINT, + UINTEGER, + USMALLINT, + UTINYINT, + VARCHAR, +) + + +class TestNativeUDF: def test_default_conn(self): def passthrough(x): return x - duckdb.create_function('default_conn_passthrough', passthrough, [BIGINT], BIGINT) - res = duckdb.sql('select default_conn_passthrough(5)').fetchall() + duckdb.create_function("default_conn_passthrough", passthrough, [BIGINT], BIGINT) + res = duckdb.sql("select default_conn_passthrough(5)").fetchall() assert res == [(5,)] def test_basic_use(self): def plus_one(x): - if x == None or x > 50: + if x is None or x > 50: return x return x + 1 con = duckdb.connect() - con.create_function('plus_one', plus_one, [BIGINT], BIGINT) - assert [(6,)] == con.sql('select plus_one(5)').fetchall() + con.create_function("plus_one", plus_one, [BIGINT], BIGINT) + assert con.sql("select plus_one(5)").fetchall() == [(6,)] - range_table = con.table_function('range', [5000]) - res = con.sql('select plus_one(i) from range_table tbl(i)').fetchall() + range_table = con.table_function("range", [5000]) # noqa: F841 + res = con.sql("select plus_one(i) from range_table tbl(i)").fetchall() assert len(res) == 5000 vector_size = duckdb.__standard_vector_size__ - res = con.sql(f'select i, plus_one(i) from test_vector_types(NULL::BIGINT, false) t(i), range({vector_size})') + res = con.sql(f"select i, plus_one(i) from test_vector_types(NULL::BIGINT, false) t(i), range({vector_size})") assert len(res) == (vector_size * 11) def test_passthrough(self): @@ -38,10 +48,10 @@ def passthrough(x): return x con = duckdb.connect() - con.create_function('passthrough', passthrough, [BIGINT], BIGINT) + con.create_function("passthrough", passthrough, [BIGINT], BIGINT) assert ( - con.sql('select passthrough(i) from range(5000) tbl(i)').fetchall() - == con.sql('select * from range(5000)').fetchall() + con.sql("select passthrough(i) from range(5000) tbl(i)").fetchall() + == con.sql("select * from range(5000)").fetchall() ) def test_execute(self): @@ -49,8 +59,8 @@ def func(x): return x % 2 con = duckdb.connect() - con.create_function('modulo_op', func, [BIGINT], TINYINT) - res = con.execute('select modulo_op(?)', [5]).fetchall() + con.create_function("modulo_op", func, [BIGINT], TINYINT) + res = con.execute("select modulo_op(?)", [5]).fetchall() assert res == [(1,)] def test_cast_output(self): @@ -58,7 +68,7 @@ def takes_string(x): return x con = duckdb.connect() - con.create_function('casts_from_string', takes_string, [VARCHAR], BIGINT) + con.create_function("casts_from_string", takes_string, [VARCHAR], BIGINT) res = con.sql("select casts_from_string('42')").fetchall() assert res == [(42,)] @@ -71,13 +81,13 @@ def concatenate(a: str, b: str): return a + b con = duckdb.connect() - con.create_function('py_concatenate', concatenate, None, VARCHAR) + con.create_function("py_concatenate", concatenate, None, VARCHAR) res = con.sql( """ select py_concatenate('5','3'); """ ).fetchall() - assert res[0][0] == '53' + assert res[0][0] == "53" def test_detected_return_type(self): def add_nums(*args) -> int: @@ -87,7 +97,7 @@ def add_nums(*args) -> int: return sum con = duckdb.connect() - con.create_function('add_nums', add_nums) + con.create_function("add_nums", add_nums) res = con.sql( """ select add_nums(5,3,2,1); @@ -101,34 +111,34 @@ def variable_args(*args): return amount con = duckdb.connect() - con.create_function('varargs', variable_args, None, BIGINT) + con.create_function("varargs", variable_args, None, BIGINT) res = con.sql("""select varargs('5', '3', '2', 1, 0.12345)""").fetchall() assert res == [(5,)] def test_return_incorrectly_typed_object(self): def returns_duckdb() -> int: - return 'duckdb' + return "duckdb" con = duckdb.connect() - con.create_function('fastest_database_in_the_west', returns_duckdb) + con.create_function("fastest_database_in_the_west", returns_duckdb) with pytest.raises( duckdb.InvalidInputException, match="Failed to cast value: Could not convert string 'duckdb' to INT64" ): - res = con.sql('select fastest_database_in_the_west()').fetchall() + con.sql("select fastest_database_in_the_west()").fetchall() def test_nulls(self): def five_if_null(x): - if x == None: + if x is None: return 5 return x con = duckdb.connect() - con.create_function('null_test', five_if_null, [BIGINT], BIGINT, null_handling="SPECIAL") - res = con.sql('select null_test(NULL)').fetchall() + con.create_function("null_test", five_if_null, [BIGINT], BIGINT, null_handling="SPECIAL") + res = con.sql("select null_test(NULL)").fetchall() assert res == [(5,)] @pytest.mark.parametrize( - 'pair', + "pair", [ (TINYINT, -129), (TINYINT, 128), @@ -159,26 +169,24 @@ def return_overflow(): return overflowing_value con = duckdb.connect() - con.create_function('return_overflow', return_overflow, None, duckdb_type) + con.create_function("return_overflow", return_overflow, None, duckdb_type) + rel = con.sql("select return_overflow()") with pytest.raises(duckdb.InvalidInputException): - rel = con.sql('select return_overflow()') - res = rel.fetchall() - print(duckdb_type) - print(res) + rel.fetchall() def test_structs(self): def add_extra_column(original): - original['a'] = 200 - original['c'] = 0 + original["a"] = 200 + original["c"] = 0 return original con = duckdb.connect() - range_table = con.table_function('range', [5000]) + range_table = con.table_function("range", [5000]) # noqa: F841 con.create_function( "append_field", add_extra_column, - [duckdb.struct_type({'a': BIGINT, 'b': BIGINT})], - duckdb.struct_type({'a': BIGINT, 'b': BIGINT, 'c': BIGINT}), + [duckdb.struct_type({"a": BIGINT, "b": BIGINT})], + duckdb.struct_type({"a": BIGINT, "b": BIGINT, "c": BIGINT}), ) res = con.sql( @@ -188,7 +196,8 @@ def add_extra_column(original): ) # added extra column to the struct assert len(res.fetchone()[0].keys()) == 3 - # FIXME: this is needed, otherwise the old transaction is still active when we try to start a new transaction inside of 'create_function', which means the call would fail + # TODO: this is needed, otherwise the old transaction is still active when we try # noqa: TD002, TD003 + # to start a new transaction inside of 'create_function', which means the call would fail res.fetchall() def swap_keys(dict): @@ -205,17 +214,17 @@ def swap_keys(dict): return result con.create_function( - 'swap_keys', + "swap_keys", swap_keys, - [con.struct_type({'a': BIGINT, 'b': VARCHAR})], - con.struct_type({'a': VARCHAR, 'b': BIGINT}), + [con.struct_type({"a": BIGINT, "b": VARCHAR})], + con.struct_type({"a": VARCHAR, "b": BIGINT}), ) res = con.sql( """ select swap_keys({'a': 42, 'b': 'answer_to_life'}) """ ).fetchall() - assert res == [({'a': 'answer_to_life', 'b': 42},)] + assert res == [({"a": "answer_to_life", "b": 42},)] def test_struct_different_field_order(self, duckdb_cursor): def example(): diff --git a/tests/fast/udf/test_transactionality.py b/tests/fast/udf/test_transactionality.py index 50286e8e..acad21ef 100644 --- a/tests/fast/udf/test_transactionality.py +++ b/tests/fast/udf/test_transactionality.py @@ -1,9 +1,10 @@ -import duckdb import pytest +import duckdb + -class TestUDFTransactionality(object): - @pytest.mark.xfail(reason='fetchone() does not realize the stream result was closed before completion') +class TestUDFTransactionality: + @pytest.mark.xfail(reason="fetchone() does not realize the stream result was closed before completion") def test_type_coverage(self, duckdb_cursor): rel = duckdb_cursor.sql("select * from range(4096)") res = rel.fetchone() @@ -12,7 +13,7 @@ def test_type_coverage(self, duckdb_cursor): def my_func(x: str) -> int: return int(x) - duckdb_cursor.create_function('test', my_func) + duckdb_cursor.create_function("test", my_func) - with pytest.raises(duckdb.InvalidInputException, match='result closed'): + with pytest.raises(duckdb.InvalidInputException, match="result closed"): res = rel.fetchone() diff --git a/tests/slow/test_h2oai_arrow.py b/tests/slow/test_h2oai_arrow.py index 40bde07b..2831f852 100644 --- a/tests/slow/test_h2oai_arrow.py +++ b/tests/slow/test_h2oai_arrow.py @@ -1,19 +1,21 @@ -import duckdb -import os import math -from pytest import mark, fixture, importorskip +from pathlib import Path + +import pytest + +import duckdb -read_csv = importorskip('pyarrow.csv').read_csv -requests = importorskip('requests') -requests_adapters = importorskip('requests.adapters') -urllib3_util = importorskip('urllib3.util') -np = importorskip('numpy') +read_csv = pytest.importorskip("pyarrow.csv").read_csv +requests = pytest.importorskip("requests") +requests_adapters = pytest.importorskip("requests.adapters") +urllib3_util = pytest.importorskip("urllib3.util") +np = pytest.importorskip("numpy") def group_by_q1(con): con.execute("CREATE TABLE ans AS SELECT id1, sum(v1) AS v1 FROM x GROUP BY id1") res = con.execute("SELECT COUNT(*), sum(v1)::varchar AS v1 FROM ans").fetchall() - assert res == [(96, '28498857')] + assert res == [(96, "28498857")] con.execute("DROP TABLE ans") @@ -55,7 +57,7 @@ def group_by_q5(con): def group_by_q6(con): con.execute( - "CREATE TABLE ans AS SELECT id4, id5, quantile_cont(v3, 0.5) AS median_v3, stddev(v3) AS sd_v3 FROM x GROUP BY id4, id5;" + "CREATE TABLE ans AS SELECT id4, id5, quantile_cont(v3, 0.5) AS median_v3, stddev(v3) AS sd_v3 FROM x GROUP BY id4, id5;" # noqa: E501 ) res = con.execute("SELECT COUNT(*), sum(median_v3) AS median_v3, sum(sd_v3) AS sd_v3 FROM ans").fetchall() assert res[0][0] == 9216 @@ -74,7 +76,7 @@ def group_by_q7(con): def group_by_q8(con): con.execute( - "CREATE TABLE ans AS SELECT id6, v3 AS largest2_v3 FROM (SELECT id6, v3, row_number() OVER (PARTITION BY id6 ORDER BY v3 DESC) AS order_v3 FROM x WHERE v3 IS NOT NULL) sub_query WHERE order_v3 <= 2" + "CREATE TABLE ans AS SELECT id6, v3 AS largest2_v3 FROM (SELECT id6, v3, row_number() OVER (PARTITION BY id6 ORDER BY v3 DESC) AS order_v3 FROM x WHERE v3 IS NOT NULL) sub_query WHERE order_v3 <= 2" # noqa: E501 ) res = con.execute("SELECT count(*), sum(largest2_v3) AS largest2_v3 FROM ans").fetchall() assert res[0][0] == 190002 @@ -92,7 +94,7 @@ def group_by_q9(con): def group_by_q10(con): con.execute( - "CREATE TABLE ans AS SELECT id1, id2, id3, id4, id5, id6, sum(v3) AS v3, count(*) AS count FROM x GROUP BY id1, id2, id3, id4, id5, id6;" + "CREATE TABLE ans AS SELECT id1, id2, id3, id4, id5, id6, sum(v3) AS v3, count(*) AS count FROM x GROUP BY id1, id2, id3, id4, id5, id6;" # noqa: E501 ) res = con.execute("SELECT sum(v3) AS v3, sum(count) AS count FROM ans;").fetchall() assert math.floor(res[0][0]) == 474969574 @@ -111,7 +113,7 @@ def join_by_q1(con): def join_by_q2(con): con.execute( - "CREATE TABLE ans AS SELECT x.*, medium.id1 AS medium_id1, medium.id4 AS medium_id4, medium.id5 AS medium_id5, v2 FROM x JOIN medium USING (id2);" + "CREATE TABLE ans AS SELECT x.*, medium.id1 AS medium_id1, medium.id4 AS medium_id4, medium.id5 AS medium_id5, v2 FROM x JOIN medium USING (id2);" # noqa: E501 ) res = con.execute("SELECT COUNT(*), SUM(v1) AS v1, SUM(v2) AS v2 FROM ans;").fetchall() assert res[0][0] == 8998412 @@ -122,7 +124,7 @@ def join_by_q2(con): def join_by_q3(con): con.execute( - "CREATE TABLE ans AS SELECT x.*, medium.id1 AS medium_id1, medium.id4 AS medium_id4, medium.id5 AS medium_id5, v2 FROM x LEFT JOIN medium USING (id2);" + "CREATE TABLE ans AS SELECT x.*, medium.id1 AS medium_id1, medium.id4 AS medium_id4, medium.id5 AS medium_id5, v2 FROM x LEFT JOIN medium USING (id2);" # noqa: E501 ) res = con.execute("SELECT COUNT(*), SUM(v1) AS v1, SUM(v2) AS v2 FROM ans;").fetchall() assert res[0][0] == 10000000 @@ -133,7 +135,7 @@ def join_by_q3(con): def join_by_q4(con): con.execute( - "CREATE TABLE ans AS SELECT x.*, medium.id1 AS medium_id1, medium.id2 AS medium_id2, medium.id4 AS medium_id4, v2 FROM x JOIN medium USING (id5);" + "CREATE TABLE ans AS SELECT x.*, medium.id1 AS medium_id1, medium.id2 AS medium_id2, medium.id4 AS medium_id4, v2 FROM x JOIN medium USING (id5);" # noqa: E501 ) res = con.execute("SELECT COUNT(*), SUM(v1) AS v1, SUM(v2) AS v2 FROM ans;").fetchall() assert res[0][0] == 8998412 @@ -144,7 +146,7 @@ def join_by_q4(con): def join_by_q5(con): con.execute( - "CREATE TABLE ans AS SELECT x.*, big.id1 AS big_id1, big.id2 AS big_id2, big.id4 AS big_id4, big.id5 AS big_id5, big.id6 AS big_id6, v2 FROM x JOIN big USING (id3);" + "CREATE TABLE ans AS SELECT x.*, big.id1 AS big_id1, big.id2 AS big_id2, big.id4 AS big_id4, big.id5 AS big_id5, big.id6 AS big_id6, v2 FROM x JOIN big USING (id3);" # noqa: E501 ) res = con.execute("SELECT COUNT(*), SUM(v1) AS v1, SUM(v2) AS v2 FROM ans;").fetchall() assert res[0][0] == 9000000 @@ -153,9 +155,9 @@ def join_by_q5(con): con.execute("DROP TABLE ans") -class TestH2OAIArrow(object): - @mark.parametrize( - 'function', +class TestH2OAIArrow: + @pytest.mark.parametrize( + "function", [ group_by_q1, group_by_q2, @@ -169,15 +171,15 @@ class TestH2OAIArrow(object): group_by_q10, ], ) - @mark.parametrize('threads', [1, 4]) - @mark.usefixtures('group_by_data') + @pytest.mark.parametrize("threads", [1, 4]) + @pytest.mark.usefixtures("group_by_data") def test_group_by(self, threads, function, group_by_data): group_by_data.execute(f"PRAGMA threads={threads}") function(group_by_data) - @mark.parametrize('threads', [1, 4]) - @mark.parametrize( - 'function', + @pytest.mark.parametrize("threads", [1, 4]) + @pytest.mark.parametrize( + "function", [ join_by_q1, join_by_q2, @@ -186,72 +188,72 @@ def test_group_by(self, threads, function, group_by_data): join_by_q5, ], ) - @mark.usefixtures('large_data') + @pytest.mark.usefixtures("large_data") def test_join(self, threads, function, large_data): large_data.execute(f"PRAGMA threads={threads}") function(large_data) -@fixture(scope="module") +@pytest.fixture(scope="module") def arrow_dataset_register(): - """Single fixture to download files and register them on the given connection""" + """Single fixture to download files and register them on the given connection.""" session = requests.Session() retries = urllib3_util.Retry( - allowed_methods={'GET'}, # only retry on GETs (all we do) + allowed_methods={"GET"}, # only retry on GETs (all we do) total=None, # disable to make the below take effect redirect=10, # Don't follow more than 10 redirects in a row connect=3, # try 3 times before giving up on connection errors read=3, # try 3 times before giving up on read errors status=3, # try 3 times before giving up on status errors (see forcelist below) - status_forcelist=[429] + [status for status in range(500, 512)], + status_forcelist=[429, *list(range(500, 512))], other=0, # whatever else may cause an error should break backoff_factor=0.1, # [0.0s, 0.2s, 0.4s] raise_on_redirect=True, # raise exception when redirect error retries are exhausted raise_on_status=True, # raise exception when status error retries are exhausted respect_retry_after_header=True, # respect Retry-After headers ) - session.mount('https://', requests_adapters.HTTPAdapter(max_retries=retries)) - saved_filenames = set() + session.mount("https://", requests_adapters.HTTPAdapter(max_retries=retries)) + saved_filepaths = set() - def _register(url, filename, con, tablename): + def _register(url, filename, con, tablename) -> None: r = session.get(url) - with open(filename, 'wb') as f: - f.write(r.content) + filepath = Path(filename) + filepath.write_bytes(r.content) con.register(tablename, read_csv(filename)) - saved_filenames.add(filename) + saved_filepaths.add(filepath) yield _register - for filename in saved_filenames: - os.remove(filename) + for filepath in saved_filepaths: + filepath.unlink() session.close() -@fixture(scope="module") +@pytest.fixture(scope="module") def large_data(arrow_dataset_register): con = duckdb.connect() arrow_dataset_register( - 'https://github.com/duckdb/duckdb-data/releases/download/v1.0/J1_1e7_NA_0_0.csv.gz', - 'J1_1e7_NA_0_0.csv.gz', + "https://github.com/duckdb/duckdb-data/releases/download/v1.0/J1_1e7_NA_0_0.csv.gz", + "J1_1e7_NA_0_0.csv.gz", con, "x", ) arrow_dataset_register( - 'https://github.com/duckdb/duckdb-data/releases/download/v1.0/J1_1e7_1e1_0_0.csv.gz', - 'J1_1e7_1e1_0_0.csv.gz', + "https://github.com/duckdb/duckdb-data/releases/download/v1.0/J1_1e7_1e1_0_0.csv.gz", + "J1_1e7_1e1_0_0.csv.gz", con, "small", ) arrow_dataset_register( - 'https://github.com/duckdb/duckdb-data/releases/download/v1.0/J1_1e7_1e4_0_0.csv.gz', - 'J1_1e7_1e4_0_0.csv.gz', + "https://github.com/duckdb/duckdb-data/releases/download/v1.0/J1_1e7_1e4_0_0.csv.gz", + "J1_1e7_1e4_0_0.csv.gz", con, "medium", ) arrow_dataset_register( - 'https://github.com/duckdb/duckdb-data/releases/download/v1.0/J1_1e7_1e7_0_0.csv.gz', - 'J1_1e7_1e7_0_0.csv.gz', + "https://github.com/duckdb/duckdb-data/releases/download/v1.0/J1_1e7_1e7_0_0.csv.gz", + "J1_1e7_1e7_0_0.csv.gz", con, "big", ) @@ -259,12 +261,12 @@ def large_data(arrow_dataset_register): con.close() -@fixture(scope="module") +@pytest.fixture(scope="module") def group_by_data(arrow_dataset_register): con = duckdb.connect() arrow_dataset_register( - 'https://github.com/duckdb/duckdb-data/releases/download/v1.0/G1_1e7_1e2_5_0.csv.gz', - 'G1_1e7_1e2_5_0.csv.gz', + "https://github.com/duckdb/duckdb-data/releases/download/v1.0/G1_1e7_1e2_5_0.csv.gz", + "G1_1e7_1e2_5_0.csv.gz", con, "x", ) diff --git a/tests/spark_namespace/__init__.py b/tests/spark_namespace/__init__.py index 11f91af2..3c8057ca 100644 --- a/tests/spark_namespace/__init__.py +++ b/tests/spark_namespace/__init__.py @@ -1,4 +1,3 @@ import os -import sys USE_ACTUAL_SPARK = os.getenv("USE_ACTUAL_SPARK") == "true" diff --git a/tests/spark_namespace/sql/__init__.py b/tests/spark_namespace/sql/__init__.py index 67075e3c..6557be5a 100644 --- a/tests/spark_namespace/sql/__init__.py +++ b/tests/spark_namespace/sql/__init__.py @@ -4,3 +4,5 @@ from pyspark.sql import SparkSession else: from duckdb.experimental.spark.sql import SparkSession + +__all__ = ["SparkSession"] diff --git a/tests/stubs/mypy.ini b/tests/stubs/mypy.ini deleted file mode 100644 index ff840b82..00000000 --- a/tests/stubs/mypy.ini +++ /dev/null @@ -1,14 +0,0 @@ -[mypy] -mypy_path = duckdb -[mypy-fsspec] -ignore_missing_imports = True -[mypy-pandas] -ignore_missing_imports = True -[mypy-polars] -ignore_missing_imports = True -[mypy-pyarrow] -ignore_missing_imports = True -[mypy-pyarrow.lib] -ignore_missing_imports = True -[mypy-torch] -ignore_missing_imports = True \ No newline at end of file diff --git a/tests/stubs/test_stubs.py b/tests/stubs/test_stubs.py deleted file mode 100644 index 2f178bcc..00000000 --- a/tests/stubs/test_stubs.py +++ /dev/null @@ -1,25 +0,0 @@ -import os - -from mypy import stubtest - -MYPY_INI_PATH = os.path.join(os.path.dirname(__file__), 'mypy.ini') - - -def test_generated_stubs(): - skip_stubs_errors = ['pybind11', 'git_revision', 'is inconsistent, metaclass differs'] - - options = stubtest.parse_options(['duckdb', '--mypy-config-file', MYPY_INI_PATH]) - stubtest.test_stubs(options) - - broken_stubs = [ - error.get_description() - for error in stubtest.test_module('duckdb') - if not any(skip in error.get_description() for skip in skip_stubs_errors) - ] - - if not broken_stubs: - return - print("Stubs must be updated, either add them to skip_stubs_errors or update __init__.pyi accordingly") - print(broken_stubs) - - assert not broken_stubs
      Phase