From 521e47e74179137b6aeff4b5008de34b2602c0dd Mon Sep 17 00:00:00 2001 From: Tishj Date: Mon, 15 Sep 2025 14:41:48 +0200 Subject: [PATCH 001/135] separate the pyarrow filter pushdown to separate file, cleanup direct imports --- Makefile | 4 + scripts/cache_data.json | 86 ++++- scripts/generate_import_cache_cpp.py | 8 +- scripts/generate_import_cache_json.py | 1 - scripts/imports.py | 16 + src/duckdb_py/arrow/CMakeLists.txt | 3 +- src/duckdb_py/arrow/arrow_array_stream.cpp | 340 +----------------- .../arrow/pyarrow_filter_pushdown.cpp | 336 +++++++++++++++++ .../arrow/arrow_array_stream.hpp | 5 - .../arrow/pyarrow_filter_pushdown.hpp | 26 ++ .../import_cache/modules/pyarrow_module.hpp | 17 +- src/duckdb_py/pyrelation/initialize.cpp | 22 +- src/duckdb_py/typing/pytype.cpp | 9 +- tests/fast/api/test_dbapi10.py | 23 +- .../relational_api/test_rapi_description.py | 2 +- tests/fast/udf/test_remove_function.py | 4 +- 16 files changed, 530 insertions(+), 372 deletions(-) create mode 100644 Makefile create mode 100644 src/duckdb_py/arrow/pyarrow_filter_pushdown.cpp create mode 100644 src/duckdb_py/include/duckdb_python/arrow/pyarrow_filter_pushdown.hpp diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..07008f11 --- /dev/null +++ b/Makefile @@ -0,0 +1,4 @@ +PYTHON ?= python3 + +format-main: + $(PYTHON) external/duckdb/scripts/format.py main --fix --noconfirm \ No newline at end of file diff --git a/scripts/cache_data.json b/scripts/cache_data.json index 640052cd..3dd9a1f1 100644 --- a/scripts/cache_data.json +++ b/scripts/cache_data.json @@ -7,7 +7,19 @@ "pyarrow.dataset", "pyarrow.Table", "pyarrow.RecordBatchReader", - "pyarrow.ipc" + "pyarrow.ipc", + "pyarrow.scalar", + "pyarrow.date32", + "pyarrow.time64", + "pyarrow.timestamp", + "pyarrow.uint8", + "pyarrow.uint16", + "pyarrow.uint32", + "pyarrow.uint64", + "pyarrow.binary_view", + "pyarrow.decimal32", + "pyarrow.decimal64", + "pyarrow.decimal128" ] }, "pyarrow.dataset": { @@ -709,5 +721,77 @@ "name": "duckdb_source", "children": [], "required": false + }, + "pyarrow.scalar": { + "type": "attribute", + "full_path": "pyarrow.scalar", + "name": "scalar", + "children": [] + }, + "pyarrow.date32": { + "type": "attribute", + "full_path": "pyarrow.date32", + "name": "date32", + "children": [] + }, + "pyarrow.time64": { + "type": "attribute", + "full_path": "pyarrow.time64", + "name": "time64", + "children": [] + }, + "pyarrow.timestamp": { + "type": "attribute", + "full_path": "pyarrow.timestamp", + "name": "timestamp", + "children": [] + }, + "pyarrow.uint8": { + "type": "attribute", + "full_path": "pyarrow.uint8", + "name": "uint8", + "children": [] + }, + "pyarrow.uint16": { + "type": "attribute", + "full_path": "pyarrow.uint16", + "name": "uint16", + "children": [] + }, + "pyarrow.uint32": { + "type": "attribute", + "full_path": "pyarrow.uint32", + "name": "uint32", + "children": [] + }, + "pyarrow.uint64": { + "type": "attribute", + "full_path": "pyarrow.uint64", + "name": "uint64", + "children": [] + }, + "pyarrow.binary_view": { + "type": "attribute", + "full_path": "pyarrow.binary_view", + "name": "binary_view", + "children": [] + }, + "pyarrow.decimal32": { + "type": "attribute", + "full_path": "pyarrow.decimal32", + "name": "decimal32", + "children": [] + }, + "pyarrow.decimal64": { + "type": "attribute", + "full_path": "pyarrow.decimal64", + "name": "decimal64", + "children": [] + }, + "pyarrow.decimal128": { + "type": "attribute", + "full_path": "pyarrow.decimal128", + "name": "decimal128", + "children": [] } } \ No newline at end of file diff --git a/scripts/generate_import_cache_cpp.py b/scripts/generate_import_cache_cpp.py index f902c5a5..07744e37 100644 --- a/scripts/generate_import_cache_cpp.py +++ b/scripts/generate_import_cache_cpp.py @@ -182,7 +182,7 @@ def to_string(self): for file in files: content = file.to_string() - path = f'src/include/duckdb_python/import_cache/modules/{file.file_name}' + path = f'src/duckdb_py/include/duckdb_python/import_cache/modules/{file.file_name}' import_cache_path = os.path.join(script_dir, '..', path) with open(import_cache_path, "w") as f: f.write(content) @@ -237,7 +237,9 @@ def get_root_modules(files: List[ModuleFile]): """ -import_cache_path = os.path.join(script_dir, '..', 'src/include/duckdb_python/import_cache/python_import_cache.hpp') +import_cache_path = os.path.join( + script_dir, '..', 'src/duckdb_py/include/duckdb_python/import_cache/python_import_cache.hpp' +) with open(import_cache_path, "w") as f: f.write(import_cache_file) @@ -252,7 +254,7 @@ def get_module_file_path_includes(files: List[ModuleFile]): module_includes = get_module_file_path_includes(files) modules_header = os.path.join( - script_dir, '..', 'src/include/duckdb_python/import_cache/python_import_cache_modules.hpp' + script_dir, '..', 'src/duckdb_py/include/duckdb_python/import_cache/python_import_cache_modules.hpp' ) with open(modules_header, "w") as f: f.write(module_includes) diff --git a/scripts/generate_import_cache_json.py b/scripts/generate_import_cache_json.py index 7a59e6b7..40e6a773 100644 --- a/scripts/generate_import_cache_json.py +++ b/scripts/generate_import_cache_json.py @@ -170,7 +170,6 @@ def update_json(existing: dict, new: dict) -> dict: # If both values are dictionaries, update recursively. if isinstance(new_value, dict) and isinstance(old_value, dict): - print(key) updated = update_json(old_value, new_value) existing[key] = updated else: diff --git a/scripts/imports.py b/scripts/imports.py index 6b035768..c51f53b7 100644 --- a/scripts/imports.py +++ b/scripts/imports.py @@ -6,6 +6,22 @@ pyarrow.Table pyarrow.RecordBatchReader pyarrow.ipc.MessageReader +pyarrow.scalar +pyarrow.date32 +pyarrow.time64 +pyarrow.timestamp +pyarrow.timestamp +pyarrow.timestamp +pyarrow.timestamp +pyarrow.timestamp +pyarrow.uint8 +pyarrow.uint16 +pyarrow.uint32 +pyarrow.uint64 +pyarrow.binary_view +pyarrow.decimal32 +pyarrow.decimal64 +pyarrow.decimal128 import pandas diff --git a/src/duckdb_py/arrow/CMakeLists.txt b/src/duckdb_py/arrow/CMakeLists.txt index 29b188c6..9a9188b8 100644 --- a/src/duckdb_py/arrow/CMakeLists.txt +++ b/src/duckdb_py/arrow/CMakeLists.txt @@ -1,4 +1,5 @@ # this is used for clang-tidy checks -add_library(python_arrow OBJECT arrow_array_stream.cpp arrow_export_utils.cpp) +add_library(python_arrow OBJECT arrow_array_stream.cpp arrow_export_utils.cpp + pyarrow_filter_pushdown.cpp) target_link_libraries(python_arrow PRIVATE _duckdb_dependencies) diff --git a/src/duckdb_py/arrow/arrow_array_stream.cpp b/src/duckdb_py/arrow/arrow_array_stream.cpp index 533c31ed..f9cfd1bb 100644 --- a/src/duckdb_py/arrow/arrow_array_stream.cpp +++ b/src/duckdb_py/arrow/arrow_array_stream.cpp @@ -1,22 +1,15 @@ #include "duckdb_python/arrow/arrow_array_stream.hpp" +#include "duckdb_python/arrow/pyarrow_filter_pushdown.hpp" -#include "duckdb/common/types/value_map.hpp" -#include "duckdb/planner/filter/in_filter.hpp" -#include "duckdb/planner/filter/optional_filter.hpp" +#include "duckdb_python/pyconnection/pyconnection.hpp" +#include "duckdb_python/pyrelation.hpp" +#include "duckdb_python/pyresult.hpp" +#include "duckdb/function/table/arrow.hpp" #include "duckdb/common/assert.hpp" #include "duckdb/common/common.hpp" #include "duckdb/common/limits.hpp" #include "duckdb/main/client_config.hpp" -#include "duckdb/planner/filter/conjunction_filter.hpp" -#include "duckdb/planner/filter/constant_filter.hpp" -#include "duckdb/planner/filter/struct_filter.hpp" -#include "duckdb/planner/table_filter.hpp" - -#include "duckdb_python/pyconnection/pyconnection.hpp" -#include "duckdb_python/pyrelation.hpp" -#include "duckdb_python/pyresult.hpp" -#include "duckdb/function/table/arrow.hpp" namespace duckdb { @@ -56,8 +49,8 @@ py::object PythonTableArrowArrayStreamFactory::ProduceScanner(DBConfig &config, } if (has_filter) { - auto filter = TransformFilter(*filters, parameters.projected_columns.projection_map, filter_to_col, - client_properties, arrow_table); + auto filter = PyArrowFilterPushdown::TransformFilter(*filters, parameters.projected_columns.projection_map, + filter_to_col, client_properties, arrow_table); if (!filter.is(py::none())) { kwargs["filter"] = filter; } @@ -171,323 +164,4 @@ void PythonTableArrowArrayStreamFactory::GetSchema(uintptr_t factory_ptr, ArrowS GetSchemaInternal(arrow_obj_handle, schema); } -string ConvertTimestampUnit(ArrowDateTimeType unit) { - switch (unit) { - case ArrowDateTimeType::MICROSECONDS: - return "us"; - case ArrowDateTimeType::MILLISECONDS: - return "ms"; - case ArrowDateTimeType::NANOSECONDS: - return "ns"; - case ArrowDateTimeType::SECONDS: - return "s"; - default: - throw NotImplementedException("DatetimeType not recognized in ConvertTimestampUnit: %d", (int)unit); - } -} - -int64_t ConvertTimestampTZValue(int64_t base_value, ArrowDateTimeType datetime_type) { - auto input = timestamp_t(base_value); - if (!Timestamp::IsFinite(input)) { - return base_value; - } - - switch (datetime_type) { - case ArrowDateTimeType::MICROSECONDS: - return Timestamp::GetEpochMicroSeconds(input); - case ArrowDateTimeType::MILLISECONDS: - return Timestamp::GetEpochMs(input); - case ArrowDateTimeType::NANOSECONDS: - return Timestamp::GetEpochNanoSeconds(input); - case ArrowDateTimeType::SECONDS: - return Timestamp::GetEpochSeconds(input); - default: - throw NotImplementedException("DatetimeType not recognized in ConvertTimestampTZValue"); - } -} - -py::object GetScalar(Value &constant, const string &timezone_config, const ArrowType &type) { - py::object scalar = py::module_::import("pyarrow").attr("scalar"); - auto &import_cache = *DuckDBPyConnection::ImportCache(); - py::object dataset_scalar = import_cache.pyarrow.dataset().attr("scalar"); - py::object scalar_value; - switch (constant.type().id()) { - case LogicalTypeId::BOOLEAN: - return dataset_scalar(constant.GetValue()); - case LogicalTypeId::TINYINT: - return dataset_scalar(constant.GetValue()); - case LogicalTypeId::SMALLINT: - return dataset_scalar(constant.GetValue()); - case LogicalTypeId::INTEGER: - return dataset_scalar(constant.GetValue()); - case LogicalTypeId::BIGINT: - return dataset_scalar(constant.GetValue()); - case LogicalTypeId::DATE: { - py::object date_type = py::module_::import("pyarrow").attr("date32"); - return dataset_scalar(scalar(constant.GetValue(), date_type())); - } - case LogicalTypeId::TIME: { - py::object date_type = py::module_::import("pyarrow").attr("time64"); - return dataset_scalar(scalar(constant.GetValue(), date_type("us"))); - } - case LogicalTypeId::TIMESTAMP: { - py::object date_type = py::module_::import("pyarrow").attr("timestamp"); - return dataset_scalar(scalar(constant.GetValue(), date_type("us"))); - } - case LogicalTypeId::TIMESTAMP_MS: { - py::object date_type = py::module_::import("pyarrow").attr("timestamp"); - return dataset_scalar(scalar(constant.GetValue(), date_type("ms"))); - } - case LogicalTypeId::TIMESTAMP_NS: { - py::object date_type = py::module_::import("pyarrow").attr("timestamp"); - return dataset_scalar(scalar(constant.GetValue(), date_type("ns"))); - } - case LogicalTypeId::TIMESTAMP_SEC: { - py::object date_type = py::module_::import("pyarrow").attr("timestamp"); - return dataset_scalar(scalar(constant.GetValue(), date_type("s"))); - } - case LogicalTypeId::TIMESTAMP_TZ: { - auto &datetime_info = type.GetTypeInfo(); - auto base_value = constant.GetValue(); - auto arrow_datetime_type = datetime_info.GetDateTimeType(); - auto time_unit_string = ConvertTimestampUnit(arrow_datetime_type); - auto converted_value = ConvertTimestampTZValue(base_value, arrow_datetime_type); - py::object date_type = py::module_::import("pyarrow").attr("timestamp"); - return dataset_scalar(scalar(converted_value, date_type(time_unit_string, py::arg("tz") = timezone_config))); - } - case LogicalTypeId::UTINYINT: { - py::object integer_type = py::module_::import("pyarrow").attr("uint8"); - return dataset_scalar(scalar(constant.GetValue(), integer_type())); - } - case LogicalTypeId::USMALLINT: { - py::object integer_type = py::module_::import("pyarrow").attr("uint16"); - return dataset_scalar(scalar(constant.GetValue(), integer_type())); - } - case LogicalTypeId::UINTEGER: { - py::object integer_type = py::module_::import("pyarrow").attr("uint32"); - return dataset_scalar(scalar(constant.GetValue(), integer_type())); - } - case LogicalTypeId::UBIGINT: { - py::object integer_type = py::module_::import("pyarrow").attr("uint64"); - return dataset_scalar(scalar(constant.GetValue(), integer_type())); - } - case LogicalTypeId::FLOAT: - return dataset_scalar(constant.GetValue()); - case LogicalTypeId::DOUBLE: - return dataset_scalar(constant.GetValue()); - case LogicalTypeId::VARCHAR: - return dataset_scalar(constant.ToString()); - case LogicalTypeId::BLOB: { - if (type.GetTypeInfo().GetSizeType() == ArrowVariableSizeType::VIEW) { - py::object binary_view_type = py::module_::import("pyarrow").attr("binary_view"); - return dataset_scalar(scalar(py::bytes(constant.GetValueUnsafe()), binary_view_type())); - } - return dataset_scalar(py::bytes(constant.GetValueUnsafe())); - } - case LogicalTypeId::DECIMAL: { - py::object decimal_type; - auto &datetime_info = type.GetTypeInfo(); - auto bit_width = datetime_info.GetBitWidth(); - switch (bit_width) { - case DecimalBitWidth::DECIMAL_32: - decimal_type = py::module_::import("pyarrow").attr("decimal32"); - break; - case DecimalBitWidth::DECIMAL_64: - decimal_type = py::module_::import("pyarrow").attr("decimal64"); - break; - case DecimalBitWidth::DECIMAL_128: - decimal_type = py::module_::import("pyarrow").attr("decimal128"); - break; - default: - throw NotImplementedException("Unsupported precision for Arrow Decimal Type."); - } - - uint8_t width; - uint8_t scale; - constant.type().GetDecimalProperties(width, scale); - // pyarrow only allows 'decimal.Decimal' to be used to construct decimal scalars such as 0.05 - auto val = import_cache.decimal.Decimal()(constant.ToString()); - return dataset_scalar( - scalar(std::move(val), decimal_type(py::arg("precision") = width, py::arg("scale") = scale))); - } - default: - throw NotImplementedException("Unimplemented type \"%s\" for Arrow Filter Pushdown", - constant.type().ToString()); - } -} - -py::object TransformFilterRecursive(TableFilter &filter, vector column_ref, const string &timezone_config, - const ArrowType &type) { - auto &import_cache = *DuckDBPyConnection::ImportCache(); - py::object field = import_cache.pyarrow.dataset().attr("field"); - switch (filter.filter_type) { - case TableFilterType::CONSTANT_COMPARISON: { - auto &constant_filter = filter.Cast(); - auto constant_field = field(py::tuple(py::cast(column_ref))); - auto constant_value = GetScalar(constant_filter.constant, timezone_config, type); - - bool is_nan = false; - auto &constant = constant_filter.constant; - auto &constant_type = constant.type(); - if (constant_type.id() == LogicalTypeId::FLOAT) { - is_nan = Value::IsNan(constant.GetValue()); - } else if (constant_type.id() == LogicalTypeId::DOUBLE) { - is_nan = Value::IsNan(constant.GetValue()); - } - - // Special handling for NaN comparisons (to explicitly violate IEEE-754) - if (is_nan) { - switch (constant_filter.comparison_type) { - case ExpressionType::COMPARE_EQUAL: - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - return constant_field.attr("is_nan")(); - case ExpressionType::COMPARE_LESSTHAN: - case ExpressionType::COMPARE_NOTEQUAL: - return constant_field.attr("is_nan")().attr("__invert__")(); - case ExpressionType::COMPARE_GREATERTHAN: - // Nothing is greater than NaN - return import_cache.pyarrow.dataset().attr("scalar")(false); - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - // Everything is less than or equal to NaN - return import_cache.pyarrow.dataset().attr("scalar")(true); - default: - throw NotImplementedException("Unsupported comparison type (%s) for NaN values", - EnumUtil::ToString(constant_filter.comparison_type)); - } - } - - switch (constant_filter.comparison_type) { - case ExpressionType::COMPARE_EQUAL: - return constant_field.attr("__eq__")(constant_value); - case ExpressionType::COMPARE_LESSTHAN: - return constant_field.attr("__lt__")(constant_value); - case ExpressionType::COMPARE_GREATERTHAN: - return constant_field.attr("__gt__")(constant_value); - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - return constant_field.attr("__le__")(constant_value); - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - return constant_field.attr("__ge__")(constant_value); - case ExpressionType::COMPARE_NOTEQUAL: - return constant_field.attr("__ne__")(constant_value); - default: - throw NotImplementedException("Comparison Type %s can't be an Arrow Scan Pushdown Filter", - EnumUtil::ToString(constant_filter.comparison_type)); - } - } - //! We do not pushdown is null yet - case TableFilterType::IS_NULL: { - auto constant_field = field(py::tuple(py::cast(column_ref))); - return constant_field.attr("is_null")(); - } - case TableFilterType::IS_NOT_NULL: { - auto constant_field = field(py::tuple(py::cast(column_ref))); - return constant_field.attr("is_valid")(); - } - //! We do not pushdown or conjunctions yet - case TableFilterType::CONJUNCTION_OR: { - auto &or_filter = filter.Cast(); - py::object expression = py::none(); - for (idx_t i = 0; i < or_filter.child_filters.size(); i++) { - auto &child_filter = *or_filter.child_filters[i]; - py::object child_expression = TransformFilterRecursive(child_filter, column_ref, timezone_config, type); - if (child_expression.is(py::none())) { - continue; - } - if (expression.is(py::none())) { - expression = std::move(child_expression); - } else { - expression = expression.attr("__or__")(child_expression); - } - } - return expression; - } - case TableFilterType::CONJUNCTION_AND: { - auto &and_filter = filter.Cast(); - py::object expression = py::none(); - for (idx_t i = 0; i < and_filter.child_filters.size(); i++) { - auto &child_filter = *and_filter.child_filters[i]; - py::object child_expression = TransformFilterRecursive(child_filter, column_ref, timezone_config, type); - if (child_expression.is(py::none())) { - continue; - } - if (expression.is(py::none())) { - expression = std::move(child_expression); - } else { - expression = expression.attr("__and__")(child_expression); - } - } - return expression; - } - case TableFilterType::STRUCT_EXTRACT: { - auto &struct_filter = filter.Cast(); - auto &child_name = struct_filter.child_name; - auto &struct_type_info = type.GetTypeInfo(); - auto &struct_child_type = struct_type_info.GetChild(struct_filter.child_idx); - - column_ref.push_back(child_name); - auto child_expr = TransformFilterRecursive(*struct_filter.child_filter, std::move(column_ref), timezone_config, - struct_child_type); - return child_expr; - } - case TableFilterType::OPTIONAL_FILTER: { - auto &optional_filter = filter.Cast(); - if (!optional_filter.child_filter) { - return py::none(); - } - return TransformFilterRecursive(*optional_filter.child_filter, column_ref, timezone_config, type); - } - case TableFilterType::IN_FILTER: { - auto &in_filter = filter.Cast(); - ConjunctionOrFilter or_filter; - value_set_t unique_values; - for (const auto &value : in_filter.values) { - if (unique_values.find(value) == unique_values.end()) { - unique_values.insert(value); - } - } - for (const auto &value : unique_values) { - or_filter.child_filters.push_back(make_uniq(ExpressionType::COMPARE_EQUAL, value)); - } - return TransformFilterRecursive(or_filter, column_ref, timezone_config, type); - } - case TableFilterType::DYNAMIC_FILTER: { - //! Ignore dynamic filters for now, not necessary for correctness - return py::none(); - } - default: - throw NotImplementedException("Pushdown Filter Type %s is not currently supported in PyArrow Scans", - EnumUtil::ToString(filter.filter_type)); - } -} - -py::object PythonTableArrowArrayStreamFactory::TransformFilter(TableFilterSet &filter_collection, - std::unordered_map &columns, - unordered_map filter_to_col, - const ClientProperties &config, - const ArrowTableSchema &arrow_table) { - auto &filters_map = filter_collection.filters; - - py::object expression = py::none(); - for (auto &it : filters_map) { - auto column_idx = it.first; - auto &column_name = columns[column_idx]; - - vector column_ref; - column_ref.push_back(column_name); - - D_ASSERT(columns.find(column_idx) != columns.end()); - - auto &arrow_type = arrow_table.GetColumns().at(filter_to_col.at(column_idx)); - py::object child_expression = TransformFilterRecursive(*it.second, column_ref, config.time_zone, *arrow_type); - if (child_expression.is(py::none())) { - continue; - } else if (expression.is(py::none())) { - expression = std::move(child_expression); - } else { - expression = expression.attr("__and__")(child_expression); - } - } - return expression; -} - } // namespace duckdb diff --git a/src/duckdb_py/arrow/pyarrow_filter_pushdown.cpp b/src/duckdb_py/arrow/pyarrow_filter_pushdown.cpp new file mode 100644 index 00000000..66a6e3fa --- /dev/null +++ b/src/duckdb_py/arrow/pyarrow_filter_pushdown.cpp @@ -0,0 +1,336 @@ +#include "duckdb_python/arrow/pyarrow_filter_pushdown.hpp" + +#include "duckdb/common/types/value_map.hpp" +#include "duckdb/planner/filter/in_filter.hpp" +#include "duckdb/planner/filter/optional_filter.hpp" +#include "duckdb/planner/filter/conjunction_filter.hpp" +#include "duckdb/planner/filter/constant_filter.hpp" +#include "duckdb/planner/filter/struct_filter.hpp" +#include "duckdb/planner/table_filter.hpp" + +#include "duckdb_python/pyconnection/pyconnection.hpp" +#include "duckdb_python/pyrelation.hpp" +#include "duckdb_python/pyresult.hpp" +#include "duckdb/function/table/arrow.hpp" + +namespace duckdb { + +string ConvertTimestampUnit(ArrowDateTimeType unit) { + switch (unit) { + case ArrowDateTimeType::MICROSECONDS: + return "us"; + case ArrowDateTimeType::MILLISECONDS: + return "ms"; + case ArrowDateTimeType::NANOSECONDS: + return "ns"; + case ArrowDateTimeType::SECONDS: + return "s"; + default: + throw NotImplementedException("DatetimeType not recognized in ConvertTimestampUnit: %d", (int)unit); + } +} + +int64_t ConvertTimestampTZValue(int64_t base_value, ArrowDateTimeType datetime_type) { + auto input = timestamp_t(base_value); + if (!Timestamp::IsFinite(input)) { + return base_value; + } + + switch (datetime_type) { + case ArrowDateTimeType::MICROSECONDS: + return Timestamp::GetEpochMicroSeconds(input); + case ArrowDateTimeType::MILLISECONDS: + return Timestamp::GetEpochMs(input); + case ArrowDateTimeType::NANOSECONDS: + return Timestamp::GetEpochNanoSeconds(input); + case ArrowDateTimeType::SECONDS: + return Timestamp::GetEpochSeconds(input); + default: + throw NotImplementedException("DatetimeType not recognized in ConvertTimestampTZValue"); + } +} + +py::object GetScalar(Value &constant, const string &timezone_config, const ArrowType &type) { + auto &import_cache = *DuckDBPyConnection::ImportCache(); + auto scalar = import_cache.pyarrow.scalar(); + py::handle dataset_scalar = import_cache.pyarrow.dataset().attr("scalar"); + + switch (constant.type().id()) { + case LogicalTypeId::BOOLEAN: + return dataset_scalar(constant.GetValue()); + case LogicalTypeId::TINYINT: + return dataset_scalar(constant.GetValue()); + case LogicalTypeId::SMALLINT: + return dataset_scalar(constant.GetValue()); + case LogicalTypeId::INTEGER: + return dataset_scalar(constant.GetValue()); + case LogicalTypeId::BIGINT: + return dataset_scalar(constant.GetValue()); + case LogicalTypeId::DATE: { + py::handle date_type = import_cache.pyarrow.date32(); + return dataset_scalar(scalar(constant.GetValue(), date_type())); + } + case LogicalTypeId::TIME: { + py::handle date_type = import_cache.pyarrow.time64(); + return dataset_scalar(scalar(constant.GetValue(), date_type("us"))); + } + case LogicalTypeId::TIMESTAMP: { + py::handle date_type = import_cache.pyarrow.timestamp(); + return dataset_scalar(scalar(constant.GetValue(), date_type("us"))); + } + case LogicalTypeId::TIMESTAMP_MS: { + py::handle date_type = import_cache.pyarrow.timestamp(); + return dataset_scalar(scalar(constant.GetValue(), date_type("ms"))); + } + case LogicalTypeId::TIMESTAMP_NS: { + py::handle date_type = import_cache.pyarrow.timestamp(); + return dataset_scalar(scalar(constant.GetValue(), date_type("ns"))); + } + case LogicalTypeId::TIMESTAMP_SEC: { + py::handle date_type = import_cache.pyarrow.timestamp(); + return dataset_scalar(scalar(constant.GetValue(), date_type("s"))); + } + case LogicalTypeId::TIMESTAMP_TZ: { + auto &datetime_info = type.GetTypeInfo(); + auto base_value = constant.GetValue(); + auto arrow_datetime_type = datetime_info.GetDateTimeType(); + auto time_unit_string = ConvertTimestampUnit(arrow_datetime_type); + auto converted_value = ConvertTimestampTZValue(base_value, arrow_datetime_type); + py::handle date_type = import_cache.pyarrow.timestamp(); + return dataset_scalar(scalar(converted_value, date_type(time_unit_string, py::arg("tz") = timezone_config))); + } + case LogicalTypeId::UTINYINT: { + py::handle integer_type = import_cache.pyarrow.uint8(); + return dataset_scalar(scalar(constant.GetValue(), integer_type())); + } + case LogicalTypeId::USMALLINT: { + py::handle integer_type = import_cache.pyarrow.uint16(); + return dataset_scalar(scalar(constant.GetValue(), integer_type())); + } + case LogicalTypeId::UINTEGER: { + py::handle integer_type = import_cache.pyarrow.uint32(); + return dataset_scalar(scalar(constant.GetValue(), integer_type())); + } + case LogicalTypeId::UBIGINT: { + py::handle integer_type = import_cache.pyarrow.uint64(); + return dataset_scalar(scalar(constant.GetValue(), integer_type())); + } + case LogicalTypeId::FLOAT: + return dataset_scalar(constant.GetValue()); + case LogicalTypeId::DOUBLE: + return dataset_scalar(constant.GetValue()); + case LogicalTypeId::VARCHAR: + return dataset_scalar(constant.ToString()); + case LogicalTypeId::BLOB: { + if (type.GetTypeInfo().GetSizeType() == ArrowVariableSizeType::VIEW) { + py::handle binary_view_type = import_cache.pyarrow.binary_view(); + return dataset_scalar(scalar(py::bytes(constant.GetValueUnsafe()), binary_view_type())); + } + return dataset_scalar(py::bytes(constant.GetValueUnsafe())); + } + case LogicalTypeId::DECIMAL: { + py::handle decimal_type; + auto &datetime_info = type.GetTypeInfo(); + auto bit_width = datetime_info.GetBitWidth(); + switch (bit_width) { + case DecimalBitWidth::DECIMAL_32: + decimal_type = import_cache.pyarrow.decimal32(); + break; + case DecimalBitWidth::DECIMAL_64: + decimal_type = import_cache.pyarrow.decimal64(); + break; + case DecimalBitWidth::DECIMAL_128: + decimal_type = import_cache.pyarrow.decimal128(); + break; + default: + throw NotImplementedException("Unsupported precision for Arrow Decimal Type."); + } + + uint8_t width; + uint8_t scale; + constant.type().GetDecimalProperties(width, scale); + // pyarrow only allows 'decimal.Decimal' to be used to construct decimal scalars such as 0.05 + auto val = import_cache.decimal.Decimal()(constant.ToString()); + return dataset_scalar( + scalar(std::move(val), decimal_type(py::arg("precision") = width, py::arg("scale") = scale))); + } + default: + throw NotImplementedException("Unimplemented type \"%s\" for Arrow Filter Pushdown", + constant.type().ToString()); + } +} + +py::object TransformFilterRecursive(TableFilter &filter, vector column_ref, const string &timezone_config, + const ArrowType &type) { + auto &import_cache = *DuckDBPyConnection::ImportCache(); + py::object field = import_cache.pyarrow.dataset().attr("field"); + switch (filter.filter_type) { + case TableFilterType::CONSTANT_COMPARISON: { + auto &constant_filter = filter.Cast(); + auto constant_field = field(py::tuple(py::cast(column_ref))); + auto constant_value = GetScalar(constant_filter.constant, timezone_config, type); + + bool is_nan = false; + auto &constant = constant_filter.constant; + auto &constant_type = constant.type(); + if (constant_type.id() == LogicalTypeId::FLOAT) { + is_nan = Value::IsNan(constant.GetValue()); + } else if (constant_type.id() == LogicalTypeId::DOUBLE) { + is_nan = Value::IsNan(constant.GetValue()); + } + + // Special handling for NaN comparisons (to explicitly violate IEEE-754) + if (is_nan) { + switch (constant_filter.comparison_type) { + case ExpressionType::COMPARE_EQUAL: + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + return constant_field.attr("is_nan")(); + case ExpressionType::COMPARE_LESSTHAN: + case ExpressionType::COMPARE_NOTEQUAL: + return constant_field.attr("is_nan")().attr("__invert__")(); + case ExpressionType::COMPARE_GREATERTHAN: + // Nothing is greater than NaN + return import_cache.pyarrow.dataset().attr("scalar")(false); + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + // Everything is less than or equal to NaN + return import_cache.pyarrow.dataset().attr("scalar")(true); + default: + throw NotImplementedException("Unsupported comparison type (%s) for NaN values", + EnumUtil::ToString(constant_filter.comparison_type)); + } + } + + switch (constant_filter.comparison_type) { + case ExpressionType::COMPARE_EQUAL: + return constant_field.attr("__eq__")(constant_value); + case ExpressionType::COMPARE_LESSTHAN: + return constant_field.attr("__lt__")(constant_value); + case ExpressionType::COMPARE_GREATERTHAN: + return constant_field.attr("__gt__")(constant_value); + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + return constant_field.attr("__le__")(constant_value); + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + return constant_field.attr("__ge__")(constant_value); + case ExpressionType::COMPARE_NOTEQUAL: + return constant_field.attr("__ne__")(constant_value); + default: + throw NotImplementedException("Comparison Type %s can't be an Arrow Scan Pushdown Filter", + EnumUtil::ToString(constant_filter.comparison_type)); + } + } + //! We do not pushdown is null yet + case TableFilterType::IS_NULL: { + auto constant_field = field(py::tuple(py::cast(column_ref))); + return constant_field.attr("is_null")(); + } + case TableFilterType::IS_NOT_NULL: { + auto constant_field = field(py::tuple(py::cast(column_ref))); + return constant_field.attr("is_valid")(); + } + //! We do not pushdown or conjunctions yet + case TableFilterType::CONJUNCTION_OR: { + auto &or_filter = filter.Cast(); + py::object expression = py::none(); + for (idx_t i = 0; i < or_filter.child_filters.size(); i++) { + auto &child_filter = *or_filter.child_filters[i]; + py::object child_expression = TransformFilterRecursive(child_filter, column_ref, timezone_config, type); + if (child_expression.is(py::none())) { + continue; + } + if (expression.is(py::none())) { + expression = std::move(child_expression); + } else { + expression = expression.attr("__or__")(child_expression); + } + } + return expression; + } + case TableFilterType::CONJUNCTION_AND: { + auto &and_filter = filter.Cast(); + py::object expression = py::none(); + for (idx_t i = 0; i < and_filter.child_filters.size(); i++) { + auto &child_filter = *and_filter.child_filters[i]; + py::object child_expression = TransformFilterRecursive(child_filter, column_ref, timezone_config, type); + if (child_expression.is(py::none())) { + continue; + } + if (expression.is(py::none())) { + expression = std::move(child_expression); + } else { + expression = expression.attr("__and__")(child_expression); + } + } + return expression; + } + case TableFilterType::STRUCT_EXTRACT: { + auto &struct_filter = filter.Cast(); + auto &child_name = struct_filter.child_name; + auto &struct_type_info = type.GetTypeInfo(); + auto &struct_child_type = struct_type_info.GetChild(struct_filter.child_idx); + + column_ref.push_back(child_name); + auto child_expr = TransformFilterRecursive(*struct_filter.child_filter, std::move(column_ref), timezone_config, + struct_child_type); + return child_expr; + } + case TableFilterType::OPTIONAL_FILTER: { + auto &optional_filter = filter.Cast(); + if (!optional_filter.child_filter) { + return py::none(); + } + return TransformFilterRecursive(*optional_filter.child_filter, column_ref, timezone_config, type); + } + case TableFilterType::IN_FILTER: { + auto &in_filter = filter.Cast(); + ConjunctionOrFilter or_filter; + value_set_t unique_values; + for (const auto &value : in_filter.values) { + if (unique_values.find(value) == unique_values.end()) { + unique_values.insert(value); + } + } + for (const auto &value : unique_values) { + or_filter.child_filters.push_back(make_uniq(ExpressionType::COMPARE_EQUAL, value)); + } + return TransformFilterRecursive(or_filter, column_ref, timezone_config, type); + } + case TableFilterType::DYNAMIC_FILTER: { + //! Ignore dynamic filters for now, not necessary for correctness + return py::none(); + } + default: + throw NotImplementedException("Pushdown Filter Type %s is not currently supported in PyArrow Scans", + EnumUtil::ToString(filter.filter_type)); + } +} + +py::object PyArrowFilterPushdown::TransformFilter(TableFilterSet &filter_collection, + unordered_map &columns, + unordered_map filter_to_col, + const ClientProperties &config, const ArrowTableSchema &arrow_table) { + auto &filters_map = filter_collection.filters; + + py::object expression = py::none(); + for (auto &it : filters_map) { + auto column_idx = it.first; + auto &column_name = columns[column_idx]; + + vector column_ref; + column_ref.push_back(column_name); + + D_ASSERT(columns.find(column_idx) != columns.end()); + + auto &arrow_type = arrow_table.GetColumns().at(filter_to_col.at(column_idx)); + py::object child_expression = TransformFilterRecursive(*it.second, column_ref, config.time_zone, *arrow_type); + if (child_expression.is(py::none())) { + continue; + } else if (expression.is(py::none())) { + expression = std::move(child_expression); + } else { + expression = expression.attr("__and__")(child_expression); + } + } + return expression; +} + +} // namespace duckdb diff --git a/src/duckdb_py/include/duckdb_python/arrow/arrow_array_stream.hpp b/src/duckdb_py/include/duckdb_python/arrow/arrow_array_stream.hpp index 7eb6d20b..a5895b4a 100644 --- a/src/duckdb_py/include/duckdb_python/arrow/arrow_array_stream.hpp +++ b/src/duckdb_py/include/duckdb_python/arrow/arrow_array_stream.hpp @@ -86,11 +86,6 @@ class PythonTableArrowArrayStreamFactory { DBConfig &config; private: - //! We transform a TableFilterSet to an Arrow Expression Object - static py::object TransformFilter(TableFilterSet &filters, std::unordered_map &columns, - unordered_map filter_to_col, - const ClientProperties &client_properties, const ArrowTableSchema &arrow_table); - static py::object ProduceScanner(DBConfig &config, py::object &arrow_scanner, py::handle &arrow_obj_handle, ArrowStreamParameters ¶meters, const ClientProperties &client_properties); }; diff --git a/src/duckdb_py/include/duckdb_python/arrow/pyarrow_filter_pushdown.hpp b/src/duckdb_py/include/duckdb_python/arrow/pyarrow_filter_pushdown.hpp new file mode 100644 index 00000000..4cc85a47 --- /dev/null +++ b/src/duckdb_py/include/duckdb_python/arrow/pyarrow_filter_pushdown.hpp @@ -0,0 +1,26 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb_python/arrow/pyarrow_filter_pushdown.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/arrow/arrow_wrapper.hpp" +#include "duckdb/function/table/arrow/arrow_duck_schema.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/planner/table_filter.hpp" +#include "duckdb/main/client_properties.hpp" +#include "duckdb_python/pybind11/pybind_wrapper.hpp" + +namespace duckdb { + +struct PyArrowFilterPushdown { + static py::object TransformFilter(TableFilterSet &filter_collection, unordered_map &columns, + unordered_map filter_to_col, const ClientProperties &config, + const ArrowTableSchema &arrow_table); +}; + +} // namespace duckdb diff --git a/src/duckdb_py/include/duckdb_python/import_cache/modules/pyarrow_module.hpp b/src/duckdb_py/include/duckdb_python/import_cache/modules/pyarrow_module.hpp index ccd8a16d..d3331565 100644 --- a/src/duckdb_py/include/duckdb_python/import_cache/modules/pyarrow_module.hpp +++ b/src/duckdb_py/include/duckdb_python/import_cache/modules/pyarrow_module.hpp @@ -56,7 +56,10 @@ struct PyarrowCacheItem : public PythonImportCacheItem { public: PyarrowCacheItem() : PythonImportCacheItem("pyarrow"), dataset(), Table("Table", this), - RecordBatchReader("RecordBatchReader", this), ipc(this) { + RecordBatchReader("RecordBatchReader", this), ipc(this), scalar("scalar", this), date32("date32", this), + time64("time64", this), timestamp("timestamp", this), uint8("uint8", this), uint16("uint16", this), + uint32("uint32", this), uint64("uint64", this), binary_view("binary_view", this), + decimal32("decimal32", this), decimal64("decimal64", this), decimal128("decimal128", this) { } ~PyarrowCacheItem() override { } @@ -65,6 +68,18 @@ struct PyarrowCacheItem : public PythonImportCacheItem { PythonImportCacheItem Table; PythonImportCacheItem RecordBatchReader; PyarrowIpcCacheItem ipc; + PythonImportCacheItem scalar; + PythonImportCacheItem date32; + PythonImportCacheItem time64; + PythonImportCacheItem timestamp; + PythonImportCacheItem uint8; + PythonImportCacheItem uint16; + PythonImportCacheItem uint32; + PythonImportCacheItem uint64; + PythonImportCacheItem binary_view; + PythonImportCacheItem decimal32; + PythonImportCacheItem decimal64; + PythonImportCacheItem decimal128; }; } // namespace duckdb diff --git a/src/duckdb_py/pyrelation/initialize.cpp b/src/duckdb_py/pyrelation/initialize.cpp index 7992cc17..cd1f042c 100644 --- a/src/duckdb_py/pyrelation/initialize.cpp +++ b/src/duckdb_py/pyrelation/initialize.cpp @@ -61,8 +61,8 @@ static void InitializeConsumers(py::class_ &m) { py::arg("date_as_object") = false) .def("fetch_df_chunk", &DuckDBPyRelation::FetchDFChunk, "Execute and fetch a chunk of the rows", py::arg("vectors_per_chunk") = 1, py::kw_only(), py::arg("date_as_object") = false) - .def("arrow", &DuckDBPyRelation::ToRecordBatch, "Execute and return an Arrow Record Batch Reader that yields all rows", - py::arg("batch_size") = 1000000) + .def("arrow", &DuckDBPyRelation::ToRecordBatch, + "Execute and return an Arrow Record Batch Reader that yields all rows", py::arg("batch_size") = 1000000) .def("fetch_arrow_table", &DuckDBPyRelation::ToArrowTable, "Execute and fetch all rows as an Arrow Table", py::arg("batch_size") = 1000000) .def("to_arrow_table", &DuckDBPyRelation::ToArrowTable, "Execute and fetch all rows as an Arrow Table", @@ -80,16 +80,16 @@ static void InitializeConsumers(py::class_ &m) { py::arg("requested_schema") = py::none()); m.def("fetch_record_batch", &DuckDBPyRelation::ToRecordBatch, "Execute and return an Arrow Record Batch Reader that yields all rows", py::arg("rows_per_batch") = 1000000) - .def("fetch_arrow_reader", &DuckDBPyRelation::ToRecordBatch, + .def("fetch_arrow_reader", &DuckDBPyRelation::ToRecordBatch, "Execute and return an Arrow Record Batch Reader that yields all rows", py::arg("batch_size") = 1000000) - .def("record_batch", - [](pybind11::object &self, idx_t rows_per_batch) - { - PyErr_WarnEx(PyExc_DeprecationWarning, - "record_batch() is deprecated, use fetch_record_batch() instead.", - 0); - return self.attr("fetch_record_batch")(rows_per_batch); - }, py::arg("batch_size") = 1000000); + .def( + "record_batch", + [](pybind11::object &self, idx_t rows_per_batch) { + PyErr_WarnEx(PyExc_DeprecationWarning, + "record_batch() is deprecated, use fetch_record_batch() instead.", 0); + return self.attr("fetch_record_batch")(rows_per_batch); + }, + py::arg("batch_size") = 1000000); } static void InitializeAggregates(py::class_ &m) { diff --git a/src/duckdb_py/typing/pytype.cpp b/src/duckdb_py/typing/pytype.cpp index 009e3dab..449c4c7d 100644 --- a/src/duckdb_py/typing/pytype.cpp +++ b/src/duckdb_py/typing/pytype.cpp @@ -326,8 +326,10 @@ void DuckDBPyType::Initialize(py::handle &m) { auto type_module = py::class_>(m, "DuckDBPyType", py::module_local()); type_module.def("__repr__", &DuckDBPyType::ToString, "Stringified representation of the type object"); - type_module.def("__eq__", &DuckDBPyType::Equals, "Compare two types for equality", py::arg("other"), py::is_operator()); - type_module.def("__eq__", &DuckDBPyType::EqualsString, "Compare two types for equality", py::arg("other"), py::is_operator()); + type_module.def("__eq__", &DuckDBPyType::Equals, "Compare two types for equality", py::arg("other"), + py::is_operator()); + type_module.def("__eq__", &DuckDBPyType::EqualsString, "Compare two types for equality", py::arg("other"), + py::is_operator()); type_module.def_property_readonly("id", &DuckDBPyType::GetId); type_module.def_property_readonly("children", &DuckDBPyType::Children); type_module.def(py::init<>([](const string &type_str, shared_ptr connection = nullptr) { @@ -347,7 +349,8 @@ void DuckDBPyType::Initialize(py::handle &m) { return make_shared_ptr(ltype); })); type_module.def("__getattr__", &DuckDBPyType::GetAttribute, "Get the child type by 'name'", py::arg("name")); - type_module.def("__getitem__", &DuckDBPyType::GetAttribute, "Get the child type by 'name'", py::arg("name"), py::is_operator()); + type_module.def("__getitem__", &DuckDBPyType::GetAttribute, "Get the child type by 'name'", py::arg("name"), + py::is_operator()); py::implicitly_convertible(); py::implicitly_convertible(); diff --git a/tests/fast/api/test_dbapi10.py b/tests/fast/api/test_dbapi10.py index 1fbde602..0ab69e0b 100644 --- a/tests/fast/api/test_dbapi10.py +++ b/tests/fast/api/test_dbapi10.py @@ -12,7 +12,12 @@ class TestCursorDescription(object): ["SELECT * FROM timestamps", "t", "TIMESTAMP", datetime], ["SELECT DATE '1992-09-20' AS date_col;", "date_col", "DATE", date], ["SELECT '\\xAA'::BLOB AS blob_col;", "blob_col", "BLOB", bytes], - ["SELECT {'x': 1, 'y': 2, 'z': 3} AS struct_col", "struct_col", "STRUCT(x INTEGER, y INTEGER, z INTEGER)", dict], + [ + "SELECT {'x': 1, 'y': 2, 'z': 3} AS struct_col", + "struct_col", + "STRUCT(x INTEGER, y INTEGER, z INTEGER)", + dict, + ], ["SELECT [1, 2, 3] AS list_col", "list_col", "INTEGER[]", list], ["SELECT 'Frank' AS str_col", "str_col", "VARCHAR", str], ["SELECT [1, 2, 3]::JSON AS json_col", "json_col", "JSON", str], @@ -32,15 +37,15 @@ def test_description_comparisons(self): NUMBER = duckdb.NUMBER DATETIME = duckdb.DATETIME - assert(types[1] == STRING) - assert(STRING == types[1]) - assert(types[0] != STRING) - assert((types[1] != STRING) == False) - assert((STRING != types[1]) == False) + assert types[1] == STRING + assert STRING == types[1] + assert types[0] != STRING + assert (types[1] != STRING) == False + assert (STRING != types[1]) == False - assert(types[1] in [STRING]) - assert(types[1] in [STRING, NUMBER]) - assert(types[1] not in [NUMBER, DATETIME]) + assert types[1] in [STRING] + assert types[1] in [STRING, NUMBER] + assert types[1] not in [NUMBER, DATETIME] def test_none_description(self, duckdb_empty_cursor): assert duckdb_empty_cursor.description is None diff --git a/tests/fast/relational_api/test_rapi_description.py b/tests/fast/relational_api/test_rapi_description.py index 01c8a460..41813d94 100644 --- a/tests/fast/relational_api/test_rapi_description.py +++ b/tests/fast/relational_api/test_rapi_description.py @@ -10,7 +10,7 @@ def test_rapi_description(self, duckdb_cursor): types = [x[1] for x in desc] assert names == ['a', 'b'] assert types == ['INTEGER', 'BIGINT'] - assert (all([x == duckdb.NUMBER for x in types])) + assert all([x == duckdb.NUMBER for x in types]) def test_rapi_describe(self, duckdb_cursor): np = pytest.importorskip("numpy") diff --git a/tests/fast/udf/test_remove_function.py b/tests/fast/udf/test_remove_function.py index 15dd6b2b..e67045c4 100644 --- a/tests/fast/udf/test_remove_function.py +++ b/tests/fast/udf/test_remove_function.py @@ -51,9 +51,7 @@ def func(x: int) -> int: """ Error: Catalog Error: Scalar Function with name func does not exist! """ - with pytest.raises( - duckdb.CatalogException, match='Scalar Function with name func does not exist!' - ): + with pytest.raises(duckdb.CatalogException, match='Scalar Function with name func does not exist!'): res = rel.fetchall() def test_use_after_remove_and_recreation(self): From b64b3442b023377fdb0e229c87d921f5ead22c97 Mon Sep 17 00:00:00 2001 From: Tishj Date: Mon, 15 Sep 2025 15:29:04 +0200 Subject: [PATCH 002/135] remove Makefile --- Makefile | 4 ---- 1 file changed, 4 deletions(-) delete mode 100644 Makefile diff --git a/Makefile b/Makefile deleted file mode 100644 index 07008f11..00000000 --- a/Makefile +++ /dev/null @@ -1,4 +0,0 @@ -PYTHON ?= python3 - -format-main: - $(PYTHON) external/duckdb/scripts/format.py main --fix --noconfirm \ No newline at end of file From 9916b996e52f91416bffb854187d243a846709ec Mon Sep 17 00:00:00 2001 From: Tishj Date: Wed, 17 Sep 2025 10:14:10 +0200 Subject: [PATCH 003/135] add hash method --- src/duckdb_py/include/duckdb_python/pytype.hpp | 1 + src/duckdb_py/typing/pytype.cpp | 5 +++++ tests/fast/test_type.py | 14 ++++++++++++++ 3 files changed, 20 insertions(+) diff --git a/src/duckdb_py/include/duckdb_python/pytype.hpp b/src/duckdb_py/include/duckdb_python/pytype.hpp index a6e13dfd..6d1e8074 100644 --- a/src/duckdb_py/include/duckdb_python/pytype.hpp +++ b/src/duckdb_py/include/duckdb_python/pytype.hpp @@ -30,6 +30,7 @@ class DuckDBPyType : public enable_shared_from_this { public: bool Equals(const shared_ptr &other) const; + ssize_t Hash() const; bool EqualsString(const string &type_str) const; shared_ptr GetAttribute(const string &name) const; py::list Children() const; diff --git a/src/duckdb_py/typing/pytype.cpp b/src/duckdb_py/typing/pytype.cpp index 009e3dab..01357ad3 100644 --- a/src/duckdb_py/typing/pytype.cpp +++ b/src/duckdb_py/typing/pytype.cpp @@ -46,6 +46,10 @@ bool DuckDBPyType::Equals(const shared_ptr &other) const { return type == other->type; } +ssize_t DuckDBPyType::Hash() const { + return py::hash(py::str(ToString())); +} + bool DuckDBPyType::EqualsString(const string &type_str) const { return StringUtil::CIEquals(type.ToString(), type_str); } @@ -328,6 +332,7 @@ void DuckDBPyType::Initialize(py::handle &m) { type_module.def("__repr__", &DuckDBPyType::ToString, "Stringified representation of the type object"); type_module.def("__eq__", &DuckDBPyType::Equals, "Compare two types for equality", py::arg("other"), py::is_operator()); type_module.def("__eq__", &DuckDBPyType::EqualsString, "Compare two types for equality", py::arg("other"), py::is_operator()); + type_module.def("__hash__", &DuckDBPyType::Hash, "Hashes the type, equal to stringifying+hashing"); type_module.def_property_readonly("id", &DuckDBPyType::GetId); type_module.def_property_readonly("children", &DuckDBPyType::Children); type_module.def(py::init<>([](const string &type_str, shared_ptr connection = nullptr) { diff --git a/tests/fast/test_type.py b/tests/fast/test_type.py index 6f648179..c5a62694 100644 --- a/tests/fast/test_type.py +++ b/tests/fast/test_type.py @@ -214,6 +214,20 @@ def test_struct_from_dict(self): res = duckdb.list_type({'a': VARCHAR, 'b': VARCHAR}) assert res == 'STRUCT(a VARCHAR, b VARCHAR)[]' + def test_hash_method(self): + type1 = duckdb.list_type({'a': VARCHAR, 'b': VARCHAR}) + type2 = duckdb.list_type({'b': VARCHAR, 'a': VARCHAR}) + type3 = VARCHAR + + type_set = set() + type_set.add(type1) + type_set.add(type2) + type_set.add(type3) + + type_set.add(type1) + expected = ['STRUCT(a VARCHAR, b VARCHAR)[]', 'STRUCT(b VARCHAR, a VARCHAR)[]', 'VARCHAR'] + assert sorted([str(x) for x in list(type_set)]) == expected + # NOTE: we can support this, but I don't think going through hoops for an outdated version of python is worth it @pytest.mark.skipif(sys.version_info < (3, 9), reason="python3.7 does not store Optional[..] in a recognized way") def test_optional(self): From b348aa679bf819c83cc890515b40651023658665 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Wed, 17 Sep 2025 10:45:41 +0200 Subject: [PATCH 004/135] Packaging workflow should respect the 'minimal' input param --- .github/workflows/packaging.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/packaging.yml b/.github/workflows/packaging.yml index 507c7bda..16771deb 100644 --- a/.github/workflows/packaging.yml +++ b/.github/workflows/packaging.yml @@ -74,7 +74,7 @@ jobs: name: Build and test releases uses: ./.github/workflows/packaging_wheels.yml with: - minimal: false + minimal: ${{ inputs.minimal }} testsuite: all duckdb-python-sha: ${{ inputs.duckdb-python-sha != '' && inputs.duckdb-python-sha || github.sha }} duckdb-sha: ${{ inputs.duckdb-sha }} From ed6b2c5ac4b9f39905867affaf9299cd8a8448db Mon Sep 17 00:00:00 2001 From: Tishj Date: Wed, 17 Sep 2025 12:23:48 +0200 Subject: [PATCH 005/135] avoid collision with windows define --- src/duckdb_py/include/duckdb_python/pytype.hpp | 2 +- src/duckdb_py/typing/pytype.cpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/duckdb_py/include/duckdb_python/pytype.hpp b/src/duckdb_py/include/duckdb_python/pytype.hpp index 6d1e8074..fced489e 100644 --- a/src/duckdb_py/include/duckdb_python/pytype.hpp +++ b/src/duckdb_py/include/duckdb_python/pytype.hpp @@ -30,7 +30,7 @@ class DuckDBPyType : public enable_shared_from_this { public: bool Equals(const shared_ptr &other) const; - ssize_t Hash() const; + ssize_t HashType() const; bool EqualsString(const string &type_str) const; shared_ptr GetAttribute(const string &name) const; py::list Children() const; diff --git a/src/duckdb_py/typing/pytype.cpp b/src/duckdb_py/typing/pytype.cpp index 35ff81a9..91a95d91 100644 --- a/src/duckdb_py/typing/pytype.cpp +++ b/src/duckdb_py/typing/pytype.cpp @@ -46,7 +46,7 @@ bool DuckDBPyType::Equals(const shared_ptr &other) const { return type == other->type; } -ssize_t DuckDBPyType::Hash() const { +ssize_t DuckDBPyType::HashType() const { return py::hash(py::str(ToString())); } @@ -334,7 +334,7 @@ void DuckDBPyType::Initialize(py::handle &m) { py::is_operator()); type_module.def("__eq__", &DuckDBPyType::EqualsString, "Compare two types for equality", py::arg("other"), py::is_operator()); - type_module.def("__hash__", &DuckDBPyType::Hash, "Hashes the type, equal to stringifying+hashing"); + type_module.def("__hash__", &DuckDBPyType::HashType, "Hashes the type, equal to stringifying+hashing"); type_module.def_property_readonly("id", &DuckDBPyType::GetId); type_module.def_property_readonly("children", &DuckDBPyType::Children); type_module.def(py::init<>([](const string &type_str, shared_ptr connection = nullptr) { From f3ab971c029624ce4efc20cb5cfa7730cd21afb4 Mon Sep 17 00:00:00 2001 From: Tishj Date: Wed, 17 Sep 2025 16:36:40 +0200 Subject: [PATCH 006/135] third attempt at making windows happy --- src/duckdb_py/include/duckdb_python/pytype.hpp | 1 - src/duckdb_py/typing/pytype.cpp | 8 +++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/duckdb_py/include/duckdb_python/pytype.hpp b/src/duckdb_py/include/duckdb_python/pytype.hpp index fced489e..a6e13dfd 100644 --- a/src/duckdb_py/include/duckdb_python/pytype.hpp +++ b/src/duckdb_py/include/duckdb_python/pytype.hpp @@ -30,7 +30,6 @@ class DuckDBPyType : public enable_shared_from_this { public: bool Equals(const shared_ptr &other) const; - ssize_t HashType() const; bool EqualsString(const string &type_str) const; shared_ptr GetAttribute(const string &name) const; py::list Children() const; diff --git a/src/duckdb_py/typing/pytype.cpp b/src/duckdb_py/typing/pytype.cpp index 91a95d91..f04c14ba 100644 --- a/src/duckdb_py/typing/pytype.cpp +++ b/src/duckdb_py/typing/pytype.cpp @@ -46,10 +46,6 @@ bool DuckDBPyType::Equals(const shared_ptr &other) const { return type == other->type; } -ssize_t DuckDBPyType::HashType() const { - return py::hash(py::str(ToString())); -} - bool DuckDBPyType::EqualsString(const string &type_str) const { return StringUtil::CIEquals(type.ToString(), type_str); } @@ -334,7 +330,9 @@ void DuckDBPyType::Initialize(py::handle &m) { py::is_operator()); type_module.def("__eq__", &DuckDBPyType::EqualsString, "Compare two types for equality", py::arg("other"), py::is_operator()); - type_module.def("__hash__", &DuckDBPyType::HashType, "Hashes the type, equal to stringifying+hashing"); + type_module.def("__hash__", [](const DuckDBPyType &type) { + return py::hash(py::str(type.ToString())); + }); type_module.def_property_readonly("id", &DuckDBPyType::GetId); type_module.def_property_readonly("children", &DuckDBPyType::Children); type_module.def(py::init<>([](const string &type_str, shared_ptr connection = nullptr) { From 7d6f47011a3e82cabbb6901df1c0002fbf7aaa67 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Tue, 9 Sep 2025 19:30:17 +0200 Subject: [PATCH 007/135] ruff conf: exclude pyi from linting --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 6291b811..4da79b50 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -320,6 +320,7 @@ fixable = ["ALL"] exclude = ['external/duckdb'] [tool.ruff.lint] +exclude = ['*.pyi'] select = [ "ANN", # flake8-annotations "B", # flake8-bugbear From db57889ae1f0d644c63c4ea10ccc0fc484038744 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Tue, 9 Sep 2025 19:33:50 +0200 Subject: [PATCH 008/135] Ruff UP006: dont use typing module for dict and list typehints --- duckdb/experimental/spark/conf.py | 6 +-- .../spark/errors/exceptions/base.py | 6 +-- duckdb/experimental/spark/errors/utils.py | 2 +- duckdb/experimental/spark/sql/_typing.py | 2 +- duckdb/experimental/spark/sql/catalog.py | 8 ++-- duckdb/experimental/spark/sql/column.py | 4 +- duckdb/experimental/spark/sql/dataframe.py | 30 ++++++------ duckdb/experimental/spark/sql/functions.py | 2 +- duckdb/experimental/spark/sql/group.py | 6 +-- duckdb/experimental/spark/sql/readwriter.py | 8 ++-- duckdb/experimental/spark/sql/session.py | 2 +- duckdb/experimental/spark/sql/type_utils.py | 6 +-- duckdb/experimental/spark/sql/types.py | 46 +++++++++---------- duckdb/value/constant/__init__.py | 4 +- duckdb_packaging/build_backend.py | 6 +-- duckdb_packaging/pypi_cleanup.py | 6 +-- .../generate_connection_wrapper_methods.py | 2 +- scripts/generate_import_cache_cpp.py | 14 +++--- scripts/generate_import_cache_json.py | 8 ++-- scripts/get_cpp_methods.py | 4 +- sqllogic/conftest.py | 6 +-- tests/fast/test_filesystem.py | 2 +- tests/fast/test_multithread.py | 4 +- 23 files changed, 92 insertions(+), 92 deletions(-) diff --git a/duckdb/experimental/spark/conf.py b/duckdb/experimental/spark/conf.py index 11680a9a..a04c993b 100644 --- a/duckdb/experimental/spark/conf.py +++ b/duckdb/experimental/spark/conf.py @@ -12,20 +12,20 @@ def contains(self, key: str) -> bool: def get(self, key: str, defaultValue: Optional[str] = None) -> Optional[str]: raise ContributionsAcceptedError - def getAll(self) -> List[Tuple[str, str]]: + def getAll(self) -> list[tuple[str, str]]: raise ContributionsAcceptedError def set(self, key: str, value: str) -> "SparkConf": raise ContributionsAcceptedError - def setAll(self, pairs: List[Tuple[str, str]]) -> "SparkConf": + def setAll(self, pairs: list[tuple[str, str]]) -> "SparkConf": raise ContributionsAcceptedError def setAppName(self, value: str) -> "SparkConf": raise ContributionsAcceptedError def setExecutorEnv( - self, key: Optional[str] = None, value: Optional[str] = None, pairs: Optional[List[Tuple[str, str]]] = None + self, key: Optional[str] = None, value: Optional[str] = None, pairs: Optional[list[tuple[str, str]]] = None ) -> "SparkConf": raise ContributionsAcceptedError diff --git a/duckdb/experimental/spark/errors/exceptions/base.py b/duckdb/experimental/spark/errors/exceptions/base.py index 21dba03b..80e91170 100644 --- a/duckdb/experimental/spark/errors/exceptions/base.py +++ b/duckdb/experimental/spark/errors/exceptions/base.py @@ -13,7 +13,7 @@ def __init__( # The error class, decides the message format, must be one of the valid options listed in 'error_classes.py' error_class: Optional[str] = None, # The dictionary listing the arguments specified in the message (or the error_class) - message_parameters: Optional[Dict[str, str]] = None, + message_parameters: Optional[dict[str, str]] = None, ): # `message` vs `error_class` & `message_parameters` are mutually exclusive. assert (message is not None and (error_class is None and message_parameters is None)) or ( @@ -24,7 +24,7 @@ def __init__( if message is None: self.message = self.error_reader.get_error_message( - cast(str, error_class), cast(Dict[str, str], message_parameters) + cast(str, error_class), cast(dict[str, str], message_parameters) ) else: self.message = message @@ -45,7 +45,7 @@ def getErrorClass(self) -> Optional[str]: """ return self.error_class - def getMessageParameters(self) -> Optional[Dict[str, str]]: + def getMessageParameters(self) -> Optional[dict[str, str]]: """ Returns a message parameters as a dictionary. diff --git a/duckdb/experimental/spark/errors/utils.py b/duckdb/experimental/spark/errors/utils.py index a375c0c7..3ef418bd 100644 --- a/duckdb/experimental/spark/errors/utils.py +++ b/duckdb/experimental/spark/errors/utils.py @@ -29,7 +29,7 @@ class ErrorClassesReader: def __init__(self) -> None: self.error_info_map = ERROR_CLASSES_MAP - def get_error_message(self, error_class: str, message_parameters: Dict[str, str]) -> str: + def get_error_message(self, error_class: str, message_parameters: dict[str, str]) -> str: """ Returns the completed error message by applying message parameters to the message template. """ diff --git a/duckdb/experimental/spark/sql/_typing.py b/duckdb/experimental/spark/sql/_typing.py index 7b1f9ad1..645b60bb 100644 --- a/duckdb/experimental/spark/sql/_typing.py +++ b/duckdb/experimental/spark/sql/_typing.py @@ -57,7 +57,7 @@ float, ) -RowLike = TypeVar("RowLike", List[Any], Tuple[Any, ...], types.Row) +RowLike = TypeVar("RowLike", list[Any], tuple[Any, ...], types.Row) SQLBatchedUDFType = Literal[100] diff --git a/duckdb/experimental/spark/sql/catalog.py b/duckdb/experimental/spark/sql/catalog.py index ebedb1a1..d3b857fb 100644 --- a/duckdb/experimental/spark/sql/catalog.py +++ b/duckdb/experimental/spark/sql/catalog.py @@ -36,7 +36,7 @@ class Catalog: def __init__(self, session: SparkSession): self._session = session - def listDatabases(self) -> List[Database]: + def listDatabases(self) -> list[Database]: res = self._session.conn.sql('select database_name from duckdb_databases()').fetchall() def transform_to_database(x) -> Database: @@ -45,7 +45,7 @@ def transform_to_database(x) -> Database: databases = [transform_to_database(x) for x in res] return databases - def listTables(self) -> List[Table]: + def listTables(self) -> list[Table]: res = self._session.conn.sql('select table_name, database_name, sql, temporary from duckdb_tables()').fetchall() def transform_to_table(x) -> Table: @@ -54,7 +54,7 @@ def transform_to_table(x) -> Table: tables = [transform_to_table(x) for x in res] return tables - def listColumns(self, tableName: str, dbName: Optional[str] = None) -> List[Column]: + def listColumns(self, tableName: str, dbName: Optional[str] = None) -> list[Column]: query = f""" select column_name, data_type, is_nullable from duckdb_columns() where table_name = '{tableName}' """ @@ -68,7 +68,7 @@ def transform_to_column(x) -> Column: columns = [transform_to_column(x) for x in res] return columns - def listFunctions(self, dbName: Optional[str] = None) -> List[Function]: + def listFunctions(self, dbName: Optional[str] = None) -> list[Function]: raise NotImplementedError def setCurrentDatabase(self, dbName: str) -> None: diff --git a/duckdb/experimental/spark/sql/column.py b/duckdb/experimental/spark/sql/column.py index 5f0b2b99..de0c95f8 100644 --- a/duckdb/experimental/spark/sql/column.py +++ b/duckdb/experimental/spark/sql/column.py @@ -234,10 +234,10 @@ def cast(self, dataType: Union[DataType, str]) -> "Column": def isin(self, *cols: Any) -> "Column": if len(cols) == 1 and isinstance(cols[0], (list, set)): # Only one argument supplied, it's a list - cols = cast(Tuple, cols[0]) + cols = cast(tuple, cols[0]) cols = cast( - Tuple, + tuple, [_get_expr(c) for c in cols], ) return Column(self.expr.isin(*cols)) diff --git a/duckdb/experimental/spark/sql/dataframe.py b/duckdb/experimental/spark/sql/dataframe.py index a81a423b..54c220eb 100644 --- a/duckdb/experimental/spark/sql/dataframe.py +++ b/duckdb/experimental/spark/sql/dataframe.py @@ -143,7 +143,7 @@ def withColumn(self, columnName: str, col: Column) -> "DataFrame": rel = self.relation.select(*cols) return DataFrame(rel, self.session) - def withColumns(self, *colsMap: Dict[str, Column]) -> "DataFrame": + def withColumns(self, *colsMap: dict[str, Column]) -> "DataFrame": """ Returns a new :class:`DataFrame` by adding multiple columns or replacing the existing columns that have the same names. @@ -218,7 +218,7 @@ def withColumns(self, *colsMap: Dict[str, Column]) -> "DataFrame": rel = self.relation.select(*cols) return DataFrame(rel, self.session) - def withColumnsRenamed(self, colsMap: Dict[str, str]) -> "DataFrame": + def withColumnsRenamed(self, colsMap: dict[str, str]) -> "DataFrame": """ Returns a new :class:`DataFrame` by renaming multiple columns. This is a no-op if the schema doesn't contain the given column names. @@ -356,7 +356,7 @@ def transform( return result def sort( - self, *cols: Union[str, Column, List[Union[str, Column]]], **kwargs: Any + self, *cols: Union[str, Column, list[Union[str, Column]]], **kwargs: Any ) -> "DataFrame": """Returns a new :class:`DataFrame` sorted by the specified column(s). @@ -487,7 +487,7 @@ def sort( orderBy = sort - def head(self, n: Optional[int] = None) -> Union[Optional[Row], List[Row]]: + def head(self, n: Optional[int] = None) -> Union[Optional[Row], list[Row]]: if n is None: rs = self.head(1) return rs[0] if rs else None @@ -495,7 +495,7 @@ def head(self, n: Optional[int] = None) -> Union[Optional[Row], List[Row]]: first = head - def take(self, num: int) -> List[Row]: + def take(self, num: int) -> list[Row]: return self.limit(num).collect() def filter(self, condition: "ColumnOrName") -> "DataFrame": @@ -579,7 +579,7 @@ def select(self, *cols) -> "DataFrame": return DataFrame(rel, self.session) @property - def columns(self) -> List[str]: + def columns(self) -> list[str]: """Returns all column names as a list. Examples @@ -589,12 +589,12 @@ def columns(self) -> List[str]: """ return [f.name for f in self.schema.fields] - def _ipython_key_completions_(self) -> List[str]: + def _ipython_key_completions_(self) -> list[str]: # Provides tab-completion for column names in PySpark DataFrame # when accessed in bracket notation, e.g. df['] return self.columns - def __dir__(self) -> List[str]: + def __dir__(self) -> list[str]: out = set(super().__dir__()) out.update(c for c in self.columns if c.isidentifier() and not iskeyword(c)) return sorted(out) @@ -602,7 +602,7 @@ def __dir__(self) -> List[str]: def join( self, other: "DataFrame", - on: Optional[Union[str, List[str], Column, List[Column]]] = None, + on: Optional[Union[str, list[str], Column, list[Column]]] = None, how: Optional[str] = None, ) -> "DataFrame": """Joins with another :class:`DataFrame`, using the given join expression. @@ -704,7 +704,7 @@ def join( assert isinstance( on[0], Expression ), "on should be Column or list of Column" - on = reduce(lambda x, y: x.__and__(y), cast(List[Expression], on)) + on = reduce(lambda x, y: x.__and__(y), cast(list[Expression], on)) if on is None and how is None: @@ -893,11 +893,11 @@ def __getitem__(self, item: Union[int, str]) -> Column: ... @overload - def __getitem__(self, item: Union[Column, List, Tuple]) -> "DataFrame": + def __getitem__(self, item: Union[Column, list, tuple]) -> "DataFrame": ... def __getitem__( - self, item: Union[int, str, Column, List, Tuple] + self, item: Union[int, str, Column, list, tuple] ) -> Union[Column, "DataFrame"]: """Returns the column as a :class:`Column`. @@ -942,7 +942,7 @@ def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": ... @overload - def groupBy(self, __cols: Union[List[Column], List[str]]) -> "GroupedData": + def groupBy(self, __cols: Union[list[Column], list[str]]) -> "GroupedData": ... def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc] @@ -1259,7 +1259,7 @@ def exceptAll(self, other: "DataFrame") -> "DataFrame": """ return DataFrame(self.relation.except_(other.relation), self.session) - def dropDuplicates(self, subset: Optional[List[str]] = None) -> "DataFrame": + def dropDuplicates(self, subset: Optional[list[str]] = None) -> "DataFrame": """Return a new :class:`DataFrame` with duplicate rows removed, optionally only considering certain columns. @@ -1391,7 +1391,7 @@ def toDF(self, *cols) -> "DataFrame": new_rel = self.relation.project(*projections) return DataFrame(new_rel, self.session) - def collect(self) -> List[Row]: + def collect(self) -> list[Row]: columns = self.relation.columns result = self.relation.fetchall() diff --git a/duckdb/experimental/spark/sql/functions.py b/duckdb/experimental/spark/sql/functions.py index fecada95..78b14de7 100644 --- a/duckdb/experimental/spark/sql/functions.py +++ b/duckdb/experimental/spark/sql/functions.py @@ -111,7 +111,7 @@ def struct(*cols: Column) -> Column: def array( - *cols: Union["ColumnOrName", Union[List["ColumnOrName"], Tuple["ColumnOrName", ...]]] + *cols: Union["ColumnOrName", Union[list["ColumnOrName"], tuple["ColumnOrName", ...]]] ) -> Column: """Creates a new array column. diff --git a/duckdb/experimental/spark/sql/group.py b/duckdb/experimental/spark/sql/group.py index e6e99beb..ad7e7e2a 100644 --- a/duckdb/experimental/spark/sql/group.py +++ b/duckdb/experimental/spark/sql/group.py @@ -177,7 +177,7 @@ def avg(self, *cols: str) -> DataFrame: if len(columns) == 0: schema = self._df.schema # Take only the numeric types of the relation - columns: List[str] = [x.name for x in schema.fields if isinstance(x.dataType, NumericType)] + columns: list[str] = [x.name for x in schema.fields if isinstance(x.dataType, NumericType)] return _api_internal(self, "avg", *columns) @df_varargs_api @@ -312,10 +312,10 @@ def agg(self, *exprs: Column) -> DataFrame: ... @overload - def agg(self, __exprs: Dict[str, str]) -> DataFrame: + def agg(self, __exprs: dict[str, str]) -> DataFrame: ... - def agg(self, *exprs: Union[Column, Dict[str, str]]) -> DataFrame: + def agg(self, *exprs: Union[Column, dict[str, str]]) -> DataFrame: """Compute aggregates and returns the result as a :class:`DataFrame`. The available aggregate functions can be: diff --git a/duckdb/experimental/spark/sql/readwriter.py b/duckdb/experimental/spark/sql/readwriter.py index 990201cf..6c8b5e7d 100644 --- a/duckdb/experimental/spark/sql/readwriter.py +++ b/duckdb/experimental/spark/sql/readwriter.py @@ -26,7 +26,7 @@ def parquet( self, path: str, mode: Optional[str] = None, - partitionBy: Union[str, List[str], None] = None, + partitionBy: Union[str, list[str], None] = None, compression: Optional[str] = None, ) -> None: relation = self.dataframe.relation @@ -94,7 +94,7 @@ def __init__(self, session: "SparkSession"): def load( self, - path: Optional[Union[str, List[str]]] = None, + path: Optional[Union[str, list[str]]] = None, format: Optional[str] = None, schema: Optional[Union[StructType, str]] = None, **options: OptionalPrimitiveType, @@ -131,7 +131,7 @@ def load( def csv( self, - path: Union[str, List[str]], + path: Union[str, list[str]], schema: Optional[Union[StructType, str]] = None, sep: Optional[str] = None, encoding: Optional[str] = None, @@ -263,7 +263,7 @@ def parquet(self, *paths: str, **options: "OptionalPrimitiveType") -> "DataFrame def json( self, - path: Union[str, List[str]], + path: Union[str, list[str]], schema: Optional[Union[StructType, str]] = None, primitivesAsString: Optional[Union[bool, str]] = None, prefersDecimal: Optional[Union[bool, str]] = None, diff --git a/duckdb/experimental/spark/sql/session.py b/duckdb/experimental/spark/sql/session.py index d3cfaa68..91f9cc0e 100644 --- a/duckdb/experimental/spark/sql/session.py +++ b/duckdb/experimental/spark/sql/session.py @@ -126,7 +126,7 @@ def _createDataFrameFromPandas(self, data: "PandasDataFrame", types, names) -> D def createDataFrame( self, data: Union["PandasDataFrame", Iterable[Any]], - schema: Optional[Union[StructType, List[str]]] = None, + schema: Optional[Union[StructType, list[str]]] = None, samplingRatio: Optional[float] = None, verifySchema: bool = True, ) -> DataFrame: diff --git a/duckdb/experimental/spark/sql/type_utils.py b/duckdb/experimental/spark/sql/type_utils.py index a17d0f53..ecccc014 100644 --- a/duckdb/experimental/spark/sql/type_utils.py +++ b/duckdb/experimental/spark/sql/type_utils.py @@ -79,7 +79,7 @@ def convert_nested_type(dtype: DuckDBPyType) -> DataType: return ArrayType(convert_type(children[0][1])) # TODO: add support for 'union' if id == 'struct': - children: List[Tuple[str, DuckDBPyType]] = dtype.children + children: list[tuple[str, DuckDBPyType]] = dtype.children fields = [StructField(x[0], convert_type(x[1])) for x in children] return StructType(fields) if id == 'map': @@ -92,7 +92,7 @@ def convert_type(dtype: DuckDBPyType) -> DataType: if id in ['list', 'struct', 'map', 'array']: return convert_nested_type(dtype) if id == 'decimal': - children: List[Tuple[str, DuckDBPyType]] = dtype.children + children: list[tuple[str, DuckDBPyType]] = dtype.children precision = cast(int, children[0][1]) scale = cast(int, children[1][1]) return DecimalType(precision, scale) @@ -100,6 +100,6 @@ def convert_type(dtype: DuckDBPyType) -> DataType: return spark_type() -def duckdb_to_spark_schema(names: List[str], types: List[DuckDBPyType]) -> StructType: +def duckdb_to_spark_schema(names: list[str], types: list[DuckDBPyType]) -> StructType: fields = [StructField(name, dtype) for name, dtype in zip(names, [convert_type(x) for x in types])] return StructType(fields) diff --git a/duckdb/experimental/spark/sql/types.py b/duckdb/experimental/spark/sql/types.py index 13cd8480..d4dcbd9a 100644 --- a/duckdb/experimental/spark/sql/types.py +++ b/duckdb/experimental/spark/sql/types.py @@ -92,7 +92,7 @@ def typeName(cls) -> str: def simpleString(self) -> str: return self.typeName() - def jsonValue(self) -> Union[str, Dict[str, Any]]: + def jsonValue(self) -> Union[str, dict[str, Any]]: raise ContributionsAcceptedError def json(self) -> str: @@ -124,9 +124,9 @@ def fromInternal(self, obj: Any) -> Any: class DataTypeSingleton(type): """Metaclass for DataType""" - _instances: ClassVar[Dict[Type["DataTypeSingleton"], "DataTypeSingleton"]] = {} + _instances: ClassVar[dict[type["DataTypeSingleton"], "DataTypeSingleton"]] = {} - def __call__(cls: Type[T]) -> T: # type: ignore[override] + def __call__(cls: type[T]) -> T: # type: ignore[override] if cls not in cls._instances: # type: ignore[attr-defined] cls._instances[cls] = super(DataTypeSingleton, cls).__call__() # type: ignore[misc, attr-defined] return cls._instances[cls] # type: ignore[attr-defined] @@ -603,12 +603,12 @@ def __repr__(self) -> str: def needConversion(self) -> bool: return self.elementType.needConversion() - def toInternal(self, obj: List[Optional[T]]) -> List[Optional[T]]: + def toInternal(self, obj: list[Optional[T]]) -> list[Optional[T]]: if not self.needConversion(): return obj return obj and [self.elementType.toInternal(v) for v in obj] - def fromInternal(self, obj: List[Optional[T]]) -> List[Optional[T]]: + def fromInternal(self, obj: list[Optional[T]]) -> list[Optional[T]]: if not self.needConversion(): return obj return obj and [self.elementType.fromInternal(v) for v in obj] @@ -670,12 +670,12 @@ def __repr__(self) -> str: def needConversion(self) -> bool: return self.keyType.needConversion() or self.valueType.needConversion() - def toInternal(self, obj: Dict[T, Optional[U]]) -> Dict[T, Optional[U]]: + def toInternal(self, obj: dict[T, Optional[U]]) -> dict[T, Optional[U]]: if not self.needConversion(): return obj return obj and dict((self.keyType.toInternal(k), self.valueType.toInternal(v)) for k, v in obj.items()) - def fromInternal(self, obj: Dict[T, Optional[U]]) -> Dict[T, Optional[U]]: + def fromInternal(self, obj: dict[T, Optional[U]]) -> dict[T, Optional[U]]: if not self.needConversion(): return obj return obj and dict((self.keyType.fromInternal(k), self.valueType.fromInternal(v)) for k, v in obj.items()) @@ -710,7 +710,7 @@ def __init__( name: str, dataType: DataType, nullable: bool = True, - metadata: Optional[Dict[str, Any]] = None, + metadata: Optional[dict[str, Any]] = None, ): super().__init__(dataType.duckdb_type) assert isinstance(dataType, DataType), "dataType %s should be an instance of %s" % ( @@ -776,7 +776,7 @@ class StructType(DataType): def _update_internal_duckdb_type(self): self.duckdb_type = duckdb.struct_type(dict(zip(self.names, [x.duckdb_type for x in self.fields]))) - def __init__(self, fields: Optional[List[StructField]] = None): + def __init__(self, fields: Optional[list[StructField]] = None): if not fields: self.fields = [] self.names = [] @@ -795,7 +795,7 @@ def add( field: str, data_type: Union[str, DataType], nullable: bool = True, - metadata: Optional[Dict[str, Any]] = None, + metadata: Optional[dict[str, Any]] = None, ) -> "StructType": ... @@ -808,7 +808,7 @@ def add( field: Union[str, StructField], data_type: Optional[Union[str, DataType]] = None, nullable: bool = True, - metadata: Optional[Dict[str, Any]] = None, + metadata: Optional[dict[str, Any]] = None, ) -> "StructType": """ Construct a :class:`StructType` by adding new elements to it, to define the schema. @@ -900,7 +900,7 @@ def __repr__(self) -> str: def __contains__(self, item: Any) -> bool: return item in self.names - def extract_types_and_names(self) -> Tuple[List[str], List[str]]: + def extract_types_and_names(self) -> tuple[list[str], list[str]]: names = [] types = [] for f in self.fields: @@ -908,7 +908,7 @@ def extract_types_and_names(self) -> Tuple[List[str], List[str]]: names.append(f.name) return (types, names) - def fieldNames(self) -> List[str]: + def fieldNames(self) -> list[str]: """ Returns all field names in a list. @@ -924,7 +924,7 @@ def needConversion(self) -> bool: # We need convert Row()/namedtuple into tuple() return True - def toInternal(self, obj: Tuple) -> Tuple: + def toInternal(self, obj: tuple) -> tuple: if obj is None: return @@ -956,14 +956,14 @@ def toInternal(self, obj: Tuple) -> Tuple: else: raise ValueError("Unexpected tuple %r with StructType" % obj) - def fromInternal(self, obj: Tuple) -> "Row": + def fromInternal(self, obj: tuple) -> "Row": if obj is None: return if isinstance(obj, Row): # it's already converted by pickler return obj - values: Union[Tuple, List] + values: Union[tuple, list] if self._needSerializeAnyField: # Only calling fromInternal function for fields that need conversion values = [f.fromInternal(v) if c else v for f, v, c in zip(self.fields, obj, self._needConversion)] @@ -1052,7 +1052,7 @@ def __eq__(self, other: Any) -> bool: return type(self) == type(other) -_atomic_types: List[Type[DataType]] = [ +_atomic_types: list[type[DataType]] = [ StringType, BinaryType, BooleanType, @@ -1068,14 +1068,14 @@ def __eq__(self, other: Any) -> bool: TimestampNTZType, NullType, ] -_all_atomic_types: Dict[str, Type[DataType]] = dict((t.typeName(), t) for t in _atomic_types) +_all_atomic_types: dict[str, type[DataType]] = dict((t.typeName(), t) for t in _atomic_types) -_complex_types: List[Type[Union[ArrayType, MapType, StructType]]] = [ +_complex_types: list[type[Union[ArrayType, MapType, StructType]]] = [ ArrayType, MapType, StructType, ] -_all_complex_types: Dict[str, Type[Union[ArrayType, MapType, StructType]]] = dict( +_all_complex_types: dict[str, type[Union[ArrayType, MapType, StructType]]] = dict( (v.typeName(), v) for v in _complex_types ) @@ -1084,7 +1084,7 @@ def __eq__(self, other: Any) -> bool: _INTERVAL_DAYTIME = re.compile(r"interval (day|hour|minute|second)( to (day|hour|minute|second))?") -def _create_row(fields: Union["Row", List[str]], values: Union[Tuple[Any, ...], List[Any]]) -> "Row": +def _create_row(fields: Union["Row", list[str]], values: Union[tuple[Any, ...], list[Any]]) -> "Row": row = Row(*values) row.__fields__ = fields return row @@ -1166,7 +1166,7 @@ def __new__(cls, *args: Optional[str], **kwargs: Optional[Any]) -> "Row": # create row class or objects return tuple.__new__(cls, args) - def asDict(self, recursive: bool = False) -> Dict[str, Any]: + def asDict(self, recursive: bool = False) -> dict[str, Any]: """ Return as a dict @@ -1260,7 +1260,7 @@ def __setattr__(self, key: Any, value: Any) -> None: def __reduce__( self, - ) -> Union[str, Tuple[Any, ...]]: + ) -> Union[str, tuple[Any, ...]]: """Returns a tuple so Python knows how to pickle Row.""" if hasattr(self, "__fields__"): return (_create_row, (self.__fields__, tuple(self))) diff --git a/duckdb/value/constant/__init__.py b/duckdb/value/constant/__init__.py index da2004b9..0a5a62c0 100644 --- a/duckdb/value/constant/__init__.py +++ b/duckdb/value/constant/__init__.py @@ -210,7 +210,7 @@ def __init__(self, object: Any, child_type: DuckDBPyType): class StructValue(Value): - def __init__(self, object: Any, children: Dict[str, DuckDBPyType]): + def __init__(self, object: Any, children: dict[str, DuckDBPyType]): import duckdb struct_type = duckdb.struct_type(children) @@ -226,7 +226,7 @@ def __init__(self, object: Any, key_type: DuckDBPyType, value_type: DuckDBPyType class UnionType(Value): - def __init__(self, object: Any, members: Dict[str, DuckDBPyType]): + def __init__(self, object: Any, members: dict[str, DuckDBPyType]): import duckdb union_type = duckdb.union_type(members) diff --git a/duckdb_packaging/build_backend.py b/duckdb_packaging/build_backend.py index d96a4847..de1a9535 100644 --- a/duckdb_packaging/build_backend.py +++ b/duckdb_packaging/build_backend.py @@ -126,7 +126,7 @@ def _read_duckdb_long_version() -> str: def _skbuild_config_add( - key: str, value: Union[List, str], config_settings: Dict[str, Union[List[str],str]], fail_if_exists: bool=False + key: str, value: Union[list, str], config_settings: dict[str, Union[list[str],str]], fail_if_exists: bool=False ): """Add or modify a configuration setting for scikit-build-core. @@ -178,7 +178,7 @@ def _skbuild_config_add( ) -def build_sdist(sdist_directory: str, config_settings: Optional[Dict[str, Union[List[str],str]]] = None) -> str: +def build_sdist(sdist_directory: str, config_settings: Optional[dict[str, Union[list[str],str]]] = None) -> str: """Build a source distribution using the DuckDB submodule. This function extracts the DuckDB version from either the git submodule and saves it @@ -208,7 +208,7 @@ def build_sdist(sdist_directory: str, config_settings: Optional[Dict[str, Union[ def build_wheel( wheel_directory: str, - config_settings: Optional[Dict[str, Union[List[str],str]]] = None, + config_settings: Optional[dict[str, Union[list[str],str]]] = None, metadata_directory: Optional[str] = None, ) -> str: """Build a wheel from either git submodule or extracted sdist sources. diff --git a/duckdb_packaging/pypi_cleanup.py b/duckdb_packaging/pypi_cleanup.py index 81d4c8e0..8236dd1d 100644 --- a/duckdb_packaging/pypi_cleanup.py +++ b/duckdb_packaging/pypi_cleanup.py @@ -290,7 +290,7 @@ def _execute_cleanup(self, http_session: Session) -> int: logging.info(f"Successfully cleaned up {len(versions_to_delete)} development versions") return 0 - def _fetch_released_versions(self, http_session: Session) -> Set[str]: + def _fetch_released_versions(self, http_session: Session) -> set[str]: """Fetch package release information from PyPI API.""" logging.debug(f"Fetching package information for '{self._package}'") @@ -330,7 +330,7 @@ def _parse_dev_version(self, version: str) -> tuple[str, int]: raise PyPICleanupError(f"Invalid dev version '{version}'") return match.group("version"), int(match.group("dev_id")) - def _determine_versions_to_delete(self, versions: Set[str]) -> Set[str]: + def _determine_versions_to_delete(self, versions: set[str]) -> set[str]: """Determine which package versions should be deleted.""" logging.debug("Analyzing versions to determine cleanup candidates") @@ -488,7 +488,7 @@ def _handle_two_factor_auth(self, http_session: Session, response: requests.Resp raise AuthenticationError("Two-factor authentication failed after all attempts") - def _delete_versions(self, http_session: Session, versions_to_delete: Set[str]) -> None: + def _delete_versions(self, http_session: Session, versions_to_delete: set[str]) -> None: """Delete the specified package versions.""" logging.info(f"Starting deletion of {len(versions_to_delete)} development versions") diff --git a/scripts/generate_connection_wrapper_methods.py b/scripts/generate_connection_wrapper_methods.py index 7be7256c..af5ad4ac 100644 --- a/scripts/generate_connection_wrapper_methods.py +++ b/scripts/generate_connection_wrapper_methods.py @@ -71,7 +71,7 @@ def is_py_kwargs(method): return 'kwargs_as_dict' in method and method['kwargs_as_dict'] == True -def remove_section(content, start_marker, end_marker) -> Tuple[List[str], List[str]]: +def remove_section(content, start_marker, end_marker) -> tuple[list[str], list[str]]: start_index = -1 end_index = -1 for i, line in enumerate(content): diff --git a/scripts/generate_import_cache_cpp.py b/scripts/generate_import_cache_cpp.py index 07744e37..f1f9d983 100644 --- a/scripts/generate_import_cache_cpp.py +++ b/scripts/generate_import_cache_cpp.py @@ -16,7 +16,7 @@ # deal with leaf nodes?? Those are just PythonImportCacheItem def get_class_name(path: str) -> str: - parts: List[str] = path.replace('_', '').split('.') + parts: list[str] = path.replace('_', '').split('.') parts = [x.title() for x in parts] return ''.join(parts) + 'CacheItem' @@ -31,7 +31,7 @@ def get_variable_name(name: str) -> str: return name -def collect_items_of_module(module: dict, collection: Dict): +def collect_items_of_module(module: dict, collection: dict): global json_data children = module['children'] collection[module['full_path']] = module @@ -122,8 +122,8 @@ def to_string(self): """ -def collect_classes(items: Dict) -> List: - output: List = [] +def collect_classes(items: dict) -> list: + output: list = [] for item in items.values(): if item['children'] == []: continue @@ -174,7 +174,7 @@ def to_string(self): return string -files: List[ModuleFile] = [] +files: list[ModuleFile] = [] for name, value in json_data.items(): if value['full_path'] != value['name']: continue @@ -188,7 +188,7 @@ def to_string(self): f.write(content) -def get_root_modules(files: List[ModuleFile]): +def get_root_modules(files: list[ModuleFile]): modules = [] for file in files: name = file.module['name'] @@ -244,7 +244,7 @@ def get_root_modules(files: List[ModuleFile]): f.write(import_cache_file) -def get_module_file_path_includes(files: List[ModuleFile]): +def get_module_file_path_includes(files: list[ModuleFile]): includes = [] for file in files: includes.append(f'#include "duckdb_python/import_cache/modules/{file.file_name}"') diff --git a/scripts/generate_import_cache_json.py b/scripts/generate_import_cache_json.py index 40e6a773..53d98c57 100644 --- a/scripts/generate_import_cache_json.py +++ b/scripts/generate_import_cache_json.py @@ -4,7 +4,7 @@ from typing import List, Dict, Union import json -lines: List[str] = [file for file in open(f'{script_dir}/imports.py').read().split('\n') if file != ''] +lines: list[str] = [file for file in open(f'{script_dir}/imports.py').read().split('\n') if file != ''] class ImportCacheAttribute: @@ -13,7 +13,7 @@ def __init__(self, full_path: str): self.type = "attribute" self.name = parts[-1] self.full_path = full_path - self.children: Dict[str, "ImportCacheAttribute"] = {} + self.children: dict[str, "ImportCacheAttribute"] = {} def has_item(self, item_name: str) -> bool: return item_name in self.children @@ -46,7 +46,7 @@ def __init__(self, full_path): self.type = "module" self.name = parts[-1] self.full_path = full_path - self.items: Dict[str, Union[ImportCacheAttribute, "ImportCacheModule"]] = {} + self.items: dict[str, Union[ImportCacheAttribute, "ImportCacheModule"]] = {} def add_item(self, item: Union[ImportCacheAttribute, "ImportCacheModule"]): assert self.full_path != item.full_path @@ -79,7 +79,7 @@ def root_module(self) -> bool: class ImportCacheGenerator: def __init__(self): - self.modules: Dict[str, ImportCacheModule] = {} + self.modules: dict[str, ImportCacheModule] = {} def add_module(self, path: str): assert path.startswith('import') diff --git a/scripts/get_cpp_methods.py b/scripts/get_cpp_methods.py index e784d054..9f86b4cb 100644 --- a/scripts/get_cpp_methods.py +++ b/scripts/get_cpp_methods.py @@ -16,7 +16,7 @@ def __init__(self, name: str, proto: str): class ConnectionMethod: - def __init__(self, name: str, params: List[FunctionParam], is_void: bool): + def __init__(self, name: str, params: list[FunctionParam], is_void: bool): self.name = name self.params = params self.is_void = is_void @@ -49,7 +49,7 @@ def on_class_method(self, state, node): self.methods_dict[name] = ConnectionMethod(name, params, is_void) -def get_methods(class_name: str) -> Dict[str, ConnectionMethod]: +def get_methods(class_name: str) -> dict[str, ConnectionMethod]: CLASSES = { "DuckDBPyConnection": os.path.join( scripts_folder, diff --git a/sqllogic/conftest.py b/sqllogic/conftest.py index 73219e0d..64ad8edc 100644 --- a/sqllogic/conftest.py +++ b/sqllogic/conftest.py @@ -90,7 +90,7 @@ def get_test_id(path: pathlib.Path, root_dir: pathlib.Path, config: pytest.Confi return str(path.relative_to(root_dir.parent)) -def get_test_marks(path: pathlib.Path, root_dir: pathlib.Path, config: pytest.Config) -> typing.List[typing.Any]: +def get_test_marks(path: pathlib.Path, root_dir: pathlib.Path, config: pytest.Config) -> list[typing.Any]: # Tests are tagged with the their category (i.e., name of their parent directory) category = path.parent.name @@ -142,7 +142,7 @@ def pytest_generate_tests(metafunc: pytest.Metafunc): if metafunc.definition.name != SQLLOGIC_TEST_CASE_NAME: return - test_dirs: typing.List[pathlib.Path] = metafunc.config.getoption("test_dirs") + test_dirs: list[pathlib.Path] = metafunc.config.getoption("test_dirs") test_glob: typing.Optional[pathlib.Path] = metafunc.config.getoption("path") parameters = [] @@ -165,7 +165,7 @@ def pytest_generate_tests(metafunc: pytest.Metafunc): metafunc.parametrize(SQLLOGIC_TEST_PARAMETER, parameters) -def determine_test_offsets(config: pytest.Config, num_tests: int) -> typing.Tuple[int, int]: +def determine_test_offsets(config: pytest.Config, num_tests: int) -> tuple[int, int]: """ If start_offset and end_offset are specified, then these are used. start_offset defaults to 0. end_offset defaults to and is capped to the last test index. diff --git a/tests/fast/test_filesystem.py b/tests/fast/test_filesystem.py index eaa86398..195de165 100644 --- a/tests/fast/test_filesystem.py +++ b/tests/fast/test_filesystem.py @@ -20,7 +20,7 @@ logging.basicConfig(level=logging.DEBUG) -def intercept(monkeypatch: MonkeyPatch, obj: object, name: str) -> List[str]: +def intercept(monkeypatch: MonkeyPatch, obj: object, name: str) -> list[str]: error_occurred = [] orig = getattr(obj, name) diff --git a/tests/fast/test_multithread.py b/tests/fast/test_multithread.py index 1ffdfc25..4b470b84 100644 --- a/tests/fast/test_multithread.py +++ b/tests/fast/test_multithread.py @@ -20,7 +20,7 @@ def connect_duck(duckdb_conn): assert out == [(42,), (84,), (None,), (128,)] -def everything_succeeded(results: List[bool]): +def everything_succeeded(results: list[bool]): return all([result == True for result in results]) @@ -501,7 +501,7 @@ def test_description(self, duckdb_cursor, pandas): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_cursor(self, duckdb_cursor, pandas): - def only_some_succeed(results: List[bool]): + def only_some_succeed(results: list[bool]): if not any([result == True for result in results]): return False if all([result == True for result in results]): From c23b65d51776a181fbf2e40e72ee741da571df8c Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Tue, 9 Sep 2025 20:12:46 +0200 Subject: [PATCH 009/135] Ruff ANN204: return type annotations --- duckdb/bytes_io_wrapper.py | 4 +- duckdb/experimental/spark/_globals.py | 10 +-- duckdb/experimental/spark/conf.py | 2 +- duckdb/experimental/spark/context.py | 2 +- .../spark/errors/exceptions/base.py | 2 +- duckdb/experimental/spark/exception.py | 2 +- duckdb/experimental/spark/sql/catalog.py | 2 +- duckdb/experimental/spark/sql/column.py | 4 +- duckdb/experimental/spark/sql/conf.py | 2 +- duckdb/experimental/spark/sql/dataframe.py | 4 +- duckdb/experimental/spark/sql/group.py | 6 +- duckdb/experimental/spark/sql/readwriter.py | 4 +- duckdb/experimental/spark/sql/session.py | 4 +- duckdb/experimental/spark/sql/streaming.py | 4 +- duckdb/experimental/spark/sql/types.py | 70 +++++++++---------- duckdb/experimental/spark/sql/udf.py | 2 +- duckdb/query_graph/__main__.py | 2 +- duckdb/value/constant/__init__.py | 66 ++++++++--------- duckdb_packaging/pypi_cleanup.py | 4 +- scripts/generate_import_cache_cpp.py | 4 +- scripts/generate_import_cache_json.py | 6 +- scripts/get_cpp_methods.py | 10 +-- sqllogic/test_sqllogic.py | 2 +- tests/conftest.py | 14 ++-- tests/fast/api/test_fsspec.py | 2 +- tests/fast/api/test_read_csv.py | 10 +-- tests/fast/arrow/test_arrow_extensions.py | 4 +- tests/fast/arrow/test_arrow_list.py | 2 +- tests/fast/arrow/test_arrow_pycapsule.py | 8 +-- tests/fast/arrow/test_dataset.py | 4 +- .../fast/pandas/test_df_object_resolution.py | 8 +-- tests/fast/pandas/test_pandas_types.py | 2 +- tests/fast/test_expression.py | 2 +- tests/fast/test_multithread.py | 2 +- tests/fast/udf/test_scalar.py | 6 +- 35 files changed, 143 insertions(+), 139 deletions(-) diff --git a/duckdb/bytes_io_wrapper.py b/duckdb/bytes_io_wrapper.py index 829b69cd..0957652b 100644 --- a/duckdb/bytes_io_wrapper.py +++ b/duckdb/bytes_io_wrapper.py @@ -1,5 +1,5 @@ from io import StringIO, TextIOBase -from typing import Union +from typing import Any, Union """ BSD 3-Clause License @@ -48,7 +48,7 @@ def __init__(self, buffer: Union[StringIO, TextIOBase], encoding: str = "utf-8") # overflow to the front of the bytestring the next time reading is performed self.overflow = b"" - def __getattr__(self, attr: str): + def __getattr__(self, attr: str) -> Any: return getattr(self.buffer, attr) def read(self, n: Union[int, None] = -1) -> bytes: diff --git a/duckdb/experimental/spark/_globals.py b/duckdb/experimental/spark/_globals.py index c43287e6..be16be41 100644 --- a/duckdb/experimental/spark/_globals.py +++ b/duckdb/experimental/spark/_globals.py @@ -32,6 +32,8 @@ def foo(arg=pyducdkb.spark._NoValue): Note that this approach is taken after from NumPy. """ +from typing import Type + __ALL__ = ["_NoValue"] @@ -54,23 +56,23 @@ class _NoValueType: __instance = None - def __new__(cls): + def __new__(cls) -> '_NoValueType': # ensure that only one instance exists if not cls.__instance: cls.__instance = super(_NoValueType, cls).__new__(cls) return cls.__instance # Make the _NoValue instance falsey - def __nonzero__(self): + def __nonzero__(self) -> bool: return False __bool__ = __nonzero__ # needed for python 2 to preserve identity through a pickle - def __reduce__(self): + def __reduce__(self) -> tuple[Type, tuple]: return (self.__class__, ()) - def __repr__(self): + def __repr__(self) -> str: return "" diff --git a/duckdb/experimental/spark/conf.py b/duckdb/experimental/spark/conf.py index a04c993b..79706781 100644 --- a/duckdb/experimental/spark/conf.py +++ b/duckdb/experimental/spark/conf.py @@ -3,7 +3,7 @@ class SparkConf: - def __init__(self): + def __init__(self) -> None: raise NotImplementedError def contains(self, key: str) -> bool: diff --git a/duckdb/experimental/spark/context.py b/duckdb/experimental/spark/context.py index a2e7c78f..95227add 100644 --- a/duckdb/experimental/spark/context.py +++ b/duckdb/experimental/spark/context.py @@ -7,7 +7,7 @@ class SparkContext: - def __init__(self, master: str): + def __init__(self, master: str) -> None: self._connection = duckdb.connect(':memory:') # This aligns the null ordering with Spark. self._connection.execute("set default_null_order='nulls_first_on_asc_last_on_desc'") diff --git a/duckdb/experimental/spark/errors/exceptions/base.py b/duckdb/experimental/spark/errors/exceptions/base.py index 80e91170..fcdce827 100644 --- a/duckdb/experimental/spark/errors/exceptions/base.py +++ b/duckdb/experimental/spark/errors/exceptions/base.py @@ -14,7 +14,7 @@ def __init__( error_class: Optional[str] = None, # The dictionary listing the arguments specified in the message (or the error_class) message_parameters: Optional[dict[str, str]] = None, - ): + ) -> None: # `message` vs `error_class` & `message_parameters` are mutually exclusive. assert (message is not None and (error_class is None and message_parameters is None)) or ( message is None and (error_class is not None and message_parameters is not None) diff --git a/duckdb/experimental/spark/exception.py b/duckdb/experimental/spark/exception.py index 7cb47650..21668cf5 100644 --- a/duckdb/experimental/spark/exception.py +++ b/duckdb/experimental/spark/exception.py @@ -5,7 +5,7 @@ class ContributionsAcceptedError(NotImplementedError): feel free to open up a PR or a Discussion over on https://github.com/duckdb/duckdb """ - def __init__(self, message=None): + def __init__(self, message=None) -> None: doc = self.__class__.__doc__ if message: doc = message + '\n' + doc diff --git a/duckdb/experimental/spark/sql/catalog.py b/duckdb/experimental/spark/sql/catalog.py index d3b857fb..0cd790f7 100644 --- a/duckdb/experimental/spark/sql/catalog.py +++ b/duckdb/experimental/spark/sql/catalog.py @@ -33,7 +33,7 @@ class Function(NamedTuple): class Catalog: - def __init__(self, session: SparkSession): + def __init__(self, session: SparkSession) -> None: self._session = session def listDatabases(self) -> list[Database]: diff --git a/duckdb/experimental/spark/sql/column.py b/duckdb/experimental/spark/sql/column.py index de0c95f8..0dd86178 100644 --- a/duckdb/experimental/spark/sql/column.py +++ b/duckdb/experimental/spark/sql/column.py @@ -95,11 +95,11 @@ class Column: .. versionadded:: 1.3.0 """ - def __init__(self, expr: Expression): + def __init__(self, expr: Expression) -> None: self.expr = expr # arithmetic operators - def __neg__(self): + def __neg__(self) -> 'Column': return Column(-self.expr) # `and`, `or`, `not` cannot be overloaded in Python, diff --git a/duckdb/experimental/spark/sql/conf.py b/duckdb/experimental/spark/sql/conf.py index 98b773fb..8e30d7ca 100644 --- a/duckdb/experimental/spark/sql/conf.py +++ b/duckdb/experimental/spark/sql/conf.py @@ -4,7 +4,7 @@ class RuntimeConfig: - def __init__(self, connection: DuckDBPyConnection): + def __init__(self, connection: DuckDBPyConnection) -> None: self._connection = connection def set(self, key: str, value: str) -> None: diff --git a/duckdb/experimental/spark/sql/dataframe.py b/duckdb/experimental/spark/sql/dataframe.py index 54c220eb..42a5b8f0 100644 --- a/duckdb/experimental/spark/sql/dataframe.py +++ b/duckdb/experimental/spark/sql/dataframe.py @@ -37,7 +37,7 @@ class DataFrame: - def __init__(self, relation: duckdb.DuckDBPyRelation, session: "SparkSession"): + def __init__(self, relation: duckdb.DuckDBPyRelation, session: "SparkSession") -> None: self.relation = relation self.session = session self._schema = None @@ -870,7 +870,7 @@ def limit(self, num: int) -> "DataFrame": rel = self.relation.limit(num) return DataFrame(rel, self.session) - def __contains__(self, item: str): + def __contains__(self, item: str) -> bool: """ Check if the :class:`DataFrame` contains a column by the name of `item` """ diff --git a/duckdb/experimental/spark/sql/group.py b/duckdb/experimental/spark/sql/group.py index ad7e7e2a..4c4d5bb6 100644 --- a/duckdb/experimental/spark/sql/group.py +++ b/duckdb/experimental/spark/sql/group.py @@ -53,7 +53,7 @@ def _api(self: "GroupedData", *cols: str) -> DataFrame: class Grouping: - def __init__(self, *cols: "ColumnOrName", **kwargs): + def __init__(self, *cols: "ColumnOrName", **kwargs) -> None: self._type = "" self._cols = [_to_column_expr(x) for x in cols] if 'special' in kwargs: @@ -66,7 +66,7 @@ def get_columns(self) -> str: columns = ",".join([str(x) for x in self._cols]) return columns - def __str__(self): + def __str__(self) -> str: columns = self.get_columns() if self._type: return self._type + '(' + columns + ')' @@ -80,7 +80,7 @@ class GroupedData: """ - def __init__(self, grouping: Grouping, df: DataFrame): + def __init__(self, grouping: Grouping, df: DataFrame) -> None: self._grouping = grouping self._df = df self.session: SparkSession = df.session diff --git a/duckdb/experimental/spark/sql/readwriter.py b/duckdb/experimental/spark/sql/readwriter.py index 6c8b5e7d..6e8c72c6 100644 --- a/duckdb/experimental/spark/sql/readwriter.py +++ b/duckdb/experimental/spark/sql/readwriter.py @@ -15,7 +15,7 @@ class DataFrameWriter: - def __init__(self, dataframe: "DataFrame"): + def __init__(self, dataframe: "DataFrame") -> None: self.dataframe = dataframe def saveAsTable(self, table_name: str) -> None: @@ -89,7 +89,7 @@ def csv( class DataFrameReader: - def __init__(self, session: "SparkSession"): + def __init__(self, session: "SparkSession") -> None: self.session = session def load( diff --git a/duckdb/experimental/spark/sql/session.py b/duckdb/experimental/spark/sql/session.py index 91f9cc0e..744a77e8 100644 --- a/duckdb/experimental/spark/sql/session.py +++ b/duckdb/experimental/spark/sql/session.py @@ -45,7 +45,7 @@ def _combine_data_and_schema(data: Iterable[Any], schema: StructType): class SparkSession: - def __init__(self, context: SparkContext): + def __init__(self, context: SparkContext) -> None: self.conn = context.connection self._context = context self._conf = RuntimeConfig(self.conn) @@ -258,7 +258,7 @@ def version(self) -> str: return '1.0.0' class Builder: - def __init__(self): + def __init__(self) -> None: pass def master(self, name: str) -> "SparkSession.Builder": diff --git a/duckdb/experimental/spark/sql/streaming.py b/duckdb/experimental/spark/sql/streaming.py index 5414344f..cda80602 100644 --- a/duckdb/experimental/spark/sql/streaming.py +++ b/duckdb/experimental/spark/sql/streaming.py @@ -10,7 +10,7 @@ class DataStreamWriter: - def __init__(self, dataframe: "DataFrame"): + def __init__(self, dataframe: "DataFrame") -> None: self.dataframe = dataframe def toTable(self, table_name: str) -> None: @@ -19,7 +19,7 @@ def toTable(self, table_name: str) -> None: class DataStreamReader: - def __init__(self, session: "SparkSession"): + def __init__(self, session: "SparkSession") -> None: self.session = session def load( diff --git a/duckdb/experimental/spark/sql/types.py b/duckdb/experimental/spark/sql/types.py index d4dcbd9a..4b3a4132 100644 --- a/duckdb/experimental/spark/sql/types.py +++ b/duckdb/experimental/spark/sql/types.py @@ -70,7 +70,7 @@ class DataType: """Base class for data types.""" - def __init__(self, duckdb_type): + def __init__(self, duckdb_type) -> None: self.duckdb_type = duckdb_type def __repr__(self) -> str: @@ -138,7 +138,7 @@ class NullType(DataType, metaclass=DataTypeSingleton): The data type representing None, used for the types that cannot be inferred. """ - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("NULL")) @classmethod @@ -166,42 +166,42 @@ class FractionalType(NumericType): class StringType(AtomicType, metaclass=DataTypeSingleton): """String data type.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("VARCHAR")) class BitstringType(AtomicType, metaclass=DataTypeSingleton): """Bitstring data type.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("BIT")) class UUIDType(AtomicType, metaclass=DataTypeSingleton): """UUID data type.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("UUID")) class BinaryType(AtomicType, metaclass=DataTypeSingleton): """Binary (byte array) data type.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("BLOB")) class BooleanType(AtomicType, metaclass=DataTypeSingleton): """Boolean data type.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("BOOLEAN")) class DateType(AtomicType, metaclass=DataTypeSingleton): """Date (datetime.date) data type.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("DATE")) EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal() @@ -221,7 +221,7 @@ def fromInternal(self, v: int) -> datetime.date: class TimestampType(AtomicType, metaclass=DataTypeSingleton): """Timestamp (datetime.datetime) data type.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("TIMESTAMPTZ")) @classmethod @@ -245,7 +245,7 @@ def fromInternal(self, ts: int) -> datetime.datetime: class TimestampNTZType(AtomicType, metaclass=DataTypeSingleton): """Timestamp (datetime.datetime) data type without timezone information with microsecond precision.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("TIMESTAMP")) def needConversion(self) -> bool: @@ -269,7 +269,7 @@ def fromInternal(self, ts: int) -> datetime.datetime: class TimestampSecondNTZType(AtomicType, metaclass=DataTypeSingleton): """Timestamp (datetime.datetime) data type without timezone information with second precision.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("TIMESTAMP_S")) def needConversion(self) -> bool: @@ -289,7 +289,7 @@ def fromInternal(self, ts: int) -> datetime.datetime: class TimestampMilisecondNTZType(AtomicType, metaclass=DataTypeSingleton): """Timestamp (datetime.datetime) data type without timezone information with milisecond precision.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("TIMESTAMP_MS")) def needConversion(self) -> bool: @@ -309,7 +309,7 @@ def fromInternal(self, ts: int) -> datetime.datetime: class TimestampNanosecondNTZType(AtomicType, metaclass=DataTypeSingleton): """Timestamp (datetime.datetime) data type without timezone information with nanosecond precision.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("TIMESTAMP_NS")) def needConversion(self) -> bool: @@ -346,7 +346,7 @@ class DecimalType(FractionalType): the number of digits on right side of dot. (default: 0) """ - def __init__(self, precision: int = 10, scale: int = 0): + def __init__(self, precision: int = 10, scale: int = 0) -> None: super().__init__(duckdb.decimal_type(precision, scale)) self.precision = precision self.scale = scale @@ -362,21 +362,21 @@ def __repr__(self) -> str: class DoubleType(FractionalType, metaclass=DataTypeSingleton): """Double data type, representing double precision floats.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("DOUBLE")) class FloatType(FractionalType, metaclass=DataTypeSingleton): """Float data type, representing single precision floats.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("FLOAT")) class ByteType(IntegralType): """Byte data type, i.e. a signed integer in a single byte.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("TINYINT")) def simpleString(self) -> str: @@ -386,7 +386,7 @@ def simpleString(self) -> str: class UnsignedByteType(IntegralType): """Unsigned byte data type, i.e. a unsigned integer in a single byte.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("UTINYINT")) def simpleString(self) -> str: @@ -396,7 +396,7 @@ def simpleString(self) -> str: class ShortType(IntegralType): """Short data type, i.e. a signed 16-bit integer.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("SMALLINT")) def simpleString(self) -> str: @@ -406,7 +406,7 @@ def simpleString(self) -> str: class UnsignedShortType(IntegralType): """Unsigned short data type, i.e. a unsigned 16-bit integer.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("USMALLINT")) def simpleString(self) -> str: @@ -416,7 +416,7 @@ def simpleString(self) -> str: class IntegerType(IntegralType): """Int data type, i.e. a signed 32-bit integer.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("INTEGER")) def simpleString(self) -> str: @@ -426,7 +426,7 @@ def simpleString(self) -> str: class UnsignedIntegerType(IntegralType): """Unsigned int data type, i.e. a unsigned 32-bit integer.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("UINTEGER")) def simpleString(self) -> str: @@ -440,7 +440,7 @@ class LongType(IntegralType): please use :class:`DecimalType`. """ - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("BIGINT")) def simpleString(self) -> str: @@ -454,7 +454,7 @@ class UnsignedLongType(IntegralType): please use :class:`HugeIntegerType`. """ - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("UBIGINT")) def simpleString(self) -> str: @@ -468,7 +468,7 @@ class HugeIntegerType(IntegralType): please use :class:`DecimalType`. """ - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("HUGEINT")) def simpleString(self) -> str: @@ -482,7 +482,7 @@ class UnsignedHugeIntegerType(IntegralType): please use :class:`DecimalType`. """ - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("UHUGEINT")) def simpleString(self) -> str: @@ -492,7 +492,7 @@ def simpleString(self) -> str: class TimeType(IntegralType): """Time (datetime.time) data type.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("TIMETZ")) def simpleString(self) -> str: @@ -502,7 +502,7 @@ def simpleString(self) -> str: class TimeNTZType(IntegralType): """Time (datetime.time) data type without timezone information.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("TIME")) def simpleString(self) -> str: @@ -526,7 +526,7 @@ class DayTimeIntervalType(AtomicType): _inverted_fields = dict(zip(_fields.values(), _fields.keys())) - def __init__(self, startField: Optional[int] = None, endField: Optional[int] = None): + def __init__(self, startField: Optional[int] = None, endField: Optional[int] = None) -> None: super().__init__(DuckDBPyType("INTERVAL")) if startField is None and endField is None: # Default matched to scala side. @@ -585,7 +585,7 @@ class ArrayType(DataType): False """ - def __init__(self, elementType: DataType, containsNull: bool = True): + def __init__(self, elementType: DataType, containsNull: bool = True) -> None: super().__init__(duckdb.list_type(elementType.duckdb_type)) assert isinstance(elementType, DataType), "elementType %s should be an instance of %s" % ( elementType, @@ -640,7 +640,7 @@ class MapType(DataType): False """ - def __init__(self, keyType: DataType, valueType: DataType, valueContainsNull: bool = True): + def __init__(self, keyType: DataType, valueType: DataType, valueContainsNull: bool = True) -> None: super().__init__(duckdb.map_type(keyType.duckdb_type, valueType.duckdb_type)) assert isinstance(keyType, DataType), "keyType %s should be an instance of %s" % ( keyType, @@ -711,7 +711,7 @@ def __init__( dataType: DataType, nullable: bool = True, metadata: Optional[dict[str, Any]] = None, - ): + ) -> None: super().__init__(dataType.duckdb_type) assert isinstance(dataType, DataType), "dataType %s should be an instance of %s" % ( dataType, @@ -776,7 +776,7 @@ class StructType(DataType): def _update_internal_duckdb_type(self): self.duckdb_type = duckdb.struct_type(dict(zip(self.names, [x.duckdb_type for x in self.fields]))) - def __init__(self, fields: Optional[list[StructField]] = None): + def __init__(self, fields: Optional[list[StructField]] = None) -> None: if not fields: self.fields = [] self.names = [] @@ -973,7 +973,7 @@ def fromInternal(self, obj: tuple) -> "Row": class UnionType(DataType): - def __init__(self): + def __init__(self) -> None: raise ContributionsAcceptedError @@ -983,7 +983,7 @@ class UserDefinedType(DataType): .. note:: WARN: Spark Internal Use Only """ - def __init__(self): + def __init__(self) -> None: raise ContributionsAcceptedError @classmethod diff --git a/duckdb/experimental/spark/sql/udf.py b/duckdb/experimental/spark/sql/udf.py index 61d3bee9..389d43ab 100644 --- a/duckdb/experimental/spark/sql/udf.py +++ b/duckdb/experimental/spark/sql/udf.py @@ -11,7 +11,7 @@ class UDFRegistration: - def __init__(self, sparkSession: "SparkSession"): + def __init__(self, sparkSession: "SparkSession") -> None: self.sparkSession = sparkSession def register( diff --git a/duckdb/query_graph/__main__.py b/duckdb/query_graph/__main__.py index 26038a6f..eab68179 100644 --- a/duckdb/query_graph/__main__.py +++ b/duckdb/query_graph/__main__.py @@ -95,7 +95,7 @@ def combine_timing(l: object, r: object) -> object: class AllTimings: - def __init__(self): + def __init__(self) -> None: self.phase_to_timings = {} def add_node_timing(self, node_timing: NodeTiming): diff --git a/duckdb/value/constant/__init__.py b/duckdb/value/constant/__init__.py index 0a5a62c0..fb7d7284 100644 --- a/duckdb/value/constant/__init__.py +++ b/duckdb/value/constant/__init__.py @@ -32,7 +32,7 @@ class Value: - def __init__(self, object: Any, type: DuckDBPyType): + def __init__(self, object: Any, type: DuckDBPyType) -> None: self.object = object self.type = type @@ -44,12 +44,12 @@ def __repr__(self) -> str: class NullValue(Value): - def __init__(self): + def __init__(self) -> None: super().__init__(None, SQLNULL) class BooleanValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, BOOLEAN) @@ -57,22 +57,22 @@ def __init__(self, object: Any): class UnsignedBinaryValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, UTINYINT) class UnsignedShortValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, USMALLINT) class UnsignedIntegerValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, UINTEGER) class UnsignedLongValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, UBIGINT) @@ -80,32 +80,32 @@ def __init__(self, object: Any): class BinaryValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, TINYINT) class ShortValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, SMALLINT) class IntegerValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, INTEGER) class LongValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, BIGINT) class HugeIntegerValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, HUGEINT) class UnsignedHugeIntegerValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, UHUGEINT) @@ -113,17 +113,17 @@ def __init__(self, object: Any): class FloatValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, FLOAT) class DoubleValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, DOUBLE) class DecimalValue(Value): - def __init__(self, object: Any, width: int, scale: int): + def __init__(self, object: Any, width: int, scale: int) -> None: import duckdb decimal_type = duckdb.decimal_type(width, scale) @@ -134,22 +134,22 @@ def __init__(self, object: Any, width: int, scale: int): class StringValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, VARCHAR) class UUIDValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, UUID) class BitValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, BIT) class BlobValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, BLOB) @@ -157,52 +157,52 @@ def __init__(self, object: Any): class DateValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, DATE) class IntervalValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, INTERVAL) class TimestampValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, TIMESTAMP) class TimestampSecondValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, TIMESTAMP_S) class TimestampMilisecondValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, TIMESTAMP_MS) class TimestampNanosecondValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, TIMESTAMP_NS) class TimestampTimeZoneValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, TIMESTAMP_TZ) class TimeValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, TIME) class TimeTimeZoneValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, TIME_TZ) class ListValue(Value): - def __init__(self, object: Any, child_type: DuckDBPyType): + def __init__(self, object: Any, child_type: DuckDBPyType) -> None: import duckdb list_type = duckdb.list_type(child_type) @@ -210,7 +210,7 @@ def __init__(self, object: Any, child_type: DuckDBPyType): class StructValue(Value): - def __init__(self, object: Any, children: dict[str, DuckDBPyType]): + def __init__(self, object: Any, children: dict[str, DuckDBPyType]) -> None: import duckdb struct_type = duckdb.struct_type(children) @@ -218,7 +218,7 @@ def __init__(self, object: Any, children: dict[str, DuckDBPyType]): class MapValue(Value): - def __init__(self, object: Any, key_type: DuckDBPyType, value_type: DuckDBPyType): + def __init__(self, object: Any, key_type: DuckDBPyType, value_type: DuckDBPyType) -> None: import duckdb map_type = duckdb.map_type(key_type, value_type) @@ -226,7 +226,7 @@ def __init__(self, object: Any, key_type: DuckDBPyType, value_type: DuckDBPyType class UnionType(Value): - def __init__(self, object: Any, members: dict[str, DuckDBPyType]): + def __init__(self, object: Any, members: dict[str, DuckDBPyType]) -> None: import duckdb union_type = duckdb.union_type(members) diff --git a/duckdb_packaging/pypi_cleanup.py b/duckdb_packaging/pypi_cleanup.py index 8236dd1d..031adf94 100644 --- a/duckdb_packaging/pypi_cleanup.py +++ b/duckdb_packaging/pypi_cleanup.py @@ -183,7 +183,7 @@ class CsrfParser(HTMLParser): Based on pypi-cleanup package (https://github.com/arcivanov/pypi-cleanup/tree/master) """ - def __init__(self, target, contains_input=None): + def __init__(self, target, contains_input=None) -> None: super().__init__() self._target = target self._contains_input = contains_input @@ -223,7 +223,7 @@ class PyPICleanup: """Main class for performing PyPI package cleanup operations.""" def __init__(self, index_url: str, do_delete: bool, max_dev_releases: int=_DEFAULT_MAX_NIGHTLIES, - username: Optional[str]=None, password: Optional[str]=None, otp: Optional[str]=None): + username: Optional[str]=None, password: Optional[str]=None, otp: Optional[str]=None) -> None: parsed_url = urlparse(index_url) self._index_url = parsed_url.geturl().rstrip('/') self._index_host = parsed_url.hostname diff --git a/scripts/generate_import_cache_cpp.py b/scripts/generate_import_cache_cpp.py index f1f9d983..f03d8d89 100644 --- a/scripts/generate_import_cache_cpp.py +++ b/scripts/generate_import_cache_cpp.py @@ -40,7 +40,7 @@ def collect_items_of_module(module: dict, collection: dict): class CacheItem: - def __init__(self, module: dict, items): + def __init__(self, module: dict, items) -> None: self.name = module['name'] self.module = module self.items = items @@ -132,7 +132,7 @@ def collect_classes(items: dict) -> list: class ModuleFile: - def __init__(self, module: dict): + def __init__(self, module: dict) -> None: self.module = module self.file_name = get_filename(module['name']) self.items = {} diff --git a/scripts/generate_import_cache_json.py b/scripts/generate_import_cache_json.py index 53d98c57..2df33b24 100644 --- a/scripts/generate_import_cache_json.py +++ b/scripts/generate_import_cache_json.py @@ -8,7 +8,7 @@ class ImportCacheAttribute: - def __init__(self, full_path: str): + def __init__(self, full_path: str) -> None: parts = full_path.split('.') self.type = "attribute" self.name = parts[-1] @@ -41,7 +41,7 @@ def populate_json(self, json_data: dict): class ImportCacheModule: - def __init__(self, full_path): + def __init__(self, full_path) -> None: parts = full_path.split('.') self.type = "module" self.name = parts[-1] @@ -78,7 +78,7 @@ def root_module(self) -> bool: class ImportCacheGenerator: - def __init__(self): + def __init__(self) -> None: self.modules: dict[str, ImportCacheModule] = {} def add_module(self, path: str): diff --git a/scripts/get_cpp_methods.py b/scripts/get_cpp_methods.py index 9f86b4cb..97b28af3 100644 --- a/scripts/get_cpp_methods.py +++ b/scripts/get_cpp_methods.py @@ -4,30 +4,30 @@ import cxxheaderparser.parser import cxxheaderparser.visitor import cxxheaderparser.preprocessor -from typing import List, Dict +from typing import List, Dict, Callable scripts_folder = os.path.dirname(os.path.abspath(__file__)) class FunctionParam: - def __init__(self, name: str, proto: str): + def __init__(self, name: str, proto: str) -> None: self.proto = proto self.name = name class ConnectionMethod: - def __init__(self, name: str, params: list[FunctionParam], is_void: bool): + def __init__(self, name: str, params: list[FunctionParam], is_void: bool) -> None: self.name = name self.params = params self.is_void = is_void class Visitor: - def __init__(self, class_name: str): + def __init__(self, class_name: str) -> None: self.methods_dict = {} self.class_name = class_name - def __getattr__(self, name): + def __getattr__(self, name) -> Callable[[...], bool]: return lambda *state: True def on_class_start(self, state): diff --git a/sqllogic/test_sqllogic.py b/sqllogic/test_sqllogic.py index ee7426cd..4e7cead0 100644 --- a/sqllogic/test_sqllogic.py +++ b/sqllogic/test_sqllogic.py @@ -39,7 +39,7 @@ def sigquit_handler(signum, frame): class SQLLogicTestExecutor(SQLLogicRunner): - def __init__(self, test_directory: str, build_directory: Optional[str] = None): + def __init__(self, test_directory: str, build_directory: Optional[str] = None) -> None: super().__init__(build_directory) self.test_directory = test_directory # TODO: get this from the `duckdb` package diff --git a/tests/conftest.py b/tests/conftest.py index ce2d0e68..b9950ee7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,6 @@ import os +from typing import Any + import pytest import shutil from os.path import abspath, join, dirname, normpath @@ -121,12 +123,12 @@ def arrow_pandas_df(*args, **kwargs): class NumpyPandas: - def __init__(self): + def __init__(self) -> None: self.backend = 'numpy_nullable' self.DataFrame = numpy_pandas_df self.pandas = import_pandas() - def __getattr__(self, name: str): + def __getattr__(self, name: str) -> Any: return getattr(self.pandas, name) @@ -156,11 +158,11 @@ def convert_and_equal(df1, df2, **kwargs): class ArrowMockTesting: - def __init__(self): + def __init__(self) -> None: self.testing = import_pandas().testing self.assert_frame_equal = convert_and_equal - def __getattr__(self, name: str): + def __getattr__(self, name: str) -> Any: return getattr(self.testing, name) @@ -168,7 +170,7 @@ def __getattr__(self, name: str): # Assert equal does the opposite, turning all pyarrow backed dataframes into numpy backed ones # this is done because we don't produce pyarrow backed dataframes yet class ArrowPandas: - def __init__(self): + def __init__(self) -> None: self.pandas = import_pandas() if pandas_2_or_higher() and pyarrow_dtypes_enabled: self.backend = 'pyarrow' @@ -179,7 +181,7 @@ def __init__(self): self.DataFrame = self.pandas.DataFrame self.testing = ArrowMockTesting() - def __getattr__(self, name: str): + def __getattr__(self, name: str) -> Any: return getattr(self.pandas, name) diff --git a/tests/fast/api/test_fsspec.py b/tests/fast/api/test_fsspec.py index 0a289972..a878fda5 100644 --- a/tests/fast/api/test_fsspec.py +++ b/tests/fast/api/test_fsspec.py @@ -44,7 +44,7 @@ def modified(self, path): def _open(self, path, **kwargs): return io.BytesIO(self._data[path]) - def __init__(self): + def __init__(self) -> None: super().__init__() self._data = {"a": parquet_data, "b": parquet_data} diff --git a/tests/fast/api/test_read_csv.py b/tests/fast/api/test_read_csv.py index 1a297109..7337515d 100644 --- a/tests/fast/api/test_read_csv.py +++ b/tests/fast/api/test_read_csv.py @@ -327,7 +327,7 @@ def test_filelike_exception(self, duckdb_cursor): _ = pytest.importorskip("fsspec") class ReadError: - def __init__(self): + def __init__(self) -> None: pass def read(self, amount=-1): @@ -337,7 +337,7 @@ def seek(self, loc): return 0 class SeekError: - def __init__(self): + def __init__(self) -> None: pass def read(self, amount=-1): @@ -359,7 +359,7 @@ def test_filelike_custom(self, duckdb_cursor): _ = pytest.importorskip("fsspec") class CustomIO: - def __init__(self): + def __init__(self) -> None: self.loc = 0 pass @@ -398,11 +398,11 @@ def test_internal_object_filesystem_cleanup(self, duckdb_cursor): class CountedObject(StringIO): instance_count = 0 - def __init__(self, str): + def __init__(self, str) -> None: CountedObject.instance_count += 1 super().__init__(str) - def __del__(self): + def __del__(self) -> None: CountedObject.instance_count -= 1 def scoped_objects(duckdb_cursor): diff --git a/tests/fast/arrow/test_arrow_extensions.py b/tests/fast/arrow/test_arrow_extensions.py index 9180fa90..95a2108a 100644 --- a/tests/fast/arrow/test_arrow_extensions.py +++ b/tests/fast/arrow/test_arrow_extensions.py @@ -116,10 +116,10 @@ def test_function(x): def test_unimplemented_extension(self, duckdb_cursor): class MyType(pa.ExtensionType): - def __init__(self): + def __init__(self) -> None: pa.ExtensionType.__init__(self, pa.binary(5), "pedro.binary") - def __arrow_ext_serialize__(self): + def __arrow_ext_serialize__(self) -> bytes: return b'' @classmethod diff --git a/tests/fast/arrow/test_arrow_list.py b/tests/fast/arrow/test_arrow_list.py index e2449fd3..556f614a 100644 --- a/tests/fast/arrow/test_arrow_list.py +++ b/tests/fast/arrow/test_arrow_list.py @@ -41,7 +41,7 @@ def create_and_register_comparison_result(column_list, duckdb_cursor): class ListGenerationResult: - def __init__(self, list, list_view): + def __init__(self, list, list_view) -> None: self.list = list self.list_view = list_view diff --git a/tests/fast/arrow/test_arrow_pycapsule.py b/tests/fast/arrow/test_arrow_pycapsule.py index c293344d..8310c58b 100644 --- a/tests/fast/arrow/test_arrow_pycapsule.py +++ b/tests/fast/arrow/test_arrow_pycapsule.py @@ -17,11 +17,11 @@ def polars_supports_capsule(): class TestArrowPyCapsule(object): def test_polars_pycapsule_scan(self, duckdb_cursor): class MyObject: - def __init__(self, obj): + def __init__(self, obj) -> None: self.obj = obj self.count = 0 - def __arrow_c_stream__(self, requested_schema=None): + def __arrow_c_stream__(self, requested_schema=None) -> object: self.count += 1 return self.obj.__arrow_c_stream__(requested_schema=requested_schema) @@ -71,11 +71,11 @@ def test_automatic_reexecution(self, duckdb_cursor): def test_consumer_interface_roundtrip(self, duckdb_cursor): def create_table(): class MyTable: - def __init__(self, rel, conn): + def __init__(self, rel, conn) -> None: self.rel = rel self.conn = conn - def __arrow_c_stream__(self, requested_schema=None): + def __arrow_c_stream__(self, requested_schema=None) -> object: return self.rel.__arrow_c_stream__(requested_schema=requested_schema) conn = duckdb.connect() diff --git a/tests/fast/arrow/test_dataset.py b/tests/fast/arrow/test_dataset.py index 2f3d7a53..521ec8f7 100644 --- a/tests/fast/arrow/test_dataset.py +++ b/tests/fast/arrow/test_dataset.py @@ -102,7 +102,7 @@ class CustomDataset(pyarrow.dataset.Dataset): SCHEMA = pyarrow.schema([pyarrow.field("a", pyarrow.int64(), True), pyarrow.field("b", pyarrow.float64(), True)]) DATA = pyarrow.Table.from_arrays([pyarrow.array(range(100)), pyarrow.array(np.arange(100) * 1.0)], schema=SCHEMA) - def __init__(self): + def __init__(self) -> None: pass def scanner(self, **kwargs): @@ -114,7 +114,7 @@ def schema(self): class CustomScanner(pyarrow.dataset.Scanner): - def __init__(self, filter=None, columns=None, **kwargs): + def __init__(self, filter=None, columns=None, **kwargs) -> None: self.filter = filter self.columns = columns self.kwargs = kwargs diff --git a/tests/fast/pandas/test_df_object_resolution.py b/tests/fast/pandas/test_df_object_resolution.py index ed89f324..d54db072 100644 --- a/tests/fast/pandas/test_df_object_resolution.py +++ b/tests/fast/pandas/test_df_object_resolution.py @@ -30,10 +30,10 @@ def create_trailing_non_null(size): class IntString: - def __init__(self, value: int): + def __init__(self, value: int) -> None: self.value = value - def __str__(self): + def __str__(self) -> str: return str(self.value) @@ -48,11 +48,11 @@ def ConvertStringToDecimal(data: list, pandas): class ObjectPair: - def __init__(self, obj1, obj2): + def __init__(self, obj1, obj2) -> None: self.first = obj1 self.second = obj2 - def __repr__(self): + def __repr__(self) -> str: return str([self.first, self.second]) diff --git a/tests/fast/pandas/test_pandas_types.py b/tests/fast/pandas/test_pandas_types.py index aeb33ea4..b21c7f14 100644 --- a/tests/fast/pandas/test_pandas_types.py +++ b/tests/fast/pandas/test_pandas_types.py @@ -185,7 +185,7 @@ def test_pandas_encoded_utf8(self, duckdb_cursor): ) def test_producing_nullable_dtypes(self, duckdb_cursor, dtype): class Input: - def __init__(self, value, expected_dtype): + def __init__(self, value, expected_dtype) -> None: self.value = value self.expected_dtype = expected_dtype diff --git a/tests/fast/test_expression.py b/tests/fast/test_expression.py index 289d88a9..e0f830c5 100644 --- a/tests/fast/test_expression.py +++ b/tests/fast/test_expression.py @@ -987,7 +987,7 @@ def test_aggregate_error(self): ): class MyClass: - def __init__(self): + def __init__(self) -> None: pass res = rel.aggregate([MyClass()]).fetchone()[0] diff --git a/tests/fast/test_multithread.py b/tests/fast/test_multithread.py index 4b470b84..ad2d56fd 100644 --- a/tests/fast/test_multithread.py +++ b/tests/fast/test_multithread.py @@ -25,7 +25,7 @@ def everything_succeeded(results: list[bool]): class DuckDBThreaded: - def __init__(self, duckdb_insert_thread_count, thread_function, pandas): + def __init__(self, duckdb_insert_thread_count, thread_function, pandas) -> None: self.duckdb_insert_thread_count = duckdb_insert_thread_count self.threads = [] self.thread_function = thread_function diff --git a/tests/fast/udf/test_scalar.py b/tests/fast/udf/test_scalar.py index 61648c20..8e0eb8b1 100644 --- a/tests/fast/udf/test_scalar.py +++ b/tests/fast/udf/test_scalar.py @@ -4,7 +4,7 @@ pd = pytest.importorskip("pandas") pa = pytest.importorskip('pyarrow', '18.0.0') -from typing import Union +from typing import Union, Any import pyarrow.compute as pc import uuid import datetime @@ -156,10 +156,10 @@ def test_non_callable(self): con.create_function('func', 5, [BIGINT], BIGINT, type='arrow') class MyCallable: - def __init__(self): + def __init__(self) -> None: pass - def __call__(self, x): + def __call__(self, x) -> Any: return x my_callable = MyCallable() From 83cdb9cfc329053705f4b0d93b4ee20dedcf0ff5 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Tue, 9 Sep 2025 20:24:46 +0200 Subject: [PATCH 010/135] Ruff config: line-length to 120 and fixable no longer top level --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4da79b50..a53f9eb5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -312,14 +312,14 @@ branch = true source = ["duckdb"] [tool.ruff] -line-length = 88 +line-length = 120 indent-width = 4 target-version = "py39" fix = true -fixable = ["ALL"] exclude = ['external/duckdb'] [tool.ruff.lint] +fixable = ["ALL"] exclude = ['*.pyi'] select = [ "ANN", # flake8-annotations From d10c477dbe2744ce92d6a152f47a163938b7d110 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Tue, 9 Sep 2025 20:26:45 +0200 Subject: [PATCH 011/135] Ruff format fixes --- duckdb/__init__.py | 571 +++++----- duckdb/__init__.pyi | 784 ++++++++++---- duckdb/bytes_io_wrapper.py | 1 - duckdb/experimental/__init__.py | 1 + duckdb/experimental/spark/_globals.py | 2 +- duckdb/experimental/spark/_typing.py | 9 +- duckdb/experimental/spark/context.py | 2 +- duckdb/experimental/spark/errors/__init__.py | 1 + .../spark/errors/exceptions/base.py | 2 + duckdb/experimental/spark/errors/utils.py | 3 +- duckdb/experimental/spark/exception.py | 2 +- duckdb/experimental/spark/sql/_typing.py | 19 +- duckdb/experimental/spark/sql/catalog.py | 8 +- duckdb/experimental/spark/sql/column.py | 10 +- duckdb/experimental/spark/sql/dataframe.py | 178 ++-- duckdb/experimental/spark/sql/functions.py | 995 +++++++++++------- duckdb/experimental/spark/sql/group.py | 57 +- duckdb/experimental/spark/sql/readwriter.py | 80 +- duckdb/experimental/spark/sql/session.py | 24 +- duckdb/experimental/spark/sql/streaming.py | 2 +- duckdb/experimental/spark/sql/type_utils.py | 70 +- duckdb/experimental/spark/sql/types.py | 54 +- duckdb/filesystem.py | 3 +- duckdb/functional/__init__.py | 18 +- duckdb/polars_io.py | 60 +- duckdb/query_graph/__main__.py | 104 +- duckdb/typing/__init__.py | 4 +- duckdb/typing/__init__.pyi | 6 +- duckdb/value/constant/__init__.pyi | 7 +- duckdb_packaging/_versioning.py | 32 +- duckdb_packaging/build_backend.py | 20 +- duckdb_packaging/pypi_cleanup.py | 173 ++- duckdb_packaging/setuptools_scm_version.py | 15 +- scripts/generate_connection_code.py | 2 +- scripts/generate_connection_methods.py | 66 +- scripts/generate_connection_stubs.py | 34 +- .../generate_connection_wrapper_methods.py | 132 +-- scripts/generate_connection_wrapper_stubs.py | 48 +- scripts/generate_import_cache_cpp.py | 96 +- scripts/generate_import_cache_json.py | 26 +- sqllogic/conftest.py | 6 +- sqllogic/skipped_tests.py | 76 +- sqllogic/test_sqllogic.py | 14 +- tests/conftest.py | 50 +- .../test_pandas_categorical_coverage.py | 16 +- tests/extensions/json/test_read_json.py | 106 +- tests/extensions/test_extensions_loading.py | 22 +- tests/extensions/test_httpfs.py | 30 +- tests/fast/adbc/test_adbc.py | 34 +- tests/fast/adbc/test_statement_bind.py | 32 +- tests/fast/api/test_3324.py | 2 +- tests/fast/api/test_3654.py | 8 +- tests/fast/api/test_3728.py | 4 +- tests/fast/api/test_6315.py | 6 +- tests/fast/api/test_attribute_getter.py | 28 +- tests/fast/api/test_config.py | 58 +- tests/fast/api/test_connection_close.py | 6 +- tests/fast/api/test_cursor.py | 22 +- tests/fast/api/test_dbapi00.py | 36 +- tests/fast/api/test_dbapi01.py | 18 +- tests/fast/api/test_dbapi04.py | 2 +- tests/fast/api/test_dbapi05.py | 24 +- tests/fast/api/test_dbapi07.py | 4 +- tests/fast/api/test_dbapi08.py | 4 +- tests/fast/api/test_dbapi09.py | 8 +- tests/fast/api/test_dbapi12.py | 48 +- tests/fast/api/test_dbapi13.py | 4 +- tests/fast/api/test_dbapi_fetch.py | 88 +- tests/fast/api/test_duckdb_connection.py | 86 +- tests/fast/api/test_duckdb_execute.py | 14 +- tests/fast/api/test_duckdb_query.py | 58 +- tests/fast/api/test_explain.py | 24 +- tests/fast/api/test_fsspec.py | 2 +- tests/fast/api/test_insert_into.py | 14 +- tests/fast/api/test_join.py | 22 +- tests/fast/api/test_native_tz.py | 36 +- tests/fast/api/test_query_interrupt.py | 2 +- tests/fast/api/test_read_csv.py | 394 +++---- tests/fast/api/test_relation_to_view.py | 28 +- tests/fast/api/test_streaming_result.py | 12 +- tests/fast/api/test_to_csv.py | 132 +-- tests/fast/api/test_to_parquet.py | 65 +- .../api/test_with_propagating_exceptions.py | 12 +- tests/fast/arrow/parquet_write_roundtrip.py | 38 +- tests/fast/arrow/test_10795.py | 6 +- tests/fast/arrow/test_12384.py | 10 +- tests/fast/arrow/test_14344.py | 2 +- tests/fast/arrow/test_2426.py | 6 +- tests/fast/arrow/test_5547.py | 2 +- tests/fast/arrow/test_6584.py | 2 +- tests/fast/arrow/test_6796.py | 4 +- tests/fast/arrow/test_7652.py | 4 +- tests/fast/arrow/test_7699.py | 2 +- tests/fast/arrow/test_arrow_batch_index.py | 8 +- tests/fast/arrow/test_arrow_binary_view.py | 4 +- tests/fast/arrow/test_arrow_case_sensitive.py | 16 +- tests/fast/arrow/test_arrow_decimal_32_64.py | 20 +- tests/fast/arrow/test_arrow_extensions.py | 123 ++- tests/fast/arrow/test_arrow_fetch.py | 6 +- .../arrow/test_arrow_fetch_recordbatch.py | 26 +- tests/fast/arrow/test_arrow_fixed_binary.py | 6 +- tests/fast/arrow/test_arrow_ipc.py | 8 +- tests/fast/arrow/test_arrow_list.py | 22 +- tests/fast/arrow/test_arrow_offsets.py | 128 +-- tests/fast/arrow/test_arrow_pycapsule.py | 6 +- .../arrow/test_arrow_recordbatchreader.py | 36 +- .../fast/arrow/test_arrow_replacement_scan.py | 20 +- .../fast/arrow/test_arrow_run_end_encoding.py | 150 ++- tests/fast/arrow/test_arrow_scanner.py | 20 +- tests/fast/arrow/test_arrow_string_view.py | 16 +- tests/fast/arrow/test_arrow_types.py | 8 +- tests/fast/arrow/test_arrow_union.py | 14 +- tests/fast/arrow/test_arrow_version_format.py | 12 +- tests/fast/arrow/test_buffer_size_option.py | 2 +- tests/fast/arrow/test_dataset.py | 12 +- tests/fast/arrow/test_date.py | 20 +- tests/fast/arrow/test_dictionary_arrow.py | 56 +- tests/fast/arrow/test_filter_pushdown.py | 179 ++-- tests/fast/arrow/test_integration.py | 52 +- tests/fast/arrow/test_interval.py | 46 +- tests/fast/arrow/test_large_offsets.py | 4 +- tests/fast/arrow/test_large_string.py | 2 +- tests/fast/arrow/test_multiple_reads.py | 4 +- tests/fast/arrow/test_nested_arrow.py | 58 +- tests/fast/arrow/test_parallel.py | 12 +- tests/fast/arrow/test_polars.py | 84 +- tests/fast/arrow/test_progress.py | 16 +- tests/fast/arrow/test_time.py | 70 +- tests/fast/arrow/test_timestamp_timezone.py | 28 +- tests/fast/arrow/test_timestamps.py | 56 +- tests/fast/arrow/test_tpch.py | 8 +- tests/fast/arrow/test_unregister.py | 16 +- tests/fast/arrow/test_view.py | 6 +- tests/fast/numpy/test_numpy_new_path.py | 14 +- tests/fast/pandas/test_2304.py | 72 +- tests/fast/pandas/test_append_df.py | 44 +- tests/fast/pandas/test_bug2281.py | 6 +- tests/fast/pandas/test_bug5922.py | 14 +- tests/fast/pandas/test_copy_on_write.py | 10 +- .../pandas/test_create_table_from_pandas.py | 4 +- tests/fast/pandas/test_date_as_datetime.py | 8 +- tests/fast/pandas/test_datetime_time.py | 20 +- tests/fast/pandas/test_datetime_timestamp.py | 52 +- tests/fast/pandas/test_df_analyze.py | 22 +- .../fast/pandas/test_df_object_resolution.py | 396 +++---- tests/fast/pandas/test_df_recursive_nested.py | 68 +- tests/fast/pandas/test_fetch_df_chunk.py | 14 +- tests/fast/pandas/test_fetch_nested.py | 8 +- .../fast/pandas/test_implicit_pandas_scan.py | 10 +- tests/fast/pandas/test_import_cache.py | 18 +- tests/fast/pandas/test_issue_1767.py | 4 +- tests/fast/pandas/test_limit.py | 10 +- tests/fast/pandas/test_pandas_arrow.py | 94 +- tests/fast/pandas/test_pandas_category.py | 56 +- tests/fast/pandas/test_pandas_enum.py | 8 +- tests/fast/pandas/test_pandas_limit.py | 8 +- tests/fast/pandas/test_pandas_na.py | 22 +- tests/fast/pandas/test_pandas_object.py | 62 +- tests/fast/pandas/test_pandas_string.py | 23 +- tests/fast/pandas/test_pandas_timestamp.py | 16 +- tests/fast/pandas/test_pandas_types.py | 110 +- tests/fast/pandas/test_pandas_unregister.py | 12 +- tests/fast/pandas/test_pandas_update.py | 12 +- .../fast/pandas/test_parallel_pandas_scan.py | 52 +- .../pandas/test_partitioned_pandas_scan.py | 4 +- tests/fast/pandas/test_progress_bar.py | 16 +- .../test_pyarrow_projection_pushdown.py | 4 +- tests/fast/pandas/test_same_name.py | 50 +- tests/fast/pandas/test_stride.py | 22 +- tests/fast/pandas/test_timedelta.py | 20 +- tests/fast/pandas/test_timestamp.py | 34 +- tests/fast/relational_api/test_groupings.py | 6 +- tests/fast/relational_api/test_joins.py | 52 +- tests/fast/relational_api/test_pivot.py | 2 +- .../relational_api/test_rapi_aggregations.py | 4 +- tests/fast/relational_api/test_rapi_close.py | 162 +-- .../relational_api/test_rapi_description.py | 26 +- .../relational_api/test_rapi_functions.py | 4 +- tests/fast/relational_api/test_rapi_query.py | 60 +- .../fast/relational_api/test_rapi_windows.py | 18 +- .../relational_api/test_table_function.py | 6 +- tests/fast/spark/test_replace_column_value.py | 22 +- tests/fast/spark/test_replace_empty_value.py | 32 +- tests/fast/spark/test_spark_catalog.py | 30 +- tests/fast/spark/test_spark_column.py | 20 +- tests/fast/spark/test_spark_dataframe.py | 134 +-- tests/fast/spark/test_spark_dataframe_sort.py | 20 +- .../fast/spark/test_spark_drop_duplicates.py | 34 +- tests/fast/spark/test_spark_except.py | 1 - tests/fast/spark/test_spark_filter.py | 74 +- .../fast/spark/test_spark_functions_array.py | 82 +- .../fast/spark/test_spark_functions_base64.py | 2 +- tests/fast/spark/test_spark_functions_date.py | 46 +- tests/fast/spark/test_spark_functions_hex.py | 4 +- .../test_spark_functions_miscellaneous.py | 30 +- tests/fast/spark/test_spark_functions_null.py | 6 +- .../spark/test_spark_functions_numeric.py | 6 +- .../fast/spark/test_spark_functions_string.py | 164 +-- tests/fast/spark/test_spark_group_by.py | 12 +- tests/fast/spark/test_spark_intersect.py | 2 - tests/fast/spark/test_spark_join.py | 254 ++--- tests/fast/spark/test_spark_order_by.py | 94 +- .../fast/spark/test_spark_pandas_dataframe.py | 12 +- tests/fast/spark/test_spark_readcsv.py | 4 +- tests/fast/spark/test_spark_readjson.py | 4 +- tests/fast/spark/test_spark_readparquet.py | 4 +- tests/fast/spark/test_spark_session.py | 16 +- tests/fast/spark/test_spark_to_csv.py | 40 +- tests/fast/spark/test_spark_transform.py | 12 +- tests/fast/spark/test_spark_types.py | 90 +- tests/fast/spark/test_spark_udf.py | 1 - tests/fast/spark/test_spark_union.py | 32 +- tests/fast/spark/test_spark_union_by_name.py | 32 +- tests/fast/spark/test_spark_with_column.py | 26 +- .../spark/test_spark_with_column_renamed.py | 56 +- tests/fast/spark/test_spark_with_columns.py | 22 +- .../spark/test_spark_with_columns_renamed.py | 38 +- tests/fast/sqlite/test_types.py | 20 +- tests/fast/test_alex_multithread.py | 32 +- tests/fast/test_all_types.py | 298 +++--- tests/fast/test_case_alias.py | 12 +- tests/fast/test_context_manager.py | 2 +- tests/fast/test_duckdb_api.py | 2 +- tests/fast/test_expression.py | 212 ++-- tests/fast/test_filesystem.py | 104 +- tests/fast/test_get_table_names.py | 44 +- tests/fast/test_import_export.py | 12 +- tests/fast/test_insert.py | 16 +- tests/fast/test_many_con_same_file.py | 10 +- tests/fast/test_map.py | 102 +- tests/fast/test_metatransaction.py | 4 +- tests/fast/test_multi_statement.py | 20 +- tests/fast/test_multithread.py | 134 +-- tests/fast/test_non_default_conn.py | 28 +- tests/fast/test_parameter_list.py | 8 +- tests/fast/test_parquet.py | 50 +- tests/fast/test_pypi_cleanup.py | 320 ++++-- tests/fast/test_pytorch.py | 16 +- tests/fast/test_relation.py | 282 ++--- tests/fast/test_relation_dependency_leak.py | 14 +- tests/fast/test_replacement_scan.py | 100 +- tests/fast/test_result.py | 34 +- tests/fast/test_runtime_error.py | 44 +- tests/fast/test_sql_expression.py | 28 +- tests/fast/test_string_annotation.py | 12 +- tests/fast/test_tf.py | 16 +- tests/fast/test_transaction.py | 16 +- tests/fast/test_type.py | 170 +-- tests/fast/test_type_explicit.py | 7 +- tests/fast/test_unicode.py | 8 +- tests/fast/test_value.py | 30 +- tests/fast/test_versioning.py | 46 +- tests/fast/test_windows_abs_path.py | 20 +- tests/fast/types/test_blob.py | 4 +- tests/fast/types/test_datetime_datetime.py | 20 +- tests/fast/types/test_decimal.py | 14 +- tests/fast/types/test_hugeint.py | 6 +- tests/fast/types/test_nan.py | 36 +- tests/fast/types/test_nested.py | 12 +- tests/fast/types/test_numpy.py | 8 +- tests/fast/types/test_object_int.py | 54 +- tests/fast/types/test_time_tz.py | 2 +- tests/fast/types/test_unsigned.py | 6 +- tests/fast/udf/test_null_filtering.py | 78 +- tests/fast/udf/test_remove_function.py | 48 +- tests/fast/udf/test_scalar.py | 114 +- tests/fast/udf/test_scalar_arrow.py | 76 +- tests/fast/udf/test_scalar_native.py | 68 +- tests/fast/udf/test_transactionality.py | 6 +- tests/slow/test_h2oai_arrow.py | 50 +- tests/stubs/test_stubs.py | 8 +- 271 files changed, 6750 insertions(+), 6184 deletions(-) diff --git a/duckdb/__init__.py b/duckdb/__init__.py index b5e994fa..bf50be5b 100644 --- a/duckdb/__init__.py +++ b/duckdb/__init__.py @@ -7,16 +7,15 @@ # duckdb.__version__ returns the version of the distribution package, i.e. the pypi version __version__ = version("duckdb") + # version() is a more human friendly formatted version string of both the distribution package and the bundled duckdb def version(): return f"{__version__} (with duckdb {duckdb_version})" -_exported_symbols = ['__version__', 'version'] -_exported_symbols.extend([ - "typing", - "functional" -]) +_exported_symbols = ["__version__", "version"] + +_exported_symbols.extend(["typing", "functional"]) class DBAPITypeObject: def __init__(self, types: list[typing.DuckDBPyType]) -> None: @@ -69,7 +68,7 @@ def __repr__(self): ExplainType, StatementType, ExpectedResultType, - CSVLineTerminator, + CSVLineTerminator, PythonExceptionHandling, RenderMode, Expression, @@ -81,217 +80,205 @@ def __repr__(self): StarExpression, FunctionExpression, CaseExpression, - SQLExpression + SQLExpression, ) -_exported_symbols.extend([ - "DuckDBPyRelation", - "DuckDBPyConnection", - "ExplainType", - "PythonExceptionHandling", - "Expression", - "ConstantExpression", - "ColumnExpression", - "DefaultExpression", - "CoalesceOperator", - "LambdaExpression", - "StarExpression", - "FunctionExpression", - "CaseExpression", - "SQLExpression" -]) -# These are overloaded twice, we define them inside of C++ so pybind can deal with it -_exported_symbols.extend([ - 'df', - 'arrow' -]) -from _duckdb import ( - df, - arrow +_exported_symbols.extend( + [ + "DuckDBPyRelation", + "DuckDBPyConnection", + "ExplainType", + "PythonExceptionHandling", + "Expression", + "ConstantExpression", + "ColumnExpression", + "DefaultExpression", + "CoalesceOperator", + "LambdaExpression", + "StarExpression", + "FunctionExpression", + "CaseExpression", + "SQLExpression", + ] ) +# These are overloaded twice, we define them inside of C++ so pybind can deal with it +_exported_symbols.extend(["df", "arrow"]) +from _duckdb import df, arrow + # NOTE: this section is generated by tools/pythonpkg/scripts/generate_connection_wrapper_methods.py. # Do not edit this section manually, your changes will be overwritten! # START OF CONNECTION WRAPPER from _duckdb import ( - cursor, - register_filesystem, - unregister_filesystem, - list_filesystems, - filesystem_is_registered, - create_function, - remove_function, - sqltype, - dtype, - type, - array_type, - list_type, - union_type, - string_type, - enum_type, - decimal_type, - struct_type, - row_type, - map_type, - duplicate, - execute, - executemany, - close, - interrupt, - query_progress, - fetchone, - fetchmany, - fetchall, - fetchnumpy, - fetchdf, - fetch_df, - df, - fetch_df_chunk, - pl, - fetch_arrow_table, - arrow, - fetch_record_batch, - torch, - tf, - begin, - commit, - rollback, - checkpoint, - append, - register, - unregister, - table, - view, - values, - table_function, - read_json, - extract_statements, - sql, - query, - from_query, - read_csv, - from_csv_auto, - from_df, - from_arrow, - from_parquet, - read_parquet, - from_parquet, - read_parquet, - get_table_names, - install_extension, - load_extension, - project, - distinct, - write_csv, - aggregate, - alias, - filter, - limit, - order, - query_df, - description, - rowcount, + cursor, + register_filesystem, + unregister_filesystem, + list_filesystems, + filesystem_is_registered, + create_function, + remove_function, + sqltype, + dtype, + type, + array_type, + list_type, + union_type, + string_type, + enum_type, + decimal_type, + struct_type, + row_type, + map_type, + duplicate, + execute, + executemany, + close, + interrupt, + query_progress, + fetchone, + fetchmany, + fetchall, + fetchnumpy, + fetchdf, + fetch_df, + df, + fetch_df_chunk, + pl, + fetch_arrow_table, + arrow, + fetch_record_batch, + torch, + tf, + begin, + commit, + rollback, + checkpoint, + append, + register, + unregister, + table, + view, + values, + table_function, + read_json, + extract_statements, + sql, + query, + from_query, + read_csv, + from_csv_auto, + from_df, + from_arrow, + from_parquet, + read_parquet, + from_parquet, + read_parquet, + get_table_names, + install_extension, + load_extension, + project, + distinct, + write_csv, + aggregate, + alias, + filter, + limit, + order, + query_df, + description, + rowcount, ) -_exported_symbols.extend([ - 'cursor', - 'register_filesystem', - 'unregister_filesystem', - 'list_filesystems', - 'filesystem_is_registered', - 'create_function', - 'remove_function', - 'sqltype', - 'dtype', - 'type', - 'array_type', - 'list_type', - 'union_type', - 'string_type', - 'enum_type', - 'decimal_type', - 'struct_type', - 'row_type', - 'map_type', - 'duplicate', - 'execute', - 'executemany', - 'close', - 'interrupt', - 'query_progress', - 'fetchone', - 'fetchmany', - 'fetchall', - 'fetchnumpy', - 'fetchdf', - 'fetch_df', - 'df', - 'fetch_df_chunk', - 'pl', - 'fetch_arrow_table', - 'arrow', - 'fetch_record_batch', - 'torch', - 'tf', - 'begin', - 'commit', - 'rollback', - 'checkpoint', - 'append', - 'register', - 'unregister', - 'table', - 'view', - 'values', - 'table_function', - 'read_json', - 'extract_statements', - 'sql', - 'query', - 'from_query', - 'read_csv', - 'from_csv_auto', - 'from_df', - 'from_arrow', - 'from_parquet', - 'read_parquet', - 'from_parquet', - 'read_parquet', - 'get_table_names', - 'install_extension', - 'load_extension', - 'project', - 'distinct', - 'write_csv', - 'aggregate', - 'alias', - 'filter', - 'limit', - 'order', - 'query_df', - 'description', - 'rowcount', -]) +_exported_symbols.extend( + [ + "cursor", + "register_filesystem", + "unregister_filesystem", + "list_filesystems", + "filesystem_is_registered", + "create_function", + "remove_function", + "sqltype", + "dtype", + "type", + "array_type", + "list_type", + "union_type", + "string_type", + "enum_type", + "decimal_type", + "struct_type", + "row_type", + "map_type", + "duplicate", + "execute", + "executemany", + "close", + "interrupt", + "query_progress", + "fetchone", + "fetchmany", + "fetchall", + "fetchnumpy", + "fetchdf", + "fetch_df", + "df", + "fetch_df_chunk", + "pl", + "fetch_arrow_table", + "arrow", + "fetch_record_batch", + "torch", + "tf", + "begin", + "commit", + "rollback", + "checkpoint", + "append", + "register", + "unregister", + "table", + "view", + "values", + "table_function", + "read_json", + "extract_statements", + "sql", + "query", + "from_query", + "read_csv", + "from_csv_auto", + "from_df", + "from_arrow", + "from_parquet", + "read_parquet", + "from_parquet", + "read_parquet", + "get_table_names", + "install_extension", + "load_extension", + "project", + "distinct", + "write_csv", + "aggregate", + "alias", + "filter", + "limit", + "order", + "query_df", + "description", + "rowcount", + ] +) # END OF CONNECTION WRAPPER # Enums -from _duckdb import ( - ANALYZE, - DEFAULT, - RETURN_NULL, - STANDARD, - COLUMNS, - ROWS -) -_exported_symbols.extend([ - "ANALYZE", - "DEFAULT", - "RETURN_NULL", - "STANDARD" -]) +from _duckdb import ANALYZE, DEFAULT, RETURN_NULL, STANDARD, COLUMNS, ROWS + +_exported_symbols.extend(["ANALYZE", "DEFAULT", "RETURN_NULL", "STANDARD"]) # read-only properties @@ -310,25 +297,28 @@ def __repr__(self): string_const, threadsafety, token_type, - tokenize + tokenize, +) + +_exported_symbols.extend( + [ + "__standard_vector_size__", + "__interactive__", + "__jupyter__", + "__formatted_python_version__", + "apilevel", + "comment", + "identifier", + "keyword", + "numeric_const", + "operator", + "paramstyle", + "string_const", + "threadsafety", + "token_type", + "tokenize", + ] ) -_exported_symbols.extend([ - "__standard_vector_size__", - "__interactive__", - "__jupyter__", - "__formatted_python_version__", - "apilevel", - "comment", - "identifier", - "keyword", - "numeric_const", - "operator", - "paramstyle", - "string_const", - "threadsafety", - "token_type", - "tokenize" -]) from _duckdb import ( @@ -337,11 +327,13 @@ def __repr__(self): set_default_connection, ) -_exported_symbols.extend([ - "connect", - "default_connection", - "set_default_connection", -]) +_exported_symbols.extend( + [ + "connect", + "default_connection", + "set_default_connection", + ] +) # Exceptions from _duckdb import ( @@ -374,40 +366,43 @@ def __repr__(self): ParserException, SyntaxException, SequenceException, - Warning + Warning, +) + +_exported_symbols.extend( + [ + "Error", + "DataError", + "ConversionException", + "OutOfRangeException", + "TypeMismatchException", + "FatalException", + "IntegrityError", + "ConstraintException", + "InternalError", + "InternalException", + "InterruptException", + "NotSupportedError", + "NotImplementedException", + "OperationalError", + "ConnectionException", + "IOException", + "HTTPException", + "OutOfMemoryException", + "SerializationException", + "TransactionException", + "PermissionException", + "ProgrammingError", + "BinderException", + "CatalogException", + "InvalidInputException", + "InvalidTypeException", + "ParserException", + "SyntaxException", + "SequenceException", + "Warning", + ] ) -_exported_symbols.extend([ - "Error", - "DataError", - "ConversionException", - "OutOfRangeException", - "TypeMismatchException", - "FatalException", - "IntegrityError", - "ConstraintException", - "InternalError", - "InternalException", - "InterruptException", - "NotSupportedError", - "NotImplementedException", - "OperationalError", - "ConnectionException", - "IOException", - "HTTPException", - "OutOfMemoryException", - "SerializationException", - "TransactionException", - "PermissionException", - "ProgrammingError", - "BinderException", - "CatalogException", - "InvalidInputException", - "InvalidTypeException", - "ParserException", - "SyntaxException", - "SequenceException", - "Warning" -]) # Value from duckdb.value.constant import ( @@ -441,35 +436,37 @@ def __repr__(self): TimeTimeZoneValue, ) -_exported_symbols.extend([ - "Value", - "NullValue", - "BooleanValue", - "UnsignedBinaryValue", - "UnsignedShortValue", - "UnsignedIntegerValue", - "UnsignedLongValue", - "BinaryValue", - "ShortValue", - "IntegerValue", - "LongValue", - "HugeIntegerValue", - "FloatValue", - "DoubleValue", - "DecimalValue", - "StringValue", - "UUIDValue", - "BitValue", - "BlobValue", - "DateValue", - "IntervalValue", - "TimestampValue", - "TimestampSecondValue", - "TimestampMilisecondValue", - "TimestampNanosecondValue", - "TimestampTimeZoneValue", - "TimeValue", - "TimeTimeZoneValue", -]) +_exported_symbols.extend( + [ + "Value", + "NullValue", + "BooleanValue", + "UnsignedBinaryValue", + "UnsignedShortValue", + "UnsignedIntegerValue", + "UnsignedLongValue", + "BinaryValue", + "ShortValue", + "IntegerValue", + "LongValue", + "HugeIntegerValue", + "FloatValue", + "DoubleValue", + "DecimalValue", + "StringValue", + "UUIDValue", + "BitValue", + "BlobValue", + "DateValue", + "IntervalValue", + "TimestampValue", + "TimestampSecondValue", + "TimestampMilisecondValue", + "TimestampNanosecondValue", + "TimestampTimeZoneValue", + "TimeValue", + "TimeTimeZoneValue", + ] +) __all__ = _exported_symbols diff --git a/duckdb/__init__.pyi b/duckdb/__init__.pyi index 8f27e5e3..0c597d11 100644 --- a/duckdb/__init__.pyi +++ b/duckdb/__init__.pyi @@ -41,6 +41,7 @@ from duckdb.value.constant import ( # We also run this in python3.7, where this is needed from typing_extensions import Literal + # stubgen override - missing import of Set from typing import Any, ClassVar, Set, Optional, Callable from io import StringIO, TextIOBase @@ -48,11 +49,13 @@ from pathlib import Path from typing import overload, Dict, List, Union, Tuple import pandas + # stubgen override - unfortunately we need this for version checks import sys import fsspec import pyarrow.lib import polars + # stubgen override - This should probably not be exposed apilevel: str comment: token_type @@ -78,15 +81,10 @@ __jupyter__: bool __formatted_python_version__: str class BinderException(ProgrammingError): ... - class CatalogException(ProgrammingError): ... - class ConnectionException(OperationalError): ... - class ConstraintException(IntegrityError): ... - class ConversionException(DataError): ... - class DataError(Error): ... class ExplainType: @@ -204,46 +202,37 @@ class Statement: class Expression: def __init__(self, *args, **kwargs) -> None: ... def __neg__(self) -> "Expression": ... - def __add__(self, expr: "Expression") -> "Expression": ... def __radd__(self, expr: "Expression") -> "Expression": ... - def __sub__(self, expr: "Expression") -> "Expression": ... def __rsub__(self, expr: "Expression") -> "Expression": ... - def __mul__(self, expr: "Expression") -> "Expression": ... def __rmul__(self, expr: "Expression") -> "Expression": ... - def __div__(self, expr: "Expression") -> "Expression": ... def __rdiv__(self, expr: "Expression") -> "Expression": ... - def __truediv__(self, expr: "Expression") -> "Expression": ... def __rtruediv__(self, expr: "Expression") -> "Expression": ... - def __floordiv__(self, expr: "Expression") -> "Expression": ... def __rfloordiv__(self, expr: "Expression") -> "Expression": ... - def __mod__(self, expr: "Expression") -> "Expression": ... def __rmod__(self, expr: "Expression") -> "Expression": ... - def __pow__(self, expr: "Expression") -> "Expression": ... def __rpow__(self, expr: "Expression") -> "Expression": ... - def __and__(self, expr: "Expression") -> "Expression": ... def __rand__(self, expr: "Expression") -> "Expression": ... def __or__(self, expr: "Expression") -> "Expression": ... def __ror__(self, expr: "Expression") -> "Expression": ... def __invert__(self) -> "Expression": ... - - def __eq__(# type: ignore[override] - self, expr: "Expression") -> "Expression": ... - def __ne__(# type: ignore[override] - self, expr: "Expression") -> "Expression": ... + def __eq__( # type: ignore[override] + self, expr: "Expression" + ) -> "Expression": ... + def __ne__( # type: ignore[override] + self, expr: "Expression" + ) -> "Expression": ... def __gt__(self, expr: "Expression") -> "Expression": ... def __ge__(self, expr: "Expression") -> "Expression": ... def __lt__(self, expr: "Expression") -> "Expression": ... def __le__(self, expr: "Expression") -> "Expression": ... - def show(self) -> None: ... def __repr__(self) -> str: ... def get_name(self) -> str: ... @@ -291,7 +280,18 @@ class DuckDBPyConnection: def unregister_filesystem(self, name: str) -> None: ... def list_filesystems(self) -> list: ... def filesystem_is_registered(self, name: str) -> bool: ... - def create_function(self, name: str, function: function, parameters: Optional[List[DuckDBPyType]] = None, return_type: Optional[DuckDBPyType] = None, *, type: Optional[PythonUDFType] = PythonUDFType.NATIVE, null_handling: Optional[FunctionNullHandling] = FunctionNullHandling.DEFAULT, exception_handling: Optional[PythonExceptionHandling] = PythonExceptionHandling.DEFAULT, side_effects: bool = False) -> DuckDBPyConnection: ... + def create_function( + self, + name: str, + function: function, + parameters: Optional[List[DuckDBPyType]] = None, + return_type: Optional[DuckDBPyType] = None, + *, + type: Optional[PythonUDFType] = PythonUDFType.NATIVE, + null_handling: Optional[FunctionNullHandling] = FunctionNullHandling.DEFAULT, + exception_handling: Optional[PythonExceptionHandling] = PythonExceptionHandling.DEFAULT, + side_effects: bool = False, + ) -> DuckDBPyConnection: ... def remove_function(self, name: str) -> DuckDBPyConnection: ... def sqltype(self, type_str: str) -> DuckDBPyType: ... def dtype(self, type_str: str) -> DuckDBPyType: ... @@ -334,21 +334,152 @@ class DuckDBPyConnection: def unregister(self, view_name: str) -> DuckDBPyConnection: ... def table(self, table_name: str) -> DuckDBPyRelation: ... def view(self, view_name: str) -> DuckDBPyRelation: ... - def values(self, *args: Union[List[Any],Expression, Tuple[Expression]]) -> DuckDBPyRelation: ... + def values(self, *args: Union[List[Any], Expression, Tuple[Expression]]) -> DuckDBPyRelation: ... def table_function(self, name: str, parameters: object = None) -> DuckDBPyRelation: ... - def read_json(self, path_or_buffer: Union[str, StringIO, TextIOBase], *, columns: Optional[Dict[str,str]] = None, sample_size: Optional[int] = None, maximum_depth: Optional[int] = None, records: Optional[str] = None, format: Optional[str] = None, date_format: Optional[str] = None, timestamp_format: Optional[str] = None, compression: Optional[str] = None, maximum_object_size: Optional[int] = None, ignore_errors: Optional[bool] = None, convert_strings_to_integers: Optional[bool] = None, field_appearance_threshold: Optional[float] = None, map_inference_threshold: Optional[int] = None, maximum_sample_files: Optional[int] = None, filename: Optional[bool | str] = None, hive_partitioning: Optional[bool] = None, union_by_name: Optional[bool] = None, hive_types: Optional[Dict[str, str]] = None, hive_types_autocast: Optional[bool] = None) -> DuckDBPyRelation: ... + def read_json( + self, + path_or_buffer: Union[str, StringIO, TextIOBase], + *, + columns: Optional[Dict[str, str]] = None, + sample_size: Optional[int] = None, + maximum_depth: Optional[int] = None, + records: Optional[str] = None, + format: Optional[str] = None, + date_format: Optional[str] = None, + timestamp_format: Optional[str] = None, + compression: Optional[str] = None, + maximum_object_size: Optional[int] = None, + ignore_errors: Optional[bool] = None, + convert_strings_to_integers: Optional[bool] = None, + field_appearance_threshold: Optional[float] = None, + map_inference_threshold: Optional[int] = None, + maximum_sample_files: Optional[int] = None, + filename: Optional[bool | str] = None, + hive_partitioning: Optional[bool] = None, + union_by_name: Optional[bool] = None, + hive_types: Optional[Dict[str, str]] = None, + hive_types_autocast: Optional[bool] = None, + ) -> DuckDBPyRelation: ... def extract_statements(self, query: str) -> List[Statement]: ... def sql(self, query: str, *, alias: str = "", params: object = None) -> DuckDBPyRelation: ... def query(self, query: str, *, alias: str = "", params: object = None) -> DuckDBPyRelation: ... def from_query(self, query: str, *, alias: str = "", params: object = None) -> DuckDBPyRelation: ... - def read_csv(self, path_or_buffer: Union[str, StringIO, TextIOBase], *, header: Optional[bool | int] = None, compression: Optional[str] = None, sep: Optional[str] = None, delimiter: Optional[str] = None, dtype: Optional[Dict[str, str] | List[str]] = None, na_values: Optional[str| List[str]] = None, skiprows: Optional[int] = None, quotechar: Optional[str] = None, escapechar: Optional[str] = None, encoding: Optional[str] = None, parallel: Optional[bool] = None, date_format: Optional[str] = None, timestamp_format: Optional[str] = None, sample_size: Optional[int] = None, all_varchar: Optional[bool] = None, normalize_names: Optional[bool] = None, null_padding: Optional[bool] = None, names: Optional[List[str]] = None, lineterminator: Optional[str] = None, columns: Optional[Dict[str, str]] = None, auto_type_candidates: Optional[List[str]] = None, max_line_size: Optional[int] = None, ignore_errors: Optional[bool] = None, store_rejects: Optional[bool] = None, rejects_table: Optional[str] = None, rejects_scan: Optional[str] = None, rejects_limit: Optional[int] = None, force_not_null: Optional[List[str]] = None, buffer_size: Optional[int] = None, decimal: Optional[str] = None, allow_quoted_nulls: Optional[bool] = None, filename: Optional[bool | str] = None, hive_partitioning: Optional[bool] = None, union_by_name: Optional[bool] = None, hive_types: Optional[Dict[str, str]] = None, hive_types_autocast: Optional[bool] = None) -> DuckDBPyRelation: ... - def from_csv_auto(self, path_or_buffer: Union[str, StringIO, TextIOBase], *, header: Optional[bool | int] = None, compression: Optional[str] = None, sep: Optional[str] = None, delimiter: Optional[str] = None, dtype: Optional[Dict[str, str] | List[str]] = None, na_values: Optional[str| List[str]] = None, skiprows: Optional[int] = None, quotechar: Optional[str] = None, escapechar: Optional[str] = None, encoding: Optional[str] = None, parallel: Optional[bool] = None, date_format: Optional[str] = None, timestamp_format: Optional[str] = None, sample_size: Optional[int] = None, all_varchar: Optional[bool] = None, normalize_names: Optional[bool] = None, null_padding: Optional[bool] = None, names: Optional[List[str]] = None, lineterminator: Optional[str] = None, columns: Optional[Dict[str, str]] = None, auto_type_candidates: Optional[List[str]] = None, max_line_size: Optional[int] = None, ignore_errors: Optional[bool] = None, store_rejects: Optional[bool] = None, rejects_table: Optional[str] = None, rejects_scan: Optional[str] = None, rejects_limit: Optional[int] = None, force_not_null: Optional[List[str]] = None, buffer_size: Optional[int] = None, decimal: Optional[str] = None, allow_quoted_nulls: Optional[bool] = None, filename: Optional[bool | str] = None, hive_partitioning: Optional[bool] = None, union_by_name: Optional[bool] = None, hive_types: Optional[Dict[str, str]] = None, hive_types_autocast: Optional[bool] = None) -> DuckDBPyRelation: ... + def read_csv( + self, + path_or_buffer: Union[str, StringIO, TextIOBase], + *, + header: Optional[bool | int] = None, + compression: Optional[str] = None, + sep: Optional[str] = None, + delimiter: Optional[str] = None, + dtype: Optional[Dict[str, str] | List[str]] = None, + na_values: Optional[str | List[str]] = None, + skiprows: Optional[int] = None, + quotechar: Optional[str] = None, + escapechar: Optional[str] = None, + encoding: Optional[str] = None, + parallel: Optional[bool] = None, + date_format: Optional[str] = None, + timestamp_format: Optional[str] = None, + sample_size: Optional[int] = None, + all_varchar: Optional[bool] = None, + normalize_names: Optional[bool] = None, + null_padding: Optional[bool] = None, + names: Optional[List[str]] = None, + lineterminator: Optional[str] = None, + columns: Optional[Dict[str, str]] = None, + auto_type_candidates: Optional[List[str]] = None, + max_line_size: Optional[int] = None, + ignore_errors: Optional[bool] = None, + store_rejects: Optional[bool] = None, + rejects_table: Optional[str] = None, + rejects_scan: Optional[str] = None, + rejects_limit: Optional[int] = None, + force_not_null: Optional[List[str]] = None, + buffer_size: Optional[int] = None, + decimal: Optional[str] = None, + allow_quoted_nulls: Optional[bool] = None, + filename: Optional[bool | str] = None, + hive_partitioning: Optional[bool] = None, + union_by_name: Optional[bool] = None, + hive_types: Optional[Dict[str, str]] = None, + hive_types_autocast: Optional[bool] = None, + ) -> DuckDBPyRelation: ... + def from_csv_auto( + self, + path_or_buffer: Union[str, StringIO, TextIOBase], + *, + header: Optional[bool | int] = None, + compression: Optional[str] = None, + sep: Optional[str] = None, + delimiter: Optional[str] = None, + dtype: Optional[Dict[str, str] | List[str]] = None, + na_values: Optional[str | List[str]] = None, + skiprows: Optional[int] = None, + quotechar: Optional[str] = None, + escapechar: Optional[str] = None, + encoding: Optional[str] = None, + parallel: Optional[bool] = None, + date_format: Optional[str] = None, + timestamp_format: Optional[str] = None, + sample_size: Optional[int] = None, + all_varchar: Optional[bool] = None, + normalize_names: Optional[bool] = None, + null_padding: Optional[bool] = None, + names: Optional[List[str]] = None, + lineterminator: Optional[str] = None, + columns: Optional[Dict[str, str]] = None, + auto_type_candidates: Optional[List[str]] = None, + max_line_size: Optional[int] = None, + ignore_errors: Optional[bool] = None, + store_rejects: Optional[bool] = None, + rejects_table: Optional[str] = None, + rejects_scan: Optional[str] = None, + rejects_limit: Optional[int] = None, + force_not_null: Optional[List[str]] = None, + buffer_size: Optional[int] = None, + decimal: Optional[str] = None, + allow_quoted_nulls: Optional[bool] = None, + filename: Optional[bool | str] = None, + hive_partitioning: Optional[bool] = None, + union_by_name: Optional[bool] = None, + hive_types: Optional[Dict[str, str]] = None, + hive_types_autocast: Optional[bool] = None, + ) -> DuckDBPyRelation: ... def from_df(self, df: pandas.DataFrame) -> DuckDBPyRelation: ... def from_arrow(self, arrow_object: object) -> DuckDBPyRelation: ... - def from_parquet(self, file_glob: str, binary_as_string: bool = False, *, file_row_number: bool = False, filename: bool = False, hive_partitioning: bool = False, union_by_name: bool = False, compression: Optional[str] = None) -> DuckDBPyRelation: ... - def read_parquet(self, file_glob: str, binary_as_string: bool = False, *, file_row_number: bool = False, filename: bool = False, hive_partitioning: bool = False, union_by_name: bool = False, compression: Optional[str] = None) -> DuckDBPyRelation: ... + def from_parquet( + self, + file_glob: str, + binary_as_string: bool = False, + *, + file_row_number: bool = False, + filename: bool = False, + hive_partitioning: bool = False, + union_by_name: bool = False, + compression: Optional[str] = None, + ) -> DuckDBPyRelation: ... + def read_parquet( + self, + file_glob: str, + binary_as_string: bool = False, + *, + file_row_number: bool = False, + filename: bool = False, + hive_partitioning: bool = False, + union_by_name: bool = False, + compression: Optional[str] = None, + ) -> DuckDBPyRelation: ... def get_table_names(self, query: str, *, qualified: bool = False) -> Set[str]: ... - def install_extension(self, extension: str, *, force_install: bool = False, repository: Optional[str] = None, repository_url: Optional[str] = None, version: Optional[str] = None) -> None: ... + def install_extension( + self, + extension: str, + *, + force_install: bool = False, + repository: Optional[str] = None, + repository_url: Optional[str] = None, + version: Optional[str] = None, + ) -> None: ... def load_extension(self, extension: str) -> None: ... # END OF CONNECTION METHODS @@ -359,19 +490,41 @@ class DuckDBPyRelation: def __init__(self, *args, **kwargs) -> None: ... def __contains__(self, name: str) -> bool: ... def aggregate(self, aggr_expr: str, group_expr: str = ...) -> DuckDBPyRelation: ... - def apply(self, function_name: str, function_aggr: str, group_expr: str = ..., function_parameter: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - + def apply( + self, + function_name: str, + function_aggr: str, + group_expr: str = ..., + function_parameter: str = ..., + projected_columns: str = ..., + ) -> DuckDBPyRelation: ... def cume_dist(self, window_spec: str, projected_columns: str = ...) -> DuckDBPyRelation: ... def dense_rank(self, window_spec: str, projected_columns: str = ...) -> DuckDBPyRelation: ... def percent_rank(self, window_spec: str, projected_columns: str = ...) -> DuckDBPyRelation: ... def rank(self, window_spec: str, projected_columns: str = ...) -> DuckDBPyRelation: ... def rank_dense(self, window_spec: str, projected_columns: str = ...) -> DuckDBPyRelation: ... def row_number(self, window_spec: str, projected_columns: str = ...) -> DuckDBPyRelation: ... - - def lag(self, column: str, window_spec: str, offset: int, default_value: str, ignore_nulls: bool, projected_columns: str = ...) -> DuckDBPyRelation: ... - def lead(self, column: str, window_spec: str, offset: int, default_value: str, ignore_nulls: bool, projected_columns: str = ...) -> DuckDBPyRelation: ... - def nth_value(self, column: str, window_spec: str, offset: int, ignore_nulls: bool = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - + def lag( + self, + column: str, + window_spec: str, + offset: int, + default_value: str, + ignore_nulls: bool, + projected_columns: str = ..., + ) -> DuckDBPyRelation: ... + def lead( + self, + column: str, + window_spec: str, + offset: int, + default_value: str, + ignore_nulls: bool, + projected_columns: str = ..., + ) -> DuckDBPyRelation: ... + def nth_value( + self, column: str, window_spec: str, offset: int, ignore_nulls: bool = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... def value_counts(self, column: str, groups: str = ...) -> DuckDBPyRelation: ... def geomean(self, column: str, groups: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... def first(self, column: str, groups: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... @@ -380,41 +533,119 @@ class DuckDBPyRelation: def last_value(self, column: str, window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... def mode(self, aggregation_columns: str, group_columns: str = ...) -> DuckDBPyRelation: ... def n_tile(self, window_spec: str, num_buckets: int, projected_columns: str = ...) -> DuckDBPyRelation: ... - def quantile_cont(self, column: str, q: Union[float, List[float]] = ..., groups: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def quantile_disc(self, column: str, q: Union[float, List[float]] = ..., groups: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... + def quantile_cont( + self, column: str, q: Union[float, List[float]] = ..., groups: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def quantile_disc( + self, column: str, q: Union[float, List[float]] = ..., groups: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... def sum(self, sum_aggr: str, group_expr: str = ...) -> DuckDBPyRelation: ... - - def any_value(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def arg_max(self, arg_column: str, value_column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def arg_min(self, arg_column: str, value_column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def avg(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def bit_and(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def bit_or(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def bit_xor(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def bitstring_agg(self, column: str, min: Optional[int], max: Optional[int], groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def bool_and(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def bool_or(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def count(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def favg(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def fsum(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def histogram(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def max(self, max_aggr: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def min(self, min_aggr: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def mean(self, mean_aggr: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def median(self, median_aggr: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def product(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def quantile(self, q: str, quantile_aggr: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def std(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def stddev(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def stddev_pop(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def stddev_samp(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def string_agg(self, column: str, sep: str = ..., groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def var(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def var_pop(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def var_samp(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def variance(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def list(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - + def any_value( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def arg_max( + self, + arg_column: str, + value_column: str, + groups: str = ..., + window_spec: str = ..., + projected_columns: str = ..., + ) -> DuckDBPyRelation: ... + def arg_min( + self, + arg_column: str, + value_column: str, + groups: str = ..., + window_spec: str = ..., + projected_columns: str = ..., + ) -> DuckDBPyRelation: ... + def avg( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def bit_and( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def bit_or( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def bit_xor( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def bitstring_agg( + self, + column: str, + min: Optional[int], + max: Optional[int], + groups: str = ..., + window_spec: str = ..., + projected_columns: str = ..., + ) -> DuckDBPyRelation: ... + def bool_and( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def bool_or( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def count( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def favg( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def fsum( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def histogram( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def max( + self, max_aggr: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def min( + self, min_aggr: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def mean( + self, mean_aggr: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def median( + self, median_aggr: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def product( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def quantile( + self, q: str, quantile_aggr: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def std( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def stddev( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def stddev_pop( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def stddev_samp( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def string_agg( + self, column: str, sep: str = ..., groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def var( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def var_pop( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def var_samp( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def variance( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def list( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... def arrow(self, batch_size: int = ...) -> pyarrow.lib.RecordBatchReader: ... def __arrow_c_stream__(self, requested_schema: Optional[object] = None) -> object: ... def create(self, table_name: str) -> None: ... @@ -424,7 +655,7 @@ class DuckDBPyRelation: def distinct(self) -> DuckDBPyRelation: ... def except_(self, other_rel: DuckDBPyRelation) -> DuckDBPyRelation: ... def execute(self, *args, **kwargs) -> DuckDBPyRelation: ... - def explain(self, type: Optional[Literal['standard', 'analyze'] | int] = 'standard') -> str: ... + def explain(self, type: Optional[Literal["standard", "analyze"] | int] = "standard") -> str: ... def fetchall(self) -> List[Any]: ... def fetchmany(self, size: int = ...) -> List[Any]: ... def fetchnumpy(self) -> dict: ... @@ -437,7 +668,9 @@ class DuckDBPyRelation: def update(self, set: Dict[str, Expression], condition: Optional[Expression] = None) -> None: ... def insert_into(self, table_name: str) -> None: ... def intersect(self, other_rel: DuckDBPyRelation) -> DuckDBPyRelation: ... - def join(self, other_rel: DuckDBPyRelation, condition: Union[str, Expression], how: str = ...) -> DuckDBPyRelation: ... + def join( + self, other_rel: DuckDBPyRelation, condition: Union[str, Expression], how: str = ... + ) -> DuckDBPyRelation: ... def cross(self, other_rel: DuckDBPyRelation) -> DuckDBPyRelation: ... def limit(self, n: int, offset: int = ...) -> DuckDBPyRelation: ... def map(self, map_function: function, schema: Optional[Dict[str, DuckDBPyType]] = None) -> DuckDBPyRelation: ... @@ -448,46 +681,55 @@ class DuckDBPyRelation: def pl(self, rows_per_batch: int = ..., connection: DuckDBPyConnection = ...) -> polars.DataFrame: ... def query(self, virtual_table_name: str, sql_query: str) -> DuckDBPyRelation: ... def record_batch(self, batch_size: int = ...) -> pyarrow.lib.RecordBatchReader: ... - def fetch_record_batch(self, rows_per_batch: int = 1000000, *, connection: DuckDBPyConnection = ...) -> pyarrow.lib.RecordBatchReader: ... + def fetch_record_batch( + self, rows_per_batch: int = 1000000, *, connection: DuckDBPyConnection = ... + ) -> pyarrow.lib.RecordBatchReader: ... def select_types(self, types: List[Union[str, DuckDBPyType]]) -> DuckDBPyRelation: ... def select_dtypes(self, types: List[Union[str, DuckDBPyType]]) -> DuckDBPyRelation: ... def set_alias(self, alias: str) -> DuckDBPyRelation: ... - def show(self, max_width: Optional[int] = None, max_rows: Optional[int] = None, max_col_width: Optional[int] = None, null_value: Optional[str] = None, render_mode: Optional[RenderMode] = None) -> None: ... + def show( + self, + max_width: Optional[int] = None, + max_rows: Optional[int] = None, + max_col_width: Optional[int] = None, + null_value: Optional[str] = None, + render_mode: Optional[RenderMode] = None, + ) -> None: ... def sql_query(self) -> str: ... def to_arrow_table(self, batch_size: int = ...) -> pyarrow.lib.Table: ... def to_csv( - self, - file_name: str, - sep: Optional[str] = None, - na_rep: Optional[str] = None, - header: Optional[bool] = None, - quotechar: Optional[str] = None, - escapechar: Optional[str] = None, - date_format: Optional[str] = None, - timestamp_format: Optional[str] = None, - quoting: Optional[str | int] = None, - encoding: Optional[str] = None, - compression: Optional[str] = None, - write_partition_columns: Optional[bool] = None, - overwrite: Optional[bool] = None, - per_thread_output: Optional[bool] = None, - use_tmp_file: Optional[bool] = None, - partition_by: Optional[List[str]] = None + self, + file_name: str, + sep: Optional[str] = None, + na_rep: Optional[str] = None, + header: Optional[bool] = None, + quotechar: Optional[str] = None, + escapechar: Optional[str] = None, + date_format: Optional[str] = None, + timestamp_format: Optional[str] = None, + quoting: Optional[str | int] = None, + encoding: Optional[str] = None, + compression: Optional[str] = None, + write_partition_columns: Optional[bool] = None, + overwrite: Optional[bool] = None, + per_thread_output: Optional[bool] = None, + use_tmp_file: Optional[bool] = None, + partition_by: Optional[List[str]] = None, ) -> None: ... def to_df(self, *args, **kwargs) -> pandas.DataFrame: ... def to_parquet( - self, - file_name: str, - compression: Optional[str] = None, - field_ids: Optional[dict | str] = None, - row_group_size_bytes: Optional[int | str] = None, - row_group_size: Optional[int] = None, - partition_by: Optional[List[str]] = None, - write_partition_columns: Optional[bool] = None, - overwrite: Optional[bool] = None, - per_thread_output: Optional[bool] = None, - use_tmp_file: Optional[bool] = None, - append: Optional[bool] = None + self, + file_name: str, + compression: Optional[str] = None, + field_ids: Optional[dict | str] = None, + row_group_size_bytes: Optional[int | str] = None, + row_group_size: Optional[int] = None, + partition_by: Optional[List[str]] = None, + write_partition_columns: Optional[bool] = None, + overwrite: Optional[bool] = None, + per_thread_output: Optional[bool] = None, + use_tmp_file: Optional[bool] = None, + append: Optional[bool] = None, ) -> None: ... def fetch_df_chunk(self, vectors_per_chunk: int = 1, *, date_as_object: bool = False) -> pandas.DataFrame: ... def to_table(self, table_name: str) -> None: ... @@ -497,37 +739,37 @@ class DuckDBPyRelation: def union(self, union_rel: DuckDBPyRelation) -> DuckDBPyRelation: ... def unique(self, unique_aggr: str) -> DuckDBPyRelation: ... def write_csv( - self, - file_name: str, - sep: Optional[str] = None, - na_rep: Optional[str] = None, - header: Optional[bool] = None, - quotechar: Optional[str] = None, - escapechar: Optional[str] = None, - date_format: Optional[str] = None, - timestamp_format: Optional[str] = None, - quoting: Optional[str | int] = None, - encoding: Optional[str] = None, - compression: Optional[str] = None, - write_partition_columns: Optional[bool] = None, - overwrite: Optional[bool] = None, - per_thread_output: Optional[bool] = None, - use_tmp_file: Optional[bool] = None, - partition_by: Optional[List[str]] = None + self, + file_name: str, + sep: Optional[str] = None, + na_rep: Optional[str] = None, + header: Optional[bool] = None, + quotechar: Optional[str] = None, + escapechar: Optional[str] = None, + date_format: Optional[str] = None, + timestamp_format: Optional[str] = None, + quoting: Optional[str | int] = None, + encoding: Optional[str] = None, + compression: Optional[str] = None, + write_partition_columns: Optional[bool] = None, + overwrite: Optional[bool] = None, + per_thread_output: Optional[bool] = None, + use_tmp_file: Optional[bool] = None, + partition_by: Optional[List[str]] = None, ) -> None: ... def write_parquet( - self, - file_name: str, - compression: Optional[str] = None, - field_ids: Optional[dict | str] = None, - row_group_size_bytes: Optional[int | str] = None, - row_group_size: Optional[int] = None, - partition_by: Optional[List[str]] = None, - write_partition_columns: Optional[bool] = None, - overwrite: Optional[bool] = None, - per_thread_output: Optional[bool] = None, - use_tmp_file: Optional[bool] = None, - append: Optional[bool] = None + self, + file_name: str, + compression: Optional[str] = None, + field_ids: Optional[dict | str] = None, + row_group_size_bytes: Optional[int | str] = None, + row_group_size: Optional[int] = None, + partition_by: Optional[List[str]] = None, + write_partition_columns: Optional[bool] = None, + overwrite: Optional[bool] = None, + per_thread_output: Optional[bool] = None, + use_tmp_file: Optional[bool] = None, + append: Optional[bool] = None, ) -> None: ... def __len__(self) -> int: ... @property @@ -546,7 +788,6 @@ class DuckDBPyRelation: def types(self) -> List[DuckDBPyType]: ... class Error(Exception): ... - class FatalException(Error): ... class HTTPException(IOException): @@ -556,51 +797,31 @@ class HTTPException(IOException): headers: Dict[str, str] class IOException(OperationalError): ... - class IntegrityError(Error): ... - class InternalError(Error): ... - class InternalException(InternalError): ... - class InterruptException(Error): ... - class InvalidInputException(ProgrammingError): ... - class InvalidTypeException(ProgrammingError): ... - class NotImplementedException(NotSupportedError): ... - class NotSupportedError(Error): ... - class OperationalError(Error): ... - class OutOfMemoryException(OperationalError): ... - class OutOfRangeException(DataError): ... - class ParserException(ProgrammingError): ... - class PermissionException(Error): ... - class ProgrammingError(Error): ... - class SequenceException(Error): ... - class SerializationException(OperationalError): ... - class SyntaxException(ProgrammingError): ... - class TransactionException(OperationalError): ... - class TypeMismatchException(DataError): ... - class Warning(Exception): ... class token_type: # stubgen override - these make mypy sad - #__doc__: ClassVar[str] = ... # read-only - #__members__: ClassVar[dict] = ... # read-only + # __doc__: ClassVar[str] = ... # read-only + # __members__: ClassVar[dict] = ... # read-only __entries: ClassVar[dict] = ... comment: ClassVar[token_type] = ... identifier: ClassVar[token_type] = ... @@ -640,7 +861,18 @@ def register_filesystem(filesystem: fsspec.AbstractFileSystem, *, connection: Du def unregister_filesystem(name: str, *, connection: DuckDBPyConnection = ...) -> None: ... def list_filesystems(*, connection: DuckDBPyConnection = ...) -> list: ... def filesystem_is_registered(name: str, *, connection: DuckDBPyConnection = ...) -> bool: ... -def create_function(name: str, function: function, parameters: Optional[List[DuckDBPyType]] = None, return_type: Optional[DuckDBPyType] = None, *, type: Optional[PythonUDFType] = PythonUDFType.NATIVE, null_handling: Optional[FunctionNullHandling] = FunctionNullHandling.DEFAULT, exception_handling: Optional[PythonExceptionHandling] = PythonExceptionHandling.DEFAULT, side_effects: bool = False, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... +def create_function( + name: str, + function: function, + parameters: Optional[List[DuckDBPyType]] = None, + return_type: Optional[DuckDBPyType] = None, + *, + type: Optional[PythonUDFType] = PythonUDFType.NATIVE, + null_handling: Optional[FunctionNullHandling] = FunctionNullHandling.DEFAULT, + exception_handling: Optional[PythonExceptionHandling] = PythonExceptionHandling.DEFAULT, + side_effects: bool = False, + connection: DuckDBPyConnection = ..., +) -> DuckDBPyConnection: ... def remove_function(name: str, *, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... def sqltype(type_str: str, *, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... def dtype(type_str: str, *, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... @@ -649,14 +881,24 @@ def array_type(type: DuckDBPyType, size: int, *, connection: DuckDBPyConnection def list_type(type: DuckDBPyType, *, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... def union_type(members: DuckDBPyType, *, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... def string_type(collation: str = "", *, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... -def enum_type(name: str, type: DuckDBPyType, values: List[Any], *, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... +def enum_type( + name: str, type: DuckDBPyType, values: List[Any], *, connection: DuckDBPyConnection = ... +) -> DuckDBPyType: ... def decimal_type(width: int, scale: int, *, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... -def struct_type(fields: Union[Dict[str, DuckDBPyType], List[str]], *, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... -def row_type(fields: Union[Dict[str, DuckDBPyType], List[str]], *, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... +def struct_type( + fields: Union[Dict[str, DuckDBPyType], List[str]], *, connection: DuckDBPyConnection = ... +) -> DuckDBPyType: ... +def row_type( + fields: Union[Dict[str, DuckDBPyType], List[str]], *, connection: DuckDBPyConnection = ... +) -> DuckDBPyType: ... def map_type(key: DuckDBPyType, value: DuckDBPyType, *, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... def duplicate(*, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... -def execute(query: object, parameters: object = None, *, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... -def executemany(query: object, parameters: object = None, *, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... +def execute( + query: object, parameters: object = None, *, connection: DuckDBPyConnection = ... +) -> DuckDBPyConnection: ... +def executemany( + query: object, parameters: object = None, *, connection: DuckDBPyConnection = ... +) -> DuckDBPyConnection: ... def close(*, connection: DuckDBPyConnection = ...) -> None: ... def interrupt(*, connection: DuckDBPyConnection = ...) -> None: ... def query_progress(*, connection: DuckDBPyConnection = ...) -> float: ... @@ -667,10 +909,16 @@ def fetchnumpy(*, connection: DuckDBPyConnection = ...) -> dict: ... def fetchdf(*, date_as_object: bool = False, connection: DuckDBPyConnection = ...) -> pandas.DataFrame: ... def fetch_df(*, date_as_object: bool = False, connection: DuckDBPyConnection = ...) -> pandas.DataFrame: ... def df(*, date_as_object: bool = False, connection: DuckDBPyConnection = ...) -> pandas.DataFrame: ... -def fetch_df_chunk(vectors_per_chunk: int = 1, *, date_as_object: bool = False, connection: DuckDBPyConnection = ...) -> pandas.DataFrame: ... -def pl(rows_per_batch: int = 1000000, *, lazy: bool = False, connection: DuckDBPyConnection = ...) -> polars.DataFrame: ... +def fetch_df_chunk( + vectors_per_chunk: int = 1, *, date_as_object: bool = False, connection: DuckDBPyConnection = ... +) -> pandas.DataFrame: ... +def pl( + rows_per_batch: int = 1000000, *, lazy: bool = False, connection: DuckDBPyConnection = ... +) -> polars.DataFrame: ... def fetch_arrow_table(rows_per_batch: int = 1000000, *, connection: DuckDBPyConnection = ...) -> pyarrow.lib.Table: ... -def fetch_record_batch(rows_per_batch: int = 1000000, *, connection: DuckDBPyConnection = ...) -> pyarrow.lib.RecordBatchReader: ... +def fetch_record_batch( + rows_per_batch: int = 1000000, *, connection: DuckDBPyConnection = ... +) -> pyarrow.lib.RecordBatchReader: ... def arrow(rows_per_batch: int = 1000000, *, connection: DuckDBPyConnection = ...) -> pyarrow.lib.RecordBatchReader: ... def torch(*, connection: DuckDBPyConnection = ...) -> dict: ... def tf(*, connection: DuckDBPyConnection = ...) -> dict: ... @@ -678,36 +926,212 @@ def begin(*, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... def commit(*, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... def rollback(*, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... def checkpoint(*, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... -def append(table_name: str, df: pandas.DataFrame, *, by_name: bool = False, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... +def append( + table_name: str, df: pandas.DataFrame, *, by_name: bool = False, connection: DuckDBPyConnection = ... +) -> DuckDBPyConnection: ... def register(view_name: str, python_object: object, *, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... def unregister(view_name: str, *, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... def table(table_name: str, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... def view(view_name: str, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def values(*args: Union[List[Any],Expression, Tuple[Expression]], connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def table_function(name: str, parameters: object = None, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def read_json(path_or_buffer: Union[str, StringIO, TextIOBase], *, columns: Optional[Dict[str,str]] = None, sample_size: Optional[int] = None, maximum_depth: Optional[int] = None, records: Optional[str] = None, format: Optional[str] = None, date_format: Optional[str] = None, timestamp_format: Optional[str] = None, compression: Optional[str] = None, maximum_object_size: Optional[int] = None, ignore_errors: Optional[bool] = None, convert_strings_to_integers: Optional[bool] = None, field_appearance_threshold: Optional[float] = None, map_inference_threshold: Optional[int] = None, maximum_sample_files: Optional[int] = None, filename: Optional[bool | str] = None, hive_partitioning: Optional[bool] = None, union_by_name: Optional[bool] = None, hive_types: Optional[Dict[str, str]] = None, hive_types_autocast: Optional[bool] = None, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... +def values( + *args: Union[List[Any], Expression, Tuple[Expression]], connection: DuckDBPyConnection = ... +) -> DuckDBPyRelation: ... +def table_function( + name: str, parameters: object = None, *, connection: DuckDBPyConnection = ... +) -> DuckDBPyRelation: ... +def read_json( + path_or_buffer: Union[str, StringIO, TextIOBase], + *, + columns: Optional[Dict[str, str]] = None, + sample_size: Optional[int] = None, + maximum_depth: Optional[int] = None, + records: Optional[str] = None, + format: Optional[str] = None, + date_format: Optional[str] = None, + timestamp_format: Optional[str] = None, + compression: Optional[str] = None, + maximum_object_size: Optional[int] = None, + ignore_errors: Optional[bool] = None, + convert_strings_to_integers: Optional[bool] = None, + field_appearance_threshold: Optional[float] = None, + map_inference_threshold: Optional[int] = None, + maximum_sample_files: Optional[int] = None, + filename: Optional[bool | str] = None, + hive_partitioning: Optional[bool] = None, + union_by_name: Optional[bool] = None, + hive_types: Optional[Dict[str, str]] = None, + hive_types_autocast: Optional[bool] = None, + connection: DuckDBPyConnection = ..., +) -> DuckDBPyRelation: ... def extract_statements(query: str, *, connection: DuckDBPyConnection = ...) -> List[Statement]: ... -def sql(query: str, *, alias: str = "", params: object = None, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def query(query: str, *, alias: str = "", params: object = None, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def from_query(query: str, *, alias: str = "", params: object = None, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def read_csv(path_or_buffer: Union[str, StringIO, TextIOBase], *, header: Optional[bool | int] = None, compression: Optional[str] = None, sep: Optional[str] = None, delimiter: Optional[str] = None, dtype: Optional[Dict[str, str] | List[str]] = None, na_values: Optional[str| List[str]] = None, skiprows: Optional[int] = None, quotechar: Optional[str] = None, escapechar: Optional[str] = None, encoding: Optional[str] = None, parallel: Optional[bool] = None, date_format: Optional[str] = None, timestamp_format: Optional[str] = None, sample_size: Optional[int] = None, all_varchar: Optional[bool] = None, normalize_names: Optional[bool] = None, null_padding: Optional[bool] = None, names: Optional[List[str]] = None, lineterminator: Optional[str] = None, columns: Optional[Dict[str, str]] = None, auto_type_candidates: Optional[List[str]] = None, max_line_size: Optional[int] = None, ignore_errors: Optional[bool] = None, store_rejects: Optional[bool] = None, rejects_table: Optional[str] = None, rejects_scan: Optional[str] = None, rejects_limit: Optional[int] = None, force_not_null: Optional[List[str]] = None, buffer_size: Optional[int] = None, decimal: Optional[str] = None, allow_quoted_nulls: Optional[bool] = None, filename: Optional[bool | str] = None, hive_partitioning: Optional[bool] = None, union_by_name: Optional[bool] = None, hive_types: Optional[Dict[str, str]] = None, hive_types_autocast: Optional[bool] = None, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def from_csv_auto(path_or_buffer: Union[str, StringIO, TextIOBase], *, header: Optional[bool | int] = None, compression: Optional[str] = None, sep: Optional[str] = None, delimiter: Optional[str] = None, dtype: Optional[Dict[str, str] | List[str]] = None, na_values: Optional[str| List[str]] = None, skiprows: Optional[int] = None, quotechar: Optional[str] = None, escapechar: Optional[str] = None, encoding: Optional[str] = None, parallel: Optional[bool] = None, date_format: Optional[str] = None, timestamp_format: Optional[str] = None, sample_size: Optional[int] = None, all_varchar: Optional[bool] = None, normalize_names: Optional[bool] = None, null_padding: Optional[bool] = None, names: Optional[List[str]] = None, lineterminator: Optional[str] = None, columns: Optional[Dict[str, str]] = None, auto_type_candidates: Optional[List[str]] = None, max_line_size: Optional[int] = None, ignore_errors: Optional[bool] = None, store_rejects: Optional[bool] = None, rejects_table: Optional[str] = None, rejects_scan: Optional[str] = None, rejects_limit: Optional[int] = None, force_not_null: Optional[List[str]] = None, buffer_size: Optional[int] = None, decimal: Optional[str] = None, allow_quoted_nulls: Optional[bool] = None, filename: Optional[bool | str] = None, hive_partitioning: Optional[bool] = None, union_by_name: Optional[bool] = None, hive_types: Optional[Dict[str, str]] = None, hive_types_autocast: Optional[bool] = None, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... +def sql( + query: str, *, alias: str = "", params: object = None, connection: DuckDBPyConnection = ... +) -> DuckDBPyRelation: ... +def query( + query: str, *, alias: str = "", params: object = None, connection: DuckDBPyConnection = ... +) -> DuckDBPyRelation: ... +def from_query( + query: str, *, alias: str = "", params: object = None, connection: DuckDBPyConnection = ... +) -> DuckDBPyRelation: ... +def read_csv( + path_or_buffer: Union[str, StringIO, TextIOBase], + *, + header: Optional[bool | int] = None, + compression: Optional[str] = None, + sep: Optional[str] = None, + delimiter: Optional[str] = None, + dtype: Optional[Dict[str, str] | List[str]] = None, + na_values: Optional[str | List[str]] = None, + skiprows: Optional[int] = None, + quotechar: Optional[str] = None, + escapechar: Optional[str] = None, + encoding: Optional[str] = None, + parallel: Optional[bool] = None, + date_format: Optional[str] = None, + timestamp_format: Optional[str] = None, + sample_size: Optional[int] = None, + all_varchar: Optional[bool] = None, + normalize_names: Optional[bool] = None, + null_padding: Optional[bool] = None, + names: Optional[List[str]] = None, + lineterminator: Optional[str] = None, + columns: Optional[Dict[str, str]] = None, + auto_type_candidates: Optional[List[str]] = None, + max_line_size: Optional[int] = None, + ignore_errors: Optional[bool] = None, + store_rejects: Optional[bool] = None, + rejects_table: Optional[str] = None, + rejects_scan: Optional[str] = None, + rejects_limit: Optional[int] = None, + force_not_null: Optional[List[str]] = None, + buffer_size: Optional[int] = None, + decimal: Optional[str] = None, + allow_quoted_nulls: Optional[bool] = None, + filename: Optional[bool | str] = None, + hive_partitioning: Optional[bool] = None, + union_by_name: Optional[bool] = None, + hive_types: Optional[Dict[str, str]] = None, + hive_types_autocast: Optional[bool] = None, + connection: DuckDBPyConnection = ..., +) -> DuckDBPyRelation: ... +def from_csv_auto( + path_or_buffer: Union[str, StringIO, TextIOBase], + *, + header: Optional[bool | int] = None, + compression: Optional[str] = None, + sep: Optional[str] = None, + delimiter: Optional[str] = None, + dtype: Optional[Dict[str, str] | List[str]] = None, + na_values: Optional[str | List[str]] = None, + skiprows: Optional[int] = None, + quotechar: Optional[str] = None, + escapechar: Optional[str] = None, + encoding: Optional[str] = None, + parallel: Optional[bool] = None, + date_format: Optional[str] = None, + timestamp_format: Optional[str] = None, + sample_size: Optional[int] = None, + all_varchar: Optional[bool] = None, + normalize_names: Optional[bool] = None, + null_padding: Optional[bool] = None, + names: Optional[List[str]] = None, + lineterminator: Optional[str] = None, + columns: Optional[Dict[str, str]] = None, + auto_type_candidates: Optional[List[str]] = None, + max_line_size: Optional[int] = None, + ignore_errors: Optional[bool] = None, + store_rejects: Optional[bool] = None, + rejects_table: Optional[str] = None, + rejects_scan: Optional[str] = None, + rejects_limit: Optional[int] = None, + force_not_null: Optional[List[str]] = None, + buffer_size: Optional[int] = None, + decimal: Optional[str] = None, + allow_quoted_nulls: Optional[bool] = None, + filename: Optional[bool | str] = None, + hive_partitioning: Optional[bool] = None, + union_by_name: Optional[bool] = None, + hive_types: Optional[Dict[str, str]] = None, + hive_types_autocast: Optional[bool] = None, + connection: DuckDBPyConnection = ..., +) -> DuckDBPyRelation: ... def from_df(df: pandas.DataFrame, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... def from_arrow(arrow_object: object, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def from_parquet(file_glob: str, binary_as_string: bool = False, *, file_row_number: bool = False, filename: bool = False, hive_partitioning: bool = False, union_by_name: bool = False, compression: Optional[str] = None, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def read_parquet(file_glob: str, binary_as_string: bool = False, *, file_row_number: bool = False, filename: bool = False, hive_partitioning: bool = False, union_by_name: bool = False, compression: Optional[str] = None, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... +def from_parquet( + file_glob: str, + binary_as_string: bool = False, + *, + file_row_number: bool = False, + filename: bool = False, + hive_partitioning: bool = False, + union_by_name: bool = False, + compression: Optional[str] = None, + connection: DuckDBPyConnection = ..., +) -> DuckDBPyRelation: ... +def read_parquet( + file_glob: str, + binary_as_string: bool = False, + *, + file_row_number: bool = False, + filename: bool = False, + hive_partitioning: bool = False, + union_by_name: bool = False, + compression: Optional[str] = None, + connection: DuckDBPyConnection = ..., +) -> DuckDBPyRelation: ... def get_table_names(query: str, *, qualified: bool = False, connection: DuckDBPyConnection = ...) -> Set[str]: ... -def install_extension(extension: str, *, force_install: bool = False, repository: Optional[str] = None, repository_url: Optional[str] = None, version: Optional[str] = None, connection: DuckDBPyConnection = ...) -> None: ... +def install_extension( + extension: str, + *, + force_install: bool = False, + repository: Optional[str] = None, + repository_url: Optional[str] = None, + version: Optional[str] = None, + connection: DuckDBPyConnection = ..., +) -> None: ... def load_extension(extension: str, *, connection: DuckDBPyConnection = ...) -> None: ... -def project(df: pandas.DataFrame, *args: str, groups: str = "", connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... +def project( + df: pandas.DataFrame, *args: str, groups: str = "", connection: DuckDBPyConnection = ... +) -> DuckDBPyRelation: ... def distinct(df: pandas.DataFrame, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def write_csv(df: pandas.DataFrame, filename: str, *, sep: Optional[str] = None, na_rep: Optional[str] = None, header: Optional[bool] = None, quotechar: Optional[str] = None, escapechar: Optional[str] = None, date_format: Optional[str] = None, timestamp_format: Optional[str] = None, quoting: Optional[str | int] = None, encoding: Optional[str] = None, compression: Optional[str] = None, overwrite: Optional[bool] = None, per_thread_output: Optional[bool] = None, use_tmp_file: Optional[bool] = None, partition_by: Optional[List[str]] = None, write_partition_columns: Optional[bool] = None, connection: DuckDBPyConnection = ...) -> None: ... -def aggregate(df: pandas.DataFrame, aggr_expr: str | List[Expression], group_expr: str = "", *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... +def write_csv( + df: pandas.DataFrame, + filename: str, + *, + sep: Optional[str] = None, + na_rep: Optional[str] = None, + header: Optional[bool] = None, + quotechar: Optional[str] = None, + escapechar: Optional[str] = None, + date_format: Optional[str] = None, + timestamp_format: Optional[str] = None, + quoting: Optional[str | int] = None, + encoding: Optional[str] = None, + compression: Optional[str] = None, + overwrite: Optional[bool] = None, + per_thread_output: Optional[bool] = None, + use_tmp_file: Optional[bool] = None, + partition_by: Optional[List[str]] = None, + write_partition_columns: Optional[bool] = None, + connection: DuckDBPyConnection = ..., +) -> None: ... +def aggregate( + df: pandas.DataFrame, + aggr_expr: str | List[Expression], + group_expr: str = "", + *, + connection: DuckDBPyConnection = ..., +) -> DuckDBPyRelation: ... def alias(df: pandas.DataFrame, alias: str, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... def filter(df: pandas.DataFrame, filter_expr: str, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def limit(df: pandas.DataFrame, n: int, offset: int = 0, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... +def limit( + df: pandas.DataFrame, n: int, offset: int = 0, *, connection: DuckDBPyConnection = ... +) -> DuckDBPyRelation: ... def order(df: pandas.DataFrame, order_expr: str, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def query_df(df: pandas.DataFrame, virtual_table_name: str, sql_query: str, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... +def query_df( + df: pandas.DataFrame, virtual_table_name: str, sql_query: str, *, connection: DuckDBPyConnection = ... +) -> DuckDBPyRelation: ... def description(*, connection: DuckDBPyConnection = ...) -> Optional[List[Any]]: ... def rowcount(*, connection: DuckDBPyConnection = ...) -> int: ... + # END OF CONNECTION WRAPPER diff --git a/duckdb/bytes_io_wrapper.py b/duckdb/bytes_io_wrapper.py index 0957652b..763fd8b7 100644 --- a/duckdb/bytes_io_wrapper.py +++ b/duckdb/bytes_io_wrapper.py @@ -63,4 +63,3 @@ def read(self, n: Union[int, None] = -1) -> bytes: to_return = combined_bytestring[:n] self.overflow = combined_bytestring[n:] return to_return - diff --git a/duckdb/experimental/__init__.py b/duckdb/experimental/__init__.py index 0ab3305b..a88a6170 100644 --- a/duckdb/experimental/__init__.py +++ b/duckdb/experimental/__init__.py @@ -1,2 +1,3 @@ from . import spark + __all__ = spark.__all__ diff --git a/duckdb/experimental/spark/_globals.py b/duckdb/experimental/spark/_globals.py index be16be41..d6a02326 100644 --- a/duckdb/experimental/spark/_globals.py +++ b/duckdb/experimental/spark/_globals.py @@ -56,7 +56,7 @@ class _NoValueType: __instance = None - def __new__(cls) -> '_NoValueType': + def __new__(cls) -> "_NoValueType": # ensure that only one instance exists if not cls.__instance: cls.__instance = super(_NoValueType, cls).__new__(cls) diff --git a/duckdb/experimental/spark/_typing.py b/duckdb/experimental/spark/_typing.py index 0c06fed5..251ef695 100644 --- a/duckdb/experimental/spark/_typing.py +++ b/duckdb/experimental/spark/_typing.py @@ -30,17 +30,14 @@ class SupportsIAdd(Protocol): - def __iadd__(self, other: "SupportsIAdd") -> "SupportsIAdd": - ... + def __iadd__(self, other: "SupportsIAdd") -> "SupportsIAdd": ... class SupportsOrdering(Protocol): - def __lt__(self, other: "SupportsOrdering") -> bool: - ... + def __lt__(self, other: "SupportsOrdering") -> bool: ... -class SizedIterable(Protocol, Sized, Iterable[T_co]): - ... +class SizedIterable(Protocol, Sized, Iterable[T_co]): ... S = TypeVar("S", bound=SupportsOrdering) diff --git a/duckdb/experimental/spark/context.py b/duckdb/experimental/spark/context.py index 95227add..dd4b016c 100644 --- a/duckdb/experimental/spark/context.py +++ b/duckdb/experimental/spark/context.py @@ -8,7 +8,7 @@ class SparkContext: def __init__(self, master: str) -> None: - self._connection = duckdb.connect(':memory:') + self._connection = duckdb.connect(":memory:") # This aligns the null ordering with Spark. self._connection.execute("set default_null_order='nulls_first_on_asc_last_on_desc'") diff --git a/duckdb/experimental/spark/errors/__init__.py b/duckdb/experimental/spark/errors/__init__.py index 5f2af443..6aac49d7 100644 --- a/duckdb/experimental/spark/errors/__init__.py +++ b/duckdb/experimental/spark/errors/__init__.py @@ -18,6 +18,7 @@ """ PySpark exceptions. """ + from .exceptions.base import ( # noqa: F401 PySparkException, AnalysisException, diff --git a/duckdb/experimental/spark/errors/exceptions/base.py b/duckdb/experimental/spark/errors/exceptions/base.py index fcdce827..48a3ea95 100644 --- a/duckdb/experimental/spark/errors/exceptions/base.py +++ b/duckdb/experimental/spark/errors/exceptions/base.py @@ -2,6 +2,7 @@ from ..utils import ErrorClassesReader + class PySparkException(Exception): """ Base Exception for handling errors generated from PySpark. @@ -79,6 +80,7 @@ def __str__(self) -> str: else: return self.message + class AnalysisException(PySparkException): """ Failed to analyze a SQL query plan. diff --git a/duckdb/experimental/spark/errors/utils.py b/duckdb/experimental/spark/errors/utils.py index 3ef418bd..f1b37f75 100644 --- a/duckdb/experimental/spark/errors/utils.py +++ b/duckdb/experimental/spark/errors/utils.py @@ -37,8 +37,7 @@ def get_error_message(self, error_class: str, message_parameters: dict[str, str] # Verify message parameters. message_parameters_from_template = re.findall("<([a-zA-Z0-9_-]+)>", message_template) assert set(message_parameters_from_template) == set(message_parameters), ( - f"Undefined error message parameter for error class: {error_class}. " - f"Parameters: {message_parameters}" + f"Undefined error message parameter for error class: {error_class}. Parameters: {message_parameters}" ) table = str.maketrans("<>", "{}") diff --git a/duckdb/experimental/spark/exception.py b/duckdb/experimental/spark/exception.py index 21668cf5..60495d88 100644 --- a/duckdb/experimental/spark/exception.py +++ b/duckdb/experimental/spark/exception.py @@ -8,7 +8,7 @@ class ContributionsAcceptedError(NotImplementedError): def __init__(self, message=None) -> None: doc = self.__class__.__doc__ if message: - doc = message + '\n' + doc + doc = message + "\n" + doc super().__init__(doc) diff --git a/duckdb/experimental/spark/sql/_typing.py b/duckdb/experimental/spark/sql/_typing.py index 645b60bb..b5a8b079 100644 --- a/duckdb/experimental/spark/sql/_typing.py +++ b/duckdb/experimental/spark/sql/_typing.py @@ -25,6 +25,7 @@ TypeVar, Union, ) + try: from typing import Literal, Protocol except ImportError: @@ -63,18 +64,15 @@ class SupportsOpen(Protocol): - def open(self, partition_id: int, epoch_id: int) -> bool: - ... + def open(self, partition_id: int, epoch_id: int) -> bool: ... class SupportsProcess(Protocol): - def process(self, row: types.Row) -> None: - ... + def process(self, row: types.Row) -> None: ... class SupportsClose(Protocol): - def close(self, error: Exception) -> None: - ... + def close(self, error: Exception) -> None: ... class UserDefinedFunctionLike(Protocol): @@ -83,11 +81,8 @@ class UserDefinedFunctionLike(Protocol): deterministic: bool @property - def returnType(self) -> types.DataType: - ... + def returnType(self) -> types.DataType: ... - def __call__(self, *args: ColumnOrName) -> Column: - ... + def __call__(self, *args: ColumnOrName) -> Column: ... - def asNondeterministic(self) -> "UserDefinedFunctionLike": - ... + def asNondeterministic(self) -> "UserDefinedFunctionLike": ... diff --git a/duckdb/experimental/spark/sql/catalog.py b/duckdb/experimental/spark/sql/catalog.py index 0cd790f7..3cc96f45 100644 --- a/duckdb/experimental/spark/sql/catalog.py +++ b/duckdb/experimental/spark/sql/catalog.py @@ -37,19 +37,19 @@ def __init__(self, session: SparkSession) -> None: self._session = session def listDatabases(self) -> list[Database]: - res = self._session.conn.sql('select database_name from duckdb_databases()').fetchall() + res = self._session.conn.sql("select database_name from duckdb_databases()").fetchall() def transform_to_database(x) -> Database: - return Database(name=x[0], description=None, locationUri='') + return Database(name=x[0], description=None, locationUri="") databases = [transform_to_database(x) for x in res] return databases def listTables(self) -> list[Table]: - res = self._session.conn.sql('select table_name, database_name, sql, temporary from duckdb_tables()').fetchall() + res = self._session.conn.sql("select table_name, database_name, sql, temporary from duckdb_tables()").fetchall() def transform_to_table(x) -> Table: - return Table(name=x[0], database=x[1], description=x[2], tableType='', isTemporary=x[3]) + return Table(name=x[0], database=x[1], description=x[2], tableType="", isTemporary=x[3]) tables = [transform_to_table(x) for x in res] return tables diff --git a/duckdb/experimental/spark/sql/column.py b/duckdb/experimental/spark/sql/column.py index 0dd86178..f78b31ae 100644 --- a/duckdb/experimental/spark/sql/column.py +++ b/duckdb/experimental/spark/sql/column.py @@ -99,7 +99,7 @@ def __init__(self, expr: Expression) -> None: self.expr = expr # arithmetic operators - def __neg__(self) -> 'Column': + def __neg__(self) -> "Column": return Column(-self.expr) # `and`, `or`, `not` cannot be overloaded in Python, @@ -161,8 +161,8 @@ def __getitem__(self, k: Any) -> "Column": Examples -------- - >>> df = spark.createDataFrame([('abcedfg', {"key": "value"})], ["l", "d"]) - >>> df.select(df.l[slice(1, 3)], df.d['key']).show() + >>> df = spark.createDataFrame([("abcedfg", {"key": "value"})], ["l", "d"]) + >>> df.select(df.l[slice(1, 3)], df.d["key"]).show() +------------------+------+ |substring(l, 1, 3)|d[key]| +------------------+------+ @@ -196,7 +196,7 @@ def __getattr__(self, item: Any) -> "Column": Examples -------- - >>> df = spark.createDataFrame([('abcedfg', {"key": "value"})], ["l", "d"]) + >>> df = spark.createDataFrame([("abcedfg", {"key": "value"})], ["l", "d"]) >>> df.select(df.d.key).show() +------+ |d[key]| @@ -347,7 +347,6 @@ def __ne__( # type: ignore[override] nulls_first = _unary_op("nulls_first") nulls_last = _unary_op("nulls_last") - def asc_nulls_first(self) -> "Column": return self.asc().nulls_first() @@ -365,4 +364,3 @@ def isNull(self) -> "Column": def isNotNull(self) -> "Column": return Column(self.expr.isnotnull()) - diff --git a/duckdb/experimental/spark/sql/dataframe.py b/duckdb/experimental/spark/sql/dataframe.py index 42a5b8f0..19f5576b 100644 --- a/duckdb/experimental/spark/sql/dataframe.py +++ b/duckdb/experimental/spark/sql/dataframe.py @@ -170,7 +170,7 @@ def withColumns(self, *colsMap: dict[str, Column]) -> "DataFrame": Examples -------- >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"]) - >>> df.withColumns({'age2': df.age + 2, 'age3': df.age + 3}).show() + >>> df.withColumns({"age2": df.age + 2, "age3": df.age + 3}).show() +---+-----+----+----+ |age| name|age2|age3| +---+-----+----+----+ @@ -248,8 +248,8 @@ def withColumnsRenamed(self, colsMap: dict[str, str]) -> "DataFrame": Examples -------- >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"]) - >>> df = df.withColumns({'age2': df.age + 2, 'age3': df.age + 3}) - >>> df.withColumnsRenamed({'age2': 'age4', 'age3': 'age5'}).show() + >>> df = df.withColumns({"age2": df.age + 2, "age3": df.age + 3}) + >>> df.withColumnsRenamed({"age2": "age4", "age3": "age5"}).show() +---+-----+----+----+ |age| name|age4|age5| +---+-----+----+----+ @@ -265,9 +265,7 @@ def withColumnsRenamed(self, colsMap: dict[str, str]) -> "DataFrame": unknown_columns = set(colsMap.keys()) - set(self.relation.columns) if unknown_columns: - raise ValueError( - f"DataFrame does not contain column(s): {', '.join(unknown_columns)}" - ) + raise ValueError(f"DataFrame does not contain column(s): {', '.join(unknown_columns)}") # Compute this only once old_column_names = list(colsMap.keys()) @@ -289,11 +287,7 @@ def withColumnsRenamed(self, colsMap: dict[str, str]) -> "DataFrame": rel = self.relation.select(*cols) return DataFrame(rel, self.session) - - - def transform( - self, func: Callable[..., "DataFrame"], *args: Any, **kwargs: Any - ) -> "DataFrame": + def transform(self, func: Callable[..., "DataFrame"], *args: Any, **kwargs: Any) -> "DataFrame": """Returns a new :class:`DataFrame`. Concise syntax for chaining custom transformations. .. versionadded:: 3.0.0 @@ -325,10 +319,8 @@ def transform( >>> df = spark.createDataFrame([(1, 1.0), (2, 2.0)], ["int", "float"]) >>> def cast_all_to_int(input_df): ... return input_df.select([col(col_name).cast("int") for col_name in input_df.columns]) - ... >>> def sort_columns_asc(input_df): ... return input_df.select(*sorted(input_df.columns)) - ... >>> df.transform(cast_all_to_int).transform(sort_columns_asc).show() +-----+---+ |float|int| @@ -338,8 +330,9 @@ def transform( +-----+---+ >>> def add_n(input_df, n): - ... return input_df.select([(col(col_name) + n).alias(col_name) - ... for col_name in input_df.columns]) + ... return input_df.select( + ... [(col(col_name) + n).alias(col_name) for col_name in input_df.columns] + ... ) >>> df.transform(add_n, 1).transform(add_n, n=10).show() +---+-----+ |int|float| @@ -350,14 +343,11 @@ def transform( """ result = func(self, *args, **kwargs) assert isinstance(result, DataFrame), ( - "Func returned an instance of type [%s], " - "should have been DataFrame." % type(result) + "Func returned an instance of type [%s], should have been DataFrame." % type(result) ) return result - def sort( - self, *cols: Union[str, Column, list[Union[str, Column]]], **kwargs: Any - ) -> "DataFrame": + def sort(self, *cols: Union[str, Column, list[Union[str, Column]]], **kwargs: Any) -> "DataFrame": """Returns a new :class:`DataFrame` sorted by the specified column(s). Parameters @@ -380,8 +370,7 @@ def sort( Examples -------- >>> from pyspark.sql.functions import desc, asc - >>> df = spark.createDataFrame([ - ... (2, "Alice"), (5, "Bob")], schema=["age", "name"]) + >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"]) Sort the DataFrame in ascending order. @@ -419,8 +408,9 @@ def sort( Specify multiple columns - >>> df = spark.createDataFrame([ - ... (2, "Alice"), (2, "Bob"), (5, "Bob")], schema=["age", "name"]) + >>> df = spark.createDataFrame( + ... [(2, "Alice"), (2, "Bob"), (5, "Bob")], schema=["age", "name"] + ... ) >>> df.orderBy(desc("age"), "name").show() +---+-----+ |age| name| @@ -516,8 +506,7 @@ def filter(self, condition: "ColumnOrName") -> "DataFrame": Examples -------- - >>> df = spark.createDataFrame([ - ... (2, "Alice"), (5, "Bob")], schema=["age", "name"]) + >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"]) Filter by :class:`Column` instances. @@ -568,13 +557,9 @@ def select(self, *cols) -> "DataFrame": if len(cols) == 1: cols = cols[0] if isinstance(cols, list): - projections = [ - x.expr if isinstance(x, Column) else ColumnExpression(x) for x in cols - ] + projections = [x.expr if isinstance(x, Column) else ColumnExpression(x) for x in cols] else: - projections = [ - cols.expr if isinstance(cols, Column) else ColumnExpression(cols) - ] + projections = [cols.expr if isinstance(cols, Column) else ColumnExpression(cols)] rel = self.relation.select(*projections) return DataFrame(rel, self.session) @@ -636,22 +621,24 @@ def join( >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")]).toDF("age", "name") >>> df2 = spark.createDataFrame([Row(height=80, name="Tom"), Row(height=85, name="Bob")]) >>> df3 = spark.createDataFrame([Row(age=2, name="Alice"), Row(age=5, name="Bob")]) - >>> df4 = spark.createDataFrame([ - ... Row(age=10, height=80, name="Alice"), - ... Row(age=5, height=None, name="Bob"), - ... Row(age=None, height=None, name="Tom"), - ... Row(age=None, height=None, name=None), - ... ]) + >>> df4 = spark.createDataFrame( + ... [ + ... Row(age=10, height=80, name="Alice"), + ... Row(age=5, height=None, name="Bob"), + ... Row(age=None, height=None, name="Tom"), + ... Row(age=None, height=None, name=None), + ... ] + ... ) Inner join on columns (default) - >>> df.join(df2, 'name').select(df.name, df2.height).show() + >>> df.join(df2, "name").select(df.name, df2.height).show() +----+------+ |name|height| +----+------+ | Bob| 85| +----+------+ - >>> df.join(df4, ['name', 'age']).select(df.name, df.age).show() + >>> df.join(df4, ["name", "age"]).select(df.name, df.age).show() +----+---+ |name|age| +----+---+ @@ -660,8 +647,9 @@ def join( Outer join for both DataFrames on the 'name' column. - >>> df.join(df2, df.name == df2.name, 'outer').select( - ... df.name, df2.height).sort(desc("name")).show() + >>> df.join(df2, df.name == df2.name, "outer").select(df.name, df2.height).sort( + ... desc("name") + ... ).show() +-----+------+ | name|height| +-----+------+ @@ -669,7 +657,7 @@ def join( |Alice| NULL| | NULL| 80| +-----+------+ - >>> df.join(df2, 'name', 'outer').select('name', 'height').sort(desc("name")).show() + >>> df.join(df2, "name", "outer").select("name", "height").sort(desc("name")).show() +-----+------+ | name|height| +-----+------+ @@ -680,11 +668,9 @@ def join( Outer join for both DataFrams with multiple columns. - >>> df.join( - ... df3, - ... [df.name == df3.name, df.age == df3.age], - ... 'outer' - ... ).select(df.name, df3.age).show() + >>> df.join(df3, [df.name == df3.name, df.age == df3.age], "outer").select( + ... df.name, df3.age + ... ).show() +-----+---+ | name|age| +-----+---+ @@ -701,12 +687,9 @@ def join( on = [_to_column_expr(x) for x in on] # & all the Expressions together to form one Expression - assert isinstance( - on[0], Expression - ), "on should be Column or list of Column" + assert isinstance(on[0], Expression), "on should be Column or list of Column" on = reduce(lambda x, y: x.__and__(y), cast(list[Expression], on)) - if on is None and how is None: result = self.relation.join(other.relation) else: @@ -765,10 +748,8 @@ def crossJoin(self, other: "DataFrame") -> "DataFrame": Examples -------- >>> from pyspark.sql import Row - >>> df = spark.createDataFrame( - ... [(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) - >>> df2 = spark.createDataFrame( - ... [Row(height=80, name="Tom"), Row(height=85, name="Bob")]) + >>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) + >>> df2 = spark.createDataFrame([Row(height=80, name="Tom"), Row(height=85, name="Bob")]) >>> df.crossJoin(df2.select("height")).select("age", "name", "height").show() +---+-----+------+ |age| name|height| @@ -799,13 +780,13 @@ def alias(self, alias: str) -> "DataFrame": Examples -------- >>> from pyspark.sql.functions import col, desc - >>> df = spark.createDataFrame( - ... [(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) + >>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) >>> df_as1 = df.alias("df_as1") >>> df_as2 = df.alias("df_as2") - >>> joined_df = df_as1.join(df_as2, col("df_as1.name") == col("df_as2.name"), 'inner') - >>> joined_df.select( - ... "df_as1.name", "df_as2.name", "df_as2.age").sort(desc("df_as1.name")).show() + >>> joined_df = df_as1.join(df_as2, col("df_as1.name") == col("df_as2.name"), "inner") + >>> joined_df.select("df_as1.name", "df_as2.name", "df_as2.age").sort( + ... desc("df_as1.name") + ... ).show() +-----+-----+---+ | name| name|age| +-----+-----+---+ @@ -853,8 +834,7 @@ def limit(self, num: int) -> "DataFrame": Examples -------- - >>> df = spark.createDataFrame( - ... [(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) + >>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) >>> df.limit(1).show() +---+----+ |age|name| @@ -889,25 +869,21 @@ def schema(self) -> StructType: return self._schema @overload - def __getitem__(self, item: Union[int, str]) -> Column: - ... + def __getitem__(self, item: Union[int, str]) -> Column: ... @overload - def __getitem__(self, item: Union[Column, list, tuple]) -> "DataFrame": - ... + def __getitem__(self, item: Union[Column, list, tuple]) -> "DataFrame": ... - def __getitem__( - self, item: Union[int, str, Column, list, tuple] - ) -> Union[Column, "DataFrame"]: + def __getitem__(self, item: Union[int, str, Column, list, tuple]) -> Union[Column, "DataFrame"]: """Returns the column as a :class:`Column`. Examples -------- - >>> df.select(df['age']).collect() + >>> df.select(df["age"]).collect() [Row(age=2), Row(age=5)] - >>> df[ ["name", "age"]].collect() + >>> df[["name", "age"]].collect() [Row(name='Alice', age=2), Row(name='Bob', age=5)] - >>> df[ df.age > 3 ].collect() + >>> df[df.age > 3].collect() [Row(age=5, name='Bob')] >>> df[df[0] > 3].collect() [Row(age=5, name='Bob')] @@ -932,18 +908,14 @@ def __getattr__(self, name: str) -> Column: [Row(age=2), Row(age=5)] """ if name not in self.relation.columns: - raise AttributeError( - "'%s' object has no attribute '%s'" % (self.__class__.__name__, name) - ) + raise AttributeError("'%s' object has no attribute '%s'" % (self.__class__.__name__, name)) return Column(duckdb.ColumnExpression(self.relation.alias, name)) @overload - def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": - ... + def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": ... @overload - def groupBy(self, __cols: Union[list[Column], list[str]]) -> "GroupedData": - ... + def groupBy(self, __cols: Union[list[Column], list[str]]) -> "GroupedData": ... def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc] """Groups the :class:`DataFrame` using the specified columns, @@ -966,8 +938,9 @@ def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc] Examples -------- - >>> df = spark.createDataFrame([ - ... (2, "Alice"), (2, "Bob"), (2, "Bob"), (5, "Bob")], schema=["age", "name"]) + >>> df = spark.createDataFrame( + ... [(2, "Alice"), (2, "Bob"), (2, "Bob"), (5, "Bob")], schema=["age", "name"] + ... ) Empty grouping columns triggers a global aggregation. @@ -1073,9 +1046,7 @@ def union(self, other: "DataFrame") -> "DataFrame": unionAll = union - def unionByName( - self, other: "DataFrame", allowMissingColumns: bool = False - ) -> "DataFrame": + def unionByName(self, other: "DataFrame", allowMissingColumns: bool = False) -> "DataFrame": """Returns a new :class:`DataFrame` containing union of rows in this and another :class:`DataFrame`. @@ -1244,7 +1215,8 @@ def exceptAll(self, other: "DataFrame") -> "DataFrame": Examples -------- >>> df1 = spark.createDataFrame( - ... [("a", 1), ("a", 1), ("a", 1), ("a", 2), ("b", 3), ("c", 4)], ["C1", "C2"]) + ... [("a", 1), ("a", 1), ("a", 1), ("a", 2), ("b", 3), ("c", 4)], ["C1", "C2"] + ... ) >>> df2 = spark.createDataFrame([("a", 1), ("b", 3)], ["C1", "C2"]) >>> df1.exceptAll(df2).show() +---+---+ @@ -1284,11 +1256,13 @@ def dropDuplicates(self, subset: Optional[list[str]] = None) -> "DataFrame": Examples -------- >>> from pyspark.sql import Row - >>> df = spark.createDataFrame([ - ... Row(name='Alice', age=5, height=80), - ... Row(name='Alice', age=5, height=80), - ... Row(name='Alice', age=10, height=80) - ... ]) + >>> df = spark.createDataFrame( + ... [ + ... Row(name="Alice", age=5, height=80), + ... Row(name="Alice", age=5, height=80), + ... Row(name="Alice", age=10, height=80), + ... ] + ... ) Deduplicate the same rows. @@ -1302,7 +1276,7 @@ def dropDuplicates(self, subset: Optional[list[str]] = None) -> "DataFrame": Deduplicate values on 'name' and 'height' columns. - >>> df.dropDuplicates(['name', 'height']).show() + >>> df.dropDuplicates(["name", "height"]).show() +-----+---+------+ | name|age|height| +-----+---+------+ @@ -1311,7 +1285,7 @@ def dropDuplicates(self, subset: Optional[list[str]] = None) -> "DataFrame": """ if subset: rn_col = f"tmp_col_{uuid.uuid1().hex}" - subset_str = ', '.join([f'"{c}"' for c in subset]) + subset_str = ", ".join([f'"{c}"' for c in subset]) window_spec = f"OVER(PARTITION BY {subset_str}) AS {rn_col}" df = DataFrame(self.relation.row_number(window_spec, "*"), self.session) return df.filter(f"{rn_col} = 1").drop(rn_col) @@ -1320,7 +1294,6 @@ def dropDuplicates(self, subset: Optional[list[str]] = None) -> "DataFrame": drop_duplicates = dropDuplicates - def distinct(self) -> "DataFrame": """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`. @@ -1331,8 +1304,7 @@ def distinct(self) -> "DataFrame": Examples -------- - >>> df = spark.createDataFrame( - ... [(14, "Tom"), (23, "Alice"), (23, "Alice")], ["age", "name"]) + >>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (23, "Alice")], ["age", "name"]) Return the number of distinct rows in the :class:`DataFrame` @@ -1352,8 +1324,7 @@ def count(self) -> int: Examples -------- - >>> df = spark.createDataFrame( - ... [(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) + >>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) Return the number of rows in the :class:`DataFrame`. @@ -1369,8 +1340,7 @@ def _cast_types(self, *types) -> "DataFrame": assert types_count == len(existing_columns) cast_expressions = [ - f"{existing}::{target_type} as {existing}" - for existing, target_type in zip(existing_columns, types) + f"{existing}::{target_type} as {existing}" for existing, target_type in zip(existing_columns, types) ] cast_expressions = ", ".join(cast_expressions) new_rel = self.relation.project(cast_expressions) @@ -1380,14 +1350,10 @@ def toDF(self, *cols) -> "DataFrame": existing_columns = self.relation.columns column_count = len(cols) if column_count != len(existing_columns): - raise PySparkValueError( - message="Provided column names and number of columns in the DataFrame don't match" - ) + raise PySparkValueError(message="Provided column names and number of columns in the DataFrame don't match") existing_columns = [ColumnExpression(x) for x in existing_columns] - projections = [ - existing.alias(new) for existing, new in zip(existing_columns, cols) - ] + projections = [existing.alias(new) for existing, new in zip(existing_columns, cols)] new_rel = self.relation.project(*projections) return DataFrame(new_rel, self.session) diff --git a/duckdb/experimental/spark/sql/functions.py b/duckdb/experimental/spark/sql/functions.py index 78b14de7..dfcf7e2e 100644 --- a/duckdb/experimental/spark/sql/functions.py +++ b/duckdb/experimental/spark/sql/functions.py @@ -11,6 +11,7 @@ LambdaExpression, SQLExpression, ) + if TYPE_CHECKING: from .dataframe import DataFrame @@ -105,14 +106,10 @@ def _inner_expr_or_val(val): def struct(*cols: Column) -> Column: - return Column( - FunctionExpression("struct_pack", *[_inner_expr_or_val(x) for x in cols]) - ) + return Column(FunctionExpression("struct_pack", *[_inner_expr_or_val(x) for x in cols])) -def array( - *cols: Union["ColumnOrName", Union[list["ColumnOrName"], tuple["ColumnOrName", ...]]] -) -> Column: +def array(*cols: Union["ColumnOrName", Union[list["ColumnOrName"], tuple["ColumnOrName", ...]]]) -> Column: """Creates a new array column. .. versionadded:: 1.4.0 @@ -134,11 +131,11 @@ def array( Examples -------- >>> df = spark.createDataFrame([("Alice", 2), ("Bob", 5)], ("name", "age")) - >>> df.select(array('age', 'age').alias("arr")).collect() + >>> df.select(array("age", "age").alias("arr")).collect() [Row(arr=[2, 2]), Row(arr=[5, 5])] >>> df.select(array([df.age, df.age]).alias("arr")).collect() [Row(arr=[2, 2]), Row(arr=[5, 5])] - >>> df.select(array('age', 'age').alias("col")).printSchema() + >>> df.select(array("age", "age").alias("col")).printSchema() root |-- col: array (nullable = false) | |-- element: long (containsNull = true) @@ -167,6 +164,7 @@ def _to_column_expr(col: ColumnOrName) -> Expression: message_parameters={"arg_name": "col", "arg_type": type(col).__name__}, ) + def regexp_replace(str: "ColumnOrName", pattern: str, replacement: str) -> Column: r"""Replace all substrings of the specified string value that match regexp with rep. @@ -174,8 +172,8 @@ def regexp_replace(str: "ColumnOrName", pattern: str, replacement: str) -> Colum Examples -------- - >>> df = spark.createDataFrame([('100-200',)], ['str']) - >>> df.select(regexp_replace('str', r'(\d+)', '--').alias('d')).collect() + >>> df = spark.createDataFrame([("100-200",)], ["str"]) + >>> df.select(regexp_replace("str", r"(\d+)", "--").alias("d")).collect() [Row(d='-----')] """ return _invoke_function( @@ -187,9 +185,7 @@ def regexp_replace(str: "ColumnOrName", pattern: str, replacement: str) -> Colum ) -def slice( - x: "ColumnOrName", start: Union["ColumnOrName", int], length: Union["ColumnOrName", int] -) -> Column: +def slice(x: "ColumnOrName", start: Union["ColumnOrName", int], length: Union["ColumnOrName", int]) -> Column: """ Collection function: returns an array containing all the elements in `x` from index `start` (array indices start at 1, or from the end if `start` is negative) with the specified `length`. @@ -215,7 +211,7 @@ def slice( Examples -------- - >>> df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ['x']) + >>> df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ["x"]) >>> df.select(slice(df.x, 2, 2).alias("sliced")).collect() [Row(sliced=[2, 3]), Row(sliced=[5])] """ @@ -301,9 +297,7 @@ def asc_nulls_first(col: "ColumnOrName") -> Column: Examples -------- - >>> df1 = spark.createDataFrame([(1, "Bob"), - ... (0, None), - ... (2, "Alice")], ["age", "name"]) + >>> df1 = spark.createDataFrame([(1, "Bob"), (0, None), (2, "Alice")], ["age", "name"]) >>> df1.sort(asc_nulls_first(df1.name)).show() +---+-----+ |age| name| @@ -339,9 +333,7 @@ def asc_nulls_last(col: "ColumnOrName") -> Column: Examples -------- - >>> df1 = spark.createDataFrame([(0, None), - ... (1, "Bob"), - ... (2, "Alice")], ["age", "name"]) + >>> df1 = spark.createDataFrame([(0, None), (1, "Bob"), (2, "Alice")], ["age", "name"]) >>> df1.sort(asc_nulls_last(df1.name)).show() +---+-----+ |age| name| @@ -414,9 +406,7 @@ def desc_nulls_first(col: "ColumnOrName") -> Column: Examples -------- - >>> df1 = spark.createDataFrame([(0, None), - ... (1, "Bob"), - ... (2, "Alice")], ["age", "name"]) + >>> df1 = spark.createDataFrame([(0, None), (1, "Bob"), (2, "Alice")], ["age", "name"]) >>> df1.sort(desc_nulls_first(df1.name)).show() +---+-----+ |age| name| @@ -452,9 +442,7 @@ def desc_nulls_last(col: "ColumnOrName") -> Column: Examples -------- - >>> df1 = spark.createDataFrame([(0, None), - ... (1, "Bob"), - ... (2, "Alice")], ["age", "name"]) + >>> df1 = spark.createDataFrame([(0, None), (1, "Bob"), (2, "Alice")], ["age", "name"]) >>> df1.sort(desc_nulls_last(df1.name)).show() +---+-----+ |age| name| @@ -484,16 +472,22 @@ def left(str: "ColumnOrName", len: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([("Spark SQL", 3,)], ['a', 'b']) - >>> df.select(left(df.a, df.b).alias('r')).collect() + >>> df = spark.createDataFrame( + ... [ + ... ( + ... "Spark SQL", + ... 3, + ... ) + ... ], + ... ["a", "b"], + ... ) + >>> df.select(left(df.a, df.b).alias("r")).collect() [Row(r='Spa')] """ len = _to_column_expr(len) return Column( CaseExpression(len <= ConstantExpression(0), ConstantExpression("")).otherwise( - FunctionExpression( - "array_slice", _to_column_expr(str), ConstantExpression(0), len - ) + FunctionExpression("array_slice", _to_column_expr(str), ConstantExpression(0), len) ) ) @@ -514,23 +508,27 @@ def right(str: "ColumnOrName", len: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([("Spark SQL", 3,)], ['a', 'b']) - >>> df.select(right(df.a, df.b).alias('r')).collect() + >>> df = spark.createDataFrame( + ... [ + ... ( + ... "Spark SQL", + ... 3, + ... ) + ... ], + ... ["a", "b"], + ... ) + >>> df.select(right(df.a, df.b).alias("r")).collect() [Row(r='SQL')] """ len = _to_column_expr(len) return Column( CaseExpression(len <= ConstantExpression(0), ConstantExpression("")).otherwise( - FunctionExpression( - "array_slice", _to_column_expr(str), -len, ConstantExpression(-1) - ) + FunctionExpression("array_slice", _to_column_expr(str), -len, ConstantExpression(-1)) ) ) -def levenshtein( - left: "ColumnOrName", right: "ColumnOrName", threshold: Optional[int] = None -) -> Column: +def levenshtein(left: "ColumnOrName", right: "ColumnOrName", threshold: Optional[int] = None) -> Column: """Computes the Levenshtein distance of the two given strings. .. versionadded:: 1.5.0 @@ -558,10 +556,18 @@ def levenshtein( Examples -------- - >>> df0 = spark.createDataFrame([('kitten', 'sitting',)], ['l', 'r']) - >>> df0.select(levenshtein('l', 'r').alias('d')).collect() + >>> df0 = spark.createDataFrame( + ... [ + ... ( + ... "kitten", + ... "sitting", + ... ) + ... ], + ... ["l", "r"], + ... ) + >>> df0.select(levenshtein("l", "r").alias("d")).collect() [Row(d=3)] - >>> df0.select(levenshtein('l', 'r', 2).alias('d')).collect() + >>> df0.select(levenshtein("l", "r", 2).alias("d")).collect() [Row(d=-1)] """ distance = _invoke_function_over_columns("levenshtein", left, right) @@ -569,7 +575,9 @@ def levenshtein( return distance else: distance = _to_column_expr(distance) - return Column(CaseExpression(distance <= ConstantExpression(threshold), distance).otherwise(ConstantExpression(-1))) + return Column( + CaseExpression(distance <= ConstantExpression(threshold), distance).otherwise(ConstantExpression(-1)) + ) def lpad(col: "ColumnOrName", len: int, pad: str) -> Column: @@ -597,8 +605,13 @@ def lpad(col: "ColumnOrName", len: int, pad: str) -> Column: Examples -------- - >>> df = spark.createDataFrame([('abcd',)], ['s',]) - >>> df.select(lpad(df.s, 6, '#').alias('s')).collect() + >>> df = spark.createDataFrame( + ... [("abcd",)], + ... [ + ... "s", + ... ], + ... ) + >>> df.select(lpad(df.s, 6, "#").alias("s")).collect() [Row(s='##abcd')] """ return _invoke_function("lpad", _to_column_expr(col), ConstantExpression(len), ConstantExpression(pad)) @@ -629,8 +642,13 @@ def rpad(col: "ColumnOrName", len: int, pad: str) -> Column: Examples -------- - >>> df = spark.createDataFrame([('abcd',)], ['s',]) - >>> df.select(rpad(df.s, 6, '#').alias('s')).collect() + >>> df = spark.createDataFrame( + ... [("abcd",)], + ... [ + ... "s", + ... ], + ... ) + >>> df.select(rpad(df.s, 6, "#").alias("s")).collect() [Row(s='abcd##')] """ return _invoke_function("rpad", _to_column_expr(col), ConstantExpression(len), ConstantExpression(pad)) @@ -702,12 +720,14 @@ def asin(col: "ColumnOrName") -> Column: """ col = _to_column_expr(col) # FIXME: ConstantExpression(float("nan")) gives NULL and not NaN - return Column(CaseExpression((col < -1.0) | (col > 1.0), ConstantExpression(float("nan"))).otherwise(FunctionExpression("asin", col))) + return Column( + CaseExpression((col < -1.0) | (col > 1.0), ConstantExpression(float("nan"))).otherwise( + FunctionExpression("asin", col) + ) + ) -def like( - str: "ColumnOrName", pattern: "ColumnOrName", escapeChar: Optional["Column"] = None -) -> Column: +def like(str: "ColumnOrName", pattern: "ColumnOrName", escapeChar: Optional["Column"] = None) -> Column: """ Returns true if str matches `pattern` with `escape`, null if any arguments are null, false otherwise. @@ -728,15 +748,14 @@ def like( Examples -------- - >>> df = spark.createDataFrame([("Spark", "_park")], ['a', 'b']) - >>> df.select(like(df.a, df.b).alias('r')).collect() + >>> df = spark.createDataFrame([("Spark", "_park")], ["a", "b"]) + >>> df.select(like(df.a, df.b).alias("r")).collect() [Row(r=True)] >>> df = spark.createDataFrame( - ... [("%SystemDrive%/Users/John", "/%SystemDrive/%//Users%")], - ... ['a', 'b'] + ... [("%SystemDrive%/Users/John", "/%SystemDrive/%//Users%")], ["a", "b"] ... ) - >>> df.select(like(df.a, df.b, lit('/')).alias('r')).collect() + >>> df.select(like(df.a, df.b, lit("/")).alias("r")).collect() [Row(r=True)] """ if escapeChar is None: @@ -746,9 +765,7 @@ def like( return _invoke_function("like_escape", _to_column_expr(str), _to_column_expr(pattern), escapeChar) -def ilike( - str: "ColumnOrName", pattern: "ColumnOrName", escapeChar: Optional["Column"] = None -) -> Column: +def ilike(str: "ColumnOrName", pattern: "ColumnOrName", escapeChar: Optional["Column"] = None) -> Column: """ Returns true if str matches `pattern` with `escape` case-insensitively, null if any arguments are null, false otherwise. @@ -769,15 +786,14 @@ def ilike( Examples -------- - >>> df = spark.createDataFrame([("Spark", "_park")], ['a', 'b']) - >>> df.select(ilike(df.a, df.b).alias('r')).collect() + >>> df = spark.createDataFrame([("Spark", "_park")], ["a", "b"]) + >>> df.select(ilike(df.a, df.b).alias("r")).collect() [Row(r=True)] >>> df = spark.createDataFrame( - ... [("%SystemDrive%/Users/John", "/%SystemDrive/%//Users%")], - ... ['a', 'b'] + ... [("%SystemDrive%/Users/John", "/%SystemDrive/%//Users%")], ["a", "b"] ... ) - >>> df.select(ilike(df.a, df.b, lit('/')).alias('r')).collect() + >>> df.select(ilike(df.a, df.b, lit("/")).alias("r")).collect() [Row(r=True)] """ if escapeChar is None: @@ -805,8 +821,8 @@ def array_agg(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([[1],[1],[2]], ["c"]) - >>> df.agg(array_agg('c').alias('r')).collect() + >>> df = spark.createDataFrame([[1], [1], [2]], ["c"]) + >>> df.agg(array_agg("c").alias("r")).collect() [Row(r=[1, 1, 2])] """ return _invoke_function_over_columns("list", col) @@ -838,8 +854,8 @@ def collect_list(col: "ColumnOrName") -> Column: Examples -------- - >>> df2 = spark.createDataFrame([(2,), (5,), (5,)], ('age',)) - >>> df2.agg(collect_list('age')).collect() + >>> df2 = spark.createDataFrame([(2,), (5,), (5,)], ("age",)) + >>> df2.agg(collect_list("age")).collect() [Row(collect_list(age)=[2, 5, 5])] """ return array_agg(col) @@ -874,15 +890,13 @@ def array_append(col: "ColumnOrName", value: Any) -> Column: >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2="c")]) >>> df.select(array_append(df.c1, df.c2)).collect() [Row(array_append(c1, c2)=['b', 'a', 'c', 'c'])] - >>> df.select(array_append(df.c1, 'x')).collect() + >>> df.select(array_append(df.c1, "x")).collect() [Row(array_append(c1, x)=['b', 'a', 'c', 'x'])] """ return _invoke_function("list_append", _to_column_expr(col), _get_expr(value)) -def array_insert( - arr: "ColumnOrName", pos: Union["ColumnOrName", int], value: Any -) -> Column: +def array_insert(arr: "ColumnOrName", pos: Union["ColumnOrName", int], value: Any) -> Column: """ Collection function: adds an item into a given array at a specified array index. Array indices start at 1, or start from the end if index is negative. @@ -913,12 +927,11 @@ def array_insert( Examples -------- >>> df = spark.createDataFrame( - ... [(['a', 'b', 'c'], 2, 'd'), (['c', 'b', 'a'], -2, 'd')], - ... ['data', 'pos', 'val'] + ... [(["a", "b", "c"], 2, "d"), (["c", "b", "a"], -2, "d")], ["data", "pos", "val"] ... ) - >>> df.select(array_insert(df.data, df.pos.cast('integer'), df.val).alias('data')).collect() + >>> df.select(array_insert(df.data, df.pos.cast("integer"), df.val).alias("data")).collect() [Row(data=['a', 'd', 'b', 'c']), Row(data=['c', 'b', 'd', 'a'])] - >>> df.select(array_insert(df.data, 5, 'hello').alias('data')).collect() + >>> df.select(array_insert(df.data, 5, "hello").alias("data")).collect() [Row(data=['a', 'b', 'c', None, 'hello']), Row(data=['c', 'b', 'a', None, 'hello'])] """ pos = _get_expr(pos) @@ -944,9 +957,7 @@ def array_insert( FunctionExpression( "list_resize", FunctionExpression("list_value", None), - FunctionExpression( - "subtract", FunctionExpression("abs", pos), list_length_plus_1 - ), + FunctionExpression("subtract", FunctionExpression("abs", pos), list_length_plus_1), ), arr, ), @@ -964,9 +975,7 @@ def array_insert( "list_slice", list_, 1, - CaseExpression( - pos_is_positive, FunctionExpression("subtract", pos, 1) - ).otherwise(pos), + CaseExpression(pos_is_positive, FunctionExpression("subtract", pos, 1)).otherwise(pos), ), # Here we insert the value at the specified position FunctionExpression("list_value", _get_expr(value)), @@ -975,9 +984,7 @@ def array_insert( FunctionExpression( "list_slice", list_, - CaseExpression(pos_is_positive, pos).otherwise( - FunctionExpression("add", pos, 1) - ), + CaseExpression(pos_is_positive, pos).otherwise(FunctionExpression("add", pos, 1)), -1, ), ) @@ -1002,7 +1009,7 @@ def array_contains(col: "ColumnOrName", value: Any) -> Column: Examples -------- - >>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data']) + >>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ["data"]) >>> df.select(array_contains(df.data, "a")).collect() [Row(array_contains(data, a)=True), Row(array_contains(data, a)=False)] >>> df.select(array_contains(df.data, lit("a"))).collect() @@ -1033,7 +1040,7 @@ def array_distinct(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([([1, 2, 3, 2],), ([4, 5, 5, 4],)], ['data']) + >>> df = spark.createDataFrame([([1, 2, 3, 2],), ([4, 5, 5, 4],)], ["data"]) >>> df.select(array_distinct(df.data)).collect() [Row(array_distinct(data)=[1, 2, 3]), Row(array_distinct(data)=[4, 5])] """ @@ -1125,11 +1132,13 @@ def array_max(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data']) - >>> df.select(array_max(df.data).alias('max')).collect() + >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ["data"]) + >>> df.select(array_max(df.data).alias("max")).collect() [Row(max=3), Row(max=10)] """ - return _invoke_function("array_extract", _to_column_expr(_invoke_function_over_columns("array_sort", col)), _get_expr(-1)) + return _invoke_function( + "array_extract", _to_column_expr(_invoke_function_over_columns("array_sort", col)), _get_expr(-1) + ) def array_min(col: "ColumnOrName") -> Column: @@ -1153,11 +1162,13 @@ def array_min(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data']) - >>> df.select(array_min(df.data).alias('min')).collect() + >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ["data"]) + >>> df.select(array_min(df.data).alias("min")).collect() [Row(min=1), Row(min=-1)] """ - return _invoke_function("array_extract", _to_column_expr(_invoke_function_over_columns("array_sort", col)), _get_expr(1)) + return _invoke_function( + "array_extract", _to_column_expr(_invoke_function_over_columns("array_sort", col)), _get_expr(1) + ) def avg(col: "ColumnOrName") -> Column: @@ -1311,11 +1322,17 @@ def median(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([ - ... ("Java", 2012, 20000), ("dotNET", 2012, 5000), - ... ("Java", 2012, 22000), ("dotNET", 2012, 10000), - ... ("dotNET", 2013, 48000), ("Java", 2013, 30000)], - ... schema=("course", "year", "earnings")) + >>> df = spark.createDataFrame( + ... [ + ... ("Java", 2012, 20000), + ... ("dotNET", 2012, 5000), + ... ("Java", 2012, 22000), + ... ("dotNET", 2012, 10000), + ... ("dotNET", 2013, 48000), + ... ("Java", 2013, 30000), + ... ], + ... schema=("course", "year", "earnings"), + ... ) >>> df.groupby("course").agg(median("earnings")).show() +------+----------------+ |course|median(earnings)| @@ -1349,11 +1366,17 @@ def mode(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([ - ... ("Java", 2012, 20000), ("dotNET", 2012, 5000), - ... ("Java", 2012, 20000), ("dotNET", 2012, 5000), - ... ("dotNET", 2013, 48000), ("Java", 2013, 30000)], - ... schema=("course", "year", "earnings")) + >>> df = spark.createDataFrame( + ... [ + ... ("Java", 2012, 20000), + ... ("dotNET", 2012, 5000), + ... ("Java", 2012, 20000), + ... ("dotNET", 2012, 5000), + ... ("dotNET", 2013, 48000), + ... ("Java", 2013, 30000), + ... ], + ... schema=("course", "year", "earnings"), + ... ) >>> df.groupby("course").agg(mode("year")).show() +------+----------+ |course|mode(year)| @@ -1416,14 +1439,12 @@ def any_value(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([(None, 1), - ... ("a", 2), - ... ("a", 3), - ... ("b", 8), - ... ("b", 2)], ["c1", "c2"]) - >>> df.select(any_value('c1'), any_value('c2')).collect() + >>> df = spark.createDataFrame( + ... [(None, 1), ("a", 2), ("a", 3), ("b", 8), ("b", 2)], ["c1", "c2"] + ... ) + >>> df.select(any_value("c1"), any_value("c2")).collect() [Row(any_value(c1)=None, any_value(c2)=1)] - >>> df.select(any_value('c1', True), any_value('c2', True)).collect() + >>> df.select(any_value("c1", True), any_value("c2", True)).collect() [Row(any_value(c1)='a', any_value(c2)=1)] """ return _invoke_function_over_columns("any_value", col) @@ -1486,8 +1507,8 @@ def approx_count_distinct(col: "ColumnOrName", rsd: Optional[float] = None) -> C Examples -------- - >>> df = spark.createDataFrame([1,2,2,3], "INT") - >>> df.agg(approx_count_distinct("value").alias('distinct_values')).show() + >>> df = spark.createDataFrame([1, 2, 2, 3], "INT") + >>> df.agg(approx_count_distinct("value").alias("distinct_values")).show() +---------------+ |distinct_values| +---------------+ @@ -1567,7 +1588,6 @@ def transform( >>> def alternate(x, i): ... return when(i % 2 == 0, x).otherwise(-x) - ... >>> df.select(transform("values", alternate).alias("alternated")).show() +--------------+ | alternated| @@ -1602,8 +1622,8 @@ def concat_ws(sep: str, *cols: "ColumnOrName") -> "Column": Examples -------- - >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd']) - >>> df.select(concat_ws('-', df.s, df.d).alias('s')).collect() + >>> df = spark.createDataFrame([("abcd", "123")], ["s", "d"]) + >>> df.select(concat_ws("-", df.s, df.d).alias("s")).collect() [Row(s='abcd-123')] """ cols = [_to_column_expr(expr) for expr in cols] @@ -1788,7 +1808,7 @@ def isnan(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([(1.0, float('nan')), (float('nan'), 2.0)], ("a", "b")) + >>> df = spark.createDataFrame([(1.0, float("nan")), (float("nan"), 2.0)], ("a", "b")) >>> df.select("a", "b", isnan("a").alias("r1"), isnan(df.b).alias("r2")).show() +---+---+-----+-----+ | a| b| r1| r2| @@ -1845,7 +1865,7 @@ def isnotnull(col: "ColumnOrName") -> Column: Examples -------- >>> df = spark.createDataFrame([(None,), (1,)], ["e"]) - >>> df.select(isnotnull(df.e).alias('r')).collect() + >>> df.select(isnotnull(df.e).alias("r")).collect() [Row(r=False), Row(r=True)] """ return Column(_to_column_expr(col).isnotnull()) @@ -1862,8 +1882,20 @@ def equal_null(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: col2 : :class:`~pyspark.sql.Column` or str Examples -------- - >>> df = spark.createDataFrame([(None, None,), (1, 9,)], ["a", "b"]) - >>> df.select(equal_null(df.a, df.b).alias('r')).collect() + >>> df = spark.createDataFrame( + ... [ + ... ( + ... None, + ... None, + ... ), + ... ( + ... 1, + ... 9, + ... ), + ... ], + ... ["a", "b"], + ... ) + >>> df.select(equal_null(df.a, df.b).alias("r")).collect() [Row(r=True), Row(r=False)] """ if isinstance(col1, str): @@ -1872,7 +1904,7 @@ def equal_null(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: if isinstance(col2, str): col2 = col(col2) - return nvl((col1 == col2) | ((col1.isNull() & col2.isNull())), lit(False)) + return nvl((col1 == col2) | (col1.isNull() & col2.isNull()), lit(False)) def flatten(col: "ColumnOrName") -> Column: @@ -1898,7 +1930,7 @@ def flatten(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],), ([None, [4, 5]],)], ['data']) + >>> df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],), ([None, [4, 5]],)], ["data"]) >>> df.show(truncate=False) +------------------------+ |data | @@ -1906,7 +1938,7 @@ def flatten(col: "ColumnOrName") -> Column: |[[1, 2, 3], [4, 5], [6]]| |[NULL, [4, 5]] | +------------------------+ - >>> df.select(flatten(df.data).alias('r')).show() + >>> df.select(flatten(df.data).alias("r")).show() +------------------+ | r| +------------------+ @@ -1916,11 +1948,7 @@ def flatten(col: "ColumnOrName") -> Column: """ col = _to_column_expr(col) contains_null = _list_contains_null(col) - return Column( - CaseExpression(contains_null, None).otherwise( - FunctionExpression("flatten", col) - ) - ) + return Column(CaseExpression(contains_null, None).otherwise(FunctionExpression("flatten", col))) def array_compact(col: "ColumnOrName") -> Column: @@ -1945,7 +1973,7 @@ def array_compact(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([([1, None, 2, 3],), ([4, 5, None, 4],)], ['data']) + >>> df = spark.createDataFrame([([1, None, 2, 3],), ([4, 5, None, 4],)], ["data"]) >>> df.select(array_compact(df.data)).collect() [Row(array_compact(data)=[1, 2, 3]), Row(array_compact(data)=[4, 5, 4])] """ @@ -1977,11 +2005,13 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column: Examples -------- - >>> df = spark.createDataFrame([([1, 2, 3, 1, 1],), ([],)], ['data']) + >>> df = spark.createDataFrame([([1, 2, 3, 1, 1],), ([],)], ["data"]) >>> df.select(array_remove(df.data, 1)).collect() [Row(array_remove(data, 1)=[2, 3]), Row(array_remove(data, 1)=[])] """ - return _invoke_function("list_filter", _to_column_expr(col), LambdaExpression("x", ColumnExpression("x") != ConstantExpression(element))) + return _invoke_function( + "list_filter", _to_column_expr(col), LambdaExpression("x", ColumnExpression("x") != ConstantExpression(element)) + ) def last_day(date: "ColumnOrName") -> Column: @@ -2005,14 +2035,13 @@ def last_day(date: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([('1997-02-10',)], ['d']) - >>> df.select(last_day(df.d).alias('date')).collect() + >>> df = spark.createDataFrame([("1997-02-10",)], ["d"]) + >>> df.select(last_day(df.d).alias("date")).collect() [Row(date=datetime.date(1997, 2, 28))] """ return _invoke_function("last_day", _to_column_expr(date)) - def sqrt(col: "ColumnOrName") -> Column: """ Computes the square root of the specified float value. @@ -2129,7 +2158,7 @@ def corr(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: >>> a = range(20) >>> b = [2 * x for x in range(20)] >>> df = spark.createDataFrame(zip(a, b), ["a", "b"]) - >>> df.agg(corr("a", "b").alias('c')).collect() + >>> df.agg(corr("a", "b").alias("c")).collect() [Row(c=1.0)] """ return _invoke_function_over_columns("corr", col1, col2) @@ -2243,7 +2272,7 @@ def positive(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([(-1,), (0,), (1,)], ['v']) + >>> df = spark.createDataFrame([(-1,), (0,), (1,)], ["v"]) >>> df.select(positive("v").alias("p")).show() +---+ | p| @@ -2303,7 +2332,14 @@ def printf(format: "ColumnOrName", *cols: "ColumnOrName") -> Column: -------- >>> import pyspark.sql.functions as sf >>> spark.createDataFrame( - ... [("aa%d%s", 123, "cc",)], ["a", "b", "c"] + ... [ + ... ( + ... "aa%d%s", + ... 123, + ... "cc", + ... ) + ... ], + ... ["a", "b", "c"], ... ).select(sf.printf("a", "b", "c")).show() +---------------+ |printf(a, b, c)| @@ -2335,9 +2371,9 @@ def product(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.range(1, 10).toDF('x').withColumn('mod3', col('x') % 3) - >>> prods = df.groupBy('mod3').agg(product('x').alias('product')) - >>> prods.orderBy('mod3').show() + >>> df = spark.range(1, 10).toDF("x").withColumn("mod3", col("x") % 3) + >>> prods = df.groupBy("mod3").agg(product("x").alias("product")) + >>> prods.orderBy("mod3").show() +----+-------+ |mod3|product| +----+-------+ @@ -2375,7 +2411,7 @@ def rand(seed: Optional[int] = None) -> Column: Examples -------- >>> from pyspark.sql import functions as sf - >>> spark.range(0, 2, 1, 1).withColumn('rand', sf.rand(seed=42) * 3).show() + >>> spark.range(0, 2, 1, 1).withColumn("rand", sf.rand(seed=42) * 3).show() +---+------------------+ | id| rand| +---+------------------+ @@ -2409,9 +2445,9 @@ def regexp(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: Examples -------- >>> import pyspark.sql.functions as sf - >>> spark.createDataFrame( - ... [("1a 2b 14m", r"(\d+)")], ["str", "regexp"] - ... ).select(sf.regexp('str', sf.lit(r'(\d+)'))).show() + >>> spark.createDataFrame([("1a 2b 14m", r"(\d+)")], ["str", "regexp"]).select( + ... sf.regexp("str", sf.lit(r"(\d+)")) + ... ).show() +------------------+ |REGEXP(str, (\d+))| +------------------+ @@ -2419,9 +2455,9 @@ def regexp(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: +------------------+ >>> import pyspark.sql.functions as sf - >>> spark.createDataFrame( - ... [("1a 2b 14m", r"(\d+)")], ["str", "regexp"] - ... ).select(sf.regexp('str', sf.lit(r'\d{2}b'))).show() + >>> spark.createDataFrame([("1a 2b 14m", r"(\d+)")], ["str", "regexp"]).select( + ... sf.regexp("str", sf.lit(r"\d{2}b")) + ... ).show() +-------------------+ |REGEXP(str, \d{2}b)| +-------------------+ @@ -2429,9 +2465,9 @@ def regexp(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: +-------------------+ >>> import pyspark.sql.functions as sf - >>> spark.createDataFrame( - ... [("1a 2b 14m", r"(\d+)")], ["str", "regexp"] - ... ).select(sf.regexp('str', sf.col("regexp"))).show() + >>> spark.createDataFrame([("1a 2b 14m", r"(\d+)")], ["str", "regexp"]).select( + ... sf.regexp("str", sf.col("regexp")) + ... ).show() +-------------------+ |REGEXP(str, regexp)| +-------------------+ @@ -2462,11 +2498,11 @@ def regexp_count(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: Examples -------- >>> df = spark.createDataFrame([("1a 2b 14m", r"\d+")], ["str", "regexp"]) - >>> df.select(regexp_count('str', lit(r'\d+')).alias('d')).collect() + >>> df.select(regexp_count("str", lit(r"\d+")).alias("d")).collect() [Row(d=3)] - >>> df.select(regexp_count('str', lit(r'mmm')).alias('d')).collect() + >>> df.select(regexp_count("str", lit(r"mmm")).alias("d")).collect() [Row(d=0)] - >>> df.select(regexp_count("str", col("regexp")).alias('d')).collect() + >>> df.select(regexp_count("str", col("regexp")).alias("d")).collect() [Row(d=3)] """ return _invoke_function_over_columns("len", _invoke_function_over_columns("regexp_extract_all", str, regexp)) @@ -2497,22 +2533,22 @@ def regexp_extract(str: "ColumnOrName", pattern: str, idx: int) -> Column: Examples -------- - >>> df = spark.createDataFrame([('100-200',)], ['str']) - >>> df.select(regexp_extract('str', r'(\d+)-(\d+)', 1).alias('d')).collect() + >>> df = spark.createDataFrame([("100-200",)], ["str"]) + >>> df.select(regexp_extract("str", r"(\d+)-(\d+)", 1).alias("d")).collect() [Row(d='100')] - >>> df = spark.createDataFrame([('foo',)], ['str']) - >>> df.select(regexp_extract('str', r'(\d+)', 1).alias('d')).collect() + >>> df = spark.createDataFrame([("foo",)], ["str"]) + >>> df.select(regexp_extract("str", r"(\d+)", 1).alias("d")).collect() [Row(d='')] - >>> df = spark.createDataFrame([('aaaac',)], ['str']) - >>> df.select(regexp_extract('str', '(a+)(b)?(c)', 2).alias('d')).collect() + >>> df = spark.createDataFrame([("aaaac",)], ["str"]) + >>> df.select(regexp_extract("str", "(a+)(b)?(c)", 2).alias("d")).collect() [Row(d='')] """ - return _invoke_function("regexp_extract", _to_column_expr(str), ConstantExpression(pattern), ConstantExpression(idx)) + return _invoke_function( + "regexp_extract", _to_column_expr(str), ConstantExpression(pattern), ConstantExpression(idx) + ) -def regexp_extract_all( - str: "ColumnOrName", regexp: "ColumnOrName", idx: Optional[Union[int, Column]] = None -) -> Column: +def regexp_extract_all(str: "ColumnOrName", regexp: "ColumnOrName", idx: Optional[Union[int, Column]] = None) -> Column: r"""Extract all strings in the `str` that match the Java regex `regexp` and corresponding to the regex group index. @@ -2535,18 +2571,20 @@ def regexp_extract_all( Examples -------- >>> df = spark.createDataFrame([("100-200, 300-400", r"(\d+)-(\d+)")], ["str", "regexp"]) - >>> df.select(regexp_extract_all('str', lit(r'(\d+)-(\d+)')).alias('d')).collect() + >>> df.select(regexp_extract_all("str", lit(r"(\d+)-(\d+)")).alias("d")).collect() [Row(d=['100', '300'])] - >>> df.select(regexp_extract_all('str', lit(r'(\d+)-(\d+)'), 1).alias('d')).collect() + >>> df.select(regexp_extract_all("str", lit(r"(\d+)-(\d+)"), 1).alias("d")).collect() [Row(d=['100', '300'])] - >>> df.select(regexp_extract_all('str', lit(r'(\d+)-(\d+)'), 2).alias('d')).collect() + >>> df.select(regexp_extract_all("str", lit(r"(\d+)-(\d+)"), 2).alias("d")).collect() [Row(d=['200', '400'])] - >>> df.select(regexp_extract_all('str', col("regexp")).alias('d')).collect() + >>> df.select(regexp_extract_all("str", col("regexp")).alias("d")).collect() [Row(d=['100', '300'])] """ if idx is None: idx = 1 - return _invoke_function("regexp_extract_all", _to_column_expr(str), _to_column_expr(regexp), ConstantExpression(idx)) + return _invoke_function( + "regexp_extract_all", _to_column_expr(str), _to_column_expr(regexp), ConstantExpression(idx) + ) def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: @@ -2569,9 +2607,9 @@ def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: Examples -------- >>> import pyspark.sql.functions as sf - >>> spark.createDataFrame( - ... [("1a 2b 14m", r"(\d+)")], ["str", "regexp"] - ... ).select(sf.regexp_like('str', sf.lit(r'(\d+)'))).show() + >>> spark.createDataFrame([("1a 2b 14m", r"(\d+)")], ["str", "regexp"]).select( + ... sf.regexp_like("str", sf.lit(r"(\d+)")) + ... ).show() +-----------------------+ |REGEXP_LIKE(str, (\d+))| +-----------------------+ @@ -2579,9 +2617,9 @@ def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: +-----------------------+ >>> import pyspark.sql.functions as sf - >>> spark.createDataFrame( - ... [("1a 2b 14m", r"(\d+)")], ["str", "regexp"] - ... ).select(sf.regexp_like('str', sf.lit(r'\d{2}b'))).show() + >>> spark.createDataFrame([("1a 2b 14m", r"(\d+)")], ["str", "regexp"]).select( + ... sf.regexp_like("str", sf.lit(r"\d{2}b")) + ... ).show() +------------------------+ |REGEXP_LIKE(str, \d{2}b)| +------------------------+ @@ -2589,9 +2627,9 @@ def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: +------------------------+ >>> import pyspark.sql.functions as sf - >>> spark.createDataFrame( - ... [("1a 2b 14m", r"(\d+)")], ["str", "regexp"] - ... ).select(sf.regexp_like('str', sf.col("regexp"))).show() + >>> spark.createDataFrame([("1a 2b 14m", r"(\d+)")], ["str", "regexp"]).select( + ... sf.regexp_like("str", sf.col("regexp")) + ... ).show() +------------------------+ |REGEXP_LIKE(str, regexp)| +------------------------+ @@ -2622,14 +2660,20 @@ def regexp_substr(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: Examples -------- >>> df = spark.createDataFrame([("1a 2b 14m", r"\d+")], ["str", "regexp"]) - >>> df.select(regexp_substr('str', lit(r'\d+')).alias('d')).collect() + >>> df.select(regexp_substr("str", lit(r"\d+")).alias("d")).collect() [Row(d='1')] - >>> df.select(regexp_substr('str', lit(r'mmm')).alias('d')).collect() + >>> df.select(regexp_substr("str", lit(r"mmm")).alias("d")).collect() [Row(d=None)] - >>> df.select(regexp_substr("str", col("regexp")).alias('d')).collect() + >>> df.select(regexp_substr("str", col("regexp")).alias("d")).collect() [Row(d='1')] """ - return Column(FunctionExpression("nullif", FunctionExpression("regexp_extract", _to_column_expr(str), _to_column_expr(regexp)), ConstantExpression(""))) + return Column( + FunctionExpression( + "nullif", + FunctionExpression("regexp_extract", _to_column_expr(str), _to_column_expr(regexp)), + ConstantExpression(""), + ) + ) def repeat(col: "ColumnOrName", n: int) -> Column: @@ -2655,16 +2699,19 @@ def repeat(col: "ColumnOrName", n: int) -> Column: Examples -------- - >>> df = spark.createDataFrame([('ab',)], ['s',]) - >>> df.select(repeat(df.s, 3).alias('s')).collect() + >>> df = spark.createDataFrame( + ... [("ab",)], + ... [ + ... "s", + ... ], + ... ) + >>> df.select(repeat(df.s, 3).alias("s")).collect() [Row(s='ababab')] """ return _invoke_function("repeat", _to_column_expr(col), ConstantExpression(n)) -def sequence( - start: "ColumnOrName", stop: "ColumnOrName", step: Optional["ColumnOrName"] = None -) -> Column: +def sequence(start: "ColumnOrName", stop: "ColumnOrName", step: Optional["ColumnOrName"] = None) -> Column: """ Generate a sequence of integers from `start` to `stop`, incrementing by `step`. If `step` is not set, incrementing by 1 if `start` is less than or equal to `stop`, @@ -2691,11 +2738,11 @@ def sequence( Examples -------- - >>> df1 = spark.createDataFrame([(-2, 2)], ('C1', 'C2')) - >>> df1.select(sequence('C1', 'C2').alias('r')).collect() + >>> df1 = spark.createDataFrame([(-2, 2)], ("C1", "C2")) + >>> df1.select(sequence("C1", "C2").alias("r")).collect() [Row(r=[-2, -1, 0, 1, 2])] - >>> df2 = spark.createDataFrame([(4, -4, -2)], ('C1', 'C2', 'C3')) - >>> df2.select(sequence('C1', 'C2', 'C3').alias('r')).collect() + >>> df2 = spark.createDataFrame([(4, -4, -2)], ("C1", "C2", "C3")) + >>> df2.select(sequence("C1", "C2", "C3").alias("r")).collect() [Row(r=[4, 2, 0, -2, -4])] """ if step is None: @@ -2726,10 +2773,7 @@ def sign(col: "ColumnOrName") -> Column: Examples -------- >>> import pyspark.sql.functions as sf - >>> spark.range(1).select( - ... sf.sign(sf.lit(-5)), - ... sf.sign(sf.lit(6)) - ... ).show() + >>> spark.range(1).select(sf.sign(sf.lit(-5)), sf.sign(sf.lit(6))).show() +--------+-------+ |sign(-5)|sign(6)| +--------+-------+ @@ -2761,10 +2805,7 @@ def signum(col: "ColumnOrName") -> Column: Examples -------- >>> import pyspark.sql.functions as sf - >>> spark.range(1).select( - ... sf.signum(sf.lit(-5)), - ... sf.signum(sf.lit(6)) - ... ).show() + >>> spark.range(1).select(sf.signum(sf.lit(-5)), sf.signum(sf.lit(6))).show() +----------+---------+ |SIGNUM(-5)|SIGNUM(6)| +----------+---------+ @@ -2824,7 +2865,7 @@ def skewness(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([[1],[1],[2]], ["c"]) + >>> df = spark.createDataFrame([[1], [1], [2]], ["c"]) >>> df.select(skewness(df.c)).first() Row(skewness(c)=0.70710...) """ @@ -2855,7 +2896,7 @@ def encode(col: "ColumnOrName", charset: str) -> Column: Examples -------- - >>> df = spark.createDataFrame([('abcd',)], ['c']) + >>> df = spark.createDataFrame([("abcd",)], ["c"]) >>> df.select(encode("c", "UTF-8")).show() +----------------+ |encode(c, UTF-8)| @@ -2885,24 +2926,20 @@ def find_in_set(str: "ColumnOrName", str_array: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([("ab", "abc,b,ab,c,def")], ['a', 'b']) - >>> df.select(find_in_set(df.a, df.b).alias('r')).collect() + >>> df = spark.createDataFrame([("ab", "abc,b,ab,c,def")], ["a", "b"]) + >>> df.select(find_in_set(df.a, df.b).alias("r")).collect() [Row(r=3)] """ str_array = _to_column_expr(str_array) str = _to_column_expr(str) return Column( - CaseExpression( - FunctionExpression("contains", str, ConstantExpression(",")), 0 - ).otherwise( + CaseExpression(FunctionExpression("contains", str, ConstantExpression(",")), 0).otherwise( CoalesceOperator( FunctionExpression( - "list_position", - FunctionExpression("string_split", str_array, ConstantExpression(",")), - str + "list_position", FunctionExpression("string_split", str_array, ConstantExpression(",")), str ), # If the element cannot be found, list_position returns null but we want to return 0 - ConstantExpression(0) + ConstantExpression(0), ) ) ) @@ -3018,7 +3055,6 @@ def last(col: "ColumnOrName", ignorenulls: bool = False) -> Column: return _invoke_function_over_columns("last", col) - def greatest(*cols: "ColumnOrName") -> Column: """ Returns the greatest value of the list of column names, skipping null values. @@ -3041,7 +3077,7 @@ def greatest(*cols: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([(1, 4, 3)], ['a', 'b', 'c']) + >>> df = spark.createDataFrame([(1, 4, 3)], ["a", "b", "c"]) >>> df.select(greatest(df.a, df.b, df.c).alias("greatest")).collect() [Row(greatest=4)] """ @@ -3075,7 +3111,7 @@ def least(*cols: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([(1, 4, 3)], ['a', 'b', 'c']) + >>> df = spark.createDataFrame([(1, 4, 3)], ["a", "b", "c"]) >>> df.select(least(df.a, df.b, df.c).alias("least")).collect() [Row(least=1)] """ @@ -3203,12 +3239,20 @@ def btrim(str: "ColumnOrName", trim: Optional["ColumnOrName"] = None) -> Column: Examples -------- - >>> df = spark.createDataFrame([("SSparkSQLS", "SL", )], ['a', 'b']) - >>> df.select(btrim(df.a, df.b).alias('r')).collect() + >>> df = spark.createDataFrame( + ... [ + ... ( + ... "SSparkSQLS", + ... "SL", + ... ) + ... ], + ... ["a", "b"], + ... ) + >>> df.select(btrim(df.a, df.b).alias("r")).collect() [Row(r='parkSQ')] - >>> df = spark.createDataFrame([(" SparkSQL ",)], ['a']) - >>> df.select(btrim(df.a).alias('r')).collect() + >>> df = spark.createDataFrame([(" SparkSQL ",)], ["a"]) + >>> df.select(btrim(df.a).alias("r")).collect() [Row(r='SparkSQL')] """ if trim is not None: @@ -3234,11 +3278,27 @@ def endswith(str: "ColumnOrName", suffix: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([("Spark SQL", "Spark",)], ["a", "b"]) - >>> df.select(endswith(df.a, df.b).alias('r')).collect() + >>> df = spark.createDataFrame( + ... [ + ... ( + ... "Spark SQL", + ... "Spark", + ... ) + ... ], + ... ["a", "b"], + ... ) + >>> df.select(endswith(df.a, df.b).alias("r")).collect() [Row(r=False)] - >>> df = spark.createDataFrame([("414243", "4243",)], ["e", "f"]) + >>> df = spark.createDataFrame( + ... [ + ... ( + ... "414243", + ... "4243", + ... ) + ... ], + ... ["e", "f"], + ... ) >>> df = df.select(to_binary("e").alias("e"), to_binary("f").alias("f")) >>> df.printSchema() root @@ -3271,11 +3331,27 @@ def startswith(str: "ColumnOrName", prefix: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([("Spark SQL", "Spark",)], ["a", "b"]) - >>> df.select(startswith(df.a, df.b).alias('r')).collect() + >>> df = spark.createDataFrame( + ... [ + ... ( + ... "Spark SQL", + ... "Spark", + ... ) + ... ], + ... ["a", "b"], + ... ) + >>> df.select(startswith(df.a, df.b).alias("r")).collect() [Row(r=True)] - >>> df = spark.createDataFrame([("414243", "4142",)], ["e", "f"]) + >>> df = spark.createDataFrame( + ... [ + ... ( + ... "414243", + ... "4142", + ... ) + ... ], + ... ["e", "f"], + ... ) >>> df = df.select(to_binary("e").alias("e"), to_binary("f").alias("f")) >>> df.printSchema() root @@ -3313,7 +3389,7 @@ def length(col: "ColumnOrName") -> Column: Examples -------- - >>> spark.createDataFrame([('ABC ',)], ['a']).select(length('a').alias('length')).collect() + >>> spark.createDataFrame([("ABC ",)], ["a"]).select(length("a").alias("length")).collect() [Row(length=4)] """ return _invoke_function_over_columns("length", col) @@ -3351,7 +3427,7 @@ def coalesce(*cols: "ColumnOrName") -> Column: | 1| | 2| +--------------+ - >>> cDf.select('*', coalesce(cDf["a"], lit(0.0))).show() + >>> cDf.select("*", coalesce(cDf["a"], lit(0.0))).show() +----+----+----------------+ | a| b|coalesce(a, 0.0)| +----+----+----------------+ @@ -3375,8 +3451,20 @@ def nvl(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: col2 : :class:`~pyspark.sql.Column` or str Examples -------- - >>> df = spark.createDataFrame([(None, 8,), (1, 9,)], ["a", "b"]) - >>> df.select(nvl(df.a, df.b).alias('r')).collect() + >>> df = spark.createDataFrame( + ... [ + ... ( + ... None, + ... 8, + ... ), + ... ( + ... 1, + ... 9, + ... ), + ... ], + ... ["a", "b"], + ... ) + >>> df.select(nvl(df.a, df.b).alias("r")).collect() [Row(r=8), Row(r=1)] """ @@ -3397,8 +3485,22 @@ def nvl2(col1: "ColumnOrName", col2: "ColumnOrName", col3: "ColumnOrName") -> Co Examples -------- - >>> df = spark.createDataFrame([(None, 8, 6,), (1, 9, 9,)], ["a", "b", "c"]) - >>> df.select(nvl2(df.a, df.b, df.c).alias('r')).collect() + >>> df = spark.createDataFrame( + ... [ + ... ( + ... None, + ... 8, + ... 6, + ... ), + ... ( + ... 1, + ... 9, + ... 9, + ... ), + ... ], + ... ["a", "b", "c"], + ... ) + >>> df.select(nvl2(df.a, df.b, df.c).alias("r")).collect() [Row(r=6), Row(r=9)] """ col1 = _to_column_expr(col1) @@ -3443,8 +3545,20 @@ def nullif(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([(None, None,), (1, 9,)], ["a", "b"]) - >>> df.select(nullif(df.a, df.b).alias('r')).collect() + >>> df = spark.createDataFrame( + ... [ + ... ( + ... None, + ... None, + ... ), + ... ( + ... 1, + ... 9, + ... ), + ... ], + ... ["a", "b"], + ... ) + >>> df.select(nullif(df.a, df.b).alias("r")).collect() [Row(r=None), Row(r=1)] """ return _invoke_function_over_columns("nullif", col1, col2) @@ -3470,7 +3584,7 @@ def md5(col: "ColumnOrName") -> Column: Examples -------- - >>> spark.createDataFrame([('ABC',)], ['a']).select(md5('a').alias('hash')).collect() + >>> spark.createDataFrame([("ABC",)], ["a"]).select(md5("a").alias("hash")).collect() [Row(hash='902fbdd2b1df0c4f70b4a5d23525e932')] """ return _invoke_function_over_columns("md5", col) @@ -3517,9 +3631,7 @@ def sha2(col: "ColumnOrName", numBits: int) -> Column: if numBits == 256: return _invoke_function_over_columns("sha256", col) - raise ContributionsAcceptedError( - "SHA-224, SHA-384, and SHA-512 are not supported yet." - ) + raise ContributionsAcceptedError("SHA-224, SHA-384, and SHA-512 are not supported yet.") def curdate() -> Column: @@ -3537,7 +3649,7 @@ def curdate() -> Column: Examples -------- >>> import pyspark.sql.functions as sf - >>> spark.range(1).select(sf.curdate()).show() # doctest: +SKIP + >>> spark.range(1).select(sf.curdate()).show() # doctest: +SKIP +--------------+ |current_date()| +--------------+ @@ -3565,7 +3677,7 @@ def current_date() -> Column: Examples -------- >>> df = spark.range(1) - >>> df.select(current_date()).show() # doctest: +SKIP + >>> df.select(current_date()).show() # doctest: +SKIP +--------------+ |current_date()| +--------------+ @@ -3589,7 +3701,7 @@ def now() -> Column: Examples -------- >>> df = spark.range(1) - >>> df.select(now()).show(truncate=False) # doctest: +SKIP + >>> df.select(now()).show(truncate=False) # doctest: +SKIP +-----------------------+ |now() | +-----------------------+ @@ -3598,6 +3710,7 @@ def now() -> Column: """ return _invoke_function("now") + def desc(col: "ColumnOrName") -> Column: """ Returns a sort expression based on the descending order of the given column name. @@ -3634,6 +3747,7 @@ def desc(col: "ColumnOrName") -> Column: """ return Column(_to_column_expr(col).desc()) + def asc(col: "ColumnOrName") -> Column: """ Returns a sort expression based on the ascending order of the given column name. @@ -3685,6 +3799,7 @@ def asc(col: "ColumnOrName") -> Column: """ return Column(_to_column_expr(col).asc()) + def date_trunc(format: str, timestamp: "ColumnOrName") -> Column: """ Returns timestamp truncated to the unit specified by the format. @@ -3700,10 +3815,10 @@ def date_trunc(format: str, timestamp: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([('1997-02-28 05:02:11',)], ['t']) - >>> df.select(date_trunc('year', df.t).alias('year')).collect() + >>> df = spark.createDataFrame([("1997-02-28 05:02:11",)], ["t"]) + >>> df.select(date_trunc("year", df.t).alias("year")).collect() [Row(year=datetime.datetime(1997, 1, 1, 0, 0))] - >>> df.select(date_trunc('mon', df.t).alias('month')).collect() + >>> df.select(date_trunc("mon", df.t).alias("month")).collect() [Row(month=datetime.datetime(1997, 2, 1, 0, 0))] """ format = format.lower() @@ -3740,14 +3855,14 @@ def date_part(field: "ColumnOrName", source: "ColumnOrName") -> Column: Examples -------- >>> import datetime - >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ['ts']) + >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ["ts"]) >>> df.select( - ... date_part(lit('YEAR'), 'ts').alias('year'), - ... date_part(lit('month'), 'ts').alias('month'), - ... date_part(lit('WEEK'), 'ts').alias('week'), - ... date_part(lit('D'), 'ts').alias('day'), - ... date_part(lit('M'), 'ts').alias('minute'), - ... date_part(lit('S'), 'ts').alias('second') + ... date_part(lit("YEAR"), "ts").alias("year"), + ... date_part(lit("month"), "ts").alias("month"), + ... date_part(lit("WEEK"), "ts").alias("week"), + ... date_part(lit("D"), "ts").alias("day"), + ... date_part(lit("M"), "ts").alias("minute"), + ... date_part(lit("S"), "ts").alias("second"), ... ).collect() [Row(year=2015, month=4, week=15, day=8, minute=8, second=Decimal('15.000000'))] """ @@ -3775,14 +3890,14 @@ def extract(field: "ColumnOrName", source: "ColumnOrName") -> Column: Examples -------- >>> import datetime - >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ['ts']) + >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ["ts"]) >>> df.select( - ... extract(lit('YEAR'), 'ts').alias('year'), - ... extract(lit('month'), 'ts').alias('month'), - ... extract(lit('WEEK'), 'ts').alias('week'), - ... extract(lit('D'), 'ts').alias('day'), - ... extract(lit('M'), 'ts').alias('minute'), - ... extract(lit('S'), 'ts').alias('second') + ... extract(lit("YEAR"), "ts").alias("year"), + ... extract(lit("month"), "ts").alias("month"), + ... extract(lit("WEEK"), "ts").alias("week"), + ... extract(lit("D"), "ts").alias("day"), + ... extract(lit("M"), "ts").alias("minute"), + ... extract(lit("S"), "ts").alias("second"), ... ).collect() [Row(year=2015, month=4, week=15, day=8, minute=8, second=Decimal('15.000000'))] """ @@ -3811,14 +3926,14 @@ def datepart(field: "ColumnOrName", source: "ColumnOrName") -> Column: Examples -------- >>> import datetime - >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ['ts']) + >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ["ts"]) >>> df.select( - ... datepart(lit('YEAR'), 'ts').alias('year'), - ... datepart(lit('month'), 'ts').alias('month'), - ... datepart(lit('WEEK'), 'ts').alias('week'), - ... datepart(lit('D'), 'ts').alias('day'), - ... datepart(lit('M'), 'ts').alias('minute'), - ... datepart(lit('S'), 'ts').alias('second') + ... datepart(lit("YEAR"), "ts").alias("year"), + ... datepart(lit("month"), "ts").alias("month"), + ... datepart(lit("WEEK"), "ts").alias("week"), + ... datepart(lit("D"), "ts").alias("day"), + ... datepart(lit("M"), "ts").alias("minute"), + ... datepart(lit("S"), "ts").alias("second"), ... ).collect() [Row(year=2015, month=4, week=15, day=8, minute=8, second=Decimal('15.000000'))] """ @@ -3854,15 +3969,19 @@ def date_diff(end: "ColumnOrName", start: "ColumnOrName") -> Column: Examples -------- >>> import pyspark.sql.functions as sf - >>> df = spark.createDataFrame([('2015-04-08','2015-05-10')], ['d1', 'd2']) - >>> df.select('*', sf.date_diff(sf.col('d1').cast('DATE'), sf.col('d2').cast('DATE'))).show() + >>> df = spark.createDataFrame([("2015-04-08", "2015-05-10")], ["d1", "d2"]) + >>> df.select( + ... "*", sf.date_diff(sf.col("d1").cast("DATE"), sf.col("d2").cast("DATE")) + ... ).show() +----------+----------+-----------------+ | d1| d2|date_diff(d1, d2)| +----------+----------+-----------------+ |2015-04-08|2015-05-10| -32| +----------+----------+-----------------+ - >>> df.select('*', sf.date_diff(sf.col('d1').cast('DATE'), sf.col('d2').cast('DATE'))).show() + >>> df.select( + ... "*", sf.date_diff(sf.col("d1").cast("DATE"), sf.col("d2").cast("DATE")) + ... ).show() +----------+----------+-----------------+ | d1| d2|date_diff(d2, d1)| +----------+----------+-----------------+ @@ -3893,8 +4012,8 @@ def year(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) - >>> df.select(year('dt').alias('year')).collect() + >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) + >>> df.select(year("dt").alias("year")).collect() [Row(year=2015)] """ return _invoke_function_over_columns("year", col) @@ -3921,8 +4040,8 @@ def quarter(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) - >>> df.select(quarter('dt').alias('quarter')).collect() + >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) + >>> df.select(quarter("dt").alias("quarter")).collect() [Row(quarter=2)] """ return _invoke_function_over_columns("quarter", col) @@ -3949,8 +4068,8 @@ def month(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) - >>> df.select(month('dt').alias('month')).collect() + >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) + >>> df.select(month("dt").alias("month")).collect() [Row(month=4)] """ return _invoke_function_over_columns("month", col) @@ -3978,8 +4097,8 @@ def dayofweek(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) - >>> df.select(dayofweek('dt').alias('day')).collect() + >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) + >>> df.select(dayofweek("dt").alias("day")).collect() [Row(day=4)] """ return _invoke_function_over_columns("dayofweek", col) + lit(1) @@ -4003,8 +4122,8 @@ def day(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) - >>> df.select(day('dt').alias('day')).collect() + >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) + >>> df.select(day("dt").alias("day")).collect() [Row(day=8)] """ return _invoke_function_over_columns("day", col) @@ -4031,8 +4150,8 @@ def dayofmonth(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) - >>> df.select(dayofmonth('dt').alias('day')).collect() + >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) + >>> df.select(dayofmonth("dt").alias("day")).collect() [Row(day=8)] """ return day(col) @@ -4059,8 +4178,8 @@ def dayofyear(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) - >>> df.select(dayofyear('dt').alias('day')).collect() + >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) + >>> df.select(dayofyear("dt").alias("day")).collect() [Row(day=98)] """ return _invoke_function_over_columns("dayofyear", col) @@ -4088,8 +4207,8 @@ def hour(col: "ColumnOrName") -> Column: Examples -------- >>> import datetime - >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ['ts']) - >>> df.select(hour('ts').alias('hour')).collect() + >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ["ts"]) + >>> df.select(hour("ts").alias("hour")).collect() [Row(hour=13)] """ return _invoke_function_over_columns("hour", col) @@ -4117,8 +4236,8 @@ def minute(col: "ColumnOrName") -> Column: Examples -------- >>> import datetime - >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ['ts']) - >>> df.select(minute('ts').alias('minute')).collect() + >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ["ts"]) + >>> df.select(minute("ts").alias("minute")).collect() [Row(minute=8)] """ return _invoke_function_over_columns("minute", col) @@ -4146,8 +4265,8 @@ def second(col: "ColumnOrName") -> Column: Examples -------- >>> import datetime - >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ['ts']) - >>> df.select(second('ts').alias('second')).collect() + >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ["ts"]) + >>> df.select(second("ts").alias("second")).collect() [Row(second=15)] """ return _invoke_function_over_columns("second", col) @@ -4176,8 +4295,8 @@ def weekofyear(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) - >>> df.select(weekofyear(df.dt).alias('week')).collect() + >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) + >>> df.select(weekofyear(df.dt).alias("week")).collect() [Row(week=15)] """ return _invoke_function_over_columns("weekofyear", col) @@ -4267,7 +4386,7 @@ def call_function(funcName: str, *cols: "ColumnOrName") -> Column: -------- >>> from pyspark.sql.functions import call_udf, col >>> from pyspark.sql.types import IntegerType, StringType - >>> df = spark.createDataFrame([(1, "a"),(2, "b"), (3, "c")],["id", "name"]) + >>> df = spark.createDataFrame([(1, "a"), (2, "b"), (3, "c")], ["id", "name"]) >>> _ = spark.udf.register("intX2", lambda i: i * 2, IntegerType()) >>> df.select(call_function("intX2", "id")).show() +---------+ @@ -4338,7 +4457,7 @@ def covar_pop(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: >>> a = [1] * 10 >>> b = [1] * 10 >>> df = spark.createDataFrame(zip(a, b), ["a", "b"]) - >>> df.agg(covar_pop("a", "b").alias('c')).collect() + >>> df.agg(covar_pop("a", "b").alias("c")).collect() [Row(c=0.0)] """ return _invoke_function_over_columns("covar_pop", col1, col2) @@ -4370,7 +4489,7 @@ def covar_samp(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: >>> a = [1] * 10 >>> b = [1] * 10 >>> df = spark.createDataFrame(zip(a, b), ["a", "b"]) - >>> df.agg(covar_samp("a", "b").alias('c')).collect() + >>> df.agg(covar_samp("a", "b").alias("c")).collect() [Row(c=0.0)] """ return _invoke_function_over_columns("covar_samp", col1, col2) @@ -4429,8 +4548,8 @@ def factorial(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([(5,)], ['n']) - >>> df.select(factorial(df.n).alias('f')).collect() + >>> df = spark.createDataFrame([(5,)], ["n"]) + >>> df.select(factorial(df.n).alias("f")).collect() [Row(f=120)] """ return _invoke_function_over_columns("factorial", col) @@ -4456,8 +4575,8 @@ def log2(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([(4,)], ['a']) - >>> df.select(log2('a').alias('log2')).show() + >>> df = spark.createDataFrame([(4,)], ["a"]) + >>> df.select(log2("a").alias("log2")).show() +----+ |log2| +----+ @@ -4484,8 +4603,8 @@ def ln(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([(4,)], ['a']) - >>> df.select(ln('a')).show() + >>> df = spark.createDataFrame([(4,)], ["a"]) + >>> df.select(ln("a")).show() +------------------+ | ln(a)| +------------------+ @@ -4525,7 +4644,6 @@ def degrees(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("degrees", col) - def radians(col: "ColumnOrName") -> Column: """ Converts an angle measured in degrees to an approximately equivalent angle @@ -4616,10 +4734,12 @@ def atan2(col1: Union["ColumnOrName", float], col2: Union["ColumnOrName", float] >>> df.select(atan2(lit(1), lit(2))).first() Row(ATAN2(1, 2)=0.46364...) """ + def lit_or_column(x: Union["ColumnOrName", float]) -> Column: if isinstance(x, (int, float)): return lit(x) return x + return _invoke_function_over_columns("atan2", lit_or_column(col1), lit_or_column(col2)) @@ -4676,7 +4796,7 @@ def round(col: "ColumnOrName", scale: int = 0) -> Column: Examples -------- - >>> spark.createDataFrame([(2.5,)], ['a']).select(round('a', 0).alias('r')).collect() + >>> spark.createDataFrame([(2.5,)], ["a"]).select(round("a", 0).alias("r")).collect() [Row(r=3.0)] """ return _invoke_function_over_columns("round", col, lit(scale)) @@ -4706,7 +4826,7 @@ def bround(col: "ColumnOrName", scale: int = 0) -> Column: Examples -------- - >>> spark.createDataFrame([(2.5,)], ['a']).select(bround('a', 0).alias('r')).collect() + >>> spark.createDataFrame([(2.5,)], ["a"]).select(bround("a", 0).alias("r")).collect() [Row(r=2.0)] """ return _invoke_function_over_columns("round_even", col, lit(scale)) @@ -4743,7 +4863,7 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: Examples -------- - >>> df = spark.createDataFrame([(["a", "b", "c"], 1)], ['data', 'index']) + >>> df = spark.createDataFrame([(["a", "b", "c"], 1)], ["data", "index"]) >>> df.select(get(df.data, 1)).show() +------------+ |get(data, 1)| @@ -4806,7 +4926,7 @@ def initcap(col: "ColumnOrName") -> Column: Examples -------- - >>> spark.createDataFrame([('ab cd',)], ['a']).select(initcap("a").alias('v')).collect() + >>> spark.createDataFrame([("ab cd",)], ["a"]).select(initcap("a").alias("v")).collect() [Row(v='Ab Cd')] """ return Column( @@ -4814,18 +4934,14 @@ def initcap(col: "ColumnOrName") -> Column: "array_to_string", FunctionExpression( "list_transform", - FunctionExpression( - "string_split", _to_column_expr(col), ConstantExpression(" ") - ), + FunctionExpression("string_split", _to_column_expr(col), ConstantExpression(" ")), LambdaExpression( "x", FunctionExpression( "concat", FunctionExpression( "upper", - FunctionExpression( - "array_extract", ColumnExpression("x"), 1 - ), + FunctionExpression("array_extract", ColumnExpression("x"), 1), ), FunctionExpression("array_slice", ColumnExpression("x"), 2, -1), ), @@ -4858,7 +4974,7 @@ def octet_length(col: "ColumnOrName") -> Column: Examples -------- >>> from pyspark.sql.functions import octet_length - >>> spark.createDataFrame([('cat',), ( '\U0001F408',)], ['cat']) \\ + >>> spark.createDataFrame([('cat',), ( '\U0001f408',)], ['cat']) \\ ... .select(octet_length('cat')).collect() [Row(octet_length(cat)=3), Row(octet_length(cat)=4)] """ @@ -4886,7 +5002,7 @@ def hex(col: "ColumnOrName") -> Column: Examples -------- - >>> spark.createDataFrame([('ABC', 3)], ['a', 'b']).select(hex('a'), hex('b')).collect() + >>> spark.createDataFrame([("ABC", 3)], ["a", "b"]).select(hex("a"), hex("b")).collect() [Row(hex(a)='414243', hex(b)='3')] """ return _invoke_function_over_columns("hex", col) @@ -4913,7 +5029,7 @@ def unhex(col: "ColumnOrName") -> Column: Examples -------- - >>> spark.createDataFrame([('414243',)], ['a']).select(unhex('a')).collect() + >>> spark.createDataFrame([("414243",)], ["a"]).select(unhex("a")).collect() [Row(unhex(a)=bytearray(b'ABC'))] """ return _invoke_function_over_columns("unhex", col) @@ -4950,7 +5066,7 @@ def base64(col: "ColumnOrName") -> Column: |UGFuZGFzIEFQSQ==| +----------------+ """ - if isinstance(col,str): + if isinstance(col, str): col = Column(ColumnExpression(col)) return _invoke_function_over_columns("base64", col.cast("BLOB")) @@ -4976,9 +5092,7 @@ def unbase64(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame(["U3Bhcms=", - ... "UHlTcGFyaw==", - ... "UGFuZGFzIEFQSQ=="], "STRING") + >>> df = spark.createDataFrame(["U3Bhcms=", "UHlTcGFyaw==", "UGFuZGFzIEFQSQ=="], "STRING") >>> df.select(unbase64("value")).show() +--------------------+ | unbase64(value)| @@ -5016,21 +5130,19 @@ def add_months(start: "ColumnOrName", months: Union["ColumnOrName", int]) -> Col Examples -------- - >>> df = spark.createDataFrame([('2015-04-08', 2)], ['dt', 'add']) - >>> df.select(add_months(df.dt, 1).alias('next_month')).collect() + >>> df = spark.createDataFrame([("2015-04-08", 2)], ["dt", "add"]) + >>> df.select(add_months(df.dt, 1).alias("next_month")).collect() [Row(next_month=datetime.date(2015, 5, 8))] - >>> df.select(add_months(df.dt, df.add.cast('integer')).alias('next_month')).collect() + >>> df.select(add_months(df.dt, df.add.cast("integer")).alias("next_month")).collect() [Row(next_month=datetime.date(2015, 6, 8))] - >>> df.select(add_months('dt', -2).alias('prev_month')).collect() + >>> df.select(add_months("dt", -2).alias("prev_month")).collect() [Row(prev_month=datetime.date(2015, 2, 8))] """ months = ConstantExpression(months) if isinstance(months, int) else _to_column_expr(months) return _invoke_function("date_add", _to_column_expr(start), FunctionExpression("to_months", months)).cast("date") -def array_join( - col: "ColumnOrName", delimiter: str, null_replacement: Optional[str] = None -) -> Column: +def array_join(col: "ColumnOrName", delimiter: str, null_replacement: Optional[str] = None) -> Column: """ Concatenates the elements of `column` using the `delimiter`. Null values are replaced with `null_replacement` if set, otherwise they are ignored. @@ -5056,7 +5168,7 @@ def array_join( Examples -------- - >>> df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], ['data']) + >>> df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], ["data"]) >>> df.select(array_join(df.data, ",").alias("joined")).collect() [Row(joined='a,b,c'), Row(joined='a')] >>> df.select(array_join(df.data, ",", "NULL").alias("joined")).collect() @@ -5065,7 +5177,14 @@ def array_join( col = _to_column_expr(col) if null_replacement is not None: col = FunctionExpression( - "list_transform", col, LambdaExpression("x", CaseExpression(ColumnExpression("x").isnull(), ConstantExpression(null_replacement)).otherwise(ColumnExpression("x"))) + "list_transform", + col, + LambdaExpression( + "x", + CaseExpression(ColumnExpression("x").isnull(), ConstantExpression(null_replacement)).otherwise( + ColumnExpression("x") + ), + ), ) return _invoke_function("array_to_string", col, ConstantExpression(delimiter)) @@ -5099,11 +5218,15 @@ def array_position(col: "ColumnOrName", value: Any) -> Column: Examples -------- - >>> df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ['data']) + >>> df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ["data"]) >>> df.select(array_position(df.data, "a")).collect() [Row(array_position(data, a)=3), Row(array_position(data, a)=0)] """ - return Column(CoalesceOperator(_to_column_expr(_invoke_function_over_columns("list_position", col, lit(value))), ConstantExpression(0))) + return Column( + CoalesceOperator( + _to_column_expr(_invoke_function_over_columns("list_position", col, lit(value))), ConstantExpression(0) + ) + ) def array_prepend(col: "ColumnOrName", value: Any) -> Column: @@ -5128,7 +5251,7 @@ def array_prepend(col: "ColumnOrName", value: Any) -> Column: Examples -------- - >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) + >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ["data"]) >>> df.select(array_prepend(df.data, 1)).collect() [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] """ @@ -5158,8 +5281,8 @@ def array_repeat(col: "ColumnOrName", count: Union["ColumnOrName", int]) -> Colu Examples -------- - >>> df = spark.createDataFrame([('ab',)], ['data']) - >>> df.select(array_repeat(df.data, 3).alias('r')).collect() + >>> df = spark.createDataFrame([("ab",)], ["data"]) + >>> df.select(array_repeat(df.data, 3).alias("r")).collect() [Row(r=['ab', 'ab', 'ab'])] """ count = lit(count) if isinstance(count, int) else count @@ -5185,15 +5308,14 @@ def array_size(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([([2, 1, 3],), (None,)], ['data']) - >>> df.select(array_size(df.data).alias('r')).collect() + >>> df = spark.createDataFrame([([2, 1, 3],), (None,)], ["data"]) + >>> df.select(array_size(df.data).alias("r")).collect() [Row(r=3), Row(r=None)] """ return _invoke_function_over_columns("len", col) -def array_sort( - col: "ColumnOrName", comparator: Optional[Callable[[Column, Column], Column]] = None -) -> Column: + +def array_sort(col: "ColumnOrName", comparator: Optional[Callable[[Column, Column], Column]] = None) -> Column: """ Collection function: sorts the input array in ascending order. The elements of the input array must be orderable. Null elements will be placed at the end of the returned array. @@ -5224,14 +5346,20 @@ def array_sort( Examples -------- - >>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data']) - >>> df.select(array_sort(df.data).alias('r')).collect() + >>> df = spark.createDataFrame([([2, 1, None, 3],), ([1],), ([],)], ["data"]) + >>> df.select(array_sort(df.data).alias("r")).collect() [Row(r=[1, 2, 3, None]), Row(r=[1]), Row(r=[])] - >>> df = spark.createDataFrame([(["foo", "foobar", None, "bar"],),(["foo"],),([],)], ['data']) - >>> df.select(array_sort( - ... "data", - ... lambda x, y: when(x.isNull() | y.isNull(), lit(0)).otherwise(length(y) - length(x)) - ... ).alias("r")).collect() + >>> df = spark.createDataFrame( + ... [(["foo", "foobar", None, "bar"],), (["foo"],), ([],)], ["data"] + ... ) + >>> df.select( + ... array_sort( + ... "data", + ... lambda x, y: when(x.isNull() | y.isNull(), lit(0)).otherwise( + ... length(y) - length(x) + ... ), + ... ).alias("r") + ... ).collect() [Row(r=['foobar', 'foo', None, 'bar']), Row(r=['foo']), Row(r=[])] """ if comparator is not None: @@ -5267,10 +5395,10 @@ def sort_array(col: "ColumnOrName", asc: bool = True) -> Column: Examples -------- - >>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data']) - >>> df.select(sort_array(df.data).alias('r')).collect() + >>> df = spark.createDataFrame([([2, 1, None, 3],), ([1],), ([],)], ["data"]) + >>> df.select(sort_array(df.data).alias("r")).collect() [Row(r=[None, 1, 2, 3]), Row(r=[1]), Row(r=[])] - >>> df.select(sort_array(df.data, asc=False).alias('r')).collect() + >>> df.select(sort_array(df.data, asc=False).alias("r")).collect() [Row(r=[3, 2, 1, None]), Row(r=[1]), Row(r=[])] """ if asc: @@ -5317,10 +5445,15 @@ def split(str: "ColumnOrName", pattern: str, limit: int = -1) -> Column: Examples -------- - >>> df = spark.createDataFrame([('oneAtwoBthreeC',)], ['s',]) - >>> df.select(split(df.s, '[ABC]', 2).alias('s')).collect() + >>> df = spark.createDataFrame( + ... [("oneAtwoBthreeC",)], + ... [ + ... "s", + ... ], + ... ) + >>> df.select(split(df.s, "[ABC]", 2).alias("s")).collect() [Row(s=['one', 'twoBthreeC'])] - >>> df.select(split(df.s, '[ABC]', -1).alias('s')).collect() + >>> df.select(split(df.s, "[ABC]", -1).alias("s")).collect() [Row(s=['one', 'two', 'three', ''])] """ if limit > 0: @@ -5351,8 +5484,17 @@ def split_part(src: "ColumnOrName", delimiter: "ColumnOrName", partNum: "ColumnO Examples -------- - >>> df = spark.createDataFrame([("11.12.13", ".", 3,)], ["a", "b", "c"]) - >>> df.select(split_part(df.a, df.b, df.c).alias('r')).collect() + >>> df = spark.createDataFrame( + ... [ + ... ( + ... "11.12.13", + ... ".", + ... 3, + ... ) + ... ], + ... ["a", "b", "c"], + ... ) + >>> df.select(split_part(df.a, df.b, df.c).alias("r")).collect() [Row(r='13')] """ src = _to_column_expr(src) @@ -5360,7 +5502,11 @@ def split_part(src: "ColumnOrName", delimiter: "ColumnOrName", partNum: "ColumnO partNum = _to_column_expr(partNum) part = FunctionExpression("split_part", src, delimiter, partNum) - return Column(CaseExpression(src.isnull() | delimiter.isnull() | partNum.isnull(), ConstantExpression(None)).otherwise(CaseExpression(delimiter == ConstantExpression(""), ConstantExpression("")).otherwise(part))) + return Column( + CaseExpression(src.isnull() | delimiter.isnull() | partNum.isnull(), ConstantExpression(None)).otherwise( + CaseExpression(delimiter == ConstantExpression(""), ConstantExpression("")).otherwise(part) + ) + ) def stddev_samp(col: "ColumnOrName") -> Column: @@ -5427,6 +5573,7 @@ def stddev(col: "ColumnOrName") -> Column: """ return stddev_samp(col) + def std(col: "ColumnOrName") -> Column: """ Aggregate function: alias for stddev_samp. @@ -5600,8 +5747,8 @@ def weekday(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) - >>> df.select(weekday('dt').alias('day')).show() + >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) + >>> df.select(weekday("dt").alias("day")).show() +---+ |day| +---+ @@ -5634,6 +5781,7 @@ def zeroifnull(col: "ColumnOrName") -> Column: """ return coalesce(col, lit(0)) + def _to_date_or_timestamp(col: "ColumnOrName", spark_datatype: _types.DataType, format: Optional[str] = None) -> Column: if format is not None: raise ContributionsAcceptedError( @@ -5670,12 +5818,12 @@ def to_date(col: "ColumnOrName", format: Optional[str] = None) -> Column: Examples -------- - >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - >>> df.select(to_date(df.t).alias('date')).collect() + >>> df = spark.createDataFrame([("1997-02-28 10:30:00",)], ["t"]) + >>> df.select(to_date(df.t).alias("date")).collect() [Row(date=datetime.date(1997, 2, 28))] - >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - >>> df.select(to_date(df.t, 'yyyy-MM-dd HH:mm:ss').alias('date')).collect() + >>> df = spark.createDataFrame([("1997-02-28 10:30:00",)], ["t"]) + >>> df.select(to_date(df.t, "yyyy-MM-dd HH:mm:ss").alias("date")).collect() [Row(date=datetime.date(1997, 2, 28))] """ return _to_date_or_timestamp(col, _types.DateType(), format) @@ -5708,12 +5856,12 @@ def to_timestamp(col: "ColumnOrName", format: Optional[str] = None) -> Column: Examples -------- - >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - >>> df.select(to_timestamp(df.t).alias('dt')).collect() + >>> df = spark.createDataFrame([("1997-02-28 10:30:00",)], ["t"]) + >>> df.select(to_timestamp(df.t).alias("dt")).collect() [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))] - >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - >>> df.select(to_timestamp(df.t, 'yyyy-MM-dd HH:mm:ss').alias('dt')).collect() + >>> df = spark.createDataFrame([("1997-02-28 10:30:00",)], ["t"]) + >>> df.select(to_timestamp(df.t, "yyyy-MM-dd HH:mm:ss").alias("dt")).collect() [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))] """ return _to_date_or_timestamp(col, _types.TimestampNTZType(), format) @@ -5739,12 +5887,12 @@ def to_timestamp_ltz( Examples -------- >>> df = spark.createDataFrame([("2016-12-31",)], ["e"]) - >>> df.select(to_timestamp_ltz(df.e, lit("yyyy-MM-dd")).alias('r')).collect() + >>> df.select(to_timestamp_ltz(df.e, lit("yyyy-MM-dd")).alias("r")).collect() ... # doctest: +SKIP [Row(r=datetime.datetime(2016, 12, 31, 0, 0))] >>> df = spark.createDataFrame([("2016-12-31",)], ["e"]) - >>> df.select(to_timestamp_ltz(df.e).alias('r')).collect() + >>> df.select(to_timestamp_ltz(df.e).alias("r")).collect() ... # doctest: +SKIP [Row(r=datetime.datetime(2016, 12, 31, 0, 0))] """ @@ -5771,12 +5919,12 @@ def to_timestamp_ntz( Examples -------- >>> df = spark.createDataFrame([("2016-04-08",)], ["e"]) - >>> df.select(to_timestamp_ntz(df.e, lit("yyyy-MM-dd")).alias('r')).collect() + >>> df.select(to_timestamp_ntz(df.e, lit("yyyy-MM-dd")).alias("r")).collect() ... # doctest: +SKIP [Row(r=datetime.datetime(2016, 4, 8, 0, 0))] >>> df = spark.createDataFrame([("2016-04-08",)], ["e"]) - >>> df.select(to_timestamp_ntz(df.e).alias('r')).collect() + >>> df.select(to_timestamp_ntz(df.e).alias("r")).collect() ... # doctest: +SKIP [Row(r=datetime.datetime(2016, 4, 8, 0, 0))] """ @@ -5797,20 +5945,19 @@ def try_to_timestamp(col: "ColumnOrName", format: Optional["ColumnOrName"] = Non format to use to convert timestamp values. Examples -------- - >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - >>> df.select(try_to_timestamp(df.t).alias('dt')).collect() + >>> df = spark.createDataFrame([("1997-02-28 10:30:00",)], ["t"]) + >>> df.select(try_to_timestamp(df.t).alias("dt")).collect() [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))] - >>> df.select(try_to_timestamp(df.t, lit('yyyy-MM-dd HH:mm:ss')).alias('dt')).collect() + >>> df.select(try_to_timestamp(df.t, lit("yyyy-MM-dd HH:mm:ss")).alias("dt")).collect() [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))] """ if format is None: - format = lit(['%Y-%m-%d', '%Y-%m-%d %H:%M:%S']) + format = lit(["%Y-%m-%d", "%Y-%m-%d %H:%M:%S"]) return _invoke_function_over_columns("try_strptime", col, format) -def substr( - str: "ColumnOrName", pos: "ColumnOrName", len: Optional["ColumnOrName"] = None -) -> Column: + +def substr(str: "ColumnOrName", pos: "ColumnOrName", len: Optional["ColumnOrName"] = None) -> Column: """ Returns the substring of `str` that starts at `pos` and is of length `len`, or the slice of byte array that starts at `pos` and is of length `len`. @@ -5830,7 +5977,14 @@ def substr( -------- >>> import pyspark.sql.functions as sf >>> spark.createDataFrame( - ... [("Spark SQL", 5, 1,)], ["a", "b", "c"] + ... [ + ... ( + ... "Spark SQL", + ... 5, + ... 1, + ... ) + ... ], + ... ["a", "b", "c"], ... ).select(sf.substr("a", "b", "c")).show() +---------------+ |substr(a, b, c)| @@ -5840,7 +5994,14 @@ def substr( >>> import pyspark.sql.functions as sf >>> spark.createDataFrame( - ... [("Spark SQL", 5, 1,)], ["a", "b", "c"] + ... [ + ... ( + ... "Spark SQL", + ... 5, + ... 1, + ... ) + ... ], + ... ["a", "b", "c"], ... ).select(sf.substr("a", "b")).show() +------------------------+ |substr(a, b, 2147483647)| @@ -5855,7 +6016,10 @@ def substr( def _unix_diff(col: "ColumnOrName", part: str) -> Column: - return _invoke_function_over_columns("date_diff", lit(part), lit("1970-01-01 00:00:00+00:00").cast("timestamp"), col) + return _invoke_function_over_columns( + "date_diff", lit(part), lit("1970-01-01 00:00:00+00:00").cast("timestamp"), col + ) + def unix_date(col: "ColumnOrName") -> Column: """Returns the number of days since 1970-01-01. @@ -5865,8 +6029,8 @@ def unix_date(col: "ColumnOrName") -> Column: Examples -------- >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") - >>> df = spark.createDataFrame([('1970-01-02',)], ['t']) - >>> df.select(unix_date(to_date(df.t)).alias('n')).collect() + >>> df = spark.createDataFrame([("1970-01-02",)], ["t"]) + >>> df.select(unix_date(to_date(df.t)).alias("n")).collect() [Row(n=1)] >>> spark.conf.unset("spark.sql.session.timeZone") """ @@ -5881,8 +6045,8 @@ def unix_micros(col: "ColumnOrName") -> Column: Examples -------- >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") - >>> df = spark.createDataFrame([('2015-07-22 10:00:00',)], ['t']) - >>> df.select(unix_micros(to_timestamp(df.t)).alias('n')).collect() + >>> df = spark.createDataFrame([("2015-07-22 10:00:00",)], ["t"]) + >>> df.select(unix_micros(to_timestamp(df.t)).alias("n")).collect() [Row(n=1437584400000000)] >>> spark.conf.unset("spark.sql.session.timeZone") """ @@ -5898,8 +6062,8 @@ def unix_millis(col: "ColumnOrName") -> Column: Examples -------- >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") - >>> df = spark.createDataFrame([('2015-07-22 10:00:00',)], ['t']) - >>> df.select(unix_millis(to_timestamp(df.t)).alias('n')).collect() + >>> df = spark.createDataFrame([("2015-07-22 10:00:00",)], ["t"]) + >>> df.select(unix_millis(to_timestamp(df.t)).alias("n")).collect() [Row(n=1437584400000)] >>> spark.conf.unset("spark.sql.session.timeZone") """ @@ -5915,8 +6079,8 @@ def unix_seconds(col: "ColumnOrName") -> Column: Examples -------- >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") - >>> df = spark.createDataFrame([('2015-07-22 10:00:00',)], ['t']) - >>> df.select(unix_seconds(to_timestamp(df.t)).alias('n')).collect() + >>> df = spark.createDataFrame([("2015-07-22 10:00:00",)], ["t"]) + >>> df.select(unix_seconds(to_timestamp(df.t)).alias("n")).collect() [Row(n=1437584400)] >>> spark.conf.unset("spark.sql.session.timeZone") """ @@ -5941,7 +6105,7 @@ def arrays_overlap(a1: "ColumnOrName", a2: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([(["a", "b"], ["b", "c"]), (["a"], ["b", "c"])], ['x', 'y']) + >>> df = spark.createDataFrame([(["a", "b"], ["b", "c"]), (["a"], ["b", "c"])], ["x", "y"]) >>> df.select(arrays_overlap(df.x, df.y).alias("overlap")).collect() [Row(overlap=True), Row(overlap=False)] """ @@ -5952,21 +6116,19 @@ def arrays_overlap(a1: "ColumnOrName", a2: "ColumnOrName") -> Column: a2_has_null = _list_contains_null(a2) return Column( - CaseExpression( - FunctionExpression("list_has_any", a1, a2), ConstantExpression(True) - ).otherwise( + CaseExpression(FunctionExpression("list_has_any", a1, a2), ConstantExpression(True)).otherwise( CaseExpression( - (FunctionExpression("len", a1) > 0) & (FunctionExpression("len", a2) > 0) & (a1_has_null | a2_has_null), ConstantExpression(None) - ).otherwise(ConstantExpression(False))) + (FunctionExpression("len", a1) > 0) & (FunctionExpression("len", a2) > 0) & (a1_has_null | a2_has_null), + ConstantExpression(None), + ).otherwise(ConstantExpression(False)) + ) ) def _list_contains_null(c: ColumnExpression) -> Expression: return FunctionExpression( "list_contains", - FunctionExpression( - "list_transform", c, LambdaExpression("x", ColumnExpression("x").isnull()) - ), + FunctionExpression("list_transform", c, LambdaExpression("x", ColumnExpression("x").isnull())), True, ) @@ -5995,8 +6157,10 @@ def arrays_zip(*cols: "ColumnOrName") -> Column: Examples -------- >>> from pyspark.sql.functions import arrays_zip - >>> df = spark.createDataFrame([([1, 2, 3], [2, 4, 6], [3, 6])], ['vals1', 'vals2', 'vals3']) - >>> df = df.select(arrays_zip(df.vals1, df.vals2, df.vals3).alias('zipped')) + >>> df = spark.createDataFrame( + ... [([1, 2, 3], [2, 4, 6], [3, 6])], ["vals1", "vals2", "vals3"] + ... ) + >>> df = df.select(arrays_zip(df.vals1, df.vals2, df.vals3).alias("zipped")) >>> df.show(truncate=False) +------------------------------------+ |zipped | @@ -6039,8 +6203,13 @@ def substring(str: "ColumnOrName", pos: int, len: int) -> Column: substring of given value. Examples -------- - >>> df = spark.createDataFrame([('abcd',)], ['s',]) - >>> df.select(substring(df.s, 1, 2).alias('s')).collect() + >>> df = spark.createDataFrame( + ... [("abcd",)], + ... [ + ... "s", + ... ], + ... ) + >>> df.select(substring(df.s, 1, 2).alias("s")).collect() [Row(s='ab')] """ return _invoke_function( @@ -6065,10 +6234,18 @@ def contains(left: "ColumnOrName", right: "ColumnOrName") -> Column: The input column or strings to find, may be NULL. Examples -------- - >>> df = spark.createDataFrame([("Spark SQL", "Spark")], ['a', 'b']) - >>> df.select(contains(df.a, df.b).alias('r')).collect() + >>> df = spark.createDataFrame([("Spark SQL", "Spark")], ["a", "b"]) + >>> df.select(contains(df.a, df.b).alias("r")).collect() [Row(r=True)] - >>> df = spark.createDataFrame([("414243", "4243",)], ["c", "d"]) + >>> df = spark.createDataFrame( + ... [ + ... ( + ... "414243", + ... "4243", + ... ) + ... ], + ... ["c", "d"], + ... ) >>> df = df.select(to_binary("c").alias("c"), to_binary("d").alias("d")) >>> df.printSchema() root @@ -6100,15 +6277,16 @@ def reverse(col: "ColumnOrName") -> Column: array of elements in reverse order. Examples -------- - >>> df = spark.createDataFrame([('Spark SQL',)], ['data']) - >>> df.select(reverse(df.data).alias('s')).collect() + >>> df = spark.createDataFrame([("Spark SQL",)], ["data"]) + >>> df.select(reverse(df.data).alias("s")).collect() [Row(s='LQS krapS')] - >>> df = spark.createDataFrame([([2, 1, 3],) ,([1],) ,([],)], ['data']) - >>> df.select(reverse(df.data).alias('r')).collect() + >>> df = spark.createDataFrame([([2, 1, 3],), ([1],), ([],)], ["data"]) + >>> df.select(reverse(df.data).alias("r")).collect() [Row(r=[3, 1, 2]), Row(r=[1]), Row(r=[])] """ return _invoke_function("reverse", _to_column_expr(col)) + def concat(*cols: "ColumnOrName") -> Column: """ Concatenates multiple input columns together into a single column. @@ -6129,13 +6307,15 @@ def concat(*cols: "ColumnOrName") -> Column: :meth:`pyspark.sql.functions.array_join` : to concatenate string columns with delimiter Examples -------- - >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd']) - >>> df = df.select(concat(df.s, df.d).alias('s')) + >>> df = spark.createDataFrame([("abcd", "123")], ["s", "d"]) + >>> df = df.select(concat(df.s, df.d).alias("s")) >>> df.collect() [Row(s='abcd123')] >>> df DataFrame[s: string] - >>> df = spark.createDataFrame([([1, 2], [3, 4], [5]), ([1, 2], None, [3])], ['a', 'b', 'c']) + >>> df = spark.createDataFrame( + ... [([1, 2], [3, 4], [5]), ([1, 2], None, [3])], ["a", "b", "c"] + ... ) >>> df = df.select(concat(df.a, df.b, df.c).alias("arr")) >>> df.collect() [Row(arr=[1, 2, 3, 4, 5]), Row(arr=None)] @@ -6174,12 +6354,18 @@ def instr(str: "ColumnOrName", substr: str) -> Column: Examples -------- - >>> df = spark.createDataFrame([('abcd',)], ['s',]) - >>> df.select(instr(df.s, 'b').alias('s')).collect() + >>> df = spark.createDataFrame( + ... [("abcd",)], + ... [ + ... "s", + ... ], + ... ) + >>> df.select(instr(df.s, "b").alias("s")).collect() [Row(s=2)] """ return _invoke_function("instr", _to_column_expr(str), ConstantExpression(substr)) + def expr(str: str) -> Column: """Parses the expression string into the column that it represents @@ -6211,6 +6397,7 @@ def expr(str: str) -> Column: """ return Column(SQLExpression(str)) + def broadcast(df: "DataFrame") -> "DataFrame": """ The broadcast function in Spark is used to optimize joins by broadcasting a smaller diff --git a/duckdb/experimental/spark/sql/group.py b/duckdb/experimental/spark/sql/group.py index 4c4d5bb6..29210e29 100644 --- a/duckdb/experimental/spark/sql/group.py +++ b/duckdb/experimental/spark/sql/group.py @@ -30,6 +30,7 @@ __all__ = ["GroupedData", "Grouping"] + def _api_internal(self: "GroupedData", name: str, *cols: str) -> DataFrame: expressions = ",".join(list(cols)) group_by = str(self._grouping) if self._grouping else "" @@ -42,6 +43,7 @@ def _api_internal(self: "GroupedData", name: str, *cols: str) -> DataFrame: ) return DataFrame(jdf, self.session) + def df_varargs_api(f: Callable[..., DataFrame]) -> Callable[..., DataFrame]: def _api(self: "GroupedData", *cols: str) -> DataFrame: name = f.__name__ @@ -56,8 +58,8 @@ class Grouping: def __init__(self, *cols: "ColumnOrName", **kwargs) -> None: self._type = "" self._cols = [_to_column_expr(x) for x in cols] - if 'special' in kwargs: - special = kwargs['special'] + if "special" in kwargs: + special = kwargs["special"] accepted_special = ["cube", "rollup"] assert special in accepted_special self._type = special @@ -69,7 +71,7 @@ def get_columns(self) -> str: def __str__(self) -> str: columns = self.get_columns() if self._type: - return self._type + '(' + columns + ')' + return self._type + "(" + columns + ")" return columns @@ -94,7 +96,8 @@ def count(self) -> DataFrame: Examples -------- >>> df = spark.createDataFrame( - ... [(2, "Alice"), (3, "Alice"), (5, "Bob"), (10, "Bob")], ["age", "name"]) + ... [(2, "Alice"), (3, "Alice"), (5, "Bob"), (10, "Bob")], ["age", "name"] + ... ) >>> df.show() +---+-----+ |age| name| @@ -115,7 +118,7 @@ def count(self) -> DataFrame: | Bob| 2| +-----+-----+ """ - return _api_internal(self, "count").withColumnRenamed('count_star()', 'count') + return _api_internal(self, "count").withColumnRenamed("count_star()", "count") @df_varargs_api def mean(self, *cols: str) -> DataFrame: @@ -141,9 +144,10 @@ def avg(self, *cols: str) -> DataFrame: Examples -------- - >>> df = spark.createDataFrame([ - ... (2, "Alice", 80), (3, "Alice", 100), - ... (5, "Bob", 120), (10, "Bob", 140)], ["age", "name", "height"]) + >>> df = spark.createDataFrame( + ... [(2, "Alice", 80), (3, "Alice", 100), (5, "Bob", 120), (10, "Bob", 140)], + ... ["age", "name", "height"], + ... ) >>> df.show() +---+-----+------+ |age| name|height| @@ -156,7 +160,7 @@ def avg(self, *cols: str) -> DataFrame: Group-by name, and calculate the mean of the age in each group. - >>> df.groupBy("name").avg('age').sort("name").show() + >>> df.groupBy("name").avg("age").sort("name").show() +-----+--------+ | name|avg(age)| +-----+--------+ @@ -166,7 +170,7 @@ def avg(self, *cols: str) -> DataFrame: Calculate the mean of the age and height in all data. - >>> df.groupBy().avg('age', 'height').show() + >>> df.groupBy().avg("age", "height").show() +--------+-----------+ |avg(age)|avg(height)| +--------+-----------+ @@ -186,9 +190,10 @@ def max(self, *cols: str) -> DataFrame: Examples -------- - >>> df = spark.createDataFrame([ - ... (2, "Alice", 80), (3, "Alice", 100), - ... (5, "Bob", 120), (10, "Bob", 140)], ["age", "name", "height"]) + >>> df = spark.createDataFrame( + ... [(2, "Alice", 80), (3, "Alice", 100), (5, "Bob", 120), (10, "Bob", 140)], + ... ["age", "name", "height"], + ... ) >>> df.show() +---+-----+------+ |age| name|height| @@ -230,9 +235,10 @@ def min(self, *cols: str) -> DataFrame: Examples -------- - >>> df = spark.createDataFrame([ - ... (2, "Alice", 80), (3, "Alice", 100), - ... (5, "Bob", 120), (10, "Bob", 140)], ["age", "name", "height"]) + >>> df = spark.createDataFrame( + ... [(2, "Alice", 80), (3, "Alice", 100), (5, "Bob", 120), (10, "Bob", 140)], + ... ["age", "name", "height"], + ... ) >>> df.show() +---+-----+------+ |age| name|height| @@ -274,9 +280,10 @@ def sum(self, *cols: str) -> DataFrame: Examples -------- - >>> df = spark.createDataFrame([ - ... (2, "Alice", 80), (3, "Alice", 100), - ... (5, "Bob", 120), (10, "Bob", 140)], ["age", "name", "height"]) + >>> df = spark.createDataFrame( + ... [(2, "Alice", 80), (3, "Alice", 100), (5, "Bob", 120), (10, "Bob", 140)], + ... ["age", "name", "height"], + ... ) >>> df.show() +---+-----+------+ |age| name|height| @@ -308,12 +315,10 @@ def sum(self, *cols: str) -> DataFrame: """ @overload - def agg(self, *exprs: Column) -> DataFrame: - ... + def agg(self, *exprs: Column) -> DataFrame: ... @overload - def agg(self, __exprs: dict[str, str]) -> DataFrame: - ... + def agg(self, __exprs: dict[str, str]) -> DataFrame: ... def agg(self, *exprs: Union[Column, dict[str, str]]) -> DataFrame: """Compute aggregates and returns the result as a :class:`DataFrame`. @@ -357,7 +362,8 @@ def agg(self, *exprs: Union[Column, dict[str, str]]) -> DataFrame: >>> from pyspark.sql import functions as F >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( - ... [(2, "Alice"), (3, "Alice"), (5, "Bob"), (10, "Bob")], ["age", "name"]) + ... [(2, "Alice"), (3, "Alice"), (5, "Bob"), (10, "Bob")], ["age", "name"] + ... ) >>> df.show() +---+-----+ |age| name| @@ -393,10 +399,9 @@ def agg(self, *exprs: Union[Column, dict[str, str]]) -> DataFrame: Same as above but uses pandas UDF. - >>> @pandas_udf('int', PandasUDFType.GROUPED_AGG) # doctest: +SKIP + >>> @pandas_udf("int", PandasUDFType.GROUPED_AGG) # doctest: +SKIP ... def min_udf(v): ... return v.min() - ... >>> df.groupBy(df.name).agg(min_udf(df.age)).sort("name").show() # doctest: +SKIP +-----+------------+ | name|min_udf(age)| diff --git a/duckdb/experimental/spark/sql/readwriter.py b/duckdb/experimental/spark/sql/readwriter.py index 6e8c72c6..18095ab6 100644 --- a/duckdb/experimental/spark/sql/readwriter.py +++ b/duckdb/experimental/spark/sql/readwriter.py @@ -328,9 +328,9 @@ def json( >>> import tempfile >>> with tempfile.TemporaryDirectory() as d: ... # Write a DataFrame into a JSON file - ... spark.createDataFrame( - ... [{"age": 100, "name": "Hyukjin Kwon"}] - ... ).write.mode("overwrite").format("json").save(d) + ... spark.createDataFrame([{"age": 100, "name": "Hyukjin Kwon"}]).write.mode( + ... "overwrite" + ... ).format("json").save(d) ... ... # Read the JSON file as a DataFrame. ... spark.read.json(d).show() @@ -344,98 +344,62 @@ def json( if schema is not None: raise ContributionsAcceptedError("The 'schema' option is not supported") if primitivesAsString is not None: - raise ContributionsAcceptedError( - "The 'primitivesAsString' option is not supported" - ) + raise ContributionsAcceptedError("The 'primitivesAsString' option is not supported") if prefersDecimal is not None: - raise ContributionsAcceptedError( - "The 'prefersDecimal' option is not supported" - ) + raise ContributionsAcceptedError("The 'prefersDecimal' option is not supported") if allowComments is not None: - raise ContributionsAcceptedError( - "The 'allowComments' option is not supported" - ) + raise ContributionsAcceptedError("The 'allowComments' option is not supported") if allowUnquotedFieldNames is not None: - raise ContributionsAcceptedError( - "The 'allowUnquotedFieldNames' option is not supported" - ) + raise ContributionsAcceptedError("The 'allowUnquotedFieldNames' option is not supported") if allowSingleQuotes is not None: - raise ContributionsAcceptedError( - "The 'allowSingleQuotes' option is not supported" - ) + raise ContributionsAcceptedError("The 'allowSingleQuotes' option is not supported") if allowNumericLeadingZero is not None: - raise ContributionsAcceptedError( - "The 'allowNumericLeadingZero' option is not supported" - ) + raise ContributionsAcceptedError("The 'allowNumericLeadingZero' option is not supported") if allowBackslashEscapingAnyCharacter is not None: - raise ContributionsAcceptedError( - "The 'allowBackslashEscapingAnyCharacter' option is not supported" - ) + raise ContributionsAcceptedError("The 'allowBackslashEscapingAnyCharacter' option is not supported") if mode is not None: raise ContributionsAcceptedError("The 'mode' option is not supported") if columnNameOfCorruptRecord is not None: - raise ContributionsAcceptedError( - "The 'columnNameOfCorruptRecord' option is not supported" - ) + raise ContributionsAcceptedError("The 'columnNameOfCorruptRecord' option is not supported") if dateFormat is not None: raise ContributionsAcceptedError("The 'dateFormat' option is not supported") if timestampFormat is not None: - raise ContributionsAcceptedError( - "The 'timestampFormat' option is not supported" - ) + raise ContributionsAcceptedError("The 'timestampFormat' option is not supported") if multiLine is not None: raise ContributionsAcceptedError("The 'multiLine' option is not supported") if allowUnquotedControlChars is not None: - raise ContributionsAcceptedError( - "The 'allowUnquotedControlChars' option is not supported" - ) + raise ContributionsAcceptedError("The 'allowUnquotedControlChars' option is not supported") if lineSep is not None: raise ContributionsAcceptedError("The 'lineSep' option is not supported") if samplingRatio is not None: - raise ContributionsAcceptedError( - "The 'samplingRatio' option is not supported" - ) + raise ContributionsAcceptedError("The 'samplingRatio' option is not supported") if dropFieldIfAllNull is not None: - raise ContributionsAcceptedError( - "The 'dropFieldIfAllNull' option is not supported" - ) + raise ContributionsAcceptedError("The 'dropFieldIfAllNull' option is not supported") if encoding is not None: raise ContributionsAcceptedError("The 'encoding' option is not supported") if locale is not None: raise ContributionsAcceptedError("The 'locale' option is not supported") if pathGlobFilter is not None: - raise ContributionsAcceptedError( - "The 'pathGlobFilter' option is not supported" - ) + raise ContributionsAcceptedError("The 'pathGlobFilter' option is not supported") if recursiveFileLookup is not None: - raise ContributionsAcceptedError( - "The 'recursiveFileLookup' option is not supported" - ) + raise ContributionsAcceptedError("The 'recursiveFileLookup' option is not supported") if modifiedBefore is not None: - raise ContributionsAcceptedError( - "The 'modifiedBefore' option is not supported" - ) + raise ContributionsAcceptedError("The 'modifiedBefore' option is not supported") if modifiedAfter is not None: - raise ContributionsAcceptedError( - "The 'modifiedAfter' option is not supported" - ) + raise ContributionsAcceptedError("The 'modifiedAfter' option is not supported") if allowNonNumericNumbers is not None: - raise ContributionsAcceptedError( - "The 'allowNonNumericNumbers' option is not supported" - ) + raise ContributionsAcceptedError("The 'allowNonNumericNumbers' option is not supported") if isinstance(path, str): path = [path] - if isinstance(path, list): + if isinstance(path, list): if len(path) == 1: rel = self.session.conn.read_json(path[0]) from .dataframe import DataFrame df = DataFrame(rel, self.session) return df - raise PySparkNotImplementedError( - message="Only a single path is supported for now" - ) + raise PySparkNotImplementedError(message="Only a single path is supported for now") else: raise PySparkTypeError( error_class="NOT_STR_OR_LIST_OF_RDD", diff --git a/duckdb/experimental/spark/sql/session.py b/duckdb/experimental/spark/sql/session.py index 744a77e8..c83c7e82 100644 --- a/duckdb/experimental/spark/sql/session.py +++ b/duckdb/experimental/spark/sql/session.py @@ -16,10 +16,7 @@ from .streaming import DataStreamReader import duckdb -from ..errors import ( - PySparkTypeError, - PySparkValueError -) +from ..errors import PySparkTypeError, PySparkValueError from ..errors.error_classes import * @@ -53,11 +50,12 @@ def __init__(self, context: SparkContext) -> None: def _create_dataframe(self, data: Union[Iterable[Any], "PandasDataFrame"]) -> DataFrame: try: import pandas + has_pandas = True except ImportError: has_pandas = False if has_pandas and isinstance(data, pandas.DataFrame): - unique_name = f'pyspark_pandas_df_{uuid.uuid1()}' + unique_name = f"pyspark_pandas_df_{uuid.uuid1()}" self.conn.register(unique_name, data) return DataFrame(self.conn.sql(f'select * from "{unique_name}"'), self) @@ -73,9 +71,9 @@ def verify_tuple_integrity(tuples): error_class="LENGTH_SHOULD_BE_THE_SAME", message_parameters={ "arg1": f"data{i}", - "arg2": f"data{i+1}", + "arg2": f"data{i + 1}", "arg1_length": str(expected_length), - "arg2_length": str(actual_length) + "arg2_length": str(actual_length), }, ) @@ -86,13 +84,13 @@ def verify_tuple_integrity(tuples): def construct_query(tuples) -> str: def construct_values_list(row, start_param_idx): parameter_count = len(row) - parameters = [f'${x+start_param_idx}' for x in range(parameter_count)] - parameters = '(' + ', '.join(parameters) + ')' + parameters = [f"${x + start_param_idx}" for x in range(parameter_count)] + parameters = "(" + ", ".join(parameters) + ")" return parameters row_size = len(tuples[0]) values_list = [construct_values_list(x, 1 + (i * row_size)) for i, x in enumerate(tuples)] - values_list = ', '.join(values_list) + values_list = ", ".join(values_list) query = f""" select * from (values {values_list}) @@ -175,7 +173,7 @@ def createDataFrame( if is_empty: rel = df.relation # Add impossible where clause - rel = rel.filter('1=0') + rel = rel.filter("1=0") df = DataFrame(rel, self) # Cast to types @@ -203,7 +201,7 @@ def range( end = start start = 0 - return DataFrame(self.conn.table_function("range", parameters=[start, end, step]),self) + return DataFrame(self.conn.table_function("range", parameters=[start, end, step]), self) def sql(self, sqlQuery: str, **kwargs: Any) -> DataFrame: if kwargs: @@ -255,7 +253,7 @@ def udf(self) -> UDFRegistration: @property def version(self) -> str: - return '1.0.0' + return "1.0.0" class Builder: def __init__(self) -> None: diff --git a/duckdb/experimental/spark/sql/streaming.py b/duckdb/experimental/spark/sql/streaming.py index cda80602..4dcba01f 100644 --- a/duckdb/experimental/spark/sql/streaming.py +++ b/duckdb/experimental/spark/sql/streaming.py @@ -27,7 +27,7 @@ def load( path: Optional[str] = None, format: Optional[str] = None, schema: Union[StructType, str, None] = None, - **options: OptionalPrimitiveType + **options: OptionalPrimitiveType, ) -> "DataFrame": from duckdb.experimental.spark.sql.dataframe import DataFrame diff --git a/duckdb/experimental/spark/sql/type_utils.py b/duckdb/experimental/spark/sql/type_utils.py index ecccc014..f8c8ce4f 100644 --- a/duckdb/experimental/spark/sql/type_utils.py +++ b/duckdb/experimental/spark/sql/type_utils.py @@ -36,62 +36,62 @@ ) _sqltype_to_spark_class = { - 'boolean': BooleanType, - 'utinyint': UnsignedByteType, - 'tinyint': ByteType, - 'usmallint': UnsignedShortType, - 'smallint': ShortType, - 'uinteger': UnsignedIntegerType, - 'integer': IntegerType, - 'ubigint': UnsignedLongType, - 'bigint': LongType, - 'hugeint': HugeIntegerType, - 'uhugeint': UnsignedHugeIntegerType, - 'varchar': StringType, - 'blob': BinaryType, - 'bit': BitstringType, - 'uuid': UUIDType, - 'date': DateType, - 'time': TimeNTZType, - 'time with time zone': TimeType, - 'timestamp': TimestampNTZType, - 'timestamp with time zone': TimestampType, - 'timestamp_ms': TimestampNanosecondNTZType, - 'timestamp_ns': TimestampMilisecondNTZType, - 'timestamp_s': TimestampSecondNTZType, - 'interval': DayTimeIntervalType, - 'list': ArrayType, - 'struct': StructType, - 'map': MapType, + "boolean": BooleanType, + "utinyint": UnsignedByteType, + "tinyint": ByteType, + "usmallint": UnsignedShortType, + "smallint": ShortType, + "uinteger": UnsignedIntegerType, + "integer": IntegerType, + "ubigint": UnsignedLongType, + "bigint": LongType, + "hugeint": HugeIntegerType, + "uhugeint": UnsignedHugeIntegerType, + "varchar": StringType, + "blob": BinaryType, + "bit": BitstringType, + "uuid": UUIDType, + "date": DateType, + "time": TimeNTZType, + "time with time zone": TimeType, + "timestamp": TimestampNTZType, + "timestamp with time zone": TimestampType, + "timestamp_ms": TimestampNanosecondNTZType, + "timestamp_ns": TimestampMilisecondNTZType, + "timestamp_s": TimestampSecondNTZType, + "interval": DayTimeIntervalType, + "list": ArrayType, + "struct": StructType, + "map": MapType, # union # enum # null (???) - 'float': FloatType, - 'double': DoubleType, - 'decimal': DecimalType, + "float": FloatType, + "double": DoubleType, + "decimal": DecimalType, } def convert_nested_type(dtype: DuckDBPyType) -> DataType: id = dtype.id - if id == 'list' or id == 'array': + if id == "list" or id == "array": children = dtype.children return ArrayType(convert_type(children[0][1])) # TODO: add support for 'union' - if id == 'struct': + if id == "struct": children: list[tuple[str, DuckDBPyType]] = dtype.children fields = [StructField(x[0], convert_type(x[1])) for x in children] return StructType(fields) - if id == 'map': + if id == "map": return MapType(convert_type(dtype.key), convert_type(dtype.value)) raise NotImplementedError def convert_type(dtype: DuckDBPyType) -> DataType: id = dtype.id - if id in ['list', 'struct', 'map', 'array']: + if id in ["list", "struct", "map", "array"]: return convert_nested_type(dtype) - if id == 'decimal': + if id == "decimal": children: list[tuple[str, DuckDBPyType]] = dtype.children precision = cast(int, children[0][1]) scale = cast(int, children[1][1]) diff --git a/duckdb/experimental/spark/sql/types.py b/duckdb/experimental/spark/sql/types.py index 4b3a4132..81293caf 100644 --- a/duckdb/experimental/spark/sql/types.py +++ b/duckdb/experimental/spark/sql/types.py @@ -632,11 +632,9 @@ class MapType(DataType): Examples -------- - >>> (MapType(StringType(), IntegerType()) - ... == MapType(StringType(), IntegerType(), True)) + >>> (MapType(StringType(), IntegerType()) == MapType(StringType(), IntegerType(), True)) True - >>> (MapType(StringType(), IntegerType(), False) - ... == MapType(StringType(), FloatType())) + >>> (MapType(StringType(), IntegerType(), False) == MapType(StringType(), FloatType())) False """ @@ -697,11 +695,9 @@ class StructField(DataType): Examples -------- - >>> (StructField("f1", StringType(), True) - ... == StructField("f1", StringType(), True)) + >>> (StructField("f1", StringType(), True) == StructField("f1", StringType(), True)) True - >>> (StructField("f1", StringType(), True) - ... == StructField("f2", StringType(), True)) + >>> (StructField("f1", StringType(), True) == StructField("f2", StringType(), True)) False """ @@ -743,7 +739,7 @@ def fromInternal(self, obj: T) -> T: return self.dataType.fromInternal(obj) def typeName(self) -> str: # type: ignore[override] - raise TypeError("StructField does not have typeName. " "Use typeName on its type explicitly instead.") + raise TypeError("StructField does not have typeName. Use typeName on its type explicitly instead.") class StructType(DataType): @@ -767,8 +763,9 @@ class StructType(DataType): >>> struct1 == struct2 True >>> struct1 = StructType([StructField("f1", StringType(), True)]) - >>> struct2 = StructType([StructField("f1", StringType(), True), - ... StructField("f2", IntegerType(), False)]) + >>> struct2 = StructType( + ... [StructField("f1", StringType(), True), StructField("f2", IntegerType(), False)] + ... ) >>> struct1 == struct2 False """ @@ -796,12 +793,10 @@ def add( data_type: Union[str, DataType], nullable: bool = True, metadata: Optional[dict[str, Any]] = None, - ) -> "StructType": - ... + ) -> "StructType": ... @overload - def add(self, field: StructField) -> "StructType": - ... + def add(self, field: StructField) -> "StructType": ... def add( self, @@ -1091,7 +1086,6 @@ def _create_row(fields: Union["Row", list[str]], values: Union[tuple[Any, ...], class Row(tuple): - """ A row in :class:`DataFrame`. The fields in it can be accessed: @@ -1115,13 +1109,13 @@ class Row(tuple): >>> row = Row(name="Alice", age=11) >>> row Row(name='Alice', age=11) - >>> row['name'], row['age'] + >>> row["name"], row["age"] ('Alice', 11) >>> row.name, row.age ('Alice', 11) - >>> 'name' in row + >>> "name" in row True - >>> 'wrong_key' in row + >>> "wrong_key" in row False Row also can be used to create another Row like class, then it @@ -1130,9 +1124,9 @@ class Row(tuple): >>> Person = Row("name", "age") >>> Person - >>> 'name' in Person + >>> "name" in Person True - >>> 'wrong_key' in Person + >>> "wrong_key" in Person False >>> Person("Alice", 11) Row(name='Alice', age=11) @@ -1147,16 +1141,14 @@ class Row(tuple): """ @overload - def __new__(cls, *args: str) -> "Row": - ... + def __new__(cls, *args: str) -> "Row": ... @overload - def __new__(cls, **kwargs: Any) -> "Row": - ... + def __new__(cls, **kwargs: Any) -> "Row": ... def __new__(cls, *args: Optional[str], **kwargs: Optional[Any]) -> "Row": if args and kwargs: - raise ValueError("Can not use both args " "and kwargs to create Row") + raise ValueError("Can not use both args and kwargs to create Row") if kwargs: # create row objects row = tuple.__new__(cls, list(kwargs.values())) @@ -1185,12 +1177,12 @@ def asDict(self, recursive: bool = False) -> dict[str, Any]: Examples -------- - >>> Row(name="Alice", age=11).asDict() == {'name': 'Alice', 'age': 11} + >>> Row(name="Alice", age=11).asDict() == {"name": "Alice", "age": 11} True - >>> row = Row(key=1, value=Row(name='a', age=2)) - >>> row.asDict() == {'key': 1, 'value': Row(name='a', age=2)} + >>> row = Row(key=1, value=Row(name="a", age=2)) + >>> row.asDict() == {"key": 1, "value": Row(name="a", age=2)} True - >>> row.asDict(True) == {'key': 1, 'value': {'name': 'a', 'age': 2}} + >>> row.asDict(True) == {"key": 1, "value": {"name": "a", "age": 2}} True """ if not hasattr(self, "__fields__"): @@ -1223,7 +1215,7 @@ def __call__(self, *args: Any) -> "Row": """create new Row object""" if len(args) > len(self): raise ValueError( - "Can not create Row with fields %s, expected %d values " "but got %s" % (self, len(self), args) + "Can not create Row with fields %s, expected %d values but got %s" % (self, len(self), args) ) return _create_row(self, args) diff --git a/duckdb/filesystem.py b/duckdb/filesystem.py index fbef757d..ea4ba540 100644 --- a/duckdb/filesystem.py +++ b/duckdb/filesystem.py @@ -3,13 +3,14 @@ from .bytes_io_wrapper import BytesIOWrapper from io import TextIOBase + def is_file_like(obj): # We only care that we can read from the file return hasattr(obj, "read") and hasattr(obj, "seek") class ModifiedMemoryFileSystem(MemoryFileSystem): - protocol = ('DUCKDB_INTERNAL_OBJECTSTORE',) + protocol = ("DUCKDB_INTERNAL_OBJECTSTORE",) # defer to the original implementation that doesn't hardcode the protocol _strip_protocol = classmethod(AbstractFileSystem._strip_protocol.__func__) diff --git a/duckdb/functional/__init__.py b/duckdb/functional/__init__.py index ac4a6495..90c2a561 100644 --- a/duckdb/functional/__init__.py +++ b/duckdb/functional/__init__.py @@ -1,17 +1,3 @@ -from _duckdb.functional import ( - FunctionNullHandling, - PythonUDFType, - SPECIAL, - DEFAULT, - NATIVE, - ARROW -) +from _duckdb.functional import FunctionNullHandling, PythonUDFType, SPECIAL, DEFAULT, NATIVE, ARROW -__all__ = [ - "FunctionNullHandling", - "PythonUDFType", - "SPECIAL", - "DEFAULT", - "NATIVE", - "ARROW" -] +__all__ = ["FunctionNullHandling", "PythonUDFType", "SPECIAL", "DEFAULT", "NATIVE", "ARROW"] diff --git a/duckdb/polars_io.py b/duckdb/polars_io.py index d8d4cfe9..ef87f03a 100644 --- a/duckdb/polars_io.py +++ b/duckdb/polars_io.py @@ -8,13 +8,14 @@ from decimal import Decimal import datetime + def _predicate_to_expression(predicate: pl.Expr) -> Optional[SQLExpression]: """ Convert a Polars predicate expression to a DuckDB-compatible SQL expression. - + Parameters: predicate (pl.Expr): A Polars expression (e.g., col("foo") > 5) - + Returns: SQLExpression: A DuckDB SQL expression string equivalent. None: If conversion fails. @@ -25,7 +26,7 @@ def _predicate_to_expression(predicate: pl.Expr) -> Optional[SQLExpression]: """ # Serialize the Polars expression tree to JSON tree = json.loads(predicate.meta.serialize(format="json")) - + try: # Convert the tree to SQL sql_filter = _pl_tree_to_sql(tree) @@ -38,7 +39,7 @@ def _predicate_to_expression(predicate: pl.Expr) -> Optional[SQLExpression]: def _pl_operation_to_sql(op: str) -> str: """ Map Polars binary operation strings to SQL equivalents. - + Example: >>> _pl_operation_to_sql("Eq") '=' @@ -73,13 +74,13 @@ def _escape_sql_identifier(identifier: str) -> str: def _pl_tree_to_sql(tree: dict) -> str: """ Recursively convert a Polars expression tree (as JSON) to a SQL string. - + Parameters: tree (dict): JSON-deserialized expression tree from Polars - + Returns: str: SQL expression string - + Example: Input tree: { @@ -97,13 +98,15 @@ def _pl_tree_to_sql(tree: dict) -> str: if node_type == "BinaryExpr": # Binary expressions: left OP right return ( - "(" + - " ".join(( - _pl_tree_to_sql(subtree['left']), - _pl_operation_to_sql(subtree['op']), - _pl_tree_to_sql(subtree['right']) - )) + - ")" + "(" + + " ".join( + ( + _pl_tree_to_sql(subtree["left"]), + _pl_operation_to_sql(subtree["op"]), + _pl_tree_to_sql(subtree["right"]), + ) + ) + + ")" ) if node_type == "Column": # A reference to a column name @@ -147,20 +150,30 @@ def _pl_tree_to_sql(tree: dict) -> str: # Decimal support if dtype.startswith("{'Decimal'") or dtype == "Decimal": - decimal_value = value['Decimal'] + decimal_value = value["Decimal"] decimal_value = Decimal(decimal_value[0]) / Decimal(10 ** decimal_value[1]) return str(decimal_value) # Datetime with microseconds since epoch if dtype.startswith("{'Datetime'") or dtype == "Datetime": - micros = value['Datetime'][0] + micros = value["Datetime"][0] dt_timestamp = datetime.datetime.fromtimestamp(micros / 1_000_000, tz=datetime.UTC) return f"'{str(dt_timestamp)}'::TIMESTAMP" # Match simple numeric/boolean types - if dtype in ("Int8", "Int16", "Int32", "Int64", - "UInt8", "UInt16", "UInt32", "UInt64", - "Float32", "Float64", "Boolean"): + if dtype in ( + "Int8", + "Int16", + "Int32", + "Int64", + "UInt8", + "UInt16", + "UInt32", + "UInt64", + "Float32", + "Float64", + "Boolean", + ): return str(value[dtype]) # Time type @@ -168,9 +181,7 @@ def _pl_tree_to_sql(tree: dict) -> str: nanoseconds = value["Time"] seconds = nanoseconds // 1_000_000_000 microseconds = (nanoseconds % 1_000_000_000) // 1_000 - dt_time = (datetime.datetime.min + datetime.timedelta( - seconds=seconds, microseconds=microseconds - )).time() + dt_time = (datetime.datetime.min + datetime.timedelta(seconds=seconds, microseconds=microseconds)).time() return f"'{dt_time}'::TIME" # Date type @@ -182,7 +193,7 @@ def _pl_tree_to_sql(tree: dict) -> str: # Binary type if dtype == "Binary": binary_data = bytes(value["Binary"]) - escaped = ''.join(f'\\x{b:02x}' for b in binary_data) + escaped = "".join(f"\\x{b:02x}" for b in binary_data) return f"'{escaped}'::BLOB" # String type @@ -191,15 +202,16 @@ def _pl_tree_to_sql(tree: dict) -> str: string_val = value.get("StringOwned", value.get("String", None)) return f"'{string_val}'" - raise NotImplementedError(f"Unsupported scalar type {str(dtype)}, with value {value}") raise NotImplementedError(f"Node type: {node_type} is not implemented. {subtree}") + def duckdb_source(relation: duckdb.DuckDBPyRelation, schema: pl.schema.Schema) -> pl.LazyFrame: """ A polars IO plugin for DuckDB. """ + def source_generator( with_columns: Optional[list[str]], predicate: Optional[pl.Expr], diff --git a/duckdb/query_graph/__main__.py b/duckdb/query_graph/__main__.py index eab68179..aa67b42f 100644 --- a/duckdb/query_graph/__main__.py +++ b/duckdb/query_graph/__main__.py @@ -77,7 +77,6 @@ class NodeTiming: - def __init__(self, phase: str, time: float) -> object: self.phase = phase self.time = time @@ -94,7 +93,6 @@ def combine_timing(l: object, r: object) -> object: class AllTimings: - def __init__(self) -> None: self.phase_to_timings = {} @@ -128,37 +126,38 @@ def open_utf8(fpath: str, flags: str) -> object: def get_child_timings(top_node: object, query_timings: object) -> str: - node_timing = NodeTiming(top_node['operator_type'], float(top_node['operator_timing'])) + node_timing = NodeTiming(top_node["operator_type"], float(top_node["operator_timing"])) query_timings.add_node_timing(node_timing) - for child in top_node['children']: + for child in top_node["children"]: get_child_timings(child, query_timings) def get_pink_shade_hex(fraction: float): fraction = max(0, min(1, fraction)) - + # Define the RGB values for very light pink (almost white) and dark pink light_pink = (255, 250, 250) # Very light pink - dark_pink = (255, 20, 147) # Dark pink - + dark_pink = (255, 20, 147) # Dark pink + # Calculate the RGB values for the given fraction r = int(light_pink[0] + (dark_pink[0] - light_pink[0]) * fraction) g = int(light_pink[1] + (dark_pink[1] - light_pink[1]) * fraction) b = int(light_pink[2] + (dark_pink[2] - light_pink[2]) * fraction) - + # Return as hexadecimal color code return f"#{r:02x}{g:02x}{b:02x}" + def get_node_body(name: str, result: str, cpu_time: float, card: int, est: int, width: int, extra_info: str) -> str: - node_style = f"background-color: {get_pink_shade_hex(float(result)/cpu_time)};" + node_style = f"background-color: {get_pink_shade_hex(float(result) / cpu_time)};" - body = f"" - body += "
" + body = f'' + body += '
' new_name = "BRIDGE" if (name == "INVALID") else name.replace("_", " ") formatted_num = f"{float(result):.4f}" body += f"

{new_name}

time: {formatted_num} seconds

" - body += f" {extra_info} " - if (width > 0): + body += f' {extra_info} ' + if width > 0: body += f"

cardinality: {card}

" body += f"

estimate: {est}

" body += f"

width: {width} bytes

" @@ -174,26 +173,31 @@ def generate_tree_recursive(json_graph: object, cpu_time: float) -> str: extra_info = "" estimate = 0 - for key in json_graph['extra_info']: - value = json_graph['extra_info'][key] - if (key == "Estimated Cardinality"): + for key in json_graph["extra_info"]: + value = json_graph["extra_info"][key] + if key == "Estimated Cardinality": estimate = int(value) else: extra_info += f"{key}: {value}
" cardinality = json_graph["operator_cardinality"] - width = int(json_graph["result_set_size"]/max(1,cardinality)) + width = int(json_graph["result_set_size"] / max(1, cardinality)) # get rid of some typically long names extra_info = re.sub(r"__internal_\s*", "__", extra_info) extra_info = re.sub(r"compress_integral\s*", "compress", extra_info) - node_body = get_node_body(json_graph["operator_type"], - json_graph["operator_timing"], - cpu_time, cardinality, estimate, width, - re.sub(r",\s*", ", ", extra_info)) + node_body = get_node_body( + json_graph["operator_type"], + json_graph["operator_timing"], + cpu_time, + cardinality, + estimate, + width, + re.sub(r",\s*", ", ", extra_info), + ) children_html = "" - if len(json_graph['children']) >= 1: + if len(json_graph["children"]) >= 1: children_html += "
    " for child in json_graph["children"]: children_html += generate_tree_recursive(child, cpu_time) @@ -205,7 +209,7 @@ def generate_tree_recursive(json_graph: object, cpu_time: float) -> str: def generate_timing_html(graph_json: object, query_timings: object) -> object: json_graph = json.loads(graph_json) gather_timing_information(json_graph, query_timings) - total_time = float(json_graph.get('operator_timing') or json_graph.get('latency')) + total_time = float(json_graph.get("operator_timing") or json_graph.get("latency")) table_head = """ @@ -242,12 +246,12 @@ def generate_timing_html(graph_json: object, query_timings: object) -> object: def generate_tree_html(graph_json: object) -> str: json_graph = json.loads(graph_json) - cpu_time = float(json_graph['cpu_time']) - tree_prefix = "
    \n
      " + cpu_time = float(json_graph["cpu_time"]) + tree_prefix = '
      \n
        ' tree_suffix = "
      " # first level of json is general overview # FIXME: make sure json output first level always has only 1 level - tree_body = generate_tree_recursive(json_graph['children'][0], cpu_time) + tree_body = generate_tree_recursive(json_graph["children"][0], cpu_time) return tree_prefix + tree_body + tree_suffix @@ -256,39 +260,32 @@ def generate_ipython(json_input: str) -> str: html_output = generate_html(json_input, False) - return HTML(("\n" - " ${CSS}\n" - " ${LIBRARIES}\n" - "
      \n" - " ${CHART_SCRIPT}\n" - " ").replace("${CSS}", html_output['css']).replace('${CHART_SCRIPT}', - html_output['chart_script']).replace( - '${LIBRARIES}', html_output['libraries'])) + return HTML( + ('\n ${CSS}\n ${LIBRARIES}\n
      \n ${CHART_SCRIPT}\n ') + .replace("${CSS}", html_output["css"]) + .replace("${CHART_SCRIPT}", html_output["chart_script"]) + .replace("${LIBRARIES}", html_output["libraries"]) + ) def generate_style_html(graph_json: str, include_meta_info: bool) -> None: - treeflex_css = "\n" + treeflex_css = '\n' css = "\n" - return { - 'treeflex_css': treeflex_css, - 'duckdb_css': css, - 'libraries': '', - 'chart_script': '' - } + return {"treeflex_css": treeflex_css, "duckdb_css": css, "libraries": "", "chart_script": ""} def gather_timing_information(json: str, query_timings: object) -> None: # add up all of the times # measure each time as a percentage of the total time. # then you can return a list of [phase, time, percentage] - get_child_timings(json['children'][0], query_timings) + get_child_timings(json["children"][0], query_timings) def translate_json_to_html(input_file: str, output_file: str) -> None: query_timings = AllTimings() - with open_utf8(input_file, 'r') as f: + with open_utf8(input_file, "r") as f: text = f.read() html_output = generate_style_html(text, True) @@ -317,10 +314,10 @@ def translate_json_to_html(input_file: str, output_file: str) -> None: """ - html = html.replace("${TREEFLEX_CSS}", html_output['treeflex_css']) - html = html.replace("${DUCKDB_CSS}", html_output['duckdb_css']) + html = html.replace("${TREEFLEX_CSS}", html_output["treeflex_css"]) + html = html.replace("${DUCKDB_CSS}", html_output["duckdb_css"]) html = html.replace("${TIMING_TABLE}", timing_table) - html = html.replace('${TREE}', tree_output) + html = html.replace("${TREE}", tree_output) f.write(html) @@ -329,11 +326,12 @@ def main() -> None: print("Please use python3") exit(1) parser = argparse.ArgumentParser( - prog='Query Graph Generator', - description='Given a json profile output, generate a html file showing the query graph and timings of operators') - parser.add_argument('profile_input', help='profile input in json') - parser.add_argument('--out', required=False, default=False) - parser.add_argument('--open', required=False, action='store_true', default=True) + prog="Query Graph Generator", + description="Given a json profile output, generate a html file showing the query graph and timings of operators", + ) + parser.add_argument("profile_input", help="profile input in json") + parser.add_argument("--out", required=False, default=False) + parser.add_argument("--open", required=False, action="store_true", default=True) args = parser.parse_args() input = args.profile_input @@ -356,8 +354,8 @@ def main() -> None: translate_json_to_html(input, output) if open_output: - webbrowser.open('file://' + os.path.abspath(output), new=2) + webbrowser.open("file://" + os.path.abspath(output), new=2) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/duckdb/typing/__init__.py b/duckdb/typing/__init__.py index d0e95b50..33cf4cd7 100644 --- a/duckdb/typing/__init__.py +++ b/duckdb/typing/__init__.py @@ -26,7 +26,7 @@ USMALLINT, UTINYINT, UUID, - VARCHAR + VARCHAR, ) __all__ = [ @@ -57,5 +57,5 @@ "USMALLINT", "UTINYINT", "UUID", - "VARCHAR" + "VARCHAR", ] diff --git a/duckdb/typing/__init__.pyi b/duckdb/typing/__init__.pyi index 69435c05..8a3cef79 100644 --- a/duckdb/typing/__init__.pyi +++ b/duckdb/typing/__init__.pyi @@ -32,5 +32,7 @@ class DuckDBPyType: def __init__(self, type_str: str, connection: DuckDBPyConnection = ...) -> None: ... def __repr__(self) -> str: ... def __eq__(self, other) -> bool: ... - def __getattr__(self, name: str): DuckDBPyType - def __getitem__(self, name: str): DuckDBPyType \ No newline at end of file + def __getattr__(self, name: str): + DuckDBPyType + def __getitem__(self, name: str): + DuckDBPyType diff --git a/duckdb/value/constant/__init__.pyi b/duckdb/value/constant/__init__.pyi index 8cea58cf..f5190345 100644 --- a/duckdb/value/constant/__init__.pyi +++ b/duckdb/value/constant/__init__.pyi @@ -54,9 +54,9 @@ class DoubleValue(Value): def __repr__(self) -> str: ... class DecimalValue(Value): - def __init__(self, object: Any, width: int, scale: int) -> None: ... - def __repr__(self) -> str: ... - + def __init__(self, object: Any, width: int, scale: int) -> None: ... + def __repr__(self) -> str: ... + class StringValue(Value): def __init__(self, object: Any) -> None: ... def __repr__(self) -> str: ... @@ -109,7 +109,6 @@ class TimeTimeZoneValue(Value): def __init__(self, object: Any) -> None: ... def __repr__(self) -> str: ... - class Value: def __init__(self, object: Any, type: DuckDBPyType) -> None: ... def __repr__(self) -> str: ... diff --git a/duckdb_packaging/_versioning.py b/duckdb_packaging/_versioning.py index ca8e7716..3709dac0 100644 --- a/duckdb_packaging/_versioning.py +++ b/duckdb_packaging/_versioning.py @@ -5,13 +5,16 @@ - Git tag creation and management - Version parsing and validation """ + import pathlib import subprocess from typing import Optional import re -VERSION_RE = re.compile(r"^(?P[0-9]+)\.(?P[0-9]+)\.(?P[0-9]+)(?:rc(?P[0-9]+)|\.post(?P[0-9]+))?$") +VERSION_RE = re.compile( + r"^(?P[0-9]+)\.(?P[0-9]+)\.(?P[0-9]+)(?:rc(?P[0-9]+)|\.post(?P[0-9]+))?$" +) def parse_version(version: str) -> tuple[int, int, int, int, int]: @@ -67,12 +70,12 @@ def git_tag_to_pep440(git_tag: str) -> str: PEP440 version string (e.g., "1.3.1", "1.3.1.post1") """ # Remove 'v' prefix if present - version = git_tag[1:] if git_tag.startswith('v') else git_tag + version = git_tag[1:] if git_tag.startswith("v") else git_tag if "-post" in version: - assert 'rc' not in version + assert "rc" not in version version = version.replace("-post", ".post") - elif '-rc' in version: + elif "-rc" in version: version = version.replace("-rc", "rc") return version @@ -87,10 +90,10 @@ def pep440_to_git_tag(version: str) -> str: Returns: Git tag format (e.g., "v1.3.1-post1") """ - if '.post' in version: - assert 'rc' not in version + if ".post" in version: + assert "rc" not in version version = version.replace(".post", "-post") - elif 'rc' in version: + elif "rc" in version: version = version.replace("rc", "-rc") return f"v{version}" @@ -104,12 +107,7 @@ def get_current_version() -> Optional[str]: """ try: # Get the latest tag - result = subprocess.run( - ["git", "describe", "--tags", "--abbrev=0"], - capture_output=True, - text=True, - check=True - ) + result = subprocess.run(["git", "describe", "--tags", "--abbrev=0"], capture_output=True, text=True, check=True) tag = result.stdout.strip() return git_tag_to_pep440(tag) except subprocess.CalledProcessError: @@ -156,18 +154,18 @@ def get_git_describe(repo_path: Optional[pathlib.Path] = None, since_major=False Git describe output or None if no tags exist """ cwd = repo_path if repo_path is not None else None - pattern="v*.*.*" + pattern = "v*.*.*" if since_major: - pattern="v*.0.0" + pattern = "v*.0.0" elif since_minor: - pattern="v*.*.0" + pattern = "v*.*.0" try: result = subprocess.run( ["git", "describe", "--tags", "--long", "--match", pattern], capture_output=True, text=True, check=True, - cwd=cwd + cwd=cwd, ) result.check_returncode() return result.stdout.strip() diff --git a/duckdb_packaging/build_backend.py b/duckdb_packaging/build_backend.py index de1a9535..b9a005db 100644 --- a/duckdb_packaging/build_backend.py +++ b/duckdb_packaging/build_backend.py @@ -12,6 +12,7 @@ Also see https://peps.python.org/pep-0517/#in-tree-build-backends. """ + import sys import os import subprocess @@ -39,7 +40,7 @@ _FORCED_PEP440_VERSION = forced_version_from_env() -def _log(msg: str, is_error: bool=False) -> None: +def _log(msg: str, is_error: bool = False) -> None: """Log a message with build backend prefix. Args: @@ -84,9 +85,9 @@ def _duckdb_submodule_path() -> Path: cur_module_reponame = None cur_module_path = None elif line.strip().startswith("path"): - cur_module_path = line.split('=')[-1].strip() + cur_module_path = line.split("=")[-1].strip() elif line.strip().startswith("url"): - basename = os.path.basename(line.split('=')[-1].strip()) + basename = os.path.basename(line.split("=")[-1].strip()) cur_module_reponame = basename[:-4] if basename.endswith(".git") else basename if cur_module_reponame is not None and cur_module_path is not None: modules[cur_module_reponame] = cur_module_path @@ -115,7 +116,7 @@ def _version_file_path() -> Path: return package_dir / _DUCKDB_VERSION_FILENAME -def _write_duckdb_long_version(long_version: str)-> None: +def _write_duckdb_long_version(long_version: str) -> None: """Write the given version string to a file in the same directory as this module.""" _version_file_path().write_text(long_version, encoding="utf-8") @@ -126,7 +127,7 @@ def _read_duckdb_long_version() -> str: def _skbuild_config_add( - key: str, value: Union[list, str], config_settings: dict[str, Union[list[str],str]], fail_if_exists: bool=False + key: str, value: Union[list, str], config_settings: dict[str, Union[list[str], str]], fail_if_exists: bool = False ): """Add or modify a configuration setting for scikit-build-core. @@ -178,7 +179,7 @@ def _skbuild_config_add( ) -def build_sdist(sdist_directory: str, config_settings: Optional[dict[str, Union[list[str],str]]] = None) -> str: +def build_sdist(sdist_directory: str, config_settings: Optional[dict[str, Union[list[str], str]]] = None) -> str: """Build a source distribution using the DuckDB submodule. This function extracts the DuckDB version from either the git submodule and saves it @@ -207,9 +208,9 @@ def build_sdist(sdist_directory: str, config_settings: Optional[dict[str, Union[ def build_wheel( - wheel_directory: str, - config_settings: Optional[dict[str, Union[list[str],str]]] = None, - metadata_directory: Optional[str] = None, + wheel_directory: str, + config_settings: Optional[dict[str, Union[list[str], str]]] = None, + metadata_directory: Optional[str] = None, ) -> str: """Build a wheel from either git submodule or extracted sdist sources. @@ -246,7 +247,6 @@ def build_wheel( else: _log("No explicit DuckDB submodule version provided. Letting CMake figure it out.") - return skbuild_build_wheel(wheel_directory, config_settings=config_settings, metadata_directory=metadata_directory) diff --git a/duckdb_packaging/pypi_cleanup.py b/duckdb_packaging/pypi_cleanup.py index 031adf94..80073c0e 100644 --- a/duckdb_packaging/pypi_cleanup.py +++ b/duckdb_packaging/pypi_cleanup.py @@ -28,8 +28,8 @@ from requests.exceptions import RequestException from urllib3 import Retry -_PYPI_URL_PROD = 'https://pypi.org/' -_PYPI_URL_TEST = 'https://test.pypi.org/' +_PYPI_URL_PROD = "https://pypi.org/" +_PYPI_URL_TEST = "https://test.pypi.org/" _DEFAULT_MAX_NIGHTLIES = 2 _LOGIN_RETRY_ATTEMPTS = 3 _LOGIN_RETRY_DELAY = 5 @@ -50,88 +50,70 @@ def create_argument_parser() -> argparse.ArgumentParser: * Keep the configured amount of dev releases per version, and remove older dev releases """, epilog="Environment variables required (unless --dry-run): PYPI_CLEANUP_PASSWORD, PYPI_CLEANUP_OTP", - formatter_class=argparse.RawDescriptionHelpFormatter + formatter_class=argparse.RawDescriptionHelpFormatter, ) - parser.add_argument( - "--dry-run", - action="store_true", - help="Show what would be deleted but don't actually do it" - ) + parser.add_argument("--dry-run", action="store_true", help="Show what would be deleted but don't actually do it") host_group = parser.add_mutually_exclusive_group(required=True) - host_group.add_argument( - "--prod", - action="store_true", - help="Use production PyPI (pypi.org)" - ) - host_group.add_argument( - "--test", - action="store_true", - help="Use test PyPI (test.pypi.org)" - ) + host_group.add_argument("--prod", action="store_true", help="Use production PyPI (pypi.org)") + host_group.add_argument("--test", action="store_true", help="Use test PyPI (test.pypi.org)") parser.add_argument( - "-m", "--max-nightlies", + "-m", + "--max-nightlies", type=int, default=_DEFAULT_MAX_NIGHTLIES, - help=f"Max number of nightlies of unreleased versions (default={_DEFAULT_MAX_NIGHTLIES})" + help=f"Max number of nightlies of unreleased versions (default={_DEFAULT_MAX_NIGHTLIES})", ) - parser.add_argument( - "-u", "--username", - type=validate_username, - help="PyPI username (required unless --dry-run)" - ) + parser.add_argument("-u", "--username", type=validate_username, help="PyPI username (required unless --dry-run)") - parser.add_argument( - "-v", "--verbose", - action="store_true", - help="Enable verbose debug logging" - ) + parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose debug logging") return parser + class PyPICleanupError(Exception): """Base exception for PyPI cleanup operations.""" + pass class AuthenticationError(PyPICleanupError): """Raised when authentication fails.""" + pass class ValidationError(PyPICleanupError): """Raised when input validation fails.""" + pass def setup_logging(verbose: bool = False) -> None: """Configure logging with appropriate level and format.""" level = logging.DEBUG if verbose else logging.INFO - logging.basicConfig( - level=level, - format='%(asctime)s - %(levelname)s - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' - ) + logging.basicConfig(level=level, format="%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S") def validate_username(value: str) -> str: """Validate and sanitize username input.""" if not value or not value.strip(): raise argparse.ArgumentTypeError("Username cannot be empty") - + username = value.strip() if len(username) > 100: # Reasonable limit raise argparse.ArgumentTypeError("Username too long (max 100 characters)") - + # Basic validation - PyPI usernames are alphanumeric with limited special chars - if not re.match(r'^[a-zA-Z0-9][a-zA-Z0-9._-]*[a-zA-Z0-9]$|^[a-zA-Z0-9]$', username): + if not re.match(r"^[a-zA-Z0-9][a-zA-Z0-9._-]*[a-zA-Z0-9]$|^[a-zA-Z0-9]$", username): raise argparse.ArgumentTypeError("Invalid username format") - + return username + @contextlib.contextmanager def session_with_retries() -> Generator[Session, None, None]: """Create a requests session with retry strategy for ephemeral errors.""" @@ -154,19 +136,20 @@ def session_with_retries() -> Generator[Session, None, None]: session.mount("https://", adapter) yield session + def load_credentials(dry_run: bool) -> tuple[Optional[str], Optional[str]]: """Load credentials from environment variables.""" if dry_run: return None, None - - password = os.getenv('PYPI_CLEANUP_PASSWORD') - otp = os.getenv('PYPI_CLEANUP_OTP') - + + password = os.getenv("PYPI_CLEANUP_PASSWORD") + otp = os.getenv("PYPI_CLEANUP_OTP") + if not password: raise ValidationError("PYPI_CLEANUP_PASSWORD environment variable is required when not in dry-run mode") if not otp: raise ValidationError("PYPI_CLEANUP_OTP environment variable is required when not in dry-run mode") - + return password, otp @@ -174,15 +157,17 @@ def validate_arguments(args: argparse.Namespace) -> None: """Validate parsed arguments.""" if not args.dry_run and not args.username: raise ValidationError("--username is required when not in dry-run mode") - + if args.max_nightlies < 0: raise ValidationError("--max-nightlies must be non-negative") + class CsrfParser(HTMLParser): """HTML parser to extract CSRF tokens from PyPI forms. - + Based on pypi-cleanup package (https://github.com/arcivanov/pypi-cleanup/tree/master) """ + def __init__(self, target, contains_input=None) -> None: super().__init__() self._target = target @@ -222,24 +207,31 @@ def handle_endtag(self, tag): class PyPICleanup: """Main class for performing PyPI package cleanup operations.""" - def __init__(self, index_url: str, do_delete: bool, max_dev_releases: int=_DEFAULT_MAX_NIGHTLIES, - username: Optional[str]=None, password: Optional[str]=None, otp: Optional[str]=None) -> None: + def __init__( + self, + index_url: str, + do_delete: bool, + max_dev_releases: int = _DEFAULT_MAX_NIGHTLIES, + username: Optional[str] = None, + password: Optional[str] = None, + otp: Optional[str] = None, + ) -> None: parsed_url = urlparse(index_url) - self._index_url = parsed_url.geturl().rstrip('/') + self._index_url = parsed_url.geturl().rstrip("/") self._index_host = parsed_url.hostname self._do_delete = do_delete self._max_dev_releases = max_dev_releases self._username = username self._password = password self._otp = otp - self._package = 'duckdb' + self._package = "duckdb" self._dev_version_pattern = re.compile(r"^(?P\d+\.\d+\.\d+)\.dev(?P\d+)$") self._rc_version_pattern = re.compile(r"^(?P\d+\.\d+\.\d+)\.rc\d+$") self._stable_version_pattern = re.compile(r"^\d+\.\d+\.\d+(\.post\d+)?$") def run(self) -> int: """Execute the cleanup process. - + Returns: int: Exit code (0 for success, non-zero for failure) """ @@ -268,17 +260,17 @@ def _execute_cleanup(self, http_session: Session) -> int: if not versions: logging.info(f"No releases found for {self._package}") return 0 - + # Determine versions to delete versions_to_delete = self._determine_versions_to_delete(versions) if not versions_to_delete: logging.info("No versions to delete (no stale rc's or dev releases)") return 0 - + logging.warning(f"Found {len(versions_to_delete)} versions to clean up:") for version in sorted(versions_to_delete): logging.warning(version) - + if not self._do_delete: logging.info("Dry run complete - no packages were deleted") return 0 @@ -286,14 +278,14 @@ def _execute_cleanup(self, http_session: Session) -> int: # Perform authentication and deletion self._authenticate(http_session) self._delete_versions(http_session, versions_to_delete) - + logging.info(f"Successfully cleaned up {len(versions_to_delete)} development versions") return 0 - + def _fetch_released_versions(self, http_session: Session) -> set[str]: """Fetch package release information from PyPI API.""" logging.debug(f"Fetching package information for '{self._package}'") - + try: req = http_session.get(f"{self._index_url}/pypi/{self._package}/json") req.raise_for_status() @@ -392,12 +384,12 @@ def _determine_versions_to_delete(self, versions: set[str]) -> set[str]: logging.warning(f"Found version string(s) in an unsupported format: {unknown_versions}") return versions_to_delete - + def _authenticate(self, http_session: Session) -> None: """Authenticate with PyPI.""" if not self._username or not self._password: raise AuthenticationError("Username and password are required for authentication") - + logging.info(f"Authenticating user '{self._username}' with PyPI") try: @@ -408,12 +400,12 @@ def _authenticate(self, http_session: Session) -> None: if login_response.url.startswith(f"{self._index_url}/account/two-factor/"): logging.debug("Two-factor authentication required") self._handle_two_factor_auth(http_session, login_response) - + logging.info("Authentication successful") except RequestException as e: raise AuthenticationError(f"Network error during authentication: {e}") from e - + def _get_csrf_token(self, http_session: Session, form_action: str) -> str: """Extract CSRF token from a form page.""" resp = http_session.get(f"{self._index_url}{form_action}") @@ -423,23 +415,19 @@ def _get_csrf_token(self, http_session: Session, form_action: str) -> str: if not parser.csrf: raise AuthenticationError(f"No CSRF token found in {form_action}") return parser.csrf - + def _perform_login(self, http_session: Session) -> requests.Response: """Perform the initial login with username/password.""" # Get login form and CSRF token csrf_token = self._get_csrf_token(http_session, "/account/login/") - login_data = { - "csrf_token": csrf_token, - "username": self._username, - "password": self._password - } + login_data = {"csrf_token": csrf_token, "username": self._username, "password": self._password} response = http_session.post( f"{self._index_url}/account/login/", data=login_data, - headers={"referer": f"{self._index_url}/account/login/"} + headers={"referer": f"{self._index_url}/account/login/"}, ) response.raise_for_status() @@ -448,16 +436,16 @@ def _perform_login(self, http_session: Session) -> requests.Response: raise AuthenticationError(f"Login failed for user '{self._username}' - check credentials") return response - + def _handle_two_factor_auth(self, http_session: Session, response: requests.Response) -> None: """Handle two-factor authentication.""" if not self._otp: raise AuthenticationError("Two-factor authentication required but no OTP secret provided") - + two_factor_url = response.url - form_action = two_factor_url[len(self._index_url):] + form_action = two_factor_url[len(self._index_url) :] csrf_token = self._get_csrf_token(http_session, form_action) - + # Try authentication with retries for attempt in range(_LOGIN_RETRY_ATTEMPTS): try: @@ -467,7 +455,7 @@ def _handle_two_factor_auth(self, http_session: Session, response: requests.Resp auth_response = http_session.post( two_factor_url, data={"csrf_token": csrf_token, "method": "totp", "totp_value": auth_code}, - headers={"referer": two_factor_url} + headers={"referer": two_factor_url}, ) auth_response.raise_for_status() @@ -479,19 +467,19 @@ def _handle_two_factor_auth(self, http_session: Session, response: requests.Resp if attempt < _LOGIN_RETRY_ATTEMPTS - 1: logging.debug(f"2FA code rejected, retrying in {_LOGIN_RETRY_DELAY} seconds...") time.sleep(_LOGIN_RETRY_DELAY) - + except RequestException as e: if attempt == _LOGIN_RETRY_ATTEMPTS - 1: raise AuthenticationError(f"Network error during 2FA: {e}") from e logging.debug(f"Network error during 2FA attempt {attempt + 1}, retrying...") time.sleep(_LOGIN_RETRY_DELAY) - + raise AuthenticationError("Two-factor authentication failed after all attempts") - + def _delete_versions(self, http_session: Session, versions_to_delete: set[str]) -> None: """Delete the specified package versions.""" logging.info(f"Starting deletion of {len(versions_to_delete)} development versions") - + failed_deletions = list() for version in sorted(versions_to_delete): try: @@ -501,24 +489,24 @@ def _delete_versions(self, http_session: Session, versions_to_delete: set[str]) # Continue with other versions rather than failing completely logging.error(f"Failed to delete version {version}: {e}") failed_deletions.append(version) - + if failed_deletions: raise PyPICleanupError( f"Failed to delete {len(failed_deletions)}/{len(versions_to_delete)} versions: {failed_deletions}" ) - + def _delete_single_version(self, http_session: Session, version: str) -> None: """Delete a single package version.""" # Safety check if not self._is_dev_version(version) or self._is_rc_version(version): raise PyPICleanupError(f"Refusing to delete non-[dev|rc] version: {version}") - + logging.debug(f"Deleting {self._package} version {version}") - + # Get deletion form and CSRF token form_action = f"/manage/project/{self._package}/release/{version}/" form_url = f"{self._index_url}{form_action}" - + csrf_token = self._get_csrf_token(http_session, form_action) # Submit deletion request @@ -528,7 +516,7 @@ def _delete_single_version(self, http_session: Session, version: str) -> None: "csrf_token": csrf_token, "confirm_delete_version": version, }, - headers={"referer": form_url} + headers={"referer": form_url}, ) delete_response.raise_for_status() @@ -537,26 +525,27 @@ def main() -> int: """Main entry point for the script.""" parser = create_argument_parser() args = parser.parse_args() - + # Setup logging setup_logging(args.verbose) - + try: # Validate arguments validate_arguments(args) - + # Load credentials password, otp = load_credentials(args.dry_run) - + # Determine PyPI URL pypi_url = _PYPI_URL_PROD if args.prod else _PYPI_URL_TEST - + # Create and run cleanup - cleanup = PyPICleanup(pypi_url, not args.dry_run, args.max_nightlies, username=args.username, - password=password, otp=otp) - + cleanup = PyPICleanup( + pypi_url, not args.dry_run, args.max_nightlies, username=args.username, password=password, otp=otp + ) + return cleanup.run() - + except ValidationError as e: logging.error(f"Configuration error: {e}") return 2 diff --git a/duckdb_packaging/setuptools_scm_version.py b/duckdb_packaging/setuptools_scm_version.py index 27bedd24..217b2ffe 100644 --- a/duckdb_packaging/setuptools_scm_version.py +++ b/duckdb_packaging/setuptools_scm_version.py @@ -21,9 +21,10 @@ def _main_branch_versioning(): - from_env = os.getenv('MAIN_BRANCH_VERSIONING') + from_env = os.getenv("MAIN_BRANCH_VERSIONING") return from_env == "1" if from_env is not None else MAIN_BRANCH_VERSIONING + def version_scheme(version: Any) -> str: """ setuptools_scm version scheme that matches DuckDB's original behavior. @@ -65,13 +66,13 @@ def _bump_version(base_version: str, distance: int, dirty: bool = False) -> str: # Otherwise we're at a distance and / or dirty, and need to bump if post != 0: # We're developing on top of a post-release - return f"{format_version(major, minor, patch, post=post+1)}.dev{distance}" + return f"{format_version(major, minor, patch, post=post + 1)}.dev{distance}" elif rc != 0: # We're developing on top of an rc - return f"{format_version(major, minor, patch, rc=rc+1)}.dev{distance}" + return f"{format_version(major, minor, patch, rc=rc + 1)}.dev{distance}" elif _main_branch_versioning(): - return f"{format_version(major, minor+1, 0)}.dev{distance}" - return f"{format_version(major, minor, patch+1)}.dev{distance}" + return f"{format_version(major, minor + 1, 0)}.dev{distance}" + return f"{format_version(major, minor, patch + 1)}.dev{distance}" def forced_version_from_env(): @@ -117,9 +118,9 @@ def _git_describe_override_to_pep_440(override_value: str) -> str: version, distance, commit_hash = match.groups() # Convert version format to PEP440 format (v1.3.1-post1 -> 1.3.1.post1) - if '-post' in version: + if "-post" in version: version = version.replace("-post", ".post") - elif '-rc' in version: + elif "-rc" in version: version = version.replace("-rc", "rc") # Bump version and format according to PEP440 diff --git a/scripts/generate_connection_code.py b/scripts/generate_connection_code.py index 3737f83a..8e2bace9 100644 --- a/scripts/generate_connection_code.py +++ b/scripts/generate_connection_code.py @@ -3,7 +3,7 @@ import generate_connection_wrapper_methods import generate_connection_wrapper_stubs -if __name__ == '__main__': +if __name__ == "__main__": generate_connection_methods.generate() generate_connection_stubs.generate() generate_connection_wrapper_methods.generate() diff --git a/scripts/generate_connection_methods.py b/scripts/generate_connection_methods.py index c1f01e54..a48b6142 100644 --- a/scripts/generate_connection_methods.py +++ b/scripts/generate_connection_methods.py @@ -13,23 +13,23 @@ def is_py_kwargs(method): - return 'kwargs_as_dict' in method and method['kwargs_as_dict'] == True + return "kwargs_as_dict" in method and method["kwargs_as_dict"] == True def is_py_args(method): - if 'args' not in method: + if "args" not in method: return False - args = method['args'] + args = method["args"] if len(args) == 0: return False - if args[0]['name'] != '*args': + if args[0]["name"] != "*args": return False return True def generate(): # Read the PYCONNECTION_SOURCE file - with open(PYCONNECTION_SOURCE, 'r') as source_file: + with open(PYCONNECTION_SOURCE, "r") as source_file: source_code = source_file.readlines() start_index = -1 @@ -52,16 +52,16 @@ def generate(): # ---- Generate the definition code from the json ---- # Read the JSON file - with open(JSON_PATH, 'r') as json_file: + with open(JSON_PATH, "r") as json_file: connection_methods = json.load(json_file) DEFAULT_ARGUMENT_MAP = { - 'True': 'true', - 'False': 'false', - 'None': 'py::none()', - 'PythonUDFType.NATIVE': 'PythonUDFType::NATIVE', - 'PythonExceptionHandling.DEFAULT': 'PythonExceptionHandling::FORWARD_ERROR', - 'FunctionNullHandling.DEFAULT': 'FunctionNullHandling::DEFAULT_NULL_HANDLING', + "True": "true", + "False": "false", + "None": "py::none()", + "PythonUDFType.NATIVE": "PythonUDFType::NATIVE", + "PythonExceptionHandling.DEFAULT": "PythonExceptionHandling::FORWARD_ERROR", + "FunctionNullHandling.DEFAULT": "FunctionNullHandling::DEFAULT_NULL_HANDLING", } def map_default(val): @@ -72,61 +72,61 @@ def map_default(val): def create_arguments(arguments) -> list: result = [] for arg in arguments: - if arg['name'] == '*args': + if arg["name"] == "*args": break - argument = f"py::arg(\"{arg['name']}\")" - if 'allow_none' in arg: - value = str(arg['allow_none']).lower() + argument = f'py::arg("{arg["name"]}")' + if "allow_none" in arg: + value = str(arg["allow_none"]).lower() argument += f".none({value})" # Add the default argument if present - if 'default' in arg: - default = map_default(arg['default']) + if "default" in arg: + default = map_default(arg["default"]) argument += f" = {default}" result.append(argument) return result def create_definition(name, method) -> str: - definition = f"m.def(\"{name}\"" + definition = f'm.def("{name}"' definition += ", " - definition += f"""&DuckDBPyConnection::{method['function']}""" + definition += f"""&DuckDBPyConnection::{method["function"]}""" definition += ", " - definition += f"\"{method['docs']}\"" - if 'args' in method and not is_py_args(method): + definition += f'"{method["docs"]}"' + if "args" in method and not is_py_args(method): definition += ", " - arguments = create_arguments(method['args']) - definition += ', '.join(arguments) - if 'kwargs' in method: + arguments = create_arguments(method["args"]) + definition += ", ".join(arguments) + if "kwargs" in method: definition += ", " if is_py_kwargs(method): definition += "py::kw_only()" else: definition += "py::kw_only(), " - arguments = create_arguments(method['kwargs']) - definition += ', '.join(arguments) + arguments = create_arguments(method["kwargs"]) + definition += ", ".join(arguments) definition += ");" return definition body = [] for method in connection_methods: - if isinstance(method['name'], list): - names = method['name'] + if isinstance(method["name"], list): + names = method["name"] else: - names = [method['name']] + names = [method["name"]] for name in names: body.append(create_definition(name, method)) # ---- End of generation code ---- - with_newlines = ['\t' + x + '\n' for x in body] + with_newlines = ["\t" + x + "\n" for x in body] # Recreate the file content by concatenating all the pieces together new_content = start_section + with_newlines + end_section # Write out the modified PYCONNECTION_SOURCE file - with open(PYCONNECTION_SOURCE, 'w') as source_file: + with open(PYCONNECTION_SOURCE, "w") as source_file: source_file.write("".join(new_content)) -if __name__ == '__main__': +if __name__ == "__main__": raise ValueError("Please use 'generate_connection_code.py' instead of running the individual script(s)") # generate() diff --git a/scripts/generate_connection_stubs.py b/scripts/generate_connection_stubs.py index fbb66c21..e3831173 100644 --- a/scripts/generate_connection_stubs.py +++ b/scripts/generate_connection_stubs.py @@ -12,7 +12,7 @@ def generate(): # Read the DUCKDB_STUBS_FILE file - with open(DUCKDB_STUBS_FILE, 'r') as source_file: + with open(DUCKDB_STUBS_FILE, "r") as source_file: source_code = source_file.readlines() start_index = -1 @@ -35,7 +35,7 @@ def generate(): # ---- Generate the definition code from the json ---- # Read the JSON file - with open(JSON_PATH, 'r') as json_file: + with open(JSON_PATH, "r") as json_file: connection_methods = json.load(json_file) body = [] @@ -45,8 +45,8 @@ def create_arguments(arguments) -> list: for arg in arguments: argument = f"{arg['name']}: {arg['type']}" # Add the default argument if present - if 'default' in arg: - default = arg['default'] + if "default" in arg: + default = arg["default"] argument += f" = {default}" result.append(argument) return result @@ -57,13 +57,13 @@ def create_definition(name, method, overloaded: bool) -> str: else: definition: str = "" definition += f"def {name}(" - arguments = ['self'] - if 'args' in method: - arguments.extend(create_arguments(method['args'])) - if 'kwargs' in method: - if not any(x.startswith('*') for x in arguments): + arguments = ["self"] + if "args" in method: + arguments.extend(create_arguments(method["args"])) + if "kwargs" in method: + if not any(x.startswith("*") for x in arguments): arguments.append("*") - arguments.extend(create_arguments(method['kwargs'])) + arguments.extend(create_arguments(method["kwargs"])) definition += ", ".join(arguments) definition += ")" definition += f" -> {method['return']}: ..." @@ -71,28 +71,28 @@ def create_definition(name, method, overloaded: bool) -> str: # We have "duplicate" methods, which are overloaded. # We keep note of them to add the @overload decorator. - overloaded_methods: set[str] = {m for m in connection_methods if isinstance(m['name'], list)} + overloaded_methods: set[str] = {m for m in connection_methods if isinstance(m["name"], list)} for method in connection_methods: - if isinstance(method['name'], list): - names = method['name'] + if isinstance(method["name"], list): + names = method["name"] else: - names = [method['name']] + names = [method["name"]] for name in names: body.append(create_definition(name, method, name in overloaded_methods)) # ---- End of generation code ---- - with_newlines = [' ' + x + '\n' for x in body] + with_newlines = [" " + x + "\n" for x in body] # Recreate the file content by concatenating all the pieces together new_content = start_section + with_newlines + end_section # Write out the modified DUCKDB_STUBS_FILE file - with open(DUCKDB_STUBS_FILE, 'w') as source_file: + with open(DUCKDB_STUBS_FILE, "w") as source_file: source_file.write("".join(new_content)) -if __name__ == '__main__': +if __name__ == "__main__": raise ValueError("Please use 'generate_connection_code.py' instead of running the individual script(s)") # generate() diff --git a/scripts/generate_connection_wrapper_methods.py b/scripts/generate_connection_wrapper_methods.py index af5ad4ac..45ac45cc 100644 --- a/scripts/generate_connection_wrapper_methods.py +++ b/scripts/generate_connection_wrapper_methods.py @@ -40,16 +40,16 @@ INIT_PY_END = "# END OF CONNECTION WRAPPER" # Read the JSON file -with open(WRAPPER_JSON_PATH, 'r') as json_file: +with open(WRAPPER_JSON_PATH, "r") as json_file: wrapper_methods = json.load(json_file) # On DuckDBPyConnection these are read_only_properties, they're basically functions without requiring () to invoke # that's not possible on 'duckdb' so it becomes a function call with no arguments (i.e duckdb.description()) -READONLY_PROPERTY_NAMES = ['description', 'rowcount'] +READONLY_PROPERTY_NAMES = ["description", "rowcount"] # These methods are not directly DuckDBPyConnection methods, # they first call 'FromDF' and then call a method on the created DuckDBPyRelation -SPECIAL_METHOD_NAMES = [x['name'] for x in wrapper_methods if x['name'] not in READONLY_PROPERTY_NAMES] +SPECIAL_METHOD_NAMES = [x["name"] for x in wrapper_methods if x["name"] not in READONLY_PROPERTY_NAMES] RETRIEVE_CONN_FROM_DICT = """auto connection_arg = kwargs.contains("conn") ? kwargs["conn"] : py::none(); auto conn = py::cast>(connection_arg); @@ -57,18 +57,18 @@ def is_py_args(method): - if 'args' not in method: + if "args" not in method: return False - args = method['args'] + args = method["args"] if len(args) == 0: return False - if args[0]['name'] != '*args': + if args[0]["name"] != "*args": return False return True def is_py_kwargs(method): - return 'kwargs_as_dict' in method and method['kwargs_as_dict'] == True + return "kwargs_as_dict" in method and method["kwargs_as_dict"] == True def remove_section(content, start_marker, end_marker) -> tuple[list[str], list[str]]: @@ -94,33 +94,33 @@ def remove_section(content, start_marker, end_marker) -> tuple[list[str], list[s def generate(): # Read the DUCKDB_PYTHON_SOURCE file - with open(DUCKDB_PYTHON_SOURCE, 'r') as source_file: + with open(DUCKDB_PYTHON_SOURCE, "r") as source_file: source_code = source_file.readlines() start_section, end_section = remove_section(source_code, START_MARKER, END_MARKER) # Read the DUCKDB_INIT_FILE file - with open(DUCKDB_INIT_FILE, 'r') as source_file: + with open(DUCKDB_INIT_FILE, "r") as source_file: source_code = source_file.readlines() py_start, py_end = remove_section(source_code, INIT_PY_START, INIT_PY_END) # ---- Generate the definition code from the json ---- # Read the JSON file - with open(JSON_PATH, 'r') as json_file: + with open(JSON_PATH, "r") as json_file: connection_methods = json.load(json_file) # Collect the definitions from the pyconnection.hpp header - cpp_connection_defs = get_methods('DuckDBPyConnection') - cpp_relation_defs = get_methods('DuckDBPyRelation') + cpp_connection_defs = get_methods("DuckDBPyConnection") + cpp_relation_defs = get_methods("DuckDBPyRelation") DEFAULT_ARGUMENT_MAP = { - 'True': 'true', - 'False': 'false', - 'None': 'py::none()', - 'PythonUDFType.NATIVE': 'PythonUDFType::NATIVE', - 'PythonExceptionHandling.DEFAULT': 'PythonExceptionHandling::FORWARD_ERROR', - 'FunctionNullHandling.DEFAULT': 'FunctionNullHandling::DEFAULT_NULL_HANDLING', + "True": "true", + "False": "false", + "None": "py::none()", + "PythonUDFType.NATIVE": "PythonUDFType::NATIVE", + "PythonExceptionHandling.DEFAULT": "PythonExceptionHandling::FORWARD_ERROR", + "FunctionNullHandling.DEFAULT": "FunctionNullHandling::DEFAULT_NULL_HANDLING", } def map_default(val): @@ -131,16 +131,16 @@ def map_default(val): def create_arguments(arguments) -> list: result = [] for arg in arguments: - if arg['name'] == '*args': + if arg["name"] == "*args": # py::args() should not have a corresponding py::arg() continue - argument = f"py::arg(\"{arg['name']}\")" - if 'allow_none' in arg: - value = str(arg['allow_none']).lower() + argument = f'py::arg("{arg["name"]}")' + if "allow_none" in arg: + value = str(arg["allow_none"]).lower() argument += f".none({value})" # Add the default argument if present - if 'default' in arg: - default = map_default(arg['default']) + if "default" in arg: + default = map_default(arg["default"]) argument += f" = {default}" result.append(argument) return result @@ -148,11 +148,11 @@ def create_arguments(arguments) -> list: def get_lambda_definition(name, method, definition: ConnectionMethod) -> str: param_definitions = [] if name in SPECIAL_METHOD_NAMES: - param_definitions.append('const PandasDataFrame &df') + param_definitions.append("const PandasDataFrame &df") param_definitions.extend([x.proto for x in definition.params]) if not is_py_kwargs(method): - param_definitions.append('shared_ptr conn = nullptr') + param_definitions.append("shared_ptr conn = nullptr") param_definitions = ", ".join(param_definitions) param_names = [x.name for x in definition.params] @@ -160,73 +160,73 @@ def get_lambda_definition(name, method, definition: ConnectionMethod) -> str: function_name = definition.name if name in SPECIAL_METHOD_NAMES: - function_name = 'FromDF(df)->' + function_name + function_name = "FromDF(df)->" + function_name format_dict = { - 'param_definitions': param_definitions, - 'opt_retrieval': '', - 'opt_return': '' if definition.is_void else 'return ', - 'function_name': function_name, - 'parameter_names': param_names, + "param_definitions": param_definitions, + "opt_retrieval": "", + "opt_return": "" if definition.is_void else "return ", + "function_name": function_name, + "parameter_names": param_names, } if is_py_kwargs(method): - format_dict['opt_retrieval'] += RETRIEVE_CONN_FROM_DICT + format_dict["opt_retrieval"] += RETRIEVE_CONN_FROM_DICT return LAMBDA_FORMAT.format_map(format_dict) def create_definition(name, method, lambda_def) -> str: - definition = f"m.def(\"{name}\"" + definition = f'm.def("{name}"' definition += ", " definition += lambda_def definition += ", " - definition += f"\"{method['docs']}\"" - if 'args' in method and not is_py_args(method): + definition += f'"{method["docs"]}"' + if "args" in method and not is_py_args(method): definition += ", " - arguments = create_arguments(method['args']) - definition += ', '.join(arguments) - if 'kwargs' in method: + arguments = create_arguments(method["args"]) + definition += ", ".join(arguments) + if "kwargs" in method: definition += ", " if is_py_kwargs(method): definition += "py::kw_only()" else: definition += "py::kw_only(), " - arguments = create_arguments(method['kwargs']) - definition += ', '.join(arguments) + arguments = create_arguments(method["kwargs"]) + definition += ", ".join(arguments) definition += ");" return definition body = [] all_names = [] for method in connection_methods: - if isinstance(method['name'], list): - names = method['name'] + if isinstance(method["name"], list): + names = method["name"] else: - names = [method['name']] - if 'kwargs' not in method: - method['kwargs'] = [] - method['kwargs'].append({'name': 'connection', 'type': 'Optional[DuckDBPyConnection]', 'default': 'None'}) + names = [method["name"]] + if "kwargs" not in method: + method["kwargs"] = [] + method["kwargs"].append({"name": "connection", "type": "Optional[DuckDBPyConnection]", "default": "None"}) for name in names: - function_name = method['function'] + function_name = method["function"] cpp_definition = cpp_connection_defs[function_name] lambda_def = get_lambda_definition(name, method, cpp_definition) body.append(create_definition(name, method, lambda_def)) all_names.append(name) for method in wrapper_methods: - if isinstance(method['name'], list): - names = method['name'] + if isinstance(method["name"], list): + names = method["name"] else: - names = [method['name']] - if 'kwargs' not in method: - method['kwargs'] = [] - method['kwargs'].append({'name': 'connection', 'type': 'Optional[DuckDBPyConnection]', 'default': 'None'}) + names = [method["name"]] + if "kwargs" not in method: + method["kwargs"] = [] + method["kwargs"].append({"name": "connection", "type": "Optional[DuckDBPyConnection]", "default": "None"}) for name in names: - function_name = method['function'] + function_name = method["function"] if name in SPECIAL_METHOD_NAMES: cpp_definition = cpp_relation_defs[function_name] - if 'args' not in method: - method['args'] = [] - method['args'].insert(0, {'name': 'df', 'type': 'DataFrame'}) + if "args" not in method: + method["args"] = [] + method["args"].insert(0, {"name": "df", "type": "DataFrame"}) else: cpp_definition = cpp_connection_defs[function_name] lambda_def = get_lambda_definition(name, method, cpp_definition) @@ -235,24 +235,24 @@ def create_definition(name, method, lambda_def) -> str: # ---- End of generation code ---- - with_newlines = ['\t' + x + '\n' for x in body] + with_newlines = ["\t" + x + "\n" for x in body] # Recreate the file content by concatenating all the pieces together new_content = start_section + with_newlines + end_section # Write out the modified DUCKDB_PYTHON_SOURCE file - with open(DUCKDB_PYTHON_SOURCE, 'w') as source_file: + with open(DUCKDB_PYTHON_SOURCE, "w") as source_file: source_file.write("".join(new_content)) - item_list = '\n'.join([f'\t{name},' for name in all_names]) - str_item_list = '\n'.join([f"\t'{name}'," for name in all_names]) - imports = PY_INIT_FORMAT.format(item_list=item_list, str_item_list=str_item_list).split('\n') - imports = [x + '\n' for x in imports] + item_list = "\n".join([f"\t{name}," for name in all_names]) + str_item_list = "\n".join([f"\t'{name}'," for name in all_names]) + imports = PY_INIT_FORMAT.format(item_list=item_list, str_item_list=str_item_list).split("\n") + imports = [x + "\n" for x in imports] init_py_content = py_start + imports + py_end # Write out the modified DUCKDB_INIT_FILE file - with open(DUCKDB_INIT_FILE, 'w') as source_file: + with open(DUCKDB_INIT_FILE, "w") as source_file: source_file.write("".join(init_py_content)) -if __name__ == '__main__': +if __name__ == "__main__": # raise ValueError("Please use 'generate_connection_code.py' instead of running the individual script(s)") generate() diff --git a/scripts/generate_connection_wrapper_stubs.py b/scripts/generate_connection_wrapper_stubs.py index 62c60a84..02e36c4e 100644 --- a/scripts/generate_connection_wrapper_stubs.py +++ b/scripts/generate_connection_wrapper_stubs.py @@ -13,7 +13,7 @@ def generate(): # Read the DUCKDB_STUBS_FILE file - with open(DUCKDB_STUBS_FILE, 'r') as source_file: + with open(DUCKDB_STUBS_FILE, "r") as source_file: source_code = source_file.readlines() start_index = -1 @@ -38,10 +38,10 @@ def generate(): methods = [] # Read the JSON file - with open(JSON_PATH, 'r') as json_file: + with open(JSON_PATH, "r") as json_file: connection_methods = json.load(json_file) - with open(WRAPPER_JSON_PATH, 'r') as json_file: + with open(WRAPPER_JSON_PATH, "r") as json_file: wrapper_methods = json.load(json_file) methods.extend(connection_methods) @@ -49,19 +49,19 @@ def generate(): # On DuckDBPyConnection these are read_only_properties, they're basically functions without requiring () to invoke # that's not possible on 'duckdb' so it becomes a function call with no arguments (i.e duckdb.description()) - READONLY_PROPERTY_NAMES = ['description', 'rowcount'] + READONLY_PROPERTY_NAMES = ["description", "rowcount"] # These methods are not directly DuckDBPyConnection methods, # they first call 'from_df' and then call a method on the created DuckDBPyRelation - SPECIAL_METHOD_NAMES = [x['name'] for x in wrapper_methods if x['name'] not in READONLY_PROPERTY_NAMES] + SPECIAL_METHOD_NAMES = [x["name"] for x in wrapper_methods if x["name"] not in READONLY_PROPERTY_NAMES] def create_arguments(arguments) -> list: result = [] for arg in arguments: argument = f"{arg['name']}: {arg['type']}" # Add the default argument if present - if 'default' in arg: - default = arg['default'] + if "default" in arg: + default = arg["default"] argument += f" = {default}" result.append(argument) return result @@ -74,49 +74,49 @@ def create_definition(name, method, overloaded: bool) -> str: definition += f"def {name}(" arguments = [] if name in SPECIAL_METHOD_NAMES: - arguments.append('df: pandas.DataFrame') - if 'args' in method: - arguments.extend(create_arguments(method['args'])) - if 'kwargs' in method: - if not any(x.startswith('*') for x in arguments): + arguments.append("df: pandas.DataFrame") + if "args" in method: + arguments.extend(create_arguments(method["args"])) + if "kwargs" in method: + if not any(x.startswith("*") for x in arguments): arguments.append("*") - arguments.extend(create_arguments(method['kwargs'])) - definition += ', '.join(arguments) + arguments.extend(create_arguments(method["kwargs"])) + definition += ", ".join(arguments) definition += ")" definition += f" -> {method['return']}: ..." return definition # We have "duplicate" methods, which are overloaded. # We keep note of them to add the @overload decorator. - overloaded_methods: set[str] = {m for m in connection_methods if isinstance(m['name'], list)} + overloaded_methods: set[str] = {m for m in connection_methods if isinstance(m["name"], list)} body = [] for method in methods: - if isinstance(method['name'], list): - names = method['name'] + if isinstance(method["name"], list): + names = method["name"] else: - names = [method['name']] + names = [method["name"]] # Artificially add 'connection' keyword argument - if 'kwargs' not in method: - method['kwargs'] = [] - method['kwargs'].append({'name': 'connection', 'type': 'DuckDBPyConnection', 'default': '...'}) + if "kwargs" not in method: + method["kwargs"] = [] + method["kwargs"].append({"name": "connection", "type": "DuckDBPyConnection", "default": "..."}) for name in names: body.append(create_definition(name, method, name in overloaded_methods)) # ---- End of generation code ---- - with_newlines = [x + '\n' for x in body] + with_newlines = [x + "\n" for x in body] # Recreate the file content by concatenating all the pieces together new_content = start_section + with_newlines + end_section # Write out the modified DUCKDB_STUBS_FILE file - with open(DUCKDB_STUBS_FILE, 'w') as source_file: + with open(DUCKDB_STUBS_FILE, "w") as source_file: source_file.write("".join(new_content)) -if __name__ == '__main__': +if __name__ == "__main__": raise ValueError("Please use 'generate_connection_code.py' instead of running the individual script(s)") # generate() diff --git a/scripts/generate_import_cache_cpp.py b/scripts/generate_import_cache_cpp.py index f03d8d89..8a4b0c36 100644 --- a/scripts/generate_import_cache_cpp.py +++ b/scripts/generate_import_cache_cpp.py @@ -16,97 +16,97 @@ # deal with leaf nodes?? Those are just PythonImportCacheItem def get_class_name(path: str) -> str: - parts: list[str] = path.replace('_', '').split('.') + parts: list[str] = path.replace("_", "").split(".") parts = [x.title() for x in parts] - return ''.join(parts) + 'CacheItem' + return "".join(parts) + "CacheItem" def get_filename(name: str) -> str: - return name.replace('_', '').lower() + '_module.hpp' + return name.replace("_", "").lower() + "_module.hpp" def get_variable_name(name: str) -> str: - if name in ['short', 'ushort']: - return name + '_' + if name in ["short", "ushort"]: + return name + "_" return name def collect_items_of_module(module: dict, collection: dict): global json_data - children = module['children'] - collection[module['full_path']] = module + children = module["children"] + collection[module["full_path"]] = module for child in children: collect_items_of_module(json_data[child], collection) class CacheItem: def __init__(self, module: dict, items) -> None: - self.name = module['name'] + self.name = module["name"] self.module = module self.items = items - self.class_name = get_class_name(module['full_path']) + self.class_name = get_class_name(module["full_path"]) def get_full_module_path(self): - if self.module['type'] != 'module': - return '' - full_path = self.module['full_path'] + if self.module["type"] != "module": + return "" + full_path = self.module["full_path"] return f""" public: \tstatic constexpr const char *Name = "{full_path}"; """ def get_optionally_required(self): - if 'required' not in self.module: - return '' + if "required" not in self.module: + return "" string = f""" protected: \tbool IsRequired() const override final {{ -\t\treturn {str(self.module['required']).lower()}; +\t\treturn {str(self.module["required"]).lower()}; \t}} """ return string def get_variables(self): variables = [] - for key in self.module['children']: + for key in self.module["children"]: item = self.items[key] - name = item['name'] + name = item["name"] var_name = get_variable_name(name) - if item['children'] == []: - class_name = 'PythonImportCacheItem' + if item["children"] == []: + class_name = "PythonImportCacheItem" else: - class_name = get_class_name(item['full_path']) - variables.append(f'\t{class_name} {var_name};') - return '\n'.join(variables) + class_name = get_class_name(item["full_path"]) + variables.append(f"\t{class_name} {var_name};") + return "\n".join(variables) def get_initializer(self): variables = [] - for key in self.module['children']: + for key in self.module["children"]: item = self.items[key] - name = item['name'] + name = item["name"] var_name = get_variable_name(name) - if item['children'] == []: + if item["children"] == []: initialization = f'{var_name}("{name}", this)' variables.append(initialization) else: - if item['type'] == 'module': - arguments = '' + if item["type"] == "module": + arguments = "" else: - arguments = 'this' - initialization = f'{var_name}({arguments})' + arguments = "this" + initialization = f"{var_name}({arguments})" variables.append(initialization) - if self.module['type'] != 'module': + if self.module["type"] != "module": constructor_params = f'"{self.name}"' - constructor_params += ', parent' + constructor_params += ", parent" else: - full_path = self.module['full_path'] + full_path = self.module["full_path"] constructor_params = f'"{full_path}"' - return f'PythonImportCacheItem({constructor_params}), ' + ', '.join(variables) + '{}' + return f"PythonImportCacheItem({constructor_params}), " + ", ".join(variables) + "{}" def get_constructor(self): - if self.module['type'] == 'module': - return f'{self.class_name}()' - return f'{self.class_name}(optional_ptr parent)' + if self.module["type"] == "module": + return f"{self.class_name}()" + return f"{self.class_name}(optional_ptr parent)" def to_string(self): return f""" @@ -125,7 +125,7 @@ def to_string(self): def collect_classes(items: dict) -> list: output: list = [] for item in items.values(): - if item['children'] == []: + if item["children"] == []: continue output.append(CacheItem(item, items)) return output @@ -134,7 +134,7 @@ def collect_classes(items: dict) -> list: class ModuleFile: def __init__(self, module: dict) -> None: self.module = module - self.file_name = get_filename(module['name']) + self.file_name = get_filename(module["name"]) self.items = {} collect_items_of_module(module, self.items) self.classes = collect_classes(self.items) @@ -144,7 +144,7 @@ def get_classes(self): classes = [] for item in self.classes: classes.append(item.to_string()) - return ''.join(classes) + return "".join(classes) def to_string(self): string = f""" @@ -176,13 +176,13 @@ def to_string(self): files: list[ModuleFile] = [] for name, value in json_data.items(): - if value['full_path'] != value['name']: + if value["full_path"] != value["name"]: continue files.append(ModuleFile(value)) for file in files: content = file.to_string() - path = f'src/duckdb_py/include/duckdb_python/import_cache/modules/{file.file_name}' + path = f"src/duckdb_py/include/duckdb_python/import_cache/modules/{file.file_name}" import_cache_path = os.path.join(script_dir, '..', path) with open(import_cache_path, "w") as f: f.write(content) @@ -191,10 +191,10 @@ def to_string(self): def get_root_modules(files: list[ModuleFile]): modules = [] for file in files: - name = file.module['name'] + name = file.module["name"] class_name = get_class_name(name) - modules.append(f'\t{class_name} {name};') - return '\n'.join(modules) + modules.append(f"\t{class_name} {name};") + return "\n".join(modules) # Generate the python_import_cache.hpp file @@ -237,9 +237,7 @@ def get_root_modules(files: list[ModuleFile]): """ -import_cache_path = os.path.join( - script_dir, '..', 'src/duckdb_py/include/duckdb_python/import_cache/python_import_cache.hpp' -) +import_cache_path = os.path.join(script_dir, "..", "src/duckdb_py/include/duckdb_python/import_cache/python_import_cache.hpp") with open(import_cache_path, "w") as f: f.write(import_cache_file) @@ -248,13 +246,13 @@ def get_module_file_path_includes(files: list[ModuleFile]): includes = [] for file in files: includes.append(f'#include "duckdb_python/import_cache/modules/{file.file_name}"') - return '\n'.join(includes) + return "\n".join(includes) module_includes = get_module_file_path_includes(files) modules_header = os.path.join( - script_dir, '..', 'src/duckdb_py/include/duckdb_python/import_cache/python_import_cache_modules.hpp' + script_dir, "..", "src/duckdb_py/include/duckdb_python/import_cache/python_import_cache_modules.hpp" ) with open(modules_header, "w") as f: f.write(module_includes) diff --git a/scripts/generate_import_cache_json.py b/scripts/generate_import_cache_json.py index 2df33b24..099db841 100644 --- a/scripts/generate_import_cache_json.py +++ b/scripts/generate_import_cache_json.py @@ -4,12 +4,12 @@ from typing import List, Dict, Union import json -lines: list[str] = [file for file in open(f'{script_dir}/imports.py').read().split('\n') if file != ''] +lines: list[str] = [file for file in open(f"{script_dir}/imports.py").read().split("\n") if file != ""] class ImportCacheAttribute: def __init__(self, full_path: str) -> None: - parts = full_path.split('.') + parts = full_path.split(".") self.type = "attribute" self.name = parts[-1] self.full_path = full_path @@ -42,7 +42,7 @@ def populate_json(self, json_data: dict): class ImportCacheModule: def __init__(self, full_path) -> None: - parts = full_path.split('.') + parts = full_path.split(".") self.type = "module" self.name = parts[-1] self.full_path = full_path @@ -82,27 +82,27 @@ def __init__(self) -> None: self.modules: dict[str, ImportCacheModule] = {} def add_module(self, path: str): - assert path.startswith('import') + assert path.startswith("import") path = path[7:] module = ImportCacheModule(path) self.modules[module.full_path] = module # Add it to the parent module if present - parts = path.split('.') + parts = path.split(".") if len(parts) == 1: return # This works back from the furthest child module to the top level module child_module = module for i in range(1, len(parts)): - parent_path = '.'.join(parts[: len(parts) - i]) + parent_path = ".".join(parts[: len(parts) - i]) parent_module = self.add_or_get_module(parent_path) parent_module.add_item(child_module) child_module = parent_module def add_or_get_module(self, module_name: str) -> ImportCacheModule: if module_name not in self.modules: - self.add_module(f'import {module_name}') + self.add_module(f"import {module_name}") return self.get_module(module_name) def get_module(self, module_name: str) -> ImportCacheModule: @@ -111,13 +111,13 @@ def get_module(self, module_name: str) -> ImportCacheModule: return self.modules[module_name] def get_item(self, item_name: str) -> Union[ImportCacheModule, ImportCacheAttribute]: - parts = item_name.split('.') + parts = item_name.split(".") if len(parts) == 1: return self.get_module(item_name) parent = self.get_module(parts[0]) for i in range(1, len(parts)): - child_path = '.'.join(parts[: i + 1]) + child_path = ".".join(parts[: i + 1]) if parent.has_item(child_path): parent = parent.get_item(child_path) else: @@ -127,8 +127,8 @@ def get_item(self, item_name: str) -> Union[ImportCacheModule, ImportCacheAttrib return parent def add_attribute(self, path: str): - assert not path.startswith('import') - parts = path.split('.') + assert not path.startswith("import") + parts = path.split(".") assert len(parts) >= 2 self.get_item(path) @@ -145,9 +145,9 @@ def to_json(self): generator = ImportCacheGenerator() for line in lines: - if line.startswith('#'): + if line.startswith("#"): continue - if line.startswith('import'): + if line.startswith("import"): generator.add_module(line) else: generator.add_attribute(line) diff --git a/sqllogic/conftest.py b/sqllogic/conftest.py index 64ad8edc..b8d913ea 100644 --- a/sqllogic/conftest.py +++ b/sqllogic/conftest.py @@ -10,7 +10,7 @@ SQLLOGIC_TEST_CASE_NAME = "test_sqllogic" SQLLOGIC_TEST_PARAMETER = "test_script_path" -DUCKDB_ROOT_DIR = (pathlib.Path(__file__).parent.parent / 'external' / 'duckdb').resolve() +DUCKDB_ROOT_DIR = (pathlib.Path(__file__).parent.parent / "external" / "duckdb").resolve() def pytest_addoption(parser: pytest.Parser): @@ -65,8 +65,8 @@ def pytest_keyboard_interrupt(excinfo: pytest.ExceptionInfo): # Ensure all tests are properly cleaned up on keyboard interrupt from .test_sqllogic import test_sqllogic - if hasattr(test_sqllogic, 'executor') and test_sqllogic.executor: - if test_sqllogic.executor.database and hasattr(test_sqllogic.executor.database, 'connection'): + if hasattr(test_sqllogic, "executor") and test_sqllogic.executor: + if test_sqllogic.executor.database and hasattr(test_sqllogic.executor.database, "connection"): test_sqllogic.executor.database.connection.interrupt() test_sqllogic.executor.cleanup() test_sqllogic.executor = None diff --git a/sqllogic/skipped_tests.py b/sqllogic/skipped_tests.py index 39269c42..485ed9b9 100644 --- a/sqllogic/skipped_tests.py +++ b/sqllogic/skipped_tests.py @@ -1,42 +1,42 @@ SKIPPED_TESTS = set( [ - 'test/sql/timezone/disable_timestamptz_casts.test', # <-- ICU extension is always loaded - 'test/sql/copy/return_stats_truncate.test', # <-- handling was changed - 'test/sql/copy/return_stats.test', # <-- handling was changed - 'test/sql/copy/parquet/writer/skip_empty_write.test', # <-- handling was changed - 'test/sql/types/map/map_empty.test', - 'test/extension/wrong_function_type.test', # <-- JSON is always loaded - 'test/sql/insert/test_insert_invalid.test', # <-- doesn't parse properly - 'test/sql/cast/cast_error_location.test', # <-- python exception doesn't contain error location yet - 'test/sql/pragma/test_query_log.test', # <-- query_log gets filled with NULL when con.query(...) is used - 'test/sql/json/table/read_json_objects.test', # <-- Python client is always loaded with JSON available - 'test/sql/copy/csv/zstd_crash.test', # <-- Python client is always loaded with Parquet available - 'test/sql/error/extension_function_error.test', # <-- Python client is always loaded with TPCH available - 'test/optimizer/joins/tpcds_nofail.test', # <-- Python client is always loaded with TPCDS available - 'test/sql/settings/errors_as_json.test', # <-- errors_as_json not currently supported in Python - 'test/sql/parallelism/intraquery/depth_first_evaluation_union_and_join.test', # <-- Python client is always loaded with TPCDS available - 'test/sql/types/timestamp/test_timestamp_tz.test', # <-- Python client is always loaded wih ICU available - making the TIMESTAMPTZ::DATE cast pass - 'test/sql/parser/invisible_spaces.test', # <-- Parser is getting tripped up on the invisible spaces - 'test/sql/copy/csv/code_cov/csv_state_machine_invalid_utf.test', # <-- ConversionException is empty, see Python Mega Issue (duckdb-internal #1488) - 'test/sql/copy/csv/test_csv_timestamp_tz.test', # <-- ICU is always loaded - 'test/fuzzer/duckfuzz/duck_fuzz_column_binding_tests.test', # <-- ICU is always loaded - 'test/sql/pragma/test_custom_optimizer_profiling.test', # Because of logic related to enabling 'restart' statement capabilities, this will not measure the right statement - 'test/sql/pragma/test_custom_profiling_settings.test', # Because of logic related to enabling 'restart' statement capabilities, this will not measure the right statement - 'test/sql/copy/csv/test_copy.test', # JSON is always loaded - 'test/sql/copy/csv/test_timestamptz_12926.test', # ICU is always loaded - 'test/fuzzer/pedro/in_clause_optimization_error.test', # error message differs due to a different execution path - 'test/sql/order/test_limit_parameter.test', # error message differs due to a different execution path - 'test/sql/catalog/test_set_search_path.test', # current_query() is not the same - 'test/sql/catalog/table/create_table_parameters.test', # prepared statement error quirks - 'test/sql/pragma/profiling/test_custom_profiling_rows_scanned.test', # we perform additional queries that mess with the expected metrics - 'test/sql/pragma/profiling/test_custom_profiling_disable_metrics.test', # we perform additional queries that mess with the expected metrics - 'test/sql/pragma/profiling/test_custom_profiling_result_set_size.test', # we perform additional queries that mess with the expected metrics - 'test/sql/pragma/profiling/test_custom_profiling_result_set_size.test', # we perform additional queries that mess with the expected metrics - 'test/sql/cte/materialized/materialized_cte_modifiers.test', # problems connected to auto installing tpcds from remote - 'test/sql/tpcds/dsdgen_readonly.test', # problems connected to auto installing tpcds from remote - 'test/sql/tpcds/tpcds_sf0.test', # problems connected to auto installing tpcds from remote - 'test/sql/optimizer/plan/test_filter_pushdown_materialized_cte.test', # problems connected to auto installing tpcds from remote - 'test/sql/explain/test_explain_analyze.test', # unknown problem with changes in API - 'test/sql/pragma/profiling/test_profiling_all.test', # Because of logic related to enabling 'restart' statement capabilities, this will not measure the right statement + "test/sql/timezone/disable_timestamptz_casts.test", # <-- ICU extension is always loaded + "test/sql/copy/return_stats_truncate.test", # <-- handling was changed + "test/sql/copy/return_stats.test", # <-- handling was changed + "test/sql/copy/parquet/writer/skip_empty_write.test", # <-- handling was changed + "test/sql/types/map/map_empty.test", + "test/extension/wrong_function_type.test", # <-- JSON is always loaded + "test/sql/insert/test_insert_invalid.test", # <-- doesn't parse properly + "test/sql/cast/cast_error_location.test", # <-- python exception doesn't contain error location yet + "test/sql/pragma/test_query_log.test", # <-- query_log gets filled with NULL when con.query(...) is used + "test/sql/json/table/read_json_objects.test", # <-- Python client is always loaded with JSON available + "test/sql/copy/csv/zstd_crash.test", # <-- Python client is always loaded with Parquet available + "test/sql/error/extension_function_error.test", # <-- Python client is always loaded with TPCH available + "test/optimizer/joins/tpcds_nofail.test", # <-- Python client is always loaded with TPCDS available + "test/sql/settings/errors_as_json.test", # <-- errors_as_json not currently supported in Python + "test/sql/parallelism/intraquery/depth_first_evaluation_union_and_join.test", # <-- Python client is always loaded with TPCDS available + "test/sql/types/timestamp/test_timestamp_tz.test", # <-- Python client is always loaded wih ICU available - making the TIMESTAMPTZ::DATE cast pass + "test/sql/parser/invisible_spaces.test", # <-- Parser is getting tripped up on the invisible spaces + "test/sql/copy/csv/code_cov/csv_state_machine_invalid_utf.test", # <-- ConversionException is empty, see Python Mega Issue (duckdb-internal #1488) + "test/sql/copy/csv/test_csv_timestamp_tz.test", # <-- ICU is always loaded + "test/fuzzer/duckfuzz/duck_fuzz_column_binding_tests.test", # <-- ICU is always loaded + "test/sql/pragma/test_custom_optimizer_profiling.test", # Because of logic related to enabling 'restart' statement capabilities, this will not measure the right statement + "test/sql/pragma/test_custom_profiling_settings.test", # Because of logic related to enabling 'restart' statement capabilities, this will not measure the right statement + "test/sql/copy/csv/test_copy.test", # JSON is always loaded + "test/sql/copy/csv/test_timestamptz_12926.test", # ICU is always loaded + "test/fuzzer/pedro/in_clause_optimization_error.test", # error message differs due to a different execution path + "test/sql/order/test_limit_parameter.test", # error message differs due to a different execution path + "test/sql/catalog/test_set_search_path.test", # current_query() is not the same + "test/sql/catalog/table/create_table_parameters.test", # prepared statement error quirks + "test/sql/pragma/profiling/test_custom_profiling_rows_scanned.test", # we perform additional queries that mess with the expected metrics + "test/sql/pragma/profiling/test_custom_profiling_disable_metrics.test", # we perform additional queries that mess with the expected metrics + "test/sql/pragma/profiling/test_custom_profiling_result_set_size.test", # we perform additional queries that mess with the expected metrics + "test/sql/pragma/profiling/test_custom_profiling_result_set_size.test", # we perform additional queries that mess with the expected metrics + "test/sql/cte/materialized/materialized_cte_modifiers.test", # problems connected to auto installing tpcds from remote + "test/sql/tpcds/dsdgen_readonly.test", # problems connected to auto installing tpcds from remote + "test/sql/tpcds/tpcds_sf0.test", # problems connected to auto installing tpcds from remote + "test/sql/optimizer/plan/test_filter_pushdown_materialized_cte.test", # problems connected to auto installing tpcds from remote + "test/sql/explain/test_explain_analyze.test", # unknown problem with changes in API + "test/sql/pragma/profiling/test_profiling_all.test", # Because of logic related to enabling 'restart' statement capabilities, this will not measure the right statement ] ) diff --git a/sqllogic/test_sqllogic.py b/sqllogic/test_sqllogic.py index 4e7cead0..6f55e931 100644 --- a/sqllogic/test_sqllogic.py +++ b/sqllogic/test_sqllogic.py @@ -6,7 +6,7 @@ import sys from typing import Any, Generator, Optional -sys.path.append(str(pathlib.Path(__file__).parent.parent / 'external' / 'duckdb' / 'scripts')) +sys.path.append(str(pathlib.Path(__file__).parent.parent / "external" / "duckdb" / "scripts")) from sqllogictest import ( SQLParserException, SQLLogicParser, @@ -24,8 +24,8 @@ def sigquit_handler(signum, frame): # Access the executor from the test_sqllogic function - if hasattr(test_sqllogic, 'executor') and test_sqllogic.executor: - if test_sqllogic.executor.database and hasattr(test_sqllogic.executor.database, 'connection'): + if hasattr(test_sqllogic, "executor") and test_sqllogic.executor: + if test_sqllogic.executor.database and hasattr(test_sqllogic.executor.database, "connection"): test_sqllogic.executor.database.connection.interrupt() test_sqllogic.executor.cleanup() test_sqllogic.executor = None @@ -85,13 +85,13 @@ def execute_test(self, test: SQLLogicTest) -> ExecuteResult: self.original_sqlite_test = self.test.is_sqlite_test() # Top level keywords - keywords = {'__TEST_DIR__': self.get_test_directory(), '__WORKING_DIRECTORY__': os.getcwd()} + keywords = {"__TEST_DIR__": self.get_test_directory(), "__WORKING_DIRECTORY__": os.getcwd()} def update_value(_: SQLLogicContext) -> Generator[Any, Any, Any]: # Yield once to represent one iteration, do not touch the keywords yield None - self.database = SQLLogicDatabase(':memory:', None) + self.database = SQLLogicDatabase(":memory:", None) pool = self.database.connect() context = SQLLogicContext(pool, self, test.statements, keywords, update_value) pool.initialize_connection(context, pool.get_connection()) @@ -126,7 +126,7 @@ def update_value(_: SQLLogicContext) -> Generator[Any, Any, Any]: def cleanup(self): if self.database: - if hasattr(self.database, 'connection'): + if hasattr(self.database, "connection"): self.database.connection.interrupt() self.database.reset() self.database = None @@ -160,6 +160,6 @@ def test_sqllogic(test_script_path: pathlib.Path, pytestconfig: pytest.Config, t test_sqllogic.executor = None -if __name__ == '__main__': +if __name__ == "__main__": # Pass all arguments including the script name to pytest sys.exit(pytest.main(sys.argv)) diff --git a/tests/conftest.py b/tests/conftest.py index b9950ee7..d69cdfce 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,11 +11,11 @@ try: # need to ignore warnings that might be thrown deep inside pandas's import tree (from dateutil in this case) - warnings.simplefilter(action='ignore', category=DeprecationWarning) - pandas = import_module('pandas') + warnings.simplefilter(action="ignore", category=DeprecationWarning) + pandas = import_module("pandas") warnings.resetwarnings() - pyarrow_dtype = getattr(pandas, 'ArrowDtype', None) + pyarrow_dtype = getattr(pandas, "ArrowDtype", None) except ImportError: pandas = None pyarrow_dtype = None @@ -65,7 +65,7 @@ def pytest_collection_modifyitems(config, items): @pytest.fixture(scope="function") def duckdb_empty_cursor(request): - connection = duckdb.connect('') + connection = duckdb.connect("") cursor = connection.cursor() return cursor @@ -99,7 +99,7 @@ def makeTimeSeries(nper=None, freq: Frequency = "B", name=None) -> Series: def pandas_2_or_higher(): from packaging.version import Version - return Version(import_pandas().__version__) >= Version('2.0.0') + return Version(import_pandas().__version__) >= Version("2.0.0") def pandas_supports_arrow_backend(): @@ -124,7 +124,7 @@ def arrow_pandas_df(*args, **kwargs): class NumpyPandas: def __init__(self) -> None: - self.backend = 'numpy_nullable' + self.backend = "numpy_nullable" self.DataFrame = numpy_pandas_df self.pandas = import_pandas() @@ -173,11 +173,11 @@ class ArrowPandas: def __init__(self) -> None: self.pandas = import_pandas() if pandas_2_or_higher() and pyarrow_dtypes_enabled: - self.backend = 'pyarrow' + self.backend = "pyarrow" self.DataFrame = arrow_pandas_df else: # For backwards compatible reasons, just mock regular pandas - self.backend = 'numpy_nullable' + self.backend = "numpy_nullable" self.DataFrame = self.pandas.DataFrame self.testing = ArrowMockTesting() @@ -187,7 +187,7 @@ def __getattr__(self, name: str) -> Any: @pytest.fixture(scope="function") def require(): - def _require(extension_name, db_name=''): + def _require(extension_name, db_name=""): # Paths to search for extensions build = normpath(join(dirname(__file__), "../../../build/")) @@ -199,11 +199,11 @@ def _require(extension_name, db_name=''): ] # DUCKDB_PYTHON_TEST_EXTENSION_PATH can be used to add a path for the extension test to search for extensions - if 'DUCKDB_PYTHON_TEST_EXTENSION_PATH' in os.environ: - env_extension_path = os.getenv('DUCKDB_PYTHON_TEST_EXTENSION_PATH') - env_extension_path = env_extension_path.rstrip('/') - extension_search_patterns.append(env_extension_path + '/*/*.duckdb_extension') - extension_search_patterns.append(env_extension_path + '/*.duckdb_extension') + if "DUCKDB_PYTHON_TEST_EXTENSION_PATH" in os.environ: + env_extension_path = os.getenv("DUCKDB_PYTHON_TEST_EXTENSION_PATH") + env_extension_path = env_extension_path.rstrip("/") + extension_search_patterns.append(env_extension_path + "/*/*.duckdb_extension") + extension_search_patterns.append(env_extension_path + "/*.duckdb_extension") extension_paths_found = [] for pattern in extension_search_patterns: @@ -215,39 +215,39 @@ def _require(extension_name, db_name=''): for path in extension_paths_found: print(path) if path.endswith(extension_name + ".duckdb_extension"): - conn = duckdb.connect(db_name, config={'allow_unsigned_extensions': 'true'}) + conn = duckdb.connect(db_name, config={"allow_unsigned_extensions": "true"}) conn.execute(f"LOAD '{path}'") return conn - pytest.skip(f'could not load {extension_name}') + pytest.skip(f"could not load {extension_name}") return _require # By making the scope 'function' we ensure that a new connection gets created for every function that uses the fixture -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def spark(): from spark_namespace import USE_ACTUAL_SPARK - if not hasattr(spark, 'session'): + if not hasattr(spark, "session"): # Cache the import from spark_namespace.sql import SparkSession as session spark.session = session - return spark.session.builder.appName('pyspark').getOrCreate() + return spark.session.builder.appName("pyspark").getOrCreate() -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def duckdb_cursor(): - connection = duckdb.connect('') + connection = duckdb.connect("") yield connection connection.close() -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def integers(duckdb_cursor): cursor = duckdb_cursor - cursor.execute('CREATE TABLE integers (i integer)') + cursor.execute("CREATE TABLE integers (i integer)") cursor.execute( """ INSERT INTO integers VALUES @@ -268,10 +268,10 @@ def integers(duckdb_cursor): cursor.execute("drop table integers") -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def timestamps(duckdb_cursor): cursor = duckdb_cursor - cursor.execute('CREATE TABLE timestamps (t timestamp)') + cursor.execute("CREATE TABLE timestamps (t timestamp)") cursor.execute("INSERT INTO timestamps VALUES ('1992-10-03 18:34:45'), ('2010-01-01 00:00:01'), (NULL)") yield cursor.execute("drop table timestamps") diff --git a/tests/coverage/test_pandas_categorical_coverage.py b/tests/coverage/test_pandas_categorical_coverage.py index e20afa72..15eee10a 100644 --- a/tests/coverage/test_pandas_categorical_coverage.py +++ b/tests/coverage/test_pandas_categorical_coverage.py @@ -15,17 +15,17 @@ def check_create_table(category, pandas): conn.execute("PRAGMA enable_verification") df_in = pandas.DataFrame( { - 'x': pandas.Categorical(category, ordered=True), - 'y': pandas.Categorical(category, ordered=True), - 'z': category, + "x": pandas.Categorical(category, ordered=True), + "y": pandas.Categorical(category, ordered=True), + "z": category, } ) - category.append('bla') + category.append("bla") df_in_diff = pandas.DataFrame( { - 'k': pandas.Categorical(category, ordered=True), + "k": pandas.Categorical(category, ordered=True), } ) @@ -44,7 +44,7 @@ def check_create_table(category, pandas): conn.execute("INSERT INTO t1 VALUES ('2','2','2')") res = conn.execute("SELECT x FROM t1 where x = '1'").fetchall() - assert res == [('1',)] + assert res == [("1",)] res = conn.execute("SELECT t1.x FROM t1 inner join t2 on (t1.x = t2.x) order by t1.x").fetchall() assert res == conn.execute("SELECT x FROM t1 order by t1.x").fetchall() @@ -70,14 +70,14 @@ def check_create_table(category, pandas): # TODO: extend tests with ArrowPandas class TestCategory(object): - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_category_string_uint16(self, duckdb_cursor, pandas): category = [] for i in range(300): category.append(str(i)) check_create_table(category, pandas) - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_category_string_uint32(self, duckdb_cursor, pandas): category = [] for i in range(70000): diff --git a/tests/extensions/json/test_read_json.py b/tests/extensions/json/test_read_json.py index 48590175..f0fd809f 100644 --- a/tests/extensions/json/test_read_json.py +++ b/tests/extensions/json/test_read_json.py @@ -10,50 +10,50 @@ def TestFile(name): import os - filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', name) + filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", name) return filename class TestReadJSON(object): def test_read_json_columns(self): - rel = duckdb.read_json(TestFile('example.json'), columns={'id': 'integer', 'name': 'varchar'}) + rel = duckdb.read_json(TestFile("example.json"), columns={"id": "integer", "name": "varchar"}) res = rel.fetchone() print(res) - assert res == (1, 'O Brother, Where Art Thou?') + assert res == (1, "O Brother, Where Art Thou?") def test_read_json_auto(self): - rel = duckdb.read_json(TestFile('example.json')) + rel = duckdb.read_json(TestFile("example.json")) res = rel.fetchone() print(res) - assert res == (1, 'O Brother, Where Art Thou?') + assert res == (1, "O Brother, Where Art Thou?") def test_read_json_maximum_depth(self): - rel = duckdb.read_json(TestFile('example.json'), maximum_depth=4) + rel = duckdb.read_json(TestFile("example.json"), maximum_depth=4) res = rel.fetchone() print(res) - assert res == (1, 'O Brother, Where Art Thou?') + assert res == (1, "O Brother, Where Art Thou?") def test_read_json_sample_size(self): - rel = duckdb.read_json(TestFile('example.json'), sample_size=2) + rel = duckdb.read_json(TestFile("example.json"), sample_size=2) res = rel.fetchone() print(res) - assert res == (1, 'O Brother, Where Art Thou?') + assert res == (1, "O Brother, Where Art Thou?") def test_read_json_format(self): # Wrong option with pytest.raises(duckdb.BinderException, match="format must be one of .* not 'test'"): - rel = duckdb.read_json(TestFile('example.json'), format='test') + rel = duckdb.read_json(TestFile("example.json"), format="test") - rel = duckdb.read_json(TestFile('example.json'), format='unstructured') + rel = duckdb.read_json(TestFile("example.json"), format="unstructured") res = rel.fetchone() print(res) assert res == ( [ - {'id': 1, 'name': 'O Brother, Where Art Thou?'}, - {'id': 2, 'name': 'Home for the Holidays'}, - {'id': 3, 'name': 'The Firm'}, - {'id': 4, 'name': 'Broadcast News'}, - {'id': 5, 'name': 'Raising Arizona'}, + {"id": 1, "name": "O Brother, Where Art Thou?"}, + {"id": 2, "name": "Home for the Holidays"}, + {"id": 3, "name": "The Firm"}, + {"id": 4, "name": "Broadcast News"}, + {"id": 5, "name": "Raising Arizona"}, ], ) @@ -63,13 +63,13 @@ def test_read_filelike(self, duckdb_cursor): duckdb_cursor.execute("set threads=1") string = StringIO("""{"id":1,"name":"O Brother, Where Art Thou?"}\n{"id":2,"name":"Home for the Holidays"}""") res = duckdb_cursor.read_json(string).fetchall() - assert res == [(1, 'O Brother, Where Art Thou?'), (2, 'Home for the Holidays')] + assert res == [(1, "O Brother, Where Art Thou?"), (2, "Home for the Holidays")] string1 = StringIO("""{"id":1,"name":"O Brother, Where Art Thou?"}""") string2 = StringIO("""{"id":2,"name":"Home for the Holidays"}""") res = duckdb_cursor.read_json([string1, string2], filename=True).fetchall() - assert res[0][1] == 'O Brother, Where Art Thou?' - assert res[1][1] == 'Home for the Holidays' + assert res[0][1] == "O Brother, Where Art Thou?" + assert res[1][1] == "Home for the Holidays" # filenames are different assert res[0][2] != res[1][2] @@ -77,51 +77,51 @@ def test_read_filelike(self, duckdb_cursor): def test_read_json_records(self): # Wrong option with pytest.raises(duckdb.BinderException, match="""read_json requires "records" to be one of"""): - rel = duckdb.read_json(TestFile('example.json'), records='none') + rel = duckdb.read_json(TestFile("example.json"), records="none") - rel = duckdb.read_json(TestFile('example.json'), records='true') + rel = duckdb.read_json(TestFile("example.json"), records="true") res = rel.fetchone() print(res) - assert res == (1, 'O Brother, Where Art Thou?') + assert res == (1, "O Brother, Where Art Thou?") @pytest.mark.parametrize( - 'option', + "option", [ - ('filename', True), - ('filename', 'test'), - ('date_format', '%m-%d-%Y'), - ('date_format', '%m-%d-%y'), - ('date_format', '%d-%m-%Y'), - ('date_format', '%d-%m-%y'), - ('date_format', '%Y-%m-%d'), - ('date_format', '%y-%m-%d'), - ('timestamp_format', '%H:%M:%S%y-%m-%d'), - ('compression', 'AUTO_DETECT'), - ('compression', 'UNCOMPRESSED'), - ('maximum_object_size', 5), - ('ignore_errors', False), - ('ignore_errors', True), - ('convert_strings_to_integers', False), - ('convert_strings_to_integers', True), - ('field_appearance_threshold', 0.534), - ('map_inference_threshold', 34234), - ('maximum_sample_files', 5), - ('hive_partitioning', True), - ('hive_partitioning', False), - ('union_by_name', True), - ('union_by_name', False), - ('hive_types_autocast', False), - ('hive_types_autocast', True), - ('hive_types', {'id': 'INTEGER', 'name': 'VARCHAR'}), + ("filename", True), + ("filename", "test"), + ("date_format", "%m-%d-%Y"), + ("date_format", "%m-%d-%y"), + ("date_format", "%d-%m-%Y"), + ("date_format", "%d-%m-%y"), + ("date_format", "%Y-%m-%d"), + ("date_format", "%y-%m-%d"), + ("timestamp_format", "%H:%M:%S%y-%m-%d"), + ("compression", "AUTO_DETECT"), + ("compression", "UNCOMPRESSED"), + ("maximum_object_size", 5), + ("ignore_errors", False), + ("ignore_errors", True), + ("convert_strings_to_integers", False), + ("convert_strings_to_integers", True), + ("field_appearance_threshold", 0.534), + ("map_inference_threshold", 34234), + ("maximum_sample_files", 5), + ("hive_partitioning", True), + ("hive_partitioning", False), + ("union_by_name", True), + ("union_by_name", False), + ("hive_types_autocast", False), + ("hive_types_autocast", True), + ("hive_types", {"id": "INTEGER", "name": "VARCHAR"}), ], ) def test_read_json_options(self, duckdb_cursor, option): keyword_arguments = dict() option_name, option_value = option keyword_arguments[option_name] = option_value - if option_name == 'hive_types': - with pytest.raises(duckdb.InvalidInputException, match=r'Unknown hive_type:'): - rel = duckdb_cursor.read_json(TestFile('example.json'), **keyword_arguments) + if option_name == "hive_types": + with pytest.raises(duckdb.InvalidInputException, match=r"Unknown hive_type:"): + rel = duckdb_cursor.read_json(TestFile("example.json"), **keyword_arguments) else: - rel = duckdb_cursor.read_json(TestFile('example.json'), **keyword_arguments) + rel = duckdb_cursor.read_json(TestFile("example.json"), **keyword_arguments) res = rel.fetchall() diff --git a/tests/extensions/test_extensions_loading.py b/tests/extensions/test_extensions_loading.py index 2b4eab0c..f35366ba 100644 --- a/tests/extensions/test_extensions_loading.py +++ b/tests/extensions/test_extensions_loading.py @@ -13,9 +13,9 @@ def test_extension_loading(require): - if not os.getenv('DUCKDB_PYTHON_TEST_EXTENSION_REQUIRED', False): + if not os.getenv("DUCKDB_PYTHON_TEST_EXTENSION_REQUIRED", False): return - extensions_list = ['json', 'excel', 'httpfs', 'tpch', 'tpcds', 'icu', 'fts'] + extensions_list = ["json", "excel", "httpfs", "tpch", "tpcds", "icu", "fts"] for extension in extensions_list: connection = require(extension) assert connection is not None @@ -26,16 +26,16 @@ def test_install_non_existent_extension(): conn.execute("set custom_extension_repository = 'http://example.com'") with raises(duckdb.IOException) as exc: - conn.install_extension('non-existent') + conn.install_extension("non-existent") if not isinstance(exc, duckdb.HTTPException): - pytest.skip(reason='This test does not throw an HTTPException, only an IOException') + pytest.skip(reason="This test does not throw an HTTPException, only an IOException") value = exc.value assert value.status_code == 404 - assert value.reason == 'Not Found' - assert 'Example Domain' in value.body - assert 'Content-Length' in value.headers + assert value.reason == "Not Found" + assert "Example Domain" in value.body + assert "Content-Length" in value.headers def test_install_misuse_errors(duckdb_cursor): @@ -43,17 +43,17 @@ def test_install_misuse_errors(duckdb_cursor): duckdb.InvalidInputException, match="Both 'repository' and 'repository_url' are set which is not allowed, please pick one or the other", ): - duckdb_cursor.install_extension('name', repository='hello', repository_url='hello.com') + duckdb_cursor.install_extension("name", repository="hello", repository_url="hello.com") with pytest.raises( duckdb.InvalidInputException, match="The provided 'repository' or 'repository_url' can not be empty!" ): - duckdb_cursor.install_extension('name', repository_url='') + duckdb_cursor.install_extension("name", repository_url="") with pytest.raises( duckdb.InvalidInputException, match="The provided 'repository' or 'repository_url' can not be empty!" ): - duckdb_cursor.install_extension('name', repository='') + duckdb_cursor.install_extension("name", repository="") with pytest.raises(duckdb.InvalidInputException, match="The provided 'version' can not be empty!"): - duckdb_cursor.install_extension('name', version='') + duckdb_cursor.install_extension("name", version="") diff --git a/tests/extensions/test_httpfs.py b/tests/extensions/test_httpfs.py index 6366e07f..866491f0 100644 --- a/tests/extensions/test_httpfs.py +++ b/tests/extensions/test_httpfs.py @@ -9,33 +9,33 @@ # FIXME: we can add a custom command line argument to pytest to provide an extension directory # We can use that instead of checking this environment variable inside of conftest.py's 'require' method pytestmark = mark.skipif( - not os.getenv('DUCKDB_PYTHON_TEST_EXTENSION_REQUIRED', False), - reason='DUCKDB_PYTHON_TEST_EXTENSION_REQUIRED is not set', + not os.getenv("DUCKDB_PYTHON_TEST_EXTENSION_REQUIRED", False), + reason="DUCKDB_PYTHON_TEST_EXTENSION_REQUIRED is not set", ) class TestHTTPFS(object): def test_read_json_httpfs(self, require): - connection = require('httpfs') + connection = require("httpfs") try: - res = connection.read_json('https://jsonplaceholder.typicode.com/todos') + res = connection.read_json("https://jsonplaceholder.typicode.com/todos") assert len(res.types) == 4 except duckdb.Error as e: - if '403' in e: + if "403" in e: pytest.skip(reason="Test is flaky, sometimes returns 403") else: pytest.fail(str(e)) def test_s3fs(self, require): - connection = require('httpfs') + connection = require("httpfs") rel = connection.read_csv(f"s3://duckdb-blobs/data/Star_Trek-Season_1.csv", header=True) res = rel.fetchone() assert res == (1, 0, datetime.date(1965, 2, 28), 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 6, 0, 0, 0, 0) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_httpfs(self, require, pandas): - connection = require('httpfs') + connection = require("httpfs") try: connection.execute( "SELECT id, first_name, last_name FROM PARQUET_SCAN('https://raw.githubusercontent.com/duckdb/duckdb/main/data/parquet-testing/userdata1.parquet') LIMIT 3;" @@ -52,15 +52,15 @@ def test_httpfs(self, require, pandas): result_df = connection.fetchdf() exp_result = pandas.DataFrame( { - 'id': pandas.Series([1, 2, 3], dtype="int32"), - 'first_name': ['Amanda', 'Albert', 'Evelyn'], - 'last_name': ['Jordan', 'Freeman', 'Morgan'], + "id": pandas.Series([1, 2, 3], dtype="int32"), + "first_name": ["Amanda", "Albert", "Evelyn"], + "last_name": ["Jordan", "Freeman", "Morgan"], } ) pandas.testing.assert_frame_equal(result_df, exp_result) def test_http_exception(self, require): - connection = require('httpfs') + connection = require("httpfs") # Read from a bogus HTTPS url, assert that it errors with a non-successful status code with raises(duckdb.HTTPException) as exc: @@ -68,15 +68,15 @@ def test_http_exception(self, require): value = exc.value assert value.status_code != 200 - assert value.body == '' - assert 'Content-Length' in value.headers + assert value.body == "" + assert "Content-Length" in value.headers def test_fsspec_priority(self, require): pytest.importorskip("fsspec") pytest.importorskip("gscfs") import fsspec - connection = require('httpfs') + connection = require("httpfs") gcs = fsspec.filesystem("gcs") connection.register_filesystem(gcs) diff --git a/tests/fast/adbc/test_adbc.py b/tests/fast/adbc/test_adbc.py index 663563cf..80b6b385 100644 --- a/tests/fast/adbc/test_adbc.py +++ b/tests/fast/adbc/test_adbc.py @@ -47,7 +47,7 @@ def test_connection_get_table_types(duck_conn): with duck_conn.cursor() as cursor: # Test Default Schema cursor.execute("CREATE TABLE tableschema (ints BIGINT)") - assert duck_conn.adbc_get_table_types() == ['BASE TABLE'] + assert duck_conn.adbc_get_table_types() == ["BASE TABLE"] def test_connection_get_objects(duck_conn): @@ -124,7 +124,7 @@ def test_commit(tmp_path): # This errors because the table does not exist with pytest.raises( adbc_driver_manager_lib.InternalError, - match=r'Table with name ingest does not exist!', + match=r"Table with name ingest does not exist!", ): cur.execute("SELECT count(*) from ingest") @@ -138,7 +138,7 @@ def test_commit(tmp_path): ) as conn: with conn.cursor() as cur: cur.execute("SELECT count(*) from ingest") - assert cur.fetch_arrow_table().to_pydict() == {'count_star()': [4]} + assert cur.fetch_arrow_table().to_pydict() == {"count_star()": [4]} def test_connection_get_table_schema(duck_conn): @@ -310,17 +310,17 @@ def test_large_chunk(tmp_path): with conn.cursor() as cur: cur.adbc_ingest("ingest", table, "create") cur.execute("SELECT count(*) from ingest") - assert cur.fetch_arrow_table().to_pydict() == {'count_star()': [30_000]} + assert cur.fetch_arrow_table().to_pydict() == {"count_star()": [30_000]} def test_dictionary_data(tmp_path): - data = ['apple', 'banana', 'apple', 'orange', 'banana', 'banana'] + data = ["apple", "banana", "apple", "orange", "banana", "banana"] dict_type = pyarrow.dictionary(index_type=pyarrow.int32(), value_type=pyarrow.string()) dict_array = pyarrow.array(data, type=dict_type) # Wrap in a table - table = pyarrow.table({'fruits': dict_array}) + table = pyarrow.table({"fruits": dict_array}) db = os.path.join(tmp_path, "tmp.db") if os.path.exists(db): os.remove(db) @@ -335,7 +335,7 @@ def test_dictionary_data(tmp_path): cur.adbc_ingest("ingest", table, "create") cur.execute("from ingest") assert cur.fetch_arrow_table().to_pydict() == { - 'fruits': ['apple', 'banana', 'apple', 'orange', 'banana', 'banana'] + "fruits": ["apple", "banana", "apple", "orange", "banana", "banana"] } @@ -361,36 +361,36 @@ def test_ree_data(tmp_path): cur.adbc_ingest("ingest", table, "create") cur.execute("from ingest") assert cur.fetch_arrow_table().to_pydict() == { - 'fruits': ['apple', 'apple', 'apple', 'banana', 'banana', 'orange'] + "fruits": ["apple", "apple", "apple", "banana", "banana", "orange"] } def sorted_get_objects(catalogs): res = [] - for catalog in sorted(catalogs, key=lambda cat: cat['catalog_name']): + for catalog in sorted(catalogs, key=lambda cat: cat["catalog_name"]): new_catalog = { - "catalog_name": catalog['catalog_name'], + "catalog_name": catalog["catalog_name"], "catalog_db_schemas": [], } - for db_schema in sorted(catalog['catalog_db_schemas'] or [], key=lambda sch: sch['db_schema_name']): + for db_schema in sorted(catalog["catalog_db_schemas"] or [], key=lambda sch: sch["db_schema_name"]): new_db_schema = { - "db_schema_name": db_schema['db_schema_name'], + "db_schema_name": db_schema["db_schema_name"], "db_schema_tables": [], } - for table in sorted(db_schema['db_schema_tables'] or [], key=lambda tab: tab['table_name']): + for table in sorted(db_schema["db_schema_tables"] or [], key=lambda tab: tab["table_name"]): new_table = { - "table_name": table['table_name'], - "table_type": table['table_type'], + "table_name": table["table_name"], + "table_type": table["table_type"], "table_columns": [], "table_constraints": [], } - for column in sorted(table['table_columns'] or [], key=lambda col: col['ordinal_position']): + for column in sorted(table["table_columns"] or [], key=lambda col: col["ordinal_position"]): new_table["table_columns"].append(column) - for constraint in sorted(table['table_constraints'] or [], key=lambda con: con['constraint_name']): + for constraint in sorted(table["table_constraints"] or [], key=lambda con: con["constraint_name"]): new_table["table_constraints"].append(constraint) new_db_schema["db_schema_tables"].append(new_table) diff --git a/tests/fast/adbc/test_statement_bind.py b/tests/fast/adbc/test_statement_bind.py index 5e9d7d45..d1919cb1 100644 --- a/tests/fast/adbc/test_statement_bind.py +++ b/tests/fast/adbc/test_statement_bind.py @@ -70,30 +70,30 @@ def test_bind_single_row(self): raw_schema = statement.get_parameter_schema() schema = _import(raw_schema) - assert schema.names == ['0'] + assert schema.names == ["0"] _bind(statement, data) res, _ = statement.execute_query() table = _import(res).read_all() - result = table['i'] + result = table["i"] assert result.num_chunks == 1 result_values = result.chunk(0) assert result_values == expected_result def test_multiple_parameters(self): int_data = pa.array([5]) - varchar_data = pa.array(['not a short string']) + varchar_data = pa.array(["not a short string"]) bool_data = pa.array([True]) # Create the schema - schema = pa.schema([('a', pa.int64()), ('b', pa.string()), ('c', pa.bool_())]) + schema = pa.schema([("a", pa.int64()), ("b", pa.string()), ("c", pa.bool_())]) # Create the PyArrow table expected_res = pa.Table.from_arrays([int_data, varchar_data, bool_data], schema=schema) data = pa.record_batch( - [[5], ['not a short string'], [True]], + [[5], ["not a short string"], [True]], names=["ints", "strings", "bools"], ) @@ -105,7 +105,7 @@ def test_multiple_parameters(self): raw_schema = statement.get_parameter_schema() schema = _import(raw_schema) - assert schema.names == ['0', '1', '2'] + assert schema.names == ["0", "1", "2"] _bind(statement, data) res, _ = statement.execute_query() @@ -115,14 +115,14 @@ def test_multiple_parameters(self): def test_bind_composite_type(self): data_dict = { - 'field1': pa.array([10], type=pa.int64()), - 'field2': pa.array([3.14], type=pa.float64()), - 'field3': pa.array(['example with long string'], type=pa.string()), + "field1": pa.array([10], type=pa.int64()), + "field2": pa.array([3.14], type=pa.float64()), + "field3": pa.array(["example with long string"], type=pa.string()), } # Create the StructArray struct_array = pa.StructArray.from_arrays(arrays=data_dict.values(), names=data_dict.keys()) - schema = pa.schema([(name, array.type) for name, array in zip(['a'], [struct_array])]) + schema = pa.schema([(name, array.type) for name, array in zip(["a"], [struct_array])]) # Create the RecordBatch record_batch = pa.RecordBatch.from_arrays([struct_array], schema=schema) @@ -135,18 +135,18 @@ def test_bind_composite_type(self): raw_schema = statement.get_parameter_schema() schema = _import(raw_schema) - assert schema.names == ['0'] + assert schema.names == ["0"] _bind(statement, record_batch) res, _ = statement.execute_query() table = _import(res).read_all() - result = table['a'] + result = table["a"] result = result.chunk(0) assert result == struct_array def test_too_many_parameters(self): data = pa.record_batch( - [[12423], ['not a short string']], + [[12423], ["not a short string"]], names=["ints", "strings"], ) @@ -158,7 +158,7 @@ def test_too_many_parameters(self): raw_schema = statement.get_parameter_schema() schema = _import(raw_schema) - assert schema.names == ['0'] + assert schema.names == ["0"] array = adbc_driver_manager.ArrowArrayHandle() schema = adbc_driver_manager.ArrowSchemaHandle() @@ -174,7 +174,7 @@ def test_too_many_parameters(self): def test_not_enough_parameters(self): data = pa.record_batch( - [['not a short string']], + [["not a short string"]], names=["strings"], ) @@ -186,7 +186,7 @@ def test_not_enough_parameters(self): raw_schema = statement.get_parameter_schema() schema = _import(raw_schema) - assert schema.names == ['0', '1'] + assert schema.names == ["0", "1"] array = adbc_driver_manager.ArrowArrayHandle() schema = adbc_driver_manager.ArrowSchemaHandle() diff --git a/tests/fast/api/test_3324.py b/tests/fast/api/test_3324.py index e8f6085f..f3cd235b 100644 --- a/tests/fast/api/test_3324.py +++ b/tests/fast/api/test_3324.py @@ -27,4 +27,4 @@ def test_3324(self, duckdb_cursor): ).fetch_df() with pytest.raises(duckdb.BinderException, match="Unexpected prepared parameter"): - duckdb_cursor.execute("""execute v1(?)""", ('test1',)).fetch_df() + duckdb_cursor.execute("""execute v1(?)""", ("test1",)).fetch_df() diff --git a/tests/fast/api/test_3654.py b/tests/fast/api/test_3654.py index e63f0cd1..8fad47e6 100644 --- a/tests/fast/api/test_3654.py +++ b/tests/fast/api/test_3654.py @@ -11,11 +11,11 @@ class Test3654(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_3654_pandas(self, duckdb_cursor, pandas): df1 = pandas.DataFrame( { - 'id': [1, 1, 2], + "id": [1, 1, 2], } ) con = duckdb.connect() @@ -24,14 +24,14 @@ def test_3654_pandas(self, duckdb_cursor, pandas): print(rel.execute().fetchall()) assert rel.execute().fetchall() == [(1,), (1,), (2,)] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_3654_arrow(self, duckdb_cursor, pandas): if not can_run: return df1 = pandas.DataFrame( { - 'id': [1, 1, 2], + "id": [1, 1, 2], } ) table = pa.Table.from_pandas(df1) diff --git a/tests/fast/api/test_3728.py b/tests/fast/api/test_3728.py index 2df3c156..37b50ee6 100644 --- a/tests/fast/api/test_3728.py +++ b/tests/fast/api/test_3728.py @@ -14,6 +14,6 @@ def test_3728_describe_enum(self, duckdb_cursor): # This fails with "RuntimeError: Not implemented Error: unsupported type: mood" assert cursor.table("person").execute().description == [ - ('name', 'VARCHAR', None, None, None, None, None), - ('current_mood', "ENUM('sad', 'ok', 'happy')", None, None, None, None, None), + ("name", "VARCHAR", None, None, None, None, None), + ("current_mood", "ENUM('sad', 'ok', 'happy')", None, None, None, None, None), ] diff --git a/tests/fast/api/test_6315.py b/tests/fast/api/test_6315.py index e8eaff59..b9e7c0cf 100644 --- a/tests/fast/api/test_6315.py +++ b/tests/fast/api/test_6315.py @@ -9,15 +9,15 @@ def test_6315(self, duckdb_cursor): rv.fetchall() desc = rv.description names = [x[0] for x in desc] - assert names == ['type', 'name', 'tbl_name', 'rootpage', 'sql'] + assert names == ["type", "name", "tbl_name", "rootpage", "sql"] # description of relation rel = c.sql("select * from sqlite_master where type = 'table'") desc = rel.description names = [x[0] for x in desc] - assert names == ['type', 'name', 'tbl_name', 'rootpage', 'sql'] + assert names == ["type", "name", "tbl_name", "rootpage", "sql"] rel.fetchall() desc = rel.description names = [x[0] for x in desc] - assert names == ['type', 'name', 'tbl_name', 'rootpage', 'sql'] + assert names == ["type", "name", "tbl_name", "rootpage", "sql"] diff --git a/tests/fast/api/test_attribute_getter.py b/tests/fast/api/test_attribute_getter.py index 958e8892..eda6845a 100644 --- a/tests/fast/api/test_attribute_getter.py +++ b/tests/fast/api/test_attribute_getter.py @@ -11,43 +11,43 @@ class TestGetAttribute(object): def test_basic_getattr(self, duckdb_cursor): - rel = duckdb_cursor.sql('select i as a, (i + 5) % 10 as b, (i + 2) % 3 as c from range(100) tbl(i)') + rel = duckdb_cursor.sql("select i as a, (i + 5) % 10 as b, (i + 2) % 3 as c from range(100) tbl(i)") assert rel.a.fetchmany(5) == [(0,), (1,), (2,), (3,), (4,)] assert rel.b.fetchmany(5) == [(5,), (6,), (7,), (8,), (9,)] assert rel.c.fetchmany(5) == [(2,), (0,), (1,), (2,), (0,)] def test_basic_getitem(self, duckdb_cursor): - rel = duckdb_cursor.sql('select i as a, (i + 5) % 10 as b, (i + 2) % 3 as c from range(100) tbl(i)') - assert rel['a'].fetchmany(5) == [(0,), (1,), (2,), (3,), (4,)] - assert rel['b'].fetchmany(5) == [(5,), (6,), (7,), (8,), (9,)] - assert rel['c'].fetchmany(5) == [(2,), (0,), (1,), (2,), (0,)] + rel = duckdb_cursor.sql("select i as a, (i + 5) % 10 as b, (i + 2) % 3 as c from range(100) tbl(i)") + assert rel["a"].fetchmany(5) == [(0,), (1,), (2,), (3,), (4,)] + assert rel["b"].fetchmany(5) == [(5,), (6,), (7,), (8,), (9,)] + assert rel["c"].fetchmany(5) == [(2,), (0,), (1,), (2,), (0,)] def test_getitem_nonexistant(self, duckdb_cursor): - rel = duckdb_cursor.sql('select i as a, (i + 5) % 10 as b, (i + 2) % 3 as c from range(100) tbl(i)') + rel = duckdb_cursor.sql("select i as a, (i + 5) % 10 as b, (i + 2) % 3 as c from range(100) tbl(i)") with pytest.raises(AttributeError): - rel['d'] + rel["d"] def test_getattr_nonexistant(self, duckdb_cursor): - rel = duckdb_cursor.sql('select i as a, (i + 5) % 10 as b, (i + 2) % 3 as c from range(100) tbl(i)') + rel = duckdb_cursor.sql("select i as a, (i + 5) % 10 as b, (i + 2) % 3 as c from range(100) tbl(i)") with pytest.raises(AttributeError): rel.d def test_getattr_collision(self, duckdb_cursor): - rel = duckdb_cursor.sql('select i as df from range(100) tbl(i)') + rel = duckdb_cursor.sql("select i as df from range(100) tbl(i)") # 'df' also exists as a method on DuckDBPyRelation assert rel.df.__class__ != duckdb.DuckDBPyRelation def test_getitem_collision(self, duckdb_cursor): - rel = duckdb_cursor.sql('select i as df from range(100) tbl(i)') + rel = duckdb_cursor.sql("select i as df from range(100) tbl(i)") # this case is not an issue on __getitem__ - assert rel['df'].__class__ == duckdb.DuckDBPyRelation + assert rel["df"].__class__ == duckdb.DuckDBPyRelation def test_getitem_struct(self, duckdb_cursor): rel = duckdb_cursor.sql("select {'a':5, 'b':6} as a, 5 as b") - assert rel['a']['a'].fetchall()[0][0] == 5 - assert rel['a']['b'].fetchall()[0][0] == 6 + assert rel["a"]["a"].fetchall()[0][0] == 5 + assert rel["a"]["b"].fetchall()[0][0] == 6 def test_getattr_struct(self, duckdb_cursor): rel = duckdb_cursor.sql("select {'a':5, 'b':6} as a, 5 as b") @@ -56,7 +56,7 @@ def test_getattr_struct(self, duckdb_cursor): def test_getattr_spaces(self, duckdb_cursor): rel = duckdb_cursor.sql('select 42 as "hello world"') - assert rel['hello world'].fetchall()[0][0] == 42 + assert rel["hello world"].fetchall()[0][0] == 42 def test_getattr_doublequotes(self, duckdb_cursor): rel = duckdb_cursor.sql('select 1 as "tricky"", ""quotes", 2 as tricky, 3 as quotes') diff --git a/tests/fast/api/test_config.py b/tests/fast/api/test_config.py index 5db5f77b..4a0a0445 100644 --- a/tests/fast/api/test_config.py +++ b/tests/fast/api/test_config.py @@ -9,54 +9,54 @@ class TestDBConfig(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_default_order(self, duckdb_cursor, pandas): - df = pandas.DataFrame({'a': [1, 2, 3]}) - con = duckdb.connect(':memory:', config={'default_order': 'desc'}) - result = con.execute('select * from df order by a').fetchall() + df = pandas.DataFrame({"a": [1, 2, 3]}) + con = duckdb.connect(":memory:", config={"default_order": "desc"}) + result = con.execute("select * from df order by a").fetchall() assert result == [(3,), (2,), (1,)] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_null_order(self, duckdb_cursor, pandas): - df = pandas.DataFrame({'a': [1, 2, 3, None]}) - con = duckdb.connect(':memory:', config={'default_null_order': 'nulls_last'}) - result = con.execute('select * from df order by a').fetchall() + df = pandas.DataFrame({"a": [1, 2, 3, None]}) + con = duckdb.connect(":memory:", config={"default_null_order": "nulls_last"}) + result = con.execute("select * from df order by a").fetchall() assert result == [(1,), (2,), (3,), (None,)] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_multiple_options(self, duckdb_cursor, pandas): - df = pandas.DataFrame({'a': [1, 2, 3, None]}) - con = duckdb.connect(':memory:', config={'default_null_order': 'nulls_last', 'default_order': 'desc'}) - result = con.execute('select * from df order by a').fetchall() + df = pandas.DataFrame({"a": [1, 2, 3, None]}) + con = duckdb.connect(":memory:", config={"default_null_order": "nulls_last", "default_order": "desc"}) + result = con.execute("select * from df order by a").fetchall() assert result == [(3,), (2,), (1,), (None,)] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_external_access(self, duckdb_cursor, pandas): - df = pandas.DataFrame({'a': [1, 2, 3]}) + df = pandas.DataFrame({"a": [1, 2, 3]}) # this works (replacement scan) - con_regular = duckdb.connect(':memory:', config={}) - con_regular.execute('select * from df') + con_regular = duckdb.connect(":memory:", config={}) + con_regular.execute("select * from df") # disable external access: this also disables pandas replacement scans - con = duckdb.connect(':memory:', config={'enable_external_access': False}) + con = duckdb.connect(":memory:", config={"enable_external_access": False}) # this should fail query_failed = False try: - con.execute('select * from df').fetchall() + con.execute("select * from df").fetchall() except: query_failed = True assert query_failed == True def test_extension_setting(self): - repository = os.environ.get('LOCAL_EXTENSION_REPO') + repository = os.environ.get("LOCAL_EXTENSION_REPO") if not repository: return - con = duckdb.connect(config={"TimeZone": "UTC", 'autoinstall_extension_repository': repository}) - assert 'UTC' == con.sql("select current_setting('TimeZone')").fetchone()[0] + con = duckdb.connect(config={"TimeZone": "UTC", "autoinstall_extension_repository": repository}) + assert "UTC" == con.sql("select current_setting('TimeZone')").fetchone()[0] def test_unrecognized_option(self, duckdb_cursor): success = True try: - con_regular = duckdb.connect(':memory:', config={'thisoptionisprobablynotthere': '42'}) + con_regular = duckdb.connect(":memory:", config={"thisoptionisprobablynotthere": "42"}) except: success = False assert success == False @@ -64,27 +64,27 @@ def test_unrecognized_option(self, duckdb_cursor): def test_incorrect_parameter(self, duckdb_cursor): success = True try: - con_regular = duckdb.connect(':memory:', config={'default_null_order': '42'}) + con_regular = duckdb.connect(":memory:", config={"default_null_order": "42"}) except: success = False assert success == False def test_user_agent_default(self, duckdb_cursor): - con_regular = duckdb.connect(':memory:') + con_regular = duckdb.connect(":memory:") regex = re.compile("duckdb/.* python/.*") # Expands to: SELECT * FROM pragma_user_agent() assert regex.match(con_regular.sql("PRAGMA user_agent").fetchone()[0]) is not None custom_user_agent = con_regular.sql("SELECT current_setting('custom_user_agent')").fetchone() - assert custom_user_agent[0] == '' + assert custom_user_agent[0] == "" def test_user_agent_custom(self, duckdb_cursor): - con_regular = duckdb.connect(':memory:', config={'custom_user_agent': 'CUSTOM_STRING'}) + con_regular = duckdb.connect(":memory:", config={"custom_user_agent": "CUSTOM_STRING"}) regex = re.compile("duckdb/.* python/.* CUSTOM_STRING") assert regex.match(con_regular.sql("PRAGMA user_agent").fetchone()[0]) is not None custom_user_agent = con_regular.sql("SELECT current_setting('custom_user_agent')").fetchone() - assert custom_user_agent[0] == 'CUSTOM_STRING' + assert custom_user_agent[0] == "CUSTOM_STRING" def test_secret_manager_option(self, duckdb_cursor): - con = duckdb.connect(':memory:', config={'allow_persistent_secrets': False}) - result = con.execute('select count(*) from duckdb_secrets()').fetchall() + con = duckdb.connect(":memory:", config={"allow_persistent_secrets": False}) + result = con.execute("select count(*) from duckdb_secrets()").fetchall() assert result == [(0,)] diff --git a/tests/fast/api/test_connection_close.py b/tests/fast/api/test_connection_close.py index e7a47404..f71a02bb 100644 --- a/tests/fast/api/test_connection_close.py +++ b/tests/fast/api/test_connection_close.py @@ -54,7 +54,7 @@ def test_get_closed_default_conn(self, duckdb_cursor): duckdb.close() # 'duckdb.close()' closes this connection, because we explicitly set it as the default - with pytest.raises(duckdb.ConnectionException, match='Connection Error: Connection already closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection Error: Connection already closed"): con.sql("select 42").fetchall() default_con = duckdb.default_connection() @@ -65,11 +65,11 @@ def test_get_closed_default_conn(self, duckdb_cursor): duckdb.sql("select 42").fetchall() # Show that the 'default_con' is still closed - with pytest.raises(duckdb.ConnectionException, match='Connection Error: Connection already closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection Error: Connection already closed"): default_con.sql("select 42").fetchall() duckdb.close() # This also does not error because we silently receive a new connection - con2 = duckdb.connect(':default:') + con2 = duckdb.connect(":default:") con2.sql("select 42").fetchall() diff --git a/tests/fast/api/test_cursor.py b/tests/fast/api/test_cursor.py index 9510fbd9..69c3fe79 100644 --- a/tests/fast/api/test_cursor.py +++ b/tests/fast/api/test_cursor.py @@ -7,7 +7,7 @@ class TestDBAPICursor(object): def test_cursor_basic(self): # Create a connection - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") # Then create a cursor on the connection cursor = con.cursor() # Use the cursor for queries @@ -15,14 +15,14 @@ def test_cursor_basic(self): assert res == [([1, 2, 3, None, 4],)] def test_cursor_preexisting(self): - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") con.execute("create table tbl as select i a, i+1 b, i+2 c from range(5) tbl(i)") cursor = con.cursor() res = cursor.execute("select * from tbl").fetchall() assert res == [(0, 1, 2), (1, 2, 3), (2, 3, 4), (3, 4, 5), (4, 5, 6)] def test_cursor_after_creation(self): - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") # First create the cursor cursor = con.cursor() # Then create table on the source connection @@ -31,7 +31,7 @@ def test_cursor_after_creation(self): assert res == [(0, 1, 2), (1, 2, 3), (2, 3, 4), (3, 4, 5), (4, 5, 6)] def test_cursor_mixed(self): - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") # First create the cursor cursor = con.cursor() # Then create table on the cursor @@ -43,7 +43,7 @@ def test_cursor_mixed(self): assert res == [(0, 1, 2), (1, 2, 3), (2, 3, 4), (3, 4, 5), (4, 5, 6)] def test_cursor_temp_schema_closed(self): - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") cursor = con.cursor() cursor.execute("create temp table tbl as select * from range(100)") other_cursor = con.cursor() @@ -54,7 +54,7 @@ def test_cursor_temp_schema_closed(self): res = other_cursor.execute("select * from tbl").fetchall() def test_cursor_temp_schema_open(self): - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") cursor = con.cursor() cursor.execute("create temp table tbl as select * from range(100)") other_cursor = con.cursor() @@ -65,7 +65,7 @@ def test_cursor_temp_schema_open(self): res = other_cursor.execute("select * from tbl").fetchall() def test_cursor_temp_schema_both(self): - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") cursor1 = con.cursor() cursor2 = con.cursor() cursor3 = con.cursor() @@ -92,23 +92,23 @@ def test_cursor_timezone(self): # Because the 'timezone' setting was not explicitly set for the connection # the setting of the DBConfig is used instead res = con1.execute("SELECT make_timestamptz(2000,01,20,03,30,59)").fetchone() - assert str(res) == '(datetime.datetime(2000, 1, 20, 3, 30, 59, tzinfo=),)' + assert str(res) == "(datetime.datetime(2000, 1, 20, 3, 30, 59, tzinfo=),)" def test_cursor_closed(self): - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") con.close() with pytest.raises(duckdb.ConnectionException): cursor = con.cursor() def test_cursor_used_after_connection_closed(self): - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") cursor = con.cursor() con.close() with pytest.raises(duckdb.ConnectionException): cursor.execute("select [1,2,3,4]") def test_cursor_used_after_close(self): - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") cursor = con.cursor() cursor.close() with pytest.raises(duckdb.ConnectionException): diff --git a/tests/fast/api/test_dbapi00.py b/tests/fast/api/test_dbapi00.py index 815a81b9..38d87887 100644 --- a/tests/fast/api/test_dbapi00.py +++ b/tests/fast/api/test_dbapi00.py @@ -12,7 +12,7 @@ def assert_result_equal(result): class TestSimpleDBAPI(object): def test_regular_selection(self, duckdb_cursor, integers): - duckdb_cursor.execute('SELECT * FROM integers') + duckdb_cursor.execute("SELECT * FROM integers") result = duckdb_cursor.fetchall() assert_result_equal(result) @@ -20,7 +20,7 @@ def test_fetchmany_default(self, duckdb_cursor, integers): # Get truth-value truth_value = len(duckdb_cursor.execute("select * from integers").fetchall()) - duckdb_cursor.execute('Select * from integers') + duckdb_cursor.execute("Select * from integers") # by default 'size' is 1 arraysize = 1 list_of_results = [] @@ -40,7 +40,7 @@ def test_fetchmany_default(self, duckdb_cursor, integers): def test_fetchmany(self, duckdb_cursor, integers): # Get truth value truth_value = len(duckdb_cursor.execute("select * from integers").fetchall()) - duckdb_cursor.execute('select * from integers') + duckdb_cursor.execute("select * from integers") list_of_results = [] arraysize = 3 expected_iteration_count = 1 + (int)(truth_value / arraysize) + (1 if truth_value % arraysize else 0) @@ -63,8 +63,8 @@ def test_fetchmany(self, duckdb_cursor, integers): assert len(res) == 0 def test_fetchmany_too_many(self, duckdb_cursor, integers): - truth_value = len(duckdb_cursor.execute('select * from integers').fetchall()) - duckdb_cursor.execute('select * from integers') + truth_value = len(duckdb_cursor.execute("select * from integers").fetchall()) + duckdb_cursor.execute("select * from integers") res = duckdb_cursor.fetchmany(truth_value * 5) assert len(res) == truth_value assert_result_equal(res) @@ -74,48 +74,48 @@ def test_fetchmany_too_many(self, duckdb_cursor, integers): assert len(res) == 0 def test_numpy_selection(self, duckdb_cursor, integers, timestamps): - duckdb_cursor.execute('SELECT * FROM integers') + duckdb_cursor.execute("SELECT * FROM integers") result = duckdb_cursor.fetchnumpy() arr = numpy.ma.masked_array(numpy.arange(11)) arr.mask = [False] * 10 + [True] - numpy.testing.assert_array_equal(result['i'], arr, "Incorrect result returned") - duckdb_cursor.execute('SELECT * FROM timestamps') + numpy.testing.assert_array_equal(result["i"], arr, "Incorrect result returned") + duckdb_cursor.execute("SELECT * FROM timestamps") result = duckdb_cursor.fetchnumpy() - arr = numpy.array(['1992-10-03 18:34:45', '2010-01-01 00:00:01', None], dtype="datetime64[ms]") + arr = numpy.array(["1992-10-03 18:34:45", "2010-01-01 00:00:01", None], dtype="datetime64[ms]") arr = numpy.ma.masked_array(arr) arr.mask = [False, False, True] - numpy.testing.assert_array_equal(result['t'], arr, "Incorrect result returned") + numpy.testing.assert_array_equal(result["t"], arr, "Incorrect result returned") - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_pandas_selection(self, duckdb_cursor, pandas, integers, timestamps): import datetime from packaging.version import Version # I don't know when this exactly changed, but 2.0.3 does not support this, recent versions do - if Version(pandas.__version__) <= Version('2.0.3'): + if Version(pandas.__version__) <= Version("2.0.3"): pytest.skip("The resulting dtype is 'object' when given a Series with dtype Int32DType") - duckdb_cursor.execute('SELECT * FROM integers') + duckdb_cursor.execute("SELECT * FROM integers") result = duckdb_cursor.fetchdf() array = numpy.ma.masked_array(numpy.arange(11)) array.mask = [False] * 10 + [True] - arr = {'i': pandas.Series(array.data, dtype=pandas.Int32Dtype)} - arr['i'][array.mask] = pandas.NA + arr = {"i": pandas.Series(array.data, dtype=pandas.Int32Dtype)} + arr["i"][array.mask] = pandas.NA arr = pandas.DataFrame(arr) pandas.testing.assert_frame_equal(result, arr) - duckdb_cursor.execute('SELECT * FROM timestamps') + duckdb_cursor.execute("SELECT * FROM timestamps") result = duckdb_cursor.fetchdf() df = pandas.DataFrame( { - 't': pandas.Series( + "t": pandas.Series( data=[ datetime.datetime(year=1992, month=10, day=3, hour=18, minute=34, second=45), datetime.datetime(year=2010, month=1, day=1, hour=0, minute=0, second=1), None, ], - dtype='datetime64[us]', + dtype="datetime64[us]", ) } ) diff --git a/tests/fast/api/test_dbapi01.py b/tests/fast/api/test_dbapi01.py index dd0d2b4e..f7f00a10 100644 --- a/tests/fast/api/test_dbapi01.py +++ b/tests/fast/api/test_dbapi01.py @@ -6,8 +6,8 @@ class TestMultipleResultSets(object): def test_regular_selection(self, duckdb_cursor, integers): - duckdb_cursor.execute('SELECT * FROM integers') - duckdb_cursor.execute('SELECT * FROM integers') + duckdb_cursor.execute("SELECT * FROM integers") + duckdb_cursor.execute("SELECT * FROM integers") result = duckdb_cursor.fetchall() assert result == [ (0,), @@ -24,18 +24,18 @@ def test_regular_selection(self, duckdb_cursor, integers): ], "Incorrect result returned" def test_numpy_selection(self, duckdb_cursor, integers): - duckdb_cursor.execute('SELECT * FROM integers') - duckdb_cursor.execute('SELECT * FROM integers') + duckdb_cursor.execute("SELECT * FROM integers") + duckdb_cursor.execute("SELECT * FROM integers") result = duckdb_cursor.fetchnumpy() expected = numpy.ma.masked_array(numpy.arange(11), mask=([False] * 10 + [True])) - numpy.testing.assert_array_equal(result['i'], expected) + numpy.testing.assert_array_equal(result["i"], expected) def test_numpy_materialized(self, duckdb_cursor, integers): - connection = duckdb.connect('') + connection = duckdb.connect("") cursor = connection.cursor() - cursor.execute('CREATE TABLE integers (i integer)') - cursor.execute('INSERT INTO integers VALUES (0),(1),(2),(3),(4),(5),(6),(7),(8),(9),(NULL)') + cursor.execute("CREATE TABLE integers (i integer)") + cursor.execute("INSERT INTO integers VALUES (0),(1),(2),(3),(4),(5),(6),(7),(8),(9),(NULL)") rel = connection.table("integers") res = rel.aggregate("sum(i)").execute().fetchnumpy() - assert res['sum(i)'][0] == 45 + assert res["sum(i)"][0] == 45 diff --git a/tests/fast/api/test_dbapi04.py b/tests/fast/api/test_dbapi04.py index b2c9173a..1125f819 100644 --- a/tests/fast/api/test_dbapi04.py +++ b/tests/fast/api/test_dbapi04.py @@ -3,7 +3,7 @@ class TestSimpleDBAPI(object): def test_regular_selection(self, duckdb_cursor, integers): - duckdb_cursor.execute('SELECT * FROM integers') + duckdb_cursor.execute("SELECT * FROM integers") result = duckdb_cursor.fetchall() assert result == [ (0,), diff --git a/tests/fast/api/test_dbapi05.py b/tests/fast/api/test_dbapi05.py index 0de217f2..234fb2ec 100644 --- a/tests/fast/api/test_dbapi05.py +++ b/tests/fast/api/test_dbapi05.py @@ -3,7 +3,7 @@ class TestSimpleDBAPI(object): def test_prepare(self, duckdb_cursor): - result = duckdb_cursor.execute('SELECT CAST(? AS INTEGER), CAST(? AS INTEGER)', ['42', '84']).fetchall() + result = duckdb_cursor.execute("SELECT CAST(? AS INTEGER), CAST(? AS INTEGER)", ["42", "84"]).fetchall() assert result == [ ( 42, @@ -15,26 +15,26 @@ def test_prepare(self, duckdb_cursor): # from python docs c.execute( - '''CREATE TABLE stocks - (date text, trans text, symbol text, qty real, price real)''' + """CREATE TABLE stocks + (date text, trans text, symbol text, qty real, price real)""" ) c.execute("INSERT INTO stocks VALUES ('2006-01-05','BUY','RHAT',100,35.14)") - t = ('RHAT',) - result = c.execute('SELECT COUNT(*) FROM stocks WHERE symbol=?', t).fetchone() + t = ("RHAT",) + result = c.execute("SELECT COUNT(*) FROM stocks WHERE symbol=?", t).fetchone() assert result == (1,) - t = ['RHAT'] - result = c.execute('SELECT COUNT(*) FROM stocks WHERE symbol=?', t).fetchone() + t = ["RHAT"] + result = c.execute("SELECT COUNT(*) FROM stocks WHERE symbol=?", t).fetchone() assert result == (1,) # Larger example that inserts many records at a time purchases = [ - ('2006-03-28', 'BUY', 'IBM', 1000, 45.00), - ('2006-04-05', 'BUY', 'MSFT', 1000, 72.00), - ('2006-04-06', 'SELL', 'IBM', 500, 53.00), + ("2006-03-28", "BUY", "IBM", 1000, 45.00), + ("2006-04-05", "BUY", "MSFT", 1000, 72.00), + ("2006-04-06", "SELL", "IBM", 500, 53.00), ] - c.executemany('INSERT INTO stocks VALUES (?,?,?,?,?)', purchases) + c.executemany("INSERT INTO stocks VALUES (?,?,?,?,?)", purchases) - result = c.execute('SELECT count(*) FROM stocks').fetchone() + result = c.execute("SELECT count(*) FROM stocks").fetchone() assert result == (4,) diff --git a/tests/fast/api/test_dbapi07.py b/tests/fast/api/test_dbapi07.py index 7792b8de..238f30fc 100644 --- a/tests/fast/api/test_dbapi07.py +++ b/tests/fast/api/test_dbapi07.py @@ -7,10 +7,10 @@ class TestNumpyTimestampMilliseconds(object): def test_numpy_timestamp(self, duckdb_cursor): res = duckdb_cursor.execute("SELECT TIMESTAMP '2019-11-26 21:11:42.501' as test_time").fetchnumpy() - assert res['test_time'] == numpy.datetime64('2019-11-26 21:11:42.501') + assert res["test_time"] == numpy.datetime64("2019-11-26 21:11:42.501") class TestTimestampMilliseconds(object): def test_numpy_timestamp(self, duckdb_cursor): res = duckdb_cursor.execute("SELECT TIMESTAMP '2019-11-26 21:11:42.501' as test_time").fetchone()[0] - assert res == datetime.strptime('2019-11-26 21:11:42.501', '%Y-%m-%d %H:%M:%S.%f') + assert res == datetime.strptime("2019-11-26 21:11:42.501", "%Y-%m-%d %H:%M:%S.%f") diff --git a/tests/fast/api/test_dbapi08.py b/tests/fast/api/test_dbapi08.py index a81acfd1..457a9e78 100644 --- a/tests/fast/api/test_dbapi08.py +++ b/tests/fast/api/test_dbapi08.py @@ -6,7 +6,7 @@ class TestType(object): - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_fetchdf(self, pandas): con = duckdb.connect() con.execute("CREATE TABLE items(item VARCHAR)") @@ -14,7 +14,7 @@ def test_fetchdf(self, pandas): res = con.execute("SELECT item FROM items").fetchdf() assert isinstance(res, pandas.core.frame.DataFrame) - df = pandas.DataFrame({'item': ['jeans', '', None]}) + df = pandas.DataFrame({"item": ["jeans", "", None]}) print(res) print(df) diff --git a/tests/fast/api/test_dbapi09.py b/tests/fast/api/test_dbapi09.py index dde8ebff..538e7fc3 100644 --- a/tests/fast/api/test_dbapi09.py +++ b/tests/fast/api/test_dbapi09.py @@ -12,11 +12,11 @@ def test_fetchall_date(self, duckdb_cursor): def test_fetchnumpy_date(self, duckdb_cursor): res = duckdb_cursor.execute("SELECT DATE '2020-01-10' as test_date").fetchnumpy() - arr = numpy.array(['2020-01-10'], dtype="datetime64[s]") + arr = numpy.array(["2020-01-10"], dtype="datetime64[s]") arr = numpy.ma.masked_array(arr) - numpy.testing.assert_array_equal(res['test_date'], arr) + numpy.testing.assert_array_equal(res["test_date"], arr) def test_fetchdf_date(self, duckdb_cursor): res = duckdb_cursor.execute("SELECT DATE '2020-01-10' as test_date").fetchdf() - ser = pandas.Series(numpy.array(['2020-01-10'], dtype="datetime64[us]"), name="test_date") - pandas.testing.assert_series_equal(res['test_date'], ser) + ser = pandas.Series(numpy.array(["2020-01-10"], dtype="datetime64[us]"), name="test_date") + pandas.testing.assert_series_equal(res["test_date"], ser) diff --git a/tests/fast/api/test_dbapi12.py b/tests/fast/api/test_dbapi12.py index 78881f5e..833d231c 100644 --- a/tests/fast/api/test_dbapi12.py +++ b/tests/fast/api/test_dbapi12.py @@ -10,45 +10,45 @@ def test_readonly(self, duckdb_cursor): def test_rel(rel, duckdb_cursor): res = ( - rel.filter('i < 3') - .order('j') - .project('i') - .union(rel.filter('i > 2').project('i')) - .join(rel.set_alias('a1'), 'i') - .project('CAST(i as BIGINT) i, j') - .order('i') + rel.filter("i < 3") + .order("j") + .project("i") + .union(rel.filter("i > 2").project("i")) + .join(rel.set_alias("a1"), "i") + .project("CAST(i as BIGINT) i, j") + .order("i") ) pd.testing.assert_frame_equal(res.to_df(), test_df) res3 = duckdb_cursor.from_df(res.to_df()).to_df() pd.testing.assert_frame_equal(res3, test_df) - df_sql = res.query('x', 'select CAST(i as BIGINT) i, j from x') + df_sql = res.query("x", "select CAST(i as BIGINT) i, j from x") pd.testing.assert_frame_equal(df_sql.df(), test_df) - res2 = res.aggregate('i, count(j) as cj', 'i').order('i') + res2 = res.aggregate("i, count(j) as cj", "i").order("i") cmp_df = pd.DataFrame.from_dict({"i": [1, 2, 3], "cj": [1, 1, 1]}) pd.testing.assert_frame_equal(res2.to_df(), cmp_df) - duckdb_cursor.execute('DROP TABLE IF EXISTS a2') - rel.create('a2') - rel_a2 = duckdb_cursor.table('a2').project('CAST(i as BIGINT) i, j').to_df() + duckdb_cursor.execute("DROP TABLE IF EXISTS a2") + rel.create("a2") + rel_a2 = duckdb_cursor.table("a2").project("CAST(i as BIGINT) i, j").to_df() pd.testing.assert_frame_equal(rel_a2, test_df) - duckdb_cursor.execute('DROP TABLE IF EXISTS a3') - duckdb_cursor.execute('CREATE TABLE a3 (i INTEGER, j STRING)') - rel.insert_into('a3') - rel_a3 = duckdb_cursor.table('a3').project('CAST(i as BIGINT) i, j').to_df() + duckdb_cursor.execute("DROP TABLE IF EXISTS a3") + duckdb_cursor.execute("CREATE TABLE a3 (i INTEGER, j STRING)") + rel.insert_into("a3") + rel_a3 = duckdb_cursor.table("a3").project("CAST(i as BIGINT) i, j").to_df() pd.testing.assert_frame_equal(rel_a3, test_df) - duckdb_cursor.execute('CREATE TABLE a (i INTEGER, j STRING)') + duckdb_cursor.execute("CREATE TABLE a (i INTEGER, j STRING)") duckdb_cursor.execute("INSERT INTO a VALUES (1, 'one'), (2, 'two'), (3, 'three')") - duckdb_cursor.execute('CREATE VIEW v AS SELECT * FROM a') + duckdb_cursor.execute("CREATE VIEW v AS SELECT * FROM a") - duckdb_cursor.execute('CREATE TEMPORARY TABLE at_ (i INTEGER)') - duckdb_cursor.execute('CREATE TEMPORARY VIEW vt AS SELECT * FROM at_') + duckdb_cursor.execute("CREATE TEMPORARY TABLE at_ (i INTEGER)") + duckdb_cursor.execute("CREATE TEMPORARY VIEW vt AS SELECT * FROM at_") - rel_a = duckdb_cursor.table('a') - rel_v = duckdb_cursor.view('v') + rel_a = duckdb_cursor.table("a") + rel_v = duckdb_cursor.view("v") # rel_at = duckdb_cursor.table('at') # rel_vt = duckdb_cursor.view('vt') @@ -59,8 +59,8 @@ def test_rel(rel, duckdb_cursor): test_rel(rel_df, duckdb_cursor) def test_fromquery(self, duckdb_cursor): - assert duckdb.from_query('select 42').fetchone()[0] == 42 - assert duckdb_cursor.query('select 43').fetchone()[0] == 43 + assert duckdb.from_query("select 42").fetchone()[0] == 42 + assert duckdb_cursor.query("select 43").fetchone()[0] == 43 # assert duckdb_cursor.from_query('select 44').execute().fetchone()[0] == 44 # assert duckdb_cursor.from_query('select 45').execute().fetchone()[0] == 45 diff --git a/tests/fast/api/test_dbapi13.py b/tests/fast/api/test_dbapi13.py index fb7fbaa8..ffdb4884 100644 --- a/tests/fast/api/test_dbapi13.py +++ b/tests/fast/api/test_dbapi13.py @@ -14,9 +14,9 @@ def test_fetchnumpy_time(self, duckdb_cursor): res = duckdb_cursor.execute("SELECT TIME '13:06:40' as test_time").fetchnumpy() arr = numpy.array([datetime.time(13, 6, 40)], dtype="object") arr = numpy.ma.masked_array(arr) - numpy.testing.assert_array_equal(res['test_time'], arr) + numpy.testing.assert_array_equal(res["test_time"], arr) def test_fetchdf_time(self, duckdb_cursor): res = duckdb_cursor.execute("SELECT TIME '13:06:40' as test_time").fetchdf() ser = pandas.Series(numpy.array([datetime.time(13, 6, 40)], dtype="object"), name="test_time") - pandas.testing.assert_series_equal(res['test_time'], ser) + pandas.testing.assert_series_equal(res["test_time"], ser) diff --git a/tests/fast/api/test_dbapi_fetch.py b/tests/fast/api/test_dbapi_fetch.py index 6eda4b9d..9c47c54c 100644 --- a/tests/fast/api/test_dbapi_fetch.py +++ b/tests/fast/api/test_dbapi_fetch.py @@ -8,21 +8,21 @@ class TestDBApiFetch(object): def test_multiple_fetch_one(self, duckdb_cursor): con = duckdb.connect() - c = con.execute('SELECT 42') + c = con.execute("SELECT 42") assert c.fetchone() == (42,) assert c.fetchone() is None assert c.fetchone() is None def test_multiple_fetch_all(self, duckdb_cursor): con = duckdb.connect() - c = con.execute('SELECT 42') + c = con.execute("SELECT 42") assert c.fetchall() == [(42,)] assert c.fetchall() == [] assert c.fetchall() == [] def test_multiple_fetch_many(self, duckdb_cursor): con = duckdb.connect() - c = con.execute('SELECT 42') + c = con.execute("SELECT 42") assert c.fetchmany(1000) == [(42,)] assert c.fetchmany(1000) == [] assert c.fetchmany(1000) == [] @@ -30,8 +30,8 @@ def test_multiple_fetch_many(self, duckdb_cursor): def test_multiple_fetch_df(self, duckdb_cursor): pd = pytest.importorskip("pandas") con = duckdb.connect() - c = con.execute('SELECT 42::BIGINT AS a') - pd.testing.assert_frame_equal(c.df(), pd.DataFrame.from_dict({'a': [42]})) + c = con.execute("SELECT 42::BIGINT AS a") + pd.testing.assert_frame_equal(c.df(), pd.DataFrame.from_dict({"a": [42]})) assert c.df() is None assert c.df() is None @@ -39,36 +39,36 @@ def test_multiple_fetch_arrow(self, duckdb_cursor): pd = pytest.importorskip("pandas") arrow = pytest.importorskip("pyarrow") con = duckdb.connect() - c = con.execute('SELECT 42::BIGINT AS a') + c = con.execute("SELECT 42::BIGINT AS a") table = c.fetch_arrow_table() df = table.to_pandas() - pd.testing.assert_frame_equal(df, pd.DataFrame.from_dict({'a': [42]})) + pd.testing.assert_frame_equal(df, pd.DataFrame.from_dict({"a": [42]})) assert c.fetch_arrow_table() is None assert c.fetch_arrow_table() is None def test_multiple_close(self, duckdb_cursor): con = duckdb.connect() - c = con.execute('SELECT 42') + c = con.execute("SELECT 42") c.close() c.close() c.close() - with pytest.raises(duckdb.InvalidInputException, match='No open result set'): + with pytest.raises(duckdb.InvalidInputException, match="No open result set"): c.fetchall() def test_multiple_fetch_all_relation(self, duckdb_cursor): - res = duckdb_cursor.query('SELECT 42') + res = duckdb_cursor.query("SELECT 42") assert res.fetchall() == [(42,)] assert res.fetchall() == [(42,)] assert res.fetchall() == [(42,)] def test_multiple_fetch_many_relation(self, duckdb_cursor): - res = duckdb_cursor.query('SELECT 42') + res = duckdb_cursor.query("SELECT 42") assert res.fetchmany(10000) == [(42,)] assert res.fetchmany(10000) == [] assert res.fetchmany(10000) == [] def test_fetch_one_relation(self, duckdb_cursor): - res = duckdb_cursor.query('SELECT * FROM range(3)') + res = duckdb_cursor.query("SELECT * FROM range(3)") assert res.fetchone() == (0,) assert res.fetchone() == (1,) assert res.fetchone() == (2,) @@ -86,40 +86,40 @@ def test_fetch_one_relation(self, duckdb_cursor): assert res.fetchone() is None @pytest.mark.parametrize( - 'test_case', + "test_case", [ - (False, 'BOOLEAN', False), - (-128, 'TINYINT', -128), - (-32768, 'SMALLINT', -32768), - (-2147483648, 'INTEGER', -2147483648), - (-9223372036854775808, 'BIGINT', -9223372036854775808), - (-170141183460469231731687303715884105728, 'HUGEINT', -170141183460469231731687303715884105728), - (0, 'UTINYINT', 0), - (0, 'USMALLINT', 0), - (0, 'UINTEGER', 0), - (0, 'UBIGINT', 0), - (0, 'UHUGEINT', 0), - (1.3423423767089844, 'FLOAT', 1.3423424), - (1.3423424, 'DOUBLE', 1.3423424), - (Decimal('1.342342'), 'DECIMAL(10, 6)', 1.342342), - ('hello', "ENUM('world', 'hello')", 'hello'), - ('🦆🦆🦆🦆🦆🦆', 'VARCHAR', '🦆🦆🦆🦆🦆🦆'), - (b'thisisalongblob\x00withnullbytes', 'BLOB', 'thisisalongblob\\x00withnullbytes'), - ('0010001001011100010101011010111', 'BITSTRING', '0010001001011100010101011010111'), - ('290309-12-22 (BC) 00:00:00', 'TIMESTAMP', '290309-12-22 (BC) 00:00:00'), - ('290309-12-22 (BC) 00:00:00', 'TIMESTAMP_MS', '290309-12-22 (BC) 00:00:00'), - (datetime.datetime(1677, 9, 22, 0, 0), 'TIMESTAMP_NS', '1677-09-22 00:00:00'), - ('290309-12-22 (BC) 00:00:00', 'TIMESTAMP_S', '290309-12-22 (BC) 00:00:00'), - ('290309-12-22 (BC) 00:00:30+00', 'TIMESTAMPTZ', '290309-12-22 (BC) 00:17:30+00:17'), + (False, "BOOLEAN", False), + (-128, "TINYINT", -128), + (-32768, "SMALLINT", -32768), + (-2147483648, "INTEGER", -2147483648), + (-9223372036854775808, "BIGINT", -9223372036854775808), + (-170141183460469231731687303715884105728, "HUGEINT", -170141183460469231731687303715884105728), + (0, "UTINYINT", 0), + (0, "USMALLINT", 0), + (0, "UINTEGER", 0), + (0, "UBIGINT", 0), + (0, "UHUGEINT", 0), + (1.3423423767089844, "FLOAT", 1.3423424), + (1.3423424, "DOUBLE", 1.3423424), + (Decimal("1.342342"), "DECIMAL(10, 6)", 1.342342), + ("hello", "ENUM('world', 'hello')", "hello"), + ("🦆🦆🦆🦆🦆🦆", "VARCHAR", "🦆🦆🦆🦆🦆🦆"), + (b"thisisalongblob\x00withnullbytes", "BLOB", "thisisalongblob\\x00withnullbytes"), + ("0010001001011100010101011010111", "BITSTRING", "0010001001011100010101011010111"), + ("290309-12-22 (BC) 00:00:00", "TIMESTAMP", "290309-12-22 (BC) 00:00:00"), + ("290309-12-22 (BC) 00:00:00", "TIMESTAMP_MS", "290309-12-22 (BC) 00:00:00"), + (datetime.datetime(1677, 9, 22, 0, 0), "TIMESTAMP_NS", "1677-09-22 00:00:00"), + ("290309-12-22 (BC) 00:00:00", "TIMESTAMP_S", "290309-12-22 (BC) 00:00:00"), + ("290309-12-22 (BC) 00:00:30+00", "TIMESTAMPTZ", "290309-12-22 (BC) 00:17:30+00:17"), ( datetime.time(0, 0, tzinfo=datetime.timezone(datetime.timedelta(seconds=57599))), - 'TIMETZ', - '00:00:00+15:59:59', + "TIMETZ", + "00:00:00+15:59:59", ), - ('5877642-06-25 (BC)', 'DATE', '5877642-06-25 (BC)'), - (UUID('cd57dfbd-d65f-4e15-991e-2a92e74b9f79'), 'UUID', 'cd57dfbd-d65f-4e15-991e-2a92e74b9f79'), - (datetime.timedelta(days=90), 'INTERVAL', '3 months'), - ('🦆🦆🦆🦆🦆🦆', 'UNION(a int, b bool, c varchar)', '🦆🦆🦆🦆🦆🦆'), + ("5877642-06-25 (BC)", "DATE", "5877642-06-25 (BC)"), + (UUID("cd57dfbd-d65f-4e15-991e-2a92e74b9f79"), "UUID", "cd57dfbd-d65f-4e15-991e-2a92e74b9f79"), + (datetime.timedelta(days=90), "INTERVAL", "3 months"), + ("🦆🦆🦆🦆🦆🦆", "UNION(a int, b bool, c varchar)", "🦆🦆🦆🦆🦆🦆"), ], ) def test_fetch_dict_coverage(self, duckdb_cursor, test_case): @@ -138,7 +138,7 @@ def test_fetch_dict_coverage(self, duckdb_cursor, test_case): print(res[0].keys()) assert res[0][python_key] == -2147483648 - @pytest.mark.parametrize('test_case', ['VARCHAR[]']) + @pytest.mark.parametrize("test_case", ["VARCHAR[]"]) def test_fetch_dict_key_not_hashable(self, duckdb_cursor, test_case): key_type = test_case query = f""" @@ -153,4 +153,4 @@ def test_fetch_dict_key_not_hashable(self, duckdb_cursor, test_case): select a from map_cte; """ res = duckdb_cursor.sql(query).fetchone() - assert 'key' in res[0].keys() + assert "key" in res[0].keys() diff --git a/tests/fast/api/test_duckdb_connection.py b/tests/fast/api/test_duckdb_connection.py index 4cb565c1..4b0dc4d6 100644 --- a/tests/fast/api/test_duckdb_connection.py +++ b/tests/fast/api/test_duckdb_connection.py @@ -9,9 +9,9 @@ def is_dunder_method(method_name: str) -> bool: if len(method_name) < 4: return False - if method_name.startswith('_pybind11'): + if method_name.startswith("_pybind11"): return True - return method_name[:2] == '__' and method_name[:-3:-1] == '__' + return method_name[:2] == "__" and method_name[:-3:-1] == "__" @pytest.fixture(scope="session") @@ -23,32 +23,32 @@ def tmp_database(tmp_path_factory): # This file contains tests for DuckDBPyConnection methods, # wrapped by the 'duckdb' module, to execute with the 'default_connection' class TestDuckDBConnection(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_append(self, pandas): duckdb.execute("Create table integers (i integer)") df_in = pandas.DataFrame( { - 'numbers': [1, 2, 3, 4, 5], + "numbers": [1, 2, 3, 4, 5], } ) - duckdb.append('integers', df_in) - assert duckdb.execute('select count(*) from integers').fetchone()[0] == 5 + duckdb.append("integers", df_in) + assert duckdb.execute("select count(*) from integers").fetchone()[0] == 5 # cleanup duckdb.execute("drop table integers") def test_default_connection_from_connect(self): - duckdb.sql('create or replace table connect_default_connect (i integer)') - con = duckdb.connect(':default:') - con.sql('select i from connect_default_connect') - duckdb.sql('drop table connect_default_connect') + duckdb.sql("create or replace table connect_default_connect (i integer)") + con = duckdb.connect(":default:") + con.sql("select i from connect_default_connect") + duckdb.sql("drop table connect_default_connect") with pytest.raises(duckdb.Error): - con.sql('select i from connect_default_connect') + con.sql("select i from connect_default_connect") # not allowed with additional options with pytest.raises( - duckdb.InvalidInputException, match='Default connection fetching is only allowed without additional options' + duckdb.InvalidInputException, match="Default connection fetching is only allowed without additional options" ): - con = duckdb.connect(':default:', read_only=True) + con = duckdb.connect(":default:", read_only=True) def test_arrow(self): pyarrow = pytest.importorskip("pyarrow") @@ -114,7 +114,7 @@ def test_readonly_properties(self): duckdb.execute("select 42") description = duckdb.description() rowcount = duckdb.rowcount() - assert description == [('42', 'INTEGER', None, None, None, None, None)] + assert description == [("42", "INTEGER", None, None, None, None, None)] assert rowcount == -1 def test_execute(self): @@ -124,29 +124,29 @@ def test_executemany(self): # executemany does not keep an open result set # TODO: shouldn't we also have a version that executes a query multiple times with different parameters, returning all of the results? duckdb.execute("create table tbl (i integer, j varchar)") - duckdb.executemany("insert into tbl VALUES (?, ?)", [(5, 'test'), (2, 'duck'), (42, 'quack')]) + duckdb.executemany("insert into tbl VALUES (?, ?)", [(5, "test"), (2, "duck"), (42, "quack")]) res = duckdb.table("tbl").fetchall() - assert res == [(5, 'test'), (2, 'duck'), (42, 'quack')] + assert res == [(5, "test"), (2, "duck"), (42, "quack")] duckdb.execute("drop table tbl") def test_pystatement(self): - with pytest.raises(duckdb.ParserException, match='seledct'): - statements = duckdb.extract_statements('seledct 42; select 21') + with pytest.raises(duckdb.ParserException, match="seledct"): + statements = duckdb.extract_statements("seledct 42; select 21") - statements = duckdb.extract_statements('select $1; select 21') + statements = duckdb.extract_statements("select $1; select 21") assert len(statements) == 2 - assert statements[0].query == 'select $1' + assert statements[0].query == "select $1" assert statements[0].type == duckdb.StatementType.SELECT - assert statements[0].named_parameters == set('1') + assert statements[0].named_parameters == set("1") assert statements[0].expected_result_type == [duckdb.ExpectedResultType.QUERY_RESULT] - assert statements[1].query == ' select 21' + assert statements[1].query == " select 21" assert statements[1].type == duckdb.StatementType.SELECT assert statements[1].named_parameters == set() with pytest.raises( duckdb.InvalidInputException, - match='Please provide either a DuckDBPyStatement or a string representing the query', + match="Please provide either a DuckDBPyStatement or a string representing the query", ): rel = duckdb.query(statements) @@ -158,23 +158,23 @@ def test_pystatement(self): with pytest.raises( duckdb.InvalidInputException, - match='Values were not provided for the following prepared statement parameters: 1', + match="Values were not provided for the following prepared statement parameters: 1", ): duckdb.execute(statements[0]) - assert duckdb.execute(statements[0], {'1': 42}).fetchall() == [(42,)] + assert duckdb.execute(statements[0], {"1": 42}).fetchall() == [(42,)] duckdb.execute("create table tbl(a integer)") - statements = duckdb.extract_statements('insert into tbl select $1') + statements = duckdb.extract_statements("insert into tbl select $1") assert statements[0].expected_result_type == [ duckdb.ExpectedResultType.CHANGED_ROWS, duckdb.ExpectedResultType.QUERY_RESULT, ] with pytest.raises( - duckdb.InvalidInputException, match='executemany requires a non-empty list of parameter sets to be provided' + duckdb.InvalidInputException, match="executemany requires a non-empty list of parameter sets to be provided" ): duckdb.executemany(statements[0]) duckdb.executemany(statements[0], [(21,), (22,), (23,)]) - assert duckdb.table('tbl').fetchall() == [(21,), (22,), (23,)] + assert duckdb.table("tbl").fetchall() == [(21,), (22,), (23,)] duckdb.execute("drop table tbl") def test_fetch_arrow_table(self): @@ -188,18 +188,18 @@ def test_fetch_arrow_table(self): duckdb.execute("Insert Into test values ('" + str(i) + "')") duckdb.execute("Insert Into test values ('5000')") duckdb.execute("Insert Into test values ('6000')") - sql = ''' + sql = """ SELECT a, COUNT(*) AS repetitions FROM test GROUP BY a - ''' + """ result_df = duckdb.execute(sql).df() arrow_table = duckdb.execute(sql).fetch_arrow_table() arrow_df = arrow_table.to_pandas() - assert result_df['repetitions'].sum() == arrow_df['repetitions'].sum() + assert result_df["repetitions"].sum() == arrow_df["repetitions"].sum() duckdb.execute("drop table test") def test_fetch_df(self): @@ -213,10 +213,10 @@ def test_fetch_df_chunk(self): duckdb.execute("CREATE table t as select range a from range(3000);") query = duckdb.execute("SELECT a FROM t") cur_chunk = query.fetch_df_chunk() - assert cur_chunk['a'][0] == 0 + assert cur_chunk["a"][0] == 0 assert len(cur_chunk) == 2048 cur_chunk = query.fetch_df_chunk() - assert cur_chunk['a'][0] == 2048 + assert cur_chunk["a"][0] == 2048 assert len(cur_chunk) == 952 duckdb.execute("DROP TABLE t") @@ -247,11 +247,11 @@ def test_fetchnumpy(self): numpy = pytest.importorskip("numpy") duckdb.execute("SELECT BLOB 'hello'") results = duckdb.fetchall() - assert results[0][0] == b'hello' + assert results[0][0] == b"hello" duckdb.execute("SELECT BLOB 'hello' AS a") results = duckdb.fetchnumpy() - assert results['a'] == numpy.array([b'hello'], dtype=object) + assert results["a"] == numpy.array([b"hello"], dtype=object) def test_fetchone(self): assert (0,) == duckdb.execute("select * from range(5)").fetchone() @@ -288,11 +288,11 @@ def test_register(self): def test_register_relation(self): con = duckdb.connect() - rel = con.sql('select [5,4,3]') + rel = con.sql("select [5,4,3]") con.register("relation", rel) con.sql("create table tbl as select * from relation") - assert con.table('tbl').fetchall() == [([5, 4, 3],)] + assert con.table("tbl").fetchall() == [([5, 4, 3],)] def test_unregister_problematic_behavior(self, duckdb_cursor): # We have a VIEW called 'vw' in the Catalog @@ -302,33 +302,33 @@ def test_unregister_problematic_behavior(self, duckdb_cursor): # Create a registered object called 'vw' arrow_result = duckdb_cursor.execute("select 42").fetch_arrow_table() with pytest.raises(duckdb.CatalogException, match='View with name "vw" already exists'): - duckdb_cursor.register('vw', arrow_result) + duckdb_cursor.register("vw", arrow_result) # Temporary views take precedence over registered objects assert duckdb_cursor.execute("select * from vw").fetchone() == (0,) # Decide that we're done with this registered object.. - duckdb_cursor.unregister('vw') + duckdb_cursor.unregister("vw") # This should not have affected the existing view: assert duckdb_cursor.execute("select * from vw").fetchone() == (0,) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_relation_out_of_scope(self, pandas): def temporary_scope(): # Create a connection, we will return this con = duckdb.connect() # Create a dataframe - df = pandas.DataFrame({'a': [1, 2, 3]}) + df = pandas.DataFrame({"a": [1, 2, 3]}) # The dataframe has to be registered as well # making sure it does not go out of scope con.register("df", df) - rel = con.sql('select * from df') + rel = con.sql("select * from df") con.register("relation", rel) return con con = temporary_scope() - res = con.sql('select * from relation').fetchall() + res = con.sql("select * from relation").fetchall() print(res) def test_table(self): diff --git a/tests/fast/api/test_duckdb_execute.py b/tests/fast/api/test_duckdb_execute.py index fba01a0c..a025fc42 100644 --- a/tests/fast/api/test_duckdb_execute.py +++ b/tests/fast/api/test_duckdb_execute.py @@ -4,8 +4,8 @@ class TestDuckDBExecute(object): def test_execute_basic(self, duckdb_cursor): - duckdb_cursor.execute('create table t as select 5') - res = duckdb_cursor.table('t').fetchall() + duckdb_cursor.execute("create table t as select 5") + res = duckdb_cursor.table("t").fetchall() assert res == [(5,)] def test_execute_many_basic(self, duckdb_cursor): @@ -19,11 +19,11 @@ def test_execute_many_basic(self, duckdb_cursor): """, (99,), ) - res = duckdb_cursor.table('t').fetchall() + res = duckdb_cursor.table("t").fetchall() assert res == [(99,)] @pytest.mark.parametrize( - 'rowcount', + "rowcount", [ 50, 2048, @@ -53,7 +53,7 @@ def test_execute_many_error(self, duckdb_cursor): # Prepared parameter used in a statement that is not the last with pytest.raises( - duckdb.NotImplementedException, match='Prepared parameters are only supported for the last statement' + duckdb.NotImplementedException, match="Prepared parameters are only supported for the last statement" ): duckdb_cursor.execute( """ @@ -73,11 +73,11 @@ def to_insert_from_generator(what): gen = to_insert_from_generator(to_insert) duckdb_cursor.execute("CREATE TABLE unittest_generator (a INTEGER);") duckdb_cursor.executemany("INSERT into unittest_generator (a) VALUES (?)", gen) - assert duckdb_cursor.table('unittest_generator').fetchall() == [(1,), (2,), (3,)] + assert duckdb_cursor.table("unittest_generator").fetchall() == [(1,), (2,), (3,)] def test_execute_multiple_statements(self, duckdb_cursor): pd = pytest.importorskip("pandas") - df = pd.DataFrame({'a': [5, 6, 7, 8]}) + df = pd.DataFrame({"a": [5, 6, 7, 8]}) sql = """ select * from df; select * from VALUES (1),(2),(3),(4) t(a); diff --git a/tests/fast/api/test_duckdb_query.py b/tests/fast/api/test_duckdb_query.py index 43f36603..2ecfd8f3 100644 --- a/tests/fast/api/test_duckdb_query.py +++ b/tests/fast/api/test_duckdb_query.py @@ -7,38 +7,38 @@ class TestDuckDBQuery(object): def test_duckdb_query(self, duckdb_cursor): # we can use duckdb_cursor.sql to run both DDL statements and select statements - duckdb_cursor.sql('create view v1 as select 42 i') - rel = duckdb_cursor.sql('select * from v1') + duckdb_cursor.sql("create view v1 as select 42 i") + rel = duckdb_cursor.sql("select * from v1") assert rel.fetchall()[0][0] == 42 # also multiple statements - duckdb_cursor.sql('create view v2 as select i*2 j from v1; create view v3 as select j * 2 from v2;') - rel = duckdb_cursor.sql('select * from v3') + duckdb_cursor.sql("create view v2 as select i*2 j from v1; create view v3 as select j * 2 from v2;") + rel = duckdb_cursor.sql("select * from v3") assert rel.fetchall()[0][0] == 168 # we can run multiple select statements - we get only the last result - res = duckdb_cursor.sql('select 42; select 84;').fetchall() + res = duckdb_cursor.sql("select 42; select 84;").fetchall() assert res == [(84,)] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_duckdb_from_query_multiple_statements(self, pandas): - tst_df = pandas.DataFrame({'a': [1, 23, 3, 5]}) + tst_df = pandas.DataFrame({"a": [1, 23, 3, 5]}) res = duckdb.sql( - ''' + """ select 42; select * from tst_df union all select * from tst_df; - ''' + """ ).fetchall() assert res == [(1,), (23,), (3,), (5,), (1,), (23,), (3,), (5,)] def test_duckdb_query_empty_result(self): con = duckdb.connect() # show tables on empty connection does not produce any tuples - res = con.query('show tables').fetchall() + res = con.query("show tables").fetchall() assert res == [] def test_parametrized_explain(self, duckdb_cursor): @@ -57,7 +57,7 @@ def test_parametrized_explain(self, duckdb_cursor): duckdb_cursor.execute(query, params) results = duckdb_cursor.fetchall() - assert 'EXPLAIN_ANALYZE' in results[0][1] + assert "EXPLAIN_ANALYZE" in results[0][1] def test_named_param(self): con = duckdb.connect() @@ -83,7 +83,7 @@ def test_named_param(self): from range(100) tbl(i) """, - {'param': 5, 'other_param': 10}, + {"param": 5, "other_param": 10}, ).fetchall() assert res == original_res @@ -95,14 +95,14 @@ def test_named_param_not_dict(self): duckdb.InvalidInputException, match="Values were not provided for the following prepared statement parameters: name1, name2, name3", ): - con.execute("select $name1, $name2, $name3", ['name1', 'name2', 'name3']) + con.execute("select $name1, $name2, $name3", ["name1", "name2", "name3"]) def test_named_param_basic(self): con = duckdb.connect() - res = con.execute("select $name1, $name2, $name3", {'name1': 5, 'name2': 3, 'name3': 'a'}).fetchall() + res = con.execute("select $name1, $name2, $name3", {"name1": 5, "name2": 3, "name3": "a"}).fetchall() assert res == [ - (5, 3, 'a'), + (5, 3, "a"), ] def test_named_param_not_exhaustive(self): @@ -112,7 +112,7 @@ def test_named_param_not_exhaustive(self): duckdb.InvalidInputException, match="Invalid Input Error: Values were not provided for the following prepared statement parameters: name3", ): - con.execute("select $name1, $name2, $name3", {'name1': 5, 'name2': 3}) + con.execute("select $name1, $name2, $name3", {"name1": 5, "name2": 3}) def test_named_param_excessive(self): con = duckdb.connect() @@ -121,7 +121,7 @@ def test_named_param_excessive(self): duckdb.InvalidInputException, match="Values were not provided for the following prepared statement parameters: name3", ): - con.execute("select $name1, $name2, $name3", {'name1': 5, 'name2': 3, 'not_a_named_param': 5}) + con.execute("select $name1, $name2, $name3", {"name1": 5, "name2": 3, "not_a_named_param": 5}) def test_named_param_not_named(self): con = duckdb.connect() @@ -130,7 +130,7 @@ def test_named_param_not_named(self): duckdb.InvalidInputException, match="Values were not provided for the following prepared statement parameters: 1, 2", ): - con.execute("select $1, $1, $2", {'name1': 5, 'name2': 3}) + con.execute("select $1, $1, $2", {"name1": 5, "name2": 3}) def test_named_param_mixed(self): con = duckdb.connect() @@ -138,13 +138,13 @@ def test_named_param_mixed(self): with pytest.raises( duckdb.NotImplementedException, match="Mixing named and positional parameters is not supported yet" ): - con.execute("select $name1, $1, $2", {'name1': 5, 'name2': 3}) + con.execute("select $name1, $1, $2", {"name1": 5, "name2": 3}) def test_named_param_strings_with_dollarsign(self): con = duckdb.connect() - res = con.execute("select '$name1', $name1, $name1, '$name1'", {'name1': 5}).fetchall() - assert res == [('$name1', 5, 5, '$name1')] + res = con.execute("select '$name1', $name1, $name1, '$name1'", {"name1": 5}).fetchall() + assert res == [("$name1", 5, 5, "$name1")] def test_named_param_case_insensivity(self): con = duckdb.connect() @@ -153,10 +153,10 @@ def test_named_param_case_insensivity(self): """ select $NaMe1, $NAME2, $name3 """, - {'name1': 5, 'nAmE2': 3, 'NAME3': 'a'}, + {"name1": 5, "nAmE2": 3, "NAME3": "a"}, ).fetchall() assert res == [ - (5, 3, 'a'), + (5, 3, "a"), ] def test_named_param_keyword(self): @@ -176,16 +176,16 @@ def test_conversion_from_tuple(self): assert result == [([21, 22, 42],)] # If wrapped in a Value, it can convert to a struct - result = con.execute("select $1", [Value(('a', 21, True), {'a': str, 'b': int, 'c': bool})]).fetchall() - assert result == [({'a': 'a', 'b': 21, 'c': True},)] + result = con.execute("select $1", [Value(("a", 21, True), {"a": str, "b": int, "c": bool})]).fetchall() + assert result == [({"a": "a", "b": 21, "c": True},)] # If the amount of items in the tuple and the children of the struct don't match # we throw an error with pytest.raises( duckdb.InvalidInputException, - match='Tried to create a STRUCT value from a tuple containing 3 elements, but the STRUCT consists of 2 children', + match="Tried to create a STRUCT value from a tuple containing 3 elements, but the STRUCT consists of 2 children", ): - result = con.execute("select $1", [Value(('a', 21, True), {'a': str, 'b': int})]).fetchall() + result = con.execute("select $1", [Value(("a", 21, True), {"a": str, "b": int})]).fetchall() # If we try to create anything other than a STRUCT or a LIST out of the tuple, we throw an error with pytest.raises(duckdb.InvalidInputException, match="Can't convert tuple to a Value of type VARCHAR"): @@ -194,12 +194,12 @@ def test_conversion_from_tuple(self): def test_column_name_behavior(self, duckdb_cursor): _ = pytest.importorskip("pandas") - expected_names = ['one', 'ONE_1'] + expected_names = ["one", "ONE_1"] df = duckdb_cursor.execute('select 1 as one, 2 as "ONE"').fetchdf() assert expected_names == list(df.columns) - duckdb_cursor.register('tbl', df) + duckdb_cursor.register("tbl", df) df = duckdb_cursor.execute("select * from tbl").fetchdf() assert expected_names == list(df.columns) diff --git a/tests/fast/api/test_explain.py b/tests/fast/api/test_explain.py index 73c198b9..feedc134 100644 --- a/tests/fast/api/test_explain.py +++ b/tests/fast/api/test_explain.py @@ -4,40 +4,40 @@ class TestExplain(object): def test_explain_basic(self, duckdb_cursor): - res = duckdb_cursor.sql('select 42').explain() + res = duckdb_cursor.sql("select 42").explain() assert isinstance(res, str) def test_explain_standard(self, duckdb_cursor): - res = duckdb_cursor.sql('select 42').explain('standard') + res = duckdb_cursor.sql("select 42").explain("standard") assert isinstance(res, str) - res = duckdb_cursor.sql('select 42').explain('STANDARD') + res = duckdb_cursor.sql("select 42").explain("STANDARD") assert isinstance(res, str) - res = duckdb_cursor.sql('select 42').explain(duckdb.STANDARD) + res = duckdb_cursor.sql("select 42").explain(duckdb.STANDARD) assert isinstance(res, str) - res = duckdb_cursor.sql('select 42').explain(duckdb.ExplainType.STANDARD) + res = duckdb_cursor.sql("select 42").explain(duckdb.ExplainType.STANDARD) assert isinstance(res, str) - res = duckdb_cursor.sql('select 42').explain(0) + res = duckdb_cursor.sql("select 42").explain(0) assert isinstance(res, str) def test_explain_analyze(self, duckdb_cursor): - res = duckdb_cursor.sql('select 42').explain('analyze') + res = duckdb_cursor.sql("select 42").explain("analyze") assert isinstance(res, str) - res = duckdb_cursor.sql('select 42').explain('ANALYZE') + res = duckdb_cursor.sql("select 42").explain("ANALYZE") assert isinstance(res, str) - res = duckdb_cursor.sql('select 42').explain(duckdb.ExplainType.ANALYZE) + res = duckdb_cursor.sql("select 42").explain(duckdb.ExplainType.ANALYZE) assert isinstance(res, str) - res = duckdb_cursor.sql('select 42').explain(1) + res = duckdb_cursor.sql("select 42").explain(1) assert isinstance(res, str) def test_explain_df(self, duckdb_cursor): pd = pytest.importorskip("pandas") - df = pd.DataFrame({'a': [42]}) - res = duckdb_cursor.sql('select * from df').explain('ANALYZE') + df = pd.DataFrame({"a": [42]}) + res = duckdb_cursor.sql("select * from df").explain("ANALYZE") assert isinstance(res, str) diff --git a/tests/fast/api/test_fsspec.py b/tests/fast/api/test_fsspec.py index a878fda5..7b797598 100644 --- a/tests/fast/api/test_fsspec.py +++ b/tests/fast/api/test_fsspec.py @@ -49,7 +49,7 @@ def __init__(self) -> None: self._data = {"a": parquet_data, "b": parquet_data} fsspec.register_implementation("deadlock", TestFileSystem, clobber=True) - fs = fsspec.filesystem('deadlock') + fs = fsspec.filesystem("deadlock") duckdb_cursor.register_filesystem(fs) result = duckdb_cursor.read_parquet(file_globs=["deadlock://a", "deadlock://b"], union_by_name=True) diff --git a/tests/fast/api/test_insert_into.py b/tests/fast/api/test_insert_into.py index e6d4c6ba..2537c182 100644 --- a/tests/fast/api/test_insert_into.py +++ b/tests/fast/api/test_insert_into.py @@ -7,22 +7,22 @@ class TestInsertInto(object): def test_insert_into_schema(self, duckdb_cursor): # open connection con = duckdb.connect() - con.execute('CREATE SCHEMA s') - con.execute('CREATE TABLE s.t (id INTEGER PRIMARY KEY)') + con.execute("CREATE SCHEMA s") + con.execute("CREATE TABLE s.t (id INTEGER PRIMARY KEY)") # make relation - df = DataFrame([1], columns=['id']) + df = DataFrame([1], columns=["id"]) rel = con.from_df(df) - rel.insert_into('s.t') + rel.insert_into("s.t") assert con.execute("select * from s.t").fetchall() == [(1,)] # This should fail since this will go to default schema with pytest.raises(duckdb.CatalogException): - rel.insert_into('t') + rel.insert_into("t") # If we add t in the default schema it should work. - con.execute('CREATE TABLE t (id INTEGER PRIMARY KEY)') - rel.insert_into('t') + con.execute("CREATE TABLE t (id INTEGER PRIMARY KEY)") + rel.insert_into("t") assert con.execute("select * from t").fetchall() == [(1,)] diff --git a/tests/fast/api/test_join.py b/tests/fast/api/test_join.py index 7d7f45c2..5e2a148f 100644 --- a/tests/fast/api/test_join.py +++ b/tests/fast/api/test_join.py @@ -8,7 +8,7 @@ def test_alias_from_sql(self): rel1 = con.sql("SELECT 1 AS col1, 2 AS col2") rel2 = con.sql("SELECT 1 AS col1, 3 AS col3") - rel = con.sql('select * from rel1 JOIN rel2 USING (col1)') + rel = con.sql("select * from rel1 JOIN rel2 USING (col1)") rel.show() res = rel.fetchall() assert res == [(1, 2, 3)] @@ -19,27 +19,27 @@ def test_relational_join(self): rel1 = con.sql("SELECT 1 AS col1, 2 AS col2") rel2 = con.sql("SELECT 1 AS col1, 3 AS col3") - rel = rel1.join(rel2, 'col1') + rel = rel1.join(rel2, "col1") res = rel.fetchall() assert res == [(1, 2, 3)] def test_relational_join_alias_collision(self): con = duckdb.connect() - rel1 = con.sql("SELECT 1 AS col1, 2 AS col2").set_alias('a') - rel2 = con.sql("SELECT 1 AS col1, 3 AS col3").set_alias('a') + rel1 = con.sql("SELECT 1 AS col1, 2 AS col2").set_alias("a") + rel2 = con.sql("SELECT 1 AS col1, 3 AS col3").set_alias("a") - with pytest.raises(duckdb.InvalidInputException, match='Both relations have the same alias'): - rel = rel1.join(rel2, 'col1') + with pytest.raises(duckdb.InvalidInputException, match="Both relations have the same alias"): + rel = rel1.join(rel2, "col1") def test_relational_join_with_condition(self): con = duckdb.connect() - rel1 = con.sql("SELECT 1 AS col1, 2 AS col2", alias='rel1') - rel2 = con.sql("SELECT 1 AS col1, 3 AS col3", alias='rel2') + rel1 = con.sql("SELECT 1 AS col1, 2 AS col2", alias="rel1") + rel2 = con.sql("SELECT 1 AS col1, 3 AS col3", alias="rel2") # This makes a USING clause, which is kind of unexpected behavior - rel = rel1.join(rel2, 'rel1.col1 = rel2.col1') + rel = rel1.join(rel2, "rel1.col1 = rel2.col1") rel.show() res = rel.fetchall() assert res == [(1, 2, 1, 3)] @@ -49,8 +49,8 @@ def test_deduplicated_bindings(self, duckdb_cursor): duckdb_cursor.execute("create table old as select * from (values ('42', 1), ('21', 2)) t(a, b)") duckdb_cursor.execute("create table old_1 as select * from (values ('42', 3), ('21', 4)) t(a, b)") - old = duckdb_cursor.table('old') - old_1 = duckdb_cursor.table('old_1') + old = duckdb_cursor.table("old") + old_1 = duckdb_cursor.table("old_1") join_one = old.join(old_1, "old.a == old_1.a") join_two = old.join(old_1, "old.a == old_1.a") diff --git a/tests/fast/api/test_native_tz.py b/tests/fast/api/test_native_tz.py index 6098ca08..f4a9d716 100644 --- a/tests/fast/api/test_native_tz.py +++ b/tests/fast/api/test_native_tz.py @@ -8,7 +8,7 @@ pa = pytest.importorskip("pyarrow") from packaging.version import Version -filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'data', 'tz.parquet') +filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "data", "tz.parquet") class TestNativeTimeZone(object): @@ -16,20 +16,20 @@ def test_native_python_timestamp_timezone(self, duckdb_cursor): duckdb_cursor.execute("SET timezone='America/Los_Angeles';") res = duckdb_cursor.execute(f"select TimeRecStart as tz from '{filename}'").fetchone() assert res[0].hour == 14 and res[0].minute == 52 - assert res[0].tzinfo.zone == 'America/Los_Angeles' + assert res[0].tzinfo.zone == "America/Los_Angeles" res = duckdb_cursor.execute(f"select TimeRecStart as tz from '{filename}'").fetchall()[0] assert res[0].hour == 14 and res[0].minute == 52 - assert res[0].tzinfo.zone == 'America/Los_Angeles' + assert res[0].tzinfo.zone == "America/Los_Angeles" res = duckdb_cursor.execute(f"select TimeRecStart as tz from '{filename}'").fetchmany(1)[0] assert res[0].hour == 14 and res[0].minute == 52 - assert res[0].tzinfo.zone == 'America/Los_Angeles' + assert res[0].tzinfo.zone == "America/Los_Angeles" duckdb_cursor.execute("SET timezone='UTC';") res = duckdb_cursor.execute(f"select TimeRecStart as tz from '{filename}'").fetchone() assert res[0].hour == 21 and res[0].minute == 52 - assert res[0].tzinfo.zone == 'UTC' + assert res[0].tzinfo.zone == "UTC" def test_native_python_time_timezone(self, duckdb_cursor): res = duckdb_cursor.execute(f"select TimeRecStart::TIMETZ as tz from '{filename}'").fetchone() @@ -41,33 +41,33 @@ def test_native_python_time_timezone(self, duckdb_cursor): def test_pandas_timestamp_timezone(self, duckdb_cursor): res = duckdb_cursor.execute("SET timezone='America/Los_Angeles';") res = duckdb_cursor.execute(f"select TimeRecStart as tz from '{filename}'").df() - assert res.dtypes["tz"].tz.zone == 'America/Los_Angeles' - assert res['tz'][0].hour == 14 and res['tz'][0].minute == 52 + assert res.dtypes["tz"].tz.zone == "America/Los_Angeles" + assert res["tz"][0].hour == 14 and res["tz"][0].minute == 52 duckdb_cursor.execute("SET timezone='UTC';") res = duckdb_cursor.execute(f"select TimeRecStart as tz from '{filename}'").df() - assert res['tz'][0].hour == 21 and res['tz'][0].minute == 52 + assert res["tz"][0].hour == 21 and res["tz"][0].minute == 52 def test_pandas_timestamp_time(self, duckdb_cursor): with pytest.raises( - duckdb.NotImplementedException, match="Not implemented Error: Unsupported type \"TIME WITH TIME ZONE\"" + duckdb.NotImplementedException, match='Not implemented Error: Unsupported type "TIME WITH TIME ZONE"' ): duckdb_cursor.execute(f"select TimeRecStart::TIMETZ as tz from '{filename}'").df() @pytest.mark.skipif( - Version(pa.__version__) < Version('15.0.0'), reason="pyarrow 14.0.2 'to_pandas' causes a DeprecationWarning" + Version(pa.__version__) < Version("15.0.0"), reason="pyarrow 14.0.2 'to_pandas' causes a DeprecationWarning" ) def test_arrow_timestamp_timezone(self, duckdb_cursor): res = duckdb_cursor.execute("SET timezone='America/Los_Angeles';") table = duckdb_cursor.execute(f"select TimeRecStart as tz from '{filename}'").fetch_arrow_table() res = table.to_pandas() - assert res.dtypes["tz"].tz.zone == 'America/Los_Angeles' - assert res['tz'][0].hour == 14 and res['tz'][0].minute == 52 + assert res.dtypes["tz"].tz.zone == "America/Los_Angeles" + assert res["tz"][0].hour == 14 and res["tz"][0].minute == 52 duckdb_cursor.execute("SET timezone='UTC';") res = duckdb_cursor.execute(f"select TimeRecStart as tz from '{filename}'").fetch_arrow_table().to_pandas() - assert res.dtypes["tz"].tz.zone == 'UTC' - assert res['tz'][0].hour == 21 and res['tz'][0].minute == 52 + assert res.dtypes["tz"].tz.zone == "UTC" + assert res["tz"][0].hour == 21 and res["tz"][0].minute == 52 def test_arrow_timestamp_time(self, duckdb_cursor): duckdb_cursor.execute("SET timezone='America/Los_Angeles';") @@ -81,8 +81,8 @@ def test_arrow_timestamp_time(self, duckdb_cursor): .fetch_arrow_table() .to_pandas() ) - assert res1['tz'][0].hour == 14 and res1['tz'][0].minute == 52 - assert res2['tz'][0].hour == res2['tz'][0].hour and res2['tz'][0].minute == res1['tz'][0].minute + assert res1["tz"][0].hour == 14 and res1["tz"][0].minute == 52 + assert res2["tz"][0].hour == res2["tz"][0].hour and res2["tz"][0].minute == res1["tz"][0].minute duckdb_cursor.execute("SET timezone='UTC';") res1 = ( @@ -95,5 +95,5 @@ def test_arrow_timestamp_time(self, duckdb_cursor): .fetch_arrow_table() .to_pandas() ) - assert res1['tz'][0].hour == 21 and res1['tz'][0].minute == 52 - assert res2['tz'][0].hour == res2['tz'][0].hour and res2['tz'][0].minute == res1['tz'][0].minute + assert res1["tz"][0].hour == 21 and res1["tz"][0].minute == 52 + assert res2["tz"][0].hour == res2["tz"][0].hour and res2["tz"][0].minute == res1["tz"][0].minute diff --git a/tests/fast/api/test_query_interrupt.py b/tests/fast/api/test_query_interrupt.py index 6334e475..e6d2b998 100644 --- a/tests/fast/api/test_query_interrupt.py +++ b/tests/fast/api/test_query_interrupt.py @@ -25,7 +25,7 @@ def test_query_interruption(self): # Start the thread thread.start() try: - res = con.execute('select count(*) from range(100000000000)').fetchall() + res = con.execute("select count(*) from range(100000000000)").fetchall() except RuntimeError: # If this is not reached, we could not cancel the query before it completed # indicating that the query interruption functionality is broken diff --git a/tests/fast/api/test_read_csv.py b/tests/fast/api/test_read_csv.py index 7337515d..dff90869 100644 --- a/tests/fast/api/test_read_csv.py +++ b/tests/fast/api/test_read_csv.py @@ -11,7 +11,7 @@ def TestFile(name): import os - filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'data', name) + filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "data", name) return filename @@ -35,262 +35,262 @@ def create_temp_csv(tmp_path): class TestReadCSV(object): def test_using_connection_wrapper(self): - rel = duckdb.read_csv(TestFile('category.csv')) + rel = duckdb.read_csv(TestFile("category.csv")) res = rel.fetchone() print(res) - assert res == (1, 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + assert res == (1, "Action", datetime.datetime(2006, 2, 15, 4, 46, 27)) def test_using_connection_wrapper_with_keyword(self): - rel = duckdb.read_csv(TestFile('category.csv'), dtype={'category_id': 'string'}) + rel = duckdb.read_csv(TestFile("category.csv"), dtype={"category_id": "string"}) res = rel.fetchone() print(res) - assert res == ('1', 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + assert res == ("1", "Action", datetime.datetime(2006, 2, 15, 4, 46, 27)) def test_no_options(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv')) + rel = duckdb_cursor.read_csv(TestFile("category.csv")) res = rel.fetchone() print(res) - assert res == (1, 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + assert res == (1, "Action", datetime.datetime(2006, 2, 15, 4, 46, 27)) def test_dtype(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), dtype={'category_id': 'string'}) + rel = duckdb_cursor.read_csv(TestFile("category.csv"), dtype={"category_id": "string"}) res = rel.fetchone() print(res) - assert res == ('1', 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + assert res == ("1", "Action", datetime.datetime(2006, 2, 15, 4, 46, 27)) def test_dtype_as_list(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), dtype=['string']) + rel = duckdb_cursor.read_csv(TestFile("category.csv"), dtype=["string"]) res = rel.fetchone() print(res) - assert res == ('1', 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + assert res == ("1", "Action", datetime.datetime(2006, 2, 15, 4, 46, 27)) - rel = duckdb_cursor.read_csv(TestFile('category.csv'), dtype=['double']) + rel = duckdb_cursor.read_csv(TestFile("category.csv"), dtype=["double"]) res = rel.fetchone() print(res) - assert res == (1.0, 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + assert res == (1.0, "Action", datetime.datetime(2006, 2, 15, 4, 46, 27)) def test_sep(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), sep=" ") + rel = duckdb_cursor.read_csv(TestFile("category.csv"), sep=" ") res = rel.fetchone() print(res) - assert res == ('1|Action|2006-02-15', datetime.time(4, 46, 27)) + assert res == ("1|Action|2006-02-15", datetime.time(4, 46, 27)) def test_delimiter(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), delimiter=" ") + rel = duckdb_cursor.read_csv(TestFile("category.csv"), delimiter=" ") res = rel.fetchone() print(res) - assert res == ('1|Action|2006-02-15', datetime.time(4, 46, 27)) + assert res == ("1|Action|2006-02-15", datetime.time(4, 46, 27)) def test_delimiter_and_sep(self, duckdb_cursor): with pytest.raises(duckdb.InvalidInputException, match="read_csv takes either 'delimiter' or 'sep', not both"): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), delimiter=" ", sep=" ") + rel = duckdb_cursor.read_csv(TestFile("category.csv"), delimiter=" ", sep=" ") def test_header_true(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv')) + rel = duckdb_cursor.read_csv(TestFile("category.csv")) res = rel.fetchone() print(res) - assert res == (1, 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + assert res == (1, "Action", datetime.datetime(2006, 2, 15, 4, 46, 27)) @pytest.mark.skip(reason="Issue #6011 needs to be fixed first, header=False doesn't work correctly") def test_header_false(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), header=False) + rel = duckdb_cursor.read_csv(TestFile("category.csv"), header=False) def test_na_values(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), na_values='Action') + rel = duckdb_cursor.read_csv(TestFile("category.csv"), na_values="Action") res = rel.fetchone() print(res) assert res == (1, None, datetime.datetime(2006, 2, 15, 4, 46, 27)) def test_na_values_list(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), na_values=['Action', 'Animation']) + rel = duckdb_cursor.read_csv(TestFile("category.csv"), na_values=["Action", "Animation"]) res = rel.fetchone() assert res == (1, None, datetime.datetime(2006, 2, 15, 4, 46, 27)) res = rel.fetchone() assert res == (2, None, datetime.datetime(2006, 2, 15, 4, 46, 27)) def test_skiprows(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), skiprows=1) + rel = duckdb_cursor.read_csv(TestFile("category.csv"), skiprows=1) res = rel.fetchone() print(res) - assert res == (1, 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + assert res == (1, "Action", datetime.datetime(2006, 2, 15, 4, 46, 27)) # We want to detect this at bind time def test_compression_wrong(self, duckdb_cursor): with pytest.raises(duckdb.Error, match="Input is not a GZIP stream"): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), compression='gzip') + rel = duckdb_cursor.read_csv(TestFile("category.csv"), compression="gzip") def test_quotechar(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('unquote_without_delimiter.csv'), quotechar="", header=False) + rel = duckdb_cursor.read_csv(TestFile("unquote_without_delimiter.csv"), quotechar="", header=False) res = rel.fetchone() print(res) assert res == ('"AAA"BB',) def test_quote(self, duckdb_cursor): with pytest.raises( - duckdb.Error, match="The methods read_csv and read_csv_auto do not have the \"quote\" argument." + duckdb.Error, match='The methods read_csv and read_csv_auto do not have the "quote" argument.' ): - rel = duckdb_cursor.read_csv(TestFile('unquote_without_delimiter.csv'), quote="", header=False) + rel = duckdb_cursor.read_csv(TestFile("unquote_without_delimiter.csv"), quote="", header=False) def test_escapechar(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('quote_escape.csv'), escapechar=";", header=False) + rel = duckdb_cursor.read_csv(TestFile("quote_escape.csv"), escapechar=";", header=False) res = rel.limit(1, 1).fetchone() print(res) - assert res == ('345', 'TEST6', '"text""2""text"') + assert res == ("345", "TEST6", '"text""2""text"') def test_encoding_wrong(self, duckdb_cursor): with pytest.raises( duckdb.BinderException, match="Copy is only supported for UTF-8 encoded files, ENCODING 'UTF-8'" ): - rel = duckdb_cursor.read_csv(TestFile('quote_escape.csv'), encoding=";") + rel = duckdb_cursor.read_csv(TestFile("quote_escape.csv"), encoding=";") def test_encoding_correct(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('quote_escape.csv'), encoding="UTF-8") + rel = duckdb_cursor.read_csv(TestFile("quote_escape.csv"), encoding="UTF-8") res = rel.limit(1, 1).fetchone() print(res) - assert res == (345, 'TEST6', 'text"2"text') + assert res == (345, "TEST6", 'text"2"text') def test_date_format_as_datetime(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('datetime.csv')) + rel = duckdb_cursor.read_csv(TestFile("datetime.csv")) res = rel.fetchone() print(res) assert res == ( 123, - 'TEST2', + "TEST2", datetime.time(12, 12, 12), datetime.date(2000, 1, 1), datetime.datetime(2000, 1, 1, 12, 12), ) def test_date_format_as_date(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('datetime.csv'), date_format='%Y-%m-%d') + rel = duckdb_cursor.read_csv(TestFile("datetime.csv"), date_format="%Y-%m-%d") res = rel.fetchone() print(res) assert res == ( 123, - 'TEST2', + "TEST2", datetime.time(12, 12, 12), datetime.date(2000, 1, 1), datetime.datetime(2000, 1, 1, 12, 12), ) def test_timestamp_format(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('datetime.csv'), timestamp_format='%Y-%m-%d %H:%M:%S') + rel = duckdb_cursor.read_csv(TestFile("datetime.csv"), timestamp_format="%Y-%m-%d %H:%M:%S") res = rel.fetchone() assert res == ( 123, - 'TEST2', + "TEST2", datetime.time(12, 12, 12), datetime.date(2000, 1, 1), datetime.datetime(2000, 1, 1, 12, 12), ) def test_sample_size_correct(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('problematic.csv'), sample_size=-1) + rel = duckdb_cursor.read_csv(TestFile("problematic.csv"), sample_size=-1) res = rel.fetchone() print(res) - assert res == ('1', '1', '1') + assert res == ("1", "1", "1") def test_all_varchar(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), all_varchar=True) + rel = duckdb_cursor.read_csv(TestFile("category.csv"), all_varchar=True) res = rel.fetchone() print(res) - assert res == ('1', 'Action', '2006-02-15 04:46:27') + assert res == ("1", "Action", "2006-02-15 04:46:27") def test_null_padding(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('nullpadding.csv'), null_padding=False, header=False) + rel = duckdb_cursor.read_csv(TestFile("nullpadding.csv"), null_padding=False, header=False) res = rel.fetchall() assert res == [ - ('# this file has a bunch of gunk at the top',), - ('one,two,three,four',), - ('1,a,alice',), - ('2,b,bob',), + ("# this file has a bunch of gunk at the top",), + ("one,two,three,four",), + ("1,a,alice",), + ("2,b,bob",), ] - rel = duckdb_cursor.read_csv(TestFile('nullpadding.csv'), null_padding=True, header=False) + rel = duckdb_cursor.read_csv(TestFile("nullpadding.csv"), null_padding=True, header=False) res = rel.fetchall() assert res == [ - ('# this file has a bunch of gunk at the top', None, None, None), - ('one', 'two', 'three', 'four'), - ('1', 'a', 'alice', None), - ('2', 'b', 'bob', None), + ("# this file has a bunch of gunk at the top", None, None, None), + ("one", "two", "three", "four"), + ("1", "a", "alice", None), + ("2", "b", "bob", None), ] - rel = duckdb.read_csv(TestFile('nullpadding.csv'), null_padding=False, header=False) + rel = duckdb.read_csv(TestFile("nullpadding.csv"), null_padding=False, header=False) res = rel.fetchall() assert res == [ - ('# this file has a bunch of gunk at the top',), - ('one,two,three,four',), - ('1,a,alice',), - ('2,b,bob',), + ("# this file has a bunch of gunk at the top",), + ("one,two,three,four",), + ("1,a,alice",), + ("2,b,bob",), ] - rel = duckdb.read_csv(TestFile('nullpadding.csv'), null_padding=True, header=False) + rel = duckdb.read_csv(TestFile("nullpadding.csv"), null_padding=True, header=False) res = rel.fetchall() assert res == [ - ('# this file has a bunch of gunk at the top', None, None, None), - ('one', 'two', 'three', 'four'), - ('1', 'a', 'alice', None), - ('2', 'b', 'bob', None), + ("# this file has a bunch of gunk at the top", None, None, None), + ("one", "two", "three", "four"), + ("1", "a", "alice", None), + ("2", "b", "bob", None), ] - rel = duckdb_cursor.from_csv_auto(TestFile('nullpadding.csv'), null_padding=False, header=False) + rel = duckdb_cursor.from_csv_auto(TestFile("nullpadding.csv"), null_padding=False, header=False) res = rel.fetchall() assert res == [ - ('# this file has a bunch of gunk at the top',), - ('one,two,three,four',), - ('1,a,alice',), - ('2,b,bob',), + ("# this file has a bunch of gunk at the top",), + ("one,two,three,four",), + ("1,a,alice",), + ("2,b,bob",), ] - rel = duckdb_cursor.from_csv_auto(TestFile('nullpadding.csv'), null_padding=True, header=False) + rel = duckdb_cursor.from_csv_auto(TestFile("nullpadding.csv"), null_padding=True, header=False) res = rel.fetchall() assert res == [ - ('# this file has a bunch of gunk at the top', None, None, None), - ('one', 'two', 'three', 'four'), - ('1', 'a', 'alice', None), - ('2', 'b', 'bob', None), + ("# this file has a bunch of gunk at the top", None, None, None), + ("one", "two", "three", "four"), + ("1", "a", "alice", None), + ("2", "b", "bob", None), ] def test_normalize_names(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), normalize_names=False) + rel = duckdb_cursor.read_csv(TestFile("category.csv"), normalize_names=False) df = rel.df() column_names = list(df.columns.values) # The names are not normalized, so they are capitalized - assert 'CATEGORY_ID' in column_names + assert "CATEGORY_ID" in column_names - rel = duckdb_cursor.read_csv(TestFile('category.csv'), normalize_names=True) + rel = duckdb_cursor.read_csv(TestFile("category.csv"), normalize_names=True) df = rel.df() column_names = list(df.columns.values) # The capitalized names are normalized to lowercase instead - assert 'CATEGORY_ID' not in column_names + assert "CATEGORY_ID" not in column_names def test_filename(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), filename=False) + rel = duckdb_cursor.read_csv(TestFile("category.csv"), filename=False) df = rel.df() column_names = list(df.columns.values) # The filename is not included in the returned columns - assert 'filename' not in column_names + assert "filename" not in column_names - rel = duckdb_cursor.read_csv(TestFile('category.csv'), filename=True) + rel = duckdb_cursor.read_csv(TestFile("category.csv"), filename=True) df = rel.df() column_names = list(df.columns.values) # The filename is included in the returned columns - assert 'filename' in column_names + assert "filename" in column_names def test_read_pathlib_path(self, duckdb_cursor): pathlib = pytest.importorskip("pathlib") - path = pathlib.Path(TestFile('category.csv')) + path = pathlib.Path(TestFile("category.csv")) rel = duckdb_cursor.read_csv(path) res = rel.fetchone() print(res) - assert res == (1, 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + assert res == (1, "Action", datetime.datetime(2006, 2, 15, 4, 46, 27)) def test_read_filelike(self, duckdb_cursor): pytest.importorskip("fsspec") string = StringIO("c1,c2,c3\na,b,c") res = duckdb_cursor.read_csv(string).fetchall() - assert res == [('a', 'b', 'c')] + assert res == [("a", "b", "c")] def test_read_filelike_rel_out_of_scope(self, duckdb_cursor): _ = pytest.importorskip("fsspec") @@ -321,7 +321,7 @@ def test_filelike_bytesio(self, duckdb_cursor): _ = pytest.importorskip("fsspec") string = BytesIO(b"c1,c2,c3\na,b,c") res = duckdb_cursor.read_csv(string).fetchall() - assert res == [('a', 'b', 'c')] + assert res == [("a", "b", "c")] def test_filelike_exception(self, duckdb_cursor): _ = pytest.importorskip("fsspec") @@ -341,7 +341,7 @@ def __init__(self) -> None: pass def read(self, amount=-1): - return b'test' + return b"test" def seek(self, loc): raise ValueError(loc) @@ -377,7 +377,7 @@ def read(self, amount=-1): obj = CustomIO() res = duckdb_cursor.read_csv(obj).fetchall() - assert res == [('a', 'b', 'c')] + assert res == [("a", "b", "c")] def test_filelike_non_readable(self, duckdb_cursor): _ = pytest.importorskip("fsspec") @@ -410,9 +410,9 @@ def scoped_objects(duckdb_cursor): rel1 = duckdb_cursor.read_csv(obj) assert rel1.fetchall() == [ ( - 'a', - 'b', - 'c', + "a", + "b", + "c", ) ] assert CountedObject.instance_count == 1 @@ -421,9 +421,9 @@ def scoped_objects(duckdb_cursor): rel2 = duckdb_cursor.read_csv(obj) assert rel2.fetchall() == [ ( - 'a', - 'b', - 'c', + "a", + "b", + "c", ) ] assert CountedObject.instance_count == 2 @@ -432,9 +432,9 @@ def scoped_objects(duckdb_cursor): rel3 = duckdb_cursor.read_csv(obj) assert rel3.fetchall() == [ ( - 'a', - 'b', - 'c', + "a", + "b", + "c", ) ] assert CountedObject.instance_count == 3 @@ -448,24 +448,24 @@ def test_read_csv_glob(self, tmp_path, create_temp_csv): # Use the temporary file paths to read CSV files con = duckdb.connect() - rel = con.read_csv(f'{tmp_path}/file*.csv') + rel = con.read_csv(f"{tmp_path}/file*.csv") res = con.sql("select * from rel order by all").fetchall() assert res == [(1,), (2,), (3,), (4,), (5,), (6,)] @pytest.mark.xfail(condition=platform.system() == "Emscripten", reason="time zones not working") def test_read_csv_combined(self, duckdb_cursor): - CSV_FILE = TestFile('stress_test.csv') + CSV_FILE = TestFile("stress_test.csv") COLUMNS = { - 'result': 'VARCHAR', - 'table': 'BIGINT', - '_time': 'TIMESTAMPTZ', - '_measurement': 'VARCHAR', - 'bench_test': 'VARCHAR', - 'flight_id': 'VARCHAR', - 'flight_status': 'VARCHAR', - 'log_level': 'VARCHAR', - 'sys_uuid': 'VARCHAR', - 'message': 'VARCHAR', + "result": "VARCHAR", + "table": "BIGINT", + "_time": "TIMESTAMPTZ", + "_measurement": "VARCHAR", + "bench_test": "VARCHAR", + "flight_id": "VARCHAR", + "flight_status": "VARCHAR", + "log_level": "VARCHAR", + "sys_uuid": "VARCHAR", + "message": "VARCHAR", } rel = duckdb.read_csv(CSV_FILE, skiprows=1, delimiter=",", quotechar='"', escapechar="\\", dtype=COLUMNS) @@ -483,39 +483,39 @@ def test_read_csv_combined(self, duckdb_cursor): def test_read_csv_names(self, tmp_path): file = tmp_path / "file.csv" - file.write_text('one,two,three,four\n1,2,3,4\n1,2,3,4\n1,2,3,4') + file.write_text("one,two,three,four\n1,2,3,4\n1,2,3,4\n1,2,3,4") con = duckdb.connect() - rel = con.read_csv(str(file), names=['a', 'b', 'c']) - assert rel.columns == ['a', 'b', 'c', 'four'] + rel = con.read_csv(str(file), names=["a", "b", "c"]) + assert rel.columns == ["a", "b", "c", "four"] with pytest.raises(duckdb.InvalidInputException, match="read_csv only accepts 'names' as a list of strings"): rel = con.read_csv(file, names=True) with pytest.raises(duckdb.InvalidInputException, match="not possible to detect the CSV Header"): - rel = con.read_csv(file, names=['a', 'b', 'c', 'd', 'e']) + rel = con.read_csv(file, names=["a", "b", "c", "d", "e"]) # Duplicates are not okay with pytest.raises(duckdb.BinderException, match="names must have unique values"): - rel = con.read_csv(file, names=['a', 'b', 'a', 'b']) - assert rel.columns == ['a', 'b', 'a', 'b'] + rel = con.read_csv(file, names=["a", "b", "a", "b"]) + assert rel.columns == ["a", "b", "a", "b"] def test_read_csv_names_mixed_with_dtypes(self, tmp_path): file = tmp_path / "file.csv" - file.write_text('one,two,three,four\n1,2,3,4\n1,2,3,4\n1,2,3,4') + file.write_text("one,two,three,four\n1,2,3,4\n1,2,3,4\n1,2,3,4") con = duckdb.connect() rel = con.read_csv( file, - names=['a', 'b', 'c'], + names=["a", "b", "c"], dtype={ - 'a': int, - 'b': bool, - 'c': str, + "a": int, + "b": bool, + "c": str, }, ) - assert rel.columns == ['a', 'b', 'c', 'four'] - assert rel.types == ['BIGINT', 'BOOLEAN', 'VARCHAR', 'BIGINT'] + assert rel.columns == ["a", "b", "c", "four"] + assert rel.types == ["BIGINT", "BOOLEAN", "VARCHAR", "BIGINT"] # dtypes and names dont match # FIXME: seems the order columns are named in this error is non-deterministic @@ -524,23 +524,23 @@ def test_read_csv_names_mixed_with_dtypes(self, tmp_path): with pytest.raises(duckdb.BinderException, match=expected_error): rel = con.read_csv( file, - names=['a', 'b', 'c'], + names=["a", "b", "c"], dtype={ - 'd': int, - 'e': bool, - 'f': str, + "d": int, + "e": bool, + "f": str, }, ) def test_read_csv_multi_file(self, tmp_path): file1 = tmp_path / "file1.csv" - file1.write_text('one,two,three,four\n1,2,3,4\n1,2,3,4\n1,2,3,4') + file1.write_text("one,two,three,four\n1,2,3,4\n1,2,3,4\n1,2,3,4") file2 = tmp_path / "file2.csv" - file2.write_text('one,two,three,four\n5,6,7,8\n5,6,7,8\n5,6,7,8') + file2.write_text("one,two,three,four\n5,6,7,8\n5,6,7,8\n5,6,7,8") file3 = tmp_path / "file3.csv" - file3.write_text('one,two,three,four\n9,10,11,12\n9,10,11,12\n9,10,11,12') + file3.write_text("one,two,three,four\n9,10,11,12\n9,10,11,12\n9,10,11,12") con = duckdb.connect() files = [str(file1), str(file2), str(file3)] @@ -562,72 +562,72 @@ def test_read_csv_empty_list(self): con = duckdb.connect() files = [] with pytest.raises( - duckdb.InvalidInputException, match='Please provide a non-empty list of paths or file-like objects' + duckdb.InvalidInputException, match="Please provide a non-empty list of paths or file-like objects" ): rel = con.read_csv(files) res = rel.fetchall() def test_read_auto_detect(self, tmp_path): file1 = tmp_path / "file1.csv" - file1.write_text('one|two|three|four\n1|2|3|4') + file1.write_text("one|two|three|four\n1|2|3|4") con = duckdb.connect() - rel = con.read_csv(str(file1), columns={'a': 'VARCHAR'}, auto_detect=False, header=False) - assert rel.fetchall() == [('one|two|three|four',), ('1|2|3|4',)] + rel = con.read_csv(str(file1), columns={"a": "VARCHAR"}, auto_detect=False, header=False) + assert rel.fetchall() == [("one|two|three|four",), ("1|2|3|4",)] def test_read_csv_list_invalid_path(self, tmp_path): con = duckdb.connect() file1 = tmp_path / "file1.csv" - file1.write_text('one,two,three,four\n1,2,3,4\n1,2,3,4\n1,2,3,4') + file1.write_text("one,two,three,four\n1,2,3,4\n1,2,3,4\n1,2,3,4") file3 = tmp_path / "file3.csv" - file3.write_text('one,two,three,four\n9,10,11,12\n9,10,11,12\n9,10,11,12') + file3.write_text("one,two,three,four\n9,10,11,12\n9,10,11,12\n9,10,11,12") - files = [str(file1), 'not_valid_path', str(file3)] + files = [str(file1), "not_valid_path", str(file3)] with pytest.raises(duckdb.IOException, match='No files found that match the pattern "not_valid_path"'): rel = con.read_csv(files) res = rel.fetchall() @pytest.mark.parametrize( - 'options', + "options", [ - {'lineterminator': '\\n'}, - {'lineterminator': 'LINE_FEED'}, - {'lineterminator': CSVLineTerminator.LINE_FEED}, - {'columns': {'id': 'INTEGER', 'name': 'INTEGER', 'c': 'integer', 'd': 'INTEGER'}}, - {'auto_type_candidates': ['INTEGER', 'INTEGER']}, - {'max_line_size': 10000}, - {'ignore_errors': True}, - {'ignore_errors': False}, - {'store_rejects': True}, - {'store_rejects': False}, - {'rejects_table': 'my_rejects_table'}, - {'rejects_scan': 'my_rejects_scan'}, - {'rejects_table': 'my_rejects_table', 'rejects_limit': 50}, - {'force_not_null': ['one', 'two']}, - {'buffer_size': 2097153}, - {'decimal': '.'}, - {'allow_quoted_nulls': True}, - {'allow_quoted_nulls': False}, - {'filename': True}, - {'filename': 'test'}, - {'hive_partitioning': True}, - {'hive_partitioning': False}, - {'union_by_name': True}, - {'union_by_name': False}, - {'hive_types_autocast': False}, - {'hive_types_autocast': True}, - {'hive_types': {'one': 'INTEGER', 'two': 'VARCHAR'}}, + {"lineterminator": "\\n"}, + {"lineterminator": "LINE_FEED"}, + {"lineterminator": CSVLineTerminator.LINE_FEED}, + {"columns": {"id": "INTEGER", "name": "INTEGER", "c": "integer", "d": "INTEGER"}}, + {"auto_type_candidates": ["INTEGER", "INTEGER"]}, + {"max_line_size": 10000}, + {"ignore_errors": True}, + {"ignore_errors": False}, + {"store_rejects": True}, + {"store_rejects": False}, + {"rejects_table": "my_rejects_table"}, + {"rejects_scan": "my_rejects_scan"}, + {"rejects_table": "my_rejects_table", "rejects_limit": 50}, + {"force_not_null": ["one", "two"]}, + {"buffer_size": 2097153}, + {"decimal": "."}, + {"allow_quoted_nulls": True}, + {"allow_quoted_nulls": False}, + {"filename": True}, + {"filename": "test"}, + {"hive_partitioning": True}, + {"hive_partitioning": False}, + {"union_by_name": True}, + {"union_by_name": False}, + {"hive_types_autocast": False}, + {"hive_types_autocast": True}, + {"hive_types": {"one": "INTEGER", "two": "VARCHAR"}}, ], ) @pytest.mark.skipif(sys.platform.startswith("win"), reason="Skipping on Windows because of lineterminator option") def test_read_csv_options(self, duckdb_cursor, options, tmp_path): file = tmp_path / "file.csv" - file.write_text('one,two,three,four\n1,2,3,4\n1,2,3,4\n1,2,3,4') + file.write_text("one,two,three,four\n1,2,3,4\n1,2,3,4\n1,2,3,4") print(options) - if 'hive_types' in options: - with pytest.raises(duckdb.InvalidInputException, match=r'Unknown hive_type:'): + if "hive_types" in options: + with pytest.raises(duckdb.InvalidInputException, match=r"Unknown hive_type:"): rel = duckdb_cursor.read_csv(file, **options) else: rel = duckdb_cursor.read_csv(file, **options) @@ -635,73 +635,73 @@ def test_read_csv_options(self, duckdb_cursor, options, tmp_path): def test_read_comment(self, tmp_path): file1 = tmp_path / "file1.csv" - file1.write_text('one|two|three|four\n1|2|3|4#|5|6\n#bla\n1|2|3|4\n') + file1.write_text("one|two|three|four\n1|2|3|4#|5|6\n#bla\n1|2|3|4\n") con = duckdb.connect() - rel = con.read_csv(str(file1), columns={'a': 'VARCHAR'}, auto_detect=False, header=False, comment='#') - assert rel.fetchall() == [('one|two|three|four',), ('1|2|3|4',), ('1|2|3|4',)] + rel = con.read_csv(str(file1), columns={"a": "VARCHAR"}, auto_detect=False, header=False, comment="#") + assert rel.fetchall() == [("one|two|three|four",), ("1|2|3|4",), ("1|2|3|4",)] def test_read_enum(self, tmp_path): file1 = tmp_path / "file1.csv" - file1.write_text('feelings\nhappy\nsad\nangry\nhappy\n') + file1.write_text("feelings\nhappy\nsad\nangry\nhappy\n") con = duckdb.connect() con.execute("CREATE TYPE mood AS ENUM ('happy', 'sad', 'angry')") - rel = con.read_csv(str(file1), dtype=['mood']) - assert rel.fetchall() == [('happy',), ('sad',), ('angry',), ('happy',)] + rel = con.read_csv(str(file1), dtype=["mood"]) + assert rel.fetchall() == [("happy",), ("sad",), ("angry",), ("happy",)] - rel = con.read_csv(str(file1), dtype={'feelings': 'mood'}) - assert rel.fetchall() == [('happy',), ('sad',), ('angry',), ('happy',)] + rel = con.read_csv(str(file1), dtype={"feelings": "mood"}) + assert rel.fetchall() == [("happy",), ("sad",), ("angry",), ("happy",)] - rel = con.read_csv(str(file1), columns={'feelings': 'mood'}) - assert rel.fetchall() == [('happy',), ('sad',), ('angry',), ('happy',)] + rel = con.read_csv(str(file1), columns={"feelings": "mood"}) + assert rel.fetchall() == [("happy",), ("sad",), ("angry",), ("happy",)] with pytest.raises(duckdb.CatalogException, match="Type with name mood_2 does not exist!"): - rel = con.read_csv(str(file1), columns={'feelings': 'mood_2'}) + rel = con.read_csv(str(file1), columns={"feelings": "mood_2"}) with pytest.raises(duckdb.CatalogException, match="Type with name mood_2 does not exist!"): - rel = con.read_csv(str(file1), dtype={'feelings': 'mood_2'}) + rel = con.read_csv(str(file1), dtype={"feelings": "mood_2"}) with pytest.raises(duckdb.CatalogException, match="Type with name mood_2 does not exist!"): - rel = con.read_csv(str(file1), dtype=['mood_2']) + rel = con.read_csv(str(file1), dtype=["mood_2"]) def test_strict_mode(self, tmp_path): file1 = tmp_path / "file1.csv" - file1.write_text('one|two|three|four\n1|2|3|4\n1|2|3|4|5\n1|2|3|4\n') + file1.write_text("one|two|three|four\n1|2|3|4\n1|2|3|4|5\n1|2|3|4\n") con = duckdb.connect() with pytest.raises(duckdb.InvalidInputException, match="CSV Error on Line"): rel = con.read_csv( str(file1), header=True, - delimiter='|', - columns={'a': 'INTEGER', 'b': 'INTEGER', 'c': 'INTEGER', 'd': 'INTEGER'}, + delimiter="|", + columns={"a": "INTEGER", "b": "INTEGER", "c": "INTEGER", "d": "INTEGER"}, auto_detect=False, ) rel.fetchall() rel = con.read_csv( str(file1), header=True, - delimiter='|', + delimiter="|", strict_mode=False, - columns={'a': 'INTEGER', 'b': 'INTEGER', 'c': 'INTEGER', 'd': 'INTEGER'}, + columns={"a": "INTEGER", "b": "INTEGER", "c": "INTEGER", "d": "INTEGER"}, auto_detect=False, ) assert rel.fetchall() == [(1, 2, 3, 4), (1, 2, 3, 4), (1, 2, 3, 4)] def test_union_by_name(self, tmp_path): file1 = tmp_path / "file1.csv" - file1.write_text('one|two|three|four\n1|2|3|4') + file1.write_text("one|two|three|four\n1|2|3|4") file1 = tmp_path / "file2.csv" - file1.write_text('two|three|four|five\n2|3|4|5') + file1.write_text("two|three|four|five\n2|3|4|5") con = duckdb.connect() file_path = tmp_path / "file*.csv" rel = con.read_csv(file_path, union_by_name=True) - assert rel.columns == ['one', 'two', 'three', 'four', 'five'] + assert rel.columns == ["one", "two", "three", "four", "five"] assert rel.fetchall() == [(1, 2, 3, 4, None), (None, 2, 3, 4, 5)] def test_thousands_separator(self, tmp_path): @@ -709,27 +709,27 @@ def test_thousands_separator(self, tmp_path): file.write_text('money\n"10,000.23"\n"1,000,000,000.01"') con = duckdb.connect() - rel = con.read_csv(file, thousands=',') + rel = con.read_csv(file, thousands=",") assert rel.fetchall() == [(10000.23,), (1000000000.01,)] with pytest.raises( duckdb.BinderException, match="Unsupported parameter for THOUSANDS: should be max one character" ): - con.read_csv(file, thousands=',,,') + con.read_csv(file, thousands=",,,") def test_skip_comment_option(self, tmp_path): file1 = tmp_path / "file1.csv" - file1.write_text('skip this line\n# comment\nx,y,z\n1,2,3\n4,5,6') + file1.write_text("skip this line\n# comment\nx,y,z\n1,2,3\n4,5,6") con = duckdb.connect() - rel = con.read_csv(file1, comment='#', skiprows=1, all_varchar=True) - assert rel.columns == ['x', 'y', 'z'] - assert rel.fetchall() == [('1', '2', '3'), ('4', '5', '6')] + rel = con.read_csv(file1, comment="#", skiprows=1, all_varchar=True) + assert rel.columns == ["x", "y", "z"] + assert rel.fetchall() == [("1", "2", "3"), ("4", "5", "6")] def test_files_to_sniff_option(self, tmp_path): file1 = tmp_path / "file1.csv" - file1.write_text('bar,baz\n2025-05-12,baz') + file1.write_text("bar,baz\n2025-05-12,baz") file2 = tmp_path / "file2.csv" - file2.write_text('bar,baz\nbar,baz') + file2.write_text("bar,baz\nbar,baz") file_path = tmp_path / "file*.csv" con = duckdb.connect() @@ -737,4 +737,4 @@ def test_files_to_sniff_option(self, tmp_path): rel = con.read_csv(file_path, files_to_sniff=1) rel.fetchall() rel = con.read_csv(file_path, files_to_sniff=-1) - assert rel.fetchall() == [('2025-05-12', 'baz'), ('bar', 'baz')] + assert rel.fetchall() == [("2025-05-12", "baz"), ("bar", "baz")] diff --git a/tests/fast/api/test_relation_to_view.py b/tests/fast/api/test_relation_to_view.py index f4a43d54..31a19d54 100644 --- a/tests/fast/api/test_relation_to_view.py +++ b/tests/fast/api/test_relation_to_view.py @@ -4,27 +4,27 @@ class TestRelationToView(object): def test_values_to_view(self, duckdb_cursor): - rel = duckdb_cursor.values(['test', 'this is a long string']) + rel = duckdb_cursor.values(["test", "this is a long string"]) res = rel.fetchall() - assert res == [('test', 'this is a long string')] + assert res == [("test", "this is a long string")] - rel.to_view('vw1') + rel.to_view("vw1") - view = duckdb_cursor.table('vw1') + view = duckdb_cursor.table("vw1") res = view.fetchall() - assert res == [('test', 'this is a long string')] + assert res == [("test", "this is a long string")] def test_relation_to_view(self, duckdb_cursor): rel = duckdb_cursor.sql("select 'test', 'this is a long string'") res = rel.fetchall() - assert res == [('test', 'this is a long string')] + assert res == [("test", "this is a long string")] - rel.to_view('vw1') + rel.to_view("vw1") - view = duckdb_cursor.table('vw1') + view = duckdb_cursor.table("vw1") res = view.fetchall() - assert res == [('test', 'this is a long string')] + assert res == [("test", "this is a long string")] def test_registered_relation(self, duckdb_cursor): rel = duckdb_cursor.sql("select 'test', 'this is a long string'") @@ -33,12 +33,12 @@ def test_registered_relation(self, duckdb_cursor): # Register on a different connection is not allowed with pytest.raises( duckdb.InvalidInputException, - match='was created by another Connection and can therefore not be used by this Connection', + match="was created by another Connection and can therefore not be used by this Connection", ): - con.register('cross_connection', rel) + con.register("cross_connection", rel) # Register on the same connection just creates a view - duckdb_cursor.register('same_connection', rel) - view = duckdb_cursor.table('same_connection') + duckdb_cursor.register("same_connection", rel) + view = duckdb_cursor.table("same_connection") res = view.fetchall() - assert res == [('test', 'this is a long string')] + assert res == [("test", "this is a long string")] diff --git a/tests/fast/api/test_streaming_result.py b/tests/fast/api/test_streaming_result.py index e51f62e4..739fd17a 100644 --- a/tests/fast/api/test_streaming_result.py +++ b/tests/fast/api/test_streaming_result.py @@ -5,7 +5,7 @@ class TestStreamingResult(object): def test_fetch_one(self, duckdb_cursor): # fetch one - res = duckdb_cursor.sql('SELECT * FROM range(100000)') + res = duckdb_cursor.sql("SELECT * FROM range(100000)") result = [] while len(result) < 5000: tpl = res.fetchone() @@ -24,7 +24,7 @@ def test_fetch_one(self, duckdb_cursor): def test_fetch_many(self, duckdb_cursor): # fetch many - res = duckdb_cursor.sql('SELECT * FROM range(100000)') + res = duckdb_cursor.sql("SELECT * FROM range(100000)") result = [] while len(result) < 5000: tpl = res.fetchmany(10) @@ -45,11 +45,11 @@ def test_record_batch_reader(self, duckdb_cursor): pytest.importorskip("pyarrow") pytest.importorskip("pyarrow.dataset") # record batch reader - res = duckdb_cursor.sql('SELECT * FROM range(100000) t(i)') + res = duckdb_cursor.sql("SELECT * FROM range(100000) t(i)") reader = res.fetch_arrow_reader(batch_size=16_384) result = [] for batch in reader: - result += batch.to_pydict()['i'] + result += batch.to_pydict()["i"] assert result == list(range(100000)) # record batch reader with error @@ -60,9 +60,9 @@ def test_record_batch_reader(self, duckdb_cursor): reader = res.fetch_arrow_reader(batch_size=16_384) def test_9801(self, duckdb_cursor): - duckdb_cursor.execute('CREATE TABLE test(id INTEGER , name VARCHAR NOT NULL);') + duckdb_cursor.execute("CREATE TABLE test(id INTEGER , name VARCHAR NOT NULL);") - words = ['aaaaaaaaaaaaaaaaaaaaaaa', 'bbbb', 'ccccccccc', 'ííííííííí'] + words = ["aaaaaaaaaaaaaaaaaaaaaaa", "bbbb", "ccccccccc", "ííííííííí"] lines = [(i, words[i % 4]) for i in range(1000)] duckdb_cursor.executemany("INSERT INTO TEST (id, name) VALUES (?, ?)", lines) diff --git a/tests/fast/api/test_to_csv.py b/tests/fast/api/test_to_csv.py index e48ae1b8..5f8000a9 100644 --- a/tests/fast/api/test_to_csv.py +++ b/tests/fast/api/test_to_csv.py @@ -9,10 +9,10 @@ class TestToCSV(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_basic_to_csv(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': [5, 3, 23, 2], 'b': [45, 234, 234, 2]}) + df = pandas.DataFrame({"a": [5, 3, 23, 2], "b": [45, 234, 234, 2]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name) @@ -20,21 +20,21 @@ def test_basic_to_csv(self, pandas): csv_rel = duckdb.read_csv(temp_file_name) assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_sep(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': [5, 3, 23, 2], 'b': [45, 234, 234, 2]}) + df = pandas.DataFrame({"a": [5, 3, 23, 2], "b": [45, 234, 234, 2]}) rel = duckdb.from_df(df) - rel.to_csv(temp_file_name, sep=',') + rel.to_csv(temp_file_name, sep=",") - csv_rel = duckdb.read_csv(temp_file_name, sep=',') + csv_rel = duckdb.read_csv(temp_file_name, sep=",") assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_na_rep(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': [5, None, 23, 2], 'b': [45, 234, 234, 2]}) + df = pandas.DataFrame({"a": [5, None, 23, 2], "b": [45, 234, 234, 2]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, na_rep="test") @@ -42,10 +42,10 @@ def test_to_csv_na_rep(self, pandas): csv_rel = duckdb.read_csv(temp_file_name, na_values="test") assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_header(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': [5, None, 23, 2], 'b': [45, 234, 234, 2]}) + df = pandas.DataFrame({"a": [5, None, 23, 2], "b": [45, 234, 234, 2]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name) @@ -53,18 +53,18 @@ def test_to_csv_header(self, pandas): csv_rel = duckdb.read_csv(temp_file_name) assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_quotechar(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': ["\'a,b,c\'", None, "hello", "bye"], 'b': [45, 234, 234, 2]}) + df = pandas.DataFrame({"a": ["'a,b,c'", None, "hello", "bye"], "b": [45, 234, 234, 2]}) rel = duckdb.from_df(df) - rel.to_csv(temp_file_name, quotechar='\'', sep=',') + rel.to_csv(temp_file_name, quotechar="'", sep=",") - csv_rel = duckdb.read_csv(temp_file_name, sep=',', quotechar='\'') + csv_rel = duckdb.read_csv(temp_file_name, sep=",", quotechar="'") assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_escapechar(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) df = pandas.DataFrame( @@ -76,11 +76,11 @@ def test_to_csv_escapechar(self, pandas): } ) rel = duckdb.from_df(df) - rel.to_csv(temp_file_name, quotechar='"', escapechar='!') - csv_rel = duckdb.read_csv(temp_file_name, quotechar='"', escapechar='!') + rel.to_csv(temp_file_name, quotechar='"', escapechar="!") + csv_rel = duckdb.read_csv(temp_file_name, quotechar='"', escapechar="!") assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_date_format(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) df = pandas.DataFrame(getTimeSeriesData()) @@ -93,82 +93,82 @@ def test_to_csv_date_format(self, pandas): assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_timestamp_format(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) data = [datetime.time(hour=23, minute=1, second=34, microsecond=234345)] - df = pandas.DataFrame({'0': pandas.Series(data=data, dtype='object')}) + df = pandas.DataFrame({"0": pandas.Series(data=data, dtype="object")}) rel = duckdb.from_df(df) - rel.to_csv(temp_file_name, timestamp_format='%m/%d/%Y') + rel.to_csv(temp_file_name, timestamp_format="%m/%d/%Y") - csv_rel = duckdb.read_csv(temp_file_name, timestamp_format='%m/%d/%Y') + csv_rel = duckdb.read_csv(temp_file_name, timestamp_format="%m/%d/%Y") assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_quoting_off(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': ['string1', 'string2', 'string3']}) + df = pandas.DataFrame({"a": ["string1", "string2", "string3"]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, quoting=None) csv_rel = duckdb.read_csv(temp_file_name) assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_quoting_on(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': ['string1', 'string2', 'string3']}) + df = pandas.DataFrame({"a": ["string1", "string2", "string3"]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, quoting="force") csv_rel = duckdb.read_csv(temp_file_name) assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_quoting_quote_all(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': ['string1', 'string2', 'string3']}) + df = pandas.DataFrame({"a": ["string1", "string2", "string3"]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, quoting=csv.QUOTE_ALL) csv_rel = duckdb.read_csv(temp_file_name) assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_encoding_incorrect(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': ['string1', 'string2', 'string3']}) + df = pandas.DataFrame({"a": ["string1", "string2", "string3"]}) rel = duckdb.from_df(df) with pytest.raises( duckdb.InvalidInputException, match="Invalid Input Error: The only supported encoding option is 'UTF8" ): rel.to_csv(temp_file_name, encoding="nope") - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_encoding_correct(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': ['string1', 'string2', 'string3']}) + df = pandas.DataFrame({"a": ["string1", "string2", "string3"]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, encoding="UTF-8") csv_rel = duckdb.read_csv(temp_file_name) assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_compression_gzip(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': ['string1', 'string2', 'string3']}) + df = pandas.DataFrame({"a": ["string1", "string2", "string3"]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, compression="gzip") csv_rel = duckdb.read_csv(temp_file_name, compression="gzip") assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_partition(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) df = pandas.DataFrame( { - "c_category": ['a', 'a', 'b', 'b'], + "c_category": ["a", "a", "b", "b"], "c_bool": [True, False, True, True], "c_float": [1.0, 3.2, 3.0, 4.0], "c_int": [42, None, 123, 321], @@ -178,23 +178,23 @@ def test_to_csv_partition(self, pandas): rel = duckdb.from_df(df) rel.to_csv(temp_file_name, header=True, partition_by=["c_category"]) csv_rel = duckdb.sql( - f'''FROM read_csv_auto('{temp_file_name}/*/*.csv', hive_partitioning=TRUE, header=TRUE);''' + f"""FROM read_csv_auto('{temp_file_name}/*/*.csv', hive_partitioning=TRUE, header=TRUE);""" ) expected = [ - (True, 1.0, 42.0, 'a', 'a'), - (False, 3.2, None, 'b,c', 'a'), - (True, 3.0, 123.0, 'e', 'b'), - (True, 4.0, 321.0, 'f', 'b'), + (True, 1.0, 42.0, "a", "a"), + (False, 3.2, None, "b,c", "a"), + (True, 3.0, 123.0, "e", "b"), + (True, 4.0, 321.0, "f", "b"), ] assert csv_rel.execute().fetchall() == expected - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_partition_with_columns_written(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) df = pandas.DataFrame( { - "c_category": ['a', 'a', 'b', 'b'], + "c_category": ["a", "a", "b", "b"], "c_bool": [True, False, True, True], "c_float": [1.0, 3.2, 3.0, 4.0], "c_int": [42, None, 123, 321], @@ -205,17 +205,17 @@ def test_to_csv_partition_with_columns_written(self, pandas): res = duckdb.sql("FROM rel order by all") rel.to_csv(temp_file_name, header=True, partition_by=["c_category"], write_partition_columns=True) csv_rel = duckdb.sql( - f'''FROM read_csv_auto('{temp_file_name}/*/*.csv', hive_partitioning=TRUE, header=TRUE) order by all;''' + f"""FROM read_csv_auto('{temp_file_name}/*/*.csv', hive_partitioning=TRUE, header=TRUE) order by all;""" ) assert res.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_overwrite(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) df = pandas.DataFrame( { - "c_category_1": ['a', 'a', 'b', 'b'], - "c_category_2": ['c', 'c', 'd', 'd'], + "c_category_1": ["a", "a", "b", "b"], + "c_category_2": ["c", "c", "d", "d"], "c_bool": [True, False, True, True], "c_float": [1.0, 3.2, 3.0, 4.0], "c_int": [42, None, 123, 321], @@ -226,24 +226,24 @@ def test_to_csv_overwrite(self, pandas): rel.to_csv(temp_file_name, header=True, partition_by=["c_category_1"]) # csv to be overwritten rel.to_csv(temp_file_name, header=True, partition_by=["c_category_1"], overwrite=True) csv_rel = duckdb.sql( - f'''FROM read_csv_auto('{temp_file_name}/*/*.csv', hive_partitioning=TRUE, header=TRUE);''' + f"""FROM read_csv_auto('{temp_file_name}/*/*.csv', hive_partitioning=TRUE, header=TRUE);""" ) # When partition columns are read from directory names, column order become different from original expected = [ - ('c', True, 1.0, 42.0, 'a', 'a'), - ('c', False, 3.2, None, 'b,c', 'a'), - ('d', True, 3.0, 123.0, 'e', 'b'), - ('d', True, 4.0, 321.0, 'f', 'b'), + ("c", True, 1.0, 42.0, "a", "a"), + ("c", False, 3.2, None, "b,c", "a"), + ("d", True, 3.0, 123.0, "e", "b"), + ("d", True, 4.0, 321.0, "f", "b"), ] assert csv_rel.execute().fetchall() == expected - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_overwrite_with_columns_written(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) df = pandas.DataFrame( { - "c_category_1": ['a', 'a', 'b', 'b'], - "c_category_2": ['c', 'c', 'd', 'd'], + "c_category_1": ["a", "a", "b", "b"], + "c_category_2": ["c", "c", "d", "d"], "c_bool": [True, False, True, True], "c_float": [1.0, 3.2, 3.0, 4.0], "c_int": [42, None, 123, 321], @@ -258,18 +258,18 @@ def test_to_csv_overwrite_with_columns_written(self, pandas): temp_file_name, header=True, partition_by=["c_category_1"], overwrite=True, write_partition_columns=True ) csv_rel = duckdb.sql( - f'''FROM read_csv_auto('{temp_file_name}/*/*.csv', hive_partitioning=TRUE, header=TRUE) order by all;''' + f"""FROM read_csv_auto('{temp_file_name}/*/*.csv', hive_partitioning=TRUE, header=TRUE) order by all;""" ) res = duckdb.sql("FROM rel order by all") assert res.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_overwrite_not_enabled(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) df = pandas.DataFrame( { - "c_category_1": ['a', 'a', 'b', 'b'], - "c_category_2": ['c', 'c', 'd', 'd'], + "c_category_1": ["a", "a", "b", "b"], + "c_category_2": ["c", "c", "d", "d"], "c_bool": [True, False, True, True], "c_float": [1.0, 3.2, 3.0, 4.0], "c_int": [42, None, 123, 321], @@ -281,14 +281,14 @@ def test_to_csv_overwrite_not_enabled(self, pandas): with pytest.raises(duckdb.IOException, match="OVERWRITE"): rel.to_csv(temp_file_name, header=True, partition_by=["c_category_1"]) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_per_thread_output(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) num_threads = duckdb.sql("select current_setting('threads')").fetchone()[0] - print('num_threads:', num_threads) + print("num_threads:", num_threads) df = pandas.DataFrame( { - "c_category": ['a', 'a', 'b', 'b'], + "c_category": ["a", "a", "b", "b"], "c_bool": [True, False, True, True], "c_float": [1.0, 3.2, 3.0, 4.0], "c_int": [42, None, 123, 321], @@ -297,16 +297,16 @@ def test_to_csv_per_thread_output(self, pandas): ) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, header=True, per_thread_output=True) - csv_rel = duckdb.read_csv(f'{temp_file_name}/*.csv', header=True) + csv_rel = duckdb.read_csv(f"{temp_file_name}/*.csv", header=True) assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_use_tmp_file(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) df = pandas.DataFrame( { - "c_category_1": ['a', 'a', 'b', 'b'], - "c_category_2": ['c', 'c', 'd', 'd'], + "c_category_1": ["a", "a", "b", "b"], + "c_category_2": ["c", "c", "d", "d"], "c_bool": [True, False, True, True], "c_float": [1.0, 3.2, 3.0, 4.0], "c_int": [42, None, 123, 321], diff --git a/tests/fast/api/test_to_parquet.py b/tests/fast/api/test_to_parquet.py index d778aba3..c13ac011 100644 --- a/tests/fast/api/test_to_parquet.py +++ b/tests/fast/api/test_to_parquet.py @@ -13,7 +13,7 @@ class TestToParquet(object): @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) def test_basic_to_parquet(self, pd): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pd.DataFrame({'a': [5, 3, 23, 2], 'b': [45, 234, 234, 2]}) + df = pd.DataFrame({"a": [5, 3, 23, 2], "b": [45, 234, 234, 2]}) rel = duckdb.from_df(df) rel.to_parquet(temp_file_name) @@ -24,7 +24,7 @@ def test_basic_to_parquet(self, pd): @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) def test_compression_gzip(self, pd): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pd.DataFrame({'a': ['string1', 'string2', 'string3']}) + df = pd.DataFrame({"a": ["string1", "string2", "string3"]}) rel = duckdb.from_df(df) rel.to_parquet(temp_file_name, compression="gzip") csv_rel = duckdb.read_parquet(temp_file_name, compression="gzip") @@ -32,37 +32,32 @@ def test_compression_gzip(self, pd): def test_field_ids_auto(self): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - rel = duckdb.sql('''SELECT {i: 128} AS my_struct''') - rel.to_parquet(temp_file_name, field_ids='auto') + rel = duckdb.sql("""SELECT {i: 128} AS my_struct""") + rel.to_parquet(temp_file_name, field_ids="auto") parquet_rel = duckdb.read_parquet(temp_file_name) assert rel.execute().fetchall() == parquet_rel.execute().fetchall() def test_field_ids(self): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - rel = duckdb.sql('''SELECT 1 as i, {j: 128} AS my_struct''') - rel.to_parquet(temp_file_name, field_ids=dict(i=42, my_struct={'__duckdb_field_id': 43, 'j': 44})) + rel = duckdb.sql("""SELECT 1 as i, {j: 128} AS my_struct""") + rel.to_parquet(temp_file_name, field_ids=dict(i=42, my_struct={"__duckdb_field_id": 43, "j": 44})) parquet_rel = duckdb.read_parquet(temp_file_name) assert rel.execute().fetchall() == parquet_rel.execute().fetchall() - assert ( - [('duckdb_schema', None), ('i', 42), ('my_struct', 43), ('j', 44)] - == duckdb.sql( - f''' + assert [("duckdb_schema", None), ("i", 42), ("my_struct", 43), ("j", 44)] == duckdb.sql( + f""" select name,field_id from parquet_schema('{temp_file_name}') - ''' - ) - .execute() - .fetchall() - ) + """ + ).execute().fetchall() @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) - @pytest.mark.parametrize('row_group_size_bytes', [122880 * 1024, '2MB']) + @pytest.mark.parametrize("row_group_size_bytes", [122880 * 1024, "2MB"]) def test_row_group_size_bytes(self, pd, row_group_size_bytes): con = duckdb.connect() con.execute("SET preserve_insertion_order=false;") temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pd.DataFrame({'a': ['string1', 'string2', 'string3']}) + df = pd.DataFrame({"a": ["string1", "string2", "string3"]}) rel = con.from_df(df) rel.to_parquet(temp_file_name, row_group_size_bytes=row_group_size_bytes) parquet_rel = con.read_parquet(temp_file_name) @@ -71,21 +66,21 @@ def test_row_group_size_bytes(self, pd, row_group_size_bytes): @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) def test_row_group_size(self, pd): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pd.DataFrame({'a': ['string1', 'string2', 'string3']}) + df = pd.DataFrame({"a": ["string1", "string2", "string3"]}) rel = duckdb.from_df(df) rel.to_parquet(temp_file_name, row_group_size=122880) parquet_rel = duckdb.read_parquet(temp_file_name) assert rel.execute().fetchall() == parquet_rel.execute().fetchall() @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) - @pytest.mark.parametrize('write_columns', [None, True, False]) + @pytest.mark.parametrize("write_columns", [None, True, False]) def test_partition(self, pd, write_columns): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) df = pd.DataFrame( { "name": ["rei", "shinji", "asuka", "kaworu"], "float": [321.0, 123.0, 23.0, 340.0], - "category": ['a', 'a', 'b', 'c'], + "category": ["a", "a", "b", "c"], } ) rel = duckdb.from_df(df) @@ -95,14 +90,14 @@ def test_partition(self, pd, write_columns): assert result.execute().fetchall() == expected @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) - @pytest.mark.parametrize('write_columns', [None, True, False]) + @pytest.mark.parametrize("write_columns", [None, True, False]) def test_overwrite(self, pd, write_columns): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) df = pd.DataFrame( { "name": ["rei", "shinji", "asuka", "kaworu"], "float": [321.0, 123.0, 23.0, 340.0], - "category": ['a', 'a', 'b', 'c'], + "category": ["a", "a", "b", "c"], } ) rel = duckdb.from_df(df) @@ -120,7 +115,7 @@ def test_use_tmp_file(self, pd): { "name": ["rei", "shinji", "asuka", "kaworu"], "float": [321.0, 123.0, 23.0, 340.0], - "category": ['a', 'a', 'b', 'c'], + "category": ["a", "a", "b", "c"], } ) rel = duckdb.from_df(df) @@ -133,17 +128,17 @@ def test_use_tmp_file(self, pd): def test_per_thread_output(self, pd): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) num_threads = duckdb.sql("select current_setting('threads')").fetchone()[0] - print('threads:', num_threads) + print("threads:", num_threads) df = pd.DataFrame( { "name": ["rei", "shinji", "asuka", "kaworu"], "float": [321.0, 123.0, 23.0, 340.0], - "category": ['a', 'a', 'b', 'c'], + "category": ["a", "a", "b", "c"], } ) rel = duckdb.from_df(df) rel.to_parquet(temp_file_name, per_thread_output=True) - result = duckdb.read_parquet(f'{temp_file_name}/*.parquet') + result = duckdb.read_parquet(f"{temp_file_name}/*.parquet") assert rel.execute().fetchall() == result.execute().fetchall() @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) @@ -153,27 +148,27 @@ def test_append(self, pd): { "name": ["rei", "shinji", "asuka", "kaworu"], "float": [321.0, 123.0, 23.0, 340.0], - "category": ['a', 'a', 'b', 'c'], + "category": ["a", "a", "b", "c"], } ) rel = duckdb.from_df(df) - rel.to_parquet(temp_file_name, partition_by=['category']) + rel.to_parquet(temp_file_name, partition_by=["category"]) df_to_append = pd.DataFrame( { "name": ["random"], "float": [420], - "category": ['a'], + "category": ["a"], } ) rel_to_append = duckdb.from_df(df_to_append) - rel_to_append.to_parquet(temp_file_name, partition_by=['category'], append=True) + rel_to_append.to_parquet(temp_file_name, partition_by=["category"], append=True) result = duckdb.sql(f"FROM read_parquet('{temp_file_name}/*/*.parquet', hive_partitioning=TRUE) ORDER BY name") result.show() expected = [ - ('asuka', 23.0, 'b'), - ('kaworu', 340.0, 'c'), - ('random', 420.0, 'a'), - ('rei', 321.0, 'a'), - ('shinji', 123.0, 'a'), + ("asuka", 23.0, "b"), + ("kaworu", 340.0, "c"), + ("random", 420.0, "a"), + ("rei", 321.0, "a"), + ("shinji", 123.0, "a"), ] assert result.execute().fetchall() == expected diff --git a/tests/fast/api/test_with_propagating_exceptions.py b/tests/fast/api/test_with_propagating_exceptions.py index e9cfb3c0..8613d6f4 100644 --- a/tests/fast/api/test_with_propagating_exceptions.py +++ b/tests/fast/api/test_with_propagating_exceptions.py @@ -7,12 +7,12 @@ def test_with(self): # Should propagate exception raised in the 'with duckdb.connect() ..' with pytest.raises(duckdb.ParserException, match="syntax error at or near *"): with duckdb.connect() as con: - print('before') - con.execute('invalid') - print('after') + print("before") + con.execute("invalid") + print("after") # Does not raise an exception with duckdb.connect() as con: - print('before') - con.execute('select 1') - print('after') + print("before") + con.execute("select 1") + print("after") diff --git a/tests/fast/arrow/parquet_write_roundtrip.py b/tests/fast/arrow/parquet_write_roundtrip.py index 093040c0..5dbf3949 100644 --- a/tests/fast/arrow/parquet_write_roundtrip.py +++ b/tests/fast/arrow/parquet_write_roundtrip.py @@ -17,13 +17,13 @@ def parquet_types_test(type_list): sql_type = type_pair[2] add_cast = len(type_pair) > 3 and type_pair[3] add_sql_cast = len(type_pair) > 4 and type_pair[4] - df = pandas.DataFrame.from_dict({'val': numpy.array(value_list, dtype=numpy_type)}) + df = pandas.DataFrame.from_dict({"val": numpy.array(value_list, dtype=numpy_type)}) duckdb_cursor = duckdb.connect() duckdb_cursor.execute(f"CREATE TABLE tmp AS SELECT val::{sql_type} val FROM df") duckdb_cursor.execute(f"COPY tmp TO '{temp_name}' (FORMAT PARQUET)") read_df = pandas.read_parquet(temp_name) if add_cast: - read_df['val'] = read_df['val'].astype(numpy_type) + read_df["val"] = read_df["val"].astype(numpy_type) assert df.equals(read_df) read_from_duckdb = duckdb_cursor.execute(f"SELECT * FROM parquet_scan('{temp_name}')").df() @@ -40,16 +40,16 @@ def parquet_types_test(type_list): class TestParquetRoundtrip(object): def test_roundtrip_numeric(self, duckdb_cursor): type_list = [ - ([-(2**7), 0, 2**7 - 1], numpy.int8, 'TINYINT'), - ([-(2**15), 0, 2**15 - 1], numpy.int16, 'SMALLINT'), - ([-(2**31), 0, 2**31 - 1], numpy.int32, 'INTEGER'), - ([-(2**63), 0, 2**63 - 1], numpy.int64, 'BIGINT'), - ([0, 42, 2**8 - 1], numpy.uint8, 'UTINYINT'), - ([0, 42, 2**16 - 1], numpy.uint16, 'USMALLINT'), - ([0, 42, 2**32 - 1], numpy.uint32, 'UINTEGER', False, True), - ([0, 42, 2**64 - 1], numpy.uint64, 'UBIGINT'), - ([0, 0.5, -0.5], numpy.float32, 'REAL'), - ([0, 0.5, -0.5], numpy.float64, 'DOUBLE'), + ([-(2**7), 0, 2**7 - 1], numpy.int8, "TINYINT"), + ([-(2**15), 0, 2**15 - 1], numpy.int16, "SMALLINT"), + ([-(2**31), 0, 2**31 - 1], numpy.int32, "INTEGER"), + ([-(2**63), 0, 2**63 - 1], numpy.int64, "BIGINT"), + ([0, 42, 2**8 - 1], numpy.uint8, "UTINYINT"), + ([0, 42, 2**16 - 1], numpy.uint16, "USMALLINT"), + ([0, 42, 2**32 - 1], numpy.uint32, "UINTEGER", False, True), + ([0, 42, 2**64 - 1], numpy.uint64, "UBIGINT"), + ([0, 0.5, -0.5], numpy.float32, "REAL"), + ([0, 0.5, -0.5], numpy.float64, "DOUBLE"), ] parquet_types_test(type_list) @@ -61,15 +61,15 @@ def test_roundtrip_timestamp(self, duckdb_cursor): datetime.datetime(1992, 7, 9, 7, 5, 33), ] type_list = [ - (date_time_list, 'datetime64[ns]', 'TIMESTAMP_NS'), - (date_time_list, 'datetime64[us]', 'TIMESTAMP'), - (date_time_list, 'datetime64[ms]', 'TIMESTAMP_MS'), - (date_time_list, 'datetime64[s]', 'TIMESTAMP_S'), - (date_time_list, 'datetime64[D]', 'DATE', True), + (date_time_list, "datetime64[ns]", "TIMESTAMP_NS"), + (date_time_list, "datetime64[us]", "TIMESTAMP"), + (date_time_list, "datetime64[ms]", "TIMESTAMP_MS"), + (date_time_list, "datetime64[s]", "TIMESTAMP_S"), + (date_time_list, "datetime64[D]", "DATE", True), ] parquet_types_test(type_list) def test_roundtrip_varchar(self, duckdb_cursor): - varchar_list = ['hello', 'this is a very long string', 'hello', None] - type_list = [(varchar_list, object, 'VARCHAR')] + varchar_list = ["hello", "this is a very long string", "hello", None] + type_list = [(varchar_list, object, "VARCHAR")] parquet_types_test(type_list) diff --git a/tests/fast/arrow/test_10795.py b/tests/fast/arrow/test_10795.py index 043bf4ff..5503e529 100644 --- a/tests/fast/arrow/test_10795.py +++ b/tests/fast/arrow/test_10795.py @@ -1,12 +1,12 @@ import duckdb import pytest -pyarrow = pytest.importorskip('pyarrow') +pyarrow = pytest.importorskip("pyarrow") -@pytest.mark.parametrize('arrow_large_buffer_size', [True, False]) +@pytest.mark.parametrize("arrow_large_buffer_size", [True, False]) def test_10795(arrow_large_buffer_size): conn = duckdb.connect() conn.sql(f"set arrow_large_buffer_size={arrow_large_buffer_size}") arrow = conn.sql("select map(['non-inlined string', 'test', 'duckdb'], [42, 1337, 123]) as map").to_arrow_table() - assert arrow.to_pydict() == {'map': [[('non-inlined string', 42), ('test', 1337), ('duckdb', 123)]]} + assert arrow.to_pydict() == {"map": [[("non-inlined string", 42), ("test", 1337), ("duckdb", 123)]]} diff --git a/tests/fast/arrow/test_12384.py b/tests/fast/arrow/test_12384.py index af9c8ed2..d2d4a7fc 100644 --- a/tests/fast/arrow/test_12384.py +++ b/tests/fast/arrow/test_12384.py @@ -2,17 +2,17 @@ import pytest import os -pa = pytest.importorskip('pyarrow') +pa = pytest.importorskip("pyarrow") def test_10795(): - arrow_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'arrow_table') - with pa.memory_map(arrow_filename, 'r') as source: + arrow_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "arrow_table") + with pa.memory_map(arrow_filename, "r") as source: reader = pa.ipc.RecordBatchFileReader(source) taxi_fhvhv_arrow = reader.read_all() - con = duckdb.connect(database=':memory:') + con = duckdb.connect(database=":memory:") con.execute("SET TimeZone='UTC';") - con.register('taxi_fhvhv', taxi_fhvhv_arrow) + con.register("taxi_fhvhv", taxi_fhvhv_arrow) res = con.execute( "SELECT PULocationID, pickup_datetime FROM taxi_fhvhv WHERE pickup_datetime >= '2023-01-01T00:00:00-05:00' AND PULocationID = 244" ).fetchall() diff --git a/tests/fast/arrow/test_14344.py b/tests/fast/arrow/test_14344.py index 522228c0..86f8728b 100644 --- a/tests/fast/arrow/test_14344.py +++ b/tests/fast/arrow/test_14344.py @@ -22,4 +22,4 @@ def test_14344(duckdb_cursor): USING (foo) """ ).fetchall() - assert res == [('123',)] + assert res == [("123",)] diff --git a/tests/fast/arrow/test_2426.py b/tests/fast/arrow/test_2426.py index cdef8da7..6d760500 100644 --- a/tests/fast/arrow/test_2426.py +++ b/tests/fast/arrow/test_2426.py @@ -22,15 +22,15 @@ def test_2426(self, duckdb_cursor): con.execute("Insert Into test values ('" + str(i) + "')") con.execute("Insert Into test values ('5000')") con.execute("Insert Into test values ('6000')") - sql = ''' + sql = """ SELECT a, COUNT(*) AS repetitions FROM test GROUP BY a - ''' + """ result_df = con.execute(sql).df() arrow_table = con.execute(sql).fetch_arrow_table() arrow_df = arrow_table.to_pandas() - assert result_df['repetitions'].sum() == arrow_df['repetitions'].sum() + assert result_df["repetitions"].sum() == arrow_df["repetitions"].sum() diff --git a/tests/fast/arrow/test_5547.py b/tests/fast/arrow/test_5547.py index b27b29b2..eb77ab83 100644 --- a/tests/fast/arrow/test_5547.py +++ b/tests/fast/arrow/test_5547.py @@ -3,7 +3,7 @@ from pandas.testing import assert_frame_equal import pytest -pa = pytest.importorskip('pyarrow') +pa = pytest.importorskip("pyarrow") def test_5547(): diff --git a/tests/fast/arrow/test_6584.py b/tests/fast/arrow/test_6584.py index 9a6241f9..6f96bf2d 100644 --- a/tests/fast/arrow/test_6584.py +++ b/tests/fast/arrow/test_6584.py @@ -2,7 +2,7 @@ import duckdb import pytest -pyarrow = pytest.importorskip('pyarrow') +pyarrow = pytest.importorskip("pyarrow") def f(cur, i, data): diff --git a/tests/fast/arrow/test_6796.py b/tests/fast/arrow/test_6796.py index 6690f22c..ef464f49 100644 --- a/tests/fast/arrow/test_6796.py +++ b/tests/fast/arrow/test_6796.py @@ -2,10 +2,10 @@ import pytest from conftest import NumpyPandas, ArrowPandas -pyarrow = pytest.importorskip('pyarrow') +pyarrow = pytest.importorskip("pyarrow") -@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) +@pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_6796(pandas): conn = duckdb.connect() input_df = pandas.DataFrame({"foo": ["bar"]}) diff --git a/tests/fast/arrow/test_7652.py b/tests/fast/arrow/test_7652.py index afe3b738..857d871d 100644 --- a/tests/fast/arrow/test_7652.py +++ b/tests/fast/arrow/test_7652.py @@ -9,7 +9,7 @@ class Test7652(object): def test_7652(self, duckdb_cursor): - temp_file_name = tempfile.NamedTemporaryFile(suffix='.parquet').name + temp_file_name = tempfile.NamedTemporaryFile(suffix=".parquet").name # Generate a list of values that aren't uniform in changes. generated_list = [1, 0, 2] @@ -17,7 +17,7 @@ def test_7652(self, duckdb_cursor): print(f"Min value: {min(generated_list)} max value: {max(generated_list)}") # Convert list of values to a PyArrow table with a single column. - fake_table = pa.Table.from_arrays([pa.array(generated_list, pa.int64())], names=['n0']) + fake_table = pa.Table.from_arrays([pa.array(generated_list, pa.int64())], names=["n0"]) # Write that column with DELTA_BINARY_PACKED encoding with pq.ParquetWriter( diff --git a/tests/fast/arrow/test_7699.py b/tests/fast/arrow/test_7699.py index c8c234ef..a4de66b9 100644 --- a/tests/fast/arrow/test_7699.py +++ b/tests/fast/arrow/test_7699.py @@ -22,4 +22,4 @@ def test_7699(self, duckdb_cursor): rel = duckdb_cursor.sql("select * from df1234") res = rel.fetchall() - assert res == [('K',), ('L',), ('K',), ('L',), ('M',)] + assert res == [("K",), ("L",), ("K",), ("L",), ("M",)] diff --git a/tests/fast/arrow/test_arrow_batch_index.py b/tests/fast/arrow/test_arrow_batch_index.py index dadf6f89..a8dc2c7f 100644 --- a/tests/fast/arrow/test_arrow_batch_index.py +++ b/tests/fast/arrow/test_arrow_batch_index.py @@ -9,13 +9,13 @@ class TestArrowBatchIndex(object): def test_arrow_batch_index(self, duckdb_cursor): con = duckdb.connect() - df = con.execute('SELECT * FROM range(10000000) t(i)').df() + df = con.execute("SELECT * FROM range(10000000) t(i)").df() arrow_tbl = pa.Table.from_pandas(df) - con.execute('CREATE TABLE tbl AS SELECT * FROM arrow_tbl') + con.execute("CREATE TABLE tbl AS SELECT * FROM arrow_tbl") - result = con.execute('SELECT * FROM tbl LIMIT 5').fetchall() + result = con.execute("SELECT * FROM tbl LIMIT 5").fetchall() assert [x[0] for x in result] == [0, 1, 2, 3, 4] - result = con.execute('SELECT * FROM tbl LIMIT 5 OFFSET 777778').fetchall() + result = con.execute("SELECT * FROM tbl LIMIT 5 OFFSET 777778").fetchall() assert [x[0] for x in result] == [777778, 777779, 777780, 777781, 777782] diff --git a/tests/fast/arrow/test_arrow_binary_view.py b/tests/fast/arrow/test_arrow_binary_view.py index 7d9d0afc..31107f67 100644 --- a/tests/fast/arrow/test_arrow_binary_view.py +++ b/tests/fast/arrow/test_arrow_binary_view.py @@ -8,7 +8,7 @@ class TestArrowBinaryView(object): def test_arrow_binary_view(self, duckdb_cursor): con = duckdb.connect() tab = pa.table({"x": pa.array([b"abc", b"thisisaverybigbinaryyaymorethanfifteen", None], pa.binary_view())}) - assert con.execute("FROM tab").fetchall() == [(b'abc',), (b'thisisaverybigbinaryyaymorethanfifteen',), (None,)] + assert con.execute("FROM tab").fetchall() == [(b"abc",), (b"thisisaverybigbinaryyaymorethanfifteen",), (None,)] # By default we won't export a view assert not con.execute("FROM tab").fetch_arrow_table().equals(tab) # We do the binary view from 1.4 onwards @@ -16,5 +16,5 @@ def test_arrow_binary_view(self, duckdb_cursor): assert con.execute("FROM tab").fetch_arrow_table().equals(tab) assert con.execute("FROM tab where x = 'thisisaverybigbinaryyaymorethanfifteen'").fetchall() == [ - (b'thisisaverybigbinaryyaymorethanfifteen',) + (b"thisisaverybigbinaryyaymorethanfifteen",) ] diff --git a/tests/fast/arrow/test_arrow_case_sensitive.py b/tests/fast/arrow/test_arrow_case_sensitive.py index 6106cc75..ef60046a 100644 --- a/tests/fast/arrow/test_arrow_case_sensitive.py +++ b/tests/fast/arrow/test_arrow_case_sensitive.py @@ -7,18 +7,18 @@ class TestArrowCaseSensitive(object): def test_arrow_case_sensitive(self, duckdb_cursor): data = (pa.array([1], type=pa.int32()), pa.array([1000], type=pa.int32())) - arrow_table = pa.Table.from_arrays([data[0], data[1]], ['A1', 'a1']) + arrow_table = pa.Table.from_arrays([data[0], data[1]], ["A1", "a1"]) - duckdb_cursor.register('arrow_tbl', arrow_table) - assert duckdb_cursor.table("arrow_tbl").columns == ['A1', 'a1_1'] + duckdb_cursor.register("arrow_tbl", arrow_table) + assert duckdb_cursor.table("arrow_tbl").columns == ["A1", "a1_1"] assert duckdb_cursor.execute("select A1 from arrow_tbl;").fetchall() == [(1,)] assert duckdb_cursor.execute("select a1_1 from arrow_tbl;").fetchall() == [(1000,)] - assert arrow_table.column_names == ['A1', 'a1'] + assert arrow_table.column_names == ["A1", "a1"] def test_arrow_case_sensitive_repeated(self, duckdb_cursor): data = (pa.array([1], type=pa.int32()), pa.array([1000], type=pa.int32())) - arrow_table = pa.Table.from_arrays([data[0], data[1], data[1]], ['A1', 'a1_1', 'a1']) + arrow_table = pa.Table.from_arrays([data[0], data[1], data[1]], ["A1", "a1_1", "a1"]) - duckdb_cursor.register('arrow_tbl', arrow_table) - assert duckdb_cursor.table("arrow_tbl").columns == ['A1', 'a1_1', 'a1_2'] - assert arrow_table.column_names == ['A1', 'a1_1', 'a1'] + duckdb_cursor.register("arrow_tbl", arrow_table) + assert duckdb_cursor.table("arrow_tbl").columns == ["A1", "a1_1", "a1_2"] + assert arrow_table.column_names == ["A1", "a1_1", "a1"] diff --git a/tests/fast/arrow/test_arrow_decimal_32_64.py b/tests/fast/arrow/test_arrow_decimal_32_64.py index 4a960454..39b6e43a 100644 --- a/tests/fast/arrow/test_arrow_decimal_32_64.py +++ b/tests/fast/arrow/test_arrow_decimal_32_64.py @@ -8,7 +8,7 @@ class TestArrowDecimalTypes(object): def test_decimal_32(self, duckdb_cursor): duckdb_cursor = duckdb.connect() - duckdb_cursor.execute('SET arrow_output_version = 1.5') + duckdb_cursor.execute("SET arrow_output_version = 1.5") decimal_32 = pa.Table.from_pylist( [ {"data": Decimal("100.20")}, @@ -20,10 +20,10 @@ def test_decimal_32(self, duckdb_cursor): ) # Test scan assert duckdb_cursor.execute("FROM decimal_32").fetchall() == [ - (Decimal('100.20'),), - (Decimal('110.21'),), - (Decimal('31.20'),), - (Decimal('500.20'),), + (Decimal("100.20"),), + (Decimal("110.21"),), + (Decimal("31.20"),), + (Decimal("500.20"),), ] # Test filter pushdown assert duckdb_cursor.execute("SELECT COUNT(*) FROM decimal_32 where data > 100 and data < 200 ").fetchall() == [ @@ -37,7 +37,7 @@ def test_decimal_32(self, duckdb_cursor): def test_decimal_64(self, duckdb_cursor): duckdb_cursor = duckdb.connect() - duckdb_cursor.execute('SET arrow_output_version = 1.5') + duckdb_cursor.execute("SET arrow_output_version = 1.5") decimal_64 = pa.Table.from_pylist( [ {"data": Decimal("1000.231")}, @@ -50,10 +50,10 @@ def test_decimal_64(self, duckdb_cursor): # Test scan assert duckdb_cursor.execute("FROM decimal_64").fetchall() == [ - (Decimal('1000.231'),), - (Decimal('1100.231'),), - (Decimal('999999999999.231'),), - (Decimal('500.200'),), + (Decimal("1000.231"),), + (Decimal("1100.231"),), + (Decimal("999999999999.231"),), + (Decimal("500.200"),), ] # Test Filter pushdown diff --git a/tests/fast/arrow/test_arrow_extensions.py b/tests/fast/arrow/test_arrow_extensions.py index 95a2108a..43c995bb 100644 --- a/tests/fast/arrow/test_arrow_extensions.py +++ b/tests/fast/arrow/test_arrow_extensions.py @@ -5,11 +5,10 @@ from uuid import UUID import datetime -pa = pytest.importorskip('pyarrow', '18.0.0') +pa = pytest.importorskip("pyarrow", "18.0.0") class TestCanonicalExtensionTypes(object): - def test_uuid(self): duckdb_cursor = duckdb.connect() duckdb_cursor.execute("SET arrow_lossless_conversion = true") @@ -17,9 +16,9 @@ def test_uuid(self): storage_array = pa.array([uuid.uuid4().bytes for _ in range(4)], pa.binary(16)) storage_array = pa.uuid().wrap_array(storage_array) - arrow_table = pa.Table.from_arrays([storage_array], names=['uuid_col']) + arrow_table = pa.Table.from_arrays([storage_array], names=["uuid_col"]) - duck_arrow = duckdb_cursor.execute('FROM arrow_table').fetch_arrow_table() + duck_arrow = duckdb_cursor.execute("FROM arrow_table").fetch_arrow_table() assert duck_arrow.equals(arrow_table) @@ -30,14 +29,14 @@ def test_uuid_from_duck(self): arrow_table = duckdb_cursor.execute("select uuid from test_all_types()").fetch_arrow_table() assert arrow_table.to_pylist() == [ - {'uuid': UUID('00000000-0000-0000-0000-000000000000')}, - {'uuid': UUID('ffffffff-ffff-ffff-ffff-ffffffffffff')}, - {'uuid': None}, + {"uuid": UUID("00000000-0000-0000-0000-000000000000")}, + {"uuid": UUID("ffffffff-ffff-ffff-ffff-ffffffffffff")}, + {"uuid": None}, ] assert duckdb_cursor.execute("FROM arrow_table").fetchall() == [ - (UUID('00000000-0000-0000-0000-000000000000'),), - (UUID('ffffffff-ffff-ffff-ffff-ffffffffffff'),), + (UUID("00000000-0000-0000-0000-000000000000"),), + (UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"),), (None,), ] @@ -45,8 +44,8 @@ def test_uuid_from_duck(self): "select '00000000-0000-0000-0000-000000000100'::UUID as uuid" ).fetch_arrow_table() - assert arrow_table.to_pylist() == [{'uuid': UUID('00000000-0000-0000-0000-000000000100')}] - assert duckdb_cursor.execute("FROM arrow_table").fetchall() == [(UUID('00000000-0000-0000-0000-000000000100'),)] + assert arrow_table.to_pylist() == [{"uuid": UUID("00000000-0000-0000-0000-000000000100")}] + assert duckdb_cursor.execute("FROM arrow_table").fetchall() == [(UUID("00000000-0000-0000-0000-000000000100"),)] def test_json(self, duckdb_cursor): data = {"name": "Pedro", "age": 28, "car": "VW Fox"} @@ -56,10 +55,10 @@ def test_json(self, duckdb_cursor): storage_array = pa.array([json_string], pa.string()) - arrow_table = pa.Table.from_arrays([storage_array], names=['json_col']) + arrow_table = pa.Table.from_arrays([storage_array], names=["json_col"]) duckdb_cursor.execute("SET arrow_lossless_conversion = true") - duck_arrow = duckdb_cursor.execute('FROM arrow_table').fetch_arrow_table() + duck_arrow = duckdb_cursor.execute("FROM arrow_table").fetch_arrow_table() assert duck_arrow.equals(arrow_table) @@ -70,8 +69,8 @@ def test_uuid_no_def(self): res_arrow = duckdb_cursor.execute("select uuid from test_all_types()").fetch_arrow_table() res_duck = duckdb_cursor.execute("from res_arrow").fetchall() assert res_duck == [ - (UUID('00000000-0000-0000-0000-000000000000'),), - (UUID('ffffffff-ffff-ffff-ffff-ffffffffffff'),), + (UUID("00000000-0000-0000-0000-000000000000"),), + (UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"),), (None,), ] @@ -79,15 +78,15 @@ def test_uuid_no_def_lossless(self): duckdb_cursor = duckdb.connect() res_arrow = duckdb_cursor.execute("select uuid from test_all_types()").fetch_arrow_table() assert res_arrow.to_pylist() == [ - {'uuid': '00000000-0000-0000-0000-000000000000'}, - {'uuid': 'ffffffff-ffff-ffff-ffff-ffffffffffff'}, - {'uuid': None}, + {"uuid": "00000000-0000-0000-0000-000000000000"}, + {"uuid": "ffffffff-ffff-ffff-ffff-ffffffffffff"}, + {"uuid": None}, ] res_duck = duckdb_cursor.execute("from res_arrow").fetchall() assert res_duck == [ - ('00000000-0000-0000-0000-000000000000',), - ('ffffffff-ffff-ffff-ffff-ffffffffffff',), + ("00000000-0000-0000-0000-000000000000",), + ("ffffffff-ffff-ffff-ffff-ffffffffffff",), (None,), ] @@ -98,8 +97,8 @@ def test_uuid_no_def_stream(self): res_arrow = duckdb_cursor.execute("select uuid from test_all_types()").fetch_record_batch() res_duck = duckdb.execute("from res_arrow").fetchall() assert res_duck == [ - (UUID('00000000-0000-0000-0000-000000000000'),), - (UUID('ffffffff-ffff-ffff-ffff-ffffffffffff'),), + (UUID("00000000-0000-0000-0000-000000000000"),), + (UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"),), (None,), ] @@ -109,9 +108,9 @@ def test_function(x): return x con = duckdb.connect() - con.create_function('test', test_function, ['UUID'], 'UUID', type='arrow') + con.create_function("test", test_function, ["UUID"], "UUID", type="arrow") - rel = con.sql("select ? as x", params=[uuid.UUID('ffffffff-ffff-ffff-ffff-ffffffffffff')]) + rel = con.sql("select ? as x", params=[uuid.UUID("ffffffff-ffff-ffff-ffff-ffffffffffff")]) rel.project("test(x) from t").fetchall() def test_unimplemented_extension(self, duckdb_cursor): @@ -120,51 +119,51 @@ def __init__(self) -> None: pa.ExtensionType.__init__(self, pa.binary(5), "pedro.binary") def __arrow_ext_serialize__(self) -> bytes: - return b'' + return b"" @classmethod def __arrow_ext_deserialize__(cls, storage_type, serialized): return UuidTypeWrong() - storage_array = pa.array(['pedro'], pa.binary(5)) + storage_array = pa.array(["pedro"], pa.binary(5)) my_type = MyType() storage_array = my_type.wrap_array(storage_array) age_array = pa.array([29], pa.int32()) - arrow_table = pa.Table.from_arrays([storage_array, age_array], names=['pedro_pedro_pedro', 'age']) + arrow_table = pa.Table.from_arrays([storage_array, age_array], names=["pedro_pedro_pedro", "age"]) - duck_arrow = duckdb_cursor.execute('FROM arrow_table').fetch_arrow_table() - assert duckdb_cursor.execute('FROM duck_arrow').fetchall() == [(b'pedro', 29)] + duck_arrow = duckdb_cursor.execute("FROM arrow_table").fetch_arrow_table() + assert duckdb_cursor.execute("FROM duck_arrow").fetchall() == [(b"pedro", 29)] def test_hugeint(self): con = duckdb.connect() con.execute("SET arrow_lossless_conversion = true") - storage_array = pa.array([b'\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff'], pa.binary(16)) + storage_array = pa.array([b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"], pa.binary(16)) hugeint_type = pa.opaque(pa.binary(16), "hugeint", "DuckDB") storage_array = hugeint_type.wrap_array(storage_array) - arrow_table = pa.Table.from_arrays([storage_array], names=['numbers']) + arrow_table = pa.Table.from_arrays([storage_array], names=["numbers"]) - assert con.execute('FROM arrow_table').fetchall() == [(-1,)] + assert con.execute("FROM arrow_table").fetchall() == [(-1,)] - assert con.execute('FROM arrow_table').fetch_arrow_table().equals(arrow_table) + assert con.execute("FROM arrow_table").fetch_arrow_table().equals(arrow_table) con.execute("SET arrow_lossless_conversion = false") - assert not con.execute('FROM arrow_table').fetch_arrow_table().equals(arrow_table) + assert not con.execute("FROM arrow_table").fetch_arrow_table().equals(arrow_table) def test_uhugeint(self, duckdb_cursor): - storage_array = pa.array([b'\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff'], pa.binary(16)) + storage_array = pa.array([b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"], pa.binary(16)) uhugeint_type = pa.opaque(pa.binary(16), "uhugeint", "DuckDB") storage_array = uhugeint_type.wrap_array(storage_array) - arrow_table = pa.Table.from_arrays([storage_array], names=['numbers']) + arrow_table = pa.Table.from_arrays([storage_array], names=["numbers"]) - assert duckdb_cursor.execute('FROM arrow_table').fetchall() == [(340282366920938463463374607431768211455,)] + assert duckdb_cursor.execute("FROM arrow_table").fetchall() == [(340282366920938463463374607431768211455,)] def test_bit(self): con = duckdb.connect() @@ -176,18 +175,18 @@ def test_bit(self): res_bit = con.execute("SELECT '0101011'::BIT str FROM range(5) tbl(i)").fetch_arrow_table() assert con.execute("FROM res_blob").fetchall() == [ - (b'\x01\xab',), - (b'\x01\xab',), - (b'\x01\xab',), - (b'\x01\xab',), - (b'\x01\xab',), + (b"\x01\xab",), + (b"\x01\xab",), + (b"\x01\xab",), + (b"\x01\xab",), + (b"\x01\xab",), ] assert con.execute("FROM res_bit").fetchall() == [ - ('0101011',), - ('0101011',), - ('0101011',), - ('0101011',), - ('0101011',), + ("0101011",), + ("0101011",), + ("0101011",), + ("0101011",), + ("0101011",), ] def test_timetz(self): @@ -209,12 +208,12 @@ def test_bignum(self): res_bignum = con.execute( "SELECT '179769313486231570814527423731704356798070567525844996598917476803157260780028538760589558632766878171540458953514382464234321326889464182768467546703537516986049910576551282076245490090389328944075868508455133942304583236903222948165808559332123348274797826204144723168738177180919299881250404026184124858368'::bignum a FROM range(1) tbl(i)" ).fetch_arrow_table() - assert res_bignum.column("a").type.type_name == 'bignum' + assert res_bignum.column("a").type.type_name == "bignum" assert res_bignum.column("a").type.vendor_name == "DuckDB" assert con.execute("FROM res_bignum").fetchall() == [ ( - '179769313486231570814527423731704356798070567525844996598917476803157260780028538760589558632766878171540458953514382464234321326889464182768467546703537516986049910576551282076245490090389328944075868508455133942304583236903222948165808559332123348274797826204144723168738177180919299881250404026184124858368', + "179769313486231570814527423731704356798070567525844996598917476803157260780028538760589558632766878171540458953514382464234321326889464182768467546703537516986049910576551282076245490090389328944075868508455133942304583236903222948165808559332123348274797826204144723168738177180919299881250404026184124858368", ) ] @@ -235,9 +234,9 @@ def test_extension_dictionary(self, duckdb_cursor): indices = pa.array([0, 1, 0, 1, 2, 1, 0, 2]) dictionary = pa.array( [ - b'\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff', - b'\x01\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff', - b'\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff', + b"\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff", + b"\x01\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff", + b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff", ], pa.binary(16), ) @@ -245,7 +244,7 @@ def test_extension_dictionary(self, duckdb_cursor): dictionary = uhugeint_type.wrap_array(dictionary) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array], ['a']) + arrow_table = pa.Table.from_arrays([dict_array], ["a"]) rel = duckdb_cursor.from_arrow(arrow_table) assert rel.execute().fetchall() == [ (340282366920938463463374607431768211200,), @@ -263,13 +262,13 @@ def test_boolean(self): con.execute("SET arrow_lossless_conversion = true") storage_array = pa.array([-1, 0, 1, 2, None], pa.int8()) bool8_array = pa.ExtensionArray.from_storage(pa.bool8(), storage_array) - arrow_table = pa.Table.from_arrays([bool8_array], names=['bool8']) - assert con.execute('FROM arrow_table').fetchall() == [(True,), (False,), (True,), (True,), (None,)] - result_table = con.execute('FROM arrow_table').fetch_arrow_table() + arrow_table = pa.Table.from_arrays([bool8_array], names=["bool8"]) + assert con.execute("FROM arrow_table").fetchall() == [(True,), (False,), (True,), (True,), (None,)] + result_table = con.execute("FROM arrow_table").fetch_arrow_table() res_storage_array = pa.array([1, 0, 1, 1, None], pa.int8()) res_bool8_array = pa.ExtensionArray.from_storage(pa.bool8(), res_storage_array) - res_arrow_table = pa.Table.from_arrays([res_bool8_array], names=['bool8']) + res_arrow_table = pa.Table.from_arrays([res_bool8_array], names=["bool8"]) assert result_table.equals(res_arrow_table) @@ -279,7 +278,7 @@ def test_accept_malformed_complex_json(self, duckdb_cursor): pa.binary(), metadata={ "ARROW:extension:name": "foofyfoo", - "ARROW:extension:metadata": 'this is not valid json', + "ARROW:extension:metadata": "this is not valid json", }, ) schema = pa.schema([field]) @@ -296,7 +295,7 @@ def test_accept_malformed_complex_json(self, duckdb_cursor): pa.binary(), metadata={ "ARROW:extension:name": "arrow.opaque", - "ARROW:extension:metadata": 'this is not valid json', + "ARROW:extension:metadata": "this is not valid json", }, ) schema = pa.schema([field]) @@ -337,9 +336,9 @@ def test_accept_malformed_complex_json(self, duckdb_cursor): schema=schema, ) assert duckdb_cursor.sql("""DESCRIBE FROM bignum_table;""").fetchone() == ( - 'bignum_value', - 'BIGNUM', - 'YES', + "bignum_value", + "BIGNUM", + "YES", None, None, None, diff --git a/tests/fast/arrow/test_arrow_fetch.py b/tests/fast/arrow/test_arrow_fetch.py index 04a34595..a969da21 100644 --- a/tests/fast/arrow/test_arrow_fetch.py +++ b/tests/fast/arrow/test_arrow_fetch.py @@ -83,8 +83,8 @@ def test_to_arrow_chunk_size(self, duckdb_cursor): duckdb_cursor = duckdb.connect() duckdb_cursor.execute("CREATE table t as select range a from range(3000);") - relation = duckdb_cursor.table('t') + relation = duckdb_cursor.table("t") arrow_tbl = relation.fetch_arrow_table() - assert arrow_tbl['a'].num_chunks == 1 + assert arrow_tbl["a"].num_chunks == 1 arrow_tbl = relation.fetch_arrow_table(2048) - assert arrow_tbl['a'].num_chunks == 2 + assert arrow_tbl["a"].num_chunks == 2 diff --git a/tests/fast/arrow/test_arrow_fetch_recordbatch.py b/tests/fast/arrow/test_arrow_fetch_recordbatch.py index 24d7c2c7..8915d886 100644 --- a/tests/fast/arrow/test_arrow_fetch_recordbatch.py +++ b/tests/fast/arrow/test_arrow_fetch_recordbatch.py @@ -1,7 +1,7 @@ import duckdb import pytest -pa = pytest.importorskip('pyarrow') +pa = pytest.importorskip("pyarrow") class TestArrowFetchRecordBatch(object): @@ -12,7 +12,7 @@ def test_record_batch_next_batch_numeric(self, duckdb_cursor): duckdb_cursor.execute("CREATE table t as select range a from range(3000);") query = duckdb_cursor.execute("SELECT a FROM t") record_batch_reader = query.fetch_record_batch(1024) - assert record_batch_reader.schema.names == ['a'] + assert record_batch_reader.schema.names == ["a"] chunk = record_batch_reader.read_next_batch() assert len(chunk) == 1024 chunk = record_batch_reader.read_next_batch() @@ -38,7 +38,7 @@ def test_record_batch_next_batch_bool(self, duckdb_cursor): ) query = duckdb_cursor.execute("SELECT a FROM t") record_batch_reader = query.fetch_record_batch(1024) - assert record_batch_reader.schema.names == ['a'] + assert record_batch_reader.schema.names == ["a"] chunk = record_batch_reader.read_next_batch() assert len(chunk) == 1024 chunk = record_batch_reader.read_next_batch() @@ -63,7 +63,7 @@ def test_record_batch_next_batch_varchar(self, duckdb_cursor): duckdb_cursor.execute("CREATE table t as select range::varchar a from range(3000);") query = duckdb_cursor.execute("SELECT a FROM t") record_batch_reader = query.fetch_record_batch(1024) - assert record_batch_reader.schema.names == ['a'] + assert record_batch_reader.schema.names == ["a"] chunk = record_batch_reader.read_next_batch() assert len(chunk) == 1024 chunk = record_batch_reader.read_next_batch() @@ -90,7 +90,7 @@ def test_record_batch_next_batch_struct(self, duckdb_cursor): ) query = duckdb_cursor.execute("SELECT a FROM t") record_batch_reader = query.fetch_record_batch(1024) - assert record_batch_reader.schema.names == ['a'] + assert record_batch_reader.schema.names == ["a"] chunk = record_batch_reader.read_next_batch() assert len(chunk) == 1024 chunk = record_batch_reader.read_next_batch() @@ -115,7 +115,7 @@ def test_record_batch_next_batch_list(self, duckdb_cursor): duckdb_cursor.execute("CREATE table t as select [i,i+1] as a from range(3000) as tbl(i);") query = duckdb_cursor.execute("SELECT a FROM t") record_batch_reader = query.fetch_record_batch(1024) - assert record_batch_reader.schema.names == ['a'] + assert record_batch_reader.schema.names == ["a"] chunk = record_batch_reader.read_next_batch() assert len(chunk) == 1024 chunk = record_batch_reader.read_next_batch() @@ -141,7 +141,7 @@ def test_record_batch_next_batch_map(self, duckdb_cursor): duckdb_cursor.execute("CREATE table t as select map([i], [i+1]) as a from range(3000) as tbl(i);") query = duckdb_cursor.execute("SELECT a FROM t") record_batch_reader = query.fetch_record_batch(1024) - assert record_batch_reader.schema.names == ['a'] + assert record_batch_reader.schema.names == ["a"] chunk = record_batch_reader.read_next_batch() assert len(chunk) == 1024 chunk = record_batch_reader.read_next_batch() @@ -169,7 +169,7 @@ def test_record_batch_next_batch_with_null(self, duckdb_cursor): ) query = duckdb_cursor.execute("SELECT a FROM t") record_batch_reader = query.fetch_record_batch(1024) - assert record_batch_reader.schema.names == ['a'] + assert record_batch_reader.schema.names == ["a"] chunk = record_batch_reader.read_next_batch() assert len(chunk) == 1024 chunk = record_batch_reader.read_next_batch() @@ -224,15 +224,15 @@ def test_record_batch_next_batch_multiple_vectors_per_chunk_error(self, duckdb_c duckdb_cursor = duckdb.connect() duckdb_cursor.execute("CREATE table t as select range a from range(5000);") query = duckdb_cursor.execute("SELECT a FROM t") - with pytest.raises(RuntimeError, match='Approximate Batch Size of Record Batch MUST be higher than 0'): + with pytest.raises(RuntimeError, match="Approximate Batch Size of Record Batch MUST be higher than 0"): record_batch_reader = query.fetch_record_batch(0) - with pytest.raises(TypeError, match='incompatible function arguments'): + with pytest.raises(TypeError, match="incompatible function arguments"): record_batch_reader = query.fetch_record_batch(-1) def test_record_batch_reader_from_relation(self, duckdb_cursor): duckdb_cursor = duckdb.connect() duckdb_cursor.execute("CREATE table t as select range a from range(3000);") - relation = duckdb_cursor.table('t') + relation = duckdb_cursor.table("t") record_batch_reader = relation.record_batch() chunk = record_batch_reader.read_next_batch() assert len(chunk) == 3000 @@ -249,7 +249,7 @@ def test_record_coverage(self, duckdb_cursor): def test_record_batch_query_error(self): duckdb_cursor = duckdb.connect() duckdb_cursor.execute("CREATE table t as select 'foo' as a;") - with pytest.raises(duckdb.ConversionException, match='Conversion Error'): + with pytest.raises(duckdb.ConversionException, match="Conversion Error"): # 'execute' materializes the result, causing the error directly query = duckdb_cursor.execute("SELECT cast(a as double) FROM t") record_batch_reader = query.fetch_record_batch(1024) @@ -282,7 +282,7 @@ def test_many_chunk_sizes(self): record_batch_reader = query.fetch_record_batch(i) num_loops = int(object_size / i) for j in range(num_loops): - assert record_batch_reader.schema.names == ['a'] + assert record_batch_reader.schema.names == ["a"] chunk = record_batch_reader.read_next_batch() assert len(chunk) == i remainder = object_size % i diff --git a/tests/fast/arrow/test_arrow_fixed_binary.py b/tests/fast/arrow/test_arrow_fixed_binary.py index aa0047a8..cec8d520 100644 --- a/tests/fast/arrow/test_arrow_fixed_binary.py +++ b/tests/fast/arrow/test_arrow_fixed_binary.py @@ -7,8 +7,8 @@ class TestArrowFixedBinary(object): def test_arrow_fixed_binary(self, duckdb_cursor): ids = [ None, - b'\x66\x4d\xf4\xae\xb1\x5c\xb0\x4a\xdd\x5d\x1d\x54', - b'\x66\x4d\xf4\xf0\xa3\xfc\xec\x5b\x26\x81\x4e\x1d', + b"\x66\x4d\xf4\xae\xb1\x5c\xb0\x4a\xdd\x5d\x1d\x54", + b"\x66\x4d\xf4\xf0\xa3\xfc\xec\x5b\x26\x81\x4e\x1d", ] id_array = pa.array(ids, type=pa.binary(12)) @@ -18,4 +18,4 @@ def test_arrow_fixed_binary(self, duckdb_cursor): SELECT lower(hex(id)) as id FROM arrow_table """ ).fetchall() - assert res == [(None,), ('664df4aeb15cb04add5d1d54',), ('664df4f0a3fcec5b26814e1d',)] + assert res == [(None,), ("664df4aeb15cb04add5d1d54",), ("664df4f0a3fcec5b26814e1d",)] diff --git a/tests/fast/arrow/test_arrow_ipc.py b/tests/fast/arrow/test_arrow_ipc.py index 1d71eaa4..24718bbc 100644 --- a/tests/fast/arrow/test_arrow_ipc.py +++ b/tests/fast/arrow/test_arrow_ipc.py @@ -1,14 +1,14 @@ import pytest import duckdb -pa = pytest.importorskip('pyarrow') +pa = pytest.importorskip("pyarrow") -ipc = pytest.importorskip('pyarrow.ipc') +ipc = pytest.importorskip("pyarrow.ipc") def get_record_batch(): - data = [pa.array([1, 2, 3, 4]), pa.array(['foo', 'bar', 'baz', None]), pa.array([True, None, False, True])] - return pa.record_batch(data, names=['f0', 'f1', 'f2']) + data = [pa.array([1, 2, 3, 4]), pa.array(["foo", "bar", "baz", None]), pa.array([True, None, False, True])] + return pa.record_batch(data, names=["f0", "f1", "f2"]) class TestArrowIPCExtension(object): diff --git a/tests/fast/arrow/test_arrow_list.py b/tests/fast/arrow/test_arrow_list.py index 556f614a..47b8cb2a 100644 --- a/tests/fast/arrow/test_arrow_list.py +++ b/tests/fast/arrow/test_arrow_list.py @@ -21,7 +21,7 @@ def create_and_register_arrow_table(column_list, duckdb_cursor): def create_and_register_comparison_result(column_list, duckdb_cursor): - columns = ",".join([f'{name} {dtype}' for (name, dtype, _) in column_list]) + columns = ",".join([f"{name} {dtype}" for (name, dtype, _) in column_list]) column_amount = len(column_list) assert column_amount row_amount = len(column_list[0][2]) @@ -31,7 +31,7 @@ def create_and_register_comparison_result(column_list, duckdb_cursor): inserted_values.append(column_list[col][2][row]) inserted_values = tuple(inserted_values) - column_format = ",".join(['?' for _ in range(column_amount)]) + column_format = ",".join(["?" for _ in range(column_amount)]) row_format = ",".join([f"({column_format})" for _ in range(row_amount)]) query = f"""CREATE TABLE test ({columns}); INSERT INTO test VALUES {row_format}; @@ -73,7 +73,7 @@ def generate_list(child_size) -> ListGenerationResult: # Create a regular ListArray list_arr = pa.ListArray.from_arrays(offsets=offsets, values=input, mask=pa.array(mask, type=pa.bool_())) - if not hasattr(pa, 'ListViewArray'): + if not hasattr(pa, "ListViewArray"): return ListGenerationResult(list_arr, None) lists = list(reversed(lists)) @@ -102,13 +102,13 @@ def test_regular_list(self, duckdb_cursor): create_and_register_arrow_table( [ - ('a', list_type, data), + ("a", list_type, data), ], duckdb_cursor, ) create_and_register_comparison_result( [ - ('a', 'FLOAT[]', data), + ("a", "FLOAT[]", data), ], duckdb_cursor, ) @@ -125,26 +125,26 @@ def test_fixedsize_list(self, duckdb_cursor): create_and_register_arrow_table( [ - ('a', list_type, data), + ("a", list_type, data), ], duckdb_cursor, ) create_and_register_comparison_result( [ - ('a', f'FLOAT[{list_size}]', data), + ("a", f"FLOAT[{list_size}]", data), ], duckdb_cursor, ) check_equal(duckdb_cursor) - @pytest.mark.skipif(not hasattr(pa, 'ListViewArray'), reason='The pyarrow version does not support ListViewArrays') - @pytest.mark.parametrize('child_size', [100000]) + @pytest.mark.skipif(not hasattr(pa, "ListViewArray"), reason="The pyarrow version does not support ListViewArrays") + @pytest.mark.parametrize("child_size", [100000]) def test_list_view(self, duckdb_cursor, child_size): res = generate_list(child_size) - list_tbl = pa.Table.from_arrays([res.list], ['x']) - list_view_tbl = pa.Table.from_arrays([res.list_view], ['x']) + list_tbl = pa.Table.from_arrays([res.list], ["x"]) + list_view_tbl = pa.Table.from_arrays([res.list_view], ["x"]) assert res.list_view.to_pylist() == res.list.to_pylist() original = duckdb_cursor.query("select * from list_tbl").fetchall() diff --git a/tests/fast/arrow/test_arrow_offsets.py b/tests/fast/arrow/test_arrow_offsets.py index 6bc94530..0ddc0f7d 100644 --- a/tests/fast/arrow/test_arrow_offsets.py +++ b/tests/fast/arrow/test_arrow_offsets.py @@ -62,7 +62,7 @@ def decimal_value(value, precision, scale): val = str(value) actual_width = precision - scale if len(val) > actual_width: - return decimal.Decimal('9' * actual_width) + return decimal.Decimal("9" * actual_width) return decimal.Decimal(val) @@ -76,7 +76,7 @@ def expected_result(col1_null, col2_null, expected): null_test_parameters = lambda: mark.parametrize( - ['col1_null', 'col2_null'], [(False, True), (True, False), (True, True), (False, False)] + ["col1_null", "col2_null"], [(False, True), (True, False), (True, True), (False, False)] ) @@ -100,10 +100,10 @@ def test_struct_of_strings(self, duckdb_cursor, col1_null, col2_null): SELECT col1, col2.a - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() - assert res == expected_result(col1_null, col2_null, '131072') + assert res == expected_result(col1_null, col2_null, "131072") @null_test_parameters() def test_struct_of_bools(self, duckdb_cursor, col1_null, col2_null): @@ -126,7 +126,7 @@ def test_struct_of_bools(self, duckdb_cursor, col1_null, col2_null): SELECT col1, col2.a - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() assert res == expected_result(col1_null, col2_null, True) @@ -158,7 +158,7 @@ def test_struct_of_dates(self, duckdb_cursor, constructor, expected, col1_null, SELECT col1, col2.a - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() assert res == expected_result(col1_null, col2_null, expected) @@ -167,8 +167,8 @@ def test_struct_of_dates(self, duckdb_cursor, constructor, expected, col1_null, def test_struct_of_enum(self, duckdb_cursor, col1_null, col2_null): enum_type = pa.dictionary(pa.int64(), pa.utf8()) - tuples = ['red' for i in range(MAGIC_ARRAY_SIZE)] - tuples[-1] = 'green' + tuples = ["red" for i in range(MAGIC_ARRAY_SIZE)] + tuples[-1] = "green" if col1_null: tuples[-1] = None @@ -177,7 +177,7 @@ def test_struct_of_enum(self, duckdb_cursor, col1_null, col2_null): struct_tuples[-1] = None arrow_table = pa.Table.from_pydict( - {'col1': pa.array(tuples, enum_type), 'col2': pa.array(struct_tuples, pa.struct({"a": enum_type}))}, + {"col1": pa.array(tuples, enum_type), "col2": pa.array(struct_tuples, pa.struct({"a": enum_type}))}, schema=pa.schema([("col1", enum_type), ("col2", pa.struct({"a": enum_type}))]), ) res = duckdb_cursor.sql( @@ -185,10 +185,10 @@ def test_struct_of_enum(self, duckdb_cursor, col1_null, col2_null): SELECT col1, col2.a - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() - assert res == expected_result(col1_null, col2_null, 'green') + assert res == expected_result(col1_null, col2_null, "green") @null_test_parameters() def test_struct_of_blobs(self, duckdb_cursor, col1_null, col2_null): @@ -209,24 +209,24 @@ def test_struct_of_blobs(self, duckdb_cursor, col1_null, col2_null): SELECT col1, col2.a - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() - assert res == expected_result(col1_null, col2_null, b'131072') + assert res == expected_result(col1_null, col2_null, b"131072") @null_test_parameters() @pytest.mark.parametrize( ["constructor", "unit", "expected"], [ - (pa_time32(), 'ms', datetime.time(0, 2, 11, 72000)), - (pa_time32(), 's', datetime.time(23, 59, 59)), - (pa_time64(), 'ns', datetime.time(0, 0, 0, 131)), - (pa_time64(), 'us', datetime.time(0, 0, 0, 131072)), + (pa_time32(), "ms", datetime.time(0, 2, 11, 72000)), + (pa_time32(), "s", datetime.time(23, 59, 59)), + (pa_time64(), "ns", datetime.time(0, 0, 0, 131)), + (pa_time64(), "us", datetime.time(0, 0, 0, 131072)), ], ) def test_struct_of_time(self, duckdb_cursor, constructor, unit, expected, col1_null, col2_null): size = MAGIC_ARRAY_SIZE - if unit == 's': + if unit == "s": # FIXME: We limit the size because we don't support time values > 24 hours size = 86400 # The amount of seconds in a day @@ -247,7 +247,7 @@ def test_struct_of_time(self, duckdb_cursor, constructor, unit, expected, col1_n SELECT col1, col2.a - FROM arrow_table offset {size-1} + FROM arrow_table offset {size - 1} """ ).fetchall() assert res == expected_result(col1_null, col2_null, expected) @@ -282,7 +282,7 @@ def test_struct_of_interval(self, duckdb_cursor, constructor, expected, converte SELECT col1, col2.a - FROM arrow_table offset {size-1} + FROM arrow_table offset {size - 1} """ ).fetchall() assert res == expected_result(col1_null, col2_null, expected) @@ -291,10 +291,10 @@ def test_struct_of_interval(self, duckdb_cursor, constructor, expected, converte @pytest.mark.parametrize( ["constructor", "unit", "expected"], [ - (pa_duration(), 'ms', datetime.timedelta(seconds=131, microseconds=72000)), - (pa_duration(), 's', datetime.timedelta(days=1, seconds=44672)), - (pa_duration(), 'ns', datetime.timedelta(microseconds=131)), - (pa_duration(), 'us', datetime.timedelta(microseconds=131072)), + (pa_duration(), "ms", datetime.timedelta(seconds=131, microseconds=72000)), + (pa_duration(), "s", datetime.timedelta(days=1, seconds=44672)), + (pa_duration(), "ns", datetime.timedelta(microseconds=131)), + (pa_duration(), "us", datetime.timedelta(microseconds=131072)), ], ) def test_struct_of_duration(self, duckdb_cursor, constructor, unit, expected, col1_null, col2_null): @@ -317,7 +317,7 @@ def test_struct_of_duration(self, duckdb_cursor, constructor, unit, expected, co SELECT col1, col2.a - FROM arrow_table offset {size-1} + FROM arrow_table offset {size - 1} """ ).fetchall() assert res == expected_result(col1_null, col2_null, expected) @@ -326,10 +326,10 @@ def test_struct_of_duration(self, duckdb_cursor, constructor, unit, expected, co @pytest.mark.parametrize( ["constructor", "unit", "expected"], [ - (pa_timestamp(), 'ms', datetime.datetime(1970, 1, 1, 0, 2, 11, 72000, tzinfo=pytz.utc)), - (pa_timestamp(), 's', datetime.datetime(1970, 1, 2, 12, 24, 32, 0, tzinfo=pytz.utc)), - (pa_timestamp(), 'ns', datetime.datetime(1970, 1, 1, 0, 0, 0, 131, tzinfo=pytz.utc)), - (pa_timestamp(), 'us', datetime.datetime(1970, 1, 1, 0, 0, 0, 131072, tzinfo=pytz.utc)), + (pa_timestamp(), "ms", datetime.datetime(1970, 1, 1, 0, 2, 11, 72000, tzinfo=pytz.utc)), + (pa_timestamp(), "s", datetime.datetime(1970, 1, 2, 12, 24, 32, 0, tzinfo=pytz.utc)), + (pa_timestamp(), "ns", datetime.datetime(1970, 1, 1, 0, 0, 0, 131, tzinfo=pytz.utc)), + (pa_timestamp(), "us", datetime.datetime(1970, 1, 1, 0, 0, 0, 131072, tzinfo=pytz.utc)), ], ) def test_struct_of_timestamp_tz(self, duckdb_cursor, constructor, unit, expected, col1_null, col2_null): @@ -346,7 +346,7 @@ def test_struct_of_timestamp_tz(self, duckdb_cursor, constructor, unit, expected arrow_table = pa.Table.from_pydict( {"col1": col1, "col2": col2}, schema=pa.schema( - [("col1", constructor(unit, 'UTC')), ("col2", pa.struct({"a": constructor(unit, 'UTC')}))] + [("col1", constructor(unit, "UTC")), ("col2", pa.struct({"a": constructor(unit, "UTC")}))] ), ) @@ -355,7 +355,7 @@ def test_struct_of_timestamp_tz(self, duckdb_cursor, constructor, unit, expected SELECT col1, col2.a - FROM arrow_table offset {size-1} + FROM arrow_table offset {size - 1} """ ).fetchall() assert res == expected_result(col1_null, col2_null, expected) @@ -379,23 +379,23 @@ def test_struct_of_large_blobs(self, duckdb_cursor, col1_null, col2_null): SELECT col1, col2.a - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() - assert res == expected_result(col1_null, col2_null, b'131072') + assert res == expected_result(col1_null, col2_null, b"131072") @null_test_parameters() @pytest.mark.parametrize( ["precision_scale", "expected"], [ - ((38, 37), decimal.Decimal('9.0000000000000000000000000000000000000')), - ((38, 24), decimal.Decimal('131072.000000000000000000000000')), - ((18, 14), decimal.Decimal('9999.00000000000000')), - ((18, 5), decimal.Decimal('131072.00000')), - ((9, 7), decimal.Decimal('99.0000000')), - ((9, 3), decimal.Decimal('131072.000')), - ((4, 2), decimal.Decimal('99.00')), - ((4, 0), decimal.Decimal('9999')), + ((38, 37), decimal.Decimal("9.0000000000000000000000000000000000000")), + ((38, 24), decimal.Decimal("131072.000000000000000000000000")), + ((18, 14), decimal.Decimal("9999.00000000000000")), + ((18, 5), decimal.Decimal("131072.00000")), + ((9, 7), decimal.Decimal("99.0000000")), + ((9, 3), decimal.Decimal("131072.000")), + ((4, 2), decimal.Decimal("99.00")), + ((4, 0), decimal.Decimal("9999")), ], ) def test_struct_of_decimal(self, duckdb_cursor, precision_scale, expected, col1_null, col2_null): @@ -420,7 +420,7 @@ def test_struct_of_decimal(self, duckdb_cursor, precision_scale, expected, col1_ SELECT col1, col2.a - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() assert res == expected_result(col1_null, col2_null, expected) @@ -443,16 +443,16 @@ def test_struct_of_small_list(self, duckdb_cursor, col1_null, col2_null): SELECT col1, col2.a - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() - res1 = None if col1_null else '131072' + res1 = None if col1_null else "131072" if col2_null: res2 = None elif col1_null: res2 = [None, None, None] else: - res2 = ['131072', '131072', '131072'] + res2 = ["131072", "131072", "131072"] assert res == [(res1, res2)] @null_test_parameters() @@ -473,16 +473,16 @@ def test_struct_of_fixed_size_list(self, duckdb_cursor, col1_null, col2_null): SELECT col1, col2.a - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() - res1 = None if col1_null else '131072' + res1 = None if col1_null else "131072" if col2_null: res2 = None elif col1_null: res2 = (None, None, None) else: - res2 = ('131072', '131072', '131072') + res2 = ("131072", "131072", "131072") assert res == [(res1, res2)] @null_test_parameters() @@ -504,16 +504,16 @@ def test_struct_of_fixed_size_blob(self, duckdb_cursor, col1_null, col2_null): SELECT col1, col2.a - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() - res1 = None if col1_null else b'131072' + res1 = None if col1_null else b"131072" if col2_null: res2 = None elif col1_null: res2 = (None, None, None) else: - res2 = (b'131072', b'131073', b'131074') + res2 = (b"131072", b"131073", b"131074") assert res == [(res1, res2)] @null_test_parameters() @@ -535,16 +535,16 @@ def test_struct_of_list_of_blobs(self, duckdb_cursor, col1_null, col2_null): SELECT col1, col2.a - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() - res1 = None if col1_null else b'131072' + res1 = None if col1_null else b"131072" if col2_null: res2 = None elif col1_null: res2 = [None, None, None] else: - res2 = [b'131072', b'131073', b'131074'] + res2 = [b"131072", b"131073", b"131074"] assert res == [(res1, res2)] @null_test_parameters() @@ -566,7 +566,7 @@ def test_struct_of_list_of_list(self, duckdb_cursor, col1_null, col2_null): SELECT col1, col2.a - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() res1 = None if col1_null else 131072 @@ -578,7 +578,7 @@ def test_struct_of_list_of_list(self, duckdb_cursor, col1_null, col2_null): res2 = [[131072, 131072, 131072], [], None, [131072]] assert res == [(res1, res2)] - @pytest.mark.parametrize('col1_null', [True, False]) + @pytest.mark.parametrize("col1_null", [True, False]) def test_list_of_struct(self, duckdb_cursor, col1_null): # One single tuple containing a very big list tuples = [{"a": i} for i in range(0, MAGIC_ARRAY_SIZE)] @@ -599,19 +599,19 @@ def test_list_of_struct(self, duckdb_cursor, col1_null): res = res[0][0] for i, x in enumerate(res[:-1]): assert x.__class__ == dict - assert x['a'] == i + assert x["a"] == i if col1_null: assert res[-1] == None else: - assert res[-1]['a'] == len(res) - 1 + assert res[-1]["a"] == len(res) - 1 - @pytest.mark.parametrize(['outer_null', 'inner_null'], [(True, False), (False, True)]) + @pytest.mark.parametrize(["outer_null", "inner_null"], [(True, False), (False, True)]) def test_list_of_list_of_struct(self, duckdb_cursor, outer_null, inner_null): tuples = [[[{"a": str(i), "b": None, "c": [i]}]] for i in range(MAGIC_ARRAY_SIZE)] if outer_null: tuples[-1] = None else: - inner = [[{"a": 'aaaaaaaaaaaaaaa', "b": 'test', "c": [1, 2, 3]}] for _ in range(MAGIC_ARRAY_SIZE)] + inner = [[{"a": "aaaaaaaaaaaaaaa", "b": "test", "c": [1, 2, 3]}] for _ in range(MAGIC_ARRAY_SIZE)] if inner_null: inner[-1] = None tuples[-1] = inner @@ -635,7 +635,7 @@ def test_list_of_list_of_struct(self, duckdb_cursor, outer_null, inner_null): f""" SELECT col1 - FROM arrow_table OFFSET {MAGIC_ARRAY_SIZE-1} + FROM arrow_table OFFSET {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() if outer_null: @@ -646,7 +646,7 @@ def test_list_of_list_of_struct(self, duckdb_cursor, outer_null, inner_null): else: assert res[-1][-1][-1] == 131072 - @pytest.mark.parametrize('col1_null', [True, False]) + @pytest.mark.parametrize("col1_null", [True, False]) def test_struct_of_list(self, duckdb_cursor, col1_null): # All elements are of size 1 tuples = [{"a": [str(i)]} for i in range(MAGIC_ARRAY_SIZE)] @@ -664,13 +664,13 @@ def test_struct_of_list(self, duckdb_cursor, col1_null): f""" SELECT col1 - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchone() if col1_null: assert res[0] == None else: - assert res[0]['a'][-1] == '131072' + assert res[0]["a"][-1] == "131072" def test_bools_with_offset(self, duckdb_cursor): bools = [False, False, False, False, True, False, False, False, False, False] diff --git a/tests/fast/arrow/test_arrow_pycapsule.py b/tests/fast/arrow/test_arrow_pycapsule.py index 8310c58b..6df5053f 100644 --- a/tests/fast/arrow/test_arrow_pycapsule.py +++ b/tests/fast/arrow/test_arrow_pycapsule.py @@ -8,11 +8,11 @@ def polars_supports_capsule(): from packaging.version import Version - return Version(pl.__version__) >= Version('1.4.1') + return Version(pl.__version__) >= Version("1.4.1") @pytest.mark.skipif( - not polars_supports_capsule(), reason='Polars version does not support the Arrow PyCapsule interface' + not polars_supports_capsule(), reason="Polars version does not support the Arrow PyCapsule interface" ) class TestArrowPyCapsule(object): def test_polars_pycapsule_scan(self, duckdb_cursor): @@ -25,7 +25,7 @@ def __arrow_c_stream__(self, requested_schema=None) -> object: self.count += 1 return self.obj.__arrow_c_stream__(requested_schema=requested_schema) - df = pl.DataFrame({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8]}) + df = pl.DataFrame({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]}) obj = MyObject(df) # Call the __arrow_c_stream__ from within DuckDB diff --git a/tests/fast/arrow/test_arrow_recordbatchreader.py b/tests/fast/arrow/test_arrow_recordbatchreader.py index 0f8a701d..a9523d43 100644 --- a/tests/fast/arrow/test_arrow_recordbatchreader.py +++ b/tests/fast/arrow/test_arrow_recordbatchreader.py @@ -10,11 +10,10 @@ class TestArrowRecordBatchReader(object): def test_parallel_reader(self, duckdb_cursor): - duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") userdata_parquet_dataset = pyarrow.dataset.dataset( [ @@ -31,19 +30,16 @@ def test_parallel_reader(self, duckdb_cursor): rel = duckdb_conn.from_arrow(reader) assert ( - rel.filter("first_name=\'Jose\' and salary > 134708.82").aggregate('count(*)').execute().fetchone()[0] == 12 + rel.filter("first_name='Jose' and salary > 134708.82").aggregate("count(*)").execute().fetchone()[0] == 12 ) # The reader is already consumed so this should be 0 - assert ( - rel.filter("first_name=\'Jose\' and salary > 134708.82").aggregate('count(*)').execute().fetchone()[0] == 0 - ) + assert rel.filter("first_name='Jose' and salary > 134708.82").aggregate("count(*)").execute().fetchone()[0] == 0 def test_parallel_reader_replacement_scans(self, duckdb_cursor): - duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") userdata_parquet_dataset = pyarrow.dataset.dataset( [ @@ -59,23 +55,22 @@ def test_parallel_reader_replacement_scans(self, duckdb_cursor): assert ( duckdb_conn.execute( - "select count(*) r1 from reader where first_name=\'Jose\' and salary > 134708.82" + "select count(*) r1 from reader where first_name='Jose' and salary > 134708.82" ).fetchone()[0] == 12 ) assert ( duckdb_conn.execute( - "select count(*) r2 from reader where first_name=\'Jose\' and salary > 134708.82" + "select count(*) r2 from reader where first_name='Jose' and salary > 134708.82" ).fetchone()[0] == 0 ) def test_parallel_reader_register(self, duckdb_cursor): - duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") userdata_parquet_dataset = pyarrow.dataset.dataset( [ @@ -92,21 +87,16 @@ def test_parallel_reader_register(self, duckdb_cursor): duckdb_conn.register("bla", reader) assert ( - duckdb_conn.execute("select count(*) from bla where first_name=\'Jose\' and salary > 134708.82").fetchone()[ - 0 - ] + duckdb_conn.execute("select count(*) from bla where first_name='Jose' and salary > 134708.82").fetchone()[0] == 12 ) assert ( - duckdb_conn.execute("select count(*) from bla where first_name=\'Jose\' and salary > 134708.82").fetchone()[ - 0 - ] + duckdb_conn.execute("select count(*) from bla where first_name='Jose' and salary > 134708.82").fetchone()[0] == 0 ) def test_parallel_reader_default_conn(self, duckdb_cursor): - - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") userdata_parquet_dataset = pyarrow.dataset.dataset( [ @@ -123,9 +113,7 @@ def test_parallel_reader_default_conn(self, duckdb_cursor): rel = duckdb.from_arrow(reader) assert ( - rel.filter("first_name=\'Jose\' and salary > 134708.82").aggregate('count(*)').execute().fetchone()[0] == 12 + rel.filter("first_name='Jose' and salary > 134708.82").aggregate("count(*)").execute().fetchone()[0] == 12 ) # The reader is already consumed so this should be 0 - assert ( - rel.filter("first_name=\'Jose\' and salary > 134708.82").aggregate('count(*)').execute().fetchone()[0] == 0 - ) + assert rel.filter("first_name='Jose' and salary > 134708.82").aggregate("count(*)").execute().fetchone()[0] == 0 diff --git a/tests/fast/arrow/test_arrow_replacement_scan.py b/tests/fast/arrow/test_arrow_replacement_scan.py index a02bac10..f2a9c13b 100644 --- a/tests/fast/arrow/test_arrow_replacement_scan.py +++ b/tests/fast/arrow/test_arrow_replacement_scan.py @@ -10,8 +10,7 @@ class TestArrowReplacementScan(object): def test_arrow_table_replacement_scan(self, duckdb_cursor): - - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") userdata_parquet_table = pq.read_table(parquet_filename) df = userdata_parquet_table.to_pandas() @@ -22,11 +21,11 @@ def test_arrow_table_replacement_scan(self, duckdb_cursor): assert con.execute("select count(*) from df").fetchone() == (1000,) @pytest.mark.skipif( - not hasattr(pa.Table, '__arrow_c_stream__'), - reason='This version of pyarrow does not support the Arrow Capsule Interface', + not hasattr(pa.Table, "__arrow_c_stream__"), + reason="This version of pyarrow does not support the Arrow Capsule Interface", ) def test_arrow_pycapsule_replacement_scan(self, duckdb_cursor): - tbl = pa.Table.from_pydict({'a': [1, 2, 3, 4, 5, 6, 7, 8, 9]}) + tbl = pa.Table.from_pydict({"a": [1, 2, 3, 4, 5, 6, 7, 8, 9]}) capsule = tbl.__arrow_c_stream__() rel = duckdb_cursor.sql("select * from capsule") @@ -36,13 +35,13 @@ def test_arrow_pycapsule_replacement_scan(self, duckdb_cursor): rel = duckdb_cursor.sql("select * from capsule where a > 3 and a < 5") assert rel.fetchall() == [(4,)] - tbl = pa.Table.from_pydict({'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [7, 8, 9], 'd': [10, 11, 12]}) + tbl = pa.Table.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9], "d": [10, 11, 12]}) capsule = tbl.__arrow_c_stream__() rel = duckdb_cursor.sql("select b, d from capsule") assert rel.fetchall() == [(i, i + 6) for i in range(4, 7)] - with pytest.raises(duckdb.InvalidInputException, match='The ArrowArrayStream was already released'): + with pytest.raises(duckdb.InvalidInputException, match="The ArrowArrayStream was already released"): rel = duckdb_cursor.sql("select b, d from capsule") schema_obj = tbl.schema @@ -53,19 +52,18 @@ def test_arrow_pycapsule_replacement_scan(self, duckdb_cursor): rel = duckdb_cursor.sql("select b, d from schema_capsule") def test_arrow_table_replacement_scan_view(self, duckdb_cursor): - - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") userdata_parquet_table = pq.read_table(parquet_filename) con = duckdb.connect() con.execute("create view x as select * from userdata_parquet_table") del userdata_parquet_table - with pytest.raises(duckdb.CatalogException, match='Table with name userdata_parquet_table does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name userdata_parquet_table does not exist"): assert con.execute("select count(*) from x").fetchone() def test_arrow_dataset_replacement_scan(self, duckdb_cursor): - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") userdata_parquet_table = pq.read_table(parquet_filename) userdata_parquet_dataset = ds.dataset(parquet_filename) diff --git a/tests/fast/arrow/test_arrow_run_end_encoding.py b/tests/fast/arrow/test_arrow_run_end_encoding.py index 6315d1b7..c6f9fad5 100644 --- a/tests/fast/arrow/test_arrow_run_end_encoding.py +++ b/tests/fast/arrow/test_arrow_run_end_encoding.py @@ -3,7 +3,7 @@ import pandas as pd import duckdb -pa = pytest.importorskip("pyarrow", '21.0.0', reason="Needs pyarrow >= 21") +pa = pytest.importorskip("pyarrow", "21.0.0", reason="Needs pyarrow >= 21") pc = pytest.importorskip("pyarrow.compute") @@ -25,14 +25,14 @@ def create_list(offsets, values): def list_constructors(): result = [] result.append(create_list) - if hasattr(pa, 'ListViewArray'): + if hasattr(pa, "ListViewArray"): result.append(create_list_view) return result class TestArrowREE(object): @pytest.mark.parametrize( - 'query', + "query", [ """ select @@ -46,22 +46,22 @@ class TestArrowREE(object): """, ], ) - @pytest.mark.parametrize('run_length', [4, 1, 10, 1000, 2048, 3000]) - @pytest.mark.parametrize('size', [100, 10000]) + @pytest.mark.parametrize("run_length", [4, 1, 10, 1000, 2048, 3000]) + @pytest.mark.parametrize("size", [100, 10000]) @pytest.mark.parametrize( - 'value_type', - ['UTINYINT', 'USMALLINT', 'UINTEGER', 'UBIGINT', 'TINYINT', 'SMALLINT', 'INTEGER', 'BIGINT', 'HUGEINT'], + "value_type", + ["UTINYINT", "USMALLINT", "UINTEGER", "UBIGINT", "TINYINT", "SMALLINT", "INTEGER", "BIGINT", "HUGEINT"], ) def test_arrow_run_end_encoding_numerics(self, duckdb_cursor, query, run_length, size, value_type): - if value_type == 'UTINYINT': + if value_type == "UTINYINT": if size > 255: size = 255 - if value_type == 'TINYINT': + if value_type == "TINYINT": if size > 127: size = 127 query = query.format(run_length, value_type, size) rel = duckdb_cursor.sql(query) - array = rel.fetch_arrow_table()['ree'] + array = rel.fetch_arrow_table()["ree"] expected = rel.fetchall() encoded_array = pc.run_end_encode(array) @@ -72,31 +72,31 @@ def test_arrow_run_end_encoding_numerics(self, duckdb_cursor, query, run_length, assert res == expected @pytest.mark.parametrize( - ['dbtype', 'val1', 'val2'], + ["dbtype", "val1", "val2"], [ - ('TINYINT', '(-128)', '127'), - ('SMALLINT', '(-32768)', '32767'), - ('INTEGER', '(-2147483648)', '2147483647'), - ('BIGINT', '(-9223372036854775808)', '9223372036854775807'), - ('UTINYINT', '0', '255'), - ('USMALLINT', '0', '65535'), - ('UINTEGER', '0', '4294967295'), - ('UBIGINT', '0', '18446744073709551615'), - ('BOOL', 'true', 'false'), - ('VARCHAR', "'test'", "'this is a long string'"), - ('BLOB', "'\\xE0\\x9F\\x98\\x84'", "'\\xF0\\x9F\\xA6\\x86'"), - ('DATE', "'1992-03-27'", "'2204-11-01'"), - ('TIME', "'01:02:03'", "'23:41:35'"), - ('TIMESTAMP_S', "'1992-03-22 01:02:03'", "'2022-11-07 08:43:04.123456'"), - ('TIMESTAMP', "'1992-03-22 01:02:03'", "'2022-11-07 08:43:04.123456'"), - ('TIMESTAMP_MS', "'1992-03-22 01:02:03'", "'2022-11-07 08:43:04.123456'"), - ('TIMESTAMP_NS', "'1992-03-22 01:02:03'", "'2022-11-07 08:43:04.123456'"), - ('DECIMAL(4,2)', "'12.23'", "'99.99'"), - ('DECIMAL(7,6)', "'1.234234'", "'0.000001'"), - ('DECIMAL(14,7)', "'134523.234234'", "'999999.000001'"), - ('DECIMAL(28,1)', "'12345678910111234123456789.1'", "'999999999999999999999999999.9'"), - ('UUID', "'10acd298-15d7-417c-8b59-eabb5a2bacab'", "'eeccb8c5-9943-b2bb-bb5e-222f4e14b687'"), - ('BIT', "'01010101010000'", "'01010100010101010101010101111111111'"), + ("TINYINT", "(-128)", "127"), + ("SMALLINT", "(-32768)", "32767"), + ("INTEGER", "(-2147483648)", "2147483647"), + ("BIGINT", "(-9223372036854775808)", "9223372036854775807"), + ("UTINYINT", "0", "255"), + ("USMALLINT", "0", "65535"), + ("UINTEGER", "0", "4294967295"), + ("UBIGINT", "0", "18446744073709551615"), + ("BOOL", "true", "false"), + ("VARCHAR", "'test'", "'this is a long string'"), + ("BLOB", "'\\xE0\\x9F\\x98\\x84'", "'\\xF0\\x9F\\xA6\\x86'"), + ("DATE", "'1992-03-27'", "'2204-11-01'"), + ("TIME", "'01:02:03'", "'23:41:35'"), + ("TIMESTAMP_S", "'1992-03-22 01:02:03'", "'2022-11-07 08:43:04.123456'"), + ("TIMESTAMP", "'1992-03-22 01:02:03'", "'2022-11-07 08:43:04.123456'"), + ("TIMESTAMP_MS", "'1992-03-22 01:02:03'", "'2022-11-07 08:43:04.123456'"), + ("TIMESTAMP_NS", "'1992-03-22 01:02:03'", "'2022-11-07 08:43:04.123456'"), + ("DECIMAL(4,2)", "'12.23'", "'99.99'"), + ("DECIMAL(7,6)", "'1.234234'", "'0.000001'"), + ("DECIMAL(14,7)", "'134523.234234'", "'999999.000001'"), + ("DECIMAL(28,1)", "'12345678910111234123456789.1'", "'999999999999999999999999999.9'"), + ("UUID", "'10acd298-15d7-417c-8b59-eabb5a2bacab'", "'eeccb8c5-9943-b2bb-bb5e-222f4e14b687'"), + ("BIT", "'01010101010000'", "'01010100010101010101010101111111111'"), ], ) @pytest.mark.parametrize( @@ -107,7 +107,7 @@ def test_arrow_run_end_encoding_numerics(self, duckdb_cursor, query, run_length, ], ) def test_arrow_run_end_encoding(self, duckdb_cursor, dbtype, val1, val2, filter): - if dbtype in ['BIT', 'UUID']: + if dbtype in ["BIT", "UUID"]: pytest.skip("BIT and UUID are currently broken (FIXME)") projection = "a, b, ree" query = """ @@ -135,25 +135,25 @@ def test_arrow_run_end_encoding(self, duckdb_cursor, dbtype, val1, val2, filter) # Create an Arrow Table from the table arrow_conversion = rel.fetch_arrow_table() arrays = { - 'ree': arrow_conversion['ree'], - 'a': arrow_conversion['a'], - 'b': arrow_conversion['b'], + "ree": arrow_conversion["ree"], + "a": arrow_conversion["a"], + "b": arrow_conversion["b"], } encoded_arrays = { - 'ree': pc.run_end_encode(arrays['ree']), - 'a': pc.run_end_encode(arrays['a']), - 'b': pc.run_end_encode(arrays['b']), + "ree": pc.run_end_encode(arrays["ree"]), + "a": pc.run_end_encode(arrays["a"]), + "b": pc.run_end_encode(arrays["b"]), } schema = pa.schema( [ - ("ree", encoded_arrays['ree'].type), - ("a", encoded_arrays['a'].type), - ("b", encoded_arrays['b'].type), + ("ree", encoded_arrays["ree"].type), + ("a", encoded_arrays["a"].type), + ("b", encoded_arrays["b"].type), ] ) - tbl = pa.Table.from_arrays([encoded_arrays['ree'], encoded_arrays['a'], encoded_arrays['b']], schema=schema) + tbl = pa.Table.from_arrays([encoded_arrays["ree"], encoded_arrays["a"], encoded_arrays["b"]], schema=schema) # Scan the Arrow Table and verify that the results are the same res = duckdb_cursor.sql("select {} from tbl where {}".format(projection, filter)).fetchall() @@ -161,8 +161,8 @@ def test_arrow_run_end_encoding(self, duckdb_cursor, dbtype, val1, val2, filter) def test_arrow_ree_empty_table(self, duckdb_cursor): duckdb_cursor.query("create table tbl (ree integer)") - rel = duckdb_cursor.table('tbl') - array = rel.fetch_arrow_table()['ree'] + rel = duckdb_cursor.table("tbl") + array = rel.fetch_arrow_table()["ree"] expected = rel.fetchall() encoded_array = pc.run_end_encode(array) @@ -172,7 +172,7 @@ def test_arrow_ree_empty_table(self, duckdb_cursor): res = duckdb_cursor.sql("select * from pa_res").fetchall() assert res == expected - @pytest.mark.parametrize('projection', ['*', 'a, c, b', 'ree, a, b, c', 'c, b, a, ree', 'c', 'b, ree, c, a']) + @pytest.mark.parametrize("projection", ["*", "a, c, b", "ree, a, b, c", "c, b, a, ree", "c", "b, ree, c, a"]) def test_arrow_ree_projections(self, duckdb_cursor, projection): # Create the schema duckdb_cursor.query( @@ -199,28 +199,28 @@ def test_arrow_ree_projections(self, duckdb_cursor, projection): ) # Fetch the result as an Arrow Table - result = duckdb_cursor.table('tbl').fetch_arrow_table() + result = duckdb_cursor.table("tbl").fetch_arrow_table() # Turn 'ree' into a run-end-encoded array and reconstruct a table from it arrays = { - 'ree': pc.run_end_encode(result['ree']), - 'a': result['a'], - 'b': result['b'], - 'c': result['c'], + "ree": pc.run_end_encode(result["ree"]), + "a": result["a"], + "b": result["b"], + "c": result["c"], } schema = pa.schema( [ - ("ree", arrays['ree'].type), - ("a", arrays['a'].type), - ("b", arrays['b'].type), - ("c", arrays['c'].type), + ("ree", arrays["ree"].type), + ("a", arrays["a"].type), + ("b", arrays["b"].type), + ("c", arrays["c"].type), ] ) - arrow_tbl = pa.Table.from_arrays([arrays['ree'], arrays['a'], arrays['b'], arrays['c']], schema=schema) + arrow_tbl = pa.Table.from_arrays([arrays["ree"], arrays["a"], arrays["b"], arrays["c"]], schema=schema) # Verify that the array is run end encoded - ar_type = arrow_tbl['ree'].type + ar_type = arrow_tbl["ree"].type assert pa.types.is_run_end_encoded(ar_type) == True # Scan the arrow table, making projections that don't cover the entire table @@ -229,9 +229,7 @@ def test_arrow_ree_projections(self, duckdb_cursor, projection): res = duckdb_cursor.query( """ select {} from arrow_tbl - """.format( - projection - ) + """.format(projection) ).fetch_arrow_table() # Verify correctness by fetching from the original table and the constructed result @@ -239,7 +237,7 @@ def test_arrow_ree_projections(self, duckdb_cursor, projection): actual = duckdb_cursor.query("select {} from res".format(projection)).fetchall() assert expected == actual - @pytest.mark.parametrize('create_list', list_constructors()) + @pytest.mark.parametrize("create_list", list_constructors()) def test_arrow_ree_list(self, duckdb_cursor, create_list): size = 1000 duckdb_cursor.query( @@ -248,9 +246,7 @@ def test_arrow_ree_list(self, duckdb_cursor, create_list): as select i // 4 as ree, FROM range({}) t(i) - """.format( - size - ) + """.format(size) ) # Populate the table with data @@ -281,7 +277,7 @@ def test_arrow_ree_list(self, duckdb_cursor, create_list): structured = pa.chunked_array(structured_chunks) - arrow_tbl = pa.Table.from_arrays([structured], names=['ree']) + arrow_tbl = pa.Table.from_arrays([structured], names=["ree"]) result = duckdb_cursor.query("select * from arrow_tbl").fetch_arrow_table() assert arrow_tbl.to_pylist() == result.to_pylist() @@ -317,7 +313,7 @@ def test_arrow_ree_struct(self, duckdb_cursor): structured_chunks = [pa.StructArray.from_arrays([y for y in x], names=names) for x in zipped] structured = pa.chunked_array(structured_chunks) - arrow_tbl = pa.Table.from_arrays([structured], names=['ree']) + arrow_tbl = pa.Table.from_arrays([structured], names=["ree"]) result = duckdb_cursor.query("select * from arrow_tbl").fetch_arrow_table() expected = duckdb_cursor.query("select {'ree': ree, 'a': a, 'b': b, 'c': c} as s from tbl").fetchall() @@ -337,9 +333,7 @@ def test_arrow_ree_union(self, duckdb_cursor): i % 2 == 0 as b, i::VARCHAR as c FROM range({}) t(i) - """.format( - size - ) + """.format(size) ) # Populate the table with data @@ -368,7 +362,7 @@ def test_arrow_ree_union(self, duckdb_cursor): structured_chunks.append(new_array) structured = pa.chunked_array(structured_chunks) - arrow_tbl = pa.Table.from_arrays([structured], names=['ree']) + arrow_tbl = pa.Table.from_arrays([structured], names=["ree"]) result = duckdb_cursor.query("select * from arrow_tbl").fetch_arrow_table() # Recreate the same result set @@ -395,9 +389,7 @@ def test_arrow_ree_map(self, duckdb_cursor): i // 4 as ree, i as a, FROM range({}) t(i) - """.format( - size - ) + """.format(size) ) # Populate the table with data @@ -431,7 +423,7 @@ def test_arrow_ree_map(self, duckdb_cursor): structured_chunks.append(new_array) structured = pa.chunked_array(structured_chunks) - arrow_tbl = pa.Table.from_arrays([structured], names=['ree']) + arrow_tbl = pa.Table.from_arrays([structured], names=["ree"]) result = duckdb_cursor.query("select * from arrow_tbl").fetch_arrow_table() # Verify that the resulting scan is the same as the input @@ -446,9 +438,7 @@ def test_arrow_ree_dictionary(self, duckdb_cursor): as select i // 4 as ree, FROM range({}) t(i) - """.format( - size - ) + """.format(size) ) # Populate the table with data @@ -473,7 +463,7 @@ def test_arrow_ree_dictionary(self, duckdb_cursor): structured_chunks.append(new_array) structured = pa.chunked_array(structured_chunks) - arrow_tbl = pa.Table.from_arrays([structured], names=['ree']) + arrow_tbl = pa.Table.from_arrays([structured], names=["ree"]) result = duckdb_cursor.query("select * from arrow_tbl").fetch_arrow_table() # Verify that the resulting scan is the same as the input diff --git a/tests/fast/arrow/test_arrow_scanner.py b/tests/fast/arrow/test_arrow_scanner.py index 6d74ddb5..2e8b1296 100644 --- a/tests/fast/arrow/test_arrow_scanner.py +++ b/tests/fast/arrow/test_arrow_scanner.py @@ -22,7 +22,7 @@ def test_parallel_scanner(self, duckdb_cursor): duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") arrow_dataset = pyarrow.dataset.dataset( [ @@ -33,13 +33,13 @@ def test_parallel_scanner(self, duckdb_cursor): format="parquet", ) - scanner_filter = (pc.field("first_name") == pc.scalar('Jose')) & (pc.field("salary") > pc.scalar(134708.82)) + scanner_filter = (pc.field("first_name") == pc.scalar("Jose")) & (pc.field("salary") > pc.scalar(134708.82)) arrow_scanner = Scanner.from_dataset(arrow_dataset, filter=scanner_filter) rel = duckdb_conn.from_arrow(arrow_scanner) - assert rel.aggregate('count(*)').execute().fetchone()[0] == 12 + assert rel.aggregate("count(*)").execute().fetchone()[0] == 12 def test_parallel_scanner_replacement_scans(self, duckdb_cursor): if not can_run: @@ -48,7 +48,7 @@ def test_parallel_scanner_replacement_scans(self, duckdb_cursor): duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") arrow_dataset = pyarrow.dataset.dataset( [ @@ -59,7 +59,7 @@ def test_parallel_scanner_replacement_scans(self, duckdb_cursor): format="parquet", ) - scanner_filter = (pc.field("first_name") == pc.scalar('Jose')) & (pc.field("salary") > pc.scalar(134708.82)) + scanner_filter = (pc.field("first_name") == pc.scalar("Jose")) & (pc.field("salary") > pc.scalar(134708.82)) arrow_scanner = Scanner.from_dataset(arrow_dataset, filter=scanner_filter) @@ -72,7 +72,7 @@ def test_parallel_scanner_register(self, duckdb_cursor): duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") arrow_dataset = pyarrow.dataset.dataset( [ @@ -83,7 +83,7 @@ def test_parallel_scanner_register(self, duckdb_cursor): format="parquet", ) - scanner_filter = (pc.field("first_name") == pc.scalar('Jose')) & (pc.field("salary") > pc.scalar(134708.82)) + scanner_filter = (pc.field("first_name") == pc.scalar("Jose")) & (pc.field("salary") > pc.scalar(134708.82)) arrow_scanner = Scanner.from_dataset(arrow_dataset, filter=scanner_filter) @@ -95,7 +95,7 @@ def test_parallel_scanner_default_conn(self, duckdb_cursor): if not can_run: return - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") arrow_dataset = pyarrow.dataset.dataset( [ @@ -106,10 +106,10 @@ def test_parallel_scanner_default_conn(self, duckdb_cursor): format="parquet", ) - scanner_filter = (pc.field("first_name") == pc.scalar('Jose')) & (pc.field("salary") > pc.scalar(134708.82)) + scanner_filter = (pc.field("first_name") == pc.scalar("Jose")) & (pc.field("salary") > pc.scalar(134708.82)) arrow_scanner = Scanner.from_dataset(arrow_dataset, filter=scanner_filter) rel = duckdb.from_arrow(arrow_scanner) - assert rel.aggregate('count(*)').execute().fetchone()[0] == 12 + assert rel.aggregate("count(*)").execute().fetchone()[0] == 12 diff --git a/tests/fast/arrow/test_arrow_string_view.py b/tests/fast/arrow/test_arrow_string_view.py index fc4bbd40..a1b46e5b 100644 --- a/tests/fast/arrow/test_arrow_string_view.py +++ b/tests/fast/arrow/test_arrow_string_view.py @@ -2,10 +2,10 @@ import pytest from packaging import version -pa = pytest.importorskip('pyarrow') +pa = pytest.importorskip("pyarrow") pytestmark = pytest.mark.skipif( - not hasattr(pa, 'string_view'), reason="This version of PyArrow does not support StringViews" + not hasattr(pa, "string_view"), reason="This version of PyArrow does not support StringViews" ) @@ -20,7 +20,7 @@ def RoundTripStringView(query, array): # Generate an arrow table # Create a field for the array with a specific data type - field = pa.field('str_val', pa.string_view()) + field = pa.field("str_val", pa.string_view()) # Create a schema for the table using the field schema = pa.schema([field]) @@ -103,26 +103,26 @@ def test_not_inlined_string_view(self): # Test Over-Vector Size def test_large_string_view_inlined(self): - RoundTripDuckDBInternal('''select * from (SELECT i::varchar str FROM range(10000) tbl(i)) order by str''') + RoundTripDuckDBInternal("""select * from (SELECT i::varchar str FROM range(10000) tbl(i)) order by str""") def test_large_string_view_inlined_with_null(self): RoundTripDuckDBInternal( - '''select * from (SELECT i::varchar str FROM range(10000) tbl(i) UNION select null) order by str''' + """select * from (SELECT i::varchar str FROM range(10000) tbl(i) UNION select null) order by str""" ) def test_large_string_view_not_inlined(self): RoundTripDuckDBInternal( - '''select * from (SELECT 'Imaverybigstringmuchbiggerthanfourbytes'||i::varchar str FROM range(10000) tbl(i) UNION select null) order by str''' + """select * from (SELECT 'Imaverybigstringmuchbiggerthanfourbytes'||i::varchar str FROM range(10000) tbl(i) UNION select null) order by str""" ) def test_large_string_view_not_inlined_with_null(self): RoundTripDuckDBInternal( - '''select * from (SELECT 'Imaverybigstringmuchbiggerthanfourbytes'||i::varchar str FROM range(10000) tbl(i) UNION select null) order by str''' + """select * from (SELECT 'Imaverybigstringmuchbiggerthanfourbytes'||i::varchar str FROM range(10000) tbl(i) UNION select null) order by str""" ) def test_large_string_view_mixed_with_null(self): RoundTripDuckDBInternal( - '''select * from (SELECT i::varchar str FROM range(10000) tbl(i) UNION SELECT 'Imaverybigstringmuchbiggerthanfourbytes'||i::varchar str FROM range(10000) tbl(i) UNION select null) order by str''' + """select * from (SELECT i::varchar str FROM range(10000) tbl(i) UNION SELECT 'Imaverybigstringmuchbiggerthanfourbytes'||i::varchar str FROM range(10000) tbl(i) UNION select null) order by str""" ) def test_multiple_data_buffers(self): diff --git a/tests/fast/arrow/test_arrow_types.py b/tests/fast/arrow/test_arrow_types.py index 97f747ef..f2bf71c7 100644 --- a/tests/fast/arrow/test_arrow_types.py +++ b/tests/fast/arrow/test_arrow_types.py @@ -17,7 +17,7 @@ def test_null_type(self, duckdb_cursor): inputs = [pa.array([None, None, None], type=pa.null())] arrow_table = pa.Table.from_arrays(inputs, schema=schema) - assert rel['data'] == arrow_table['data'] + assert rel["data"] == arrow_table["data"] def test_invalid_struct(self, duckdb_cursor): empty_struct_type = pa.struct([]) @@ -27,7 +27,7 @@ def test_invalid_struct(self, duckdb_cursor): arrow_table = pa.Table.from_arrays([empty_array], schema=pa.schema([("data", empty_struct_type)])) with pytest.raises( duckdb.InvalidInputException, - match='Attempted to convert a STRUCT with no fields to DuckDB which is not supported', + match="Attempted to convert a STRUCT with no fields to DuckDB which is not supported", ): duckdb_cursor.sql("select * from arrow_table").fetchall() @@ -39,9 +39,9 @@ def test_invalid_union(self, duckdb_cursor): arrow_table = pa.Table.from_arrays([sparse_union_array], schema=pa.schema([("data", sparse_union_array.type)])) with pytest.raises( duckdb.InvalidInputException, - match='Attempted to convert a UNION with no fields to DuckDB which is not supported', + match="Attempted to convert a UNION with no fields to DuckDB which is not supported", ): - duckdb_cursor.register('invalid_union', arrow_table) + duckdb_cursor.register("invalid_union", arrow_table) res = duckdb_cursor.sql("select * from invalid_union").fetchall() print(res) diff --git a/tests/fast/arrow/test_arrow_union.py b/tests/fast/arrow/test_arrow_union.py index 1d853a1b..c0a5d568 100644 --- a/tests/fast/arrow/test_arrow_union.py +++ b/tests/fast/arrow/test_arrow_union.py @@ -1,13 +1,13 @@ from pytest import importorskip -importorskip('pyarrow') +importorskip("pyarrow") import duckdb from pyarrow import scalar, string, large_string, list_, int32, types def test_nested(duckdb_cursor): - res = run(duckdb_cursor, 'select 42::UNION(name VARCHAR, attr UNION(age INT, veteran BOOL)) as res') + res = run(duckdb_cursor, "select 42::UNION(name VARCHAR, attr UNION(age INT, veteran BOOL)) as res") assert types.is_union(res.type) assert res.value.value == scalar(42, type=int32()) @@ -16,14 +16,14 @@ def test_union_contains_nested_data(duckdb_cursor): _ = importorskip("pyarrow", minversion="11") res = run(duckdb_cursor, "select ['hello']::UNION(first_name VARCHAR, middle_names VARCHAR[]) as res") assert types.is_union(res.type) - assert res.value == scalar(['hello'], type=list_(string())) + assert res.value == scalar(["hello"], type=list_(string())) def test_unions_inside_lists_structs_maps(duckdb_cursor): res = run(duckdb_cursor, "select [union_value(name := 'Frank')] as res") assert types.is_list(res.type) assert types.is_union(res.type.value_type) - assert res[0].value == scalar('Frank', type=string()) + assert res[0].value == scalar("Frank", type=string()) def test_unions_with_struct(duckdb_cursor): @@ -38,13 +38,13 @@ def test_unions_with_struct(duckdb_cursor): """ ) - rel = duckdb_cursor.table('tbl') + rel = duckdb_cursor.table("tbl") arrow = rel.fetch_arrow_table() duckdb_cursor.execute("create table other as select * from arrow") - rel2 = duckdb_cursor.table('other') + rel2 = duckdb_cursor.table("other") res = rel2.fetchall() - assert res == [({'a': 42, 'b': True},)] + assert res == [({"a": 42, "b": True},)] def run(conn, query): diff --git a/tests/fast/arrow/test_arrow_version_format.py b/tests/fast/arrow/test_arrow_version_format.py index ff8699eb..fd169ce0 100644 --- a/tests/fast/arrow/test_arrow_version_format.py +++ b/tests/fast/arrow/test_arrow_version_format.py @@ -32,20 +32,20 @@ def test_decimal_v1_5(self, duckdb_cursor): ) col_type = duckdb_cursor.execute("FROM decimal_64").fetch_arrow_table().schema.field("data").type assert col_type.bit_width == 64 and pa.types.is_decimal(col_type) - for version in ['1.0', '1.1', '1.2', '1.3', '1.4']: + for version in ["1.0", "1.1", "1.2", "1.3", "1.4"]: duckdb_cursor.execute(f"SET arrow_output_version = {version}") result = duckdb_cursor.execute("FROM decimal_32").fetch_arrow_table() col_type = result.schema.field("data").type assert col_type.bit_width == 128 and pa.types.is_decimal(col_type) assert result.to_pydict() == { - 'data': [Decimal('100.20'), Decimal('110.21'), Decimal('31.20'), Decimal('500.20')] + "data": [Decimal("100.20"), Decimal("110.21"), Decimal("31.20"), Decimal("500.20")] } result = duckdb_cursor.execute("FROM decimal_64").fetch_arrow_table() col_type = result.schema.field("data").type assert col_type.bit_width == 128 and pa.types.is_decimal(col_type) assert result.to_pydict() == { - 'data': [Decimal('1000.231'), Decimal('1100.231'), Decimal('999999999999.231'), Decimal('500.200')] + "data": [Decimal("1000.231"), Decimal("1100.231"), Decimal("999999999999.231"), Decimal("500.200")] } def test_invalide_opt(self, duckdb_cursor): @@ -63,14 +63,14 @@ def test_view_v1_4(self, duckdb_cursor): col_type = duckdb_cursor.execute("SELECT ['string'] as data ").fetch_arrow_table().schema.field("data").type assert pa.types.is_list_view(col_type) - for version in ['1.0', '1.1', '1.2', '1.3']: + for version in ["1.0", "1.1", "1.2", "1.3"]: duckdb_cursor.execute(f"SET arrow_output_version = {version}") col_type = duckdb_cursor.execute("SELECT 'string' as data ").fetch_arrow_table().schema.field("data").type assert not pa.types.is_string_view(col_type) col_type = duckdb_cursor.execute("SELECT ['string'] as data ").fetch_arrow_table().schema.field("data").type assert not pa.types.is_list_view(col_type) - for version in ['1.4', '1.5']: + for version in ["1.4", "1.5"]: duckdb_cursor.execute(f"SET arrow_output_version = {version}") col_type = duckdb_cursor.execute("SELECT 'string' as data ").fetch_arrow_table().schema.field("data").type assert pa.types.is_string_view(col_type) @@ -80,7 +80,7 @@ def test_view_v1_4(self, duckdb_cursor): duckdb_cursor.execute("SET produce_arrow_string_view=False") duckdb_cursor.execute("SET arrow_output_list_view=False") - for version in ['1.4', '1.5']: + for version in ["1.4", "1.5"]: duckdb_cursor.execute(f"SET arrow_output_version = {version}") col_type = duckdb_cursor.execute("SELECT 'string' as data ").fetch_arrow_table().schema.field("data").type assert not pa.types.is_string_view(col_type) diff --git a/tests/fast/arrow/test_buffer_size_option.py b/tests/fast/arrow/test_buffer_size_option.py index 46047e21..7d5131e5 100644 --- a/tests/fast/arrow/test_buffer_size_option.py +++ b/tests/fast/arrow/test_buffer_size_option.py @@ -34,7 +34,7 @@ def just_return(x): return x con = duckdb.connect() - con.create_function('just_return', just_return, [VARCHAR], VARCHAR, type='arrow') + con.create_function("just_return", just_return, [VARCHAR], VARCHAR, type="arrow") res = con.query("select just_return('bla')").fetch_arrow_table() diff --git a/tests/fast/arrow/test_dataset.py b/tests/fast/arrow/test_dataset.py index 521ec8f7..8ec0094e 100644 --- a/tests/fast/arrow/test_dataset.py +++ b/tests/fast/arrow/test_dataset.py @@ -14,7 +14,7 @@ def test_parallel_dataset(self, duckdb_cursor): duckdb_conn.execute("PRAGMA threads=4") duckdb_conn.execute("PRAGMA verify_parallelism") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") userdata_parquet_dataset = pyarrow.dataset.dataset( [ @@ -28,7 +28,7 @@ def test_parallel_dataset(self, duckdb_cursor): rel = duckdb_conn.from_arrow(userdata_parquet_dataset) assert ( - rel.filter("first_name=\'Jose\' and salary > 134708.82").aggregate('count(*)').execute().fetchone()[0] == 12 + rel.filter("first_name='Jose' and salary > 134708.82").aggregate("count(*)").execute().fetchone()[0] == 12 ) def test_parallel_dataset_register(self, duckdb_cursor): @@ -36,7 +36,7 @@ def test_parallel_dataset_register(self, duckdb_cursor): duckdb_conn.execute("PRAGMA threads=4") duckdb_conn.execute("PRAGMA verify_parallelism") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") userdata_parquet_dataset = pyarrow.dataset.dataset( [ @@ -61,7 +61,7 @@ def test_parallel_dataset_roundtrip(self, duckdb_cursor): duckdb_conn.execute("PRAGMA threads=4") duckdb_conn.execute("PRAGMA verify_parallelism") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") userdata_parquet_dataset = pyarrow.dataset.dataset( [ @@ -79,7 +79,7 @@ def test_parallel_dataset_roundtrip(self, duckdb_cursor): arrow_table = record_batch_reader.read_all() # reorder since order of rows isn't deterministic - df = userdata_parquet_dataset.to_table().to_pandas().sort_values('id').reset_index(drop=True) + df = userdata_parquet_dataset.to_table().to_pandas().sort_values("id").reset_index(drop=True) # turn it into an arrow table arrow_table_2 = pyarrow.Table.from_pandas(df) result_1 = duckdb_conn.execute("select * from arrow_table order by all").fetchall() @@ -94,7 +94,7 @@ def test_ducktyping(self, duckdb_cursor): query = duckdb_conn.execute("SELECT b FROM dataset WHERE a < 5") record_batch_reader = query.fetch_record_batch(2048) arrow_table = record_batch_reader.read_all() - assert arrow_table.equals(CustomDataset.DATA[:5].select(['b'])) + assert arrow_table.equals(CustomDataset.DATA[:5].select(["b"])) class CustomDataset(pyarrow.dataset.Dataset): diff --git a/tests/fast/arrow/test_date.py b/tests/fast/arrow/test_date.py index 316fc689..9649ffa6 100644 --- a/tests/fast/arrow/test_date.py +++ b/tests/fast/arrow/test_date.py @@ -18,30 +18,30 @@ def test_date_types(self, duckdb_cursor): return data = (pa.array([1000 * 60 * 60 * 24], type=pa.date64()), pa.array([1], type=pa.date32())) - arrow_table = pa.Table.from_arrays([data[0], data[1]], ['a', 'b']) + arrow_table = pa.Table.from_arrays([data[0], data[1]], ["a", "b"]) rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() - assert rel['a'] == arrow_table['b'] - assert rel['b'] == arrow_table['b'] + assert rel["a"] == arrow_table["b"] + assert rel["b"] == arrow_table["b"] def test_date_null(self, duckdb_cursor): if not can_run: return data = (pa.array([None], type=pa.date64()), pa.array([None], type=pa.date32())) - arrow_table = pa.Table.from_arrays([data[0], data[1]], ['a', 'b']) + arrow_table = pa.Table.from_arrays([data[0], data[1]], ["a", "b"]) rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() - assert rel['a'] == arrow_table['b'] - assert rel['b'] == arrow_table['b'] + assert rel["a"] == arrow_table["b"] + assert rel["b"] == arrow_table["b"] def test_max_date(self, duckdb_cursor): if not can_run: return data = (pa.array([2147483647], type=pa.date32()), pa.array([2147483647], type=pa.date32())) - result = pa.Table.from_arrays([data[0], data[1]], ['a', 'b']) + result = pa.Table.from_arrays([data[0], data[1]], ["a", "b"]) data = ( pa.array([2147483647 * (1000 * 60 * 60 * 24)], type=pa.date64()), pa.array([2147483647], type=pa.date32()), ) - arrow_table = pa.Table.from_arrays([data[0], data[1]], ['a', 'b']) + arrow_table = pa.Table.from_arrays([data[0], data[1]], ["a", "b"]) rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() - assert rel['a'] == result['a'] - assert rel['b'] == result['b'] + assert rel["a"] == result["a"] + assert rel["b"] == result["b"] diff --git a/tests/fast/arrow/test_dictionary_arrow.py b/tests/fast/arrow/test_dictionary_arrow.py index 823d6b05..e4319f7c 100644 --- a/tests/fast/arrow/test_dictionary_arrow.py +++ b/tests/fast/arrow/test_dictionary_arrow.py @@ -17,7 +17,7 @@ def test_dictionary(self, duckdb_cursor): indices = pa.array([0, 1, 0, 1, 2, 1, 0, 2]) dictionary = pa.array([10, 100, None]) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array], ['a']) + arrow_table = pa.Table.from_arrays([dict_array], ["a"]) rel = duckdb_cursor.from_arrow(arrow_table) assert rel.execute().fetchall() == [(10,), (100,), (10,), (100,), (None,), (100,), (10,), (None,)] @@ -27,14 +27,14 @@ def test_dictionary(self, duckdb_cursor): indices = pa.array(indices_list) dictionary = pa.array([10, 100, None, 999999]) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array], ['a']) + arrow_table = pa.Table.from_arrays([dict_array], ["a"]) rel = duckdb_cursor.from_arrow(arrow_table) result = [(10,), (100,), (10,), (100,), (None,), (100,), (10,), (None,), (999999,)] * 10000 assert rel.execute().fetchall() == result # Table with dictionary and normal array - arrow_table = pa.Table.from_arrays([dict_array, pa.array(indices_list)], ['a', 'b']) + arrow_table = pa.Table.from_arrays([dict_array, pa.array(indices_list)], ["a", "b"]) rel = duckdb_cursor.from_arrow(arrow_table) result = [(10, 0), (100, 1), (10, 0), (100, 1), (None, 2), (100, 1), (10, 0), (None, 2), (999999, 3)] * 10000 assert rel.execute().fetchall() == result @@ -43,7 +43,7 @@ def test_dictionary_null_index(self, duckdb_cursor): indices = pa.array([None, 1, 0, 1, 2, 1, 0, 2]) dictionary = pa.array([10, 100, None]) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array], ['a']) + arrow_table = pa.Table.from_arrays([dict_array], ["a"]) rel = duckdb_cursor.from_arrow(arrow_table) assert rel.execute().fetchall() == [(None,), (100,), (10,), (100,), (None,), (100,), (10,), (None,)] @@ -51,7 +51,7 @@ def test_dictionary_null_index(self, duckdb_cursor): indices = pa.array([None, 1, None, 1, 2, 1, 0]) dictionary = pa.array([10, 100, 100]) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array], ['a']) + arrow_table = pa.Table.from_arrays([dict_array], ["a"]) rel = duckdb_cursor.from_arrow(arrow_table) print(rel.execute().fetchall()) assert rel.execute().fetchall() == [(None,), (100,), (None,), (100,), (100,), (100,), (10,)] @@ -61,19 +61,19 @@ def test_dictionary_null_index(self, duckdb_cursor): indices = pa.array(indices_list * 1000) dictionary = pa.array([10, 100, 100]) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array], ['a']) + arrow_table = pa.Table.from_arrays([dict_array], ["a"]) rel = duckdb_cursor.from_arrow(arrow_table) result = [(None,), (100,), (None,), (100,), (100,), (100,), (10,)] * 1000 assert rel.execute().fetchall() == result # Table with dictionary and normal array - arrow_table = pa.Table.from_arrays([dict_array, indices], ['a', 'b']) + arrow_table = pa.Table.from_arrays([dict_array, indices], ["a", "b"]) rel = duckdb_cursor.from_arrow(arrow_table) result = [(None, None), (100, 1), (None, None), (100, 1), (100, 2), (100, 1), (10, 0)] * 1000 assert rel.execute().fetchall() == result @pytest.mark.parametrize( - 'element', + "element", [ # list """ @@ -110,7 +110,7 @@ def test_dictionary_null_index(self, duckdb_cursor): ], ) @pytest.mark.parametrize( - 'count', + "count", [ 1, 10, @@ -123,14 +123,14 @@ def test_dictionary_null_index(self, duckdb_cursor): 5000, ], ) - @pytest.mark.parametrize('query', ["select {} as a from range({})", "select [{} for x in range({})] as a"]) + @pytest.mark.parametrize("query", ["select {} as a from range({})", "select [{} for x in range({})] as a"]) def test_dictionary_roundtrip(self, query, element, duckdb_cursor, count): query = query.format(element, count) original_rel = duckdb_cursor.sql(query) expected = original_rel.fetchall() arrow_res = original_rel.fetch_arrow_table() - roundtrip_rel = duckdb_cursor.sql('select * from arrow_res') + roundtrip_rel = duckdb_cursor.sql("select * from arrow_res") actual = roundtrip_rel.fetchall() assert expected == actual assert original_rel.columns == roundtrip_rel.columns @@ -142,14 +142,14 @@ def test_dictionary_batches(self, duckdb_cursor): indices = pa.array(indices_list * 10000) dictionary = pa.array([10, 100, 100]) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array], ['a']) + arrow_table = pa.Table.from_arrays([dict_array], ["a"]) batch_arrow_table = pa.Table.from_batches(arrow_table.to_batches(10)) rel = duckdb_cursor.from_arrow(batch_arrow_table) result = [(None,), (100,), (None,), (100,), (100,), (100,), (10,)] * 10000 assert rel.execute().fetchall() == result # Table with dictionary and normal array - arrow_table = pa.Table.from_arrays([dict_array, indices], ['a', 'b']) + arrow_table = pa.Table.from_arrays([dict_array, indices], ["a", "b"]) batch_arrow_table = pa.Table.from_batches(arrow_table.to_batches(10)) rel = duckdb_cursor.from_arrow(batch_arrow_table) result = [(None, None), (100, 1), (None, None), (100, 1), (100, 2), (100, 1), (10, 0)] * 10000 @@ -157,14 +157,14 @@ def test_dictionary_batches(self, duckdb_cursor): def test_dictionary_lifetime(self, duckdb_cursor): tables = [] - expected = '' + expected = "" for i in range(100): if i % 3 == 0: - input = 'ABCD' * 17000 + input = "ABCD" * 17000 elif i % 3 == 1: - input = 'FOOO' * 17000 + input = "FOOO" * 17000 else: - input = 'BARR' * 17000 + input = "BARR" * 17000 expected += input array = pa.array( input, @@ -186,14 +186,14 @@ def test_dictionary_batches_parallel(self, duckdb_cursor): indices = pa.array(indices_list * 10000) dictionary = pa.array([10, 100, 100]) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array], ['a']) + arrow_table = pa.Table.from_arrays([dict_array], ["a"]) batch_arrow_table = pa.Table.from_batches(arrow_table.to_batches(10)) rel = duckdb_cursor.from_arrow(batch_arrow_table) result = [(None,), (100,), (None,), (100,), (100,), (100,), (10,)] * 10000 assert rel.execute().fetchall() == result # Table with dictionary and normal array - arrow_table = pa.Table.from_arrays([dict_array, indices], ['a', 'b']) + arrow_table = pa.Table.from_arrays([dict_array, indices], ["a", "b"]) batch_arrow_table = pa.Table.from_batches(arrow_table.to_batches(10)) rel = duckdb_cursor.from_arrow(batch_arrow_table) result = [(None, None), (100, 1), (None, None), (100, 1), (100, 2), (100, 1), (10, 0)] * 10000 @@ -214,7 +214,7 @@ def test_dictionary_index_types(self, duckdb_cursor): for index_type in index_types: dict_array = pa.DictionaryArray.from_arrays(index_type, dictionary) - arrow_table = pa.Table.from_arrays([dict_array], ['a']) + arrow_table = pa.Table.from_arrays([dict_array], ["a"]) rel = duckdb_cursor.from_arrow(arrow_table) result = [(None,), (100,), (None,), (100,), (100,), (100,), (10,)] * 10000 assert rel.execute().fetchall() == result @@ -222,17 +222,17 @@ def test_dictionary_index_types(self, duckdb_cursor): def test_dictionary_strings(self, duckdb_cursor): indices_list = [None, 0, 1, 2, 3, 4, None] indices = pa.array(indices_list * 1000) - dictionary = pa.array(['Matt Daaaaaaaaamon', 'Alec Baldwin', 'Sean Penn', 'Tim Robbins', 'Samuel L. Jackson']) + dictionary = pa.array(["Matt Daaaaaaaaamon", "Alec Baldwin", "Sean Penn", "Tim Robbins", "Samuel L. Jackson"]) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array], ['a']) + arrow_table = pa.Table.from_arrays([dict_array], ["a"]) rel = duckdb_cursor.from_arrow(arrow_table) result = [ (None,), - ('Matt Daaaaaaaaamon',), - ('Alec Baldwin',), - ('Sean Penn',), - ('Tim Robbins',), - ('Samuel L. Jackson',), + ("Matt Daaaaaaaaamon",), + ("Alec Baldwin",), + ("Sean Penn",), + ("Tim Robbins",), + ("Samuel L. Jackson",), (None,), ] * 1000 assert rel.execute().fetchall() == result @@ -249,7 +249,7 @@ def test_dictionary_timestamps(self, duckdb_cursor): ] ) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array], ['a']) + arrow_table = pa.Table.from_arrays([dict_array], ["a"]) rel = duckdb_cursor.from_arrow(arrow_table) print(rel.execute().fetchall()) expected = [ diff --git a/tests/fast/arrow/test_filter_pushdown.py b/tests/fast/arrow/test_filter_pushdown.py index dffa9631..026b52f4 100644 --- a/tests/fast/arrow/test_filter_pushdown.py +++ b/tests/fast/arrow/test_filter_pushdown.py @@ -17,7 +17,7 @@ def create_pyarrow_pandas(rel): if not pandas_supports_arrow_backend(): pytest.skip(reason="Pandas version doesn't support 'pyarrow' backend") - return rel.df().convert_dtypes(dtype_backend='pyarrow') + return rel.df().convert_dtypes(dtype_backend="pyarrow") def create_pyarrow_table(rel): @@ -34,7 +34,7 @@ def test_decimal_filter_pushdown(duckdb_cursor): np = pytest.importorskip("numpy") np.random.seed(10) - df = pl.DataFrame({'x': pl.Series(np.random.uniform(-10, 10, 1000)).cast(pl.Decimal(precision=18, scale=4))}) + df = pl.DataFrame({"x": pl.Series(np.random.uniform(-10, 10, 1000)).cast(pl.Decimal(precision=18, scale=4))}) query = """ SELECT @@ -179,34 +179,33 @@ def string_check_or_pushdown(connection, tbl_name, create_table): class TestArrowFilterPushdown(object): - @pytest.mark.parametrize( - 'data_type', + "data_type", [ - 'TINYINT', - 'SMALLINT', - 'INTEGER', - 'BIGINT', - 'UTINYINT', - 'USMALLINT', - 'UINTEGER', - 'UBIGINT', - 'FLOAT', - 'DOUBLE', - 'HUGEINT', - 'DECIMAL(4,1)', - 'DECIMAL(9,1)', - 'DECIMAL(18,4)', - 'DECIMAL(30,12)', + "TINYINT", + "SMALLINT", + "INTEGER", + "BIGINT", + "UTINYINT", + "USMALLINT", + "UINTEGER", + "UBIGINT", + "FLOAT", + "DOUBLE", + "HUGEINT", + "DECIMAL(4,1)", + "DECIMAL(9,1)", + "DECIMAL(18,4)", + "DECIMAL(30,12)", ], ) - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_filter_pushdown_numeric(self, data_type, duckdb_cursor, create_table): tbl_name = "tbl" numeric_operators(duckdb_cursor, data_type, tbl_name, create_table) numeric_check_or_pushdown(duckdb_cursor, tbl_name, create_table) - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_filter_pushdown_varchar(self, duckdb_cursor, create_table): duckdb_cursor.execute( """ @@ -259,7 +258,7 @@ def test_filter_pushdown_varchar(self, duckdb_cursor, create_table): # More complex tests for OR pushed down on string string_check_or_pushdown(duckdb_cursor, "test_varchar", create_table) - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_filter_pushdown_bool(self, duckdb_cursor, create_table): duckdb_cursor.execute( """ @@ -294,7 +293,7 @@ def test_filter_pushdown_bool(self, duckdb_cursor, create_table): # Try Or assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a = True or b = True").fetchone()[0] == 3 - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_filter_pushdown_time(self, duckdb_cursor, create_table): duckdb_cursor.execute( """ @@ -352,7 +351,7 @@ def test_filter_pushdown_time(self, duckdb_cursor, create_table): == 2 ) - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_filter_pushdown_timestamp(self, duckdb_cursor, create_table): duckdb_cursor.execute( """ @@ -422,7 +421,7 @@ def test_filter_pushdown_timestamp(self, duckdb_cursor, create_table): == 2 ) - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_filter_pushdown_timestamp_TZ(self, duckdb_cursor, create_table): duckdb_cursor.execute( """ @@ -494,18 +493,18 @@ def test_filter_pushdown_timestamp_TZ(self, duckdb_cursor, create_table): == 2 ) - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) @pytest.mark.parametrize( - ['data_type', 'value'], + ["data_type", "value"], [ - ['TINYINT', 127], - ['SMALLINT', 32767], - ['INTEGER', 2147483647], - ['BIGINT', 9223372036854775807], - ['UTINYINT', 255], - ['USMALLINT', 65535], - ['UINTEGER', 4294967295], - ['UBIGINT', 18446744073709551615], + ["TINYINT", 127], + ["SMALLINT", 32767], + ["INTEGER", 2147483647], + ["BIGINT", 9223372036854775807], + ["UTINYINT", 255], + ["USMALLINT", 65535], + ["UINTEGER", 4294967295], + ["UBIGINT", 18446744073709551615], ], ) def test_filter_pushdown_integers(self, duckdb_cursor, data_type, value, create_table): @@ -514,9 +513,9 @@ def test_filter_pushdown_integers(self, duckdb_cursor, data_type, value, create_ CREATE TABLE tbl as select {value}::{data_type} as i """ ) - expected = duckdb_cursor.table('tbl').fetchall() + expected = duckdb_cursor.table("tbl").fetchall() filter = "i > 0" - rel = duckdb_cursor.table('tbl') + rel = duckdb_cursor.table("tbl") arrow_table = create_table(rel) actual = duckdb_cursor.sql(f"select * from arrow_table where {filter}").fetchall() assert expected == actual @@ -529,7 +528,7 @@ def test_filter_pushdown_integers(self, duckdb_cursor, data_type, value, create_ assert expected == actual @pytest.mark.skipif( - Version(pa.__version__) < Version('15.0.0'), reason="pyarrow 14.0.2 'to_pandas' causes a DeprecationWarning" + Version(pa.__version__) < Version("15.0.0"), reason="pyarrow 14.0.2 'to_pandas' causes a DeprecationWarning" ) def test_9371(self, duckdb_cursor, tmp_path): import datetime @@ -546,7 +545,7 @@ def test_9371(self, duckdb_cursor, tmp_path): # Example data dt = datetime.datetime(2023, 8, 29, 1, tzinfo=datetime.timezone.utc) - my_arrow_table = pa.Table.from_pydict({'ts': [dt, dt, dt], 'value': [1, 2, 3]}) + my_arrow_table = pa.Table.from_pydict({"ts": [dt, dt, dt], "value": [1, 2, 3]}) df = my_arrow_table.to_pandas() df = df.set_index("ts") # SET INDEX! (It all works correctly when the index is not set) df.to_parquet(str(file_path)) @@ -557,7 +556,7 @@ def test_9371(self, duckdb_cursor, tmp_path): expected = [(1, dt), (2, dt), (3, dt)] assert output == expected - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_filter_pushdown_date(self, duckdb_cursor, create_table): duckdb_cursor.execute( """ @@ -617,15 +616,15 @@ def test_filter_pushdown_date(self, duckdb_cursor, create_table): == 2 ) - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_filter_pushdown_blob(self, duckdb_cursor, create_table): import pandas df = pandas.DataFrame( { - 'a': [bytes([1]), bytes([2]), bytes([3]), None], - 'b': [bytes([1]), bytes([2]), bytes([3]), None], - 'c': [bytes([1]), bytes([2]), bytes([3]), None], + "a": [bytes([1]), bytes([2]), bytes([3]), None], + "b": [bytes([1]), bytes([2]), bytes([3]), None], + "c": [bytes([1]), bytes([2]), bytes([3]), None], } ) rel = duckdb.from_df(df) @@ -660,7 +659,7 @@ def test_filter_pushdown_blob(self, duckdb_cursor, create_table): duckdb_cursor.execute("SELECT count(*) from arrow_table where a = '\x01' or b = '\x02'").fetchone()[0] == 2 ) - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table, create_pyarrow_dataset]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table, create_pyarrow_dataset]) def test_filter_pushdown_no_projection(self, duckdb_cursor, create_table): duckdb_cursor.execute( """ @@ -685,7 +684,7 @@ def test_filter_pushdown_no_projection(self, duckdb_cursor, create_table): assert duckdb_cursor.execute("SELECT * FROM arrow_table VALUES where a = 1").fetchall() == [(1, 1, 1)] - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_filter_pushdown_2145(self, duckdb_cursor, tmp_path, create_table): import pandas @@ -697,12 +696,12 @@ def test_filter_pushdown_2145(self, duckdb_cursor, tmp_path, create_table): df2 = pandas.DataFrame(np.random.randn(date2.shape[0], 5), columns=list("ABCDE")) df2["date"] = date2 - data1 = tmp_path / 'data1.parquet' - data2 = tmp_path / 'data2.parquet' + data1 = tmp_path / "data1.parquet" + data2 = tmp_path / "data2.parquet" duckdb_cursor.execute(f"copy (select * from df1) to '{data1.as_posix()}'") duckdb_cursor.execute(f"copy (select * from df2) to '{data2.as_posix()}'") - glob_pattern = tmp_path / 'data*.parquet' + glob_pattern = tmp_path / "data*.parquet" table = duckdb_cursor.read_parquet(glob_pattern.as_posix()).fetch_arrow_table() output_df = duckdb.arrow(table).filter("date > '2019-01-01'").df() @@ -710,7 +709,7 @@ def test_filter_pushdown_2145(self, duckdb_cursor, tmp_path, create_table): pandas.testing.assert_frame_equal(expected_df, output_df) # https://github.com/duckdb/duckdb/pull/4817/files#r1339973721 - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_filter_column_removal(self, duckdb_cursor, create_table): duckdb_cursor.execute( """ @@ -738,7 +737,7 @@ def test_filter_column_removal(self, duckdb_cursor, create_table): assert not match @pytest.mark.skipif(sys.version_info < (3, 9), reason="Requires python 3.9") - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_struct_filter_pushdown(self, duckdb_cursor, create_table): duckdb_cursor.execute( """ @@ -768,7 +767,7 @@ def test_struct_filter_pushdown(self, duckdb_cursor, create_table): ).fetchall() input = query_res[0][1] - if 'PANDAS_SCAN' in input: + if "PANDAS_SCAN" in input: pytest.skip(reason="This version of pandas does not produce an Arrow object") match = re.search(r".*ARROW_SCAN.*Filters:.*s\.a<2.*", input, flags=re.DOTALL) assert match @@ -809,7 +808,7 @@ def test_struct_filter_pushdown(self, duckdb_cursor, create_table): assert not match @pytest.mark.skipif(sys.version_info < (3, 9), reason="Requires python 3.9") - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_nested_struct_filter_pushdown(self, duckdb_cursor, create_table): duckdb_cursor.execute( """ @@ -838,15 +837,15 @@ def test_nested_struct_filter_pushdown(self, duckdb_cursor, create_table): ).fetchall() input = query_res[0][1] - if 'PANDAS_SCAN' in input: + if "PANDAS_SCAN" in input: pytest.skip(reason="This version of pandas does not produce an Arrow object") match = re.search(r".*ARROW_SCAN.*Filters:.*s\.a\.b<2.*", input, flags=re.DOTALL) assert match # Check that the filter is applied correctly assert duckdb_cursor.execute("SELECT * FROM arrow_table WHERE s.a.b < 2").fetchone()[0] == { - 'a': {'b': 1, 'c': False}, - 'd': {'e': 2, 'f': 'foo'}, + "a": {"b": 1, "c": False}, + "d": {"e": 2, "f": "foo"}, } query_res = duckdb_cursor.execute( @@ -866,8 +865,8 @@ def test_nested_struct_filter_pushdown(self, duckdb_cursor, create_table): # Check that the filter is applied correctly assert duckdb_cursor.execute("SELECT COUNT(*) FROM arrow_table WHERE s.a.c=true AND s.d.e=5").fetchone()[0] == 1 assert duckdb_cursor.execute("SELECT * FROM arrow_table WHERE s.a.c=true AND s.d.e=5").fetchone()[0] == { - 'a': {'b': None, 'c': True}, - 'd': {'e': 5, 'f': 'qux'}, + "a": {"b": None, "c": True}, + "d": {"e": 5, "f": "qux"}, } query_res = duckdb_cursor.execute( @@ -887,8 +886,8 @@ def test_nested_struct_filter_pushdown(self, duckdb_cursor, create_table): # Check that the filter is applied correctly assert duckdb_cursor.execute("SELECT * FROM arrow_table WHERE s.d.f = 'bar'").fetchone()[0] == { - 'a': {'b': 3, 'c': True}, - 'd': {'e': 4, 'f': 'bar'}, + "a": {"b": 3, "c": True}, + "d": {"e": 4, "f": "bar"}, } def test_filter_pushdown_not_supported(self): @@ -899,21 +898,21 @@ def test_filter_pushdown_not_supported(self): arrow_tbl = con.execute("FROM T").fetch_arrow_table() # No projection just unsupported filter - assert con.execute("from arrow_tbl where c == 3").fetchall() == [(3, '3', 3, 3)] + assert con.execute("from arrow_tbl where c == 3").fetchall() == [(3, "3", 3, 3)] # No projection unsupported + supported filter - assert con.execute("from arrow_tbl where c < 4 and a > 2").fetchall() == [(3, '3', 3, 3)] + assert con.execute("from arrow_tbl where c < 4 and a > 2").fetchall() == [(3, "3", 3, 3)] # No projection supported + unsupported + supported filter - assert con.execute("from arrow_tbl where a > 2 and c < 4 and b == '3' ").fetchall() == [(3, '3', 3, 3)] + assert con.execute("from arrow_tbl where a > 2 and c < 4 and b == '3' ").fetchall() == [(3, "3", 3, 3)] assert con.execute("from arrow_tbl where a > 2 and c < 4 and b == '0' ").fetchall() == [] # Projection with unsupported filter column + unsupported + supported filter - assert con.execute("select c, b from arrow_tbl where c < 4 and b == '3' and a > 2 ").fetchall() == [(3, '3')] - assert con.execute("select c, b from arrow_tbl where a > 2 and c < 4 and b == '3'").fetchall() == [(3, '3')] + assert con.execute("select c, b from arrow_tbl where c < 4 and b == '3' and a > 2 ").fetchall() == [(3, "3")] + assert con.execute("select c, b from arrow_tbl where a > 2 and c < 4 and b == '3'").fetchall() == [(3, "3")] # Projection without unsupported filter column + unsupported + supported filter - assert con.execute("select a, b from arrow_tbl where a > 2 and c < 4 and b == '3' ").fetchall() == [(3, '3')] + assert con.execute("select a, b from arrow_tbl where a > 2 and c < 4 and b == '3' ").fetchall() == [(3, "3")] # Lets also experiment with multiple unpush-able filters con.execute( @@ -924,7 +923,7 @@ def test_filter_pushdown_not_supported(self): assert con.execute( "select a, b from arrow_tbl where a > 2 and c < 40 and b == '28' and g > 15 and e < 30" - ).fetchall() == [(28, '28')] + ).fetchall() == [(28, "28")] def test_join_filter_pushdown(self, duckdb_cursor): duckdb_conn = duckdb.connect() @@ -951,18 +950,18 @@ def test_in_filter_pushdown(self, duckdb_cursor): def test_pushdown_of_optional_filter(self, duckdb_cursor): cardinality_table = pa.Table.from_pydict( { - 'column_name': [ - 'id', - 'product_code', - 'price', - 'quantity', - 'category', - 'is_available', - 'rating', - 'discount', - 'color', + "column_name": [ + "id", + "product_code", + "price", + "quantity", + "category", + "is_available", + "rating", + "discount", + "color", ], - 'cardinality': [100, 100, 100, 45, 5, 3, 6, 39, 5], + "cardinality": [100, 100, 100, 45, 5, 3, 6, 39, 5], } ) @@ -976,15 +975,15 @@ def test_pushdown_of_optional_filter(self, duckdb_cursor): ) res = result.fetchall() assert res == [ - ('is_available', 3), - ('category', 5), - ('color', 5), - ('rating', 6), - ('discount', 39), - ('quantity', 45), - ('id', 100), - ('product_code', 100), - ('price', 100), + ("is_available", 3), + ("category", 5), + ("color", 5), + ("rating", 6), + ("discount", 39), + ("quantity", 45), + ("id", 100), + ("product_code", 100), + ("price", 100), ] # DuckDB intentionally violates IEEE-754 when it comes to NaNs, ensuring a total ordering where NaN is the greatest value @@ -1002,11 +1001,11 @@ def test_nan_filter_pushdown(self, duckdb_cursor): ) def assert_equal_results(con, arrow_table, query): - duckdb_res = con.sql(query.format(table='test')).fetchall() - arrow_res = con.sql(query.format(table='arrow_table')).fetchall() + duckdb_res = con.sql(query.format(table="test")).fetchall() + arrow_res = con.sql(query.format(table="arrow_table")).fetchall() assert len(duckdb_res) == len(arrow_res) - arrow_table = duckdb_cursor.table('test').fetch_arrow_table() + arrow_table = duckdb_cursor.table("test").fetch_arrow_table() assert_equal_results(duckdb_cursor, arrow_table, "select * from {table} where a > 'NaN'::FLOAT") assert_equal_results(duckdb_cursor, arrow_table, "select * from {table} where a >= 'NaN'::FLOAT") assert_equal_results(duckdb_cursor, arrow_table, "select * from {table} where a < 'NaN'::FLOAT") diff --git a/tests/fast/arrow/test_integration.py b/tests/fast/arrow/test_integration.py index d9006758..6ab7350d 100644 --- a/tests/fast/arrow/test_integration.py +++ b/tests/fast/arrow/test_integration.py @@ -10,8 +10,8 @@ class TestArrowIntegration(object): def test_parquet_roundtrip(self, duckdb_cursor): - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') - cols = 'id, first_name, last_name, email, gender, ip_address, cc, country, birthdate, salary, title, comments' + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") + cols = "id, first_name, last_name, email, gender, ip_address, cc, country, birthdate, salary, title, comments" # TODO timestamp @@ -35,8 +35,8 @@ def test_parquet_roundtrip(self, duckdb_cursor): assert rel_from_arrow.equals(rel_from_duckdb, check_metadata=True) def test_unsigned_roundtrip(self, duckdb_cursor): - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'unsigned.parquet') - cols = 'a, b, c, d' + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "unsigned.parquet") + cols = "a, b, c, d" unsigned_parquet_table = pq.read_table(parquet_filename) unsigned_parquet_table.validate(full=True) @@ -82,16 +82,16 @@ def test_decimals_roundtrip(self, duckdb_cursor): "SELECT typeof(a), typeof(b), typeof(c),typeof(d) from testarrow" ).fetchone() - assert arrow_result[0] == 'DECIMAL(4,2)' - assert arrow_result[1] == 'DECIMAL(9,2)' - assert arrow_result[2] == 'DECIMAL(18,2)' - assert arrow_result[3] == 'DECIMAL(30,2)' + assert arrow_result[0] == "DECIMAL(4,2)" + assert arrow_result[1] == "DECIMAL(9,2)" + assert arrow_result[2] == "DECIMAL(18,2)" + assert arrow_result[3] == "DECIMAL(30,2)" # Lets also test big number comming from arrow land data = pa.array(np.array([9999999999999999999999999999999999]), type=pa.decimal128(38, 0)) - arrow_tbl = pa.Table.from_arrays([data], ['a']) + arrow_tbl = pa.Table.from_arrays([data], ["a"]) duckdb_cursor.from_arrow(arrow_tbl).create("bigdecimal") - result = duckdb_cursor.execute('select * from bigdecimal') + result = duckdb_cursor.execute("select * from bigdecimal") assert result.fetchone()[0] == 9999999999999999999999999999999999 def test_intervals_roundtrip(self, duckdb_cursor): @@ -110,9 +110,9 @@ def test_intervals_roundtrip(self, duckdb_cursor): arr = [expected_value] data = pa.array(arr, pa.month_day_nano_interval()) - arrow_tbl = pa.Table.from_arrays([data], ['a']) + arrow_tbl = pa.Table.from_arrays([data], ["a"]) duckdb_cursor.from_arrow(arrow_tbl).create("intervaltbl") - duck_arrow_tbl = duckdb_cursor.table("intervaltbl").fetch_arrow_table()['a'] + duck_arrow_tbl = duckdb_cursor.table("intervaltbl").fetch_arrow_table()["a"] assert duck_arrow_tbl[0].value == expected_value @@ -120,7 +120,7 @@ def test_intervals_roundtrip(self, duckdb_cursor): duckdb_cursor.execute("CREATE TABLE test (a INTERVAL)") duckdb_cursor.execute("INSERT INTO test VALUES (INTERVAL 1 YEAR + INTERVAL 1 DAY + INTERVAL 1 SECOND)") expected_value = pa.MonthDayNano([12, 1, 1000000000]) - duck_tbl_arrow = duckdb_cursor.table("test").fetch_arrow_table()['a'] + duck_tbl_arrow = duckdb_cursor.table("test").fetch_arrow_table()["a"] assert duck_tbl_arrow[0].value.months == expected_value.months assert duck_tbl_arrow[0].value.days == expected_value.days assert duck_tbl_arrow[0].value.nanoseconds == expected_value.nanoseconds @@ -140,9 +140,9 @@ def test_null_intervals_roundtrip(self, duckdb_cursor): ) arr = [None, expected_value] data = pa.array(arr, pa.month_day_nano_interval()) - arrow_tbl = pa.Table.from_arrays([data], ['a']) + arrow_tbl = pa.Table.from_arrays([data], ["a"]) duckdb_cursor.from_arrow(arrow_tbl).create("intervalnulltbl") - duckdb_tbl_arrow = duckdb_cursor.table("intervalnulltbl").fetch_arrow_table()['a'] + duckdb_tbl_arrow = duckdb_cursor.table("intervalnulltbl").fetch_arrow_table()["a"] assert duckdb_tbl_arrow[0].value == None assert duckdb_tbl_arrow[1].value == expected_value @@ -154,9 +154,9 @@ def test_nested_interval_roundtrip(self, duckdb_cursor): second_value = pa.MonthDayNano([90, 12, 0]) dictionary = pa.array([first_value, second_value, None]) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array], ['a']) + arrow_table = pa.Table.from_arrays([dict_array], ["a"]) duckdb_cursor.from_arrow(arrow_table).create("dictionarytbl") - duckdb_tbl_arrow = duckdb_cursor.table("dictionarytbl").fetch_arrow_table()['a'] + duckdb_tbl_arrow = duckdb_cursor.table("dictionarytbl").fetch_arrow_table()["a"] assert duckdb_tbl_arrow[0].value == first_value assert duckdb_tbl_arrow[1].value == second_value @@ -170,7 +170,7 @@ def test_nested_interval_roundtrip(self, duckdb_cursor): # List query = duckdb_cursor.sql( "SELECT a from (select list_value(INTERVAL 3 MONTHS, INTERVAL 5 DAYS, INTERVAL 10 SECONDS, NULL) as a) as t" - ).fetch_arrow_table()['a'] + ).fetch_arrow_table()["a"] assert query[0][0].value == pa.MonthDayNano([3, 0, 0]) assert query[0][1].value == pa.MonthDayNano([0, 5, 0]) assert query[0][2].value == pa.MonthDayNano([0, 0, 10000000000]) @@ -180,25 +180,25 @@ def test_nested_interval_roundtrip(self, duckdb_cursor): query = "SELECT a from (SELECT STRUCT_PACK(a := INTERVAL 1 MONTHS, b := INTERVAL 10 DAYS, c:= INTERVAL 20 SECONDS) as a) as t" true_answer = duckdb_cursor.sql(query).fetchall() from_arrow = duckdb_cursor.from_arrow(duckdb_cursor.sql(query).fetch_arrow_table()).fetchall() - assert true_answer[0][0]['a'] == from_arrow[0][0]['a'] - assert true_answer[0][0]['b'] == from_arrow[0][0]['b'] - assert true_answer[0][0]['c'] == from_arrow[0][0]['c'] + assert true_answer[0][0]["a"] == from_arrow[0][0]["a"] + assert true_answer[0][0]["b"] == from_arrow[0][0]["b"] + assert true_answer[0][0]["c"] == from_arrow[0][0]["c"] def test_min_max_interval_roundtrip(self, duckdb_cursor): interval_min_value = pa.MonthDayNano([0, 0, 0]) interval_max_value = pa.MonthDayNano([2147483647, 2147483647, 9223372036854775000]) data = pa.array([interval_min_value, interval_max_value], pa.month_day_nano_interval()) - arrow_tbl = pa.Table.from_arrays([data], ['a']) + arrow_tbl = pa.Table.from_arrays([data], ["a"]) duckdb_cursor.from_arrow(arrow_tbl).create("intervalminmaxtbl") - duck_arrow_tbl = duckdb_cursor.table("intervalminmaxtbl").fetch_arrow_table()['a'] + duck_arrow_tbl = duckdb_cursor.table("intervalminmaxtbl").fetch_arrow_table()["a"] assert duck_arrow_tbl[0].value == pa.MonthDayNano([0, 0, 0]) assert duck_arrow_tbl[1].value == pa.MonthDayNano([2147483647, 2147483647, 9223372036854775000]) def test_duplicate_column_names(self, duckdb_cursor): pd = pytest.importorskip("pandas") - df_a = pd.DataFrame({'join_key': [1, 2, 3], 'col_a': ['a', 'b', 'c']}) - df_b = pd.DataFrame({'join_key': [1, 3, 4], 'col_a': ['x', 'y', 'z']}) + df_a = pd.DataFrame({"join_key": [1, 2, 3], "col_a": ["a", "b", "c"]}) + df_b = pd.DataFrame({"join_key": [1, 3, 4], "col_a": ["x", "y", "z"]}) res = duckdb_cursor.execute( """ @@ -210,7 +210,7 @@ def test_duplicate_column_names(self, duckdb_cursor): table1.join_key = table2.join_key """ ).fetch_arrow_table() - assert res.schema.names == ['join_key', 'col_a', 'join_key', 'col_a'] + assert res.schema.names == ["join_key", "col_a", "join_key", "col_a"] def test_strings_roundtrip(self, duckdb_cursor): duckdb_cursor.execute("CREATE TABLE test (a varchar)") diff --git a/tests/fast/arrow/test_interval.py b/tests/fast/arrow/test_interval.py index a548818f..32b7fa64 100644 --- a/tests/fast/arrow/test_interval.py +++ b/tests/fast/arrow/test_interval.py @@ -17,45 +17,45 @@ def test_duration_types(self, duckdb_cursor): if not can_run: return expected_arrow = pa.Table.from_arrays( - [pa.array([pa.MonthDayNano([0, 0, 1000000000])], type=pa.month_day_nano_interval())], ['a'] + [pa.array([pa.MonthDayNano([0, 0, 1000000000])], type=pa.month_day_nano_interval())], ["a"] ) data = ( - pa.array([1000000000], type=pa.duration('ns')), - pa.array([1000000], type=pa.duration('us')), - pa.array([1000], pa.duration('ms')), - pa.array([1], pa.duration('s')), + pa.array([1000000000], type=pa.duration("ns")), + pa.array([1000000], type=pa.duration("us")), + pa.array([1000], pa.duration("ms")), + pa.array([1], pa.duration("s")), ) - arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ['a', 'b', 'c', 'd']) + arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ["a", "b", "c", "d"]) rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() - assert rel['a'] == expected_arrow['a'] - assert rel['b'] == expected_arrow['a'] - assert rel['c'] == expected_arrow['a'] - assert rel['d'] == expected_arrow['a'] + assert rel["a"] == expected_arrow["a"] + assert rel["b"] == expected_arrow["a"] + assert rel["c"] == expected_arrow["a"] + assert rel["d"] == expected_arrow["a"] def test_duration_null(self, duckdb_cursor): if not can_run: return - expected_arrow = pa.Table.from_arrays([pa.array([None], type=pa.month_day_nano_interval())], ['a']) + expected_arrow = pa.Table.from_arrays([pa.array([None], type=pa.month_day_nano_interval())], ["a"]) data = ( - pa.array([None], type=pa.duration('ns')), - pa.array([None], type=pa.duration('us')), - pa.array([None], pa.duration('ms')), - pa.array([None], pa.duration('s')), + pa.array([None], type=pa.duration("ns")), + pa.array([None], type=pa.duration("us")), + pa.array([None], pa.duration("ms")), + pa.array([None], pa.duration("s")), ) - arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ['a', 'b', 'c', 'd']) + arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ["a", "b", "c", "d"]) rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() - assert rel['a'] == expected_arrow['a'] - assert rel['b'] == expected_arrow['a'] - assert rel['c'] == expected_arrow['a'] - assert rel['d'] == expected_arrow['a'] + assert rel["a"] == expected_arrow["a"] + assert rel["b"] == expected_arrow["a"] + assert rel["c"] == expected_arrow["a"] + assert rel["d"] == expected_arrow["a"] def test_duration_overflow(self, duckdb_cursor): if not can_run: return # Only seconds can overflow - data = pa.array([9223372036854775807], pa.duration('s')) - arrow_table = pa.Table.from_arrays([data], ['a']) + data = pa.array([9223372036854775807], pa.duration("s")) + arrow_table = pa.Table.from_arrays([data], ["a"]) - with pytest.raises(duckdb.ConversionException, match='Could not convert Interval to Microsecond'): + with pytest.raises(duckdb.ConversionException, match="Could not convert Interval to Microsecond"): arrow_from_duck = duckdb.from_arrow(arrow_table).fetch_arrow_table() diff --git a/tests/fast/arrow/test_large_offsets.py b/tests/fast/arrow/test_large_offsets.py index 1bcdd1b7..dccfa101 100644 --- a/tests/fast/arrow/test_large_offsets.py +++ b/tests/fast/arrow/test_large_offsets.py @@ -18,7 +18,7 @@ def test_large_lists(self, duckdb_cursor): tbl = pa.Table.from_pydict(dict(col=ary)) with pytest.raises( duckdb.InvalidInputException, - match='Arrow Appender: The maximum combined list offset for regular list buffers is 2147483647 but the offset of 2147481000 exceeds this.', + match="Arrow Appender: The maximum combined list offset for regular list buffers is 2147483647 but the offset of 2147481000 exceeds this.", ): res = duckdb_cursor.sql("SELECT col FROM tbl").fetch_arrow_table() @@ -34,7 +34,7 @@ def test_large_maps(self, duckdb_cursor): with pytest.raises( duckdb.InvalidInputException, - match='Arrow Appender: The maximum combined list offset for regular list buffers is 2147483647 but the offset of 2147481000 exceeds this.', + match="Arrow Appender: The maximum combined list offset for regular list buffers is 2147483647 but the offset of 2147481000 exceeds this.", ): arrow_map = duckdb_cursor.sql("select map(col, col) from tbl").fetch_arrow_table() diff --git a/tests/fast/arrow/test_large_string.py b/tests/fast/arrow/test_large_string.py index 4836048d..308785af 100644 --- a/tests/fast/arrow/test_large_string.py +++ b/tests/fast/arrow/test_large_string.py @@ -22,4 +22,4 @@ def test_large_string_type(self, duckdb_cursor): rel = duckdb.from_arrow(arrow_table) res = rel.execute().fetchall() - assert res == [('foo',), ('baaaar',), ('b',)] + assert res == [("foo",), ("baaaar",), ("b",)] diff --git a/tests/fast/arrow/test_multiple_reads.py b/tests/fast/arrow/test_multiple_reads.py index 935a8a9c..36fb8f59 100644 --- a/tests/fast/arrow/test_multiple_reads.py +++ b/tests/fast/arrow/test_multiple_reads.py @@ -14,8 +14,8 @@ class TestArrowReads(object): def test_multiple_queries_same_relation(self, duckdb_cursor): if not can_run: return - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') - cols = 'id, first_name, last_name, email, gender, ip_address, cc, country, birthdate, salary, title, comments' + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") + cols = "id, first_name, last_name, email, gender, ip_address, cc, country, birthdate, salary, title, comments" userdata_parquet_table = pyarrow.parquet.read_table(parquet_filename) userdata_parquet_table.validate(full=True) diff --git a/tests/fast/arrow/test_nested_arrow.py b/tests/fast/arrow/test_nested_arrow.py index 693a5155..a906324f 100644 --- a/tests/fast/arrow/test_nested_arrow.py +++ b/tests/fast/arrow/test_nested_arrow.py @@ -16,13 +16,13 @@ def compare_results(duckdb_cursor, query): def arrow_to_pandas(duckdb_cursor, query): - return duckdb_cursor.query(query).fetch_arrow_table().to_pandas()['a'].values.tolist() + return duckdb_cursor.query(query).fetch_arrow_table().to_pandas()["a"].values.tolist() def get_use_list_view_options(): result = [] result.append(False) - if hasattr(pa, 'ListViewArray'): + if hasattr(pa, "ListViewArray"): result.append(True) return result @@ -32,7 +32,7 @@ def test_lists_basic(self, duckdb_cursor): # Test Constant List query = ( duckdb_cursor.query("SELECT a from (select list_value(3,5,10) as a) as t") - .fetch_arrow_table()['a'] + .fetch_arrow_table()["a"] .to_numpy() ) assert query[0][0] == 3 @@ -40,32 +40,32 @@ def test_lists_basic(self, duckdb_cursor): assert query[0][2] == 10 # Empty List - query = duckdb_cursor.query("SELECT a from (select list_value() as a) as t").fetch_arrow_table()['a'].to_numpy() + query = duckdb_cursor.query("SELECT a from (select list_value() as a) as t").fetch_arrow_table()["a"].to_numpy() assert len(query[0]) == 0 # Test Constant List With Null query = ( duckdb_cursor.query("SELECT a from (select list_value(3,NULL) as a) as t") - .fetch_arrow_table()['a'] + .fetch_arrow_table()["a"] .to_numpy() ) assert query[0][0] == 3 assert np.isnan(query[0][1]) - @pytest.mark.parametrize('use_list_view', get_use_list_view_options()) + @pytest.mark.parametrize("use_list_view", get_use_list_view_options()) def test_list_types(self, duckdb_cursor, use_list_view): duckdb_cursor.execute(f"pragma arrow_output_list_view={use_list_view};") # Large Lists data = pa.array([[1], None, [2]], type=pa.large_list(pa.int64())) - arrow_table = pa.Table.from_arrays([data], ['a']) + arrow_table = pa.Table.from_arrays([data], ["a"]) rel = duckdb_cursor.from_arrow(arrow_table) res = rel.execute().fetchall() assert res == [([1],), (None,), ([2],)] # Fixed Size Lists data = pa.array([[1], None, [2]], type=pa.list_(pa.int64(), 1)) - arrow_table = pa.Table.from_arrays([data], ['a']) + arrow_table = pa.Table.from_arrays([data], ["a"]) rel = duckdb_cursor.from_arrow(arrow_table) res = rel.execute().fetchall() assert res == [((1,),), (None,), ((2,),)] @@ -76,27 +76,27 @@ def test_list_types(self, duckdb_cursor, use_list_view): pa.array([[1], None, [2]], type=pa.large_list(pa.int64())), pa.array([[1, 2, 3], None, [2, 1]], type=pa.list_(pa.int64())), ] - arrow_table = pa.Table.from_arrays([data[0], data[1], data[2]], ['a', 'b', 'c']) + arrow_table = pa.Table.from_arrays([data[0], data[1], data[2]], ["a", "b", "c"]) rel = duckdb_cursor.from_arrow(arrow_table) - res = rel.project('a').execute().fetchall() + res = rel.project("a").execute().fetchall() assert res == [((1,),), (None,), ((2,),)] - res = rel.project('b').execute().fetchall() + res = rel.project("b").execute().fetchall() assert res == [([1],), (None,), ([2],)] - res = rel.project('c').execute().fetchall() + res = rel.project("c").execute().fetchall() assert res == [([1, 2, 3],), (None,), ([2, 1],)] # Struct Holding different List Types - struct = [pa.StructArray.from_arrays(data, ['fixed', 'large', 'normal'])] - arrow_table = pa.Table.from_arrays(struct, ['a']) + struct = [pa.StructArray.from_arrays(data, ["fixed", "large", "normal"])] + arrow_table = pa.Table.from_arrays(struct, ["a"]) rel = duckdb_cursor.from_arrow(arrow_table) res = rel.execute().fetchall() assert res == [ - ({'fixed': (1,), 'large': [1], 'normal': [1, 2, 3]},), - ({'fixed': None, 'large': None, 'normal': None},), - ({'fixed': (2,), 'large': [2], 'normal': [2, 1]},), + ({"fixed": (1,), "large": [1], "normal": [1, 2, 3]},), + ({"fixed": None, "large": None, "normal": None},), + ({"fixed": (2,), "large": [2], "normal": [2, 1]},), ] - @pytest.mark.parametrize('use_list_view', get_use_list_view_options()) + @pytest.mark.parametrize("use_list_view", get_use_list_view_options()) @pytest.mark.skip(reason="FIXME: this fails on CI") def test_lists_roundtrip(self, duckdb_cursor, use_list_view): duckdb_cursor.execute(f"pragma arrow_output_list_view={use_list_view};") @@ -132,8 +132,8 @@ def test_lists_roundtrip(self, duckdb_cursor, use_list_view): compare_results( duckdb_cursor, - '''SELECT grp,lst,cs FROM (select grp, lst, case when grp>1 then lst else list_value(null) end as cs - from (SELECT a%4 as grp, list(a order by a) as lst FROM range(7) tbl(a) group by grp) as lst_tbl) as T order by all;''', + """SELECT grp,lst,cs FROM (select grp, lst, case when grp>1 then lst else list_value(null) end as cs + from (SELECT a%4 as grp, list(a order by a) as lst FROM range(7) tbl(a) group by grp) as lst_tbl) as T order by all;""", ) # Tests for converting multiple lists to/from Arrow with NULL values and/or strings compare_results( @@ -141,7 +141,7 @@ def test_lists_roundtrip(self, duckdb_cursor, use_list_view): "SELECT list(st order by st) from (select i, case when i%10 then NULL else i::VARCHAR end as st from range(1000) tbl(i)) as t group by i%5 order by all", ) - @pytest.mark.parametrize('use_list_view', get_use_list_view_options()) + @pytest.mark.parametrize("use_list_view", get_use_list_view_options()) def test_struct_roundtrip(self, duckdb_cursor, use_list_view): duckdb_cursor.execute(f"pragma arrow_output_list_view={use_list_view};") @@ -156,7 +156,7 @@ def test_struct_roundtrip(self, duckdb_cursor, use_list_view): "SELECT a from (SELECT STRUCT_PACK(a := LIST_VALUE(1,2,3), b := i) as a FROM range(10000) tbl(i)) as t", ) - @pytest.mark.parametrize('use_list_view', get_use_list_view_options()) + @pytest.mark.parametrize("use_list_view", get_use_list_view_options()) def test_map_roundtrip(self, duckdb_cursor, use_list_view): duckdb_cursor.execute(f"pragma arrow_output_list_view={use_list_view};") @@ -185,13 +185,13 @@ def test_map_roundtrip(self, duckdb_cursor, use_list_view): "SELECT m from (select MAP(lsta,lstb) as m from (SELECT list(i) as lsta, list(i) as lstb from range(10000) tbl(i) group by i%5 order by all) as lst_tbl) as T", ) - @pytest.mark.parametrize('use_list_view', get_use_list_view_options()) + @pytest.mark.parametrize("use_list_view", get_use_list_view_options()) def test_map_arrow_to_duckdb(self, duckdb_cursor, use_list_view): duckdb_cursor.execute(f"pragma arrow_output_list_view={use_list_view};") map_type = pa.map_(pa.int32(), pa.int32()) values = [[(3, 12), (3, 21)], [(5, 42)]] - arrow_table = pa.table({'detail': pa.array(values, map_type)}) + arrow_table = pa.table({"detail": pa.array(values, map_type)}) with pytest.raises( duckdb.InvalidInputException, match="Arrow map contains duplicate key, which isn't supported by DuckDB map type", @@ -201,11 +201,11 @@ def test_map_arrow_to_duckdb(self, duckdb_cursor, use_list_view): def test_null_map_arrow_to_duckdb(self, duckdb_cursor): map_type = pa.map_(pa.int32(), pa.int32()) values = [None, [(5, 42)]] - arrow_table = pa.table({'detail': pa.array(values, map_type)}) + arrow_table = pa.table({"detail": pa.array(values, map_type)}) res = duckdb_cursor.sql("select * from arrow_table").fetchall() assert res == [(None,), ({5: 42},)] - @pytest.mark.parametrize('use_list_view', get_use_list_view_options()) + @pytest.mark.parametrize("use_list_view", get_use_list_view_options()) def test_map_arrow_to_pandas(self, duckdb_cursor, use_list_view): duckdb_cursor.execute(f"pragma arrow_output_list_view={use_list_view};") assert arrow_to_pandas( @@ -215,16 +215,16 @@ def test_map_arrow_to_pandas(self, duckdb_cursor, use_list_view): assert arrow_to_pandas( duckdb_cursor, "SELECT a from (select MAP(LIST_VALUE('Jon Lajoie', 'Backstreet Boys', 'Tenacious D'),LIST_VALUE(10,9,10)) as a) as t", - ) == [[('Jon Lajoie', 10), ('Backstreet Boys', 9), ('Tenacious D', 10)]] + ) == [[("Jon Lajoie", 10), ("Backstreet Boys", 9), ("Tenacious D", 10)]] assert arrow_to_pandas( duckdb_cursor, "SELECT a from (select MAP(list_value(1), list_value(2)) from range(5) tbl(i)) tbl(a)" ) == [[(1, 2)], [(1, 2)], [(1, 2)], [(1, 2)], [(1, 2)]] assert arrow_to_pandas( duckdb_cursor, "SELECT MAP(LIST_VALUE({'i':1,'j':2},{'i':3,'j':4}),LIST_VALUE({'i':1,'j':2},{'i':3,'j':4})) as a", - ) == [[({'i': 1, 'j': 2}, {'i': 1, 'j': 2}), ({'i': 3, 'j': 4}, {'i': 3, 'j': 4})]] + ) == [[({"i": 1, "j": 2}, {"i": 1, "j": 2}), ({"i": 3, "j": 4}, {"i": 3, "j": 4})]] - @pytest.mark.parametrize('use_list_view', get_use_list_view_options()) + @pytest.mark.parametrize("use_list_view", get_use_list_view_options()) def test_frankstein_nested(self, duckdb_cursor, use_list_view): duckdb_cursor.execute(f"pragma arrow_output_list_view={use_list_view};") diff --git a/tests/fast/arrow/test_parallel.py b/tests/fast/arrow/test_parallel.py index 2609d1ae..c768a1dd 100644 --- a/tests/fast/arrow/test_parallel.py +++ b/tests/fast/arrow/test_parallel.py @@ -19,7 +19,7 @@ def test_parallel_run(self, duckdb_cursor): duckdb_conn.execute("PRAGMA threads=4") duckdb_conn.execute("PRAGMA verify_parallelism") data = pyarrow.array(np.random.randint(800, size=1000000), type=pyarrow.int32()) - tbl = pyarrow.Table.from_batches(pyarrow.Table.from_arrays([data], ['a']).to_batches(10000)) + tbl = pyarrow.Table.from_batches(pyarrow.Table.from_arrays([data], ["a"]).to_batches(10000)) rel = duckdb_conn.from_arrow(tbl) # Also test multiple reads assert rel.aggregate("(count(a))::INT").execute().fetchone()[0] == 1000000 @@ -32,17 +32,17 @@ def test_parallel_types_and_different_batches(self, duckdb_cursor): duckdb_conn.execute("PRAGMA threads=4") duckdb_conn.execute("PRAGMA verify_parallelism") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') - cols = 'id, first_name, last_name, email, gender, ip_address, cc, country, birthdate, salary, title, comments' + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") + cols = "id, first_name, last_name, email, gender, ip_address, cc, country, birthdate, salary, title, comments" userdata_parquet_table = pyarrow.parquet.read_table(parquet_filename) for i in [7, 51, 99, 100, 101, 500, 1000, 2000]: data = pyarrow.array(np.arange(3, 7), type=pyarrow.int32()) - tbl = pyarrow.Table.from_arrays([data], ['a']) + tbl = pyarrow.Table.from_arrays([data], ["a"]) rel_id = duckdb_conn.from_arrow(tbl) userdata_parquet_table2 = pyarrow.Table.from_batches(userdata_parquet_table.to_batches(i)) rel = duckdb_conn.from_arrow(userdata_parquet_table2) - result = rel.filter("first_name=\'Jose\' and salary > 134708.82").aggregate('count(*)') + result = rel.filter("first_name='Jose' and salary > 134708.82").aggregate("count(*)") assert result.execute().fetchone()[0] == 4 def test_parallel_fewer_batches_than_threads(self, duckdb_cursor): @@ -53,7 +53,7 @@ def test_parallel_fewer_batches_than_threads(self, duckdb_cursor): duckdb_conn.execute("PRAGMA verify_parallelism") data = pyarrow.array(np.random.randint(800, size=1000), type=pyarrow.int32()) - tbl = pyarrow.Table.from_batches(pyarrow.Table.from_arrays([data], ['a']).to_batches(2)) + tbl = pyarrow.Table.from_batches(pyarrow.Table.from_arrays([data], ["a"]).to_batches(2)) rel = duckdb_conn.from_arrow(tbl) # Also test multiple reads assert rel.aggregate("(count(a))::INT").execute().fetchone()[0] == 1000 diff --git a/tests/fast/arrow/test_polars.py b/tests/fast/arrow/test_polars.py index 87e2f726..a4e94d18 100644 --- a/tests/fast/arrow/test_polars.py +++ b/tests/fast/arrow/test_polars.py @@ -31,21 +31,21 @@ def test_polars(self, duckdb_cursor): } ) # scan plus return a polars dataframe - polars_result = duckdb_cursor.sql('SELECT * FROM df').pl() + polars_result = duckdb_cursor.sql("SELECT * FROM df").pl() pl_testing.assert_frame_equal(df, polars_result) # now do the same for a lazy dataframe lazy_df = df.lazy() - lazy_result = duckdb_cursor.sql('SELECT * FROM lazy_df').pl() + lazy_result = duckdb_cursor.sql("SELECT * FROM lazy_df").pl() pl_testing.assert_frame_equal(df, lazy_result) con = duckdb.connect() - con_result = con.execute('SELECT * FROM df').pl() + con_result = con.execute("SELECT * FROM df").pl() pl_testing.assert_frame_equal(df, con_result) def test_execute_polars(self, duckdb_cursor): res1 = duckdb_cursor.execute("SELECT 1 AS a, 2 AS a").pl() - assert res1.columns == ['a', 'a_1'] + assert res1.columns == ["a", "a_1"] def test_register_polars(self, duckdb_cursor): con = duckdb.connect() @@ -58,21 +58,21 @@ def test_register_polars(self, duckdb_cursor): } ) # scan plus return a polars dataframe - con.register('polars_df', df) - polars_result = con.execute('select * from polars_df').pl() + con.register("polars_df", df) + polars_result = con.execute("select * from polars_df").pl() pl_testing.assert_frame_equal(df, polars_result) - con.unregister('polars_df') - with pytest.raises(duckdb.CatalogException, match='Table with name polars_df does not exist'): + con.unregister("polars_df") + with pytest.raises(duckdb.CatalogException, match="Table with name polars_df does not exist"): con.execute("SELECT * FROM polars_df;").pl() - con.register('polars_df', df.lazy()) - polars_result = con.execute('select * from polars_df').pl() + con.register("polars_df", df.lazy()) + polars_result = con.execute("select * from polars_df").pl() pl_testing.assert_frame_equal(df, polars_result) def test_empty_polars_dataframe(self, duckdb_cursor): polars_empty_df = pl.DataFrame() with pytest.raises( - duckdb.InvalidInputException, match='Provided table/dataframe must have at least one column' + duckdb.InvalidInputException, match="Provided table/dataframe must have at least one column" ): duckdb_cursor.sql("from polars_empty_df") @@ -82,7 +82,7 @@ def test_polars_from_json(self, duckdb_cursor): duckdb_cursor.sql("set arrow_lossless_conversion=false") string = StringIO("""{"entry":[{"content":{"ManagedSystem":{"test":null}}}]}""") res = duckdb_cursor.read_json(string).pl() - assert str(res['entry'][0][0]) == "{'content': {'ManagedSystem': {'test': None}}}" + assert str(res["entry"][0][0]) == "{'content': {'ManagedSystem': {'test': None}}}" @pytest.mark.skipif( not hasattr(pl.exceptions, "PanicException"), reason="Polars has no PanicException in this version" @@ -93,13 +93,13 @@ def test_polars_from_json_error(self, duckdb_cursor): duckdb_cursor.sql("set arrow_lossless_conversion=true") string = StringIO("""{"entry":[{"content":{"ManagedSystem":{"test":null}}}]}""") res = duckdb_cursor.read_json(string).pl() - assert duckdb_cursor.execute("FROM res").fetchall() == [([{'content': {'ManagedSystem': {'test': None}}}],)] + assert duckdb_cursor.execute("FROM res").fetchall() == [([{"content": {"ManagedSystem": {"test": None}}}],)] def test_polars_from_json_error(self, duckdb_cursor): conn = duckdb.connect() my_table = conn.query("select 'x' my_str").pl() my_res = duckdb.query("select my_str from my_table where my_str != 'y'") - assert my_res.fetchall() == [('x',)] + assert my_res.fetchall() == [("x",)] def test_polars_lazy_from_conn(self, duckdb_cursor): duckdb_conn = duckdb.connect() @@ -107,7 +107,7 @@ def test_polars_lazy_from_conn(self, duckdb_cursor): result = duckdb_conn.execute("SELECT 42 as bla") lazy_df = result.pl(lazy=True) - assert lazy_df.collect().to_dicts() == [{'bla': 42}] + assert lazy_df.collect().to_dicts() == [{"bla": 42}] def test_polars_lazy(self, duckdb_cursor): con = duckdb.connect() @@ -118,18 +118,18 @@ def test_polars_lazy(self, duckdb_cursor): assert isinstance(lazy_df, pl.LazyFrame) assert lazy_df.collect().to_dicts() == [ - {'a': 'Pedro', 'b': 32}, - {'a': 'Mark', 'b': 31}, - {'a': 'Thijs', 'b': 29}, + {"a": "Pedro", "b": 32}, + {"a": "Mark", "b": 31}, + {"a": "Thijs", "b": 29}, ] - assert lazy_df.select('a').collect().to_dicts() == [{'a': 'Pedro'}, {'a': 'Mark'}, {'a': 'Thijs'}] - assert lazy_df.limit(1).collect().to_dicts() == [{'a': 'Pedro', 'b': 32}] + assert lazy_df.select("a").collect().to_dicts() == [{"a": "Pedro"}, {"a": "Mark"}, {"a": "Thijs"}] + assert lazy_df.limit(1).collect().to_dicts() == [{"a": "Pedro", "b": 32}] assert lazy_df.filter(pl.col("b") < 32).collect().to_dicts() == [ - {'a': 'Mark', 'b': 31}, - {'a': 'Thijs', 'b': 29}, + {"a": "Mark", "b": 31}, + {"a": "Thijs", "b": 29}, ] - assert lazy_df.filter(pl.col("b") < 32).select('a').collect().to_dicts() == [{'a': 'Mark'}, {'a': 'Thijs'}] + assert lazy_df.filter(pl.col("b") < 32).select("a").collect().to_dicts() == [{"a": "Mark"}, {"a": "Thijs"}] def test_polars_column_with_tricky_name(self, duckdb_cursor): # Test that a polars DataFrame with a column name that is non standard still works @@ -162,23 +162,23 @@ def test_polars_column_with_tricky_name(self, duckdb_cursor): assert result.to_dicts() == [{'"xy"': 1}] @pytest.mark.parametrize( - 'data_type', + "data_type", [ - 'TINYINT', - 'SMALLINT', - 'INTEGER', - 'BIGINT', - 'UTINYINT', - 'USMALLINT', - 'UINTEGER', - 'UBIGINT', - 'FLOAT', - 'DOUBLE', - 'HUGEINT', - 'DECIMAL(4,1)', - 'DECIMAL(9,1)', - 'DECIMAL(18,4)', - 'DECIMAL(30,12)', + "TINYINT", + "SMALLINT", + "INTEGER", + "BIGINT", + "UTINYINT", + "USMALLINT", + "UINTEGER", + "UBIGINT", + "FLOAT", + "DOUBLE", + "HUGEINT", + "DECIMAL(4,1)", + "DECIMAL(9,1)", + "DECIMAL(18,4)", + "DECIMAL(30,12)", ], ) def test_polars_lazy_pushdown_numeric(self, data_type, duckdb_cursor): @@ -524,9 +524,9 @@ def test_polars_lazy_pushdown_blob(self, duckdb_cursor): df = pandas.DataFrame( { - 'a': [bytes([1]), bytes([2]), bytes([3]), None], - 'b': [bytes([1]), bytes([2]), bytes([3]), None], - 'c': [bytes([1]), bytes([2]), bytes([3]), None], + "a": [bytes([1]), bytes([2]), bytes([3]), None], + "b": [bytes([1]), bytes([2]), bytes([3]), None], + "c": [bytes([1]), bytes([2]), bytes([3]), None], } ) duck_tbl = duckdb.from_df(df) diff --git a/tests/fast/arrow/test_progress.py b/tests/fast/arrow/test_progress.py index c20ebe51..6f056937 100644 --- a/tests/fast/arrow/test_progress.py +++ b/tests/fast/arrow/test_progress.py @@ -8,7 +8,7 @@ class TestProgressBarArrow(object): def test_progress_arrow(self): - if os.name == 'nt': + if os.name == "nt": return np = pytest.importorskip("numpy") pyarrow = pytest.importorskip("pyarrow") @@ -18,9 +18,9 @@ def test_progress_arrow(self): duckdb_conn.execute("PRAGMA progress_bar_time=1") duckdb_conn.execute("PRAGMA disable_print_progress_bar") - tbl = pyarrow.Table.from_arrays([data], ['a']) + tbl = pyarrow.Table.from_arrays([data], ["a"]) rel = duckdb_conn.from_arrow(tbl) - result = rel.aggregate('sum(a)') + result = rel.aggregate("sum(a)") assert result.execute().fetchone()[0] == 49999995000000 # Multiple Threads duckdb_conn.execute("PRAGMA threads=4") @@ -28,9 +28,9 @@ def test_progress_arrow(self): assert result.execute().fetchone()[0] == 49999995000000 # More than one batch - tbl = pyarrow.Table.from_batches(pyarrow.Table.from_arrays([data], ['a']).to_batches(100)) + tbl = pyarrow.Table.from_batches(pyarrow.Table.from_arrays([data], ["a"]).to_batches(100)) rel = duckdb_conn.from_arrow(tbl) - result = rel.aggregate('sum(a)') + result = rel.aggregate("sum(a)") assert result.execute().fetchone()[0] == 49999995000000 # Single Thread @@ -40,7 +40,7 @@ def test_progress_arrow(self): assert py_res == 49999995000000 def test_progress_arrow_empty(self): - if os.name == 'nt': + if os.name == "nt": return np = pytest.importorskip("numpy") pyarrow = pytest.importorskip("pyarrow") @@ -50,7 +50,7 @@ def test_progress_arrow_empty(self): duckdb_conn.execute("PRAGMA progress_bar_time=1") duckdb_conn.execute("PRAGMA disable_print_progress_bar") - tbl = pyarrow.Table.from_arrays([data], ['a']) + tbl = pyarrow.Table.from_arrays([data], ["a"]) rel = duckdb_conn.from_arrow(tbl) - result = rel.aggregate('sum(a)') + result = rel.aggregate("sum(a)") assert result.execute().fetchone()[0] == None diff --git a/tests/fast/arrow/test_time.py b/tests/fast/arrow/test_time.py index 726b0f6a..e7c4404e 100644 --- a/tests/fast/arrow/test_time.py +++ b/tests/fast/arrow/test_time.py @@ -18,60 +18,60 @@ def test_time_types(self, duckdb_cursor): return data = ( - pa.array([1], type=pa.time32('s')), - pa.array([1000], type=pa.time32('ms')), - pa.array([1000000], pa.time64('us')), - pa.array([1000000000], pa.time64('ns')), + pa.array([1], type=pa.time32("s")), + pa.array([1000], type=pa.time32("ms")), + pa.array([1000000], pa.time64("us")), + pa.array([1000000000], pa.time64("ns")), ) - arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ['a', 'b', 'c', 'd']) + arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ["a", "b", "c", "d"]) rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() - assert rel['a'] == arrow_table['c'] - assert rel['b'] == arrow_table['c'] - assert rel['c'] == arrow_table['c'] - assert rel['d'] == arrow_table['c'] + assert rel["a"] == arrow_table["c"] + assert rel["b"] == arrow_table["c"] + assert rel["c"] == arrow_table["c"] + assert rel["d"] == arrow_table["c"] def test_time_null(self, duckdb_cursor): if not can_run: return data = ( - pa.array([None], type=pa.time32('s')), - pa.array([None], type=pa.time32('ms')), - pa.array([None], pa.time64('us')), - pa.array([None], pa.time64('ns')), + pa.array([None], type=pa.time32("s")), + pa.array([None], type=pa.time32("ms")), + pa.array([None], pa.time64("us")), + pa.array([None], pa.time64("ns")), ) - arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ['a', 'b', 'c', 'd']) + arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ["a", "b", "c", "d"]) rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() - assert rel['a'] == arrow_table['c'] - assert rel['b'] == arrow_table['c'] - assert rel['c'] == arrow_table['c'] - assert rel['d'] == arrow_table['c'] + assert rel["a"] == arrow_table["c"] + assert rel["b"] == arrow_table["c"] + assert rel["c"] == arrow_table["c"] + assert rel["d"] == arrow_table["c"] def test_max_times(self, duckdb_cursor): if not can_run: return - data = pa.array([2147483647000000], type=pa.time64('us')) - result = pa.Table.from_arrays([data], ['a']) + data = pa.array([2147483647000000], type=pa.time64("us")) + result = pa.Table.from_arrays([data], ["a"]) # Max Sec - data = pa.array([2147483647], type=pa.time32('s')) - arrow_table = pa.Table.from_arrays([data], ['a']) + data = pa.array([2147483647], type=pa.time32("s")) + arrow_table = pa.Table.from_arrays([data], ["a"]) rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() - assert rel['a'] == result['a'] + assert rel["a"] == result["a"] # Max MSec - data = pa.array([2147483647000], type=pa.time64('us')) - result = pa.Table.from_arrays([data], ['a']) - data = pa.array([2147483647], type=pa.time32('ms')) - arrow_table = pa.Table.from_arrays([data], ['a']) + data = pa.array([2147483647000], type=pa.time64("us")) + result = pa.Table.from_arrays([data], ["a"]) + data = pa.array([2147483647], type=pa.time32("ms")) + arrow_table = pa.Table.from_arrays([data], ["a"]) rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() - assert rel['a'] == result['a'] + assert rel["a"] == result["a"] # Max NSec - data = pa.array([9223372036854774], type=pa.time64('us')) - result = pa.Table.from_arrays([data], ['a']) - data = pa.array([9223372036854774000], type=pa.time64('ns')) - arrow_table = pa.Table.from_arrays([data], ['a']) + data = pa.array([9223372036854774], type=pa.time64("us")) + result = pa.Table.from_arrays([data], ["a"]) + data = pa.array([9223372036854774000], type=pa.time64("ns")) + arrow_table = pa.Table.from_arrays([data], ["a"]) rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() - print(rel['a']) - print(result['a']) - assert rel['a'] == result['a'] + print(rel["a"]) + print(result["a"]) + assert rel["a"] == result["a"] diff --git a/tests/fast/arrow/test_timestamp_timezone.py b/tests/fast/arrow/test_timestamp_timezone.py index 4fdadf49..08816be1 100644 --- a/tests/fast/arrow/test_timestamp_timezone.py +++ b/tests/fast/arrow/test_timestamp_timezone.py @@ -3,7 +3,7 @@ import datetime import pytz -pa = pytest.importorskip('pyarrow') +pa = pytest.importorskip("pyarrow") def generate_table(current_time, precision, timezone): @@ -13,30 +13,30 @@ def generate_table(current_time, precision, timezone): return pa.Table.from_arrays(inputs, schema=schema) -timezones = ['UTC', 'BET', 'CET', 'Asia/Kathmandu'] +timezones = ["UTC", "BET", "CET", "Asia/Kathmandu"] class TestArrowTimestampsTimezone(object): def test_timestamp_timezone(self, duckdb_cursor): - precisions = ['us', 's', 'ns', 'ms'] + precisions = ["us", "s", "ns", "ms"] current_time = datetime.datetime(2017, 11, 28, 23, 55, 59, tzinfo=pytz.UTC) con = duckdb.connect() con.execute("SET TimeZone = 'UTC'") for precision in precisions: - arrow_table = generate_table(current_time, precision, 'UTC') + arrow_table = generate_table(current_time, precision, "UTC") res_utc = con.from_arrow(arrow_table).execute().fetchall() assert res_utc[0][0] == current_time def test_timestamp_timezone_overflow(self, duckdb_cursor): - precisions = ['s', 'ms'] + precisions = ["s", "ms"] current_time = 9223372036854775807 for precision in precisions: - with pytest.raises(duckdb.ConversionException, match='Could not convert'): - arrow_table = generate_table(current_time, precision, 'UTC') + with pytest.raises(duckdb.ConversionException, match="Could not convert"): + arrow_table = generate_table(current_time, precision, "UTC") res_utc = duckdb.from_arrow(arrow_table).execute().fetchall() def test_timestamp_tz_to_arrow(self, duckdb_cursor): - precisions = ['us', 's', 'ns', 'ms'] + precisions = ["us", "s", "ns", "ms"] current_time = datetime.datetime(2017, 11, 28, 23, 55, 59) con = duckdb.connect() for precision in precisions: @@ -44,16 +44,16 @@ def test_timestamp_tz_to_arrow(self, duckdb_cursor): con.execute("SET TimeZone = '" + timezone + "'") arrow_table = generate_table(current_time, precision, timezone) res = con.from_arrow(arrow_table).fetch_arrow_table() - assert res[0].type == pa.timestamp('us', tz=timezone) - assert res == generate_table(current_time, 'us', timezone) + assert res[0].type == pa.timestamp("us", tz=timezone) + assert res == generate_table(current_time, "us", timezone) def test_timestamp_tz_with_null(self, duckdb_cursor): con = duckdb.connect() con.execute("create table t (i timestamptz)") con.execute("insert into t values (NULL),('2021-11-15 02:30:00'::timestamptz)") - rel = con.table('t') + rel = con.table("t") arrow_tbl = rel.fetch_arrow_table() - con.register('t2', arrow_tbl) + con.register("t2", arrow_tbl) assert con.execute("select * from t").fetchall() == con.execute("select * from t2").fetchall() @@ -61,8 +61,8 @@ def test_timestamp_stream(self, duckdb_cursor): con = duckdb.connect() con.execute("create table t (i timestamptz)") con.execute("insert into t values (NULL),('2021-11-15 02:30:00'::timestamptz)") - rel = con.table('t') + rel = con.table("t") arrow_tbl = rel.record_batch().read_all() - con.register('t2', arrow_tbl) + con.register("t2", arrow_tbl) assert con.execute("select * from t").fetchall() == con.execute("select * from t2").fetchall() diff --git a/tests/fast/arrow/test_timestamps.py b/tests/fast/arrow/test_timestamps.py index c2529c83..684a333c 100644 --- a/tests/fast/arrow/test_timestamps.py +++ b/tests/fast/arrow/test_timestamps.py @@ -17,61 +17,61 @@ def test_timestamp_types(self, duckdb_cursor): if not can_run: return data = ( - pa.array([datetime.datetime.now()], type=pa.timestamp('ns')), - pa.array([datetime.datetime.now()], type=pa.timestamp('us')), - pa.array([datetime.datetime.now()], pa.timestamp('ms')), - pa.array([datetime.datetime.now()], pa.timestamp('s')), + pa.array([datetime.datetime.now()], type=pa.timestamp("ns")), + pa.array([datetime.datetime.now()], type=pa.timestamp("us")), + pa.array([datetime.datetime.now()], pa.timestamp("ms")), + pa.array([datetime.datetime.now()], pa.timestamp("s")), ) - arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ['a', 'b', 'c', 'd']) + arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ["a", "b", "c", "d"]) rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() - assert rel['a'] == arrow_table['a'] - assert rel['b'] == arrow_table['b'] - assert rel['c'] == arrow_table['c'] - assert rel['d'] == arrow_table['d'] + assert rel["a"] == arrow_table["a"] + assert rel["b"] == arrow_table["b"] + assert rel["c"] == arrow_table["c"] + assert rel["d"] == arrow_table["d"] def test_timestamp_nulls(self, duckdb_cursor): if not can_run: return data = ( - pa.array([None], type=pa.timestamp('ns')), - pa.array([None], type=pa.timestamp('us')), - pa.array([None], pa.timestamp('ms')), - pa.array([None], pa.timestamp('s')), + pa.array([None], type=pa.timestamp("ns")), + pa.array([None], type=pa.timestamp("us")), + pa.array([None], pa.timestamp("ms")), + pa.array([None], pa.timestamp("s")), ) - arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ['a', 'b', 'c', 'd']) + arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ["a", "b", "c", "d"]) rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() - assert rel['a'] == arrow_table['a'] - assert rel['b'] == arrow_table['b'] - assert rel['c'] == arrow_table['c'] - assert rel['d'] == arrow_table['d'] + assert rel["a"] == arrow_table["a"] + assert rel["b"] == arrow_table["b"] + assert rel["c"] == arrow_table["c"] + assert rel["d"] == arrow_table["d"] def test_timestamp_overflow(self, duckdb_cursor): if not can_run: return data = ( - pa.array([9223372036854775807], pa.timestamp('s')), - pa.array([9223372036854775807], pa.timestamp('ms')), - pa.array([9223372036854775807], pa.timestamp('us')), + pa.array([9223372036854775807], pa.timestamp("s")), + pa.array([9223372036854775807], pa.timestamp("ms")), + pa.array([9223372036854775807], pa.timestamp("us")), ) - arrow_table = pa.Table.from_arrays([data[0], data[1], data[2]], ['a', 'b', 'c']) + arrow_table = pa.Table.from_arrays([data[0], data[1], data[2]], ["a", "b", "c"]) arrow_from_duck = duckdb.from_arrow(arrow_table).fetch_arrow_table() - assert arrow_from_duck['a'] == arrow_table['a'] - assert arrow_from_duck['b'] == arrow_table['b'] - assert arrow_from_duck['c'] == arrow_table['c'] + assert arrow_from_duck["a"] == arrow_table["a"] + assert arrow_from_duck["b"] == arrow_table["b"] + assert arrow_from_duck["c"] == arrow_table["c"] expected = (datetime.datetime(9999, 12, 31, 23, 59, 59, 999999),) duck_rel = duckdb.from_arrow(arrow_table) - res = duck_rel.project('a::TIMESTAMP_US') + res = duck_rel.project("a::TIMESTAMP_US") result = res.fetchone() assert result == expected duck_rel = duckdb.from_arrow(arrow_table) - res = duck_rel.project('b::TIMESTAMP_US') + res = duck_rel.project("b::TIMESTAMP_US") result = res.fetchone() assert result == expected duck_rel = duckdb.from_arrow(arrow_table) - res = duck_rel.project('c::TIMESTAMP_NS') + res = duck_rel.project("c::TIMESTAMP_NS") result = res.fetchone() assert result == expected diff --git a/tests/fast/arrow/test_tpch.py b/tests/fast/arrow/test_tpch.py index ff4a0445..d5d13b20 100644 --- a/tests/fast/arrow/test_tpch.py +++ b/tests/fast/arrow/test_tpch.py @@ -24,7 +24,7 @@ def check_result(result, answers): db_result = result.fetchone() cq_results = q_res.split("|") # The end of the rows, continue - if cq_results == [''] and str(db_result) == 'None' or str(db_result[0]) == 'None': + if cq_results == [""] and str(db_result) == "None" or str(db_result[0]) == "None": continue ans_result = [munge(cell) for cell in cq_results] db_result = [munge(cell) for cell in db_result] @@ -39,7 +39,7 @@ def test_tpch_arrow(self, duckdb_cursor): if not can_run: return - tpch_tables = ['part', 'partsupp', 'supplier', 'customer', 'lineitem', 'orders', 'nation', 'region'] + tpch_tables = ["part", "partsupp", "supplier", "customer", "lineitem", "orders", "nation", "region"] arrow_tables = [] duckdb_conn = duckdb.connect() @@ -69,7 +69,7 @@ def test_tpch_arrow_01(self, duckdb_cursor): if not can_run: return - tpch_tables = ['part', 'partsupp', 'supplier', 'customer', 'lineitem', 'orders', 'nation', 'region'] + tpch_tables = ["part", "partsupp", "supplier", "customer", "lineitem", "orders", "nation", "region"] arrow_tables = [] duckdb_conn = duckdb.connect() @@ -97,7 +97,7 @@ def test_tpch_arrow_batch(self, duckdb_cursor): if not can_run: return - tpch_tables = ['part', 'partsupp', 'supplier', 'customer', 'lineitem', 'orders', 'nation', 'region'] + tpch_tables = ["part", "partsupp", "supplier", "customer", "lineitem", "orders", "nation", "region"] arrow_tables = [] duckdb_conn = duckdb.connect() diff --git a/tests/fast/arrow/test_unregister.py b/tests/fast/arrow/test_unregister.py index c63ef0d6..8ff37b5a 100644 --- a/tests/fast/arrow/test_unregister.py +++ b/tests/fast/arrow/test_unregister.py @@ -17,8 +17,8 @@ class TestArrowUnregister(object): def test_arrow_unregister1(self, duckdb_cursor): if not can_run: return - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') - cols = 'id, first_name, last_name, email, gender, ip_address, cc, country, birthdate, salary, title, comments' + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") + cols = "id, first_name, last_name, email, gender, ip_address, cc, country, birthdate, salary, title, comments" arrow_table_obj = pyarrow.parquet.read_table(parquet_filename) connection = duckdb.connect(":memory:") @@ -26,9 +26,9 @@ def test_arrow_unregister1(self, duckdb_cursor): arrow_table_2 = connection.execute("SELECT * FROM arrow_table;").fetch_arrow_table() connection.unregister("arrow_table") - with pytest.raises(duckdb.CatalogException, match='Table with name arrow_table does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name arrow_table does not exist"): connection.execute("SELECT * FROM arrow_table;").fetch_arrow_table() - with pytest.raises(duckdb.CatalogException, match='View with name arrow_table does not exist'): + with pytest.raises(duckdb.CatalogException, match="View with name arrow_table does not exist"): connection.execute("DROP VIEW arrow_table;") connection.execute("DROP VIEW IF EXISTS arrow_table;") @@ -40,8 +40,8 @@ def test_arrow_unregister2(self, duckdb_cursor): os.remove(db) connection = duckdb.connect(db) - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') - cols = 'id, first_name, last_name, email, gender, ip_address, cc, country, birthdate, salary, title, comments' + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") + cols = "id, first_name, last_name, email, gender, ip_address, cc, country, birthdate, salary, title, comments" arrow_table_obj = pyarrow.parquet.read_table(parquet_filename) connection.register("arrow_table", arrow_table_obj) connection.unregister("arrow_table") # Attempting to unregister. @@ -49,7 +49,7 @@ def test_arrow_unregister2(self, duckdb_cursor): # Reconnecting while Arrow Table still in mem. connection = duckdb.connect(db) assert len(connection.execute("PRAGMA show_tables;").fetchall()) == 0 - with pytest.raises(duckdb.CatalogException, match='Table with name arrow_table does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name arrow_table does not exist"): connection.execute("SELECT * FROM arrow_table;").fetch_arrow_table() connection.close() del arrow_table_obj @@ -57,6 +57,6 @@ def test_arrow_unregister2(self, duckdb_cursor): # Reconnecting after Arrow Table is freed. connection = duckdb.connect(db) assert len(connection.execute("PRAGMA show_tables;").fetchall()) == 0 - with pytest.raises(duckdb.CatalogException, match='Table with name arrow_table does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name arrow_table does not exist"): connection.execute("SELECT * FROM arrow_table;").fetch_arrow_table() connection.close() diff --git a/tests/fast/arrow/test_view.py b/tests/fast/arrow/test_view.py index 54acb336..7f1410aa 100644 --- a/tests/fast/arrow/test_view.py +++ b/tests/fast/arrow/test_view.py @@ -8,9 +8,9 @@ class TestArrowView(object): def test_arrow_view(self, duckdb_cursor): - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") userdata_parquet_table = pa.parquet.read_table(parquet_filename) userdata_parquet_table.validate(full=True) - duckdb_cursor.from_arrow(userdata_parquet_table).create_view('arrow_view') - assert duckdb_cursor.execute("PRAGMA show_tables").fetchone() == ('arrow_view',) + duckdb_cursor.from_arrow(userdata_parquet_table).create_view("arrow_view") + assert duckdb_cursor.execute("PRAGMA show_tables").fetchone() == ("arrow_view",) assert duckdb_cursor.execute("select avg(salary)::INT from arrow_view").fetchone()[0] == 149005 diff --git a/tests/fast/numpy/test_numpy_new_path.py b/tests/fast/numpy/test_numpy_new_path.py index 4267085c..b872d4d9 100644 --- a/tests/fast/numpy/test_numpy_new_path.py +++ b/tests/fast/numpy/test_numpy_new_path.py @@ -28,11 +28,11 @@ def test_scan_numpy(self, duckdb_cursor): z = np.array(["zzz", "xxx"]) res = duckdb_cursor.sql("select * from z").fetchall() - assert res == [('zzz',), ('xxx',)] + assert res == [("zzz",), ("xxx",)] z = [np.array(["zzz", "xxx"]), np.array([1, 2])] res = duckdb_cursor.sql("select * from z").fetchall() - assert res == [('zzz', 1), ('xxx', 2)] + assert res == [("zzz", 1), ("xxx", 2)] # test ndarray with dtype = object (python dict) z = [] @@ -41,9 +41,9 @@ def test_scan_numpy(self, duckdb_cursor): z = np.array(z) res = duckdb_cursor.sql("select * from z").fetchall() assert res == [ - ({'3': 0},), - ({'2': 1},), - ({'1': 2},), + ({"3": 0},), + ({"2": 1},), + ({"1": 2},), ] # test timedelta @@ -74,12 +74,12 @@ def test_scan_numpy(self, duckdb_cursor): # dict of mixed types z = {"z": np.array([1, 2, 3]), "x": np.array(["z", "x", "c"])} res = duckdb_cursor.sql("select * from z").fetchall() - assert res == [(1, 'z'), (2, 'x'), (3, 'c')] + assert res == [(1, "z"), (2, "x"), (3, "c")] # list of mixed types z = [np.array([1, 2, 3]), np.array(["z", "x", "c"])] res = duckdb_cursor.sql("select * from z").fetchall() - assert res == [(1, 'z'), (2, 'x'), (3, 'c')] + assert res == [(1, "z"), (2, "x"), (3, "c")] # currently unsupported formats, will throw duckdb.InvalidInputException diff --git a/tests/fast/pandas/test_2304.py b/tests/fast/pandas/test_2304.py index 6fc355e5..11344df8 100644 --- a/tests/fast/pandas/test_2304.py +++ b/tests/fast/pandas/test_2304.py @@ -5,37 +5,37 @@ class TestPandasMergeSameName(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_2304(self, duckdb_cursor, pandas): df1 = pandas.DataFrame( { - 'id_1': [1, 1, 1, 2, 2], - 'agedate': np.array(['2010-01-01', '2010-02-01', '2010-03-01', '2020-02-01', '2020-03-01']).astype( - 'datetime64[D]' + "id_1": [1, 1, 1, 2, 2], + "agedate": np.array(["2010-01-01", "2010-02-01", "2010-03-01", "2020-02-01", "2020-03-01"]).astype( + "datetime64[D]" ), - 'age': [1, 2, 3, 1, 2], - 'v': [1.1, 1.2, 1.3, 2.1, 2.2], + "age": [1, 2, 3, 1, 2], + "v": [1.1, 1.2, 1.3, 2.1, 2.2], } ) df2 = pandas.DataFrame( { - 'id_1': [1, 1, 2], - 'agedate': np.array(['2010-01-01', '2010-02-01', '2020-03-01']).astype('datetime64[D]'), - 'v2': [11.1, 11.2, 21.2], + "id_1": [1, 1, 2], + "agedate": np.array(["2010-01-01", "2010-02-01", "2020-03-01"]).astype("datetime64[D]"), + "v2": [11.1, 11.2, 21.2], } ) con = duckdb.connect() - con.register('df1', df1) - con.register('df2', df2) + con.register("df1", df1) + con.register("df2", df2) query = """SELECT * from df1 LEFT OUTER JOIN df2 ON (df1.id_1=df2.id_1 and df1.agedate=df2.agedate) order by df1.id_1, df1.agedate, df1.age, df1.v, df2.id_1,df2.agedate,df2.v2""" result_df = con.execute(query).fetchdf() expected_result = con.execute(query).fetchall() - con.register('result_df', result_df) + con.register("result_df", result_df) rel = con.sql( """ select * from result_df order by @@ -52,32 +52,32 @@ def test_2304(self, duckdb_cursor, pandas): assert result == expected_result - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_pd_names(self, duckdb_cursor, pandas): df1 = pandas.DataFrame( { - 'id': [1, 1, 2], - 'id_1': [1, 1, 2], - 'id_3': [1, 1, 2], + "id": [1, 1, 2], + "id_1": [1, 1, 2], + "id_3": [1, 1, 2], } ) - df2 = pandas.DataFrame({'id': [1, 1, 2], 'id_1': [1, 1, 2], 'id_2': [1, 1, 1]}) + df2 = pandas.DataFrame({"id": [1, 1, 2], "id_1": [1, 1, 2], "id_2": [1, 1, 1]}) exp_result = pandas.DataFrame( { - 'id': [1, 1, 2, 1, 1], - 'id_1': [1, 1, 2, 1, 1], - 'id_3': [1, 1, 2, 1, 1], - 'id_2': [1, 1, 2, 1, 1], - 'id_1_1': [1, 1, 2, 1, 1], - 'id_2_1': [1, 1, 1, 1, 1], + "id": [1, 1, 2, 1, 1], + "id_1": [1, 1, 2, 1, 1], + "id_3": [1, 1, 2, 1, 1], + "id_2": [1, 1, 2, 1, 1], + "id_1_1": [1, 1, 2, 1, 1], + "id_2_1": [1, 1, 1, 1, 1], } ) con = duckdb.connect() - con.register('df1', df1) - con.register('df2', df2) + con.register("df1", df1) + con.register("df2", df2) query = """SELECT * from df1 LEFT OUTER JOIN df2 ON (df1.id_1=df2.id_1)""" @@ -85,30 +85,30 @@ def test_pd_names(self, duckdb_cursor, pandas): result_df = con.execute(query).fetchdf() pandas.testing.assert_frame_equal(exp_result, result_df) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_repeat_name(self, duckdb_cursor, pandas): df1 = pandas.DataFrame( { - 'id': [1], - 'id_1': [1], - 'id_2': [1], + "id": [1], + "id_1": [1], + "id_2": [1], } ) - df2 = pandas.DataFrame({'id': [1]}) + df2 = pandas.DataFrame({"id": [1]}) exp_result = pandas.DataFrame( { - 'id': [1], - 'id_1': [1], - 'id_2': [1], - 'id_3': [1], + "id": [1], + "id_1": [1], + "id_2": [1], + "id_3": [1], } ) con = duckdb.connect() - con.register('df1', df1) - con.register('df2', df2) + con.register("df1", df1) + con.register("df2", df2) result_df = con.execute( """ diff --git a/tests/fast/pandas/test_append_df.py b/tests/fast/pandas/test_append_df.py index 18805a5a..e6d64776 100644 --- a/tests/fast/pandas/test_append_df.py +++ b/tests/fast/pandas/test_append_df.py @@ -4,35 +4,35 @@ class TestAppendDF(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_df_to_table_append(self, duckdb_cursor, pandas): conn = duckdb.connect() conn.execute("Create table integers (i integer)") df_in = pandas.DataFrame( { - 'numbers': [1, 2, 3, 4, 5], + "numbers": [1, 2, 3, 4, 5], } ) - conn.append('integers', df_in) - assert conn.execute('select count(*) from integers').fetchone()[0] == 5 + conn.append("integers", df_in) + assert conn.execute("select count(*) from integers").fetchone()[0] == 5 - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_append_by_name(self, pandas): con = duckdb.connect() con.execute("create table tbl (a integer, b bool, c varchar)") - df_in = pandas.DataFrame({'c': ['duck', 'db'], 'b': [False, True], 'a': [4, 2]}) + df_in = pandas.DataFrame({"c": ["duck", "db"], "b": [False, True], "a": [4, 2]}) # By default we append by position, causing the following exception: with pytest.raises( duckdb.ConversionException, match="Conversion Error: Could not convert string 'duck' to INT32" ): - con.append('tbl', df_in) + con.append("tbl", df_in) # When we use 'by_name' we instead append by name - con.append('tbl', df_in, by_name=True) - res = con.table('tbl').fetchall() - assert res == [(4, False, 'duck'), (2, True, 'db')] + con.append("tbl", df_in, by_name=True) + res = con.table("tbl").fetchall() + assert res == [(4, False, "duck"), (2, True, "db")] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_append_by_name_quoted(self, pandas): con = duckdb.connect() con.execute( @@ -41,32 +41,32 @@ def test_append_by_name_quoted(self, pandas): """ ) df_in = pandas.DataFrame({"needs to be quoted": [1, 2, 3]}) - con.append('tbl', df_in, by_name=True) - res = con.table('tbl').fetchall() + con.append("tbl", df_in, by_name=True) + res = con.table("tbl").fetchall() assert res == [(1, None), (2, None), (3, None)] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_append_by_name_no_exact_match(self, pandas): con = duckdb.connect() con.execute("create table tbl (a integer, b bool)") - df_in = pandas.DataFrame({'c': ['a', 'b'], 'b': [True, False], 'a': [42, 1337]}) + df_in = pandas.DataFrame({"c": ["a", "b"], "b": [True, False], "a": [42, 1337]}) # Too many columns raises an error, because the columns cant be found in the targeted table with pytest.raises(duckdb.BinderException, match='Table "tbl" does not have a column with name "c"'): - con.append('tbl', df_in, by_name=True) + con.append("tbl", df_in, by_name=True) - df_in = pandas.DataFrame({'b': [False, False, False]}) + df_in = pandas.DataFrame({"b": [False, False, False]}) # Not matching all columns is not a problem, as they will be filled with NULL instead - con.append('tbl', df_in, by_name=True) - res = con.table('tbl').fetchall() + con.append("tbl", df_in, by_name=True) + res = con.table("tbl").fetchall() # 'a' got filled by NULL automatically because it wasn't inserted into assert res == [(None, False), (None, False), (None, False)] # Empty the table con.execute("create or replace table tbl (a integer, b bool)") - df_in = pandas.DataFrame({'a': [1, 2, 3]}) - con.append('tbl', df_in, by_name=True) - res = con.table('tbl').fetchall() + df_in = pandas.DataFrame({"a": [1, 2, 3]}) + con.append("tbl", df_in, by_name=True) + res = con.table("tbl").fetchall() # Also works for missing columns *after* the supplied ones assert res == [(1, None), (2, None), (3, None)] diff --git a/tests/fast/pandas/test_bug2281.py b/tests/fast/pandas/test_bug2281.py index 703baf4b..98a90937 100644 --- a/tests/fast/pandas/test_bug2281.py +++ b/tests/fast/pandas/test_bug2281.py @@ -8,11 +8,11 @@ class TestPandasStringNull(object): def test_pandas_string_null(self, duckdb_cursor): - csv = u'''what,is_control,is_test + csv = """what,is_control,is_test ,0,0 -foo,1,0''' +foo,1,0""" df = pd.read_csv(io.StringIO(csv)) duckdb_cursor.register("c", df) - duckdb_cursor.execute('select what, count(*) from c group by what') + duckdb_cursor.execute("select what, count(*) from c group by what") df_result = duckdb_cursor.fetchdf() assert True # Should not crash ^^ diff --git a/tests/fast/pandas/test_bug5922.py b/tests/fast/pandas/test_bug5922.py index af9be167..28daabe9 100644 --- a/tests/fast/pandas/test_bug5922.py +++ b/tests/fast/pandas/test_bug5922.py @@ -4,13 +4,13 @@ class TestPandasAcceptFloat16(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_pandas_accept_float16(self, duckdb_cursor, pandas): - df = pandas.DataFrame({'col': [1, 2, 3]}) - df16 = df.astype({'col': 'float16'}) + df = pandas.DataFrame({"col": [1, 2, 3]}) + df16 = df.astype({"col": "float16"}) con = duckdb.connect() - con.execute('CREATE TABLE tbl AS SELECT * FROM df16') - con.execute('select * from tbl') + con.execute("CREATE TABLE tbl AS SELECT * FROM df16") + con.execute("select * from tbl") df_result = con.fetchdf() - df32 = df.astype({'col': 'float32'}) - assert (df32['col'] == df_result['col']).all() + df32 = df.astype({"col": "float32"}) + assert (df32["col"] == df_result["col"]).all() diff --git a/tests/fast/pandas/test_copy_on_write.py b/tests/fast/pandas/test_copy_on_write.py index dc484f1b..ec1b8786 100644 --- a/tests/fast/pandas/test_copy_on_write.py +++ b/tests/fast/pandas/test_copy_on_write.py @@ -2,7 +2,7 @@ import pytest # https://pandas.pydata.org/docs/dev/user_guide/copy_on_write.html -pandas = pytest.importorskip('pandas', '1.5', reason='copy_on_write does not exist in earlier versions') +pandas = pytest.importorskip("pandas", "1.5", reason="copy_on_write does not exist in earlier versions") import datetime @@ -23,9 +23,9 @@ def convert_to_result(col): class TestCopyOnWrite(object): @pytest.mark.parametrize( - 'col', + "col", [ - ['a', 'b', 'this is a long string'], + ["a", "b", "this is a long string"], [1.2334, None, 234.12], [123234, -213123, 2324234], [datetime.date(1990, 12, 7), None, datetime.date(1940, 1, 13)], @@ -37,10 +37,10 @@ def test_copy_on_write(self, col): con = duckdb.connect() df_in = pandas.DataFrame( { - 'numbers': col, + "numbers": col, } ) - rel = con.sql('select * from df_in') + rel = con.sql("select * from df_in") res = rel.fetchall() print(res) expected = convert_to_result(col) diff --git a/tests/fast/pandas/test_create_table_from_pandas.py b/tests/fast/pandas/test_create_table_from_pandas.py index 69234dc7..2194d964 100644 --- a/tests/fast/pandas/test_create_table_from_pandas.py +++ b/tests/fast/pandas/test_create_table_from_pandas.py @@ -26,12 +26,12 @@ def assert_create_register(internal_data, expected_result, data_type, pandas): class TestCreateTableFromPandas(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_integer_create_table(self, duckdb_cursor, pandas): if sys.version_info.major < 3: return # FIXME: This should work with other data types e.g., int8... - data_types = ['Int8', 'Int16', 'Int32', 'Int64'] + data_types = ["Int8", "Int16", "Int32", "Int64"] internal_data = [1, 2, 3, 4] expected_result = [(1,), (2,), (3,), (4,)] for data_type in data_types: diff --git a/tests/fast/pandas/test_date_as_datetime.py b/tests/fast/pandas/test_date_as_datetime.py index 038f24a8..b738b2e1 100644 --- a/tests/fast/pandas/test_date_as_datetime.py +++ b/tests/fast/pandas/test_date_as_datetime.py @@ -5,9 +5,9 @@ def run_checks(df): - assert type(df['d'][0]) is datetime.date - assert df['d'][0] == datetime.date(1992, 7, 30) - assert pd.isnull(df['d'][1]) + assert type(df["d"][0]) is datetime.date + assert df["d"][0] == datetime.date(1992, 7, 30) + assert pd.isnull(df["d"][1]) def test_date_as_datetime(): @@ -22,7 +22,7 @@ def test_date_as_datetime(): run_checks(con.execute("Select * from t").fetch_df(date_as_object=True)) # Relation Methods - rel = con.table('t') + rel = con.table("t") run_checks(rel.df(date_as_object=True)) run_checks(rel.to_df(date_as_object=True)) diff --git a/tests/fast/pandas/test_datetime_time.py b/tests/fast/pandas/test_datetime_time.py index cda96e6b..1a5a3f7a 100644 --- a/tests/fast/pandas/test_datetime_time.py +++ b/tests/fast/pandas/test_datetime_time.py @@ -8,24 +8,24 @@ class TestDateTimeTime(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_time_high(self, duckdb_cursor, pandas): duckdb_time = duckdb_cursor.sql("SELECT make_time(23, 1, 34.234345) AS '0'").df() data = [time(hour=23, minute=1, second=34, microsecond=234345)] - df_in = pandas.DataFrame({'0': pandas.Series(data=data, dtype='object')}) + df_in = pandas.DataFrame({"0": pandas.Series(data=data, dtype="object")}) df_out = duckdb.query_df(df_in, "df", "select * from df").df() pandas.testing.assert_frame_equal(df_out, duckdb_time) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_time_low(self, duckdb_cursor, pandas): duckdb_time = duckdb_cursor.sql("SELECT make_time(00, 01, 1.000) AS '0'").df() data = [time(hour=0, minute=1, second=1)] - df_in = pandas.DataFrame({'0': pandas.Series(data=data, dtype='object')}) + df_in = pandas.DataFrame({"0": pandas.Series(data=data, dtype="object")}) df_out = duckdb.query_df(df_in, "df", "select * from df").df() pandas.testing.assert_frame_equal(df_out, duckdb_time) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) - @pytest.mark.parametrize('input', ['2263-02-28', '9999-01-01']) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("input", ["2263-02-28", "9999-01-01"]) def test_pandas_datetime_big(self, pandas, input): duckdb_con = duckdb.connect() @@ -33,8 +33,8 @@ def test_pandas_datetime_big(self, pandas, input): duckdb_con.execute(f"INSERT INTO TEST VALUES ('{input}')") res = duckdb_con.execute("select * from test").df() - date_value = np.array([f'{input}'], dtype='datetime64[us]') - df = pandas.DataFrame({'date': date_value}) + date_value = np.array([f"{input}"], dtype="datetime64[us]") + df = pandas.DataFrame({"date": date_value}) pandas.testing.assert_frame_equal(res, df) def test_timezone_datetime(self): @@ -45,6 +45,6 @@ def test_timezone_datetime(self): original = dt stringified = str(dt) - original_res = con.execute('select ?::TIMESTAMPTZ', [original]).fetchone() - stringified_res = con.execute('select ?::TIMESTAMPTZ', [stringified]).fetchone() + original_res = con.execute("select ?::TIMESTAMPTZ", [original]).fetchone() + stringified_res = con.execute("select ?::TIMESTAMPTZ", [stringified]).fetchone() assert original_res == stringified_res diff --git a/tests/fast/pandas/test_datetime_timestamp.py b/tests/fast/pandas/test_datetime_timestamp.py index e3b26501..ffc1b7d8 100644 --- a/tests/fast/pandas/test_datetime_timestamp.py +++ b/tests/fast/pandas/test_datetime_timestamp.py @@ -9,21 +9,21 @@ class TestDateTimeTimeStamp(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_timestamp_high(self, pandas, duckdb_cursor): duckdb_time = duckdb_cursor.sql("SELECT '2260-01-01 23:59:00'::TIMESTAMP AS '0'").df() df_in = pandas.DataFrame( { 0: pandas.Series( data=[datetime.datetime(year=2260, month=1, day=1, hour=23, minute=59)], - dtype='datetime64[us]', + dtype="datetime64[us]", ) } ) df_out = duckdb_cursor.sql("select * from df_in").df() pandas.testing.assert_frame_equal(df_out, duckdb_time) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_timestamp_low(self, pandas, duckdb_cursor): duckdb_time = duckdb_cursor.sql( """ @@ -32,27 +32,27 @@ def test_timestamp_low(self, pandas, duckdb_cursor): ).df() df_in = pandas.DataFrame( { - '0': pandas.Series( + "0": pandas.Series( data=[ pandas.Timestamp( datetime.datetime(year=1680, month=1, day=1, hour=23, minute=59, microsecond=234243), - unit='us', + unit="us", ) ], - dtype='datetime64[us]', + dtype="datetime64[us]", ) } ) - print('original:', duckdb_time['0'].dtype) - print('df_in:', df_in['0'].dtype) + print("original:", duckdb_time["0"].dtype) + print("df_in:", df_in["0"].dtype) df_out = duckdb_cursor.sql("select * from df_in").df() - print('df_out:', df_out['0'].dtype) + print("df_out:", df_out["0"].dtype) pandas.testing.assert_frame_equal(df_out, duckdb_time) @pytest.mark.skipif( - Version(pd.__version__) < Version('2.0.2'), reason="pandas < 2.0.2 does not properly convert timezones" + Version(pd.__version__) < Version("2.0.2"), reason="pandas < 2.0.2 does not properly convert timezones" ) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_timestamp_timezone_regular(self, pandas, duckdb_cursor): duckdb_time = duckdb_cursor.sql( """ @@ -65,7 +65,7 @@ def test_timestamp_timezone_regular(self, pandas, duckdb_cursor): df_in = pandas.DataFrame( { 0: pandas.Series( - data=[datetime.datetime(year=2022, month=1, day=1, hour=15, tzinfo=timezone)], dtype='object' + data=[datetime.datetime(year=2022, month=1, day=1, hour=15, tzinfo=timezone)], dtype="object" ) } ) @@ -75,9 +75,9 @@ def test_timestamp_timezone_regular(self, pandas, duckdb_cursor): pandas.testing.assert_frame_equal(df_out, duckdb_time) @pytest.mark.skipif( - Version(pd.__version__) < Version('2.0.2'), reason="pandas < 2.0.2 does not properly convert timezones" + Version(pd.__version__) < Version("2.0.2"), reason="pandas < 2.0.2 does not properly convert timezones" ) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_timestamp_timezone_negative_extreme(self, pandas, duckdb_cursor): duckdb_time = duckdb_cursor.sql( """ @@ -91,7 +91,7 @@ def test_timestamp_timezone_negative_extreme(self, pandas, duckdb_cursor): df_in = pandas.DataFrame( { 0: pandas.Series( - data=[datetime.datetime(year=2021, month=12, day=31, hour=22, tzinfo=timezone)], dtype='object' + data=[datetime.datetime(year=2021, month=12, day=31, hour=22, tzinfo=timezone)], dtype="object" ) } ) @@ -99,9 +99,9 @@ def test_timestamp_timezone_negative_extreme(self, pandas, duckdb_cursor): pandas.testing.assert_frame_equal(df_out, duckdb_time) @pytest.mark.skipif( - Version(pd.__version__) < Version('2.0.2'), reason="pandas < 2.0.2 does not properly convert timezones" + Version(pd.__version__) < Version("2.0.2"), reason="pandas < 2.0.2 does not properly convert timezones" ) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_timestamp_timezone_positive_extreme(self, pandas, duckdb_cursor): duckdb_time = duckdb_cursor.sql( """ @@ -115,7 +115,7 @@ def test_timestamp_timezone_positive_extreme(self, pandas, duckdb_cursor): df_in = pandas.DataFrame( { 0: pandas.Series( - data=[datetime.datetime(year=2021, month=12, day=31, hour=23, tzinfo=timezone)], dtype='object' + data=[datetime.datetime(year=2021, month=12, day=31, hour=23, tzinfo=timezone)], dtype="object" ) } ) @@ -123,16 +123,16 @@ def test_timestamp_timezone_positive_extreme(self, pandas, duckdb_cursor): pandas.testing.assert_frame_equal(df_out, duckdb_time) @pytest.mark.skipif( - Version(pd.__version__) < Version('2.0.2'), reason="pandas < 2.0.2 does not properly convert timezones" + Version(pd.__version__) < Version("2.0.2"), reason="pandas < 2.0.2 does not properly convert timezones" ) - @pytest.mark.parametrize('unit', ['ms', 'ns', 's']) + @pytest.mark.parametrize("unit", ["ms", "ns", "s"]) def test_timestamp_timezone_coverage(self, unit, duckdb_cursor): pd = pytest.importorskip("pandas") ts_df = pd.DataFrame( - {'ts': pd.Series(data=[pd.Timestamp(datetime.datetime(1990, 12, 21))], dtype=f'datetime64[{unit}]')} + {"ts": pd.Series(data=[pd.Timestamp(datetime.datetime(1990, 12, 21))], dtype=f"datetime64[{unit}]")} ) usecond_df = pd.DataFrame( - {'ts': pd.Series(data=[pd.Timestamp(datetime.datetime(1990, 12, 21))], dtype='datetime64[us]')} + {"ts": pd.Series(data=[pd.Timestamp(datetime.datetime(1990, 12, 21))], dtype="datetime64[us]")} ) query = """ @@ -142,12 +142,12 @@ def test_timestamp_timezone_coverage(self, unit, duckdb_cursor): """ duckdb_cursor.sql("set TimeZone = 'UTC'") - utc_usecond = duckdb_cursor.sql(query.format('usecond_df')).df() - utc_other = duckdb_cursor.sql(query.format('ts_df')).df() + utc_usecond = duckdb_cursor.sql(query.format("usecond_df")).df() + utc_other = duckdb_cursor.sql(query.format("ts_df")).df() duckdb_cursor.sql("set TimeZone = 'America/Los_Angeles'") - us_usecond = duckdb_cursor.sql(query.format('usecond_df')).df() - us_other = duckdb_cursor.sql(query.format('ts_df')).df() + us_usecond = duckdb_cursor.sql(query.format("usecond_df")).df() + us_other = duckdb_cursor.sql(query.format("ts_df")).df() pd.testing.assert_frame_equal(utc_usecond, utc_other) pd.testing.assert_frame_equal(us_usecond, us_other) diff --git a/tests/fast/pandas/test_df_analyze.py b/tests/fast/pandas/test_df_analyze.py index 114f8e3f..8e67da4a 100644 --- a/tests/fast/pandas/test_df_analyze.py +++ b/tests/fast/pandas/test_df_analyze.py @@ -6,11 +6,11 @@ def create_generic_dataframe(data, pandas): - return pandas.DataFrame({'col0': pandas.Series(data=data, dtype='object')}) + return pandas.DataFrame({"col0": pandas.Series(data=data, dtype="object")}) class TestResolveObjectColumns(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_sample_low_correct(self, duckdb_cursor, pandas): print(pandas.backend) duckdb_conn = duckdb.connect() @@ -21,7 +21,7 @@ def test_sample_low_correct(self, duckdb_cursor, pandas): duckdb_df = duckdb_conn.query("select * FROM (VALUES (1000008), (6), (9), (4), (1), (6)) as '0'").df() pandas.testing.assert_frame_equal(duckdb_df, roundtripped_df, check_dtype=False) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_sample_low_incorrect_detected(self, duckdb_cursor, pandas): duckdb_conn = duckdb.connect() duckdb_conn.execute("SET pandas_analyze_sample=2") @@ -31,9 +31,9 @@ def test_sample_low_incorrect_detected(self, duckdb_cursor, pandas): df = create_generic_dataframe(data, pandas) roundtripped_df = duckdb.query_df(df, "x", "select * from x", connection=duckdb_conn).df() # Sample high enough to detect mismatch in types, fallback to VARCHAR - assert roundtripped_df['col0'].dtype == np.dtype('object') + assert roundtripped_df["col0"].dtype == np.dtype("object") - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_sample_zero(self, duckdb_cursor, pandas): duckdb_conn = duckdb.connect() # Disable dataframe analyze @@ -42,12 +42,12 @@ def test_sample_zero(self, duckdb_cursor, pandas): df = create_generic_dataframe(data, pandas) roundtripped_df = duckdb.query_df(df, "x", "select * from x", connection=duckdb_conn).df() # Always converts to VARCHAR - if pandas.backend == 'pyarrow': - assert roundtripped_df['col0'].dtype == np.dtype('int64') + if pandas.backend == "pyarrow": + assert roundtripped_df["col0"].dtype == np.dtype("int64") else: - assert roundtripped_df['col0'].dtype == np.dtype('object') + assert roundtripped_df["col0"].dtype == np.dtype("object") - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_sample_low_incorrect_undetected(self, duckdb_cursor, pandas): duckdb_conn = duckdb.connect() duckdb_conn.execute("SET pandas_analyze_sample=1") @@ -65,10 +65,10 @@ def test_reset_analyze_sample_setting(self, duckdb_cursor): res = duckdb_cursor.execute("select current_setting('pandas_analyze_sample')").fetchall() assert res == [(1000,)] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_10750(self, duckdb_cursor, pandas): max_row_number = 2000 - data = {'id': [i for i in range(max_row_number + 1)], 'content': [None for _ in range(max_row_number + 1)]} + data = {"id": [i for i in range(max_row_number + 1)], "content": [None for _ in range(max_row_number + 1)]} pdf = pandas.DataFrame(data=data) duckdb_cursor.register("content", pdf) diff --git a/tests/fast/pandas/test_df_object_resolution.py b/tests/fast/pandas/test_df_object_resolution.py index d54db072..73470818 100644 --- a/tests/fast/pandas/test_df_object_resolution.py +++ b/tests/fast/pandas/test_df_object_resolution.py @@ -13,7 +13,7 @@ def create_generic_dataframe(data, pandas): - return pandas.DataFrame({'0': pandas.Series(data=data, dtype='object')}) + return pandas.DataFrame({"0": pandas.Series(data=data, dtype="object")}) def create_repeated_nulls(size): @@ -25,7 +25,7 @@ def create_repeated_nulls(size): def create_trailing_non_null(size): data = [None for _ in range(size - 1)] - data.append('this is a long string') + data.append("this is a long string") return data @@ -43,7 +43,7 @@ def ConvertStringToDecimal(data: list, pandas): for i in range(len(data)): if isinstance(data[i], str): data[i] = decimal.Decimal(data[i]) - data = pandas.Series(data=data, dtype='object') + data = pandas.Series(data=data, dtype="object") return data @@ -61,13 +61,13 @@ def construct_list(pair): def construct_struct(pair): - return [{'v1': pair.first}, {'v1': pair.second}] + return [{"v1": pair.first}, {"v1": pair.second}] def construct_map(pair): return [ - {'key': ['v1', 'v2'], "value": [pair.first, pair.first]}, - {'key': ['v1', 'v2'], "value": [pair.second, pair.second]}, + {"key": ["v1", "v2"], "value": [pair.first, pair.first]}, + {"key": ["v1", "v2"], "value": [pair.second, pair.second]}, ] @@ -83,157 +83,157 @@ def check_struct_upgrade(expected_type: str, creation_method, pair: ObjectPair, class TestResolveObjectColumns(object): # TODO: add support for ArrowPandas - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_integers(self, pandas, duckdb_cursor): data = [5, 0, 3] df_in = create_generic_dataframe(data, pandas) # These are float64 because pandas would force these to be float64 even if we set them to int8, int16, int32, int64 respectively - df_expected_res = pandas.DataFrame({'0': pandas.Series(data=data, dtype='int32')}) + df_expected_res = pandas.DataFrame({"0": pandas.Series(data=data, dtype="int32")}) df_out = duckdb_cursor.sql("SELECT * FROM df_in").df() print(df_out) pandas.testing.assert_frame_equal(df_expected_res, df_out) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_struct_correct(self, pandas, duckdb_cursor): - data = [{'a': 1, 'b': 3, 'c': 3, 'd': 7}] - df = pandas.DataFrame({'0': pandas.Series(data=data)}) + data = [{"a": 1, "b": 3, "c": 3, "d": 7}] + df = pandas.DataFrame({"0": pandas.Series(data=data)}) duckdb_col = duckdb_cursor.sql("SELECT {a: 1, b: 3, c: 3, d: 7} as '0'").df() converted_col = duckdb_cursor.sql("SELECT * FROM df").df() pandas.testing.assert_frame_equal(duckdb_col, converted_col) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_map_fallback_different_keys(self, pandas, duckdb_cursor): x = pandas.DataFrame( [ - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'e': 7}], #'e' instead of 'd' as key - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "e": 7}], #'e' instead of 'd' as key + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], ] ) converted_df = duckdb_cursor.sql("SELECT * FROM x").df() y = pandas.DataFrame( [ - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'e'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "e"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], ] ) equal_df = duckdb_cursor.sql("SELECT * FROM y").df() pandas.testing.assert_frame_equal(converted_df, equal_df) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_map_fallback_incorrect_amount_of_keys(self, pandas, duckdb_cursor): x = pandas.DataFrame( [ - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3}], # incorrect amount of keys - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3}], # incorrect amount of keys + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], ] ) converted_df = duckdb_cursor.sql("SELECT * FROM x").df() y = pandas.DataFrame( [ - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c'], 'value': [1, 3, 3]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c"], "value": [1, 3, 3]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], ] ) equal_df = duckdb_cursor.sql("SELECT * FROM y").df() pandas.testing.assert_frame_equal(converted_df, equal_df) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_struct_value_upgrade(self, pandas, duckdb_cursor): x = pandas.DataFrame( [ - [{'a': 1, 'b': 3, 'c': 3, 'd': 'string'}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], + [{"a": 1, "b": 3, "c": 3, "d": "string"}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], ] ) y = pandas.DataFrame( [ - [{'a': 1, 'b': 3, 'c': 3, 'd': 'string'}], - [{'a': 1, 'b': 3, 'c': 3, 'd': '7'}], - [{'a': 1, 'b': 3, 'c': 3, 'd': '7'}], - [{'a': 1, 'b': 3, 'c': 3, 'd': '7'}], - [{'a': 1, 'b': 3, 'c': 3, 'd': '7'}], + [{"a": 1, "b": 3, "c": 3, "d": "string"}], + [{"a": 1, "b": 3, "c": 3, "d": "7"}], + [{"a": 1, "b": 3, "c": 3, "d": "7"}], + [{"a": 1, "b": 3, "c": 3, "d": "7"}], + [{"a": 1, "b": 3, "c": 3, "d": "7"}], ] ) converted_df = duckdb_cursor.sql("SELECT * FROM x").df() equal_df = duckdb_cursor.sql("SELECT * FROM y").df() pandas.testing.assert_frame_equal(converted_df, equal_df) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_struct_null(self, pandas, duckdb_cursor): x = pandas.DataFrame( [ [None], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], ] ) y = pandas.DataFrame( [ [None], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], ] ) converted_df = duckdb_cursor.sql("SELECT * FROM x").df() equal_df = duckdb_cursor.sql("SELECT * FROM y").df() pandas.testing.assert_frame_equal(converted_df, equal_df) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_map_fallback_value_upgrade(self, pandas, duckdb_cursor): x = pandas.DataFrame( [ - [{'a': 1, 'b': 3, 'c': 3, 'd': 'test'}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], + [{"a": 1, "b": 3, "c": 3, "d": "test"}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], ] ) y = pandas.DataFrame( [ - [{'a': '1', 'b': '3', 'c': '3', 'd': 'test'}], - [{'a': '1', 'b': '3', 'c': '3', 'd': '7'}], - [{'a': '1', 'b': '3', 'c': '3'}], - [{'a': '1', 'b': '3', 'c': '3', 'd': '7'}], - [{'a': '1', 'b': '3', 'c': '3', 'd': '7'}], + [{"a": "1", "b": "3", "c": "3", "d": "test"}], + [{"a": "1", "b": "3", "c": "3", "d": "7"}], + [{"a": "1", "b": "3", "c": "3"}], + [{"a": "1", "b": "3", "c": "3", "d": "7"}], + [{"a": "1", "b": "3", "c": "3", "d": "7"}], ] ) converted_df = duckdb_cursor.sql("SELECT * FROM x").df() equal_df = duckdb_cursor.sql("SELECT * FROM y").df() pandas.testing.assert_frame_equal(converted_df, equal_df) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_map_correct(self, pandas, duckdb_cursor): x = pandas.DataFrame( [ - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], ] ) - x.rename(columns={0: 'a'}, inplace=True) + x.rename(columns={0: "a"}, inplace=True) converted_col = duckdb_cursor.sql("select * from x as 'a'").df() duckdb_cursor.sql( """ @@ -253,10 +253,10 @@ def test_map_correct(self, pandas, duckdb_cursor): print(converted_col.columns) pandas.testing.assert_frame_equal(converted_col, duckdb_col) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) - @pytest.mark.parametrize('sample_size', [1, 10]) - @pytest.mark.parametrize('fill', [1000, 10000]) - @pytest.mark.parametrize('get_data', [create_repeated_nulls, create_trailing_non_null]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("sample_size", [1, 10]) + @pytest.mark.parametrize("fill", [1000, 10000]) + @pytest.mark.parametrize("get_data", [create_repeated_nulls, create_trailing_non_null]) def test_analyzing_nulls(self, pandas, duckdb_cursor, fill, sample_size, get_data): data = get_data(fill) df1 = pandas.DataFrame(data={"col1": data}) @@ -265,9 +265,9 @@ def test_analyzing_nulls(self, pandas, duckdb_cursor, fill, sample_size, get_dat pandas.testing.assert_frame_equal(df1, df) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_nested_map(self, pandas, duckdb_cursor): - df = pandas.DataFrame(data={'col1': [{'a': {'b': {'x': 'A', 'y': 'B'}}}, {'c': {'b': {'x': 'A'}}}]}) + df = pandas.DataFrame(data={"col1": [{"a": {"b": {"x": "A", "y": "B"}}}, {"c": {"b": {"x": "A"}}}]}) rel = duckdb_cursor.sql("select * from df") expected_rel = duckdb_cursor.sql( @@ -283,18 +283,18 @@ def test_nested_map(self, pandas, duckdb_cursor): expected_res = str(expected_rel) assert res == expected_res - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_map_value_upgrade(self, pandas, duckdb_cursor): x = pandas.DataFrame( [ - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 'test']}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, "test"]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], ] ) - x.rename(columns={0: 'a'}, inplace=True) + x.rename(columns={0: "a"}, inplace=True) converted_col = duckdb_cursor.sql("select * from x").df() duckdb_cursor.sql( """ @@ -319,69 +319,69 @@ def test_map_value_upgrade(self, pandas, duckdb_cursor): print(converted_col.columns) pandas.testing.assert_frame_equal(converted_col, duckdb_col) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_map_duplicate(self, pandas, duckdb_cursor): - x = pandas.DataFrame([[{'key': ['a', 'a', 'b'], 'value': [4, 0, 4]}]]) + x = pandas.DataFrame([[{"key": ["a", "a", "b"], "value": [4, 0, 4]}]]) with pytest.raises(duckdb.InvalidInputException, match="Map keys must be unique."): duckdb_cursor.sql("select * from x").show() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_map_nullkey(self, pandas, duckdb_cursor): - x = pandas.DataFrame([[{'key': [None, 'a', 'b'], 'value': [4, 0, 4]}]]) + x = pandas.DataFrame([[{"key": [None, "a", "b"], "value": [4, 0, 4]}]]) with pytest.raises(duckdb.InvalidInputException, match="Map keys can not be NULL."): converted_col = duckdb_cursor.sql("select * from x").df() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_map_nullkeylist(self, pandas, duckdb_cursor): - x = pandas.DataFrame([[{'key': None, 'value': None}]]) + x = pandas.DataFrame([[{"key": None, "value": None}]]) converted_col = duckdb_cursor.sql("select * from x").df() duckdb_col = duckdb_cursor.sql("SELECT MAP(NULL, NULL) as '0'").df() pandas.testing.assert_frame_equal(duckdb_col, converted_col) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_map_fallback_nullkey(self, pandas, duckdb_cursor): - x = pandas.DataFrame([[{'a': 4, None: 0, 'c': 4}], [{'a': 4, None: 0, 'd': 4}]]) + x = pandas.DataFrame([[{"a": 4, None: 0, "c": 4}], [{"a": 4, None: 0, "d": 4}]]) with pytest.raises(duckdb.InvalidInputException, match="Map keys can not be NULL."): converted_col = duckdb_cursor.sql("select * from x").df() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_map_fallback_nullkey_coverage(self, pandas, duckdb_cursor): x = pandas.DataFrame( [ - [{'key': None, 'value': None}], - [{'key': None, None: 5}], + [{"key": None, "value": None}], + [{"key": None, None: 5}], ] ) with pytest.raises(duckdb.InvalidInputException, match="Map keys can not be NULL."): converted_col = duckdb_cursor.sql("select * from x").df() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_structs_in_nested_types(self, pandas, duckdb_cursor): # This test is testing a bug that occurred when type upgrades occurred inside nested types # STRUCT(key1 varchar) + STRUCT(key1 varchar, key2 varchar) turns into MAP # But when inside a nested structure, this upgrade did not happen properly pairs = { - 'v1': ObjectPair({'key1': 21}, {'key1': 21, 'key2': 42}), - 'v2': ObjectPair({'key1': 21}, {'key2': 21}), - 'v3': ObjectPair({'key1': 21, 'key2': 42}, {'key1': 21}), - 'v4': ObjectPair({}, {'key1': 21}), + "v1": ObjectPair({"key1": 21}, {"key1": 21, "key2": 42}), + "v2": ObjectPair({"key1": 21}, {"key2": 21}), + "v3": ObjectPair({"key1": 21, "key2": 42}, {"key1": 21}), + "v4": ObjectPair({}, {"key1": 21}), } for _, pair in pairs.items(): - check_struct_upgrade('MAP(VARCHAR, INTEGER)[]', construct_list, pair, pandas, duckdb_cursor) + check_struct_upgrade("MAP(VARCHAR, INTEGER)[]", construct_list, pair, pandas, duckdb_cursor) for key, pair in pairs.items(): - if key == 'v4': - expected_type = 'MAP(VARCHAR, MAP(VARCHAR, INTEGER))' + if key == "v4": + expected_type = "MAP(VARCHAR, MAP(VARCHAR, INTEGER))" else: - expected_type = 'STRUCT(v1 MAP(VARCHAR, INTEGER))' + expected_type = "STRUCT(v1 MAP(VARCHAR, INTEGER))" check_struct_upgrade(expected_type, construct_struct, pair, pandas, duckdb_cursor) for key, pair in pairs.items(): - check_struct_upgrade('MAP(VARCHAR, MAP(VARCHAR, INTEGER))', construct_map, pair, pandas, duckdb_cursor) + check_struct_upgrade("MAP(VARCHAR, MAP(VARCHAR, INTEGER))", construct_map, pair, pandas, duckdb_cursor) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_structs_of_different_sizes(self, pandas, duckdb_cursor): # This list has both a STRUCT(v1) and a STRUCT(v1, v2) member # Those can't be combined @@ -404,9 +404,9 @@ def test_structs_of_different_sizes(self, pandas, duckdb_cursor): ) res = duckdb_cursor.query("select typeof(col) from df").fetchall() # So we fall back to converting them as VARCHAR instead - assert res == [('MAP(VARCHAR, VARCHAR)[]',), ('MAP(VARCHAR, VARCHAR)[]',)] + assert res == [("MAP(VARCHAR, VARCHAR)[]",), ("MAP(VARCHAR, VARCHAR)[]",)] - malformed_struct = duckdb.Value({"v1": 1, "v2": 2}, duckdb.struct_type({'v1': int})) + malformed_struct = duckdb.Value({"v1": 1, "v2": 2}, duckdb.struct_type({"v1": int})) with pytest.raises( duckdb.InvalidInputException, match=re.escape( @@ -416,7 +416,7 @@ def test_structs_of_different_sizes(self, pandas, duckdb_cursor): res = duckdb_cursor.execute("select $1", [malformed_struct]) print(res) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_struct_key_conversion(self, pandas, duckdb_cursor): x = pandas.DataFrame( [ @@ -428,48 +428,48 @@ def test_struct_key_conversion(self, pandas, duckdb_cursor): duckdb_cursor.sql("drop view if exists tbl") pandas.testing.assert_frame_equal(duckdb_col, converted_col) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_list_correct(self, pandas, duckdb_cursor): - x = pandas.DataFrame([{'0': [[5], [34], [-245]]}]) + x = pandas.DataFrame([{"0": [[5], [34], [-245]]}]) duckdb_col = duckdb_cursor.sql("select [[5], [34], [-245]] as '0'").df() converted_col = duckdb_cursor.sql("select * from x").df() duckdb_cursor.sql("drop view if exists tbl") pandas.testing.assert_frame_equal(duckdb_col, converted_col) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_list_contains_null(self, pandas, duckdb_cursor): - x = pandas.DataFrame([{'0': [[5], None, [-245]]}]) + x = pandas.DataFrame([{"0": [[5], None, [-245]]}]) duckdb_col = duckdb_cursor.sql("select [[5], NULL, [-245]] as '0'").df() converted_col = duckdb_cursor.sql("select * from x").df() duckdb_cursor.sql("drop view if exists tbl") pandas.testing.assert_frame_equal(duckdb_col, converted_col) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_list_starts_with_null(self, pandas, duckdb_cursor): - x = pandas.DataFrame([{'0': [None, [5], [-245]]}]) + x = pandas.DataFrame([{"0": [None, [5], [-245]]}]) duckdb_col = duckdb_cursor.sql("select [NULL, [5], [-245]] as '0'").df() converted_col = duckdb_cursor.sql("select * from x").df() duckdb_cursor.sql("drop view if exists tbl") pandas.testing.assert_frame_equal(duckdb_col, converted_col) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_list_value_upgrade(self, pandas, duckdb_cursor): - x = pandas.DataFrame([{'0': [['5'], [34], [-245]]}]) + x = pandas.DataFrame([{"0": [["5"], [34], [-245]]}]) duckdb_rel = duckdb_cursor.sql("select [['5'], ['34'], ['-245']] as '0'") duckdb_col = duckdb_rel.df() converted_col = duckdb_cursor.sql("select * from x").df() pandas.testing.assert_frame_equal(duckdb_col, converted_col) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_list_column_value_upgrade(self, pandas, duckdb_cursor): x = pandas.DataFrame( [ [[1, 25, 300]], [[500, 345, 30]], - [[50, 'a', 67]], + [[50, "a", 67]], ] ) - x.rename(columns={0: 'a'}, inplace=True) + x.rename(columns={0: "a"}, inplace=True) converted_col = duckdb_cursor.sql("select * from x").df() duckdb_cursor.sql( """ @@ -498,29 +498,29 @@ def test_list_column_value_upgrade(self, pandas, duckdb_cursor): print(converted_col.columns) pandas.testing.assert_frame_equal(converted_col, duckdb_col) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_ubigint_object_conversion(self, pandas, duckdb_cursor): # UBIGINT + TINYINT would result in HUGEINT, but conversion to HUGEINT is not supported yet from pandas->duckdb # So this instead becomes a DOUBLE data = [18446744073709551615, 0] - x = pandas.DataFrame({'0': pandas.Series(data=data, dtype='object')}) + x = pandas.DataFrame({"0": pandas.Series(data=data, dtype="object")}) converted_col = duckdb_cursor.sql("select * from x").df() - if pandas.backend == 'numpy_nullable': - float64 = np.dtype('float64') - assert isinstance(converted_col['0'].dtype, float64.__class__) == True + if pandas.backend == "numpy_nullable": + float64 = np.dtype("float64") + assert isinstance(converted_col["0"].dtype, float64.__class__) == True else: - uint64 = np.dtype('uint64') - assert isinstance(converted_col['0'].dtype, uint64.__class__) == True + uint64 = np.dtype("uint64") + assert isinstance(converted_col["0"].dtype, uint64.__class__) == True - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_double_object_conversion(self, pandas, duckdb_cursor): data = [18446744073709551616, 0] - x = pandas.DataFrame({'0': pandas.Series(data=data, dtype='object')}) + x = pandas.DataFrame({"0": pandas.Series(data=data, dtype="object")}) converted_col = duckdb_cursor.sql("select * from x").df() - double_dtype = np.dtype('float64') - assert isinstance(converted_col['0'].dtype, double_dtype.__class__) == True + double_dtype = np.dtype("float64") + assert isinstance(converted_col["0"].dtype, double_dtype.__class__) == True - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) @pytest.mark.xfail( condition=platform.system() == "Emscripten", reason="older numpy raises a warning when running with Pyodide", @@ -551,51 +551,51 @@ def test_numpy_object_with_stride(self, pandas, duckdb_cursor): (9, 18, 0), ] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_numpy_stringliterals(self, pandas, duckdb_cursor): df = pandas.DataFrame({"x": list(map(np.str_, range(3)))}) res = duckdb_cursor.execute("select * from df").fetchall() - assert res == [('0',), ('1',), ('2',)] + assert res == [("0",), ("1",), ("2",)] - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_integer_conversion_fail(self, pandas, duckdb_cursor): data = [2**10000, 0] - x = pandas.DataFrame({'0': pandas.Series(data=data, dtype='object')}) + x = pandas.DataFrame({"0": pandas.Series(data=data, dtype="object")}) converted_col = duckdb_cursor.sql("select * from x").df() - print(converted_col['0']) - double_dtype = np.dtype('object') - assert isinstance(converted_col['0'].dtype, double_dtype.__class__) == True + print(converted_col["0"]) + double_dtype = np.dtype("object") + assert isinstance(converted_col["0"].dtype, double_dtype.__class__) == True # Most of the time numpy.datetime64 is just a wrapper around a datetime.datetime object # But to support arbitrary precision, it can fall back to using an `int` internally - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) # Which we don't support yet + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) # Which we don't support yet def test_numpy_datetime(self, pandas, duckdb_cursor): numpy = pytest.importorskip("numpy") data = [] - data += [numpy.datetime64('2022-12-10T21:38:24.578696')] * standard_vector_size - data += [numpy.datetime64('2022-02-21T06:59:23.324812')] * standard_vector_size - data += [numpy.datetime64('1974-06-05T13:12:01.000000')] * standard_vector_size - data += [numpy.datetime64('2049-01-13T00:24:31.999999')] * standard_vector_size - x = pandas.DataFrame({'dates': pandas.Series(data=data, dtype='object')}) + data += [numpy.datetime64("2022-12-10T21:38:24.578696")] * standard_vector_size + data += [numpy.datetime64("2022-02-21T06:59:23.324812")] * standard_vector_size + data += [numpy.datetime64("1974-06-05T13:12:01.000000")] * standard_vector_size + data += [numpy.datetime64("2049-01-13T00:24:31.999999")] * standard_vector_size + x = pandas.DataFrame({"dates": pandas.Series(data=data, dtype="object")}) res = duckdb_cursor.sql("select distinct * from x").df() - assert len(res['dates'].__array__()) == 4 + assert len(res["dates"].__array__()) == 4 - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_numpy_datetime_int_internally(self, pandas, duckdb_cursor): numpy = pytest.importorskip("numpy") - data = [numpy.datetime64('2022-12-10T21:38:24.0000000000001')] - x = pandas.DataFrame({'dates': pandas.Series(data=data, dtype='object')}) + data = [numpy.datetime64("2022-12-10T21:38:24.0000000000001")] + x = pandas.DataFrame({"dates": pandas.Series(data=data, dtype="object")}) with pytest.raises( duckdb.ConversionException, match=re.escape("Conversion Error: Unimplemented type for cast (BIGINT -> TIMESTAMP)"), ): rel = duckdb.query_df(x, "x", "create table dates as select dates::TIMESTAMP WITHOUT TIME ZONE from x") - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_fallthrough_object_conversion(self, pandas, duckdb_cursor): x = pandas.DataFrame( [ @@ -605,10 +605,10 @@ def test_fallthrough_object_conversion(self, pandas, duckdb_cursor): ] ) duckdb_col = duckdb_cursor.sql("select * from x").df() - df_expected_res = pandas.DataFrame({'0': pandas.Series(['4', '2', '0'])}) + df_expected_res = pandas.DataFrame({"0": pandas.Series(["4", "2", "0"])}) pandas.testing.assert_frame_equal(duckdb_col, df_expected_res) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_numeric_decimal(self, pandas, duckdb_cursor): # DuckDB uses DECIMAL where possible, so all the 'float' types here are actually DECIMAL reference_query = """ @@ -626,12 +626,12 @@ def test_numeric_decimal(self, pandas, duckdb_cursor): # Because of this we need to wrap these native floats as DECIMAL for this test, to avoid these decimals being "upgraded" to DOUBLE x = pandas.DataFrame( { - '0': ConvertStringToDecimal([5, '12.0', '-123.0', '-234234.0', None, '1.234'], pandas), - '1': ConvertStringToDecimal( - [5002340, 13, '-12.0000000005', '7453324234.0', None, '-324234234'], pandas + "0": ConvertStringToDecimal([5, "12.0", "-123.0", "-234234.0", None, "1.234"], pandas), + "1": ConvertStringToDecimal( + [5002340, 13, "-12.0000000005", "7453324234.0", None, "-324234234"], pandas ), - '2': ConvertStringToDecimal( - ['-234234234234.0', '324234234.00000005', -128, 345345, '1E5', '1324234359'], pandas + "2": ConvertStringToDecimal( + ["-234234234234.0", "324234234.00000005", -128, 345345, "1E5", "1324234359"], pandas ), } ) @@ -640,10 +640,10 @@ def test_numeric_decimal(self, pandas, duckdb_cursor): assert conversion == reference - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_numeric_decimal_coverage(self, pandas, duckdb_cursor): x = pandas.DataFrame( - {'0': [Decimal("nan"), Decimal("+nan"), Decimal("-nan"), Decimal("inf"), Decimal("+inf"), Decimal("-inf")]} + {"0": [Decimal("nan"), Decimal("+nan"), Decimal("-nan"), Decimal("inf"), Decimal("+inf"), Decimal("-inf")]} ) conversion = duckdb_cursor.sql("select * from x").fetchall() print(conversion[0][0].__class__) @@ -655,12 +655,12 @@ def test_numeric_decimal_coverage(self, pandas, duckdb_cursor): assert math.isinf(conversion[3][0]) assert math.isinf(conversion[4][0]) assert math.isinf(conversion[5][0]) - assert str(conversion) == '[(nan,), (nan,), (nan,), (inf,), (inf,), (inf,)]' + assert str(conversion) == "[(nan,), (nan,), (nan,), (inf,), (inf,), (inf,)]" # Test that the column 'offset' is actually used when converting, @pytest.mark.parametrize( - 'pandas', [NumpyPandas(), ArrowPandas()] + "pandas", [NumpyPandas(), ArrowPandas()] ) # and that the same 2048 (STANDARD_VECTOR_SIZE) values are not being scanned over and over again def test_multiple_chunks(self, pandas, duckdb_cursor): data = [] @@ -668,11 +668,11 @@ def test_multiple_chunks(self, pandas, duckdb_cursor): data += [datetime.date(2022, 9, 14) for x in range(standard_vector_size)] data += [datetime.date(2022, 9, 15) for x in range(standard_vector_size)] data += [datetime.date(2022, 9, 16) for x in range(standard_vector_size)] - x = pandas.DataFrame({'dates': pandas.Series(data=data, dtype='object')}) + x = pandas.DataFrame({"dates": pandas.Series(data=data, dtype="object")}) res = duckdb_cursor.sql("select distinct * from x").df() - assert len(res['dates'].__array__()) == 4 + assert len(res["dates"].__array__()) == 4 - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_multiple_chunks_aggregate(self, pandas, duckdb_cursor): duckdb_cursor.execute(f"SET GLOBAL pandas_analyze_sample=4096") duckdb_cursor.execute( @@ -683,8 +683,8 @@ def test_multiple_chunks_aggregate(self, pandas, duckdb_cursor): date_df = res.copy() # Convert the dataframe to datetime - date_df['i'] = pandas.to_datetime(res['i']).dt.date - assert str(date_df['i'].dtype) == 'object' + date_df["i"] = pandas.to_datetime(res["i"]).dt.date + assert str(date_df["i"].dtype) == "object" expected_res = [ ( @@ -707,10 +707,10 @@ def test_multiple_chunks_aggregate(self, pandas, duckdb_cursor): assert expected_res == actual_res # Now interleave nulls into the dataframe - duckdb_cursor.execute('drop table dates') - for i in range(0, len(res['i']), 2): - res.loc[i, 'i'] = None - duckdb_cursor.execute('create table dates as select * from res') + duckdb_cursor.execute("drop table dates") + for i in range(0, len(res["i"]), 2): + res.loc[i, "i"] = None + duckdb_cursor.execute("create table dates as select * from res") expected_res = [ ( @@ -721,8 +721,8 @@ def test_multiple_chunks_aggregate(self, pandas, duckdb_cursor): ] # Convert the dataframe to datetime date_df = res.copy() - date_df['i'] = pandas.to_datetime(res['i']).dt.date - assert str(date_df['i'].dtype) == 'object' + date_df["i"] = pandas.to_datetime(res["i"]).dt.date + assert str(date_df["i"].dtype) == "object" actual_res = duckdb_cursor.sql( """ @@ -736,47 +736,47 @@ def test_multiple_chunks_aggregate(self, pandas, duckdb_cursor): assert expected_res == actual_res - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_mixed_object_types(self, pandas, duckdb_cursor): x = pandas.DataFrame( { - 'nested': pandas.Series( - data=[{'a': 1, 'b': 2}, [5, 4, 3], {'key': [1, 2, 3], 'value': ['a', 'b', 'c']}], dtype='object' + "nested": pandas.Series( + data=[{"a": 1, "b": 2}, [5, 4, 3], {"key": [1, 2, 3], "value": ["a", "b", "c"]}], dtype="object" ), } ) res = duckdb_cursor.sql("select * from x").df() - assert res['nested'].dtype == np.dtype('object') + assert res["nested"].dtype == np.dtype("object") - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_struct_deeply_nested_in_struct(self, pandas, duckdb_cursor): x = pandas.DataFrame( [ { # STRUCT(b STRUCT(x VARCHAR, y VARCHAR)) - 'a': {'b': {'x': 'A', 'y': 'B'}} + "a": {"b": {"x": "A", "y": "B"}} }, { # STRUCT(b STRUCT(x VARCHAR)) - 'a': {'b': {'x': 'A'}} + "a": {"b": {"x": "A"}} }, ] ) # The dataframe has incompatible struct schemas in the nested child # This gets upgraded to STRUCT(b MAP(VARCHAR, VARCHAR)) res = duckdb_cursor.sql("select * from x").fetchall() - assert res == [({'b': {'x': 'A', 'y': 'B'}},), ({'b': {'x': 'A'}},)] + assert res == [({"b": {"x": "A", "y": "B"}},), ({"b": {"x": "A"}},)] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_struct_deeply_nested_in_list(self, pandas, duckdb_cursor): x = pandas.DataFrame( { - 'a': [ + "a": [ [ # STRUCT(x VARCHAR, y VARCHAR)[] - {'x': 'A', 'y': 'B'}, + {"x": "A", "y": "B"}, # STRUCT(x VARCHAR)[] - {'x': 'A'}, + {"x": "A"}, ] ] } @@ -784,16 +784,16 @@ def test_struct_deeply_nested_in_list(self, pandas, duckdb_cursor): # The dataframe has incompatible struct schemas in the nested child # This gets upgraded to STRUCT(b MAP(VARCHAR, VARCHAR)) res = duckdb_cursor.sql("select * from x").fetchall() - assert res == [([{'x': 'A', 'y': 'B'}, {'x': 'A'}],)] + assert res == [([{"x": "A", "y": "B"}, {"x": "A"}],)] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_analyze_sample_too_small(self, pandas, duckdb_cursor): data = [1 for _ in range(9)] + [[1, 2, 3]] + [1 for _ in range(9991)] - x = pandas.DataFrame({'a': pandas.Series(data=data)}) + x = pandas.DataFrame({"a": pandas.Series(data=data)}) with pytest.raises(duckdb.InvalidInputException, match="Failed to cast value: Unimplemented type for cast"): res = duckdb_cursor.sql("select * from x").df() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_numeric_decimal_zero_fractional(self, pandas, duckdb_cursor): decimals = pandas.DataFrame( data={ @@ -826,7 +826,7 @@ def test_numeric_decimal_zero_fractional(self, pandas, duckdb_cursor): assert conversion == reference - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_numeric_decimal_incompatible(self, pandas, duckdb_cursor): reference_query = """ CREATE TABLE tbl AS SELECT * FROM ( @@ -842,10 +842,10 @@ def test_numeric_decimal_incompatible(self, pandas, duckdb_cursor): duckdb_cursor.execute(reference_query) x = pandas.DataFrame( { - '0': ConvertStringToDecimal(['5', '12.0', '-123.0', '-234234.0', None, '1.234'], pandas), - '1': ConvertStringToDecimal([5002340, 13, '-12.0000000005', 7453324234, None, '-324234234'], pandas), - '2': ConvertStringToDecimal( - [-234234234234, '324234234.00000005', -128, 345345, 0, '1324234359'], pandas + "0": ConvertStringToDecimal(["5", "12.0", "-123.0", "-234234.0", None, "1.234"], pandas), + "1": ConvertStringToDecimal([5002340, 13, "-12.0000000005", 7453324234, None, "-324234234"], pandas), + "2": ConvertStringToDecimal( + [-234234234234, "324234234.00000005", -128, 345345, 0, "1324234359"], pandas ), } ) @@ -857,7 +857,7 @@ def test_numeric_decimal_incompatible(self, pandas, duckdb_cursor): print(conversion) @pytest.mark.parametrize( - 'pandas', [NumpyPandas(), ArrowPandas()] + "pandas", [NumpyPandas(), ArrowPandas()] ) # result: [('1E-28',), ('10000000000000000000000000.0',)] def test_numeric_decimal_combined(self, pandas, duckdb_cursor): decimals = pandas.DataFrame( @@ -878,7 +878,7 @@ def test_numeric_decimal_combined(self, pandas, duckdb_cursor): print(conversion) # result: [('1234.0',), ('123456789.0',), ('1234567890123456789.0',), ('0.1234567890123456789',)] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_numeric_decimal_varying_sizes(self, pandas, duckdb_cursor): decimals = pandas.DataFrame( data={ @@ -906,7 +906,7 @@ def test_numeric_decimal_varying_sizes(self, pandas, duckdb_cursor): print(reference) print(conversion) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_numeric_decimal_fallback_to_double(self, pandas, duckdb_cursor): # The widths of these decimal values are bigger than the max supported width for DECIMAL data = [ @@ -927,7 +927,7 @@ def test_numeric_decimal_fallback_to_double(self, pandas, duckdb_cursor): assert conversion == reference assert isinstance(conversion[0][0], float) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_numeric_decimal_double_mixed(self, pandas, duckdb_cursor): data = [ Decimal("1.234"), @@ -959,7 +959,7 @@ def test_numeric_decimal_double_mixed(self, pandas, duckdb_cursor): assert conversion == reference assert isinstance(conversion[0][0], float) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_numeric_decimal_out_of_range(self, pandas, duckdb_cursor): data = [Decimal("1.234567890123456789012345678901234567"), Decimal("123456789012345678901234567890123456.0")] decimals = pandas.DataFrame(data={"0": data}) diff --git a/tests/fast/pandas/test_df_recursive_nested.py b/tests/fast/pandas/test_df_recursive_nested.py index b8de512a..fb7d2ad0 100644 --- a/tests/fast/pandas/test_df_recursive_nested.py +++ b/tests/fast/pandas/test_df_recursive_nested.py @@ -12,8 +12,8 @@ def check_equal(conn, df, reference_query, data): duckdb_conn = duckdb.connect() duckdb_conn.execute(reference_query, parameters=[data]) - res = duckdb_conn.query('SELECT * FROM tbl').fetchall() - df_res = duckdb_conn.query('SELECT * FROM tbl').df() + res = duckdb_conn.query("SELECT * FROM tbl").fetchall() + df_res = duckdb_conn.query("SELECT * FROM tbl").df() out = conn.sql("SELECT * FROM df").fetchall() assert res == out @@ -24,39 +24,39 @@ def create_reference_query(): class TestDFRecursiveNested(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_list_of_structs(self, duckdb_cursor, pandas): - data = [[{'a': 5}, NULL, {'a': NULL}], NULL, [{'a': 5}, NULL, {'a': NULL}]] + data = [[{"a": 5}, NULL, {"a": NULL}], NULL, [{"a": 5}, NULL, {"a": NULL}]] reference_query = create_reference_query() - df = pandas.DataFrame([{'a': data}]) - check_equal(duckdb_cursor, df, reference_query, Value(data, 'STRUCT(a INTEGER)[]')) + df = pandas.DataFrame([{"a": data}]) + check_equal(duckdb_cursor, df, reference_query, Value(data, "STRUCT(a INTEGER)[]")) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_list_of_map(self, duckdb_cursor, pandas): # LIST(MAP(VARCHAR, VARCHAR)) - data = [[{5: NULL}, NULL, {}], NULL, [NULL, {3: NULL, 2: 'a', 4: NULL}, {'a': 1, 'b': 2, 'c': 3}]] + data = [[{5: NULL}, NULL, {}], NULL, [NULL, {3: NULL, 2: "a", 4: NULL}, {"a": 1, "b": 2, "c": 3}]] reference_query = create_reference_query() print(reference_query) - df = pandas.DataFrame([{'a': data}]) - check_equal(duckdb_cursor, df, reference_query, Value(data, 'MAP(VARCHAR, VARCHAR)[][]')) + df = pandas.DataFrame([{"a": data}]) + check_equal(duckdb_cursor, df, reference_query, Value(data, "MAP(VARCHAR, VARCHAR)[][]")) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_recursive_list(self, duckdb_cursor, pandas): # LIST(LIST(LIST(LIST(INTEGER)))) data = [[[[3, NULL, 5], NULL], NULL, [[5, -20, NULL]]], NULL, [[[NULL]], [[]], NULL]] reference_query = create_reference_query() - df = pandas.DataFrame([{'a': data}]) - check_equal(duckdb_cursor, df, reference_query, Value(data, 'INTEGER[][][][]')) + df = pandas.DataFrame([{"a": data}]) + check_equal(duckdb_cursor, df, reference_query, Value(data, "INTEGER[][][][]")) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_recursive_struct(self, duckdb_cursor, pandas): # STRUCT(STRUCT(STRUCT(LIST))) data = { - 'A': {'a': {'1': [1, 2, 3]}, 'b': NULL, 'c': {'1': NULL}}, - 'B': {'a': {'1': [1, NULL, 3]}, 'b': NULL, 'c': {'1': NULL}}, + "A": {"a": {"1": [1, 2, 3]}, "b": NULL, "c": {"1": NULL}}, + "B": {"a": {"1": [1, NULL, 3]}, "b": NULL, "c": {"1": NULL}}, } reference_query = create_reference_query() - df = pandas.DataFrame([{'a': data}]) + df = pandas.DataFrame([{"a": data}]) check_equal( duckdb_cursor, df, @@ -92,7 +92,7 @@ def test_recursive_struct(self, duckdb_cursor, pandas): ), ) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_recursive_map(self, duckdb_cursor, pandas): # MAP( # MAP( @@ -102,42 +102,42 @@ def test_recursive_map(self, duckdb_cursor, pandas): # INTEGER # ) data = { - 'key': [ - {'key': [5, 6, 7], 'value': [{'key': [8], 'value': [NULL]}, NULL, {'key': [9], 'value': ['a']}]}, - {'key': [], 'value': []}, + "key": [ + {"key": [5, 6, 7], "value": [{"key": [8], "value": [NULL]}, NULL, {"key": [9], "value": ["a"]}]}, + {"key": [], "value": []}, ], - 'value': [1, 2], + "value": [1, 2], } reference_query = create_reference_query() - df = pandas.DataFrame([{'a': data}]) + df = pandas.DataFrame([{"a": data}]) check_equal( - duckdb_cursor, df, reference_query, Value(data, 'MAP(MAP(INTEGER, MAP(INTEGER, VARCHAR)), INTEGER)') + duckdb_cursor, df, reference_query, Value(data, "MAP(MAP(INTEGER, MAP(INTEGER, VARCHAR)), INTEGER)") ) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_recursive_stresstest(self, duckdb_cursor, pandas): data = [ { - 'a': { - 'key': [ + "a": { + "key": [ # key 1 - {'1': [5, 4, 3], '2': [8, 7, 6], '3': [1, 2, 3]}, + {"1": [5, 4, 3], "2": [8, 7, 6], "3": [1, 2, 3]}, # key 2 - {'1': [], '2': NULL, '3': [NULL, 0, NULL]}, + {"1": [], "2": NULL, "3": [NULL, 0, NULL]}, ], - 'value': [ + "value": [ # value 1 - [{'A': 'abc', 'B': 'def', 'C': NULL}], + [{"A": "abc", "B": "def", "C": NULL}], # value 2 [NULL], ], }, - 'b': NULL, - 'c': {'key': [], 'value': []}, + "b": NULL, + "c": {"key": [], "value": []}, } ] reference_query = create_reference_query() - df = pandas.DataFrame([{'a': data}]) + df = pandas.DataFrame([{"a": data}]) duckdb_type = """ STRUCT( a MAP( diff --git a/tests/fast/pandas/test_fetch_df_chunk.py b/tests/fast/pandas/test_fetch_df_chunk.py index 1973a729..1f2d4b1b 100644 --- a/tests/fast/pandas/test_fetch_df_chunk.py +++ b/tests/fast/pandas/test_fetch_df_chunk.py @@ -13,16 +13,16 @@ def test_fetch_df_chunk(self): # Fetch the first chunk cur_chunk = query.fetch_df_chunk() - assert cur_chunk['a'][0] == 0 + assert cur_chunk["a"][0] == 0 assert len(cur_chunk) == VECTOR_SIZE # Fetch the second chunk, can't be entirely filled cur_chunk = query.fetch_df_chunk() - assert cur_chunk['a'][0] == VECTOR_SIZE + assert cur_chunk["a"][0] == VECTOR_SIZE expected = size - VECTOR_SIZE assert len(cur_chunk) == expected - @pytest.mark.parametrize('size', [3000, 10000, 100000, VECTOR_SIZE - 1, VECTOR_SIZE + 1, VECTOR_SIZE]) + @pytest.mark.parametrize("size", [3000, 10000, 100000, VECTOR_SIZE - 1, VECTOR_SIZE + 1, VECTOR_SIZE]) def test_monahan(self, size): con = duckdb.connect() con.execute(f"CREATE table t as select range a from range({size});") @@ -52,12 +52,12 @@ def test_fetch_df_chunk_parameter(self): # Return 2 vectors cur_chunk = query.fetch_df_chunk(2) - assert cur_chunk['a'][0] == 0 + assert cur_chunk["a"][0] == 0 assert len(cur_chunk) == VECTOR_SIZE * 2 # Return Default 1 vector cur_chunk = query.fetch_df_chunk() - assert cur_chunk['a'][0] == VECTOR_SIZE * 2 + assert cur_chunk["a"][0] == VECTOR_SIZE * 2 assert len(cur_chunk) == VECTOR_SIZE # Return 0 vectors @@ -69,7 +69,7 @@ def test_fetch_df_chunk_parameter(self): # Return more vectors than we have remaining cur_chunk = query.fetch_df_chunk(3) - assert cur_chunk['a'][0] == fetched + assert cur_chunk["a"][0] == fetched assert len(cur_chunk) == expected # These shouldn't throw errors (Just emmit empty chunks) @@ -88,5 +88,5 @@ def test_fetch_df_chunk_negative_parameter(self): query = con.execute("SELECT a FROM t") # Return -1 vector should not work - with pytest.raises(TypeError, match='incompatible function arguments'): + with pytest.raises(TypeError, match="incompatible function arguments"): cur_chunk = query.fetch_df_chunk(-1) diff --git a/tests/fast/pandas/test_fetch_nested.py b/tests/fast/pandas/test_fetch_nested.py index 5727429f..e25a44ba 100644 --- a/tests/fast/pandas/test_fetch_nested.py +++ b/tests/fast/pandas/test_fetch_nested.py @@ -10,10 +10,10 @@ def compare_results(con, query, expected): expected = pd.DataFrame.from_dict(expected) unsorted_res = con.query(query).df() - print(unsorted_res, unsorted_res['a'][0].__class__) + print(unsorted_res, unsorted_res["a"][0].__class__) df_duck = con.query("select * from unsorted_res order by all").df() - print(df_duck, df_duck['a'][0].__class__) - print(expected, expected['a'][0].__class__) + print(df_duck, df_duck["a"][0].__class__) + print(expected, expected["a"][0].__class__) pd.testing.assert_frame_equal(df_duck, expected) @@ -147,7 +147,7 @@ def list_test_cases(): class TestFetchNested(object): - @pytest.mark.parametrize('query, expected', list_test_cases()) + @pytest.mark.parametrize("query, expected", list_test_cases()) def test_fetch_df_list(self, duckdb_cursor, query, expected): compare_results(duckdb_cursor, query, expected) diff --git a/tests/fast/pandas/test_implicit_pandas_scan.py b/tests/fast/pandas/test_implicit_pandas_scan.py index e6f0b9f4..2d4610ff 100644 --- a/tests/fast/pandas/test_implicit_pandas_scan.py +++ b/tests/fast/pandas/test_implicit_pandas_scan.py @@ -15,7 +15,7 @@ except: pyarrow_dtypes_enabled = False -if Version(pd.__version__) >= Version('2.0.0') and pyarrow_dtypes_enabled: +if Version(pd.__version__) >= Version("2.0.0") and pyarrow_dtypes_enabled: pyarrow_df = numpy_nullable_df.convert_dtypes(dtype_backend="pyarrow") else: # dtype_backend is not supported in pandas < 2.0.0 @@ -23,20 +23,20 @@ class TestImplicitPandasScan(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_local_pandas_scan(self, duckdb_cursor, pandas): con = duckdb.connect() df = pandas.DataFrame([{"COL1": "val1", "CoL2": 1.05}, {"COL1": "val3", "CoL2": 17}]) - r1 = con.execute('select * from df').fetchdf() + r1 = con.execute("select * from df").fetchdf() assert r1["COL1"][0] == "val1" assert r1["COL1"][1] == "val3" assert r1["CoL2"][0] == 1.05 assert r1["CoL2"][1] == 17 - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_global_pandas_scan(self, duckdb_cursor, pandas): con = duckdb.connect() - r1 = con.execute(f'select * from {pandas.backend}_df').fetchdf() + r1 = con.execute(f"select * from {pandas.backend}_df").fetchdf() assert r1["COL1"][0] == "val1" assert r1["COL1"][1] == "val4" assert r1["CoL2"][0] == 1.05 diff --git a/tests/fast/pandas/test_import_cache.py b/tests/fast/pandas/test_import_cache.py index 32eab7b0..6ed601c5 100644 --- a/tests/fast/pandas/test_import_cache.py +++ b/tests/fast/pandas/test_import_cache.py @@ -3,26 +3,26 @@ import pytest -@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) +@pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_import_cache_explicit_dtype(pandas): df = pandas.DataFrame( { - 'id': [1, 2, 3], - 'value': pandas.Series(['123.123', pandas.NaT, pandas.NA], dtype=pandas.StringDtype(storage='python')), + "id": [1, 2, 3], + "value": pandas.Series(["123.123", pandas.NaT, pandas.NA], dtype=pandas.StringDtype(storage="python")), } ) con = duckdb.connect() result_df = con.query("select id, value from df").df() - assert result_df['value'][1] is None - assert result_df['value'][2] is None + assert result_df["value"][1] is None + assert result_df["value"][2] is None -@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) +@pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_import_cache_implicit_dtype(pandas): - df = pandas.DataFrame({'id': [1, 2, 3], 'value': pandas.Series(['123.123', pandas.NaT, pandas.NA])}) + df = pandas.DataFrame({"id": [1, 2, 3], "value": pandas.Series(["123.123", pandas.NaT, pandas.NA])}) con = duckdb.connect() result_df = con.query("select id, value from df").df() - assert result_df['value'][1] is None - assert result_df['value'][2] is None + assert result_df["value"][1] is None + assert result_df["value"][2] is None diff --git a/tests/fast/pandas/test_issue_1767.py b/tests/fast/pandas/test_issue_1767.py index e37f19e1..27f0c2ff 100644 --- a/tests/fast/pandas/test_issue_1767.py +++ b/tests/fast/pandas/test_issue_1767.py @@ -9,7 +9,7 @@ # Join from pandas not matching identical strings #1767 class TestIssue1767(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_unicode_join_pandas(self, duckdb_cursor, pandas): A = pandas.DataFrame({"key": ["a", "п"]}) B = pandas.DataFrame({"key": ["a", "п"]}) @@ -18,6 +18,6 @@ def test_unicode_join_pandas(self, duckdb_cursor, pandas): q = arrow.query("""SELECT key FROM "A" FULL JOIN "B" USING ("key") ORDER BY key""") result = q.df() - d = {'key': ["a", "п"]} + d = {"key": ["a", "п"]} df = pandas.DataFrame(data=d) pandas.testing.assert_frame_equal(result, df) diff --git a/tests/fast/pandas/test_limit.py b/tests/fast/pandas/test_limit.py index 4a03c24f..460716cd 100644 --- a/tests/fast/pandas/test_limit.py +++ b/tests/fast/pandas/test_limit.py @@ -4,22 +4,22 @@ class TestLimitPandas(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_limit_df(self, duckdb_cursor, pandas): df_in = pandas.DataFrame( { - 'numbers': [1, 2, 3, 4, 5], + "numbers": [1, 2, 3, 4, 5], } ) limit_df = duckdb.limit(df_in, 2) assert len(limit_df.execute().fetchall()) == 2 - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_aggregate_df(self, duckdb_cursor, pandas): df_in = pandas.DataFrame( { - 'numbers': [1, 2, 2, 2], + "numbers": [1, 2, 2, 2], } ) - aggregate_df = duckdb.aggregate(df_in, 'count(numbers)', 'numbers').order('all') + aggregate_df = duckdb.aggregate(df_in, "count(numbers)", "numbers").order("all") assert aggregate_df.execute().fetchall() == [(1,), (3,)] diff --git a/tests/fast/pandas/test_pandas_arrow.py b/tests/fast/pandas/test_pandas_arrow.py index 8729362d..e1661041 100644 --- a/tests/fast/pandas/test_pandas_arrow.py +++ b/tests/fast/pandas/test_pandas_arrow.py @@ -4,7 +4,7 @@ from conftest import pandas_supports_arrow_backend -pd = pytest.importorskip("pandas", '2.0.0') +pd = pytest.importorskip("pandas", "2.0.0") import numpy as np from pandas.api.types import is_integer_dtype @@ -13,7 +13,7 @@ class TestPandasArrow(object): def test_pandas_arrow(self, duckdb_cursor): pd = pytest.importorskip("pandas") - df = pd.DataFrame({'a': pd.Series([5, 4, 3])}).convert_dtypes() + df = pd.DataFrame({"a": pd.Series([5, 4, 3])}).convert_dtypes() con = duckdb.connect() res = con.sql("select * from df").fetchall() assert res == [(5,), (4,), (3,)] @@ -21,8 +21,8 @@ def test_pandas_arrow(self, duckdb_cursor): def test_mixed_columns(self): df = pd.DataFrame( { - 'strings': pd.Series(['abc', 'DuckDB', 'quack', 'quack']), - 'timestamps': pd.Series( + "strings": pd.Series(["abc", "DuckDB", "quack", "quack"]), + "timestamps": pd.Series( [ datetime.datetime(1990, 10, 21), datetime.datetime(2023, 1, 11), @@ -30,23 +30,23 @@ def test_mixed_columns(self): datetime.datetime(1990, 10, 21), ] ), - 'objects': pd.Series([[5, 4, 3], 'test', None, {'a': 42}]), - 'integers': np.ndarray((4,), buffer=np.array([1, 2, 3, 4, 5]), offset=np.int_().itemsize, dtype=int), + "objects": pd.Series([[5, 4, 3], "test", None, {"a": 42}]), + "integers": np.ndarray((4,), buffer=np.array([1, 2, 3, 4, 5]), offset=np.int_().itemsize, dtype=int), } ) - pyarrow_df = df.convert_dtypes(dtype_backend='pyarrow') + pyarrow_df = df.convert_dtypes(dtype_backend="pyarrow") con = duckdb.connect() with pytest.raises( - duckdb.InvalidInputException, match='The dataframe could not be converted to a pyarrow.lib.Table' + duckdb.InvalidInputException, match="The dataframe could not be converted to a pyarrow.lib.Table" ): - res = con.sql('select * from pyarrow_df').fetchall() + res = con.sql("select * from pyarrow_df").fetchall() numpy_df = pd.DataFrame( - {'a': np.ndarray((2,), buffer=np.array([1, 2, 3]), offset=np.int_().itemsize, dtype=int)} - ).convert_dtypes(dtype_backend='numpy_nullable') + {"a": np.ndarray((2,), buffer=np.array([1, 2, 3]), offset=np.int_().itemsize, dtype=int)} + ).convert_dtypes(dtype_backend="numpy_nullable") arrow_df = pd.DataFrame( { - 'a': pd.Series( + "a": pd.Series( [ datetime.datetime(1990, 10, 21), datetime.datetime(2023, 1, 11), @@ -55,45 +55,45 @@ def test_mixed_columns(self): ] ) } - ).convert_dtypes(dtype_backend='pyarrow') - python_df = pd.DataFrame({'a': pd.Series(['test', [5, 4, 3], {'a': 42}])}).convert_dtypes() + ).convert_dtypes(dtype_backend="pyarrow") + python_df = pd.DataFrame({"a": pd.Series(["test", [5, 4, 3], {"a": 42}])}).convert_dtypes() - df = pd.concat([numpy_df['a'], arrow_df['a'], python_df['a']], axis=1, keys=['numpy', 'arrow', 'python']) - assert is_integer_dtype(df.dtypes['numpy']) - assert isinstance(df.dtypes['arrow'], pd.ArrowDtype) - assert isinstance(df.dtypes['python'], np.dtype('O').__class__) + df = pd.concat([numpy_df["a"], arrow_df["a"], python_df["a"]], axis=1, keys=["numpy", "arrow", "python"]) + assert is_integer_dtype(df.dtypes["numpy"]) + assert isinstance(df.dtypes["arrow"], pd.ArrowDtype) + assert isinstance(df.dtypes["python"], np.dtype("O").__class__) with pytest.raises( - duckdb.InvalidInputException, match='The dataframe could not be converted to a pyarrow.lib.Table' + duckdb.InvalidInputException, match="The dataframe could not be converted to a pyarrow.lib.Table" ): - res = con.sql('select * from df').fetchall() + res = con.sql("select * from df").fetchall() def test_empty_df(self): df = pd.DataFrame( { - 'string': pd.Series(data=[], dtype='string'), - 'object': pd.Series(data=[], dtype='object'), - 'Int64': pd.Series(data=[], dtype='Int64'), - 'Float64': pd.Series(data=[], dtype='Float64'), - 'bool': pd.Series(data=[], dtype='bool'), - 'datetime64[ns]': pd.Series(data=[], dtype='datetime64[ns]'), - 'datetime64[ms]': pd.Series(data=[], dtype='datetime64[ms]'), - 'datetime64[us]': pd.Series(data=[], dtype='datetime64[us]'), - 'datetime64[s]': pd.Series(data=[], dtype='datetime64[s]'), - 'category': pd.Series(data=[], dtype='category'), - 'timedelta64[ns]': pd.Series(data=[], dtype='timedelta64[ns]'), + "string": pd.Series(data=[], dtype="string"), + "object": pd.Series(data=[], dtype="object"), + "Int64": pd.Series(data=[], dtype="Int64"), + "Float64": pd.Series(data=[], dtype="Float64"), + "bool": pd.Series(data=[], dtype="bool"), + "datetime64[ns]": pd.Series(data=[], dtype="datetime64[ns]"), + "datetime64[ms]": pd.Series(data=[], dtype="datetime64[ms]"), + "datetime64[us]": pd.Series(data=[], dtype="datetime64[us]"), + "datetime64[s]": pd.Series(data=[], dtype="datetime64[s]"), + "category": pd.Series(data=[], dtype="category"), + "timedelta64[ns]": pd.Series(data=[], dtype="timedelta64[ns]"), } ) - pyarrow_df = df.convert_dtypes(dtype_backend='pyarrow') + pyarrow_df = df.convert_dtypes(dtype_backend="pyarrow") con = duckdb.connect() - res = con.sql('select * from pyarrow_df').fetchall() + res = con.sql("select * from pyarrow_df").fetchall() assert res == [] def test_completely_null_df(self): df = pd.DataFrame( { - 'a': pd.Series( + "a": pd.Series( data=[ None, np.nan, @@ -102,35 +102,35 @@ def test_completely_null_df(self): ) } ) - pyarrow_df = df.convert_dtypes(dtype_backend='pyarrow') + pyarrow_df = df.convert_dtypes(dtype_backend="pyarrow") con = duckdb.connect() - res = con.sql('select * from pyarrow_df').fetchall() + res = con.sql("select * from pyarrow_df").fetchall() assert res == [(None,), (None,), (None,)] def test_mixed_nulls(self): df = pd.DataFrame( { - 'float': pd.Series(data=[4.123123, None, 7.23456], dtype='Float64'), - 'int64': pd.Series(data=[-234234124, 709329413, pd.NA], dtype='Int64'), - 'bool': pd.Series(data=[np.nan, True, False], dtype='boolean'), - 'string': pd.Series(data=['NULL', None, 'quack']), - 'list[str]': pd.Series(data=[['Huey', 'Dewey', 'Louie'], [None, pd.NA, np.nan, 'DuckDB'], None]), - 'datetime64': pd.Series( + "float": pd.Series(data=[4.123123, None, 7.23456], dtype="Float64"), + "int64": pd.Series(data=[-234234124, 709329413, pd.NA], dtype="Int64"), + "bool": pd.Series(data=[np.nan, True, False], dtype="boolean"), + "string": pd.Series(data=["NULL", None, "quack"]), + "list[str]": pd.Series(data=[["Huey", "Dewey", "Louie"], [None, pd.NA, np.nan, "DuckDB"], None]), + "datetime64": pd.Series( data=[datetime.datetime(2011, 8, 16, 22, 7, 8), None, datetime.datetime(2010, 4, 26, 18, 14, 14)] ), - 'date': pd.Series(data=[datetime.date(2008, 5, 28), datetime.date(2013, 7, 14), None]), + "date": pd.Series(data=[datetime.date(2008, 5, 28), datetime.date(2013, 7, 14), None]), } ) - pyarrow_df = df.convert_dtypes(dtype_backend='pyarrow') + pyarrow_df = df.convert_dtypes(dtype_backend="pyarrow") con = duckdb.connect() - res = con.sql('select * from pyarrow_df').fetchone() + res = con.sql("select * from pyarrow_df").fetchone() assert res == ( 4.123123, -234234124, None, - 'NULL', - ['Huey', 'Dewey', 'Louie'], + "NULL", + ["Huey", "Dewey", "Louie"], datetime.datetime(2011, 8, 16, 22, 7, 8), datetime.date(2008, 5, 28), ) diff --git a/tests/fast/pandas/test_pandas_category.py b/tests/fast/pandas/test_pandas_category.py index e86a97d9..4b29b3fb 100644 --- a/tests/fast/pandas/test_pandas_category.py +++ b/tests/fast/pandas/test_pandas_category.py @@ -7,7 +7,7 @@ def check_category_equal(category): df_in = pd.DataFrame( { - 'x': pd.Categorical(category, ordered=True), + "x": pd.Categorical(category, ordered=True), } ) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() @@ -23,7 +23,7 @@ def check_create_table(category): conn = duckdb.connect() conn.execute("PRAGMA enable_verification") - df_in = pd.DataFrame({'x': pd.Categorical(category, ordered=True), 'y': pd.Categorical(category, ordered=True)}) + df_in = pd.DataFrame({"x": pd.Categorical(category, ordered=True), "y": pd.Categorical(category, ordered=True)}) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() assert df_in.equals(df_out) @@ -39,7 +39,7 @@ def check_create_table(category): conn.execute("INSERT INTO t1 VALUES ('2','2')") res = conn.execute("SELECT x FROM t1 where x = '1'").fetchall() - assert res == [('1',)] + assert res == [("1",)] res = conn.execute("SELECT t1.x FROM t1 inner join t2 on (t1.x = t2.x)").fetchall() assert res == conn.execute("SELECT x FROM t1").fetchall() @@ -56,27 +56,27 @@ def check_create_table(category): class TestCategory(object): def test_category_simple(self, duckdb_cursor): - df_in = pd.DataFrame({'float': [1.0, 2.0, 1.0], 'int': pd.Series([1, 2, 1], dtype="category")}) + df_in = pd.DataFrame({"float": [1.0, 2.0, 1.0], "int": pd.Series([1, 2, 1], dtype="category")}) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() print(duckdb.query_df(df_in, "data", "SELECT * FROM data").fetchall()) - print(df_out['int']) - assert numpy.all(df_out['float'] == numpy.array([1.0, 2.0, 1.0])) - assert numpy.all(df_out['int'] == numpy.array([1, 2, 1])) + print(df_out["int"]) + assert numpy.all(df_out["float"] == numpy.array([1.0, 2.0, 1.0])) + assert numpy.all(df_out["int"] == numpy.array([1, 2, 1])) def test_category_nulls(self, duckdb_cursor): - df_in = pd.DataFrame({'int': pd.Series([1, 2, None], dtype="category")}) + df_in = pd.DataFrame({"int": pd.Series([1, 2, None], dtype="category")}) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() print(duckdb.query_df(df_in, "data", "SELECT * FROM data").fetchall()) - assert df_out['int'][0] == 1 - assert df_out['int'][1] == 2 - assert pd.isna(df_out['int'][2]) + assert df_out["int"][0] == 1 + assert df_out["int"][1] == 2 + assert pd.isna(df_out["int"][2]) def test_category_string(self, duckdb_cursor): - check_category_equal(['foo', 'bla', 'zoo', 'foo', 'foo', 'bla']) + check_category_equal(["foo", "bla", "zoo", "foo", "foo", "bla"]) def test_category_string_null(self, duckdb_cursor): - check_category_equal(['foo', 'bla', None, 'zoo', 'foo', 'foo', None, 'bla']) + check_category_equal(["foo", "bla", None, "zoo", "foo", "foo", None, "bla"]) def test_category_string_null_bug_4747(self, duckdb_cursor): check_category_equal([str(i) for i in range(160)] + [None]) @@ -84,18 +84,18 @@ def test_category_string_null_bug_4747(self, duckdb_cursor): def test_categorical_fetchall(self, duckdb_cursor): df_in = pd.DataFrame( { - 'x': pd.Categorical(['foo', 'bla', None, 'zoo', 'foo', 'foo', None, 'bla'], ordered=True), + "x": pd.Categorical(["foo", "bla", None, "zoo", "foo", "foo", None, "bla"], ordered=True), } ) assert duckdb.query_df(df_in, "data", "SELECT * FROM data").fetchall() == [ - ('foo',), - ('bla',), + ("foo",), + ("bla",), (None,), - ('zoo',), - ('foo',), - ('foo',), + ("zoo",), + ("foo",), + ("foo",), (None,), - ('bla',), + ("bla",), ] def test_category_string_uint8(self, duckdb_cursor): @@ -105,30 +105,30 @@ def test_category_string_uint8(self, duckdb_cursor): check_create_table(category) def test_empty_categorical(self, duckdb_cursor): - empty_categoric_df = pd.DataFrame({'category': pd.Series(dtype='category')}) + empty_categoric_df = pd.DataFrame({"category": pd.Series(dtype="category")}) duckdb_cursor.execute("CREATE TABLE test AS SELECT * FROM empty_categoric_df") - res = duckdb_cursor.table('test').fetchall() + res = duckdb_cursor.table("test").fetchall() assert res == [] with pytest.raises(duckdb.ConversionException, match="Could not convert string 'test' to UINT8"): duckdb_cursor.execute("insert into test VALUES('test')") duckdb_cursor.execute("insert into test VALUES(NULL)") - res = duckdb_cursor.table('test').fetchall() + res = duckdb_cursor.table("test").fetchall() assert res == [(None,)] def test_category_fetch_df_chunk(self, duckdb_cursor): con = duckdb.connect() - categories = ['foo', 'bla', None, 'zoo', 'foo', 'foo', None, 'bla'] + categories = ["foo", "bla", None, "zoo", "foo", "foo", None, "bla"] result = categories * 256 categories = result * 2 df_result = pd.DataFrame( { - 'x': pd.Categorical(result, ordered=True), + "x": pd.Categorical(result, ordered=True), } ) df_in = pd.DataFrame( { - 'x': pd.Categorical(categories, ordered=True), + "x": pd.Categorical(categories, ordered=True), } ) con.register("data", df_in) @@ -146,8 +146,8 @@ def test_category_fetch_df_chunk(self, duckdb_cursor): def test_category_mix(self, duckdb_cursor): df_in = pd.DataFrame( { - 'float': [1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 0.0], - 'x': pd.Categorical(['foo', 'bla', None, 'zoo', 'foo', 'foo', None, 'bla'], ordered=True), + "float": [1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 0.0], + "x": pd.Categorical(["foo", "bla", None, "zoo", "foo", "foo", None, "bla"], ordered=True), } ) diff --git a/tests/fast/pandas/test_pandas_enum.py b/tests/fast/pandas/test_pandas_enum.py index 9dc13a64..b1eb2c7f 100644 --- a/tests/fast/pandas/test_pandas_enum.py +++ b/tests/fast/pandas/test_pandas_enum.py @@ -15,7 +15,7 @@ def test_3480(self, duckdb_cursor): """ ) df = duckdb_cursor.query(f"SELECT * FROM tab LIMIT 0;").to_df() - assert df["cat"].cat.categories.equals(pd.Index(['marie', 'duchess', 'toulouse'])) + assert df["cat"].cat.categories.equals(pd.Index(["marie", "duchess", "toulouse"])) duckdb_cursor.execute("DROP TABLE tab") duckdb_cursor.execute("DROP TYPE cat") @@ -32,14 +32,14 @@ def test_3479(self, duckdb_cursor): df = pd.DataFrame( { - "cat2": pd.Series(['duchess', 'toulouse', 'marie', None, "berlioz", "o_malley"], dtype="category"), + "cat2": pd.Series(["duchess", "toulouse", "marie", None, "berlioz", "o_malley"], dtype="category"), "amt": [1, 2, 3, 4, 5, 6], } ) - duckdb_cursor.register('df', df) + duckdb_cursor.register("df", df) with pytest.raises( duckdb.ConversionException, - match='Type UINT8 with value 0 can\'t be cast because the value is out of range for the destination type UINT8', + match="Type UINT8 with value 0 can't be cast because the value is out of range for the destination type UINT8", ): duckdb_cursor.execute(f"INSERT INTO tab SELECT * FROM df;") diff --git a/tests/fast/pandas/test_pandas_limit.py b/tests/fast/pandas/test_pandas_limit.py index 506d5dd5..d551a6e4 100644 --- a/tests/fast/pandas/test_pandas_limit.py +++ b/tests/fast/pandas/test_pandas_limit.py @@ -6,9 +6,9 @@ class TestPandasLimit(object): def test_pandas_limit(self, duckdb_cursor): con = duckdb.connect() - df = con.execute('select * from range(10000000) tbl(i)').df() + df = con.execute("select * from range(10000000) tbl(i)").df() - con.execute('SET threads=8') + con.execute("SET threads=8") - limit_df = con.execute('SELECT * FROM df WHERE i=334 OR i>9967864 LIMIT 5').df() - assert list(limit_df['i']) == [334, 9967865, 9967866, 9967867, 9967868] + limit_df = con.execute("SELECT * FROM df WHERE i=334 OR i>9967864 LIMIT 5").df() + assert list(limit_df["i"]) == [334, 9967865, 9967866, 9967867, 9967868] diff --git a/tests/fast/pandas/test_pandas_na.py b/tests/fast/pandas/test_pandas_na.py index f165d180..7bc01003 100644 --- a/tests/fast/pandas/test_pandas_na.py +++ b/tests/fast/pandas/test_pandas_na.py @@ -16,20 +16,20 @@ def assert_nullness(items, null_indices): @pytest.mark.skipif(platform.system() == "Emscripten", reason="Pandas interaction is broken in Pyodide 3.11") class TestPandasNA(object): - @pytest.mark.parametrize('rows', [100, duckdb.__standard_vector_size__, 5000, 1000000]) - @pytest.mark.parametrize('pd', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("rows", [100, duckdb.__standard_vector_size__, 5000, 1000000]) + @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) def test_pandas_string_null(self, duckdb_cursor, rows, pd): df: pd.DataFrame = pd.DataFrame(index=np.arange(rows)) df["string_column"] = pd.Series(dtype="string") e_df_rel = duckdb_cursor.from_df(df) - assert e_df_rel.types == ['VARCHAR'] + assert e_df_rel.types == ["VARCHAR"] roundtrip = e_df_rel.df() - assert roundtrip['string_column'].dtype == 'object' - expected = pd.DataFrame({'string_column': [None for _ in range(rows)]}) + assert roundtrip["string_column"].dtype == "object" + expected = pd.DataFrame({"string_column": [None for _ in range(rows)]}) pd.testing.assert_frame_equal(expected, roundtrip) def test_pandas_na(self, duckdb_cursor): - pd = pytest.importorskip('pandas', minversion='1.0.0', reason='Support for pandas.NA has not been added yet') + pd = pytest.importorskip("pandas", minversion="1.0.0", reason="Support for pandas.NA has not been added yet") # DataFrame containing a single pd.NA df = pd.DataFrame(pd.Series([pd.NA])) @@ -46,7 +46,7 @@ def test_pandas_na(self, duckdb_cursor): # Test if pd.NA behaves the same as np.nan once converted nan_df = pd.DataFrame( { - 'a': [ + "a": [ 1.123, 5.23234, np.nan, @@ -60,7 +60,7 @@ def test_pandas_na(self, duckdb_cursor): ) na_df = pd.DataFrame( { - 'a': [ + "a": [ 1.123, 5.23234, pd.NA, @@ -72,15 +72,15 @@ def test_pandas_na(self, duckdb_cursor): ] } ) - assert str(nan_df['a'].dtype) == 'float64' - assert str(na_df['a'].dtype) == 'object' # pd.NA values turn the column into 'object' + assert str(nan_df["a"].dtype) == "float64" + assert str(na_df["a"].dtype) == "object" # pd.NA values turn the column into 'object' nan_result = duckdb_cursor.execute("select * from nan_df").df() na_result = duckdb_cursor.execute("select * from na_df").df() pd.testing.assert_frame_equal(nan_result, na_result) # Mixed with stringified pd.NA values - na_string_df = pd.DataFrame({'a': [str(pd.NA), str(pd.NA), pd.NA, str(pd.NA), pd.NA, pd.NA, pd.NA, str(pd.NA)]}) + na_string_df = pd.DataFrame({"a": [str(pd.NA), str(pd.NA), pd.NA, str(pd.NA), pd.NA, pd.NA, pd.NA, str(pd.NA)]}) null_indices = [2, 4, 5, 6] res = duckdb_cursor.execute("select * from na_string_df").fetchall() items = [x[0] for x in [y for y in res]] diff --git a/tests/fast/pandas/test_pandas_object.py b/tests/fast/pandas/test_pandas_object.py index c00fcbc2..9e10681c 100644 --- a/tests/fast/pandas/test_pandas_object.py +++ b/tests/fast/pandas/test_pandas_object.py @@ -9,22 +9,22 @@ class TestPandasObject(object): def test_object_lotof_nulls(self): # Test mostly null column data = [None] + [1] + [None] * 10000 # Last element is 1, others are None - pandas_df = pd.DataFrame(data, columns=['c'], dtype=object) + pandas_df = pd.DataFrame(data, columns=["c"], dtype=object) con = duckdb.connect() - assert con.execute('FROM pandas_df where c is not null').fetchall() == [(1.0,)] + assert con.execute("FROM pandas_df where c is not null").fetchall() == [(1.0,)] # Test all nulls, should return varchar data = [None] * 10000 # Last element is 1, others are None - pandas_df_2 = pd.DataFrame(data, columns=['c'], dtype=object) - assert con.execute('FROM pandas_df_2 limit 1').fetchall() == [(None,)] - assert con.execute('select typeof(c) FROM pandas_df_2 limit 1').fetchall() == [('"NULL"',)] + pandas_df_2 = pd.DataFrame(data, columns=["c"], dtype=object) + assert con.execute("FROM pandas_df_2 limit 1").fetchall() == [(None,)] + assert con.execute("select typeof(c) FROM pandas_df_2 limit 1").fetchall() == [('"NULL"',)] def test_object_to_string(self, duckdb_cursor): - con = duckdb.connect(database=':memory:', read_only=False) - x = pd.DataFrame([[1, 'a', 2], [1, None, 2], [1, 1.1, 2], [1, 1.1, 2], [1, 1.1, 2]]) + con = duckdb.connect(database=":memory:", read_only=False) + x = pd.DataFrame([[1, "a", 2], [1, None, 2], [1, 1.1, 2], [1, 1.1, 2], [1, 1.1, 2]]) x = x.iloc[1:].copy() # middle col now entirely native float items - con.register('view2', x) - df = con.execute('select * from view2').fetchall() + con.register("view2", x) + df = con.execute("select * from view2").fetchall() assert df == [(1, None, 2), (1, 1.1, 2), (1, 1.1, 2), (1, 1.1, 2)] def test_tuple_to_list(self, duckdb_cursor): @@ -45,7 +45,7 @@ def test_tuple_to_list(self, duckdb_cursor): ) ) duckdb_cursor.execute("CREATE TABLE test as SELECT * FROM tuple_df") - res = duckdb_cursor.table('test').fetchall() + res = duckdb_cursor.table("test").fetchall() assert res == [([1, 2, 3],), ([4, 5, 6],)] def test_2273(self, duckdb_cursor): @@ -56,8 +56,8 @@ def test_object_to_string_with_stride(self, duckdb_cursor): data = np.array([["a", "b", "c"], [1, 2, 3], [1, 2, 3], [11, 22, 33]]) df = pd.DataFrame(data=data[1:,], columns=data[0]) duckdb_cursor.register("object_with_strides", df) - res = duckdb_cursor.sql('select * from object_with_strides').fetchall() - assert res == [('1', '2', '3'), ('1', '2', '3'), ('11', '22', '33')] + res = duckdb_cursor.sql("select * from object_with_strides").fetchall() + assert res == [("1", "2", "3"), ("1", "2", "3"), ("11", "22", "33")] def test_2499(self, duckdb_cursor): df = pd.DataFrame( @@ -65,11 +65,11 @@ def test_2499(self, duckdb_cursor): [ np.array( [ - {'a': 0.881040697801939}, - {'a': 0.9922600577751953}, - {'a': 0.1589674833259317}, - {'a': 0.8928451262745073}, - {'a': 0.07022897889168278}, + {"a": 0.881040697801939}, + {"a": 0.9922600577751953}, + {"a": 0.1589674833259317}, + {"a": 0.8928451262745073}, + {"a": 0.07022897889168278}, ], dtype=object, ) @@ -77,11 +77,11 @@ def test_2499(self, duckdb_cursor): [ np.array( [ - {'a': 0.8759413504156746}, - {'a': 0.055784331256246156}, - {'a': 0.8605151517439655}, - {'a': 0.40807139339337695}, - {'a': 0.8429048322459952}, + {"a": 0.8759413504156746}, + {"a": 0.055784331256246156}, + {"a": 0.8605151517439655}, + {"a": 0.40807139339337695}, + {"a": 0.8429048322459952}, ], dtype=object, ) @@ -89,19 +89,19 @@ def test_2499(self, duckdb_cursor): [ np.array( [ - {'a': 0.9697093934032401}, - {'a': 0.9529257667149468}, - {'a': 0.21398182248591713}, - {'a': 0.6328512122275955}, - {'a': 0.5146953214092728}, + {"a": 0.9697093934032401}, + {"a": 0.9529257667149468}, + {"a": 0.21398182248591713}, + {"a": 0.6328512122275955}, + {"a": 0.5146953214092728}, ], dtype=object, ) ], ], - columns=['col'], + columns=["col"], ) - con = duckdb.connect(database=':memory:', read_only=False) - con.register('df', df) - assert con.execute('select count(*) from df').fetchone() == (3,) + con = duckdb.connect(database=":memory:", read_only=False) + con.register("df", df) + assert con.execute("select count(*) from df").fetchone() == (3,) diff --git a/tests/fast/pandas/test_pandas_string.py b/tests/fast/pandas/test_pandas_string.py index 494823ad..4bd5996d 100644 --- a/tests/fast/pandas/test_pandas_string.py +++ b/tests/fast/pandas/test_pandas_string.py @@ -5,23 +5,23 @@ class TestPandasString(object): def test_pandas_string(self, duckdb_cursor): - strings = numpy.array(['foo', 'bar', 'baz']) + strings = numpy.array(["foo", "bar", "baz"]) # https://pandas.pydata.org/pandas-docs/stable/user_guide/text.html df_in = pd.DataFrame( { - 'object': pd.Series(strings, dtype='object'), + "object": pd.Series(strings, dtype="object"), } ) # Only available in pandas 1.0.0 - if hasattr(pd, 'StringDtype'): - df_in['string'] = pd.Series(strings, dtype=pd.StringDtype()) + if hasattr(pd, "StringDtype"): + df_in["string"] = pd.Series(strings, dtype=pd.StringDtype()) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() - assert numpy.all(df_out['object'] == strings) - if hasattr(pd, 'StringDtype'): - assert numpy.all(df_out['string'] == strings) + assert numpy.all(df_out["object"] == strings) + if hasattr(pd, "StringDtype"): + assert numpy.all(df_out["string"] == strings) def test_bug_2467(self, duckdb_cursor): N = 1_000_000 @@ -35,11 +35,8 @@ def test_bug_2467(self, duckdb_cursor): CREATE TABLE t1 AS SELECT * FROM df """ ) - assert ( - con.execute( - f""" + assert con.execute( + f""" SELECT count(*) from t1 """ - ).fetchall() - == [(3000000,)] - ) + ).fetchall() == [(3000000,)] diff --git a/tests/fast/pandas/test_pandas_timestamp.py b/tests/fast/pandas/test_pandas_timestamp.py index 8e17db21..835ff3af 100644 --- a/tests/fast/pandas/test_pandas_timestamp.py +++ b/tests/fast/pandas/test_pandas_timestamp.py @@ -7,30 +7,30 @@ from conftest import pandas_2_or_higher -@pytest.mark.parametrize('timezone', ['UTC', 'CET', 'Asia/Kathmandu']) +@pytest.mark.parametrize("timezone", ["UTC", "CET", "Asia/Kathmandu"]) @pytest.mark.skipif(not pandas_2_or_higher(), reason="Pandas <2.0.0 does not support timezones in the metadata string") def test_run_pandas_with_tz(timezone): con = duckdb.connect() con.execute(f"SET TimeZone = '{timezone}'") df = pandas.DataFrame( { - 'timestamp': pandas.Series( - data=[pandas.Timestamp(year=2022, month=1, day=1, hour=10, minute=15, tz=timezone, unit='us')], - dtype=f'datetime64[us, {timezone}]', + "timestamp": pandas.Series( + data=[pandas.Timestamp(year=2022, month=1, day=1, hour=10, minute=15, tz=timezone, unit="us")], + dtype=f"datetime64[us, {timezone}]", ) } ) duck_df = con.from_df(df).df() - assert duck_df['timestamp'][0] == df['timestamp'][0] + assert duck_df["timestamp"][0] == df["timestamp"][0] def test_timestamp_conversion(duckdb_cursor): - tzinfo = pandas.Timestamp('2024-01-01 00:00:00+0100', tz='Europe/Copenhagen').tzinfo + tzinfo = pandas.Timestamp("2024-01-01 00:00:00+0100", tz="Europe/Copenhagen").tzinfo ts_df = pandas.DataFrame( { "ts": [ - pandas.Timestamp('2024-01-01 00:00:00+0100', tz=tzinfo), - pandas.Timestamp('2024-01-02 00:00:00+0100', tz=tzinfo), + pandas.Timestamp("2024-01-01 00:00:00+0100", tz=tzinfo), + pandas.Timestamp("2024-01-02 00:00:00+0100", tz=tzinfo), ] } ) diff --git a/tests/fast/pandas/test_pandas_types.py b/tests/fast/pandas/test_pandas_types.py index b21c7f14..fcc63b82 100644 --- a/tests/fast/pandas/test_pandas_types.py +++ b/tests/fast/pandas/test_pandas_types.py @@ -11,7 +11,7 @@ def round_trip(data, pandas_type): df_in = pd.DataFrame( { - 'object': pd.Series(data, dtype=pandas_type), + "object": pd.Series(data, dtype=pandas_type), } ) @@ -23,7 +23,7 @@ def round_trip(data, pandas_type): class TestNumpyNullableTypes(object): def test_pandas_numeric(self): - base_df = pd.DataFrame({'a': range(10)}) + base_df = pd.DataFrame({"a": range(10)}) data_types = [ "uint8", @@ -46,7 +46,7 @@ def test_pandas_numeric(self): "float64", ] - if version.parse(pd.__version__) >= version.parse('1.2.0'): + if version.parse(pd.__version__) >= version.parse("1.2.0"): # These DTypes where added in 1.2.0 data_types.extend(["Float32", "Float64"]) # Generate a dataframe with all the types, in the form of: @@ -59,7 +59,7 @@ def test_pandas_numeric(self): df = pd.DataFrame.from_dict(data) conn = duckdb.connect() - out_df = conn.execute('select * from df').df() + out_df = conn.execute("select * from df").df() # Verify that the types in the out_df are correct # FIXME: we don't support outputting pandas specific types (i.e UInt64) @@ -68,14 +68,14 @@ def test_pandas_numeric(self): assert str(out_df[column_name].dtype) == item.lower() def test_pandas_unsigned(self, duckdb_cursor): - unsigned_types = ['uint8', 'uint16', 'uint32', 'uint64'] + unsigned_types = ["uint8", "uint16", "uint32", "uint64"] data = numpy.array([0, 1, 2, 3]) for u_type in unsigned_types: round_trip(data, u_type) def test_pandas_bool(self, duckdb_cursor): data = numpy.array([True, False, False, True]) - round_trip(data, 'bool') + round_trip(data, "bool") def test_pandas_masked_float64(self, duckdb_cursor, tmp_path): pa = pytest.importorskip("pyarrow") @@ -102,85 +102,85 @@ def test_pandas_boolean(self, duckdb_cursor): data = numpy.array([True, None, pd.NA, numpy.nan, True]) df_in = pd.DataFrame( { - 'object': pd.Series(data, dtype='boolean'), + "object": pd.Series(data, dtype="boolean"), } ) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() - assert df_out['object'][0] == df_in['object'][0] - assert pd.isna(df_out['object'][1]) - assert pd.isna(df_out['object'][2]) - assert pd.isna(df_out['object'][3]) - assert df_out['object'][4] == df_in['object'][4] + assert df_out["object"][0] == df_in["object"][0] + assert pd.isna(df_out["object"][1]) + assert pd.isna(df_out["object"][2]) + assert pd.isna(df_out["object"][3]) + assert df_out["object"][4] == df_in["object"][4] def test_pandas_float32(self, duckdb_cursor): data = numpy.array([0.1, 0.32, 0.78, numpy.nan]) df_in = pd.DataFrame( { - 'object': pd.Series(data, dtype='float32'), + "object": pd.Series(data, dtype="float32"), } ) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() - assert df_out['object'][0] == df_in['object'][0] - assert df_out['object'][1] == df_in['object'][1] - assert df_out['object'][2] == df_in['object'][2] - assert pd.isna(df_out['object'][3]) + assert df_out["object"][0] == df_in["object"][0] + assert df_out["object"][1] == df_in["object"][1] + assert df_out["object"][2] == df_in["object"][2] + assert pd.isna(df_out["object"][3]) def test_pandas_float64(self): - data = numpy.array([0.233, numpy.nan, 3456.2341231, float('-inf'), -23424.45345, float('+inf'), 0.0000000001]) + data = numpy.array([0.233, numpy.nan, 3456.2341231, float("-inf"), -23424.45345, float("+inf"), 0.0000000001]) df_in = pd.DataFrame( { - 'object': pd.Series(data, dtype='float64'), + "object": pd.Series(data, dtype="float64"), } ) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() for i in range(len(data)): - if pd.isna(df_out['object'][i]): + if pd.isna(df_out["object"][i]): assert i == 1 continue - assert df_out['object'][i] == df_in['object'][i] + assert df_out["object"][i] == df_in["object"][i] def test_pandas_interval(self, duckdb_cursor): - if pd.__version__ != '1.2.4': + if pd.__version__ != "1.2.4": return data = numpy.array([2069211000000000, numpy.datetime64("NaT")]) df_in = pd.DataFrame( { - 'object': pd.Series(data, dtype='timedelta64[ns]'), + "object": pd.Series(data, dtype="timedelta64[ns]"), } ) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() - assert df_out['object'][0] == df_in['object'][0] - assert pd.isnull(df_out['object'][1]) + assert df_out["object"][0] == df_in["object"][0] + assert pd.isnull(df_out["object"][1]) def test_pandas_encoded_utf8(self, duckdb_cursor): - data = u'\u00c3' # Unicode data - data = [data.encode('utf8')] + data = "\u00c3" # Unicode data + data = [data.encode("utf8")] expected_result = data[0] - df_in = pd.DataFrame({'object': pd.Series(data, dtype='object')}) + df_in = pd.DataFrame({"object": pd.Series(data, dtype="object")}) result = duckdb.query_df(df_in, "data", "SELECT * FROM data").fetchone()[0] assert result == expected_result @pytest.mark.parametrize( - 'dtype', + "dtype", [ - 'bool', - 'utinyint', - 'usmallint', - 'uinteger', - 'ubigint', - 'tinyint', - 'smallint', - 'integer', - 'bigint', - 'float', - 'double', + "bool", + "utinyint", + "usmallint", + "uinteger", + "ubigint", + "tinyint", + "smallint", + "integer", + "bigint", + "float", + "double", ], ) def test_producing_nullable_dtypes(self, duckdb_cursor, dtype): @@ -190,19 +190,19 @@ def __init__(self, value, expected_dtype) -> None: self.expected_dtype = expected_dtype inputs = { - 'bool': Input('true', 'BooleanDtype'), - 'utinyint': Input('255', 'UInt8Dtype'), - 'usmallint': Input('65535', 'UInt16Dtype'), - 'uinteger': Input('4294967295', 'UInt32Dtype'), - 'ubigint': Input('18446744073709551615', 'UInt64Dtype'), - 'tinyint': Input('-128', 'Int8Dtype'), - 'smallint': Input('-32768', 'Int16Dtype'), - 'integer': Input('-2147483648', 'Int32Dtype'), - 'bigint': Input('-9223372036854775808', 'Int64Dtype'), - 'float': Input('268043421344044473239570760152672894976.0000000000', 'float32'), - 'double': Input( - '14303088389124869511075243108389716684037132417196499782261853698893384831666205572097390431189931733040903060865714975797777061496396865611606109149583360363636503436181348332896211726552694379264498632046075093077887837955077425420408952536212326792778411457460885268567735875437456412217418386401944141824.0000000000', - 'float64', + "bool": Input("true", "BooleanDtype"), + "utinyint": Input("255", "UInt8Dtype"), + "usmallint": Input("65535", "UInt16Dtype"), + "uinteger": Input("4294967295", "UInt32Dtype"), + "ubigint": Input("18446744073709551615", "UInt64Dtype"), + "tinyint": Input("-128", "Int8Dtype"), + "smallint": Input("-32768", "Int16Dtype"), + "integer": Input("-2147483648", "Int32Dtype"), + "bigint": Input("-9223372036854775808", "Int64Dtype"), + "float": Input("268043421344044473239570760152672894976.0000000000", "float32"), + "double": Input( + "14303088389124869511075243108389716684037132417196499782261853698893384831666205572097390431189931733040903060865714975797777061496396865611606109149583360363636503436181348332896211726552694379264498632046075093077887837955077425420408952536212326792778411457460885268567735875437456412217418386401944141824.0000000000", + "float64", ), } @@ -222,7 +222,7 @@ def __init__(self, value, expected_dtype) -> None: rel = duckdb_cursor.sql(query) # Pandas <= 2.2.3 does not convert without throwing a warning - warnings.simplefilter(action='ignore', category=RuntimeWarning) + warnings.simplefilter(action="ignore", category=RuntimeWarning) with suppress(TypeError): df = rel.df() warnings.resetwarnings() @@ -231,4 +231,4 @@ def __init__(self, value, expected_dtype) -> None: expected_dtype = getattr(pd, input.expected_dtype) else: expected_dtype = numpy.dtype(input.expected_dtype) - assert isinstance(df['a'].dtype, expected_dtype) + assert isinstance(df["a"].dtype, expected_dtype) diff --git a/tests/fast/pandas/test_pandas_unregister.py b/tests/fast/pandas/test_pandas_unregister.py index 794e5910..fce8f42a 100644 --- a/tests/fast/pandas/test_pandas_unregister.py +++ b/tests/fast/pandas/test_pandas_unregister.py @@ -8,7 +8,7 @@ class TestPandasUnregister(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_pandas_unregister1(self, duckdb_cursor, pandas): df = pandas.DataFrame([[1, 2, 3], [4, 5, 6]]) connection = duckdb.connect(":memory:") @@ -16,13 +16,13 @@ def test_pandas_unregister1(self, duckdb_cursor, pandas): df2 = connection.execute("SELECT * FROM dataframe;").fetchdf() connection.unregister("dataframe") - with pytest.raises(duckdb.CatalogException, match='Table with name dataframe does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name dataframe does not exist"): connection.execute("SELECT * FROM dataframe;").fetchdf() - with pytest.raises(duckdb.CatalogException, match='View with name dataframe does not exist'): + with pytest.raises(duckdb.CatalogException, match="View with name dataframe does not exist"): connection.execute("DROP VIEW dataframe;") connection.execute("DROP VIEW IF EXISTS dataframe;") - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_pandas_unregister2(self, duckdb_cursor, pandas): fd, db = tempfile.mkstemp() os.close(fd) @@ -39,7 +39,7 @@ def test_pandas_unregister2(self, duckdb_cursor, pandas): connection = duckdb.connect(db) assert len(connection.execute("PRAGMA show_tables;").fetchall()) == 0 - with pytest.raises(duckdb.CatalogException, match='Table with name dataframe does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name dataframe does not exist"): connection.execute("SELECT * FROM dataframe;").fetchdf() connection.close() @@ -50,6 +50,6 @@ def test_pandas_unregister2(self, duckdb_cursor, pandas): # Reconnecting after DataFrame freed. connection = duckdb.connect(db) assert len(connection.execute("PRAGMA show_tables;").fetchall()) == 0 - with pytest.raises(duckdb.CatalogException, match='Table with name dataframe does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name dataframe does not exist"): connection.execute("SELECT * FROM dataframe;").fetchdf() connection.close() diff --git a/tests/fast/pandas/test_pandas_update.py b/tests/fast/pandas/test_pandas_update.py index 663d6da2..86d17154 100644 --- a/tests/fast/pandas/test_pandas_update.py +++ b/tests/fast/pandas/test_pandas_update.py @@ -4,10 +4,10 @@ class TestPandasUpdateList(object): def test_pandas_update_list(self, duckdb_cursor): - duckdb_cursor = duckdb.connect(':memory:') - duckdb_cursor.execute('create table t (l int[])') - duckdb_cursor.execute('insert into t values ([1, 2]), ([3,4])') - duckdb_cursor.execute('update t set l = [5, 6]') - expected = pd.DataFrame({'l': [[5, 6], [5, 6]]}) - res = duckdb_cursor.execute('select * from t').fetchdf() + duckdb_cursor = duckdb.connect(":memory:") + duckdb_cursor.execute("create table t (l int[])") + duckdb_cursor.execute("insert into t values ([1, 2]), ([3,4])") + duckdb_cursor.execute("update t set l = [5, 6]") + expected = pd.DataFrame({"l": [[5, 6], [5, 6]]}) + res = duckdb_cursor.execute("select * from t").fetchdf() pd.testing.assert_frame_equal(expected, res) diff --git a/tests/fast/pandas/test_parallel_pandas_scan.py b/tests/fast/pandas/test_parallel_pandas_scan.py index a9fd99b9..d113bbca 100644 --- a/tests/fast/pandas/test_parallel_pandas_scan.py +++ b/tests/fast/pandas/test_parallel_pandas_scan.py @@ -24,8 +24,8 @@ def run_parallel_queries(main_table, left_join_table, expected_df, pandas, itera try: duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") - duckdb_conn.register('main_table', main_table) - duckdb_conn.register('left_join_table', left_join_table) + duckdb_conn.register("main_table", main_table) + duckdb_conn.register("left_join_table", left_join_table) output_df = duckdb_conn.execute(sql).fetchdf() pandas.testing.assert_frame_equal(expected_df, output_df) print(output_df) @@ -36,69 +36,69 @@ def run_parallel_queries(main_table, left_join_table, expected_df, pandas, itera class TestParallelPandasScan(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_parallel_numeric_scan(self, duckdb_cursor, pandas): main_table = pandas.DataFrame([{"join_column": 3}]) left_join_table = pandas.DataFrame([{"join_column": 3, "other_column": 4}]) run_parallel_queries(main_table, left_join_table, left_join_table, pandas) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_parallel_ascii_text(self, duckdb_cursor, pandas): main_table = pandas.DataFrame([{"join_column": "text"}]) left_join_table = pandas.DataFrame([{"join_column": "text", "other_column": "more text"}]) run_parallel_queries(main_table, left_join_table, left_join_table, pandas) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_parallel_unicode_text(self, duckdb_cursor, pandas): - main_table = pandas.DataFrame([{"join_column": u"mühleisen"}]) - left_join_table = pandas.DataFrame([{"join_column": u"mühleisen", "other_column": u"höhöhö"}]) + main_table = pandas.DataFrame([{"join_column": "mühleisen"}]) + left_join_table = pandas.DataFrame([{"join_column": "mühleisen", "other_column": "höhöhö"}]) run_parallel_queries(main_table, left_join_table, left_join_table, pandas) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_parallel_complex_unicode_text(self, duckdb_cursor, pandas): - main_table = pandas.DataFrame([{"join_column": u"鴨"}]) - left_join_table = pandas.DataFrame([{"join_column": u"鴨", "other_column": u"數據庫"}]) + main_table = pandas.DataFrame([{"join_column": "鴨"}]) + left_join_table = pandas.DataFrame([{"join_column": "鴨", "other_column": "數據庫"}]) run_parallel_queries(main_table, left_join_table, left_join_table, pandas) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_parallel_emojis(self, duckdb_cursor, pandas): - main_table = pandas.DataFrame([{"join_column": u"🤦🏼‍♂️ L🤦🏼‍♂️R 🤦🏼‍♂️"}]) - left_join_table = pandas.DataFrame([{"join_column": u"🤦🏼‍♂️ L🤦🏼‍♂️R 🤦🏼‍♂️", "other_column": u"🦆🍞🦆"}]) + main_table = pandas.DataFrame([{"join_column": "🤦🏼‍♂️ L🤦🏼‍♂️R 🤦🏼‍♂️"}]) + left_join_table = pandas.DataFrame([{"join_column": "🤦🏼‍♂️ L🤦🏼‍♂️R 🤦🏼‍♂️", "other_column": "🦆🍞🦆"}]) run_parallel_queries(main_table, left_join_table, left_join_table, pandas) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_parallel_numeric_object(self, duckdb_cursor, pandas): - main_table = pandas.DataFrame({'join_column': pandas.Series([3], dtype="Int8")}) + main_table = pandas.DataFrame({"join_column": pandas.Series([3], dtype="Int8")}) left_join_table = pandas.DataFrame( - {'join_column': pandas.Series([3], dtype="Int8"), 'other_column': pandas.Series([4], dtype="Int8")} + {"join_column": pandas.Series([3], dtype="Int8"), "other_column": pandas.Series([4], dtype="Int8")} ) expected_df = pandas.DataFrame( {"join_column": numpy.array([3], dtype=numpy.int8), "other_column": numpy.array([4], dtype=numpy.int8)} ) run_parallel_queries(main_table, left_join_table, expected_df, pandas) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_parallel_timestamp(self, duckdb_cursor, pandas): - main_table = pandas.DataFrame({'join_column': [pandas.Timestamp('20180310T11:17:54Z')]}) + main_table = pandas.DataFrame({"join_column": [pandas.Timestamp("20180310T11:17:54Z")]}) left_join_table = pandas.DataFrame( { - 'join_column': [pandas.Timestamp('20180310T11:17:54Z')], - 'other_column': [pandas.Timestamp('20190310T11:17:54Z')], + "join_column": [pandas.Timestamp("20180310T11:17:54Z")], + "other_column": [pandas.Timestamp("20190310T11:17:54Z")], } ) expected_df = pandas.DataFrame( { - "join_column": numpy.array([datetime.datetime(2018, 3, 10, 11, 17, 54)], dtype='datetime64[ns]'), - "other_column": numpy.array([datetime.datetime(2019, 3, 10, 11, 17, 54)], dtype='datetime64[ns]'), + "join_column": numpy.array([datetime.datetime(2018, 3, 10, 11, 17, 54)], dtype="datetime64[ns]"), + "other_column": numpy.array([datetime.datetime(2019, 3, 10, 11, 17, 54)], dtype="datetime64[ns]"), } ) run_parallel_queries(main_table, left_join_table, expected_df, pandas) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_parallel_empty(self, duckdb_cursor, pandas): - df_empty = pandas.DataFrame({'A': []}) + df_empty = pandas.DataFrame({"A": []}) duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") duckdb_conn.execute("PRAGMA verify_parallelism") - duckdb_conn.register('main_table', df_empty) - assert duckdb_conn.execute('select * from main_table').fetchall() == [] + duckdb_conn.register("main_table", df_empty) + assert duckdb_conn.execute("select * from main_table").fetchall() == [] diff --git a/tests/fast/pandas/test_partitioned_pandas_scan.py b/tests/fast/pandas/test_partitioned_pandas_scan.py index 32c5352f..d2447ef8 100644 --- a/tests/fast/pandas/test_partitioned_pandas_scan.py +++ b/tests/fast/pandas/test_partitioned_pandas_scan.py @@ -8,9 +8,9 @@ class TestPartitionedPandasScan(object): def test_parallel_pandas(self, duckdb_cursor): con = duckdb.connect() - df = pd.DataFrame({'i': numpy.arange(10000000)}) + df = pd.DataFrame({"i": numpy.arange(10000000)}) - con.register('df', df) + con.register("df", df) seq_results = con.execute("SELECT SUM(i) FROM df").fetchall() diff --git a/tests/fast/pandas/test_progress_bar.py b/tests/fast/pandas/test_progress_bar.py index 241cedd6..7c1c21e1 100644 --- a/tests/fast/pandas/test_progress_bar.py +++ b/tests/fast/pandas/test_progress_bar.py @@ -8,10 +8,10 @@ class TestProgressBarPandas(object): def test_progress_pandas_single(self, duckdb_cursor): con = duckdb.connect() - df = pd.DataFrame({'i': numpy.arange(10000000)}) + df = pd.DataFrame({"i": numpy.arange(10000000)}) - con.register('df', df) - con.register('df_2', df) + con.register("df", df) + con.register("df_2", df) con.execute("PRAGMA progress_bar_time=1") con.execute("PRAGMA disable_print_progress_bar") result = con.execute("SELECT SUM(df.i) FROM df inner join df_2 on (df.i = df_2.i)").fetchall() @@ -19,10 +19,10 @@ def test_progress_pandas_single(self, duckdb_cursor): def test_progress_pandas_parallel(self, duckdb_cursor): con = duckdb.connect() - df = pd.DataFrame({'i': numpy.arange(10000000)}) + df = pd.DataFrame({"i": numpy.arange(10000000)}) - con.register('df', df) - con.register('df_2', df) + con.register("df", df) + con.register("df_2", df) con.execute("PRAGMA progress_bar_time=1") con.execute("PRAGMA disable_print_progress_bar") con.execute("PRAGMA threads=4") @@ -31,8 +31,8 @@ def test_progress_pandas_parallel(self, duckdb_cursor): def test_progress_pandas_empty(self, duckdb_cursor): con = duckdb.connect() - df = pd.DataFrame({'i': []}) - con.register('df', df) + df = pd.DataFrame({"i": []}) + con.register("df", df) con.execute("PRAGMA progress_bar_time=1") con.execute("PRAGMA disable_print_progress_bar") result = con.execute("SELECT SUM(df.i) from df").fetchall() diff --git a/tests/fast/pandas/test_pyarrow_projection_pushdown.py b/tests/fast/pandas/test_pyarrow_projection_pushdown.py index e693e75c..b04f713a 100644 --- a/tests/fast/pandas/test_pyarrow_projection_pushdown.py +++ b/tests/fast/pandas/test_pyarrow_projection_pushdown.py @@ -6,7 +6,7 @@ pa = pytest.importorskip("pyarrow") ds = pytest.importorskip("pyarrow.dataset") -_ = pytest.importorskip("pandas", '2.0.0') +_ = pytest.importorskip("pandas", "2.0.0") @pytest.mark.skipif(not pandas_supports_arrow_backend(), reason="pandas does not support the 'pyarrow' backend") @@ -16,6 +16,6 @@ def test_projection_pushdown_no_filter(self, duckdb_cursor): duckdb_conn.execute("CREATE TABLE test (a INTEGER, b INTEGER, c INTEGER)") duckdb_conn.execute("INSERT INTO test VALUES (1,1,1),(10,10,10),(100,10,100),(NULL,NULL,NULL)") duck_tbl = duckdb_conn.table("test") - arrow_table = duck_tbl.df().convert_dtypes(dtype_backend='pyarrow') + arrow_table = duck_tbl.df().convert_dtypes(dtype_backend="pyarrow") duckdb_conn.register("testarrowtable", arrow_table) assert duckdb_conn.execute("SELECT sum(a) FROM testarrowtable").fetchall() == [(111,)] diff --git a/tests/fast/pandas/test_same_name.py b/tests/fast/pandas/test_same_name.py index f48eb7eb..ac4f407a 100644 --- a/tests/fast/pandas/test_same_name.py +++ b/tests/fast/pandas/test_same_name.py @@ -5,76 +5,76 @@ class TestMultipleColumnsSameName(object): def test_multiple_columns_with_same_name(self, duckdb_cursor): - df = pd.DataFrame({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8], 'd': [9, 10, 11, 12]}) + df = pd.DataFrame({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8], "d": [9, 10, 11, 12]}) df = df.rename(columns={df.columns[1]: "a"}) - duckdb_cursor.register('df_view', df) + duckdb_cursor.register("df_view", df) - assert duckdb_cursor.table("df_view").columns == ['a', 'a_1', 'd'] + assert duckdb_cursor.table("df_view").columns == ["a", "a_1", "d"] assert duckdb_cursor.execute("select a_1 from df_view;").fetchall() == [(5,), (6,), (7,), (8,)] assert duckdb_cursor.execute("select a from df_view;").fetchall() == [(1,), (2,), (3,), (4,)] # Verify we are not changing original dataframe - assert all(df.columns == ['a', 'a', 'd']), df.columns + assert all(df.columns == ["a", "a", "d"]), df.columns def test_multiple_columns_with_same_name_relation(self, duckdb_cursor): - df = pd.DataFrame({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8], 'd': [9, 10, 11, 12]}) + df = pd.DataFrame({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8], "d": [9, 10, 11, 12]}) df = df.rename(columns={df.columns[1]: "a"}) rel = duckdb_cursor.from_df(df) assert rel.query("df_view", "DESCRIBE df_view;").fetchall() == [ - ('a', 'BIGINT', 'YES', None, None, None), - ('a_1', 'BIGINT', 'YES', None, None, None), - ('d', 'BIGINT', 'YES', None, None, None), + ("a", "BIGINT", "YES", None, None, None), + ("a_1", "BIGINT", "YES", None, None, None), + ("d", "BIGINT", "YES", None, None, None), ] assert rel.query("df_view", "select a_1 from df_view;").fetchall() == [(5,), (6,), (7,), (8,)] assert rel.query("df_view", "select a from df_view;").fetchall() == [(1,), (2,), (3,), (4,)] # Verify we are not changing original dataframe - assert all(df.columns == ['a', 'a', 'd']), df.columns + assert all(df.columns == ["a", "a", "d"]), df.columns def test_multiple_columns_with_same_name_replacement_scans(self, duckdb_cursor): - df = pd.DataFrame({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8], 'd': [9, 10, 11, 12]}) + df = pd.DataFrame({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8], "d": [9, 10, 11, 12]}) df = df.rename(columns={df.columns[1]: "a"}) assert duckdb_cursor.execute("select a_1 from df;").fetchall() == [(5,), (6,), (7,), (8,)] assert duckdb_cursor.execute("select a from df;").fetchall() == [(1,), (2,), (3,), (4,)] # Verify we are not changing original dataframe - assert all(df.columns == ['a', 'a', 'd']), df.columns + assert all(df.columns == ["a", "a", "d"]), df.columns def test_3669(self, duckdb_cursor): - df = pd.DataFrame([(1, 5, 9), (2, 6, 10), (3, 7, 11), (4, 8, 12)], columns=['a_1', 'a', 'a']) - duckdb_cursor.register('df_view', df) - assert duckdb_cursor.table("df_view").columns == ['a_1', 'a', 'a_2'] + df = pd.DataFrame([(1, 5, 9), (2, 6, 10), (3, 7, 11), (4, 8, 12)], columns=["a_1", "a", "a"]) + duckdb_cursor.register("df_view", df) + assert duckdb_cursor.table("df_view").columns == ["a_1", "a", "a_2"] assert duckdb_cursor.execute("select a_1 from df_view;").fetchall() == [(1,), (2,), (3,), (4,)] assert duckdb_cursor.execute("select a from df_view;").fetchall() == [(5,), (6,), (7,), (8,)] # Verify we are not changing original dataframe - assert all(df.columns == ['a_1', 'a', 'a']), df.columns + assert all(df.columns == ["a_1", "a", "a"]), df.columns def test_minimally_rename(self, duckdb_cursor): df = pd.DataFrame( - [(1, 5, 9, 13), (2, 6, 10, 14), (3, 7, 11, 15), (4, 8, 12, 16)], columns=['a_1', 'a', 'a', 'a_2'] + [(1, 5, 9, 13), (2, 6, 10, 14), (3, 7, 11, 15), (4, 8, 12, 16)], columns=["a_1", "a", "a", "a_2"] ) - duckdb_cursor.register('df_view', df) + duckdb_cursor.register("df_view", df) rel = duckdb_cursor.table("df_view") res = rel.columns - assert res == ['a_1', 'a', 'a_2', 'a_2_1'] + assert res == ["a_1", "a", "a_2", "a_2_1"] assert duckdb_cursor.execute("select a_1 from df_view;").fetchall() == [(1,), (2,), (3,), (4,)] assert duckdb_cursor.execute("select a from df_view;").fetchall() == [(5,), (6,), (7,), (8,)] assert duckdb_cursor.execute("select a_2 from df_view;").fetchall() == [(9,), (10,), (11,), (12,)] assert duckdb_cursor.execute("select a_2_1 from df_view;").fetchall() == [(13,), (14,), (15,), (16,)] # Verify we are not changing original dataframe - assert all(df.columns == ['a_1', 'a', 'a', 'a_2']), df.columns + assert all(df.columns == ["a_1", "a", "a", "a_2"]), df.columns def test_multiple_columns_with_same_name_2(self, duckdb_cursor): - df = pd.DataFrame({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8], 'a_1': [9, 10, 11, 12]}) + df = pd.DataFrame({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8], "a_1": [9, 10, 11, 12]}) df = df.rename(columns={df.columns[1]: "a_1"}) - duckdb_cursor.register('df_view', df) - assert duckdb_cursor.table("df_view").columns == ['a', 'a_1', 'a_1_1'] + duckdb_cursor.register("df_view", df) + assert duckdb_cursor.table("df_view").columns == ["a", "a_1", "a_1_1"] assert duckdb_cursor.execute("select a_1 from df_view;").fetchall() == [(5,), (6,), (7,), (8,)] assert duckdb_cursor.execute("select a from df_view;").fetchall() == [(1,), (2,), (3,), (4,)] assert duckdb_cursor.execute("select a_1_1 from df_view;").fetchall() == [(9,), (10,), (11,), (12,)] def test_case_insensitive(self, duckdb_cursor): - df = pd.DataFrame({'A_1': [1, 2, 3, 4], 'a_1': [9, 10, 11, 12]}) - duckdb_cursor.register('df_view', df) - assert duckdb_cursor.table("df_view").columns == ['A_1', 'a_1_1'] + df = pd.DataFrame({"A_1": [1, 2, 3, 4], "a_1": [9, 10, 11, 12]}) + duckdb_cursor.register("df_view", df) + assert duckdb_cursor.table("df_view").columns == ["A_1", "a_1_1"] assert duckdb_cursor.execute("select a_1 from df_view;").fetchall() == [(1,), (2,), (3,), (4,)] assert duckdb_cursor.execute("select a_1_1 from df_view;").fetchall() == [(9,), (10,), (11,), (12,)] diff --git a/tests/fast/pandas/test_stride.py b/tests/fast/pandas/test_stride.py index 5efe8d56..1b2f5052 100644 --- a/tests/fast/pandas/test_stride.py +++ b/tests/fast/pandas/test_stride.py @@ -8,27 +8,27 @@ class TestPandasStride(object): def test_stride(self, duckdb_cursor): expected_df = pd.DataFrame(np.arange(20).reshape(5, 4), columns=["a", "b", "c", "d"]) con = duckdb.connect() - con.register('df_view', expected_df) + con.register("df_view", expected_df) output_df = con.execute("SELECT * FROM df_view;").fetchdf() pd.testing.assert_frame_equal(expected_df, output_df) def test_stride_fp32(self, duckdb_cursor): - expected_df = pd.DataFrame(np.arange(20, dtype='float32').reshape(5, 4), columns=["a", "b", "c", "d"]) + expected_df = pd.DataFrame(np.arange(20, dtype="float32").reshape(5, 4), columns=["a", "b", "c", "d"]) con = duckdb.connect() - con.register('df_view', expected_df) + con.register("df_view", expected_df) output_df = con.execute("SELECT * FROM df_view;").fetchdf() for col in output_df.columns: - assert str(output_df[col].dtype) == 'float32' + assert str(output_df[col].dtype) == "float32" pd.testing.assert_frame_equal(expected_df, output_df) def test_stride_datetime(self, duckdb_cursor): - df = pd.DataFrame({'date': pd.Series(pd.date_range("2024-01-01", freq="D", periods=100))}) + df = pd.DataFrame({"date": pd.Series(pd.date_range("2024-01-01", freq="D", periods=100))}) df = df.loc[::23,] roundtrip = duckdb_cursor.sql("select * from df").df() expected = pd.DataFrame( { - 'date': [ + "date": [ datetime.datetime(2024, 1, 1), datetime.datetime(2024, 1, 24), datetime.datetime(2024, 2, 16), @@ -40,13 +40,13 @@ def test_stride_datetime(self, duckdb_cursor): pd.testing.assert_frame_equal(roundtrip, expected) def test_stride_timedelta(self, duckdb_cursor): - df = pd.DataFrame({'date': [datetime.timedelta(days=i) for i in range(100)]}) + df = pd.DataFrame({"date": [datetime.timedelta(days=i) for i in range(100)]}) df = df.loc[::23,] roundtrip = duckdb_cursor.sql("select * from df").df() expected = pd.DataFrame( { - 'date': [ + "date": [ datetime.timedelta(days=0), datetime.timedelta(days=23), datetime.timedelta(days=46), @@ -58,10 +58,10 @@ def test_stride_timedelta(self, duckdb_cursor): pd.testing.assert_frame_equal(roundtrip, expected) def test_stride_fp64(self, duckdb_cursor): - expected_df = pd.DataFrame(np.arange(20, dtype='float64').reshape(5, 4), columns=["a", "b", "c", "d"]) + expected_df = pd.DataFrame(np.arange(20, dtype="float64").reshape(5, 4), columns=["a", "b", "c", "d"]) con = duckdb.connect() - con.register('df_view', expected_df) + con.register("df_view", expected_df) output_df = con.execute("SELECT * FROM df_view;").fetchdf() for col in output_df.columns: - assert str(output_df[col].dtype) == 'float64' + assert str(output_df[col].dtype) == "float64" pd.testing.assert_frame_equal(expected_df, output_df) diff --git a/tests/fast/pandas/test_timedelta.py b/tests/fast/pandas/test_timedelta.py index 5c6aa4b9..c0afeb74 100644 --- a/tests/fast/pandas/test_timedelta.py +++ b/tests/fast/pandas/test_timedelta.py @@ -11,7 +11,7 @@ def test_timedelta_positive(self, duckdb_cursor): "SELECT '2290-01-01 23:59:00'::TIMESTAMP - '2000-01-01 23:59:00'::TIMESTAMP AS '0'" ).df() data = [datetime.timedelta(microseconds=9151574400000000)] - df_in = pd.DataFrame({0: pd.Series(data=data, dtype='object')}) + df_in = pd.DataFrame({0: pd.Series(data=data, dtype="object")}) df_out = duckdb.query_df(df_in, "df", "select * from df", connection=duckdb_cursor).df() pd.testing.assert_frame_equal(df_out, duckdb_interval) @@ -20,7 +20,7 @@ def test_timedelta_basic(self, duckdb_cursor): "SELECT '2290-08-30 23:53:40'::TIMESTAMP - '2000-02-01 01:56:00'::TIMESTAMP AS '0'" ).df() data = [datetime.timedelta(microseconds=9169797460000000)] - df_in = pd.DataFrame({0: pd.Series(data=data, dtype='object')}) + df_in = pd.DataFrame({0: pd.Series(data=data, dtype="object")}) df_out = duckdb.query_df(df_in, "df", "select * from df", connection=duckdb_cursor).df() pd.testing.assert_frame_equal(df_out, duckdb_interval) @@ -29,24 +29,24 @@ def test_timedelta_negative(self, duckdb_cursor): "SELECT '2000-01-01 23:59:00'::TIMESTAMP - '2290-01-01 23:59:00'::TIMESTAMP AS '0'" ).df() data = [datetime.timedelta(microseconds=-9151574400000000)] - df_in = pd.DataFrame({0: pd.Series(data=data, dtype='object')}) + df_in = pd.DataFrame({0: pd.Series(data=data, dtype="object")}) df_out = duckdb.query_df(df_in, "df", "select * from df", connection=duckdb_cursor).df() pd.testing.assert_frame_equal(df_out, duckdb_interval) - @pytest.mark.parametrize('days', [1, 9999]) - @pytest.mark.parametrize('seconds', [0, 60]) + @pytest.mark.parametrize("days", [1, 9999]) + @pytest.mark.parametrize("seconds", [0, 60]) @pytest.mark.parametrize( - 'microseconds', + "microseconds", [ 0, 232493, 999_999, ], ) - @pytest.mark.parametrize('milliseconds', [0, 999]) - @pytest.mark.parametrize('minutes', [0, 60]) - @pytest.mark.parametrize('hours', [0, 24]) - @pytest.mark.parametrize('weeks', [0, 51]) + @pytest.mark.parametrize("milliseconds", [0, 999]) + @pytest.mark.parametrize("minutes", [0, 60]) + @pytest.mark.parametrize("hours", [0, 24]) + @pytest.mark.parametrize("weeks", [0, 51]) @pytest.mark.skipif(platform.system() == "Emscripten", reason="Bind parameters are broken when running on Pyodide") def test_timedelta_coverage(self, duckdb_cursor, days, seconds, microseconds, milliseconds, minutes, hours, weeks): def create_duck_interval(days, seconds, microseconds, milliseconds, minutes, hours, weeks) -> str: diff --git a/tests/fast/pandas/test_timestamp.py b/tests/fast/pandas/test_timestamp.py index 0a580025..dbb7273d 100644 --- a/tests/fast/pandas/test_timestamp.py +++ b/tests/fast/pandas/test_timestamp.py @@ -8,33 +8,33 @@ class TestPandasTimestamps(object): - @pytest.mark.parametrize('unit', ['s', 'ms', 'us', 'ns']) + @pytest.mark.parametrize("unit", ["s", "ms", "us", "ns"]) def test_timestamp_types_roundtrip(self, unit): d = { - 'time': pd.Series( + "time": pd.Series( [pd.Timestamp(datetime.datetime(2020, 6, 12, 14, 43, 24, 394587), unit=unit)], - dtype=f'datetime64[{unit}]', + dtype=f"datetime64[{unit}]", ) } df = pd.DataFrame(data=d) df_from_duck = duckdb.from_df(df).df() assert df_from_duck.equals(df) - @pytest.mark.parametrize('unit', ['s', 'ms', 'us', 'ns']) + @pytest.mark.parametrize("unit", ["s", "ms", "us", "ns"]) def test_timestamp_timezone_roundtrip(self, unit): if pandas_2_or_higher(): - dtype = pd.core.dtypes.dtypes.DatetimeTZDtype(unit=unit, tz='UTC') - expected_dtype = pd.core.dtypes.dtypes.DatetimeTZDtype(unit='us', tz='UTC') + dtype = pd.core.dtypes.dtypes.DatetimeTZDtype(unit=unit, tz="UTC") + expected_dtype = pd.core.dtypes.dtypes.DatetimeTZDtype(unit="us", tz="UTC") else: # Older versions of pandas only support 'ns' as timezone unit - expected_dtype = pd.core.dtypes.dtypes.DatetimeTZDtype(unit='ns', tz='UTC') - dtype = pd.core.dtypes.dtypes.DatetimeTZDtype(unit='ns', tz='UTC') + expected_dtype = pd.core.dtypes.dtypes.DatetimeTZDtype(unit="ns", tz="UTC") + dtype = pd.core.dtypes.dtypes.DatetimeTZDtype(unit="ns", tz="UTC") conn = duckdb.connect() conn.execute("SET TimeZone =UTC") d = { - 'time': pd.Series( - [pd.Timestamp(datetime.datetime(2020, 6, 12, 14, 43, 24, 394587), unit=unit, tz='UTC')], + "time": pd.Series( + [pd.Timestamp(datetime.datetime(2020, 6, 12, 14, 43, 24, 394587), unit=unit, tz="UTC")], dtype=dtype, ) } @@ -46,9 +46,9 @@ def test_timestamp_timezone_roundtrip(self, unit): df_from_duck = conn.from_df(df).df() assert df_from_duck.equals(expected) - @pytest.mark.parametrize('unit', ['s', 'ms', 'us', 'ns']) + @pytest.mark.parametrize("unit", ["s", "ms", "us", "ns"]) def test_timestamp_nulls(self, unit): - d = {'time': pd.Series([pd.Timestamp(None, unit=unit)], dtype=f'datetime64[{unit}]')} + d = {"time": pd.Series([pd.Timestamp(None, unit=unit)], dtype=f"datetime64[{unit}]")} df = pd.DataFrame(data=d) df_from_duck = duckdb.from_df(df).df() assert df_from_duck.equals(df) @@ -56,10 +56,10 @@ def test_timestamp_nulls(self, unit): def test_timestamp_timedelta(self): df = pd.DataFrame( { - 'a': [pd.Timedelta(1, unit='s')], - 'b': [pd.Timedelta(None, unit='s')], - 'c': [pd.Timedelta(1, unit='us')], - 'd': [pd.Timedelta(1, unit='ms')], + "a": [pd.Timedelta(1, unit="s")], + "b": [pd.Timedelta(None, unit="s")], + "c": [pd.Timedelta(1, unit="us")], + "d": [pd.Timedelta(1, unit="ms")], } ) df_from_duck = duckdb.from_df(df).df() @@ -78,4 +78,4 @@ def test_timestamp_timezone(self, duckdb_cursor): """ ) res = rel.df() - assert res['dateTime'][0] == res['dateTime_1'][0] + assert res["dateTime"][0] == res["dateTime_1"][0] diff --git a/tests/fast/relational_api/test_groupings.py b/tests/fast/relational_api/test_groupings.py index fc81deba..b0a95410 100644 --- a/tests/fast/relational_api/test_groupings.py +++ b/tests/fast/relational_api/test_groupings.py @@ -22,7 +22,7 @@ def con(): class TestGroupings(object): def test_basic_grouping(self, con): - rel = con.table('tbl').sum("a", "b") + rel = con.table("tbl").sum("a", "b") res = rel.fetchall() assert res == [(7,), (2,), (5,)] @@ -31,7 +31,7 @@ def test_basic_grouping(self, con): assert res == res2 def test_cubed(self, con): - rel = con.table('tbl').sum("a", "CUBE (b)").order("ALL") + rel = con.table("tbl").sum("a", "CUBE (b)").order("ALL") res = rel.fetchall() assert res == [(2,), (5,), (7,), (14,)] @@ -40,7 +40,7 @@ def test_cubed(self, con): assert res == res2 def test_rollup(self, con): - rel = con.table('tbl').sum("a", "ROLLUP (b, c)").order("ALL") + rel = con.table("tbl").sum("a", "ROLLUP (b, c)").order("ALL") res = rel.fetchall() assert res == [(1,), (1,), (2,), (2,), (2,), (3,), (5,), (5,), (7,), (14,)] diff --git a/tests/fast/relational_api/test_joins.py b/tests/fast/relational_api/test_joins.py index 8eb365d5..cf3d3cf2 100644 --- a/tests/fast/relational_api/test_joins.py +++ b/tests/fast/relational_api/test_joins.py @@ -31,57 +31,57 @@ def con(): class TestRAPIJoins(object): def test_outer_join(self, con): - a = con.table('tbl_a') - b = con.table('tbl_b') - expr = ColumnExpression('tbl_a.b') == ColumnExpression('tbl_b.a') - rel = a.join(b, expr, 'outer') + a = con.table("tbl_a") + b = con.table("tbl_b") + expr = ColumnExpression("tbl_a.b") == ColumnExpression("tbl_b.a") + rel = a.join(b, expr, "outer") res = rel.fetchall() assert res == [(1, 1, 1, 4), (2, 1, 1, 4), (3, 2, None, None), (None, None, 3, 5)] def test_inner_join(self, con): - a = con.table('tbl_a') - b = con.table('tbl_b') - expr = ColumnExpression('tbl_a.b') == ColumnExpression('tbl_b.a') - rel = a.join(b, expr, 'inner') + a = con.table("tbl_a") + b = con.table("tbl_b") + expr = ColumnExpression("tbl_a.b") == ColumnExpression("tbl_b.a") + rel = a.join(b, expr, "inner") res = rel.fetchall() assert res == [(1, 1, 1, 4), (2, 1, 1, 4)] def test_anti_join(self, con): - a = con.table('tbl_a') - b = con.table('tbl_b') - expr = ColumnExpression('tbl_a.b') == ColumnExpression('tbl_b.a') - rel = a.join(b, expr, 'anti') + a = con.table("tbl_a") + b = con.table("tbl_b") + expr = ColumnExpression("tbl_a.b") == ColumnExpression("tbl_b.a") + rel = a.join(b, expr, "anti") res = rel.fetchall() # Only output the row(s) from A where the condition is false assert res == [(3, 2)] def test_left_join(self, con): - a = con.table('tbl_a') - b = con.table('tbl_b') - expr = ColumnExpression('tbl_a.b') == ColumnExpression('tbl_b.a') - rel = a.join(b, expr, 'left') + a = con.table("tbl_a") + b = con.table("tbl_b") + expr = ColumnExpression("tbl_a.b") == ColumnExpression("tbl_b.a") + rel = a.join(b, expr, "left") res = rel.fetchall() assert res == [(1, 1, 1, 4), (2, 1, 1, 4), (3, 2, None, None)] def test_right_join(self, con): - a = con.table('tbl_a') - b = con.table('tbl_b') - expr = ColumnExpression('tbl_a.b') == ColumnExpression('tbl_b.a') - rel = a.join(b, expr, 'right') + a = con.table("tbl_a") + b = con.table("tbl_b") + expr = ColumnExpression("tbl_a.b") == ColumnExpression("tbl_b.a") + rel = a.join(b, expr, "right") res = rel.fetchall() assert res == [(1, 1, 1, 4), (2, 1, 1, 4), (None, None, 3, 5)] def test_semi_join(self, con): - a = con.table('tbl_a') - b = con.table('tbl_b') - expr = ColumnExpression('tbl_a.b') == ColumnExpression('tbl_b.a') - rel = a.join(b, expr, 'semi') + a = con.table("tbl_a") + b = con.table("tbl_b") + expr = ColumnExpression("tbl_a.b") == ColumnExpression("tbl_b.a") + rel = a.join(b, expr, "semi") res = rel.fetchall() assert res == [(1, 1), (2, 1)] def test_cross_join(self, con): - a = con.table('tbl_a') - b = con.table('tbl_b') + a = con.table("tbl_a") + b = con.table("tbl_b") rel = a.cross(b) res = rel.fetchall() assert res == [(1, 1, 1, 4), (2, 1, 1, 4), (3, 2, 1, 4), (1, 1, 3, 5), (2, 1, 3, 5), (3, 2, 3, 5)] diff --git a/tests/fast/relational_api/test_pivot.py b/tests/fast/relational_api/test_pivot.py index d78df656..9cf91e56 100644 --- a/tests/fast/relational_api/test_pivot.py +++ b/tests/fast/relational_api/test_pivot.py @@ -27,4 +27,4 @@ def test_pivot_issue_14601(self, duckdb_cursor): export_dir = tempfile.mkdtemp() duckdb_cursor.query(f"EXPORT DATABASE '{export_dir}'") with open(os.path.join(export_dir, "schema.sql"), "r") as f: - assert 'CREATE TYPE' not in f.read() + assert "CREATE TYPE" not in f.read() diff --git a/tests/fast/relational_api/test_rapi_aggregations.py b/tests/fast/relational_api/test_rapi_aggregations.py index 29202759..3466a77a 100644 --- a/tests/fast/relational_api/test_rapi_aggregations.py +++ b/tests/fast/relational_api/test_rapi_aggregations.py @@ -269,7 +269,7 @@ def test_product(self, table): def test_string_agg(self, table): result = table.string_agg("s", sep="/").execute().fetchall() - expected = [('h/e/l/l/o/,/wor/ld',)] + expected = [("h/e/l/l/o/,/wor/ld",)] assert len(result) == len(expected) assert all([r == e for r, e in zip(result, expected)]) result = ( @@ -278,7 +278,7 @@ def test_string_agg(self, table): .execute() .fetchall() ) - expected = [(1, 'h/e/l'), (2, 'l/o'), (3, ',/wor/ld')] + expected = [(1, "h/e/l"), (2, "l/o"), (3, ",/wor/ld")] assert len(result) == len(expected) assert all([r == e for r, e in zip(result, expected)]) diff --git a/tests/fast/relational_api/test_rapi_close.py b/tests/fast/relational_api/test_rapi_close.py index 270c58f5..b6355167 100644 --- a/tests/fast/relational_api/test_rapi_close.py +++ b/tests/fast/relational_api/test_rapi_close.py @@ -11,153 +11,153 @@ def test_close_conn_rel(self, duckdb_cursor): rel = con.table("items") con.close() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): print(rel) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): len(rel) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.aggregate("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.any_value("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.apply("", "") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.arg_max("", "") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.arg_min("", "") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.fetch_arrow_table() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.avg("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.bit_and("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.bit_or("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.bit_xor("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.bitstring_agg("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.bool_and("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.bool_or("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.count("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.create("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.create_view("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.cume_dist("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.dense_rank("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.describe() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.df() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.distinct() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.execute() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.favg("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.fetchall() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.fetchnumpy() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.fetchone() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.filter("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.first("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.first_value("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.fsum("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.geomean("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.histogram("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.insert("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.insert_into("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.lag("", "") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.last("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.last_value("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.lead("", "") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): print(rel.limit(1)) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.list("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.map(lambda df: df['col0'].add(42).to_frame()) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): + rel.map(lambda df: df["col0"].add(42).to_frame()) + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.max("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.mean("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.median("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.min("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.mode("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.n_tile("", 1) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.nth_value("", "", 1) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.order("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.percent_rank("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.product("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.project("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.quantile("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.quantile_cont("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.quantile_disc("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.query("", "") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.rank("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.rank_dense("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.row_number("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.std("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.stddev("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.stddev_pop("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.stddev_samp("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.string_agg("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.sum("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.to_arrow_table() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.to_df() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.var("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.var_pop("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.var_samp("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.variance("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.write_csv("") con = duckdb.connect() @@ -166,14 +166,14 @@ def test_close_conn_rel(self, duckdb_cursor): valid_rel = con.table("items") # Test these bad boys when left relation is valid - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): valid_rel.union(rel) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): valid_rel.except_(rel) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): valid_rel.intersect(rel) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - valid_rel.join(rel.set_alias('rel'), "rel.items = valid_rel.items") + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): + valid_rel.join(rel.set_alias("rel"), "rel.items = valid_rel.items") def test_del_conn(self, duckdb_cursor): con = duckdb.connect() @@ -181,5 +181,5 @@ def test_del_conn(self, duckdb_cursor): con.execute("INSERT INTO items VALUES ('jeans', 20.0, 1), ('hammer', 42.2, 2)") rel = con.table("items") del con - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): print(rel) diff --git a/tests/fast/relational_api/test_rapi_description.py b/tests/fast/relational_api/test_rapi_description.py index 41813d94..80616132 100644 --- a/tests/fast/relational_api/test_rapi_description.py +++ b/tests/fast/relational_api/test_rapi_description.py @@ -4,31 +4,31 @@ class TestRAPIDescription(object): def test_rapi_description(self, duckdb_cursor): - res = duckdb_cursor.query('select 42::INT AS a, 84::BIGINT AS b') + res = duckdb_cursor.query("select 42::INT AS a, 84::BIGINT AS b") desc = res.description names = [x[0] for x in desc] types = [x[1] for x in desc] - assert names == ['a', 'b'] - assert types == ['INTEGER', 'BIGINT'] + assert names == ["a", "b"] + assert types == ["INTEGER", "BIGINT"] assert all([x == duckdb.NUMBER for x in types]) def test_rapi_describe(self, duckdb_cursor): np = pytest.importorskip("numpy") pd = pytest.importorskip("pandas") - res = duckdb_cursor.query('select 42::INT AS a, 84::BIGINT AS b') + res = duckdb_cursor.query("select 42::INT AS a, 84::BIGINT AS b") duck_describe = res.describe().df() - np.testing.assert_array_equal(duck_describe['aggr'], ['count', 'mean', 'stddev', 'min', 'max', 'median']) - np.testing.assert_array_equal(duck_describe['a'], [1, 42, float('nan'), 42, 42, 42]) - np.testing.assert_array_equal(duck_describe['b'], [1, 84, float('nan'), 84, 84, 84]) + np.testing.assert_array_equal(duck_describe["aggr"], ["count", "mean", "stddev", "min", "max", "median"]) + np.testing.assert_array_equal(duck_describe["a"], [1, 42, float("nan"), 42, 42, 42]) + np.testing.assert_array_equal(duck_describe["b"], [1, 84, float("nan"), 84, 84, 84]) # now with more values res = duckdb_cursor.query( - 'select CASE WHEN i%2=0 THEN i ELSE NULL END AS i, i * 10 AS j, (i * 23 // 27)::DOUBLE AS k FROM range(10000) t(i)' + "select CASE WHEN i%2=0 THEN i ELSE NULL END AS i, i * 10 AS j, (i * 23 // 27)::DOUBLE AS k FROM range(10000) t(i)" ) duck_describe = res.describe().df() - np.testing.assert_allclose(duck_describe['i'], [5000.0, 4999.0, 2887.0400066504103, 0.0, 9998.0, 4999.0]) - np.testing.assert_allclose(duck_describe['j'], [10000.0, 49995.0, 28868.956799071675, 0.0, 99990.0, 49995.0]) - np.testing.assert_allclose(duck_describe['k'], [10000.0, 4258.3518, 2459.207430770227, 0.0, 8517.0, 4258.5]) + np.testing.assert_allclose(duck_describe["i"], [5000.0, 4999.0, 2887.0400066504103, 0.0, 9998.0, 4999.0]) + np.testing.assert_allclose(duck_describe["j"], [10000.0, 49995.0, 28868.956799071675, 0.0, 99990.0, 49995.0]) + np.testing.assert_allclose(duck_describe["k"], [10000.0, 4258.3518, 2459.207430770227, 0.0, 8517.0, 4258.5]) # describe data with other (non-numeric) types res = duckdb_cursor.query("select 'hello world' AS a, [1, 2, 3] AS b") @@ -38,8 +38,8 @@ def test_rapi_describe(self, duckdb_cursor): # describe mixed table res = duckdb_cursor.query("select 42::INT AS a, 84::BIGINT AS b, 'hello world' AS c") duck_describe = res.describe().df() - np.testing.assert_array_equal(duck_describe['a'], [1, 42, float('nan'), 42, 42, 42]) - np.testing.assert_array_equal(duck_describe['b'], [1, 84, float('nan'), 84, 84, 84]) + np.testing.assert_array_equal(duck_describe["a"], [1, 42, float("nan"), 42, 42, 42]) + np.testing.assert_array_equal(duck_describe["b"], [1, 84, float("nan"), 84, 84, 84]) # timestamps res = duckdb_cursor.query("select timestamp '1992-01-01', date '2000-01-01'") diff --git a/tests/fast/relational_api/test_rapi_functions.py b/tests/fast/relational_api/test_rapi_functions.py index 92de4c2c..c6b1f1fa 100644 --- a/tests/fast/relational_api/test_rapi_functions.py +++ b/tests/fast/relational_api/test_rapi_functions.py @@ -3,10 +3,10 @@ class TestRAPIFunctions(object): def test_rapi_str_print(self, duckdb_cursor): - res = duckdb_cursor.query('select 42::INT AS a, 84::BIGINT AS b') + res = duckdb_cursor.query("select 42::INT AS a, 84::BIGINT AS b") assert str(res) is not None res.show() def test_rapi_relation_sql_query(self): - res = duckdb.table_function('range', [10]) + res = duckdb.table_function("range", [10]) assert res.sql_query() == 'SELECT * FROM "range"(10)' diff --git a/tests/fast/relational_api/test_rapi_query.py b/tests/fast/relational_api/test_rapi_query.py index 92f87776..16ed326c 100644 --- a/tests/fast/relational_api/test_rapi_query.py +++ b/tests/fast/relational_api/test_rapi_query.py @@ -10,12 +10,12 @@ def tbl_table(): con.execute("drop table if exists tbl CASCADE") con.execute("create table tbl (i integer)") yield - con.execute('drop table tbl CASCADE') + con.execute("drop table tbl CASCADE") @pytest.fixture() def scoped_default(duckdb_cursor): - default = duckdb.connect(':default:') + default = duckdb.connect(":default:") duckdb.set_default_connection(duckdb_cursor) # Overwrite the default connection yield @@ -24,7 +24,7 @@ def scoped_default(duckdb_cursor): class TestRAPIQuery(object): - @pytest.mark.parametrize('steps', [1, 2, 3, 4]) + @pytest.mark.parametrize("steps", [1, 2, 3, 4]) def test_query_chain(self, steps): con = duckdb.default_connection() amount = int(1000000) @@ -36,7 +36,7 @@ def test_query_chain(self, steps): result = rel.execute() assert len(result.fetchall()) == amount - @pytest.mark.parametrize('input', [[5, 4, 3], [], [1000]]) + @pytest.mark.parametrize("input", [[5, 4, 3], [], [1000]]) def test_query_table(self, tbl_table, input): con = duckdb.default_connection() rel = con.table("tbl") @@ -98,80 +98,80 @@ def test_query_table_unrelated(self, tbl_table): def test_query_non_select_result(self, duckdb_cursor): with pytest.raises(duckdb.ParserException, match="syntax error"): - duckdb_cursor.query('selec 42') + duckdb_cursor.query("selec 42") - res = duckdb_cursor.query('explain select 42').fetchall() + res = duckdb_cursor.query("explain select 42").fetchall() assert len(res) > 0 - res = duckdb_cursor.query('describe select 42::INT AS column_name').fetchall() - assert res[0][0] == 'column_name' + res = duckdb_cursor.query("describe select 42::INT AS column_name").fetchall() + assert res[0][0] == "column_name" - res = duckdb_cursor.query('create or replace table tbl_non_select_result(i integer)') + res = duckdb_cursor.query("create or replace table tbl_non_select_result(i integer)") assert res is None - res = duckdb_cursor.query('insert into tbl_non_select_result values (42)') + res = duckdb_cursor.query("insert into tbl_non_select_result values (42)") assert res is None - res = duckdb_cursor.query('insert into tbl_non_select_result values (84) returning *').fetchall() + res = duckdb_cursor.query("insert into tbl_non_select_result values (84) returning *").fetchall() assert res == [(84,)] - res = duckdb_cursor.query('select * from tbl_non_select_result').fetchall() + res = duckdb_cursor.query("select * from tbl_non_select_result").fetchall() assert res == [(42,), (84,)] - res = duckdb_cursor.query('insert into tbl_non_select_result select * from range(10000) returning *').fetchall() + res = duckdb_cursor.query("insert into tbl_non_select_result select * from range(10000) returning *").fetchall() assert len(res) == 10000 - res = duckdb_cursor.query('show tables').fetchall() + res = duckdb_cursor.query("show tables").fetchall() assert len(res) > 0 - res = duckdb_cursor.query('drop table tbl_non_select_result') + res = duckdb_cursor.query("drop table tbl_non_select_result") assert res is None def test_replacement_scan_recursion(self, duckdb_cursor): depth_limit = 1000 - if sys.platform.startswith('win') or platform.system() == "Emscripten": + if sys.platform.startswith("win") or platform.system() == "Emscripten": # With the default we reach a stack overflow in the CI for windows # and also outside of it for Pyodide depth_limit = 250 duckdb_cursor.execute(f"SET max_expression_depth TO {depth_limit}") - rel = duckdb_cursor.sql('select 42 a, 21 b') - rel = duckdb_cursor.sql('select a+a a, b+b b from rel') - other_rel = duckdb_cursor.sql('select a from rel') + rel = duckdb_cursor.sql("select 42 a, 21 b") + rel = duckdb_cursor.sql("select a+a a, b+b b from rel") + other_rel = duckdb_cursor.sql("select a from rel") res = other_rel.fetchall() assert res == [(84,)] def test_set_default_connection(self, scoped_default): duckdb.sql("create table t as select 42") - assert duckdb.table('t').fetchall() == [(42,)] - con = duckdb.connect(':default:') + assert duckdb.table("t").fetchall() == [(42,)] + con = duckdb.connect(":default:") # Uses the same db as the module - assert con.table('t').fetchall() == [(42,)] + assert con.table("t").fetchall() == [(42,)] con2 = duckdb.connect() con2.sql("create table t as select 21") - assert con2.table('t').fetchall() == [(21,)] + assert con2.table("t").fetchall() == [(21,)] # Change the db used by the module duckdb.set_default_connection(con2) - with pytest.raises(duckdb.CatalogException, match='Table with name d does not exist'): - con2.table('d').fetchall() + with pytest.raises(duckdb.CatalogException, match="Table with name d does not exist"): + con2.table("d").fetchall() - assert duckdb.table('t').fetchall() == [(21,)] + assert duckdb.table("t").fetchall() == [(21,)] duckdb.sql("create table d as select [1,2,3]") - assert duckdb.table('d').fetchall() == [([1, 2, 3],)] - assert con2.table('d').fetchall() == [([1, 2, 3],)] + assert duckdb.table("d").fetchall() == [([1, 2, 3],)] + assert con2.table("d").fetchall() == [([1, 2, 3],)] def test_set_default_connection_error(self, scoped_default): - with pytest.raises(TypeError, match='Invoked with: None'): + with pytest.raises(TypeError, match="Invoked with: None"): # set_default_connection does not allow None duckdb.set_default_connection(None) - with pytest.raises(TypeError, match='Invoked with: 5'): + with pytest.raises(TypeError, match="Invoked with: 5"): duckdb.set_default_connection(5) assert duckdb.sql("select 42").fetchall() == [(42,)] diff --git a/tests/fast/relational_api/test_rapi_windows.py b/tests/fast/relational_api/test_rapi_windows.py index 7c13debc..cc58b8f1 100644 --- a/tests/fast/relational_api/test_rapi_windows.py +++ b/tests/fast/relational_api/test_rapi_windows.py @@ -429,14 +429,14 @@ def test_bitstring_agg(self, table): .fetchall() ) expected = [ - (1, '0010000000000'), - (1, '0010000000000'), - (1, '0011000000000'), - (2, '0000000000001'), - (2, '0000000000011'), - (3, '0000001000000'), - (3, '1000001000000'), - (3, '1000001000000'), + (1, "0010000000000"), + (1, "0010000000000"), + (1, "0011000000000"), + (2, "0000000000001"), + (2, "0000000000011"), + (3, "0000001000000"), + (3, "1000001000000"), + (3, "1000001000000"), ] assert len(result) == len(expected) assert all([r == e for r, e in zip(result, expected)]) @@ -619,7 +619,7 @@ def test_string_agg(self, table): .execute() .fetchall() ) - expected = [(1, 'e'), (1, 'e/h'), (1, 'e/h/l'), (2, 'o'), (2, 'o/l'), (3, 'wor'), (3, 'wor/,'), (3, 'wor/,/ld')] + expected = [(1, "e"), (1, "e/h"), (1, "e/h/l"), (2, "o"), (2, "o/l"), (3, "wor"), (3, "wor/,"), (3, "wor/,/ld")] assert len(result) == len(expected) assert all([r == e for r, e in zip(result, expected)]) diff --git a/tests/fast/relational_api/test_table_function.py b/tests/fast/relational_api/test_table_function.py index 4f4a1016..5748f762 100644 --- a/tests/fast/relational_api/test_table_function.py +++ b/tests/fast/relational_api/test_table_function.py @@ -7,11 +7,11 @@ class TestTableFunction(object): def test_table_function(self, duckdb_cursor): - path = os.path.join(script_path, '..', 'data/integers.csv') - rel = duckdb_cursor.table_function('read_csv', [path]) + path = os.path.join(script_path, "..", "data/integers.csv") + rel = duckdb_cursor.table_function("read_csv", [path]) res = rel.fetchall() assert res == [(1, 10, 0), (2, 50, 30)] # Provide only a string as argument, should error, needs a list with pytest.raises(duckdb.InvalidInputException, match=r"'params' has to be a list of parameters"): - rel = duckdb_cursor.table_function('read_csv', path) + rel = duckdb_cursor.table_function("read_csv", path) diff --git a/tests/fast/spark/test_replace_column_value.py b/tests/fast/spark/test_replace_column_value.py index 33940616..65ab85f1 100644 --- a/tests/fast/spark/test_replace_column_value.py +++ b/tests/fast/spark/test_replace_column_value.py @@ -13,7 +13,7 @@ def test_replace_value(self, spark): # Replace part of string with another string from spark_namespace.sql.functions import regexp_replace - df2 = df.withColumn('address', regexp_replace('address', 'Rd', 'Road')) + df2 = df.withColumn("address", regexp_replace("address", "Rd", "Road")) # Replace string column value conditionally from spark_namespace.sql.functions import when @@ -21,24 +21,24 @@ def test_replace_value(self, spark): res = df2.collect() print(res) df2 = df.withColumn( - 'address', - when(df.address.endswith('Rd'), regexp_replace(df.address, 'Rd', 'Road')) - .when(df.address.endswith('St'), regexp_replace(df.address, 'St', 'Street')) - .when(df.address.endswith('Ave'), regexp_replace(df.address, 'Ave', 'Avenue')) + "address", + when(df.address.endswith("Rd"), regexp_replace(df.address, "Rd", "Road")) + .when(df.address.endswith("St"), regexp_replace(df.address, "St", "Street")) + .when(df.address.endswith("Ave"), regexp_replace(df.address, "Ave", "Avenue")) .otherwise(df.address), ) res = df2.collect() print(res) expected = [ - Row(id=1, address='14851 Jeffrey Road', state='DE'), - Row(id=2, address='43421 Margarita Street', state='NY'), - Row(id=3, address='13111 Siemon Avenue', state='CA'), + Row(id=1, address="14851 Jeffrey Road", state="DE"), + Row(id=2, address="43421 Margarita Street", state="NY"), + Row(id=3, address="13111 Siemon Avenue", state="CA"), ] print(expected) assert res == expected # Replace all substrings of the specified string value that match regexp with rep. - df3 = spark.createDataFrame([('100-200',)], ['str']) - res = df3.select(regexp_replace('str', r'(\d+)', '--').alias('d')).collect() - expected = [Row(d='-----')] + df3 = spark.createDataFrame([("100-200",)], ["str"]) + res = df3.select(regexp_replace("str", r"(\d+)", "--").alias("d")).collect() + expected = [Row(d="-----")] print(expected) assert res == expected diff --git a/tests/fast/spark/test_replace_empty_value.py b/tests/fast/spark/test_replace_empty_value.py index 71a9f25f..aad6a43e 100644 --- a/tests/fast/spark/test_replace_empty_value.py +++ b/tests/fast/spark/test_replace_empty_value.py @@ -12,32 +12,32 @@ def test_replace_empty(self, spark): # Create the dataframe data = [("", "CA"), ("Julia", ""), ("Robert", ""), ("", "NJ")] df = spark.createDataFrame(data, ["name", "state"]) - res = df.select('name').collect() - assert res == [Row(name=''), Row(name='Julia'), Row(name='Robert'), Row(name='')] - res = df.select('state').collect() - assert res == [Row(state='CA'), Row(state=''), Row(state=''), Row(state='NJ')] + res = df.select("name").collect() + assert res == [Row(name=""), Row(name="Julia"), Row(name="Robert"), Row(name="")] + res = df.select("state").collect() + assert res == [Row(state="CA"), Row(state=""), Row(state=""), Row(state="NJ")] # Replace name # CASE WHEN "name" == '' THEN NULL ELSE "name" END from spark_namespace.sql.functions import col, when df2 = df.withColumn("name", when(col("name") == "", None).otherwise(col("name"))) - assert df2.columns == ['name', 'state'] - res = df2.select('name').collect() - assert res == [Row(name=None), Row(name='Julia'), Row(name='Robert'), Row(name=None)] + assert df2.columns == ["name", "state"] + res = df2.select("name").collect() + assert res == [Row(name=None), Row(name="Julia"), Row(name="Robert"), Row(name=None)] # Replace state + name from spark_namespace.sql.functions import col, when df2 = df.select([when(col(c) == "", None).otherwise(col(c)).alias(c) for c in df.columns]) - assert df2.columns == ['name', 'state'] + assert df2.columns == ["name", "state"] key_f = lambda x: x.name or x.state res = df2.sort("name", "state").collect() expected_res = [ - Row(name=None, state='CA'), - Row(name=None, state='NJ'), - Row(name='Julia', state=None), - Row(name='Robert', state=None), + Row(name=None, state="CA"), + Row(name=None, state="NJ"), + Row(name="Julia", state=None), + Row(name="Robert", state=None), ] assert res == expected_res @@ -46,15 +46,15 @@ def test_replace_empty(self, spark): from spark_namespace.sql.functions import col, when replaceCols = ["state"] - df2 = df.select([when(col(c) == "", None).otherwise(col(c)).alias(c) for c in replaceCols]).sort(col('state')) - assert df2.columns == ['state'] + df2 = df.select([when(col(c) == "", None).otherwise(col(c)).alias(c) for c in replaceCols]).sort(col("state")) + assert df2.columns == ["state"] key_f = lambda x: x.state or "" res = df2.collect() assert sorted(res, key=key_f) == sorted( [ - Row(state='CA'), - Row(state='NJ'), + Row(state="CA"), + Row(state="NJ"), Row(state=None), Row(state=None), ], diff --git a/tests/fast/spark/test_spark_catalog.py b/tests/fast/spark/test_spark_catalog.py index 7f523abd..2ecaad24 100644 --- a/tests/fast/spark/test_spark_catalog.py +++ b/tests/fast/spark/test_spark_catalog.py @@ -13,9 +13,9 @@ def test_list_databases(self, spark): assert all(isinstance(db, Database) for db in dbs) else: assert dbs == [ - Database(name='memory', description=None, locationUri=''), - Database(name='system', description=None, locationUri=''), - Database(name='temp', description=None, locationUri=''), + Database(name="memory", description=None, locationUri=""), + Database(name="system", description=None, locationUri=""), + Database(name="temp", description=None, locationUri=""), ] def test_list_tables(self, spark): @@ -26,31 +26,31 @@ def test_list_tables(self, spark): if not USE_ACTUAL_SPARK: # Skip this if we're using actual Spark because we can't create tables # with our setup. - spark.sql('create table tbl(a varchar)') + spark.sql("create table tbl(a varchar)") tbls = spark.catalog.listTables() assert tbls == [ Table( - name='tbl', - database='memory', - description='CREATE TABLE tbl(a VARCHAR);', - tableType='', + name="tbl", + database="memory", + description="CREATE TABLE tbl(a VARCHAR);", + tableType="", isTemporary=False, ) ] @pytest.mark.skipif(USE_ACTUAL_SPARK, reason="We can't create tables with our Spark test setup") def test_list_columns(self, spark): - spark.sql('create table tbl(a varchar, b bool)') - columns = spark.catalog.listColumns('tbl') + spark.sql("create table tbl(a varchar, b bool)") + columns = spark.catalog.listColumns("tbl") assert columns == [ - Column(name='a', description=None, dataType='VARCHAR', nullable=True, isPartition=False, isBucket=False), - Column(name='b', description=None, dataType='BOOLEAN', nullable=True, isPartition=False, isBucket=False), + Column(name="a", description=None, dataType="VARCHAR", nullable=True, isPartition=False, isBucket=False), + Column(name="b", description=None, dataType="BOOLEAN", nullable=True, isPartition=False, isBucket=False), ] # FIXME: should this error instead? - non_existant_columns = spark.catalog.listColumns('none_existant') + non_existant_columns = spark.catalog.listColumns("none_existant") assert non_existant_columns == [] - spark.sql('create view vw as select * from tbl') - view_columns = spark.catalog.listColumns('vw') + spark.sql("create view vw as select * from tbl") + view_columns = spark.catalog.listColumns("vw") assert view_columns == columns diff --git a/tests/fast/spark/test_spark_column.py b/tests/fast/spark/test_spark_column.py index e56ba9ee..9ef17d95 100644 --- a/tests/fast/spark/test_spark_column.py +++ b/tests/fast/spark/test_spark_column.py @@ -18,26 +18,26 @@ def test_struct_column(self, spark): # FIXME: column names should be set explicitly using the Row, rather than letting duckdb assign defaults (col0, col1, etc..) if USE_ACTUAL_SPARK: - df = df.withColumn('struct', struct(df.a, df.b)) + df = df.withColumn("struct", struct(df.a, df.b)) else: - df = df.withColumn('struct', struct(df.col0, df.col1)) - assert 'struct' in df - new_col = df.schema['struct'] + df = df.withColumn("struct", struct(df.col0, df.col1)) + assert "struct" in df + new_col = df.schema["struct"] if USE_ACTUAL_SPARK: - assert 'a' in df.schema['struct'].dataType.fieldNames() - assert 'b' in df.schema['struct'].dataType.fieldNames() + assert "a" in df.schema["struct"].dataType.fieldNames() + assert "b" in df.schema["struct"].dataType.fieldNames() else: - assert 'col0' in new_col.dataType - assert 'col1' in new_col.dataType + assert "col0" in new_col.dataType + assert "col1" in new_col.dataType with pytest.raises( PySparkTypeError, match=re.escape("[NOT_COLUMN] Argument `col` should be a Column, got str.") ): - df = df.withColumn('struct', 'yes') + df = df.withColumn("struct", "yes") def test_array_column(self, spark): - df = spark.createDataFrame([Row(a=1, b=2, c=3, d=4)], ['a', 'b', 'c', 'd']) + df = spark.createDataFrame([Row(a=1, b=2, c=3, d=4)], ["a", "b", "c", "d"]) df2 = df.select( array(df["a"], df["b"]).alias("array"), diff --git a/tests/fast/spark/test_spark_dataframe.py b/tests/fast/spark/test_spark_dataframe.py index d88b03eb..e86995ec 100644 --- a/tests/fast/spark/test_spark_dataframe.py +++ b/tests/fast/spark/test_spark_dataframe.py @@ -36,9 +36,9 @@ def test_dataframe_from_list_of_tuples(self, spark): df = spark.createDataFrame(address, ["id", "address", "state"]) res = df.collect() assert res == [ - Row(id=1, address='14851 Jeffrey Rd', state='DE'), - Row(id=2, address='43421 Margarita St', state='NY'), - Row(id=3, address='13111 Siemon Ave', state='CA'), + Row(id=1, address="14851 Jeffrey Rd", state="DE"), + Row(id=2, address="43421 Margarita St", state="NY"), + Row(id=3, address="13111 Siemon Ave", state="CA"), ] # Tuples of different sizes @@ -93,9 +93,9 @@ def test_dataframe_from_list_of_tuples(self, spark): df = spark.createDataFrame(address, []) res = df.collect() assert res == [ - Row(col0=1, col1='14851 Jeffrey Rd', col2='DE'), - Row(col0=2, col1='43421 Margarita St', col2='NY'), - Row(col0=3, col1='13111 Siemon Ave', col2='DE'), + Row(col0=1, col1="14851 Jeffrey Rd", col2="DE"), + Row(col0=2, col1="43421 Margarita St", col2="NY"), + Row(col0=3, col1="13111 Siemon Ave", col2="DE"), ] # Too many column names @@ -107,17 +107,17 @@ def test_dataframe_from_list_of_tuples(self, spark): # Column names is not a list (but is iterable) if not USE_ACTUAL_SPARK: # These things do not work in Spark or throw different errors - df = spark.createDataFrame(address, {'a': 5, 'b': 6, 'c': 42}) + df = spark.createDataFrame(address, {"a": 5, "b": 6, "c": 42}) res = df.collect() assert res == [ - Row(a=1, b='14851 Jeffrey Rd', c='DE'), - Row(a=2, b='43421 Margarita St', c='NY'), - Row(a=3, b='13111 Siemon Ave', c='DE'), + Row(a=1, b="14851 Jeffrey Rd", c="DE"), + Row(a=2, b="43421 Margarita St", c="NY"), + Row(a=3, b="13111 Siemon Ave", c="DE"), ] # Column names is not a list (string, becomes a single column name) with pytest.raises(PySparkValueError, match="number of columns in the DataFrame don't match"): - df = spark.createDataFrame(address, 'a') + df = spark.createDataFrame(address, "a") with pytest.raises(TypeError, match="must be an iterable, not int"): df = spark.createDataFrame(address, 5) @@ -126,7 +126,7 @@ def test_dataframe(self, spark): # Create DataFrame df = spark.createDataFrame([("Scala", 25000), ("Spark", 35000), ("PHP", 21000)]) res = df.collect() - assert res == [Row(col0='Scala', col1=25000), Row(col0='Spark', col1=35000), Row(col0='PHP', col1=21000)] + assert res == [Row(col0="Scala", col1=25000), Row(col0="Spark", col1=35000), Row(col0="PHP", col1=21000)] @pytest.mark.skipif(USE_ACTUAL_SPARK, reason="We can't create tables with our Spark test setup") def test_writing_to_table(self, spark): @@ -136,18 +136,18 @@ def test_writing_to_table(self, spark): create table sample_table("_1" bool, "_2" integer) """ ) - spark.sql('insert into sample_table VALUES (True, 42)') + spark.sql("insert into sample_table VALUES (True, 42)") spark.table("sample_table").write.saveAsTable("sample_hive_table") df3 = spark.sql("SELECT _1,_2 FROM sample_hive_table") res = df3.collect() assert res == [Row(_1=True, _2=42)] schema = df3.schema - assert schema == StructType([StructField('_1', BooleanType(), True), StructField('_2', IntegerType(), True)]) + assert schema == StructType([StructField("_1", BooleanType(), True), StructField("_2", IntegerType(), True)]) def test_dataframe_collect(self, spark): - df = spark.createDataFrame([(42,), (21,)]).toDF('a') + df = spark.createDataFrame([(42,), (21,)]).toDF("a") res = df.collect() - assert str(res) == '[Row(a=42), Row(a=21)]' + assert str(res) == "[Row(a=42), Row(a=21)]" def test_dataframe_from_rows(self, spark): columns = ["language", "users_count"] @@ -157,17 +157,17 @@ def test_dataframe_from_rows(self, spark): df = spark.createDataFrame(rowData, columns) res = df.collect() assert res == [ - Row(language='Java', users_count='20000'), - Row(language='Python', users_count='100000'), - Row(language='Scala', users_count='3000'), + Row(language="Java", users_count="20000"), + Row(language="Python", users_count="100000"), + Row(language="Scala", users_count="3000"), ] def test_empty_df(self, spark): schema = StructType( [ - StructField('firstname', StringType(), True), - StructField('middlename', StringType(), True), - StructField('lastname', StringType(), True), + StructField("firstname", StringType(), True), + StructField("middlename", StringType(), True), + StructField("lastname", StringType(), True), ] ) df = spark.createDataFrame([], schema=schema) @@ -178,18 +178,18 @@ def test_empty_df(self, spark): def test_df_from_pandas(self, spark): import pandas as pd - df = spark.createDataFrame(pd.DataFrame({'a': [42, 21], 'b': [True, False]})) + df = spark.createDataFrame(pd.DataFrame({"a": [42, 21], "b": [True, False]})) res = df.collect() assert res == [Row(a=42, b=True), Row(a=21, b=False)] def test_df_from_struct_type(self, spark): - schema = StructType([StructField('a', LongType()), StructField('b', BooleanType())]) + schema = StructType([StructField("a", LongType()), StructField("b", BooleanType())]) df = spark.createDataFrame([(42, True), (21, False)], schema) res = df.collect() assert res == [Row(a=42, b=True), Row(a=21, b=False)] def test_df_from_name_list(self, spark): - df = spark.createDataFrame([(42, True), (21, False)], ['a', 'b']) + df = spark.createDataFrame([(42, True), (21, False)], ["a", "b"]) res = df.collect() assert res == [Row(a=42, b=True), Row(a=21, b=False)] @@ -218,11 +218,11 @@ def test_df_creation_coverage(self, spark): df = spark.createDataFrame(data=data2, schema=schema) res = df.collect() assert res == [ - Row(firstname='James', middlename='', lastname='Smith', id='36636', gender='M', salary=3000), - Row(firstname='Michael', middlename='Rose', lastname='', id='40288', gender='M', salary=4000), - Row(firstname='Robert', middlename='', lastname='Williams', id='42114', gender='M', salary=4000), - Row(firstname='Maria', middlename='Anne', lastname='Jones', id='39192', gender='F', salary=4000), - Row(firstname='Jen', middlename='Mary', lastname='Brown', id='', gender='F', salary=-1), + Row(firstname="James", middlename="", lastname="Smith", id="36636", gender="M", salary=3000), + Row(firstname="Michael", middlename="Rose", lastname="", id="40288", gender="M", salary=4000), + Row(firstname="Robert", middlename="", lastname="Williams", id="42114", gender="M", salary=4000), + Row(firstname="Maria", middlename="Anne", lastname="Jones", id="39192", gender="F", salary=4000), + Row(firstname="Jen", middlename="Mary", lastname="Brown", id="", gender="F", salary=-1), ] def test_df_nested_struct(self, spark): @@ -236,18 +236,18 @@ def test_df_nested_struct(self, spark): structureSchema = StructType( [ StructField( - 'name', + "name", StructType( [ - StructField('firstname', StringType(), True), - StructField('middlename', StringType(), True), - StructField('lastname', StringType(), True), + StructField("firstname", StringType(), True), + StructField("middlename", StringType(), True), + StructField("lastname", StringType(), True), ] ), ), - StructField('id', StringType(), True), - StructField('gender', StringType(), True), - StructField('salary', IntegerType(), True), + StructField("id", StringType(), True), + StructField("gender", StringType(), True), + StructField("salary", IntegerType(), True), ] ) @@ -255,24 +255,24 @@ def test_df_nested_struct(self, spark): res = df2.collect() expected_res = [ Row( - name={'firstname': 'James', 'middlename': '', 'lastname': 'Smith'}, id='36636', gender='M', salary=3100 + name={"firstname": "James", "middlename": "", "lastname": "Smith"}, id="36636", gender="M", salary=3100 ), Row( - name={'firstname': 'Michael', 'middlename': 'Rose', 'lastname': ''}, id='40288', gender='M', salary=4300 + name={"firstname": "Michael", "middlename": "Rose", "lastname": ""}, id="40288", gender="M", salary=4300 ), Row( - name={'firstname': 'Robert', 'middlename': '', 'lastname': 'Williams'}, - id='42114', - gender='M', + name={"firstname": "Robert", "middlename": "", "lastname": "Williams"}, + id="42114", + gender="M", salary=1400, ), Row( - name={'firstname': 'Maria', 'middlename': 'Anne', 'lastname': 'Jones'}, - id='39192', - gender='F', + name={"firstname": "Maria", "middlename": "Anne", "lastname": "Jones"}, + id="39192", + gender="F", salary=5500, ), - Row(name={'firstname': 'Jen', 'middlename': 'Mary', 'lastname': 'Brown'}, id='', gender='F', salary=-1), + Row(name={"firstname": "Jen", "middlename": "Mary", "lastname": "Brown"}, id="", gender="F", salary=-1), ] if USE_ACTUAL_SPARK: expected_res = [Row(name=Row(**r.name), id=r.id, gender=r.gender, salary=r.salary) for r in expected_res] @@ -281,19 +281,19 @@ def test_df_nested_struct(self, spark): assert schema == StructType( [ StructField( - 'name', + "name", StructType( [ - StructField('firstname', StringType(), True), - StructField('middlename', StringType(), True), - StructField('lastname', StringType(), True), + StructField("firstname", StringType(), True), + StructField("middlename", StringType(), True), + StructField("lastname", StringType(), True), ] ), True, ), - StructField('id', StringType(), True), - StructField('gender', StringType(), True), - StructField('salary', IntegerType(), True), + StructField("id", StringType(), True), + StructField("gender", StringType(), True), + StructField("salary", IntegerType(), True), ] ) @@ -310,18 +310,18 @@ def test_df_columns(self, spark): structureSchema = StructType( [ StructField( - 'name', + "name", StructType( [ - StructField('firstname', StringType(), True), - StructField('middlename', StringType(), True), - StructField('lastname', StringType(), True), + StructField("firstname", StringType(), True), + StructField("middlename", StringType(), True), + StructField("lastname", StringType(), True), ] ), ), - StructField('id', StringType(), True), - StructField('gender', StringType(), True), - StructField('salary', IntegerType(), True), + StructField("id", StringType(), True), + StructField("gender", StringType(), True), + StructField("salary", IntegerType(), True), ] ) @@ -339,7 +339,7 @@ def test_df_columns(self, spark): ), ).drop("id", "gender", "salary") - assert 'OtherInfo' in updatedDF.columns + assert "OtherInfo" in updatedDF.columns def test_array_and_map_type(self, spark): """Array & Map""" @@ -347,17 +347,17 @@ def test_array_and_map_type(self, spark): arrayStructureSchema = StructType( [ StructField( - 'name', + "name", StructType( [ - StructField('firstname', StringType(), True), - StructField('middlename', StringType(), True), - StructField('lastname', StringType(), True), + StructField("firstname", StringType(), True), + StructField("middlename", StringType(), True), + StructField("lastname", StringType(), True), ] ), ), - StructField('hobbies', ArrayType(StringType()), True), - StructField('properties', MapType(StringType(), StringType()), True), + StructField("hobbies", ArrayType(StringType()), True), + StructField("properties", MapType(StringType(), StringType()), True), ] ) diff --git a/tests/fast/spark/test_spark_dataframe_sort.py b/tests/fast/spark/test_spark_dataframe_sort.py index 20558197..db7dce4b 100644 --- a/tests/fast/spark/test_spark_dataframe_sort.py +++ b/tests/fast/spark/test_spark_dataframe_sort.py @@ -91,11 +91,11 @@ def test_sort_with_desc(self, spark): df = df.sort(desc("name")) res = df.collect() assert res == [ - Row(age=3, name='Dave'), - Row(age=56, name='Carol'), - Row(age=1, name='Ben'), - Row(age=3, name='Anna'), - Row(age=20, name='Alice'), + Row(age=3, name="Dave"), + Row(age=56, name="Carol"), + Row(age=1, name="Ben"), + Row(age=3, name="Anna"), + Row(age=20, name="Alice"), ] def test_sort_with_asc(self, spark): @@ -103,9 +103,9 @@ def test_sort_with_asc(self, spark): df = df.sort(asc("name")) res = df.collect() assert res == [ - Row(age=20, name='Alice'), - Row(age=3, name='Anna'), - Row(age=1, name='Ben'), - Row(age=56, name='Carol'), - Row(age=3, name='Dave'), + Row(age=20, name="Alice"), + Row(age=3, name="Anna"), + Row(age=1, name="Ben"), + Row(age=56, name="Carol"), + Row(age=3, name="Dave"), ] diff --git a/tests/fast/spark/test_spark_drop_duplicates.py b/tests/fast/spark/test_spark_drop_duplicates.py index 6dc7f573..563a5e76 100644 --- a/tests/fast/spark/test_spark_drop_duplicates.py +++ b/tests/fast/spark/test_spark_drop_duplicates.py @@ -34,15 +34,15 @@ def test_spark_drop_duplicates(self, method, spark): res = distinctDF.collect() # James | Sales had a duplicate, has been removed expected = [ - Row(employee_name='James', department='Sales', salary=3000), - Row(employee_name='Jeff', department='Marketing', salary=3000), - Row(employee_name='Jen', department='Finance', salary=3900), - Row(employee_name='Kumar', department='Marketing', salary=2000), - Row(employee_name='Maria', department='Finance', salary=3000), - Row(employee_name='Michael', department='Sales', salary=4600), - Row(employee_name='Robert', department='Sales', salary=4100), - Row(employee_name='Saif', department='Sales', salary=4100), - Row(employee_name='Scott', department='Finance', salary=3300), + Row(employee_name="James", department="Sales", salary=3000), + Row(employee_name="Jeff", department="Marketing", salary=3000), + Row(employee_name="Jen", department="Finance", salary=3900), + Row(employee_name="Kumar", department="Marketing", salary=2000), + Row(employee_name="Maria", department="Finance", salary=3000), + Row(employee_name="Michael", department="Sales", salary=4600), + Row(employee_name="Robert", department="Sales", salary=4100), + Row(employee_name="Saif", department="Sales", salary=4100), + Row(employee_name="Scott", department="Finance", salary=3300), ] assert res == expected @@ -52,14 +52,14 @@ def test_spark_drop_duplicates(self, method, spark): assert res2 == res expected_subset = [ - Row(department='Finance', salary=3000), - Row(department='Finance', salary=3300), - Row(department='Finance', salary=3900), - Row(department='Marketing', salary=2000), - Row(department='Marketing', salary=3000), - Row(epartment='Sales', salary=3000), - Row(department='Sales', salary=4100), - Row(department='Sales', salary=4600), + Row(department="Finance", salary=3000), + Row(department="Finance", salary=3300), + Row(department="Finance", salary=3900), + Row(department="Marketing", salary=2000), + Row(department="Marketing", salary=3000), + Row(epartment="Sales", salary=3000), + Row(department="Sales", salary=4100), + Row(department="Sales", salary=4600), ] dropDisDF = getattr(df, method)(["department", "salary"]).sort("department", "salary") diff --git a/tests/fast/spark/test_spark_except.py b/tests/fast/spark/test_spark_except.py index 434ac613..7c28cc29 100644 --- a/tests/fast/spark/test_spark_except.py +++ b/tests/fast/spark/test_spark_except.py @@ -19,7 +19,6 @@ def df2(spark): class TestDataFrameIntersect: def test_exceptAll(self, spark, df, df2): - df3 = df.exceptAll(df2).sort(*df.columns) res = df3.collect() diff --git a/tests/fast/spark/test_spark_filter.py b/tests/fast/spark/test_spark_filter.py index fb6f0b1a..a4733a44 100644 --- a/tests/fast/spark/test_spark_filter.py +++ b/tests/fast/spark/test_spark_filter.py @@ -35,18 +35,18 @@ def test_dataframe_filter(self, spark): schema = StructType( [ StructField( - 'name', + "name", StructType( [ - StructField('firstname', StringType(), True), - StructField('middlename', StringType(), True), - StructField('lastname', StringType(), True), + StructField("firstname", StringType(), True), + StructField("middlename", StringType(), True), + StructField("lastname", StringType(), True), ] ), ), - StructField('languages', ArrayType(StringType()), True), - StructField('state', StringType(), True), - StructField('gender', StringType(), True), + StructField("languages", ArrayType(StringType()), True), + StructField("state", StringType(), True), + StructField("gender", StringType(), True), ] ) @@ -57,51 +57,51 @@ def test_dataframe_filter(self, spark): # Using equals condition df2 = df.filter(df.state == "OH") res = df2.collect() - assert res[0].state == 'OH' + assert res[0].state == "OH" # not equals condition df2 = df.filter(df.state != "OH") df2 = df.filter(~(df.state == "OH")) res = df2.collect() for item in res: - assert item.state == 'NY' or item.state == 'CA' + assert item.state == "NY" or item.state == "CA" df2 = df.filter(col("state") == "OH") res = df2.collect() - assert res[0].state == 'OH' + assert res[0].state == "OH" df2 = df.filter("gender == 'M'") res = df2.collect() - assert res[0].gender == 'M' + assert res[0].gender == "M" df2 = df.filter("gender != 'M'") res = df2.collect() - assert res[0].gender == 'F' + assert res[0].gender == "F" df2 = df.filter("gender <> 'M'") res = df2.collect() - assert res[0].gender == 'F' + assert res[0].gender == "F" # Filter multiple condition df2 = df.filter((df.state == "OH") & (df.gender == "M")) res = df2.collect() assert len(res) == 2 for item in res: - assert item.gender == 'M' and item.state == 'OH' + assert item.gender == "M" and item.state == "OH" # Filter IS IN List values li = ["OH", "NY"] df2 = df.filter(df.state.isin(li)) res = df2.collect() for item in res: - assert item.state == 'OH' or item.state == 'NY' + assert item.state == "OH" or item.state == "NY" # Filter NOT IS IN List values # These show all records with NY (NY is not part of the list) df2 = df.filter(~df.state.isin(li)) res = df2.collect() for item in res: - assert item.state != 'OH' and item.state != 'NY' + assert item.state != "OH" and item.state != "NY" df2 = df.filter(df.state.isin(li) == False) res2 = df2.collect() @@ -111,19 +111,19 @@ def test_dataframe_filter(self, spark): df2 = df.filter(df.state.startswith("N")) res = df2.collect() for item in res: - assert item.state == 'NY' + assert item.state == "NY" # using endswith df2 = df.filter(df.state.endswith("H")) res = df2.collect() for item in res: - assert item.state == 'OH' + assert item.state == "OH" # contains df2 = df.filter(df.state.contains("H")) res = df2.collect() for item in res: - assert item.state == 'OH' + assert item.state == "OH" data2 = [(2, "Michael Rose"), (3, "Robert Williams"), (4, "Rames Rose"), (5, "Rames rose")] df2 = spark.createDataFrame(data=data2, schema=["id", "name"]) @@ -131,56 +131,56 @@ def test_dataframe_filter(self, spark): # like - SQL LIKE pattern df3 = df2.filter(df2.name.like("%rose%")) res = df3.collect() - assert res == [Row(id=5, name='Rames rose')] + assert res == [Row(id=5, name="Rames rose")] # rlike - SQL RLIKE pattern (LIKE with Regex) # This check case insensitive df3 = df2.filter(df2.name.rlike("(?i)^*rose$")) res = df3.collect() - assert res == [Row(id=2, name='Michael Rose'), Row(id=4, name='Rames Rose'), Row(id=5, name='Rames rose')] + assert res == [Row(id=2, name="Michael Rose"), Row(id=4, name="Rames Rose"), Row(id=5, name="Rames rose")] df2 = df.filter(array_contains(df.languages, "Java")) res = df2.collect() - james_name = {'firstname': 'James', 'middlename': '', 'lastname': 'Smith'} - anna_name = {'firstname': 'Anna', 'middlename': 'Rose', 'lastname': ''} + james_name = {"firstname": "James", "middlename": "", "lastname": "Smith"} + anna_name = {"firstname": "Anna", "middlename": "Rose", "lastname": ""} if USE_ACTUAL_SPARK: james_name = Row(**james_name) anna_name = Row(**anna_name) assert res == [ Row( name=james_name, - languages=['Java', 'Scala', 'C++'], - state='OH', - gender='M', + languages=["Java", "Scala", "C++"], + state="OH", + gender="M", ), Row( name=anna_name, - languages=['Spark', 'Java', 'C++'], - state='CA', - gender='F', + languages=["Spark", "Java", "C++"], + state="CA", + gender="F", ), ] df2 = df.filter(df.name.lastname == "Williams") res = df2.collect() - julia_name = {'firstname': 'Julia', 'middlename': '', 'lastname': 'Williams'} - mike_name = {'firstname': 'Mike', 'middlename': 'Mary', 'lastname': 'Williams'} + julia_name = {"firstname": "Julia", "middlename": "", "lastname": "Williams"} + mike_name = {"firstname": "Mike", "middlename": "Mary", "lastname": "Williams"} if USE_ACTUAL_SPARK: julia_name = Row(**julia_name) mike_name = Row(**mike_name) assert res == [ Row( name=julia_name, - languages=['CSharp', 'VB'], - state='OH', - gender='F', + languages=["CSharp", "VB"], + state="OH", + gender="F", ), Row( name=mike_name, - languages=['Python', 'VB'], - state='OH', - gender='M', + languages=["Python", "VB"], + state="OH", + gender="M", ), ] diff --git a/tests/fast/spark/test_spark_functions_array.py b/tests/fast/spark/test_spark_functions_array.py index f83e0ef2..5ecba132 100644 --- a/tests/fast/spark/test_spark_functions_array.py +++ b/tests/fast/spark/test_spark_functions_array.py @@ -75,7 +75,7 @@ def test_array_min(self, spark): ] def test_get(self, spark): - df = spark.createDataFrame([(["a", "b", "c"], 1)], ['data', 'index']) + df = spark.createDataFrame([(["a", "b", "c"], 1)], ["data", "index"]) res = df.select(sf.get(df.data, 1).alias("r")).collect() assert res == [Row(r="b")] @@ -87,25 +87,25 @@ def test_get(self, spark): assert res == [Row(r=None)] res = df.select(sf.get(df.data, "index").alias("r")).collect() - assert res == [Row(r='b')] + assert res == [Row(r="b")] res = df.select(sf.get(df.data, sf.col("index") - 1).alias("r")).collect() - assert res == [Row(r='a')] + assert res == [Row(r="a")] def test_flatten(self, spark): - df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],), ([None, [4, 5]],)], ['data']) + df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],), ([None, [4, 5]],)], ["data"]) res = df.select(sf.flatten(df.data).alias("r")).collect() assert res == [Row(r=[1, 2, 3, 4, 5, 6]), Row(r=None)] def test_array_compact(self, spark): - df = spark.createDataFrame([([1, None, 2, 3],), ([4, 5, None, 4],)], ['data']) + df = spark.createDataFrame([([1, None, 2, 3],), ([4, 5, None, 4],)], ["data"]) res = df.select(sf.array_compact(df.data).alias("v")).collect() assert [Row(v=[1, 2, 3]), Row(v=[4, 5, 4])] def test_array_remove(self, spark): - df = spark.createDataFrame([([1, 2, 3, 1, 1],), ([],)], ['data']) + df = spark.createDataFrame([([1, 2, 3, 1, 1],), ([],)], ["data"]) res = df.select(sf.array_remove(df.data, 1).alias("v")).collect() assert res == [Row(v=[2, 3]), Row(v=[])] @@ -126,101 +126,101 @@ def test_array_append(self, spark): df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2="c")], ["c1", "c2"]) res = df.select(sf.array_append(df.c1, df.c2).alias("r")).collect() - assert res == [Row(r=['b', 'a', 'c', 'c'])] + assert res == [Row(r=["b", "a", "c", "c"])] - res = df.select(sf.array_append(df.c1, 'x')).collect() - assert res == [Row(r=['b', 'a', 'c', 'x'])] + res = df.select(sf.array_append(df.c1, "x")).collect() + assert res == [Row(r=["b", "a", "c", "x"])] def test_array_insert(self, spark): df = spark.createDataFrame( - [(['a', 'b', 'c'], 2, 'd'), (['a', 'b', 'c', 'e'], 2, 'd'), (['c', 'b', 'a'], -2, 'd')], - ['data', 'pos', 'val'], + [(["a", "b", "c"], 2, "d"), (["a", "b", "c", "e"], 2, "d"), (["c", "b", "a"], -2, "d")], + ["data", "pos", "val"], ) - res = df.select(sf.array_insert(df.data, df.pos.cast('integer'), df.val).alias('data')).collect() + res = df.select(sf.array_insert(df.data, df.pos.cast("integer"), df.val).alias("data")).collect() assert res == [ - Row(data=['a', 'd', 'b', 'c']), - Row(data=['a', 'd', 'b', 'c', 'e']), - Row(data=['c', 'b', 'd', 'a']), + Row(data=["a", "d", "b", "c"]), + Row(data=["a", "d", "b", "c", "e"]), + Row(data=["c", "b", "d", "a"]), ] - res = df.select(sf.array_insert(df.data, 5, 'hello').alias('data')).collect() + res = df.select(sf.array_insert(df.data, 5, "hello").alias("data")).collect() assert res == [ - Row(data=['a', 'b', 'c', None, 'hello']), - Row(data=['a', 'b', 'c', 'e', 'hello']), - Row(data=['c', 'b', 'a', None, 'hello']), + Row(data=["a", "b", "c", None, "hello"]), + Row(data=["a", "b", "c", "e", "hello"]), + Row(data=["c", "b", "a", None, "hello"]), ] - res = df.select(sf.array_insert(df.data, -5, 'hello').alias('data')).collect() + res = df.select(sf.array_insert(df.data, -5, "hello").alias("data")).collect() assert res == [ - Row(data=['hello', None, 'a', 'b', 'c']), - Row(data=['hello', 'a', 'b', 'c', 'e']), - Row(data=['hello', None, 'c', 'b', 'a']), + Row(data=["hello", None, "a", "b", "c"]), + Row(data=["hello", "a", "b", "c", "e"]), + Row(data=["hello", None, "c", "b", "a"]), ] def test_slice(self, spark): - df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ['x']) + df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ["x"]) res = df.select(sf.slice(df.x, 2, 2).alias("sliced")).collect() assert res == [Row(sliced=[2, 3]), Row(sliced=[5])] def test_sort_array(self, spark): - df = spark.createDataFrame([([2, 1, None, 3],), ([1],), ([],)], ['data']) + df = spark.createDataFrame([([2, 1, None, 3],), ([1],), ([],)], ["data"]) - res = df.select(sf.sort_array(df.data).alias('r')).collect() + res = df.select(sf.sort_array(df.data).alias("r")).collect() assert res == [Row(r=[None, 1, 2, 3]), Row(r=[1]), Row(r=[])] - res = df.select(sf.sort_array(df.data, asc=False).alias('r')).collect() + res = df.select(sf.sort_array(df.data, asc=False).alias("r")).collect() assert res == [Row(r=[3, 2, 1, None]), Row(r=[1]), Row(r=[])] @pytest.mark.parametrize(("null_replacement", "expected_joined_2"), [(None, "a"), ("replaced", "a,replaced")]) def test_array_join(self, spark, null_replacement, expected_joined_2): - df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], ['data']) + df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], ["data"]) res = df.select(sf.array_join(df.data, ",", null_replacement=null_replacement).alias("joined")).collect() - assert res == [Row(joined='a,b,c'), Row(joined=expected_joined_2)] + assert res == [Row(joined="a,b,c"), Row(joined=expected_joined_2)] def test_array_position(self, spark): - df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ['data']) + df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ["data"]) res = df.select(sf.array_position(df.data, "a").alias("pos")).collect() assert res == [Row(pos=3), Row(pos=0)] def test_array_preprend(self, spark): - df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) + df = spark.createDataFrame([([2, 3, 4],), ([],)], ["data"]) res = df.select(sf.array_prepend(df.data, 1).alias("pre")).collect() assert res == [Row(pre=[1, 2, 3, 4]), Row(pre=[1])] def test_array_repeat(self, spark): - df = spark.createDataFrame([('ab',)], ['data']) + df = spark.createDataFrame([("ab",)], ["data"]) - res = df.select(sf.array_repeat(df.data, 3).alias('r')).collect() - assert res == [Row(r=['ab', 'ab', 'ab'])] + res = df.select(sf.array_repeat(df.data, 3).alias("r")).collect() + assert res == [Row(r=["ab", "ab", "ab"])] def test_array_size(self, spark): - df = spark.createDataFrame([([2, 1, 3],), (None,)], ['data']) + df = spark.createDataFrame([([2, 1, 3],), (None,)], ["data"]) - res = df.select(sf.array_size(df.data).alias('r')).collect() + res = df.select(sf.array_size(df.data).alias("r")).collect() assert res == [Row(r=3), Row(r=None)] def test_array_sort(self, spark): - df = spark.createDataFrame([([2, 1, None, 3],), ([1],), ([],)], ['data']) + df = spark.createDataFrame([([2, 1, None, 3],), ([1],), ([],)], ["data"]) - res = df.select(sf.array_sort(df.data).alias('r')).collect() + res = df.select(sf.array_sort(df.data).alias("r")).collect() assert res == [Row(r=[1, 2, 3, None]), Row(r=[1]), Row(r=[])] def test_arrays_overlap(self, spark): df = spark.createDataFrame( - [(["a", "b"], ["b", "c"]), (["a"], ["b", "c"]), ([None, "c"], ["a"]), ([None, "c"], [None])], ['x', 'y'] + [(["a", "b"], ["b", "c"]), (["a"], ["b", "c"]), ([None, "c"], ["a"]), ([None, "c"], [None])], ["x", "y"] ) res = df.select(sf.arrays_overlap(df.x, df.y).alias("overlap")).collect() assert res == [Row(overlap=True), Row(overlap=False), Row(overlap=None), Row(overlap=None)] def test_arrays_zip(self, spark): - df = spark.createDataFrame([([1, 2, 3], [2, 4, 6], [3, 6])], ['vals1', 'vals2', 'vals3']) + df = spark.createDataFrame([([1, 2, 3], [2, 4, 6], [3, 6])], ["vals1", "vals2", "vals3"]) - res = df.select(sf.arrays_zip(df.vals1, df.vals2, df.vals3).alias('zipped')).collect() + res = df.select(sf.arrays_zip(df.vals1, df.vals2, df.vals3).alias("zipped")).collect() # FIXME: The structure of the results should be the same if USE_ACTUAL_SPARK: assert res == [ diff --git a/tests/fast/spark/test_spark_functions_base64.py b/tests/fast/spark/test_spark_functions_base64.py index 734a5275..5a179481 100644 --- a/tests/fast/spark/test_spark_functions_base64.py +++ b/tests/fast/spark/test_spark_functions_base64.py @@ -40,4 +40,4 @@ def test_unbase64(self, spark): .select("decoded_value") .collect() ) - assert res[0].decoded_value == b'quack' + assert res[0].decoded_value == b"quack" diff --git a/tests/fast/spark/test_spark_functions_date.py b/tests/fast/spark/test_spark_functions_date.py index 2a51d9b8..a298c0ff 100644 --- a/tests/fast/spark/test_spark_functions_date.py +++ b/tests/fast/spark/test_spark_functions_date.py @@ -145,43 +145,43 @@ def test_second(self, spark): assert result[0].second_num == 45 def test_unix_date(self, spark): - df = spark.createDataFrame([('1970-01-02',)], ['t']) - res = df.select(F.unix_date(df.t.cast("date")).alias('n')).collect() + df = spark.createDataFrame([("1970-01-02",)], ["t"]) + res = df.select(F.unix_date(df.t.cast("date")).alias("n")).collect() assert res == [Row(n=1)] def test_unix_micros(self, spark): - df = spark.createDataFrame([('2015-07-22 10:00:00+00:00',)], ['t']) - res = df.select(F.unix_micros(df.t.cast("timestamp")).alias('n')).collect() + df = spark.createDataFrame([("2015-07-22 10:00:00+00:00",)], ["t"]) + res = df.select(F.unix_micros(df.t.cast("timestamp")).alias("n")).collect() assert res == [Row(n=1437559200000000)] def test_unix_millis(self, spark): - df = spark.createDataFrame([('2015-07-22 10:00:00+00:00',)], ['t']) - res = df.select(F.unix_millis(df.t.cast("timestamp")).alias('n')).collect() + df = spark.createDataFrame([("2015-07-22 10:00:00+00:00",)], ["t"]) + res = df.select(F.unix_millis(df.t.cast("timestamp")).alias("n")).collect() assert res == [Row(n=1437559200000)] def test_unix_seconds(self, spark): - df = spark.createDataFrame([('2015-07-22 10:00:00+00:00',)], ['t']) - res = df.select(F.unix_seconds(df.t.cast("timestamp")).alias('n')).collect() + df = spark.createDataFrame([("2015-07-22 10:00:00+00:00",)], ["t"]) + res = df.select(F.unix_seconds(df.t.cast("timestamp")).alias("n")).collect() assert res == [Row(n=1437559200)] def test_weekday(self, spark): - df = spark.createDataFrame([('2015-04-08',)], ['dt']) - res = df.select(F.weekday(df.dt.cast("date")).alias('day')).collect() + df = spark.createDataFrame([("2015-04-08",)], ["dt"]) + res = df.select(F.weekday(df.dt.cast("date")).alias("day")).collect() assert res == [Row(day=2)] def test_to_date(self, spark): - df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - res = df.select(F.to_date(df.t).alias('date')).collect() + df = spark.createDataFrame([("1997-02-28 10:30:00",)], ["t"]) + res = df.select(F.to_date(df.t).alias("date")).collect() assert res == [Row(date=date(1997, 2, 28))] def test_to_timestamp(self, spark): - df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - res = df.select(F.to_timestamp(df.t).alias('dt')).collect() + df = spark.createDataFrame([("1997-02-28 10:30:00",)], ["t"]) + res = df.select(F.to_timestamp(df.t).alias("dt")).collect() assert res == [Row(dt=datetime(1997, 2, 28, 10, 30))] def test_to_timestamp_ltz(self, spark): df = spark.createDataFrame([("2016-12-31",)], ["e"]) - res = df.select(F.to_timestamp_ltz(df.e).alias('r')).collect() + res = df.select(F.to_timestamp_ltz(df.e).alias("r")).collect() assert res == [Row(r=datetime(2016, 12, 31, 0, 0))] @@ -194,15 +194,15 @@ def test_to_timestamp_ntz(self, spark): if USE_ACTUAL_SPARK: with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=DeprecationWarning) - res = df.select(F.to_timestamp_ntz(df.e).alias('r')).collect() + res = df.select(F.to_timestamp_ntz(df.e).alias("r")).collect() else: - res = df.select(F.to_timestamp_ntz(df.e).alias('r')).collect() + res = df.select(F.to_timestamp_ntz(df.e).alias("r")).collect() assert res == [Row(r=datetime(2016, 4, 8, 0, 0))] def test_last_day(self, spark): - df = spark.createDataFrame([('1997-02-10',)], ['d']) + df = spark.createDataFrame([("1997-02-10",)], ["d"]) - res = df.select(F.last_day(df.d.cast("date")).alias('date')).collect() + res = df.select(F.last_day(df.d.cast("date")).alias("date")).collect() assert res == [Row(date=date(1997, 2, 28))] def test_add_months(self, spark): @@ -219,12 +219,12 @@ def test_add_months(self, spark): assert result[0].with_col == date(2024, 7, 12) def test_date_diff(self, spark): - df = spark.createDataFrame([('2015-04-08', '2015-05-10')], ["d1", "d2"]) + df = spark.createDataFrame([("2015-04-08", "2015-05-10")], ["d1", "d2"]) - result_data = df.select(F.date_diff(col("d2").cast('DATE'), col("d1").cast('DATE')).alias("diff")).collect() + result_data = df.select(F.date_diff(col("d2").cast("DATE"), col("d1").cast("DATE")).alias("diff")).collect() assert result_data[0]["diff"] == -32 - result_data = df.select(F.date_diff(col("d1").cast('DATE'), col("d2").cast('DATE')).alias("diff")).collect() + result_data = df.select(F.date_diff(col("d1").cast("DATE"), col("d2").cast("DATE")).alias("diff")).collect() assert result_data[0]["diff"] == 32 def test_try_to_timestamp(self, spark): @@ -239,4 +239,4 @@ def test_try_to_timestamp_with_format(self, spark): res = df.select(F.try_to_timestamp(df.t, format=F.lit("%Y-%m-%d %H:%M:%S")).alias("dt")).collect() assert res[0].dt == datetime(1997, 2, 28, 10, 30) assert res[1].dt is None - assert res[2].dt is None \ No newline at end of file + assert res[2].dt is None diff --git a/tests/fast/spark/test_spark_functions_hex.py b/tests/fast/spark/test_spark_functions_hex.py index e5cbf12f..7d5f3c6a 100644 --- a/tests/fast/spark/test_spark_functions_hex.py +++ b/tests/fast/spark/test_spark_functions_hex.py @@ -20,7 +20,7 @@ def test_hex_string_col(self, spark): def test_hex_binary_col(self, spark): data = [ - (b'quack',), + (b"quack",), ] res = ( spark.createDataFrame(data, ["firstColumn"]) @@ -65,4 +65,4 @@ def test_unhex(self, spark): .select("unhex_value") .collect() ) - assert res[0].unhex_value == b'quack' + assert res[0].unhex_value == b"quack" diff --git a/tests/fast/spark/test_spark_functions_miscellaneous.py b/tests/fast/spark/test_spark_functions_miscellaneous.py index 87b6b776..f6af47fe 100644 --- a/tests/fast/spark/test_spark_functions_miscellaneous.py +++ b/tests/fast/spark/test_spark_functions_miscellaneous.py @@ -30,38 +30,38 @@ def test_call_function(self, spark): ] def test_octet_length(self, spark): - df = spark.createDataFrame([('cat',)], ['c1']) - res = df.select(F.octet_length('c1').alias("o")).collect() + df = spark.createDataFrame([("cat",)], ["c1"]) + res = df.select(F.octet_length("c1").alias("o")).collect() assert res == [Row(o=3)] def test_positive(self, spark): - df = spark.createDataFrame([(-1,), (0,), (1,)], ['v']) + df = spark.createDataFrame([(-1,), (0,), (1,)], ["v"]) res = df.select(F.positive("v").alias("p")).collect() assert res == [Row(p=-1), Row(p=0), Row(p=1)] def test_sequence(self, spark): - df1 = spark.createDataFrame([(-2, 2)], ('C1', 'C2')) - res = df1.select(F.sequence('C1', 'C2').alias('r')).collect() + df1 = spark.createDataFrame([(-2, 2)], ("C1", "C2")) + res = df1.select(F.sequence("C1", "C2").alias("r")).collect() assert res == [Row(r=[-2, -1, 0, 1, 2])] - df2 = spark.createDataFrame([(4, -4, -2)], ('C1', 'C2', 'C3')) - res = df2.select(F.sequence('C1', 'C2', 'C3').alias('r')).collect() + df2 = spark.createDataFrame([(4, -4, -2)], ("C1", "C2", "C3")) + res = df2.select(F.sequence("C1", "C2", "C3").alias("r")).collect() assert res == [Row(r=[4, 2, 0, -2, -4])] def test_like(self, spark): - df = spark.createDataFrame([("Spark", "_park")], ['a', 'b']) - res = df.select(F.like(df.a, df.b).alias('r')).collect() + df = spark.createDataFrame([("Spark", "_park")], ["a", "b"]) + res = df.select(F.like(df.a, df.b).alias("r")).collect() assert res == [Row(r=True)] - df = spark.createDataFrame([("%SystemDrive%/Users/John", "/%SystemDrive/%//Users%")], ['a', 'b']) - res = df.select(F.like(df.a, df.b, F.lit('/')).alias('r')).collect() + df = spark.createDataFrame([("%SystemDrive%/Users/John", "/%SystemDrive/%//Users%")], ["a", "b"]) + res = df.select(F.like(df.a, df.b, F.lit("/")).alias("r")).collect() assert res == [Row(r=True)] def test_ilike(self, spark): - df = spark.createDataFrame([("Spark", "spark")], ['a', 'b']) - res = df.select(F.ilike(df.a, df.b).alias('r')).collect() + df = spark.createDataFrame([("Spark", "spark")], ["a", "b"]) + res = df.select(F.ilike(df.a, df.b).alias("r")).collect() assert res == [Row(r=True)] - df = spark.createDataFrame([("%SystemDrive%/Users/John", "/%SystemDrive/%//Users%")], ['a', 'b']) - res = df.select(F.ilike(df.a, df.b, F.lit('/')).alias('r')).collect() + df = spark.createDataFrame([("%SystemDrive%/Users/John", "/%SystemDrive/%//Users%")], ["a", "b"]) + res = df.select(F.ilike(df.a, df.b, F.lit("/")).alias("r")).collect() assert res == [Row(r=True)] diff --git a/tests/fast/spark/test_spark_functions_null.py b/tests/fast/spark/test_spark_functions_null.py index 3f5ee31b..230634dc 100644 --- a/tests/fast/spark/test_spark_functions_null.py +++ b/tests/fast/spark/test_spark_functions_null.py @@ -62,7 +62,7 @@ def test_nvl2(self, spark): ], ["a", "b", "c"], ) - res = df.select(F.nvl2(df.a, df.b, df.c).alias('r')).collect() + res = df.select(F.nvl2(df.a, df.b, df.c).alias("r")).collect() assert res == [Row(r=6), Row(r=9)] def test_ifnull(self, spark): @@ -92,7 +92,7 @@ def test_nullif(self, spark): ], ["a", "b"], ) - res = df.select(F.nullif(df.a, df.b).alias('r')).collect() + res = df.select(F.nullif(df.a, df.b).alias("r")).collect() assert res == [Row(r=None), Row(r=1)] def test_isnull(self, spark): @@ -116,4 +116,4 @@ def test_isnotnull(self, spark): def test_equal_null(self, spark): df = spark.createDataFrame([(1, 1), (None, 2), (None, None)], ("a", "b")) res = df.select(F.equal_null("a", F.col("b")).alias("r")).collect() - assert res == [Row(r=True), Row(r=False), Row(r=True)] \ No newline at end of file + assert res == [Row(r=True), Row(r=False), Row(r=True)] diff --git a/tests/fast/spark/test_spark_functions_numeric.py b/tests/fast/spark/test_spark_functions_numeric.py index 9c4bafb9..3548d439 100644 --- a/tests/fast/spark/test_spark_functions_numeric.py +++ b/tests/fast/spark/test_spark_functions_numeric.py @@ -301,7 +301,7 @@ def test_corr(self, spark): # Have to use a groupby to test this as agg is not yet implemented without df = spark.createDataFrame(zip(a, b, ["group1"] * N), ["a", "b", "g"]) - res = df.groupBy("g").agg(sf.corr("a", "b").alias('c')).collect() + res = df.groupBy("g").agg(sf.corr("a", "b").alias("c")).collect() assert pytest.approx(res[0].c) == 1 def test_cot(self, spark): @@ -330,7 +330,7 @@ def test_pow(self, spark): def test_random(self, spark): df = spark.range(0, 2, 1) - res = df.withColumn('rand', sf.rand()).collect() + res = df.withColumn("rand", sf.rand()).collect() assert isinstance(res[0].rand, float) assert res[0].rand >= 0 and res[0].rand < 1 @@ -355,4 +355,4 @@ def test_negative(self, spark): res = df.collect() assert res[0].value == 0 assert res[1].value == -2 - assert res[2].value == -3 \ No newline at end of file + assert res[2].value == -3 diff --git a/tests/fast/spark/test_spark_functions_string.py b/tests/fast/spark/test_spark_functions_string.py index e90cca11..b8d7f483 100644 --- a/tests/fast/spark/test_spark_functions_string.py +++ b/tests/fast/spark/test_spark_functions_string.py @@ -152,47 +152,47 @@ def test_btrim(self, spark): "SL", ) ], - ['a', 'b'], + ["a", "b"], ) - res = df.select(F.btrim(df.a, df.b).alias('r')).collect() - assert res == [Row(r='parkSQ')] + res = df.select(F.btrim(df.a, df.b).alias("r")).collect() + assert res == [Row(r="parkSQ")] - df = spark.createDataFrame([(" SparkSQL ",)], ['a']) - res = df.select(F.btrim(df.a).alias('r')).collect() - assert res == [Row(r='SparkSQL')] + df = spark.createDataFrame([(" SparkSQL ",)], ["a"]) + res = df.select(F.btrim(df.a).alias("r")).collect() + assert res == [Row(r="SparkSQL")] def test_char(self, spark): df = spark.createDataFrame( [(65,), (65 + 256,), (66 + 256,)], [ - 'a', + "a", ], ) - res = df.select(F.char(df.a).alias('ch')).collect() - assert res == [Row(ch='A'), Row(ch='A'), Row(ch='B')] + res = df.select(F.char(df.a).alias("ch")).collect() + assert res == [Row(ch="A"), Row(ch="A"), Row(ch="B")] def test_encode(self, spark): - df = spark.createDataFrame([('abcd',)], ['c']) + df = spark.createDataFrame([("abcd",)], ["c"]) res = df.select(F.encode("c", "UTF-8").alias("encoded")).collect() # FIXME: Should return the same type if USE_ACTUAL_SPARK: - assert res == [Row(encoded=bytearray(b'abcd'))] + assert res == [Row(encoded=bytearray(b"abcd"))] else: - assert res == [Row(encoded=b'abcd')] + assert res == [Row(encoded=b"abcd")] def test_split(self, spark): df = spark.createDataFrame( - [('oneAtwoBthreeC',)], + [("oneAtwoBthreeC",)], [ - 's', + "s", ], ) - res = df.select(F.split(df.s, '[ABC]').alias('s')).collect() - assert res == [Row(s=['one', 'two', 'three', ''])] + res = df.select(F.split(df.s, "[ABC]").alias("s")).collect() + assert res == [Row(s=["one", "two", "three", ""])] def test_split_part(self, spark): df = spark.createDataFrame( @@ -206,8 +206,8 @@ def test_split_part(self, spark): ["a", "b", "c"], ) - res = df.select(F.split_part(df.a, df.b, df.c).alias('r')).collect() - assert res == [Row(r='13')] + res = df.select(F.split_part(df.a, df.b, df.c).alias("r")).collect() + assert res == [Row(r="13")] # If any input is null, should return null df = spark.createDataFrame( @@ -225,8 +225,8 @@ def test_split_part(self, spark): ], ["a", "b", "c"], ) - res = df.select(F.split_part(df.a, df.b, df.c).alias('r')).collect() - assert res == [Row(r=None), Row(r='11')] + res = df.select(F.split_part(df.a, df.b, df.c).alias("r")).collect() + assert res == [Row(r=None), Row(r="11")] # If partNum is out of range, should return an empty string df = spark.createDataFrame( @@ -239,8 +239,8 @@ def test_split_part(self, spark): ], ["a", "b", "c"], ) - res = df.select(F.split_part(df.a, df.b, df.c).alias('r')).collect() - assert res == [Row(r='')] + res = df.select(F.split_part(df.a, df.b, df.c).alias("r")).collect() + assert res == [Row(r="")] # If partNum is negative, parts are counted backwards df = spark.createDataFrame( @@ -253,8 +253,8 @@ def test_split_part(self, spark): ], ["a", "b", "c"], ) - res = df.select(F.split_part(df.a, df.b, df.c).alias('r')).collect() - assert res == [Row(r='13')] + res = df.select(F.split_part(df.a, df.b, df.c).alias("r")).collect() + assert res == [Row(r="13")] # If the delimiter is an empty string, the return should be empty df = spark.createDataFrame( @@ -267,8 +267,8 @@ def test_split_part(self, spark): ], ["a", "b", "c"], ) - res = df.select(F.split_part(df.a, df.b, df.c).alias('r')).collect() - assert res == [Row(r='')] + res = df.select(F.split_part(df.a, df.b, df.c).alias("r")).collect() + assert res == [Row(r="")] def test_substr(self, spark): df = spark.createDataFrame( @@ -282,7 +282,7 @@ def test_substr(self, spark): ["a", "b", "c"], ) res = df.select(F.substr("a", "b", "c").alias("s")).collect() - assert res == [Row(s='k')] + assert res == [Row(s="k")] df = spark.createDataFrame( [ @@ -295,21 +295,21 @@ def test_substr(self, spark): ["a", "b", "c"], ) res = df.select(F.substr("a", "b").alias("s")).collect() - assert res == [Row(s='k SQL')] + assert res == [Row(s="k SQL")] def test_find_in_set(self, spark): string_array = "abc,b,ab,c,def" - df = spark.createDataFrame([("ab", string_array), ("b,c", string_array), ("z", string_array)], ['a', 'b']) + df = spark.createDataFrame([("ab", string_array), ("b,c", string_array), ("z", string_array)], ["a", "b"]) - res = df.select(F.find_in_set(df.a, df.b).alias('r')).collect() + res = df.select(F.find_in_set(df.a, df.b).alias("r")).collect() assert res == [Row(r=3), Row(r=0), Row(r=0)] def test_initcap(self, spark): - df = spark.createDataFrame([('ab cd',)], ['a']) + df = spark.createDataFrame([("ab cd",)], ["a"]) - res = df.select(F.initcap("a").alias('v')).collect() - assert res == [Row(v='Ab Cd')] + res = df.select(F.initcap("a").alias("v")).collect() + assert res == [Row(v="Ab Cd")] def test_left(self, spark): df = spark.createDataFrame( @@ -327,11 +327,11 @@ def test_left(self, spark): -3, ), ], - ['a', 'b'], + ["a", "b"], ) - res = df.select(F.left(df.a, df.b).alias('r')).collect() - assert res == [Row(r='Spa'), Row(r=''), Row(r='')] + res = df.select(F.left(df.a, df.b).alias("r")).collect() + assert res == [Row(r="Spa"), Row(r=""), Row(r="")] def test_right(self, spark): df = spark.createDataFrame( @@ -349,39 +349,39 @@ def test_right(self, spark): -3, ), ], - ['a', 'b'], + ["a", "b"], ) - res = df.select(F.right(df.a, df.b).alias('r')).collect() - assert res == [Row(r='SQL'), Row(r=''), Row(r='')] + res = df.select(F.right(df.a, df.b).alias("r")).collect() + assert res == [Row(r="SQL"), Row(r=""), Row(r="")] def test_levenshtein(self, spark): - df = spark.createDataFrame([("kitten", "sitting"), ("saturdays", "sunday")], ['a', 'b']) + df = spark.createDataFrame([("kitten", "sitting"), ("saturdays", "sunday")], ["a", "b"]) - res = df.select(F.levenshtein(df.a, df.b).alias('r'), F.levenshtein(df.a, df.b, 3).alias('r_th')).collect() + res = df.select(F.levenshtein(df.a, df.b).alias("r"), F.levenshtein(df.a, df.b, 3).alias("r_th")).collect() assert res == [Row(r=3, r_th=3), Row(r=4, r_th=-1)] def test_lpad(self, spark): df = spark.createDataFrame( - [('abcd',)], + [("abcd",)], [ - 's', + "s", ], ) - res = df.select(F.lpad(df.s, 6, '#').alias('s')).collect() - assert res == [Row(s='##abcd')] + res = df.select(F.lpad(df.s, 6, "#").alias("s")).collect() + assert res == [Row(s="##abcd")] def test_rpad(self, spark): df = spark.createDataFrame( - [('abcd',)], + [("abcd",)], [ - 's', + "s", ], ) - res = df.select(F.rpad(df.s, 6, '#').alias('s')).collect() - assert res == [Row(s='abcd##')] + res = df.select(F.rpad(df.s, 6, "#").alias("s")).collect() + assert res == [Row(s="abcd##")] def test_printf(self, spark): df = spark.createDataFrame( @@ -395,79 +395,79 @@ def test_printf(self, spark): ["a", "b", "c"], ) res = df.select(F.printf("a", "b", "c").alias("r")).collect() - assert res == [Row(r='aa123cc')] + assert res == [Row(r="aa123cc")] @pytest.mark.parametrize("regexp_func", [F.regexp, F.regexp_like]) def test_regexp_and_regexp_like(self, spark, regexp_func): df = spark.createDataFrame([("1a 2b 14m", r"(\d+)")], ["str", "regexp"]) - res = df.select(regexp_func('str', F.lit(r'(\d+)')).alias("m")).collect() + res = df.select(regexp_func("str", F.lit(r"(\d+)")).alias("m")).collect() assert res[0].m is True df = spark.createDataFrame([("1a 2b 14m", r"(\d+)")], ["str", "regexp"]) - res = df.select(regexp_func('str', F.lit(r'\d{2}b')).alias("m")).collect() + res = df.select(regexp_func("str", F.lit(r"\d{2}b")).alias("m")).collect() assert res[0].m is False df = spark.createDataFrame([("1a 2b 14m", r"(\d+)")], ["str", "regexp"]) - res = df.select(regexp_func('str', F.col("regexp")).alias("m")).collect() + res = df.select(regexp_func("str", F.col("regexp")).alias("m")).collect() assert res[0].m is True def test_regexp_count(self, spark): df = spark.createDataFrame([("1a 2b 14m", r"\d+")], ["str", "regexp"]) - res = df.select(F.regexp_count('str', F.lit(r'\d+')).alias('d')).collect() + res = df.select(F.regexp_count("str", F.lit(r"\d+")).alias("d")).collect() assert res == [Row(d=3)] - res = df.select(F.regexp_count('str', F.lit(r'mmm')).alias('d')).collect() + res = df.select(F.regexp_count("str", F.lit(r"mmm")).alias("d")).collect() assert res == [Row(d=0)] - res = df.select(F.regexp_count("str", F.col("regexp")).alias('d')).collect() + res = df.select(F.regexp_count("str", F.col("regexp")).alias("d")).collect() assert res == [Row(d=3)] def test_regexp_extract(self, spark): - df = spark.createDataFrame([('100-200',)], ['str']) - res = df.select(F.regexp_extract('str', r'(\d+)-(\d+)', 1).alias('d')).collect() - assert res == [Row(d='100')] + df = spark.createDataFrame([("100-200",)], ["str"]) + res = df.select(F.regexp_extract("str", r"(\d+)-(\d+)", 1).alias("d")).collect() + assert res == [Row(d="100")] - df = spark.createDataFrame([('foo',)], ['str']) - res = df.select(F.regexp_extract('str', r'(\d+)', 1).alias('d')).collect() - assert res == [Row(d='')] + df = spark.createDataFrame([("foo",)], ["str"]) + res = df.select(F.regexp_extract("str", r"(\d+)", 1).alias("d")).collect() + assert res == [Row(d="")] - df = spark.createDataFrame([('aaaac',)], ['str']) - res = df.select(F.regexp_extract('str', '(a+)(b)?(c)', 2).alias('d')).collect() - assert res == [Row(d='')] + df = spark.createDataFrame([("aaaac",)], ["str"]) + res = df.select(F.regexp_extract("str", "(a+)(b)?(c)", 2).alias("d")).collect() + assert res == [Row(d="")] def test_regexp_extract_all(self, spark): df = spark.createDataFrame([("100-200, 300-400", r"(\d+)-(\d+)")], ["str", "regexp"]) - res = df.select(F.regexp_extract_all('str', F.lit(r'(\d+)-(\d+)')).alias('d')).collect() - assert res == [Row(d=['100', '300'])] + res = df.select(F.regexp_extract_all("str", F.lit(r"(\d+)-(\d+)")).alias("d")).collect() + assert res == [Row(d=["100", "300"])] - res = df.select(F.regexp_extract_all('str', F.lit(r'(\d+)-(\d+)'), 1).alias('d')).collect() - assert res == [Row(d=['100', '300'])] + res = df.select(F.regexp_extract_all("str", F.lit(r"(\d+)-(\d+)"), 1).alias("d")).collect() + assert res == [Row(d=["100", "300"])] - res = df.select(F.regexp_extract_all('str', F.lit(r'(\d+)-(\d+)'), 2).alias('d')).collect() - assert res == [Row(d=['200', '400'])] + res = df.select(F.regexp_extract_all("str", F.lit(r"(\d+)-(\d+)"), 2).alias("d")).collect() + assert res == [Row(d=["200", "400"])] - res = df.select(F.regexp_extract_all('str', F.col("regexp")).alias('d')).collect() - assert res == [Row(d=['100', '300'])] + res = df.select(F.regexp_extract_all("str", F.col("regexp")).alias("d")).collect() + assert res == [Row(d=["100", "300"])] def test_regexp_substr(self, spark): df = spark.createDataFrame([("1a 2b 14m", r"\d+")], ["str", "regexp"]) - res = df.select(F.regexp_substr('str', F.lit(r'\d+')).alias('d')).collect() - assert res == [Row(d='1')] + res = df.select(F.regexp_substr("str", F.lit(r"\d+")).alias("d")).collect() + assert res == [Row(d="1")] - res = df.select(F.regexp_substr('str', F.lit(r'mmm')).alias('d')).collect() + res = df.select(F.regexp_substr("str", F.lit(r"mmm")).alias("d")).collect() assert res == [Row(d=None)] - res = df.select(F.regexp_substr("str", F.col("regexp")).alias('d')).collect() - assert res == [Row(d='1')] + res = df.select(F.regexp_substr("str", F.col("regexp")).alias("d")).collect() + assert res == [Row(d="1")] def test_repeat(self, spark): df = spark.createDataFrame( - [('ab',)], + [("ab",)], [ - 's', + "s", ], ) - res = df.select(F.repeat(df.s, 3).alias('s')).collect() - assert res == [Row(s='ababab')] + res = df.select(F.repeat(df.s, 3).alias("s")).collect() + assert res == [Row(s="ababab")] def test_reverse(self, spark): data = [ diff --git a/tests/fast/spark/test_spark_group_by.py b/tests/fast/spark/test_spark_group_by.py index 8b66901f..9e8a8ea0 100644 --- a/tests/fast/spark/test_spark_group_by.py +++ b/tests/fast/spark/test_spark_group_by.py @@ -175,7 +175,7 @@ def test_group_by_empty(self, spark): ) res = df.groupBy("name").count().columns - assert res == ['name', 'count'] + assert res == ["name", "count"] def test_group_by_first_and_last(self, spark): df = spark.createDataFrame([("Alice", 2), ("Bob", 5), ("Alice", None)], ("name", "age")) @@ -188,7 +188,7 @@ def test_group_by_first_and_last(self, spark): .collect() ) - assert res == [Row(name='Alice', first_age=None, last_age=2), Row(name='Bob', first_age=5, last_age=5)] + assert res == [Row(name="Alice", first_age=None, last_age=2), Row(name="Bob", first_age=5, last_age=5)] def test_standard_deviations(self, spark): df = spark.createDataFrame( @@ -265,7 +265,7 @@ def test_group_by_mean(self, spark): res = df.groupBy("course").agg(median("earnings").alias("m")).collect() - assert sorted(res, key=lambda x: x.course) == [Row(course='Java', m=22000), Row(course='dotNET', m=10000)] + assert sorted(res, key=lambda x: x.course) == [Row(course="Java", m=22000), Row(course="dotNET", m=10000)] def test_group_by_mode(self, spark): df = spark.createDataFrame( @@ -282,11 +282,11 @@ def test_group_by_mode(self, spark): res = df.groupby("course").agg(mode("year").alias("mode")).collect() - assert sorted(res, key=lambda x: x.course) == [Row(course='Java', mode=2012), Row(course='dotNET', mode=2012)] + assert sorted(res, key=lambda x: x.course) == [Row(course="Java", mode=2012), Row(course="dotNET", mode=2012)] def test_group_by_product(self, spark): - df = spark.range(1, 10).toDF('x').withColumn('mod3', col('x') % 3) - res = df.groupBy('mod3').agg(product('x').alias('product')).orderBy("mod3").collect() + df = spark.range(1, 10).toDF("x").withColumn("mod3", col("x") % 3) + res = df.groupBy("mod3").agg(product("x").alias("product")).orderBy("mod3").collect() assert res == [Row(mod3=0, product=162), Row(mod3=1, product=28), Row(mod3=2, product=80)] def test_group_by_skewness(self, spark): diff --git a/tests/fast/spark/test_spark_intersect.py b/tests/fast/spark/test_spark_intersect.py index 7fd97d40..ba0afbdd 100644 --- a/tests/fast/spark/test_spark_intersect.py +++ b/tests/fast/spark/test_spark_intersect.py @@ -19,7 +19,6 @@ def df2(spark): class TestDataFrameIntersect: def test_intersect(self, spark, df, df2): - df3 = df.intersect(df2).sort(df.C1) res = df3.collect() @@ -29,7 +28,6 @@ def test_intersect(self, spark, df, df2): ] def test_intersect_all(self, spark, df, df2): - df3 = df.intersectAll(df2).sort(df.C1) res = df3.collect() diff --git a/tests/fast/spark/test_spark_join.py b/tests/fast/spark/test_spark_join.py index c7ef9878..f67c54cb 100644 --- a/tests/fast/spark/test_spark_join.py +++ b/tests/fast/spark/test_spark_join.py @@ -49,63 +49,63 @@ def test_inner_join(self, dataframe_a, dataframe_b): expected = [ Row( emp_id=1, - name='Smith', + name="Smith", superior_emp_id=-1, - year_joined='2018', - emp_dept_id='10', - gender='M', + year_joined="2018", + emp_dept_id="10", + gender="M", salary=3000, - dept_name='Finance', + dept_name="Finance", dept_id=10, ), Row( emp_id=2, - name='Rose', + name="Rose", superior_emp_id=1, - year_joined='2010', - emp_dept_id='20', - gender='M', + year_joined="2010", + emp_dept_id="20", + gender="M", salary=4000, - dept_name='Marketing', + dept_name="Marketing", dept_id=20, ), Row( emp_id=3, - name='Williams', + name="Williams", superior_emp_id=1, - year_joined='2010', - emp_dept_id='10', - gender='M', + year_joined="2010", + emp_dept_id="10", + gender="M", salary=1000, - dept_name='Finance', + dept_name="Finance", dept_id=10, ), Row( emp_id=4, - name='Jones', + name="Jones", superior_emp_id=2, - year_joined='2005', - emp_dept_id='10', - gender='F', + year_joined="2005", + emp_dept_id="10", + gender="F", salary=2000, - dept_name='Finance', + dept_name="Finance", dept_id=10, ), Row( emp_id=5, - name='Brown', + name="Brown", superior_emp_id=2, - year_joined='2010', - emp_dept_id='40', - gender='', + year_joined="2010", + emp_dept_id="40", + gender="", salary=-1, - dept_name='IT', + dept_name="IT", dept_id=40, ), ] assert sorted(res) == sorted(expected) - @pytest.mark.parametrize('how', ['outer', 'fullouter', 'full', 'full_outer']) + @pytest.mark.parametrize("how", ["outer", "fullouter", "full", "full_outer"]) def test_outer_join(self, dataframe_a, dataframe_b, how): df = dataframe_a.join(dataframe_b, dataframe_a.emp_dept_id == dataframe_b.dept_id, how) df = df.sort(*df.columns) @@ -114,66 +114,66 @@ def test_outer_join(self, dataframe_a, dataframe_b, how): [ Row( emp_id=1, - name='Smith', + name="Smith", superior_emp_id=-1, - year_joined='2018', - emp_dept_id='10', - gender='M', + year_joined="2018", + emp_dept_id="10", + gender="M", salary=3000, - dept_name='Finance', + dept_name="Finance", dept_id=10, ), Row( emp_id=2, - name='Rose', + name="Rose", superior_emp_id=1, - year_joined='2010', - emp_dept_id='20', - gender='M', + year_joined="2010", + emp_dept_id="20", + gender="M", salary=4000, - dept_name='Marketing', + dept_name="Marketing", dept_id=20, ), Row( emp_id=3, - name='Williams', + name="Williams", superior_emp_id=1, - year_joined='2010', - emp_dept_id='10', - gender='M', + year_joined="2010", + emp_dept_id="10", + gender="M", salary=1000, - dept_name='Finance', + dept_name="Finance", dept_id=10, ), Row( emp_id=4, - name='Jones', + name="Jones", superior_emp_id=2, - year_joined='2005', - emp_dept_id='10', - gender='F', + year_joined="2005", + emp_dept_id="10", + gender="F", salary=2000, - dept_name='Finance', + dept_name="Finance", dept_id=10, ), Row( emp_id=5, - name='Brown', + name="Brown", superior_emp_id=2, - year_joined='2010', - emp_dept_id='40', - gender='', + year_joined="2010", + emp_dept_id="40", + gender="", salary=-1, - dept_name='IT', + dept_name="IT", dept_id=40, ), Row( emp_id=6, - name='Brown', + name="Brown", superior_emp_id=2, - year_joined='2010', - emp_dept_id='50', - gender='', + year_joined="2010", + emp_dept_id="50", + gender="", salary=-1, dept_name=None, dept_id=None, @@ -186,14 +186,14 @@ def test_outer_join(self, dataframe_a, dataframe_b, how): emp_dept_id=None, gender=None, salary=None, - dept_name='Sales', + dept_name="Sales", dept_id=30, ), ], key=lambda x: x.emp_id or 0, ) - @pytest.mark.parametrize('how', ['right', 'rightouter', 'right_outer']) + @pytest.mark.parametrize("how", ["right", "rightouter", "right_outer"]) def test_right_join(self, dataframe_a, dataframe_b, how): df = dataframe_a.join(dataframe_b, dataframe_a.emp_dept_id == dataframe_b.dept_id, how) df = df.sort(*df.columns) @@ -202,57 +202,57 @@ def test_right_join(self, dataframe_a, dataframe_b, how): [ Row( emp_id=1, - name='Smith', + name="Smith", superior_emp_id=-1, - year_joined='2018', - emp_dept_id='10', - gender='M', + year_joined="2018", + emp_dept_id="10", + gender="M", salary=3000, - dept_name='Finance', + dept_name="Finance", dept_id=10, ), Row( emp_id=2, - name='Rose', + name="Rose", superior_emp_id=1, - year_joined='2010', - emp_dept_id='20', - gender='M', + year_joined="2010", + emp_dept_id="20", + gender="M", salary=4000, - dept_name='Marketing', + dept_name="Marketing", dept_id=20, ), Row( emp_id=3, - name='Williams', + name="Williams", superior_emp_id=1, - year_joined='2010', - emp_dept_id='10', - gender='M', + year_joined="2010", + emp_dept_id="10", + gender="M", salary=1000, - dept_name='Finance', + dept_name="Finance", dept_id=10, ), Row( emp_id=4, - name='Jones', + name="Jones", superior_emp_id=2, - year_joined='2005', - emp_dept_id='10', - gender='F', + year_joined="2005", + emp_dept_id="10", + gender="F", salary=2000, - dept_name='Finance', + dept_name="Finance", dept_id=10, ), Row( emp_id=5, - name='Brown', + name="Brown", superior_emp_id=2, - year_joined='2010', - emp_dept_id='40', - gender='', + year_joined="2010", + emp_dept_id="40", + gender="", salary=-1, - dept_name='IT', + dept_name="IT", dept_id=40, ), Row( @@ -263,14 +263,14 @@ def test_right_join(self, dataframe_a, dataframe_b, how): emp_dept_id=None, gender=None, salary=None, - dept_name='Sales', + dept_name="Sales", dept_id=30, ), ], key=lambda x: x.emp_id or 0, ) - @pytest.mark.parametrize('how', ['semi', 'leftsemi', 'left_semi']) + @pytest.mark.parametrize("how", ["semi", "leftsemi", "left_semi"]) def test_semi_join(self, dataframe_a, dataframe_b, how): df = dataframe_a.join(dataframe_b, dataframe_a.emp_dept_id == dataframe_b.dept_id, how) df = df.sort(*df.columns) @@ -279,59 +279,59 @@ def test_semi_join(self, dataframe_a, dataframe_b, how): [ Row( emp_id=1, - name='Smith', + name="Smith", superior_emp_id=-1, - year_joined='2018', - emp_dept_id='10', - gender='M', + year_joined="2018", + emp_dept_id="10", + gender="M", salary=3000, ), Row( emp_id=2, - name='Rose', + name="Rose", superior_emp_id=1, - year_joined='2010', - emp_dept_id='20', - gender='M', + year_joined="2010", + emp_dept_id="20", + gender="M", salary=4000, ), Row( emp_id=3, - name='Williams', + name="Williams", superior_emp_id=1, - year_joined='2010', - emp_dept_id='10', - gender='M', + year_joined="2010", + emp_dept_id="10", + gender="M", salary=1000, ), Row( emp_id=4, - name='Jones', + name="Jones", superior_emp_id=2, - year_joined='2005', - emp_dept_id='10', - gender='F', + year_joined="2005", + emp_dept_id="10", + gender="F", salary=2000, ), Row( emp_id=5, - name='Brown', + name="Brown", superior_emp_id=2, - year_joined='2010', - emp_dept_id='40', - gender='', + year_joined="2010", + emp_dept_id="40", + gender="", salary=-1, ), ] ) - @pytest.mark.parametrize('how', ['anti', 'leftanti', 'left_anti']) + @pytest.mark.parametrize("how", ["anti", "leftanti", "left_anti"]) def test_anti_join(self, dataframe_a, dataframe_b, how): df = dataframe_a.join(dataframe_b, dataframe_a.emp_dept_id == dataframe_b.dept_id, how) df = df.sort(*df.columns) res = df.collect() assert res == [ - Row(emp_id=6, name='Brown', superior_emp_id=2, year_joined='2010', emp_dept_id='50', gender='', salary=-1) + Row(emp_id=6, name="Brown", superior_emp_id=2, year_joined="2010", emp_dept_id="50", gender="", salary=-1) ] def test_self_join(self, dataframe_a): @@ -351,11 +351,11 @@ def test_self_join(self, dataframe_a): res = df.collect() assert sorted(res, key=lambda x: x.emp_id) == sorted( [ - Row(emp_id=2, name='Rose', superior_emp_id=1, superior_emp_name='Smith'), - Row(emp_id=3, name='Williams', superior_emp_id=1, superior_emp_name='Smith'), - Row(emp_id=4, name='Jones', superior_emp_id=2, superior_emp_name='Rose'), - Row(emp_id=5, name='Brown', superior_emp_id=2, superior_emp_name='Rose'), - Row(emp_id=6, name='Brown', superior_emp_id=2, superior_emp_name='Rose'), + Row(emp_id=2, name="Rose", superior_emp_id=1, superior_emp_name="Smith"), + Row(emp_id=3, name="Williams", superior_emp_id=1, superior_emp_name="Smith"), + Row(emp_id=4, name="Jones", superior_emp_id=2, superior_emp_name="Rose"), + Row(emp_id=5, name="Brown", superior_emp_id=2, superior_emp_name="Rose"), + Row(emp_id=6, name="Brown", superior_emp_id=2, superior_emp_name="Rose"), ], key=lambda x: x.emp_id, ) @@ -382,29 +382,29 @@ def test_cross_join(self, spark): ) def test_join_with_using_clause(self, spark, dataframe_a): - dataframe_a = dataframe_a.select('name', 'year_joined') + dataframe_a = dataframe_a.select("name", "year_joined") - df = dataframe_a.alias('df1') - df2 = dataframe_a.alias('df2') - res = df.join(df2, ['name', 'year_joined']).sort('name', 'year_joined') + df = dataframe_a.alias("df1") + df2 = dataframe_a.alias("df2") + res = df.join(df2, ["name", "year_joined"]).sort("name", "year_joined") res = res.collect() assert res == [ - Row(name='Brown', year_joined='2010'), - Row(name='Brown', year_joined='2010'), - Row(name='Brown', year_joined='2010'), - Row(name='Brown', year_joined='2010'), - Row(name='Jones', year_joined='2005'), - Row(name='Rose', year_joined='2010'), - Row(name='Smith', year_joined='2018'), - Row(name='Williams', year_joined='2010'), + Row(name="Brown", year_joined="2010"), + Row(name="Brown", year_joined="2010"), + Row(name="Brown", year_joined="2010"), + Row(name="Brown", year_joined="2010"), + Row(name="Jones", year_joined="2005"), + Row(name="Rose", year_joined="2010"), + Row(name="Smith", year_joined="2018"), + Row(name="Williams", year_joined="2010"), ] def test_join_with_common_column(self, spark, dataframe_a): - dataframe_a = dataframe_a.select('name', 'year_joined') + dataframe_a = dataframe_a.select("name", "year_joined") - df = dataframe_a.alias('df1') - df2 = dataframe_a.alias('df2') - res = df.join(df2, df.name == df2.name).sort('df1.name') + df = dataframe_a.alias("df1") + df2 = dataframe_a.alias("df2") + res = df.join(df2, df.name == df2.name).sort("df1.name") res = res.collect() assert ( str(res) diff --git a/tests/fast/spark/test_spark_order_by.py b/tests/fast/spark/test_spark_order_by.py index 92aa4d3a..cc08dd7c 100644 --- a/tests/fast/spark/test_spark_order_by.py +++ b/tests/fast/spark/test_spark_order_by.py @@ -38,15 +38,15 @@ def test_order_by(self, spark): df2 = df.sort("department", "state") res1 = df2.collect() assert res1 == [ - Row(employee_name='Maria', department='Finance', state='CA', salary=90000, age=24, bonus=23000), - Row(employee_name='Raman', department='Finance', state='CA', salary=99000, age=40, bonus=24000), - Row(employee_name='Scott', department='Finance', state='NY', salary=83000, age=36, bonus=19000), - Row(employee_name='Jen', department='Finance', state='NY', salary=79000, age=53, bonus=15000), - Row(employee_name='Jeff', department='Marketing', state='CA', salary=80000, age=25, bonus=18000), - Row(employee_name='Kumar', department='Marketing', state='NY', salary=91000, age=50, bonus=21000), - Row(employee_name='Robert', department='Sales', state='CA', salary=81000, age=30, bonus=23000), - Row(employee_name='James', department='Sales', state='NY', salary=90000, age=34, bonus=10000), - Row(employee_name='Michael', department='Sales', state='NY', salary=86000, age=56, bonus=20000), + Row(employee_name="Maria", department="Finance", state="CA", salary=90000, age=24, bonus=23000), + Row(employee_name="Raman", department="Finance", state="CA", salary=99000, age=40, bonus=24000), + Row(employee_name="Scott", department="Finance", state="NY", salary=83000, age=36, bonus=19000), + Row(employee_name="Jen", department="Finance", state="NY", salary=79000, age=53, bonus=15000), + Row(employee_name="Jeff", department="Marketing", state="CA", salary=80000, age=25, bonus=18000), + Row(employee_name="Kumar", department="Marketing", state="NY", salary=91000, age=50, bonus=21000), + Row(employee_name="Robert", department="Sales", state="CA", salary=81000, age=30, bonus=23000), + Row(employee_name="James", department="Sales", state="NY", salary=90000, age=34, bonus=10000), + Row(employee_name="Michael", department="Sales", state="NY", salary=86000, age=56, bonus=20000), ] df2 = df.sort(col("department"), col("state")) @@ -60,15 +60,15 @@ def test_order_by(self, spark): df2 = df.sort(df.department.asc(), df.state.desc()) res1 = df2.collect() assert res1 == [ - Row(employee_name='Scott', department='Finance', state='NY', salary=83000, age=36, bonus=19000), - Row(employee_name='Jen', department='Finance', state='NY', salary=79000, age=53, bonus=15000), - Row(employee_name='Maria', department='Finance', state='CA', salary=90000, age=24, bonus=23000), - Row(employee_name='Raman', department='Finance', state='CA', salary=99000, age=40, bonus=24000), - Row(employee_name='Kumar', department='Marketing', state='NY', salary=91000, age=50, bonus=21000), - Row(employee_name='Jeff', department='Marketing', state='CA', salary=80000, age=25, bonus=18000), - Row(employee_name='James', department='Sales', state='NY', salary=90000, age=34, bonus=10000), - Row(employee_name='Michael', department='Sales', state='NY', salary=86000, age=56, bonus=20000), - Row(employee_name='Robert', department='Sales', state='CA', salary=81000, age=30, bonus=23000), + Row(employee_name="Scott", department="Finance", state="NY", salary=83000, age=36, bonus=19000), + Row(employee_name="Jen", department="Finance", state="NY", salary=79000, age=53, bonus=15000), + Row(employee_name="Maria", department="Finance", state="CA", salary=90000, age=24, bonus=23000), + Row(employee_name="Raman", department="Finance", state="CA", salary=99000, age=40, bonus=24000), + Row(employee_name="Kumar", department="Marketing", state="NY", salary=91000, age=50, bonus=21000), + Row(employee_name="Jeff", department="Marketing", state="CA", salary=80000, age=25, bonus=18000), + Row(employee_name="James", department="Sales", state="NY", salary=90000, age=34, bonus=10000), + Row(employee_name="Michael", department="Sales", state="NY", salary=86000, age=56, bonus=20000), + Row(employee_name="Robert", department="Sales", state="CA", salary=81000, age=30, bonus=23000), ] df2 = df.sort(col("department").asc(), col("state").desc()) @@ -94,15 +94,15 @@ def test_order_by(self, spark): ) res = df2.collect() assert res == [ - Row(employee_name='Maria', department='Finance', state='CA', salary=90000, age=24, bonus=23000), - Row(employee_name='Raman', department='Finance', state='CA', salary=99000, age=40, bonus=24000), - Row(employee_name='Scott', department='Finance', state='NY', salary=83000, age=36, bonus=19000), - Row(employee_name='Jen', department='Finance', state='NY', salary=79000, age=53, bonus=15000), - Row(employee_name='Jeff', department='Marketing', state='CA', salary=80000, age=25, bonus=18000), - Row(employee_name='Kumar', department='Marketing', state='NY', salary=91000, age=50, bonus=21000), - Row(employee_name='James', department='Sales', state='NY', salary=90000, age=34, bonus=10000), - Row(employee_name='Michael', department='Sales', state='NY', salary=86000, age=56, bonus=20000), - Row(employee_name='Robert', department='Sales', state='CA', salary=81000, age=30, bonus=23000), + Row(employee_name="Maria", department="Finance", state="CA", salary=90000, age=24, bonus=23000), + Row(employee_name="Raman", department="Finance", state="CA", salary=99000, age=40, bonus=24000), + Row(employee_name="Scott", department="Finance", state="NY", salary=83000, age=36, bonus=19000), + Row(employee_name="Jen", department="Finance", state="NY", salary=79000, age=53, bonus=15000), + Row(employee_name="Jeff", department="Marketing", state="CA", salary=80000, age=25, bonus=18000), + Row(employee_name="Kumar", department="Marketing", state="NY", salary=91000, age=50, bonus=21000), + Row(employee_name="James", department="Sales", state="NY", salary=90000, age=34, bonus=10000), + Row(employee_name="Michael", department="Sales", state="NY", salary=86000, age=56, bonus=20000), + Row(employee_name="Robert", department="Sales", state="CA", salary=81000, age=30, bonus=23000), ] def test_null_ordering(self, spark): @@ -130,56 +130,56 @@ def test_null_ordering(self, spark): res = df.orderBy("value1", "value2").collect() assert res == [ - Row(value1=None, value2='A'), - Row(value1=2, value2='A'), + Row(value1=None, value2="A"), + Row(value1=2, value2="A"), Row(value1=3, value2=None), - Row(value1=3, value2='A'), + Row(value1=3, value2="A"), ] res = df.orderBy("value1", "value2", ascending=True).collect() assert res == [ - Row(value1=None, value2='A'), - Row(value1=2, value2='A'), + Row(value1=None, value2="A"), + Row(value1=2, value2="A"), Row(value1=3, value2=None), - Row(value1=3, value2='A'), + Row(value1=3, value2="A"), ] res = df.orderBy("value1", "value2", ascending=False).collect() assert res == [ - Row(value1=3, value2='A'), + Row(value1=3, value2="A"), Row(value1=3, value2=None), - Row(value1=2, value2='A'), - Row(value1=None, value2='A'), + Row(value1=2, value2="A"), + Row(value1=None, value2="A"), ] res = df.orderBy(df.value1, df.value2).collect() assert res == [ - Row(value1=None, value2='A'), - Row(value1=2, value2='A'), + Row(value1=None, value2="A"), + Row(value1=2, value2="A"), Row(value1=3, value2=None), - Row(value1=3, value2='A'), + Row(value1=3, value2="A"), ] res = df.orderBy(df.value1.asc(), df.value2.asc()).collect() assert res == [ - Row(value1=None, value2='A'), - Row(value1=2, value2='A'), + Row(value1=None, value2="A"), + Row(value1=2, value2="A"), Row(value1=3, value2=None), - Row(value1=3, value2='A'), + Row(value1=3, value2="A"), ] res = df.orderBy(df.value1.desc(), df.value2.desc()).collect() assert res == [ - Row(value1=3, value2='A'), + Row(value1=3, value2="A"), Row(value1=3, value2=None), - Row(value1=2, value2='A'), - Row(value1=None, value2='A'), + Row(value1=2, value2="A"), + Row(value1=None, value2="A"), ] res = df.orderBy(df.value1, df.value2, ascending=[True, False]).collect() assert res == [ - Row(value1=None, value2='A'), - Row(value1=2, value2='A'), + Row(value1=None, value2="A"), + Row(value1=2, value2="A"), Row(value1=3, value2="A"), Row(value1=3, value2=None), ] diff --git a/tests/fast/spark/test_spark_pandas_dataframe.py b/tests/fast/spark/test_spark_pandas_dataframe.py index dcec77a8..6491b7a6 100644 --- a/tests/fast/spark/test_spark_pandas_dataframe.py +++ b/tests/fast/spark/test_spark_pandas_dataframe.py @@ -23,9 +23,9 @@ @pytest.fixture def pandasDF(spark): - data = [['Scott', 50], ['Jeff', 45], ['Thomas', 54], ['Ann', 34]] + data = [["Scott", 50], ["Jeff", 45], ["Thomas", 54], ["Ann", 34]] # Create the pandas DataFrame - df = pd.DataFrame(data, columns=['Name', 'Age']) + df = pd.DataFrame(data, columns=["Name", "Age"]) yield df @@ -35,10 +35,10 @@ def test_pd_conversion_basic(self, spark, pandasDF): res = sparkDF.collect() sparkDF.show() expected = [ - Row(Name='Scott', Age=50), - Row(Name='Jeff', Age=45), - Row(Name='Thomas', Age=54), - Row(Name='Ann', Age=34), + Row(Name="Scott", Age=50), + Row(Name="Jeff", Age=45), + Row(Name="Thomas", Age=54), + Row(Name="Ann", Age=34), ] assert res == expected diff --git a/tests/fast/spark/test_spark_readcsv.py b/tests/fast/spark/test_spark_readcsv.py index 8e6c0515..5ba3d199 100644 --- a/tests/fast/spark/test_spark_readcsv.py +++ b/tests/fast/spark/test_spark_readcsv.py @@ -9,8 +9,8 @@ class TestSparkReadCSV(object): def test_read_csv(self, spark, tmp_path): - file_path = tmp_path / 'basic.csv' - with open(file_path, 'w+') as f: + file_path = tmp_path / "basic.csv" + with open(file_path, "w+") as f: f.write( textwrap.dedent( """ diff --git a/tests/fast/spark/test_spark_readjson.py b/tests/fast/spark/test_spark_readjson.py index a6ad05f0..638bee2d 100644 --- a/tests/fast/spark/test_spark_readjson.py +++ b/tests/fast/spark/test_spark_readjson.py @@ -9,9 +9,9 @@ class TestSparkReadJson(object): def test_read_json(self, duckdb_cursor, spark, tmp_path): - file_path = tmp_path / 'basic.parquet' + file_path = tmp_path / "basic.parquet" file_path = file_path.as_posix() duckdb_cursor.execute(f"COPY (select 42 a, true b, 'this is a long string' c) to '{file_path}' (FORMAT JSON)") df = spark.read.json(file_path) res = df.collect() - assert res == [Row(a=42, b=True, c='this is a long string')] + assert res == [Row(a=42, b=True, c="this is a long string")] diff --git a/tests/fast/spark/test_spark_readparquet.py b/tests/fast/spark/test_spark_readparquet.py index a08ab16d..1b3ddd74 100644 --- a/tests/fast/spark/test_spark_readparquet.py +++ b/tests/fast/spark/test_spark_readparquet.py @@ -9,11 +9,11 @@ class TestSparkReadParquet(object): def test_read_parquet(self, duckdb_cursor, spark, tmp_path): - file_path = tmp_path / 'basic.parquet' + file_path = tmp_path / "basic.parquet" file_path = file_path.as_posix() duckdb_cursor.execute( f"COPY (select 42 a, true b, 'this is a long string' c) to '{file_path}' (FORMAT PARQUET)" ) df = spark.read.parquet(file_path) res = df.collect() - assert res == [Row(a=42, b=True, c='this is a long string')] + assert res == [Row(a=42, b=True, c="this is a long string")] diff --git a/tests/fast/spark/test_spark_session.py b/tests/fast/spark/test_spark_session.py index 7c338898..06c9dbcb 100644 --- a/tests/fast/spark/test_spark_session.py +++ b/tests/fast/spark/test_spark_session.py @@ -14,14 +14,14 @@ def test_spark_session_default(self): session = SparkSession.builder.getOrCreate() def test_spark_session(self): - session = SparkSession.builder.master("local[1]").appName('SparkByExamples.com').getOrCreate() + session = SparkSession.builder.master("local[1]").appName("SparkByExamples.com").getOrCreate() def test_new_session(self, spark: SparkSession): session = spark.newSession() - @pytest.mark.skip(reason='not tested yet') + @pytest.mark.skip(reason="not tested yet") def test_retrieve_same_session(self): - spark = SparkSession.builder.master('test').appName('test2').getOrCreate() + spark = SparkSession.builder.master("test").appName("test2").getOrCreate() spark2 = SparkSession.builder.getOrCreate() # Same connection should be returned assert spark == spark2 @@ -49,7 +49,7 @@ def test_hive_support(self): @pytest.mark.skipif(USE_ACTUAL_SPARK, reason="Different version numbers") def test_version(self, spark): version = spark.version - assert version == '1.0.0' + assert version == "1.0.0" def test_get_active_session(self, spark): active_session = spark.getActiveSession() @@ -58,7 +58,7 @@ def test_read(self, spark): reader = spark.read def test_write(self, spark): - df = spark.sql('select 42') + df = spark.sql("select 42") writer = df.write def test_read_stream(self, spark): @@ -68,7 +68,7 @@ def test_spark_context(self, spark): context = spark.sparkContext def test_sql(self, spark): - df = spark.sql('select 42') + df = spark.sql("select 42") def test_stop_context(self, spark): context = spark.sparkContext @@ -78,8 +78,8 @@ def test_stop_context(self, spark): USE_ACTUAL_SPARK, reason="Can't create table with the local PySpark setup in the CI/CD pipeline" ) def test_table(self, spark): - spark.sql('create table tbl(a varchar(10))') - df = spark.table('tbl') + spark.sql("create table tbl(a varchar(10))") + df = spark.table("tbl") def test_range(self, spark): res_1 = spark.range(3).collect() diff --git a/tests/fast/spark/test_spark_to_csv.py b/tests/fast/spark/test_spark_to_csv.py index 5048e579..e5387a6c 100644 --- a/tests/fast/spark/test_spark_to_csv.py +++ b/tests/fast/spark/test_spark_to_csv.py @@ -40,14 +40,14 @@ def df(spark): @pytest.fixture(params=[NumpyPandas(), ArrowPandas()]) def pandas_df_ints(request, spark): pandas = request.param - dataframe = pandas.DataFrame({'a': [5, 3, 23, 2], 'b': [45, 234, 234, 2]}) + dataframe = pandas.DataFrame({"a": [5, 3, 23, 2], "b": [45, 234, 234, 2]}) yield dataframe @pytest.fixture(params=[NumpyPandas(), ArrowPandas()]) def pandas_df_strings(request, spark): pandas = request.param - dataframe = pandas.DataFrame({'a': ['string1', 'string2', 'string3']}) + dataframe = pandas.DataFrame({"a": ["string1", "string2", "string3"]}) yield dataframe @@ -68,15 +68,15 @@ def test_to_csv_sep(self, pandas_df_ints, spark, tmp_path): df = spark.createDataFrame(pandas_df_ints) - df.write.csv(temp_file_name, sep=',') + df.write.csv(temp_file_name, sep=",") - csv_rel = spark.read.csv(temp_file_name, sep=',') + csv_rel = spark.read.csv(temp_file_name, sep=",") assert df.collect() == csv_rel.collect() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_na_rep(self, pandas, spark, tmp_path): temp_file_name = os.path.join(tmp_path, "temp_file.csv") - pandas_df = pandas.DataFrame({'a': [5, None, 23, 2], 'b': [45, 234, 234, 2]}) + pandas_df = pandas.DataFrame({"a": [5, None, 23, 2], "b": [45, 234, 234, 2]}) df = spark.createDataFrame(pandas_df) @@ -85,10 +85,10 @@ def test_to_csv_na_rep(self, pandas, spark, tmp_path): csv_rel = spark.read.csv(temp_file_name, nullValue="test") assert df.collect() == csv_rel.collect() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_header(self, pandas, spark, tmp_path): temp_file_name = os.path.join(tmp_path, "temp_file.csv") - pandas_df = pandas.DataFrame({'a': [5, None, 23, 2], 'b': [45, 234, 234, 2]}) + pandas_df = pandas.DataFrame({"a": [5, None, 23, 2], "b": [45, 234, 234, 2]}) df = spark.createDataFrame(pandas_df) @@ -97,20 +97,20 @@ def test_to_csv_header(self, pandas, spark, tmp_path): csv_rel = spark.read.csv(temp_file_name) assert df.collect() == csv_rel.collect() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_quotechar(self, pandas, spark, tmp_path): temp_file_name = os.path.join(tmp_path, "temp_file.csv") - pandas_df = pandas.DataFrame({'a': ["\'a,b,c\'", None, "hello", "bye"], 'b': [45, 234, 234, 2]}) + pandas_df = pandas.DataFrame({"a": ["'a,b,c'", None, "hello", "bye"], "b": [45, 234, 234, 2]}) df = spark.createDataFrame(pandas_df) - df.write.csv(temp_file_name, quote='\'', sep=',') + df.write.csv(temp_file_name, quote="'", sep=",") - csv_rel = spark.read.csv(temp_file_name, sep=',', quote='\'') + csv_rel = spark.read.csv(temp_file_name, sep=",", quote="'") assert df.collect() == csv_rel.collect() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_escapechar(self, pandas, spark, tmp_path): temp_file_name = os.path.join(tmp_path, "temp_file.csv") pandas_df = pandas.DataFrame( @@ -124,11 +124,11 @@ def test_to_csv_escapechar(self, pandas, spark, tmp_path): df = spark.createDataFrame(pandas_df) - df.write.csv(temp_file_name, quote='"', escape='!') - csv_rel = spark.read.csv(temp_file_name, quote='"', escape='!') + df.write.csv(temp_file_name, quote='"', escape="!") + csv_rel = spark.read.csv(temp_file_name, quote='"', escape="!") assert df.collect() == csv_rel.collect() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_date_format(self, pandas, spark, tmp_path): temp_file_name = os.path.join(tmp_path, "temp_file.csv") pandas_df = pandas.DataFrame(getTimeSeriesData()) @@ -143,17 +143,17 @@ def test_to_csv_date_format(self, pandas, spark, tmp_path): assert df.collect() == csv_rel.collect() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_timestamp_format(self, pandas, spark, tmp_path): temp_file_name = os.path.join(tmp_path, "temp_file.csv") data = [datetime.time(hour=23, minute=1, second=34, microsecond=234345)] - pandas_df = pandas.DataFrame({'0': pandas.Series(data=data, dtype='object')}) + pandas_df = pandas.DataFrame({"0": pandas.Series(data=data, dtype="object")}) df = spark.createDataFrame(pandas_df) - df.write.csv(temp_file_name, timestampFormat='%m/%d/%Y') + df.write.csv(temp_file_name, timestampFormat="%m/%d/%Y") - csv_rel = spark.read.csv(temp_file_name, timestampFormat='%m/%d/%Y') + csv_rel = spark.read.csv(temp_file_name, timestampFormat="%m/%d/%Y") assert df.collect() == csv_rel.collect() diff --git a/tests/fast/spark/test_spark_transform.py b/tests/fast/spark/test_spark_transform.py index 83e219a5..1f1186c5 100644 --- a/tests/fast/spark/test_spark_transform.py +++ b/tests/fast/spark/test_spark_transform.py @@ -62,15 +62,15 @@ def apply_discount(df): df2 = df.transform(to_upper_str_columns).transform(reduce_price, 1000).transform(apply_discount) res = df2.collect() assert res == [ - Row(CourseName='JAVA', fee=4000, discount=5, new_fee=3000, discounted_fee=2850.0), - Row(CourseName='PYTHON', fee=4600, discount=10, new_fee=3600, discounted_fee=3240.0), - Row(CourseName='SCALA', fee=4100, discount=15, new_fee=3100, discounted_fee=2635.0), - Row(CourseName='SCALA', fee=4500, discount=15, new_fee=3500, discounted_fee=2975.0), - Row(CourseName='PHP', fee=3000, discount=20, new_fee=2000, discounted_fee=1600.0), + Row(CourseName="JAVA", fee=4000, discount=5, new_fee=3000, discounted_fee=2850.0), + Row(CourseName="PYTHON", fee=4600, discount=10, new_fee=3600, discounted_fee=3240.0), + Row(CourseName="SCALA", fee=4100, discount=15, new_fee=3100, discounted_fee=2635.0), + Row(CourseName="SCALA", fee=4500, discount=15, new_fee=3500, discounted_fee=2975.0), + Row(CourseName="PHP", fee=3000, discount=20, new_fee=2000, discounted_fee=1600.0), ] # https://sparkbyexamples.com/pyspark/pyspark-transform-function/ - @pytest.mark.skip(reason='LambdaExpressions are currently under development, waiting til that is finished') + @pytest.mark.skip(reason="LambdaExpressions are currently under development, waiting til that is finished") def test_transform_function(self, spark, array_df): from spark_namespace.sql.functions import upper, transform diff --git a/tests/fast/spark/test_spark_types.py b/tests/fast/spark/test_spark_types.py index fb6e6102..6c97c2d9 100644 --- a/tests/fast/spark/test_spark_types.py +++ b/tests/fast/spark/test_spark_types.py @@ -70,65 +70,65 @@ def test_all_types_schema(self, spark): schema = df.schema assert schema == StructType( [ - StructField('bool', BooleanType(), True), - StructField('tinyint', ByteType(), True), - StructField('smallint', ShortType(), True), - StructField('int', IntegerType(), True), - StructField('bigint', LongType(), True), - StructField('hugeint', HugeIntegerType(), True), - StructField('uhugeint', UnsignedHugeIntegerType(), True), - StructField('utinyint', UnsignedByteType(), True), - StructField('usmallint', UnsignedShortType(), True), - StructField('uint', UnsignedIntegerType(), True), - StructField('ubigint', UnsignedLongType(), True), - StructField('date', DateType(), True), - StructField('time', TimeNTZType(), True), - StructField('timestamp', TimestampNTZType(), True), - StructField('timestamp_s', TimestampSecondNTZType(), True), - StructField('timestamp_ms', TimestampNanosecondNTZType(), True), - StructField('timestamp_ns', TimestampMilisecondNTZType(), True), - StructField('time_tz', TimeType(), True), - StructField('timestamp_tz', TimestampType(), True), - StructField('float', FloatType(), True), - StructField('double', DoubleType(), True), - StructField('dec_4_1', DecimalType(4, 1), True), - StructField('dec_9_4', DecimalType(9, 4), True), - StructField('dec_18_6', DecimalType(18, 6), True), - StructField('dec38_10', DecimalType(38, 10), True), - StructField('uuid', UUIDType(), True), - StructField('interval', DayTimeIntervalType(0, 3), True), - StructField('varchar', StringType(), True), - StructField('blob', BinaryType(), True), - StructField('bit', BitstringType(), True), - StructField('int_array', ArrayType(IntegerType(), True), True), - StructField('double_array', ArrayType(DoubleType(), True), True), - StructField('date_array', ArrayType(DateType(), True), True), - StructField('timestamp_array', ArrayType(TimestampNTZType(), True), True), - StructField('timestamptz_array', ArrayType(TimestampType(), True), True), - StructField('varchar_array', ArrayType(StringType(), True), True), - StructField('nested_int_array', ArrayType(ArrayType(IntegerType(), True), True), True), + StructField("bool", BooleanType(), True), + StructField("tinyint", ByteType(), True), + StructField("smallint", ShortType(), True), + StructField("int", IntegerType(), True), + StructField("bigint", LongType(), True), + StructField("hugeint", HugeIntegerType(), True), + StructField("uhugeint", UnsignedHugeIntegerType(), True), + StructField("utinyint", UnsignedByteType(), True), + StructField("usmallint", UnsignedShortType(), True), + StructField("uint", UnsignedIntegerType(), True), + StructField("ubigint", UnsignedLongType(), True), + StructField("date", DateType(), True), + StructField("time", TimeNTZType(), True), + StructField("timestamp", TimestampNTZType(), True), + StructField("timestamp_s", TimestampSecondNTZType(), True), + StructField("timestamp_ms", TimestampNanosecondNTZType(), True), + StructField("timestamp_ns", TimestampMilisecondNTZType(), True), + StructField("time_tz", TimeType(), True), + StructField("timestamp_tz", TimestampType(), True), + StructField("float", FloatType(), True), + StructField("double", DoubleType(), True), + StructField("dec_4_1", DecimalType(4, 1), True), + StructField("dec_9_4", DecimalType(9, 4), True), + StructField("dec_18_6", DecimalType(18, 6), True), + StructField("dec38_10", DecimalType(38, 10), True), + StructField("uuid", UUIDType(), True), + StructField("interval", DayTimeIntervalType(0, 3), True), + StructField("varchar", StringType(), True), + StructField("blob", BinaryType(), True), + StructField("bit", BitstringType(), True), + StructField("int_array", ArrayType(IntegerType(), True), True), + StructField("double_array", ArrayType(DoubleType(), True), True), + StructField("date_array", ArrayType(DateType(), True), True), + StructField("timestamp_array", ArrayType(TimestampNTZType(), True), True), + StructField("timestamptz_array", ArrayType(TimestampType(), True), True), + StructField("varchar_array", ArrayType(StringType(), True), True), + StructField("nested_int_array", ArrayType(ArrayType(IntegerType(), True), True), True), StructField( - 'struct', - StructType([StructField('a', IntegerType(), True), StructField('b', StringType(), True)]), + "struct", + StructType([StructField("a", IntegerType(), True), StructField("b", StringType(), True)]), True, ), StructField( - 'struct_of_arrays', + "struct_of_arrays", StructType( [ - StructField('a', ArrayType(IntegerType(), True), True), - StructField('b', ArrayType(StringType(), True), True), + StructField("a", ArrayType(IntegerType(), True), True), + StructField("b", ArrayType(StringType(), True), True), ] ), True, ), StructField( - 'array_of_structs', + "array_of_structs", ArrayType( - StructType([StructField('a', IntegerType(), True), StructField('b', StringType(), True)]), True + StructType([StructField("a", IntegerType(), True), StructField("b", StringType(), True)]), True ), True, ), - StructField('map', MapType(StringType(), StringType(), True), True), + StructField("map", MapType(StringType(), StringType(), True), True), ] ) diff --git a/tests/fast/spark/test_spark_udf.py b/tests/fast/spark/test_spark_udf.py index 3b5a5d36..eebabbb3 100644 --- a/tests/fast/spark/test_spark_udf.py +++ b/tests/fast/spark/test_spark_udf.py @@ -5,7 +5,6 @@ class TestSparkUDF(object): def test_udf_register(self, spark): - def to_upper_fn(s: str) -> str: return s.upper() diff --git a/tests/fast/spark/test_spark_union.py b/tests/fast/spark/test_spark_union.py index ea889e05..8a3ff9ce 100644 --- a/tests/fast/spark/test_spark_union.py +++ b/tests/fast/spark/test_spark_union.py @@ -40,15 +40,15 @@ def test_merge_with_union(self, df, df2): unionDF = df.union(df2) res = unionDF.collect() assert res == [ - Row(employee_name='James', department='Sales', state='NY', salary=90000, age=34, bonus=10000), - Row(employee_name='Michael', department='Sales', state='NY', salary=86000, age=56, bonus=20000), - Row(employee_name='Robert', department='Sales', state='CA', salary=81000, age=30, bonus=23000), - Row(employee_name='Maria', department='Finance', state='CA', salary=90000, age=24, bonus=23000), - Row(employee_name='James', department='Sales', state='NY', salary=90000, age=34, bonus=10000), - Row(employee_name='Maria', department='Finance', state='CA', salary=90000, age=24, bonus=23000), - Row(employee_name='Jen', department='Finance', state='NY', salary=79000, age=53, bonus=15000), - Row(employee_name='Jeff', department='Marketing', state='CA', salary=80000, age=25, bonus=18000), - Row(employee_name='Kumar', department='Marketing', state='NY', salary=91000, age=50, bonus=21000), + Row(employee_name="James", department="Sales", state="NY", salary=90000, age=34, bonus=10000), + Row(employee_name="Michael", department="Sales", state="NY", salary=86000, age=56, bonus=20000), + Row(employee_name="Robert", department="Sales", state="CA", salary=81000, age=30, bonus=23000), + Row(employee_name="Maria", department="Finance", state="CA", salary=90000, age=24, bonus=23000), + Row(employee_name="James", department="Sales", state="NY", salary=90000, age=34, bonus=10000), + Row(employee_name="Maria", department="Finance", state="CA", salary=90000, age=24, bonus=23000), + Row(employee_name="Jen", department="Finance", state="NY", salary=79000, age=53, bonus=15000), + Row(employee_name="Jeff", department="Marketing", state="CA", salary=80000, age=25, bonus=18000), + Row(employee_name="Kumar", department="Marketing", state="NY", salary=91000, age=50, bonus=21000), ] unionDF = df.unionAll(df2) res2 = unionDF.collect() @@ -60,11 +60,11 @@ def test_merge_without_duplicates(self, df, df2): disDF = df.union(df2).distinct().sort(col("employee_name")) res = disDF.collect() assert res == [ - Row(employee_name='James', department='Sales', state='NY', salary=90000, age=34, bonus=10000), - Row(employee_name='Jeff', department='Marketing', state='CA', salary=80000, age=25, bonus=18000), - Row(employee_name='Jen', department='Finance', state='NY', salary=79000, age=53, bonus=15000), - Row(employee_name='Kumar', department='Marketing', state='NY', salary=91000, age=50, bonus=21000), - Row(employee_name='Maria', department='Finance', state='CA', salary=90000, age=24, bonus=23000), - Row(employee_name='Michael', department='Sales', state='NY', salary=86000, age=56, bonus=20000), - Row(employee_name='Robert', department='Sales', state='CA', salary=81000, age=30, bonus=23000), + Row(employee_name="James", department="Sales", state="NY", salary=90000, age=34, bonus=10000), + Row(employee_name="Jeff", department="Marketing", state="CA", salary=80000, age=25, bonus=18000), + Row(employee_name="Jen", department="Finance", state="NY", salary=79000, age=53, bonus=15000), + Row(employee_name="Kumar", department="Marketing", state="NY", salary=91000, age=50, bonus=21000), + Row(employee_name="Maria", department="Finance", state="CA", salary=90000, age=24, bonus=23000), + Row(employee_name="Michael", department="Sales", state="NY", salary=86000, age=56, bonus=20000), + Row(employee_name="Robert", department="Sales", state="CA", salary=81000, age=30, bonus=23000), ] diff --git a/tests/fast/spark/test_spark_union_by_name.py b/tests/fast/spark/test_spark_union_by_name.py index 08f3c62b..4739f0d8 100644 --- a/tests/fast/spark/test_spark_union_by_name.py +++ b/tests/fast/spark/test_spark_union_by_name.py @@ -38,14 +38,14 @@ def test_union_by_name(self, df1, df2): rel = df1.unionByName(df2) res = rel.collect() expected = [ - Row(name='James', id=34), - Row(name='Michael', id=56), - Row(name='Robert', id=30), - Row(name='Maria', id=24), - Row(name='James', id=34), - Row(name='Maria', id=45), - Row(name='Jen', id=45), - Row(name='Jeff', id=34), + Row(name="James", id=34), + Row(name="Michael", id=56), + Row(name="Robert", id=30), + Row(name="Maria", id=24), + Row(name="James", id=34), + Row(name="Maria", id=45), + Row(name="Jen", id=45), + Row(name="Jeff", id=34), ] assert res == expected @@ -53,13 +53,13 @@ def test_union_by_name_allow_missing_cols(self, df1, df2): rel = df1.unionByName(df2.drop("id"), allowMissingColumns=True) res = rel.collect() expected = [ - Row(name='James', id=34), - Row(name='Michael', id=56), - Row(name='Robert', id=30), - Row(name='Maria', id=24), - Row(name='James', id=None), - Row(name='Maria', id=None), - Row(name='Jen', id=None), - Row(name='Jeff', id=None), + Row(name="James", id=34), + Row(name="Michael", id=56), + Row(name="Robert", id=30), + Row(name="Maria", id=24), + Row(name="James", id=None), + Row(name="Maria", id=None), + Row(name="Jen", id=None), + Row(name="Jeff", id=None), ] assert res == expected diff --git a/tests/fast/spark/test_spark_with_column.py b/tests/fast/spark/test_spark_with_column.py index 80da34c3..2980e7fe 100644 --- a/tests/fast/spark/test_spark_with_column.py +++ b/tests/fast/spark/test_spark_with_column.py @@ -23,20 +23,20 @@ class TestWithColumn(object): def test_with_column(self, spark): data = [ - ('James', '', 'Smith', '1991-04-01', 'M', 3000), - ('Michael', 'Rose', '', '2000-05-19', 'M', 4000), - ('Robert', '', 'Williams', '1978-09-05', 'M', 4000), - ('Maria', 'Anne', 'Jones', '1967-12-01', 'F', 4000), - ('Jen', 'Mary', 'Brown', '1980-02-17', 'F', -1), + ("James", "", "Smith", "1991-04-01", "M", 3000), + ("Michael", "Rose", "", "2000-05-19", "M", 4000), + ("Robert", "", "Williams", "1978-09-05", "M", 4000), + ("Maria", "Anne", "Jones", "1967-12-01", "F", 4000), + ("Jen", "Mary", "Brown", "1980-02-17", "F", -1), ] columns = ["firstname", "middlename", "lastname", "dob", "gender", "salary"] df = spark.createDataFrame(data=data, schema=columns) - assert df.schema['salary'].dataType.typeName() == ('long' if USE_ACTUAL_SPARK else 'integer') + assert df.schema["salary"].dataType.typeName() == ("long" if USE_ACTUAL_SPARK else "integer") # The type of 'salary' has been cast to Bigint new_df = df.withColumn("salary", col("salary").cast("BIGINT")) - assert new_df.schema['salary'].dataType.typeName() == 'long' + assert new_df.schema["salary"].dataType.typeName() == "long" # Replace the 'salary' column with '(salary * 100)' df2 = df.withColumn("salary", col("salary") * 100) @@ -50,16 +50,16 @@ def test_with_column(self, spark): df2 = df.withColumn("Country", lit("USA")) res = df2.collect() - assert res[0].Country == 'USA' + assert res[0].Country == "USA" df2 = df.withColumn("Country", lit("USA")).withColumn("anotherColumn", lit("anotherValue")) res = df2.collect() - assert res[0].Country == 'USA' - assert res[1].anotherColumn == 'anotherValue' + assert res[0].Country == "USA" + assert res[1].anotherColumn == "anotherValue" df2 = df.withColumnRenamed("gender", "sex") - assert 'gender' not in df2.columns - assert 'sex' in df2.columns + assert "gender" not in df2.columns + assert "sex" in df2.columns df2 = df.drop("salary") - assert 'salary' not in df2.columns + assert "salary" not in df2.columns diff --git a/tests/fast/spark/test_spark_with_column_renamed.py b/tests/fast/spark/test_spark_with_column_renamed.py index 168ff23a..8534ab0b 100644 --- a/tests/fast/spark/test_spark_with_column_renamed.py +++ b/tests/fast/spark/test_spark_with_column_renamed.py @@ -22,49 +22,49 @@ class TestWithColumnRenamed(object): def test_with_column_renamed(self, spark): dataDF = [ - (('James', '', 'Smith'), '1991-04-01', 'M', 3000), - (('Michael', 'Rose', ''), '2000-05-19', 'M', 4000), - (('Robert', '', 'Williams'), '1978-09-05', 'M', 4000), - (('Maria', 'Anne', 'Jones'), '1967-12-01', 'F', 4000), - (('Jen', 'Mary', 'Brown'), '1980-02-17', 'F', -1), + (("James", "", "Smith"), "1991-04-01", "M", 3000), + (("Michael", "Rose", ""), "2000-05-19", "M", 4000), + (("Robert", "", "Williams"), "1978-09-05", "M", 4000), + (("Maria", "Anne", "Jones"), "1967-12-01", "F", 4000), + (("Jen", "Mary", "Brown"), "1980-02-17", "F", -1), ] from spark_namespace.sql.types import StructType, StructField, StringType, IntegerType schema = StructType( [ StructField( - 'name', + "name", StructType( [ - StructField('firstname', StringType(), True), - StructField('middlename', StringType(), True), - StructField('lastname', StringType(), True), + StructField("firstname", StringType(), True), + StructField("middlename", StringType(), True), + StructField("lastname", StringType(), True), ] ), ), - StructField('dob', StringType(), True), - StructField('gender', StringType(), True), - StructField('salary', IntegerType(), True), + StructField("dob", StringType(), True), + StructField("gender", StringType(), True), + StructField("salary", IntegerType(), True), ] ) df = spark.createDataFrame(data=dataDF, schema=schema) df2 = df.withColumnRenamed("dob", "DateOfBirth").withColumnRenamed("salary", "salary_amount") - assert 'dob' not in df2.columns - assert 'salary' not in df2.columns - assert 'DateOfBirth' in df2.columns - assert 'salary_amount' in df2.columns + assert "dob" not in df2.columns + assert "salary" not in df2.columns + assert "DateOfBirth" in df2.columns + assert "salary_amount" in df2.columns schema2 = StructType( [ StructField( - 'full name', + "full name", StructType( [ - StructField('fname', StringType(), True), - StructField('mname', StringType(), True), - StructField('lname', StringType(), True), + StructField("fname", StringType(), True), + StructField("mname", StringType(), True), + StructField("lname", StringType(), True), ] ), ), @@ -72,9 +72,9 @@ def test_with_column_renamed(self, spark): ) df2 = df.withColumnRenamed("name", "full name") - assert 'name' not in df2.columns - assert 'full name' in df2.columns - assert 'firstname' in df2.schema['full name'].dataType.fieldNames() + assert "name" not in df2.columns + assert "full name" in df2.columns + assert "firstname" in df2.schema["full name"].dataType.fieldNames() df2 = df.select( col("name").alias("full name"), @@ -82,9 +82,9 @@ def test_with_column_renamed(self, spark): col("gender"), col("salary"), ) - assert 'name' not in df2.columns - assert 'full name' in df2.columns - assert 'firstname' in df2.schema['full name'].dataType.fieldNames() + assert "name" not in df2.columns + assert "full name" in df2.columns + assert "firstname" in df2.schema["full name"].dataType.fieldNames() df2 = df.select( col("name.firstname").alias("fname"), @@ -94,5 +94,5 @@ def test_with_column_renamed(self, spark): col("gender"), col("salary"), ) - assert 'firstname' not in df2.columns - assert 'fname' in df2.columns + assert "firstname" not in df2.columns + assert "fname" in df2.columns diff --git a/tests/fast/spark/test_spark_with_columns.py b/tests/fast/spark/test_spark_with_columns.py index 6e1bedea..535f357d 100644 --- a/tests/fast/spark/test_spark_with_columns.py +++ b/tests/fast/spark/test_spark_with_columns.py @@ -10,20 +10,20 @@ class TestWithColumns: def test_with_columns(self, spark): data = [ - ('James', '', 'Smith', '1991-04-01', 'M', 3000), - ('Michael', 'Rose', '', '2000-05-19', 'M', 4000), - ('Robert', '', 'Williams', '1978-09-05', 'M', 4000), - ('Maria', 'Anne', 'Jones', '1967-12-01', 'F', 4000), - ('Jen', 'Mary', 'Brown', '1980-02-17', 'F', -1), + ("James", "", "Smith", "1991-04-01", "M", 3000), + ("Michael", "Rose", "", "2000-05-19", "M", 4000), + ("Robert", "", "Williams", "1978-09-05", "M", 4000), + ("Maria", "Anne", "Jones", "1967-12-01", "F", 4000), + ("Jen", "Mary", "Brown", "1980-02-17", "F", -1), ] columns = ["firstname", "middlename", "lastname", "dob", "gender", "salary"] df = spark.createDataFrame(data=data, schema=columns) - assert df.schema['salary'].dataType.typeName() == ('long' if USE_ACTUAL_SPARK else 'integer') + assert df.schema["salary"].dataType.typeName() == ("long" if USE_ACTUAL_SPARK else "integer") # The type of 'salary' has been cast to Bigint new_df = df.withColumns({"salary": col("salary").cast("BIGINT")}) - assert new_df.schema['salary'].dataType.typeName() == 'long' + assert new_df.schema["salary"].dataType.typeName() == "long" # Replace the 'salary' column with '(salary * 100)' and add a new column # from an existing column @@ -34,12 +34,12 @@ def test_with_columns(self, spark): df2 = df.withColumns({"Country": lit("USA")}) res = df2.collect() - assert res[0].Country == 'USA' + assert res[0].Country == "USA" df2 = df.withColumns({"Country": lit("USA")}).withColumns({"anotherColumn": lit("anotherValue")}) res = df2.collect() - assert res[0].Country == 'USA' - assert res[1].anotherColumn == 'anotherValue' + assert res[0].Country == "USA" + assert res[1].anotherColumn == "anotherValue" df2 = df.drop("salary") - assert 'salary' not in df2.columns + assert "salary" not in df2.columns diff --git a/tests/fast/spark/test_spark_with_columns_renamed.py b/tests/fast/spark/test_spark_with_columns_renamed.py index 99c4ce63..80b8b9e0 100644 --- a/tests/fast/spark/test_spark_with_columns_renamed.py +++ b/tests/fast/spark/test_spark_with_columns_renamed.py @@ -9,44 +9,44 @@ class TestWithColumnsRenamed(object): def test_with_columns_renamed(self, spark): dataDF = [ - (('James', '', 'Smith'), '1991-04-01', 'M', 3000), - (('Michael', 'Rose', ''), '2000-05-19', 'M', 4000), - (('Robert', '', 'Williams'), '1978-09-05', 'M', 4000), - (('Maria', 'Anne', 'Jones'), '1967-12-01', 'F', 4000), - (('Jen', 'Mary', 'Brown'), '1980-02-17', 'F', -1), + (("James", "", "Smith"), "1991-04-01", "M", 3000), + (("Michael", "Rose", ""), "2000-05-19", "M", 4000), + (("Robert", "", "Williams"), "1978-09-05", "M", 4000), + (("Maria", "Anne", "Jones"), "1967-12-01", "F", 4000), + (("Jen", "Mary", "Brown"), "1980-02-17", "F", -1), ] from spark_namespace.sql.types import StructType, StructField, StringType, IntegerType schema = StructType( [ StructField( - 'name', + "name", StructType( [ - StructField('firstname', StringType(), True), - StructField('middlename', StringType(), True), - StructField('lastname', StringType(), True), + StructField("firstname", StringType(), True), + StructField("middlename", StringType(), True), + StructField("lastname", StringType(), True), ] ), ), - StructField('dob', StringType(), True), - StructField('gender', StringType(), True), - StructField('salary', IntegerType(), True), + StructField("dob", StringType(), True), + StructField("gender", StringType(), True), + StructField("salary", IntegerType(), True), ] ) df = spark.createDataFrame(data=dataDF, schema=schema) df2 = df.withColumnsRenamed({"dob": "DateOfBirth", "salary": "salary_amount"}) - assert 'dob' not in df2.columns - assert 'salary' not in df2.columns - assert 'DateOfBirth' in df2.columns - assert 'salary_amount' in df2.columns + assert "dob" not in df2.columns + assert "salary" not in df2.columns + assert "DateOfBirth" in df2.columns + assert "salary_amount" in df2.columns df2 = df.withColumnsRenamed({"name": "full name"}) - assert 'name' not in df2.columns - assert 'full name' in df2.columns - assert 'firstname' in df2.schema['full name'].dataType.fieldNames() + assert "name" not in df2.columns + assert "full name" in df2.columns + assert "firstname" in df2.schema["full name"].dataType.fieldNames() # PySpark does not raise an error. This is a convenience we provide in DuckDB. if not USE_ACTUAL_SPARK: diff --git a/tests/fast/sqlite/test_types.py b/tests/fast/sqlite/test_types.py index d4be447a..3ffdceae 100644 --- a/tests/fast/sqlite/test_types.py +++ b/tests/fast/sqlite/test_types.py @@ -42,10 +42,10 @@ def tearDown(self): self.con.close() def test_CheckString(self): - self.cur.execute("insert into test(s) values (?)", (u"Österreich",)) + self.cur.execute("insert into test(s) values (?)", ("Österreich",)) self.cur.execute("select s from test") row = self.cur.fetchone() - self.assertEqual(row[0], u"Österreich") + self.assertEqual(row[0], "Österreich") def test_CheckSmallInt(self): self.cur.execute("insert into test(i) values (?)", (42,)) @@ -75,7 +75,7 @@ def test_CheckDecimalTooBig(self): self.assertEqual(row[0], val) def test_CheckDecimal(self): - val = '17.29' + val = "17.29" val = decimal.Decimal(val) self.cur.execute("insert into test(f) values (?)", (val,)) self.cur.execute("select f from test") @@ -83,7 +83,7 @@ def test_CheckDecimal(self): self.assertEqual(row[0], self.cur.execute("select 17.29::DOUBLE").fetchone()[0]) def test_CheckDecimalWithExponent(self): - val = '1E5' + val = "1E5" val = decimal.Decimal(val) self.cur.execute("insert into test(f) values (?)", (val,)) self.cur.execute("select f from test") @@ -93,14 +93,14 @@ def test_CheckDecimalWithExponent(self): def test_CheckNaN(self): import math - val = decimal.Decimal('nan') + val = decimal.Decimal("nan") self.cur.execute("insert into test(f) values (?)", (val,)) self.cur.execute("select f from test") row = self.cur.fetchone() self.assertEqual(math.isnan(row[0]), True) def test_CheckInf(self): - val = decimal.Decimal('inf') + val = decimal.Decimal("inf") self.cur.execute("insert into test(f) values (?)", (val,)) self.cur.execute("select f from test") row = self.cur.fetchone() @@ -122,7 +122,7 @@ def test_CheckMemoryviewBlob(self): self.assertEqual(row[0], sample) def test_CheckMemoryviewFromhexBlob(self): - sample = bytes.fromhex('00FF0F2E3D4C5B6A798800FF00') + sample = bytes.fromhex("00FF0F2E3D4C5B6A798800FF00") val = memoryview(sample) self.cur.execute("insert into test(b) values (?)", (val,)) self.cur.execute("select b from test") @@ -137,9 +137,9 @@ def test_CheckNoneBlob(self): self.assertEqual(row[0], val) def test_CheckUnicodeExecute(self): - self.cur.execute(u"select 'Österreich'") + self.cur.execute("select 'Österreich'") row = self.cur.fetchone() - self.assertEqual(row[0], u"Österreich") + self.assertEqual(row[0], "Österreich") class CommonTableExpressionTests(unittest.TestCase): @@ -206,7 +206,7 @@ def test_CheckTimestamp(self): self.assertEqual(ts, ts2) def test_CheckSqlTimestamp(self): - now = datetime.datetime.now(datetime.UTC) if hasattr(datetime, 'UTC') else datetime.datetime.utcnow() + now = datetime.datetime.now(datetime.UTC) if hasattr(datetime, "UTC") else datetime.datetime.utcnow() self.cur.execute("insert into test(ts) values (current_timestamp)") self.cur.execute("select ts from test") ts = self.cur.fetchone()[0] diff --git a/tests/fast/test_alex_multithread.py b/tests/fast/test_alex_multithread.py index 92768ec0..bcb0181b 100644 --- a/tests/fast/test_alex_multithread.py +++ b/tests/fast/test_alex_multithread.py @@ -41,7 +41,7 @@ def test_multiple_cursors(self, duckdb_cursor): # Kick off multiple threads (in the same process) # Pass in the same connection as an argument, and an object to store the results for i in range(thread_count): - threads.append(Thread(target=insert_from_cursor, args=(duckdb_con,), name='my_thread_' + str(i))) + threads.append(Thread(target=insert_from_cursor, args=(duckdb_con,), name="my_thread_" + str(i))) for thread in threads: thread.start() @@ -50,9 +50,9 @@ def test_multiple_cursors(self, duckdb_cursor): thread.join() assert duckdb_con.execute("""SELECT * FROM my_inserts order by thread_name""").fetchall() == [ - ('my_thread_0',), - ('my_thread_1',), - ('my_thread_2',), + ("my_thread_0",), + ("my_thread_1",), + ("my_thread_2",), ] def test_same_connection(self, duckdb_cursor): @@ -67,7 +67,7 @@ def test_same_connection(self, duckdb_cursor): # Pass in the same connection as an argument, and an object to store the results for i in range(thread_count): cursors.append(duckdb_con.cursor()) - threads.append(Thread(target=insert_from_same_connection, args=(cursors[i],), name='my_thread_' + str(i))) + threads.append(Thread(target=insert_from_same_connection, args=(cursors[i],), name="my_thread_" + str(i))) for thread in threads: thread.start() @@ -76,9 +76,9 @@ def test_same_connection(self, duckdb_cursor): thread.join() assert duckdb_con.execute("""SELECT * FROM my_inserts order by thread_name""").fetchall() == [ - ('my_thread_0',), - ('my_thread_1',), - ('my_thread_2',), + ("my_thread_0",), + ("my_thread_1",), + ("my_thread_2",), ] def test_multiple_cursors_persisted(self, tmp_database): @@ -91,7 +91,7 @@ def test_multiple_cursors_persisted(self, tmp_database): # Kick off multiple threads (in the same process) # Pass in the same connection as an argument, and an object to store the results for i in range(thread_count): - threads.append(Thread(target=insert_from_cursor, args=(duckdb_con,), name='my_thread_' + str(i))) + threads.append(Thread(target=insert_from_cursor, args=(duckdb_con,), name="my_thread_" + str(i))) for thread in threads: thread.start() @@ -99,9 +99,9 @@ def test_multiple_cursors_persisted(self, tmp_database): thread.join() assert duckdb_con.execute("""SELECT * FROM my_inserts order by thread_name""").fetchall() == [ - ('my_thread_0',), - ('my_thread_1',), - ('my_thread_2',), + ("my_thread_0",), + ("my_thread_1",), + ("my_thread_2",), ] duckdb_con.close() @@ -115,7 +115,7 @@ def test_same_connection_persisted(self, tmp_database): # Kick off multiple threads (in the same process) # Pass in the same connection as an argument, and an object to store the results for i in range(thread_count): - threads.append(Thread(target=insert_from_same_connection, args=(duckdb_con,), name='my_thread_' + str(i))) + threads.append(Thread(target=insert_from_same_connection, args=(duckdb_con,), name="my_thread_" + str(i))) for thread in threads: thread.start() @@ -123,8 +123,8 @@ def test_same_connection_persisted(self, tmp_database): thread.join() assert duckdb_con.execute("""SELECT * FROM my_inserts order by thread_name""").fetchall() == [ - ('my_thread_0',), - ('my_thread_1',), - ('my_thread_2',), + ("my_thread_0",), + ("my_thread_1",), + ("my_thread_2",), ] duckdb_con.close() diff --git a/tests/fast/test_all_types.py b/tests/fast/test_all_types.py index 2128f9f1..3e701ced 100644 --- a/tests/fast/test_all_types.py +++ b/tests/fast/test_all_types.py @@ -12,7 +12,7 @@ def replace_with_ndarray(obj): - if hasattr(obj, '__getitem__'): + if hasattr(obj, "__getitem__"): if isinstance(obj, dict): for key, value in obj.items(): obj[key] = replace_with_ndarray(value) @@ -115,69 +115,69 @@ def recursive_equality(o1, o2): class TestAllTypes(object): - @pytest.mark.parametrize('cur_type', all_types) + @pytest.mark.parametrize("cur_type", all_types) def test_fetchall(self, cur_type): conn = duckdb.connect() conn.execute("SET TimeZone =UTC") # We replace these values since the extreme ranges are not supported in native-python. replacement_values = { - 'timestamp': "'1990-01-01 00:00:00'::TIMESTAMP", - 'timestamp_s': "'1990-01-01 00:00:00'::TIMESTAMP_S", - 'timestamp_ns': "'1990-01-01 00:00:00'::TIMESTAMP_NS", - 'timestamp_ms': "'1990-01-01 00:00:00'::TIMESTAMP_MS", - 'timestamp_tz': "'1990-01-01 00:00:00Z'::TIMESTAMPTZ", - 'date': "'1990-01-01'::DATE", - 'date_array': "[], ['1970-01-01'::DATE, NULL, '0001-01-01'::DATE, '9999-12-31'::DATE,], [NULL::DATE,]", - 'timestamp_array': "[], ['1970-01-01'::TIMESTAMP, NULL, '0001-01-01'::TIMESTAMP, '9999-12-31 23:59:59.999999'::TIMESTAMP,], [NULL::TIMESTAMP,]", - 'timestamptz_array': "[], ['1970-01-01 00:00:00Z'::TIMESTAMPTZ, NULL, '0001-01-01 00:00:00Z'::TIMESTAMPTZ, '9999-12-31 23:59:59.999999Z'::TIMESTAMPTZ,], [NULL::TIMESTAMPTZ,]", + "timestamp": "'1990-01-01 00:00:00'::TIMESTAMP", + "timestamp_s": "'1990-01-01 00:00:00'::TIMESTAMP_S", + "timestamp_ns": "'1990-01-01 00:00:00'::TIMESTAMP_NS", + "timestamp_ms": "'1990-01-01 00:00:00'::TIMESTAMP_MS", + "timestamp_tz": "'1990-01-01 00:00:00Z'::TIMESTAMPTZ", + "date": "'1990-01-01'::DATE", + "date_array": "[], ['1970-01-01'::DATE, NULL, '0001-01-01'::DATE, '9999-12-31'::DATE,], [NULL::DATE,]", + "timestamp_array": "[], ['1970-01-01'::TIMESTAMP, NULL, '0001-01-01'::TIMESTAMP, '9999-12-31 23:59:59.999999'::TIMESTAMP,], [NULL::TIMESTAMP,]", + "timestamptz_array": "[], ['1970-01-01 00:00:00Z'::TIMESTAMPTZ, NULL, '0001-01-01 00:00:00Z'::TIMESTAMPTZ, '9999-12-31 23:59:59.999999Z'::TIMESTAMPTZ,], [NULL::TIMESTAMPTZ,]", } adjusted_values = { - 'time': """CASE WHEN "time" = '24:00:00'::TIME THEN '23:59:59.999999'::TIME ELSE "time" END AS "time" """, - 'time_tz': """CASE WHEN time_tz = '24:00:00-1559'::TIMETZ THEN '23:59:59.999999-1559'::TIMETZ ELSE time_tz END AS "time_tz" """, + "time": """CASE WHEN "time" = '24:00:00'::TIME THEN '23:59:59.999999'::TIME ELSE "time" END AS "time" """, + "time_tz": """CASE WHEN time_tz = '24:00:00-1559'::TIMETZ THEN '23:59:59.999999-1559'::TIMETZ ELSE time_tz END AS "time_tz" """, } min_datetime = datetime.datetime.min min_datetime_with_utc = min_datetime.replace(tzinfo=pytz.UTC) max_datetime = datetime.datetime.max max_datetime_with_utc = max_datetime.replace(tzinfo=pytz.UTC) correct_answer_map = { - 'bool': [(False,), (True,), (None,)], - 'tinyint': [(-128,), (127,), (None,)], - 'smallint': [(-32768,), (32767,), (None,)], - 'int': [(-2147483648,), (2147483647,), (None,)], - 'bigint': [(-9223372036854775808,), (9223372036854775807,), (None,)], - 'hugeint': [ + "bool": [(False,), (True,), (None,)], + "tinyint": [(-128,), (127,), (None,)], + "smallint": [(-32768,), (32767,), (None,)], + "int": [(-2147483648,), (2147483647,), (None,)], + "bigint": [(-9223372036854775808,), (9223372036854775807,), (None,)], + "hugeint": [ (-170141183460469231731687303715884105728,), (170141183460469231731687303715884105727,), (None,), ], - 'utinyint': [(0,), (255,), (None,)], - 'usmallint': [(0,), (65535,), (None,)], - 'uint': [(0,), (4294967295,), (None,)], - 'ubigint': [(0,), (18446744073709551615,), (None,)], - 'time': [(datetime.time(0, 0),), (datetime.time(23, 59, 59, 999999),), (None,)], - 'float': [(-3.4028234663852886e38,), (3.4028234663852886e38,), (None,)], - 'double': [(-1.7976931348623157e308,), (1.7976931348623157e308,), (None,)], - 'dec_4_1': [(Decimal('-999.9'),), (Decimal('999.9'),), (None,)], - 'dec_9_4': [(Decimal('-99999.9999'),), (Decimal('99999.9999'),), (None,)], - 'dec_18_6': [(Decimal('-999999999999.999999'),), (Decimal('999999999999.999999'),), (None,)], - 'dec38_10': [ - (Decimal('-9999999999999999999999999999.9999999999'),), - (Decimal('9999999999999999999999999999.9999999999'),), + "utinyint": [(0,), (255,), (None,)], + "usmallint": [(0,), (65535,), (None,)], + "uint": [(0,), (4294967295,), (None,)], + "ubigint": [(0,), (18446744073709551615,), (None,)], + "time": [(datetime.time(0, 0),), (datetime.time(23, 59, 59, 999999),), (None,)], + "float": [(-3.4028234663852886e38,), (3.4028234663852886e38,), (None,)], + "double": [(-1.7976931348623157e308,), (1.7976931348623157e308,), (None,)], + "dec_4_1": [(Decimal("-999.9"),), (Decimal("999.9"),), (None,)], + "dec_9_4": [(Decimal("-99999.9999"),), (Decimal("99999.9999"),), (None,)], + "dec_18_6": [(Decimal("-999999999999.999999"),), (Decimal("999999999999.999999"),), (None,)], + "dec38_10": [ + (Decimal("-9999999999999999999999999999.9999999999"),), + (Decimal("9999999999999999999999999999.9999999999"),), (None,), ], - 'uuid': [ - (UUID('00000000-0000-0000-0000-000000000000'),), - (UUID('ffffffff-ffff-ffff-ffff-ffffffffffff'),), + "uuid": [ + (UUID("00000000-0000-0000-0000-000000000000"),), + (UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"),), (None,), ], - 'varchar': [('🦆🦆🦆🦆🦆🦆',), ('goo\0se',), (None,)], - 'json': [('🦆🦆🦆🦆🦆🦆',), ('goose',), (None,)], - 'blob': [(b'thisisalongblob\x00withnullbytes',), (b'\x00\x00\x00a',), (None,)], - 'bit': [('0010001001011100010101011010111',), ('10101',), (None,)], - 'small_enum': [('DUCK_DUCK_ENUM',), ('GOOSE',), (None,)], - 'medium_enum': [('enum_0',), ('enum_299',), (None,)], - 'large_enum': [('enum_0',), ('enum_69999',), (None,)], - 'date_array': [ + "varchar": [("🦆🦆🦆🦆🦆🦆",), ("goo\0se",), (None,)], + "json": [("🦆🦆🦆🦆🦆🦆",), ("goose",), (None,)], + "blob": [(b"thisisalongblob\x00withnullbytes",), (b"\x00\x00\x00a",), (None,)], + "bit": [("0010001001011100010101011010111",), ("10101",), (None,)], + "small_enum": [("DUCK_DUCK_ENUM",), ("GOOSE",), (None,)], + "medium_enum": [("enum_0",), ("enum_299",), (None,)], + "large_enum": [("enum_0",), ("enum_69999",), (None,)], + "date_array": [ ( [], [datetime.date(1970, 1, 1), None, datetime.date.min, datetime.date.max], @@ -186,7 +186,7 @@ def test_fetchall(self, cur_type): ], ) ], - 'timestamp_array': [ + "timestamp_array": [ ( [], [datetime.datetime(1970, 1, 1), None, datetime.datetime.min, datetime.datetime.max], @@ -195,7 +195,7 @@ def test_fetchall(self, cur_type): ], ), ], - 'timestamptz_array': [ + "timestamptz_array": [ ( [], [ @@ -209,67 +209,67 @@ def test_fetchall(self, cur_type): ], ), ], - 'int_array': [([],), ([42, 999, None, None, -42],), (None,)], - 'varchar_array': [([],), (['🦆🦆🦆🦆🦆🦆', 'goose', None, ''],), (None,)], - 'double_array': [([],), ([42.0, float('nan'), float('inf'), float('-inf'), None, -42.0],), (None,)], - 'nested_int_array': [ + "int_array": [([],), ([42, 999, None, None, -42],), (None,)], + "varchar_array": [([],), (["🦆🦆🦆🦆🦆🦆", "goose", None, ""],), (None,)], + "double_array": [([],), ([42.0, float("nan"), float("inf"), float("-inf"), None, -42.0],), (None,)], + "nested_int_array": [ ([],), ([[], [42, 999, None, None, -42], None, [], [42, 999, None, None, -42]],), (None,), ], - 'struct': [({'a': None, 'b': None},), ({'a': 42, 'b': '🦆🦆🦆🦆🦆🦆'},), (None,)], - 'struct_of_arrays': [ - ({'a': None, 'b': None},), - ({'a': [42, 999, None, None, -42], 'b': ['🦆🦆🦆🦆🦆🦆', 'goose', None, '']},), + "struct": [({"a": None, "b": None},), ({"a": 42, "b": "🦆🦆🦆🦆🦆🦆"},), (None,)], + "struct_of_arrays": [ + ({"a": None, "b": None},), + ({"a": [42, 999, None, None, -42], "b": ["🦆🦆🦆🦆🦆🦆", "goose", None, ""]},), (None,), ], - 'array_of_structs': [([],), ([{'a': None, 'b': None}, {'a': 42, 'b': '🦆🦆🦆🦆🦆🦆'}, None],), (None,)], - 'map': [ + "array_of_structs": [([],), ([{"a": None, "b": None}, {"a": 42, "b": "🦆🦆🦆🦆🦆🦆"}, None],), (None,)], + "map": [ ({},), - ({'key1': '🦆🦆🦆🦆🦆🦆', 'key2': 'goose'},), + ({"key1": "🦆🦆🦆🦆🦆🦆", "key2": "goose"},), (None,), ], - 'time_tz': [(datetime.time(0, 0),), (datetime.time(23, 59, 59, 999999),), (None,)], - 'interval': [ + "time_tz": [(datetime.time(0, 0),), (datetime.time(23, 59, 59, 999999),), (None,)], + "interval": [ (datetime.timedelta(0),), (datetime.timedelta(days=30969, seconds=999, microseconds=999999),), (None,), ], - 'timestamp': [(datetime.datetime(1990, 1, 1, 0, 0),)], - 'date': [(datetime.date(1990, 1, 1),)], - 'timestamp_s': [(datetime.datetime(1990, 1, 1, 0, 0),)], - 'timestamp_ns': [(datetime.datetime(1990, 1, 1, 0, 0),)], - 'timestamp_ms': [(datetime.datetime(1990, 1, 1, 0, 0),)], - 'timestamp_tz': [(datetime.datetime(1990, 1, 1, 0, 0, tzinfo=pytz.UTC),)], - 'union': [('Frank',), (5,), (None,)], - 'fixed_int_array': [((None, 2, 3),), ((4, 5, 6),), (None,)], - 'fixed_varchar_array': [(('a', None, 'c'),), (('d', 'e', 'f'),), (None,)], - 'fixed_nested_int_array': [ + "timestamp": [(datetime.datetime(1990, 1, 1, 0, 0),)], + "date": [(datetime.date(1990, 1, 1),)], + "timestamp_s": [(datetime.datetime(1990, 1, 1, 0, 0),)], + "timestamp_ns": [(datetime.datetime(1990, 1, 1, 0, 0),)], + "timestamp_ms": [(datetime.datetime(1990, 1, 1, 0, 0),)], + "timestamp_tz": [(datetime.datetime(1990, 1, 1, 0, 0, tzinfo=pytz.UTC),)], + "union": [("Frank",), (5,), (None,)], + "fixed_int_array": [((None, 2, 3),), ((4, 5, 6),), (None,)], + "fixed_varchar_array": [(("a", None, "c"),), (("d", "e", "f"),), (None,)], + "fixed_nested_int_array": [ (((None, 2, 3), None, (None, 2, 3)),), (((4, 5, 6), (None, 2, 3), (4, 5, 6)),), (None,), ], - 'fixed_nested_varchar_array': [ - ((('a', None, 'c'), None, ('a', None, 'c')),), - ((('d', 'e', 'f'), ('a', None, 'c'), ('d', 'e', 'f')),), + "fixed_nested_varchar_array": [ + ((("a", None, "c"), None, ("a", None, "c")),), + ((("d", "e", "f"), ("a", None, "c"), ("d", "e", "f")),), (None,), ], - 'fixed_struct_array': [ - (({'a': None, 'b': None}, {'a': 42, 'b': '🦆🦆🦆🦆🦆🦆'}, {'a': None, 'b': None}),), - (({'a': 42, 'b': '🦆🦆🦆🦆🦆🦆'}, {'a': None, 'b': None}, {'a': 42, 'b': '🦆🦆🦆🦆🦆🦆'}),), + "fixed_struct_array": [ + (({"a": None, "b": None}, {"a": 42, "b": "🦆🦆🦆🦆🦆🦆"}, {"a": None, "b": None}),), + (({"a": 42, "b": "🦆🦆🦆🦆🦆🦆"}, {"a": None, "b": None}, {"a": 42, "b": "🦆🦆🦆🦆🦆🦆"}),), (None,), ], - 'struct_of_fixed_array': [ - ({'a': (None, 2, 3), 'b': ('a', None, 'c')},), - ({'a': (4, 5, 6), 'b': ('d', 'e', 'f')},), + "struct_of_fixed_array": [ + ({"a": (None, 2, 3), "b": ("a", None, "c")},), + ({"a": (4, 5, 6), "b": ("d", "e", "f")},), (None,), ], - 'fixed_array_of_int_list': [ + "fixed_array_of_int_list": [ (([], [42, 999, None, None, -42], []),), (([42, 999, None, None, -42], [], [42, 999, None, None, -42]),), (None,), ], - 'list_of_fixed_int_array': [ + "list_of_fixed_int_array": [ ([(None, 2, 3), (4, 5, 6), (None, 2, 3)],), ([(4, 5, 6), (None, 2, 3), (4, 5, 6)],), (None,), @@ -278,14 +278,14 @@ def test_fetchall(self, cur_type): if cur_type in replacement_values: result = conn.execute("select " + replacement_values[cur_type]).fetchall() elif cur_type in adjusted_values: - result = conn.execute(f'select {adjusted_values[cur_type]} from test_all_types()').fetchall() + result = conn.execute(f"select {adjusted_values[cur_type]} from test_all_types()").fetchall() else: result = conn.execute(f'select "{cur_type}" from test_all_types()').fetchall() correct_result = correct_answer_map[cur_type] assert recursive_equality(result, correct_result) def test_bytearray_with_nulls(self): - con = duckdb.connect(database=':memory:') + con = duckdb.connect(database=":memory:") con.execute("CREATE TABLE test (content BLOB)") want = bytearray([1, 2, 0, 3, 4]) con.execute("INSERT INTO test VALUES (?)", [want]) @@ -295,90 +295,90 @@ def test_bytearray_with_nulls(self): # Don't truncate the array on the nullbyte assert want == bytearray(got) - @pytest.mark.parametrize('cur_type', all_types) + @pytest.mark.parametrize("cur_type", all_types) def test_fetchnumpy(self, cur_type): conn = duckdb.connect() correct_answer_map = { - 'bool': np.ma.array( + "bool": np.ma.array( [False, True, False], mask=[0, 0, 1], ), - 'tinyint': np.ma.array( + "tinyint": np.ma.array( [-128, 127, -1], mask=[0, 0, 1], dtype=np.int8, ), - 'smallint': np.ma.array( + "smallint": np.ma.array( [-32768, 32767, -1], mask=[0, 0, 1], dtype=np.int16, ), - 'int': np.ma.array( + "int": np.ma.array( [-2147483648, 2147483647, -1], mask=[0, 0, 1], dtype=np.int32, ), - 'bigint': np.ma.array( + "bigint": np.ma.array( [-9223372036854775808, 9223372036854775807, -1], mask=[0, 0, 1], dtype=np.int64, ), - 'utinyint': np.ma.array( + "utinyint": np.ma.array( [0, 255, 42], mask=[0, 0, 1], dtype=np.uint8, ), - 'usmallint': np.ma.array( + "usmallint": np.ma.array( [0, 65535, 42], mask=[0, 0, 1], dtype=np.uint16, ), - 'uint': np.ma.array( + "uint": np.ma.array( [0, 4294967295, 42], mask=[0, 0, 1], dtype=np.uint32, ), - 'ubigint': np.ma.array( + "ubigint": np.ma.array( [0, 18446744073709551615, 42], mask=[0, 0, 1], dtype=np.uint64, ), - 'float': np.ma.array( + "float": np.ma.array( [-3.4028234663852886e38, 3.4028234663852886e38, 42.0], mask=[0, 0, 1], dtype=np.float32, ), - 'double': np.ma.array( + "double": np.ma.array( [-1.7976931348623157e308, 1.7976931348623157e308, 42.0], mask=[0, 0, 1], dtype=np.float64, ), - 'uuid': np.ma.array( + "uuid": np.ma.array( [ - UUID('00000000-0000-0000-0000-000000000000'), - UUID('ffffffff-ffff-ffff-ffff-ffffffffffff'), - UUID('00000000-0000-0000-0000-000000000042'), + UUID("00000000-0000-0000-0000-000000000000"), + UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"), + UUID("00000000-0000-0000-0000-000000000042"), ], mask=[0, 0, 1], dtype=object, ), - 'varchar': np.ma.array( - ['🦆🦆🦆🦆🦆🦆', 'goo\0se', "42"], + "varchar": np.ma.array( + ["🦆🦆🦆🦆🦆🦆", "goo\0se", "42"], mask=[0, 0, 1], dtype=object, ), - 'json': np.ma.array( - ['🦆🦆🦆🦆🦆🦆', 'goose', "42"], + "json": np.ma.array( + ["🦆🦆🦆🦆🦆🦆", "goose", "42"], mask=[0, 0, 1], dtype=object, ), - 'blob': np.ma.array( - [b'thisisalongblob\x00withnullbytes', b'\x00\x00\x00a', b"42"], + "blob": np.ma.array( + [b"thisisalongblob\x00withnullbytes", b"\x00\x00\x00a", b"42"], mask=[0, 0, 1], dtype=object, ), - 'interval': np.ma.array( + "interval": np.ma.array( [ np.timedelta64(0), np.timedelta64(2675722599999999000), @@ -388,7 +388,7 @@ def test_fetchnumpy(self, cur_type): ), # For timestamp_ns, the lowest value is out-of-range for numpy, # such that the conversion yields "Not a Time" - 'timestamp_ns': np.ma.array( + "timestamp_ns": np.ma.array( [ np.datetime64("NaT"), np.datetime64(9223372036854775806, "ns"), @@ -397,21 +397,21 @@ def test_fetchnumpy(self, cur_type): mask=[0, 0, 1], ), # Enums don't have a numpy equivalent and yield pandas Categorical. - 'small_enum': pd.Categorical( - ['DUCK_DUCK_ENUM', 'GOOSE', np.nan], + "small_enum": pd.Categorical( + ["DUCK_DUCK_ENUM", "GOOSE", np.nan], ordered=True, ), - 'medium_enum': pd.Categorical( - ['enum_0', 'enum_299', np.nan], + "medium_enum": pd.Categorical( + ["enum_0", "enum_299", np.nan], ordered=True, ), - 'large_enum': pd.Categorical( - ['enum_0', 'enum_69999', np.nan], + "large_enum": pd.Categorical( + ["enum_0", "enum_69999", np.nan], ordered=True, ), # The following types don't have a numpy equivalent and yield # object arrays: - 'int_array': np.ma.array( + "int_array": np.ma.array( [ [], [42, 999, None, None, -42], @@ -420,25 +420,25 @@ def test_fetchnumpy(self, cur_type): mask=[0, 0, 1], dtype=object, ), - 'varchar_array': np.ma.array( + "varchar_array": np.ma.array( [ [], - ['🦆🦆🦆🦆🦆🦆', 'goose', None, ''], + ["🦆🦆🦆🦆🦆🦆", "goose", None, ""], None, ], mask=[0, 0, 1], dtype=object, ), - 'double_array': np.ma.array( + "double_array": np.ma.array( [ [], - [42.0, float('nan'), float('inf'), float('-inf'), None, -42.0], + [42.0, float("nan"), float("inf"), float("-inf"), None, -42.0], None, ], mask=[0, 0, 1], dtype=object, ), - 'nested_int_array': np.ma.array( + "nested_int_array": np.ma.array( [ [], [[], [42, 999, None, None, -42], None, [], [42, 999, None, None, -42]], @@ -447,53 +447,53 @@ def test_fetchnumpy(self, cur_type): mask=[0, 0, 1], dtype=object, ), - 'struct': np.ma.array( + "struct": np.ma.array( [ - {'a': None, 'b': None}, - {'a': 42, 'b': '🦆🦆🦆🦆🦆🦆'}, + {"a": None, "b": None}, + {"a": 42, "b": "🦆🦆🦆🦆🦆🦆"}, None, ], mask=[0, 0, 1], dtype=object, ), - 'struct_of_arrays': np.ma.array( + "struct_of_arrays": np.ma.array( [ - {'a': None, 'b': None}, - {'a': [42, 999, None, None, -42], 'b': ['🦆🦆🦆🦆🦆🦆', 'goose', None, '']}, + {"a": None, "b": None}, + {"a": [42, 999, None, None, -42], "b": ["🦆🦆🦆🦆🦆🦆", "goose", None, ""]}, None, ], mask=[0, 0, 1], dtype=object, ), - 'array_of_structs': np.ma.array( + "array_of_structs": np.ma.array( [ [], - [{'a': None, 'b': None}, {'a': 42, 'b': '🦆🦆🦆🦆🦆🦆'}, None], + [{"a": None, "b": None}, {"a": 42, "b": "🦆🦆🦆🦆🦆🦆"}, None], None, ], mask=[0, 0, 1], dtype=object, ), - 'map': np.ma.array( + "map": np.ma.array( [ {}, - {'key1': '🦆🦆🦆🦆🦆🦆', 'key2': 'goose'}, + {"key1": "🦆🦆🦆🦆🦆🦆", "key2": "goose"}, None, ], mask=[0, 0, 1], dtype=object, ), - 'time': np.ma.array( - ['00:00:00', '24:00:00', None], + "time": np.ma.array( + ["00:00:00", "24:00:00", None], mask=[0, 0, 1], dtype=object, ), - 'time_tz': np.ma.array( - ['00:00:00', '23:59:59.999999', None], + "time_tz": np.ma.array( + ["00:00:00", "23:59:59.999999", None], mask=[0, 0, 1], dtype=object, ), - 'union': np.ma.array(['Frank', 5, None], mask=[0, 0, 1], dtype=object), + "union": np.ma.array(["Frank", 5, None], mask=[0, 0, 1], dtype=object), } correct_answer_map = replace_with_ndarray(correct_answer_map) @@ -535,19 +535,19 @@ def test_fetchnumpy(self, cur_type): assert np.all(result.mask == correct_answer.mask) np.testing.assert_equal(result, correct_answer) - @pytest.mark.parametrize('cur_type', all_types) + @pytest.mark.parametrize("cur_type", all_types) def test_arrow(self, cur_type): try: import pyarrow as pa except: return # We skip those since the extreme ranges are not supported in arrow. - replacement_values = {'interval': "INTERVAL '2 years'"} + replacement_values = {"interval": "INTERVAL '2 years'"} # We do not round trip enum types - enum_types = {'small_enum', 'medium_enum', 'large_enum', 'double_array'} + enum_types = {"small_enum", "medium_enum", "large_enum", "double_array"} # uhugeint currently not supported by arrow - skip_types = {'uhugeint'} + skip_types = {"uhugeint"} if cur_type in skip_types: return @@ -565,33 +565,33 @@ def test_arrow(self, cur_type): round_trip_arrow_table = conn.execute("select * from arrow_table").fetch_arrow_table() assert arrow_table.equals(round_trip_arrow_table, check_metadata=True) - @pytest.mark.parametrize('cur_type', all_types) + @pytest.mark.parametrize("cur_type", all_types) def test_pandas(self, cur_type): # We skip those since the extreme ranges are not supported in python. replacement_values = { - 'timestamp': "'1990-01-01 00:00:00'::TIMESTAMP", - 'timestamp_s': "'1990-01-01 00:00:00'::TIMESTAMP_S", - 'timestamp_ns': "'1990-01-01 00:00:00'::TIMESTAMP_NS", - 'timestamp_ms': "'1990-01-01 00:00:00'::TIMESTAMP_MS", - 'timestamp_tz': "'1990-01-01 00:00:00Z'::TIMESTAMPTZ", - 'date': "'1990-01-01'::DATE", - 'date_array': "[], ['1970-01-01'::DATE, NULL, '0001-01-01'::DATE, '9999-12-31'::DATE,], [NULL::DATE,]", - 'timestamp_array': "[], ['1970-01-01'::TIMESTAMP, NULL, '0001-01-01'::TIMESTAMP, '9999-12-31 23:59:59.999999'::TIMESTAMP,], [NULL::TIMESTAMP,]", - 'timestamptz_array': "[], ['1970-01-01 00:00:00Z'::TIMESTAMPTZ, NULL, '0001-01-01 00:00:00Z'::TIMESTAMPTZ, '9999-12-31 23:59:59.999999Z'::TIMESTAMPTZ,], [NULL::TIMESTAMPTZ,]", + "timestamp": "'1990-01-01 00:00:00'::TIMESTAMP", + "timestamp_s": "'1990-01-01 00:00:00'::TIMESTAMP_S", + "timestamp_ns": "'1990-01-01 00:00:00'::TIMESTAMP_NS", + "timestamp_ms": "'1990-01-01 00:00:00'::TIMESTAMP_MS", + "timestamp_tz": "'1990-01-01 00:00:00Z'::TIMESTAMPTZ", + "date": "'1990-01-01'::DATE", + "date_array": "[], ['1970-01-01'::DATE, NULL, '0001-01-01'::DATE, '9999-12-31'::DATE,], [NULL::DATE,]", + "timestamp_array": "[], ['1970-01-01'::TIMESTAMP, NULL, '0001-01-01'::TIMESTAMP, '9999-12-31 23:59:59.999999'::TIMESTAMP,], [NULL::TIMESTAMP,]", + "timestamptz_array": "[], ['1970-01-01 00:00:00Z'::TIMESTAMPTZ, NULL, '0001-01-01 00:00:00Z'::TIMESTAMPTZ, '9999-12-31 23:59:59.999999Z'::TIMESTAMPTZ,], [NULL::TIMESTAMPTZ,]", } adjusted_values = { - 'time': """CASE WHEN "time" = '24:00:00'::TIME THEN '23:59:59.999999'::TIME ELSE "time" END AS "time" """, + "time": """CASE WHEN "time" = '24:00:00'::TIME THEN '23:59:59.999999'::TIME ELSE "time" END AS "time" """, } conn = duckdb.connect() # Pandas <= 2.2.3 does not convert without throwing a warning conn.execute("SET timezone = UTC") - warnings.simplefilter(action='ignore', category=RuntimeWarning) + warnings.simplefilter(action="ignore", category=RuntimeWarning) with suppress(TypeError): if cur_type in replacement_values: dataframe = conn.execute("select " + replacement_values[cur_type]).df() elif cur_type in adjusted_values: - dataframe = conn.execute(f'select {adjusted_values[cur_type]} from test_all_types()').df() + dataframe = conn.execute(f"select {adjusted_values[cur_type]} from test_all_types()").df() else: dataframe = conn.execute(f'select "{cur_type}" from test_all_types()').df() print(cur_type) diff --git a/tests/fast/test_case_alias.py b/tests/fast/test_case_alias.py index 4fcbd49c..2e42f0ed 100644 --- a/tests/fast/test_case_alias.py +++ b/tests/fast/test_case_alias.py @@ -7,35 +7,35 @@ class TestCaseAlias(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_case_alias(self, duckdb_cursor, pandas): import numpy as np import datetime import duckdb - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") df = pandas.DataFrame([{"COL1": "val1", "CoL2": 1.05}, {"COL1": "val3", "CoL2": 17}]) - r1 = con.from_df(df).query('df', 'select * from df').df() + r1 = con.from_df(df).query("df", "select * from df").df() assert r1["COL1"][0] == "val1" assert r1["COL1"][1] == "val3" assert r1["CoL2"][0] == 1.05 assert r1["CoL2"][1] == 17 - r2 = con.from_df(df).query('df', 'select COL1, COL2 from df').df() + r2 = con.from_df(df).query("df", "select COL1, COL2 from df").df() assert r2["COL1"][0] == "val1" assert r2["COL1"][1] == "val3" assert r2["CoL2"][0] == 1.05 assert r2["CoL2"][1] == 17 - r3 = con.from_df(df).query('df', 'select COL1, COL2 from df ORDER BY COL1').df() + r3 = con.from_df(df).query("df", "select COL1, COL2 from df ORDER BY COL1").df() assert r3["COL1"][0] == "val1" assert r3["COL1"][1] == "val3" assert r3["CoL2"][0] == 1.05 assert r3["CoL2"][1] == 17 - r4 = con.from_df(df).query('df', 'select COL1, COL2 from df GROUP BY COL1, COL2 ORDER BY COL1').df() + r4 = con.from_df(df).query("df", "select COL1, COL2 from df GROUP BY COL1, COL2 ORDER BY COL1").df() assert r4["COL1"][0] == "val1" assert r4["COL1"][1] == "val3" assert r4["CoL2"][0] == 1.05 diff --git a/tests/fast/test_context_manager.py b/tests/fast/test_context_manager.py index 2ac451d1..65ec1d33 100644 --- a/tests/fast/test_context_manager.py +++ b/tests/fast/test_context_manager.py @@ -3,5 +3,5 @@ class TestContextManager(object): def test_context_manager(self, duckdb_cursor): - with duckdb.connect(database=':memory:', read_only=False) as con: + with duckdb.connect(database=":memory:", read_only=False) as con: assert con.execute("select 1").fetchall() == [(1,)] diff --git a/tests/fast/test_duckdb_api.py b/tests/fast/test_duckdb_api.py index f5dcfb60..ea847d50 100644 --- a/tests/fast/test_duckdb_api.py +++ b/tests/fast/test_duckdb_api.py @@ -5,4 +5,4 @@ def test_duckdb_api(): res = duckdb.execute("SELECT name, value FROM duckdb_settings() WHERE name == 'duckdb_api'") formatted_python_version = f"{sys.version_info.major}.{sys.version_info.minor}" - assert res.fetchall() == [('duckdb_api', f'python/{formatted_python_version}')] + assert res.fetchall() == [("duckdb_api", f"python/{formatted_python_version}")] diff --git a/tests/fast/test_expression.py b/tests/fast/test_expression.py index e0f830c5..82753382 100644 --- a/tests/fast/test_expression.py +++ b/tests/fast/test_expression.py @@ -21,7 +21,7 @@ ) -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def filter_rel(): con = duckdb.connect() rel = con.sql( @@ -59,7 +59,7 @@ def test_constant_expression(self): res = rel.fetchall() assert res == [(5,)] - @pytest.mark.skipif(platform.system() == 'Windows', reason="There is some weird interaction in Windows CI") + @pytest.mark.skipif(platform.system() == "Windows", reason="There is some weird interaction in Windows CI") def test_column_expression(self): con = duckdb.connect() @@ -71,12 +71,12 @@ def test_column_expression(self): 3 as c """ ) - column = ColumnExpression('a') + column = ColumnExpression("a") rel2 = rel.select(column) res = rel2.fetchall() assert res == [(1,)] - column = ColumnExpression('d') + column = ColumnExpression("d") with pytest.raises(duckdb.BinderException, match='Referenced column "d" not found'): rel2 = rel.select(column) @@ -89,9 +89,9 @@ def test_coalesce_operator(self): """ ) - rel2 = rel.select(CoalesceOperator(ConstantExpression(None), ConstantExpression('hello').cast(int))) + rel2 = rel.select(CoalesceOperator(ConstantExpression(None), ConstantExpression("hello").cast(int))) res = rel2.explain() - assert 'COALESCE' in res + assert "COALESCE" in res with pytest.raises(duckdb.ConversionException, match="Could not convert string 'hello' to INT64"): rel2.fetchall() @@ -103,7 +103,7 @@ def test_coalesce_operator(self): """ ) - with pytest.raises(duckdb.InvalidInputException, match='Please provide at least one argument'): + with pytest.raises(duckdb.InvalidInputException, match="Please provide at least one argument"): rel3 = rel.select(CoalesceOperator()) rel4 = rel.select(CoalesceOperator(ConstantExpression(None))) @@ -112,7 +112,7 @@ def test_coalesce_operator(self): rel5 = rel.select(CoalesceOperator(ConstantExpression(42))) assert rel5.fetchone() == (42,) - exprtest = con.table('exprtest') + exprtest = con.table("exprtest") rel6 = exprtest.select(CoalesceOperator(ColumnExpression("a"))) res = rel6.fetchall() assert res == [(42,), (43,), (None,), (45,)] @@ -193,17 +193,17 @@ def test_column_expression_explain(self): """ ) rel = rel.select( - ConstantExpression("a").alias('c0'), - ConstantExpression(42).alias('c1'), - ConstantExpression(None).alias('c2'), + ConstantExpression("a").alias("c0"), + ConstantExpression(42).alias("c1"), + ConstantExpression(None).alias("c2"), ) res = rel.explain() - assert 'c0' in res - assert 'c1' in res + assert "c0" in res + assert "c1" in res # 'c2' is not in the explain result because it shows NULL instead - assert 'NULL' in res + assert "NULL" in res res = rel.fetchall() - assert res == [('a', 42, None)] + assert res == [("a", 42, None)] def test_column_expression_table(self): con = duckdb.connect() @@ -219,10 +219,10 @@ def test_column_expression_table(self): """ ) - rel = con.table('tbl') - rel2 = rel.select('c0', 'c1', 'c2') + rel = con.table("tbl") + rel2 = rel.select("c0", "c1", "c2") res = rel2.fetchall() - assert res == [('a', 'b', 'c'), ('d', 'e', 'f'), ('g', 'h', 'i')] + assert res == [("a", "b", "c"), ("d", "e", "f"), ("g", "h", "i")] def test_column_expression_view(self): con = duckdb.connect() @@ -241,18 +241,18 @@ def test_column_expression_view(self): CREATE VIEW v1 as select c0 as c3, c2 as c4 from tbl; """ ) - rel = con.view('v1') - rel2 = rel.select('c3', 'c4') + rel = con.view("v1") + rel2 = rel.select("c3", "c4") res = rel2.fetchall() - assert res == [('a', 'c'), ('d', 'f'), ('g', 'i')] + assert res == [("a", "c"), ("d", "f"), ("g", "i")] def test_column_expression_replacement_scan(self): con = duckdb.connect() pd = pytest.importorskip("pandas") - df = pd.DataFrame({'a': [42, 43, 0], 'b': [True, False, True], 'c': [23.123, 623.213, 0.30234]}) + df = pd.DataFrame({"a": [42, 43, 0], "b": [True, False, True], "c": [23.123, 623.213, 0.30234]}) rel = con.sql("select * from df") - rel2 = rel.select('a', 'b') + rel2 = rel.select("a", "b") res = rel2.fetchall() assert res == [(42, True), (43, False), (0, True)] @@ -271,7 +271,7 @@ def test_add_operator(self): ) constant = ConstantExpression(val) - col = ColumnExpression('b') + col = ColumnExpression("b") expr = col + constant rel = rel.select(expr, expr) @@ -288,7 +288,7 @@ def test_binary_function_expression(self): 5 as b """ ) - function = FunctionExpression("-", ColumnExpression('b'), ColumnExpression('a')) + function = FunctionExpression("-", ColumnExpression("b"), ColumnExpression("a")) rel2 = rel.select(function) res = rel2.fetchall() assert res == [(4,)] @@ -301,7 +301,7 @@ def test_negate_expression(self): select 5 as a """ ) - col = ColumnExpression('a') + col = ColumnExpression("a") col = -col rel = rel.select(col) res = rel.fetchall() @@ -317,8 +317,8 @@ def test_subtract_expression(self): 1 as b """ ) - col1 = ColumnExpression('a') - col2 = ColumnExpression('b') + col1 = ColumnExpression("a") + col2 = ColumnExpression("b") expr = col1 - col2 rel2 = rel.select(expr) res = rel2.fetchall() @@ -337,8 +337,8 @@ def test_multiply_expression(self): 2 as b """ ) - col1 = ColumnExpression('a') - col2 = ColumnExpression('b') + col1 = ColumnExpression("a") + col2 = ColumnExpression("b") expr = col1 * col2 rel = rel.select(expr) res = rel.fetchall() @@ -354,8 +354,8 @@ def test_division_expression(self): 2 as b """ ) - col1 = ColumnExpression('a') - col2 = ColumnExpression('b') + col1 = ColumnExpression("a") + col2 = ColumnExpression("b") expr = col1 / col2 rel2 = rel.select(expr) res = rel2.fetchall() @@ -376,8 +376,8 @@ def test_modulus_expression(self): 2 as b """ ) - col1 = ColumnExpression('a') - col2 = ColumnExpression('b') + col1 = ColumnExpression("a") + col2 = ColumnExpression("b") expr = col1 % col2 rel2 = rel.select(expr) res = rel2.fetchall() @@ -393,8 +393,8 @@ def test_power_expression(self): 2 as b """ ) - col1 = ColumnExpression('a') - col2 = ColumnExpression('b') + col1 = ColumnExpression("a") + col2 = ColumnExpression("b") expr = col1**col2 rel2 = rel.select(expr) res = rel2.fetchall() @@ -411,9 +411,9 @@ def test_between_expression(self): 3 as c """ ) - a = ColumnExpression('a') - b = ColumnExpression('b') - c = ColumnExpression('c') + a = ColumnExpression("a") + b = ColumnExpression("b") + c = ColumnExpression("c") # 5 BETWEEN 2 AND 3 -> false assert rel.select(a.between(b, c)).fetchall() == [(False,)] @@ -437,32 +437,32 @@ def test_collate_expression(self): """ ) - col1 = ColumnExpression('c0') - col2 = ColumnExpression('c1') + col1 = ColumnExpression("c0") + col2 = ColumnExpression("c1") - lower_a = ConstantExpression('a') - upper_a = ConstantExpression('A') + lower_a = ConstantExpression("a") + upper_a = ConstantExpression("A") # SELECT c0 LIKE 'a' == True - assert rel.select(FunctionExpression('~~', col1, lower_a)).fetchall() == [(True,)] + assert rel.select(FunctionExpression("~~", col1, lower_a)).fetchall() == [(True,)] # SELECT c0 LIKE 'A' == False - assert rel.select(FunctionExpression('~~', col1, upper_a)).fetchall() == [(False,)] + assert rel.select(FunctionExpression("~~", col1, upper_a)).fetchall() == [(False,)] # SELECT c0 LIKE 'A' COLLATE NOCASE == True - assert rel.select(FunctionExpression('~~', col1, upper_a.collate('NOCASE'))).fetchall() == [(True,)] + assert rel.select(FunctionExpression("~~", col1, upper_a.collate("NOCASE"))).fetchall() == [(True,)] # SELECT c1 LIKE 'a' == False - assert rel.select(FunctionExpression('~~', col2, lower_a)).fetchall() == [(False,)] + assert rel.select(FunctionExpression("~~", col2, lower_a)).fetchall() == [(False,)] # SELECT c1 LIKE 'a' COLLATE NOCASE == True - assert rel.select(FunctionExpression('~~', col2, lower_a.collate('NOCASE'))).fetchall() == [(True,)] + assert rel.select(FunctionExpression("~~", col2, lower_a.collate("NOCASE"))).fetchall() == [(True,)] - with pytest.raises(duckdb.BinderException, match='collations are only supported for type varchar'): - rel.select(FunctionExpression('~~', col2, lower_a).collate('NOCASE')) + with pytest.raises(duckdb.BinderException, match="collations are only supported for type varchar"): + rel.select(FunctionExpression("~~", col2, lower_a).collate("NOCASE")) - with pytest.raises(duckdb.CatalogException, match='Collation with name non-existant does not exist'): - rel.select(FunctionExpression('~~', col2, lower_a.collate('non-existant'))) + with pytest.raises(duckdb.CatalogException, match="Collation with name non-existant does not exist"): + rel.select(FunctionExpression("~~", col2, lower_a.collate("non-existant"))) def test_equality_expression(self): con = duckdb.connect() @@ -475,9 +475,9 @@ def test_equality_expression(self): 5 as c """ ) - col1 = ColumnExpression('a') - col2 = ColumnExpression('b') - col3 = ColumnExpression('c') + col1 = ColumnExpression("a") + col2 = ColumnExpression("b") + col3 = ColumnExpression("c") expr1 = col1 == col2 expr2 = col1 == col3 rel2 = rel.select(expr1, expr2) @@ -497,22 +497,22 @@ def test_lambda_expression(self): # Use a tuple of strings as 'lhs' func = FunctionExpression( "list_reduce", - ColumnExpression('a'), - LambdaExpression(('x', 'y'), ColumnExpression('x') + ColumnExpression('y')), + ColumnExpression("a"), + LambdaExpression(("x", "y"), ColumnExpression("x") + ColumnExpression("y")), ) rel2 = rel.select(func) res = rel2.fetchall() assert res == [(6,)] # Use only a string name as 'lhs' - func = FunctionExpression("list_apply", ColumnExpression('a'), LambdaExpression('x', ColumnExpression('x') + 3)) + func = FunctionExpression("list_apply", ColumnExpression("a"), LambdaExpression("x", ColumnExpression("x") + 3)) rel2 = rel.select(func) res = rel2.fetchall() assert res == [([4, 5, 6],)] # 'row' is not a lambda function, so it doesn't accept a lambda expression - func = FunctionExpression("row", ColumnExpression('a'), LambdaExpression('x', ColumnExpression('x') + 3)) - with pytest.raises(duckdb.BinderException, match='This scalar function does not support lambdas'): + func = FunctionExpression("row", ColumnExpression("a"), LambdaExpression("x", ColumnExpression("x") + 3)) + with pytest.raises(duckdb.BinderException, match="This scalar function does not support lambdas"): rel2 = rel.select(func) # lhs has to be a tuple of strings or a single string @@ -520,11 +520,11 @@ def test_lambda_expression(self): ValueError, match="Please provide 'lhs' as either a tuple containing strings, or a single string" ): func = FunctionExpression( - "list_filter", ColumnExpression('a'), LambdaExpression(42, ColumnExpression('x') + 3) + "list_filter", ColumnExpression("a"), LambdaExpression(42, ColumnExpression("x") + 3) ) func = FunctionExpression( - "list_filter", ColumnExpression('a'), LambdaExpression('x', ColumnExpression('y') != 3) + "list_filter", ColumnExpression("a"), LambdaExpression("x", ColumnExpression("y") != 3) ) with pytest.raises(duckdb.BinderException, match='Referenced column "y" not found in FROM clause'): rel2 = rel.select(func) @@ -540,9 +540,9 @@ def test_inequality_expression(self): 5 as c """ ) - col1 = ColumnExpression('a') - col2 = ColumnExpression('b') - col3 = ColumnExpression('c') + col1 = ColumnExpression("a") + col2 = ColumnExpression("b") + col3 = ColumnExpression("c") expr1 = col1 != col2 expr2 = col1 != col3 rel2 = rel.select(expr1, expr2) @@ -561,10 +561,10 @@ def test_comparison_expressions(self): 3 as d """ ) - col1 = ColumnExpression('a') - col2 = ColumnExpression('b') - col3 = ColumnExpression('c') - col4 = ColumnExpression('d') + col1 = ColumnExpression("a") + col2 = ColumnExpression("b") + col3 = ColumnExpression("c") + col4 = ColumnExpression("d") # Greater than expr1 = col1 > col2 @@ -606,11 +606,11 @@ def test_expression_alias(self): select 1 as a """ ) - col = ColumnExpression('a') - col = col.alias('b') + col = ColumnExpression("a") + col = col.alias("b") rel2 = rel.select(col) - assert rel2.columns == ['b'] + assert rel2.columns == ["b"] def test_star_expression(self): con = duckdb.connect() @@ -628,7 +628,7 @@ def test_star_expression(self): assert res == [(1, 2)] # With exclude list - star = StarExpression(exclude=['a']) + star = StarExpression(exclude=["a"]) rel2 = rel.select(star) res = rel2.fetchall() assert res == [(2,)] @@ -644,13 +644,13 @@ def test_struct_expression(self): """ ) - col1 = ColumnExpression('a') - col2 = ColumnExpression('b') - expr = FunctionExpression('struct_pack', col1, col2).alias('struct') + col1 = ColumnExpression("a") + col2 = ColumnExpression("b") + expr = FunctionExpression("struct_pack", col1, col2).alias("struct") rel = rel.select(expr) res = rel.fetchall() - assert res == [({'a': 1, 'b': 2},)] + assert res == [({"a": 1, "b": 2},)] def test_function_expression_udf(self): con = duckdb.connect() @@ -658,7 +658,7 @@ def test_function_expression_udf(self): def my_simple_func(a: int, b: int, c: int) -> int: return a + b + c - con.create_function('my_func', my_simple_func) + con.create_function("my_func", my_simple_func) rel = con.sql( """ @@ -668,10 +668,10 @@ def my_simple_func(a: int, b: int, c: int) -> int: 3 as c """ ) - col1 = ColumnExpression('a') - col2 = ColumnExpression('b') - col3 = ColumnExpression('c') - expr = FunctionExpression('my_func', col1, col2, col3) + col1 = ColumnExpression("a") + col2 = ColumnExpression("b") + col3 = ColumnExpression("c") + expr = FunctionExpression("my_func", col1, col2, col3) rel2 = rel.select(expr) res = rel2.fetchall() assert res == [(6,)] @@ -688,10 +688,10 @@ def test_function_expression_basic(self): ) tbl(text, start, "end") """ ) - expr = FunctionExpression('array_slice', "start", "text", "end") + expr = FunctionExpression("array_slice", "start", "text", "end") rel2 = rel.select(expr) res = rel2.fetchall() - assert res == [('tes',), ('his is',), ('di',)] + assert res == [("tes",), ("his is",), ("di",)] def test_column_expression_function_coverage(self): con = duckdb.connect() @@ -707,11 +707,11 @@ def test_column_expression_function_coverage(self): """ ) - rel = con.table('tbl') - expr = FunctionExpression('||', FunctionExpression('||', 'c0', 'c1'), 'c2') + rel = con.table("tbl") + expr = FunctionExpression("||", FunctionExpression("||", "c0", "c1"), "c2") rel2 = rel.select(expr) res = rel2.fetchall() - assert res == [('abc',), ('def',), ('ghi',)] + assert res == [("abc",), ("def",), ("ghi",)] def test_function_expression_aggregate(self): con = duckdb.connect() @@ -725,9 +725,9 @@ def test_function_expression_aggregate(self): ) tbl(text) """ ) - expr = FunctionExpression('first', 'text') + expr = FunctionExpression("first", "text") with pytest.raises( - duckdb.BinderException, match='Binder Error: Aggregates cannot be present in a Project relation!' + duckdb.BinderException, match="Binder Error: Aggregates cannot be present in a Project relation!" ): rel2 = rel.select(expr) @@ -743,9 +743,9 @@ def test_case_expression(self): """ ) - col1 = ColumnExpression('a') - col2 = ColumnExpression('b') - col3 = ColumnExpression('c') + col1 = ColumnExpression("a") + col2 = ColumnExpression("b") + col3 = ColumnExpression("c") const1 = ConstantExpression(IntegerValue(1)) # CASE WHEN col1 > 1 THEN 5 ELSE NULL @@ -796,7 +796,7 @@ def test_implicit_constant_conversion(self): def test_numeric_overflow(self): con = duckdb.connect() - rel = con.sql('select 3000::SHORT salary') + rel = con.sql("select 3000::SHORT salary") with pytest.raises(duckdb.OutOfRangeException, match="Overflow in multiplication of INT16"): expr = ColumnExpression("salary") * 100 rel2 = rel.select(expr) @@ -823,7 +823,7 @@ def test_filter_equality(self, filter_rel): rel2 = filter_rel.filter(expr) res = rel2.fetchall() assert len(res) == 2 - assert res == [(1, 'a'), (1, 'b')] + assert res == [(1, "a"), (1, "b")] def test_filter_not(self, filter_rel): expr = ColumnExpression("a") == 1 @@ -832,18 +832,18 @@ def test_filter_not(self, filter_rel): rel2 = filter_rel.filter(expr) res = rel2.fetchall() assert len(res) == 3 - assert res == [(2, 'b'), (3, 'c'), (4, 'a')] + assert res == [(2, "b"), (3, "c"), (4, "a")] def test_filter_and(self, filter_rel): expr = ColumnExpression("a") == 1 expr = ~expr # AND operator - expr = expr & ('b' != ConstantExpression('b')) + expr = expr & ("b" != ConstantExpression("b")) rel2 = filter_rel.filter(expr) res = rel2.fetchall() assert len(res) == 2 - assert res == [(3, 'c'), (4, 'a')] + assert res == [(3, "c"), (4, "a")] def test_filter_or(self, filter_rel): # OR operator @@ -851,7 +851,7 @@ def test_filter_or(self, filter_rel): rel2 = filter_rel.filter(expr) res = rel2.fetchall() assert len(res) == 3 - assert res == [(1, 'a'), (1, 'b'), (4, 'a')] + assert res == [(1, "a"), (1, "b"), (4, "a")] def test_filter_mixed(self, filter_rel): # Mixed @@ -861,7 +861,7 @@ def test_filter_mixed(self, filter_rel): rel2 = filter_rel.filter(expr) res = rel2.fetchall() assert len(res) == 2 - assert res == [(1, 'a'), (4, 'a')] + assert res == [(1, "a"), (4, "a")] def test_empty_in(self, filter_rel): expr = ColumnExpression("a") @@ -884,7 +884,7 @@ def test_filter_in(self, filter_rel): rel2 = filter_rel.filter(expr) res = rel2.fetchall() assert len(res) == 3 - assert res == [(1, 'a'), (2, 'b'), (1, 'b')] + assert res == [(1, "a"), (2, "b"), (1, "b")] def test_filter_not_in(self, filter_rel): expr = ColumnExpression("a") @@ -894,7 +894,7 @@ def test_filter_not_in(self, filter_rel): rel2 = filter_rel.filter(expr) res = rel2.fetchall() assert len(res) == 2 - assert res == [(3, 'c'), (4, 'a')] + assert res == [(3, "c"), (4, "a")] # NOT IN expression expr = ColumnExpression("a") @@ -902,7 +902,7 @@ def test_filter_not_in(self, filter_rel): rel2 = filter_rel.filter(expr) res = rel2.fetchall() assert len(res) == 2 - assert res == [(3, 'c'), (4, 'a')] + assert res == [(3, "c"), (4, "a")] def test_null(self): con = duckdb.connect() @@ -924,7 +924,7 @@ def test_null(self): assert res == [(False,), (False,), (True,), (False,), (False,)] res2 = rel.filter(b.isnotnull()).fetchall() - assert res2 == [(1, 'a'), (2, 'b'), (4, 'c'), (5, 'a')] + assert res2 == [(1, "a"), (2, "b"), (4, "c"), (5, "a")] def test_sort(self): con = duckdb.connect() @@ -956,12 +956,12 @@ def test_sort(self): # Nulls first rel2 = rel.sort(b.desc().nulls_first()) res = rel2.b.fetchall() - assert res == [(None,), ('c',), ('b',), ('a',), ('a',)] + assert res == [(None,), ("c",), ("b",), ("a",), ("a",)] # Nulls last rel2 = rel.sort(b.desc().nulls_last()) res = rel2.b.fetchall() - assert res == [('c',), ('b',), ('a',), ('a',), (None,)] + assert res == [("c",), ("b",), ("a",), ("a",), (None,)] def test_aggregate(self): con = duckdb.connect() @@ -983,7 +983,7 @@ def test_aggregate_error(self): # Providing something that can not be converted into an expression is an error: with pytest.raises( - duckdb.InvalidInputException, match='Invalid Input Error: Please provide arguments of type Expression!' + duckdb.InvalidInputException, match="Invalid Input Error: Please provide arguments of type Expression!" ): class MyClass: diff --git a/tests/fast/test_filesystem.py b/tests/fast/test_filesystem.py index 195de165..7b8fbb05 100644 --- a/tests/fast/test_filesystem.py +++ b/tests/fast/test_filesystem.py @@ -10,12 +10,12 @@ from duckdb import DuckDBPyConnection, InvalidInputException from pytest import raises, importorskip, fixture, MonkeyPatch, mark -importorskip('fsspec', '2022.11.0') +importorskip("fsspec", "2022.11.0") from fsspec import filesystem, AbstractFileSystem from fsspec.implementations.memory import MemoryFileSystem from fsspec.implementations.local import LocalFileOpener, LocalFileSystem -FILENAME = 'integers.csv' +FILENAME = "integers.csv" logging.basicConfig(level=logging.DEBUG) @@ -43,11 +43,11 @@ def duckdb_cursor(): @fixture() def memory(): - fs = filesystem('memory', skip_instance_cache=True) + fs = filesystem("memory", skip_instance_cache=True) # ensure each instance is independent (to work around a weird quirk in fsspec) fs.store = {} - fs.pseudo_dirs = [''] + fs.pseudo_dirs = [""] # copy csv into memory filesystem add_file(fs) @@ -55,39 +55,39 @@ def memory(): def add_file(fs, filename=FILENAME): - with (Path(__file__).parent / 'data' / filename).open('rb') as source, fs.open(filename, 'wb') as dest: + with (Path(__file__).parent / "data" / filename).open("rb") as source, fs.open(filename, "wb") as dest: copyfileobj(source, dest) class TestPythonFilesystem: def test_unregister_non_existent_filesystem(self, duckdb_cursor: DuckDBPyConnection): with raises(InvalidInputException): - duckdb_cursor.unregister_filesystem('fake') + duckdb_cursor.unregister_filesystem("fake") def test_memory_filesystem(self, duckdb_cursor: DuckDBPyConnection, memory: AbstractFileSystem): duckdb_cursor.register_filesystem(memory) - assert memory.protocol == 'memory' + assert memory.protocol == "memory" duckdb_cursor.execute(f"select * from 'memory://{FILENAME}'") assert duckdb_cursor.fetchall() == [(1, 10, 0), (2, 50, 30)] - duckdb_cursor.unregister_filesystem('memory') + duckdb_cursor.unregister_filesystem("memory") def test_reject_abstract_filesystem(self, duckdb_cursor: DuckDBPyConnection): with raises(InvalidInputException): duckdb_cursor.register_filesystem(AbstractFileSystem()) def test_unregister_builtin(self, require: Callable[[str], DuckDBPyConnection]): - duckdb_cursor = require('httpfs') - assert duckdb_cursor.filesystem_is_registered('S3FileSystem') == True - duckdb_cursor.unregister_filesystem('S3FileSystem') - assert duckdb_cursor.filesystem_is_registered('S3FileSystem') == False + duckdb_cursor = require("httpfs") + assert duckdb_cursor.filesystem_is_registered("S3FileSystem") == True + duckdb_cursor.unregister_filesystem("S3FileSystem") + assert duckdb_cursor.filesystem_is_registered("S3FileSystem") == False def test_multiple_protocol_filesystems(self, duckdb_cursor: DuckDBPyConnection): class ExtendedMemoryFileSystem(MemoryFileSystem): - protocol = ('file', 'local') + protocol = ("file", "local") # defer to the original implementation that doesn't hardcode the protocol _strip_protocol = classmethod(AbstractFileSystem._strip_protocol.__func__) @@ -104,51 +104,51 @@ def test_write(self, duckdb_cursor: DuckDBPyConnection, memory: AbstractFileSyst duckdb_cursor.execute("copy (select 1) to 'memory://01.csv' (FORMAT CSV, HEADER 0)") - assert memory.open('01.csv').read() == b'1\n' + assert memory.open("01.csv").read() == b"1\n" def test_null_bytes(self, duckdb_cursor: DuckDBPyConnection, memory: AbstractFileSystem): - with memory.open('test.csv', 'wb') as fh: - fh.write(b'hello\n\0world\0') + with memory.open("test.csv", "wb") as fh: + fh.write(b"hello\n\0world\0") duckdb_cursor.register_filesystem(memory) - duckdb_cursor.execute('select * from read_csv("memory://test.csv", header = 0, quote = \'"\', escape = \'"\')') + duckdb_cursor.execute("select * from read_csv(\"memory://test.csv\", header = 0, quote = '\"', escape = '\"')") - assert duckdb_cursor.fetchall() == [('hello',), ('\0world\0',)] + assert duckdb_cursor.fetchall() == [("hello",), ("\0world\0",)] def test_read_parquet(self, duckdb_cursor: DuckDBPyConnection, memory: AbstractFileSystem): - filename = 'binary_string.parquet' + filename = "binary_string.parquet" add_file(memory, filename) duckdb_cursor.register_filesystem(memory) duckdb_cursor.execute(f"select * from read_parquet('memory://{filename}')") - assert duckdb_cursor.fetchall() == [(b'foo',), (b'bar',), (b'baz',)] + assert duckdb_cursor.fetchall() == [(b"foo",), (b"bar",), (b"baz",)] def test_write_parquet(self, duckdb_cursor: DuckDBPyConnection, memory: AbstractFileSystem): duckdb_cursor.register_filesystem(memory) - filename = 'output.parquet' + filename = "output.parquet" - duckdb_cursor.execute(f'''COPY (SELECT 1) TO 'memory://{filename}' (FORMAT PARQUET);''') + duckdb_cursor.execute(f"""COPY (SELECT 1) TO 'memory://{filename}' (FORMAT PARQUET);""") - assert memory.open(filename).read().startswith(b'PAR1') + assert memory.open(filename).read().startswith(b"PAR1") def test_when_fsspec_not_installed(self, duckdb_cursor: DuckDBPyConnection, monkeypatch: MonkeyPatch): - monkeypatch.setitem(sys.modules, 'fsspec', None) + monkeypatch.setitem(sys.modules, "fsspec", None) with raises(ModuleNotFoundError): duckdb_cursor.register_filesystem(None) @mark.skipif(sys.version_info < (3, 8), reason="ArrowFSWrapper requires python 3.8 or higher") def test_arrow_fs_wrapper(self, tmp_path: Path, duckdb_cursor: DuckDBPyConnection): - fs = importorskip('pyarrow.fs') + fs = importorskip("pyarrow.fs") from fsspec.implementations.arrow import ArrowFSWrapper local = fs.LocalFileSystem() local_fsspec = ArrowFSWrapper(local, skip_instance_cache=True) # posix calls here required as ArrowFSWrapper only supports url-like paths (not Windows paths) filename = str(PurePosixPath(tmp_path.as_posix()) / "test.csv") - with local_fsspec.open(filename, mode='w') as f: + with local_fsspec.open(filename, mode="w") as f: f.write("a,b,c\n") f.write("1,2,3\n") f.write("4,5,6\n") @@ -159,29 +159,29 @@ def test_arrow_fs_wrapper(self, tmp_path: Path, duckdb_cursor: DuckDBPyConnectio assert duckdb_cursor.fetchall() == [(1, 2, 3), (4, 5, 6)] def test_database_attach(self, tmp_path: Path, monkeypatch: MonkeyPatch): - db_path = str(tmp_path / 'hello.db') + db_path = str(tmp_path / "hello.db") # setup a database to attach later with duckdb.connect(db_path) as conn: conn.execute( - ''' + """ CREATE TABLE t (id int); INSERT INTO t VALUES (0) - ''' + """ ) assert exists(db_path) with duckdb.connect() as conn: - fs = filesystem('file', skip_instance_cache=True) - write_errors = intercept(monkeypatch, LocalFileOpener, 'write') + fs = filesystem("file", skip_instance_cache=True) + write_errors = intercept(monkeypatch, LocalFileOpener, "write") conn.register_filesystem(fs) db_path_posix = str(PurePosixPath(tmp_path.as_posix()) / "hello.db") conn.execute(f"ATTACH 'file://{db_path_posix}'") - conn.execute('INSERT INTO hello.t VALUES (1)') + conn.execute("INSERT INTO hello.t VALUES (1)") - conn.execute('FROM hello.t') + conn.execute("FROM hello.t") assert conn.fetchall() == [(0,), (1,)] # duckdb sometimes seems to swallow write errors, so we use this to ensure that @@ -193,7 +193,7 @@ def test_copy_partition(self, duckdb_cursor: DuckDBPyConnection, memory: Abstrac duckdb_cursor.execute("copy (select 1 as a, 2 as b) to 'memory://root' (partition_by (a), HEADER 0)") - assert memory.open('/root/a=1/data_0.csv').read() == b'2\n' + assert memory.open("/root/a=1/data_0.csv").read() == b"2\n" def test_copy_partition_with_columns_written(self, duckdb_cursor: DuckDBPyConnection, memory: AbstractFileSystem): duckdb_cursor.register_filesystem(memory) @@ -202,7 +202,7 @@ def test_copy_partition_with_columns_written(self, duckdb_cursor: DuckDBPyConnec "copy (select 1 as a) to 'memory://root' (partition_by (a), HEADER 0, WRITE_PARTITION_COLUMNS)" ) - assert memory.open('/root/a=1/data_0.csv').read() == b'1\n' + assert memory.open("/root/a=1/data_0.csv").read() == b"1\n" def test_read_hive_partition(self, duckdb_cursor: DuckDBPyConnection, memory: AbstractFileSystem): duckdb_cursor.register_filesystem(memory) @@ -210,25 +210,25 @@ def test_read_hive_partition(self, duckdb_cursor: DuckDBPyConnection, memory: Ab "copy (select 2 as a, 3 as b, 4 as c) to 'memory://partition' (partition_by (a), HEADER 0)" ) - path = 'memory:///partition/*/*.csv' + path = "memory:///partition/*/*.csv" query = "SELECT * FROM read_csv_auto('" + path + "'" # hive partitioning - duckdb_cursor.execute(query + ', HIVE_PARTITIONING=1' + ');') + duckdb_cursor.execute(query + ", HIVE_PARTITIONING=1" + ");") assert duckdb_cursor.fetchall() == [(3, 4, 2)] # hive partitioning: auto detection - duckdb_cursor.execute(query + ');') + duckdb_cursor.execute(query + ");") assert duckdb_cursor.fetchall() == [(3, 4, 2)] # hive partitioning: cast to int - duckdb_cursor.execute(query + ', HIVE_PARTITIONING=1' + ', HIVE_TYPES_AUTOCAST=1' + ');') + duckdb_cursor.execute(query + ", HIVE_PARTITIONING=1" + ", HIVE_TYPES_AUTOCAST=1" + ");") assert duckdb_cursor.fetchall() == [(3, 4, 2)] # hive partitioning: no cast to int - duckdb_cursor.execute(query + ', HIVE_PARTITIONING=1' + ', HIVE_TYPES_AUTOCAST=0' + ');') - assert duckdb_cursor.fetchall() == [(3, 4, '2')] + duckdb_cursor.execute(query + ", HIVE_PARTITIONING=1" + ", HIVE_TYPES_AUTOCAST=0" + ");") + assert duckdb_cursor.fetchall() == [(3, 4, "2")] def test_read_hive_partition_with_columns_written( self, duckdb_cursor: DuckDBPyConnection, memory: AbstractFileSystem @@ -238,34 +238,34 @@ def test_read_hive_partition_with_columns_written( "copy (select 2 as a) to 'memory://partition' (partition_by (a), HEADER 0, WRITE_PARTITION_COLUMNS)" ) - path = 'memory:///partition/*/*.csv' + path = "memory:///partition/*/*.csv" query = "SELECT * FROM read_csv_auto('" + path + "'" # hive partitioning - duckdb_cursor.execute(query + ', HIVE_PARTITIONING=1' + ');') + duckdb_cursor.execute(query + ", HIVE_PARTITIONING=1" + ");") assert duckdb_cursor.fetchall() == [(2, 2)] # hive partitioning: auto detection - duckdb_cursor.execute(query + ');') + duckdb_cursor.execute(query + ");") assert duckdb_cursor.fetchall() == [(2, 2)] # hive partitioning: cast to int - duckdb_cursor.execute(query + ', HIVE_PARTITIONING=1' + ', HIVE_TYPES_AUTOCAST=1' + ');') + duckdb_cursor.execute(query + ", HIVE_PARTITIONING=1" + ", HIVE_TYPES_AUTOCAST=1" + ");") assert duckdb_cursor.fetchall() == [(2, 2)] # hive partitioning: no cast to int - duckdb_cursor.execute(query + ', HIVE_PARTITIONING=1' + ', HIVE_TYPES_AUTOCAST=0' + ');') - assert duckdb_cursor.fetchall() == [(2, '2')] + duckdb_cursor.execute(query + ", HIVE_PARTITIONING=1" + ", HIVE_TYPES_AUTOCAST=0" + ");") + assert duckdb_cursor.fetchall() == [(2, "2")] def test_parallel_union_by_name(self, tmp_path): - pa = importorskip('pyarrow') - pq = importorskip('pyarrow.parquet') - fsspec = importorskip('fsspec') + pa = importorskip("pyarrow") + pq = importorskip("pyarrow.parquet") + fsspec = importorskip("fsspec") table1 = pa.Table.from_pylist( [ - {'time': 1719568210134107692, 'col1': 1}, + {"time": 1719568210134107692, "col1": 1}, ] ) table1_path = tmp_path / "table1.parquet" @@ -273,7 +273,7 @@ def test_parallel_union_by_name(self, tmp_path): table2 = pa.Table.from_pylist( [ - {'time': 1719568210134107692, 'col1': 1}, + {"time": 1719568210134107692, "col1": 1}, ] ) table2_path = tmp_path / "table2.parquet" diff --git a/tests/fast/test_get_table_names.py b/tests/fast/test_get_table_names.py index c11b8a65..1f90e444 100644 --- a/tests/fast/test_get_table_names.py +++ b/tests/fast/test_get_table_names.py @@ -6,7 +6,7 @@ class TestGetTableNames(object): def test_table_success(self, duckdb_cursor): conn = duckdb.connect() table_names = conn.get_table_names("SELECT * FROM my_table1, my_table2, my_table3") - assert table_names == {'my_table2', 'my_table3', 'my_table1'} + assert table_names == {"my_table2", "my_table3", "my_table1"} def test_table_fail(self, duckdb_cursor): conn = duckdb.connect() @@ -19,11 +19,11 @@ def test_qualified_parameter_basic(self): # Default (qualified=False) table_names = conn.get_table_names("SELECT * FROM test_table") - assert table_names == {'test_table'} + assert table_names == {"test_table"} # Explicit qualified=False table_names = conn.get_table_names("SELECT * FROM test_table", qualified=False) - assert table_names == {'test_table'} + assert table_names == {"test_table"} def test_qualified_parameter_schemas(self): conn = duckdb.connect() @@ -31,11 +31,11 @@ def test_qualified_parameter_schemas(self): # Default (qualified=False) query = "SELECT * FROM test_schema.schema_table, main_table" table_names = conn.get_table_names(query) - assert table_names == {'schema_table', 'main_table'} + assert table_names == {"schema_table", "main_table"} # Test with qualified names table_names = conn.get_table_names(query, qualified=True) - assert table_names == {'test_schema.schema_table', 'main_table'} + assert table_names == {"test_schema.schema_table", "main_table"} def test_qualified_parameter_catalogs(self): conn = duckdb.connect() @@ -45,11 +45,11 @@ def test_qualified_parameter_catalogs(self): # Default (qualified=False) table_names = conn.get_table_names(query) - assert table_names == {'catalog_table', 'regular_table'} + assert table_names == {"catalog_table", "regular_table"} # With qualified=True table_names = conn.get_table_names(query, qualified=True) - assert table_names == {'catalog1.test_schema.catalog_table', 'regular_table'} + assert table_names == {"catalog1.test_schema.catalog_table", "regular_table"} def test_qualified_parameter_quoted_identifiers(self): conn = duckdb.connect() @@ -59,7 +59,7 @@ def test_qualified_parameter_quoted_identifiers(self): # Default (qualified=False) table_names = conn.get_table_names(query) - assert table_names == {'Table.With.Dots', 'Table With Spaces'} + assert table_names == {"Table.With.Dots", "Table With Spaces"} # With qualified=True table_names = conn.get_table_names(query, qualified=True) @@ -67,45 +67,45 @@ def test_qualified_parameter_quoted_identifiers(self): def test_expanded_views(self): conn = duckdb.connect() - conn.execute('CREATE TABLE my_table(i INT)') - conn.execute('CREATE VIEW v1 AS SELECT * FROM my_table') + conn.execute("CREATE TABLE my_table(i INT)") + conn.execute("CREATE VIEW v1 AS SELECT * FROM my_table") # Test that v1 expands to my_table - query = 'SELECT col_a FROM v1' + query = "SELECT col_a FROM v1" # Default (qualified=False) table_names = conn.get_table_names(query) - assert table_names == {'my_table'} + assert table_names == {"my_table"} # With qualified=True table_names = conn.get_table_names(query, qualified=True) - assert table_names == {'my_table'} + assert table_names == {"my_table"} def test_expanded_views_with_schema(self): conn = duckdb.connect() - conn.execute('CREATE SCHEMA my_schema') - conn.execute('CREATE TABLE my_schema.my_table(i INT)') - conn.execute('CREATE VIEW v1 AS SELECT * FROM my_schema.my_table') + conn.execute("CREATE SCHEMA my_schema") + conn.execute("CREATE TABLE my_schema.my_table(i INT)") + conn.execute("CREATE VIEW v1 AS SELECT * FROM my_schema.my_table") # Test that v1 expands to my_table - query = 'SELECT col_a FROM v1' + query = "SELECT col_a FROM v1" # Default (qualified=False) table_names = conn.get_table_names(query) - assert table_names == {'my_table'} + assert table_names == {"my_table"} # With qualified=True table_names = conn.get_table_names(query, qualified=True) - assert table_names == {'my_schema.my_table'} + assert table_names == {"my_schema.my_table"} def test_select_function(self): conn = duckdb.connect() - query = 'SELECT EXTRACT(second FROM i) FROM timestamps;' + query = "SELECT EXTRACT(second FROM i) FROM timestamps;" # Default (qualified=False) table_names = conn.get_table_names(query) - assert table_names == {'timestamps'} + assert table_names == {"timestamps"} # With qualified=True table_names = conn.get_table_names(query, qualified=True) - assert table_names == {'timestamps'} + assert table_names == {"timestamps"} diff --git a/tests/fast/test_import_export.py b/tests/fast/test_import_export.py index 2fce1636..d98a2d73 100644 --- a/tests/fast/test_import_export.py +++ b/tests/fast/test_import_export.py @@ -33,7 +33,7 @@ def move_database(export_location, import_location): assert path.exists(export_location) assert path.exists(import_location) - for file in ['schema.sql', 'load.sql', 'tbl.csv']: + for file in ["schema.sql", "load.sql", "tbl.csv"]: shutil.move(path.join(export_location, file), import_location) @@ -56,7 +56,7 @@ def export_and_import_empty_db(db_path, _): class TestDuckDBImportExport: - @pytest.mark.parametrize('routine', [export_move_and_import, export_and_import_empty_db]) + @pytest.mark.parametrize("routine", [export_move_and_import, export_and_import_empty_db]) def test_import_and_export(self, routine, tmp_path_factory): export_path = str(tmp_path_factory.mktemp("export_dbs", numbered=True)) import_path = str(tmp_path_factory.mktemp("import_dbs", numbered=True)) @@ -66,15 +66,15 @@ def test_import_empty_db(self, tmp_path_factory): import_path = str(tmp_path_factory.mktemp("empty_db", numbered=True)) # Create an empty db folder structure - Path(Path(import_path) / 'load.sql').touch() - Path(Path(import_path) / 'schema.sql').touch() + Path(Path(import_path) / "load.sql").touch() + Path(Path(import_path) / "schema.sql").touch() con = duckdb.connect() con.execute(f"import database '{import_path}'") # Put a single comment into the 'schema.sql' file - with open(Path(import_path) / 'schema.sql', 'w') as f: - f.write('--\n') + with open(Path(import_path) / "schema.sql", "w") as f: + f.write("--\n") con.close() con = duckdb.connect() diff --git a/tests/fast/test_insert.py b/tests/fast/test_insert.py index 1465b68a..baae75b4 100644 --- a/tests/fast/test_insert.py +++ b/tests/fast/test_insert.py @@ -6,7 +6,7 @@ class TestInsert(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_insert(self, pandas): test_df = pandas.DataFrame({"i": [1, 2, 3], "j": ["one", "two", "three"]}) # connect to an in-memory temporary database @@ -15,19 +15,19 @@ def test_insert(self, pandas): cursor = conn.cursor() conn.execute("CREATE TABLE test (i INTEGER, j STRING)") rel = conn.table("test") - rel.insert([1, 'one']) - rel.insert([2, 'two']) - rel.insert([3, 'three']) - rel_a3 = cursor.table('test').project('CAST(i as BIGINT)i, j').to_df() + rel.insert([1, "one"]) + rel.insert([2, "two"]) + rel.insert([3, "three"]) + rel_a3 = cursor.table("test").project("CAST(i as BIGINT)i, j").to_df() pandas.testing.assert_frame_equal(rel_a3, test_df) def test_insert_with_schema(self, duckdb_cursor): duckdb_cursor.sql("create schema not_main") duckdb_cursor.sql("create table not_main.tbl as select * from range(10)") - res = duckdb_cursor.table('not_main.tbl').fetchall() + res = duckdb_cursor.table("not_main.tbl").fetchall() assert len(res) == 10 # FIXME: This is not currently supported - with pytest.raises(duckdb.CatalogException, match='Table with name tbl does not exist'): - duckdb_cursor.table('not_main.tbl').insert([42, 21, 1337]) + with pytest.raises(duckdb.CatalogException, match="Table with name tbl does not exist"): + duckdb_cursor.table("not_main.tbl").insert([42, 21, 1337]) diff --git a/tests/fast/test_many_con_same_file.py b/tests/fast/test_many_con_same_file.py index 6b7362a6..3cef2494 100644 --- a/tests/fast/test_many_con_same_file.py +++ b/tests/fast/test_many_con_same_file.py @@ -23,7 +23,7 @@ def test_multiple_writes(): con1.close() con3 = duckdb.connect("test.db") tbls = get_tables(con3) - assert tbls == ['bar1', 'foo1'] + assert tbls == ["bar1", "foo1"] del con1 del con2 del con3 @@ -41,9 +41,9 @@ def test_multiple_writes_memory(): con2.execute("CREATE TABLE bar1 as SELECT 2 as a, 3 as b") con3 = duckdb.connect(":memory:") tbls = get_tables(con1) - assert tbls == ['foo1'] + assert tbls == ["foo1"] tbls = get_tables(con2) - assert tbls == ['bar1'] + assert tbls == ["bar1"] tbls = get_tables(con3) assert tbls == [] del con1 @@ -58,7 +58,7 @@ def test_multiple_writes_named_memory(): con2.execute("CREATE TABLE bar1 as SELECT 2 as a, 3 as b") con3 = duckdb.connect(":memory:1") tbls = get_tables(con3) - assert tbls == ['bar1', 'foo1'] + assert tbls == ["bar1", "foo1"] del con1 del con2 del con3 @@ -76,7 +76,7 @@ def test_diff_config(): def test_diff_config_extended(): - con1 = duckdb.connect("test.db", config={'null_order': 'NULLS FIRST'}) + con1 = duckdb.connect("test.db", config={"null_order": "NULLS FIRST"}) with pytest.raises( duckdb.ConnectionException, match="Can't open a connection to same database file with a different configuration than existing connections", diff --git a/tests/fast/test_map.py b/tests/fast/test_map.py index 4dbd1a36..f86dd60b 100644 --- a/tests/fast/test_map.py +++ b/tests/fast/test_map.py @@ -9,36 +9,36 @@ # column count differs from bind def evil1(df): if len(df) == 0: - return df['col0'].to_frame() + return df["col0"].to_frame() else: return df class TestMap(object): - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_evil_map(self, duckdb_cursor, pandas): testrel = duckdb.values([1, 2]) - with pytest.raises(duckdb.InvalidInputException, match='Expected 1 columns from UDF, got 2'): - rel = testrel.map(evil1, schema={'i': str}) + with pytest.raises(duckdb.InvalidInputException, match="Expected 1 columns from UDF, got 2"): + rel = testrel.map(evil1, schema={"i": str}) df = rel.df() print(df) - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_map(self, duckdb_cursor, pandas): testrel = duckdb.values([1, 2]) conn = duckdb_cursor - conn.execute('CREATE TABLE t (a integer)') - empty_rel = conn.table('t') + conn.execute("CREATE TABLE t (a integer)") + empty_rel = conn.table("t") - newdf1 = testrel.map(lambda df: df['col0'].add(42).to_frame()) - newdf2 = testrel.map(lambda df: df['col0'].astype('string').to_frame()) + newdf1 = testrel.map(lambda df: df["col0"].add(42).to_frame()) + newdf2 = testrel.map(lambda df: df["col0"].astype("string").to_frame()) newdf3 = testrel.map(lambda df: df) # column type differs from bind def evil2(df): result = df.copy(deep=True) if len(result) == 0: - result['col0'] = result['col0'].astype('double') + result["col0"] = result["col0"].astype("double") return result # column name differs from bind @@ -56,10 +56,10 @@ def evil5(df): raise TypeError def return_dataframe(df): - return pandas.DataFrame({'A': [1]}) + return pandas.DataFrame({"A": [1]}) def return_big_dataframe(df): - return pandas.DataFrame({'A': [1] * 5000}) + return pandas.DataFrame({"A": [1] * 5000}) def return_none(df): return None @@ -67,13 +67,13 @@ def return_none(df): def return_empty_df(df): return pandas.DataFrame() - with pytest.raises(duckdb.InvalidInputException, match='Expected 1 columns from UDF, got 2'): + with pytest.raises(duckdb.InvalidInputException, match="Expected 1 columns from UDF, got 2"): print(testrel.map(evil1).df()) - with pytest.raises(duckdb.InvalidInputException, match='UDF column type mismatch'): + with pytest.raises(duckdb.InvalidInputException, match="UDF column type mismatch"): print(testrel.map(evil2).df()) - with pytest.raises(duckdb.InvalidInputException, match='UDF column name mismatch'): + with pytest.raises(duckdb.InvalidInputException, match="UDF column name mismatch"): print(testrel.map(evil3).df()) with pytest.raises( @@ -92,19 +92,19 @@ def return_empty_df(df): with pytest.raises(TypeError): print(testrel.map().df()) - testrel.map(return_dataframe).df().equals(pandas.DataFrame({'A': [1]})) + testrel.map(return_dataframe).df().equals(pandas.DataFrame({"A": [1]})) with pytest.raises( - duckdb.InvalidInputException, match='UDF returned more than 2048 rows, which is not allowed.' + duckdb.InvalidInputException, match="UDF returned more than 2048 rows, which is not allowed." ): testrel.map(return_big_dataframe).df() - empty_rel.map(return_dataframe).df().equals(pandas.DataFrame({'A': []})) + empty_rel.map(return_dataframe).df().equals(pandas.DataFrame({"A": []})) - with pytest.raises(duckdb.InvalidInputException, match='No return value from Python function'): + with pytest.raises(duckdb.InvalidInputException, match="No return value from Python function"): testrel.map(return_none).df() - with pytest.raises(duckdb.InvalidInputException, match='Need a DataFrame with at least one column'): + with pytest.raises(duckdb.InvalidInputException, match="Need a DataFrame with at least one column"): testrel.map(return_empty_df).df() def test_map_with_object_column(self, duckdb_cursor): @@ -115,21 +115,21 @@ def return_with_no_modification(df): # when a dataframe with 'object' column is returned, we use the content to infer the type # when the dataframe is empty, this results in NULL, which is not desirable # in this case we assume the returned type should be the same as the input type - duckdb_cursor.values([b'1234']).map(return_with_no_modification).fetchall() + duckdb_cursor.values([b"1234"]).map(return_with_no_modification).fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_isse_3237(self, duckdb_cursor, pandas): def process(rel): def mapper(x): - dates = x['date'].to_numpy("datetime64[us]") - days = x['days_to_add'].to_numpy("int") + dates = x["date"].to_numpy("datetime64[us]") + days = x["days_to_add"].to_numpy("int") x["result1"] = pandas.Series( [pandas.to_datetime(y[0]).date() + timedelta(days=y[1].item()) for y in zip(dates, days)], - dtype='datetime64[us]', + dtype="datetime64[us]", ) x["result2"] = pandas.Series( [pandas.to_datetime(y[0]).date() + timedelta(days=-y[1].item()) for y in zip(dates, days)], - dtype='datetime64[us]', + dtype="datetime64[us]", ) return x @@ -140,22 +140,22 @@ def mapper(x): return rel df = pandas.DataFrame( - {'date': pandas.Series([date(2000, 1, 1), date(2000, 1, 2)], dtype="datetime64[us]"), 'days_to_add': [1, 2]} + {"date": pandas.Series([date(2000, 1, 1), date(2000, 1, 2)], dtype="datetime64[us]"), "days_to_add": [1, 2]} ) rel = duckdb.from_df(df) rel = process(rel) x = rel.fetchdf() - assert x['days_to_add'].to_numpy()[0] == 1 + assert x["days_to_add"].to_numpy()[0] == 1 def test_explicit_schema(self): def cast_to_string(df): - df['i'] = df['i'].astype(str) + df["i"] = df["i"].astype(str) return df con = duckdb.connect() - rel = con.sql('select i from range (10) tbl(i)') + rel = con.sql("select i from range (10) tbl(i)") assert rel.types[0] == duckdb.NUMBER - mapped_rel = rel.map(cast_to_string, schema={'i': str}) + mapped_rel = rel.map(cast_to_string, schema={"i": str}) assert mapped_rel.types[0] == duckdb.STRING def test_explicit_schema_returntype_mismatch(self): @@ -163,45 +163,45 @@ def does_nothing(df): return df con = duckdb.connect() - rel = con.sql('select i from range(10) tbl(i)') + rel = con.sql("select i from range(10) tbl(i)") # expects the mapper to return a string column - rel = rel.map(does_nothing, schema={'i': str}) + rel = rel.map(does_nothing, schema={"i": str}) with pytest.raises( duckdb.InvalidInputException, match=re.escape("UDF column type mismatch, expected [VARCHAR], got [BIGINT]") ): rel.fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_explicit_schema_name_mismatch(self, pandas): def renames_column(df): - return pandas.DataFrame({'a': df['i']}) + return pandas.DataFrame({"a": df["i"]}) con = duckdb.connect() - rel = con.sql('select i from range(10) tbl(i)') - rel = rel.map(renames_column, schema={'i': int}) - with pytest.raises(duckdb.InvalidInputException, match=re.escape('UDF column name mismatch')): + rel = con.sql("select i from range(10) tbl(i)") + rel = rel.map(renames_column, schema={"i": int}) + with pytest.raises(duckdb.InvalidInputException, match=re.escape("UDF column name mismatch")): rel.fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_explicit_schema_error(self, pandas): def no_op(df): return df con = duckdb.connect() - rel = con.sql('select 42') + rel = con.sql("select 42") with pytest.raises( duckdb.InvalidInputException, match=re.escape("Invalid Input Error: 'schema' should be given as a Dict[str, DuckDBType]"), ): rel.map(no_op, schema=[int]) - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_returns_non_dataframe(self, pandas): def returns_series(df): - return df.loc[:, 'i'] + return df.loc[:, "i"] con = duckdb.connect() - rel = con.sql('select i, i as j from range(10) tbl(i)') + rel = con.sql("select i, i as j from range(10) tbl(i)") with pytest.raises( duckdb.InvalidInputException, match=re.escape( @@ -210,29 +210,29 @@ def returns_series(df): ): rel = rel.map(returns_series) - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_explicit_schema_columncount_mismatch(self, pandas): def returns_subset(df): - return pandas.DataFrame({'i': df.loc[:, 'i']}) + return pandas.DataFrame({"i": df.loc[:, "i"]}) con = duckdb.connect() - rel = con.sql('select i, i as j from range(10) tbl(i)') - rel = rel.map(returns_subset, schema={'i': int, 'j': int}) + rel = con.sql("select i, i as j from range(10) tbl(i)") + rel = rel.map(returns_subset, schema={"i": int, "j": int}) with pytest.raises( - duckdb.InvalidInputException, match='Invalid Input Error: Expected 2 columns from UDF, got 1' + duckdb.InvalidInputException, match="Invalid Input Error: Expected 2 columns from UDF, got 1" ): rel.fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_pyarrow_df(self, pandas): # PyArrow backed dataframes only exist on pandas >= 2.0.0 _ = pytest.importorskip("pandas", "2.0.0") def basic_function(df): # Create a pyarrow backed dataframe - df = pandas.DataFrame({'a': [5, 3, 2, 1, 2]}).convert_dtypes(dtype_backend='pyarrow') + df = pandas.DataFrame({"a": [5, 3, 2, 1, 2]}).convert_dtypes(dtype_backend="pyarrow") return df con = duckdb.connect() with pytest.raises(duckdb.InvalidInputException): - rel = con.sql('select 42').map(basic_function) + rel = con.sql("select 42").map(basic_function) diff --git a/tests/fast/test_metatransaction.py b/tests/fast/test_metatransaction.py index 158bb6a9..f617cba2 100644 --- a/tests/fast/test_metatransaction.py +++ b/tests/fast/test_metatransaction.py @@ -10,7 +10,7 @@ class TestMetaTransaction(object): def test_fetchmany(self, duckdb_cursor): duckdb_cursor.execute("CREATE SEQUENCE id_seq") - column_names = ',\n'.join([f'column_{i} FLOAT' for i in range(1, NUMBER_OF_COLUMNS + 1)]) + column_names = ",\n".join([f"column_{i} FLOAT" for i in range(1, NUMBER_OF_COLUMNS + 1)]) create_table_query = f""" CREATE TABLE my_table ( id INTEGER DEFAULT nextval('id_seq'), @@ -23,7 +23,7 @@ def test_fetchmany(self, duckdb_cursor): for i in range(20): # Then insert a large amount of tuples, triggering a parallel execution data = np.random.rand(NUMBER_OF_ROWS, NUMBER_OF_COLUMNS) - columns = [f'Column_{i+1}' for i in range(NUMBER_OF_COLUMNS)] + columns = [f"Column_{i + 1}" for i in range(NUMBER_OF_COLUMNS)] df = pd.DataFrame(data, columns=columns) df_columns = ", ".join(df.columns) # This gets executed in parallel, causing NextValFunction to be called in parallel diff --git a/tests/fast/test_multi_statement.py b/tests/fast/test_multi_statement.py index db82eaf3..722ab31a 100644 --- a/tests/fast/test_multi_statement.py +++ b/tests/fast/test_multi_statement.py @@ -7,36 +7,36 @@ class TestMultiStatement(object): def test_multi_statement(self, duckdb_cursor): import duckdb - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") # test empty statement - con.execute('') + con.execute("") # run multiple statements in one call to execute con.execute( - ''' + """ CREATE TABLE integers(i integer); insert into integers select * from range(10); select * from integers; - ''' + """ ) results = [x[0] for x in con.fetchall()] assert results == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] # test export/import - export_location = os.path.join(os.getcwd(), 'duckdb_pytest_dir_export') + export_location = os.path.join(os.getcwd(), "duckdb_pytest_dir_export") try: shutil.rmtree(export_location) except: pass - con.execute('CREATE TABLE integers2(i INTEGER)') - con.execute('INSERT INTO integers2 VALUES (1), (5), (7), (1928)') + con.execute("CREATE TABLE integers2(i INTEGER)") + con.execute("INSERT INTO integers2 VALUES (1), (5), (7), (1928)") con.execute("EXPORT DATABASE '%s'" % (export_location,)) # reset connection - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") con.execute("IMPORT DATABASE '%s'" % (export_location,)) - integers = [x[0] for x in con.execute('SELECT * FROM integers').fetchall()] - integers2 = [x[0] for x in con.execute('SELECT * FROM integers2').fetchall()] + integers = [x[0] for x in con.execute("SELECT * FROM integers").fetchall()] + integers2 = [x[0] for x in con.execute("SELECT * FROM integers2").fetchall()] assert integers == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] assert integers2 == [1, 5, 7, 1928] shutil.rmtree(export_location) diff --git a/tests/fast/test_multithread.py b/tests/fast/test_multithread.py index ad2d56fd..628aacd8 100644 --- a/tests/fast/test_multithread.py +++ b/tests/fast/test_multithread.py @@ -16,7 +16,7 @@ def connect_duck(duckdb_conn): - out = duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)').fetchall() + out = duckdb_conn.execute("select i from (values (42), (84), (NULL), (128)) tbl(i)").fetchall() assert out == [(42,), (84,), (None,), (128,)] @@ -39,7 +39,7 @@ def multithread_test(self, result_verification=everything_succeeded): for i in range(0, self.duckdb_insert_thread_count): self.threads.append( threading.Thread( - target=self.thread_function, args=(duckdb_conn, queue, self.pandas), name='duckdb_thread_' + str(i) + target=self.thread_function, args=(duckdb_conn, queue, self.pandas), name="duckdb_thread_" + str(i) ) ) @@ -60,7 +60,7 @@ def multithread_test(self, result_verification=everything_succeeded): def execute_query_same_connection(duckdb_conn, queue, pandas): try: - out = duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)') + out = duckdb_conn.execute("select i from (values (42), (84), (NULL), (128)) tbl(i)") queue.put(False) except: queue.put(True) @@ -70,7 +70,7 @@ def execute_query(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() try: - duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)') + duckdb_conn.execute("select i from (values (42), (84), (NULL), (128)) tbl(i)") queue.put(True) except: queue.put(False) @@ -80,7 +80,7 @@ def insert_runtime_error(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() try: - duckdb_conn.execute('insert into T values (42), (84), (NULL), (128)') + duckdb_conn.execute("insert into T values (42), (84), (NULL), (128)") queue.put(False) except: queue.put(True) @@ -104,9 +104,9 @@ def execute_many_query(duckdb_conn, queue, pandas): ) # Larger example that inserts many records at a time purchases = [ - ('2006-03-28', 'BUY', 'IBM', 1000, 45.00), - ('2006-04-05', 'BUY', 'MSFT', 1000, 72.00), - ('2006-04-06', 'SELL', 'IBM', 500, 53.00), + ("2006-03-28", "BUY", "IBM", 1000, 45.00), + ("2006-04-05", "BUY", "MSFT", 1000, 72.00), + ("2006-04-06", "SELL", "IBM", 500, 53.00), ] duckdb_conn.executemany( """ @@ -123,7 +123,7 @@ def fetchone_query(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() try: - duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)').fetchone() + duckdb_conn.execute("select i from (values (42), (84), (NULL), (128)) tbl(i)").fetchone() queue.put(True) except: queue.put(False) @@ -133,7 +133,7 @@ def fetchall_query(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() try: - duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)').fetchall() + duckdb_conn.execute("select i from (values (42), (84), (NULL), (128)) tbl(i)").fetchall() queue.put(True) except: queue.put(False) @@ -153,7 +153,7 @@ def fetchnp_query(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() try: - duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)').fetchnumpy() + duckdb_conn.execute("select i from (values (42), (84), (NULL), (128)) tbl(i)").fetchnumpy() queue.put(True) except: queue.put(False) @@ -163,7 +163,7 @@ def fetchdf_query(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() try: - duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)').fetchdf() + duckdb_conn.execute("select i from (values (42), (84), (NULL), (128)) tbl(i)").fetchdf() queue.put(True) except: queue.put(False) @@ -173,7 +173,7 @@ def fetchdf_chunk_query(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() try: - duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)').fetch_df_chunk() + duckdb_conn.execute("select i from (values (42), (84), (NULL), (128)) tbl(i)").fetch_df_chunk() queue.put(True) except: queue.put(False) @@ -183,7 +183,7 @@ def fetch_arrow_query(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() try: - duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)').fetch_arrow_table() + duckdb_conn.execute("select i from (values (42), (84), (NULL), (128)) tbl(i)").fetch_arrow_table() queue.put(True) except: queue.put(False) @@ -193,7 +193,7 @@ def fetch_record_batch_query(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() try: - duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)').fetch_record_batch() + duckdb_conn.execute("select i from (values (42), (84), (NULL), (128)) tbl(i)").fetch_record_batch() queue.put(True) except: queue.put(False) @@ -205,9 +205,9 @@ def transaction_query(duckdb_conn, queue, pandas): duckdb_conn.execute("CREATE TABLE T ( i INTEGER)") try: duckdb_conn.begin() - duckdb_conn.execute('insert into T values (42), (84), (NULL), (128)') + duckdb_conn.execute("insert into T values (42), (84), (NULL), (128)") duckdb_conn.rollback() - duckdb_conn.execute('insert into T values (42), (84), (NULL), (128)') + duckdb_conn.execute("insert into T values (42), (84), (NULL), (128)") duckdb_conn.commit() queue.put(True) except: @@ -218,9 +218,9 @@ def df_append(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() duckdb_conn.execute("CREATE TABLE T ( i INTEGER)") - df = pandas.DataFrame(np.random.randint(0, 100, size=15), columns=['A']) + df = pandas.DataFrame(np.random.randint(0, 100, size=15), columns=["A"]) try: - duckdb_conn.append('T', df) + duckdb_conn.append("T", df) queue.put(True) except: queue.put(False) @@ -229,9 +229,9 @@ def df_append(duckdb_conn, queue, pandas): def df_register(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() - df = pandas.DataFrame(np.random.randint(0, 100, size=15), columns=['A']) + df = pandas.DataFrame(np.random.randint(0, 100, size=15), columns=["A"]) try: - duckdb_conn.register('T', df) + duckdb_conn.register("T", df) queue.put(True) except: queue.put(False) @@ -240,10 +240,10 @@ def df_register(duckdb_conn, queue, pandas): def df_unregister(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() - df = pandas.DataFrame(np.random.randint(0, 100, size=15), columns=['A']) + df = pandas.DataFrame(np.random.randint(0, 100, size=15), columns=["A"]) try: - duckdb_conn.register('T', df) - duckdb_conn.unregister('T') + duckdb_conn.register("T", df) + duckdb_conn.unregister("T") queue.put(True) except: queue.put(False) @@ -251,12 +251,12 @@ def df_unregister(duckdb_conn, queue, pandas): def arrow_register_unregister(duckdb_conn, queue, pandas): # Get a new connection - pa = pytest.importorskip('pyarrow') + pa = pytest.importorskip("pyarrow") duckdb_conn = duckdb.connect() - arrow_tbl = pa.Table.from_pydict({'my_column': pa.array([1, 2, 3, 4, 5], type=pa.int64())}) + arrow_tbl = pa.Table.from_pydict({"my_column": pa.array([1, 2, 3, 4, 5], type=pa.int64())}) try: - duckdb_conn.register('T', arrow_tbl) - duckdb_conn.unregister('T') + duckdb_conn.register("T", arrow_tbl) + duckdb_conn.unregister("T") queue.put(True) except: queue.put(False) @@ -267,7 +267,7 @@ def table(duckdb_conn, queue, pandas): duckdb_conn = duckdb.connect() duckdb_conn.execute("CREATE TABLE T ( i INTEGER)") try: - out = duckdb_conn.table('T') + out = duckdb_conn.table("T") queue.put(True) except: queue.put(False) @@ -279,7 +279,7 @@ def view(duckdb_conn, queue, pandas): duckdb_conn.execute("CREATE TABLE T ( i INTEGER)") duckdb_conn.execute("CREATE VIEW V as (SELECT * FROM T)") try: - out = duckdb_conn.values([5, 'five']) + out = duckdb_conn.values([5, "five"]) queue.put(True) except: queue.put(False) @@ -289,7 +289,7 @@ def values(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() try: - out = duckdb_conn.values([5, 'five']) + out = duckdb_conn.values([5, "five"]) queue.put(True) except: queue.put(False) @@ -308,7 +308,7 @@ def from_query(duckdb_conn, queue, pandas): def from_df(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() - df = pandas.DataFrame(['bla', 'blabla'] * 10, columns=['A']) + df = pandas.DataFrame(["bla", "blabla"] * 10, columns=["A"]) try: out = duckdb_conn.execute("select * from df").fetchall() queue.put(True) @@ -318,9 +318,9 @@ def from_df(duckdb_conn, queue, pandas): def from_arrow(duckdb_conn, queue, pandas): # Get a new connection - pa = pytest.importorskip('pyarrow') + pa = pytest.importorskip("pyarrow") duckdb_conn = duckdb.connect() - arrow_tbl = pa.Table.from_pydict({'my_column': pa.array([1, 2, 3, 4, 5], type=pa.int64())}) + arrow_tbl = pa.Table.from_pydict({"my_column": pa.array([1, 2, 3, 4, 5], type=pa.int64())}) try: out = duckdb_conn.from_arrow(arrow_tbl) queue.put(True) @@ -331,7 +331,7 @@ def from_arrow(duckdb_conn, queue, pandas): def from_csv_auto(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() - filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'integers.csv') + filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "integers.csv") try: out = duckdb_conn.from_csv_auto(filename) queue.put(True) @@ -342,7 +342,7 @@ def from_csv_auto(duckdb_conn, queue, pandas): def from_parquet(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() - filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'binary_string.parquet') + filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "binary_string.parquet") try: out = duckdb_conn.from_parquet(filename) queue.put(True) @@ -353,7 +353,7 @@ def from_parquet(duckdb_conn, queue, pandas): def description(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() - duckdb_conn.execute('CREATE TABLE test (i bool, j TIME, k VARCHAR)') + duckdb_conn.execute("CREATE TABLE test (i bool, j TIME, k VARCHAR)") duckdb_conn.execute("INSERT INTO test VALUES (TRUE, '01:01:01', 'bla' )") rel = duckdb_conn.table("test") rel.execute() @@ -368,138 +368,138 @@ def cursor(duckdb_conn, queue, pandas): # Get a new connection cx = duckdb_conn.cursor() try: - cx.execute('CREATE TABLE test (i bool, j TIME, k VARCHAR)') + cx.execute("CREATE TABLE test (i bool, j TIME, k VARCHAR)") queue.put(False) except: queue.put(True) class TestDuckMultithread(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_execute(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, execute_query, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_execute_many(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, execute_many_query, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_fetchone(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, fetchone_query, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_fetchall(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, fetchall_query, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_close(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, conn_close, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_fetchnp(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, fetchnp_query, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_fetchdf(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, fetchdf_query, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_fetchdfchunk(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, fetchdf_chunk_query, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_fetcharrow(self, duckdb_cursor, pandas): - pytest.importorskip('pyarrow') + pytest.importorskip("pyarrow") duck_threads = DuckDBThreaded(10, fetch_arrow_query, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_fetch_record_batch(self, duckdb_cursor, pandas): - pytest.importorskip('pyarrow') + pytest.importorskip("pyarrow") duck_threads = DuckDBThreaded(10, fetch_record_batch_query, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_transaction(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, transaction_query, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_df_append(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, df_append, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_df_register(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, df_register, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_df_unregister(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, df_unregister, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_arrow_register_unregister(self, duckdb_cursor, pandas): - pytest.importorskip('pyarrow') + pytest.importorskip("pyarrow") duck_threads = DuckDBThreaded(10, arrow_register_unregister, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_table(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, table, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_view(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, view, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_values(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, values, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_from_query(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, from_query, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_from_DF(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, from_df, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_from_arrow(self, duckdb_cursor, pandas): - pytest.importorskip('pyarrow') + pytest.importorskip("pyarrow") duck_threads = DuckDBThreaded(10, from_arrow, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_from_csv_auto(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, from_csv_auto, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_from_parquet(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, from_parquet, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_description(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, description, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_cursor(self, duckdb_cursor, pandas): def only_some_succeed(results: list[bool]): if not any([result == True for result in results]): diff --git a/tests/fast/test_non_default_conn.py b/tests/fast/test_non_default_conn.py index bc9fa5f0..cb0218e3 100644 --- a/tests/fast/test_non_default_conn.py +++ b/tests/fast/test_non_default_conn.py @@ -24,7 +24,7 @@ def test_from_csv(self, duckdb_cursor): test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) test_df.to_csv(temp_file_name, index=False) rel = duckdb_cursor.from_csv_auto(temp_file_name) - assert rel.query('t_2', 'select count(*) from t inner join t_2 on (a = i)').fetchall()[0] == (1,) + assert rel.query("t_2", "select count(*) from t inner join t_2 on (a = i)").fetchall()[0] == (1,) def test_from_parquet(self, duckdb_cursor): try: @@ -37,16 +37,16 @@ def test_from_parquet(self, duckdb_cursor): test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) test_df.to_parquet(temp_file_name, index=False) rel = duckdb_cursor.from_parquet(temp_file_name) - assert rel.query('t_2', 'select count(*) from t inner join t_2 on (a = i)').fetchall()[0] == (1,) + assert rel.query("t_2", "select count(*) from t inner join t_2 on (a = i)").fetchall()[0] == (1,) def test_from_df(self, duckdb_cursor): duckdb_cursor.execute("create table t (a integer)") duckdb_cursor.execute("insert into t values (1)") test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) rel = duckdb.df(test_df, connection=duckdb_cursor) - assert rel.query('t_2', 'select count(*) from t inner join t_2 on (a = i)').fetchall()[0] == (1,) + assert rel.query("t_2", "select count(*) from t inner join t_2 on (a = i)").fetchall()[0] == (1,) rel = duckdb_cursor.from_df(test_df) - assert rel.query('t_2', 'select count(*) from t inner join t_2 on (a = i)').fetchall()[0] == (1,) + assert rel.query("t_2", "select count(*) from t inner join t_2 on (a = i)").fetchall()[0] == (1,) def test_from_arrow(self, duckdb_cursor): try: @@ -59,55 +59,55 @@ def test_from_arrow(self, duckdb_cursor): test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) test_arrow = pa.Table.from_pandas(test_df) rel = duckdb_cursor.from_arrow(test_arrow) - assert rel.query('t_2', 'select count(*) from t inner join t_2 on (a = i)').fetchall()[0] == (1,) + assert rel.query("t_2", "select count(*) from t inner join t_2 on (a = i)").fetchall()[0] == (1,) rel = duckdb.arrow(test_arrow, connection=duckdb_cursor) - assert rel.query('t_2', 'select count(*) from t inner join t_2 on (a = i)').fetchall()[0] == (1,) + assert rel.query("t_2", "select count(*) from t inner join t_2 on (a = i)").fetchall()[0] == (1,) def test_filter_df(self, duckdb_cursor): duckdb_cursor.execute("create table t (a integer)") duckdb_cursor.execute("insert into t values (1), (4)") test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) rel = duckdb.filter(test_df, "i < 2", connection=duckdb_cursor) - assert rel.query('t_2', 'select count(*) from t inner join t_2 on (a = i)').fetchall()[0] == (1,) + assert rel.query("t_2", "select count(*) from t inner join t_2 on (a = i)").fetchall()[0] == (1,) def test_project_df(self, duckdb_cursor): duckdb_cursor.execute("create table t (a integer)") duckdb_cursor.execute("insert into t values (1), (4)") test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": [1, 2, 3, 4]}) rel = duckdb.project(test_df, "i", connection=duckdb_cursor) - assert rel.query('t_2', 'select * from t inner join t_2 on (a = i)').fetchall()[0] == (1, 1) + assert rel.query("t_2", "select * from t inner join t_2 on (a = i)").fetchall()[0] == (1, 1) def test_agg_df(self, duckdb_cursor): duckdb_cursor.execute("create table t (a integer)") duckdb_cursor.execute("insert into t values (1), (4)") test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": [1, 2, 3, 4]}) rel = duckdb.aggregate(test_df, "count(*) as i", connection=duckdb_cursor) - assert rel.query('t_2', 'select * from t inner join t_2 on (a = i)').fetchall()[0] == (4, 4) + assert rel.query("t_2", "select * from t inner join t_2 on (a = i)").fetchall()[0] == (4, 4) def test_distinct_df(self, duckdb_cursor): duckdb_cursor.execute("create table t (a integer)") duckdb_cursor.execute("insert into t values (1)") test_df = pd.DataFrame.from_dict({"i": [1, 1, 2, 3, 4]}) rel = duckdb.distinct(test_df, connection=duckdb_cursor) - assert rel.query('t_2', 'select * from t inner join t_2 on (a = i)').fetchall()[0] == (1, 1) + assert rel.query("t_2", "select * from t inner join t_2 on (a = i)").fetchall()[0] == (1, 1) def test_limit_df(self, duckdb_cursor): duckdb_cursor.execute("create table t (a integer)") duckdb_cursor.execute("insert into t values (1),(4)") test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) rel = duckdb.limit(test_df, 1, connection=duckdb_cursor) - assert rel.query('t_2', 'select * from t inner join t_2 on (a = i)').fetchall()[0] == (1, 1) + assert rel.query("t_2", "select * from t inner join t_2 on (a = i)").fetchall()[0] == (1, 1) def test_query_df(self, duckdb_cursor): duckdb_cursor.execute("create table t (a integer)") duckdb_cursor.execute("insert into t values (1),(4)") test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) - rel = duckdb.query_df(test_df, 't_2', 'select * from t inner join t_2 on (a = i)', connection=duckdb_cursor) + rel = duckdb.query_df(test_df, "t_2", "select * from t inner join t_2 on (a = i)", connection=duckdb_cursor) assert rel.fetchall()[0] == (1, 1) def test_query_order(self, duckdb_cursor): duckdb_cursor.execute("create table t (a integer)") duckdb_cursor.execute("insert into t values (1),(4)") test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) - rel = duckdb.order(test_df, 'i', connection=duckdb_cursor) - assert rel.query('t_2', 'select * from t inner join t_2 on (a = i)').fetchall()[0] == (1, 1) + rel = duckdb.order(test_df, "i", connection=duckdb_cursor) + assert rel.query("t_2", "select * from t inner join t_2 on (a = i)").fetchall()[0] == (1, 1) diff --git a/tests/fast/test_parameter_list.py b/tests/fast/test_parameter_list.py index 032b1b9c..5a85ac2f 100644 --- a/tests/fast/test_parameter_list.py +++ b/tests/fast/test_parameter_list.py @@ -11,22 +11,22 @@ def test_bool(self, duckdb_cursor): res = conn.execute("select count(*) from bool_table where a =?", [True]) assert res.fetchone()[0] == 1 - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_exception(self, duckdb_cursor, pandas): conn = duckdb.connect() df_in = pandas.DataFrame( { - 'numbers': [1, 2, 3, 4, 5], + "numbers": [1, 2, 3, 4, 5], } ) conn.execute("create table bool_table (a bool)") conn.execute("insert into bool_table values (TRUE)") - with pytest.raises(duckdb.NotImplementedException, match='Unable to transform'): + with pytest.raises(duckdb.NotImplementedException, match="Unable to transform"): res = conn.execute("select count(*) from bool_table where a =?", [df_in]) def test_explicit_nan_param(self): con = duckdb.default_connection() - res = con.execute('select isnan(cast(? as double))', (float("nan"),)) + res = con.execute("select isnan(cast(? as double))", (float("nan"),)) assert res.fetchone()[0] == True def test_string_parameter(self, duckdb_cursor): diff --git a/tests/fast/test_parquet.py b/tests/fast/test_parquet.py index 51d8d276..61d74023 100644 --- a/tests/fast/test_parquet.py +++ b/tests/fast/test_parquet.py @@ -7,13 +7,13 @@ VARCHAR = duckdb.typing.VARCHAR BIGINT = duckdb.typing.BIGINT -filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'binary_string.parquet') +filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "binary_string.parquet") @pytest.fixture(scope="session") def tmp_parquets(tmp_path_factory): - tmp_dir = tmp_path_factory.mktemp('parquets', numbered=True) - tmp_parquets = [str(tmp_dir / ('tmp' + str(i) + '.parquet')) for i in range(1, 4)] + tmp_dir = tmp_path_factory.mktemp("parquets", numbered=True) + tmp_parquets = [str(tmp_dir / ("tmp" + str(i) + ".parquet")) for i in range(1, 4)] return tmp_parquets @@ -21,34 +21,34 @@ class TestParquet(object): def test_scan_binary(self, duckdb_cursor): conn = duckdb.connect() res = conn.execute("SELECT typeof(#1) FROM parquet_scan('" + filename + "') limit 1").fetchall() - assert res[0] == ('BLOB',) + assert res[0] == ("BLOB",) res = conn.execute("SELECT * FROM parquet_scan('" + filename + "')").fetchall() - assert res[0] == (b'foo',) + assert res[0] == (b"foo",) def test_from_parquet_binary(self, duckdb_cursor): rel = duckdb.from_parquet(filename) - assert rel.types == ['BLOB'] + assert rel.types == ["BLOB"] res = rel.execute().fetchall() - assert res[0] == (b'foo',) + assert res[0] == (b"foo",) def test_scan_binary_as_string(self, duckdb_cursor): conn = duckdb.connect() res = conn.execute( "SELECT typeof(#1) FROM parquet_scan('" + filename + "',binary_as_string=True) limit 1" ).fetchall() - assert res[0] == ('VARCHAR',) + assert res[0] == ("VARCHAR",) res = conn.execute("SELECT * FROM parquet_scan('" + filename + "',binary_as_string=True)").fetchall() - assert res[0] == ('foo',) + assert res[0] == ("foo",) def test_from_parquet_binary_as_string(self, duckdb_cursor): rel = duckdb.from_parquet(filename, True) assert rel.types == [VARCHAR] res = rel.execute().fetchall() - assert res[0] == ('foo',) + assert res[0] == ("foo",) def test_from_parquet_file_row_number(self, duckdb_cursor): rel = duckdb.from_parquet(filename, binary_as_string=True, file_row_number=True) @@ -56,7 +56,7 @@ def test_from_parquet_file_row_number(self, duckdb_cursor): res = rel.execute().fetchall() assert res[0] == ( - 'foo', + "foo", 0, ) @@ -66,7 +66,7 @@ def test_from_parquet_filename(self, duckdb_cursor): res = rel.execute().fetchall() assert res[0] == ( - 'foo', + "foo", filename, ) @@ -75,7 +75,7 @@ def test_from_parquet_list_binary_as_string(self, duckdb_cursor): assert rel.types == [VARCHAR] res = rel.execute().fetchall() - assert res[0] == ('foo',) + assert res[0] == ("foo",) def test_from_parquet_list_file_row_number(self, duckdb_cursor): rel = duckdb.from_parquet([filename], binary_as_string=True, file_row_number=True) @@ -83,7 +83,7 @@ def test_from_parquet_list_file_row_number(self, duckdb_cursor): res = rel.execute().fetchall() assert res[0] == ( - 'foo', + "foo", 0, ) @@ -93,41 +93,41 @@ def test_from_parquet_list_filename(self, duckdb_cursor): res = rel.execute().fetchall() assert res[0] == ( - 'foo', + "foo", filename, ) def test_parquet_binary_as_string_pragma(self, duckdb_cursor): conn = duckdb.connect() res = conn.execute("SELECT typeof(#1) FROM parquet_scan('" + filename + "') limit 1").fetchall() - assert res[0] == ('BLOB',) + assert res[0] == ("BLOB",) res = conn.execute("SELECT * FROM parquet_scan('" + filename + "')").fetchall() - assert res[0] == (b'foo',) + assert res[0] == (b"foo",) conn.execute("PRAGMA binary_as_string=1") res = conn.execute("SELECT typeof(#1) FROM parquet_scan('" + filename + "') limit 1").fetchall() - assert res[0] == ('VARCHAR',) + assert res[0] == ("VARCHAR",) res = conn.execute("SELECT * FROM parquet_scan('" + filename + "')").fetchall() - assert res[0] == ('foo',) + assert res[0] == ("foo",) res = conn.execute( "SELECT typeof(#1) FROM parquet_scan('" + filename + "',binary_as_string=False) limit 1" ).fetchall() - assert res[0] == ('BLOB',) + assert res[0] == ("BLOB",) res = conn.execute("SELECT * FROM parquet_scan('" + filename + "',binary_as_string=False)").fetchall() - assert res[0] == (b'foo',) + assert res[0] == (b"foo",) conn.execute("PRAGMA binary_as_string=0") res = conn.execute("SELECT typeof(#1) FROM parquet_scan('" + filename + "') limit 1").fetchall() - assert res[0] == ('BLOB',) + assert res[0] == ("BLOB",) res = conn.execute("SELECT * FROM parquet_scan('" + filename + "')").fetchall() - assert res[0] == (b'foo',) + assert res[0] == (b"foo",) def test_from_parquet_binary_as_string_default_conn(self, duckdb_cursor): duckdb.execute("PRAGMA binary_as_string=1") @@ -136,7 +136,7 @@ def test_from_parquet_binary_as_string_default_conn(self, duckdb_cursor): assert rel.types == [VARCHAR] res = rel.execute().fetchall() - assert res[0] == ('foo',) + assert res[0] == ("foo",) def test_from_parquet_union_by_name(self, tmp_parquets): conn = duckdb.connect() @@ -159,7 +159,7 @@ def test_from_parquet_union_by_name(self, tmp_parquets): + "' (format 'parquet');" ) - rel = duckdb.from_parquet(tmp_parquets, union_by_name=True).order('a') + rel = duckdb.from_parquet(tmp_parquets, union_by_name=True).order("a") assert rel.execute().fetchall() == [ ( 1, diff --git a/tests/fast/test_pypi_cleanup.py b/tests/fast/test_pypi_cleanup.py index 6e1460e2..84d4c9ff 100644 --- a/tests/fast/test_pypi_cleanup.py +++ b/tests/fast/test_pypi_cleanup.py @@ -15,51 +15,61 @@ duckdb_packaging = pytest.importorskip("duckdb_packaging") from duckdb_packaging.pypi_cleanup import ( - PyPICleanup, CsrfParser, PyPICleanupError, AuthenticationError, ValidationError, - setup_logging, validate_username, create_argument_parser, session_with_retries, - load_credentials, validate_arguments, main + PyPICleanup, + CsrfParser, + PyPICleanupError, + AuthenticationError, + ValidationError, + setup_logging, + validate_username, + create_argument_parser, + session_with_retries, + load_credentials, + validate_arguments, + main, ) + class TestValidation: """Test input validation functions.""" - + def test_validate_username_valid(self): """Test valid usernames.""" assert validate_username("user123") == "user123" assert validate_username(" user.name ") == "user.name" assert validate_username("test-user_name") == "test-user_name" assert validate_username("a") == "a" - + def test_validate_username_invalid(self): """Test invalid usernames.""" from argparse import ArgumentTypeError - + with pytest.raises(ArgumentTypeError, match="cannot be empty"): validate_username("") - + with pytest.raises(ArgumentTypeError, match="cannot be empty"): validate_username(" ") - + with pytest.raises(ArgumentTypeError, match="too long"): validate_username("a" * 101) - + with pytest.raises(ArgumentTypeError, match="Invalid username format"): validate_username("-invalid") - + with pytest.raises(ArgumentTypeError, match="Invalid username format"): validate_username("invalid-") - + def test_validate_arguments_dry_run(self): """Test argument validation for dry run mode.""" args = Mock(dry_run=True, username=None, max_nightlies=2) validate_arguments(args) # Should not raise - + def test_validate_arguments_live_mode_no_username(self): """Test argument validation for live mode without username.""" args = Mock(dry_run=False, username=None, max_nightlies=2) with pytest.raises(ValidationError, match="username is required"): validate_arguments(args) - + def test_validate_arguments_negative_nightlies(self): """Test argument validation with negative max nightlies.""" args = Mock(dry_run=True, username="test", max_nightlies=-1) @@ -69,27 +79,27 @@ def test_validate_arguments_negative_nightlies(self): class TestCredentials: """Test credential loading.""" - + def test_load_credentials_dry_run(self): """Test credential loading in dry run mode.""" password, otp = load_credentials(dry_run=True) assert password is None assert otp is None - - @patch.dict(os.environ, {'PYPI_CLEANUP_PASSWORD': 'test_pass', 'PYPI_CLEANUP_OTP': 'test_otp'}) + + @patch.dict(os.environ, {"PYPI_CLEANUP_PASSWORD": "test_pass", "PYPI_CLEANUP_OTP": "test_otp"}) def test_load_credentials_live_mode_success(self): """Test successful credential loading in live mode.""" password, otp = load_credentials(dry_run=False) - assert password == 'test_pass' - assert otp == 'test_otp' - + assert password == "test_pass" + assert otp == "test_otp" + @patch.dict(os.environ, {}, clear=True) def test_load_credentials_missing_password(self): """Test credential loading with missing password.""" with pytest.raises(ValidationError, match="PYPI_CLEANUP_PASSWORD"): load_credentials(dry_run=False) - - @patch.dict(os.environ, {'PYPI_CLEANUP_PASSWORD': 'test_pass'}) + + @patch.dict(os.environ, {"PYPI_CLEANUP_PASSWORD": "test_pass"}) def test_load_credentials_missing_otp(self): """Test credential loading with missing OTP.""" with pytest.raises(ValidationError, match="PYPI_CLEANUP_OTP"): @@ -105,56 +115,56 @@ def test_create_session_with_retries(self): assert isinstance(session, requests.Session) # Verify retry adapter is mounted adapter = session.get_adapter("https://example.com") - assert hasattr(adapter, 'max_retries') - retries = getattr(adapter, 'max_retries') + assert hasattr(adapter, "max_retries") + retries = getattr(adapter, "max_retries") assert isinstance(retries, Retry) - @patch('duckdb_packaging.pypi_cleanup.logging.basicConfig') + @patch("duckdb_packaging.pypi_cleanup.logging.basicConfig") def test_setup_logging_normal(self, mock_basicConfig): """Test logging setup in normal mode.""" setup_logging(verbose=False) mock_basicConfig.assert_called_once() call_args = mock_basicConfig.call_args[1] - assert call_args['level'] == 20 # INFO level + assert call_args["level"] == 20 # INFO level - @patch('duckdb_packaging.pypi_cleanup.logging.basicConfig') + @patch("duckdb_packaging.pypi_cleanup.logging.basicConfig") def test_setup_logging_verbose(self, mock_basicConfig): """Test logging setup in verbose mode.""" setup_logging(verbose=True) mock_basicConfig.assert_called_once() call_args = mock_basicConfig.call_args[1] - assert call_args['level'] == 10 # DEBUG level + assert call_args["level"] == 10 # DEBUG level class TestCsrfParser: """Test CSRF token parser.""" - + def test_csrf_parser_simple_form(self): """Test parsing CSRF token from simple form.""" - html = ''' + html = """
      - ''' + """ parser = CsrfParser("/test") parser.feed(html) assert parser.csrf == "abc123" - + def test_csrf_parser_multiple_forms(self): """Test parsing CSRF token when multiple forms exist.""" - html = ''' + html = """
      - ''' + """ parser = CsrfParser("/test") parser.feed(html) assert parser.csrf == "correct" - + def test_csrf_parser_no_token(self): """Test parser when no CSRF token is found.""" html = '
      ' @@ -165,6 +175,7 @@ def test_csrf_parser_no_token(self): class TestPyPICleanup: """Test the main PyPICleanup class.""" + @pytest.fixture def cleanup_dryrun_max_2(self) -> PyPICleanup: return PyPICleanup("https://test.pypi.org/", False, 2) @@ -175,26 +186,59 @@ def cleanup_dryrun_max_0(self) -> PyPICleanup: @pytest.fixture def cleanup_max_2(self) -> PyPICleanup: - return PyPICleanup("https://test.pypi.org/", True, 2, - username="", password="", otp="") + return PyPICleanup("https://test.pypi.org/", True, 2, username="", password="", otp="") def test_determine_versions_to_delete_max_2(self, cleanup_dryrun_max_2): start_state = { "0.1.0", - "1.0.0.dev1", "1.0.0.dev2", "1.0.0.rc1", "1.0.0", - "1.0.1.dev3", "1.0.1.dev5", "1.0.1.dev8", "1.0.1.dev13", "1.0.1.dev21", "1.0.1", - "1.1.0.dev34", "1.1.0.dev54", "1.1.0.dev88", "1.1.0", "1.1.0.post1", - "1.1.1.dev142", "1.1.1.dev230", "1.1.1.dev372", - "2.0.0.dev602", "2.0.0.rc1", "2.0.0.rc2", "2.0.0.rc3", "2.0.0.rc4", "2.0.0", - "2.0.1.dev974", "2.0.1.rc1", "2.0.1.rc2", "2.0.1.rc3", + "1.0.0.dev1", + "1.0.0.dev2", + "1.0.0.rc1", + "1.0.0", + "1.0.1.dev3", + "1.0.1.dev5", + "1.0.1.dev8", + "1.0.1.dev13", + "1.0.1.dev21", + "1.0.1", + "1.1.0.dev34", + "1.1.0.dev54", + "1.1.0.dev88", + "1.1.0", + "1.1.0.post1", + "1.1.1.dev142", + "1.1.1.dev230", + "1.1.1.dev372", + "2.0.0.dev602", + "2.0.0.rc1", + "2.0.0.rc2", + "2.0.0.rc3", + "2.0.0.rc4", + "2.0.0", + "2.0.1.dev974", + "2.0.1.rc1", + "2.0.1.rc2", + "2.0.1.rc3", } expected_deletions = { - "1.0.0.dev1", "1.0.0.dev2", "1.0.0.rc1", - "1.0.1.dev3", "1.0.1.dev5", "1.0.1.dev8", "1.0.1.dev13", "1.0.1.dev21", - "1.1.0.dev34", "1.1.0.dev54", "1.1.0.dev88", + "1.0.0.dev1", + "1.0.0.dev2", + "1.0.0.rc1", + "1.0.1.dev3", + "1.0.1.dev5", + "1.0.1.dev8", + "1.0.1.dev13", + "1.0.1.dev21", + "1.1.0.dev34", + "1.1.0.dev54", + "1.1.0.dev88", "1.1.1.dev142", - "2.0.0.dev602", "2.0.0.rc1", "2.0.0.rc2", "2.0.0.rc3", "2.0.0.rc4", - "2.0.1.dev974" + "2.0.0.dev602", + "2.0.0.rc1", + "2.0.0.rc2", + "2.0.0.rc3", + "2.0.0.rc4", + "2.0.1.dev974", } versions_to_delete = cleanup_dryrun_max_2._determine_versions_to_delete(start_state) assert versions_to_delete == expected_deletions @@ -202,35 +246,82 @@ def test_determine_versions_to_delete_max_2(self, cleanup_dryrun_max_2): def test_determine_versions_to_delete_max_0(self, cleanup_dryrun_max_0): start_state = { "0.1.0", - "1.0.0.dev1", "1.0.0.dev2", "1.0.0.rc1", "1.0.0", - "1.0.1.dev3", "1.0.1.dev5", "1.0.1.dev8", "1.0.1.dev13", "1.0.1.dev21", "1.0.1", - "1.1.0.dev34", "1.1.0.dev54", "1.1.0.dev88", "1.1.0", "1.1.0.post1", - "1.1.1.dev142", "1.1.1.dev230", "1.1.1.dev372", - "2.0.0.dev602", "2.0.0.rc1", "2.0.0.rc2", "2.0.0.rc3", "2.0.0.rc4", "2.0.0", - "2.0.1.dev974", "2.0.1.rc1", "2.0.1.rc2", "2.0.1.rc3", + "1.0.0.dev1", + "1.0.0.dev2", + "1.0.0.rc1", + "1.0.0", + "1.0.1.dev3", + "1.0.1.dev5", + "1.0.1.dev8", + "1.0.1.dev13", + "1.0.1.dev21", + "1.0.1", + "1.1.0.dev34", + "1.1.0.dev54", + "1.1.0.dev88", + "1.1.0", + "1.1.0.post1", + "1.1.1.dev142", + "1.1.1.dev230", + "1.1.1.dev372", + "2.0.0.dev602", + "2.0.0.rc1", + "2.0.0.rc2", + "2.0.0.rc3", + "2.0.0.rc4", + "2.0.0", + "2.0.1.dev974", + "2.0.1.rc1", + "2.0.1.rc2", + "2.0.1.rc3", } expected_deletions = { - "1.0.0.dev1", "1.0.0.dev2", "1.0.0.rc1", - "1.0.1.dev3", "1.0.1.dev5", "1.0.1.dev8", "1.0.1.dev13", "1.0.1.dev21", - "1.1.0.dev34", "1.1.0.dev54", "1.1.0.dev88", - "1.1.1.dev142", "1.1.1.dev230", "1.1.1.dev372", - "2.0.0.dev602", "2.0.0.rc1", "2.0.0.rc2", "2.0.0.rc3", "2.0.0.rc4", - "2.0.1.dev974" + "1.0.0.dev1", + "1.0.0.dev2", + "1.0.0.rc1", + "1.0.1.dev3", + "1.0.1.dev5", + "1.0.1.dev8", + "1.0.1.dev13", + "1.0.1.dev21", + "1.1.0.dev34", + "1.1.0.dev54", + "1.1.0.dev88", + "1.1.1.dev142", + "1.1.1.dev230", + "1.1.1.dev372", + "2.0.0.dev602", + "2.0.0.rc1", + "2.0.0.rc2", + "2.0.0.rc3", + "2.0.0.rc4", + "2.0.1.dev974", } versions_to_delete = cleanup_dryrun_max_0._determine_versions_to_delete(start_state) assert versions_to_delete == expected_deletions def test_determine_versions_to_delete_only_devs_max_2(self, cleanup_dryrun_max_2): start_state = { - "1.0.0.dev1", "1.0.0.dev2", - "1.0.1.dev3", "1.0.1.dev5", "1.0.1.dev8", "1.0.1.dev13", "1.0.1.dev21", - "1.1.0.dev34", "1.1.0.dev54", "1.1.0.dev88", - "1.1.1.dev142", "1.1.1.dev230", "1.1.1.dev372", + "1.0.0.dev1", + "1.0.0.dev2", + "1.0.1.dev3", + "1.0.1.dev5", + "1.0.1.dev8", + "1.0.1.dev13", + "1.0.1.dev21", + "1.1.0.dev34", + "1.1.0.dev54", + "1.1.0.dev88", + "1.1.1.dev142", + "1.1.1.dev230", + "1.1.1.dev372", "2.0.0.dev602", "2.0.1.dev974", } expected_deletions = { - "1.0.1.dev3", "1.0.1.dev5", "1.0.1.dev8", + "1.0.1.dev3", + "1.0.1.dev5", + "1.0.1.dev8", "1.1.0.dev34", "1.1.1.dev142", } @@ -239,19 +330,28 @@ def test_determine_versions_to_delete_only_devs_max_2(self, cleanup_dryrun_max_2 def test_determine_versions_to_delete_only_devs_max_0_fails(self, cleanup_dryrun_max_0): start_state = { - "1.0.0.dev1", "1.0.0.dev2", - "1.0.1.dev3", "1.0.1.dev5", "1.0.1.dev8", "1.0.1.dev13", "1.0.1.dev21", - "1.1.0.dev34", "1.1.0.dev54", "1.1.0.dev88", - "1.1.1.dev142", "1.1.1.dev230", "1.1.1.dev372", + "1.0.0.dev1", + "1.0.0.dev2", + "1.0.1.dev3", + "1.0.1.dev5", + "1.0.1.dev8", + "1.0.1.dev13", + "1.0.1.dev21", + "1.1.0.dev34", + "1.1.0.dev54", + "1.1.0.dev88", + "1.1.1.dev142", + "1.1.1.dev230", + "1.1.1.dev372", "2.0.0.dev602", "2.0.1.dev974", } with pytest.raises(PyPICleanupError, match="Safety check failed"): cleanup_dryrun_max_0._determine_versions_to_delete(start_state) - @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._delete_versions') - @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._fetch_released_versions') - @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._determine_versions_to_delete') + @patch("duckdb_packaging.pypi_cleanup.PyPICleanup._delete_versions") + @patch("duckdb_packaging.pypi_cleanup.PyPICleanup._fetch_released_versions") + @patch("duckdb_packaging.pypi_cleanup.PyPICleanup._determine_versions_to_delete") def test_execute_cleanup_dry_run(self, mock_determine, mock_fetch, mock_delete, cleanup_dryrun_max_2): mock_fetch.return_value = {"1.0.0.dev1"} mock_determine.return_value = {"1.0.0.dev1"} @@ -264,14 +364,14 @@ def test_execute_cleanup_dry_run(self, mock_determine, mock_fetch, mock_delete, mock_determine.assert_called_once() mock_delete.assert_not_called() - @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._fetch_released_versions') + @patch("duckdb_packaging.pypi_cleanup.PyPICleanup._fetch_released_versions") def test_execute_cleanup_no_releases(self, mock_fetch, cleanup_dryrun_max_2): mock_fetch.return_value = {} with session_with_retries() as session: result = cleanup_dryrun_max_2._execute_cleanup(session) assert result == 0 - @patch('requests.Session.get') + @patch("requests.Session.get") def test_fetch_released_versions_success(self, mock_get, cleanup_dryrun_max_2): """Test successful package release fetching.""" mock_response = Mock() @@ -288,7 +388,7 @@ def test_fetch_released_versions_success(self, mock_get, cleanup_dryrun_max_2): assert releases == {"1.0.0", "1.0.0.dev1"} - @patch('requests.Session.get') + @patch("requests.Session.get") def test_fetch_released_versions_not_found(self, mock_get, cleanup_dryrun_max_2): """Test package release fetching when package not found.""" mock_response = Mock() @@ -299,8 +399,8 @@ def test_fetch_released_versions_not_found(self, mock_get, cleanup_dryrun_max_2) with session_with_retries() as session: cleanup_dryrun_max_2._fetch_released_versions(session) - @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._get_csrf_token') - @patch('requests.Session.post') + @patch("duckdb_packaging.pypi_cleanup.PyPICleanup._get_csrf_token") + @patch("requests.Session.post") def test_authenticate_success(self, mock_post, mock_csrf, cleanup_max_2): """Test successful authentication.""" mock_csrf.return_value = "csrf123" @@ -313,11 +413,11 @@ def test_authenticate_success(self, mock_post, mock_csrf, cleanup_max_2): mock_csrf.assert_called_once_with(session, "/account/login/") mock_post.assert_called_once() - assert mock_post.call_args.args[0].endswith('/account/login/') + assert mock_post.call_args.args[0].endswith("/account/login/") - @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._get_csrf_token') - @patch('requests.Session.post') - @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._handle_two_factor_auth') + @patch("duckdb_packaging.pypi_cleanup.PyPICleanup._get_csrf_token") + @patch("requests.Session.post") + @patch("duckdb_packaging.pypi_cleanup.PyPICleanup._handle_two_factor_auth") def test_authenticate_with_2fa(self, mock_2fa, mock_post, mock_csrf, cleanup_max_2): mock_csrf.return_value = "csrf123" mock_response = Mock() @@ -332,7 +432,7 @@ def test_authenticate_missing_credentials(self, cleanup_dryrun_max_2): with pytest.raises(AuthenticationError, match="Username and password are required"): cleanup_dryrun_max_2._authenticate(None) - @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._delete_single_version') + @patch("duckdb_packaging.pypi_cleanup.PyPICleanup._delete_single_version") def test_delete_versions_success(self, mock_delete, cleanup_max_2): """Test successful version deletion.""" versions = {"1.0.0.dev1", "1.0.0.dev2"} @@ -343,7 +443,7 @@ def test_delete_versions_success(self, mock_delete, cleanup_max_2): assert mock_delete.call_count == 2 - @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._delete_single_version') + @patch("duckdb_packaging.pypi_cleanup.PyPICleanup._delete_single_version") def test_delete_versions_partial_failure(self, mock_delete, cleanup_max_2): """Test version deletion with partial failures.""" versions = {"1.0.0.dev1", "1.0.0.dev2"} @@ -360,75 +460,75 @@ def test_delete_single_version_safety_check(self, cleanup_max_2): class TestArgumentParser: """Test command line argument parsing.""" - + def test_argument_parser_creation(self): """Test argument parser creation.""" parser = create_argument_parser() assert parser.prog is not None - + def test_parse_args_prod_dry_run(self): """Test parsing arguments for production dry run.""" parser = create_argument_parser() - args = parser.parse_args(['--prod', '--dry-run']) - + args = parser.parse_args(["--prod", "--dry-run"]) + assert args.prod is True assert args.test is False assert args.dry_run is True assert args.max_nightlies == 2 assert args.verbose is False - + def test_parse_args_test_with_username(self): """Test parsing arguments for test with username.""" parser = create_argument_parser() - args = parser.parse_args(['--test', '-u', 'testuser', '--verbose']) - + args = parser.parse_args(["--test", "-u", "testuser", "--verbose"]) + assert args.test is True assert args.prod is False - assert args.username == 'testuser' + assert args.username == "testuser" assert args.verbose is True - + def test_parse_args_missing_host(self): """Test parsing arguments with missing host selection.""" parser = create_argument_parser() - + with pytest.raises(SystemExit): - parser.parse_args(['--dry-run']) # Missing --prod or --test + parser.parse_args(["--dry-run"]) # Missing --prod or --test class TestMainFunction: """Test the main function.""" - - @patch('duckdb_packaging.pypi_cleanup.setup_logging') - @patch('duckdb_packaging.pypi_cleanup.PyPICleanup') - @patch.dict(os.environ, {'PYPI_CLEANUP_PASSWORD': 'test', 'PYPI_CLEANUP_OTP': 'test'}) + + @patch("duckdb_packaging.pypi_cleanup.setup_logging") + @patch("duckdb_packaging.pypi_cleanup.PyPICleanup") + @patch.dict(os.environ, {"PYPI_CLEANUP_PASSWORD": "test", "PYPI_CLEANUP_OTP": "test"}) def test_main_success(self, mock_cleanup_class, mock_setup_logging): """Test successful main function execution.""" mock_cleanup = Mock() mock_cleanup.run.return_value = 0 mock_cleanup_class.return_value = mock_cleanup - - with patch('sys.argv', ['pypi_cleanup.py', '--test', '-u', 'testuser']): + + with patch("sys.argv", ["pypi_cleanup.py", "--test", "-u", "testuser"]): result = main() - + assert result == 0 mock_setup_logging.assert_called_once() mock_cleanup.run.assert_called_once() - - @patch('duckdb_packaging.pypi_cleanup.setup_logging') + + @patch("duckdb_packaging.pypi_cleanup.setup_logging") def test_main_validation_error(self, mock_setup_logging): """Test main function with validation error.""" - with patch('sys.argv', ['pypi_cleanup.py', '--test']): # Missing username for live mode + with patch("sys.argv", ["pypi_cleanup.py", "--test"]): # Missing username for live mode result = main() - + assert result == 2 # Validation error exit code - - @patch('duckdb_packaging.pypi_cleanup.setup_logging') - @patch('duckdb_packaging.pypi_cleanup.validate_arguments') + + @patch("duckdb_packaging.pypi_cleanup.setup_logging") + @patch("duckdb_packaging.pypi_cleanup.validate_arguments") def test_main_keyboard_interrupt(self, mock_validate, mock_setup_logging): """Test main function with keyboard interrupt.""" mock_validate.side_effect = KeyboardInterrupt() - - with patch('sys.argv', ['pypi_cleanup.py', '--test', '--dry-run']): + + with patch("sys.argv", ["pypi_cleanup.py", "--test", "--dry-run"]): result = main() - + assert result == 130 # Keyboard interrupt exit code diff --git a/tests/fast/test_pytorch.py b/tests/fast/test_pytorch.py index 365585cc..c5b9b4d6 100644 --- a/tests/fast/test_pytorch.py +++ b/tests/fast/test_pytorch.py @@ -2,7 +2,7 @@ import pytest -torch = pytest.importorskip('torch') +torch = pytest.importorskip("torch") @pytest.mark.skip(reason="some issues with Numpy, to be reverted") @@ -15,16 +15,16 @@ def test_pytorch(): # Test from connection duck_torch = con.execute("select * from t").torch() duck_numpy = con.sql("select * from t").fetchnumpy() - torch.equal(duck_torch['a'], torch.tensor(duck_numpy['a'])) - torch.equal(duck_torch['b'], torch.tensor(duck_numpy['b'])) + torch.equal(duck_torch["a"], torch.tensor(duck_numpy["a"])) + torch.equal(duck_torch["b"], torch.tensor(duck_numpy["b"])) # Test from relation duck_torch = con.sql("select * from t").torch() - torch.equal(duck_torch['a'], torch.tensor(duck_numpy['a'])) - torch.equal(duck_torch['b'], torch.tensor(duck_numpy['b'])) + torch.equal(duck_torch["a"], torch.tensor(duck_numpy["a"])) + torch.equal(duck_torch["b"], torch.tensor(duck_numpy["b"])) # Test all Numeric Types - numeric_types = ['TINYINT', 'SMALLINT', 'BIGINT', 'HUGEINT', 'FLOAT', 'DOUBLE', 'DECIMAL(4,1)', 'UTINYINT'] + numeric_types = ["TINYINT", "SMALLINT", "BIGINT", "HUGEINT", "FLOAT", "DOUBLE", "DECIMAL(4,1)", "UTINYINT"] for supported_type in numeric_types: con = duckdb.connect() @@ -32,8 +32,8 @@ def test_pytorch(): con.execute("insert into t values (1,2), (3,4)") duck_torch = con.sql("select * from t").torch() duck_numpy = con.sql("select * from t").fetchnumpy() - torch.equal(duck_torch['a'], torch.tensor(duck_numpy['a'])) - torch.equal(duck_torch['b'], torch.tensor(duck_numpy['b'])) + torch.equal(duck_torch["a"], torch.tensor(duck_numpy["a"])) + torch.equal(duck_torch["b"], torch.tensor(duck_numpy["b"])) # Comment out test that might fail or not depending on pytorch versions # with pytest.raises(TypeError, match="can't convert"): diff --git a/tests/fast/test_relation.py b/tests/fast/test_relation.py index 8e68c149..31ca393c 100644 --- a/tests/fast/test_relation.py +++ b/tests/fast/test_relation.py @@ -37,10 +37,10 @@ def test_csv_auto(self): csv_rel = duckdb.from_csv_auto(temp_file_name) assert df_rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_relation_view(self, duckdb_cursor, pandas): def create_view(duckdb_cursor): - df_in = pandas.DataFrame({'numbers': [1, 2, 3, 4, 5]}) + df_in = pandas.DataFrame({"numbers": [1, 2, 3, 4, 5]}) rel = duckdb_cursor.query("select * from df_in") rel.to_view("my_view") @@ -59,23 +59,23 @@ def create_view(duckdb_cursor): def test_filter_operator(self): conn = duckdb.connect() rel = get_relation(conn) - assert rel.filter('i > 1').execute().fetchall() == [(2, 'two'), (3, 'three'), (4, 'four')] + assert rel.filter("i > 1").execute().fetchall() == [(2, "two"), (3, "three"), (4, "four")] def test_projection_operator_single(self): conn = duckdb.connect() rel = get_relation(conn) - assert rel.project('i').execute().fetchall() == [(1,), (2,), (3,), (4,)] + assert rel.project("i").execute().fetchall() == [(1,), (2,), (3,), (4,)] def test_projection_operator_double(self): conn = duckdb.connect() rel = get_relation(conn) - assert rel.order('j').execute().fetchall() == [(4, 'four'), (1, 'one'), (3, 'three'), (2, 'two')] + assert rel.order("j").execute().fetchall() == [(4, "four"), (1, "one"), (3, "three"), (2, "two")] def test_limit_operator(self): conn = duckdb.connect() rel = get_relation(conn) - assert rel.limit(2).execute().fetchall() == [(1, 'one'), (2, 'two')] - assert rel.limit(2, offset=1).execute().fetchall() == [(2, 'two'), (3, 'three')] + assert rel.limit(2).execute().fetchall() == [(1, "one"), (2, "two")] + assert rel.limit(2, offset=1).execute().fetchall() == [(2, "two"), (3, "three")] def test_intersect_operator(self): conn = duckdb.connect() @@ -86,23 +86,23 @@ def test_intersect_operator(self): rel = conn.from_df(test_df) rel_2 = conn.from_df(test_df_2) - assert rel.intersect(rel_2).order('i').execute().fetchall() == [(3,), (4,)] + assert rel.intersect(rel_2).order("i").execute().fetchall() == [(3,), (4,)] def test_aggregate_operator(self): conn = duckdb.connect() rel = get_relation(conn) assert rel.aggregate("sum(i)").execute().fetchall() == [(10,)] - assert rel.aggregate("j, sum(i)").order('#2').execute().fetchall() == [ - ('one', 1), - ('two', 2), - ('three', 3), - ('four', 4), + assert rel.aggregate("j, sum(i)").order("#2").execute().fetchall() == [ + ("one", 1), + ("two", 2), + ("three", 3), + ("four", 4), ] def test_relation_fetch_df_chunk(self, duckdb_cursor): duckdb_cursor.execute(f"create table tbl as select * from range({duckdb.__standard_vector_size__ * 3})") - rel = duckdb_cursor.table('tbl') + rel = duckdb_cursor.table("tbl") # default arguments df1 = rel.fetch_df_chunk() assert len(df1) == duckdb.__standard_vector_size__ @@ -114,40 +114,40 @@ def test_relation_fetch_df_chunk(self, duckdb_cursor): f"create table dates as select (DATE '2021/02/21' + INTERVAL (i) DAYS)::DATE a from range({duckdb.__standard_vector_size__ * 4}) t(i)" ) - rel = duckdb_cursor.table('dates') + rel = duckdb_cursor.table("dates") # default arguments df1 = rel.fetch_df_chunk() assert len(df1) == duckdb.__standard_vector_size__ - assert df1['a'][0].__class__ == pd.Timestamp + assert df1["a"][0].__class__ == pd.Timestamp # date as object df1 = rel.fetch_df_chunk(date_as_object=True) assert len(df1) == duckdb.__standard_vector_size__ - assert df1['a'][0].__class__ == datetime.date + assert df1["a"][0].__class__ == datetime.date # vectors and date as object df1 = rel.fetch_df_chunk(2, date_as_object=True) assert len(df1) == duckdb.__standard_vector_size__ * 2 - assert df1['a'][0].__class__ == datetime.date + assert df1["a"][0].__class__ == datetime.date def test_distinct_operator(self): conn = duckdb.connect() rel = get_relation(conn) - assert rel.distinct().order('all').execute().fetchall() == [(1, 'one'), (2, 'two'), (3, 'three'), (4, 'four')] + assert rel.distinct().order("all").execute().fetchall() == [(1, "one"), (2, "two"), (3, "three"), (4, "four")] def test_union_operator(self): conn = duckdb.connect() rel = get_relation(conn) print(rel.union(rel).execute().fetchall()) assert rel.union(rel).execute().fetchall() == [ - (1, 'one'), - (2, 'two'), - (3, 'three'), - (4, 'four'), - (1, 'one'), - (2, 'two'), - (3, 'three'), - (4, 'four'), + (1, "one"), + (2, "two"), + (3, "three"), + (4, "four"), + (1, "one"), + (2, "two"), + (3, "three"), + (4, "four"), ] def test_join_operator(self): @@ -156,11 +156,11 @@ def test_join_operator(self): test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]}) rel = conn.from_df(test_df) rel2 = conn.from_df(test_df) - assert rel.join(rel2, 'i').execute().fetchall() == [ - (1, 'one', 'one'), - (2, 'two', 'two'), - (3, 'three', 'three'), - (4, 'four', 'four'), + assert rel.join(rel2, "i").execute().fetchall() == [ + (1, "one", "one"), + (2, "two", "two"), + (3, "three", "three"), + (4, "four", "four"), ] def test_except_operator(self): @@ -176,10 +176,10 @@ def test_create_operator(self): rel = conn.from_df(test_df) rel.create("test_df") assert conn.query("select * from test_df").execute().fetchall() == [ - (1, 'one'), - (2, 'two'), - (3, 'three'), - (4, 'four'), + (1, "one"), + (2, "two"), + (3, "three"), + (4, "four"), ] def test_create_view_operator(self): @@ -188,31 +188,31 @@ def test_create_view_operator(self): rel = conn.from_df(test_df) rel.create_view("test_df") assert conn.query("select * from test_df").execute().fetchall() == [ - (1, 'one'), - (2, 'two'), - (3, 'three'), - (4, 'four'), + (1, "one"), + (2, "two"), + (3, "three"), + (4, "four"), ] def test_update_relation(self, duckdb_cursor): duckdb_cursor.sql("create table tbl (a varchar default 'test', b int)") - duckdb_cursor.table('tbl').insert(['hello', 21]) - duckdb_cursor.table('tbl').insert(['hello', 42]) + duckdb_cursor.table("tbl").insert(["hello", 21]) + duckdb_cursor.table("tbl").insert(["hello", 42]) # UPDATE tbl SET a = DEFAULT where b = 42 - duckdb_cursor.table('tbl').update( - {'a': duckdb.DefaultExpression()}, condition=duckdb.ColumnExpression('b') == 42 + duckdb_cursor.table("tbl").update( + {"a": duckdb.DefaultExpression()}, condition=duckdb.ColumnExpression("b") == 42 ) - assert duckdb_cursor.table('tbl').fetchall() == [('hello', 21), ('test', 42)] + assert duckdb_cursor.table("tbl").fetchall() == [("hello", 21), ("test", 42)] - rel = duckdb_cursor.table('tbl') - with pytest.raises(duckdb.InvalidInputException, match='Please provide at least one set expression'): + rel = duckdb_cursor.table("tbl") + with pytest.raises(duckdb.InvalidInputException, match="Please provide at least one set expression"): rel.update({}) with pytest.raises( - duckdb.InvalidInputException, match='Please provide the column name as the key of the dictionary' + duckdb.InvalidInputException, match="Please provide the column name as the key of the dictionary" ): rel.update({1: 21}) - with pytest.raises(duckdb.BinderException, match='Referenced update column c not found in table!'): - rel.update({'c': 21}) + with pytest.raises(duckdb.BinderException, match="Referenced update column c not found in table!"): + rel.update({"c": 21}) with pytest.raises( duckdb.InvalidInputException, match="Please provide 'set' as a dictionary of column name to Expression" ): @@ -221,11 +221,11 @@ def test_update_relation(self, duckdb_cursor): duckdb.InvalidInputException, match="Please provide an object of type Expression as the value, not ", ): - rel.update({'a': {21}}) + rel.update({"a": {21}}) def test_value_relation(self, duckdb_cursor): # Needs at least one input - with pytest.raises(duckdb.InvalidInputException, match='Could not create a ValueRelation without any inputs'): + with pytest.raises(duckdb.InvalidInputException, match="Could not create a ValueRelation without any inputs"): duckdb_cursor.values() # From a list of (python) values @@ -233,28 +233,28 @@ def test_value_relation(self, duckdb_cursor): assert rel.fetchall() == [(1, 2, 3)] # From an Expression - rel = duckdb_cursor.values(duckdb.ConstantExpression('test')) - assert rel.fetchall() == [('test',)] + rel = duckdb_cursor.values(duckdb.ConstantExpression("test")) + assert rel.fetchall() == [("test",)] # From multiple Expressions rel = duckdb_cursor.values( - duckdb.ConstantExpression('1'), duckdb.ConstantExpression('2'), duckdb.ConstantExpression('3') + duckdb.ConstantExpression("1"), duckdb.ConstantExpression("2"), duckdb.ConstantExpression("3") ) - assert rel.fetchall() == [('1', '2', '3')] + assert rel.fetchall() == [("1", "2", "3")] # From Expressions mixed with random values - with pytest.raises(duckdb.InvalidInputException, match='Please provide arguments of type Expression!'): + with pytest.raises(duckdb.InvalidInputException, match="Please provide arguments of type Expression!"): rel = duckdb_cursor.values( - duckdb.ConstantExpression('1'), - {'test'}, - duckdb.ConstantExpression('3'), + duckdb.ConstantExpression("1"), + {"test"}, + duckdb.ConstantExpression("3"), ) # From Expressions mixed with values that *can* be autocast to Expression rel = duckdb_cursor.values( - duckdb.ConstantExpression('1'), + duckdb.ConstantExpression("1"), 2, - duckdb.ConstantExpression('3'), + duckdb.ConstantExpression("3"), ) const = duckdb.ConstantExpression @@ -264,21 +264,21 @@ def test_value_relation(self, duckdb_cursor): # From mismatching tuples of Expressions with pytest.raises( - duckdb.InvalidInputException, match='Mismatch between length of tuples in input, expected 3 but found 2' + duckdb.InvalidInputException, match="Mismatch between length of tuples in input, expected 3 but found 2" ): rel = duckdb_cursor.values((const(1), const(2), const(3)), (const(5), const(4))) # From an empty tuple - with pytest.raises(duckdb.InvalidInputException, match='Please provide a non-empty tuple'): + with pytest.raises(duckdb.InvalidInputException, match="Please provide a non-empty tuple"): rel = duckdb_cursor.values(()) # Mixing tuples with Expressions - with pytest.raises(duckdb.InvalidInputException, match='Expected objects of type tuple'): + with pytest.raises(duckdb.InvalidInputException, match="Expected objects of type tuple"): rel = duckdb_cursor.values((const(1), const(2), const(3)), const(4)) # Using Expressions that can't be resolved: with pytest.raises(duckdb.BinderException, match='Referenced column "a" not found in FROM clause!'): - duckdb_cursor.values(duckdb.ColumnExpression('a')) + duckdb_cursor.values(duckdb.ColumnExpression("a")) def test_insert_into_operator(self): conn = duckdb.connect() @@ -290,17 +290,17 @@ def test_insert_into_operator(self): rel.insert_into("test_table3") # Inserting elements into table_3 - print(conn.values([5, 'five']).insert_into("test_table3")) + print(conn.values([5, "five"]).insert_into("test_table3")) rel_3 = conn.table("test_table3") - rel_3.insert([6, 'six']) + rel_3.insert([6, "six"]) assert rel_3.execute().fetchall() == [ - (1, 'one'), - (2, 'two'), - (3, 'three'), - (4, 'four'), - (5, 'five'), - (6, 'six'), + (1, "one"), + (2, "two"), + (3, "three"), + (4, "four"), + (5, "five"), + (6, "six"), ] def test_write_csv_operator(self): @@ -316,8 +316,8 @@ def test_table_update_with_schema(self, duckdb_cursor): duckdb_cursor.sql("create schema not_main;") duckdb_cursor.sql("create table not_main.tbl as select * from range(10) t(a)") - duckdb_cursor.table('not_main.tbl').update({'a': 21}, condition=ColumnExpression('a') == 5) - res = duckdb_cursor.table('not_main.tbl').fetchall() + duckdb_cursor.table("not_main.tbl").update({"a": 21}, condition=ColumnExpression("a") == 5) + res = duckdb_cursor.table("not_main.tbl").fetchall() assert res == [(0,), (1,), (2,), (3,), (4,), (21,), (6,), (7,), (8,), (9,)] def test_table_update_with_catalog(self, duckdb_cursor): @@ -325,8 +325,8 @@ def test_table_update_with_catalog(self, duckdb_cursor): duckdb_cursor.sql("create schema pg.not_main;") duckdb_cursor.sql("create table pg.not_main.tbl as select * from range(10) t(a)") - duckdb_cursor.table('pg.not_main.tbl').update({'a': 21}, condition=ColumnExpression('a') == 5) - res = duckdb_cursor.table('pg.not_main.tbl').fetchall() + duckdb_cursor.table("pg.not_main.tbl").update({"a": 21}, condition=ColumnExpression("a") == 5) + res = duckdb_cursor.table("pg.not_main.tbl").fetchall() assert res == [(0,), (1,), (2,), (3,), (4,), (21,), (6,), (7,), (8,), (9,)] def test_get_attr_operator(self): @@ -335,50 +335,50 @@ def test_get_attr_operator(self): rel = conn.table("test") assert rel.alias == "test" assert rel.type == "TABLE_RELATION" - assert rel.columns == ['i'] - assert rel.types == ['INTEGER'] + assert rel.columns == ["i"] + assert rel.types == ["INTEGER"] def test_query_fail(self): conn = duckdb.connect() conn.execute("CREATE TABLE test (i INTEGER)") rel = conn.table("test") - with pytest.raises(TypeError, match='incompatible function arguments'): + with pytest.raises(TypeError, match="incompatible function arguments"): rel.query("select j from test") def test_execute_fail(self): conn = duckdb.connect() conn.execute("CREATE TABLE test (i INTEGER)") rel = conn.table("test") - with pytest.raises(TypeError, match='incompatible function arguments'): + with pytest.raises(TypeError, match="incompatible function arguments"): rel.execute("select j from test") def test_df_proj(self): test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]}) - rel = duckdb.project(test_df, 'i') + rel = duckdb.project(test_df, "i") assert rel.execute().fetchall() == [(1,), (2,), (3,), (4,)] def test_relation_lifetime(self, duckdb_cursor): def create_relation(con): - df = pd.DataFrame({'a': [1, 2, 3]}) + df = pd.DataFrame({"a": [1, 2, 3]}) return con.sql("select * from df") assert create_relation(duckdb_cursor).fetchall() == [(1,), (2,), (3,)] def create_simple_join(con): - df1 = pd.DataFrame({'a': ['a', 'b', 'c'], 'b': [1, 2, 3]}) - df2 = pd.DataFrame({'a': ['a', 'b', 'c'], 'b': [4, 5, 6]}) + df1 = pd.DataFrame({"a": ["a", "b", "c"], "b": [1, 2, 3]}) + df2 = pd.DataFrame({"a": ["a", "b", "c"], "b": [4, 5, 6]}) return con.sql("select * from df1 JOIN df2 USING (a, a)") - assert create_simple_join(duckdb_cursor).fetchall() == [('a', 1, 4), ('b', 2, 5), ('c', 3, 6)] + assert create_simple_join(duckdb_cursor).fetchall() == [("a", 1, 4), ("b", 2, 5), ("c", 3, 6)] def create_complex_join(con): - df1 = pd.DataFrame({'a': [1], '1': [1]}) - df2 = pd.DataFrame({'a': [1], '2': [2]}) - df3 = pd.DataFrame({'a': [1], '3': [3]}) - df4 = pd.DataFrame({'a': [1], '4': [4]}) - df5 = pd.DataFrame({'a': [1], '5': [5]}) - df6 = pd.DataFrame({'a': [1], '6': [6]}) + df1 = pd.DataFrame({"a": [1], "1": [1]}) + df2 = pd.DataFrame({"a": [1], "2": [2]}) + df3 = pd.DataFrame({"a": [1], "3": [3]}) + df4 = pd.DataFrame({"a": [1], "4": [4]}) + df5 = pd.DataFrame({"a": [1], "5": [5]}) + df6 = pd.DataFrame({"a": [1], "6": [6]}) query = "select * from df1" for i in range(5): query += f" JOIN df{i + 2} USING (a, a)" @@ -407,7 +407,7 @@ def test_project_on_types(self): assert projection.columns == ["c2", "c4"] # select bigint, tinyint and a type that isn't there - projection = rel.select_types([BIGINT, "tinyint", con.struct_type({'a': VARCHAR, 'b': TINYINT})]) + projection = rel.select_types([BIGINT, "tinyint", con.struct_type({"a": VARCHAR, "b": TINYINT})]) assert projection.columns == ["c0", "c1"] ## select with empty projection list, not possible @@ -420,30 +420,30 @@ def test_project_on_types(self): def test_df_alias(self): test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]}) - rel = duckdb.alias(test_df, 'dfzinho') + rel = duckdb.alias(test_df, "dfzinho") assert rel.alias == "dfzinho" def test_df_filter(self): test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]}) - rel = duckdb.filter(test_df, 'i > 1') - assert rel.execute().fetchall() == [(2, 'two'), (3, 'three'), (4, 'four')] + rel = duckdb.filter(test_df, "i > 1") + assert rel.execute().fetchall() == [(2, "two"), (3, "three"), (4, "four")] def test_df_order_by(self): test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]}) - rel = duckdb.order(test_df, 'j') - assert rel.execute().fetchall() == [(4, 'four'), (1, 'one'), (3, 'three'), (2, 'two')] + rel = duckdb.order(test_df, "j") + assert rel.execute().fetchall() == [(4, "four"), (1, "one"), (3, "three"), (2, "two")] def test_df_distinct(self): test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]}) - rel = duckdb.distinct(test_df).order('i') - assert rel.execute().fetchall() == [(1, 'one'), (2, 'two'), (3, 'three'), (4, 'four')] + rel = duckdb.distinct(test_df).order("i") + assert rel.execute().fetchall() == [(1, "one"), (2, "two"), (3, "three"), (4, "four")] def test_df_write_csv(self): test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]}) temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) duckdb.write_csv(test_df, temp_file_name) csv_rel = duckdb.from_csv_auto(temp_file_name) - assert csv_rel.execute().fetchall() == [(1, 'one'), (2, 'two'), (3, 'three'), (4, 'four')] + assert csv_rel.execute().fetchall() == [(1, "one"), (2, "two"), (3, "three"), (4, "four")] def test_join_types(self): test_df1 = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) @@ -452,9 +452,9 @@ def test_join_types(self): rel1 = con.from_df(test_df1) rel2 = con.from_df(test_df2) - assert rel1.join(rel2, 'i=j', 'inner').aggregate('count()').fetchone()[0] == 2 + assert rel1.join(rel2, "i=j", "inner").aggregate("count()").fetchone()[0] == 2 - assert rel1.join(rel2, 'i=j', 'left').aggregate('count()').fetchone()[0] == 4 + assert rel1.join(rel2, "i=j", "left").aggregate("count()").fetchone()[0] == 4 def test_fetchnumpy(self): start, stop = -1000, 2000 @@ -493,10 +493,10 @@ def counter(): counter.count = 0 conn = duckdb.connect() - conn.create_function('my_counter', counter, [], BIGINT) + conn.create_function("my_counter", counter, [], BIGINT) # Create a relation - rel = conn.sql('select my_counter()') + rel = conn.sql("select my_counter()") # Execute the relation once rel.fetchall() assert counter.count == 1 @@ -508,20 +508,20 @@ def counter(): assert counter.count == 2 # Verify that the query is run at least once if it's closed before it was executed. - rel = conn.sql('select my_counter()') + rel = conn.sql("select my_counter()") rel.close() assert counter.count == 3 def test_relation_print(self): con = duckdb.connect() con.execute("Create table t1 as select * from range(1000000)") - rel1 = con.table('t1') + rel1 = con.table("t1") text1 = str(rel1) - assert '? rows' in text1 - assert '>9999 rows' in text1 + assert "? rows" in text1 + assert ">9999 rows" in text1 @pytest.mark.parametrize( - 'num_rows', + "num_rows", [ 1024, 2048, @@ -563,7 +563,7 @@ def test_materialized_relation(self, duckdb_cursor, num_rows): assert len(res) == num_rows rel = duckdb_cursor.sql(query) - projection = rel.select('column0') + projection = rel.select("column0") assert projection.fetchall() == [(42,) for _ in range(num_rows)] filtered = rel.filter("column1 != 'test'") @@ -575,58 +575,58 @@ def test_materialized_relation(self, duckdb_cursor, num_rows): ): rel.insert([1, 2, 3, 4]) - query_rel = rel.query('x', "select 42 from x where column0 != 42") + query_rel = rel.query("x", "select 42 from x where column0 != 42") assert query_rel.fetchall() == [] distinct_rel = rel.distinct() - assert distinct_rel.fetchall() == [(42, 'test', 'this is a long string', True)] + assert distinct_rel.fetchall() == [(42, "test", "this is a long string", True)] limited_rel = rel.limit(50) assert len(limited_rel.fetchall()) == 50 # Using parameters also results in a MaterializedRelation materialized_one = duckdb_cursor.sql("select * from range(?)", params=[10]).project( - ColumnExpression('range').cast(str).alias('range') + ColumnExpression("range").cast(str).alias("range") ) materialized_two = duckdb_cursor.sql("call repeat('a', 5)") - joined_rel = materialized_one.join(materialized_two, 'range != a') + joined_rel = materialized_one.join(materialized_two, "range != a") res = joined_rel.fetchall() assert len(res) == 50 relation = duckdb_cursor.sql("select a from materialized_two") - assert relation.fetchone() == ('a',) + assert relation.fetchone() == ("a",) described = materialized_one.describe() res = described.fetchall() - assert res == [('count', '10'), ('mean', None), ('stddev', None), ('min', '0'), ('max', '9'), ('median', None)] + assert res == [("count", "10"), ("mean", None), ("stddev", None), ("min", "0"), ("max", "9"), ("median", None)] unioned_rel = materialized_one.union(materialized_two) res = unioned_rel.fetchall() assert res == [ - ('0',), - ('1',), - ('2',), - ('3',), - ('4',), - ('5',), - ('6',), - ('7',), - ('8',), - ('9',), - ('a',), - ('a',), - ('a',), - ('a',), - ('a',), + ("0",), + ("1",), + ("2",), + ("3",), + ("4",), + ("5",), + ("6",), + ("7",), + ("8",), + ("9",), + ("a",), + ("a",), + ("a",), + ("a",), + ("a",), ] except_rel = unioned_rel.except_(materialized_one) res = except_rel.fetchall() - assert res == [tuple('a') for _ in range(5)] + assert res == [tuple("a") for _ in range(5)] - intersect_rel = unioned_rel.intersect(materialized_one).order('range') + intersect_rel = unioned_rel.intersect(materialized_one).order("range") res = intersect_rel.fetchall() - assert res == [('0',), ('1',), ('2',), ('3',), ('4',), ('5',), ('6',), ('7',), ('8',), ('9',)] + assert res == [("0",), ("1",), ("2",), ("3",), ("4",), ("5",), ("6",), ("7",), ("8",), ("9",)] def test_materialized_relation_view(self, duckdb_cursor): def create_view(duckdb_cursor): @@ -635,11 +635,11 @@ def create_view(duckdb_cursor): create table tbl(a varchar); insert into tbl values ('test') returning * """ - ).to_view('vw') + ).to_view("vw") create_view(duckdb_cursor) res = duckdb_cursor.sql("select * from vw").fetchone() - assert res == ('test',) + assert res == ("test",) def test_materialized_relation_view2(self, duckdb_cursor): # This creates a MaterializedRelation @@ -654,7 +654,7 @@ def test_materialized_relation_view2(self, duckdb_cursor): # The VIEW still works because the CDC that is being referenced is kept alive through the MaterializedDependency item rel = duckdb_cursor.sql("select * from test") res = rel.fetchall() - assert res == [([2], ['Alice'])] + assert res == [([2], ["Alice"])] def test_serialized_materialized_relation(self, tmp_database): con = duckdb.connect(tmp_database) @@ -663,9 +663,9 @@ def create_view(con, view_name: str): rel = con.sql("select 'this is not a small string ' || range::varchar from range(?)", params=[10]) rel.to_view(view_name) - expected = [(f'this is not a small string {i}',) for i in range(10)] + expected = [(f"this is not a small string {i}",) for i in range(10)] - create_view(con, 'vw') + create_view(con, "vw") res = con.sql("select * from vw").fetchall() assert res == expected diff --git a/tests/fast/test_relation_dependency_leak.py b/tests/fast/test_relation_dependency_leak.py index ca505704..ee98e30a 100644 --- a/tests/fast/test_relation_dependency_leak.py +++ b/tests/fast/test_relation_dependency_leak.py @@ -31,13 +31,13 @@ def from_df(pandas, duckdb_cursor): def from_arrow(pandas, duckdb_cursor): data = pa.array(np.random.rand(1_000_000), type=pa.float32()) - arrow_table = pa.Table.from_arrays([data], ['a']) + arrow_table = pa.Table.from_arrays([data], ["a"]) duckdb_cursor.from_arrow(arrow_table) def arrow_replacement(pandas, duckdb_cursor): data = pa.array(np.random.rand(1_000_000), type=pa.float32()) - arrow_table = pa.Table.from_arrays([data], ['a']) + arrow_table = pa.Table.from_arrays([data], ["a"]) duckdb_cursor.query("select sum(a) from arrow_table").fetchall() @@ -47,27 +47,27 @@ def pandas_replacement(pandas, duckdb_cursor): class TestRelationDependencyMemoryLeak(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_from_arrow_leak(self, pandas, duckdb_cursor): if not can_run: return check_memory(from_arrow, pandas, duckdb_cursor) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_from_df_leak(self, pandas, duckdb_cursor): check_memory(from_df, pandas, duckdb_cursor) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_arrow_replacement_scan_leak(self, pandas, duckdb_cursor): if not can_run: return check_memory(arrow_replacement, pandas, duckdb_cursor) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_pandas_replacement_scan_leak(self, pandas, duckdb_cursor): check_memory(pandas_replacement, pandas, duckdb_cursor) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_relation_view_leak(self, pandas, duckdb_cursor): rel = from_df(pandas, duckdb_cursor) rel.create_view("bla") diff --git a/tests/fast/test_replacement_scan.py b/tests/fast/test_replacement_scan.py index 0cf69356..555773dc 100644 --- a/tests/fast/test_replacement_scan.py +++ b/tests/fast/test_replacement_scan.py @@ -8,13 +8,13 @@ def using_table(con, to_scan, object_name): - local_scope = {'con': con, object_name: to_scan, 'object_name': object_name} + local_scope = {"con": con, object_name: to_scan, "object_name": object_name} exec(f"result = con.table(object_name)", globals(), local_scope) return local_scope["result"] def using_sql(con, to_scan, object_name): - local_scope = {'con': con, object_name: to_scan, 'object_name': object_name} + local_scope = {"con": con, object_name: to_scan, "object_name": object_name} exec(f"result = con.sql('select * from \"{object_name}\"')", globals(), local_scope) return local_scope["result"] @@ -60,40 +60,40 @@ def fetch_relation(rel): def from_pandas(): - df = pd.DataFrame({'a': [1, 2, 3]}) + df = pd.DataFrame({"a": [1, 2, 3]}) return df def from_arrow(): - schema = pa.schema([('field_1', pa.int64())]) + schema = pa.schema([("field_1", pa.int64())]) df = pa.RecordBatchReader.from_batches(schema, [pa.RecordBatch.from_arrays([pa.array([1, 2, 3])], schema=schema)]) return df def create_relation(conn, query: str) -> duckdb.DuckDBPyRelation: - df = pd.DataFrame({'a': [1, 2, 3]}) + df = pd.DataFrame({"a": [1, 2, 3]}) return conn.sql(query) class TestReplacementScan(object): def test_csv_replacement(self): con = duckdb.connect() - filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'integers.csv') + filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "integers.csv") res = con.execute("select count(*) from '%s'" % (filename)) assert res.fetchone()[0] == 2 def test_parquet_replacement(self): con = duckdb.connect() - filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'binary_string.parquet') + filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "binary_string.parquet") res = con.execute("select count(*) from '%s'" % (filename)) assert res.fetchone()[0] == 3 - @pytest.mark.parametrize('get_relation', [using_table, using_sql]) + @pytest.mark.parametrize("get_relation", [using_table, using_sql]) @pytest.mark.parametrize( - 'fetch_method', + "fetch_method", [fetch_polars, fetch_df, fetch_arrow, fetch_arrow_table, fetch_arrow_record_batch, fetch_relation], ) - @pytest.mark.parametrize('object_name', ['tbl', 'table', 'select', 'update']) + @pytest.mark.parametrize("object_name", ["tbl", "table", "select", "update"]) def test_table_replacement_scans(self, duckdb_cursor, get_relation, fetch_method, object_name): base_rel = duckdb_cursor.values([1, 2, 3]) to_scan = fetch_method(base_rel) @@ -105,29 +105,29 @@ def test_table_replacement_scans(self, duckdb_cursor, get_relation, fetch_method def test_scan_global(self, duckdb_cursor): duckdb_cursor.execute("set python_enable_replacements=false") - with pytest.raises(duckdb.CatalogException, match='Table with name global_polars_df does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name global_polars_df does not exist"): # We set the depth to look for global variables to 0 so it's never found duckdb_cursor.sql("select * from global_polars_df") duckdb_cursor.execute("set python_enable_replacements=true") # Now the depth is 1, which is enough to locate the variable rel = duckdb_cursor.sql("select * from global_polars_df") res = rel.fetchone() - assert res == (1, 'banana', 5, 'beetle') + assert res == (1, "banana", 5, "beetle") def test_scan_local(self, duckdb_cursor): - df = pd.DataFrame({'a': [1, 2, 3]}) + df = pd.DataFrame({"a": [1, 2, 3]}) def inner_func(duckdb_cursor): duckdb_cursor.execute("set python_enable_replacements=false") - with pytest.raises(duckdb.CatalogException, match='Table with name df does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name df does not exist"): # We set the depth to look for local variables to 0 so it's never found duckdb_cursor.sql("select * from df") duckdb_cursor.execute("set python_enable_replacements=true") - with pytest.raises(duckdb.CatalogException, match='Table with name df does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name df does not exist"): # Here it's still not found, because it's not visible to this frame duckdb_cursor.sql("select * from df") - df = pd.DataFrame({'a': [4, 5, 6]}) + df = pd.DataFrame({"a": [4, 5, 6]}) duckdb_cursor.execute("set python_enable_replacements=true") # We can find the newly defined 'df' with depth 1 rel = duckdb_cursor.sql("select * from df") @@ -137,11 +137,11 @@ def inner_func(duckdb_cursor): inner_func(duckdb_cursor) def test_scan_local_unlimited(self, duckdb_cursor): - df = pd.DataFrame({'a': [1, 2, 3]}) + df = pd.DataFrame({"a": [1, 2, 3]}) def inner_func(duckdb_cursor): duckdb_cursor.execute("set python_enable_replacements=true") - with pytest.raises(duckdb.CatalogException, match='Table with name df does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name df does not exist"): # We set the depth to look for local variables to 1 so it's still not found because it wasn't defined in this function duckdb_cursor.sql("select * from df") duckdb_cursor.execute("set python_scan_all_frames=true") @@ -155,37 +155,37 @@ def inner_func(duckdb_cursor): def test_replacement_scan_relapi(self): con = duckdb.connect() - pyrel1 = con.query('from (values (42), (84), (120)) t(i)') + pyrel1 = con.query("from (values (42), (84), (120)) t(i)") assert isinstance(pyrel1, duckdb.DuckDBPyRelation) assert pyrel1.fetchall() == [(42,), (84,), (120,)] - pyrel2 = con.query('from pyrel1 limit 2') + pyrel2 = con.query("from pyrel1 limit 2") assert isinstance(pyrel2, duckdb.DuckDBPyRelation) assert pyrel2.fetchall() == [(42,), (84,)] - pyrel3 = con.query('select i + 100 from pyrel2') + pyrel3 = con.query("select i + 100 from pyrel2") assert type(pyrel3) == duckdb.DuckDBPyRelation assert pyrel3.fetchall() == [(142,), (184,)] def test_replacement_scan_not_found(self): con = duckdb.connect() con.execute("set python_scan_all_frames=true") - with pytest.raises(duckdb.CatalogException, match='Table with name non_existant does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name non_existant does not exist"): res = con.sql("select * from non_existant").fetchall() def test_replacement_scan_alias(self): con = duckdb.connect() - pyrel1 = con.query('from (values (1, 2)) t(i, j)') - pyrel2 = con.query('from (values (1, 10)) t(i, k)') - pyrel3 = con.query('from pyrel1 join pyrel2 using(i)') + pyrel1 = con.query("from (values (1, 2)) t(i, j)") + pyrel2 = con.query("from (values (1, 10)) t(i, k)") + pyrel3 = con.query("from pyrel1 join pyrel2 using(i)") assert type(pyrel3) == duckdb.DuckDBPyRelation assert pyrel3.fetchall() == [(1, 2, 10)] def test_replacement_scan_pandas_alias(self): con = duckdb.connect() - df1 = con.query('from (values (1, 2)) t(i, j)').df() - df2 = con.query('from (values (1, 10)) t(i, k)').df() - df3 = con.query('from df1 join df2 using(i)') + df1 = con.query("from (values (1, 2)) t(i, j)").df() + df2 = con.query("from (values (1, 10)) t(i, k)").df() + df3 = con.query("from df1 join df2 using(i)") assert df3.fetchall() == [(1, 2, 10)] def test_replacement_scan_after_creation(self, duckdb_cursor): @@ -194,14 +194,14 @@ def test_replacement_scan_after_creation(self, duckdb_cursor): rel = duckdb_cursor.sql("select * from df") duckdb_cursor.execute("drop table df") - df = pd.DataFrame({'b': [1, 2, 3]}) + df = pd.DataFrame({"b": [1, 2, 3]}) res = rel.fetchall() # FIXME: this should error instead, the 'df' table we relied on has been removed and replaced with a replacement scan assert res == [(1,), (2,), (3,)] def test_replacement_scan_caching(self, duckdb_cursor): def return_rel(conn): - df = pd.DataFrame({'a': [1, 2, 3]}) + df = pd.DataFrame({"a": [1, 2, 3]}) rel = conn.sql("select * from df") return rel @@ -220,7 +220,7 @@ def test_replacement_scan_fail(self): con.execute("select count(*) from random_object").fetchone() @pytest.mark.parametrize( - 'df_create', + "df_create", [ from_pandas, from_arrow, @@ -332,7 +332,7 @@ def test_same_name_cte(self, duckdb_cursor): def test_use_with_view(self, duckdb_cursor): rel = create_relation(duckdb_cursor, "select * from df") - rel.create_view('v1') + rel.create_view("v1") del rel rel = duckdb_cursor.sql("select * from v1") @@ -342,12 +342,12 @@ def test_use_with_view(self, duckdb_cursor): def create_view_in_func(con): df = pd.DataFrame({"a": [1, 2, 3]}) - con.execute('CREATE VIEW v1 AS SELECT * FROM df') + con.execute("CREATE VIEW v1 AS SELECT * FROM df") create_view_in_func(duckdb_cursor) # FIXME: this should be fixed in the future, likely by unifying the behavior of .sql and .execute - with pytest.raises(duckdb.CatalogException, match='Table with name df does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name df does not exist"): rel = duckdb_cursor.sql("select * from v1") def test_recursive_cte(self, duckdb_cursor): @@ -409,7 +409,7 @@ def test_multiple_replacements(self, duckdb_cursor): """ rel = duckdb_cursor.sql(query) res = rel.fetchall() - assert res == [(2, 'Bob', None), (3, 'Charlie', None), (4, 'David', 1.0), (5, 'Eve', 1.0)] + assert res == [(2, "Bob", None), (3, "Charlie", None), (4, "David", 1.0), (5, "Eve", 1.0)] def test_cte_at_different_levels(self, duckdb_cursor): query = """ @@ -459,17 +459,17 @@ def test_replacement_disabled(self): ## disable external access con.execute("set enable_external_access=false") - with pytest.raises(duckdb.CatalogException, match='Table with name df does not exist!'): + with pytest.raises(duckdb.CatalogException, match="Table with name df does not exist!"): rel = create_relation(con, "select * from df") res = rel.fetchall() with pytest.raises( - duckdb.InvalidInputException, match='Cannot change enable_external_access setting while database is running' + duckdb.InvalidInputException, match="Cannot change enable_external_access setting while database is running" ): con.execute("set enable_external_access=true") # Create connection with external access disabled - con = duckdb.connect(config={'enable_external_access': False}) - with pytest.raises(duckdb.CatalogException, match='Table with name df does not exist!'): + con = duckdb.connect(config={"enable_external_access": False}) + with pytest.raises(duckdb.CatalogException, match="Table with name df does not exist!"): rel = create_relation(con, "select * from df") res = rel.fetchall() @@ -486,23 +486,23 @@ def test_replacement_disabled(self): assert res == [(1,), (2,), (3,)] def test_replacement_of_cross_connection_relation(self): - con1 = duckdb.connect(':memory:') - con2 = duckdb.connect(':memory:') - con1.query('create table integers(i int)') - con2.query('create table integers(v varchar)') - con1.query('insert into integers values (42)') - con2.query('insert into integers values (\'xxx\')') - rel1 = con1.query('select * from integers') + con1 = duckdb.connect(":memory:") + con2 = duckdb.connect(":memory:") + con1.query("create table integers(i int)") + con2.query("create table integers(v varchar)") + con1.query("insert into integers values (42)") + con2.query("insert into integers values ('xxx')") + rel1 = con1.query("select * from integers") with pytest.raises( duckdb.InvalidInputException, - match=r'The object was created by another Connection and can therefore not be used by this Connection.', + match=r"The object was created by another Connection and can therefore not be used by this Connection.", ): - con2.query('from rel1') + con2.query("from rel1") del con1 with pytest.raises( duckdb.InvalidInputException, - match=r'The object was created by another Connection and can therefore not be used by this Connection.', + match=r"The object was created by another Connection and can therefore not be used by this Connection.", ): - con2.query('from rel1') + con2.query("from rel1") diff --git a/tests/fast/test_result.py b/tests/fast/test_result.py index af68e268..906b1198 100644 --- a/tests/fast/test_result.py +++ b/tests/fast/test_result.py @@ -5,42 +5,42 @@ class TestPythonResult(object): def test_result_closed(self, duckdb_cursor): - connection = duckdb.connect('') + connection = duckdb.connect("") cursor = connection.cursor() - cursor.execute('CREATE TABLE integers (i integer)') - cursor.execute('INSERT INTO integers VALUES (0),(1),(2),(3),(4),(5),(6),(7),(8),(9),(NULL)') + cursor.execute("CREATE TABLE integers (i integer)") + cursor.execute("INSERT INTO integers VALUES (0),(1),(2),(3),(4),(5),(6),(7),(8),(9),(NULL)") rel = connection.table("integers") res = rel.aggregate("sum(i)").execute() res.close() - with pytest.raises(duckdb.InvalidInputException, match='result closed'): + with pytest.raises(duckdb.InvalidInputException, match="result closed"): res.fetchone() - with pytest.raises(duckdb.InvalidInputException, match='result closed'): + with pytest.raises(duckdb.InvalidInputException, match="result closed"): res.fetchall() - with pytest.raises(duckdb.InvalidInputException, match='result closed'): + with pytest.raises(duckdb.InvalidInputException, match="result closed"): res.fetchnumpy() - with pytest.raises(duckdb.InvalidInputException, match='There is no query result'): + with pytest.raises(duckdb.InvalidInputException, match="There is no query result"): res.fetch_arrow_table() - with pytest.raises(duckdb.InvalidInputException, match='There is no query result'): + with pytest.raises(duckdb.InvalidInputException, match="There is no query result"): res.fetch_arrow_reader(1) def test_result_describe_types(self, duckdb_cursor): - connection = duckdb.connect('') + connection = duckdb.connect("") cursor = connection.cursor() - cursor.execute('CREATE TABLE test (i bool, j TIME, k VARCHAR)') + cursor.execute("CREATE TABLE test (i bool, j TIME, k VARCHAR)") cursor.execute("INSERT INTO test VALUES (TRUE, '01:01:01', 'bla' )") rel = connection.table("test") res = rel.execute() assert res.description == [ - ('i', 'BOOLEAN', None, None, None, None, None), - ('j', 'TIME', None, None, None, None, None), - ('k', 'VARCHAR', None, None, None, None, None), + ("i", "BOOLEAN", None, None, None, None, None), + ("j", "TIME", None, None, None, None, None), + ("k", "VARCHAR", None, None, None, None, None), ] def test_result_timestamps(self, duckdb_cursor): - connection = duckdb.connect('') + connection = duckdb.connect("") cursor = connection.cursor() cursor.execute( - 'CREATE TABLE IF NOT EXISTS timestamps (sec TIMESTAMP_S, milli TIMESTAMP_MS,micro TIMESTAMP_US, nano TIMESTAMP_NS );' + "CREATE TABLE IF NOT EXISTS timestamps (sec TIMESTAMP_S, milli TIMESTAMP_MS,micro TIMESTAMP_US, nano TIMESTAMP_NS );" ) cursor.execute( "INSERT INTO timestamps VALUES ('2008-01-01 00:00:11','2008-01-01 00:00:01.794','2008-01-01 00:00:01.98926','2008-01-01 00:00:01.899268321' )" @@ -59,12 +59,12 @@ def test_result_timestamps(self, duckdb_cursor): def test_result_interval(self): connection = duckdb.connect() cursor = connection.cursor() - cursor.execute('CREATE TABLE IF NOT EXISTS intervals (ivals INTERVAL)') + cursor.execute("CREATE TABLE IF NOT EXISTS intervals (ivals INTERVAL)") cursor.execute("INSERT INTO intervals VALUES ('1 day'), ('2 second'), ('1 microsecond')") rel = connection.table("intervals") res = rel.execute() - assert res.description == [('ivals', 'INTERVAL', None, None, None, None, None)] + assert res.description == [("ivals", "INTERVAL", None, None, None, None, None)] assert res.fetchall() == [ (datetime.timedelta(days=1.0),), (datetime.timedelta(seconds=2.0),), diff --git a/tests/fast/test_runtime_error.py b/tests/fast/test_runtime_error.py index 29e81d1e..327be004 100644 --- a/tests/fast/test_runtime_error.py +++ b/tests/fast/test_runtime_error.py @@ -2,8 +2,8 @@ import pytest from conftest import NumpyPandas, ArrowPandas -closed = lambda: pytest.raises(duckdb.ConnectionException, match='Connection already closed') -no_result_set = lambda: pytest.raises(duckdb.InvalidInputException, match='No open result set') +closed = lambda: pytest.raises(duckdb.ConnectionException, match="Connection already closed") +no_result_set = lambda: pytest.raises(duckdb.InvalidInputException, match="No open result set") class TestRuntimeError(object): @@ -20,7 +20,7 @@ def test_df_error(self): con.execute("select i::int from tbl").df() def test_arrow_error(self): - pytest.importorskip('pyarrow') + pytest.importorskip("pyarrow") con = duckdb.connect() con.execute("create table tbl as select 'hello' i") @@ -34,83 +34,83 @@ def test_register_error(self): con.register(py_obj, "v") def test_arrow_fetch_table_error(self): - pytest.importorskip('pyarrow') + pytest.importorskip("pyarrow") con = duckdb.connect() arrow_object = con.execute("select 1").fetch_arrow_table() arrow_relation = con.from_arrow(arrow_object) res = arrow_relation.execute() res.close() - with pytest.raises(duckdb.InvalidInputException, match='There is no query result'): + with pytest.raises(duckdb.InvalidInputException, match="There is no query result"): res.fetch_arrow_table() def test_arrow_record_batch_reader_error(self): - pytest.importorskip('pyarrow') + pytest.importorskip("pyarrow") con = duckdb.connect() arrow_object = con.execute("select 1").fetch_arrow_table() arrow_relation = con.from_arrow(arrow_object) res = arrow_relation.execute() res.close() - with pytest.raises(duckdb.ProgrammingError, match='There is no query result'): + with pytest.raises(duckdb.ProgrammingError, match="There is no query result"): res.fetch_arrow_reader(1) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_relation_cache_fetchall(self, pandas): conn = duckdb.connect() df_in = pandas.DataFrame( { - 'numbers': [1, 2, 3, 4, 5], + "numbers": [1, 2, 3, 4, 5], } ) conn.execute("create view x as select * from df_in") rel = conn.query("select * from x") del df_in - with pytest.raises(duckdb.ProgrammingError, match='Table with name df_in does not exist'): + with pytest.raises(duckdb.ProgrammingError, match="Table with name df_in does not exist"): # Even when we preserve ExternalDependency objects correctly, this is not supported # Relations only save dependencies for their immediate TableRefs, # so the dependency of 'x' on 'df_in' is not registered in 'rel' rel.fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_relation_cache_execute(self, pandas): conn = duckdb.connect() df_in = pandas.DataFrame( { - 'numbers': [1, 2, 3, 4, 5], + "numbers": [1, 2, 3, 4, 5], } ) conn.execute("create view x as select * from df_in") rel = conn.query("select * from x") del df_in - with pytest.raises(duckdb.ProgrammingError, match='Table with name df_in does not exist'): + with pytest.raises(duckdb.ProgrammingError, match="Table with name df_in does not exist"): rel.execute() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_relation_query_error(self, pandas): conn = duckdb.connect() df_in = pandas.DataFrame( { - 'numbers': [1, 2, 3, 4, 5], + "numbers": [1, 2, 3, 4, 5], } ) conn.execute("create view x as select * from df_in") rel = conn.query("select * from x") del df_in - with pytest.raises(duckdb.CatalogException, match='Table with name df_in does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name df_in does not exist"): rel.query("bla", "select * from bla") - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_conn_broken_statement_error(self, pandas): conn = duckdb.connect() df_in = pandas.DataFrame( { - 'numbers': [1, 2, 3, 4, 5], + "numbers": [1, 2, 3, 4, 5], } ) conn.execute("create view x as select * from df_in") del df_in - with pytest.raises(duckdb.CatalogException, match='Table with name df_in does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name df_in does not exist"): conn.execute("select 1; select * from x; select 3;") def test_conn_prepared_statement_error(self): @@ -118,17 +118,17 @@ def test_conn_prepared_statement_error(self): conn.execute("create table integers (a integer, b integer)") with pytest.raises( duckdb.InvalidInputException, - match='Values were not provided for the following prepared statement parameters: 2', + match="Values were not provided for the following prepared statement parameters: 2", ): conn.execute("select * from integers where a =? and b=?", [1]) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_closed_conn_exceptions(self, pandas): conn = duckdb.connect() conn.close() df_in = pandas.DataFrame( { - 'numbers': [1, 2, 3, 4, 5], + "numbers": [1, 2, 3, 4, 5], } ) diff --git a/tests/fast/test_sql_expression.py b/tests/fast/test_sql_expression.py index 771be84d..4dc4cab5 100644 --- a/tests/fast/test_sql_expression.py +++ b/tests/fast/test_sql_expression.py @@ -9,7 +9,6 @@ class TestSQLExpression(object): def test_sql_expression_basic(self, duckdb_cursor): - # Test simple constant expressions expr = SQLExpression("42") rel = duckdb_cursor.sql("SELECT 1").select(expr) @@ -17,7 +16,7 @@ def test_sql_expression_basic(self, duckdb_cursor): expr = SQLExpression("'hello'") rel = duckdb_cursor.sql("SELECT 1").select(expr) - assert rel.fetchall() == [('hello',)] + assert rel.fetchall() == [("hello",)] expr = SQLExpression("NULL") rel = duckdb_cursor.sql("SELECT 1").select(expr) @@ -43,14 +42,13 @@ def test_sql_expression_basic(self, duckdb_cursor): # Test function calls expr = SQLExpression("UPPER('test')") rel = duckdb_cursor.sql("SELECT 1").select(expr) - assert rel.fetchall() == [('TEST',)] + assert rel.fetchall() == [("TEST",)] expr = SQLExpression("CONCAT('hello', ' ', 'world')") rel = duckdb_cursor.sql("SELECT 1").select(expr) - assert rel.fetchall() == [('hello world',)] + assert rel.fetchall() == [("hello world",)] def test_sql_expression_with_columns(self, duckdb_cursor): - # Create a test table duckdb_cursor.execute( """ @@ -75,12 +73,12 @@ def test_sql_expression_with_columns(self, duckdb_cursor): expr = SQLExpression("UPPER(b)") rel2 = rel.select(expr) - assert rel2.fetchall() == [('ONE',), ('TWO',), ('THREE',)] + assert rel2.fetchall() == [("ONE",), ("TWO",), ("THREE",)] # Test complex expressions expr = SQLExpression("CASE WHEN a > 1 THEN b ELSE 'default' END") rel2 = rel.select(expr) - assert rel2.fetchall() == [('default',), ('two',), ('three',)] + assert rel2.fetchall() == [("default",), ("two",), ("three",)] # Test combining with other expression types expr1 = SQLExpression("a + 5") @@ -122,8 +120,8 @@ def test_sql_expression_alias(self, duckdb_cursor): rel = duckdb_cursor.table("test_alias") expr = SQLExpression("a + 10").alias("a_plus_10") rel2 = rel.select(expr, "b") - assert rel2.fetchall() == [(11, 'one'), (12, 'two')] - assert rel2.columns == ['a_plus_10', 'b'] + assert rel2.fetchall() == [(11, "one"), (12, "two")] + assert rel2.columns == ["a_plus_10", "b"] def test_sql_expression_in_filter(self, duckdb_cursor): duckdb_cursor.execute( @@ -142,18 +140,18 @@ def test_sql_expression_in_filter(self, duckdb_cursor): # Test filter with SQL expression expr = SQLExpression("a > 2") rel2 = rel.filter(expr) - assert rel2.fetchall() == [(3, 'three'), (4, 'four')] + assert rel2.fetchall() == [(3, "three"), (4, "four")] # Test complex filter expr = SQLExpression("a % 2 = 0 AND b LIKE '%o%'") rel2 = rel.filter(expr) - assert rel2.fetchall() == [(2, 'two'), (4, 'four')] + assert rel2.fetchall() == [(2, "two"), (4, "four")] # Test combining with other expression types expr1 = SQLExpression("a > 1") expr2 = ColumnExpression("b") == ConstantExpression("four") rel2 = rel.filter(expr1 & expr2) - assert rel2.fetchall() == [(4, 'four')] + assert rel2.fetchall() == [(4, "four")] def test_sql_expression_in_aggregates(self, duckdb_cursor): duckdb_cursor.execute( @@ -176,14 +174,14 @@ def test_sql_expression_in_aggregates(self, duckdb_cursor): # Test aggregation with group by expr = SQLExpression("SUM(c)") - rel2 = rel.aggregate([expr, "b"]).sort('b') + rel2 = rel.aggregate([expr, "b"]).sort("b") result = rel2.fetchall() - assert result == [(30, 'group1'), (70, 'group2')] + assert result == [(30, "group1"), (70, "group2")] # Test multiple aggregations expr1 = SQLExpression("SUM(a)").alias("sum_a") expr2 = SQLExpression("AVG(c)").alias("avg_c") - rel2 = rel.aggregate([expr1, expr2], "b").sort('sum_a', 'avg_c') + rel2 = rel.aggregate([expr1, expr2], "b").sort("sum_a", "avg_c") result = rel2.fetchall() result.sort() assert result == [(3, 15.0), (7, 35.0)] diff --git a/tests/fast/test_string_annotation.py b/tests/fast/test_string_annotation.py index c5500c66..83685bed 100644 --- a/tests/fast/test_string_annotation.py +++ b/tests/fast/test_string_annotation.py @@ -14,7 +14,7 @@ def test_base(): test_base.__code__, test_base.__globals__, test_base.__name__, test_base.__defaults__, test_base.__closure__ ) # Add the 'type' string as return_annotation - test_function.__annotations__ = {'return': type} + test_function.__annotations__ = {"return": type} return test_function @@ -33,12 +33,12 @@ class TestStringAnnotation(object): python_version_lower_than_3_10(), reason="inspect.signature(eval_str=True) only supported since 3.10 and higher" ) @pytest.mark.parametrize( - ['input', 'expected'], + ["input", "expected"], [ - ('str', 'VARCHAR'), - ('list[str]', 'VARCHAR[]'), - ('dict[str, str]', 'MAP(VARCHAR, VARCHAR)'), - ('dict[Union[str, bool], str]', 'MAP(UNION(u1 VARCHAR, u2 BOOLEAN), VARCHAR)'), + ("str", "VARCHAR"), + ("list[str]", "VARCHAR[]"), + ("dict[str, str]", "MAP(VARCHAR, VARCHAR)"), + ("dict[Union[str, bool], str]", "MAP(UNION(u1 VARCHAR, u2 BOOLEAN), VARCHAR)"), ], ) def test_string_annotations(self, duckdb_cursor, input, expected): diff --git a/tests/fast/test_tf.py b/tests/fast/test_tf.py index b65acec6..db93d0de 100644 --- a/tests/fast/test_tf.py +++ b/tests/fast/test_tf.py @@ -2,7 +2,7 @@ import pytest -tf = pytest.importorskip('tensorflow') +tf = pytest.importorskip("tensorflow") def test_tf(): @@ -14,16 +14,16 @@ def test_tf(): # Test from connection duck_tf = con.execute("select * from t").tf() duck_numpy = con.sql("select * from t").fetchnumpy() - tf.math.equal(duck_tf['a'], tf.convert_to_tensor(duck_numpy['a'])) - tf.math.equal(duck_tf['b'], tf.convert_to_tensor(duck_numpy['b'])) + tf.math.equal(duck_tf["a"], tf.convert_to_tensor(duck_numpy["a"])) + tf.math.equal(duck_tf["b"], tf.convert_to_tensor(duck_numpy["b"])) # Test from relation duck_tf = con.sql("select * from t").tf() - tf.math.equal(duck_tf['a'], tf.convert_to_tensor(duck_numpy['a'])) - tf.math.equal(duck_tf['b'], tf.convert_to_tensor(duck_numpy['b'])) + tf.math.equal(duck_tf["a"], tf.convert_to_tensor(duck_numpy["a"])) + tf.math.equal(duck_tf["b"], tf.convert_to_tensor(duck_numpy["b"])) # Test all Numeric Types - numeric_types = ['TINYINT', 'SMALLINT', 'BIGINT', 'HUGEINT', 'FLOAT', 'DOUBLE', 'DECIMAL(4,1)', 'UTINYINT'] + numeric_types = ["TINYINT", "SMALLINT", "BIGINT", "HUGEINT", "FLOAT", "DOUBLE", "DECIMAL(4,1)", "UTINYINT"] for supported_type in numeric_types: con = duckdb.connect() @@ -31,5 +31,5 @@ def test_tf(): con.execute("insert into t values (1,2), (3,4)") duck_tf = con.sql("select * from t").tf() duck_numpy = con.sql("select * from t").fetchnumpy() - tf.math.equal(duck_tf['a'], tf.convert_to_tensor(duck_numpy['a'])) - tf.math.equal(duck_tf['b'], tf.convert_to_tensor(duck_numpy['b'])) + tf.math.equal(duck_tf["a"], tf.convert_to_tensor(duck_numpy["a"])) + tf.math.equal(duck_tf["b"], tf.convert_to_tensor(duck_numpy["b"])) diff --git a/tests/fast/test_transaction.py b/tests/fast/test_transaction.py index 54deaf82..ff0ba1a7 100644 --- a/tests/fast/test_transaction.py +++ b/tests/fast/test_transaction.py @@ -5,16 +5,16 @@ class TestConnectionTransaction(object): def test_transaction(self, duckdb_cursor): con = duckdb.connect() - con.execute('create table t (i integer)') - con.execute('insert into t values (1)') + con.execute("create table t (i integer)") + con.execute("insert into t values (1)") con.begin() - con.execute('insert into t values (1)') - assert con.execute('select count (*) from t').fetchone()[0] == 2 + con.execute("insert into t values (1)") + assert con.execute("select count (*) from t").fetchone()[0] == 2 con.rollback() - assert con.execute('select count (*) from t').fetchone()[0] == 1 + assert con.execute("select count (*) from t").fetchone()[0] == 1 con.begin() - con.execute('insert into t values (1)') - assert con.execute('select count (*) from t').fetchone()[0] == 2 + con.execute("insert into t values (1)") + assert con.execute("select count (*) from t").fetchone()[0] == 2 con.commit() - assert con.execute('select count (*) from t').fetchone()[0] == 2 + assert con.execute("select count (*) from t").fetchone()[0] == 2 diff --git a/tests/fast/test_type.py b/tests/fast/test_type.py index c5a62694..1e8ebc25 100644 --- a/tests/fast/test_type.py +++ b/tests/fast/test_type.py @@ -40,83 +40,83 @@ class TestType(object): def test_sqltype(self): - assert str(duckdb.sqltype('struct(a VARCHAR, b BIGINT)')) == 'STRUCT(a VARCHAR, b BIGINT)' + assert str(duckdb.sqltype("struct(a VARCHAR, b BIGINT)")) == "STRUCT(a VARCHAR, b BIGINT)" # todo: add tests with invalid type_str def test_primitive_types(self): assert str(SQLNULL) == '"NULL"' - assert str(BOOLEAN) == 'BOOLEAN' - assert str(TINYINT) == 'TINYINT' - assert str(UTINYINT) == 'UTINYINT' - assert str(SMALLINT) == 'SMALLINT' - assert str(USMALLINT) == 'USMALLINT' - assert str(INTEGER) == 'INTEGER' - assert str(UINTEGER) == 'UINTEGER' - assert str(BIGINT) == 'BIGINT' - assert str(UBIGINT) == 'UBIGINT' - assert str(HUGEINT) == 'HUGEINT' - assert str(UHUGEINT) == 'UHUGEINT' - assert str(UUID) == 'UUID' - assert str(FLOAT) == 'FLOAT' - assert str(DOUBLE) == 'DOUBLE' - assert str(DATE) == 'DATE' - assert str(TIMESTAMP) == 'TIMESTAMP' - assert str(TIMESTAMP_MS) == 'TIMESTAMP_MS' - assert str(TIMESTAMP_NS) == 'TIMESTAMP_NS' - assert str(TIMESTAMP_S) == 'TIMESTAMP_S' - assert str(TIME) == 'TIME' - assert str(TIME_TZ) == 'TIME WITH TIME ZONE' - assert str(TIMESTAMP_TZ) == 'TIMESTAMP WITH TIME ZONE' - assert str(VARCHAR) == 'VARCHAR' - assert str(BLOB) == 'BLOB' - assert str(BIT) == 'BIT' - assert str(INTERVAL) == 'INTERVAL' + assert str(BOOLEAN) == "BOOLEAN" + assert str(TINYINT) == "TINYINT" + assert str(UTINYINT) == "UTINYINT" + assert str(SMALLINT) == "SMALLINT" + assert str(USMALLINT) == "USMALLINT" + assert str(INTEGER) == "INTEGER" + assert str(UINTEGER) == "UINTEGER" + assert str(BIGINT) == "BIGINT" + assert str(UBIGINT) == "UBIGINT" + assert str(HUGEINT) == "HUGEINT" + assert str(UHUGEINT) == "UHUGEINT" + assert str(UUID) == "UUID" + assert str(FLOAT) == "FLOAT" + assert str(DOUBLE) == "DOUBLE" + assert str(DATE) == "DATE" + assert str(TIMESTAMP) == "TIMESTAMP" + assert str(TIMESTAMP_MS) == "TIMESTAMP_MS" + assert str(TIMESTAMP_NS) == "TIMESTAMP_NS" + assert str(TIMESTAMP_S) == "TIMESTAMP_S" + assert str(TIME) == "TIME" + assert str(TIME_TZ) == "TIME WITH TIME ZONE" + assert str(TIMESTAMP_TZ) == "TIMESTAMP WITH TIME ZONE" + assert str(VARCHAR) == "VARCHAR" + assert str(BLOB) == "BLOB" + assert str(BIT) == "BIT" + assert str(INTERVAL) == "INTERVAL" def test_list_type(self): type = duckdb.list_type(BIGINT) - assert str(type) == 'BIGINT[]' + assert str(type) == "BIGINT[]" def test_array_type(self): type = duckdb.array_type(BIGINT, 3) - assert str(type) == 'BIGINT[3]' + assert str(type) == "BIGINT[3]" def test_struct_type(self): - type = duckdb.struct_type({'a': BIGINT, 'b': BOOLEAN}) - assert str(type) == 'STRUCT(a BIGINT, b BOOLEAN)' + type = duckdb.struct_type({"a": BIGINT, "b": BOOLEAN}) + assert str(type) == "STRUCT(a BIGINT, b BOOLEAN)" # FIXME: create an unnamed struct when fields are provided as a list type = duckdb.struct_type([BIGINT, BOOLEAN]) - assert str(type) == 'STRUCT(v1 BIGINT, v2 BOOLEAN)' + assert str(type) == "STRUCT(v1 BIGINT, v2 BOOLEAN)" def test_incomplete_struct_type(self): with pytest.raises( - duckdb.InvalidInputException, match='Could not convert empty dictionary to a duckdb STRUCT type' + duckdb.InvalidInputException, match="Could not convert empty dictionary to a duckdb STRUCT type" ): type = duckdb.typing.DuckDBPyType(dict()) def test_map_type(self): type = duckdb.map_type(duckdb.sqltype("BIGINT"), duckdb.sqltype("DECIMAL(10, 2)")) - assert str(type) == 'MAP(BIGINT, DECIMAL(10,2))' + assert str(type) == "MAP(BIGINT, DECIMAL(10,2))" def test_decimal_type(self): type = duckdb.decimal_type(5, 3) - assert str(type) == 'DECIMAL(5,3)' + assert str(type) == "DECIMAL(5,3)" def test_string_type(self): type = duckdb.string_type() - assert str(type) == 'VARCHAR' + assert str(type) == "VARCHAR" def test_string_type_collation(self): - type = duckdb.string_type('NOCASE') + type = duckdb.string_type("NOCASE") # collation does not show up in the string representation.. - assert str(type) == 'VARCHAR' + assert str(type) == "VARCHAR" def test_union_type(self): type = duckdb.union_type([BIGINT, VARCHAR, TINYINT]) - assert str(type) == 'UNION(v1 BIGINT, v2 VARCHAR, v3 TINYINT)' + assert str(type) == "UNION(v1 BIGINT, v2 VARCHAR, v3 TINYINT)" - type = duckdb.union_type({'a': BIGINT, 'b': VARCHAR, 'c': TINYINT}) - assert str(type) == 'UNION(a BIGINT, b VARCHAR, c TINYINT)' + type = duckdb.union_type({"a": BIGINT, "b": VARCHAR, "c": TINYINT}) + assert str(type) == "UNION(a BIGINT, b VARCHAR, c TINYINT)" import sys @@ -125,42 +125,42 @@ def test_implicit_convert_from_builtin_type(self): type = duckdb.list_type(list[str]) assert str(type.child) == "VARCHAR[]" - mapping = {str: 'VARCHAR', int: 'BIGINT', bytes: 'BLOB', bytearray: 'BLOB', bool: 'BOOLEAN', float: 'DOUBLE'} + mapping = {str: "VARCHAR", int: "BIGINT", bytes: "BLOB", bytearray: "BLOB", bool: "BOOLEAN", float: "DOUBLE"} for duckdb_type, expected in mapping.items(): res = duckdb.list_type(duckdb_type) assert str(res.child) == expected - res = duckdb.list_type({'a': str, 'b': int}) - assert str(res.child) == 'STRUCT(a VARCHAR, b BIGINT)' + res = duckdb.list_type({"a": str, "b": int}) + assert str(res.child) == "STRUCT(a VARCHAR, b BIGINT)" res = duckdb.list_type(dict[str, int]) - assert str(res.child) == 'MAP(VARCHAR, BIGINT)' + assert str(res.child) == "MAP(VARCHAR, BIGINT)" res = duckdb.list_type(list[str]) - assert str(res.child) == 'VARCHAR[]' + assert str(res.child) == "VARCHAR[]" res = duckdb.list_type(list[dict[str, dict[list[str], str]]]) - assert str(res.child) == 'MAP(VARCHAR, MAP(VARCHAR[], VARCHAR))[]' + assert str(res.child) == "MAP(VARCHAR, MAP(VARCHAR[], VARCHAR))[]" res = duckdb.list_type(list[Union[str, int]]) - assert str(res.child) == 'UNION(u1 VARCHAR, u2 BIGINT)[]' + assert str(res.child) == "UNION(u1 VARCHAR, u2 BIGINT)[]" def test_implicit_convert_from_numpy(self, duckdb_cursor): np = pytest.importorskip("numpy") type_mapping = { - 'bool': 'BOOLEAN', - 'int8': 'TINYINT', - 'uint8': 'UTINYINT', - 'int16': 'SMALLINT', - 'uint16': 'USMALLINT', - 'int32': 'INTEGER', - 'uint32': 'UINTEGER', - 'int64': 'BIGINT', - 'uint64': 'UBIGINT', - 'float16': 'FLOAT', - 'float32': 'FLOAT', - 'float64': 'DOUBLE', + "bool": "BOOLEAN", + "int8": "TINYINT", + "uint8": "UTINYINT", + "int16": "SMALLINT", + "uint16": "USMALLINT", + "int32": "INTEGER", + "uint32": "UINTEGER", + "int64": "BIGINT", + "uint64": "UBIGINT", + "float16": "FLOAT", + "float32": "FLOAT", + "float64": "DOUBLE", } builtins = [] @@ -189,30 +189,30 @@ def test_implicit_convert_from_numpy(self, duckdb_cursor): def test_attribute_accessor(self): type = duckdb.row_type([BIGINT, duckdb.list_type(duckdb.map_type(BLOB, BIT))]) - assert hasattr(type, 'a') == False - assert hasattr(type, 'v1') == True + assert hasattr(type, "a") == False + assert hasattr(type, "v1") == True - field_one = type['v1'] - assert str(field_one) == 'BIGINT' + field_one = type["v1"] + assert str(field_one) == "BIGINT" field_one = type.v1 - assert str(field_one) == 'BIGINT' + assert str(field_one) == "BIGINT" - field_two = type['v2'] - assert str(field_two) == 'MAP(BLOB, BIT)[]' + field_two = type["v2"] + assert str(field_two) == "MAP(BLOB, BIT)[]" child_type = type.v2.child - assert str(child_type) == 'MAP(BLOB, BIT)' + assert str(child_type) == "MAP(BLOB, BIT)" def test_json_type(self): - json_type = duckdb.type('JSON') + json_type = duckdb.type("JSON") val = duckdb.Value('{"duck": 42}', json_type) res = duckdb.execute("select typeof($1)", [val]).fetchone() - assert res == ('JSON',) + assert res == ("JSON",) def test_struct_from_dict(self): - res = duckdb.list_type({'a': VARCHAR, 'b': VARCHAR}) - assert res == 'STRUCT(a VARCHAR, b VARCHAR)[]' + res = duckdb.list_type({"a": VARCHAR, "b": VARCHAR}) + assert res == "STRUCT(a VARCHAR, b VARCHAR)[]" def test_hash_method(self): type1 = duckdb.list_type({'a': VARCHAR, 'b': VARCHAR}) @@ -232,29 +232,29 @@ def test_hash_method(self): @pytest.mark.skipif(sys.version_info < (3, 9), reason="python3.7 does not store Optional[..] in a recognized way") def test_optional(self): type = duckdb.typing.DuckDBPyType(Optional[str]) - assert type == 'VARCHAR' + assert type == "VARCHAR" type = duckdb.typing.DuckDBPyType(Optional[Union[int, bool]]) - assert type == 'UNION(u1 BIGINT, u2 BOOLEAN)' + assert type == "UNION(u1 BIGINT, u2 BOOLEAN)" type = duckdb.typing.DuckDBPyType(Optional[list[int]]) - assert type == 'BIGINT[]' + assert type == "BIGINT[]" type = duckdb.typing.DuckDBPyType(Optional[dict[int, str]]) - assert type == 'MAP(BIGINT, VARCHAR)' + assert type == "MAP(BIGINT, VARCHAR)" type = duckdb.typing.DuckDBPyType(Optional[dict[Optional[int], Optional[str]]]) - assert type == 'MAP(BIGINT, VARCHAR)' + assert type == "MAP(BIGINT, VARCHAR)" type = duckdb.typing.DuckDBPyType(Optional[dict[Optional[int], Optional[str]]]) - assert type == 'MAP(BIGINT, VARCHAR)' + assert type == "MAP(BIGINT, VARCHAR)" type = duckdb.typing.DuckDBPyType(Optional[Union[Optional[str], Optional[bool]]]) - assert type == 'UNION(u1 VARCHAR, u2 BOOLEAN)' + assert type == "UNION(u1 VARCHAR, u2 BOOLEAN)" type = duckdb.typing.DuckDBPyType(Union[str, None]) - assert type == 'VARCHAR' + assert type == "VARCHAR" @pytest.mark.skipif(sys.version_info < (3, 10), reason="'str | None' syntax requires Python 3.10 or higher") def test_optional_310(self): type = duckdb.typing.DuckDBPyType(str | None) - assert type == 'VARCHAR' + assert type == "VARCHAR" def test_children_attribute(self): - assert DuckDBPyType('INTEGER[]').children == [('child', DuckDBPyType('INTEGER'))] - assert DuckDBPyType('INTEGER[2]').children == [('child', DuckDBPyType('INTEGER')), ('size', 2)] - assert DuckDBPyType('INTEGER[2][3]').children == [('child', DuckDBPyType('INTEGER[2]')), ('size', 3)] - assert DuckDBPyType("ENUM('a', 'b', 'c')").children == [('values', ['a', 'b', 'c'])] + assert DuckDBPyType("INTEGER[]").children == [("child", DuckDBPyType("INTEGER"))] + assert DuckDBPyType("INTEGER[2]").children == [("child", DuckDBPyType("INTEGER")), ("size", 2)] + assert DuckDBPyType("INTEGER[2][3]").children == [("child", DuckDBPyType("INTEGER[2]")), ("size", 3)] + assert DuckDBPyType("ENUM('a', 'b', 'c')").children == [("values", ["a", "b", "c"])] diff --git a/tests/fast/test_type_explicit.py b/tests/fast/test_type_explicit.py index 23dcddc3..7b0797e6 100644 --- a/tests/fast/test_type_explicit.py +++ b/tests/fast/test_type_explicit.py @@ -2,19 +2,18 @@ class TestMap(object): - def test_array_list_tuple_ambiguity(self): con = duckdb.connect() - res = con.sql("SELECT $arg", params={'arg': (1, 2)}).fetchall()[0][0] + res = con.sql("SELECT $arg", params={"arg": (1, 2)}).fetchall()[0][0] assert res == [1, 2] # By using an explicit duckdb.Value with an array type, we should convert the input as an array # and get an array (tuple) back typ = duckdb.array_type(duckdb.typing.BIGINT, 2) val = duckdb.Value((1, 2), typ) - res = con.sql("SELECT $arg", params={'arg': val}).fetchall()[0][0] + res = con.sql("SELECT $arg", params={"arg": val}).fetchall()[0][0] assert res == (1, 2) val = duckdb.Value([3, 4], typ) - res = con.sql("SELECT $arg", params={'arg': val}).fetchall()[0][0] + res = con.sql("SELECT $arg", params={"arg": val}).fetchall()[0][0] assert res == (3, 4) diff --git a/tests/fast/test_unicode.py b/tests/fast/test_unicode.py index b697f84a..7d08ac88 100644 --- a/tests/fast/test_unicode.py +++ b/tests/fast/test_unicode.py @@ -7,7 +7,7 @@ class TestUnicode(object): def test_unicode_pandas_scan(self, duckdb_cursor): - con = duckdb.connect(database=':memory:', read_only=False) - test_df = pd.DataFrame.from_dict({"i": [1, 2, 3], "j": ["a", "c", u"ë"]}) - con.register('test_df_view', test_df) - con.execute('SELECT i, j, LENGTH(j) FROM test_df_view').fetchall() + con = duckdb.connect(database=":memory:", read_only=False) + test_df = pd.DataFrame.from_dict({"i": [1, 2, 3], "j": ["a", "c", "ë"]}) + con.register("test_df_view", test_df) + con.execute("SELECT i, j, LENGTH(j) FROM test_df_view").fetchall() diff --git a/tests/fast/test_value.py b/tests/fast/test_value.py index 4f74516c..c17264fd 100644 --- a/tests/fast/test_value.py +++ b/tests/fast/test_value.py @@ -71,7 +71,7 @@ class TestValue(object): # This excludes timezone aware values, as those are a pain to test @pytest.mark.parametrize( - 'item', + "item", [ (BOOLEAN, BooleanValue(True), True), (UTINYINT, UnsignedBinaryValue(129), 129), @@ -88,17 +88,17 @@ class TestValue(object): (DOUBLE, DoubleValue(0.23234234234), 0.23234234234), ( duckdb.decimal_type(12, 8), - DecimalValue(decimal.Decimal('1234.12345678'), 12, 8), - decimal.Decimal('1234.12345678'), + DecimalValue(decimal.Decimal("1234.12345678"), 12, 8), + decimal.Decimal("1234.12345678"), ), - (VARCHAR, StringValue('this is a long string'), 'this is a long string'), + (VARCHAR, StringValue("this is a long string"), "this is a long string"), ( UUID, - UUIDValue(uuid.UUID('ffffffff-ffff-ffff-ffff-ffffffffffff')), - uuid.UUID('ffffffff-ffff-ffff-ffff-ffffffffffff'), + UUIDValue(uuid.UUID("ffffffff-ffff-ffff-ffff-ffffffffffff")), + uuid.UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"), ), - (BIT, BitValue(b'010101010101'), '010101010101'), - (BLOB, BlobValue(b'\x00\x00\x00a'), b'\x00\x00\x00a'), + (BIT, BitValue(b"010101010101"), "010101010101"), + (BLOB, BlobValue(b"\x00\x00\x00a"), b"\x00\x00\x00a"), (DATE, DateValue(datetime.date(2000, 5, 4)), datetime.date(2000, 5, 4)), (INTERVAL, IntervalValue(datetime.timedelta(days=5)), datetime.timedelta(days=5)), ( @@ -116,10 +116,10 @@ def test_value_helpers(self, item): expected_value = item[2] con = duckdb.connect() - observed_type = con.execute('select typeof(a) from (select $1) tbl(a)', [value_object]).fetchall()[0][0] + observed_type = con.execute("select typeof(a) from (select $1) tbl(a)", [value_object]).fetchall()[0][0] assert observed_type == str(expected_type) - con.execute('select $1', [value_object]) + con.execute("select $1", [value_object]) result = con.fetchone() result = result[0] assert result == expected_value @@ -129,10 +129,10 @@ def test_float_to_decimal_prevention(self): con = duckdb.connect() with pytest.raises(duckdb.ConversionException, match="Can't losslessly convert"): - con.execute('select $1', [value]).fetchall() + con.execute("select $1", [value]).fetchall() @pytest.mark.parametrize( - 'value', + "value", [ TimestampSecondValue(datetime.datetime(1970, 3, 21, 12, 36, 43)), TimestampMilisecondValue(datetime.datetime(1970, 3, 21, 12, 36, 43)), @@ -144,10 +144,10 @@ def test_timestamp_sec_not_supported(self, value): with pytest.raises( duckdb.NotImplementedException, match="Conversion from 'datetime' to type .* is not implemented yet" ): - con.execute('select $1', [value]).fetchall() + con.execute("select $1", [value]).fetchall() @pytest.mark.parametrize( - 'target_type,test_value,expected_conversion_success', + "target_type,test_value,expected_conversion_success", [ (TINYINT, 0, True), (TINYINT, 255, False), @@ -187,7 +187,7 @@ def test_numeric_values(self, target_type, test_value, expected_conversion_succe value = Value(test_value, target_type) con = duckdb.connect() - work = lambda: con.execute('select typeof(a) from (select $1) tbl(a)', [value]).fetchall() + work = lambda: con.execute("select typeof(a) from (select $1) tbl(a)", [value]).fetchall() if expected_conversion_success: res = work() diff --git a/tests/fast/test_versioning.py b/tests/fast/test_versioning.py index 7a3c7a68..2ec3f784 100644 --- a/tests/fast/test_versioning.py +++ b/tests/fast/test_versioning.py @@ -1,6 +1,7 @@ """ Tests for duckdb_pytooling versioning functionality. """ + import os import unittest @@ -109,26 +110,26 @@ def test_bump_version_exact_tag(self): assert _bump_version("1.2.3", 0, False) == "1.2.3" assert _bump_version("1.2.3.post1", 0, False) == "1.2.3.post1" - @patch.dict('os.environ', {'MAIN_BRANCH_VERSIONING': '1'}) + @patch.dict("os.environ", {"MAIN_BRANCH_VERSIONING": "1"}) def test_bump_version_with_distance(self): """Test bump_version with distance from tag.""" assert _bump_version("1.2.3", 5, False) == "1.3.0.dev5" - + # Post-release development assert _bump_version("1.2.3.post1", 3, False) == "1.2.3.post2.dev3" - @patch.dict('os.environ', {'MAIN_BRANCH_VERSIONING': '0'}) + @patch.dict("os.environ", {"MAIN_BRANCH_VERSIONING": "0"}) def test_bump_version_release_branch(self): """Test bump_version on bugfix branch.""" assert _bump_version("1.2.3", 5, False) == "1.2.4.dev5" - @patch.dict('os.environ', {'MAIN_BRANCH_VERSIONING': '1'}) + @patch.dict("os.environ", {"MAIN_BRANCH_VERSIONING": "1"}) def test_bump_version_dirty(self): """Test bump_version with dirty working directory.""" assert _bump_version("1.2.3", 0, True) == "1.3.0.dev0" assert _bump_version("1.2.3.post1", 0, True) == "1.2.3.post2.dev0" - @patch.dict('os.environ', {'MAIN_BRANCH_VERSIONING': '1'}) + @patch.dict("os.environ", {"MAIN_BRANCH_VERSIONING": "1"}) def test_version_scheme_function(self): """Test the version_scheme function that setuptools_scm calls.""" # Mock setuptools_scm version object @@ -136,7 +137,7 @@ def test_version_scheme_function(self): mock_version.tag = "1.2.3" mock_version.distance = 5 mock_version.dirty = False - + result = version_scheme(mock_version) assert result == "1.3.0.dev5" @@ -149,48 +150,45 @@ def test_bump_version_invalid_format(self): class TestGitOperations(unittest.TestCase): """Test git-related operations (mocked).""" - @patch('subprocess.run') + @patch("subprocess.run") def test_get_current_version_success(self, mock_run): """Test successful current version retrieval.""" mock_run.return_value.stdout = "v1.2.3\n" mock_run.return_value.check = True - + result = get_current_version() assert result == "1.2.3" mock_run.assert_called_once_with( - ["git", "describe", "--tags", "--abbrev=0"], - capture_output=True, - text=True, - check=True + ["git", "describe", "--tags", "--abbrev=0"], capture_output=True, text=True, check=True ) - @patch('subprocess.run') + @patch("subprocess.run") def test_get_current_version_with_post_release(self, mock_run): """Test current version retrieval with post-release tag.""" mock_run.return_value.stdout = "v1.2.3-post1\n" mock_run.return_value.check = True - + result = get_current_version() assert result == "1.2.3.post1" - @patch('subprocess.run') + @patch("subprocess.run") def test_get_current_version_no_tags(self, mock_run): """Test current version retrieval when no tags exist.""" mock_run.side_effect = subprocess.CalledProcessError(1, "git") - + result = get_current_version() assert result is None - @patch('subprocess.run') + @patch("subprocess.run") def test_get_git_describe_success(self, mock_run): """Test successful git describe.""" mock_run.return_value.stdout = "v1.2.3-5-g1234567\n" mock_run.return_value.check = True - + result = get_git_describe() assert result == "v1.2.3-5-g1234567" - @patch('subprocess.run') + @patch("subprocess.run") def test_get_git_describe_no_tags(self, mock_run): """Test git describe when no tags exist.""" mock_run.side_effect = subprocess.CalledProcessError(1, "git") @@ -202,21 +200,21 @@ def test_get_git_describe_no_tags(self, mock_run): class TestEnvironmentVariableHandling(unittest.TestCase): """Test environment variable handling in setuptools_scm integration.""" - @patch.dict('os.environ', {'OVERRIDE_GIT_DESCRIBE': 'v1.2.3-5-g1234567'}) + @patch.dict("os.environ", {"OVERRIDE_GIT_DESCRIBE": "v1.2.3-5-g1234567"}) def test_override_git_describe_basic(self): """Test OVERRIDE_GIT_DESCRIBE with basic format.""" forced_version_from_env() # Check that the environment variable was processed - assert 'SETUPTOOLS_SCM_PRETEND_VERSION_FOR_DUCKDB' in os.environ + assert "SETUPTOOLS_SCM_PRETEND_VERSION_FOR_DUCKDB" in os.environ - @patch.dict('os.environ', {'OVERRIDE_GIT_DESCRIBE': 'v1.2.3-post1-3-g1234567'}) + @patch.dict("os.environ", {"OVERRIDE_GIT_DESCRIBE": "v1.2.3-post1-3-g1234567"}) def test_override_git_describe_post_release(self): """Test OVERRIDE_GIT_DESCRIBE with post-release format.""" forced_version_from_env() # Check that post-release was converted correctly - assert 'SETUPTOOLS_SCM_PRETEND_VERSION_FOR_DUCKDB' in os.environ + assert "SETUPTOOLS_SCM_PRETEND_VERSION_FOR_DUCKDB" in os.environ - @patch.dict('os.environ', {'OVERRIDE_GIT_DESCRIBE': 'invalid-format'}) + @patch.dict("os.environ", {"OVERRIDE_GIT_DESCRIBE": "invalid-format"}) def test_override_git_describe_invalid(self): """Test OVERRIDE_GIT_DESCRIBE with invalid format.""" with pytest.raises(ValueError, match="Invalid git describe override"): diff --git a/tests/fast/test_windows_abs_path.py b/tests/fast/test_windows_abs_path.py index bc9f05ec..4ce8311b 100644 --- a/tests/fast/test_windows_abs_path.py +++ b/tests/fast/test_windows_abs_path.py @@ -6,15 +6,15 @@ class TestWindowsAbsPath(object): def test_windows_path_accent(self): - if os.name != 'nt': + if os.name != "nt": return current_directory = os.getcwd() - test_dir = os.path.join(current_directory, 'tést') + test_dir = os.path.join(current_directory, "tést") if os.path.isdir(test_dir): shutil.rmtree(test_dir) os.mkdir(test_dir) - dbname = 'test.db' + dbname = "test.db" dbpath = os.path.join(test_dir, dbname) con = duckdb.connect(dbpath) con.execute("CREATE OR REPLACE TABLE int AS SELECT * FROM range(10) t(i)") @@ -23,8 +23,8 @@ def test_windows_path_accent(self): del res del con - os.chdir('tést') - dbpath = os.path.join('..', dbpath) + os.chdir("tést") + dbpath = os.path.join("..", dbpath) con = duckdb.connect(dbpath) res = con.execute("SELECT COUNT(*) FROM int").fetchall() assert res[0][0] == 10 @@ -37,13 +37,13 @@ def test_windows_path_accent(self): del res del con - os.chdir('..') + os.chdir("..") def test_windows_abs_path(self): - if os.name != 'nt': + if os.name != "nt": return current_directory = os.getcwd() - dbpath = os.path.join(current_directory, 'test.db') + dbpath = os.path.join(current_directory, "test.db") con = duckdb.connect(dbpath) con.execute("CREATE OR REPLACE TABLE int AS SELECT * FROM range(10) t(i)") res = con.execute("SELECT COUNT(*) FROM int").fetchall() @@ -51,7 +51,7 @@ def test_windows_abs_path(self): del res del con - assert dbpath[1] == ':' + assert dbpath[1] == ":" # remove the drive separator and reconnect dbpath = dbpath[2:] con = duckdb.connect(dbpath) @@ -61,7 +61,7 @@ def test_windows_abs_path(self): del con # forward slashes work as well - dbpath = dbpath.replace('\\', '/') + dbpath = dbpath.replace("\\", "/") con = duckdb.connect(dbpath) res = con.execute("SELECT COUNT(*) FROM int").fetchall() assert res[0][0] == 10 diff --git a/tests/fast/types/test_blob.py b/tests/fast/types/test_blob.py index 162859d2..0d331f7f 100644 --- a/tests/fast/types/test_blob.py +++ b/tests/fast/types/test_blob.py @@ -6,8 +6,8 @@ class TestBlob(object): def test_blob(self, duckdb_cursor): duckdb_cursor.execute("SELECT BLOB 'hello'") results = duckdb_cursor.fetchall() - assert results[0][0] == b'hello' + assert results[0][0] == b"hello" duckdb_cursor.execute("SELECT BLOB 'hello' AS a") results = duckdb_cursor.fetchnumpy() - assert results['a'] == numpy.array([b'hello'], dtype=object) + assert results["a"] == numpy.array([b"hello"], dtype=object) diff --git a/tests/fast/types/test_datetime_datetime.py b/tests/fast/types/test_datetime_datetime.py index 08a9953d..2df14b18 100644 --- a/tests/fast/types/test_datetime_datetime.py +++ b/tests/fast/types/test_datetime_datetime.py @@ -4,29 +4,29 @@ def create_query(positive, type): - inf = 'infinity' if positive else '-infinity' + inf = "infinity" if positive else "-infinity" return f""" select '{inf}'::{type} """ class TestDateTimeDateTime(object): - @pytest.mark.parametrize('positive', [True, False]) + @pytest.mark.parametrize("positive", [True, False]) @pytest.mark.parametrize( - 'type', + "type", [ - 'TIMESTAMP', - 'TIMESTAMP_S', - 'TIMESTAMP_MS', - 'TIMESTAMP_NS', - 'TIMESTAMPTZ', - 'TIMESTAMP_US', + "TIMESTAMP", + "TIMESTAMP_S", + "TIMESTAMP_MS", + "TIMESTAMP_NS", + "TIMESTAMPTZ", + "TIMESTAMP_US", ], ) def test_timestamp_infinity(self, positive, type): con = duckdb.connect() - if type in ['TIMESTAMP_S', 'TIMESTAMP_MS', 'TIMESTAMP_NS']: + if type in ["TIMESTAMP_S", "TIMESTAMP_MS", "TIMESTAMP_NS"]: # Infinity (both positive and negative) is not supported for non-usecond timetamps return diff --git a/tests/fast/types/test_decimal.py b/tests/fast/types/test_decimal.py index 30cb13e7..b068056d 100644 --- a/tests/fast/types/test_decimal.py +++ b/tests/fast/types/test_decimal.py @@ -6,21 +6,21 @@ class TestDecimal(object): def test_decimal(self, duckdb_cursor): duckdb_cursor.execute( - 'SELECT 1.2::DECIMAL(4,1), 100.3::DECIMAL(9,1), 320938.4298::DECIMAL(18,4), 49082094824.904820482094::DECIMAL(30,12), NULL::DECIMAL' + "SELECT 1.2::DECIMAL(4,1), 100.3::DECIMAL(9,1), 320938.4298::DECIMAL(18,4), 49082094824.904820482094::DECIMAL(30,12), NULL::DECIMAL" ) result = duckdb_cursor.fetchall() assert result == [ - (Decimal('1.2'), Decimal('100.3'), Decimal('320938.4298'), Decimal('49082094824.904820482094'), None) + (Decimal("1.2"), Decimal("100.3"), Decimal("320938.4298"), Decimal("49082094824.904820482094"), None) ] def test_decimal_numpy(self, duckdb_cursor): duckdb_cursor.execute( - 'SELECT 1.2::DECIMAL(4,1) AS a, 100.3::DECIMAL(9,1) AS b, 320938.4298::DECIMAL(18,4) AS c, 49082094824.904820482094::DECIMAL(30,12) AS d' + "SELECT 1.2::DECIMAL(4,1) AS a, 100.3::DECIMAL(9,1) AS b, 320938.4298::DECIMAL(18,4) AS c, 49082094824.904820482094::DECIMAL(30,12) AS d" ) result = duckdb_cursor.fetchnumpy() assert result == { - 'a': numpy.array([1.2]), - 'b': numpy.array([100.3]), - 'c': numpy.array([320938.4298]), - 'd': numpy.array([49082094824.904820482094]), + "a": numpy.array([1.2]), + "b": numpy.array([100.3]), + "c": numpy.array([320938.4298]), + "d": numpy.array([49082094824.904820482094]), } diff --git a/tests/fast/types/test_hugeint.py b/tests/fast/types/test_hugeint.py index f0254380..e9b5016a 100644 --- a/tests/fast/types/test_hugeint.py +++ b/tests/fast/types/test_hugeint.py @@ -4,11 +4,11 @@ class TestHugeint(object): def test_hugeint(self, duckdb_cursor): - duckdb_cursor.execute('SELECT 437894723897234238947043214') + duckdb_cursor.execute("SELECT 437894723897234238947043214") result = duckdb_cursor.fetchall() assert result == [(437894723897234238947043214,)] def test_hugeint_numpy(self, duckdb_cursor): - duckdb_cursor.execute('SELECT 1::HUGEINT AS i') + duckdb_cursor.execute("SELECT 1::HUGEINT AS i") result = duckdb_cursor.fetchnumpy() - assert result == {'i': numpy.array([1.0])} + assert result == {"i": numpy.array([1.0])} diff --git a/tests/fast/types/test_nan.py b/tests/fast/types/test_nan.py index b714ae6c..fe99a990 100644 --- a/tests/fast/types/test_nan.py +++ b/tests/fast/types/test_nan.py @@ -15,34 +15,34 @@ def test_pandas_nan(self, duckdb_cursor): # now create a new column with the current time # (FIXME: we replace the microseconds with 0 for now, because we only support millisecond resolution) current_time = datetime.datetime.now().replace(microsecond=0) - df['datetest'] = current_time + df["datetest"] = current_time # introduce a NaT (Not a Time value) - df.loc[0, 'datetest'] = pandas.NaT + df.loc[0, "datetest"] = pandas.NaT # now pass the DF through duckdb: - conn = duckdb.connect(':memory:') - conn.register('testing_null_values', df) + conn = duckdb.connect(":memory:") + conn.register("testing_null_values", df) # scan the DF and fetch the results normally - results = conn.execute('select * from testing_null_values').fetchall() - assert results[0][0] == 'val1' + results = conn.execute("select * from testing_null_values").fetchall() + assert results[0][0] == "val1" assert results[0][1] == 1.05 assert results[0][2] == None assert results[0][3] == None - assert results[1][0] == 'val3' + assert results[1][0] == "val3" assert results[1][1] == None - assert results[1][2] == 'val3' + assert results[1][2] == "val3" assert results[1][3] == current_time # now fetch the results as numpy: - result_np = conn.execute('select * from testing_null_values').fetchnumpy() - assert result_np['col1'][0] == df['col1'][0] - assert result_np['col1'][1] == df['col1'][1] - assert result_np['col2'][0] == df['col2'][0] + result_np = conn.execute("select * from testing_null_values").fetchnumpy() + assert result_np["col1"][0] == df["col1"][0] + assert result_np["col1"][1] == df["col1"][1] + assert result_np["col2"][0] == df["col2"][0] - assert result_np['col2'].mask[1] - assert result_np['newcol1'].mask[0] - assert result_np['newcol1'][1] == df['newcol1'][1] + assert result_np["col2"].mask[1] + assert result_np["newcol1"].mask[0] + assert result_np["newcol1"][1] == df["newcol1"][1] - result_df = conn.execute('select * from testing_null_values').fetchdf() - assert pandas.isnull(result_df['datetest'][0]) - assert result_df['datetest'][1] == df['datetest'][1] + result_df = conn.execute("select * from testing_null_values").fetchdf() + assert pandas.isnull(result_df["datetest"][0]) + assert result_df["datetest"][1] == df["datetest"][1] diff --git a/tests/fast/types/test_nested.py b/tests/fast/types/test_nested.py index e005b3f3..7f777384 100644 --- a/tests/fast/types/test_nested.py +++ b/tests/fast/types/test_nested.py @@ -23,24 +23,24 @@ def test_nested_lists(self, duckdb_cursor): def test_struct(self, duckdb_cursor): result = duckdb_cursor.execute("SELECT STRUCT_PACK(a := 42, b := 43)").fetchall() - assert result == [({'a': 42, 'b': 43},)] + assert result == [({"a": 42, "b": 43},)] result = duckdb_cursor.execute("SELECT STRUCT_PACK(a := 42, b := NULL)").fetchall() - assert result == [({'a': 42, 'b': None},)] + assert result == [({"a": 42, "b": None},)] def test_unnamed_struct(self, duckdb_cursor): result = duckdb_cursor.execute("SELECT row('aa','bb') AS x").fetchall() - assert result == [(('aa', 'bb'),)] + assert result == [(("aa", "bb"),)] result = duckdb_cursor.execute("SELECT row('aa',NULL) AS x").fetchall() - assert result == [(('aa', None),)] + assert result == [(("aa", None),)] def test_nested_struct(self, duckdb_cursor): result = duckdb_cursor.execute("SELECT STRUCT_PACK(a := 42, b := LIST_VALUE(10, 9, 8, 7))").fetchall() - assert result == [({'a': 42, 'b': [10, 9, 8, 7]},)] + assert result == [({"a": 42, "b": [10, 9, 8, 7]},)] result = duckdb_cursor.execute("SELECT STRUCT_PACK(a := 42, b := LIST_VALUE(10, 9, 8, NULL))").fetchall() - assert result == [({'a': 42, 'b': [10, 9, 8, None]},)] + assert result == [({"a": 42, "b": [10, 9, 8, None]},)] def test_map(self, duckdb_cursor): result = duckdb_cursor.execute("select MAP(LIST_VALUE(1, 2, 3, 4),LIST_VALUE(10, 9, 8, 7))").fetchall() diff --git a/tests/fast/types/test_numpy.py b/tests/fast/types/test_numpy.py index 42ae33a0..40b1a5de 100644 --- a/tests/fast/types/test_numpy.py +++ b/tests/fast/types/test_numpy.py @@ -11,7 +11,7 @@ def test_numpy_datetime64(self, duckdb_cursor): duckdb_con.execute("create table tbl(col TIMESTAMP)") duckdb_con.execute( "insert into tbl VALUES (CAST(? AS TIMESTAMP WITHOUT TIME ZONE))", - parameters=[np.datetime64('2022-02-08T06:01:38.761310')], + parameters=[np.datetime64("2022-02-08T06:01:38.761310")], ) assert [(datetime.datetime(2022, 2, 8, 6, 1, 38, 761310),)] == duckdb_con.execute( "select * from tbl" @@ -24,11 +24,11 @@ def test_numpy_datetime_big(self): duckdb_con.execute("INSERT INTO TEST VALUES ('2263-02-28')") res1 = duckdb_con.execute("select * from test").fetchnumpy() - date_value = {'date': np.array(['2263-02-28'], dtype='datetime64[us]')} + date_value = {"date": np.array(["2263-02-28"], dtype="datetime64[us]")} assert res1 == date_value def test_numpy_enum_conversion(self, duckdb_cursor): - arr = np.array(['a', 'b', 'c']) + arr = np.array(["a", "b", "c"]) rel = duckdb_cursor.sql("select * from arr") - res = rel.fetchnumpy()['column0'] + res = rel.fetchnumpy()["column0"] np.testing.assert_equal(res, arr) diff --git a/tests/fast/types/test_object_int.py b/tests/fast/types/test_object_int.py index ce153d49..ed3a8d14 100644 --- a/tests/fast/types/test_object_int.py +++ b/tests/fast/types/test_object_int.py @@ -12,19 +12,19 @@ def test_object_integer(self, duckdb_cursor): pd = pytest.importorskip("pandas") df_in = pd.DataFrame( { - 'int8': pd.Series([None, 1, -1], dtype="Int8"), - 'int16': pd.Series([None, 1, -1], dtype="Int16"), - 'int32': pd.Series([None, 1, -1], dtype="Int32"), - 'int64': pd.Series([None, 1, -1], dtype="Int64"), + "int8": pd.Series([None, 1, -1], dtype="Int8"), + "int16": pd.Series([None, 1, -1], dtype="Int16"), + "int32": pd.Series([None, 1, -1], dtype="Int32"), + "int64": pd.Series([None, 1, -1], dtype="Int64"), } ) - warnings.simplefilter(action='ignore', category=RuntimeWarning) + warnings.simplefilter(action="ignore", category=RuntimeWarning) df_expected_res = pd.DataFrame( { - 'int8': pd.Series(np.ma.masked_array([0, 1, -1], mask=[True, False, False]), dtype='Int8'), - 'int16': pd.Series(np.ma.masked_array([0, 1, -1], mask=[True, False, False]), dtype='Int16'), - 'int32': pd.Series(np.ma.masked_array([0, 1, -1], mask=[True, False, False]), dtype='Int32'), - 'int64': pd.Series(np.ma.masked_array([0, 1, -1], mask=[True, False, False]), dtype='Int64'), + "int8": pd.Series(np.ma.masked_array([0, 1, -1], mask=[True, False, False]), dtype="Int8"), + "int16": pd.Series(np.ma.masked_array([0, 1, -1], mask=[True, False, False]), dtype="Int16"), + "int32": pd.Series(np.ma.masked_array([0, 1, -1], mask=[True, False, False]), dtype="Int32"), + "int64": pd.Series(np.ma.masked_array([0, 1, -1], mask=[True, False, False]), dtype="Int64"), } ) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() @@ -37,22 +37,22 @@ def test_object_uinteger(self, duckdb_cursor): with suppress(TypeError): df_in = pd.DataFrame( { - 'uint8': pd.Series([None, 1, 255], dtype="UInt8"), - 'uint16': pd.Series([None, 1, 65535], dtype="UInt16"), - 'uint32': pd.Series([None, 1, 4294967295], dtype="UInt32"), - 'uint64': pd.Series([None, 1, 18446744073709551615], dtype="UInt64"), + "uint8": pd.Series([None, 1, 255], dtype="UInt8"), + "uint16": pd.Series([None, 1, 65535], dtype="UInt16"), + "uint32": pd.Series([None, 1, 4294967295], dtype="UInt32"), + "uint64": pd.Series([None, 1, 18446744073709551615], dtype="UInt64"), } ) - warnings.simplefilter(action='ignore', category=RuntimeWarning) + warnings.simplefilter(action="ignore", category=RuntimeWarning) df_expected_res = pd.DataFrame( { - 'uint8': pd.Series(np.ma.masked_array([0, 1, 255], mask=[True, False, False]), dtype='UInt8'), - 'uint16': pd.Series(np.ma.masked_array([0, 1, 65535], mask=[True, False, False]), dtype='UInt16'), - 'uint32': pd.Series( - np.ma.masked_array([0, 1, 4294967295], mask=[True, False, False]), dtype='UInt32' + "uint8": pd.Series(np.ma.masked_array([0, 1, 255], mask=[True, False, False]), dtype="UInt8"), + "uint16": pd.Series(np.ma.masked_array([0, 1, 65535], mask=[True, False, False]), dtype="UInt16"), + "uint32": pd.Series( + np.ma.masked_array([0, 1, 4294967295], mask=[True, False, False]), dtype="UInt32" ), - 'uint64': pd.Series( - np.ma.masked_array([0, 1, 18446744073709551615], mask=[True, False, False]), dtype='UInt64' + "uint64": pd.Series( + np.ma.masked_array([0, 1, 18446744073709551615], mask=[True, False, False]), dtype="UInt64" ), } ) @@ -63,20 +63,20 @@ def test_object_uinteger(self, duckdb_cursor): # Unsigned Masked float/double types def test_object_float(self, duckdb_cursor): # Require pandas 1.2.0 >= for this, because Float32|Float64 was not added before this version - pd = pytest.importorskip("pandas", '1.2.0') + pd = pytest.importorskip("pandas", "1.2.0") df_in = pd.DataFrame( { - 'float32': pd.Series([None, 1, 4294967295], dtype="Float32"), - 'float64': pd.Series([None, 1, 18446744073709551615], dtype="Float64"), + "float32": pd.Series([None, 1, 4294967295], dtype="Float32"), + "float64": pd.Series([None, 1, 18446744073709551615], dtype="Float64"), } ) df_expected_res = pd.DataFrame( { - 'float32': pd.Series( - np.ma.masked_array([0, 1, 4294967295], mask=[True, False, False]), dtype='float32' + "float32": pd.Series( + np.ma.masked_array([0, 1, 4294967295], mask=[True, False, False]), dtype="float32" ), - 'float64': pd.Series( - np.ma.masked_array([0, 1, 18446744073709551615], mask=[True, False, False]), dtype='float64' + "float64": pd.Series( + np.ma.masked_array([0, 1, 18446744073709551615], mask=[True, False, False]), dtype="float64" ), } ) diff --git a/tests/fast/types/test_time_tz.py b/tests/fast/types/test_time_tz.py index 66475df8..eceed79a 100644 --- a/tests/fast/types/test_time_tz.py +++ b/tests/fast/types/test_time_tz.py @@ -11,7 +11,7 @@ class TestTimeTz(object): def test_time_tz(self, duckdb_cursor): df = pandas.DataFrame({"col1": [time(1, 2, 3, tzinfo=timezone.utc)]}) - sql = f'SELECT * FROM df' + sql = f"SELECT * FROM df" duckdb_cursor.execute(sql) diff --git a/tests/fast/types/test_unsigned.py b/tests/fast/types/test_unsigned.py index 6ac50727..a35a2216 100644 --- a/tests/fast/types/test_unsigned.py +++ b/tests/fast/types/test_unsigned.py @@ -1,7 +1,7 @@ class TestUnsigned(object): def test_unsigned(self, duckdb_cursor): - duckdb_cursor.execute('create table unsigned (a utinyint, b usmallint, c uinteger, d ubigint)') - duckdb_cursor.execute('insert into unsigned values (1,1,1,1), (null,null,null,null)') - duckdb_cursor.execute('select * from unsigned order by a nulls first') + duckdb_cursor.execute("create table unsigned (a utinyint, b usmallint, c uinteger, d ubigint)") + duckdb_cursor.execute("insert into unsigned values (1,1,1,1), (null,null,null,null)") + duckdb_cursor.execute("select * from unsigned order by a nulls first") result = duckdb_cursor.fetchall() assert result == [(None, None, None, None), (1, 1, 1, 1)] diff --git a/tests/fast/udf/test_null_filtering.py b/tests/fast/udf/test_null_filtering.py index 208a9246..fd5b45d0 100644 --- a/tests/fast/udf/test_null_filtering.py +++ b/tests/fast/udf/test_null_filtering.py @@ -2,7 +2,7 @@ import pytest pd = pytest.importorskip("pandas") -pa = pytest.importorskip('pyarrow', '18.0.0') +pa = pytest.importorskip("pyarrow", "18.0.0") from typing import Union import pyarrow.compute as pc import uuid @@ -22,11 +22,11 @@ class Candidate(NamedTuple): def layout(index: int): return [ - ['x', 'x', 'y'], - ['x', None, 'y'], - [None, 'y', None], - ['x', None, None], - [None, None, 'y'], + ["x", "x", "y"], + ["x", None, "y"], + [None, "y", None], + ["x", None, None], + [None, None, "y"], [None, None, None], ][index] @@ -36,14 +36,14 @@ def add_variations(data, index: int): data.extend( [ { - 'a': layout(index), - 'b': layout(0), - 'c': layout(0), + "a": layout(index), + "b": layout(0), + "c": layout(0), }, { - 'a': layout(0), - 'b': layout(0), - 'c': layout(index), + "a": layout(0), + "b": layout(0), + "c": layout(index), }, ] ) @@ -83,9 +83,9 @@ def get_types(): 2147483647, ), Candidate(UBIGINT, 18446744073709551615, 9223372036854776000), - Candidate(VARCHAR, 'long_string_test', 'smallstring'), + Candidate(VARCHAR, "long_string_test", "smallstring"), Candidate( - UUID, uuid.UUID('ffffffff-ffff-ffff-ffff-ffffffffffff'), uuid.UUID('ffffffff-ffff-ffff-ffff-000000000000') + UUID, uuid.UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"), uuid.UUID("ffffffff-ffff-ffff-ffff-000000000000") ), Candidate( FLOAT, @@ -106,8 +106,8 @@ def get_types(): ), Candidate( BLOB, - b'\xf6\x96\xb0\x85', - b'\x85\xb0\x96\xf6', + b"\xf6\x96\xb0\x85", + b"\x85\xb0\x96\xf6", ), Candidate( INTERVAL, @@ -120,24 +120,24 @@ def get_types(): False, ), Candidate( - duckdb.struct_type(['BIGINT[]', 'VARCHAR[]']), - {'v1': [1, 2, 3], 'v2': ['a', 'non-inlined string', 'duckdb']}, - {'v1': [5, 4, 3, 2, 1], 'v2': ['non-inlined-string', 'a', 'b', 'c', 'duckdb']}, + duckdb.struct_type(["BIGINT[]", "VARCHAR[]"]), + {"v1": [1, 2, 3], "v2": ["a", "non-inlined string", "duckdb"]}, + {"v1": [5, 4, 3, 2, 1], "v2": ["non-inlined-string", "a", "b", "c", "duckdb"]}, ), - Candidate(duckdb.list_type('VARCHAR'), ['the', 'duck', 'non-inlined string'], ['non-inlined-string', 'test']), + Candidate(duckdb.list_type("VARCHAR"), ["the", "duck", "non-inlined string"], ["non-inlined-string", "test"]), ] def construct_query(tuples) -> str: def construct_values_list(row, start_param_idx): parameter_count = len(row) - parameters = [f'${x+start_param_idx}' for x in range(parameter_count)] - parameters = '(' + ', '.join(parameters) + ')' + parameters = [f"${x + start_param_idx}" for x in range(parameter_count)] + parameters = "(" + ", ".join(parameters) + ")" return parameters row_size = len(tuples[0]) values_list = [construct_values_list(x, 1 + (i * row_size)) for i, x in enumerate(tuples)] - values_list = ', '.join(values_list) + values_list = ", ".join(values_list) query = f""" select * from (values {values_list}) @@ -154,19 +154,19 @@ def construct_parameters(tuples, dbtype): class TestUDFNullFiltering(object): @pytest.mark.parametrize( - 'table_data', + "table_data", get_table_data(), ) @pytest.mark.parametrize( - 'test_type', + "test_type", get_types(), ) - @pytest.mark.parametrize('udf_type', ['arrow', 'native']) + @pytest.mark.parametrize("udf_type", ["arrow", "native"]) def test_null_filtering(self, duckdb_cursor, table_data: dict, test_type: Candidate, udf_type): null_count = sum([1 for x in list(zip(*table_data.values())) if any([y == None for y in x])]) row_count = len(table_data) table_data = { - key: [None if not x else test_type.variant_one if x == 'x' else test_type.variant_two for x in value] + key: [None if not x else test_type.variant_one if x == "x" else test_type.variant_two for x in value] for key, value in table_data.items() } @@ -174,21 +174,21 @@ def test_null_filtering(self, duckdb_cursor, table_data: dict, test_type: Candid query = construct_query(tuples) parameters = construct_parameters(tuples, test_type.type) rel = duckdb_cursor.sql(query + " t(a, b, c)", params=parameters) - rel.to_table('tbl') + rel.to_table("tbl") rel.show() def my_func(*args): - if udf_type == 'arrow': + if udf_type == "arrow": my_func.count += len(args[0]) else: my_func.count += 1 return args[0] def create_parameters(table_data, dbtype): - return ", ".join(f'{key}::{dbtype}' for key in list(table_data.keys())) + return ", ".join(f"{key}::{dbtype}" for key in list(table_data.keys())) my_func.count = 0 - duckdb_cursor.create_function('test', my_func, None, test_type.type, type=udf_type) + duckdb_cursor.create_function("test", my_func, None, test_type.type, type=udf_type) query = f"select test({create_parameters(table_data, test_type.type)}) from tbl" result = duckdb_cursor.sql(query).fetchall() @@ -201,7 +201,7 @@ def create_parameters(table_data, dbtype): assert my_func.count == row_count - null_count @pytest.mark.parametrize( - 'table_data', + "table_data", [ [1, 2, 3, 4], [1, 2, None, 4], @@ -211,14 +211,14 @@ def test_nulls_from_default_null_handling_native(self, duckdb_cursor, table_data def returns_null(x): return None - df = pd.DataFrame({'a': table_data}) + df = pd.DataFrame({"a": table_data}) duckdb_cursor.execute("create table tbl as select * from df") - duckdb_cursor.create_function('test', returns_null, [str], int, type='native') - with pytest.raises(duckdb.InvalidInputException, match='The UDF is not expected to return NULL values'): + duckdb_cursor.create_function("test", returns_null, [str], int, type="native") + with pytest.raises(duckdb.InvalidInputException, match="The UDF is not expected to return NULL values"): result = duckdb_cursor.sql("select test(a::VARCHAR) from tbl").fetchall() @pytest.mark.parametrize( - 'table_data', + "table_data", [ [1, 2, 3, 4], [1, 2, None, 4], @@ -229,9 +229,9 @@ def returns_null(x): l = x.to_pylist() return pa.array([None for _ in l], type=pa.int64()) - df = pd.DataFrame({'a': table_data}) + df = pd.DataFrame({"a": table_data}) duckdb_cursor.execute("create table tbl as select * from df") - duckdb_cursor.create_function('test', returns_null, [str], int, type='arrow') - with pytest.raises(duckdb.InvalidInputException, match='The UDF is not expected to return NULL values'): + duckdb_cursor.create_function("test", returns_null, [str], int, type="arrow") + with pytest.raises(duckdb.InvalidInputException, match="The UDF is not expected to return NULL values"): result = duckdb_cursor.sql("select test(a::VARCHAR) from tbl").fetchall() print(result) diff --git a/tests/fast/udf/test_remove_function.py b/tests/fast/udf/test_remove_function.py index e67045c4..d03fd7e6 100644 --- a/tests/fast/udf/test_remove_function.py +++ b/tests/fast/udf/test_remove_function.py @@ -21,37 +21,37 @@ def test_not_created(self): duckdb.InvalidInputException, match="No function by the name of 'not_a_registered_function' was found in the list of registered functions", ): - con.remove_function('not_a_registered_function') + con.remove_function("not_a_registered_function") def test_double_remove(self): def func(x: int) -> int: return x con = duckdb.connect() - con.create_function('func', func) - con.sql('select func(42)') - con.remove_function('func') + con.create_function("func", func) + con.sql("select func(42)") + con.remove_function("func") with pytest.raises( duckdb.InvalidInputException, match="No function by the name of 'func' was found in the list of registered functions", ): - con.remove_function('func') + con.remove_function("func") - with pytest.raises(duckdb.CatalogException, match='Scalar Function with name func does not exist!'): - con.sql('select func(42)') + with pytest.raises(duckdb.CatalogException, match="Scalar Function with name func does not exist!"): + con.sql("select func(42)") def test_use_after_remove(self): def func(x: int) -> int: return x con = duckdb.connect() - con.create_function('func', func) - rel = con.sql('select func(42)') - con.remove_function('func') + con.create_function("func", func) + rel = con.sql("select func(42)") + con.remove_function("func") """ Error: Catalog Error: Scalar Function with name func does not exist! """ - with pytest.raises(duckdb.CatalogException, match='Scalar Function with name func does not exist!'): + with pytest.raises(duckdb.CatalogException, match="Scalar Function with name func does not exist!"): res = rel.fetchall() def test_use_after_remove_and_recreation(self): @@ -59,18 +59,18 @@ def func(x: str) -> str: return x con = duckdb.connect() - con.create_function('func', func) + con.create_function("func", func) - with pytest.raises(duckdb.BinderException, match='No function matches the given name'): - rel1 = con.sql('select func(42)') + with pytest.raises(duckdb.BinderException, match="No function matches the given name"): + rel1 = con.sql("select func(42)") rel2 = con.sql("select func('test'::VARCHAR)") - con.remove_function('func') + con.remove_function("func") def also_func(x: int) -> int: return x - con.create_function('func', also_func) - with pytest.raises(duckdb.BinderException, match='No function matches the given name'): + con.create_function("func", also_func) + with pytest.raises(duckdb.BinderException, match="No function matches the given name"): res = rel2.fetchall() def test_overwrite_name(self): @@ -79,7 +79,7 @@ def func(x): con = duckdb.connect() # create first version of the function - con.create_function('func', func, [BIGINT], BIGINT) + con.create_function("func", func, [BIGINT], BIGINT) # create relation that uses the function rel1 = con.sql("select func('3')") @@ -91,17 +91,17 @@ def other_func(x): duckdb.NotImplementedException, match="A function by the name of 'func' is already created, creating multiple functions with the same name is not supported yet, please remove it first", ): - con.create_function('func', other_func, [VARCHAR], VARCHAR) + con.create_function("func", other_func, [VARCHAR], VARCHAR) - con.remove_function('func') + con.remove_function("func") with pytest.raises( - duckdb.CatalogException, match='Catalog Error: Scalar Function with name func does not exist!' + duckdb.CatalogException, match="Catalog Error: Scalar Function with name func does not exist!" ): # Attempted to execute the relation using the 'func' function, but it was deleted rel1.fetchall() - con.create_function('func', other_func, [VARCHAR], VARCHAR) + con.create_function("func", other_func, [VARCHAR], VARCHAR) # create relation that uses the new version rel2 = con.sql("select func('test')") @@ -109,5 +109,5 @@ def other_func(x): res1 = rel1.fetchall() res2 = rel2.fetchall() # This has been converted to string, because the previous version of the function no longer exists - assert res1 == [('3',)] - assert res2 == [('test',)] + assert res1 == [("3",)] + assert res2 == [("test",)] diff --git a/tests/fast/udf/test_scalar.py b/tests/fast/udf/test_scalar.py index 8e0eb8b1..c156f94b 100644 --- a/tests/fast/udf/test_scalar.py +++ b/tests/fast/udf/test_scalar.py @@ -3,7 +3,7 @@ import pytest pd = pytest.importorskip("pandas") -pa = pytest.importorskip('pyarrow', '18.0.0') +pa = pytest.importorskip("pyarrow", "18.0.0") from typing import Union, Any import pyarrow.compute as pc import uuid @@ -25,14 +25,14 @@ def test_base(x): test_base.__code__, test_base.__globals__, test_base.__name__, test_base.__defaults__, test_base.__closure__ ) # Add annotations for the return type and 'x' - test_function.__annotations__ = {'return': type, 'x': type} + test_function.__annotations__ = {"return": type, "x": type} return test_function class TestScalarUDF(object): - @pytest.mark.parametrize('function_type', ['native', 'arrow']) + @pytest.mark.parametrize("function_type", ["native", "arrow"]) @pytest.mark.parametrize( - 'test_type', + "test_type", [ (TINYINT, -42), (SMALLINT, -512), @@ -43,21 +43,21 @@ class TestScalarUDF(object): (UINTEGER, 4294967295), (UBIGINT, 18446744073709551615), (HUGEINT, 18446744073709551616), - (VARCHAR, 'long_string_test'), - (UUID, uuid.UUID('ffffffff-ffff-ffff-ffff-ffffffffffff')), + (VARCHAR, "long_string_test"), + (UUID, uuid.UUID("ffffffff-ffff-ffff-ffff-ffffffffffff")), (FLOAT, 0.12246409803628922), (DOUBLE, 123142.12312416293784721232344), (DATE, datetime.date(2005, 3, 11)), (TIMESTAMP, datetime.datetime(2009, 2, 13, 11, 5, 53)), (TIME, datetime.time(14, 1, 12)), - (BLOB, b'\xf6\x96\xb0\x85'), + (BLOB, b"\xf6\x96\xb0\x85"), (INTERVAL, datetime.timedelta(days=30969, seconds=999, microseconds=999999)), (BOOLEAN, True), ( - duckdb.struct_type(['BIGINT[]', 'VARCHAR[]']), - {'v1': [1, 2, 3], 'v2': ['a', 'non-inlined string', 'duckdb']}, + duckdb.struct_type(["BIGINT[]", "VARCHAR[]"]), + {"v1": [1, 2, 3], "v2": ["a", "non-inlined string", "duckdb"]}, ), - (duckdb.list_type('VARCHAR'), ['the', 'duck', 'non-inlined string']), + (duckdb.list_type("VARCHAR"), ["the", "duck", "non-inlined string"]), ], ) def test_type_coverage(self, test_type, function_type): @@ -67,7 +67,7 @@ def test_type_coverage(self, test_type, function_type): test_function = make_annotated_function(type) con = duckdb.connect() - con.create_function('test', test_function, type=function_type) + con.create_function("test", test_function, type=function_type) # Single value res = con.execute(f"select test(?::{str(type)})", [value]).fetchall() assert res[0][0] == value @@ -114,46 +114,46 @@ def test_type_coverage(self, test_type, function_type): # Using 'relation.project' con.execute(f"create table tbl as select ?::{str(type)} as x", [value]) - table_rel = con.table('tbl') - res = table_rel.project('test(x)').fetchall() + table_rel = con.table("tbl") + res = table_rel.project("test(x)").fetchall() assert res[0][0] == value - @pytest.mark.parametrize('udf_type', ['arrow', 'native']) + @pytest.mark.parametrize("udf_type", ["arrow", "native"]) def test_map_coverage(self, udf_type): def no_op(x): return x con = duckdb.connect() - map_type = con.map_type('VARCHAR', 'BIGINT') - con.create_function('test_map', no_op, [map_type], map_type, type=udf_type) + map_type = con.map_type("VARCHAR", "BIGINT") + con.create_function("test_map", no_op, [map_type], map_type, type=udf_type) rel = con.sql("select test_map(map(['non-inlined string', 'test', 'duckdb'], [42, 1337, 123]))") res = rel.fetchall() - assert res == [({'non-inlined string': 42, 'test': 1337, 'duckdb': 123},)] + assert res == [({"non-inlined string": 42, "test": 1337, "duckdb": 123},)] - @pytest.mark.parametrize('udf_type', ['arrow', 'native']) + @pytest.mark.parametrize("udf_type", ["arrow", "native"]) def test_exceptions(self, udf_type): def raises_exception(x): raise AttributeError("error") con = duckdb.connect() - con.create_function('raises', raises_exception, [BIGINT], BIGINT, type=udf_type) + con.create_function("raises", raises_exception, [BIGINT], BIGINT, type=udf_type) with pytest.raises( duckdb.InvalidInputException, - match=' Python exception occurred while executing the UDF: AttributeError: error', + match=" Python exception occurred while executing the UDF: AttributeError: error", ): - res = con.sql('select raises(3)').fetchall() + res = con.sql("select raises(3)").fetchall() - con.remove_function('raises') + con.remove_function("raises") con.create_function( - 'raises', raises_exception, [BIGINT], BIGINT, exception_handling='return_null', type=udf_type + "raises", raises_exception, [BIGINT], BIGINT, exception_handling="return_null", type=udf_type ) - res = con.sql('select raises(3) from range(5)').fetchall() + res = con.sql("select raises(3) from range(5)").fetchall() assert res == [(None,), (None,), (None,), (None,), (None,)] def test_non_callable(self): con = duckdb.connect() with pytest.raises(TypeError): - con.create_function('func', 5, [BIGINT], BIGINT, type='arrow') + con.create_function("func", 5, [BIGINT], BIGINT, type="arrow") class MyCallable: def __init__(self) -> None: @@ -163,22 +163,22 @@ def __call__(self, x) -> Any: return x my_callable = MyCallable() - con.create_function('func', my_callable, [BIGINT], BIGINT, type='arrow') - res = con.sql('select func(5)').fetchall() + con.create_function("func", my_callable, [BIGINT], BIGINT, type="arrow") + res = con.sql("select func(5)").fetchall() assert res == [(5,)] # pyarrow does not support creating an array filled with pd.NA values - @pytest.mark.parametrize('udf_type', ['native']) - @pytest.mark.parametrize('duckdb_type', [FLOAT, DOUBLE]) + @pytest.mark.parametrize("udf_type", ["native"]) + @pytest.mark.parametrize("duckdb_type", [FLOAT, DOUBLE]) def test_pd_nan(self, duckdb_type, udf_type): def return_pd_nan(): - if udf_type == 'native': + if udf_type == "native": return pd.NA con = duckdb.connect() - con.create_function('return_pd_nan', return_pd_nan, None, duckdb_type, null_handling='SPECIAL', type=udf_type) + con.create_function("return_pd_nan", return_pd_nan, None, duckdb_type, null_handling="SPECIAL", type=udf_type) - res = con.sql('select return_pd_nan()').fetchall() + res = con.sql("select return_pd_nan()").fetchall() assert res[0][0] == None def test_side_effects(self): @@ -190,21 +190,21 @@ def count() -> int: count.counter = 0 con = duckdb.connect() - con.create_function('my_counter', count, side_effects=False) - res = con.sql('select my_counter() from range(10)').fetchall() + con.create_function("my_counter", count, side_effects=False) + res = con.sql("select my_counter() from range(10)").fetchall() assert res == [(0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,)] count.counter = 0 - con.remove_function('my_counter') - con.create_function('my_counter', count, side_effects=True) - res = con.sql('select my_counter() from range(10)').fetchall() + con.remove_function("my_counter") + con.create_function("my_counter", count, side_effects=True) + res = con.sql("select my_counter() from range(10)").fetchall() assert res == [(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,), (9,)] - @pytest.mark.parametrize('udf_type', ['arrow', 'native']) - @pytest.mark.parametrize('duckdb_type', [FLOAT, DOUBLE]) + @pytest.mark.parametrize("udf_type", ["arrow", "native"]) + @pytest.mark.parametrize("duckdb_type", [FLOAT, DOUBLE]) def test_np_nan(self, duckdb_type, udf_type): def return_np_nan(): - if udf_type == 'native': + if udf_type == "native": return np.nan else: import pyarrow as pa @@ -212,18 +212,18 @@ def return_np_nan(): return pa.chunked_array([[np.nan]], type=pa.float64()) con = duckdb.connect() - con.create_function('return_np_nan', return_np_nan, None, duckdb_type, null_handling='SPECIAL', type=udf_type) + con.create_function("return_np_nan", return_np_nan, None, duckdb_type, null_handling="SPECIAL", type=udf_type) - res = con.sql('select return_np_nan()').fetchall() + res = con.sql("select return_np_nan()").fetchall() assert pd.isnull(res[0][0]) - @pytest.mark.parametrize('udf_type', ['arrow', 'native']) - @pytest.mark.parametrize('duckdb_type', [FLOAT, DOUBLE]) + @pytest.mark.parametrize("udf_type", ["arrow", "native"]) + @pytest.mark.parametrize("duckdb_type", [FLOAT, DOUBLE]) def test_math_nan(self, duckdb_type, udf_type): def return_math_nan(): import cmath - if udf_type == 'native': + if udf_type == "native": return cmath.nan else: import pyarrow as pa @@ -232,15 +232,15 @@ def return_math_nan(): con = duckdb.connect() con.create_function( - 'return_math_nan', return_math_nan, None, duckdb_type, null_handling='SPECIAL', type=udf_type + "return_math_nan", return_math_nan, None, duckdb_type, null_handling="SPECIAL", type=udf_type ) - res = con.sql('select return_math_nan()').fetchall() + res = con.sql("select return_math_nan()").fetchall() assert pd.isnull(res[0][0]) - @pytest.mark.parametrize('udf_type', ['arrow', 'native']) + @pytest.mark.parametrize("udf_type", ["arrow", "native"]) @pytest.mark.parametrize( - 'data_type', + "data_type", [ TINYINT, SMALLINT, @@ -262,13 +262,13 @@ def return_math_nan(): BLOB, INTERVAL, BOOLEAN, - duckdb.struct_type(['BIGINT[]', 'VARCHAR[]']), - duckdb.list_type('VARCHAR'), + duckdb.struct_type(["BIGINT[]", "VARCHAR[]"]), + duckdb.list_type("VARCHAR"), ], ) def test_return_null(self, data_type, udf_type): def return_null(): - if udf_type == 'native': + if udf_type == "native": return None else: import pyarrow as pa @@ -276,8 +276,8 @@ def return_null(): return pa.nulls(1) con = duckdb.connect() - con.create_function('return_null', return_null, None, data_type, null_handling='special', type=udf_type) - rel = con.sql('select return_null() as x') + con.create_function("return_null", return_null, None, data_type, null_handling="special", type=udf_type) + rel = con.sql("select return_null() as x") assert rel.types[0] == data_type assert rel.fetchall()[0][0] == None @@ -286,13 +286,13 @@ def func(x: int) -> int: return x con = duckdb.connect() - rel = con.sql('select 42') + rel = con.sql("select 42") # Using fetchone keeps the result open, with a transaction rel.fetchone() - con.create_function('func', func) + con.create_function("func", func) rel.fetchall() - res = con.sql('select func(5)').fetchall() + res = con.sql("select func(5)").fetchall() assert res == [(5,)] diff --git a/tests/fast/udf/test_scalar_arrow.py b/tests/fast/udf/test_scalar_arrow.py index 5773c474..794ebc35 100644 --- a/tests/fast/udf/test_scalar_arrow.py +++ b/tests/fast/udf/test_scalar_arrow.py @@ -15,35 +15,35 @@ class TestPyArrowUDF(object): def test_basic_use(self): def plus_one(x): - table = pa.lib.Table.from_arrays([x], names=['c0']) + table = pa.lib.Table.from_arrays([x], names=["c0"]) import pandas as pd df = pd.DataFrame(x.to_pandas()) - df['c0'] = df['c0'] + 1 + df["c0"] = df["c0"] + 1 return pa.lib.Table.from_pandas(df) con = duckdb.connect() - con.create_function('plus_one', plus_one, [BIGINT], BIGINT, type='arrow') - assert [(6,)] == con.sql('select plus_one(5)').fetchall() + con.create_function("plus_one", plus_one, [BIGINT], BIGINT, type="arrow") + assert [(6,)] == con.sql("select plus_one(5)").fetchall() - range_table = con.table_function('range', [5000]) - res = con.sql('select plus_one(i) from range_table tbl(i)').fetchall() + range_table = con.table_function("range", [5000]) + res = con.sql("select plus_one(i) from range_table tbl(i)").fetchall() assert len(res) == 5000 vector_size = duckdb.__standard_vector_size__ - res = con.sql(f'select i, plus_one(i) from test_vector_types(NULL::BIGINT, false) t(i), range({vector_size})') + res = con.sql(f"select i, plus_one(i) from test_vector_types(NULL::BIGINT, false) t(i), range({vector_size})") assert len(res) == (vector_size * 11) # NOTE: This only works up to duckdb.__standard_vector_size__, # because we process up to STANDARD_VECTOR_SIZE tuples at a time def test_sort_table(self): def sort_table(x): - table = pa.lib.Table.from_arrays([x], names=['c0']) + table = pa.lib.Table.from_arrays([x], names=["c0"]) sorted_table = table.sort_by([("c0", "ascending")]) return sorted_table con = duckdb.connect() - con.create_function('sort_table', sort_table, [BIGINT], BIGINT, type='arrow') + con.create_function("sort_table", sort_table, [BIGINT], BIGINT, type="arrow") res = con.sql("select 100-i as original, sort_table(original) from range(100) tbl(i)").fetchall() assert res[0] == (100, 1) @@ -57,7 +57,7 @@ def variable_args(*args): con = duckdb.connect() # This function takes any number of arguments, returning the first column - con.create_function('varargs', variable_args, None, BIGINT, type='arrow') + con.create_function("varargs", variable_args, None, BIGINT, type="arrow") res = con.sql("""select varargs(5, '3', '2', 1, 0.12345)""").fetchall() assert res == [(5,)] @@ -70,7 +70,7 @@ def takes_string(col): con = duckdb.connect() # The return type of the function is set to BIGINT, but it takes a VARCHAR - con.create_function('pyarrow_string_to_num', takes_string, [VARCHAR], BIGINT, type='arrow') + con.create_function("pyarrow_string_to_num", takes_string, [VARCHAR], BIGINT, type="arrow") # Successful conversion res = con.sql("""select pyarrow_string_to_num('5')""").fetchall() @@ -84,14 +84,14 @@ def returns_two_columns(col): import pandas as pd # Return a pyarrow table consisting of two columns - return pa.lib.Table.from_pandas(pd.DataFrame({'a': [5, 4, 3], 'b': ['test', 'quack', 'duckdb']})) + return pa.lib.Table.from_pandas(pd.DataFrame({"a": [5, 4, 3], "b": ["test", "quack", "duckdb"]})) con = duckdb.connect() # Scalar functions only return a single value per tuple - con.create_function('two_columns', returns_two_columns, [BIGINT], BIGINT, type='arrow') + con.create_function("two_columns", returns_two_columns, [BIGINT], BIGINT, type="arrow") with pytest.raises( duckdb.InvalidInputException, - match='The returned table from a pyarrow scalar udf should only contain one column, found 2', + match="The returned table from a pyarrow scalar udf should only contain one column, found 2", ): res = con.sql("""select two_columns(5)""").fetchall() @@ -100,35 +100,35 @@ def returns_none(col): return None con = duckdb.connect() - con.create_function('will_crash', returns_none, [BIGINT], BIGINT, type='arrow') + con.create_function("will_crash", returns_none, [BIGINT], BIGINT, type="arrow") with pytest.raises(duckdb.Error, match="""Could not convert the result into an Arrow Table"""): res = con.sql("""select will_crash(5)""").fetchall() def test_empty_result(self): def return_empty(col): # Always returns an empty table - return pa.lib.Table.from_arrays([[]], names=['c0']) + return pa.lib.Table.from_arrays([[]], names=["c0"]) con = duckdb.connect() - con.create_function('empty_result', return_empty, [BIGINT], BIGINT, type='arrow') - with pytest.raises(duckdb.InvalidInputException, match='Returned pyarrow table should have 1 tuples, found 0'): + con.create_function("empty_result", return_empty, [BIGINT], BIGINT, type="arrow") + with pytest.raises(duckdb.InvalidInputException, match="Returned pyarrow table should have 1 tuples, found 0"): res = con.sql("""select empty_result(5)""").fetchall() def test_excessive_result(self): def return_too_many(col): # Always returns a table consisting of 5 tuples - return pa.lib.Table.from_arrays([[5, 4, 3, 2, 1]], names=['c0']) + return pa.lib.Table.from_arrays([[5, 4, 3, 2, 1]], names=["c0"]) con = duckdb.connect() - con.create_function('too_many_tuples', return_too_many, [BIGINT], BIGINT, type='arrow') - with pytest.raises(duckdb.InvalidInputException, match='Returned pyarrow table should have 1 tuples, found 5'): + con.create_function("too_many_tuples", return_too_many, [BIGINT], BIGINT, type="arrow") + with pytest.raises(duckdb.InvalidInputException, match="Returned pyarrow table should have 1 tuples, found 5"): res = con.sql("""select too_many_tuples(5)""").fetchall() def test_arrow_side_effects(self, duckdb_cursor): import random as r def random_arrow(x): - if not hasattr(random_arrow, 'data'): + if not hasattr(random_arrow, "data"): random_arrow.data = 0 input = x.to_pylist() @@ -158,17 +158,17 @@ def return_struct(col): ).fetch_arrow_table() con = duckdb.connect() - struct_type = con.struct_type({'a': BIGINT, 'b': VARCHAR, 'c': con.list_type(BIGINT)}) - con.create_function('return_struct', return_struct, [BIGINT], struct_type, type='arrow') + struct_type = con.struct_type({"a": BIGINT, "b": VARCHAR, "c": con.list_type(BIGINT)}) + con.create_function("return_struct", return_struct, [BIGINT], struct_type, type="arrow") res = con.sql("""select return_struct(5)""").fetchall() - assert res == [({'a': 5, 'b': 'test', 'c': [5, 3, 2]},)] + assert res == [({"a": 5, "b": "test", "c": [5, 3, 2]},)] def test_multiple_chunks(self): def return_unmodified(col): return col con = duckdb.connect() - con.create_function('unmodified', return_unmodified, [BIGINT], BIGINT, type='arrow') + con.create_function("unmodified", return_unmodified, [BIGINT], BIGINT, type="arrow") res = con.sql( """ select unmodified(i) from range(5000) tbl(i) @@ -176,19 +176,19 @@ def return_unmodified(col): ).fetchall() assert len(res) == 5000 - assert res == con.sql('select * from range(5000)').fetchall() + assert res == con.sql("select * from range(5000)").fetchall() def test_inferred(self): def func(x: int) -> int: import pandas as pd - df = pd.DataFrame({'c0': x}) - df['c0'] = df['c0'] ** 2 + df = pd.DataFrame({"c0": x}) + df["c0"] = df["c0"] ** 2 return pa.lib.Table.from_pandas(df) con = duckdb.connect() - con.create_function('inferred', func, type='arrow') - res = con.sql('select inferred(42)').fetchall() + con.create_function("inferred", func, type="arrow") + res = con.sql("select inferred(42)").fetchall() assert res == [(1764,)] def test_nulls(self): @@ -196,27 +196,27 @@ def return_five(x): import pandas as pd length = len(x) - return pa.lib.Table.from_pandas(pd.DataFrame({'a': [5 for _ in range(length)]})) + return pa.lib.Table.from_pandas(pd.DataFrame({"a": [5 for _ in range(length)]})) con = duckdb.connect() - con.create_function('return_five', return_five, [BIGINT], BIGINT, null_handling='special', type='arrow') - res = con.sql('select return_five(NULL) from range(10)').fetchall() + con.create_function("return_five", return_five, [BIGINT], BIGINT, null_handling="special", type="arrow") + res = con.sql("select return_five(NULL) from range(10)").fetchall() # without 'special' null handling these would all be NULL assert res == [(5,), (5,), (5,), (5,), (5,), (5,), (5,), (5,), (5,), (5,)] con = duckdb.connect() - con.create_function('return_five', return_five, [BIGINT], BIGINT, null_handling='default', type='arrow') - res = con.sql('select return_five(NULL) from range(10)').fetchall() + con.create_function("return_five", return_five, [BIGINT], BIGINT, null_handling="default", type="arrow") + res = con.sql("select return_five(NULL) from range(10)").fetchall() # Because we didn't specify 'special' null handling, these are all NULL assert res == [(None,), (None,), (None,), (None,), (None,), (None,), (None,), (None,), (None,), (None,)] def test_struct_with_non_inlined_string(self, duckdb_cursor): def func(data): - return pa.array([{'x': 1, 'y': 'this is not an inlined string'}] * data.length()) + return pa.array([{"x": 1, "y": "this is not an inlined string"}] * data.length()) duckdb_cursor.create_function( name="func", function=func, return_type="STRUCT(x integer, y varchar)", type="arrow", side_effects=False ) res = duckdb_cursor.sql("select func(1).y").fetchone() - assert res == ('this is not an inlined string',) + assert res == ("this is not an inlined string",) diff --git a/tests/fast/udf/test_scalar_native.py b/tests/fast/udf/test_scalar_native.py index df58f6a4..0c5cf927 100644 --- a/tests/fast/udf/test_scalar_native.py +++ b/tests/fast/udf/test_scalar_native.py @@ -11,8 +11,8 @@ def test_default_conn(self): def passthrough(x): return x - duckdb.create_function('default_conn_passthrough', passthrough, [BIGINT], BIGINT) - res = duckdb.sql('select default_conn_passthrough(5)').fetchall() + duckdb.create_function("default_conn_passthrough", passthrough, [BIGINT], BIGINT) + res = duckdb.sql("select default_conn_passthrough(5)").fetchall() assert res == [(5,)] def test_basic_use(self): @@ -22,15 +22,15 @@ def plus_one(x): return x + 1 con = duckdb.connect() - con.create_function('plus_one', plus_one, [BIGINT], BIGINT) - assert [(6,)] == con.sql('select plus_one(5)').fetchall() + con.create_function("plus_one", plus_one, [BIGINT], BIGINT) + assert [(6,)] == con.sql("select plus_one(5)").fetchall() - range_table = con.table_function('range', [5000]) - res = con.sql('select plus_one(i) from range_table tbl(i)').fetchall() + range_table = con.table_function("range", [5000]) + res = con.sql("select plus_one(i) from range_table tbl(i)").fetchall() assert len(res) == 5000 vector_size = duckdb.__standard_vector_size__ - res = con.sql(f'select i, plus_one(i) from test_vector_types(NULL::BIGINT, false) t(i), range({vector_size})') + res = con.sql(f"select i, plus_one(i) from test_vector_types(NULL::BIGINT, false) t(i), range({vector_size})") assert len(res) == (vector_size * 11) def test_passthrough(self): @@ -38,10 +38,10 @@ def passthrough(x): return x con = duckdb.connect() - con.create_function('passthrough', passthrough, [BIGINT], BIGINT) + con.create_function("passthrough", passthrough, [BIGINT], BIGINT) assert ( - con.sql('select passthrough(i) from range(5000) tbl(i)').fetchall() - == con.sql('select * from range(5000)').fetchall() + con.sql("select passthrough(i) from range(5000) tbl(i)").fetchall() + == con.sql("select * from range(5000)").fetchall() ) def test_execute(self): @@ -49,8 +49,8 @@ def func(x): return x % 2 con = duckdb.connect() - con.create_function('modulo_op', func, [BIGINT], TINYINT) - res = con.execute('select modulo_op(?)', [5]).fetchall() + con.create_function("modulo_op", func, [BIGINT], TINYINT) + res = con.execute("select modulo_op(?)", [5]).fetchall() assert res == [(1,)] def test_cast_output(self): @@ -58,7 +58,7 @@ def takes_string(x): return x con = duckdb.connect() - con.create_function('casts_from_string', takes_string, [VARCHAR], BIGINT) + con.create_function("casts_from_string", takes_string, [VARCHAR], BIGINT) res = con.sql("select casts_from_string('42')").fetchall() assert res == [(42,)] @@ -71,13 +71,13 @@ def concatenate(a: str, b: str): return a + b con = duckdb.connect() - con.create_function('py_concatenate', concatenate, None, VARCHAR) + con.create_function("py_concatenate", concatenate, None, VARCHAR) res = con.sql( """ select py_concatenate('5','3'); """ ).fetchall() - assert res[0][0] == '53' + assert res[0][0] == "53" def test_detected_return_type(self): def add_nums(*args) -> int: @@ -87,7 +87,7 @@ def add_nums(*args) -> int: return sum con = duckdb.connect() - con.create_function('add_nums', add_nums) + con.create_function("add_nums", add_nums) res = con.sql( """ select add_nums(5,3,2,1); @@ -101,20 +101,20 @@ def variable_args(*args): return amount con = duckdb.connect() - con.create_function('varargs', variable_args, None, BIGINT) + con.create_function("varargs", variable_args, None, BIGINT) res = con.sql("""select varargs('5', '3', '2', 1, 0.12345)""").fetchall() assert res == [(5,)] def test_return_incorrectly_typed_object(self): def returns_duckdb() -> int: - return 'duckdb' + return "duckdb" con = duckdb.connect() - con.create_function('fastest_database_in_the_west', returns_duckdb) + con.create_function("fastest_database_in_the_west", returns_duckdb) with pytest.raises( duckdb.InvalidInputException, match="Failed to cast value: Could not convert string 'duckdb' to INT64" ): - res = con.sql('select fastest_database_in_the_west()').fetchall() + res = con.sql("select fastest_database_in_the_west()").fetchall() def test_nulls(self): def five_if_null(x): @@ -123,12 +123,12 @@ def five_if_null(x): return x con = duckdb.connect() - con.create_function('null_test', five_if_null, [BIGINT], BIGINT, null_handling="SPECIAL") - res = con.sql('select null_test(NULL)').fetchall() + con.create_function("null_test", five_if_null, [BIGINT], BIGINT, null_handling="SPECIAL") + res = con.sql("select null_test(NULL)").fetchall() assert res == [(5,)] @pytest.mark.parametrize( - 'pair', + "pair", [ (TINYINT, -129), (TINYINT, 128), @@ -159,26 +159,26 @@ def return_overflow(): return overflowing_value con = duckdb.connect() - con.create_function('return_overflow', return_overflow, None, duckdb_type) + con.create_function("return_overflow", return_overflow, None, duckdb_type) with pytest.raises(duckdb.InvalidInputException): - rel = con.sql('select return_overflow()') + rel = con.sql("select return_overflow()") res = rel.fetchall() print(duckdb_type) print(res) def test_structs(self): def add_extra_column(original): - original['a'] = 200 - original['c'] = 0 + original["a"] = 200 + original["c"] = 0 return original con = duckdb.connect() - range_table = con.table_function('range', [5000]) + range_table = con.table_function("range", [5000]) con.create_function( "append_field", add_extra_column, - [duckdb.struct_type({'a': BIGINT, 'b': BIGINT})], - duckdb.struct_type({'a': BIGINT, 'b': BIGINT, 'c': BIGINT}), + [duckdb.struct_type({"a": BIGINT, "b": BIGINT})], + duckdb.struct_type({"a": BIGINT, "b": BIGINT, "c": BIGINT}), ) res = con.sql( @@ -205,17 +205,17 @@ def swap_keys(dict): return result con.create_function( - 'swap_keys', + "swap_keys", swap_keys, - [con.struct_type({'a': BIGINT, 'b': VARCHAR})], - con.struct_type({'a': VARCHAR, 'b': BIGINT}), + [con.struct_type({"a": BIGINT, "b": VARCHAR})], + con.struct_type({"a": VARCHAR, "b": BIGINT}), ) res = con.sql( """ select swap_keys({'a': 42, 'b': 'answer_to_life'}) """ ).fetchall() - assert res == [({'a': 'answer_to_life', 'b': 42},)] + assert res == [({"a": "answer_to_life", "b": 42},)] def test_struct_different_field_order(self, duckdb_cursor): def example(): diff --git a/tests/fast/udf/test_transactionality.py b/tests/fast/udf/test_transactionality.py index 50286e8e..134df663 100644 --- a/tests/fast/udf/test_transactionality.py +++ b/tests/fast/udf/test_transactionality.py @@ -3,7 +3,7 @@ class TestUDFTransactionality(object): - @pytest.mark.xfail(reason='fetchone() does not realize the stream result was closed before completion') + @pytest.mark.xfail(reason="fetchone() does not realize the stream result was closed before completion") def test_type_coverage(self, duckdb_cursor): rel = duckdb_cursor.sql("select * from range(4096)") res = rel.fetchone() @@ -12,7 +12,7 @@ def test_type_coverage(self, duckdb_cursor): def my_func(x: str) -> int: return int(x) - duckdb_cursor.create_function('test', my_func) + duckdb_cursor.create_function("test", my_func) - with pytest.raises(duckdb.InvalidInputException, match='result closed'): + with pytest.raises(duckdb.InvalidInputException, match="result closed"): res = rel.fetchone() diff --git a/tests/slow/test_h2oai_arrow.py b/tests/slow/test_h2oai_arrow.py index 40bde07b..b0901ab8 100644 --- a/tests/slow/test_h2oai_arrow.py +++ b/tests/slow/test_h2oai_arrow.py @@ -3,17 +3,17 @@ import math from pytest import mark, fixture, importorskip -read_csv = importorskip('pyarrow.csv').read_csv -requests = importorskip('requests') -requests_adapters = importorskip('requests.adapters') -urllib3_util = importorskip('urllib3.util') -np = importorskip('numpy') +read_csv = importorskip("pyarrow.csv").read_csv +requests = importorskip("requests") +requests_adapters = importorskip("requests.adapters") +urllib3_util = importorskip("urllib3.util") +np = importorskip("numpy") def group_by_q1(con): con.execute("CREATE TABLE ans AS SELECT id1, sum(v1) AS v1 FROM x GROUP BY id1") res = con.execute("SELECT COUNT(*), sum(v1)::varchar AS v1 FROM ans").fetchall() - assert res == [(96, '28498857')] + assert res == [(96, "28498857")] con.execute("DROP TABLE ans") @@ -155,7 +155,7 @@ def join_by_q5(con): class TestH2OAIArrow(object): @mark.parametrize( - 'function', + "function", [ group_by_q1, group_by_q2, @@ -169,15 +169,15 @@ class TestH2OAIArrow(object): group_by_q10, ], ) - @mark.parametrize('threads', [1, 4]) - @mark.usefixtures('group_by_data') + @mark.parametrize("threads", [1, 4]) + @mark.usefixtures("group_by_data") def test_group_by(self, threads, function, group_by_data): group_by_data.execute(f"PRAGMA threads={threads}") function(group_by_data) - @mark.parametrize('threads', [1, 4]) + @mark.parametrize("threads", [1, 4]) @mark.parametrize( - 'function', + "function", [ join_by_q1, join_by_q2, @@ -186,7 +186,7 @@ def test_group_by(self, threads, function, group_by_data): join_by_q5, ], ) - @mark.usefixtures('large_data') + @mark.usefixtures("large_data") def test_join(self, threads, function, large_data): large_data.execute(f"PRAGMA threads={threads}") @@ -198,7 +198,7 @@ def arrow_dataset_register(): """Single fixture to download files and register them on the given connection""" session = requests.Session() retries = urllib3_util.Retry( - allowed_methods={'GET'}, # only retry on GETs (all we do) + allowed_methods={"GET"}, # only retry on GETs (all we do) total=None, # disable to make the below take effect redirect=10, # Don't follow more than 10 redirects in a row connect=3, # try 3 times before giving up on connection errors @@ -211,12 +211,12 @@ def arrow_dataset_register(): raise_on_status=True, # raise exception when status error retries are exhausted respect_retry_after_header=True, # respect Retry-After headers ) - session.mount('https://', requests_adapters.HTTPAdapter(max_retries=retries)) + session.mount("https://", requests_adapters.HTTPAdapter(max_retries=retries)) saved_filenames = set() def _register(url, filename, con, tablename): r = session.get(url) - with open(filename, 'wb') as f: + with open(filename, "wb") as f: f.write(r.content) con.register(tablename, read_csv(filename)) saved_filenames.add(filename) @@ -232,26 +232,26 @@ def _register(url, filename, con, tablename): def large_data(arrow_dataset_register): con = duckdb.connect() arrow_dataset_register( - 'https://github.com/duckdb/duckdb-data/releases/download/v1.0/J1_1e7_NA_0_0.csv.gz', - 'J1_1e7_NA_0_0.csv.gz', + "https://github.com/duckdb/duckdb-data/releases/download/v1.0/J1_1e7_NA_0_0.csv.gz", + "J1_1e7_NA_0_0.csv.gz", con, "x", ) arrow_dataset_register( - 'https://github.com/duckdb/duckdb-data/releases/download/v1.0/J1_1e7_1e1_0_0.csv.gz', - 'J1_1e7_1e1_0_0.csv.gz', + "https://github.com/duckdb/duckdb-data/releases/download/v1.0/J1_1e7_1e1_0_0.csv.gz", + "J1_1e7_1e1_0_0.csv.gz", con, "small", ) arrow_dataset_register( - 'https://github.com/duckdb/duckdb-data/releases/download/v1.0/J1_1e7_1e4_0_0.csv.gz', - 'J1_1e7_1e4_0_0.csv.gz', + "https://github.com/duckdb/duckdb-data/releases/download/v1.0/J1_1e7_1e4_0_0.csv.gz", + "J1_1e7_1e4_0_0.csv.gz", con, "medium", ) arrow_dataset_register( - 'https://github.com/duckdb/duckdb-data/releases/download/v1.0/J1_1e7_1e7_0_0.csv.gz', - 'J1_1e7_1e7_0_0.csv.gz', + "https://github.com/duckdb/duckdb-data/releases/download/v1.0/J1_1e7_1e7_0_0.csv.gz", + "J1_1e7_1e7_0_0.csv.gz", con, "big", ) @@ -263,8 +263,8 @@ def large_data(arrow_dataset_register): def group_by_data(arrow_dataset_register): con = duckdb.connect() arrow_dataset_register( - 'https://github.com/duckdb/duckdb-data/releases/download/v1.0/G1_1e7_1e2_5_0.csv.gz', - 'G1_1e7_1e2_5_0.csv.gz', + "https://github.com/duckdb/duckdb-data/releases/download/v1.0/G1_1e7_1e2_5_0.csv.gz", + "G1_1e7_1e2_5_0.csv.gz", con, "x", ) diff --git a/tests/stubs/test_stubs.py b/tests/stubs/test_stubs.py index 2f178bcc..c68f7068 100644 --- a/tests/stubs/test_stubs.py +++ b/tests/stubs/test_stubs.py @@ -2,18 +2,18 @@ from mypy import stubtest -MYPY_INI_PATH = os.path.join(os.path.dirname(__file__), 'mypy.ini') +MYPY_INI_PATH = os.path.join(os.path.dirname(__file__), "mypy.ini") def test_generated_stubs(): - skip_stubs_errors = ['pybind11', 'git_revision', 'is inconsistent, metaclass differs'] + skip_stubs_errors = ["pybind11", "git_revision", "is inconsistent, metaclass differs"] - options = stubtest.parse_options(['duckdb', '--mypy-config-file', MYPY_INI_PATH]) + options = stubtest.parse_options(["duckdb", "--mypy-config-file", MYPY_INI_PATH]) stubtest.test_stubs(options) broken_stubs = [ error.get_description() - for error in stubtest.test_module('duckdb') + for error in stubtest.test_module("duckdb") if not any(skip in error.get_description() for skip in skip_stubs_errors) ] From 0ec05720d1bef8e64c7305280c82e7263cca940f Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Tue, 9 Sep 2025 20:28:42 +0200 Subject: [PATCH 012/135] Ruff linter fixes --- adbc_driver_duckdb/dbapi.py | 4 +- duckdb/__init__.py | 234 ++-- duckdb/experimental/spark/__init__.py | 6 +- duckdb/experimental/spark/_globals.py | 3 +- duckdb/experimental/spark/_typing.py | 7 +- duckdb/experimental/spark/conf.py | 3 +- duckdb/experimental/spark/context.py | 4 +- duckdb/experimental/spark/errors/__init__.py | 66 +- .../spark/errors/exceptions/base.py | 91 +- duckdb/experimental/spark/errors/utils.py | 10 +- duckdb/experimental/spark/exception.py | 3 +- duckdb/experimental/spark/sql/__init__.py | 10 +- duckdb/experimental/spark/sql/_typing.py | 2 - duckdb/experimental/spark/sql/catalog.py | 5 +- duckdb/experimental/spark/sql/column.py | 36 +- duckdb/experimental/spark/sql/conf.py | 3 +- duckdb/experimental/spark/sql/dataframe.py | 131 +- duckdb/experimental/spark/sql/functions.py | 1162 ++++++++--------- duckdb/experimental/spark/sql/group.py | 30 +- duckdb/experimental/spark/sql/readwriter.py | 18 +- duckdb/experimental/spark/sql/session.py | 24 +- duckdb/experimental/spark/sql/streaming.py | 2 +- duckdb/experimental/spark/sql/type_utils.py | 48 +- duckdb/experimental/spark/sql/types.py | 150 +-- duckdb/filesystem.py | 8 +- duckdb/functional/__init__.py | 4 +- duckdb/polars_io.py | 33 +- duckdb/query_graph/__main__.py | 4 +- duckdb/typing/__init__.py | 10 +- duckdb/udf.py | 5 +- duckdb/value/constant/__init__.py | 49 +- duckdb_packaging/_versioning.py | 6 +- duckdb_packaging/build_backend.py | 34 +- duckdb_packaging/pypi_cleanup.py | 21 +- duckdb_packaging/setuptools_scm_version.py | 13 +- scripts/generate_connection_methods.py | 6 +- scripts/generate_connection_stubs.py | 6 +- .../generate_connection_wrapper_methods.py | 14 +- scripts/generate_connection_wrapper_stubs.py | 8 +- scripts/generate_import_cache_cpp.py | 3 +- scripts/generate_import_cache_json.py | 8 +- scripts/get_cpp_methods.py | 4 +- sqllogic/conftest.py | 16 +- sqllogic/test_sqllogic.py | 17 +- tests/conftest.py | 19 +- .../test_pandas_categorical_coverage.py | 8 +- tests/extensions/json/test_read_json.py | 10 +- tests/extensions/test_extensions_loading.py | 4 +- tests/extensions/test_httpfs.py | 18 +- tests/fast/adbc/test_adbc.py | 8 +- tests/fast/adbc/test_connection_get_info.py | 5 +- tests/fast/adbc/test_statement_bind.py | 2 +- tests/fast/api/test_3324.py | 3 +- tests/fast/api/test_3654.py | 7 +- tests/fast/api/test_3728.py | 2 +- tests/fast/api/test_6315.py | 2 +- tests/fast/api/test_attribute_getter.py | 13 +- tests/fast/api/test_config.py | 15 +- tests/fast/api/test_connection_close.py | 8 +- tests/fast/api/test_connection_interrupt.py | 5 +- tests/fast/api/test_cursor.py | 3 +- tests/fast/api/test_dbapi00.py | 5 +- tests/fast/api/test_dbapi01.py | 3 +- tests/fast/api/test_dbapi04.py | 2 +- tests/fast/api/test_dbapi05.py | 2 +- tests/fast/api/test_dbapi07.py | 7 +- tests/fast/api/test_dbapi08.py | 6 +- tests/fast/api/test_dbapi09.py | 5 +- tests/fast/api/test_dbapi10.py | 7 +- tests/fast/api/test_dbapi11.py | 7 +- tests/fast/api/test_dbapi12.py | 8 +- tests/fast/api/test_dbapi13.py | 5 +- tests/fast/api/test_dbapi_fetch.py | 10 +- tests/fast/api/test_duckdb_connection.py | 51 +- tests/fast/api/test_duckdb_execute.py | 5 +- tests/fast/api/test_duckdb_query.py | 7 +- tests/fast/api/test_explain.py | 3 +- tests/fast/api/test_fsspec.py | 10 +- tests/fast/api/test_insert_into.py | 7 +- tests/fast/api/test_join.py | 5 +- tests/fast/api/test_native_tz.py | 7 +- tests/fast/api/test_query_interrupt.py | 11 +- tests/fast/api/test_query_progress.py | 9 +- tests/fast/api/test_read_csv.py | 12 +- tests/fast/api/test_relation_to_view.py | 3 +- tests/fast/api/test_streaming_result.py | 3 +- tests/fast/api/test_to_csv.py | 15 +- tests/fast/api/test_to_parquet.py | 16 +- .../api/test_with_propagating_exceptions.py | 3 +- tests/fast/arrow/parquet_write_roundtrip.py | 10 +- tests/fast/arrow/test_10795.py | 3 +- tests/fast/arrow/test_12384.py | 6 +- tests/fast/arrow/test_14344.py | 7 +- tests/fast/arrow/test_2426.py | 5 +- tests/fast/arrow/test_5547.py | 5 +- tests/fast/arrow/test_6584.py | 4 +- tests/fast/arrow/test_6796.py | 5 +- tests/fast/arrow/test_7652.py | 7 +- tests/fast/arrow/test_7699.py | 6 +- tests/fast/arrow/test_8522.py | 7 +- tests/fast/arrow/test_9443.py | 3 +- tests/fast/arrow/test_arrow_batch_index.py | 5 +- tests/fast/arrow/test_arrow_binary_view.py | 5 +- tests/fast/arrow/test_arrow_case_sensitive.py | 3 +- tests/fast/arrow/test_arrow_decimal256.py | 8 +- tests/fast/arrow/test_arrow_decimal_32_64.py | 8 +- tests/fast/arrow/test_arrow_extensions.py | 12 +- tests/fast/arrow/test_arrow_fetch.py | 5 +- .../arrow/test_arrow_fetch_recordbatch.py | 5 +- tests/fast/arrow/test_arrow_fixed_binary.py | 2 +- tests/fast/arrow/test_arrow_ipc.py | 3 +- tests/fast/arrow/test_arrow_list.py | 7 +- tests/fast/arrow/test_arrow_offsets.py | 42 +- tests/fast/arrow/test_arrow_pycapsule.py | 7 +- .../arrow/test_arrow_recordbatchreader.py | 6 +- .../fast/arrow/test_arrow_replacement_scan.py | 9 +- .../fast/arrow/test_arrow_run_end_encoding.py | 43 +- tests/fast/arrow/test_arrow_scanner.py | 10 +- tests/fast/arrow/test_arrow_string_view.py | 6 +- tests/fast/arrow/test_arrow_types.py | 5 +- tests/fast/arrow/test_arrow_union.py | 3 +- tests/fast/arrow/test_arrow_version_format.py | 14 +- tests/fast/arrow/test_binary_type.py | 6 +- tests/fast/arrow/test_buffer_size_option.py | 5 +- tests/fast/arrow/test_dataset.py | 6 +- tests/fast/arrow/test_date.py | 8 +- tests/fast/arrow/test_dictionary_arrow.py | 3 +- tests/fast/arrow/test_filter_pushdown.py | 12 +- tests/fast/arrow/test_integration.py | 10 +- tests/fast/arrow/test_interval.py | 9 +- tests/fast/arrow/test_large_offsets.py | 9 +- tests/fast/arrow/test_large_string.py | 6 +- tests/fast/arrow/test_multiple_reads.py | 5 +- tests/fast/arrow/test_nested_arrow.py | 6 +- tests/fast/arrow/test_parallel.py | 7 +- tests/fast/arrow/test_polars.py | 9 +- tests/fast/arrow/test_progress.py | 7 +- tests/fast/arrow/test_projection_pushdown.py | 4 +- tests/fast/arrow/test_time.py | 8 +- tests/fast/arrow/test_timestamp_timezone.py | 8 +- tests/fast/arrow/test_timestamps.py | 8 +- tests/fast/arrow/test_tpch.py | 6 +- tests/fast/arrow/test_unregister.py | 10 +- tests/fast/arrow/test_view.py | 4 +- tests/fast/numpy/test_numpy_new_path.py | 8 +- tests/fast/pandas/test_2304.py | 7 +- tests/fast/pandas/test_append_df.py | 7 +- tests/fast/pandas/test_bug2281.py | 9 +- tests/fast/pandas/test_bug5922.py | 7 +- tests/fast/pandas/test_copy_on_write.py | 5 +- .../pandas/test_create_table_from_pandas.py | 9 +- tests/fast/pandas/test_date_as_datetime.py | 5 +- tests/fast/pandas/test_datetime_time.py | 10 +- tests/fast/pandas/test_datetime_timestamp.py | 7 +- tests/fast/pandas/test_df_analyze.py | 9 +- .../fast/pandas/test_df_object_resolution.py | 18 +- tests/fast/pandas/test_df_recursive_nested.py | 11 +- tests/fast/pandas/test_fetch_df_chunk.py | 3 +- tests/fast/pandas/test_fetch_nested.py | 7 +- .../fast/pandas/test_implicit_pandas_scan.py | 7 +- tests/fast/pandas/test_import_cache.py | 5 +- tests/fast/pandas/test_issue_1767.py | 9 +- tests/fast/pandas/test_limit.py | 7 +- tests/fast/pandas/test_pandas_arrow.py | 7 +- tests/fast/pandas/test_pandas_category.py | 7 +- tests/fast/pandas/test_pandas_df_none.py | 6 +- tests/fast/pandas/test_pandas_enum.py | 7 +- tests/fast/pandas/test_pandas_limit.py | 5 +- tests/fast/pandas/test_pandas_na.py | 11 +- tests/fast/pandas/test_pandas_object.py | 9 +- tests/fast/pandas/test_pandas_string.py | 11 +- tests/fast/pandas/test_pandas_timestamp.py | 8 +- tests/fast/pandas/test_pandas_types.py | 14 +- tests/fast/pandas/test_pandas_unregister.py | 13 +- tests/fast/pandas/test_pandas_update.py | 5 +- .../fast/pandas/test_parallel_pandas_scan.py | 13 +- .../pandas/test_partitioned_pandas_scan.py | 10 +- tests/fast/pandas/test_progress_bar.py | 10 +- .../test_pyarrow_projection_pushdown.py | 8 +- tests/fast/pandas/test_same_name.py | 4 +- tests/fast/pandas/test_stride.py | 8 +- tests/fast/pandas/test_timedelta.py | 8 +- tests/fast/pandas/test_timestamp.py | 10 +- tests/fast/relational_api/test_groupings.py | 7 +- tests/fast/relational_api/test_joins.py | 7 +- tests/fast/relational_api/test_pivot.py | 6 +- .../relational_api/test_rapi_aggregations.py | 9 +- tests/fast/relational_api/test_rapi_close.py | 5 +- .../relational_api/test_rapi_description.py | 3 +- .../relational_api/test_rapi_functions.py | 2 +- tests/fast/relational_api/test_rapi_query.py | 14 +- .../fast/relational_api/test_rapi_windows.py | 5 +- .../relational_api/test_table_function.py | 8 +- tests/fast/spark/test_replace_column_value.py | 2 +- tests/fast/spark/test_replace_empty_value.py | 3 +- tests/fast/spark/test_spark_arrow_table.py | 2 - tests/fast/spark/test_spark_catalog.py | 4 +- tests/fast/spark/test_spark_column.py | 12 +- tests/fast/spark/test_spark_dataframe.py | 28 +- tests/fast/spark/test_spark_dataframe_sort.py | 8 +- .../fast/spark/test_spark_drop_duplicates.py | 4 +- tests/fast/spark/test_spark_except.py | 2 - tests/fast/spark/test_spark_filter.py | 22 +- .../spark/test_spark_function_concat_ws.py | 4 +- .../fast/spark/test_spark_functions_array.py | 5 +- .../fast/spark/test_spark_functions_base64.py | 2 +- tests/fast/spark/test_spark_functions_date.py | 5 +- tests/fast/spark/test_spark_functions_expr.py | 2 +- tests/fast/spark/test_spark_functions_hash.py | 2 +- tests/fast/spark/test_spark_functions_hex.py | 6 +- tests/fast/spark/test_spark_functions_null.py | 2 +- .../spark/test_spark_functions_numeric.py | 3 +- .../fast/spark/test_spark_functions_string.py | 2 +- tests/fast/spark/test_spark_group_by.py | 48 +- tests/fast/spark/test_spark_intersect.py | 2 - tests/fast/spark/test_spark_join.py | 18 +- tests/fast/spark/test_spark_limit.py | 2 +- tests/fast/spark/test_spark_order_by.py | 15 +- .../fast/spark/test_spark_pandas_dataframe.py | 20 +- tests/fast/spark/test_spark_readcsv.py | 7 +- tests/fast/spark/test_spark_readjson.py | 5 +- tests/fast/spark/test_spark_readparquet.py | 5 +- tests/fast/spark/test_spark_runtime_config.py | 2 +- tests/fast/spark/test_spark_session.py | 7 +- tests/fast/spark/test_spark_to_csv.py | 24 +- tests/fast/spark/test_spark_to_parquet.py | 9 +- tests/fast/spark/test_spark_transform.py | 19 +- tests/fast/spark/test_spark_types.py | 41 +- tests/fast/spark/test_spark_udf.py | 2 +- tests/fast/spark/test_spark_union.py | 9 +- tests/fast/spark/test_spark_union_by_name.py | 17 +- tests/fast/spark/test_spark_with_column.py | 18 +- .../spark/test_spark_with_column_renamed.py | 20 +- tests/fast/spark/test_spark_with_columns.py | 2 +- .../spark/test_spark_with_columns_renamed.py | 5 +- tests/fast/sqlite/test_types.py | 2 +- tests/fast/test_alex_multithread.py | 5 +- tests/fast/test_all_types.py | 19 +- tests/fast/test_ambiguous_prepare.py | 5 +- tests/fast/test_case_alias.py | 13 +- tests/fast/test_context_manager.py | 2 +- tests/fast/test_duckdb_api.py | 3 +- tests/fast/test_expression.py | 25 +- tests/fast/test_filesystem.py | 18 +- tests/fast/test_get_table_names.py | 5 +- tests/fast/test_import_export.py | 9 +- tests/fast/test_insert.py | 10 +- tests/fast/test_json_logging.py | 3 +- tests/fast/test_many_con_same_file.py | 4 +- tests/fast/test_map.py | 13 +- tests/fast/test_metatransaction.py | 2 +- tests/fast/test_multi_statement.py | 6 +- tests/fast/test_multithread.py | 20 +- tests/fast/test_non_default_conn.py | 9 +- tests/fast/test_parameter_list.py | 7 +- tests/fast/test_parquet.py | 10 +- tests/fast/test_pypi_cleanup.py | 17 +- tests/fast/test_pytorch.py | 2 +- tests/fast/test_relation.py | 17 +- tests/fast/test_relation_dependency_leak.py | 8 +- tests/fast/test_replacement_scan.py | 8 +- tests/fast/test_result.py | 8 +- tests/fast/test_runtime_error.py | 7 +- tests/fast/test_sql_expression.py | 5 +- tests/fast/test_string_annotation.py | 8 +- tests/fast/test_tf.py | 2 +- tests/fast/test_transaction.py | 4 +- tests/fast/test_type.py | 56 +- tests/fast/test_type_explicit.py | 2 +- tests/fast/test_unicode.py | 6 +- tests/fast/test_union.py | 4 +- tests/fast/test_value.py | 110 +- tests/fast/test_version.py | 3 +- tests/fast/test_versioning.py | 15 +- tests/fast/test_windows_abs_path.py | 6 +- tests/fast/types/test_blob.py | 3 +- tests/fast/types/test_boolean.py | 4 +- tests/fast/types/test_datetime_date.py | 5 +- tests/fast/types/test_datetime_datetime.py | 6 +- tests/fast/types/test_decimal.py | 6 +- tests/fast/types/test_hugeint.py | 3 +- tests/fast/types/test_nan.py | 8 +- tests/fast/types/test_nested.py | 3 +- tests/fast/types/test_null.py | 3 +- tests/fast/types/test_numeric.py | 4 +- tests/fast/types/test_numpy.py | 9 +- tests/fast/types/test_object_int.py | 11 +- tests/fast/types/test_time_tz.py | 9 +- tests/fast/types/test_unsigned.py | 2 +- tests/fast/udf/test_null_filtering.py | 13 +- tests/fast/udf/test_remove_function.py | 13 +- tests/fast/udf/test_scalar.py | 28 +- tests/fast/udf/test_scalar_arrow.py | 14 +- tests/fast/udf/test_scalar_native.py | 9 +- tests/fast/udf/test_transactionality.py | 5 +- tests/slow/test_h2oai_arrow.py | 10 +- 296 files changed, 2147 insertions(+), 2425 deletions(-) diff --git a/adbc_driver_duckdb/dbapi.py b/adbc_driver_duckdb/dbapi.py index 793c4242..7d703713 100644 --- a/adbc_driver_duckdb/dbapi.py +++ b/adbc_driver_duckdb/dbapi.py @@ -15,14 +15,14 @@ # specific language governing permissions and limitations # under the License. -""" -DBAPI 2.0-compatible facade for the ADBC DuckDB driver. +"""DBAPI 2.0-compatible facade for the ADBC DuckDB driver. """ import typing import adbc_driver_manager import adbc_driver_manager.dbapi + import adbc_driver_duckdb __all__ = [ diff --git a/duckdb/__init__.py b/duckdb/__init__.py index bf50be5b..73fcbbd2 100644 --- a/duckdb/__init__.py +++ b/duckdb/__init__.py @@ -1,8 +1,10 @@ # Modules +from importlib.metadata import version + +from _duckdb import __version__ as duckdb_version + import duckdb.functional as functional import duckdb.typing as typing -from _duckdb import __version__ as duckdb_version -from importlib.metadata import version # duckdb.__version__ returns the version of the distribution package, i.e. the pypi version __version__ = version("duckdb") @@ -62,25 +64,25 @@ def __repr__(self): # Classes from _duckdb import ( - DuckDBPyRelation, + CaseExpression, + CoalesceOperator, + ColumnExpression, + ConstantExpression, + CSVLineTerminator, + DefaultExpression, DuckDBPyConnection, - Statement, - ExplainType, - StatementType, + DuckDBPyRelation, ExpectedResultType, - CSVLineTerminator, - PythonExceptionHandling, - RenderMode, + ExplainType, Expression, - ConstantExpression, - ColumnExpression, - DefaultExpression, - CoalesceOperator, - LambdaExpression, - StarExpression, FunctionExpression, - CaseExpression, + LambdaExpression, + PythonExceptionHandling, + RenderMode, SQLExpression, + StarExpression, + Statement, + StatementType, ) _exported_symbols.extend( @@ -104,91 +106,85 @@ def __repr__(self): # These are overloaded twice, we define them inside of C++ so pybind can deal with it _exported_symbols.extend(["df", "arrow"]) -from _duckdb import df, arrow - # NOTE: this section is generated by tools/pythonpkg/scripts/generate_connection_wrapper_methods.py. # Do not edit this section manually, your changes will be overwritten! - # START OF CONNECTION WRAPPER - from _duckdb import ( - cursor, - register_filesystem, - unregister_filesystem, - list_filesystems, - filesystem_is_registered, - create_function, - remove_function, - sqltype, - dtype, - type, + aggregate, + alias, + append, array_type, - list_type, - union_type, - string_type, - enum_type, + arrow, + begin, + checkpoint, + close, + commit, + create_function, + cursor, decimal_type, - struct_type, - row_type, - map_type, + description, + df, + distinct, + dtype, duplicate, + enum_type, execute, executemany, - close, - interrupt, - query_progress, - fetchone, - fetchmany, - fetchall, - fetchnumpy, - fetchdf, + extract_statements, + fetch_arrow_table, fetch_df, - df, fetch_df_chunk, - pl, - fetch_arrow_table, - arrow, fetch_record_batch, - torch, - tf, - begin, - commit, - rollback, - checkpoint, - append, - register, - unregister, - table, - view, - values, - table_function, - read_json, - extract_statements, - sql, - query, - from_query, - read_csv, + fetchall, + fetchdf, + fetchmany, + fetchnumpy, + fetchone, + filesystem_is_registered, + filter, + from_arrow, from_csv_auto, from_df, - from_arrow, - from_parquet, - read_parquet, from_parquet, - read_parquet, + from_query, get_table_names, install_extension, - load_extension, - project, - distinct, - write_csv, - aggregate, - alias, - filter, + interrupt, limit, + list_filesystems, + list_type, + load_extension, + map_type, order, + pl, + project, + query, query_df, - description, + query_progress, + read_csv, + read_json, + read_parquet, + register, + register_filesystem, + remove_function, + rollback, + row_type, rowcount, + sql, + sqltype, + string_type, + struct_type, + table, + table_function, + tf, + torch, + type, + union_type, + unregister, + unregister_filesystem, + values, + view, + write_csv, ) _exported_symbols.extend( @@ -276,17 +272,17 @@ def __repr__(self): # END OF CONNECTION WRAPPER # Enums -from _duckdb import ANALYZE, DEFAULT, RETURN_NULL, STANDARD, COLUMNS, ROWS +from _duckdb import ANALYZE, COLUMNS, DEFAULT, RETURN_NULL, ROWS, STANDARD _exported_symbols.extend(["ANALYZE", "DEFAULT", "RETURN_NULL", "STANDARD"]) # read-only properties from _duckdb import ( - __standard_vector_size__, + __formatted_python_version__, __interactive__, __jupyter__, - __formatted_python_version__, + __standard_vector_size__, apilevel, comment, identifier, @@ -337,35 +333,35 @@ def __repr__(self): # Exceptions from _duckdb import ( - Error, - DataError, + BinderException, + CatalogException, + ConnectionException, + ConstraintException, ConversionException, - OutOfRangeException, - TypeMismatchException, + DataError, + Error, FatalException, + HTTPException, IntegrityError, - ConstraintException, InternalError, InternalException, InterruptException, - NotSupportedError, + InvalidInputException, + InvalidTypeException, + IOException, NotImplementedException, + NotSupportedError, OperationalError, - ConnectionException, - IOException, - HTTPException, OutOfMemoryException, - SerializationException, - TransactionException, + OutOfRangeException, + ParserException, PermissionException, ProgrammingError, - BinderException, - CatalogException, - InvalidInputException, - InvalidTypeException, - ParserException, - SyntaxException, SequenceException, + SerializationException, + SyntaxException, + TransactionException, + TypeMismatchException, Warning, ) @@ -406,34 +402,34 @@ def __repr__(self): # Value from duckdb.value.constant import ( - Value, - NullValue, - BooleanValue, - UnsignedBinaryValue, - UnsignedShortValue, - UnsignedIntegerValue, - UnsignedLongValue, BinaryValue, - ShortValue, - IntegerValue, - LongValue, - HugeIntegerValue, - FloatValue, - DoubleValue, - DecimalValue, - StringValue, - UUIDValue, BitValue, BlobValue, + BooleanValue, DateValue, + DecimalValue, + DoubleValue, + FloatValue, + HugeIntegerValue, + IntegerValue, IntervalValue, - TimestampValue, - TimestampSecondValue, + LongValue, + NullValue, + ShortValue, + StringValue, TimestampMilisecondValue, TimestampNanosecondValue, + TimestampSecondValue, TimestampTimeZoneValue, - TimeValue, + TimestampValue, TimeTimeZoneValue, + TimeValue, + UnsignedBinaryValue, + UnsignedIntegerValue, + UnsignedLongValue, + UnsignedShortValue, + UUIDValue, + Value, ) _exported_symbols.extend( diff --git a/duckdb/experimental/spark/__init__.py b/duckdb/experimental/spark/__init__.py index 66895dcb..bdde2ef8 100644 --- a/duckdb/experimental/spark/__init__.py +++ b/duckdb/experimental/spark/__init__.py @@ -1,7 +1,7 @@ -from .sql import SparkSession, DataFrame +from ._globals import _NoValue from .conf import SparkConf from .context import SparkContext -from ._globals import _NoValue from .exception import ContributionsAcceptedError +from .sql import DataFrame, SparkSession -__all__ = ["SparkSession", "DataFrame", "SparkConf", "SparkContext", "ContributionsAcceptedError"] +__all__ = ["ContributionsAcceptedError", "DataFrame", "SparkConf", "SparkContext", "SparkSession"] diff --git a/duckdb/experimental/spark/_globals.py b/duckdb/experimental/spark/_globals.py index d6a02326..4bc325f7 100644 --- a/duckdb/experimental/spark/_globals.py +++ b/duckdb/experimental/spark/_globals.py @@ -15,8 +15,7 @@ # limitations under the License. # -""" -Module defining global singleton classes. +"""Module defining global singleton classes. This module raises a RuntimeError if an attempt to reload it is made. In that way the identities of the classes defined here are fixed and will remain so diff --git a/duckdb/experimental/spark/_typing.py b/duckdb/experimental/spark/_typing.py index 251ef695..12d16ced 100644 --- a/duckdb/experimental/spark/_typing.py +++ b/duckdb/experimental/spark/_typing.py @@ -16,10 +16,11 @@ # specific language governing permissions and limitations # under the License. -from typing import Callable, Iterable, Sized, TypeVar, Union -from typing_extensions import Literal, Protocol +from collections.abc import Iterable, Sized +from typing import Callable, TypeVar, Union -from numpy import int32, int64, float32, float64, ndarray +from numpy import float32, float64, int32, int64, ndarray +from typing_extensions import Literal, Protocol F = TypeVar("F", bound=Callable) T_co = TypeVar("T_co", covariant=True) diff --git a/duckdb/experimental/spark/conf.py b/duckdb/experimental/spark/conf.py index 79706781..ea1153b4 100644 --- a/duckdb/experimental/spark/conf.py +++ b/duckdb/experimental/spark/conf.py @@ -1,4 +1,5 @@ -from typing import Optional, List, Tuple +from typing import Optional + from duckdb.experimental.spark.exception import ContributionsAcceptedError diff --git a/duckdb/experimental/spark/context.py b/duckdb/experimental/spark/context.py index dd4b016c..9f1b4155 100644 --- a/duckdb/experimental/spark/context.py +++ b/duckdb/experimental/spark/context.py @@ -1,9 +1,9 @@ from typing import Optional + import duckdb from duckdb import DuckDBPyConnection - -from duckdb.experimental.spark.exception import ContributionsAcceptedError from duckdb.experimental.spark.conf import SparkConf +from duckdb.experimental.spark.exception import ContributionsAcceptedError class SparkContext: diff --git a/duckdb/experimental/spark/errors/__init__.py b/duckdb/experimental/spark/errors/__init__.py index 6aac49d7..2f265d97 100644 --- a/duckdb/experimental/spark/errors/__init__.py +++ b/duckdb/experimental/spark/errors/__init__.py @@ -15,59 +15,57 @@ # limitations under the License. # -""" -PySpark exceptions. +"""PySpark exceptions. """ -from .exceptions.base import ( # noqa: F401 - PySparkException, +from .exceptions.base import ( AnalysisException, - TempTableAlreadyExistsException, - ParseException, - IllegalArgumentException, ArithmeticException, - UnsupportedOperationException, ArrayIndexOutOfBoundsException, DateTimeException, + IllegalArgumentException, NumberFormatException, - StreamingQueryException, - QueryExecutionException, + ParseException, + PySparkAssertionError, + PySparkAttributeError, + PySparkException, + PySparkIndexError, + PySparkNotImplementedError, + PySparkRuntimeError, + PySparkTypeError, + PySparkValueError, PythonException, - UnknownException, + QueryExecutionException, SparkRuntimeException, SparkUpgradeException, - PySparkTypeError, - PySparkValueError, - PySparkIndexError, - PySparkAttributeError, - PySparkRuntimeError, - PySparkAssertionError, - PySparkNotImplementedError, + StreamingQueryException, + TempTableAlreadyExistsException, + UnknownException, + UnsupportedOperationException, ) - __all__ = [ - "PySparkException", "AnalysisException", - "TempTableAlreadyExistsException", - "ParseException", - "IllegalArgumentException", "ArithmeticException", - "UnsupportedOperationException", "ArrayIndexOutOfBoundsException", "DateTimeException", + "IllegalArgumentException", "NumberFormatException", - "StreamingQueryException", - "QueryExecutionException", + "ParseException", + "PySparkAssertionError", + "PySparkAttributeError", + "PySparkException", + "PySparkIndexError", + "PySparkNotImplementedError", + "PySparkRuntimeError", + "PySparkTypeError", + "PySparkValueError", "PythonException", - "UnknownException", + "QueryExecutionException", "SparkRuntimeException", "SparkUpgradeException", - "PySparkTypeError", - "PySparkValueError", - "PySparkIndexError", - "PySparkAttributeError", - "PySparkRuntimeError", - "PySparkAssertionError", - "PySparkNotImplementedError", + "StreamingQueryException", + "TempTableAlreadyExistsException", + "UnknownException", + "UnsupportedOperationException", ] diff --git a/duckdb/experimental/spark/errors/exceptions/base.py b/duckdb/experimental/spark/errors/exceptions/base.py index 48a3ea95..a6f1f940 100644 --- a/duckdb/experimental/spark/errors/exceptions/base.py +++ b/duckdb/experimental/spark/errors/exceptions/base.py @@ -1,11 +1,10 @@ -from typing import Dict, Optional, cast +from typing import Optional, cast from ..utils import ErrorClassesReader class PySparkException(Exception): - """ - Base Exception for handling errors generated from PySpark. + """Base Exception for handling errors generated from PySpark. """ def __init__( @@ -25,7 +24,7 @@ def __init__( if message is None: self.message = self.error_reader.get_error_message( - cast(str, error_class), cast(dict[str, str], message_parameters) + cast("str", error_class), cast("dict[str, str]", message_parameters) ) else: self.message = message @@ -34,12 +33,11 @@ def __init__( self.message_parameters = message_parameters def getErrorClass(self) -> Optional[str]: - """ - Returns an error class as a string. + """Returns an error class as a string. .. versionadded:: 3.4.0 - See Also + See Also: -------- :meth:`PySparkException.getMessageParameters` :meth:`PySparkException.getSqlState` @@ -47,12 +45,11 @@ def getErrorClass(self) -> Optional[str]: return self.error_class def getMessageParameters(self) -> Optional[dict[str, str]]: - """ - Returns a message parameters as a dictionary. + """Returns a message parameters as a dictionary. .. versionadded:: 3.4.0 - See Also + See Also: -------- :meth:`PySparkException.getErrorClass` :meth:`PySparkException.getSqlState` @@ -60,14 +57,13 @@ def getMessageParameters(self) -> Optional[dict[str, str]]: return self.message_parameters def getSqlState(self) -> None: - """ - Returns an SQLSTATE as a string. + """Returns an SQLSTATE as a string. Errors generated in Python have no SQLSTATE, so it always returns None. .. versionadded:: 3.4.0 - See Also + See Also: -------- :meth:`PySparkException.getErrorClass` :meth:`PySparkException.getMessageParameters` @@ -82,138 +78,115 @@ def __str__(self) -> str: class AnalysisException(PySparkException): - """ - Failed to analyze a SQL query plan. + """Failed to analyze a SQL query plan. """ class SessionNotSameException(PySparkException): - """ - Performed the same operation on different SparkSession. + """Performed the same operation on different SparkSession. """ class TempTableAlreadyExistsException(AnalysisException): - """ - Failed to create temp view since it is already exists. + """Failed to create temp view since it is already exists. """ class ParseException(AnalysisException): - """ - Failed to parse a SQL command. + """Failed to parse a SQL command. """ class IllegalArgumentException(PySparkException): - """ - Passed an illegal or inappropriate argument. + """Passed an illegal or inappropriate argument. """ class ArithmeticException(PySparkException): - """ - Arithmetic exception thrown from Spark with an error class. + """Arithmetic exception thrown from Spark with an error class. """ class UnsupportedOperationException(PySparkException): - """ - Unsupported operation exception thrown from Spark with an error class. + """Unsupported operation exception thrown from Spark with an error class. """ class ArrayIndexOutOfBoundsException(PySparkException): - """ - Array index out of bounds exception thrown from Spark with an error class. + """Array index out of bounds exception thrown from Spark with an error class. """ class DateTimeException(PySparkException): - """ - Datetime exception thrown from Spark with an error class. + """Datetime exception thrown from Spark with an error class. """ class NumberFormatException(IllegalArgumentException): - """ - Number format exception thrown from Spark with an error class. + """Number format exception thrown from Spark with an error class. """ class StreamingQueryException(PySparkException): - """ - Exception that stopped a :class:`StreamingQuery`. + """Exception that stopped a :class:`StreamingQuery`. """ class QueryExecutionException(PySparkException): - """ - Failed to execute a query. + """Failed to execute a query. """ class PythonException(PySparkException): - """ - Exceptions thrown from Python workers. + """Exceptions thrown from Python workers. """ class SparkRuntimeException(PySparkException): - """ - Runtime exception thrown from Spark with an error class. + """Runtime exception thrown from Spark with an error class. """ class SparkUpgradeException(PySparkException): - """ - Exception thrown because of Spark upgrade. + """Exception thrown because of Spark upgrade. """ class UnknownException(PySparkException): - """ - None of the above exceptions. + """None of the above exceptions. """ class PySparkValueError(PySparkException, ValueError): - """ - Wrapper class for ValueError to support error classes. + """Wrapper class for ValueError to support error classes. """ class PySparkIndexError(PySparkException, IndexError): - """ - Wrapper class for IndexError to support error classes. + """Wrapper class for IndexError to support error classes. """ class PySparkTypeError(PySparkException, TypeError): - """ - Wrapper class for TypeError to support error classes. + """Wrapper class for TypeError to support error classes. """ class PySparkAttributeError(PySparkException, AttributeError): - """ - Wrapper class for AttributeError to support error classes. + """Wrapper class for AttributeError to support error classes. """ class PySparkRuntimeError(PySparkException, RuntimeError): - """ - Wrapper class for RuntimeError to support error classes. + """Wrapper class for RuntimeError to support error classes. """ class PySparkAssertionError(PySparkException, AssertionError): - """ - Wrapper class for AssertionError to support error classes. + """Wrapper class for AssertionError to support error classes. """ class PySparkNotImplementedError(PySparkException, NotImplementedError): - """ - Wrapper class for NotImplementedError to support error classes. + """Wrapper class for NotImplementedError to support error classes. """ diff --git a/duckdb/experimental/spark/errors/utils.py b/duckdb/experimental/spark/errors/utils.py index f1b37f75..c8c66896 100644 --- a/duckdb/experimental/spark/errors/utils.py +++ b/duckdb/experimental/spark/errors/utils.py @@ -16,22 +16,19 @@ # import re -from typing import Dict from .error_classes import ERROR_CLASSES_MAP class ErrorClassesReader: - """ - A reader to load error information from error_classes.py. + """A reader to load error information from error_classes.py. """ def __init__(self) -> None: self.error_info_map = ERROR_CLASSES_MAP def get_error_message(self, error_class: str, message_parameters: dict[str, str]) -> str: - """ - Returns the completed error message by applying message parameters to the message template. + """Returns the completed error message by applying message parameters to the message template. """ message_template = self.get_message_template(error_class) # Verify message parameters. @@ -44,8 +41,7 @@ def get_error_message(self, error_class: str, message_parameters: dict[str, str] return message_template.translate(table).format(**message_parameters) def get_message_template(self, error_class: str) -> str: - """ - Returns the message template for corresponding error class from error_classes.py. + """Returns the message template for corresponding error class from error_classes.py. For example, when given `error_class` is "EXAMPLE_ERROR_CLASS", diff --git a/duckdb/experimental/spark/exception.py b/duckdb/experimental/spark/exception.py index 60495d88..791f7090 100644 --- a/duckdb/experimental/spark/exception.py +++ b/duckdb/experimental/spark/exception.py @@ -1,6 +1,5 @@ class ContributionsAcceptedError(NotImplementedError): - """ - This method is not planned to be implemented, if you would like to implement this method + """This method is not planned to be implemented, if you would like to implement this method or show your interest in this method to other members of the community, feel free to open up a PR or a Discussion over on https://github.com/duckdb/duckdb """ diff --git a/duckdb/experimental/spark/sql/__init__.py b/duckdb/experimental/spark/sql/__init__.py index 2312ee50..9ae09308 100644 --- a/duckdb/experimental/spark/sql/__init__.py +++ b/duckdb/experimental/spark/sql/__init__.py @@ -1,7 +1,7 @@ -from .session import SparkSession -from .readwriter import DataFrameWriter -from .dataframe import DataFrame -from .conf import RuntimeConfig from .catalog import Catalog +from .conf import RuntimeConfig +from .dataframe import DataFrame +from .readwriter import DataFrameWriter +from .session import SparkSession -__all__ = ["SparkSession", "DataFrame", "RuntimeConfig", "DataFrameWriter", "Catalog"] +__all__ = ["Catalog", "DataFrame", "DataFrameWriter", "RuntimeConfig", "SparkSession"] diff --git a/duckdb/experimental/spark/sql/_typing.py b/duckdb/experimental/spark/sql/_typing.py index b5a8b079..caf0058c 100644 --- a/duckdb/experimental/spark/sql/_typing.py +++ b/duckdb/experimental/spark/sql/_typing.py @@ -19,9 +19,7 @@ from typing import ( Any, Callable, - List, Optional, - Tuple, TypeVar, Union, ) diff --git a/duckdb/experimental/spark/sql/catalog.py b/duckdb/experimental/spark/sql/catalog.py index 3cc96f45..8e510fdf 100644 --- a/duckdb/experimental/spark/sql/catalog.py +++ b/duckdb/experimental/spark/sql/catalog.py @@ -1,4 +1,5 @@ -from typing import List, NamedTuple, Optional +from typing import NamedTuple, Optional + from .session import SparkSession @@ -75,4 +76,4 @@ def setCurrentDatabase(self, dbName: str) -> None: raise NotImplementedError -__all__ = ["Catalog", "Table", "Column", "Function", "Database"] +__all__ = ["Catalog", "Column", "Database", "Function", "Table"] diff --git a/duckdb/experimental/spark/sql/column.py b/duckdb/experimental/spark/sql/column.py index f78b31ae..3a6f6cea 100644 --- a/duckdb/experimental/spark/sql/column.py +++ b/duckdb/experimental/spark/sql/column.py @@ -1,13 +1,12 @@ -from typing import Union, TYPE_CHECKING, Any, cast, Callable, Tuple -from ..exception import ContributionsAcceptedError +from typing import TYPE_CHECKING, Any, Callable, Union, cast +from ..exception import ContributionsAcceptedError from .types import DataType if TYPE_CHECKING: - from ._typing import ColumnOrName, LiteralType, DecimalLiteral, DateTimeLiteral - -from duckdb import ConstantExpression, ColumnExpression, FunctionExpression, Expression + from ._typing import DateTimeLiteral, DecimalLiteral, LiteralType +from duckdb import ColumnExpression, ConstantExpression, Expression, FunctionExpression from duckdb.typing import DuckDBPyType __all__ = ["Column"] @@ -78,8 +77,7 @@ def _( class Column: - """ - A column in a DataFrame. + """A column in a DataFrame. :class:`Column` instances can be created by:: @@ -139,8 +137,7 @@ def __neg__(self) -> "Column": __rpow__ = _bin_op("__rpow__") def __getitem__(self, k: Any) -> "Column": - """ - An expression that gets an item at position ``ordinal`` out of a list, + """An expression that gets an item at position ``ordinal`` out of a list, or gets an item by key out of a dict. .. versionadded:: 1.3.0 @@ -153,13 +150,13 @@ def __getitem__(self, k: Any) -> "Column": k a literal value, or a slice object without step. - Returns + Returns: ------- :class:`Column` Column representing the item got by key out of a dict, or substrings sliced by the given slice object. - Examples + Examples: -------- >>> df = spark.createDataFrame([("abcedfg", {"key": "value"})], ["l", "d"]) >>> df.select(df.l[slice(1, 3)], df.d["key"]).show() @@ -180,8 +177,7 @@ def __getitem__(self, k: Any) -> "Column": return Column(ColumnExpression(expr_str)) def __getattr__(self, item: Any) -> "Column": - """ - An expression that gets an item at position ``ordinal`` out of a list, + """An expression that gets an item at position ``ordinal`` out of a list, or gets an item by key out of a dict. Parameters @@ -189,12 +185,12 @@ def __getattr__(self, item: Any) -> "Column": item a literal value. - Returns + Returns: ------- :class:`Column` Column representing the item got by key out of a dict. - Examples + Examples: -------- >>> df = spark.createDataFrame([("abcedfg", {"key": "value"})], ["l", "d"]) >>> df.select(df.d.key).show() @@ -234,10 +230,10 @@ def cast(self, dataType: Union[DataType, str]) -> "Column": def isin(self, *cols: Any) -> "Column": if len(cols) == 1 and isinstance(cols[0], (list, set)): # Only one argument supplied, it's a list - cols = cast(tuple, cols[0]) + cols = cast("tuple", cols[0]) cols = cast( - tuple, + "tuple", [_get_expr(c) for c in cols], ) return Column(self.expr.isin(*cols)) @@ -247,14 +243,14 @@ def __eq__( # type: ignore[override] self, other: Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"], ) -> "Column": - """binary function""" + """Binary function""" return Column(self.expr == (_get_expr(other))) def __ne__( # type: ignore[override] self, - other: Any, + other: object, ) -> "Column": - """binary function""" + """Binary function""" return Column(self.expr != (_get_expr(other))) __lt__ = _bin_op("__lt__") diff --git a/duckdb/experimental/spark/sql/conf.py b/duckdb/experimental/spark/sql/conf.py index 8e30d7ca..8ab9fa38 100644 --- a/duckdb/experimental/spark/sql/conf.py +++ b/duckdb/experimental/spark/sql/conf.py @@ -1,6 +1,7 @@ from typing import Optional, Union -from duckdb.experimental.spark._globals import _NoValueType, _NoValue + from duckdb import DuckDBPyConnection +from duckdb.experimental.spark._globals import _NoValue, _NoValueType class RuntimeConfig: diff --git a/duckdb/experimental/spark/sql/dataframe.py b/duckdb/experimental/spark/sql/dataframe.py index 19f5576b..3f32aa32 100644 --- a/duckdb/experimental/spark/sql/dataframe.py +++ b/duckdb/experimental/spark/sql/dataframe.py @@ -1,25 +1,22 @@ +import uuid from functools import reduce +from keyword import iskeyword from typing import ( TYPE_CHECKING, Any, Callable, - List, - Dict, Optional, - Tuple, Union, cast, overload, ) -import uuid -from keyword import iskeyword import duckdb from duckdb import ColumnExpression, Expression, StarExpression -from ._typing import ColumnOrName -from ..errors import PySparkTypeError, PySparkValueError, PySparkIndexError +from ..errors import PySparkIndexError, PySparkTypeError, PySparkValueError from ..exception import ContributionsAcceptedError +from ._typing import ColumnOrName from .column import Column from .readwriter import DataFrameWriter from .type_utils import duckdb_to_spark_schema @@ -29,10 +26,9 @@ import pyarrow as pa from pandas.core.frame import DataFrame as PandasDataFrame - from .group import GroupedData, Grouping + from .group import GroupedData from .session import SparkSession -from ..errors import PySparkValueError from .functions import _to_column_expr, col, lit @@ -51,21 +47,20 @@ def toPandas(self) -> "PandasDataFrame": return self.relation.df() def toArrow(self) -> "pa.Table": - """ - Returns the contents of this :class:`DataFrame` as PyArrow ``pyarrow.Table``. + """Returns the contents of this :class:`DataFrame` as PyArrow ``pyarrow.Table``. This is only available if PyArrow is installed and available. .. versionadded:: 4.0.0 - Notes + Notes: ----- This method should only be used if the resulting PyArrow ``pyarrow.Table`` is expected to be small, as all the data is loaded into the driver's memory. This API is a developer API. - Examples + Examples: -------- >>> df.toArrow() # doctest: +SKIP pyarrow.Table @@ -88,7 +83,7 @@ def createOrReplaceTempView(self, name: str) -> None: name : str Name of the view. - Examples + Examples: -------- Create a local temporary view named 'people'. @@ -144,8 +139,7 @@ def withColumn(self, columnName: str, col: Column) -> "DataFrame": return DataFrame(rel, self.session) def withColumns(self, *colsMap: dict[str, Column]) -> "DataFrame": - """ - Returns a new :class:`DataFrame` by adding multiple columns or replacing the + """Returns a new :class:`DataFrame` by adding multiple columns or replacing the existing columns that have the same names. The colsMap is a map of column name and column, the column must only refer to attributes @@ -162,12 +156,12 @@ def withColumns(self, *colsMap: dict[str, Column]) -> "DataFrame": colsMap : dict a dict of column name and :class:`Column`. Currently, only a single map is supported. - Returns + Returns: ------- :class:`DataFrame` DataFrame with new or replaced columns. - Examples + Examples: -------- >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"]) >>> df.withColumns({"age2": df.age + 2, "age3": df.age + 3}).show() @@ -219,8 +213,7 @@ def withColumns(self, *colsMap: dict[str, Column]) -> "DataFrame": return DataFrame(rel, self.session) def withColumnsRenamed(self, colsMap: dict[str, str]) -> "DataFrame": - """ - Returns a new :class:`DataFrame` by renaming multiple columns. + """Returns a new :class:`DataFrame` by renaming multiple columns. This is a no-op if the schema doesn't contain the given column names. .. versionadded:: 3.4.0 @@ -232,20 +225,20 @@ def withColumnsRenamed(self, colsMap: dict[str, str]) -> "DataFrame": a dict of existing column names and corresponding desired column names. Currently, only a single map is supported. - Returns + Returns: ------- :class:`DataFrame` DataFrame with renamed columns. - See Also + See Also: -------- :meth:`withColumnRenamed` - Notes + Notes: ----- Support Spark Connect - Examples + Examples: -------- >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"]) >>> df = df.withColumns({"age2": df.age + 2, "age3": df.age + 3}) @@ -308,12 +301,12 @@ def transform(self, func: Callable[..., "DataFrame"], *args: Any, **kwargs: Any) .. versionadded:: 3.3.0 - Returns + Returns: ------- :class:`DataFrame` Transformed DataFrame. - Examples + Examples: -------- >>> from pyspark.sql.functions import col >>> df = spark.createDataFrame([(1, 1.0), (2, 2.0)], ["int", "float"]) @@ -362,12 +355,12 @@ def sort(self, *cols: Union[str, Column, list[Union[str, Column]]], **kwargs: An Sort ascending vs. descending. Specify list for multiple sort orders. If a list is specified, the length of the list must equal the length of the `cols`. - Returns + Returns: ------- :class:`DataFrame` Sorted DataFrame. - Examples + Examples: -------- >>> from pyspark.sql.functions import desc, asc >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"]) @@ -499,12 +492,12 @@ def filter(self, condition: "ColumnOrName") -> "DataFrame": a :class:`Column` of :class:`types.BooleanType` or a string of SQL expressions. - Returns + Returns: ------- :class:`DataFrame` Filtered DataFrame. - Examples + Examples: -------- >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"]) @@ -567,7 +560,7 @@ def select(self, *cols) -> "DataFrame": def columns(self) -> list[str]: """Returns all column names as a list. - Examples + Examples: -------- >>> df.columns ['age', 'name'] @@ -607,12 +600,12 @@ def join( ``right``, ``rightouter``, ``right_outer``, ``semi``, ``leftsemi``, ``left_semi``, ``anti``, ``leftanti`` and ``left_anti``. - Returns + Returns: ------- :class:`DataFrame` Joined DataFrame. - Examples + Examples: -------- The following performs a full outer join between ``df1`` and ``df2``. @@ -678,7 +671,6 @@ def join( | Bob| 5| +-----+---+ """ - if on is not None and not isinstance(on, list): on = [on] # type: ignore[assignment] if on is not None and not all([isinstance(x, str) for x in on]): @@ -688,7 +680,7 @@ def join( # & all the Expressions together to form one Expression assert isinstance(on[0], Expression), "on should be Column or list of Column" - on = reduce(lambda x, y: x.__and__(y), cast(list[Expression], on)) + on = reduce(lambda x, y: x.__and__(y), cast("list[Expression]", on)) if on is None and how is None: result = self.relation.join(other.relation) @@ -740,12 +732,12 @@ def crossJoin(self, other: "DataFrame") -> "DataFrame": other : :class:`DataFrame` Right side of the cartesian product. - Returns + Returns: ------- :class:`DataFrame` Joined DataFrame. - Examples + Examples: -------- >>> from pyspark.sql import Row >>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) @@ -772,12 +764,12 @@ def alias(self, alias: str) -> "DataFrame": alias : str an alias name to be set for the :class:`DataFrame`. - Returns + Returns: ------- :class:`DataFrame` Aliased DataFrame. - Examples + Examples: -------- >>> from pyspark.sql.functions import col, desc >>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) @@ -827,12 +819,12 @@ def limit(self, num: int) -> "DataFrame": Number of records to return. Will return this number of records or all records if the DataFrame contains less than this number of records. - Returns + Returns: ------- :class:`DataFrame` Subset of the records - Examples + Examples: -------- >>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) >>> df.limit(1).show() @@ -851,8 +843,7 @@ def limit(self, num: int) -> "DataFrame": return DataFrame(rel, self.session) def __contains__(self, item: str) -> bool: - """ - Check if the :class:`DataFrame` contains a column by the name of `item` + """Check if the :class:`DataFrame` contains a column by the name of `item` """ return item in self.relation @@ -860,7 +851,7 @@ def __contains__(self, item: str) -> bool: def schema(self) -> StructType: """Returns the schema of this :class:`DataFrame` as a :class:`duckdb.experimental.spark.sql.types.StructType`. - Examples + Examples: -------- >>> df.schema StructType([StructField('age', IntegerType(), True), @@ -877,7 +868,7 @@ def __getitem__(self, item: Union[Column, list, tuple]) -> "DataFrame": ... def __getitem__(self, item: Union[int, str, Column, list, tuple]) -> Union[Column, "DataFrame"]: """Returns the column as a :class:`Column`. - Examples + Examples: -------- >>> df.select(df["age"]).collect() [Row(age=2), Row(age=5)] @@ -902,7 +893,7 @@ def __getitem__(self, item: Union[int, str, Column, list, tuple]) -> Union[Colum def __getattr__(self, name: str) -> Column: """Returns the :class:`Column` denoted by ``name``. - Examples + Examples: -------- >>> df.select(df.age).collect() [Row(age=2), Row(age=5)] @@ -931,12 +922,12 @@ def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc] Each element should be a column name (string) or an expression (:class:`Column`) or list of them. - Returns + Returns: ------- :class:`GroupedData` Grouped data by given columns. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [(2, "Alice"), (2, "Bob"), (2, "Bob"), (5, "Bob")], schema=["age", "name"] @@ -1008,22 +999,22 @@ def union(self, other: "DataFrame") -> "DataFrame": other : :class:`DataFrame` Another :class:`DataFrame` that needs to be unioned - Returns + Returns: ------- :class:`DataFrame` - See Also + See Also: -------- DataFrame.unionAll - Notes + Notes: ----- This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union (that does deduplication of elements), use this function followed by :func:`distinct`. Also as standard in SQL, this function resolves columns by position (not by name). - Examples + Examples: -------- >>> df1 = spark.createDataFrame([[1, 2, 3]], ["col0", "col1", "col2"]) >>> df2 = spark.createDataFrame([[4, 5, 6]], ["col1", "col2", "col0"]) @@ -1067,12 +1058,12 @@ def unionByName(self, other: "DataFrame", allowMissingColumns: bool = False) -> .. versionadded:: 3.1.0 - Returns + Returns: ------- :class:`DataFrame` Combined DataFrame. - Examples + Examples: -------- The difference between this function and :func:`union` is that this function resolves columns by name (not by position): @@ -1131,16 +1122,16 @@ def intersect(self, other: "DataFrame") -> "DataFrame": other : :class:`DataFrame` Another :class:`DataFrame` that needs to be combined. - Returns + Returns: ------- :class:`DataFrame` Combined DataFrame. - Notes + Notes: ----- This is equivalent to `INTERSECT` in SQL. - Examples + Examples: -------- >>> df1 = spark.createDataFrame([("a", 1), ("a", 1), ("b", 3), ("c", 4)], ["C1", "C2"]) >>> df2 = spark.createDataFrame([("a", 1), ("a", 1), ("b", 3)], ["C1", "C2"]) @@ -1171,12 +1162,12 @@ def intersectAll(self, other: "DataFrame") -> "DataFrame": other : :class:`DataFrame` Another :class:`DataFrame` that needs to be combined. - Returns + Returns: ------- :class:`DataFrame` Combined DataFrame. - Examples + Examples: -------- >>> df1 = spark.createDataFrame([("a", 1), ("a", 1), ("b", 3), ("c", 4)], ["C1", "C2"]) >>> df2 = spark.createDataFrame([("a", 1), ("a", 1), ("b", 3)], ["C1", "C2"]) @@ -1208,11 +1199,11 @@ def exceptAll(self, other: "DataFrame") -> "DataFrame": other : :class:`DataFrame` The other :class:`DataFrame` to compare to. - Returns + Returns: ------- :class:`DataFrame` - Examples + Examples: -------- >>> df1 = spark.createDataFrame( ... [("a", 1), ("a", 1), ("a", 1), ("a", 2), ("b", 3), ("c", 4)], ["C1", "C2"] @@ -1248,12 +1239,12 @@ def dropDuplicates(self, subset: Optional[list[str]] = None) -> "DataFrame": subset : List of column names, optional List of columns to use for duplicate comparison (default All columns). - Returns + Returns: ------- :class:`DataFrame` DataFrame without duplicates. - Examples + Examples: -------- >>> from pyspark.sql import Row >>> df = spark.createDataFrame( @@ -1297,12 +1288,12 @@ def dropDuplicates(self, subset: Optional[list[str]] = None) -> "DataFrame": def distinct(self) -> "DataFrame": """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`. - Returns + Returns: ------- :class:`DataFrame` DataFrame with distinct records. - Examples + Examples: -------- >>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (23, "Alice")], ["age", "name"]) @@ -1317,12 +1308,12 @@ def distinct(self) -> "DataFrame": def count(self) -> int: """Returns the number of rows in this :class:`DataFrame`. - Returns + Returns: ------- int Number of rows. - Examples + Examples: -------- >>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) @@ -1377,16 +1368,16 @@ def cache(self) -> "DataFrame": .. versionchanged:: 3.4.0 Supports Spark Connect. - Notes + Notes: ----- The default storage level has changed to `MEMORY_AND_DISK_DESER` to match Scala in 3.0. - Returns + Returns: ------- :class:`DataFrame` Cached DataFrame. - Examples + Examples: -------- >>> df = spark.range(1) >>> df.cache() diff --git a/duckdb/experimental/spark/sql/functions.py b/duckdb/experimental/spark/sql/functions.py index dfcf7e2e..501c9503 100644 --- a/duckdb/experimental/spark/sql/functions.py +++ b/duckdb/experimental/spark/sql/functions.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, Callable, Union, overload, Optional, List, Tuple, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Callable, Optional, Union, overload from duckdb import ( CaseExpression, @@ -17,14 +17,13 @@ from ..errors import PySparkTypeError from ..exception import ContributionsAcceptedError +from . import types as _types from ._typing import ColumnOrName from .column import Column, _get_expr -from . import types as _types def _invoke_function_over_columns(name: str, *cols: "ColumnOrName") -> Column: - """ - Invokes n-ary JVM function identified by name + """Invokes n-ary JVM function identified by name and wraps the result with :class:`~pyspark.sql.Column`. """ cols = [_to_column_expr(expr) for expr in cols] @@ -36,8 +35,7 @@ def col(column: str): def upper(col: "ColumnOrName") -> Column: - """ - Converts a string expression to upper case. + """Converts a string expression to upper case. .. versionadded:: 1.5.0 @@ -49,12 +47,12 @@ def upper(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` upper case values. - Examples + Examples: -------- >>> df = spark.createDataFrame(["Spark", "PySpark", "Pandas API"], "STRING") >>> df.select(upper("value")).show() @@ -70,8 +68,7 @@ def upper(col: "ColumnOrName") -> Column: def ucase(str: "ColumnOrName") -> Column: - """ - Returns `str` with all characters changed to uppercase. + """Returns `str` with all characters changed to uppercase. .. versionadded:: 3.5.0 @@ -80,7 +77,7 @@ def ucase(str: "ColumnOrName") -> Column: str : :class:`~pyspark.sql.Column` or str Input column or strings. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.range(1).select(sf.ucase(sf.lit("Spark"))).show() @@ -123,12 +120,12 @@ def array(*cols: Union["ColumnOrName", Union[list["ColumnOrName"], tuple["Column column names or :class:`~pyspark.sql.Column`\\s that have the same data type. - Returns + Returns: ------- :class:`~pyspark.sql.Column` a column of array type. - Examples + Examples: -------- >>> df = spark.createDataFrame([("Alice", 2), ("Bob", 5)], ("name", "age")) >>> df.select(array("age", "age").alias("arr")).collect() @@ -170,7 +167,7 @@ def regexp_replace(str: "ColumnOrName", pattern: str, replacement: str) -> Colum .. versionadded:: 1.5.0 - Examples + Examples: -------- >>> df = spark.createDataFrame([("100-200",)], ["str"]) >>> df.select(regexp_replace("str", r"(\d+)", "--").alias("d")).collect() @@ -186,8 +183,7 @@ def regexp_replace(str: "ColumnOrName", pattern: str, replacement: str) -> Colum def slice(x: "ColumnOrName", start: Union["ColumnOrName", int], length: Union["ColumnOrName", int]) -> Column: - """ - Collection function: returns an array containing all the elements in `x` from index `start` + """Collection function: returns an array containing all the elements in `x` from index `start` (array indices start at 1, or from the end if `start` is negative) with the specified `length`. .. versionadded:: 2.4.0 @@ -204,12 +200,12 @@ def slice(x: "ColumnOrName", start: Union["ColumnOrName", int], length: Union["C length : :class:`~pyspark.sql.Column` or str or int column name, column, or int containing the length of the slice - Returns + Returns: ------- :class:`~pyspark.sql.Column` a column of array type. Subset of array. - Examples + Examples: -------- >>> df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ["x"]) >>> df.select(slice(df.x, 2, 2).alias("sliced")).collect() @@ -224,8 +220,7 @@ def slice(x: "ColumnOrName", start: Union["ColumnOrName", int], length: Union["C def asc(col: "ColumnOrName") -> Column: - """ - Returns a sort expression based on the ascending order of the given column name. + """Returns a sort expression based on the ascending order of the given column name. .. versionadded:: 1.3.0 @@ -237,12 +232,12 @@ def asc(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to sort by in the ascending order. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column specifying the order. - Examples + Examples: -------- Sort by the column 'id' in the descending order. @@ -276,8 +271,7 @@ def asc(col: "ColumnOrName") -> Column: def asc_nulls_first(col: "ColumnOrName") -> Column: - """ - Returns a sort expression based on the ascending order of the given + """Returns a sort expression based on the ascending order of the given column name, and null values return before non-null values. .. versionadded:: 2.4.0 @@ -290,12 +284,12 @@ def asc_nulls_first(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to sort by in the ascending order. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column specifying the order. - Examples + Examples: -------- >>> df1 = spark.createDataFrame([(1, "Bob"), (0, None), (2, "Alice")], ["age", "name"]) >>> df1.sort(asc_nulls_first(df1.name)).show() @@ -312,8 +306,7 @@ def asc_nulls_first(col: "ColumnOrName") -> Column: def asc_nulls_last(col: "ColumnOrName") -> Column: - """ - Returns a sort expression based on the ascending order of the given + """Returns a sort expression based on the ascending order of the given column name, and null values appear after non-null values. .. versionadded:: 2.4.0 @@ -326,12 +319,12 @@ def asc_nulls_last(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to sort by in the ascending order. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column specifying the order. - Examples + Examples: -------- >>> df1 = spark.createDataFrame([(0, None), (1, "Bob"), (2, "Alice")], ["age", "name"]) >>> df1.sort(asc_nulls_last(df1.name)).show() @@ -348,8 +341,7 @@ def asc_nulls_last(col: "ColumnOrName") -> Column: def desc(col: "ColumnOrName") -> Column: - """ - Returns a sort expression based on the descending order of the given column name. + """Returns a sort expression based on the descending order of the given column name. .. versionadded:: 1.3.0 @@ -361,12 +353,12 @@ def desc(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to sort by in the descending order. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column specifying the order. - Examples + Examples: -------- Sort by the column 'id' in the descending order. @@ -385,8 +377,7 @@ def desc(col: "ColumnOrName") -> Column: def desc_nulls_first(col: "ColumnOrName") -> Column: - """ - Returns a sort expression based on the descending order of the given + """Returns a sort expression based on the descending order of the given column name, and null values appear before non-null values. .. versionadded:: 2.4.0 @@ -399,12 +390,12 @@ def desc_nulls_first(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to sort by in the descending order. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column specifying the order. - Examples + Examples: -------- >>> df1 = spark.createDataFrame([(0, None), (1, "Bob"), (2, "Alice")], ["age", "name"]) >>> df1.sort(desc_nulls_first(df1.name)).show() @@ -421,8 +412,7 @@ def desc_nulls_first(col: "ColumnOrName") -> Column: def desc_nulls_last(col: "ColumnOrName") -> Column: - """ - Returns a sort expression based on the descending order of the given + """Returns a sort expression based on the descending order of the given column name, and null values appear after non-null values. .. versionadded:: 2.4.0 @@ -435,12 +425,12 @@ def desc_nulls_last(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to sort by in the descending order. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column specifying the order. - Examples + Examples: -------- >>> df1 = spark.createDataFrame([(0, None), (1, "Bob"), (2, "Alice")], ["age", "name"]) >>> df1.sort(desc_nulls_last(df1.name)).show() @@ -457,8 +447,7 @@ def desc_nulls_last(col: "ColumnOrName") -> Column: def left(str: "ColumnOrName", len: "ColumnOrName") -> Column: - """ - Returns the leftmost `len`(`len` can be string type) characters from the string `str`, + """Returns the leftmost `len`(`len` can be string type) characters from the string `str`, if `len` is less or equal than 0 the result is an empty string. .. versionadded:: 3.5.0 @@ -470,7 +459,7 @@ def left(str: "ColumnOrName", len: "ColumnOrName") -> Column: len : :class:`~pyspark.sql.Column` or str Input column or strings, the leftmost `len`. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [ @@ -493,8 +482,7 @@ def left(str: "ColumnOrName", len: "ColumnOrName") -> Column: def right(str: "ColumnOrName", len: "ColumnOrName") -> Column: - """ - Returns the rightmost `len`(`len` can be string type) characters from the string `str`, + """Returns the rightmost `len`(`len` can be string type) characters from the string `str`, if `len` is less or equal than 0 the result is an empty string. .. versionadded:: 3.5.0 @@ -506,7 +494,7 @@ def right(str: "ColumnOrName", len: "ColumnOrName") -> Column: len : :class:`~pyspark.sql.Column` or str Input column or strings, the rightmost `len`. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [ @@ -549,12 +537,12 @@ def levenshtein(left: "ColumnOrName", right: "ColumnOrName", threshold: Optional .. versionchanged: 3.5.0 Added ``threshold`` argument. - Returns + Returns: ------- :class:`~pyspark.sql.Column` Levenshtein distance as integer value. - Examples + Examples: -------- >>> df0 = spark.createDataFrame( ... [ @@ -581,8 +569,7 @@ def levenshtein(left: "ColumnOrName", right: "ColumnOrName", threshold: Optional def lpad(col: "ColumnOrName", len: int, pad: str) -> Column: - """ - Left-pad the string column to width `len` with `pad`. + """Left-pad the string column to width `len` with `pad`. .. versionadded:: 1.5.0 @@ -598,12 +585,12 @@ def lpad(col: "ColumnOrName", len: int, pad: str) -> Column: pad : str chars to prepend. - Returns + Returns: ------- :class:`~pyspark.sql.Column` left padded result. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [("abcd",)], @@ -618,8 +605,7 @@ def lpad(col: "ColumnOrName", len: int, pad: str) -> Column: def rpad(col: "ColumnOrName", len: int, pad: str) -> Column: - """ - Right-pad the string column to width `len` with `pad`. + """Right-pad the string column to width `len` with `pad`. .. versionadded:: 1.5.0 @@ -635,12 +621,12 @@ def rpad(col: "ColumnOrName", len: int, pad: str) -> Column: pad : str chars to append. - Returns + Returns: ------- :class:`~pyspark.sql.Column` right padded result. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [("abcd",)], @@ -655,8 +641,7 @@ def rpad(col: "ColumnOrName", len: int, pad: str) -> Column: def ascii(col: "ColumnOrName") -> Column: - """ - Computes the numeric value of the first character of the string column. + """Computes the numeric value of the first character of the string column. .. versionadded:: 1.5.0 @@ -668,12 +653,12 @@ def ascii(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` numeric value. - Examples + Examples: -------- >>> df = spark.createDataFrame(["Spark", "PySpark", "Pandas API"], "STRING") >>> df.select(ascii("value")).show() @@ -689,8 +674,7 @@ def ascii(col: "ColumnOrName") -> Column: def asin(col: "ColumnOrName") -> Column: - """ - Computes inverse sine of the input column. + """Computes inverse sine of the input column. .. versionadded:: 1.4.0 @@ -702,12 +686,12 @@ def asin(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` inverse sine of `col`, as if computed by `java.lang.Math.asin()` - Examples + Examples: -------- >>> df = spark.createDataFrame([(0,), (2,)]) >>> df.select(asin(df.schema.fieldNames()[0])).show() @@ -728,8 +712,7 @@ def asin(col: "ColumnOrName") -> Column: def like(str: "ColumnOrName", pattern: "ColumnOrName", escapeChar: Optional["Column"] = None) -> Column: - """ - Returns true if str matches `pattern` with `escape`, + """Returns true if str matches `pattern` with `escape`, null if any arguments are null, false otherwise. The default escape character is the '\'. @@ -746,7 +729,7 @@ def like(str: "ColumnOrName", pattern: "ColumnOrName", escapeChar: Optional["Col If an escape character precedes a special symbol or another escape character, the following character is matched literally. It is invalid to escape any other character. - Examples + Examples: -------- >>> df = spark.createDataFrame([("Spark", "_park")], ["a", "b"]) >>> df.select(like(df.a, df.b).alias("r")).collect() @@ -766,8 +749,7 @@ def like(str: "ColumnOrName", pattern: "ColumnOrName", escapeChar: Optional["Col def ilike(str: "ColumnOrName", pattern: "ColumnOrName", escapeChar: Optional["Column"] = None) -> Column: - """ - Returns true if str matches `pattern` with `escape` case-insensitively, + """Returns true if str matches `pattern` with `escape` case-insensitively, null if any arguments are null, false otherwise. The default escape character is the '\'. @@ -784,7 +766,7 @@ def ilike(str: "ColumnOrName", pattern: "ColumnOrName", escapeChar: Optional["Co If an escape character precedes a special symbol or another escape character, the following character is matched literally. It is invalid to escape any other character. - Examples + Examples: -------- >>> df = spark.createDataFrame([("Spark", "_park")], ["a", "b"]) >>> df.select(ilike(df.a, df.b).alias("r")).collect() @@ -804,8 +786,7 @@ def ilike(str: "ColumnOrName", pattern: "ColumnOrName", escapeChar: Optional["Co def array_agg(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns a list of objects with duplicates. + """Aggregate function: returns a list of objects with duplicates. .. versionadded:: 3.5.0 @@ -814,12 +795,12 @@ def array_agg(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` list of objects with duplicates. - Examples + Examples: -------- >>> df = spark.createDataFrame([[1], [1], [2]], ["c"]) >>> df.agg(array_agg("c").alias("r")).collect() @@ -829,15 +810,14 @@ def array_agg(col: "ColumnOrName") -> Column: def collect_list(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns a list of objects with duplicates. + """Aggregate function: returns a list of objects with duplicates. .. versionadded:: 1.6.0 .. versionchanged:: 3.4.0 Supports Spark Connect. - Notes + Notes: ----- The function is non-deterministic because the order of collected results depends on the order of the rows which may be non-deterministic after a shuffle. @@ -847,12 +827,12 @@ def collect_list(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` list of objects with duplicates. - Examples + Examples: -------- >>> df2 = spark.createDataFrame([(2,), (5,), (5,)], ("age",)) >>> df2.agg(collect_list("age")).collect() @@ -862,8 +842,7 @@ def collect_list(col: "ColumnOrName") -> Column: def array_append(col: "ColumnOrName", value: Any) -> Column: - """ - Collection function: returns an array of the elements in col1 along + """Collection function: returns an array of the elements in col1 along with the added element in col2 at the last of the array. .. versionadded:: 3.4.0 @@ -875,16 +854,16 @@ def array_append(col: "ColumnOrName", value: Any) -> Column: value : a literal value, or a :class:`~pyspark.sql.Column` expression. - Returns + Returns: ------- :class:`~pyspark.sql.Column` an array of values from first array along with the element. - Notes + Notes: ----- Supports Spark Connect. - Examples + Examples: -------- >>> from pyspark.sql import Row >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2="c")]) @@ -897,8 +876,7 @@ def array_append(col: "ColumnOrName", value: Any) -> Column: def array_insert(arr: "ColumnOrName", pos: Union["ColumnOrName", int], value: Any) -> Column: - """ - Collection function: adds an item into a given array at a specified array index. + """Collection function: adds an item into a given array at a specified array index. Array indices start at 1, or start from the end if index is negative. Index above array size appends the array, or prepends the array if index is negative, with 'null' elements. @@ -915,16 +893,16 @@ def array_insert(arr: "ColumnOrName", pos: Union["ColumnOrName", int], value: An value : a literal value, or a :class:`~pyspark.sql.Column` expression. - Returns + Returns: ------- :class:`~pyspark.sql.Column` an array of values, including the new specified value - Notes + Notes: ----- Supports Spark Connect. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [(["a", "b", "c"], 2, "d"), (["c", "b", "a"], -2, "d")], ["data", "pos", "val"] @@ -991,8 +969,7 @@ def array_insert(arr: "ColumnOrName", pos: Union["ColumnOrName", int], value: An def array_contains(col: "ColumnOrName", value: Any) -> Column: - """ - Collection function: returns null if the array is null, true if the array contains the + """Collection function: returns null if the array is null, true if the array contains the given value, and false otherwise. Parameters @@ -1002,12 +979,12 @@ def array_contains(col: "ColumnOrName", value: Any) -> Column: value : value or column to check for in array - Returns + Returns: ------- :class:`~pyspark.sql.Column` a column of Boolean type. - Examples + Examples: -------- >>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ["data"]) >>> df.select(array_contains(df.data, "a")).collect() @@ -1020,8 +997,7 @@ def array_contains(col: "ColumnOrName", value: Any) -> Column: def array_distinct(col: "ColumnOrName") -> Column: - """ - Collection function: removes duplicate values from the array. + """Collection function: removes duplicate values from the array. .. versionadded:: 2.4.0 @@ -1033,12 +1009,12 @@ def array_distinct(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str name of column or expression - Returns + Returns: ------- :class:`~pyspark.sql.Column` an array of unique values. - Examples + Examples: -------- >>> df = spark.createDataFrame([([1, 2, 3, 2],), ([4, 5, 5, 4],)], ["data"]) >>> df.select(array_distinct(df.data)).collect() @@ -1048,8 +1024,7 @@ def array_distinct(col: "ColumnOrName") -> Column: def array_intersect(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: - """ - Collection function: returns an array of the elements in the intersection of col1 and col2, + """Collection function: returns an array of the elements in the intersection of col1 and col2, without duplicates. .. versionadded:: 2.4.0 @@ -1064,12 +1039,12 @@ def array_intersect(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: col2 : :class:`~pyspark.sql.Column` or str name of column containing array - Returns + Returns: ------- :class:`~pyspark.sql.Column` an array of values in the intersection of two arrays. - Examples + Examples: -------- >>> from pyspark.sql import Row >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) @@ -1080,8 +1055,7 @@ def array_intersect(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: def array_union(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: - """ - Collection function: returns an array of the elements in the union of col1 and col2, + """Collection function: returns an array of the elements in the union of col1 and col2, without duplicates. .. versionadded:: 2.4.0 @@ -1096,12 +1070,12 @@ def array_union(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: col2 : :class:`~pyspark.sql.Column` or str name of column containing array - Returns + Returns: ------- :class:`~pyspark.sql.Column` an array of values in union of two arrays. - Examples + Examples: -------- >>> from pyspark.sql import Row >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) @@ -1112,8 +1086,7 @@ def array_union(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: def array_max(col: "ColumnOrName") -> Column: - """ - Collection function: returns the maximum value of the array. + """Collection function: returns the maximum value of the array. .. versionadded:: 2.4.0 @@ -1125,12 +1098,12 @@ def array_max(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str name of column or expression - Returns + Returns: ------- :class:`~pyspark.sql.Column` maximum value of an array. - Examples + Examples: -------- >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ["data"]) >>> df.select(array_max(df.data).alias("max")).collect() @@ -1142,8 +1115,7 @@ def array_max(col: "ColumnOrName") -> Column: def array_min(col: "ColumnOrName") -> Column: - """ - Collection function: returns the minimum value of the array. + """Collection function: returns the minimum value of the array. .. versionadded:: 2.4.0 @@ -1155,12 +1127,12 @@ def array_min(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str name of column or expression - Returns + Returns: ------- :class:`~pyspark.sql.Column` minimum value of array. - Examples + Examples: -------- >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ["data"]) >>> df.select(array_min(df.data).alias("min")).collect() @@ -1172,8 +1144,7 @@ def array_min(col: "ColumnOrName") -> Column: def avg(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns the average of the values in a group. + """Aggregate function: returns the average of the values in a group. .. versionadded:: 1.3.0 @@ -1185,12 +1156,12 @@ def avg(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column for computed results. - Examples + Examples: -------- >>> df = spark.range(10) >>> df.select(avg(col("id"))).show() @@ -1204,8 +1175,7 @@ def avg(col: "ColumnOrName") -> Column: def sum(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns the sum of all values in the expression. + """Aggregate function: returns the sum of all values in the expression. .. versionadded:: 1.3.0 @@ -1217,12 +1187,12 @@ def sum(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column for computed results. - Examples + Examples: -------- >>> df = spark.range(10) >>> df.select(sum(df["id"])).show() @@ -1236,8 +1206,7 @@ def sum(col: "ColumnOrName") -> Column: def max(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns the maximum value of the expression in a group. + """Aggregate function: returns the maximum value of the expression in a group. .. versionadded:: 1.3.0 @@ -1249,12 +1218,12 @@ def max(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` column for computed results. - Examples + Examples: -------- >>> df = spark.range(10) >>> df.select(max(col("id"))).show() @@ -1268,8 +1237,7 @@ def max(col: "ColumnOrName") -> Column: def mean(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns the average of the values in a group. + """Aggregate function: returns the average of the values in a group. An alias of :func:`avg`. .. versionadded:: 1.4.0 @@ -1282,12 +1250,12 @@ def mean(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column for computed results. - Examples + Examples: -------- >>> df = spark.range(10) >>> df.select(mean(df.id)).show() @@ -1301,8 +1269,7 @@ def mean(col: "ColumnOrName") -> Column: def median(col: "ColumnOrName") -> Column: - """ - Returns the median of the values in a group. + """Returns the median of the values in a group. .. versionadded:: 3.4.0 @@ -1311,16 +1278,16 @@ def median(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the median of the values in a group. - Notes + Notes: ----- Supports Spark Connect. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [ @@ -1345,8 +1312,7 @@ def median(col: "ColumnOrName") -> Column: def mode(col: "ColumnOrName") -> Column: - """ - Returns the most frequent value in a group. + """Returns the most frequent value in a group. .. versionadded:: 3.4.0 @@ -1355,16 +1321,16 @@ def mode(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the most frequent value in a group. - Notes + Notes: ----- Supports Spark Connect. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [ @@ -1389,8 +1355,7 @@ def mode(col: "ColumnOrName") -> Column: def min(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns the minimum value of the expression in a group. + """Aggregate function: returns the minimum value of the expression in a group. .. versionadded:: 1.3.0 @@ -1402,12 +1367,12 @@ def min(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` column for computed results. - Examples + Examples: -------- >>> df = spark.range(10) >>> df.select(min(df.id)).show() @@ -1432,12 +1397,12 @@ def any_value(col: "ColumnOrName") -> Column: ignorenulls : :class:`~pyspark.sql.Column` or bool if first value is null then look for first non-null value. - Returns + Returns: ------- :class:`~pyspark.sql.Column` some value of `col` for a group of rows. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [(None, 1), ("a", 2), ("a", 3), ("b", 8), ("b", 2)], ["c1", "c2"] @@ -1451,8 +1416,7 @@ def any_value(col: "ColumnOrName") -> Column: def count(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns the number of items in a group. + """Aggregate function: returns the number of items in a group. .. versionadded:: 1.3.0 @@ -1464,12 +1428,12 @@ def count(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` column for computed results. - Examples + Examples: -------- Count by all columns (start), and by a column that does not count ``None``. @@ -1500,12 +1464,12 @@ def approx_count_distinct(col: "ColumnOrName", rsd: Optional[float] = None) -> C maximum relative standard deviation allowed (default = 0.05). For rsd < 0.01, it is more efficient to use :func:`count_distinct` - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column of computed results. - Examples + Examples: -------- >>> df = spark.createDataFrame([1, 2, 2, 3], "INT") >>> df.agg(approx_count_distinct("value").alias("distinct_values")).show() @@ -1521,8 +1485,7 @@ def approx_count_distinct(col: "ColumnOrName", rsd: Optional[float] = None) -> C def approxCountDistinct(col: "ColumnOrName", rsd: Optional[float] = None) -> Column: - """ - .. versionadded:: 1.3.0 + """.. versionadded:: 1.3.0 .. versionchanged:: 3.4.0 Supports Spark Connect. @@ -1546,8 +1509,7 @@ def transform( col: "ColumnOrName", f: Union[Callable[[Column], Column], Callable[[Column, Column], Column]], ) -> Column: - """ - Returns an array of elements after applying a transformation to each element in the input array. + """Returns an array of elements after applying a transformation to each element in the input array. .. versionadded:: 3.1.0 @@ -1571,12 +1533,12 @@ def transform( Python ``UserDefinedFunctions`` are not supported (`SPARK-27052 `__). - Returns + Returns: ------- :class:`~pyspark.sql.Column` a new array of transformed elements. - Examples + Examples: -------- >>> df = spark.createDataFrame([(1, [1, 2, 3, 4])], ("key", "values")) >>> df.select(transform("values", lambda x: x * 2).alias("doubled")).show() @@ -1599,8 +1561,7 @@ def transform( def concat_ws(sep: str, *cols: "ColumnOrName") -> "Column": - """ - Concatenates multiple input string columns together into a single string column, + """Concatenates multiple input string columns together into a single string column, using the given separator. .. versionadded:: 1.5.0 @@ -1615,12 +1576,12 @@ def concat_ws(sep: str, *cols: "ColumnOrName") -> "Column": cols : :class:`~pyspark.sql.Column` or str list of columns to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` string of concatenated words. - Examples + Examples: -------- >>> df = spark.createDataFrame([("abcd", "123")], ["s", "d"]) >>> df.select(concat_ws("-", df.s, df.d).alias("s")).collect() @@ -1631,8 +1592,7 @@ def concat_ws(sep: str, *cols: "ColumnOrName") -> "Column": def lower(col: "ColumnOrName") -> Column: - """ - Converts a string expression to lower case. + """Converts a string expression to lower case. .. versionadded:: 1.5.0 @@ -1644,12 +1604,12 @@ def lower(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` lower case values. - Examples + Examples: -------- >>> df = spark.createDataFrame(["Spark", "PySpark", "Pandas API"], "STRING") >>> df.select(lower("value")).show() @@ -1665,8 +1625,7 @@ def lower(col: "ColumnOrName") -> Column: def lcase(str: "ColumnOrName") -> Column: - """ - Returns `str` with all characters changed to lowercase. + """Returns `str` with all characters changed to lowercase. .. versionadded:: 3.5.0 @@ -1675,7 +1634,7 @@ def lcase(str: "ColumnOrName") -> Column: str : :class:`~pyspark.sql.Column` or str Input column or strings. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.range(1).select(sf.lcase(sf.lit("Spark"))).show() @@ -1689,8 +1648,7 @@ def lcase(str: "ColumnOrName") -> Column: def ceil(col: "ColumnOrName") -> Column: - """ - Computes the ceiling of the given value. + """Computes the ceiling of the given value. .. versionadded:: 1.4.0 @@ -1702,12 +1660,12 @@ def ceil(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column for computed results. - Examples + Examples: -------- >>> df = spark.range(1) >>> df.select(ceil(lit(-0.1))).show() @@ -1725,8 +1683,7 @@ def ceiling(col: "ColumnOrName") -> Column: def floor(col: "ColumnOrName") -> Column: - """ - Computes the floor of the given value. + """Computes the floor of the given value. .. versionadded:: 1.4.0 @@ -1738,12 +1695,12 @@ def floor(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str column to find floor for. - Returns + Returns: ------- :class:`~pyspark.sql.Column` nearest integer that is less than or equal to given value. - Examples + Examples: -------- >>> df = spark.range(1) >>> df.select(floor(lit(2.5))).show() @@ -1757,8 +1714,7 @@ def floor(col: "ColumnOrName") -> Column: def abs(col: "ColumnOrName") -> Column: - """ - Computes the absolute value. + """Computes the absolute value. .. versionadded:: 1.3.0 @@ -1770,12 +1726,12 @@ def abs(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` column for computed results. - Examples + Examples: -------- >>> df = spark.range(1) >>> df.select(abs(lit(-1))).show() @@ -1801,12 +1757,12 @@ def isnan(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` True if value is NaN and False otherwise. - Examples + Examples: -------- >>> df = spark.createDataFrame([(1.0, float("nan")), (float("nan"), 2.0)], ("a", "b")) >>> df.select("a", "b", isnan("a").alias("r1"), isnan(df.b).alias("r2")).show() @@ -1833,12 +1789,12 @@ def isnull(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` True if value is null and False otherwise. - Examples + Examples: -------- >>> df = spark.createDataFrame([(1, None), (None, 2)], ("a", "b")) >>> df.select("a", "b", isnull("a").alias("r1"), isnull(df.b).alias("r2")).show() @@ -1853,8 +1809,7 @@ def isnull(col: "ColumnOrName") -> Column: def isnotnull(col: "ColumnOrName") -> Column: - """ - Returns true if `col` is not null, or false otherwise. + """Returns true if `col` is not null, or false otherwise. .. versionadded:: 3.5.0 @@ -1862,7 +1817,7 @@ def isnotnull(col: "ColumnOrName") -> Column: ---------- col : :class:`~pyspark.sql.Column` or str - Examples + Examples: -------- >>> df = spark.createDataFrame([(None,), (1,)], ["e"]) >>> df.select(isnotnull(df.e).alias("r")).collect() @@ -1872,15 +1827,15 @@ def isnotnull(col: "ColumnOrName") -> Column: def equal_null(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: - """ - Returns same result as the EQUAL(=) operator for non-null operands, + """Returns same result as the EQUAL(=) operator for non-null operands, but returns true if both are null, false if one of the them is null. .. versionadded:: 3.5.0 Parameters ---------- col1 : :class:`~pyspark.sql.Column` or str col2 : :class:`~pyspark.sql.Column` or str - Examples + + Examples: -------- >>> df = spark.createDataFrame( ... [ @@ -1908,8 +1863,7 @@ def equal_null(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: def flatten(col: "ColumnOrName") -> Column: - """ - Collection function: creates a single array from an array of arrays. + """Collection function: creates a single array from an array of arrays. If a structure of nested arrays is deeper than two levels, only one level of nesting is removed. @@ -1923,12 +1877,12 @@ def flatten(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str name of column or expression - Returns + Returns: ------- :class:`~pyspark.sql.Column` flattened array. - Examples + Examples: -------- >>> df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],), ([None, [4, 5]],)], ["data"]) >>> df.show(truncate=False) @@ -1952,8 +1906,7 @@ def flatten(col: "ColumnOrName") -> Column: def array_compact(col: "ColumnOrName") -> Column: - """ - Collection function: removes null values from the array. + """Collection function: removes null values from the array. .. versionadded:: 3.4.0 @@ -1962,16 +1915,16 @@ def array_compact(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str name of column or expression - Returns + Returns: ------- :class:`~pyspark.sql.Column` an array by excluding the null values. - Notes + Notes: ----- Supports Spark Connect. - Examples + Examples: -------- >>> df = spark.createDataFrame([([1, None, 2, 3],), ([4, 5, None, 4],)], ["data"]) >>> df.select(array_compact(df.data)).collect() @@ -1983,8 +1936,7 @@ def array_compact(col: "ColumnOrName") -> Column: def array_remove(col: "ColumnOrName", element: Any) -> Column: - """ - Collection function: Remove all elements that equal to element from the given array. + """Collection function: Remove all elements that equal to element from the given array. .. versionadded:: 2.4.0 @@ -1998,12 +1950,12 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column: element : element to be removed from the array - Returns + Returns: ------- :class:`~pyspark.sql.Column` an array excluding given value. - Examples + Examples: -------- >>> df = spark.createDataFrame([([1, 2, 3, 1, 1],), ([],)], ["data"]) >>> df.select(array_remove(df.data, 1)).collect() @@ -2015,8 +1967,7 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column: def last_day(date: "ColumnOrName") -> Column: - """ - Returns the last day of the month which the given date belongs to. + """Returns the last day of the month which the given date belongs to. .. versionadded:: 1.5.0 @@ -2028,12 +1979,12 @@ def last_day(date: "ColumnOrName") -> Column: date : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` last day of the month. - Examples + Examples: -------- >>> df = spark.createDataFrame([("1997-02-10",)], ["d"]) >>> df.select(last_day(df.d).alias("date")).collect() @@ -2043,8 +1994,7 @@ def last_day(date: "ColumnOrName") -> Column: def sqrt(col: "ColumnOrName") -> Column: - """ - Computes the square root of the specified float value. + """Computes the square root of the specified float value. .. versionadded:: 1.3.0 @@ -2056,12 +2006,12 @@ def sqrt(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` column for computed results. - Examples + Examples: -------- >>> df = spark.range(1) >>> df.select(sqrt(lit(4))).show() @@ -2075,8 +2025,7 @@ def sqrt(col: "ColumnOrName") -> Column: def cbrt(col: "ColumnOrName") -> Column: - """ - Computes the cube-root of the given value. + """Computes the cube-root of the given value. .. versionadded:: 1.4.0 @@ -2088,12 +2037,12 @@ def cbrt(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column for computed results. - Examples + Examples: -------- >>> df = spark.range(1) >>> df.select(cbrt(lit(27))).show() @@ -2107,8 +2056,7 @@ def cbrt(col: "ColumnOrName") -> Column: def char(col: "ColumnOrName") -> Column: - """ - Returns the ASCII character having the binary equivalent to `col`. If col is larger than 256 the + """Returns the ASCII character having the binary equivalent to `col`. If col is larger than 256 the result is equivalent to char(col % 256) .. versionadded:: 3.5.0 @@ -2118,7 +2066,7 @@ def char(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str Input column or strings. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.range(1).select(sf.char(sf.lit(65))).show() @@ -2148,12 +2096,12 @@ def corr(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: col1 : :class:`~pyspark.sql.Column` or str second column to calculate correlation. - Returns + Returns: ------- :class:`~pyspark.sql.Column` Pearson Correlation Coefficient of these two column values. - Examples + Examples: -------- >>> a = range(20) >>> b = [2 * x for x in range(20)] @@ -2165,8 +2113,7 @@ def corr(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: def cot(col: "ColumnOrName") -> Column: - """ - Computes cotangent of the input column. + """Computes cotangent of the input column. .. versionadded:: 3.3.0 @@ -2178,12 +2125,12 @@ def cot(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str angle in radians. - Returns + Returns: ------- :class:`~pyspark.sql.Column` cotangent of the angle. - Examples + Examples: -------- >>> import math >>> df = spark.range(1) @@ -2198,7 +2145,7 @@ def e() -> Column: .. versionadded:: 3.5.0 - Examples + Examples: -------- >>> spark.range(1).select(e()).show() +-----------------+ @@ -2211,18 +2158,19 @@ def e() -> Column: def negative(col: "ColumnOrName") -> Column: - """ - Returns the negative value. + """Returns the negative value. .. versionadded:: 3.5.0 Parameters ---------- col : :class:`~pyspark.sql.Column` or str column to calculate negative value for. - Returns + + Returns: ------- :class:`~pyspark.sql.Column` negative value. - Examples + + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.range(3).select(sf.negative("id")).show() @@ -2242,7 +2190,7 @@ def pi() -> Column: .. versionadded:: 3.5.0 - Examples + Examples: -------- >>> spark.range(1).select(pi()).show() +-----------------+ @@ -2255,8 +2203,7 @@ def pi() -> Column: def positive(col: "ColumnOrName") -> Column: - """ - Returns the value. + """Returns the value. .. versionadded:: 3.5.0 @@ -2265,12 +2212,12 @@ def positive(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str input value column. - Returns + Returns: ------- :class:`~pyspark.sql.Column` value. - Examples + Examples: -------- >>> df = spark.createDataFrame([(-1,), (0,), (1,)], ["v"]) >>> df.select(positive("v").alias("p")).show() @@ -2286,8 +2233,7 @@ def positive(col: "ColumnOrName") -> Column: def pow(col1: Union["ColumnOrName", float], col2: Union["ColumnOrName", float]) -> Column: - """ - Returns the value of the first argument raised to the power of the second argument. + """Returns the value of the first argument raised to the power of the second argument. .. versionadded:: 1.4.0 @@ -2301,12 +2247,12 @@ def pow(col1: Union["ColumnOrName", float], col2: Union["ColumnOrName", float]) col2 : str, :class:`~pyspark.sql.Column` or float the exponent number. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the base rased to the power the argument. - Examples + Examples: -------- >>> df = spark.range(1) >>> df.select(pow(lit(3), lit(2))).first() @@ -2316,8 +2262,7 @@ def pow(col1: Union["ColumnOrName", float], col2: Union["ColumnOrName", float]) def printf(format: "ColumnOrName", *cols: "ColumnOrName") -> Column: - """ - Formats the arguments in printf-style and returns the result as a string column. + """Formats the arguments in printf-style and returns the result as a string column. .. versionadded:: 3.5.0 @@ -2328,7 +2273,7 @@ def printf(format: "ColumnOrName", *cols: "ColumnOrName") -> Column: cols : :class:`~pyspark.sql.Column` or str column names or :class:`~pyspark.sql.Column`\\s to be used in formatting - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.createDataFrame( @@ -2351,8 +2296,7 @@ def printf(format: "ColumnOrName", *cols: "ColumnOrName") -> Column: def product(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns the product of the values in a group. + """Aggregate function: returns the product of the values in a group. .. versionadded:: 3.2.0 @@ -2364,12 +2308,12 @@ def product(col: "ColumnOrName") -> Column: col : str, :class:`Column` column containing values to be multiplied together - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column for computed results. - Examples + Examples: -------- >>> df = spark.range(1, 10).toDF("x").withColumn("mod3", col("x") % 3) >>> prods = df.groupBy("mod3").agg(product("x").alias("product")) @@ -2394,7 +2338,7 @@ def rand(seed: Optional[int] = None) -> Column: .. versionchanged:: 3.4.0 Supports Spark Connect. - Notes + Notes: ----- The function is non-deterministic in general case. @@ -2403,12 +2347,12 @@ def rand(seed: Optional[int] = None) -> Column: seed : int (default: None) seed value for random generator. - Returns + Returns: ------- :class:`~pyspark.sql.Column` random values. - Examples + Examples: -------- >>> from pyspark.sql import functions as sf >>> spark.range(0, 2, 1, 1).withColumn("rand", sf.rand(seed=42) * 3).show() @@ -2437,12 +2381,12 @@ def regexp(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: regexp : :class:`~pyspark.sql.Column` or str regex pattern to apply. - Returns + Returns: ------- :class:`~pyspark.sql.Column` true if `str` matches a Java regex, or false otherwise. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.createDataFrame([("1a 2b 14m", r"(\d+)")], ["str", "regexp"]).select( @@ -2490,12 +2434,12 @@ def regexp_count(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: regexp : :class:`~pyspark.sql.Column` or str regex pattern to apply. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the number of times that a Java regex pattern is matched in the string. - Examples + Examples: -------- >>> df = spark.createDataFrame([("1a 2b 14m", r"\d+")], ["str", "regexp"]) >>> df.select(regexp_count("str", lit(r"\d+")).alias("d")).collect() @@ -2526,12 +2470,12 @@ def regexp_extract(str: "ColumnOrName", pattern: str, idx: int) -> Column: idx : int matched group id. - Returns + Returns: ------- :class:`~pyspark.sql.Column` matched value specified by `idx` group id. - Examples + Examples: -------- >>> df = spark.createDataFrame([("100-200",)], ["str"]) >>> df.select(regexp_extract("str", r"(\d+)-(\d+)", 1).alias("d")).collect() @@ -2563,12 +2507,12 @@ def regexp_extract_all(str: "ColumnOrName", regexp: "ColumnOrName", idx: Optiona idx : int matched group id. - Returns + Returns: ------- :class:`~pyspark.sql.Column` all strings in the `str` that match a Java regex and corresponding to the regex group index. - Examples + Examples: -------- >>> df = spark.createDataFrame([("100-200, 300-400", r"(\d+)-(\d+)")], ["str", "regexp"]) >>> df.select(regexp_extract_all("str", lit(r"(\d+)-(\d+)")).alias("d")).collect() @@ -2599,12 +2543,12 @@ def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: regexp : :class:`~pyspark.sql.Column` or str regex pattern to apply. - Returns + Returns: ------- :class:`~pyspark.sql.Column` true if `str` matches a Java regex, or false otherwise. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.createDataFrame([("1a 2b 14m", r"(\d+)")], ["str", "regexp"]).select( @@ -2652,12 +2596,12 @@ def regexp_substr(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: regexp : :class:`~pyspark.sql.Column` or str regex pattern to apply. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the substring that matches a Java regex within the string `str`. - Examples + Examples: -------- >>> df = spark.createDataFrame([("1a 2b 14m", r"\d+")], ["str", "regexp"]) >>> df.select(regexp_substr("str", lit(r"\d+")).alias("d")).collect() @@ -2677,8 +2621,7 @@ def regexp_substr(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: def repeat(col: "ColumnOrName", n: int) -> Column: - """ - Repeats a string column n times, and returns it as a new string column. + """Repeats a string column n times, and returns it as a new string column. .. versionadded:: 1.5.0 @@ -2692,12 +2635,12 @@ def repeat(col: "ColumnOrName", n: int) -> Column: n : int number of times to repeat value. - Returns + Returns: ------- :class:`~pyspark.sql.Column` string with repeated values. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [("ab",)], @@ -2712,8 +2655,7 @@ def repeat(col: "ColumnOrName", n: int) -> Column: def sequence(start: "ColumnOrName", stop: "ColumnOrName", step: Optional["ColumnOrName"] = None) -> Column: - """ - Generate a sequence of integers from `start` to `stop`, incrementing by `step`. + """Generate a sequence of integers from `start` to `stop`, incrementing by `step`. If `step` is not set, incrementing by 1 if `start` is less than or equal to `stop`, otherwise -1. @@ -2731,12 +2673,12 @@ def sequence(start: "ColumnOrName", stop: "ColumnOrName", step: Optional["Column step : :class:`~pyspark.sql.Column` or str, optional value to add to current to get next element (default is 1) - Returns + Returns: ------- :class:`~pyspark.sql.Column` an array of sequence values - Examples + Examples: -------- >>> df1 = spark.createDataFrame([(-2, 2)], ("C1", "C2")) >>> df1.select(sequence("C1", "C2").alias("r")).collect() @@ -2752,8 +2694,7 @@ def sequence(start: "ColumnOrName", stop: "ColumnOrName", step: Optional["Column def sign(col: "ColumnOrName") -> Column: - """ - Computes the signum of the given value. + """Computes the signum of the given value. .. versionadded:: 1.4.0 @@ -2765,12 +2706,12 @@ def sign(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column for computed results. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.range(1).select(sf.sign(sf.lit(-5)), sf.sign(sf.lit(6))).show() @@ -2784,8 +2725,7 @@ def sign(col: "ColumnOrName") -> Column: def signum(col: "ColumnOrName") -> Column: - """ - Computes the signum of the given value. + """Computes the signum of the given value. .. versionadded:: 1.4.0 @@ -2797,12 +2737,12 @@ def signum(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column for computed results. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.range(1).select(sf.signum(sf.lit(-5)), sf.signum(sf.lit(6))).show() @@ -2816,8 +2756,7 @@ def signum(col: "ColumnOrName") -> Column: def sin(col: "ColumnOrName") -> Column: - """ - Computes sine of the input column. + """Computes sine of the input column. .. versionadded:: 1.4.0 @@ -2829,12 +2768,12 @@ def sin(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` sine of the angle, as if computed by `java.lang.Math.sin()` - Examples + Examples: -------- >>> import math >>> df = spark.range(1) @@ -2845,8 +2784,7 @@ def sin(col: "ColumnOrName") -> Column: def skewness(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns the skewness of the values in a group. + """Aggregate function: returns the skewness of the values in a group. .. versionadded:: 1.6.0 @@ -2858,12 +2796,12 @@ def skewness(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` skewness of given column. - Examples + Examples: -------- >>> df = spark.createDataFrame([[1], [1], [2]], ["c"]) >>> df.select(skewness(df.c)).first() @@ -2873,8 +2811,7 @@ def skewness(col: "ColumnOrName") -> Column: def encode(col: "ColumnOrName", charset: str) -> Column: - """ - Computes the first argument into a binary from a string using the provided character set + """Computes the first argument into a binary from a string using the provided character set (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). .. versionadded:: 1.5.0 @@ -2889,12 +2826,12 @@ def encode(col: "ColumnOrName", charset: str) -> Column: charset : str charset to use to encode. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column for computed results. - Examples + Examples: -------- >>> df = spark.createDataFrame([("abcd",)], ["c"]) >>> df.select(encode("c", "UTF-8")).show() @@ -2910,8 +2847,7 @@ def encode(col: "ColumnOrName", charset: str) -> Column: def find_in_set(str: "ColumnOrName", str_array: "ColumnOrName") -> Column: - """ - Returns the index (1-based) of the given string (`str`) in the comma-delimited + """Returns the index (1-based) of the given string (`str`) in the comma-delimited list (`strArray`). Returns 0, if the string was not found or if the given string (`str`) contains a comma. @@ -2924,7 +2860,7 @@ def find_in_set(str: "ColumnOrName", str_array: "ColumnOrName") -> Column: str_array : :class:`~pyspark.sql.Column` or str The comma-delimited list. - Examples + Examples: -------- >>> df = spark.createDataFrame([("ab", "abc,b,ab,c,def")], ["a", "b"]) >>> df.select(find_in_set(df.a, df.b).alias("r")).collect() @@ -2956,7 +2892,7 @@ def first(col: "ColumnOrName", ignorenulls: bool = False) -> Column: .. versionchanged:: 3.4.0 Supports Spark Connect. - Notes + Notes: ----- The function is non-deterministic because its results depends on the order of the rows which may be non-deterministic after a shuffle. @@ -2968,12 +2904,12 @@ def first(col: "ColumnOrName", ignorenulls: bool = False) -> Column: ignorenulls : :class:`~pyspark.sql.Column` or str if first value is null then look for first non-null value. - Returns + Returns: ------- :class:`~pyspark.sql.Column` first value of the group. - Examples + Examples: -------- >>> df = spark.createDataFrame([("Alice", 2), ("Bob", 5), ("Alice", None)], ("name", "age")) >>> df = df.orderBy(df.age) @@ -3011,7 +2947,7 @@ def last(col: "ColumnOrName", ignorenulls: bool = False) -> Column: .. versionchanged:: 3.4.0 Supports Spark Connect. - Notes + Notes: ----- The function is non-deterministic because its results depends on the order of the rows which may be non-deterministic after a shuffle. @@ -3023,12 +2959,12 @@ def last(col: "ColumnOrName", ignorenulls: bool = False) -> Column: ignorenulls : :class:`~pyspark.sql.Column` or str if last value is null then look for non-null value. - Returns + Returns: ------- :class:`~pyspark.sql.Column` last value of the group. - Examples + Examples: -------- >>> df = spark.createDataFrame([("Alice", 2), ("Bob", 5), ("Alice", None)], ("name", "age")) >>> df = df.orderBy(df.age.desc()) @@ -3056,8 +2992,7 @@ def last(col: "ColumnOrName", ignorenulls: bool = False) -> Column: def greatest(*cols: "ColumnOrName") -> Column: - """ - Returns the greatest value of the list of column names, skipping null values. + """Returns the greatest value of the list of column names, skipping null values. This function takes at least 2 parameters. It will return null if all parameters are null. .. versionadded:: 1.5.0 @@ -3070,18 +3005,17 @@ def greatest(*cols: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str columns to check for gratest value. - Returns + Returns: ------- :class:`~pyspark.sql.Column` gratest value. - Examples + Examples: -------- >>> df = spark.createDataFrame([(1, 4, 3)], ["a", "b", "c"]) >>> df.select(greatest(df.a, df.b, df.c).alias("greatest")).collect() [Row(greatest=4)] """ - if len(cols) < 2: raise ValueError("greatest should take at least 2 columns") @@ -3090,8 +3024,7 @@ def greatest(*cols: "ColumnOrName") -> Column: def least(*cols: "ColumnOrName") -> Column: - """ - Returns the least value of the list of column names, skipping null values. + """Returns the least value of the list of column names, skipping null values. This function takes at least 2 parameters. It will return null if all parameters are null. .. versionadded:: 1.5.0 @@ -3104,12 +3037,12 @@ def least(*cols: "ColumnOrName") -> Column: cols : :class:`~pyspark.sql.Column` or str column names or columns to be compared - Returns + Returns: ------- :class:`~pyspark.sql.Column` least value. - Examples + Examples: -------- >>> df = spark.createDataFrame([(1, 4, 3)], ["a", "b", "c"]) >>> df.select(least(df.a, df.b, df.c).alias("least")).collect() @@ -3123,8 +3056,7 @@ def least(*cols: "ColumnOrName") -> Column: def trim(col: "ColumnOrName") -> Column: - """ - Trim the spaces from left end for the specified string value. + """Trim the spaces from left end for the specified string value. .. versionadded:: 1.5.0 @@ -3136,12 +3068,12 @@ def trim(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` left trimmed values. - Examples + Examples: -------- >>> df = spark.createDataFrame([" Spark", "Spark ", " Spark"], "STRING") >>> df.select(ltrim("value").alias("r")).withColumn("length", length("r")).show() @@ -3157,8 +3089,7 @@ def trim(col: "ColumnOrName") -> Column: def rtrim(col: "ColumnOrName") -> Column: - """ - Trim the spaces from right end for the specified string value. + """Trim the spaces from right end for the specified string value. .. versionadded:: 1.5.0 @@ -3170,12 +3101,12 @@ def rtrim(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` right trimmed values. - Examples + Examples: -------- >>> df = spark.createDataFrame([" Spark", "Spark ", " Spark"], "STRING") >>> df.select(rtrim("value").alias("r")).withColumn("length", length("r")).show() @@ -3191,8 +3122,7 @@ def rtrim(col: "ColumnOrName") -> Column: def ltrim(col: "ColumnOrName") -> Column: - """ - Trim the spaces from left end for the specified string value. + """Trim the spaces from left end for the specified string value. .. versionadded:: 1.5.0 @@ -3204,12 +3134,12 @@ def ltrim(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` left trimmed values. - Examples + Examples: -------- >>> df = spark.createDataFrame([" Spark", "Spark ", " Spark"], "STRING") >>> df.select(ltrim("value").alias("r")).withColumn("length", length("r")).show() @@ -3225,8 +3155,7 @@ def ltrim(col: "ColumnOrName") -> Column: def btrim(str: "ColumnOrName", trim: Optional["ColumnOrName"] = None) -> Column: - """ - Remove the leading and trailing `trim` characters from `str`. + """Remove the leading and trailing `trim` characters from `str`. .. versionadded:: 3.5.0 @@ -3237,7 +3166,7 @@ def btrim(str: "ColumnOrName", trim: Optional["ColumnOrName"] = None) -> Column: trim : :class:`~pyspark.sql.Column` or str The trim string characters to trim, the default value is a single space - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [ @@ -3262,8 +3191,7 @@ def btrim(str: "ColumnOrName", trim: Optional["ColumnOrName"] = None) -> Column: def endswith(str: "ColumnOrName", suffix: "ColumnOrName") -> Column: - """ - Returns a boolean. The value is True if str ends with suffix. + """Returns a boolean. The value is True if str ends with suffix. Returns NULL if either input expression is NULL. Otherwise, returns False. Both str or suffix must be of STRING or BINARY type. @@ -3276,7 +3204,7 @@ def endswith(str: "ColumnOrName", suffix: "ColumnOrName") -> Column: suffix : :class:`~pyspark.sql.Column` or str A column of string, the suffix. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [ @@ -3315,8 +3243,7 @@ def endswith(str: "ColumnOrName", suffix: "ColumnOrName") -> Column: def startswith(str: "ColumnOrName", prefix: "ColumnOrName") -> Column: - """ - Returns a boolean. The value is True if str starts with prefix. + """Returns a boolean. The value is True if str starts with prefix. Returns NULL if either input expression is NULL. Otherwise, returns False. Both str or prefix must be of STRING or BINARY type. @@ -3329,7 +3256,7 @@ def startswith(str: "ColumnOrName", prefix: "ColumnOrName") -> Column: prefix : :class:`~pyspark.sql.Column` or str A column of string, the prefix. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [ @@ -3382,12 +3309,12 @@ def length(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` length of the value. - Examples + Examples: -------- >>> spark.createDataFrame([("ABC ",)], ["a"]).select(length("a").alias("length")).collect() [Row(length=4)] @@ -3404,11 +3331,13 @@ def coalesce(*cols: "ColumnOrName") -> Column: ---------- cols : :class:`~pyspark.sql.Column` or str list of columns to work on. - Returns + + Returns: ------- :class:`~pyspark.sql.Column` value of the first column that is not null. - Examples + + Examples: -------- >>> cDf = spark.createDataFrame([(None, None), (1, None), (None, 2)], ("a", "b")) >>> cDf.show() @@ -3436,20 +3365,19 @@ def coalesce(*cols: "ColumnOrName") -> Column: |NULL| 2| 0.0| +----+----+----------------+ """ - cols = [_to_column_expr(expr) for expr in cols] return Column(CoalesceOperator(*cols)) def nvl(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: - """ - Returns `col2` if `col1` is null, or `col1` otherwise. + """Returns `col2` if `col1` is null, or `col1` otherwise. .. versionadded:: 3.5.0 Parameters ---------- col1 : :class:`~pyspark.sql.Column` or str col2 : :class:`~pyspark.sql.Column` or str - Examples + + Examples: -------- >>> df = spark.createDataFrame( ... [ @@ -3467,13 +3395,11 @@ def nvl(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: >>> df.select(nvl(df.a, df.b).alias("r")).collect() [Row(r=8), Row(r=1)] """ - return coalesce(col1, col2) def nvl2(col1: "ColumnOrName", col2: "ColumnOrName", col3: "ColumnOrName") -> Column: - """ - Returns `col2` if `col1` is not null, or `col3` otherwise. + """Returns `col2` if `col1` is not null, or `col3` otherwise. .. versionadded:: 3.5.0 @@ -3483,7 +3409,7 @@ def nvl2(col1: "ColumnOrName", col2: "ColumnOrName", col3: "ColumnOrName") -> Co col2 : :class:`~pyspark.sql.Column` or str col3 : :class:`~pyspark.sql.Column` or str - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [ @@ -3510,14 +3436,14 @@ def nvl2(col1: "ColumnOrName", col2: "ColumnOrName", col3: "ColumnOrName") -> Co def ifnull(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: - """ - Returns `col2` if `col1` is null, or `col1` otherwise. + """Returns `col2` if `col1` is null, or `col1` otherwise. .. versionadded:: 3.5.0 Parameters ---------- col1 : :class:`~pyspark.sql.Column` or str col2 : :class:`~pyspark.sql.Column` or str - Examples + + Examples: -------- >>> import pyspark.sql.functions as sf >>> df = spark.createDataFrame([(None,), (1,)], ["e"]) @@ -3533,8 +3459,7 @@ def ifnull(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: def nullif(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: - """ - Returns null if `col1` equals to `col2`, or `col1` otherwise. + """Returns null if `col1` equals to `col2`, or `col1` otherwise. .. versionadded:: 3.5.0 @@ -3543,7 +3468,7 @@ def nullif(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: col1 : :class:`~pyspark.sql.Column` or str col2 : :class:`~pyspark.sql.Column` or str - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [ @@ -3577,12 +3502,12 @@ def md5(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column for computed results. - Examples + Examples: -------- >>> spark.createDataFrame([("ABC",)], ["a"]).select(md5("a").alias("hash")).collect() [Row(hash='902fbdd2b1df0c4f70b4a5d23525e932')] @@ -3608,12 +3533,12 @@ def sha2(col: "ColumnOrName", numBits: int) -> Column: the desired bit length of the result, which must have a value of 224, 256, 384, 512, or 0 (which is equivalent to 256). - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column for computed results. - Examples + Examples: -------- >>> df = spark.createDataFrame([["Alice"], ["Bob"]], ["name"]) >>> df.withColumn("sha2", sha2(df.name, 256)).show(truncate=False) @@ -3624,7 +3549,6 @@ def sha2(col: "ColumnOrName", numBits: int) -> Column: |Bob |cd9fb1e148ccd8442e5aa74904cc73bf6fb54d1d54d333bd596aa9bb4bb4e961| +-----+----------------------------------------------------------------+ """ - if numBits not in {224, 256, 384, 512, 0}: raise ValueError("numBits should be one of {224, 256, 384, 512, 0}") @@ -3635,18 +3559,17 @@ def sha2(col: "ColumnOrName", numBits: int) -> Column: def curdate() -> Column: - """ - Returns the current date at the start of query evaluation as a :class:`DateType` column. + """Returns the current date at the start of query evaluation as a :class:`DateType` column. All calls of current_date within the same query return the same value. .. versionadded:: 3.5.0 - Returns + Returns: ------- :class:`~pyspark.sql.Column` current date. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.range(1).select(sf.curdate()).show() # doctest: +SKIP @@ -3660,8 +3583,7 @@ def curdate() -> Column: def current_date() -> Column: - """ - Returns the current date at the start of query evaluation as a :class:`DateType` column. + """Returns the current date at the start of query evaluation as a :class:`DateType` column. All calls of current_date within the same query return the same value. .. versionadded:: 1.5.0 @@ -3669,12 +3591,12 @@ def current_date() -> Column: .. versionchanged:: 3.4.0 Supports Spark Connect. - Returns + Returns: ------- :class:`~pyspark.sql.Column` current date. - Examples + Examples: -------- >>> df = spark.range(1) >>> df.select(current_date()).show() # doctest: +SKIP @@ -3688,17 +3610,16 @@ def current_date() -> Column: def now() -> Column: - """ - Returns the current timestamp at the start of query evaluation. + """Returns the current timestamp at the start of query evaluation. .. versionadded:: 3.5.0 - Returns + Returns: ------- :class:`~pyspark.sql.Column` current timestamp at the start of query evaluation. - Examples + Examples: -------- >>> df = spark.range(1) >>> df.select(now()).show(truncate=False) # doctest: +SKIP @@ -3712,8 +3633,7 @@ def now() -> Column: def desc(col: "ColumnOrName") -> Column: - """ - Returns a sort expression based on the descending order of the given column name. + """Returns a sort expression based on the descending order of the given column name. .. versionadded:: 1.3.0 @@ -3725,12 +3645,12 @@ def desc(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to sort by in the descending order. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column specifying the order. - Examples + Examples: -------- Sort by the column 'id' in the descending order. @@ -3749,8 +3669,7 @@ def desc(col: "ColumnOrName") -> Column: def asc(col: "ColumnOrName") -> Column: - """ - Returns a sort expression based on the ascending order of the given column name. + """Returns a sort expression based on the ascending order of the given column name. .. versionadded:: 1.3.0 @@ -3762,12 +3681,12 @@ def asc(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to sort by in the ascending order. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column specifying the order. - Examples + Examples: -------- Sort by the column 'id' in the descending order. @@ -3801,8 +3720,7 @@ def asc(col: "ColumnOrName") -> Column: def date_trunc(format: str, timestamp: "ColumnOrName") -> Column: - """ - Returns timestamp truncated to the unit specified by the format. + """Returns timestamp truncated to the unit specified by the format. .. versionadded:: 2.3.0 @@ -3813,7 +3731,7 @@ def date_trunc(format: str, timestamp: "ColumnOrName") -> Column: 'day', 'dd', 'hour', 'minute', 'second', 'week', 'quarter' timestamp : :class:`~pyspark.sql.Column` or str - Examples + Examples: -------- >>> df = spark.createDataFrame([("1997-02-28 05:02:11",)], ["t"]) >>> df.select(date_trunc("year", df.t).alias("year")).collect() @@ -3834,8 +3752,7 @@ def date_trunc(format: str, timestamp: "ColumnOrName") -> Column: def date_part(field: "ColumnOrName", source: "ColumnOrName") -> Column: - """ - Extracts a part of the date/timestamp or interval source. + """Extracts a part of the date/timestamp or interval source. .. versionadded:: 3.5.0 @@ -3847,12 +3764,12 @@ def date_part(field: "ColumnOrName", source: "ColumnOrName") -> Column: source : :class:`~pyspark.sql.Column` or str a date/timestamp or interval column from where `field` should be extracted. - Returns + Returns: ------- :class:`~pyspark.sql.Column` a part of the date/timestamp or interval source. - Examples + Examples: -------- >>> import datetime >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ["ts"]) @@ -3870,8 +3787,7 @@ def date_part(field: "ColumnOrName", source: "ColumnOrName") -> Column: def extract(field: "ColumnOrName", source: "ColumnOrName") -> Column: - """ - Extracts a part of the date/timestamp or interval source. + """Extracts a part of the date/timestamp or interval source. .. versionadded:: 3.5.0 @@ -3882,12 +3798,12 @@ def extract(field: "ColumnOrName", source: "ColumnOrName") -> Column: source : :class:`~pyspark.sql.Column` or str a date/timestamp or interval column from where `field` should be extracted. - Returns + Returns: ------- :class:`~pyspark.sql.Column` a part of the date/timestamp or interval source. - Examples + Examples: -------- >>> import datetime >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ["ts"]) @@ -3905,8 +3821,7 @@ def extract(field: "ColumnOrName", source: "ColumnOrName") -> Column: def datepart(field: "ColumnOrName", source: "ColumnOrName") -> Column: - """ - Extracts a part of the date/timestamp or interval source. + """Extracts a part of the date/timestamp or interval source. .. versionadded:: 3.5.0 @@ -3918,12 +3833,12 @@ def datepart(field: "ColumnOrName", source: "ColumnOrName") -> Column: source : :class:`~pyspark.sql.Column` or str a date/timestamp or interval column from where `field` should be extracted. - Returns + Returns: ------- :class:`~pyspark.sql.Column` a part of the date/timestamp or interval source. - Examples + Examples: -------- >>> import datetime >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ["ts"]) @@ -3941,8 +3856,7 @@ def datepart(field: "ColumnOrName", source: "ColumnOrName") -> Column: def date_diff(end: "ColumnOrName", start: "ColumnOrName") -> Column: - """ - Returns the number of days from `start` to `end`. + """Returns the number of days from `start` to `end`. .. versionadded:: 3.5.0 @@ -3953,12 +3867,12 @@ def date_diff(end: "ColumnOrName", start: "ColumnOrName") -> Column: start : :class:`~pyspark.sql.Column` or column name from date column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` difference in days between two dates. - See Also + See Also: -------- :meth:`pyspark.sql.functions.dateadd` :meth:`pyspark.sql.functions.date_add` @@ -3966,7 +3880,7 @@ def date_diff(end: "ColumnOrName", start: "ColumnOrName") -> Column: :meth:`pyspark.sql.functions.datediff` :meth:`pyspark.sql.functions.timestamp_diff` - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> df = spark.createDataFrame([("2015-04-08", "2015-05-10")], ["d1", "d2"]) @@ -3992,8 +3906,7 @@ def date_diff(end: "ColumnOrName", start: "ColumnOrName") -> Column: def year(col: "ColumnOrName") -> Column: - """ - Extract the year of a given date/timestamp as integer. + """Extract the year of a given date/timestamp as integer. .. versionadded:: 1.5.0 @@ -4005,12 +3918,12 @@ def year(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target date/timestamp column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` year part of the date/timestamp as integer. - Examples + Examples: -------- >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) >>> df.select(year("dt").alias("year")).collect() @@ -4020,8 +3933,7 @@ def year(col: "ColumnOrName") -> Column: def quarter(col: "ColumnOrName") -> Column: - """ - Extract the quarter of a given date/timestamp as integer. + """Extract the quarter of a given date/timestamp as integer. .. versionadded:: 1.5.0 @@ -4033,12 +3945,12 @@ def quarter(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target date/timestamp column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` quarter of the date/timestamp as integer. - Examples + Examples: -------- >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) >>> df.select(quarter("dt").alias("quarter")).collect() @@ -4048,8 +3960,7 @@ def quarter(col: "ColumnOrName") -> Column: def month(col: "ColumnOrName") -> Column: - """ - Extract the month of a given date/timestamp as integer. + """Extract the month of a given date/timestamp as integer. .. versionadded:: 1.5.0 @@ -4061,12 +3972,12 @@ def month(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target date/timestamp column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` month part of the date/timestamp as integer. - Examples + Examples: -------- >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) >>> df.select(month("dt").alias("month")).collect() @@ -4076,8 +3987,7 @@ def month(col: "ColumnOrName") -> Column: def dayofweek(col: "ColumnOrName") -> Column: - """ - Extract the day of the week of a given date/timestamp as integer. + """Extract the day of the week of a given date/timestamp as integer. Ranges from 1 for a Sunday through to 7 for a Saturday .. versionadded:: 2.3.0 @@ -4090,12 +4000,12 @@ def dayofweek(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target date/timestamp column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` day of the week for given date/timestamp as integer. - Examples + Examples: -------- >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) >>> df.select(dayofweek("dt").alias("day")).collect() @@ -4105,8 +4015,7 @@ def dayofweek(col: "ColumnOrName") -> Column: def day(col: "ColumnOrName") -> Column: - """ - Extract the day of the month of a given date/timestamp as integer. + """Extract the day of the month of a given date/timestamp as integer. .. versionadded:: 3.5.0 @@ -4115,12 +4024,12 @@ def day(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target date/timestamp column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` day of the month for given date/timestamp as integer. - Examples + Examples: -------- >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) >>> df.select(day("dt").alias("day")).collect() @@ -4130,8 +4039,7 @@ def day(col: "ColumnOrName") -> Column: def dayofmonth(col: "ColumnOrName") -> Column: - """ - Extract the day of the month of a given date/timestamp as integer. + """Extract the day of the month of a given date/timestamp as integer. .. versionadded:: 1.5.0 @@ -4143,12 +4051,12 @@ def dayofmonth(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target date/timestamp column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` day of the month for given date/timestamp as integer. - Examples + Examples: -------- >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) >>> df.select(dayofmonth("dt").alias("day")).collect() @@ -4158,8 +4066,7 @@ def dayofmonth(col: "ColumnOrName") -> Column: def dayofyear(col: "ColumnOrName") -> Column: - """ - Extract the day of the year of a given date/timestamp as integer. + """Extract the day of the year of a given date/timestamp as integer. .. versionadded:: 1.5.0 @@ -4171,12 +4078,12 @@ def dayofyear(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target date/timestamp column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` day of the year for given date/timestamp as integer. - Examples + Examples: -------- >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) >>> df.select(dayofyear("dt").alias("day")).collect() @@ -4186,8 +4093,7 @@ def dayofyear(col: "ColumnOrName") -> Column: def hour(col: "ColumnOrName") -> Column: - """ - Extract the hours of a given timestamp as integer. + """Extract the hours of a given timestamp as integer. .. versionadded:: 1.5.0 @@ -4199,12 +4105,12 @@ def hour(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target date/timestamp column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` hour part of the timestamp as integer. - Examples + Examples: -------- >>> import datetime >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ["ts"]) @@ -4215,8 +4121,7 @@ def hour(col: "ColumnOrName") -> Column: def minute(col: "ColumnOrName") -> Column: - """ - Extract the minutes of a given timestamp as integer. + """Extract the minutes of a given timestamp as integer. .. versionadded:: 1.5.0 @@ -4228,12 +4133,12 @@ def minute(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target date/timestamp column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` minutes part of the timestamp as integer. - Examples + Examples: -------- >>> import datetime >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ["ts"]) @@ -4244,8 +4149,7 @@ def minute(col: "ColumnOrName") -> Column: def second(col: "ColumnOrName") -> Column: - """ - Extract the seconds of a given date as integer. + """Extract the seconds of a given date as integer. .. versionadded:: 1.5.0 @@ -4257,12 +4161,12 @@ def second(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target date/timestamp column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` `seconds` part of the timestamp as integer. - Examples + Examples: -------- >>> import datetime >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ["ts"]) @@ -4273,8 +4177,7 @@ def second(col: "ColumnOrName") -> Column: def weekofyear(col: "ColumnOrName") -> Column: - """ - Extract the week number of a given date as integer. + """Extract the week number of a given date as integer. A week is considered to start on a Monday and week 1 is the first week with more than 3 days, as defined by ISO 8601 @@ -4288,12 +4191,12 @@ def weekofyear(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target timestamp column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` `week` of the year for given date as integer. - Examples + Examples: -------- >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) >>> df.select(weekofyear(df.dt).alias("week")).collect() @@ -4303,8 +4206,7 @@ def weekofyear(col: "ColumnOrName") -> Column: def cos(col: "ColumnOrName") -> Column: - """ - Computes cosine of the input column. + """Computes cosine of the input column. .. versionadded:: 1.4.0 @@ -4316,12 +4218,12 @@ def cos(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str angle in radians - Returns + Returns: ------- :class:`~pyspark.sql.Column` cosine of the angle, as if computed by `java.lang.Math.cos()`. - Examples + Examples: -------- >>> import math >>> df = spark.range(1) @@ -4332,8 +4234,7 @@ def cos(col: "ColumnOrName") -> Column: def acos(col: "ColumnOrName") -> Column: - """ - Computes inverse cosine of the input column. + """Computes inverse cosine of the input column. .. versionadded:: 1.4.0 @@ -4345,12 +4246,12 @@ def acos(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` inverse cosine of `col`, as if computed by `java.lang.Math.acos()` - Examples + Examples: -------- >>> df = spark.range(1, 3) >>> df.select(acos(df.id)).show() @@ -4365,8 +4266,7 @@ def acos(col: "ColumnOrName") -> Column: def call_function(funcName: str, *cols: "ColumnOrName") -> Column: - """ - Call a SQL function. + """Call a SQL function. .. versionadded:: 3.5.0 @@ -4377,12 +4277,12 @@ def call_function(funcName: str, *cols: "ColumnOrName") -> Column: cols : :class:`~pyspark.sql.Column` or str column names or :class:`~pyspark.sql.Column`\\s to be used in the function - Returns + Returns: ------- :class:`~pyspark.sql.Column` result of executed function. - Examples + Examples: -------- >>> from pyspark.sql.functions import call_udf, col >>> from pyspark.sql.types import IntegerType, StringType @@ -4447,12 +4347,12 @@ def covar_pop(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: col1 : :class:`~pyspark.sql.Column` or str second column to calculate covariance. - Returns + Returns: ------- :class:`~pyspark.sql.Column` covariance of these two column values. - Examples + Examples: -------- >>> a = [1] * 10 >>> b = [1] * 10 @@ -4479,12 +4379,12 @@ def covar_samp(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: col1 : :class:`~pyspark.sql.Column` or str second column to calculate covariance. - Returns + Returns: ------- :class:`~pyspark.sql.Column` sample covariance of these two column values. - Examples + Examples: -------- >>> a = [1] * 10 >>> b = [1] * 10 @@ -4496,8 +4396,7 @@ def covar_samp(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: def exp(col: "ColumnOrName") -> Column: - """ - Computes the exponential of the given value. + """Computes the exponential of the given value. .. versionadded:: 1.4.0 @@ -4509,12 +4408,12 @@ def exp(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str column to calculate exponential for. - Returns + Returns: ------- :class:`~pyspark.sql.Column` exponential of the given value. - Examples + Examples: -------- >>> df = spark.range(1) >>> df.select(exp(lit(0))).show() @@ -4528,8 +4427,7 @@ def exp(col: "ColumnOrName") -> Column: def factorial(col: "ColumnOrName") -> Column: - """ - Computes the factorial of the given value. + """Computes the factorial of the given value. .. versionadded:: 1.5.0 @@ -4541,12 +4439,12 @@ def factorial(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str a column to calculate factorial for. - Returns + Returns: ------- :class:`~pyspark.sql.Column` factorial of given value. - Examples + Examples: -------- >>> df = spark.createDataFrame([(5,)], ["n"]) >>> df.select(factorial(df.n).alias("f")).collect() @@ -4568,12 +4466,12 @@ def log2(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str a column to calculate logariphm for. - Returns + Returns: ------- :class:`~pyspark.sql.Column` logariphm of given value. - Examples + Examples: -------- >>> df = spark.createDataFrame([(4,)], ["a"]) >>> df.select(log2("a").alias("log2")).show() @@ -4596,12 +4494,12 @@ def ln(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str a column to calculate logariphm for. - Returns + Returns: ------- :class:`~pyspark.sql.Column` natural logarithm of given value. - Examples + Examples: -------- >>> df = spark.createDataFrame([(4,)], ["a"]) >>> df.select(ln("a")).show() @@ -4615,8 +4513,7 @@ def ln(col: "ColumnOrName") -> Column: def degrees(col: "ColumnOrName") -> Column: - """ - Converts an angle measured in radians to an approximately equivalent angle + """Converts an angle measured in radians to an approximately equivalent angle measured in degrees. .. versionadded:: 2.1.0 @@ -4629,12 +4526,12 @@ def degrees(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str angle in radians - Returns + Returns: ------- :class:`~pyspark.sql.Column` angle in degrees, as if computed by `java.lang.Math.toDegrees()` - Examples + Examples: -------- >>> import math >>> df = spark.range(1) @@ -4645,8 +4542,7 @@ def degrees(col: "ColumnOrName") -> Column: def radians(col: "ColumnOrName") -> Column: - """ - Converts an angle measured in degrees to an approximately equivalent angle + """Converts an angle measured in degrees to an approximately equivalent angle measured in radians. .. versionadded:: 2.1.0 @@ -4659,12 +4555,12 @@ def radians(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str angle in degrees - Returns + Returns: ------- :class:`~pyspark.sql.Column` angle in radians, as if computed by `java.lang.Math.toRadians()` - Examples + Examples: -------- >>> df = spark.range(1) >>> df.select(radians(lit(180))).first() @@ -4674,8 +4570,7 @@ def radians(col: "ColumnOrName") -> Column: def atan(col: "ColumnOrName") -> Column: - """ - Compute inverse tangent of the input column. + """Compute inverse tangent of the input column. .. versionadded:: 1.4.0 @@ -4687,12 +4582,12 @@ def atan(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` inverse tangent of `col`, as if computed by `java.lang.Math.atan()` - Examples + Examples: -------- >>> df = spark.range(1) >>> df.select(atan(df.id)).show() @@ -4706,8 +4601,7 @@ def atan(col: "ColumnOrName") -> Column: def atan2(col1: Union["ColumnOrName", float], col2: Union["ColumnOrName", float]) -> Column: - """ - .. versionadded:: 1.4.0 + """.. versionadded:: 1.4.0 .. versionchanged:: 3.4.0 Supports Spark Connect. @@ -4719,7 +4613,7 @@ def atan2(col1: Union["ColumnOrName", float], col2: Union["ColumnOrName", float] col2 : str, :class:`~pyspark.sql.Column` or float coordinate on x-axis - Returns + Returns: ------- :class:`~pyspark.sql.Column` the `theta` component of the point @@ -4728,7 +4622,7 @@ def atan2(col1: Union["ColumnOrName", float], col2: Union["ColumnOrName", float] (`x`, `y`) in Cartesian coordinates, as if computed by `java.lang.Math.atan2()` - Examples + Examples: -------- >>> df = spark.range(1) >>> df.select(atan2(lit(1), lit(2))).first() @@ -4744,8 +4638,7 @@ def lit_or_column(x: Union["ColumnOrName", float]) -> Column: def tan(col: "ColumnOrName") -> Column: - """ - Computes tangent of the input column. + """Computes tangent of the input column. .. versionadded:: 1.4.0 @@ -4757,12 +4650,12 @@ def tan(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str angle in radians - Returns + Returns: ------- :class:`~pyspark.sql.Column` tangent of the given value, as if computed by `java.lang.Math.tan()` - Examples + Examples: -------- >>> import math >>> df = spark.range(1) @@ -4773,8 +4666,7 @@ def tan(col: "ColumnOrName") -> Column: def round(col: "ColumnOrName", scale: int = 0) -> Column: - """ - Round the given value to `scale` decimal places using HALF_UP rounding mode if `scale` >= 0 + """Round the given value to `scale` decimal places using HALF_UP rounding mode if `scale` >= 0 or at integral part when `scale` < 0. .. versionadded:: 1.5.0 @@ -4789,12 +4681,12 @@ def round(col: "ColumnOrName", scale: int = 0) -> Column: scale : int optional default 0 scale value. - Returns + Returns: ------- :class:`~pyspark.sql.Column` rounded values. - Examples + Examples: -------- >>> spark.createDataFrame([(2.5,)], ["a"]).select(round("a", 0).alias("r")).collect() [Row(r=3.0)] @@ -4803,8 +4695,7 @@ def round(col: "ColumnOrName", scale: int = 0) -> Column: def bround(col: "ColumnOrName", scale: int = 0) -> Column: - """ - Round the given value to `scale` decimal places using HALF_EVEN rounding mode if `scale` >= 0 + """Round the given value to `scale` decimal places using HALF_EVEN rounding mode if `scale` >= 0 or at integral part when `scale` < 0. .. versionadded:: 2.0.0 @@ -4819,12 +4710,12 @@ def bround(col: "ColumnOrName", scale: int = 0) -> Column: scale : int optional default 0 scale value. - Returns + Returns: ------- :class:`~pyspark.sql.Column` rounded values. - Examples + Examples: -------- >>> spark.createDataFrame([(2.5,)], ["a"]).select(bround("a", 0).alias("r")).collect() [Row(r=2.0)] @@ -4833,8 +4724,7 @@ def bround(col: "ColumnOrName", scale: int = 0) -> Column: def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: - """ - Collection function: Returns element of array at given (0-based) index. + """Collection function: Returns element of array at given (0-based) index. If the index points outside of the array boundaries, then this function returns NULL. @@ -4847,21 +4737,21 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: index : :class:`~pyspark.sql.Column` or str or int index to check for in array - Returns + Returns: ------- :class:`~pyspark.sql.Column` value at given position. - Notes + Notes: ----- The position is not 1 based, but 0 based index. Supports Spark Connect. - See Also + See Also: -------- :meth:`element_at` - Examples + Examples: -------- >>> df = spark.createDataFrame([(["a", "b", "c"], 1)], ["data", "index"]) >>> df.select(get(df.data, 1)).show() @@ -4919,12 +4809,12 @@ def initcap(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` string with all first letters are uppercase in each word. - Examples + Examples: -------- >>> spark.createDataFrame([("ab cd",)], ["a"]).select(initcap("a").alias("v")).collect() [Row(v='Ab Cd')] @@ -4953,8 +4843,7 @@ def initcap(col: "ColumnOrName") -> Column: def octet_length(col: "ColumnOrName") -> Column: - """ - Calculates the byte length for the specified string column. + """Calculates the byte length for the specified string column. .. versionadded:: 3.3.0 @@ -4966,12 +4855,12 @@ def octet_length(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str Source column or strings - Returns + Returns: ------- :class:`~pyspark.sql.Column` Byte length of the col - Examples + Examples: -------- >>> from pyspark.sql.functions import octet_length >>> spark.createDataFrame([('cat',), ( '\U0001f408',)], ['cat']) \\ @@ -4982,8 +4871,7 @@ def octet_length(col: "ColumnOrName") -> Column: def hex(col: "ColumnOrName") -> Column: - """ - Computes hex value of the given column, which could be :class:`~pyspark.sql.types.StringType`, :class:`~pyspark.sql.types.BinaryType`, :class:`~pyspark.sql.types.IntegerType` or :class:`~pyspark.sql.types.LongType`. + """Computes hex value of the given column, which could be :class:`~pyspark.sql.types.StringType`, :class:`~pyspark.sql.types.BinaryType`, :class:`~pyspark.sql.types.IntegerType` or :class:`~pyspark.sql.types.LongType`. .. versionadded:: 1.5.0 @@ -4995,12 +4883,12 @@ def hex(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` hexadecimal representation of given value as string. - Examples + Examples: -------- >>> spark.createDataFrame([("ABC", 3)], ["a", "b"]).select(hex("a"), hex("b")).collect() [Row(hex(a)='414243', hex(b)='3')] @@ -5009,8 +4897,7 @@ def hex(col: "ColumnOrName") -> Column: def unhex(col: "ColumnOrName") -> Column: - """ - Inverse of hex. Interprets each pair of characters as a hexadecimal number and converts to the byte representation of number. column and returns it as a binary column. + """Inverse of hex. Interprets each pair of characters as a hexadecimal number and converts to the byte representation of number. column and returns it as a binary column. .. versionadded:: 1.5.0 @@ -5022,12 +4909,12 @@ def unhex(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` string representation of given hexadecimal value. - Examples + Examples: -------- >>> spark.createDataFrame([("414243",)], ["a"]).select(unhex("a")).collect() [Row(unhex(a)=bytearray(b'ABC'))] @@ -5036,8 +4923,7 @@ def unhex(col: "ColumnOrName") -> Column: def base64(col: "ColumnOrName") -> Column: - """ - Computes the BASE64 encoding of a binary column and returns it as a string column. + """Computes the BASE64 encoding of a binary column and returns it as a string column. .. versionadded:: 1.5.0 @@ -5049,12 +4935,12 @@ def base64(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` BASE64 encoding of string value. - Examples + Examples: -------- >>> df = spark.createDataFrame(["Spark", "PySpark", "Pandas API"], "STRING") >>> df.select(base64("value")).show() @@ -5072,8 +4958,7 @@ def base64(col: "ColumnOrName") -> Column: def unbase64(col: "ColumnOrName") -> Column: - """ - Decodes a BASE64 encoded string column and returns it as a binary column. + """Decodes a BASE64 encoded string column and returns it as a binary column. .. versionadded:: 1.5.0 @@ -5085,12 +4970,12 @@ def unbase64(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` encoded string value. - Examples + Examples: -------- >>> df = spark.createDataFrame(["U3Bhcms=", "UHlTcGFyaw==", "UGFuZGFzIEFQSQ=="], "STRING") >>> df.select(unbase64("value")).show() @@ -5106,8 +4991,7 @@ def unbase64(col: "ColumnOrName") -> Column: def add_months(start: "ColumnOrName", months: Union["ColumnOrName", int]) -> Column: - """ - Returns the date that is `months` months after `start`. If `months` is a negative value + """Returns the date that is `months` months after `start`. If `months` is a negative value then these amount of months will be deducted from the `start`. .. versionadded:: 1.5.0 @@ -5123,12 +5007,12 @@ def add_months(start: "ColumnOrName", months: Union["ColumnOrName", int]) -> Col how many months after the given date to calculate. Accepts negative value as well to calculate backwards. - Returns + Returns: ------- :class:`~pyspark.sql.Column` a date after/before given number of months. - Examples + Examples: -------- >>> df = spark.createDataFrame([("2015-04-08", 2)], ["dt", "add"]) >>> df.select(add_months(df.dt, 1).alias("next_month")).collect() @@ -5143,8 +5027,7 @@ def add_months(start: "ColumnOrName", months: Union["ColumnOrName", int]) -> Col def array_join(col: "ColumnOrName", delimiter: str, null_replacement: Optional[str] = None) -> Column: - """ - Concatenates the elements of `column` using the `delimiter`. Null values are replaced with + """Concatenates the elements of `column` using the `delimiter`. Null values are replaced with `null_replacement` if set, otherwise they are ignored. .. versionadded:: 2.4.0 @@ -5161,12 +5044,12 @@ def array_join(col: "ColumnOrName", delimiter: str, null_replacement: Optional[s null_replacement : str, optional if set then null values will be replaced by this value - Returns + Returns: ------- :class:`~pyspark.sql.Column` a column of string type. Concatenated values. - Examples + Examples: -------- >>> df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], ["data"]) >>> df.select(array_join(df.data, ",").alias("joined")).collect() @@ -5190,8 +5073,7 @@ def array_join(col: "ColumnOrName", delimiter: str, null_replacement: Optional[s def array_position(col: "ColumnOrName", value: Any) -> Column: - """ - Collection function: Locates the position of the first occurrence of the given value + """Collection function: Locates the position of the first occurrence of the given value in the given array. Returns null if either of the arguments are null. .. versionadded:: 2.4.0 @@ -5199,7 +5081,7 @@ def array_position(col: "ColumnOrName", value: Any) -> Column: .. versionchanged:: 3.4.0 Supports Spark Connect. - Notes + Notes: ----- The position is not zero based, but 1 based index. Returns 0 if the given value could not be found in the array. @@ -5211,12 +5093,12 @@ def array_position(col: "ColumnOrName", value: Any) -> Column: value : Any value to look for. - Returns + Returns: ------- :class:`~pyspark.sql.Column` position of the value in the given array if found and 0 otherwise. - Examples + Examples: -------- >>> df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ["data"]) >>> df.select(array_position(df.data, "a")).collect() @@ -5230,8 +5112,7 @@ def array_position(col: "ColumnOrName", value: Any) -> Column: def array_prepend(col: "ColumnOrName", value: Any) -> Column: - """ - Collection function: Returns an array containing element as + """Collection function: Returns an array containing element as well as all elements from array. The new element is positioned at the beginning of the array. @@ -5244,12 +5125,12 @@ def array_prepend(col: "ColumnOrName", value: Any) -> Column: value : a literal value, or a :class:`~pyspark.sql.Column` expression. - Returns + Returns: ------- :class:`~pyspark.sql.Column` an array excluding given value. - Examples + Examples: -------- >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ["data"]) >>> df.select(array_prepend(df.data, 1)).collect() @@ -5259,8 +5140,7 @@ def array_prepend(col: "ColumnOrName", value: Any) -> Column: def array_repeat(col: "ColumnOrName", count: Union["ColumnOrName", int]) -> Column: - """ - Collection function: creates an array containing a column repeated count times. + """Collection function: creates an array containing a column repeated count times. .. versionadded:: 2.4.0 @@ -5274,12 +5154,12 @@ def array_repeat(col: "ColumnOrName", count: Union["ColumnOrName", int]) -> Colu count : :class:`~pyspark.sql.Column` or str or int column name, column, or int containing the number of times to repeat the first argument - Returns + Returns: ------- :class:`~pyspark.sql.Column` an array of repeated elements. - Examples + Examples: -------- >>> df = spark.createDataFrame([("ab",)], ["data"]) >>> df.select(array_repeat(df.data, 3).alias("r")).collect() @@ -5291,8 +5171,7 @@ def array_repeat(col: "ColumnOrName", count: Union["ColumnOrName", int]) -> Colu def array_size(col: "ColumnOrName") -> Column: - """ - Returns the total number of elements in the array. The function returns null for null input. + """Returns the total number of elements in the array. The function returns null for null input. .. versionadded:: 3.5.0 @@ -5301,12 +5180,12 @@ def array_size(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` total number of elements in the array. - Examples + Examples: -------- >>> df = spark.createDataFrame([([2, 1, 3],), (None,)], ["data"]) >>> df.select(array_size(df.data).alias("r")).collect() @@ -5316,8 +5195,7 @@ def array_size(col: "ColumnOrName") -> Column: def array_sort(col: "ColumnOrName", comparator: Optional[Callable[[Column, Column], Column]] = None) -> Column: - """ - Collection function: sorts the input array in ascending order. The elements of the input array + """Collection function: sorts the input array in ascending order. The elements of the input array must be orderable. Null elements will be placed at the end of the returned array. .. versionadded:: 2.4.0 @@ -5339,12 +5217,12 @@ def array_sort(col: "ColumnOrName", comparator: Optional[Callable[[Column, Colum positive integer as the first element is less than, equal to, or greater than the second element. If the comparator function returns null, the function will fail and raise an error. - Returns + Returns: ------- :class:`~pyspark.sql.Column` sorted array. - Examples + Examples: -------- >>> df = spark.createDataFrame([([2, 1, None, 3],), ([1],), ([],)], ["data"]) >>> df.select(array_sort(df.data).alias("r")).collect() @@ -5369,8 +5247,7 @@ def array_sort(col: "ColumnOrName", comparator: Optional[Callable[[Column, Colum def sort_array(col: "ColumnOrName", asc: bool = True) -> Column: - """ - Collection function: sorts the input array in ascending or descending order according + """Collection function: sorts the input array in ascending or descending order according to the natural ordering of the array elements. Null elements will be placed at the beginning of the returned array in ascending order or at the end of the returned array in descending order. @@ -5388,12 +5265,12 @@ def sort_array(col: "ColumnOrName", asc: bool = True) -> Column: whether to sort in ascending or descending order. If `asc` is True (default) then ascending and if False then descending. - Returns + Returns: ------- :class:`~pyspark.sql.Column` sorted array. - Examples + Examples: -------- >>> df = spark.createDataFrame([([2, 1, None, 3],), ([1],), ([],)], ["data"]) >>> df.select(sort_array(df.data).alias("r")).collect() @@ -5411,8 +5288,7 @@ def sort_array(col: "ColumnOrName", asc: bool = True) -> Column: def split(str: "ColumnOrName", pattern: str, limit: int = -1) -> Column: - """ - Splits str around matches of the given pattern. + """Splits str around matches of the given pattern. .. versionadded:: 1.5.0 @@ -5438,12 +5314,12 @@ def split(str: "ColumnOrName", pattern: str, limit: int = -1) -> Column: .. versionchanged:: 3.0 `split` now takes an optional `limit` field. If not provided, default limit value is -1. - Returns + Returns: ------- :class:`~pyspark.sql.Column` array of separated strings. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [("oneAtwoBthreeC",)], @@ -5464,8 +5340,7 @@ def split(str: "ColumnOrName", pattern: str, limit: int = -1) -> Column: def split_part(src: "ColumnOrName", delimiter: "ColumnOrName", partNum: "ColumnOrName") -> Column: - """ - Splits `str` by delimiter and return requested part of the split (1-based). + """Splits `str` by delimiter and return requested part of the split (1-based). If any input is null, returns null. if `partNum` is out of range of split parts, returns empty string. If `partNum` is 0, throws an error. If `partNum` is negative, the parts are counted backward from the end of the string. @@ -5482,7 +5357,7 @@ def split_part(src: "ColumnOrName", delimiter: "ColumnOrName", partNum: "ColumnO partNum : :class:`~pyspark.sql.Column` or str A column of string, requested part of the split (1-based). - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [ @@ -5510,8 +5385,7 @@ def split_part(src: "ColumnOrName", delimiter: "ColumnOrName", partNum: "ColumnO def stddev_samp(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns the unbiased sample standard deviation of + """Aggregate function: returns the unbiased sample standard deviation of the expression in a group. .. versionadded:: 1.6.0 @@ -5524,12 +5398,12 @@ def stddev_samp(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` standard deviation of given column. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.range(6).select(sf.stddev_samp("id")).show() @@ -5543,8 +5417,7 @@ def stddev_samp(col: "ColumnOrName") -> Column: def stddev(col: "ColumnOrName") -> Column: - """ - Aggregate function: alias for stddev_samp. + """Aggregate function: alias for stddev_samp. .. versionadded:: 1.6.0 @@ -5556,12 +5429,12 @@ def stddev(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` standard deviation of given column. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.range(6).select(sf.stddev("id")).show() @@ -5575,8 +5448,7 @@ def stddev(col: "ColumnOrName") -> Column: def std(col: "ColumnOrName") -> Column: - """ - Aggregate function: alias for stddev_samp. + """Aggregate function: alias for stddev_samp. .. versionadded:: 3.5.0 @@ -5585,12 +5457,12 @@ def std(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` standard deviation of given column. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.range(6).select(sf.std("id")).show() @@ -5604,8 +5476,7 @@ def std(col: "ColumnOrName") -> Column: def stddev_pop(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns population standard deviation of + """Aggregate function: returns population standard deviation of the expression in a group. .. versionadded:: 1.6.0 @@ -5618,12 +5489,12 @@ def stddev_pop(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` standard deviation of given column. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.range(6).select(sf.stddev_pop("id")).show() @@ -5637,8 +5508,7 @@ def stddev_pop(col: "ColumnOrName") -> Column: def var_pop(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns the population variance of the values in a group. + """Aggregate function: returns the population variance of the values in a group. .. versionadded:: 1.6.0 @@ -5650,12 +5520,12 @@ def var_pop(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` variance of given column. - Examples + Examples: -------- >>> df = spark.range(6) >>> df.select(var_pop(df.id)).first() @@ -5665,8 +5535,7 @@ def var_pop(col: "ColumnOrName") -> Column: def var_samp(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns the unbiased sample variance of + """Aggregate function: returns the unbiased sample variance of the values in a group. .. versionadded:: 1.6.0 @@ -5679,12 +5548,12 @@ def var_samp(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` variance of given column. - Examples + Examples: -------- >>> df = spark.range(6) >>> df.select(var_samp(df.id)).show() @@ -5698,8 +5567,7 @@ def var_samp(col: "ColumnOrName") -> Column: def variance(col: "ColumnOrName") -> Column: - """ - Aggregate function: alias for var_samp + """Aggregate function: alias for var_samp .. versionadded:: 1.6.0 @@ -5711,12 +5579,12 @@ def variance(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` variance of given column. - Examples + Examples: -------- >>> df = spark.range(6) >>> df.select(variance(df.id)).show() @@ -5730,8 +5598,7 @@ def variance(col: "ColumnOrName") -> Column: def weekday(col: "ColumnOrName") -> Column: - """ - Returns the day of the week for date/timestamp (0 = Monday, 1 = Tuesday, ..., 6 = Sunday). + """Returns the day of the week for date/timestamp (0 = Monday, 1 = Tuesday, ..., 6 = Sunday). .. versionadded:: 3.5.0 @@ -5740,12 +5607,12 @@ def weekday(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target date/timestamp column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the day of the week for date/timestamp (0 = Monday, 1 = Tuesday, ..., 6 = Sunday). - Examples + Examples: -------- >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) >>> df.select(weekday("dt").alias("day")).show() @@ -5759,8 +5626,7 @@ def weekday(col: "ColumnOrName") -> Column: def zeroifnull(col: "ColumnOrName") -> Column: - """ - Returns zero if `col` is null, or `col` otherwise. + """Returns zero if `col` is null, or `col` otherwise. .. versionadded:: 4.0.0 @@ -5768,7 +5634,7 @@ def zeroifnull(col: "ColumnOrName") -> Column: ---------- col : :class:`~pyspark.sql.Column` or str - Examples + Examples: -------- >>> df = spark.createDataFrame([(None,), (1,)], ["a"]) >>> df.select(zeroifnull(df.a).alias("result")).show() @@ -5811,12 +5677,12 @@ def to_date(col: "ColumnOrName", format: Optional[str] = None) -> Column: format: str, optional format to use to convert date values. - Returns + Returns: ------- :class:`~pyspark.sql.Column` date value as :class:`pyspark.sql.types.DateType` type. - Examples + Examples: -------- >>> df = spark.createDataFrame([("1997-02-28 10:30:00",)], ["t"]) >>> df.select(to_date(df.t).alias("date")).collect() @@ -5849,12 +5715,12 @@ def to_timestamp(col: "ColumnOrName", format: Optional[str] = None) -> Column: format: str, optional format to use to convert timestamp values. - Returns + Returns: ------- :class:`~pyspark.sql.Column` timestamp value as :class:`pyspark.sql.types.TimestampType` type. - Examples + Examples: -------- >>> df = spark.createDataFrame([("1997-02-28 10:30:00",)], ["t"]) >>> df.select(to_timestamp(df.t).alias("dt")).collect() @@ -5871,8 +5737,7 @@ def to_timestamp_ltz( timestamp: "ColumnOrName", format: Optional["ColumnOrName"] = None, ) -> Column: - """ - Parses the `timestamp` with the `format` to a timestamp without time zone. + """Parses the `timestamp` with the `format` to a timestamp without time zone. Returns null with invalid input. .. versionadded:: 3.5.0 @@ -5884,7 +5749,7 @@ def to_timestamp_ltz( format : :class:`~pyspark.sql.Column` or str, optional format to use to convert type `TimestampType` timestamp values. - Examples + Examples: -------- >>> df = spark.createDataFrame([("2016-12-31",)], ["e"]) >>> df.select(to_timestamp_ltz(df.e, lit("yyyy-MM-dd")).alias("r")).collect() @@ -5903,8 +5768,7 @@ def to_timestamp_ntz( timestamp: "ColumnOrName", format: Optional["ColumnOrName"] = None, ) -> Column: - """ - Parses the `timestamp` with the `format` to a timestamp without time zone. + """Parses the `timestamp` with the `format` to a timestamp without time zone. Returns null with invalid input. .. versionadded:: 3.5.0 @@ -5916,7 +5780,7 @@ def to_timestamp_ntz( format : :class:`~pyspark.sql.Column` or str, optional format to use to convert type `TimestampNTZType` timestamp values. - Examples + Examples: -------- >>> df = spark.createDataFrame([("2016-04-08",)], ["e"]) >>> df.select(to_timestamp_ntz(df.e, lit("yyyy-MM-dd")).alias("r")).collect() @@ -5932,8 +5796,7 @@ def to_timestamp_ntz( def try_to_timestamp(col: "ColumnOrName", format: Optional["ColumnOrName"] = None) -> Column: - """ - Parses the `col` with the `format` to a timestamp. The function always + """Parses the `col` with the `format` to a timestamp. The function always returns null on an invalid input with/without ANSI SQL mode enabled. The result data type is consistent with the value of configuration `spark.sql.timestampType`. .. versionadded:: 3.5.0 @@ -5943,7 +5806,8 @@ def try_to_timestamp(col: "ColumnOrName", format: Optional["ColumnOrName"] = Non column values to convert. format: str, optional format to use to convert timestamp values. - Examples + + Examples: -------- >>> df = spark.createDataFrame([("1997-02-28 10:30:00",)], ["t"]) >>> df.select(try_to_timestamp(df.t).alias("dt")).collect() @@ -5958,8 +5822,7 @@ def try_to_timestamp(col: "ColumnOrName", format: Optional["ColumnOrName"] = Non def substr(str: "ColumnOrName", pos: "ColumnOrName", len: Optional["ColumnOrName"] = None) -> Column: - """ - Returns the substring of `str` that starts at `pos` and is of length `len`, + """Returns the substring of `str` that starts at `pos` and is of length `len`, or the slice of byte array that starts at `pos` and is of length `len`. .. versionadded:: 3.5.0 @@ -5973,7 +5836,7 @@ def substr(str: "ColumnOrName", pos: "ColumnOrName", len: Optional["ColumnOrName len : :class:`~pyspark.sql.Column` or str, optional A column of string, the substring of `str` is of length `len`. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.createDataFrame( @@ -6026,7 +5889,7 @@ def unix_date(col: "ColumnOrName") -> Column: .. versionadded:: 3.5.0 - Examples + Examples: -------- >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") >>> df = spark.createDataFrame([("1970-01-02",)], ["t"]) @@ -6042,7 +5905,7 @@ def unix_micros(col: "ColumnOrName") -> Column: .. versionadded:: 3.5.0 - Examples + Examples: -------- >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") >>> df = spark.createDataFrame([("2015-07-22 10:00:00",)], ["t"]) @@ -6059,7 +5922,7 @@ def unix_millis(col: "ColumnOrName") -> Column: .. versionadded:: 3.5.0 - Examples + Examples: -------- >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") >>> df = spark.createDataFrame([("2015-07-22 10:00:00",)], ["t"]) @@ -6076,7 +5939,7 @@ def unix_seconds(col: "ColumnOrName") -> Column: .. versionadded:: 3.5.0 - Examples + Examples: -------- >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") >>> df = spark.createDataFrame([("2015-07-22 10:00:00",)], ["t"]) @@ -6088,8 +5951,7 @@ def unix_seconds(col: "ColumnOrName") -> Column: def arrays_overlap(a1: "ColumnOrName", a2: "ColumnOrName") -> Column: - """ - Collection function: returns true if the arrays contain any common non-null element; if not, + """Collection function: returns true if the arrays contain any common non-null element; if not, returns null if both the arrays are non-empty and any of them contains a null element; returns false otherwise. @@ -6098,12 +5960,12 @@ def arrays_overlap(a1: "ColumnOrName", a2: "ColumnOrName") -> Column: .. versionchanged:: 3.4.0 Supports Spark Connect. - Returns + Returns: ------- :class:`~pyspark.sql.Column` a column of Boolean type. - Examples + Examples: -------- >>> df = spark.createDataFrame([(["a", "b"], ["b", "c"]), (["a"], ["b", "c"])], ["x", "y"]) >>> df.select(arrays_overlap(df.x, df.y).alias("overlap")).collect() @@ -6134,8 +5996,7 @@ def _list_contains_null(c: ColumnExpression) -> Expression: def arrays_zip(*cols: "ColumnOrName") -> Column: - """ - Collection function: Returns a merged array of structs in which the N-th struct contains all + """Collection function: Returns a merged array of structs in which the N-th struct contains all N-th values of input arrays. If one of the arrays is shorter than others then resulting struct type value will be a `null` for missing elements. @@ -6149,12 +6010,12 @@ def arrays_zip(*cols: "ColumnOrName") -> Column: cols : :class:`~pyspark.sql.Column` or str columns of arrays to be merged. - Returns + Returns: ------- :class:`~pyspark.sql.Column` merged array of entries. - Examples + Examples: -------- >>> from pyspark.sql.functions import arrays_zip >>> df = spark.createDataFrame( @@ -6179,14 +6040,14 @@ def arrays_zip(*cols: "ColumnOrName") -> Column: def substring(str: "ColumnOrName", pos: int, len: int) -> Column: - """ - Substring starts at `pos` and is of length `len` when str is String type or + """Substring starts at `pos` and is of length `len` when str is String type or returns the slice of byte array that starts at `pos` in byte and is of length `len` when str is Binary type. .. versionadded:: 1.5.0 .. versionchanged:: 3.4.0 Supports Spark Connect. - Notes + + Notes: ----- The position is not zero based, but 1 based index. Parameters @@ -6197,11 +6058,13 @@ def substring(str: "ColumnOrName", pos: int, len: int) -> Column: starting position in str. len : int length of chars. - Returns + + Returns: ------- :class:`~pyspark.sql.Column` substring of given value. - Examples + + Examples: -------- >>> df = spark.createDataFrame( ... [("abcd",)], @@ -6221,8 +6084,7 @@ def substring(str: "ColumnOrName", pos: int, len: int) -> Column: def contains(left: "ColumnOrName", right: "ColumnOrName") -> Column: - """ - Returns a boolean. The value is True if right is found inside left. + """Returns a boolean. The value is True if right is found inside left. Returns NULL if either input expression is NULL. Otherwise, returns False. Both left or right must be of STRING or BINARY type. .. versionadded:: 3.5.0 @@ -6232,7 +6094,8 @@ def contains(left: "ColumnOrName", right: "ColumnOrName") -> Column: The input column or strings to check, may be NULL. right : :class:`~pyspark.sql.Column` or str The input column or strings to find, may be NULL. - Examples + + Examples: -------- >>> df = spark.createDataFrame([("Spark SQL", "Spark")], ["a", "b"]) >>> df.select(contains(df.a, df.b).alias("r")).collect() @@ -6262,8 +6125,7 @@ def contains(left: "ColumnOrName", right: "ColumnOrName") -> Column: def reverse(col: "ColumnOrName") -> Column: - """ - Collection function: returns a reversed string or an array with reverse order of elements. + """Collection function: returns a reversed string or an array with reverse order of elements. .. versionadded:: 1.5.0 .. versionchanged:: 3.4.0 Supports Spark Connect. @@ -6271,11 +6133,13 @@ def reverse(col: "ColumnOrName") -> Column: ---------- col : :class:`~pyspark.sql.Column` or str name of column or expression - Returns + + Returns: ------- :class:`~pyspark.sql.Column` array of elements in reverse order. - Examples + + Examples: -------- >>> df = spark.createDataFrame([("Spark SQL",)], ["data"]) >>> df.select(reverse(df.data).alias("s")).collect() @@ -6288,8 +6152,7 @@ def reverse(col: "ColumnOrName") -> Column: def concat(*cols: "ColumnOrName") -> Column: - """ - Concatenates multiple input columns together into a single column. + """Concatenates multiple input columns together into a single column. The function works with strings, numeric, binary and compatible array columns. .. versionadded:: 1.5.0 .. versionchanged:: 3.4.0 @@ -6298,14 +6161,17 @@ def concat(*cols: "ColumnOrName") -> Column: ---------- cols : :class:`~pyspark.sql.Column` or str target column or columns to work on. - Returns + + Returns: ------- :class:`~pyspark.sql.Column` concatenated values. Type of the `Column` depends on input columns' type. - See Also + + See Also: -------- :meth:`pyspark.sql.functions.array_join` : to concatenate string columns with delimiter - Examples + + Examples: -------- >>> df = spark.createDataFrame([("abcd", "123")], ["s", "d"]) >>> df = df.select(concat(df.s, df.d).alias("s")) @@ -6326,8 +6192,7 @@ def concat(*cols: "ColumnOrName") -> Column: def instr(str: "ColumnOrName", substr: str) -> Column: - """ - Locate the position of the first occurrence of substr column in the given string. + """Locate the position of the first occurrence of substr column in the given string. Returns null if either of the arguments are null. .. versionadded:: 1.5.0 @@ -6335,7 +6200,7 @@ def instr(str: "ColumnOrName", substr: str) -> Column: .. versionchanged:: 3.4.0 Supports Spark Connect. - Notes + Notes: ----- The position is not zero based, but 1 based index. Returns 0 if substr could not be found in str. @@ -6347,12 +6212,12 @@ def instr(str: "ColumnOrName", substr: str) -> Column: substr : str substring to look for. - Returns + Returns: ------- :class:`~pyspark.sql.Column` location of the first occurrence of the substring as integer. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [("abcd",)], @@ -6379,12 +6244,12 @@ def expr(str: str) -> Column: str : str expression defined in string. - Returns + Returns: ------- :class:`~pyspark.sql.Column` column representing the expression. - Examples + Examples: -------- >>> df = spark.createDataFrame([["Alice"], ["Bob"]], ["name"]) >>> df.select("name", expr("length(name)")).show() @@ -6399,8 +6264,7 @@ def expr(str: str) -> Column: def broadcast(df: "DataFrame") -> "DataFrame": - """ - The broadcast function in Spark is used to optimize joins by broadcasting a smaller + """The broadcast function in Spark is used to optimize joins by broadcasting a smaller dataset to all the worker nodes. However, DuckDB operates on a single-node architecture . As a result, the function simply returns the input DataFrame without applying any modifications or optimizations, since broadcasting is not applicable in the DuckDB context. diff --git a/duckdb/experimental/spark/sql/group.py b/duckdb/experimental/spark/sql/group.py index 29210e29..7aa9eb11 100644 --- a/duckdb/experimental/spark/sql/group.py +++ b/duckdb/experimental/spark/sql/group.py @@ -15,19 +15,16 @@ # limitations under the License. # -from ..exception import ContributionsAcceptedError -from typing import Callable, TYPE_CHECKING, overload, Dict, Union, List +from typing import Callable, Union, overload +from ..exception import ContributionsAcceptedError +from ._typing import ColumnOrName from .column import Column -from .session import SparkSession from .dataframe import DataFrame from .functions import _to_column_expr -from ._typing import ColumnOrName +from .session import SparkSession from .types import NumericType -if TYPE_CHECKING: - from ._typing import LiteralType - __all__ = ["GroupedData", "Grouping"] @@ -35,7 +32,7 @@ def _api_internal(self: "GroupedData", name: str, *cols: str) -> DataFrame: expressions = ",".join(list(cols)) group_by = str(self._grouping) if self._grouping else "" projections = self._grouping.get_columns() - jdf = getattr(self._df.relation, "apply")( + jdf = self._df.relation.apply( function_name=name, # aggregate function function_aggr=expressions, # inputs to aggregate group_expr=group_by, # groups @@ -76,8 +73,7 @@ def __str__(self) -> str: class GroupedData: - """ - A set of methods for aggregations on a :class:`DataFrame`, + """A set of methods for aggregations on a :class:`DataFrame`, created by :func:`DataFrame.groupBy`. """ @@ -93,7 +89,7 @@ def __repr__(self) -> str: def count(self) -> DataFrame: """Counts the number of records for each group. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [(2, "Alice"), (3, "Alice"), (5, "Bob"), (10, "Bob")], ["age", "name"] @@ -142,7 +138,7 @@ def avg(self, *cols: str) -> DataFrame: cols : str column names. Non-numeric columns are ignored. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [(2, "Alice", 80), (3, "Alice", 100), (5, "Bob", 120), (10, "Bob", 140)], @@ -188,7 +184,7 @@ def avg(self, *cols: str) -> DataFrame: def max(self, *cols: str) -> DataFrame: """Computes the max value for each numeric columns for each group. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [(2, "Alice", 80), (3, "Alice", 100), (5, "Bob", 120), (10, "Bob", 140)], @@ -233,7 +229,7 @@ def min(self, *cols: str) -> DataFrame: cols : str column names. Non-numeric columns are ignored. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [(2, "Alice", 80), (3, "Alice", 100), (5, "Bob", 120), (10, "Bob", 140)], @@ -278,7 +274,7 @@ def sum(self, *cols: str) -> DataFrame: cols : str column names. Non-numeric columns are ignored. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [(2, "Alice", 80), (3, "Alice", 100), (5, "Bob", 120), (10, "Bob", 140)], @@ -352,12 +348,12 @@ def agg(self, *exprs: Union[Column, dict[str, str]]) -> DataFrame: a dict mapping from column name (string) to aggregate functions (string), or a list of :class:`Column`. - Notes + Notes: ----- Built-in aggregation functions and group aggregate pandas UDFs cannot be mixed in a single call to this function. - Examples + Examples: -------- >>> from pyspark.sql import functions as F >>> from pyspark.sql.functions import pandas_udf, PandasUDFType diff --git a/duckdb/experimental/spark/sql/readwriter.py b/duckdb/experimental/spark/sql/readwriter.py index 18095ab6..607e9d36 100644 --- a/duckdb/experimental/spark/sql/readwriter.py +++ b/duckdb/experimental/spark/sql/readwriter.py @@ -1,11 +1,9 @@ -from typing import TYPE_CHECKING, List, Optional, Union, cast +from typing import TYPE_CHECKING, Optional, Union, cast +from ..errors import PySparkNotImplementedError, PySparkTypeError from ..exception import ContributionsAcceptedError from .types import StructType - -from ..errors import PySparkNotImplementedError, PySparkTypeError - PrimitiveType = Union[bool, float, int, str] OptionalPrimitiveType = Optional[PrimitiveType] @@ -123,7 +121,7 @@ def load( if schema: if not isinstance(schema, StructType): raise ContributionsAcceptedError - schema = cast(StructType, schema) + schema = cast("StructType", schema) types, names = schema.extract_types_and_names() df = df._cast_types(types) df = df.toDF(names) @@ -225,7 +223,7 @@ def csv( dtype = None names = None if schema: - schema = cast(StructType, schema) + schema = cast("StructType", schema) dtype, names = schema.extract_types_and_names() rel = self.session.conn.read_csv( @@ -289,8 +287,7 @@ def json( modifiedAfter: Optional[Union[bool, str]] = None, allowNonNumericNumbers: Optional[Union[bool, str]] = None, ) -> "DataFrame": - """ - Loads JSON files and returns the results as a :class:`DataFrame`. + """Loads JSON files and returns the results as a :class:`DataFrame`. `JSON Lines `_ (newline-delimited JSON) is supported by default. For JSON (one record per file), set the ``multiLine`` parameter to ``true``. @@ -321,7 +318,7 @@ def json( .. # noqa - Examples + Examples: -------- Write a DataFrame into a JSON file and read it back. @@ -340,7 +337,6 @@ def json( |100|Hyukjin Kwon| +---+------------+ """ - if schema is not None: raise ContributionsAcceptedError("The 'schema' option is not supported") if primitivesAsString is not None: @@ -410,4 +406,4 @@ def json( ) -__all__ = ["DataFrameWriter", "DataFrameReader"] +__all__ = ["DataFrameReader", "DataFrameWriter"] diff --git a/duckdb/experimental/spark/sql/session.py b/duckdb/experimental/spark/sql/session.py index c83c7e82..4b919446 100644 --- a/duckdb/experimental/spark/sql/session.py +++ b/duckdb/experimental/spark/sql/session.py @@ -1,24 +1,24 @@ -from typing import Optional, List, Any, Union, Iterable, TYPE_CHECKING import uuid +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any, Optional, Union if TYPE_CHECKING: - from .catalog import Catalog from pandas.core.frame import DataFrame as PandasDataFrame -from ..exception import ContributionsAcceptedError -from .types import StructType, AtomicType, DataType + from .catalog import Catalog + + from ..conf import SparkConf -from .dataframe import DataFrame +from ..context import SparkContext +from ..errors import PySparkTypeError +from ..errors.error_classes import * +from ..exception import ContributionsAcceptedError from .conf import RuntimeConfig +from .dataframe import DataFrame from .readwriter import DataFrameReader -from ..context import SparkContext -from .udf import UDFRegistration from .streaming import DataStreamReader -import duckdb - -from ..errors import PySparkTypeError, PySparkValueError - -from ..errors.error_classes import * +from .types import StructType +from .udf import UDFRegistration # In spark: # SparkSession holds a SparkContext diff --git a/duckdb/experimental/spark/sql/streaming.py b/duckdb/experimental/spark/sql/streaming.py index 4dcba01f..ba54db60 100644 --- a/duckdb/experimental/spark/sql/streaming.py +++ b/duckdb/experimental/spark/sql/streaming.py @@ -1,4 +1,5 @@ from typing import TYPE_CHECKING, Optional, Union + from .types import StructType if TYPE_CHECKING: @@ -29,7 +30,6 @@ def load( schema: Union[StructType, str, None] = None, **options: OptionalPrimitiveType, ) -> "DataFrame": - from duckdb.experimental.spark.sql.dataframe import DataFrame raise NotImplementedError diff --git a/duckdb/experimental/spark/sql/type_utils.py b/duckdb/experimental/spark/sql/type_utils.py index f8c8ce4f..446eac97 100644 --- a/duckdb/experimental/spark/sql/type_utils.py +++ b/duckdb/experimental/spark/sql/type_utils.py @@ -1,38 +1,40 @@ +from typing import cast + from duckdb.typing import DuckDBPyType -from typing import List, Tuple, cast + from .types import ( - DataType, - StringType, + ArrayType, BinaryType, BitstringType, - UUIDType, BooleanType, + ByteType, + DataType, DateType, - TimestampType, - TimestampNTZType, - TimeType, - TimeNTZType, - TimestampNanosecondNTZType, - TimestampMilisecondNTZType, - TimestampSecondNTZType, + DayTimeIntervalType, DecimalType, DoubleType, FloatType, - ByteType, - UnsignedByteType, - ShortType, - UnsignedShortType, + HugeIntegerType, IntegerType, - UnsignedIntegerType, LongType, - UnsignedLongType, - HugeIntegerType, - UnsignedHugeIntegerType, - DayTimeIntervalType, - ArrayType, MapType, + ShortType, + StringType, StructField, StructType, + TimeNTZType, + TimestampMilisecondNTZType, + TimestampNanosecondNTZType, + TimestampNTZType, + TimestampSecondNTZType, + TimestampType, + TimeType, + UnsignedByteType, + UnsignedHugeIntegerType, + UnsignedIntegerType, + UnsignedLongType, + UnsignedShortType, + UUIDType, ) _sqltype_to_spark_class = { @@ -93,8 +95,8 @@ def convert_type(dtype: DuckDBPyType) -> DataType: return convert_nested_type(dtype) if id == "decimal": children: list[tuple[str, DuckDBPyType]] = dtype.children - precision = cast(int, children[0][1]) - scale = cast(int, children[1][1]) + precision = cast("int", children[0][1]) + scale = cast("int", children[1][1]) return DecimalType(precision, scale) spark_type = _sqltype_to_spark_class[id] return spark_type() diff --git a/duckdb/experimental/spark/sql/types.py b/duckdb/experimental/spark/sql/types.py index 81293caf..d8a04b8e 100644 --- a/duckdb/experimental/spark/sql/types.py +++ b/duckdb/experimental/spark/sql/types.py @@ -1,25 +1,21 @@ # This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'spark' folder. +import calendar +import datetime +import math +import re +import time +from builtins import tuple +from collections.abc import Iterator from typing import ( - cast, - overload, - Dict, - Optional, - List, - Tuple, Any, - Union, - Type, - TypeVar, ClassVar, - Iterator, + Optional, + TypeVar, + Union, + cast, + overload, ) -from builtins import tuple -import datetime -import calendar -import time -import math -import re import duckdb from duckdb.typing import DuckDBPyType @@ -30,40 +26,40 @@ U = TypeVar("U") __all__ = [ - "DataType", - "NullType", - "StringType", + "ArrayType", "BinaryType", - "UUIDType", "BitstringType", "BooleanType", + "ByteType", + "DataType", "DateType", - "TimestampType", - "TimestampNTZType", - "TimestampNanosecondNTZType", - "TimestampMilisecondNTZType", - "TimestampSecondNTZType", - "TimeType", - "TimeNTZType", + "DayTimeIntervalType", "DecimalType", "DoubleType", "FloatType", - "ByteType", - "UnsignedByteType", - "ShortType", - "UnsignedShortType", + "HugeIntegerType", "IntegerType", - "UnsignedIntegerType", "LongType", - "UnsignedLongType", - "HugeIntegerType", - "UnsignedHugeIntegerType", - "DayTimeIntervalType", - "Row", - "ArrayType", "MapType", + "NullType", + "Row", + "ShortType", + "StringType", "StructField", "StructType", + "TimeNTZType", + "TimeType", + "TimestampMilisecondNTZType", + "TimestampNTZType", + "TimestampNanosecondNTZType", + "TimestampSecondNTZType", + "TimestampType", + "UUIDType", + "UnsignedByteType", + "UnsignedHugeIntegerType", + "UnsignedIntegerType", + "UnsignedLongType", + "UnsignedShortType", ] @@ -79,10 +75,10 @@ def __repr__(self) -> str: def __hash__(self) -> int: return hash(str(self)) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - def __ne__(self, other: Any) -> bool: + def __ne__(self, other: object) -> bool: return not self.__eq__(other) @classmethod @@ -99,22 +95,19 @@ def json(self) -> str: raise ContributionsAcceptedError def needConversion(self) -> bool: - """ - Does this type needs conversion between Python object and internal SQL object. + """Does this type needs conversion between Python object and internal SQL object. This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType. """ return False def toInternal(self, obj: Any) -> Any: - """ - Converts a Python object into an internal SQL object. + """Converts a Python object into an internal SQL object. """ return obj def fromInternal(self, obj: Any) -> Any: - """ - Converts an internal SQL object into a native Python object. + """Converts an internal SQL object into a native Python object. """ return obj @@ -148,7 +141,8 @@ def typeName(cls) -> str: class AtomicType(DataType): """An internal type used to represent everything that is not - null, UDTs, arrays, structs, and maps.""" + null, UDTs, arrays, structs, and maps. + """ class NumericType(AtomicType): @@ -538,8 +532,8 @@ def __init__(self, startField: Optional[int] = None, endField: Optional[int] = N fields = DayTimeIntervalType._fields if startField not in fields.keys() or endField not in fields.keys(): raise RuntimeError("interval %s to %s is invalid" % (startField, endField)) - self.startField = cast(int, startField) - self.endField = cast(int, endField) + self.startField = cast("int", startField) + self.endField = cast("int", endField) def _str_repr(self) -> str: fields = DayTimeIntervalType._fields @@ -577,7 +571,7 @@ class ArrayType(DataType): containsNull : bool, optional whether the array can contain null (None) values. - Examples + Examples: -------- >>> ArrayType(StringType()) == ArrayType(StringType(), True) True @@ -626,11 +620,11 @@ class MapType(DataType): valueContainsNull : bool, optional indicates whether values can contain null (None) values. - Notes + Notes: ----- Keys in a map data type are not allowed to be null (None). - Examples + Examples: -------- >>> (MapType(StringType(), IntegerType()) == MapType(StringType(), IntegerType(), True)) True @@ -693,7 +687,7 @@ class StructField(DataType): metadata : dict, optional a dict from string to simple type that can be toInternald to JSON automatically - Examples + Examples: -------- >>> (StructField("f1", StringType(), True) == StructField("f1", StringType(), True)) True @@ -750,7 +744,7 @@ class StructType(DataType): Iterating a :class:`StructType` will iterate over its :class:`StructField`\\s. A contained :class:`StructField` can be accessed by its name or position. - Examples + Examples: -------- >>> struct1 = StructType([StructField("f1", StringType(), True)]) >>> struct1["f1"] @@ -805,8 +799,7 @@ def add( nullable: bool = True, metadata: Optional[dict[str, Any]] = None, ) -> "StructType": - """ - Construct a :class:`StructType` by adding new elements to it, to define the schema. + """Construct a :class:`StructType` by adding new elements to it, to define the schema. The method accepts either: a) A single parameter which is a :class:`StructField` object. @@ -825,11 +818,11 @@ def add( metadata : dict, optional Any additional metadata (default None) - Returns + Returns: ------- :class:`StructType` - Examples + Examples: -------- >>> struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) >>> struct2 = StructType([StructField("f1", StringType(), True), \\ @@ -875,7 +868,7 @@ def __getitem__(self, key: Union[str, int]) -> StructField: for field in self: if field.name == key: return field - raise KeyError("No StructField named {0}".format(key)) + raise KeyError(f"No StructField named {key}") elif isinstance(key, int): try: return self.fields[key] @@ -904,10 +897,9 @@ def extract_types_and_names(self) -> tuple[list[str], list[str]]: return (types, names) def fieldNames(self) -> list[str]: - """ - Returns all field names in a list. + """Returns all field names in a list. - Examples + Examples: -------- >>> struct = StructType([StructField("f1", StringType(), True)]) >>> struct.fieldNames() @@ -987,22 +979,19 @@ def typeName(cls) -> str: @classmethod def sqlType(cls) -> DataType: - """ - Underlying SQL storage type for this UDT. + """Underlying SQL storage type for this UDT. """ raise NotImplementedError("UDT must implement sqlType().") @classmethod def module(cls) -> str: - """ - The Python module of the UDT. + """The Python module of the UDT. """ raise NotImplementedError("UDT must implement module().") @classmethod def scalaUDT(cls) -> str: - """ - The class name of the paired Scala UDT (could be '', if there + """The class name of the paired Scala UDT (could be '', if there is no corresponding one). """ return "" @@ -1012,8 +1001,7 @@ def needConversion(self) -> bool: @classmethod def _cachedSqlType(cls) -> DataType: - """ - Cache the sqlType() into class, because it's heavily used in `toInternal`. + """Cache the sqlType() into class, because it's heavily used in `toInternal`. """ if not hasattr(cls, "_cached_sql_type"): cls._cached_sql_type = cls.sqlType() # type: ignore[attr-defined] @@ -1029,21 +1017,19 @@ def fromInternal(self, obj: Any) -> Any: return self.deserialize(v) def serialize(self, obj: Any) -> Any: - """ - Converts a user-type object into a SQL datum. + """Converts a user-type object into a SQL datum. """ raise NotImplementedError("UDT must implement toInternal().") def deserialize(self, datum: Any) -> Any: - """ - Converts a SQL datum into a user-type object. + """Converts a SQL datum into a user-type object. """ raise NotImplementedError("UDT must implement fromInternal().") def simpleString(self) -> str: return "udt" - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: return type(self) == type(other) @@ -1086,8 +1072,7 @@ def _create_row(fields: Union["Row", list[str]], values: Union[tuple[Any, ...], class Row(tuple): - """ - A row in :class:`DataFrame`. + """A row in :class:`DataFrame`. The fields in it can be accessed: * like attributes (``row.key``) @@ -1104,7 +1089,7 @@ class Row(tuple): field names sorted alphabetically and will be ordered in the position as entered. - Examples + Examples: -------- >>> row = Row(name="Alice", age=11) >>> row @@ -1159,15 +1144,14 @@ def __new__(cls, *args: Optional[str], **kwargs: Optional[Any]) -> "Row": return tuple.__new__(cls, args) def asDict(self, recursive: bool = False) -> dict[str, Any]: - """ - Return as a dict + """Return as a dict Parameters ---------- recursive : bool, optional turns the nested Rows to dict (default: False). - Notes + Notes: ----- If a row contains duplicate field names, e.g., the rows of a join between two :class:`DataFrame` that both have the fields of same names, @@ -1175,7 +1159,7 @@ def asDict(self, recursive: bool = False) -> dict[str, Any]: will also return one of the duplicate fields, however returned value might be different to ``asDict``. - Examples + Examples: -------- >>> Row(name="Alice", age=11).asDict() == {"name": "Alice", "age": 11} True @@ -1212,7 +1196,7 @@ def __contains__(self, item: Any) -> bool: # let object acts like class def __call__(self, *args: Any) -> "Row": - """create new Row object""" + """Create new Row object""" if len(args) > len(self): raise ValueError( "Can not create Row with fields %s, expected %d values but got %s" % (self, len(self), args) diff --git a/duckdb/filesystem.py b/duckdb/filesystem.py index ea4ba540..885c797f 100644 --- a/duckdb/filesystem.py +++ b/duckdb/filesystem.py @@ -1,8 +1,10 @@ -from fsspec import filesystem, AbstractFileSystem -from fsspec.implementations.memory import MemoryFileSystem, MemoryFile -from .bytes_io_wrapper import BytesIOWrapper from io import TextIOBase +from fsspec import AbstractFileSystem +from fsspec.implementations.memory import MemoryFile, MemoryFileSystem + +from .bytes_io_wrapper import BytesIOWrapper + def is_file_like(obj): # We only care that we can read from the file diff --git a/duckdb/functional/__init__.py b/duckdb/functional/__init__.py index 90c2a561..b1ddab19 100644 --- a/duckdb/functional/__init__.py +++ b/duckdb/functional/__init__.py @@ -1,3 +1,3 @@ -from _duckdb.functional import FunctionNullHandling, PythonUDFType, SPECIAL, DEFAULT, NATIVE, ARROW +from _duckdb.functional import ARROW, DEFAULT, NATIVE, SPECIAL, FunctionNullHandling, PythonUDFType -__all__ = ["FunctionNullHandling", "PythonUDFType", "SPECIAL", "DEFAULT", "NATIVE", "ARROW"] +__all__ = ["ARROW", "DEFAULT", "NATIVE", "SPECIAL", "FunctionNullHandling", "PythonUDFType"] diff --git a/duckdb/polars_io.py b/duckdb/polars_io.py index ef87f03a..b1fc244c 100644 --- a/duckdb/polars_io.py +++ b/duckdb/polars_io.py @@ -1,17 +1,18 @@ -import duckdb -import polars as pl -from typing import Iterator, Optional +import datetime +import json +from collections.abc import Iterator +from decimal import Decimal +from typing import Optional +import polars as pl from polars.io.plugins import register_io_source + +import duckdb from duckdb import SQLExpression -import json -from decimal import Decimal -import datetime def _predicate_to_expression(predicate: pl.Expr) -> Optional[SQLExpression]: - """ - Convert a Polars predicate expression to a DuckDB-compatible SQL expression. + """Convert a Polars predicate expression to a DuckDB-compatible SQL expression. Parameters: predicate (pl.Expr): A Polars expression (e.g., col("foo") > 5) @@ -37,8 +38,7 @@ def _predicate_to_expression(predicate: pl.Expr) -> Optional[SQLExpression]: def _pl_operation_to_sql(op: str) -> str: - """ - Map Polars binary operation strings to SQL equivalents. + """Map Polars binary operation strings to SQL equivalents. Example: >>> _pl_operation_to_sql("Eq") @@ -60,8 +60,7 @@ def _pl_operation_to_sql(op: str) -> str: def _escape_sql_identifier(identifier: str) -> str: - """ - Escape SQL identifiers by doubling any double quotes and wrapping in double quotes. + """Escape SQL identifiers by doubling any double quotes and wrapping in double quotes. Example: >>> _escape_sql_identifier('column"name') @@ -72,8 +71,7 @@ def _escape_sql_identifier(identifier: str) -> str: def _pl_tree_to_sql(tree: dict) -> str: - """ - Recursively convert a Polars expression tree (as JSON) to a SQL string. + """Recursively convert a Polars expression tree (as JSON) to a SQL string. Parameters: tree (dict): JSON-deserialized expression tree from Polars @@ -158,7 +156,7 @@ def _pl_tree_to_sql(tree: dict) -> str: if dtype.startswith("{'Datetime'") or dtype == "Datetime": micros = value["Datetime"][0] dt_timestamp = datetime.datetime.fromtimestamp(micros / 1_000_000, tz=datetime.UTC) - return f"'{str(dt_timestamp)}'::TIMESTAMP" + return f"'{dt_timestamp!s}'::TIMESTAMP" # Match simple numeric/boolean types if dtype in ( @@ -202,14 +200,13 @@ def _pl_tree_to_sql(tree: dict) -> str: string_val = value.get("StringOwned", value.get("String", None)) return f"'{string_val}'" - raise NotImplementedError(f"Unsupported scalar type {str(dtype)}, with value {value}") + raise NotImplementedError(f"Unsupported scalar type {dtype!s}, with value {value}") raise NotImplementedError(f"Node type: {node_type} is not implemented. {subtree}") def duckdb_source(relation: duckdb.DuckDBPyRelation, schema: pl.schema.Schema) -> pl.LazyFrame: - """ - A polars IO plugin for DuckDB. + """A polars IO plugin for DuckDB. """ def source_generator( diff --git a/duckdb/query_graph/__main__.py b/duckdb/query_graph/__main__.py index aa67b42f..88d96350 100644 --- a/duckdb/query_graph/__main__.py +++ b/duckdb/query_graph/__main__.py @@ -1,10 +1,10 @@ +import argparse import json import os -import sys import re +import sys import webbrowser from functools import reduce -import argparse qgraph_css = """ .styled-table { diff --git a/duckdb/typing/__init__.py b/duckdb/typing/__init__.py index 33cf4cd7..53207418 100644 --- a/duckdb/typing/__init__.py +++ b/duckdb/typing/__init__.py @@ -1,5 +1,4 @@ from _duckdb.typing import ( - DuckDBPyType, BIGINT, BIT, BLOB, @@ -8,29 +7,29 @@ DOUBLE, FLOAT, HUGEINT, - UHUGEINT, INTEGER, INTERVAL, SMALLINT, SQLNULL, TIME, + TIME_TZ, TIMESTAMP, TIMESTAMP_MS, TIMESTAMP_NS, TIMESTAMP_S, TIMESTAMP_TZ, - TIME_TZ, TINYINT, UBIGINT, + UHUGEINT, UINTEGER, USMALLINT, UTINYINT, UUID, VARCHAR, + DuckDBPyType, ) __all__ = [ - "DuckDBPyType", "BIGINT", "BIT", "BLOB", @@ -39,7 +38,6 @@ "DOUBLE", "FLOAT", "HUGEINT", - "UHUGEINT", "INTEGER", "INTERVAL", "SMALLINT", @@ -53,9 +51,11 @@ "TIME_TZ", "TINYINT", "UBIGINT", + "UHUGEINT", "UINTEGER", "USMALLINT", "UTINYINT", "UUID", "VARCHAR", + "DuckDBPyType", ] diff --git a/duckdb/udf.py b/duckdb/udf.py index bbf05c7d..0eb59ba9 100644 --- a/duckdb/udf.py +++ b/duckdb/udf.py @@ -1,9 +1,8 @@ def vectorized(func): + """Decorate a function with annotated function parameters, so DuckDB can infer that the function should be provided with pyarrow arrays and should expect pyarrow array(s) as output """ - Decorate a function with annotated function parameters, so DuckDB can infer that the function should be provided with pyarrow arrays and should expect pyarrow array(s) as output - """ - from inspect import signature import types + from inspect import signature new_func = types.FunctionType(func.__code__, func.__globals__, func.__name__, func.__defaults__, func.__closure__) # Construct the annotations: diff --git a/duckdb/value/constant/__init__.py b/duckdb/value/constant/__init__.py index fb7d7284..9bbf2493 100644 --- a/duckdb/value/constant/__init__.py +++ b/duckdb/value/constant/__init__.py @@ -1,5 +1,5 @@ from typing import Any, Dict -from duckdb.typing import DuckDBPyType + from duckdb.typing import ( BIGINT, BIT, @@ -9,25 +9,26 @@ DOUBLE, FLOAT, HUGEINT, - UHUGEINT, INTEGER, INTERVAL, SMALLINT, SQLNULL, TIME, + TIME_TZ, TIMESTAMP, TIMESTAMP_MS, TIMESTAMP_NS, TIMESTAMP_S, TIMESTAMP_TZ, - TIME_TZ, TINYINT, UBIGINT, + UHUGEINT, UINTEGER, USMALLINT, UTINYINT, UUID, VARCHAR, + DuckDBPyType, ) @@ -236,33 +237,33 @@ def __init__(self, object: Any, members: dict[str, DuckDBPyType]) -> None: # TODO: add EnumValue once `duckdb.enum_type` is added __all__ = [ - "Value", - "NullValue", - "BooleanValue", - "UnsignedBinaryValue", - "UnsignedShortValue", - "UnsignedIntegerValue", - "UnsignedLongValue", "BinaryValue", - "ShortValue", - "IntegerValue", - "LongValue", - "HugeIntegerValue", - "UnsignedHugeIntegerValue", - "FloatValue", - "DoubleValue", - "DecimalValue", - "StringValue", - "UUIDValue", "BitValue", "BlobValue", + "BooleanValue", "DateValue", + "DecimalValue", + "DoubleValue", + "FloatValue", + "HugeIntegerValue", + "IntegerValue", "IntervalValue", - "TimestampValue", - "TimestampSecondValue", + "LongValue", + "NullValue", + "ShortValue", + "StringValue", + "TimeTimeZoneValue", + "TimeValue", "TimestampMilisecondValue", "TimestampNanosecondValue", + "TimestampSecondValue", "TimestampTimeZoneValue", - "TimeValue", - "TimeTimeZoneValue", + "TimestampValue", + "UUIDValue", + "UnsignedBinaryValue", + "UnsignedHugeIntegerValue", + "UnsignedIntegerValue", + "UnsignedLongValue", + "UnsignedShortValue", + "Value", ] diff --git a/duckdb_packaging/_versioning.py b/duckdb_packaging/_versioning.py index 3709dac0..57008fa3 100644 --- a/duckdb_packaging/_versioning.py +++ b/duckdb_packaging/_versioning.py @@ -7,10 +7,9 @@ """ import pathlib +import re import subprocess from typing import Optional -import re - VERSION_RE = re.compile( r"^(?P[0-9]+)\.(?P[0-9]+)\.(?P[0-9]+)(?:rc(?P[0-9]+)|\.post(?P[0-9]+))?$" @@ -139,8 +138,7 @@ def create_git_tag(version: str, message: Optional[str] = None, repo_path: Optio def strip_post_from_version(version: str) -> str: - """ - Removing post-release suffixes from the given version. + """Removing post-release suffixes from the given version. DuckDB doesn't allow post-release versions, so .post* suffixes are stripped. """ diff --git a/duckdb_packaging/build_backend.py b/duckdb_packaging/build_backend.py index b9a005db..dc94eeaa 100644 --- a/duckdb_packaging/build_backend.py +++ b/duckdb_packaging/build_backend.py @@ -13,25 +13,29 @@ Also see https://peps.python.org/pep-0517/#in-tree-build-backends. """ -import sys import os import subprocess +import sys from pathlib import Path -from typing import Optional, Dict, List, Union +from typing import Optional, Union + from scikit_build_core.build import ( - build_wheel as skbuild_build_wheel, build_editable, - build_sdist as skbuild_build_sdist, - get_requires_for_build_wheel, - get_requires_for_build_sdist, get_requires_for_build_editable, - prepare_metadata_for_build_wheel, + get_requires_for_build_sdist, + get_requires_for_build_wheel, prepare_metadata_for_build_editable, + prepare_metadata_for_build_wheel, +) +from scikit_build_core.build import ( + build_sdist as skbuild_build_sdist, +) +from scikit_build_core.build import ( + build_wheel as skbuild_build_wheel, ) -from duckdb_packaging._versioning import create_git_tag, pep440_to_git_tag, get_git_describe, strip_post_from_version -from duckdb_packaging.setuptools_scm_version import forced_version_from_env, MAIN_BRANCH_VERSIONING - +from duckdb_packaging._versioning import get_git_describe, pep440_to_git_tag, strip_post_from_version +from duckdb_packaging.setuptools_scm_version import MAIN_BRANCH_VERSIONING, forced_version_from_env _DUCKDB_VERSION_FILENAME = "duckdb_version.txt" _LOGGING_FORMAT = "[duckdb_pytooling.build_backend] {}" @@ -251,12 +255,12 @@ def build_wheel( __all__ = [ - "build_wheel", - "build_sdist", "build_editable", - "get_requires_for_build_wheel", - "get_requires_for_build_sdist", + "build_sdist", + "build_wheel", "get_requires_for_build_editable", - "prepare_metadata_for_build_wheel", + "get_requires_for_build_sdist", + "get_requires_for_build_wheel", "prepare_metadata_for_build_editable", + "prepare_metadata_for_build_wheel", ] diff --git a/duckdb_packaging/pypi_cleanup.py b/duckdb_packaging/pypi_cleanup.py index 80073c0e..8e91b34f 100644 --- a/duckdb_packaging/pypi_cleanup.py +++ b/duckdb_packaging/pypi_cleanup.py @@ -1,5 +1,4 @@ -""" -!!HERE BE DRAGONS!! Use this script with care! +"""!!HERE BE DRAGONS!! Use this script with care! PyPI package cleanup tool. This script will: * Never remove a stable version (including a post release version) @@ -17,8 +16,9 @@ import sys import time from collections import defaultdict +from collections.abc import Generator from html.parser import HTMLParser -from typing import Optional, Set, Generator +from typing import Optional from urllib.parse import urlparse import pyotp @@ -77,19 +77,16 @@ def create_argument_parser() -> argparse.ArgumentParser: class PyPICleanupError(Exception): """Base exception for PyPI cleanup operations.""" - pass class AuthenticationError(PyPICleanupError): """Raised when authentication fails.""" - pass class ValidationError(PyPICleanupError): """Raised when input validation fails.""" - pass def setup_logging(verbose: bool = False) -> None: @@ -236,7 +233,7 @@ def run(self) -> int: int: Exit code (0 for success, non-zero for failure) """ if self._do_delete: - logging.warning(f"NOT A DRILL: WILL DELETE PACKAGES") + logging.warning("NOT A DRILL: WILL DELETE PACKAGES") else: logging.info("Running in DRY RUN mode, nothing will be deleted") @@ -246,7 +243,7 @@ def run(self) -> int: with session_with_retries() as http_session: return self._execute_cleanup(http_session) except PyPICleanupError as e: - logging.error(f"Cleanup failed: {e}") + logging.exception(f"Cleanup failed: {e}") return 1 except Exception as e: logging.error(f"Unexpected error: {e}", exc_info=True) @@ -254,7 +251,6 @@ def run(self) -> int: def _execute_cleanup(self, http_session: Session) -> int: """Execute the main cleanup logic.""" - # Get released versions versions = self._fetch_released_versions(http_session) if not versions: @@ -418,7 +414,6 @@ def _get_csrf_token(self, http_session: Session, form_action: str) -> str: def _perform_login(self, http_session: Session) -> requests.Response: """Perform the initial login with username/password.""" - # Get login form and CSRF token csrf_token = self._get_csrf_token(http_session, "/account/login/") @@ -487,7 +482,7 @@ def _delete_versions(self, http_session: Session, versions_to_delete: set[str]) logging.info(f"Successfully deleted {self._package} version {version}") except Exception as e: # Continue with other versions rather than failing completely - logging.error(f"Failed to delete version {version}: {e}") + logging.exception(f"Failed to delete version {version}: {e}") failed_deletions.append(version) if failed_deletions: @@ -547,13 +542,13 @@ def main() -> int: return cleanup.run() except ValidationError as e: - logging.error(f"Configuration error: {e}") + logging.exception(f"Configuration error: {e}") return 2 except KeyboardInterrupt: logging.info("Operation cancelled by user") return 130 except Exception as e: - logging.error(f"Unexpected error: {e}", exc_info=args.verbose) + logging.exception(f"Unexpected error: {e}", exc_info=args.verbose) return 1 diff --git a/duckdb_packaging/setuptools_scm_version.py b/duckdb_packaging/setuptools_scm_version.py index 217b2ffe..2ff79f80 100644 --- a/duckdb_packaging/setuptools_scm_version.py +++ b/duckdb_packaging/setuptools_scm_version.py @@ -1,5 +1,4 @@ -""" -setuptools_scm integration for DuckDB Python versioning. +"""setuptools_scm integration for DuckDB Python versioning. This module provides the setuptools_scm version scheme and handles environment variable overrides to match the exact behavior of the original DuckDB Python package. @@ -10,7 +9,7 @@ from typing import Any # Import from our own versioning module to avoid duplication -from ._versioning import parse_version, format_version +from ._versioning import format_version, parse_version # MAIN_BRANCH_VERSIONING should be 'True' on main branch only MAIN_BRANCH_VERSIONING = False @@ -26,8 +25,7 @@ def _main_branch_versioning(): def version_scheme(version: Any) -> str: - """ - setuptools_scm version scheme that matches DuckDB's original behavior. + """setuptools_scm version scheme that matches DuckDB's original behavior. Args: version: setuptools_scm version object @@ -55,7 +53,7 @@ def _bump_version(base_version: str, distance: int, dirty: bool = False) -> str: # Validate the base version (this should never include anything else than X.Y.Z or X.Y.Z.[rc|post]N) try: major, minor, patch, post, rc = parse_version(base_version) - except ValueError as e: + except ValueError: raise ValueError(f"Incorrect version format: {base_version} (expected X.Y.Z or X.Y.Z.postN)") # If we're exactly on a tag (distance = 0, dirty=False) @@ -76,8 +74,7 @@ def _bump_version(base_version: str, distance: int, dirty: bool = False) -> str: def forced_version_from_env(): - """ - Handle getting versions from environment variables. + """Handle getting versions from environment variables. Only supports a single way of manually overriding the version through OVERRIDE_GIT_DESCRIBE. If SETUPTOOLS_SCM_PRETEND_VERSION* is set, it gets unset. diff --git a/scripts/generate_connection_methods.py b/scripts/generate_connection_methods.py index a48b6142..51f667f6 100644 --- a/scripts/generate_connection_methods.py +++ b/scripts/generate_connection_methods.py @@ -1,5 +1,5 @@ -import os import json +import os os.chdir(os.path.dirname(__file__)) @@ -29,7 +29,7 @@ def is_py_args(method): def generate(): # Read the PYCONNECTION_SOURCE file - with open(PYCONNECTION_SOURCE, "r") as source_file: + with open(PYCONNECTION_SOURCE) as source_file: source_code = source_file.readlines() start_index = -1 @@ -52,7 +52,7 @@ def generate(): # ---- Generate the definition code from the json ---- # Read the JSON file - with open(JSON_PATH, "r") as json_file: + with open(JSON_PATH) as json_file: connection_methods = json.load(json_file) DEFAULT_ARGUMENT_MAP = { diff --git a/scripts/generate_connection_stubs.py b/scripts/generate_connection_stubs.py index e3831173..9b1be9aa 100644 --- a/scripts/generate_connection_stubs.py +++ b/scripts/generate_connection_stubs.py @@ -1,5 +1,5 @@ -import os import json +import os os.chdir(os.path.dirname(__file__)) @@ -12,7 +12,7 @@ def generate(): # Read the DUCKDB_STUBS_FILE file - with open(DUCKDB_STUBS_FILE, "r") as source_file: + with open(DUCKDB_STUBS_FILE) as source_file: source_code = source_file.readlines() start_index = -1 @@ -35,7 +35,7 @@ def generate(): # ---- Generate the definition code from the json ---- # Read the JSON file - with open(JSON_PATH, "r") as json_file: + with open(JSON_PATH) as json_file: connection_methods = json.load(json_file) body = [] diff --git a/scripts/generate_connection_wrapper_methods.py b/scripts/generate_connection_wrapper_methods.py index 45ac45cc..d2ef0bba 100644 --- a/scripts/generate_connection_wrapper_methods.py +++ b/scripts/generate_connection_wrapper_methods.py @@ -1,10 +1,8 @@ -import os -import sys import json +import os # Requires `python3 -m pip install cxxheaderparser pcpp` -from get_cpp_methods import get_methods, FunctionParam, ConnectionMethod -from typing import List, Tuple +from get_cpp_methods import ConnectionMethod, get_methods os.chdir(os.path.dirname(__file__)) @@ -40,7 +38,7 @@ INIT_PY_END = "# END OF CONNECTION WRAPPER" # Read the JSON file -with open(WRAPPER_JSON_PATH, "r") as json_file: +with open(WRAPPER_JSON_PATH) as json_file: wrapper_methods = json.load(json_file) # On DuckDBPyConnection these are read_only_properties, they're basically functions without requiring () to invoke @@ -94,19 +92,19 @@ def remove_section(content, start_marker, end_marker) -> tuple[list[str], list[s def generate(): # Read the DUCKDB_PYTHON_SOURCE file - with open(DUCKDB_PYTHON_SOURCE, "r") as source_file: + with open(DUCKDB_PYTHON_SOURCE) as source_file: source_code = source_file.readlines() start_section, end_section = remove_section(source_code, START_MARKER, END_MARKER) # Read the DUCKDB_INIT_FILE file - with open(DUCKDB_INIT_FILE, "r") as source_file: + with open(DUCKDB_INIT_FILE) as source_file: source_code = source_file.readlines() py_start, py_end = remove_section(source_code, INIT_PY_START, INIT_PY_END) # ---- Generate the definition code from the json ---- # Read the JSON file - with open(JSON_PATH, "r") as json_file: + with open(JSON_PATH) as json_file: connection_methods = json.load(json_file) # Collect the definitions from the pyconnection.hpp header diff --git a/scripts/generate_connection_wrapper_stubs.py b/scripts/generate_connection_wrapper_stubs.py index 02e36c4e..3b3b8c93 100644 --- a/scripts/generate_connection_wrapper_stubs.py +++ b/scripts/generate_connection_wrapper_stubs.py @@ -1,5 +1,5 @@ -import os import json +import os os.chdir(os.path.dirname(__file__)) @@ -13,7 +13,7 @@ def generate(): # Read the DUCKDB_STUBS_FILE file - with open(DUCKDB_STUBS_FILE, "r") as source_file: + with open(DUCKDB_STUBS_FILE) as source_file: source_code = source_file.readlines() start_index = -1 @@ -38,10 +38,10 @@ def generate(): methods = [] # Read the JSON file - with open(JSON_PATH, "r") as json_file: + with open(JSON_PATH) as json_file: connection_methods = json.load(json_file) - with open(WRAPPER_JSON_PATH, "r") as json_file: + with open(WRAPPER_JSON_PATH) as json_file: wrapper_methods = json.load(json_file) methods.extend(connection_methods) diff --git a/scripts/generate_import_cache_cpp.py b/scripts/generate_import_cache_cpp.py index 8a4b0c36..036115f4 100644 --- a/scripts/generate_import_cache_cpp.py +++ b/scripts/generate_import_cache_cpp.py @@ -1,14 +1,13 @@ import os script_dir = os.path.dirname(__file__) -from typing import List, Dict import json # Load existing JSON data from a file if it exists json_data = {} json_cache_path = os.path.join(script_dir, "cache_data.json") try: - with open(json_cache_path, "r") as file: + with open(json_cache_path) as file: json_data = json.load(file) except FileNotFoundError: print("Please first use 'generate_import_cache_json.py' first to generate json") diff --git a/scripts/generate_import_cache_json.py b/scripts/generate_import_cache_json.py index 099db841..34cd84b6 100644 --- a/scripts/generate_import_cache_json.py +++ b/scripts/generate_import_cache_json.py @@ -1,8 +1,8 @@ import os script_dir = os.path.dirname(__file__) -from typing import List, Dict, Union import json +from typing import Union lines: list[str] = [file for file in open(f"{script_dir}/imports.py").read().split("\n") if file != ""] @@ -13,7 +13,7 @@ def __init__(self, full_path: str) -> None: self.type = "attribute" self.name = parts[-1] self.full_path = full_path - self.children: dict[str, "ImportCacheAttribute"] = {} + self.children: dict[str, ImportCacheAttribute] = {} def has_item(self, item_name: str) -> bool: return item_name in self.children @@ -46,7 +46,7 @@ def __init__(self, full_path) -> None: self.type = "module" self.name = parts[-1] self.full_path = full_path - self.items: dict[str, Union[ImportCacheAttribute, "ImportCacheModule"]] = {} + self.items: dict[str, Union[ImportCacheAttribute, ImportCacheModule]] = {} def add_item(self, item: Union[ImportCacheAttribute, "ImportCacheModule"]): assert self.full_path != item.full_path @@ -156,7 +156,7 @@ def to_json(self): existing_json_data = {} json_cache_path = os.path.join(script_dir, "cache_data.json") try: - with open(json_cache_path, "r") as file: + with open(json_cache_path) as file: existing_json_data = json.load(file) except FileNotFoundError: pass diff --git a/scripts/get_cpp_methods.py b/scripts/get_cpp_methods.py index 97b28af3..25aa7c7d 100644 --- a/scripts/get_cpp_methods.py +++ b/scripts/get_cpp_methods.py @@ -1,10 +1,10 @@ # Requires `python3 -m pip install cxxheaderparser pcpp` import os +from typing import Callable import cxxheaderparser.parser -import cxxheaderparser.visitor import cxxheaderparser.preprocessor -from typing import List, Dict, Callable +import cxxheaderparser.visitor scripts_folder = os.path.dirname(os.path.abspath(__file__)) diff --git a/sqllogic/conftest.py b/sqllogic/conftest.py index b8d913ea..8d772111 100644 --- a/sqllogic/conftest.py +++ b/sqllogic/conftest.py @@ -1,11 +1,12 @@ import itertools import pathlib -import pytest import random import re import typing import warnings -import glob + +import pytest + from .skipped_tests import SKIPPED_TESTS SQLLOGIC_TEST_CASE_NAME = "test_sqllogic" @@ -126,11 +127,9 @@ def create_parameters_from_paths(paths, root_dir: pathlib.Path, config: pytest.C def scan_for_test_scripts(root_dir: pathlib.Path, config: pytest.Config) -> typing.Iterator[typing.Any]: - """ - Scans for .test files in the given directory and its subdirectories. + """Scans for .test files in the given directory and its subdirectories. Returns an iterator of pytest parameters (argument, id and marks). """ - # TODO: Add tests from extensions test_script_extensions = [".test", ".test_slow", ".test_coverage"] it = itertools.chain.from_iterable(root_dir.rglob(f"*{ext}") for ext in test_script_extensions) @@ -166,13 +165,11 @@ def pytest_generate_tests(metafunc: pytest.Metafunc): def determine_test_offsets(config: pytest.Config, num_tests: int) -> tuple[int, int]: - """ - If start_offset and end_offset are specified, then these are used. + """If start_offset and end_offset are specified, then these are used. start_offset defaults to 0. end_offset defaults to and is capped to the last test index. start_offset_percentage and end_offset_percentage are used to calculate the start and end offsets based on the total number of tests. This is done in a way that a test run to 25% and another test run starting at 25% do not overlap by excluding the 25th percent test. """ - start_offset = config.getoption("start_offset") end_offset = config.getoption("end_offset") start_offset_percentage = config.getoption("start_offset_percentage") @@ -271,8 +268,7 @@ def pytest_collection_modifyitems(session: pytest.Session, config: pytest.Config def pytest_runtest_setup(item: pytest.Item): - """ - Show the test index after the test name + """Show the test index after the test name """ def get_from_tuple_list(tuples, key): diff --git a/sqllogic/test_sqllogic.py b/sqllogic/test_sqllogic.py index 6f55e931..35736015 100644 --- a/sqllogic/test_sqllogic.py +++ b/sqllogic/test_sqllogic.py @@ -1,24 +1,25 @@ import gc import os import pathlib -import pytest import signal import sys -from typing import Any, Generator, Optional +from collections.abc import Generator +from typing import Any, Optional + +import pytest sys.path.append(str(pathlib.Path(__file__).parent.parent / "external" / "duckdb" / "scripts")) from sqllogictest import ( - SQLParserException, SQLLogicParser, SQLLogicTest, + SQLParserException, ) - from sqllogictest.result import ( - TestException, - SQLLogicRunner, - SQLLogicDatabase, - SQLLogicContext, ExecuteResult, + SQLLogicContext, + SQLLogicDatabase, + SQLLogicRunner, + TestException, ) diff --git a/tests/conftest.py b/tests/conftest.py index d69cdfce..83c10f3a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,14 @@ +import glob import os +import shutil +import warnings +from importlib import import_module +from os.path import abspath, dirname, join, normpath from typing import Any import pytest -import shutil -from os.path import abspath, join, dirname, normpath -import glob + import duckdb -import warnings -from importlib import import_module try: # need to ignore warnings that might be thrown deep inside pandas's import tree (from dateutil in this case) @@ -71,11 +72,12 @@ def duckdb_empty_cursor(request): def getTimeSeriesData(nper=None, freq: "Frequency" = "B"): - from pandas import DatetimeIndex, bdate_range, Series + import string from datetime import datetime - from pandas._typing import Frequency + import numpy as np - import string + from pandas import DatetimeIndex, Series, bdate_range + from pandas._typing import Frequency _N = 30 _K = 4 @@ -226,7 +228,6 @@ def _require(extension_name, db_name=""): # By making the scope 'function' we ensure that a new connection gets created for every function that uses the fixture @pytest.fixture(scope="function") def spark(): - from spark_namespace import USE_ACTUAL_SPARK if not hasattr(spark, "session"): # Cache the import diff --git a/tests/coverage/test_pandas_categorical_coverage.py b/tests/coverage/test_pandas_categorical_coverage.py index 15eee10a..b0130577 100644 --- a/tests/coverage/test_pandas_categorical_coverage.py +++ b/tests/coverage/test_pandas_categorical_coverage.py @@ -1,7 +1,7 @@ -import duckdb -import numpy import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import NumpyPandas + +import duckdb def check_result_list(res): @@ -69,7 +69,7 @@ def check_create_table(category, pandas): # TODO: extend tests with ArrowPandas -class TestCategory(object): +class TestCategory: @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_category_string_uint16(self, duckdb_cursor, pandas): category = [] diff --git a/tests/extensions/json/test_read_json.py b/tests/extensions/json/test_read_json.py index f0fd809f..9ac5be88 100644 --- a/tests/extensions/json/test_read_json.py +++ b/tests/extensions/json/test_read_json.py @@ -1,10 +1,8 @@ -import numpy -import datetime -import pandas +from io import StringIO + import pytest + import duckdb -import re -from io import StringIO def TestFile(name): @@ -14,7 +12,7 @@ def TestFile(name): return filename -class TestReadJSON(object): +class TestReadJSON: def test_read_json_columns(self): rel = duckdb.read_json(TestFile("example.json"), columns={"id": "integer", "name": "varchar"}) res = rel.fetchone() diff --git a/tests/extensions/test_extensions_loading.py b/tests/extensions/test_extensions_loading.py index f35366ba..3aa5fe81 100644 --- a/tests/extensions/test_extensions_loading.py +++ b/tests/extensions/test_extensions_loading.py @@ -1,10 +1,10 @@ import os import platform -import duckdb -from pytest import raises import pytest +from pytest import raises +import duckdb pytestmark = pytest.mark.skipif( platform.system() == "Emscripten", diff --git a/tests/extensions/test_httpfs.py b/tests/extensions/test_httpfs.py index 866491f0..bd1ec015 100644 --- a/tests/extensions/test_httpfs.py +++ b/tests/extensions/test_httpfs.py @@ -1,9 +1,11 @@ -import duckdb +import datetime import os -from pytest import raises, mark + import pytest -from conftest import NumpyPandas, ArrowPandas -import datetime +from conftest import ArrowPandas, NumpyPandas +from pytest import mark, raises + +import duckdb # We only run this test if this env var is set # FIXME: we can add a custom command line argument to pytest to provide an extension directory @@ -14,7 +16,7 @@ ) -class TestHTTPFS(object): +class TestHTTPFS: def test_read_json_httpfs(self, require): connection = require("httpfs") try: @@ -29,7 +31,7 @@ def test_read_json_httpfs(self, require): def test_s3fs(self, require): connection = require("httpfs") - rel = connection.read_csv(f"s3://duckdb-blobs/data/Star_Trek-Season_1.csv", header=True) + rel = connection.read_csv("s3://duckdb-blobs/data/Star_Trek-Season_1.csv", header=True) res = rel.fetchone() assert res == (1, 0, datetime.date(1965, 2, 28), 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 6, 0, 0, 0, 0) @@ -42,9 +44,7 @@ def test_httpfs(self, require, pandas): ) except RuntimeError as e: # Test will ignore result if it fails due to networking issues while running the test. - if str(e).startswith("HTTP HEAD error"): - return - elif str(e).startswith("Unable to connect"): + if str(e).startswith("HTTP HEAD error") or str(e).startswith("Unable to connect"): return else: raise e diff --git a/tests/fast/adbc/test_adbc.py b/tests/fast/adbc/test_adbc.py index 80b6b385..6f6213e6 100644 --- a/tests/fast/adbc/test_adbc.py +++ b/tests/fast/adbc/test_adbc.py @@ -1,9 +1,11 @@ -import duckdb -import pytest -import sys import datetime import os +import sys + import numpy as np +import pytest + +import duckdb if sys.version_info < (3, 9): pytest.skip( diff --git a/tests/fast/adbc/test_connection_get_info.py b/tests/fast/adbc/test_connection_get_info.py index 3744b7da..4f8163bc 100644 --- a/tests/fast/adbc/test_connection_get_info.py +++ b/tests/fast/adbc/test_connection_get_info.py @@ -1,8 +1,9 @@ import sys -import duckdb import pytest +import duckdb + pa = pytest.importorskip("pyarrow") adbc_driver_manager = pytest.importorskip("adbc_driver_manager") @@ -22,7 +23,7 @@ ) -class TestADBCConnectionGetInfo(object): +class TestADBCConnectionGetInfo: def test_connection_basic(self): con = adbc_driver_duckdb.connect() with con.cursor() as cursor: diff --git a/tests/fast/adbc/test_statement_bind.py b/tests/fast/adbc/test_statement_bind.py index d1919cb1..dc5d1f59 100644 --- a/tests/fast/adbc/test_statement_bind.py +++ b/tests/fast/adbc/test_statement_bind.py @@ -31,7 +31,7 @@ def _bind(stmt, batch): stmt.bind(array, schema) -class TestADBCStatementBind(object): +class TestADBCStatementBind: def test_bind_multiple_rows(self): data = pa.record_batch( [ diff --git a/tests/fast/api/test_3324.py b/tests/fast/api/test_3324.py index f3cd235b..fb860600 100644 --- a/tests/fast/api/test_3324.py +++ b/tests/fast/api/test_3324.py @@ -1,8 +1,9 @@ import pytest + import duckdb -class Test3324(object): +class Test3324: def test_3324(self, duckdb_cursor): create_output = duckdb_cursor.execute( """ diff --git a/tests/fast/api/test_3654.py b/tests/fast/api/test_3654.py index 8fad47e6..2ffee855 100644 --- a/tests/fast/api/test_3654.py +++ b/tests/fast/api/test_3654.py @@ -1,16 +1,17 @@ -import duckdb import pytest +import duckdb + try: import pyarrow as pa can_run = True except: can_run = False -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas -class Test3654(object): +class Test3654: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_3654_pandas(self, duckdb_cursor, pandas): df1 = pandas.DataFrame( diff --git a/tests/fast/api/test_3728.py b/tests/fast/api/test_3728.py index 37b50ee6..bd770bf0 100644 --- a/tests/fast/api/test_3728.py +++ b/tests/fast/api/test_3728.py @@ -1,7 +1,7 @@ import duckdb -class Test3728(object): +class Test3728: def test_3728_describe_enum(self, duckdb_cursor): # Create an in-memory database, but the problem is also present in file-backed DBs cursor = duckdb.connect(":memory:") diff --git a/tests/fast/api/test_6315.py b/tests/fast/api/test_6315.py index b9e7c0cf..3702831e 100644 --- a/tests/fast/api/test_6315.py +++ b/tests/fast/api/test_6315.py @@ -1,7 +1,7 @@ import duckdb -class Test6315(object): +class Test6315: def test_6315(self, duckdb_cursor): # segfault when accessing description after fetching rows c = duckdb.connect(":memory:") diff --git a/tests/fast/api/test_attribute_getter.py b/tests/fast/api/test_attribute_getter.py index eda6845a..3b1513d1 100644 --- a/tests/fast/api/test_attribute_getter.py +++ b/tests/fast/api/test_attribute_getter.py @@ -1,15 +1,10 @@ -import duckdb -import tempfile -import os -import pandas as pd -import tempfile -import pandas._testing as tm -import datetime -import csv + import pytest +import duckdb + -class TestGetAttribute(object): +class TestGetAttribute: def test_basic_getattr(self, duckdb_cursor): rel = duckdb_cursor.sql("select i as a, (i + 5) % 10 as b, (i + 2) % 3 as c from range(100) tbl(i)") assert rel.a.fetchmany(5) == [(0,), (1,), (2,), (3,), (4,)] diff --git a/tests/fast/api/test_config.py b/tests/fast/api/test_config.py index 4a0a0445..89620e96 100644 --- a/tests/fast/api/test_config.py +++ b/tests/fast/api/test_config.py @@ -1,14 +1,15 @@ # simple DB API testcase -import duckdb -import numpy -import pytest -import re import os -from conftest import NumpyPandas, ArrowPandas +import re + +import pytest +from conftest import ArrowPandas, NumpyPandas + +import duckdb -class TestDBConfig(object): +class TestDBConfig: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_default_order(self, duckdb_cursor, pandas): df = pandas.DataFrame({"a": [1, 2, 3]}) @@ -51,7 +52,7 @@ def test_extension_setting(self): if not repository: return con = duckdb.connect(config={"TimeZone": "UTC", "autoinstall_extension_repository": repository}) - assert "UTC" == con.sql("select current_setting('TimeZone')").fetchone()[0] + assert con.sql("select current_setting('TimeZone')").fetchone()[0] == "UTC" def test_unrecognized_option(self, duckdb_cursor): success = True diff --git a/tests/fast/api/test_connection_close.py b/tests/fast/api/test_connection_close.py index f71a02bb..bbf66772 100644 --- a/tests/fast/api/test_connection_close.py +++ b/tests/fast/api/test_connection_close.py @@ -1,10 +1,12 @@ # cursor description -import duckdb -import tempfile import os +import tempfile + import pytest +import duckdb + def check_exception(f): had_exception = False @@ -15,7 +17,7 @@ def check_exception(f): assert had_exception -class TestConnectionClose(object): +class TestConnectionClose: def test_connection_close(self, duckdb_cursor): fd, db = tempfile.mkstemp() os.close(fd) diff --git a/tests/fast/api/test_connection_interrupt.py b/tests/fast/api/test_connection_interrupt.py index 4efd68b5..8a027b5a 100644 --- a/tests/fast/api/test_connection_interrupt.py +++ b/tests/fast/api/test_connection_interrupt.py @@ -2,11 +2,12 @@ import threading import time -import duckdb import pytest +import duckdb + -class TestConnectionInterrupt(object): +class TestConnectionInterrupt: @pytest.mark.xfail( condition=platform.system() == "Emscripten", reason="threads not allowed on Emscripten", diff --git a/tests/fast/api/test_cursor.py b/tests/fast/api/test_cursor.py index 69c3fe79..7a2c4176 100644 --- a/tests/fast/api/test_cursor.py +++ b/tests/fast/api/test_cursor.py @@ -1,10 +1,11 @@ # simple DB API testcase import pytest + import duckdb -class TestDBAPICursor(object): +class TestDBAPICursor: def test_cursor_basic(self): # Create a connection con = duckdb.connect(":memory:") diff --git a/tests/fast/api/test_dbapi00.py b/tests/fast/api/test_dbapi00.py index 38d87887..6201d569 100644 --- a/tests/fast/api/test_dbapi00.py +++ b/tests/fast/api/test_dbapi00.py @@ -2,15 +2,14 @@ import numpy import pytest -import duckdb -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas def assert_result_equal(result): assert result == [(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,), (9,), (None,)], "Incorrect result returned" -class TestSimpleDBAPI(object): +class TestSimpleDBAPI: def test_regular_selection(self, duckdb_cursor, integers): duckdb_cursor.execute("SELECT * FROM integers") result = duckdb_cursor.fetchall() diff --git a/tests/fast/api/test_dbapi01.py b/tests/fast/api/test_dbapi01.py index f7f00a10..4d52fd64 100644 --- a/tests/fast/api/test_dbapi01.py +++ b/tests/fast/api/test_dbapi01.py @@ -1,10 +1,11 @@ # multiple result sets import numpy + import duckdb -class TestMultipleResultSets(object): +class TestMultipleResultSets: def test_regular_selection(self, duckdb_cursor, integers): duckdb_cursor.execute("SELECT * FROM integers") duckdb_cursor.execute("SELECT * FROM integers") diff --git a/tests/fast/api/test_dbapi04.py b/tests/fast/api/test_dbapi04.py index 1125f819..2c2259ce 100644 --- a/tests/fast/api/test_dbapi04.py +++ b/tests/fast/api/test_dbapi04.py @@ -1,7 +1,7 @@ # simple DB API testcase -class TestSimpleDBAPI(object): +class TestSimpleDBAPI: def test_regular_selection(self, duckdb_cursor, integers): duckdb_cursor.execute("SELECT * FROM integers") result = duckdb_cursor.fetchall() diff --git a/tests/fast/api/test_dbapi05.py b/tests/fast/api/test_dbapi05.py index 234fb2ec..6c6d4fa1 100644 --- a/tests/fast/api/test_dbapi05.py +++ b/tests/fast/api/test_dbapi05.py @@ -1,7 +1,7 @@ # simple DB API testcase -class TestSimpleDBAPI(object): +class TestSimpleDBAPI: def test_prepare(self, duckdb_cursor): result = duckdb_cursor.execute("SELECT CAST(? AS INTEGER), CAST(? AS INTEGER)", ["42", "84"]).fetchall() assert result == [ diff --git a/tests/fast/api/test_dbapi07.py b/tests/fast/api/test_dbapi07.py index 238f30fc..eab581e5 100644 --- a/tests/fast/api/test_dbapi07.py +++ b/tests/fast/api/test_dbapi07.py @@ -1,16 +1,17 @@ # timestamp ms precision -import numpy from datetime import datetime +import numpy + -class TestNumpyTimestampMilliseconds(object): +class TestNumpyTimestampMilliseconds: def test_numpy_timestamp(self, duckdb_cursor): res = duckdb_cursor.execute("SELECT TIMESTAMP '2019-11-26 21:11:42.501' as test_time").fetchnumpy() assert res["test_time"] == numpy.datetime64("2019-11-26 21:11:42.501") -class TestTimestampMilliseconds(object): +class TestTimestampMilliseconds: def test_numpy_timestamp(self, duckdb_cursor): res = duckdb_cursor.execute("SELECT TIMESTAMP '2019-11-26 21:11:42.501' as test_time").fetchone()[0] assert res == datetime.strptime("2019-11-26 21:11:42.501", "%Y-%m-%d %H:%M:%S.%f") diff --git a/tests/fast/api/test_dbapi08.py b/tests/fast/api/test_dbapi08.py index 457a9e78..def4e925 100644 --- a/tests/fast/api/test_dbapi08.py +++ b/tests/fast/api/test_dbapi08.py @@ -1,11 +1,11 @@ # test fetchdf with various types -import numpy import pytest -import duckdb from conftest import NumpyPandas +import duckdb + -class TestType(object): +class TestType: @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_fetchdf(self, pandas): con = duckdb.connect() diff --git a/tests/fast/api/test_dbapi09.py b/tests/fast/api/test_dbapi09.py index 538e7fc3..8a31e10e 100644 --- a/tests/fast/api/test_dbapi09.py +++ b/tests/fast/api/test_dbapi09.py @@ -1,11 +1,12 @@ # date type -import numpy import datetime + +import numpy import pandas -class TestNumpyDate(object): +class TestNumpyDate: def test_fetchall_date(self, duckdb_cursor): res = duckdb_cursor.execute("SELECT DATE '2020-01-10' as test_date").fetchall() assert res == [(datetime.date(2020, 1, 10),)] diff --git a/tests/fast/api/test_dbapi10.py b/tests/fast/api/test_dbapi10.py index 0ab69e0b..8b5cb0e4 100644 --- a/tests/fast/api/test_dbapi10.py +++ b/tests/fast/api/test_dbapi10.py @@ -1,10 +1,11 @@ # cursor description -from datetime import datetime, date +from datetime import date, datetime + from pytest import mark import duckdb -class TestCursorDescription(object): +class TestCursorDescription: @mark.parametrize( "query,column_name,string_type,real_type", [ @@ -51,6 +52,6 @@ def test_none_description(self, duckdb_empty_cursor): assert duckdb_empty_cursor.description is None -class TestCursorRowcount(object): +class TestCursorRowcount: def test_rowcount(self, duckdb_cursor): assert duckdb_cursor.rowcount == -1 diff --git a/tests/fast/api/test_dbapi11.py b/tests/fast/api/test_dbapi11.py index 91237b9e..56f5724d 100644 --- a/tests/fast/api/test_dbapi11.py +++ b/tests/fast/api/test_dbapi11.py @@ -1,8 +1,9 @@ # cursor description -import duckdb -import tempfile import os +import tempfile + +import duckdb def check_exception(f): @@ -14,7 +15,7 @@ def check_exception(f): assert had_exception -class TestReadOnly(object): +class TestReadOnly: def test_readonly(self, duckdb_cursor): fd, db = tempfile.mkstemp() os.close(fd) diff --git a/tests/fast/api/test_dbapi12.py b/tests/fast/api/test_dbapi12.py index 833d231c..96b1deac 100644 --- a/tests/fast/api/test_dbapi12.py +++ b/tests/fast/api/test_dbapi12.py @@ -1,10 +1,10 @@ -import duckdb -import tempfile -import os + import pandas as pd +import duckdb + -class TestRelationApi(object): +class TestRelationApi: def test_readonly(self, duckdb_cursor): test_df = pd.DataFrame.from_dict({"i": [1, 2, 3], "j": ["one", "two", "three"]}) diff --git a/tests/fast/api/test_dbapi13.py b/tests/fast/api/test_dbapi13.py index ffdb4884..c08cefb1 100644 --- a/tests/fast/api/test_dbapi13.py +++ b/tests/fast/api/test_dbapi13.py @@ -1,11 +1,12 @@ # time type -import numpy import datetime + +import numpy import pandas -class TestNumpyTime(object): +class TestNumpyTime: def test_fetchall_time(self, duckdb_cursor): res = duckdb_cursor.execute("SELECT TIME '13:06:40' as test_time").fetchall() assert res == [(datetime.time(13, 6, 40),)] diff --git a/tests/fast/api/test_dbapi_fetch.py b/tests/fast/api/test_dbapi_fetch.py index 9c47c54c..5ec18aca 100644 --- a/tests/fast/api/test_dbapi_fetch.py +++ b/tests/fast/api/test_dbapi_fetch.py @@ -1,11 +1,13 @@ -import duckdb -import pytest -from uuid import UUID import datetime from decimal import Decimal +from uuid import UUID + +import pytest + +import duckdb -class TestDBApiFetch(object): +class TestDBApiFetch: def test_multiple_fetch_one(self, duckdb_cursor): con = duckdb.connect() c = con.execute("SELECT 42") diff --git a/tests/fast/api/test_duckdb_connection.py b/tests/fast/api/test_duckdb_connection.py index 4b0dc4d6..eb241145 100644 --- a/tests/fast/api/test_duckdb_connection.py +++ b/tests/fast/api/test_duckdb_connection.py @@ -1,7 +1,8 @@ +import pytest +from conftest import ArrowPandas, NumpyPandas + import duckdb import duckdb.typing -import pytest -from conftest import NumpyPandas, ArrowPandas pa = pytest.importorskip("pyarrow") @@ -22,7 +23,7 @@ def tmp_database(tmp_path_factory): # This file contains tests for DuckDBPyConnection methods, # wrapped by the 'duckdb' module, to execute with the 'default_connection' -class TestDuckDBConnection(object): +class TestDuckDBConnection: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_append(self, pandas): duckdb.execute("Create table integers (i integer)") @@ -118,7 +119,7 @@ def test_readonly_properties(self): assert rowcount == -1 def test_execute(self): - assert [([4, 2],)] == duckdb.execute("select [4,2]").fetchall() + assert duckdb.execute("select [4,2]").fetchall() == [([4, 2],)] def test_executemany(self): # executemany does not keep an open result set @@ -231,7 +232,7 @@ def test_fetch_record_batch(self): assert len(chunk) == 3000 def test_fetchall(self): - assert [([1, 2, 3],)] == duckdb.execute("select [1,2,3]").fetchall() + assert duckdb.execute("select [1,2,3]").fetchall() == [([1, 2, 3],)] def test_fetchdf(self): ref = [([1, 2, 3],)] @@ -241,7 +242,7 @@ def test_fetchdf(self): assert res == ref def test_fetchmany(self): - assert [(0,), (1,)] == duckdb.execute("select * from range(5)").fetchmany(2) + assert duckdb.execute("select * from range(5)").fetchmany(2) == [(0,), (1,)] def test_fetchnumpy(self): numpy = pytest.importorskip("numpy") @@ -254,37 +255,37 @@ def test_fetchnumpy(self): assert results["a"] == numpy.array([b"hello"], dtype=object) def test_fetchone(self): - assert (0,) == duckdb.execute("select * from range(5)").fetchone() + assert duckdb.execute("select * from range(5)").fetchone() == (0,) def test_from_arrow(self): - assert None != duckdb.from_arrow + assert duckdb.from_arrow != None def test_from_csv_auto(self): - assert None != duckdb.from_csv_auto + assert duckdb.from_csv_auto != None def test_from_df(self): - assert None != duckdb.from_df + assert duckdb.from_df != None def test_from_parquet(self): - assert None != duckdb.from_parquet + assert duckdb.from_parquet != None def test_from_query(self): - assert None != duckdb.from_query + assert duckdb.from_query != None def test_get_table_names(self): - assert None != duckdb.get_table_names + assert duckdb.get_table_names != None def test_install_extension(self): - assert None != duckdb.install_extension + assert duckdb.install_extension != None def test_load_extension(self): - assert None != duckdb.load_extension + assert duckdb.load_extension != None def test_query(self): - assert [(3,)] == duckdb.query("select 3").fetchall() + assert duckdb.query("select 3").fetchall() == [(3,)] def test_register(self): - assert None != duckdb.register + assert duckdb.register != None def test_register_relation(self): con = duckdb.connect() @@ -334,27 +335,27 @@ def temporary_scope(): def test_table(self): con = duckdb.connect() con.execute("create table tbl as select 1") - assert [(1,)] == con.table("tbl").fetchall() + assert con.table("tbl").fetchall() == [(1,)] def test_table_function(self): - assert None != duckdb.table_function + assert duckdb.table_function != None def test_unregister(self): - assert None != duckdb.unregister + assert duckdb.unregister != None def test_values(self): - assert None != duckdb.values + assert duckdb.values != None def test_view(self): duckdb.execute("create view vw as select range(5)") - assert [([0, 1, 2, 3, 4],)] == duckdb.view("vw").fetchall() + assert duckdb.view("vw").fetchall() == [([0, 1, 2, 3, 4],)] duckdb.execute("drop view vw") def test_close(self): - assert None != duckdb.close + assert duckdb.close != None def test_interrupt(self): - assert None != duckdb.interrupt + assert duckdb.interrupt != None def test_wrap_shadowing(self): pd = NumpyPandas() @@ -393,7 +394,7 @@ def test_set_pandas_analyze_sample_size(self): # Find the cached config con2 = duckdb.connect(":memory:named", config={"pandas_analyze_sample": 0}) - con2.execute(f"SET GLOBAL pandas_analyze_sample=2") + con2.execute("SET GLOBAL pandas_analyze_sample=2") # This change is reflected in 'con' because the instance was cached res = con.sql("select current_setting('pandas_analyze_sample')").fetchone() diff --git a/tests/fast/api/test_duckdb_execute.py b/tests/fast/api/test_duckdb_execute.py index a025fc42..df8bff63 100644 --- a/tests/fast/api/test_duckdb_execute.py +++ b/tests/fast/api/test_duckdb_execute.py @@ -1,8 +1,9 @@ -import duckdb import pytest +import duckdb + -class TestDuckDBExecute(object): +class TestDuckDBExecute: def test_execute_basic(self, duckdb_cursor): duckdb_cursor.execute("create table t as select 5") res = duckdb_cursor.table("t").fetchall() diff --git a/tests/fast/api/test_duckdb_query.py b/tests/fast/api/test_duckdb_query.py index 2ecfd8f3..db807f44 100644 --- a/tests/fast/api/test_duckdb_query.py +++ b/tests/fast/api/test_duckdb_query.py @@ -1,10 +1,11 @@ -import duckdb import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb from duckdb import Value -class TestDuckDBQuery(object): +class TestDuckDBQuery: def test_duckdb_query(self, duckdb_cursor): # we can use duckdb_cursor.sql to run both DDL statements and select statements duckdb_cursor.sql("create view v1 as select 42 i") diff --git a/tests/fast/api/test_explain.py b/tests/fast/api/test_explain.py index feedc134..23bcfcd4 100644 --- a/tests/fast/api/test_explain.py +++ b/tests/fast/api/test_explain.py @@ -1,8 +1,9 @@ import pytest + import duckdb -class TestExplain(object): +class TestExplain: def test_explain_basic(self, duckdb_cursor): res = duckdb_cursor.sql("select 42").explain() assert isinstance(res, str) diff --git a/tests/fast/api/test_fsspec.py b/tests/fast/api/test_fsspec.py index 7b797598..d7d2503d 100644 --- a/tests/fast/api/test_fsspec.py +++ b/tests/fast/api/test_fsspec.py @@ -1,16 +1,16 @@ -import pytest -import duckdb -import io import datetime +import io + +import pytest fsspec = pytest.importorskip("fsspec") -class TestReadParquet(object): +class TestReadParquet: def test_fsspec_deadlock(self, duckdb_cursor, tmp_path): # Create test parquet data file_path = tmp_path / "data.parquet" - duckdb_cursor.sql("COPY (FROM range(50_000)) TO '{}' (FORMAT parquet)".format(str(file_path))) + duckdb_cursor.sql(f"COPY (FROM range(50_000)) TO '{file_path!s}' (FORMAT parquet)") with open(file_path, "rb") as f: parquet_data = f.read() diff --git a/tests/fast/api/test_insert_into.py b/tests/fast/api/test_insert_into.py index 2537c182..1214203b 100644 --- a/tests/fast/api/test_insert_into.py +++ b/tests/fast/api/test_insert_into.py @@ -1,9 +1,10 @@ -import duckdb -from pandas import DataFrame import pytest +from pandas import DataFrame + +import duckdb -class TestInsertInto(object): +class TestInsertInto: def test_insert_into_schema(self, duckdb_cursor): # open connection con = duckdb.connect() diff --git a/tests/fast/api/test_join.py b/tests/fast/api/test_join.py index 5e2a148f..30ace540 100644 --- a/tests/fast/api/test_join.py +++ b/tests/fast/api/test_join.py @@ -1,8 +1,9 @@ -import duckdb import pytest +import duckdb + -class TestJoin(object): +class TestJoin: def test_alias_from_sql(self): con = duckdb.connect() rel1 = con.sql("SELECT 1 AS col1, 2 AS col2") diff --git a/tests/fast/api/test_native_tz.py b/tests/fast/api/test_native_tz.py index f4a9d716..39d301e2 100644 --- a/tests/fast/api/test_native_tz.py +++ b/tests/fast/api/test_native_tz.py @@ -1,9 +1,10 @@ -import duckdb import datetime -import pytz import os + import pytest +import duckdb + pd = pytest.importorskip("pandas") pa = pytest.importorskip("pyarrow") from packaging.version import Version @@ -11,7 +12,7 @@ filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "data", "tz.parquet") -class TestNativeTimeZone(object): +class TestNativeTimeZone: def test_native_python_timestamp_timezone(self, duckdb_cursor): duckdb_cursor.execute("SET timezone='America/Los_Angeles';") res = duckdb_cursor.execute(f"select TimeRecStart as tz from '{filename}'").fetchone() diff --git a/tests/fast/api/test_query_interrupt.py b/tests/fast/api/test_query_interrupt.py index e6d2b998..56c182f8 100644 --- a/tests/fast/api/test_query_interrupt.py +++ b/tests/fast/api/test_query_interrupt.py @@ -1,10 +1,11 @@ -import duckdb +import _thread as thread +import platform +import threading import time + import pytest -import platform -import threading -import _thread as thread +import duckdb def send_keyboard_interrupt(): @@ -14,7 +15,7 @@ def send_keyboard_interrupt(): thread.interrupt_main() -class TestQueryInterruption(object): +class TestQueryInterruption: @pytest.mark.xfail( condition=platform.system() == "Emscripten", reason="Emscripten builds cannot use threads", diff --git a/tests/fast/api/test_query_progress.py b/tests/fast/api/test_query_progress.py index f885e36d..8d1d85a9 100644 --- a/tests/fast/api/test_query_progress.py +++ b/tests/fast/api/test_query_progress.py @@ -2,11 +2,12 @@ import threading import time -import duckdb import pytest +import duckdb + -class TestQueryProgress(object): +class TestQueryProgress: @pytest.mark.xfail( condition=platform.system() == "Emscripten", reason="threads not allowed on Emscripten", @@ -33,7 +34,7 @@ def thread_target(): # query never progresses. This will also fail if the query is too # quick as it will be back at -1 as soon as the query is finished. - for _ in range(0, 500): + for _ in range(500): assert thread.is_alive(), "query finished too quick" if (qp1 := conn.query_progress()) > 0: break @@ -42,7 +43,7 @@ def thread_target(): pytest.fail("query start timeout") # keep monitoring and wait for the progress to increase - for _ in range(0, 500): + for _ in range(500): assert thread.is_alive(), "query finished too quick" if (qp2 := conn.query_progress()) > qp1: break diff --git a/tests/fast/api/test_read_csv.py b/tests/fast/api/test_read_csv.py index dff90869..a4e90c44 100644 --- a/tests/fast/api/test_read_csv.py +++ b/tests/fast/api/test_read_csv.py @@ -1,11 +1,12 @@ -from multiprocessing.sharedctypes import Value import datetime -import pytest import platform +import sys +from io import BytesIO, StringIO + +import pytest + import duckdb -from io import StringIO, BytesIO from duckdb import CSVLineTerminator -import sys def TestFile(name): @@ -33,7 +34,7 @@ def create_temp_csv(tmp_path): return file1_path, file2_path -class TestReadCSV(object): +class TestReadCSV: def test_using_connection_wrapper(self): rel = duckdb.read_csv(TestFile("category.csv")) res = rel.fetchone() @@ -361,7 +362,6 @@ def test_filelike_custom(self, duckdb_cursor): class CustomIO: def __init__(self) -> None: self.loc = 0 - pass def seek(self, loc): self.loc = loc diff --git a/tests/fast/api/test_relation_to_view.py b/tests/fast/api/test_relation_to_view.py index 31a19d54..14f4cb4d 100644 --- a/tests/fast/api/test_relation_to_view.py +++ b/tests/fast/api/test_relation_to_view.py @@ -1,8 +1,9 @@ import pytest + import duckdb -class TestRelationToView(object): +class TestRelationToView: def test_values_to_view(self, duckdb_cursor): rel = duckdb_cursor.values(["test", "this is a long string"]) res = rel.fetchall() diff --git a/tests/fast/api/test_streaming_result.py b/tests/fast/api/test_streaming_result.py index 739fd17a..700057ed 100644 --- a/tests/fast/api/test_streaming_result.py +++ b/tests/fast/api/test_streaming_result.py @@ -1,8 +1,9 @@ import pytest + import duckdb -class TestStreamingResult(object): +class TestStreamingResult: def test_fetch_one(self, duckdb_cursor): # fetch one res = duckdb_cursor.sql("SELECT * FROM range(100000)") diff --git a/tests/fast/api/test_to_csv.py b/tests/fast/api/test_to_csv.py index 5f8000a9..ef2aef6c 100644 --- a/tests/fast/api/test_to_csv.py +++ b/tests/fast/api/test_to_csv.py @@ -1,14 +1,15 @@ -import duckdb -import tempfile -import os -import pandas._testing as tm -import datetime import csv +import datetime +import os +import tempfile + import pytest -from conftest import NumpyPandas, ArrowPandas, getTimeSeriesData +from conftest import ArrowPandas, NumpyPandas, getTimeSeriesData + +import duckdb -class TestToCSV(object): +class TestToCSV: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_basic_to_csv(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) diff --git a/tests/fast/api/test_to_parquet.py b/tests/fast/api/test_to_parquet.py index c13ac011..834763bf 100644 --- a/tests/fast/api/test_to_parquet.py +++ b/tests/fast/api/test_to_parquet.py @@ -1,15 +1,13 @@ -import duckdb -import tempfile import os import tempfile -import pandas._testing as tm -import datetime -import csv + import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb -class TestToParquet(object): +class TestToParquet: @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) def test_basic_to_parquet(self, pd): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) @@ -43,12 +41,12 @@ def test_field_ids(self): rel.to_parquet(temp_file_name, field_ids=dict(i=42, my_struct={"__duckdb_field_id": 43, "j": 44})) parquet_rel = duckdb.read_parquet(temp_file_name) assert rel.execute().fetchall() == parquet_rel.execute().fetchall() - assert [("duckdb_schema", None), ("i", 42), ("my_struct", 43), ("j", 44)] == duckdb.sql( + assert duckdb.sql( f""" select name,field_id from parquet_schema('{temp_file_name}') """ - ).execute().fetchall() + ).execute().fetchall() == [("duckdb_schema", None), ("i", 42), ("my_struct", 43), ("j", 44)] @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) @pytest.mark.parametrize("row_group_size_bytes", [122880 * 1024, "2MB"]) diff --git a/tests/fast/api/test_with_propagating_exceptions.py b/tests/fast/api/test_with_propagating_exceptions.py index 8613d6f4..41df088f 100644 --- a/tests/fast/api/test_with_propagating_exceptions.py +++ b/tests/fast/api/test_with_propagating_exceptions.py @@ -1,8 +1,9 @@ import pytest + import duckdb -class TestWithPropagatingExceptions(object): +class TestWithPropagatingExceptions: def test_with(self): # Should propagate exception raised in the 'with duckdb.connect() ..' with pytest.raises(duckdb.ParserException, match="syntax error at or near *"): diff --git a/tests/fast/arrow/parquet_write_roundtrip.py b/tests/fast/arrow/parquet_write_roundtrip.py index 5dbf3949..5c42773c 100644 --- a/tests/fast/arrow/parquet_write_roundtrip.py +++ b/tests/fast/arrow/parquet_write_roundtrip.py @@ -1,9 +1,11 @@ -import duckdb -import pytest +import datetime import tempfile + import numpy import pandas -import datetime +import pytest + +import duckdb pa = pytest.importorskip("pyarrow") @@ -37,7 +39,7 @@ def parquet_types_test(type_list): assert read_df.equals(read_from_arrow) -class TestParquetRoundtrip(object): +class TestParquetRoundtrip: def test_roundtrip_numeric(self, duckdb_cursor): type_list = [ ([-(2**7), 0, 2**7 - 1], numpy.int8, "TINYINT"), diff --git a/tests/fast/arrow/test_10795.py b/tests/fast/arrow/test_10795.py index 5503e529..5dc88402 100644 --- a/tests/fast/arrow/test_10795.py +++ b/tests/fast/arrow/test_10795.py @@ -1,6 +1,7 @@ -import duckdb import pytest +import duckdb + pyarrow = pytest.importorskip("pyarrow") diff --git a/tests/fast/arrow/test_12384.py b/tests/fast/arrow/test_12384.py index d2d4a7fc..e91cbe8c 100644 --- a/tests/fast/arrow/test_12384.py +++ b/tests/fast/arrow/test_12384.py @@ -1,7 +1,9 @@ -import duckdb -import pytest import os +import pytest + +import duckdb + pa = pytest.importorskip("pyarrow") diff --git a/tests/fast/arrow/test_14344.py b/tests/fast/arrow/test_14344.py index 86f8728b..77cfaaa2 100644 --- a/tests/fast/arrow/test_14344.py +++ b/tests/fast/arrow/test_14344.py @@ -1,4 +1,3 @@ -import duckdb import pytest pa = pytest.importorskip("pyarrow") @@ -6,13 +5,13 @@ def test_14344(duckdb_cursor): - my_table = pa.Table.from_pydict({"foo": pa.array([hashlib.sha256("foo".encode()).digest()], type=pa.binary())}) + my_table = pa.Table.from_pydict({"foo": pa.array([hashlib.sha256(b"foo").digest()], type=pa.binary())}) my_table2 = pa.Table.from_pydict( - {"foo": pa.array([hashlib.sha256("foo".encode()).digest()], type=pa.binary()), "a": ["123"]} + {"foo": pa.array([hashlib.sha256(b"foo").digest()], type=pa.binary()), "a": ["123"]} ) res = duckdb_cursor.sql( - f""" + """ SELECT my_table2.* EXCLUDE (foo) FROM diff --git a/tests/fast/arrow/test_2426.py b/tests/fast/arrow/test_2426.py index 6d760500..5e6d42ef 100644 --- a/tests/fast/arrow/test_2426.py +++ b/tests/fast/arrow/test_2426.py @@ -1,15 +1,14 @@ + import duckdb -import os try: - import pyarrow as pa can_run = True except: can_run = False -class Test2426(object): +class Test2426: def test_2426(self, duckdb_cursor): if not can_run: return diff --git a/tests/fast/arrow/test_5547.py b/tests/fast/arrow/test_5547.py index eb77ab83..8e8b40ed 100644 --- a/tests/fast/arrow/test_5547.py +++ b/tests/fast/arrow/test_5547.py @@ -1,7 +1,8 @@ -import duckdb import pandas as pd -from pandas.testing import assert_frame_equal import pytest +from pandas.testing import assert_frame_equal + +import duckdb pa = pytest.importorskip("pyarrow") diff --git a/tests/fast/arrow/test_6584.py b/tests/fast/arrow/test_6584.py index 6f96bf2d..feadc6d7 100644 --- a/tests/fast/arrow/test_6584.py +++ b/tests/fast/arrow/test_6584.py @@ -1,7 +1,9 @@ from concurrent.futures import ThreadPoolExecutor -import duckdb + import pytest +import duckdb + pyarrow = pytest.importorskip("pyarrow") diff --git a/tests/fast/arrow/test_6796.py b/tests/fast/arrow/test_6796.py index ef464f49..454fa005 100644 --- a/tests/fast/arrow/test_6796.py +++ b/tests/fast/arrow/test_6796.py @@ -1,6 +1,7 @@ -import duckdb import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb pyarrow = pytest.importorskip("pyarrow") diff --git a/tests/fast/arrow/test_7652.py b/tests/fast/arrow/test_7652.py index 857d871d..e38174b9 100644 --- a/tests/fast/arrow/test_7652.py +++ b/tests/fast/arrow/test_7652.py @@ -1,13 +1,12 @@ -import duckdb -import os -import pytest import tempfile +import pytest + pa = pytest.importorskip("pyarrow", minversion="11") pq = pytest.importorskip("pyarrow.parquet", minversion="11") -class Test7652(object): +class Test7652: def test_7652(self, duckdb_cursor): temp_file_name = tempfile.NamedTemporaryFile(suffix=".parquet").name # Generate a list of values that aren't uniform in changes. diff --git a/tests/fast/arrow/test_7699.py b/tests/fast/arrow/test_7699.py index a4de66b9..ba2f4af3 100644 --- a/tests/fast/arrow/test_7699.py +++ b/tests/fast/arrow/test_7699.py @@ -1,13 +1,13 @@ -import duckdb -import pytest import string +import pytest + pa = pytest.importorskip("pyarrow") pq = pytest.importorskip("pyarrow.parquet") pl = pytest.importorskip("polars") -class Test7699(object): +class Test7699: def test_7699(self, duckdb_cursor): pl_tbl = pl.DataFrame( { diff --git a/tests/fast/arrow/test_8522.py b/tests/fast/arrow/test_8522.py index 84aa125c..681e8fdf 100644 --- a/tests/fast/arrow/test_8522.py +++ b/tests/fast/arrow/test_8522.py @@ -1,15 +1,14 @@ -import duckdb -import pytest -import string import datetime as dt +import pytest + pa = pytest.importorskip("pyarrow") # Reconstruct filters when pushing down into arrow scan # arrow supports timestamp_tz with different units than US, we only support US # so we have to convert ConstantValues back to their native unit when pushing the filter expression containing them down to pyarrow -class Test8522(object): +class Test8522: def test_8522(self, duckdb_cursor): t_us = pa.Table.from_arrays( arrays=[pa.array([dt.datetime(2022, 1, 1)])], diff --git a/tests/fast/arrow/test_9443.py b/tests/fast/arrow/test_9443.py index 7de04bde..f6627c00 100644 --- a/tests/fast/arrow/test_9443.py +++ b/tests/fast/arrow/test_9443.py @@ -1,4 +1,3 @@ -import duckdb import pytest pq = pytest.importorskip("pyarrow.parquet") @@ -8,7 +7,7 @@ from pathlib import PurePosixPath -class Test9443(object): +class Test9443: def test_9443(self, tmp_path, duckdb_cursor): arrow_table = pa.Table.from_pylist( [ diff --git a/tests/fast/arrow/test_arrow_batch_index.py b/tests/fast/arrow/test_arrow_batch_index.py index a8dc2c7f..0cd4d679 100644 --- a/tests/fast/arrow/test_arrow_batch_index.py +++ b/tests/fast/arrow/test_arrow_batch_index.py @@ -1,12 +1,11 @@ -import duckdb import pytest -import pandas as pd + import duckdb pa = pytest.importorskip("pyarrow") -class TestArrowBatchIndex(object): +class TestArrowBatchIndex: def test_arrow_batch_index(self, duckdb_cursor): con = duckdb.connect() df = con.execute("SELECT * FROM range(10000000) t(i)").df() diff --git a/tests/fast/arrow/test_arrow_binary_view.py b/tests/fast/arrow/test_arrow_binary_view.py index 31107f67..4e161ac3 100644 --- a/tests/fast/arrow/test_arrow_binary_view.py +++ b/tests/fast/arrow/test_arrow_binary_view.py @@ -1,10 +1,11 @@ -import duckdb import pytest +import duckdb + pa = pytest.importorskip("pyarrow") -class TestArrowBinaryView(object): +class TestArrowBinaryView: def test_arrow_binary_view(self, duckdb_cursor): con = duckdb.connect() tab = pa.table({"x": pa.array([b"abc", b"thisisaverybigbinaryyaymorethanfifteen", None], pa.binary_view())}) diff --git a/tests/fast/arrow/test_arrow_case_sensitive.py b/tests/fast/arrow/test_arrow_case_sensitive.py index ef60046a..11bca339 100644 --- a/tests/fast/arrow/test_arrow_case_sensitive.py +++ b/tests/fast/arrow/test_arrow_case_sensitive.py @@ -1,10 +1,9 @@ -import duckdb import pytest pa = pytest.importorskip("pyarrow") -class TestArrowCaseSensitive(object): +class TestArrowCaseSensitive: def test_arrow_case_sensitive(self, duckdb_cursor): data = (pa.array([1], type=pa.int32()), pa.array([1000], type=pa.int32())) arrow_table = pa.Table.from_arrays([data[0], data[1]], ["A1", "a1"]) diff --git a/tests/fast/arrow/test_arrow_decimal256.py b/tests/fast/arrow/test_arrow_decimal256.py index 0ab84d3a..08612918 100644 --- a/tests/fast/arrow/test_arrow_decimal256.py +++ b/tests/fast/arrow/test_arrow_decimal256.py @@ -1,11 +1,13 @@ -import duckdb -import pytest from decimal import Decimal +import pytest + +import duckdb + pa = pytest.importorskip("pyarrow") -class TestArrowDecimal256(object): +class TestArrowDecimal256: def test_decimal_256_throws(self, duckdb_cursor): with duckdb.connect() as conn: pa_decimal256 = pa.Table.from_pylist( diff --git a/tests/fast/arrow/test_arrow_decimal_32_64.py b/tests/fast/arrow/test_arrow_decimal_32_64.py index 39b6e43a..301d890f 100644 --- a/tests/fast/arrow/test_arrow_decimal_32_64.py +++ b/tests/fast/arrow/test_arrow_decimal_32_64.py @@ -1,11 +1,13 @@ -import duckdb -import pytest from decimal import Decimal +import pytest + +import duckdb + pa = pytest.importorskip("pyarrow") -class TestArrowDecimalTypes(object): +class TestArrowDecimalTypes: def test_decimal_32(self, duckdb_cursor): duckdb_cursor = duckdb.connect() duckdb_cursor.execute("SET arrow_output_version = 1.5") diff --git a/tests/fast/arrow/test_arrow_extensions.py b/tests/fast/arrow/test_arrow_extensions.py index 43c995bb..1bc0e179 100644 --- a/tests/fast/arrow/test_arrow_extensions.py +++ b/tests/fast/arrow/test_arrow_extensions.py @@ -1,14 +1,16 @@ -import duckdb -import pytest -import uuid +import datetime import json +import uuid from uuid import UUID -import datetime + +import pytest + +import duckdb pa = pytest.importorskip("pyarrow", "18.0.0") -class TestCanonicalExtensionTypes(object): +class TestCanonicalExtensionTypes: def test_uuid(self): duckdb_cursor = duckdb.connect() duckdb_cursor.execute("SET arrow_lossless_conversion = true") diff --git a/tests/fast/arrow/test_arrow_fetch.py b/tests/fast/arrow/test_arrow_fetch.py index a969da21..62460912 100644 --- a/tests/fast/arrow/test_arrow_fetch.py +++ b/tests/fast/arrow/test_arrow_fetch.py @@ -1,8 +1,7 @@ + import duckdb -import pytest try: - import pyarrow as pa can_run = True except: @@ -18,7 +17,7 @@ def check_equal(duckdb_conn): assert arrow_result == true_result -class TestArrowFetch(object): +class TestArrowFetch: def test_empty_table(self, duckdb_cursor): if not can_run: return diff --git a/tests/fast/arrow/test_arrow_fetch_recordbatch.py b/tests/fast/arrow/test_arrow_fetch_recordbatch.py index 8915d886..4d7fe28a 100644 --- a/tests/fast/arrow/test_arrow_fetch_recordbatch.py +++ b/tests/fast/arrow/test_arrow_fetch_recordbatch.py @@ -1,10 +1,11 @@ -import duckdb import pytest +import duckdb + pa = pytest.importorskip("pyarrow") -class TestArrowFetchRecordBatch(object): +class TestArrowFetchRecordBatch: # Test with basic numeric conversion (integers, floats, and others fall this code-path) def test_record_batch_next_batch_numeric(self, duckdb_cursor): duckdb_cursor = duckdb.connect() diff --git a/tests/fast/arrow/test_arrow_fixed_binary.py b/tests/fast/arrow/test_arrow_fixed_binary.py index cec8d520..754a472f 100644 --- a/tests/fast/arrow/test_arrow_fixed_binary.py +++ b/tests/fast/arrow/test_arrow_fixed_binary.py @@ -3,7 +3,7 @@ pa = pytest.importorskip("pyarrow") -class TestArrowFixedBinary(object): +class TestArrowFixedBinary: def test_arrow_fixed_binary(self, duckdb_cursor): ids = [ None, diff --git a/tests/fast/arrow/test_arrow_ipc.py b/tests/fast/arrow/test_arrow_ipc.py index 24718bbc..b3271fcd 100644 --- a/tests/fast/arrow/test_arrow_ipc.py +++ b/tests/fast/arrow/test_arrow_ipc.py @@ -1,4 +1,5 @@ import pytest + import duckdb pa = pytest.importorskip("pyarrow") @@ -11,7 +12,7 @@ def get_record_batch(): return pa.record_batch(data, names=["f0", "f1", "f2"]) -class TestArrowIPCExtension(object): +class TestArrowIPCExtension: # Only thing we can test in core is that it suggests the # instalation and loading of the extension def test_single_buffer(self, duckdb_cursor): diff --git a/tests/fast/arrow/test_arrow_list.py b/tests/fast/arrow/test_arrow_list.py index 47b8cb2a..4c2804a0 100644 --- a/tests/fast/arrow/test_arrow_list.py +++ b/tests/fast/arrow/test_arrow_list.py @@ -1,4 +1,3 @@ -import duckdb import numpy as np import pytest @@ -91,13 +90,13 @@ def generate_list(child_size) -> ListGenerationResult: return ListGenerationResult(list_arr, list_view_arr) -class TestArrowListType(object): +class TestArrowListType: def test_regular_list(self, duckdb_cursor): n = 5 # Amount of lists generated_size = 3 # Size of each list list_size = -1 # Argument passed to `pa._list()` - data = [np.random.random((generated_size)) for _ in range(n)] + data = [np.random.random(generated_size) for _ in range(n)] list_type = pa.list_(pa.float32(), list_size=list_size) create_and_register_arrow_table( @@ -120,7 +119,7 @@ def test_fixedsize_list(self, duckdb_cursor): generated_size = 3 # Size of each list list_size = 3 # Argument passed to `pa._list()` - data = [np.random.random((generated_size)) for _ in range(n)] + data = [np.random.random(generated_size) for _ in range(n)] list_type = pa.list_(pa.float32(), list_size=list_size) create_and_register_arrow_table( diff --git a/tests/fast/arrow/test_arrow_offsets.py b/tests/fast/arrow/test_arrow_offsets.py index 0ddc0f7d..32a59112 100644 --- a/tests/fast/arrow/test_arrow_offsets.py +++ b/tests/fast/arrow/test_arrow_offsets.py @@ -1,9 +1,9 @@ -import duckdb -import pytest -from pytest import mark import datetime import decimal + +import pytest import pytz +from pytest import mark pa = pytest.importorskip("pyarrow") @@ -80,10 +80,10 @@ def expected_result(col1_null, col2_null, expected): ) -class TestArrowOffsets(object): +class TestArrowOffsets: @null_test_parameters() def test_struct_of_strings(self, duckdb_cursor, col1_null, col2_null): - col1 = [str(i) for i in range(0, MAGIC_ARRAY_SIZE)] + col1 = [str(i) for i in range(MAGIC_ARRAY_SIZE)] if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -107,7 +107,7 @@ def test_struct_of_strings(self, duckdb_cursor, col1_null, col2_null): @null_test_parameters() def test_struct_of_bools(self, duckdb_cursor, col1_null, col2_null): - tuples = [False for i in range(0, MAGIC_ARRAY_SIZE)] + tuples = [False for i in range(MAGIC_ARRAY_SIZE)] tuples[-1] = True col1 = tuples @@ -140,7 +140,7 @@ def test_struct_of_bools(self, duckdb_cursor, col1_null, col2_null): ) @null_test_parameters() def test_struct_of_dates(self, duckdb_cursor, constructor, expected, col1_null, col2_null): - tuples = [i for i in range(0, MAGIC_ARRAY_SIZE)] + tuples = [i for i in range(MAGIC_ARRAY_SIZE)] col1 = tuples if col1_null: @@ -192,7 +192,7 @@ def test_struct_of_enum(self, duckdb_cursor, col1_null, col2_null): @null_test_parameters() def test_struct_of_blobs(self, duckdb_cursor, col1_null, col2_null): - col1 = [str(i) for i in range(0, MAGIC_ARRAY_SIZE)] + col1 = [str(i) for i in range(MAGIC_ARRAY_SIZE)] if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -230,7 +230,7 @@ def test_struct_of_time(self, duckdb_cursor, constructor, unit, expected, col1_n # FIXME: We limit the size because we don't support time values > 24 hours size = 86400 # The amount of seconds in a day - col1 = [i for i in range(0, size)] + col1 = [i for i in range(size)] if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -265,7 +265,7 @@ def test_struct_of_time(self, duckdb_cursor, constructor, unit, expected, col1_n def test_struct_of_interval(self, duckdb_cursor, constructor, expected, converter, col1_null, col2_null): size = MAGIC_ARRAY_SIZE - col1 = [converter(i) for i in range(0, size)] + col1 = [converter(i) for i in range(size)] if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -300,7 +300,7 @@ def test_struct_of_interval(self, duckdb_cursor, constructor, expected, converte def test_struct_of_duration(self, duckdb_cursor, constructor, unit, expected, col1_null, col2_null): size = MAGIC_ARRAY_SIZE - col1 = [i for i in range(0, size)] + col1 = [i for i in range(size)] if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -336,7 +336,7 @@ def test_struct_of_timestamp_tz(self, duckdb_cursor, constructor, unit, expected size = MAGIC_ARRAY_SIZE duckdb_cursor.execute("set timezone='UTC'") - col1 = [i for i in range(0, size)] + col1 = [i for i in range(size)] if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -362,7 +362,7 @@ def test_struct_of_timestamp_tz(self, duckdb_cursor, constructor, unit, expected @null_test_parameters() def test_struct_of_large_blobs(self, duckdb_cursor, col1_null, col2_null): - col1 = [str(i) for i in range(0, MAGIC_ARRAY_SIZE)] + col1 = [str(i) for i in range(MAGIC_ARRAY_SIZE)] if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -400,7 +400,7 @@ def test_struct_of_large_blobs(self, duckdb_cursor, col1_null, col2_null): ) def test_struct_of_decimal(self, duckdb_cursor, precision_scale, expected, col1_null, col2_null): precision, scale = precision_scale - col1 = [decimal_value(i, precision, scale) for i in range(0, MAGIC_ARRAY_SIZE)] + col1 = [decimal_value(i, precision, scale) for i in range(MAGIC_ARRAY_SIZE)] if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -427,7 +427,7 @@ def test_struct_of_decimal(self, duckdb_cursor, precision_scale, expected, col1_ @null_test_parameters() def test_struct_of_small_list(self, duckdb_cursor, col1_null, col2_null): - col1 = [str(i) for i in range(0, MAGIC_ARRAY_SIZE)] + col1 = [str(i) for i in range(MAGIC_ARRAY_SIZE)] if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -457,7 +457,7 @@ def test_struct_of_small_list(self, duckdb_cursor, col1_null, col2_null): @null_test_parameters() def test_struct_of_fixed_size_list(self, duckdb_cursor, col1_null, col2_null): - col1 = [str(i) for i in range(0, MAGIC_ARRAY_SIZE)] + col1 = [str(i) for i in range(MAGIC_ARRAY_SIZE)] if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -487,7 +487,7 @@ def test_struct_of_fixed_size_list(self, duckdb_cursor, col1_null, col2_null): @null_test_parameters() def test_struct_of_fixed_size_blob(self, duckdb_cursor, col1_null, col2_null): - col1 = [str(i) for i in range(0, MAGIC_ARRAY_SIZE)] + col1 = [str(i) for i in range(MAGIC_ARRAY_SIZE)] if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -518,7 +518,7 @@ def test_struct_of_fixed_size_blob(self, duckdb_cursor, col1_null, col2_null): @null_test_parameters() def test_struct_of_list_of_blobs(self, duckdb_cursor, col1_null, col2_null): - col1 = [str(i) for i in range(0, MAGIC_ARRAY_SIZE)] + col1 = [str(i) for i in range(MAGIC_ARRAY_SIZE)] if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -549,7 +549,7 @@ def test_struct_of_list_of_blobs(self, duckdb_cursor, col1_null, col2_null): @null_test_parameters() def test_struct_of_list_of_list(self, duckdb_cursor, col1_null, col2_null): - col1 = [i for i in range(0, MAGIC_ARRAY_SIZE)] + col1 = [i for i in range(MAGIC_ARRAY_SIZE)] if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -581,7 +581,7 @@ def test_struct_of_list_of_list(self, duckdb_cursor, col1_null, col2_null): @pytest.mark.parametrize("col1_null", [True, False]) def test_list_of_struct(self, duckdb_cursor, col1_null): # One single tuple containing a very big list - tuples = [{"a": i} for i in range(0, MAGIC_ARRAY_SIZE)] + tuples = [{"a": i} for i in range(MAGIC_ARRAY_SIZE)] if col1_null: tuples[-1] = None tuples = [tuples] @@ -590,7 +590,7 @@ def test_list_of_struct(self, duckdb_cursor, col1_null): schema=pa.schema([("col1", pa.list_(pa.struct({"a": pa.int32()})))]), ) res = duckdb_cursor.sql( - f""" + """ SELECT col1 FROM arrow_table diff --git a/tests/fast/arrow/test_arrow_pycapsule.py b/tests/fast/arrow/test_arrow_pycapsule.py index 6df5053f..295f0292 100644 --- a/tests/fast/arrow/test_arrow_pycapsule.py +++ b/tests/fast/arrow/test_arrow_pycapsule.py @@ -1,6 +1,7 @@ -import duckdb + import pytest -import os + +import duckdb pl = pytest.importorskip("polars") @@ -14,7 +15,7 @@ def polars_supports_capsule(): @pytest.mark.skipif( not polars_supports_capsule(), reason="Polars version does not support the Arrow PyCapsule interface" ) -class TestArrowPyCapsule(object): +class TestArrowPyCapsule: def test_polars_pycapsule_scan(self, duckdb_cursor): class MyObject: def __init__(self, obj) -> None: diff --git a/tests/fast/arrow/test_arrow_recordbatchreader.py b/tests/fast/arrow/test_arrow_recordbatchreader.py index a9523d43..80520499 100644 --- a/tests/fast/arrow/test_arrow_recordbatchreader.py +++ b/tests/fast/arrow/test_arrow_recordbatchreader.py @@ -1,14 +1,16 @@ -import duckdb import os + import pytest +import duckdb + pyarrow = pytest.importorskip("pyarrow") pyarrow.parquet = pytest.importorskip("pyarrow.parquet") pyarrow.dataset = pytest.importorskip("pyarrow.dataset") np = pytest.importorskip("numpy") -class TestArrowRecordBatchReader(object): +class TestArrowRecordBatchReader: def test_parallel_reader(self, duckdb_cursor): duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") diff --git a/tests/fast/arrow/test_arrow_replacement_scan.py b/tests/fast/arrow/test_arrow_replacement_scan.py index f2a9c13b..614e1e9f 100644 --- a/tests/fast/arrow/test_arrow_replacement_scan.py +++ b/tests/fast/arrow/test_arrow_replacement_scan.py @@ -1,14 +1,15 @@ -import duckdb -import pytest import os -import pandas as pd + +import pytest + +import duckdb pa = pytest.importorskip("pyarrow") pq = pytest.importorskip("pyarrow.parquet") ds = pytest.importorskip("pyarrow.dataset") -class TestArrowReplacementScan(object): +class TestArrowReplacementScan: def test_arrow_table_replacement_scan(self, duckdb_cursor): parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") userdata_parquet_table = pq.read_table(parquet_filename) diff --git a/tests/fast/arrow/test_arrow_run_end_encoding.py b/tests/fast/arrow/test_arrow_run_end_encoding.py index c6f9fad5..de841dd0 100644 --- a/tests/fast/arrow/test_arrow_run_end_encoding.py +++ b/tests/fast/arrow/test_arrow_run_end_encoding.py @@ -1,7 +1,4 @@ -import duckdb import pytest -import pandas as pd -import duckdb pa = pytest.importorskip("pyarrow", "21.0.0", reason="Needs pyarrow >= 21") pc = pytest.importorskip("pyarrow.compute") @@ -30,7 +27,7 @@ def list_constructors(): return result -class TestArrowREE(object): +class TestArrowREE: @pytest.mark.parametrize( "query", [ @@ -130,7 +127,7 @@ def test_arrow_run_end_encoding(self, duckdb_cursor, dbtype, val1, val2, filter) duckdb_cursor.execute(query) rel = duckdb_cursor.query("select * from ree_tbl") - expected = duckdb_cursor.query("select {} from ree_tbl where {}".format(projection, filter)).fetchall() + expected = duckdb_cursor.query(f"select {projection} from ree_tbl where {filter}").fetchall() # Create an Arrow Table from the table arrow_conversion = rel.fetch_arrow_table() @@ -156,7 +153,7 @@ def test_arrow_run_end_encoding(self, duckdb_cursor, dbtype, val1, val2, filter) tbl = pa.Table.from_arrays([encoded_arrays["ree"], encoded_arrays["a"], encoded_arrays["b"]], schema=schema) # Scan the Arrow Table and verify that the results are the same - res = duckdb_cursor.sql("select {} from tbl where {}".format(projection, filter)).fetchall() + res = duckdb_cursor.sql(f"select {projection} from tbl where {filter}").fetchall() assert res == expected def test_arrow_ree_empty_table(self, duckdb_cursor): @@ -227,26 +224,26 @@ def test_arrow_ree_projections(self, duckdb_cursor, projection): # This should be pushed down into arrow to only provide us with the necessary columns res = duckdb_cursor.query( - """ - select {} from arrow_tbl - """.format(projection) + f""" + select {projection} from arrow_tbl + """ ).fetch_arrow_table() # Verify correctness by fetching from the original table and the constructed result - expected = duckdb_cursor.query("select {} from tbl".format(projection)).fetchall() - actual = duckdb_cursor.query("select {} from res".format(projection)).fetchall() + expected = duckdb_cursor.query(f"select {projection} from tbl").fetchall() + actual = duckdb_cursor.query(f"select {projection} from res").fetchall() assert expected == actual @pytest.mark.parametrize("create_list", list_constructors()) def test_arrow_ree_list(self, duckdb_cursor, create_list): size = 1000 duckdb_cursor.query( - """ + f""" create table tbl as select i // 4 as ree, - FROM range({}) t(i) - """.format(size) + FROM range({size}) t(i) + """ ) # Populate the table with data @@ -325,15 +322,15 @@ def test_arrow_ree_union(self, duckdb_cursor): size = 1000 duckdb_cursor.query( - """ + f""" create table tbl as select i // 4 as ree, i as a, i % 2 == 0 as b, i::VARCHAR as c - FROM range({}) t(i) - """.format(size) + FROM range({size}) t(i) + """ ) # Populate the table with data @@ -383,13 +380,13 @@ def test_arrow_ree_map(self, duckdb_cursor): size = 1000 duckdb_cursor.query( - """ + f""" create table tbl as select i // 4 as ree, i as a, - FROM range({}) t(i) - """.format(size) + FROM range({size}) t(i) + """ ) # Populate the table with data @@ -433,12 +430,12 @@ def test_arrow_ree_dictionary(self, duckdb_cursor): size = 1000 duckdb_cursor.query( - """ + f""" create table tbl as select i // 4 as ree, - FROM range({}) t(i) - """.format(size) + FROM range({size}) t(i) + """ ) # Populate the table with data diff --git a/tests/fast/arrow/test_arrow_scanner.py b/tests/fast/arrow/test_arrow_scanner.py index 2e8b1296..ccfa5676 100644 --- a/tests/fast/arrow/test_arrow_scanner.py +++ b/tests/fast/arrow/test_arrow_scanner.py @@ -1,20 +1,20 @@ -import duckdb import os +import duckdb + try: import pyarrow - import pyarrow.parquet + import pyarrow.compute as pc import pyarrow.dataset + import pyarrow.parquet from pyarrow.dataset import Scanner - import pyarrow.compute as pc - import numpy as np can_run = True except: can_run = False -class TestArrowScanner(object): +class TestArrowScanner: def test_parallel_scanner(self, duckdb_cursor): if not can_run: return diff --git a/tests/fast/arrow/test_arrow_string_view.py b/tests/fast/arrow/test_arrow_string_view.py index a1b46e5b..0d34bb6e 100644 --- a/tests/fast/arrow/test_arrow_string_view.py +++ b/tests/fast/arrow/test_arrow_string_view.py @@ -1,6 +1,6 @@ -import duckdb import pytest -from packaging import version + +import duckdb pa = pytest.importorskip("pyarrow") @@ -43,7 +43,7 @@ def RoundTripDuckDBInternal(query): assert res[i] == from_arrow_res[i] -class TestArrowStringView(object): +class TestArrowStringView: # Test Small Inlined String View def test_inlined_string_view(self): RoundTripStringView( diff --git a/tests/fast/arrow/test_arrow_types.py b/tests/fast/arrow/test_arrow_types.py index f2bf71c7..199874cf 100644 --- a/tests/fast/arrow/test_arrow_types.py +++ b/tests/fast/arrow/test_arrow_types.py @@ -1,11 +1,12 @@ -import duckdb import pytest +import duckdb + pa = pytest.importorskip("pyarrow") ds = pytest.importorskip("pyarrow.dataset") -class TestArrowTypes(object): +class TestArrowTypes: def test_null_type(self, duckdb_cursor): schema = pa.schema([("data", pa.null())]) inputs = [pa.array([None, None, None], type=pa.null())] diff --git a/tests/fast/arrow/test_arrow_union.py b/tests/fast/arrow/test_arrow_union.py index c0a5d568..04fd73b3 100644 --- a/tests/fast/arrow/test_arrow_union.py +++ b/tests/fast/arrow/test_arrow_union.py @@ -2,8 +2,7 @@ importorskip("pyarrow") -import duckdb -from pyarrow import scalar, string, large_string, list_, int32, types +from pyarrow import int32, list_, scalar, string, types def test_nested(duckdb_cursor): diff --git a/tests/fast/arrow/test_arrow_version_format.py b/tests/fast/arrow/test_arrow_version_format.py index fd169ce0..ed335f9e 100644 --- a/tests/fast/arrow/test_arrow_version_format.py +++ b/tests/fast/arrow/test_arrow_version_format.py @@ -1,14 +1,16 @@ -import duckdb -import pytest from decimal import Decimal +import pytest + +import duckdb + pa = pytest.importorskip("pyarrow") -class TestArrowDecimalTypes(object): +class TestArrowDecimalTypes: def test_decimal_v1_5(self, duckdb_cursor): duckdb_cursor = duckdb.connect() - duckdb_cursor.execute(f"SET arrow_output_version = 1.5") + duckdb_cursor.execute("SET arrow_output_version = 1.5") decimal_32 = pa.Table.from_pylist( [ {"data": Decimal("100.20")}, @@ -51,11 +53,11 @@ def test_decimal_v1_5(self, duckdb_cursor): def test_invalide_opt(self, duckdb_cursor): duckdb_cursor = duckdb.connect() with pytest.raises(duckdb.NotImplementedException, match="unrecognized"): - duckdb_cursor.execute(f"SET arrow_output_version = 999.9") + duckdb_cursor.execute("SET arrow_output_version = 999.9") def test_view_v1_4(self, duckdb_cursor): duckdb_cursor = duckdb.connect() - duckdb_cursor.execute(f"SET arrow_output_version = 1.5") + duckdb_cursor.execute("SET arrow_output_version = 1.5") duckdb_cursor.execute("SET produce_arrow_string_view=True") duckdb_cursor.execute("SET arrow_output_list_view=True") col_type = duckdb_cursor.execute("SELECT 'string' as data ").fetch_arrow_table().schema.field("data").type diff --git a/tests/fast/arrow/test_binary_type.py b/tests/fast/arrow/test_binary_type.py index 489d4caf..5932fba8 100644 --- a/tests/fast/arrow/test_binary_type.py +++ b/tests/fast/arrow/test_binary_type.py @@ -1,10 +1,8 @@ + import duckdb -import os try: import pyarrow as pa - from pyarrow import parquet as pq - import numpy as np can_run = True except: @@ -17,7 +15,7 @@ def create_binary_table(type): return pa.Table.from_arrays(inputs, schema=schema) -class TestArrowBinary(object): +class TestArrowBinary: def test_binary_types(self, duckdb_cursor): if not can_run: return diff --git a/tests/fast/arrow/test_buffer_size_option.py b/tests/fast/arrow/test_buffer_size_option.py index 7d5131e5..845d0db0 100644 --- a/tests/fast/arrow/test_buffer_size_option.py +++ b/tests/fast/arrow/test_buffer_size_option.py @@ -1,11 +1,12 @@ -import duckdb import pytest +import duckdb + pa = pytest.importorskip("pyarrow") from duckdb.typing import * -class TestArrowBufferSize(object): +class TestArrowBufferSize: def test_arrow_buffer_size(self): con = duckdb.connect() diff --git a/tests/fast/arrow/test_dataset.py b/tests/fast/arrow/test_dataset.py index 8ec0094e..aa2a8b9b 100644 --- a/tests/fast/arrow/test_dataset.py +++ b/tests/fast/arrow/test_dataset.py @@ -1,14 +1,16 @@ -import duckdb import os + import pytest +import duckdb + pyarrow = pytest.importorskip("pyarrow") np = pytest.importorskip("numpy") pyarrow.parquet = pytest.importorskip("pyarrow.parquet") pyarrow.dataset = pytest.importorskip("pyarrow.dataset") -class TestArrowDataset(object): +class TestArrowDataset: def test_parallel_dataset(self, duckdb_cursor): duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") diff --git a/tests/fast/arrow/test_date.py b/tests/fast/arrow/test_date.py index 9649ffa6..83c14932 100644 --- a/tests/fast/arrow/test_date.py +++ b/tests/fast/arrow/test_date.py @@ -1,18 +1,16 @@ + + import duckdb -import os -import datetime -import pytest try: import pyarrow as pa - import pandas as pd can_run = True except: can_run = False -class TestArrowDate(object): +class TestArrowDate: def test_date_types(self, duckdb_cursor): if not can_run: return diff --git a/tests/fast/arrow/test_dictionary_arrow.py b/tests/fast/arrow/test_dictionary_arrow.py index e4319f7c..5cb2d38d 100644 --- a/tests/fast/arrow/test_dictionary_arrow.py +++ b/tests/fast/arrow/test_dictionary_arrow.py @@ -1,4 +1,3 @@ -import duckdb import pytest @@ -12,7 +11,7 @@ Timestamp = pd.Timestamp -class TestArrowDictionary(object): +class TestArrowDictionary: def test_dictionary(self, duckdb_cursor): indices = pa.array([0, 1, 0, 1, 2, 1, 0, 2]) dictionary = pa.array([10, 100, None]) diff --git a/tests/fast/arrow/test_filter_pushdown.py b/tests/fast/arrow/test_filter_pushdown.py index 026b52f4..2238f744 100644 --- a/tests/fast/arrow/test_filter_pushdown.py +++ b/tests/fast/arrow/test_filter_pushdown.py @@ -1,12 +1,11 @@ -from re import S -import duckdb -import os +import sys + import pytest -import tempfile from conftest import pandas_supports_arrow_backend -import sys from packaging.version import Version +import duckdb + pa = pytest.importorskip("pyarrow") pq = pytest.importorskip("pyarrow.parquet") ds = pytest.importorskip("pyarrow.dataset") @@ -178,7 +177,7 @@ def string_check_or_pushdown(connection, tbl_name, create_table): assert not match -class TestArrowFilterPushdown(object): +class TestArrowFilterPushdown: @pytest.mark.parametrize( "data_type", [ @@ -532,7 +531,6 @@ def test_filter_pushdown_integers(self, duckdb_cursor, data_type, value, create_ ) def test_9371(self, duckdb_cursor, tmp_path): import datetime - import pathlib # connect to an in-memory database duckdb_cursor.execute("SET TimeZone='UTC';") diff --git a/tests/fast/arrow/test_integration.py b/tests/fast/arrow/test_integration.py index 6ab7350d..1c00c800 100644 --- a/tests/fast/arrow/test_integration.py +++ b/tests/fast/arrow/test_integration.py @@ -1,14 +1,16 @@ -import duckdb -import os import datetime +import os + import pytest +import duckdb + pa = pytest.importorskip("pyarrow") pq = pytest.importorskip("pyarrow.parquet") np = pytest.importorskip("numpy") -class TestArrowIntegration(object): +class TestArrowIntegration: def test_parquet_roundtrip(self, duckdb_cursor): parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") cols = "id, first_name, last_name, email, gender, ip_address, cc, country, birthdate, salary, title, comments" @@ -216,7 +218,7 @@ def test_strings_roundtrip(self, duckdb_cursor): duckdb_cursor.execute("CREATE TABLE test (a varchar)") # Test Small, Null and Very Big String - for i in range(0, 1000): + for i in range(1000): duckdb_cursor.execute( "INSERT INTO test VALUES ('Matt Damon'),(NULL), ('Jeffffreeeey Jeeeeef Baaaaaaazos'), ('X-Content-Type-Options')" ) diff --git a/tests/fast/arrow/test_interval.py b/tests/fast/arrow/test_interval.py index 32b7fa64..5cdb04bd 100644 --- a/tests/fast/arrow/test_interval.py +++ b/tests/fast/arrow/test_interval.py @@ -1,18 +1,17 @@ -import duckdb -import os -import datetime + import pytest +import duckdb + try: import pyarrow as pa - import pandas as pd can_run = True except: can_run = False -class TestArrowInterval(object): +class TestArrowInterval: def test_duration_types(self, duckdb_cursor): if not can_run: return diff --git a/tests/fast/arrow/test_large_offsets.py b/tests/fast/arrow/test_large_offsets.py index dccfa101..0a2669f5 100644 --- a/tests/fast/arrow/test_large_offsets.py +++ b/tests/fast/arrow/test_large_offsets.py @@ -1,9 +1,6 @@ -from re import S -import duckdb -import os import pytest -import tempfile -from conftest import pandas_supports_arrow_backend + +import duckdb pa = pytest.importorskip("pyarrow") pq = pytest.importorskip("pyarrow.parquet") @@ -11,7 +8,7 @@ np = pytest.importorskip("numpy") -class TestArrowLargeOffsets(object): +class TestArrowLargeOffsets: @pytest.mark.skip(reason="CI does not have enough memory to validate this") def test_large_lists(self, duckdb_cursor): ary = pa.array([np.arange(start=0, stop=3000, dtype=np.uint8) for i in range(1_000_000)]) diff --git a/tests/fast/arrow/test_large_string.py b/tests/fast/arrow/test_large_string.py index 308785af..bb9d1b5b 100644 --- a/tests/fast/arrow/test_large_string.py +++ b/tests/fast/arrow/test_large_string.py @@ -1,17 +1,15 @@ + import duckdb -import os try: import pyarrow as pa - from pyarrow import parquet as pq - import numpy as np can_run = True except: can_run = False -class TestArrowLargeString(object): +class TestArrowLargeString: def test_large_string_type(self, duckdb_cursor): if not can_run: return diff --git a/tests/fast/arrow/test_multiple_reads.py b/tests/fast/arrow/test_multiple_reads.py index 36fb8f59..30b0c02a 100644 --- a/tests/fast/arrow/test_multiple_reads.py +++ b/tests/fast/arrow/test_multiple_reads.py @@ -1,6 +1,7 @@ -import duckdb import os +import duckdb + try: import pyarrow import pyarrow.parquet @@ -10,7 +11,7 @@ can_run = False -class TestArrowReads(object): +class TestArrowReads: def test_multiple_queries_same_relation(self, duckdb_cursor): if not can_run: return diff --git a/tests/fast/arrow/test_nested_arrow.py b/tests/fast/arrow/test_nested_arrow.py index a906324f..42b674e3 100644 --- a/tests/fast/arrow/test_nested_arrow.py +++ b/tests/fast/arrow/test_nested_arrow.py @@ -1,7 +1,7 @@ -import duckdb - import pytest +import duckdb + pa = pytest.importorskip("pyarrow") pq = pytest.importorskip("pyarrow.parquet") np = pytest.importorskip("numpy") @@ -27,7 +27,7 @@ def get_use_list_view_options(): return result -class TestArrowNested(object): +class TestArrowNested: def test_lists_basic(self, duckdb_cursor): # Test Constant List query = ( diff --git a/tests/fast/arrow/test_parallel.py b/tests/fast/arrow/test_parallel.py index c768a1dd..3348c13e 100644 --- a/tests/fast/arrow/test_parallel.py +++ b/tests/fast/arrow/test_parallel.py @@ -1,17 +1,18 @@ -import duckdb import os +import duckdb + try: + import numpy as np import pyarrow import pyarrow.parquet - import numpy as np can_run = True except: can_run = False -class TestArrowParallel(object): +class TestArrowParallel: def test_parallel_run(self, duckdb_cursor): if not can_run: return diff --git a/tests/fast/arrow/test_polars.py b/tests/fast/arrow/test_polars.py index a4e94d18..329a9758 100644 --- a/tests/fast/arrow/test_polars.py +++ b/tests/fast/arrow/test_polars.py @@ -1,8 +1,9 @@ -import duckdb -import pytest -import sys import datetime +import pytest + +import duckdb + pl = pytest.importorskip("polars") arrow = pytest.importorskip("pyarrow") pl_testing = pytest.importorskip("polars.testing") @@ -20,7 +21,7 @@ def invalid_filter(filter): assert sql_expression is None -class TestPolars(object): +class TestPolars: def test_polars(self, duckdb_cursor): df = pl.DataFrame( { diff --git a/tests/fast/arrow/test_progress.py b/tests/fast/arrow/test_progress.py index 6f056937..4c558784 100644 --- a/tests/fast/arrow/test_progress.py +++ b/tests/fast/arrow/test_progress.py @@ -1,12 +1,13 @@ -import duckdb import os + import pytest +import duckdb + pyarrow_parquet = pytest.importorskip("pyarrow.parquet") -import sys -class TestProgressBarArrow(object): +class TestProgressBarArrow: def test_progress_arrow(self): if os.name == "nt": return diff --git a/tests/fast/arrow/test_projection_pushdown.py b/tests/fast/arrow/test_projection_pushdown.py index 802259e1..803a2703 100644 --- a/tests/fast/arrow/test_projection_pushdown.py +++ b/tests/fast/arrow/test_projection_pushdown.py @@ -1,9 +1,7 @@ -import duckdb -import os import pytest -class TestArrowProjectionPushdown(object): +class TestArrowProjectionPushdown: def test_projection_pushdown_no_filter(self, duckdb_cursor): pa = pytest.importorskip("pyarrow") ds = pytest.importorskip("pyarrow.dataset") diff --git a/tests/fast/arrow/test_time.py b/tests/fast/arrow/test_time.py index e7c4404e..b3bab360 100644 --- a/tests/fast/arrow/test_time.py +++ b/tests/fast/arrow/test_time.py @@ -1,18 +1,16 @@ + + import duckdb -import os -import datetime -import pytest try: import pyarrow as pa - import pandas as pd can_run = True except: can_run = False -class TestArrowTime(object): +class TestArrowTime: def test_time_types(self, duckdb_cursor): if not can_run: return diff --git a/tests/fast/arrow/test_timestamp_timezone.py b/tests/fast/arrow/test_timestamp_timezone.py index 08816be1..c056f19f 100644 --- a/tests/fast/arrow/test_timestamp_timezone.py +++ b/tests/fast/arrow/test_timestamp_timezone.py @@ -1,8 +1,10 @@ -import duckdb -import pytest import datetime + +import pytest import pytz +import duckdb + pa = pytest.importorskip("pyarrow") @@ -16,7 +18,7 @@ def generate_table(current_time, precision, timezone): timezones = ["UTC", "BET", "CET", "Asia/Kathmandu"] -class TestArrowTimestampsTimezone(object): +class TestArrowTimestampsTimezone: def test_timestamp_timezone(self, duckdb_cursor): precisions = ["us", "s", "ns", "ms"] current_time = datetime.datetime(2017, 11, 28, 23, 55, 59, tzinfo=pytz.UTC) diff --git a/tests/fast/arrow/test_timestamps.py b/tests/fast/arrow/test_timestamps.py index 684a333c..6efe0000 100644 --- a/tests/fast/arrow/test_timestamps.py +++ b/tests/fast/arrow/test_timestamps.py @@ -1,18 +1,16 @@ -import duckdb -import os import datetime -import pytest + +import duckdb try: import pyarrow as pa - import pandas as pd can_run = True except: can_run = False -class TestArrowTimestamps(object): +class TestArrowTimestamps: def test_timestamp_types(self, duckdb_cursor): if not can_run: return diff --git a/tests/fast/arrow/test_tpch.py b/tests/fast/arrow/test_tpch.py index d5d13b20..30eca05c 100644 --- a/tests/fast/arrow/test_tpch.py +++ b/tests/fast/arrow/test_tpch.py @@ -1,10 +1,10 @@ import pytest + import duckdb try: import pyarrow import pyarrow.parquet - import numpy as np can_run = True except: @@ -24,7 +24,7 @@ def check_result(result, answers): db_result = result.fetchone() cq_results = q_res.split("|") # The end of the rows, continue - if cq_results == [""] and str(db_result) == "None" or str(db_result[0]) == "None": + if (cq_results == [""] and str(db_result) == "None") or str(db_result[0]) == "None": continue ans_result = [munge(cell) for cell in cq_results] db_result = [munge(cell) for cell in db_result] @@ -34,7 +34,7 @@ def check_result(result, answers): @pytest.mark.skip(reason="Test needs to be adapted to missing TPCH extension") -class TestTPCHArrow(object): +class TestTPCHArrow: def test_tpch_arrow(self, duckdb_cursor): if not can_run: return diff --git a/tests/fast/arrow/test_unregister.py b/tests/fast/arrow/test_unregister.py index 8ff37b5a..de8bab9c 100644 --- a/tests/fast/arrow/test_unregister.py +++ b/tests/fast/arrow/test_unregister.py @@ -1,8 +1,10 @@ -import pytest -import tempfile import gc -import duckdb import os +import tempfile + +import pytest + +import duckdb try: import pyarrow @@ -13,7 +15,7 @@ can_run = False -class TestArrowUnregister(object): +class TestArrowUnregister: def test_arrow_unregister1(self, duckdb_cursor): if not can_run: return diff --git a/tests/fast/arrow/test_view.py b/tests/fast/arrow/test_view.py index 7f1410aa..98b0b6cc 100644 --- a/tests/fast/arrow/test_view.py +++ b/tests/fast/arrow/test_view.py @@ -1,12 +1,12 @@ -import duckdb import os + import pytest pa = pytest.importorskip("pyarrow") pq = pytest.importorskip("pyarrow.parquet") -class TestArrowView(object): +class TestArrowView: def test_arrow_view(self, duckdb_cursor): parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") userdata_parquet_table = pa.parquet.read_table(parquet_filename) diff --git a/tests/fast/numpy/test_numpy_new_path.py b/tests/fast/numpy/test_numpy_new_path.py index b872d4d9..d95c93d1 100644 --- a/tests/fast/numpy/test_numpy_new_path.py +++ b/tests/fast/numpy/test_numpy_new_path.py @@ -2,13 +2,15 @@ Therefore, we only test the new codes and exec paths. """ -import numpy as np -import duckdb from datetime import timedelta + +import numpy as np import pytest +import duckdb + -class TestScanNumpy(object): +class TestScanNumpy: def test_scan_numpy(self, duckdb_cursor): z = np.array([1, 2, 3]) res = duckdb_cursor.sql("select * from z").fetchall() diff --git a/tests/fast/pandas/test_2304.py b/tests/fast/pandas/test_2304.py index 11344df8..859c5265 100644 --- a/tests/fast/pandas/test_2304.py +++ b/tests/fast/pandas/test_2304.py @@ -1,10 +1,11 @@ -import duckdb import numpy as np import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb -class TestPandasMergeSameName(object): +class TestPandasMergeSameName: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_2304(self, duckdb_cursor, pandas): df1 = pandas.DataFrame( diff --git a/tests/fast/pandas/test_append_df.py b/tests/fast/pandas/test_append_df.py index e6d64776..d93cfa2d 100644 --- a/tests/fast/pandas/test_append_df.py +++ b/tests/fast/pandas/test_append_df.py @@ -1,9 +1,10 @@ -import duckdb import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb -class TestAppendDF(object): +class TestAppendDF: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_df_to_table_append(self, duckdb_cursor, pandas): conn = duckdb.connect() diff --git a/tests/fast/pandas/test_bug2281.py b/tests/fast/pandas/test_bug2281.py index 98a90937..ca80504d 100644 --- a/tests/fast/pandas/test_bug2281.py +++ b/tests/fast/pandas/test_bug2281.py @@ -1,12 +1,9 @@ -import duckdb -import os -import datetime -import pytest -import pandas as pd import io +import pandas as pd + -class TestPandasStringNull(object): +class TestPandasStringNull: def test_pandas_string_null(self, duckdb_cursor): csv = """what,is_control,is_test ,0,0 diff --git a/tests/fast/pandas/test_bug5922.py b/tests/fast/pandas/test_bug5922.py index 28daabe9..584fe710 100644 --- a/tests/fast/pandas/test_bug5922.py +++ b/tests/fast/pandas/test_bug5922.py @@ -1,9 +1,10 @@ -import duckdb import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb -class TestPandasAcceptFloat16(object): +class TestPandasAcceptFloat16: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_pandas_accept_float16(self, duckdb_cursor, pandas): df = pandas.DataFrame({"col": [1, 2, 3]}) diff --git a/tests/fast/pandas/test_copy_on_write.py b/tests/fast/pandas/test_copy_on_write.py index ec1b8786..0fcf503f 100644 --- a/tests/fast/pandas/test_copy_on_write.py +++ b/tests/fast/pandas/test_copy_on_write.py @@ -1,6 +1,7 @@ -import duckdb import pytest +import duckdb + # https://pandas.pydata.org/docs/dev/user_guide/copy_on_write.html pandas = pytest.importorskip("pandas", "1.5", reason="copy_on_write does not exist in earlier versions") import datetime @@ -21,7 +22,7 @@ def convert_to_result(col): return [(x,) for x in col] -class TestCopyOnWrite(object): +class TestCopyOnWrite: @pytest.mark.parametrize( "col", [ diff --git a/tests/fast/pandas/test_create_table_from_pandas.py b/tests/fast/pandas/test_create_table_from_pandas.py index 2194d964..bc5792e0 100644 --- a/tests/fast/pandas/test_create_table_from_pandas.py +++ b/tests/fast/pandas/test_create_table_from_pandas.py @@ -1,8 +1,9 @@ +import sys + import pytest +from conftest import ArrowPandas, NumpyPandas + import duckdb -import numpy as np -import sys -from conftest import NumpyPandas, ArrowPandas def assert_create(internal_data, expected_result, data_type, pandas): @@ -25,7 +26,7 @@ def assert_create_register(internal_data, expected_result, data_type, pandas): assert result == expected_result -class TestCreateTableFromPandas(object): +class TestCreateTableFromPandas: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_integer_create_table(self, duckdb_cursor, pandas): if sys.version_info.major < 3: diff --git a/tests/fast/pandas/test_date_as_datetime.py b/tests/fast/pandas/test_date_as_datetime.py index b738b2e1..484674ea 100644 --- a/tests/fast/pandas/test_date_as_datetime.py +++ b/tests/fast/pandas/test_date_as_datetime.py @@ -1,7 +1,8 @@ +import datetime + import pandas as pd + import duckdb -import datetime -import pytest def run_checks(df): diff --git a/tests/fast/pandas/test_datetime_time.py b/tests/fast/pandas/test_datetime_time.py index 1a5a3f7a..0b2642b0 100644 --- a/tests/fast/pandas/test_datetime_time.py +++ b/tests/fast/pandas/test_datetime_time.py @@ -1,13 +1,15 @@ -import duckdb +from datetime import datetime, time, timezone + import numpy as np import pytest -from conftest import NumpyPandas, ArrowPandas -from datetime import datetime, timezone, time, timedelta +from conftest import ArrowPandas, NumpyPandas + +import duckdb _ = pytest.importorskip("pandas", minversion="2.0.0") -class TestDateTimeTime(object): +class TestDateTimeTime: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_time_high(self, duckdb_cursor, pandas): duckdb_time = duckdb_cursor.sql("SELECT make_time(23, 1, 34.234345) AS '0'").df() diff --git a/tests/fast/pandas/test_datetime_timestamp.py b/tests/fast/pandas/test_datetime_timestamp.py index ffc1b7d8..2649cee0 100644 --- a/tests/fast/pandas/test_datetime_timestamp.py +++ b/tests/fast/pandas/test_datetime_timestamp.py @@ -1,14 +1,13 @@ -import duckdb import datetime -import numpy as np + import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas from packaging.version import Version pd = pytest.importorskip("pandas") -class TestDateTimeTimeStamp(object): +class TestDateTimeTimeStamp: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_timestamp_high(self, pandas, duckdb_cursor): duckdb_time = duckdb_cursor.sql("SELECT '2260-01-01 23:59:00'::TIMESTAMP AS '0'").df() diff --git a/tests/fast/pandas/test_df_analyze.py b/tests/fast/pandas/test_df_analyze.py index 8e67da4a..92318085 100644 --- a/tests/fast/pandas/test_df_analyze.py +++ b/tests/fast/pandas/test_df_analyze.py @@ -1,15 +1,16 @@ -import duckdb -import datetime + import numpy as np import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb def create_generic_dataframe(data, pandas): return pandas.DataFrame({"col0": pandas.Series(data=data, dtype="object")}) -class TestResolveObjectColumns(object): +class TestResolveObjectColumns: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_sample_low_correct(self, duckdb_cursor, pandas): print(pandas.backend) diff --git a/tests/fast/pandas/test_df_object_resolution.py b/tests/fast/pandas/test_df_object_resolution.py index 73470818..27dc2116 100644 --- a/tests/fast/pandas/test_df_object_resolution.py +++ b/tests/fast/pandas/test_df_object_resolution.py @@ -1,13 +1,15 @@ -import duckdb import datetime -import numpy as np -import platform -import pytest import decimal import math -from decimal import Decimal +import platform import re -from conftest import NumpyPandas, ArrowPandas +from decimal import Decimal + +import numpy as np +import pytest +from conftest import ArrowPandas, NumpyPandas + +import duckdb standard_vector_size = duckdb.__standard_vector_size__ @@ -81,7 +83,7 @@ def check_struct_upgrade(expected_type: str, creation_method, pair: ObjectPair, assert expected_type == rel.types[0] -class TestResolveObjectColumns(object): +class TestResolveObjectColumns: # TODO: add support for ArrowPandas @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_integers(self, pandas, duckdb_cursor): @@ -674,7 +676,7 @@ def test_multiple_chunks(self, pandas, duckdb_cursor): @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_multiple_chunks_aggregate(self, pandas, duckdb_cursor): - duckdb_cursor.execute(f"SET GLOBAL pandas_analyze_sample=4096") + duckdb_cursor.execute("SET GLOBAL pandas_analyze_sample=4096") duckdb_cursor.execute( "create table dates as select '2022-09-14'::DATE + INTERVAL (i::INTEGER) DAY as i from range(4096) tbl(i);" ) diff --git a/tests/fast/pandas/test_df_recursive_nested.py b/tests/fast/pandas/test_df_recursive_nested.py index fb7d2ad0..4eacf777 100644 --- a/tests/fast/pandas/test_df_recursive_nested.py +++ b/tests/fast/pandas/test_df_recursive_nested.py @@ -1,9 +1,8 @@ -import duckdb -import datetime -import numpy as np + import pytest -import copy -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb from duckdb import Value NULL = None @@ -23,7 +22,7 @@ def create_reference_query(): return query -class TestDFRecursiveNested(object): +class TestDFRecursiveNested: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_list_of_structs(self, duckdb_cursor, pandas): data = [[{"a": 5}, NULL, {"a": NULL}], NULL, [{"a": 5}, NULL, {"a": NULL}]] diff --git a/tests/fast/pandas/test_fetch_df_chunk.py b/tests/fast/pandas/test_fetch_df_chunk.py index 1f2d4b1b..90f4e428 100644 --- a/tests/fast/pandas/test_fetch_df_chunk.py +++ b/tests/fast/pandas/test_fetch_df_chunk.py @@ -1,10 +1,11 @@ import pytest + import duckdb VECTOR_SIZE = duckdb.__standard_vector_size__ -class TestType(object): +class TestType: def test_fetch_df_chunk(self): size = 3000 con = duckdb.connect() diff --git a/tests/fast/pandas/test_fetch_nested.py b/tests/fast/pandas/test_fetch_nested.py index e25a44ba..6e878643 100644 --- a/tests/fast/pandas/test_fetch_nested.py +++ b/tests/fast/pandas/test_fetch_nested.py @@ -1,6 +1,7 @@ + import pytest + import duckdb -import sys pd = pytest.importorskip("pandas") import numpy as np @@ -55,7 +56,7 @@ def list_test_cases(): }), ("SELECT a from (SELECT LIST(i) as a FROM range(10000) tbl(i)) as t", { 'a': [ - list(range(0, 10000)) + list(range(10000)) ] }), ("SELECT LIST(i) as a FROM range(5) tbl(i) group by i%2 order by all", { @@ -146,7 +147,7 @@ def list_test_cases(): return test_cases -class TestFetchNested(object): +class TestFetchNested: @pytest.mark.parametrize("query, expected", list_test_cases()) def test_fetch_df_list(self, duckdb_cursor, query, expected): compare_results(duckdb_cursor, query, expected) diff --git a/tests/fast/pandas/test_implicit_pandas_scan.py b/tests/fast/pandas/test_implicit_pandas_scan.py index 2d4610ff..3808c42a 100644 --- a/tests/fast/pandas/test_implicit_pandas_scan.py +++ b/tests/fast/pandas/test_implicit_pandas_scan.py @@ -1,11 +1,12 @@ # simple DB API testcase -import duckdb import pandas as pd import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas from packaging.version import Version +import duckdb + numpy_nullable_df = pd.DataFrame([{"COL1": "val1", "CoL2": 1.05}, {"COL1": "val4", "CoL2": 17}]) try: @@ -22,7 +23,7 @@ pyarrow_df = numpy_nullable_df -class TestImplicitPandasScan(object): +class TestImplicitPandasScan: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_local_pandas_scan(self, duckdb_cursor, pandas): con = duckdb.connect() diff --git a/tests/fast/pandas/test_import_cache.py b/tests/fast/pandas/test_import_cache.py index 6ed601c5..d67b50ca 100644 --- a/tests/fast/pandas/test_import_cache.py +++ b/tests/fast/pandas/test_import_cache.py @@ -1,6 +1,7 @@ -from conftest import NumpyPandas, ArrowPandas -import duckdb import pytest +from conftest import ArrowPandas, NumpyPandas + +import duckdb @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) diff --git a/tests/fast/pandas/test_issue_1767.py b/tests/fast/pandas/test_issue_1767.py index 27f0c2ff..48d3e852 100644 --- a/tests/fast/pandas/test_issue_1767.py +++ b/tests/fast/pandas/test_issue_1767.py @@ -1,14 +1,13 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- -import duckdb -import numpy import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb # Join from pandas not matching identical strings #1767 -class TestIssue1767(object): +class TestIssue1767: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_unicode_join_pandas(self, duckdb_cursor, pandas): A = pandas.DataFrame({"key": ["a", "п"]}) diff --git a/tests/fast/pandas/test_limit.py b/tests/fast/pandas/test_limit.py index 460716cd..51c4a382 100644 --- a/tests/fast/pandas/test_limit.py +++ b/tests/fast/pandas/test_limit.py @@ -1,9 +1,10 @@ -import duckdb import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb -class TestLimitPandas(object): +class TestLimitPandas: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_limit_df(self, duckdb_cursor, pandas): df_in = pandas.DataFrame( diff --git a/tests/fast/pandas/test_pandas_arrow.py b/tests/fast/pandas/test_pandas_arrow.py index e1661041..cd82736b 100644 --- a/tests/fast/pandas/test_pandas_arrow.py +++ b/tests/fast/pandas/test_pandas_arrow.py @@ -1,16 +1,17 @@ -import duckdb -import pytest import datetime +import pytest from conftest import pandas_supports_arrow_backend +import duckdb + pd = pytest.importorskip("pandas", "2.0.0") import numpy as np from pandas.api.types import is_integer_dtype @pytest.mark.skipif(not pandas_supports_arrow_backend(), reason="pandas does not support the 'pyarrow' backend") -class TestPandasArrow(object): +class TestPandasArrow: def test_pandas_arrow(self, duckdb_cursor): pd = pytest.importorskip("pandas") df = pd.DataFrame({"a": pd.Series([5, 4, 3])}).convert_dtypes() diff --git a/tests/fast/pandas/test_pandas_category.py b/tests/fast/pandas/test_pandas_category.py index 4b29b3fb..b40fefb8 100644 --- a/tests/fast/pandas/test_pandas_category.py +++ b/tests/fast/pandas/test_pandas_category.py @@ -1,8 +1,9 @@ -import duckdb -import pandas as pd import numpy +import pandas as pd import pytest +import duckdb + def check_category_equal(category): df_in = pd.DataFrame( @@ -54,7 +55,7 @@ def check_create_table(category): conn.execute("DROP TABLE t1") -class TestCategory(object): +class TestCategory: def test_category_simple(self, duckdb_cursor): df_in = pd.DataFrame({"float": [1.0, 2.0, 1.0], "int": pd.Series([1, 2, 1], dtype="category")}) diff --git a/tests/fast/pandas/test_pandas_df_none.py b/tests/fast/pandas/test_pandas_df_none.py index 50e1553c..5fa76c8c 100644 --- a/tests/fast/pandas/test_pandas_df_none.py +++ b/tests/fast/pandas/test_pandas_df_none.py @@ -1,11 +1,7 @@ -import pandas as pd -import pytest import duckdb -import sys -import gc -class TestPandasDFNone(object): +class TestPandasDFNone: # This used to decrease the ref count of None def test_none_deref(self): con = duckdb.connect() diff --git a/tests/fast/pandas/test_pandas_enum.py b/tests/fast/pandas/test_pandas_enum.py index b1eb2c7f..5b246fcf 100644 --- a/tests/fast/pandas/test_pandas_enum.py +++ b/tests/fast/pandas/test_pandas_enum.py @@ -1,9 +1,10 @@ import pandas as pd import pytest + import duckdb -class TestPandasEnum(object): +class TestPandasEnum: def test_3480(self, duckdb_cursor): duckdb_cursor.execute( """ @@ -14,7 +15,7 @@ def test_3480(self, duckdb_cursor): ); """ ) - df = duckdb_cursor.query(f"SELECT * FROM tab LIMIT 0;").to_df() + df = duckdb_cursor.query("SELECT * FROM tab LIMIT 0;").to_df() assert df["cat"].cat.categories.equals(pd.Index(["marie", "duchess", "toulouse"])) duckdb_cursor.execute("DROP TABLE tab") duckdb_cursor.execute("DROP TYPE cat") @@ -41,7 +42,7 @@ def test_3479(self, duckdb_cursor): duckdb.ConversionException, match="Type UINT8 with value 0 can't be cast because the value is out of range for the destination type UINT8", ): - duckdb_cursor.execute(f"INSERT INTO tab SELECT * FROM df;") + duckdb_cursor.execute("INSERT INTO tab SELECT * FROM df;") assert duckdb_cursor.execute("select * from tab").fetchall() == [] duckdb_cursor.execute("DROP TABLE tab") diff --git a/tests/fast/pandas/test_pandas_limit.py b/tests/fast/pandas/test_pandas_limit.py index d551a6e4..89fe1583 100644 --- a/tests/fast/pandas/test_pandas_limit.py +++ b/tests/fast/pandas/test_pandas_limit.py @@ -1,9 +1,8 @@ + import duckdb -import pandas as pd -import pytest -class TestPandasLimit(object): +class TestPandasLimit: def test_pandas_limit(self, duckdb_cursor): con = duckdb.connect() df = con.execute("select * from range(10000000) tbl(i)").df() diff --git a/tests/fast/pandas/test_pandas_na.py b/tests/fast/pandas/test_pandas_na.py index 7bc01003..f83be08a 100644 --- a/tests/fast/pandas/test_pandas_na.py +++ b/tests/fast/pandas/test_pandas_na.py @@ -1,9 +1,10 @@ +import platform + import numpy as np -import datetime -import duckdb import pytest -import platform -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb def assert_nullness(items, null_indices): @@ -15,7 +16,7 @@ def assert_nullness(items, null_indices): @pytest.mark.skipif(platform.system() == "Emscripten", reason="Pandas interaction is broken in Pyodide 3.11") -class TestPandasNA(object): +class TestPandasNA: @pytest.mark.parametrize("rows", [100, duckdb.__standard_vector_size__, 5000, 1000000]) @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) def test_pandas_string_null(self, duckdb_cursor, rows, pd): diff --git a/tests/fast/pandas/test_pandas_object.py b/tests/fast/pandas/test_pandas_object.py index 9e10681c..bb8e3eff 100644 --- a/tests/fast/pandas/test_pandas_object.py +++ b/tests/fast/pandas/test_pandas_object.py @@ -1,11 +1,12 @@ -import pandas as pd -import duckdb import datetime + import numpy as np -import random +import pandas as pd + +import duckdb -class TestPandasObject(object): +class TestPandasObject: def test_object_lotof_nulls(self): # Test mostly null column data = [None] + [1] + [None] * 10000 # Last element is 1, others are None diff --git a/tests/fast/pandas/test_pandas_string.py b/tests/fast/pandas/test_pandas_string.py index 4bd5996d..d1302f89 100644 --- a/tests/fast/pandas/test_pandas_string.py +++ b/tests/fast/pandas/test_pandas_string.py @@ -1,9 +1,10 @@ -import duckdb -import pandas as pd import numpy +import pandas as pd + +import duckdb -class TestPandasString(object): +class TestPandasString: def test_pandas_string(self, duckdb_cursor): strings = numpy.array(["foo", "bar", "baz"]) @@ -31,12 +32,12 @@ def test_bug_2467(self, duckdb_cursor): con = duckdb.connect() con.register("df", df) con.execute( - f""" + """ CREATE TABLE t1 AS SELECT * FROM df """ ) assert con.execute( - f""" + """ SELECT count(*) from t1 """ ).fetchall() == [(3000000,)] diff --git a/tests/fast/pandas/test_pandas_timestamp.py b/tests/fast/pandas/test_pandas_timestamp.py index 835ff3af..635cee36 100644 --- a/tests/fast/pandas/test_pandas_timestamp.py +++ b/tests/fast/pandas/test_pandas_timestamp.py @@ -1,11 +1,11 @@ -import duckdb +from datetime import datetime + import pandas import pytest - -from datetime import datetime -from pytz import timezone from conftest import pandas_2_or_higher +import duckdb + @pytest.mark.parametrize("timezone", ["UTC", "CET", "Asia/Kathmandu"]) @pytest.mark.skipif(not pandas_2_or_higher(), reason="Pandas <2.0.0 does not support timezones in the metadata string") diff --git a/tests/fast/pandas/test_pandas_types.py b/tests/fast/pandas/test_pandas_types.py index fcc63b82..f7df363d 100644 --- a/tests/fast/pandas/test_pandas_types.py +++ b/tests/fast/pandas/test_pandas_types.py @@ -1,12 +1,14 @@ -import duckdb -import pytest -import pandas as pd -import numpy import string -from packaging import version import warnings from contextlib import suppress +import numpy +import pandas as pd +import pytest +from packaging import version + +import duckdb + def round_trip(data, pandas_type): df_in = pd.DataFrame( @@ -21,7 +23,7 @@ def round_trip(data, pandas_type): assert df_out.equals(df_in) -class TestNumpyNullableTypes(object): +class TestNumpyNullableTypes: def test_pandas_numeric(self): base_df = pd.DataFrame({"a": range(10)}) diff --git a/tests/fast/pandas/test_pandas_unregister.py b/tests/fast/pandas/test_pandas_unregister.py index fce8f42a..bce93158 100644 --- a/tests/fast/pandas/test_pandas_unregister.py +++ b/tests/fast/pandas/test_pandas_unregister.py @@ -1,13 +1,14 @@ -import duckdb -import pytest -import tempfile -import os import gc +import os +import tempfile + import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb -class TestPandasUnregister(object): +class TestPandasUnregister: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_pandas_unregister1(self, duckdb_cursor, pandas): df = pandas.DataFrame([[1, 2, 3], [4, 5, 6]]) diff --git a/tests/fast/pandas/test_pandas_update.py b/tests/fast/pandas/test_pandas_update.py index 86d17154..bc1740d9 100644 --- a/tests/fast/pandas/test_pandas_update.py +++ b/tests/fast/pandas/test_pandas_update.py @@ -1,8 +1,9 @@ -import duckdb import pandas as pd +import duckdb + -class TestPandasUpdateList(object): +class TestPandasUpdateList: def test_pandas_update_list(self, duckdb_cursor): duckdb_cursor = duckdb.connect(":memory:") duckdb_cursor.execute("create table t (l int[])") diff --git a/tests/fast/pandas/test_parallel_pandas_scan.py b/tests/fast/pandas/test_parallel_pandas_scan.py index d113bbca..b389fce5 100644 --- a/tests/fast/pandas/test_parallel_pandas_scan.py +++ b/tests/fast/pandas/test_parallel_pandas_scan.py @@ -1,14 +1,15 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- -import duckdb -import numpy import datetime + +import numpy import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb def run_parallel_queries(main_table, left_join_table, expected_df, pandas, iteration_count=5): - for i in range(0, iteration_count): + for i in range(iteration_count): output_df = None sql = """ select @@ -35,7 +36,7 @@ def run_parallel_queries(main_table, left_join_table, expected_df, pandas, itera duckdb_conn.close() -class TestParallelPandasScan(object): +class TestParallelPandasScan: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_parallel_numeric_scan(self, duckdb_cursor, pandas): main_table = pandas.DataFrame([{"join_column": 3}]) diff --git a/tests/fast/pandas/test_partitioned_pandas_scan.py b/tests/fast/pandas/test_partitioned_pandas_scan.py index d2447ef8..9f580659 100644 --- a/tests/fast/pandas/test_partitioned_pandas_scan.py +++ b/tests/fast/pandas/test_partitioned_pandas_scan.py @@ -1,11 +1,11 @@ -import duckdb -import pandas as pd + import numpy -import datetime -import time +import pandas as pd + +import duckdb -class TestPartitionedPandasScan(object): +class TestPartitionedPandasScan: def test_parallel_pandas(self, duckdb_cursor): con = duckdb.connect() df = pd.DataFrame({"i": numpy.arange(10000000)}) diff --git a/tests/fast/pandas/test_progress_bar.py b/tests/fast/pandas/test_progress_bar.py index 7c1c21e1..c8cfb2e0 100644 --- a/tests/fast/pandas/test_progress_bar.py +++ b/tests/fast/pandas/test_progress_bar.py @@ -1,11 +1,11 @@ -import duckdb -import pandas as pd + import numpy -import datetime -import time +import pandas as pd + +import duckdb -class TestProgressBarPandas(object): +class TestProgressBarPandas: def test_progress_pandas_single(self, duckdb_cursor): con = duckdb.connect() df = pd.DataFrame({"i": numpy.arange(10000000)}) diff --git a/tests/fast/pandas/test_pyarrow_projection_pushdown.py b/tests/fast/pandas/test_pyarrow_projection_pushdown.py index b04f713a..4191a96e 100644 --- a/tests/fast/pandas/test_pyarrow_projection_pushdown.py +++ b/tests/fast/pandas/test_pyarrow_projection_pushdown.py @@ -1,16 +1,16 @@ -import duckdb -import os -import pytest +import pytest from conftest import pandas_supports_arrow_backend +import duckdb + pa = pytest.importorskip("pyarrow") ds = pytest.importorskip("pyarrow.dataset") _ = pytest.importorskip("pandas", "2.0.0") @pytest.mark.skipif(not pandas_supports_arrow_backend(), reason="pandas does not support the 'pyarrow' backend") -class TestArrowDFProjectionPushdown(object): +class TestArrowDFProjectionPushdown: def test_projection_pushdown_no_filter(self, duckdb_cursor): duckdb_conn = duckdb.connect() duckdb_conn.execute("CREATE TABLE test (a INTEGER, b INTEGER, c INTEGER)") diff --git a/tests/fast/pandas/test_same_name.py b/tests/fast/pandas/test_same_name.py index ac4f407a..ff499ddf 100644 --- a/tests/fast/pandas/test_same_name.py +++ b/tests/fast/pandas/test_same_name.py @@ -1,9 +1,7 @@ -import pytest -import duckdb import pandas as pd -class TestMultipleColumnsSameName(object): +class TestMultipleColumnsSameName: def test_multiple_columns_with_same_name(self, duckdb_cursor): df = pd.DataFrame({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8], "d": [9, 10, 11, 12]}) df = df.rename(columns={df.columns[1]: "a"}) diff --git a/tests/fast/pandas/test_stride.py b/tests/fast/pandas/test_stride.py index 1b2f5052..cbe23cfd 100644 --- a/tests/fast/pandas/test_stride.py +++ b/tests/fast/pandas/test_stride.py @@ -1,10 +1,12 @@ +import datetime + +import numpy as np import pandas as pd + import duckdb -import numpy as np -import datetime -class TestPandasStride(object): +class TestPandasStride: def test_stride(self, duckdb_cursor): expected_df = pd.DataFrame(np.arange(20).reshape(5, 4), columns=["a", "b", "c", "d"]) con = duckdb.connect() diff --git a/tests/fast/pandas/test_timedelta.py b/tests/fast/pandas/test_timedelta.py index c0afeb74..deca62e0 100644 --- a/tests/fast/pandas/test_timedelta.py +++ b/tests/fast/pandas/test_timedelta.py @@ -1,11 +1,13 @@ +import datetime import platform + import pandas as pd -import duckdb -import datetime import pytest +import duckdb + -class TestTimedelta(object): +class TestTimedelta: def test_timedelta_positive(self, duckdb_cursor): duckdb_interval = duckdb_cursor.query( "SELECT '2290-01-01 23:59:00'::TIMESTAMP - '2000-01-01 23:59:00'::TIMESTAMP AS '0'" diff --git a/tests/fast/pandas/test_timestamp.py b/tests/fast/pandas/test_timestamp.py index dbb7273d..e14d82a6 100644 --- a/tests/fast/pandas/test_timestamp.py +++ b/tests/fast/pandas/test_timestamp.py @@ -1,13 +1,15 @@ -import duckdb import datetime import os -import pytest -import pandas as pd import platform + +import pandas as pd +import pytest from conftest import pandas_2_or_higher +import duckdb + -class TestPandasTimestamps(object): +class TestPandasTimestamps: @pytest.mark.parametrize("unit", ["s", "ms", "us", "ns"]) def test_timestamp_types_roundtrip(self, unit): d = { diff --git a/tests/fast/relational_api/test_groupings.py b/tests/fast/relational_api/test_groupings.py index b0a95410..250df7ad 100644 --- a/tests/fast/relational_api/test_groupings.py +++ b/tests/fast/relational_api/test_groupings.py @@ -1,6 +1,7 @@ -import duckdb import pytest +import duckdb + @pytest.fixture def con(): @@ -17,10 +18,10 @@ def con(): ) AS tbl(a, b, c)) """ ) - yield conn + return conn -class TestGroupings(object): +class TestGroupings: def test_basic_grouping(self, con): rel = con.table("tbl").sum("a", "b") res = rel.fetchall() diff --git a/tests/fast/relational_api/test_joins.py b/tests/fast/relational_api/test_joins.py index cf3d3cf2..726fdac8 100644 --- a/tests/fast/relational_api/test_joins.py +++ b/tests/fast/relational_api/test_joins.py @@ -1,5 +1,6 @@ -import duckdb import pytest + +import duckdb from duckdb import ColumnExpression @@ -26,10 +27,10 @@ def con(): ) AS t(a, b)) """ ) - yield conn + return conn -class TestRAPIJoins(object): +class TestRAPIJoins: def test_outer_join(self, con): a = con.table("tbl_a") b = con.table("tbl_b") diff --git a/tests/fast/relational_api/test_pivot.py b/tests/fast/relational_api/test_pivot.py index 9cf91e56..1cca02b4 100644 --- a/tests/fast/relational_api/test_pivot.py +++ b/tests/fast/relational_api/test_pivot.py @@ -1,10 +1,8 @@ -import duckdb -import pytest import os import tempfile -class TestPivot(object): +class TestPivot: def test_pivot_issue_14600(self, duckdb_cursor): duckdb_cursor.sql( "create table input_data as select unnest(['u','v','w']) as a, unnest(['x','y','z']) as b, unnest([1,2,3]) as c;" @@ -26,5 +24,5 @@ def test_pivot_issue_14601(self, duckdb_cursor): pivot_1.create("pivot_1") export_dir = tempfile.mkdtemp() duckdb_cursor.query(f"EXPORT DATABASE '{export_dir}'") - with open(os.path.join(export_dir, "schema.sql"), "r") as f: + with open(os.path.join(export_dir, "schema.sql")) as f: assert "CREATE TYPE" not in f.read() diff --git a/tests/fast/relational_api/test_rapi_aggregations.py b/tests/fast/relational_api/test_rapi_aggregations.py index 3466a77a..31cb21c9 100644 --- a/tests/fast/relational_api/test_rapi_aggregations.py +++ b/tests/fast/relational_api/test_rapi_aggregations.py @@ -1,7 +1,8 @@ -import duckdb -from decimal import Decimal + import pytest +import duckdb + @pytest.fixture(autouse=True) def setup_and_teardown_of_table(duckdb_cursor): @@ -23,12 +24,12 @@ def setup_and_teardown_of_table(duckdb_cursor): duckdb_cursor.execute("drop table agg") -@pytest.fixture() +@pytest.fixture def table(duckdb_cursor): return duckdb_cursor.table("agg") -class TestRAPIAggregations(object): +class TestRAPIAggregations: # General aggregate functions def test_any_value(self, table): diff --git a/tests/fast/relational_api/test_rapi_close.py b/tests/fast/relational_api/test_rapi_close.py index b6355167..969e2792 100644 --- a/tests/fast/relational_api/test_rapi_close.py +++ b/tests/fast/relational_api/test_rapi_close.py @@ -1,9 +1,10 @@ -import duckdb import pytest +import duckdb + # A closed connection should invalidate all relation's methods -class TestRAPICloseConnRel(object): +class TestRAPICloseConnRel: def test_close_conn_rel(self, duckdb_cursor): con = duckdb.connect() con.execute("CREATE TABLE items(item VARCHAR, value DECIMAL(10,2), count INTEGER)") diff --git a/tests/fast/relational_api/test_rapi_description.py b/tests/fast/relational_api/test_rapi_description.py index 80616132..2696ed2f 100644 --- a/tests/fast/relational_api/test_rapi_description.py +++ b/tests/fast/relational_api/test_rapi_description.py @@ -1,8 +1,7 @@ -import duckdb import pytest -class TestRAPIDescription(object): +class TestRAPIDescription: def test_rapi_description(self, duckdb_cursor): res = duckdb_cursor.query("select 42::INT AS a, 84::BIGINT AS b") desc = res.description diff --git a/tests/fast/relational_api/test_rapi_functions.py b/tests/fast/relational_api/test_rapi_functions.py index c6b1f1fa..143aa8df 100644 --- a/tests/fast/relational_api/test_rapi_functions.py +++ b/tests/fast/relational_api/test_rapi_functions.py @@ -1,7 +1,7 @@ import duckdb -class TestRAPIFunctions(object): +class TestRAPIFunctions: def test_rapi_str_print(self, duckdb_cursor): res = duckdb_cursor.query("select 42::INT AS a, 84::BIGINT AS b") assert str(res) is not None diff --git a/tests/fast/relational_api/test_rapi_query.py b/tests/fast/relational_api/test_rapi_query.py index 16ed326c..b9f2ef68 100644 --- a/tests/fast/relational_api/test_rapi_query.py +++ b/tests/fast/relational_api/test_rapi_query.py @@ -1,10 +1,12 @@ -import duckdb -import pytest import platform import sys +import pytest + +import duckdb + -@pytest.fixture() +@pytest.fixture def tbl_table(): con = duckdb.default_connection() con.execute("drop table if exists tbl CASCADE") @@ -13,7 +15,7 @@ def tbl_table(): con.execute("drop table tbl CASCADE") -@pytest.fixture() +@pytest.fixture def scoped_default(duckdb_cursor): default = duckdb.connect(":default:") duckdb.set_default_connection(duckdb_cursor) @@ -23,11 +25,11 @@ def scoped_default(duckdb_cursor): duckdb.set_default_connection(default) -class TestRAPIQuery(object): +class TestRAPIQuery: @pytest.mark.parametrize("steps", [1, 2, 3, 4]) def test_query_chain(self, steps): con = duckdb.default_connection() - amount = int(1000000) + amount = 1000000 rel = None for _ in range(steps): rel = con.query(f"select i from range({amount}::BIGINT) tbl(i)") diff --git a/tests/fast/relational_api/test_rapi_windows.py b/tests/fast/relational_api/test_rapi_windows.py index cc58b8f1..ce0196fc 100644 --- a/tests/fast/relational_api/test_rapi_windows.py +++ b/tests/fast/relational_api/test_rapi_windows.py @@ -1,6 +1,7 @@ -import duckdb import pytest +import duckdb + @pytest.fixture(autouse=True) def setup_and_teardown_of_table(duckdb_cursor): @@ -22,7 +23,7 @@ def setup_and_teardown_of_table(duckdb_cursor): duckdb_cursor.execute("drop table win") -@pytest.fixture() +@pytest.fixture def table(duckdb_cursor): return duckdb_cursor.table("win") diff --git a/tests/fast/relational_api/test_table_function.py b/tests/fast/relational_api/test_table_function.py index 5748f762..2a5271f9 100644 --- a/tests/fast/relational_api/test_table_function.py +++ b/tests/fast/relational_api/test_table_function.py @@ -1,11 +1,13 @@ -import duckdb -import pytest import os +import pytest + +import duckdb + script_path = os.path.dirname(__file__) -class TestTableFunction(object): +class TestTableFunction: def test_table_function(self, duckdb_cursor): path = os.path.join(script_path, "..", "data/integers.csv") rel = duckdb_cursor.table_function("read_csv", [path]) diff --git a/tests/fast/spark/test_replace_column_value.py b/tests/fast/spark/test_replace_column_value.py index 65ab85f1..17a2254e 100644 --- a/tests/fast/spark/test_replace_column_value.py +++ b/tests/fast/spark/test_replace_column_value.py @@ -4,7 +4,7 @@ from spark_namespace.sql.types import Row -class TestReplaceValue(object): +class TestReplaceValue: # https://sparkbyexamples.com/pyspark/pyspark-replace-column-values/?expand_article=1 def test_replace_value(self, spark): address = [(1, "14851 Jeffrey Rd", "DE"), (2, "43421 Margarita St", "NY"), (3, "13111 Siemon Ave", "CA")] diff --git a/tests/fast/spark/test_replace_empty_value.py b/tests/fast/spark/test_replace_empty_value.py index aad6a43e..615b15d8 100644 --- a/tests/fast/spark/test_replace_empty_value.py +++ b/tests/fast/spark/test_replace_empty_value.py @@ -2,12 +2,11 @@ _ = pytest.importorskip("duckdb.experimental.spark") -from spark_namespace import USE_ACTUAL_SPARK from spark_namespace.sql.types import Row # https://sparkbyexamples.com/pyspark/pyspark-replace-empty-value-with-none-on-dataframe-2/?expand_article=1 -class TestReplaceEmpty(object): +class TestReplaceEmpty: def test_replace_empty(self, spark): # Create the dataframe data = [("", "CA"), ("Julia", ""), ("Robert", ""), ("", "NJ")] diff --git a/tests/fast/spark/test_spark_arrow_table.py b/tests/fast/spark/test_spark_arrow_table.py index 57c81599..fc773562 100644 --- a/tests/fast/spark/test_spark_arrow_table.py +++ b/tests/fast/spark/test_spark_arrow_table.py @@ -2,8 +2,6 @@ _ = pytest.importorskip("duckdb.experimental.spark") pa = pytest.importorskip("pyarrow") -from spark_namespace import USE_ACTUAL_SPARK - from spark_namespace import USE_ACTUAL_SPARK from spark_namespace.sql.dataframe import DataFrame diff --git a/tests/fast/spark/test_spark_catalog.py b/tests/fast/spark/test_spark_catalog.py index 2ecaad24..c19ec83c 100644 --- a/tests/fast/spark/test_spark_catalog.py +++ b/tests/fast/spark/test_spark_catalog.py @@ -3,10 +3,10 @@ _ = pytest.importorskip("duckdb.experimental.spark") from spark_namespace import USE_ACTUAL_SPARK -from spark_namespace.sql.catalog import Table, Database, Column +from spark_namespace.sql.catalog import Column, Database, Table -class TestSparkCatalog(object): +class TestSparkCatalog: def test_list_databases(self, spark): dbs = spark.catalog.listDatabases() if USE_ACTUAL_SPARK: diff --git a/tests/fast/spark/test_spark_column.py b/tests/fast/spark/test_spark_column.py index 9ef17d95..e8da1333 100644 --- a/tests/fast/spark/test_spark_column.py +++ b/tests/fast/spark/test_spark_column.py @@ -2,17 +2,15 @@ _ = pytest.importorskip("duckdb.experimental.spark") +import re + from spark_namespace import USE_ACTUAL_SPARK -from spark_namespace.sql.column import Column -from spark_namespace.sql.functions import struct, array, col -from spark_namespace.sql.types import Row from spark_namespace.errors import PySparkTypeError - -import duckdb -import re +from spark_namespace.sql.functions import array, col, struct +from spark_namespace.sql.types import Row -class TestSparkColumn(object): +class TestSparkColumn: def test_struct_column(self, spark): df = spark.createDataFrame([Row(a=1, b=2, c=3, d=4)]) diff --git a/tests/fast/spark/test_spark_dataframe.py b/tests/fast/spark/test_spark_dataframe.py index e86995ec..26006952 100644 --- a/tests/fast/spark/test_spark_dataframe.py +++ b/tests/fast/spark/test_spark_dataframe.py @@ -2,25 +2,22 @@ _ = pytest.importorskip("duckdb.experimental.spark") + from spark_namespace import USE_ACTUAL_SPARK +from spark_namespace.errors import PySparkTypeError, PySparkValueError +from spark_namespace.sql.column import Column +from spark_namespace.sql.functions import col, struct, when from spark_namespace.sql.types import ( - LongType, - StructType, + ArrayType, BooleanType, - StructField, - StringType, IntegerType, LongType, - Row, - ArrayType, MapType, + Row, + StringType, + StructField, + StructType, ) -from spark_namespace.sql.functions import col, struct, when -from spark_namespace.sql.column import Column -import duckdb -import re - -from spark_namespace.errors import PySparkValueError, PySparkTypeError def assert_column_objects_equal(col1: Column, col2: Column): @@ -29,7 +26,7 @@ def assert_column_objects_equal(col1: Column, col2: Column): assert col1.expr == col2.expr -class TestDataFrame(object): +class TestDataFrame: def test_dataframe_from_list_of_tuples(self, spark): # Valid address = [(1, "14851 Jeffrey Rd", "DE"), (2, "43421 Margarita St", "NY"), (3, "13111 Siemon Ave", "CA")] @@ -194,7 +191,7 @@ def test_df_from_name_list(self, spark): assert res == [Row(a=42, b=True), Row(a=21, b=False)] def test_df_creation_coverage(self, spark): - from spark_namespace.sql.types import StructType, StructField, StringType, IntegerType + from spark_namespace.sql.types import IntegerType, StringType, StructField, StructType data2 = [ ("James", "", "Smith", "36636", "M", 3000), @@ -298,7 +295,7 @@ def test_df_nested_struct(self, spark): ) def test_df_columns(self, spark): - from spark_namespace.sql.functions import col, struct, when + from spark_namespace.sql.functions import col structureData = [ (("James", "", "Smith"), "36636", "M", 3100), @@ -343,7 +340,6 @@ def test_df_columns(self, spark): def test_array_and_map_type(self, spark): """Array & Map""" - arrayStructureSchema = StructType( [ StructField( diff --git a/tests/fast/spark/test_spark_dataframe_sort.py b/tests/fast/spark/test_spark_dataframe_sort.py index db7dce4b..49631d4d 100644 --- a/tests/fast/spark/test_spark_dataframe_sort.py +++ b/tests/fast/spark/test_spark_dataframe_sort.py @@ -3,13 +3,13 @@ _ = pytest.importorskip("duckdb.experimental.spark") import spark_namespace.errors -from spark_namespace.sql.types import Row -from spark_namespace.sql.functions import desc, asc -from spark_namespace.errors import PySparkTypeError, PySparkValueError from spark_namespace import USE_ACTUAL_SPARK +from spark_namespace.errors import PySparkTypeError, PySparkValueError +from spark_namespace.sql.functions import asc, desc +from spark_namespace.sql.types import Row -class TestDataFrameSort(object): +class TestDataFrameSort: data = [(56, "Carol"), (20, "Alice"), (3, "Dave"), (3, "Anna"), (1, "Ben")] def test_sort_ascending(self, spark): diff --git a/tests/fast/spark/test_spark_drop_duplicates.py b/tests/fast/spark/test_spark_drop_duplicates.py index 563a5e76..cd658c77 100644 --- a/tests/fast/spark/test_spark_drop_duplicates.py +++ b/tests/fast/spark/test_spark_drop_duplicates.py @@ -1,6 +1,4 @@ import pytest - - from spark_namespace.sql.types import ( Row, ) @@ -8,7 +6,7 @@ _ = pytest.importorskip("duckdb.experimental.spark") -class TestDataFrameDropDuplicates(object): +class TestDataFrameDropDuplicates: @pytest.mark.parametrize("method", ["dropDuplicates", "drop_duplicates"]) def test_spark_drop_duplicates(self, method, spark): # Prepare Data diff --git a/tests/fast/spark/test_spark_except.py b/tests/fast/spark/test_spark_except.py index 7c28cc29..dd6c802d 100644 --- a/tests/fast/spark/test_spark_except.py +++ b/tests/fast/spark/test_spark_except.py @@ -1,10 +1,8 @@ -import platform import pytest _ = pytest.importorskip("duckdb.experimental.spark") from duckdb.experimental.spark.sql.types import Row -from duckdb.experimental.spark.sql.functions import col @pytest.fixture diff --git a/tests/fast/spark/test_spark_filter.py b/tests/fast/spark/test_spark_filter.py index a4733a44..9dbb8c94 100644 --- a/tests/fast/spark/test_spark_filter.py +++ b/tests/fast/spark/test_spark_filter.py @@ -2,26 +2,20 @@ _ = pytest.importorskip("duckdb.experimental.spark") + from spark_namespace import USE_ACTUAL_SPARK +from spark_namespace.errors import PySparkTypeError +from spark_namespace.sql.functions import array_contains, col from spark_namespace.sql.types import ( - LongType, - StructType, - BooleanType, - StructField, - StringType, - IntegerType, - LongType, - Row, ArrayType, - MapType, + Row, + StringType, + StructField, + StructType, ) -from spark_namespace.sql.functions import col, struct, when, lit, array_contains -from spark_namespace.errors import PySparkTypeError -import duckdb -import re -class TestDataFrameFilter(object): +class TestDataFrameFilter: def test_dataframe_filter(self, spark): data = [ (("James", "", "Smith"), ["Java", "Scala", "C++"], "OH", "M"), diff --git a/tests/fast/spark/test_spark_function_concat_ws.py b/tests/fast/spark/test_spark_function_concat_ws.py index 82f19cd1..b4268d0f 100644 --- a/tests/fast/spark/test_spark_function_concat_ws.py +++ b/tests/fast/spark/test_spark_function_concat_ws.py @@ -1,11 +1,11 @@ import pytest _ = pytest.importorskip("duckdb.experimental.spark") +from spark_namespace.sql.functions import col, concat_ws from spark_namespace.sql.types import Row -from spark_namespace.sql.functions import concat_ws, col -class TestReplaceEmpty(object): +class TestReplaceEmpty: def test_replace_empty(self, spark): data = [ ("firstRowFirstColumn", "firstRowSecondColumn"), diff --git a/tests/fast/spark/test_spark_functions_array.py b/tests/fast/spark/test_spark_functions_array.py index 5ecba132..36afed54 100644 --- a/tests/fast/spark/test_spark_functions_array.py +++ b/tests/fast/spark/test_spark_functions_array.py @@ -1,10 +1,11 @@ -import pytest import platform +import pytest + _ = pytest.importorskip("duckdb.experimental.spark") +from spark_namespace import USE_ACTUAL_SPARK from spark_namespace.sql import functions as sf from spark_namespace.sql.types import Row -from spark_namespace import USE_ACTUAL_SPARK pytestmark = pytest.mark.skipif( platform.system() == "Emscripten", diff --git a/tests/fast/spark/test_spark_functions_base64.py b/tests/fast/spark/test_spark_functions_base64.py index 5a179481..44e4a7cd 100644 --- a/tests/fast/spark/test_spark_functions_base64.py +++ b/tests/fast/spark/test_spark_functions_base64.py @@ -5,7 +5,7 @@ from spark_namespace.sql import functions as F -class TestSparkFunctionsBase64(object): +class TestSparkFunctionsBase64: def test_base64(self, spark): data = [ ("quack",), diff --git a/tests/fast/spark/test_spark_functions_date.py b/tests/fast/spark/test_spark_functions_date.py index a298c0ff..914d33f6 100644 --- a/tests/fast/spark/test_spark_functions_date.py +++ b/tests/fast/spark/test_spark_functions_date.py @@ -1,4 +1,5 @@ import warnings + import pytest _ = pytest.importorskip("duckdb.experimental.spark") @@ -6,11 +7,11 @@ from spark_namespace import USE_ACTUAL_SPARK from spark_namespace.sql import functions as F -from spark_namespace.sql.types import Row from spark_namespace.sql.functions import col +from spark_namespace.sql.types import Row -class TestsSparkFunctionsDate(object): +class TestsSparkFunctionsDate: def test_date_trunc(self, spark): df = spark.createDataFrame( [(datetime(2019, 1, 23, 14, 34, 9, 87539),)], diff --git a/tests/fast/spark/test_spark_functions_expr.py b/tests/fast/spark/test_spark_functions_expr.py index 7cc47735..f14dbcce 100644 --- a/tests/fast/spark/test_spark_functions_expr.py +++ b/tests/fast/spark/test_spark_functions_expr.py @@ -5,7 +5,7 @@ _ = pytest.importorskip("duckdb.experimental.spark") -class TestSparkFunctionsExpr(object): +class TestSparkFunctionsExpr: def test_expr(self, spark): df = spark.createDataFrame([["Alice"], ["Bob"]], ["name"]) res = df.select("name", F.expr("length(name)").alias("str_len")).collect() diff --git a/tests/fast/spark/test_spark_functions_hash.py b/tests/fast/spark/test_spark_functions_hash.py index 7b14f29e..d1890990 100644 --- a/tests/fast/spark/test_spark_functions_hash.py +++ b/tests/fast/spark/test_spark_functions_hash.py @@ -4,7 +4,7 @@ from spark_namespace.sql import functions as F -class TestSparkFunctionsHash(object): +class TestSparkFunctionsHash: def test_md5(self, spark): data = [ ("quack",), diff --git a/tests/fast/spark/test_spark_functions_hex.py b/tests/fast/spark/test_spark_functions_hex.py index 7d5f3c6a..c58c6d90 100644 --- a/tests/fast/spark/test_spark_functions_hex.py +++ b/tests/fast/spark/test_spark_functions_hex.py @@ -1,11 +1,11 @@ + import pytest -import sys _ = pytest.importorskip("duckdb.experimental.spark") from spark_namespace.sql import functions as F -class TestSparkFunctionsHex(object): +class TestSparkFunctionsHex: def test_hex_string_col(self, spark): data = [ ("quack",), @@ -32,7 +32,7 @@ def test_hex_binary_col(self, spark): def test_hex_integer_col(self, spark): data = [ - (int(42),), + (42,), ] res = ( spark.createDataFrame(data, ["firstColumn"]) diff --git a/tests/fast/spark/test_spark_functions_null.py b/tests/fast/spark/test_spark_functions_null.py index 230634dc..2bcfd94a 100644 --- a/tests/fast/spark/test_spark_functions_null.py +++ b/tests/fast/spark/test_spark_functions_null.py @@ -7,7 +7,7 @@ from spark_namespace.sql.types import Row -class TestsSparkFunctionsNull(object): +class TestsSparkFunctionsNull: def test_coalesce(self, spark): data = [ (None, 2), diff --git a/tests/fast/spark/test_spark_functions_numeric.py b/tests/fast/spark/test_spark_functions_numeric.py index 3548d439..30224735 100644 --- a/tests/fast/spark/test_spark_functions_numeric.py +++ b/tests/fast/spark/test_spark_functions_numeric.py @@ -3,13 +3,14 @@ _ = pytest.importorskip("duckdb.experimental.spark") import math + import numpy as np from spark_namespace import USE_ACTUAL_SPARK from spark_namespace.sql import functions as sf from spark_namespace.sql.types import Row -class TestSparkFunctionsNumeric(object): +class TestSparkFunctionsNumeric: def test_greatest(self, spark): data = [ (1, 2), diff --git a/tests/fast/spark/test_spark_functions_string.py b/tests/fast/spark/test_spark_functions_string.py index b8d7f483..0001a167 100644 --- a/tests/fast/spark/test_spark_functions_string.py +++ b/tests/fast/spark/test_spark_functions_string.py @@ -7,7 +7,7 @@ from spark_namespace.sql.types import Row -class TestSparkFunctionsString(object): +class TestSparkFunctionsString: def test_length(self, spark): data = [ ("firstRowFirstColumn",), diff --git a/tests/fast/spark/test_spark_group_by.py b/tests/fast/spark/test_spark_group_by.py index 9e8a8ea0..f3748f1d 100644 --- a/tests/fast/spark/test_spark_group_by.py +++ b/tests/fast/spark/test_spark_group_by.py @@ -3,47 +3,35 @@ _ = pytest.importorskip("duckdb.experimental.spark") from spark_namespace import USE_ACTUAL_SPARK -from spark_namespace.sql.types import ( - LongType, - StructType, - BooleanType, - StructField, - StringType, - IntegerType, - LongType, - Row, - ArrayType, - MapType, -) -from spark_namespace.sql.functions import col, struct, when, lit, array_contains from spark_namespace.sql.functions import ( - sum, + any_value, + approx_count_distinct, avg, + col, + covar_pop, + covar_samp, + first, + last, max, - min, - stddev_samp, - stddev, + median, + mode, + product, + skewness, std, + stddev, stddev_pop, + stddev_samp, + sum, var_pop, var_samp, variance, - mean, - mode, - median, - product, - count, - skewness, - any_value, - approx_count_distinct, - covar_pop, - covar_samp, - first, - last, +) +from spark_namespace.sql.types import ( + Row, ) -class TestDataFrameGroupBy(object): +class TestDataFrameGroupBy: def test_group_by(self, spark): simpleData = [ ("James", "Sales", "NY", 90000, 34, 10000), diff --git a/tests/fast/spark/test_spark_intersect.py b/tests/fast/spark/test_spark_intersect.py index ba0afbdd..8ec67dd0 100644 --- a/tests/fast/spark/test_spark_intersect.py +++ b/tests/fast/spark/test_spark_intersect.py @@ -1,10 +1,8 @@ -import platform import pytest _ = pytest.importorskip("duckdb.experimental.spark") from duckdb.experimental.spark.sql.types import Row -from duckdb.experimental.spark.sql.functions import col @pytest.fixture diff --git a/tests/fast/spark/test_spark_join.py b/tests/fast/spark/test_spark_join.py index f67c54cb..842dfbc5 100644 --- a/tests/fast/spark/test_spark_join.py +++ b/tests/fast/spark/test_spark_join.py @@ -2,20 +2,10 @@ _ = pytest.importorskip("duckdb.experimental.spark") +from spark_namespace.sql.functions import col from spark_namespace.sql.types import ( - LongType, - StructType, - BooleanType, - StructField, - StringType, - IntegerType, - LongType, Row, - ArrayType, - MapType, ) -from spark_namespace.sql.functions import col, struct, when, lit, array_contains -from spark_namespace.sql.functions import sum, avg, max, min, mean, count @pytest.fixture @@ -30,7 +20,7 @@ def dataframe_a(spark): ] empColumns = ["emp_id", "name", "superior_emp_id", "year_joined", "emp_dept_id", "gender", "salary"] dataframe = spark.createDataFrame(data=emp, schema=empColumns) - yield dataframe + return dataframe @pytest.fixture @@ -38,10 +28,10 @@ def dataframe_b(spark): dept = [("Finance", 10), ("Marketing", 20), ("Sales", 30), ("IT", 40)] deptColumns = ["dept_name", "dept_id"] dataframe = spark.createDataFrame(data=dept, schema=deptColumns) - yield dataframe + return dataframe -class TestDataFrameJoin(object): +class TestDataFrameJoin: def test_inner_join(self, dataframe_a, dataframe_b): df = dataframe_a.join(dataframe_b, dataframe_a.emp_dept_id == dataframe_b.dept_id, "inner") df = df.sort(*df.columns) diff --git a/tests/fast/spark/test_spark_limit.py b/tests/fast/spark/test_spark_limit.py index c00496a0..eb88fc6a 100644 --- a/tests/fast/spark/test_spark_limit.py +++ b/tests/fast/spark/test_spark_limit.py @@ -7,7 +7,7 @@ ) -class TestDataFrameLimit(object): +class TestDataFrameLimit: def test_dataframe_limit(self, spark): df = spark.sql("select * from range(100000)") df2 = df.limit(10) diff --git a/tests/fast/spark/test_spark_order_by.py b/tests/fast/spark/test_spark_order_by.py index cc08dd7c..030db4b8 100644 --- a/tests/fast/spark/test_spark_order_by.py +++ b/tests/fast/spark/test_spark_order_by.py @@ -2,24 +2,13 @@ _ = pytest.importorskip("duckdb.experimental.spark") +from spark_namespace.sql.functions import col from spark_namespace.sql.types import ( - LongType, - StructType, - BooleanType, - StructField, - StringType, - IntegerType, - LongType, Row, - ArrayType, - MapType, ) -from spark_namespace.sql.functions import col, struct, when, lit, array_contains -import duckdb -import re -class TestDataFrameOrderBy(object): +class TestDataFrameOrderBy: def test_order_by(self, spark): simpleData = [ ("James", "Sales", "NY", 90000, 34, 10000), diff --git a/tests/fast/spark/test_spark_pandas_dataframe.py b/tests/fast/spark/test_spark_pandas_dataframe.py index 6491b7a6..ab069156 100644 --- a/tests/fast/spark/test_spark_pandas_dataframe.py +++ b/tests/fast/spark/test_spark_pandas_dataframe.py @@ -3,22 +3,14 @@ _ = pytest.importorskip("duckdb.experimental.spark") pd = pytest.importorskip("pandas") +from pandas.testing import assert_frame_equal from spark_namespace.sql.types import ( - LongType, - StructType, - BooleanType, - StructField, - StringType, IntegerType, - LongType, Row, - ArrayType, - MapType, + StringType, + StructField, + StructType, ) -from spark_namespace.sql.functions import col, struct, when -import duckdb -import re -from pandas.testing import assert_frame_equal @pytest.fixture @@ -26,10 +18,10 @@ def pandasDF(spark): data = [["Scott", 50], ["Jeff", 45], ["Thomas", 54], ["Ann", 34]] # Create the pandas DataFrame df = pd.DataFrame(data, columns=["Name", "Age"]) - yield df + return df -class TestPandasDataFrame(object): +class TestPandasDataFrame: def test_pd_conversion_basic(self, spark, pandasDF): sparkDF = spark.createDataFrame(pandasDF) res = sparkDF.collect() diff --git a/tests/fast/spark/test_spark_readcsv.py b/tests/fast/spark/test_spark_readcsv.py index 5ba3d199..10d1a17c 100644 --- a/tests/fast/spark/test_spark_readcsv.py +++ b/tests/fast/spark/test_spark_readcsv.py @@ -2,12 +2,13 @@ _ = pytest.importorskip("duckdb.experimental.spark") -from spark_namespace.sql.types import Row -from spark_namespace import USE_ACTUAL_SPARK import textwrap +from spark_namespace import USE_ACTUAL_SPARK +from spark_namespace.sql.types import Row + -class TestSparkReadCSV(object): +class TestSparkReadCSV: def test_read_csv(self, spark, tmp_path): file_path = tmp_path / "basic.csv" with open(file_path, "w+") as f: diff --git a/tests/fast/spark/test_spark_readjson.py b/tests/fast/spark/test_spark_readjson.py index 638bee2d..aa8d8ec5 100644 --- a/tests/fast/spark/test_spark_readjson.py +++ b/tests/fast/spark/test_spark_readjson.py @@ -2,12 +2,11 @@ _ = pytest.importorskip("duckdb.experimental.spark") + from spark_namespace.sql.types import Row -import textwrap -import duckdb -class TestSparkReadJson(object): +class TestSparkReadJson: def test_read_json(self, duckdb_cursor, spark, tmp_path): file_path = tmp_path / "basic.parquet" file_path = file_path.as_posix() diff --git a/tests/fast/spark/test_spark_readparquet.py b/tests/fast/spark/test_spark_readparquet.py index 1b3ddd74..2f182650 100644 --- a/tests/fast/spark/test_spark_readparquet.py +++ b/tests/fast/spark/test_spark_readparquet.py @@ -2,12 +2,11 @@ _ = pytest.importorskip("duckdb.experimental.spark") + from spark_namespace.sql.types import Row -import textwrap -import duckdb -class TestSparkReadParquet(object): +class TestSparkReadParquet: def test_read_parquet(self, duckdb_cursor, spark, tmp_path): file_path = tmp_path / "basic.parquet" file_path = file_path.as_posix() diff --git a/tests/fast/spark/test_spark_runtime_config.py b/tests/fast/spark/test_spark_runtime_config.py index 5e93ed63..b9053899 100644 --- a/tests/fast/spark/test_spark_runtime_config.py +++ b/tests/fast/spark/test_spark_runtime_config.py @@ -5,7 +5,7 @@ from spark_namespace import USE_ACTUAL_SPARK -class TestSparkRuntimeConfig(object): +class TestSparkRuntimeConfig: def test_spark_runtime_config(self, spark): # This fetches the internal runtime config from the session spark.conf diff --git a/tests/fast/spark/test_spark_session.py b/tests/fast/spark/test_spark_session.py index 06c9dbcb..604c85f1 100644 --- a/tests/fast/spark/test_spark_session.py +++ b/tests/fast/spark/test_spark_session.py @@ -1,15 +1,16 @@ import pytest +from spark_namespace import USE_ACTUAL_SPARK +from spark_namespace.sql.types import Row + from duckdb.experimental.spark.exception import ( ContributionsAcceptedError, ) -from spark_namespace.sql.types import Row -from spark_namespace import USE_ACTUAL_SPARK _ = pytest.importorskip("duckdb.experimental.spark") from spark_namespace.sql import SparkSession -class TestSparkSession(object): +class TestSparkSession: def test_spark_session_default(self): session = SparkSession.builder.getOrCreate() diff --git a/tests/fast/spark/test_spark_to_csv.py b/tests/fast/spark/test_spark_to_csv.py index e5387a6c..122f3223 100644 --- a/tests/fast/spark/test_spark_to_csv.py +++ b/tests/fast/spark/test_spark_to_csv.py @@ -1,8 +1,7 @@ -import pytest -import tempfile - import os +import pytest + _ = pytest.importorskip("duckdb.experimental.spark") from spark_namespace import USE_ACTUAL_SPARK @@ -15,12 +14,13 @@ allow_module_level=True, ) -from duckdb import connect, InvalidInputException, read_csv -from conftest import NumpyPandas, ArrowPandas, getTimeSeriesData -from spark_namespace import USE_ACTUAL_SPARK -import pandas._testing as tm -import datetime import csv +import datetime + +from conftest import ArrowPandas, NumpyPandas, getTimeSeriesData +from spark_namespace import USE_ACTUAL_SPARK + +from duckdb import InvalidInputException, read_csv @pytest.fixture @@ -34,24 +34,24 @@ def df(spark): ) columns = ["CourseName", "fee", "discount"] dataframe = spark.createDataFrame(data=simpleData, schema=columns) - yield dataframe + return dataframe @pytest.fixture(params=[NumpyPandas(), ArrowPandas()]) def pandas_df_ints(request, spark): pandas = request.param dataframe = pandas.DataFrame({"a": [5, 3, 23, 2], "b": [45, 234, 234, 2]}) - yield dataframe + return dataframe @pytest.fixture(params=[NumpyPandas(), ArrowPandas()]) def pandas_df_strings(request, spark): pandas = request.param dataframe = pandas.DataFrame({"a": ["string1", "string2", "string3"]}) - yield dataframe + return dataframe -class TestSparkToCSV(object): +class TestSparkToCSV: def test_basic_to_csv(self, pandas_df_ints, spark, tmp_path): temp_file_name = os.path.join(tmp_path, "temp_file.csv") diff --git a/tests/fast/spark/test_spark_to_parquet.py b/tests/fast/spark/test_spark_to_parquet.py index 68a10f65..8dc2d386 100644 --- a/tests/fast/spark/test_spark_to_parquet.py +++ b/tests/fast/spark/test_spark_to_parquet.py @@ -1,8 +1,7 @@ -import pytest -import tempfile - import os +import pytest + _ = pytest.importorskip("duckdb.experimental.spark") @@ -17,10 +16,10 @@ def df(spark): ) columns = ["CourseName", "fee", "discount"] dataframe = spark.createDataFrame(data=simpleData, schema=columns) - yield dataframe + return dataframe -class TestSparkToParquet(object): +class TestSparkToParquet: def test_basic_to_parquet(self, df, spark, tmp_path): temp_file_name = os.path.join(tmp_path, "temp_file.parquet") diff --git a/tests/fast/spark/test_spark_transform.py b/tests/fast/spark/test_spark_transform.py index 1f1186c5..bf1c7b01 100644 --- a/tests/fast/spark/test_spark_transform.py +++ b/tests/fast/spark/test_spark_transform.py @@ -3,19 +3,8 @@ _ = pytest.importorskip("duckdb.experimental.spark") from spark_namespace.sql.types import ( - LongType, - StructType, - BooleanType, - StructField, - StringType, - IntegerType, - LongType, Row, - ArrayType, - MapType, ) -from spark_namespace.sql.functions import col, struct, when, lit, array_contains -from spark_namespace.sql.functions import sum, avg, max, min, mean, count @pytest.fixture @@ -26,7 +15,7 @@ def array_df(spark): ("Robert,,Williams", ["CSharp", "VB"], ["Spark", "Python"]), ] dataframe = spark.createDataFrame(data=data, schema=["Name", "Languages1", "Languages2"]) - yield dataframe + return dataframe @pytest.fixture @@ -40,10 +29,10 @@ def df(spark): ) columns = ["CourseName", "fee", "discount"] dataframe = spark.createDataFrame(data=simpleData, schema=columns) - yield dataframe + return dataframe -class TestDataFrameUnion(object): +class TestDataFrameUnion: def test_transform(self, spark, df): # Custom transformation 1 from spark_namespace.sql.functions import upper @@ -72,6 +61,6 @@ def apply_discount(df): # https://sparkbyexamples.com/pyspark/pyspark-transform-function/ @pytest.mark.skip(reason="LambdaExpressions are currently under development, waiting til that is finished") def test_transform_function(self, spark, array_df): - from spark_namespace.sql.functions import upper, transform + from spark_namespace.sql.functions import transform, upper df.select(transform("Languages1", lambda x: upper(x)).alias("languages1")).show() diff --git a/tests/fast/spark/test_spark_types.py b/tests/fast/spark/test_spark_types.py index 6c97c2d9..d19b3833 100644 --- a/tests/fast/spark/test_spark_types.py +++ b/tests/fast/spark/test_spark_types.py @@ -9,43 +9,42 @@ "Skipping these tests as they use test_all_types() which is specific to DuckDB", allow_module_level=True ) -from spark_namespace.sql.types import Row from spark_namespace.sql.types import ( - StringType, + ArrayType, BinaryType, BitstringType, - UUIDType, BooleanType, + ByteType, DateType, - TimestampType, - TimestampNTZType, - TimeType, - TimeNTZType, - TimestampNanosecondNTZType, - TimestampMilisecondNTZType, - TimestampSecondNTZType, + DayTimeIntervalType, DecimalType, DoubleType, FloatType, - ByteType, - UnsignedByteType, - ShortType, - UnsignedShortType, + HugeIntegerType, IntegerType, - UnsignedIntegerType, LongType, - UnsignedLongType, - HugeIntegerType, - UnsignedHugeIntegerType, - DayTimeIntervalType, - ArrayType, MapType, + ShortType, + StringType, StructField, StructType, + TimeNTZType, + TimestampMilisecondNTZType, + TimestampNanosecondNTZType, + TimestampNTZType, + TimestampSecondNTZType, + TimestampType, + TimeType, + UnsignedByteType, + UnsignedHugeIntegerType, + UnsignedIntegerType, + UnsignedLongType, + UnsignedShortType, + UUIDType, ) -class TestTypes(object): +class TestTypes: def test_all_types_schema(self, spark): # Create DataFrame df = spark.sql( diff --git a/tests/fast/spark/test_spark_udf.py b/tests/fast/spark/test_spark_udf.py index eebabbb3..cee0f256 100644 --- a/tests/fast/spark/test_spark_udf.py +++ b/tests/fast/spark/test_spark_udf.py @@ -3,7 +3,7 @@ _ = pytest.importorskip("duckdb.experimental.spark") -class TestSparkUDF(object): +class TestSparkUDF: def test_udf_register(self, spark): def to_upper_fn(s: str) -> str: return s.upper() diff --git a/tests/fast/spark/test_spark_union.py b/tests/fast/spark/test_spark_union.py index 8a3ff9ce..588c7ecd 100644 --- a/tests/fast/spark/test_spark_union.py +++ b/tests/fast/spark/test_spark_union.py @@ -1,10 +1,11 @@ import platform + import pytest _ = pytest.importorskip("duckdb.experimental.spark") -from spark_namespace.sql.types import Row from spark_namespace.sql.functions import col +from spark_namespace.sql.types import Row @pytest.fixture @@ -18,7 +19,7 @@ def df(spark): columns = ["employee_name", "department", "state", "salary", "age", "bonus"] dataframe = spark.createDataFrame(data=simpleData, schema=columns) - yield dataframe + return dataframe @pytest.fixture @@ -32,10 +33,10 @@ def df2(spark): ] columns2 = ["employee_name", "department", "state", "salary", "age", "bonus"] dataframe = spark.createDataFrame(data=simpleData2, schema=columns2) - yield dataframe + return dataframe -class TestDataFrameUnion(object): +class TestDataFrameUnion: def test_merge_with_union(self, df, df2): unionDF = df.union(df2) res = unionDF.collect() diff --git a/tests/fast/spark/test_spark_union_by_name.py b/tests/fast/spark/test_spark_union_by_name.py index 4739f0d8..bec539a2 100644 --- a/tests/fast/spark/test_spark_union_by_name.py +++ b/tests/fast/spark/test_spark_union_by_name.py @@ -4,36 +4,25 @@ from spark_namespace.sql.types import ( - LongType, - StructType, - BooleanType, - StructField, - StringType, - IntegerType, - LongType, Row, - ArrayType, - MapType, ) -from spark_namespace.sql.functions import col, struct, when, lit, array_contains -from spark_namespace.sql.functions import sum, avg, max, min, mean, count @pytest.fixture def df1(spark): data = [("James", 34), ("Michael", 56), ("Robert", 30), ("Maria", 24)] dataframe = spark.createDataFrame(data=data, schema=["name", "id"]) - yield dataframe + return dataframe @pytest.fixture def df2(spark): data2 = [(34, "James"), (45, "Maria"), (45, "Jen"), (34, "Jeff")] dataframe = spark.createDataFrame(data=data2, schema=["id", "name"]) - yield dataframe + return dataframe -class TestDataFrameUnion(object): +class TestDataFrameUnion: def test_union_by_name(self, df1, df2): rel = df1.unionByName(df2) res = rel.collect() diff --git a/tests/fast/spark/test_spark_with_column.py b/tests/fast/spark/test_spark_with_column.py index 2980e7fe..4ea62fe1 100644 --- a/tests/fast/spark/test_spark_with_column.py +++ b/tests/fast/spark/test_spark_with_column.py @@ -2,25 +2,11 @@ _ = pytest.importorskip("duckdb.experimental.spark") -from spark_namespace.sql.types import ( - LongType, - StructType, - BooleanType, - StructField, - StringType, - IntegerType, - LongType, - Row, - ArrayType, - MapType, -) -from spark_namespace.sql.functions import col, struct, when, lit from spark_namespace import USE_ACTUAL_SPARK -import duckdb -import re +from spark_namespace.sql.functions import col, lit -class TestWithColumn(object): +class TestWithColumn: def test_with_column(self, spark): data = [ ("James", "", "Smith", "1991-04-01", "M", 3000), diff --git a/tests/fast/spark/test_spark_with_column_renamed.py b/tests/fast/spark/test_spark_with_column_renamed.py index 8534ab0b..789bf2c1 100644 --- a/tests/fast/spark/test_spark_with_column_renamed.py +++ b/tests/fast/spark/test_spark_with_column_renamed.py @@ -2,24 +2,17 @@ _ = pytest.importorskip("duckdb.experimental.spark") + +from spark_namespace.sql.functions import col from spark_namespace.sql.types import ( - LongType, - StructType, - BooleanType, - StructField, - StringType, IntegerType, - LongType, - Row, - ArrayType, - MapType, + StringType, + StructField, + StructType, ) -from spark_namespace.sql.functions import col, struct, when, lit -import duckdb -import re -class TestWithColumnRenamed(object): +class TestWithColumnRenamed: def test_with_column_renamed(self, spark): dataDF = [ (("James", "", "Smith"), "1991-04-01", "M", 3000), @@ -28,7 +21,6 @@ def test_with_column_renamed(self, spark): (("Maria", "Anne", "Jones"), "1967-12-01", "F", 4000), (("Jen", "Mary", "Brown"), "1980-02-17", "F", -1), ] - from spark_namespace.sql.types import StructType, StructField, StringType, IntegerType schema = StructType( [ diff --git a/tests/fast/spark/test_spark_with_columns.py b/tests/fast/spark/test_spark_with_columns.py index 535f357d..244d40a3 100644 --- a/tests/fast/spark/test_spark_with_columns.py +++ b/tests/fast/spark/test_spark_with_columns.py @@ -3,8 +3,8 @@ _ = pytest.importorskip("duckdb.experimental.spark") -from spark_namespace.sql.functions import col, lit from spark_namespace import USE_ACTUAL_SPARK +from spark_namespace.sql.functions import col, lit class TestWithColumns: diff --git a/tests/fast/spark/test_spark_with_columns_renamed.py b/tests/fast/spark/test_spark_with_columns_renamed.py index 80b8b9e0..8c24062b 100644 --- a/tests/fast/spark/test_spark_with_columns_renamed.py +++ b/tests/fast/spark/test_spark_with_columns_renamed.py @@ -1,4 +1,5 @@ import re + import pytest _ = pytest.importorskip("duckdb.experimental.spark") @@ -6,7 +7,7 @@ from spark_namespace import USE_ACTUAL_SPARK -class TestWithColumnsRenamed(object): +class TestWithColumnsRenamed: def test_with_columns_renamed(self, spark): dataDF = [ (("James", "", "Smith"), "1991-04-01", "M", 3000), @@ -15,7 +16,7 @@ def test_with_columns_renamed(self, spark): (("Maria", "Anne", "Jones"), "1967-12-01", "F", 4000), (("Jen", "Mary", "Brown"), "1980-02-17", "F", -1), ] - from spark_namespace.sql.types import StructType, StructField, StringType, IntegerType + from spark_namespace.sql.types import IntegerType, StringType, StructField, StructType schema = StructType( [ diff --git a/tests/fast/sqlite/test_types.py b/tests/fast/sqlite/test_types.py index 3ffdceae..47c4b7e1 100644 --- a/tests/fast/sqlite/test_types.py +++ b/tests/fast/sqlite/test_types.py @@ -27,8 +27,8 @@ import datetime import decimal import unittest + import duckdb -import pytest class DuckDBTypeTests(unittest.TestCase): diff --git a/tests/fast/test_alex_multithread.py b/tests/fast/test_alex_multithread.py index bcb0181b..7e25b5bb 100644 --- a/tests/fast/test_alex_multithread.py +++ b/tests/fast/test_alex_multithread.py @@ -1,8 +1,9 @@ import platform -import duckdb from threading import Thread, current_thread + import pytest +import duckdb pytestmark = pytest.mark.xfail( condition=platform.system() == "Emscripten", @@ -30,7 +31,7 @@ def insert_from_same_connection(duckdb_cursor): duckdb_cursor.execute("""INSERT INTO my_inserts VALUES (?)""", (thread_name,)) -class TestPythonMultithreading(object): +class TestPythonMultithreading: def test_multiple_cursors(self, duckdb_cursor): duckdb_con = duckdb.connect() # In Memory DuckDB duckdb_con.execute("""CREATE OR REPLACE TABLE my_inserts (thread_name varchar)""") diff --git a/tests/fast/test_all_types.py b/tests/fast/test_all_types.py index 3e701ced..e74cca30 100644 --- a/tests/fast/test_all_types.py +++ b/tests/fast/test_all_types.py @@ -1,14 +1,16 @@ -import duckdb -import pandas as pd -import numpy as np import datetime import math +import warnings +from contextlib import suppress from decimal import Decimal from uuid import UUID -import pytz + +import numpy as np +import pandas as pd import pytest -import warnings -from contextlib import suppress +import pytz + +import duckdb def replace_with_ndarray(obj): @@ -25,7 +27,6 @@ def replace_with_ndarray(obj): # we need to write our own equality function that considers nan==nan for testing purposes def recursive_equality(o1, o2): - import math if type(o1) != type(o2): return False @@ -114,7 +115,7 @@ def recursive_equality(o1, o2): ] -class TestAllTypes(object): +class TestAllTypes: @pytest.mark.parametrize("cur_type", all_types) def test_fetchall(self, cur_type): conn = duckdb.connect() @@ -538,7 +539,7 @@ def test_fetchnumpy(self, cur_type): @pytest.mark.parametrize("cur_type", all_types) def test_arrow(self, cur_type): try: - import pyarrow as pa + pass except: return # We skip those since the extreme ranges are not supported in arrow. diff --git a/tests/fast/test_ambiguous_prepare.py b/tests/fast/test_ambiguous_prepare.py index 998367ec..0865b007 100644 --- a/tests/fast/test_ambiguous_prepare.py +++ b/tests/fast/test_ambiguous_prepare.py @@ -1,9 +1,8 @@ + import duckdb -import pandas as pd -import pytest -class TestAmbiguousPrepare(object): +class TestAmbiguousPrepare: def test_bool(self, duckdb_cursor): conn = duckdb.connect() res = conn.execute("select ?, ?, ?", (True, 42, [1, 2, 3])).fetchall() diff --git a/tests/fast/test_case_alias.py b/tests/fast/test_case_alias.py index 2e42f0ed..5092f099 100644 --- a/tests/fast/test_case_alias.py +++ b/tests/fast/test_case_alias.py @@ -1,17 +1,12 @@ -import pandas -import numpy as np -import datetime -import duckdb import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb -class TestCaseAlias(object): +class TestCaseAlias: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_case_alias(self, duckdb_cursor, pandas): - import numpy as np - import datetime - import duckdb con = duckdb.connect(":memory:") diff --git a/tests/fast/test_context_manager.py b/tests/fast/test_context_manager.py index 65ec1d33..b6a9ebb2 100644 --- a/tests/fast/test_context_manager.py +++ b/tests/fast/test_context_manager.py @@ -1,7 +1,7 @@ import duckdb -class TestContextManager(object): +class TestContextManager: def test_context_manager(self, duckdb_cursor): with duckdb.connect(database=":memory:", read_only=False) as con: assert con.execute("select 1").fetchall() == [(1,)] diff --git a/tests/fast/test_duckdb_api.py b/tests/fast/test_duckdb_api.py index ea847d50..d779a368 100644 --- a/tests/fast/test_duckdb_api.py +++ b/tests/fast/test_duckdb_api.py @@ -1,6 +1,7 @@ -import duckdb import sys +import duckdb + def test_duckdb_api(): res = duckdb.execute("SELECT name, value FROM duckdb_settings() WHERE name == 'duckdb_api'") diff --git a/tests/fast/test_expression.py b/tests/fast/test_expression.py index 82753382..049a2a5c 100644 --- a/tests/fast/test_expression.py +++ b/tests/fast/test_expression.py @@ -1,19 +1,20 @@ +import datetime import platform -import duckdb + import pytest -from duckdb.typing import INTEGER, VARCHAR, TIMESTAMP + +import duckdb from duckdb import ( - Expression, - ConstantExpression, + CaseExpression, + CoalesceOperator, ColumnExpression, + ConstantExpression, + FunctionExpression, LambdaExpression, - CoalesceOperator, StarExpression, - FunctionExpression, - CaseExpression, ) -from duckdb.value.constant import Value, IntegerValue -import datetime +from duckdb.typing import INTEGER, TIMESTAMP, VARCHAR +from duckdb.value.constant import IntegerValue, Value pytestmark = pytest.mark.skipif( platform.system() == "Emscripten", @@ -35,10 +36,10 @@ def filter_rel(): ) tbl(a, b) """ ) - yield rel + return rel -class TestExpression(object): +class TestExpression: def test_constant_expression(self): con = duckdb.connect() @@ -839,7 +840,7 @@ def test_filter_and(self, filter_rel): expr = ~expr # AND operator - expr = expr & ("b" != ConstantExpression("b")) + expr = expr & (ConstantExpression("b") != "b") rel2 = filter_rel.filter(expr) res = rel2.fetchall() assert len(res) == 2 diff --git a/tests/fast/test_filesystem.py b/tests/fast/test_filesystem.py index 7b8fbb05..3fd6d60d 100644 --- a/tests/fast/test_filesystem.py +++ b/tests/fast/test_filesystem.py @@ -1,19 +1,19 @@ import logging import sys -from pathlib import Path -from shutil import copyfileobj -from typing import Callable, List from os.path import exists -from pathlib import PurePosixPath +from pathlib import Path, PurePosixPath +from shutil import copyfileobj +from typing import Callable + +from pytest import MonkeyPatch, fixture, importorskip, mark, raises import duckdb from duckdb import DuckDBPyConnection, InvalidInputException -from pytest import raises, importorskip, fixture, MonkeyPatch, mark importorskip("fsspec", "2022.11.0") -from fsspec import filesystem, AbstractFileSystem -from fsspec.implementations.memory import MemoryFileSystem +from fsspec import AbstractFileSystem, filesystem from fsspec.implementations.local import LocalFileOpener, LocalFileSystem +from fsspec.implementations.memory import MemoryFileSystem FILENAME = "integers.csv" @@ -35,13 +35,13 @@ def ceptor(*args, **kwargs): return error_occurred -@fixture() +@fixture def duckdb_cursor(): with duckdb.connect() as conn: yield conn -@fixture() +@fixture def memory(): fs = filesystem("memory", skip_instance_cache=True) diff --git a/tests/fast/test_get_table_names.py b/tests/fast/test_get_table_names.py index 1f90e444..92fa1c39 100644 --- a/tests/fast/test_get_table_names.py +++ b/tests/fast/test_get_table_names.py @@ -1,8 +1,9 @@ -import duckdb import pytest +import duckdb + -class TestGetTableNames(object): +class TestGetTableNames: def test_table_success(self, duckdb_cursor): conn = duckdb.connect() table_names = conn.get_table_names("SELECT * FROM my_table1, my_table2, my_table3") diff --git a/tests/fast/test_import_export.py b/tests/fast/test_import_export.py index d98a2d73..09b8cbda 100644 --- a/tests/fast/test_import_export.py +++ b/tests/fast/test_import_export.py @@ -1,10 +1,11 @@ -import duckdb -import pytest -from os import path import shutil -import os +from os import path from pathlib import Path +import pytest + +import duckdb + def export_database(export_location): # Create the db diff --git a/tests/fast/test_insert.py b/tests/fast/test_insert.py index baae75b4..34489b44 100644 --- a/tests/fast/test_insert.py +++ b/tests/fast/test_insert.py @@ -1,11 +1,11 @@ -import duckdb -import tempfile -import os + import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb -class TestInsert(object): +class TestInsert: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_insert(self, pandas): test_df = pandas.DataFrame({"i": [1, 2, 3], "j": ["one", "two", "three"]}) diff --git a/tests/fast/test_json_logging.py b/tests/fast/test_json_logging.py index a7f305f3..b29ea7bf 100644 --- a/tests/fast/test_json_logging.py +++ b/tests/fast/test_json_logging.py @@ -1,8 +1,9 @@ import json -import duckdb import pytest +import duckdb + def _parse_json_func(error_prefix: str): """Helper to check that the error message is indeed parsable json""" diff --git a/tests/fast/test_many_con_same_file.py b/tests/fast/test_many_con_same_file.py index 3cef2494..79b5db68 100644 --- a/tests/fast/test_many_con_same_file.py +++ b/tests/fast/test_many_con_same_file.py @@ -1,7 +1,9 @@ -import duckdb import os + import pytest +import duckdb + def get_tables(con): tbls = con.execute("SHOW TABLES").fetchall() diff --git a/tests/fast/test_map.py b/tests/fast/test_map.py index f86dd60b..1ce63110 100644 --- a/tests/fast/test_map.py +++ b/tests/fast/test_map.py @@ -1,9 +1,10 @@ -import duckdb -import numpy -import pytest -from datetime import date, timedelta import re -from conftest import NumpyPandas, ArrowPandas +from datetime import date, timedelta + +import pytest +from conftest import ArrowPandas, NumpyPandas + +import duckdb # column count differs from bind @@ -14,7 +15,7 @@ def evil1(df): return df -class TestMap(object): +class TestMap: @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_evil_map(self, duckdb_cursor, pandas): testrel = duckdb.values([1, 2]) diff --git a/tests/fast/test_metatransaction.py b/tests/fast/test_metatransaction.py index f617cba2..35d7c239 100644 --- a/tests/fast/test_metatransaction.py +++ b/tests/fast/test_metatransaction.py @@ -7,7 +7,7 @@ NUMBER_OF_COLUMNS = 1 -class TestMetaTransaction(object): +class TestMetaTransaction: def test_fetchmany(self, duckdb_cursor): duckdb_cursor.execute("CREATE SEQUENCE id_seq") column_names = ",\n".join([f"column_{i} FLOAT" for i in range(1, NUMBER_OF_COLUMNS + 1)]) diff --git a/tests/fast/test_multi_statement.py b/tests/fast/test_multi_statement.py index 722ab31a..cd3111e6 100644 --- a/tests/fast/test_multi_statement.py +++ b/tests/fast/test_multi_statement.py @@ -1,11 +1,11 @@ -import duckdb import os import shutil +import duckdb + -class TestMultiStatement(object): +class TestMultiStatement: def test_multi_statement(self, duckdb_cursor): - import duckdb con = duckdb.connect(":memory:") diff --git a/tests/fast/test_multithread.py b/tests/fast/test_multithread.py index 628aacd8..aeeeb412 100644 --- a/tests/fast/test_multithread.py +++ b/tests/fast/test_multithread.py @@ -1,13 +1,13 @@ +import os import platform -import duckdb -import pytest -import threading import queue as Queue +import threading + import numpy as np -from conftest import NumpyPandas, ArrowPandas -import os -from typing import List +import pytest +from conftest import ArrowPandas, NumpyPandas +import duckdb pytestmark = pytest.mark.xfail( condition=platform.system() == "Emscripten", @@ -36,7 +36,7 @@ def multithread_test(self, result_verification=everything_succeeded): queue = Queue.Queue() # Create all threads - for i in range(0, self.duckdb_insert_thread_count): + for i in range(self.duckdb_insert_thread_count): self.threads.append( threading.Thread( target=self.thread_function, args=(duckdb_conn, queue, self.pandas), name="duckdb_thread_" + str(i) @@ -45,13 +45,13 @@ def multithread_test(self, result_verification=everything_succeeded): # Record for every thread if they succeeded or not thread_results = [] - for i in range(0, len(self.threads)): + for i in range(len(self.threads)): self.threads[i].start() thread_result: bool = queue.get(timeout=60) thread_results.append(thread_result) # Finish all threads - for i in range(0, len(self.threads)): + for i in range(len(self.threads)): self.threads[i].join() # Assert that the results are what we expected @@ -374,7 +374,7 @@ def cursor(duckdb_conn, queue, pandas): queue.put(True) -class TestDuckMultithread(object): +class TestDuckMultithread: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_execute(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, execute_query, pandas) diff --git a/tests/fast/test_non_default_conn.py b/tests/fast/test_non_default_conn.py index cb0218e3..06cd5fe5 100644 --- a/tests/fast/test_non_default_conn.py +++ b/tests/fast/test_non_default_conn.py @@ -1,11 +1,12 @@ -import pandas as pd -import numpy as np -import duckdb import os import tempfile +import pandas as pd + +import duckdb + -class TestNonDefaultConn(object): +class TestNonDefaultConn: def test_values(self, duckdb_cursor): duckdb_cursor.execute("create table t (a integer)") duckdb.values([1], connection=duckdb_cursor).insert_into("t") diff --git a/tests/fast/test_parameter_list.py b/tests/fast/test_parameter_list.py index 5a85ac2f..a28838ba 100644 --- a/tests/fast/test_parameter_list.py +++ b/tests/fast/test_parameter_list.py @@ -1,9 +1,10 @@ -import duckdb import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb -class TestParameterList(object): +class TestParameterList: def test_bool(self, duckdb_cursor): conn = duckdb.connect() conn.execute("create table bool_table (a bool)") diff --git a/tests/fast/test_parquet.py b/tests/fast/test_parquet.py index 61d74023..fd506da2 100644 --- a/tests/fast/test_parquet.py +++ b/tests/fast/test_parquet.py @@ -1,8 +1,8 @@ -import duckdb -import pytest import os -import tempfile -import pandas as pd + +import pytest + +import duckdb VARCHAR = duckdb.typing.VARCHAR BIGINT = duckdb.typing.BIGINT @@ -17,7 +17,7 @@ def tmp_parquets(tmp_path_factory): return tmp_parquets -class TestParquet(object): +class TestParquet: def test_scan_binary(self, duckdb_cursor): conn = duckdb.connect() res = conn.execute("SELECT typeof(#1) FROM parquet_scan('" + filename + "') limit 1").fetchall() diff --git a/tests/fast/test_pypi_cleanup.py b/tests/fast/test_pypi_cleanup.py index 84d4c9ff..0e0439ce 100644 --- a/tests/fast/test_pypi_cleanup.py +++ b/tests/fast/test_pypi_cleanup.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -""" -Unit tests for pypi_cleanup.py +"""Unit tests for pypi_cleanup.py Run with: python -m pytest test_pypi_cleanup.py -v """ @@ -15,18 +14,18 @@ duckdb_packaging = pytest.importorskip("duckdb_packaging") from duckdb_packaging.pypi_cleanup import ( - PyPICleanup, + AuthenticationError, CsrfParser, + PyPICleanup, PyPICleanupError, - AuthenticationError, ValidationError, - setup_logging, - validate_username, create_argument_parser, - session_with_retries, load_credentials, - validate_arguments, main, + session_with_retries, + setup_logging, + validate_arguments, + validate_username, ) @@ -116,7 +115,7 @@ def test_create_session_with_retries(self): # Verify retry adapter is mounted adapter = session.get_adapter("https://example.com") assert hasattr(adapter, "max_retries") - retries = getattr(adapter, "max_retries") + retries = adapter.max_retries assert isinstance(retries, Retry) @patch("duckdb_packaging.pypi_cleanup.logging.basicConfig") diff --git a/tests/fast/test_pytorch.py b/tests/fast/test_pytorch.py index c5b9b4d6..c0b9392d 100644 --- a/tests/fast/test_pytorch.py +++ b/tests/fast/test_pytorch.py @@ -1,6 +1,6 @@ -import duckdb import pytest +import duckdb torch = pytest.importorskip("torch") diff --git a/tests/fast/test_relation.py b/tests/fast/test_relation.py index 31ca393c..6628198f 100644 --- a/tests/fast/test_relation.py +++ b/tests/fast/test_relation.py @@ -1,16 +1,17 @@ -import duckdb -import numpy as np +import datetime +import gc +import os import platform import tempfile -import os + +import numpy as np import pandas as pd import pytest from conftest import ArrowPandas, NumpyPandas -import datetime -import gc -from duckdb import ColumnExpression -from duckdb.typing import BIGINT, VARCHAR, TINYINT, BOOLEAN +import duckdb +from duckdb import ColumnExpression +from duckdb.typing import BIGINT, BOOLEAN, TINYINT, VARCHAR @pytest.fixture(scope="session") @@ -25,7 +26,7 @@ def get_relation(conn): return conn.from_df(test_df) -class TestRelation(object): +class TestRelation: def test_csv_auto(self): conn = duckdb.connect() df_rel = get_relation(conn) diff --git a/tests/fast/test_relation_dependency_leak.py b/tests/fast/test_relation_dependency_leak.py index ee98e30a..73ea7df7 100644 --- a/tests/fast/test_relation_dependency_leak.py +++ b/tests/fast/test_relation_dependency_leak.py @@ -1,5 +1,6 @@ -import numpy as np import os + +import numpy as np import pytest try: @@ -8,8 +9,7 @@ can_run = True except ImportError: can_run = False -from conftest import NumpyPandas, ArrowPandas - +from conftest import ArrowPandas, NumpyPandas psutil = pytest.importorskip("psutil") @@ -46,7 +46,7 @@ def pandas_replacement(pandas, duckdb_cursor): duckdb_cursor.query("select sum(x) from df").fetchall() -class TestRelationDependencyMemoryLeak(object): +class TestRelationDependencyMemoryLeak: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_from_arrow_leak(self, pandas, duckdb_cursor): if not can_run: diff --git a/tests/fast/test_replacement_scan.py b/tests/fast/test_replacement_scan.py index 555773dc..c9d9ae3a 100644 --- a/tests/fast/test_replacement_scan.py +++ b/tests/fast/test_replacement_scan.py @@ -1,7 +1,9 @@ -import duckdb import os + import pytest +import duckdb + pa = pytest.importorskip("pyarrow") pl = pytest.importorskip("polars") pd = pytest.importorskip("pandas") @@ -9,7 +11,7 @@ def using_table(con, to_scan, object_name): local_scope = {"con": con, object_name: to_scan, "object_name": object_name} - exec(f"result = con.table(object_name)", globals(), local_scope) + exec("result = con.table(object_name)", globals(), local_scope) return local_scope["result"] @@ -75,7 +77,7 @@ def create_relation(conn, query: str) -> duckdb.DuckDBPyRelation: return conn.sql(query) -class TestReplacementScan(object): +class TestReplacementScan: def test_csv_replacement(self): con = duckdb.connect() filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "integers.csv") diff --git a/tests/fast/test_result.py b/tests/fast/test_result.py index 906b1198..38ae1de6 100644 --- a/tests/fast/test_result.py +++ b/tests/fast/test_result.py @@ -1,9 +1,11 @@ -import duckdb -import pytest import datetime +import pytest + +import duckdb + -class TestPythonResult(object): +class TestPythonResult: def test_result_closed(self, duckdb_cursor): connection = duckdb.connect("") cursor = connection.cursor() diff --git a/tests/fast/test_runtime_error.py b/tests/fast/test_runtime_error.py index 327be004..7ab160bb 100644 --- a/tests/fast/test_runtime_error.py +++ b/tests/fast/test_runtime_error.py @@ -1,12 +1,13 @@ -import duckdb import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb closed = lambda: pytest.raises(duckdb.ConnectionException, match="Connection already closed") no_result_set = lambda: pytest.raises(duckdb.InvalidInputException, match="No open result set") -class TestRuntimeError(object): +class TestRuntimeError: def test_fetch_error(self): con = duckdb.connect() con.execute("create table tbl as select 'hello' i") diff --git a/tests/fast/test_sql_expression.py b/tests/fast/test_sql_expression.py index 4dc4cab5..f3cf41ca 100644 --- a/tests/fast/test_sql_expression.py +++ b/tests/fast/test_sql_expression.py @@ -1,5 +1,6 @@ -import duckdb import pytest + +import duckdb from duckdb import ( ColumnExpression, ConstantExpression, @@ -7,7 +8,7 @@ ) -class TestSQLExpression(object): +class TestSQLExpression: def test_sql_expression_basic(self, duckdb_cursor): # Test simple constant expressions expr = SQLExpression("42") diff --git a/tests/fast/test_string_annotation.py b/tests/fast/test_string_annotation.py index 83685bed..17c22844 100644 --- a/tests/fast/test_string_annotation.py +++ b/tests/fast/test_string_annotation.py @@ -1,7 +1,6 @@ -import duckdb -import pytest import sys -from typing import Union + +import pytest def make_annotated_function(type: str): @@ -19,7 +18,6 @@ def test_base(): def python_version_lower_than_3_10(): - import sys if sys.version_info[0] < 3: return True @@ -28,7 +26,7 @@ def python_version_lower_than_3_10(): return False -class TestStringAnnotation(object): +class TestStringAnnotation: @pytest.mark.skipif( python_version_lower_than_3_10(), reason="inspect.signature(eval_str=True) only supported since 3.10 and higher" ) diff --git a/tests/fast/test_tf.py b/tests/fast/test_tf.py index db93d0de..ceec2ee0 100644 --- a/tests/fast/test_tf.py +++ b/tests/fast/test_tf.py @@ -1,6 +1,6 @@ -import duckdb import pytest +import duckdb tf = pytest.importorskip("tensorflow") diff --git a/tests/fast/test_transaction.py b/tests/fast/test_transaction.py index ff0ba1a7..4a06c9e7 100644 --- a/tests/fast/test_transaction.py +++ b/tests/fast/test_transaction.py @@ -1,8 +1,8 @@ + import duckdb -import pandas as pd -class TestConnectionTransaction(object): +class TestConnectionTransaction: def test_transaction(self, duckdb_cursor): con = duckdb.connect() con.execute("create table t (i integer)") diff --git a/tests/fast/test_type.py b/tests/fast/test_type.py index 1e8ebc25..768b7782 100644 --- a/tests/fast/test_type.py +++ b/tests/fast/test_type.py @@ -1,47 +1,46 @@ -import duckdb -import os -import pandas as pd -import pytest -from typing import Union, Optional import sys +from typing import Optional, Union + +import pytest +import duckdb +import duckdb.typing from duckdb.typing import ( - SQLNULL, - BOOLEAN, - TINYINT, - UTINYINT, - SMALLINT, - USMALLINT, - INTEGER, - UINTEGER, BIGINT, - UBIGINT, - HUGEINT, - UHUGEINT, - UUID, - FLOAT, - DOUBLE, + BIT, + BLOB, + BOOLEAN, DATE, + DOUBLE, + FLOAT, + HUGEINT, + INTEGER, + INTERVAL, + SMALLINT, + SQLNULL, + TIME, + TIME_TZ, TIMESTAMP, TIMESTAMP_MS, TIMESTAMP_NS, TIMESTAMP_S, - DuckDBPyType, - TIME, - TIME_TZ, TIMESTAMP_TZ, + TINYINT, + UBIGINT, + UHUGEINT, + UINTEGER, + USMALLINT, + UTINYINT, + UUID, VARCHAR, - BLOB, - BIT, - INTERVAL, + DuckDBPyType, ) -import duckdb.typing -class TestType(object): +class TestType: def test_sqltype(self): assert str(duckdb.sqltype("struct(a VARCHAR, b BIGINT)")) == "STRUCT(a VARCHAR, b BIGINT)" - # todo: add tests with invalid type_str + # TODO: add tests with invalid type_str def test_primitive_types(self): assert str(SQLNULL) == '"NULL"' @@ -118,7 +117,6 @@ def test_union_type(self): type = duckdb.union_type({"a": BIGINT, "b": VARCHAR, "c": TINYINT}) assert str(type) == "UNION(a BIGINT, b VARCHAR, c TINYINT)" - import sys @pytest.mark.skipif(sys.version_info < (3, 9), reason="requires >= python3.9") def test_implicit_convert_from_builtin_type(self): diff --git a/tests/fast/test_type_explicit.py b/tests/fast/test_type_explicit.py index 7b0797e6..3b9fe334 100644 --- a/tests/fast/test_type_explicit.py +++ b/tests/fast/test_type_explicit.py @@ -1,7 +1,7 @@ import duckdb -class TestMap(object): +class TestMap: def test_array_list_tuple_ambiguity(self): con = duckdb.connect() res = con.sql("SELECT $arg", params={"arg": (1, 2)}).fetchall()[0][0] diff --git a/tests/fast/test_unicode.py b/tests/fast/test_unicode.py index 7d08ac88..f1ed8501 100644 --- a/tests/fast/test_unicode.py +++ b/tests/fast/test_unicode.py @@ -1,11 +1,11 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- -import duckdb import pandas as pd +import duckdb + -class TestUnicode(object): +class TestUnicode: def test_unicode_pandas_scan(self, duckdb_cursor): con = duckdb.connect(database=":memory:", read_only=False) test_df = pd.DataFrame.from_dict({"i": [1, 2, 3], "j": ["a", "c", "ë"]}) diff --git a/tests/fast/test_union.py b/tests/fast/test_union.py index 912caff9..d47a8192 100644 --- a/tests/fast/test_union.py +++ b/tests/fast/test_union.py @@ -1,8 +1,8 @@ + import duckdb -import pandas as pd -class TestUnion(object): +class TestUnion: def test_union_by_all(self): connection = duckdb.connect() diff --git a/tests/fast/test_value.py b/tests/fast/test_value.py index c17264fd..9e446fc3 100644 --- a/tests/fast/test_value.py +++ b/tests/fast/test_value.py @@ -1,74 +1,68 @@ -import duckdb -from pytest import raises -from duckdb import NotImplementedException, InvalidInputException -from duckdb.value.constant import ( - Value, - NullValue, - BooleanValue, - UnsignedBinaryValue, - UnsignedShortValue, - UnsignedIntegerValue, - UnsignedLongValue, - BinaryValue, - ShortValue, - IntegerValue, - LongValue, - HugeIntegerValue, - UnsignedHugeIntegerValue, - FloatValue, - DoubleValue, - DecimalValue, - StringValue, - UUIDValue, - BitValue, - BlobValue, - DateValue, - IntervalValue, - TimestampValue, - TimestampSecondValue, - TimestampMilisecondValue, - TimestampNanosecondValue, - TimestampTimeZoneValue, - TimeValue, - TimeTimeZoneValue, -) -import uuid import datetime -import pytest import decimal +import uuid + +import pytest +from pytest import raises +import duckdb +from duckdb import InvalidInputException, NotImplementedException from duckdb.typing import ( - SQLNULL, + BIGINT, + BIT, + BLOB, BOOLEAN, - TINYINT, - UTINYINT, - SMALLINT, - USMALLINT, + DATE, + DOUBLE, + FLOAT, + HUGEINT, INTEGER, - UINTEGER, - BIGINT, + INTERVAL, + SMALLINT, + SQLNULL, + TIME, + TIMESTAMP, + TINYINT, UBIGINT, - HUGEINT, UHUGEINT, + UINTEGER, + USMALLINT, + UTINYINT, UUID, - FLOAT, - DOUBLE, - DATE, - TIMESTAMP, - TIMESTAMP_MS, - TIMESTAMP_NS, - TIMESTAMP_S, - TIME, - TIME_TZ, - TIMESTAMP_TZ, VARCHAR, - BLOB, - BIT, - INTERVAL, +) +from duckdb.value.constant import ( + BinaryValue, + BitValue, + BlobValue, + BooleanValue, + DateValue, + DecimalValue, + DoubleValue, + FloatValue, + HugeIntegerValue, + IntegerValue, + IntervalValue, + LongValue, + NullValue, + ShortValue, + StringValue, + TimestampMilisecondValue, + TimestampNanosecondValue, + TimestampSecondValue, + TimestampValue, + TimeValue, + UnsignedBinaryValue, + UnsignedHugeIntegerValue, + UnsignedIntegerValue, + UnsignedLongValue, + UnsignedShortValue, + UUIDValue, + Value, ) -class TestValue(object): +class TestValue: # This excludes timezone aware values, as those are a pain to test @pytest.mark.parametrize( "item", diff --git a/tests/fast/test_version.py b/tests/fast/test_version.py index cdeb42b0..81f72855 100644 --- a/tests/fast/test_version.py +++ b/tests/fast/test_version.py @@ -1,6 +1,7 @@ -import duckdb import sys +import duckdb + def test_version(): assert duckdb.__version__ != "0.0.0" diff --git a/tests/fast/test_versioning.py b/tests/fast/test_versioning.py index 2ec3f784..207b24fe 100644 --- a/tests/fast/test_versioning.py +++ b/tests/fast/test_versioning.py @@ -1,25 +1,24 @@ -""" -Tests for duckdb_pytooling versioning functionality. +"""Tests for duckdb_pytooling versioning functionality. """ import os +import subprocess import unittest +from unittest.mock import MagicMock, patch import pytest -import subprocess -from unittest.mock import patch, MagicMock duckdb_packaging = pytest.importorskip("duckdb_packaging") from duckdb_packaging._versioning import ( - parse_version, format_version, - git_tag_to_pep440, - pep440_to_git_tag, get_current_version, get_git_describe, + git_tag_to_pep440, + parse_version, + pep440_to_git_tag, ) -from duckdb_packaging.setuptools_scm_version import _bump_version, version_scheme, forced_version_from_env +from duckdb_packaging.setuptools_scm_version import _bump_version, forced_version_from_env, version_scheme class TestVersionParsing(unittest.TestCase): diff --git a/tests/fast/test_windows_abs_path.py b/tests/fast/test_windows_abs_path.py index 4ce8311b..7cc31d0b 100644 --- a/tests/fast/test_windows_abs_path.py +++ b/tests/fast/test_windows_abs_path.py @@ -1,10 +1,10 @@ -import duckdb -import pytest import os import shutil +import duckdb + -class TestWindowsAbsPath(object): +class TestWindowsAbsPath: def test_windows_path_accent(self): if os.name != "nt": return diff --git a/tests/fast/types/test_blob.py b/tests/fast/types/test_blob.py index 0d331f7f..74f7f0b8 100644 --- a/tests/fast/types/test_blob.py +++ b/tests/fast/types/test_blob.py @@ -1,8 +1,7 @@ -import duckdb import numpy -class TestBlob(object): +class TestBlob: def test_blob(self, duckdb_cursor): duckdb_cursor.execute("SELECT BLOB 'hello'") results = duckdb_cursor.fetchall() diff --git a/tests/fast/types/test_boolean.py b/tests/fast/types/test_boolean.py index 8e8d2147..5a519e51 100644 --- a/tests/fast/types/test_boolean.py +++ b/tests/fast/types/test_boolean.py @@ -1,8 +1,6 @@ -import duckdb -import numpy -class TestBoolean(object): +class TestBoolean: def test_bool(self, duckdb_cursor): duckdb_cursor.execute("SELECT TRUE") results = duckdb_cursor.fetchall() diff --git a/tests/fast/types/test_datetime_date.py b/tests/fast/types/test_datetime_date.py index 9efb6bd1..d1c3d30b 100644 --- a/tests/fast/types/test_datetime_date.py +++ b/tests/fast/types/test_datetime_date.py @@ -1,8 +1,9 @@ -import duckdb import datetime +import duckdb + -class TestDateTimeDate(object): +class TestDateTimeDate: def test_date_infinity(self): con = duckdb.connect() # Positive infinity diff --git a/tests/fast/types/test_datetime_datetime.py b/tests/fast/types/test_datetime_datetime.py index 2df14b18..c486f9c9 100644 --- a/tests/fast/types/test_datetime_datetime.py +++ b/tests/fast/types/test_datetime_datetime.py @@ -1,7 +1,9 @@ -import duckdb import datetime + import pytest +import duckdb + def create_query(positive, type): inf = "infinity" if positive else "-infinity" @@ -10,7 +12,7 @@ def create_query(positive, type): """ -class TestDateTimeDateTime(object): +class TestDateTimeDateTime: @pytest.mark.parametrize("positive", [True, False]) @pytest.mark.parametrize( "type", diff --git a/tests/fast/types/test_decimal.py b/tests/fast/types/test_decimal.py index b068056d..8be55e44 100644 --- a/tests/fast/types/test_decimal.py +++ b/tests/fast/types/test_decimal.py @@ -1,9 +1,9 @@ -import numpy -import pandas from decimal import * +import numpy + -class TestDecimal(object): +class TestDecimal: def test_decimal(self, duckdb_cursor): duckdb_cursor.execute( "SELECT 1.2::DECIMAL(4,1), 100.3::DECIMAL(9,1), 320938.4298::DECIMAL(18,4), 49082094824.904820482094::DECIMAL(30,12), NULL::DECIMAL" diff --git a/tests/fast/types/test_hugeint.py b/tests/fast/types/test_hugeint.py index e9b5016a..aa8c900d 100644 --- a/tests/fast/types/test_hugeint.py +++ b/tests/fast/types/test_hugeint.py @@ -1,8 +1,7 @@ import numpy -import pandas -class TestHugeint(object): +class TestHugeint: def test_hugeint(self, duckdb_cursor): duckdb_cursor.execute("SELECT 437894723897234238947043214") result = duckdb_cursor.fetchall() diff --git a/tests/fast/types/test_nan.py b/tests/fast/types/test_nan.py index fe99a990..8ffbe1bc 100644 --- a/tests/fast/types/test_nan.py +++ b/tests/fast/types/test_nan.py @@ -1,12 +1,14 @@ -import numpy as np import datetime -import duckdb + +import numpy as np import pytest +import duckdb + pandas = pytest.importorskip("pandas") -class TestPandasNaN(object): +class TestPandasNaN: def test_pandas_nan(self, duckdb_cursor): # create a DataFrame with some basic values df = pandas.DataFrame([{"col1": "val1", "col2": 1.05}, {"col1": "val3", "col2": np.nan}]) diff --git a/tests/fast/types/test_nested.py b/tests/fast/types/test_nested.py index 7f777384..e82673c7 100644 --- a/tests/fast/types/test_nested.py +++ b/tests/fast/types/test_nested.py @@ -1,7 +1,6 @@ -import duckdb -class TestNested(object): +class TestNested: def test_lists(self, duckdb_cursor): result = duckdb_cursor.execute("SELECT LIST_VALUE(1, 2, 3, 4) ").fetchall() assert result == [([1, 2, 3, 4],)] diff --git a/tests/fast/types/test_null.py b/tests/fast/types/test_null.py index fa4105b6..e5fe2e3d 100644 --- a/tests/fast/types/test_null.py +++ b/tests/fast/types/test_null.py @@ -1,7 +1,6 @@ -import traceback -class TestNull(object): +class TestNull: def test_fetchone_null(self, duckdb_cursor): duckdb_cursor.execute("CREATE TABLE atable (Value int)") duckdb_cursor.execute("INSERT INTO atable VALUES (1)") diff --git a/tests/fast/types/test_numeric.py b/tests/fast/types/test_numeric.py index f25b72b1..174700aa 100644 --- a/tests/fast/types/test_numeric.py +++ b/tests/fast/types/test_numeric.py @@ -1,5 +1,3 @@ -import duckdb -import numpy def check_result(duckdb_cursor, value, type): @@ -8,7 +6,7 @@ def check_result(duckdb_cursor, value, type): assert results[0][0] == value -class TestNumeric(object): +class TestNumeric: def test_numeric_results(self, duckdb_cursor): check_result(duckdb_cursor, 1, "TINYINT") check_result(duckdb_cursor, 1, "SMALLINT") diff --git a/tests/fast/types/test_numpy.py b/tests/fast/types/test_numpy.py index 40b1a5de..b5fe6b3c 100644 --- a/tests/fast/types/test_numpy.py +++ b/tests/fast/types/test_numpy.py @@ -1,10 +1,11 @@ -import duckdb -import numpy as np import datetime -import pytest + +import numpy as np + +import duckdb -class TestNumpyDatetime64(object): +class TestNumpyDatetime64: def test_numpy_datetime64(self, duckdb_cursor): duckdb_con = duckdb.connect() diff --git a/tests/fast/types/test_object_int.py b/tests/fast/types/test_object_int.py index ed3a8d14..f0665535 100644 --- a/tests/fast/types/test_object_int.py +++ b/tests/fast/types/test_object_int.py @@ -1,12 +1,13 @@ -import numpy as np -import datetime -import duckdb -import pytest import warnings from contextlib import suppress +import numpy as np +import pytest + +import duckdb + -class TestPandasObjectInteger(object): +class TestPandasObjectInteger: # Signed Masked Integer types def test_object_integer(self, duckdb_cursor): pd = pytest.importorskip("pandas") diff --git a/tests/fast/types/test_time_tz.py b/tests/fast/types/test_time_tz.py index eceed79a..2215a046 100644 --- a/tests/fast/types/test_time_tz.py +++ b/tests/fast/types/test_time_tz.py @@ -1,17 +1,16 @@ -import numpy as np +import datetime from datetime import time, timezone -import duckdb + import pytest -import datetime pandas = pytest.importorskip("pandas") -class TestTimeTz(object): +class TestTimeTz: def test_time_tz(self, duckdb_cursor): df = pandas.DataFrame({"col1": [time(1, 2, 3, tzinfo=timezone.utc)]}) - sql = f"SELECT * FROM df" + sql = "SELECT * FROM df" duckdb_cursor.execute(sql) diff --git a/tests/fast/types/test_unsigned.py b/tests/fast/types/test_unsigned.py index a35a2216..5639d33b 100644 --- a/tests/fast/types/test_unsigned.py +++ b/tests/fast/types/test_unsigned.py @@ -1,4 +1,4 @@ -class TestUnsigned(object): +class TestUnsigned: def test_unsigned(self, duckdb_cursor): duckdb_cursor.execute("create table unsigned (a utinyint, b usmallint, c uinteger, d ubigint)") duckdb_cursor.execute("insert into unsigned values (1,1,1,1), (null,null,null,null)") diff --git a/tests/fast/udf/test_null_filtering.py b/tests/fast/udf/test_null_filtering.py index fd5b45d0..db86168c 100644 --- a/tests/fast/udf/test_null_filtering.py +++ b/tests/fast/udf/test_null_filtering.py @@ -1,15 +1,12 @@ -import duckdb import pytest +import duckdb + pd = pytest.importorskip("pandas") pa = pytest.importorskip("pyarrow", "18.0.0") -from typing import Union -import pyarrow.compute as pc -import uuid import datetime -import numpy as np -import cmath -from typing import NamedTuple, Any, List +import uuid +from typing import Any, NamedTuple from duckdb.typing import * @@ -152,7 +149,7 @@ def construct_parameters(tuples, dbtype): return parameters -class TestUDFNullFiltering(object): +class TestUDFNullFiltering: @pytest.mark.parametrize( "table_data", get_table_data(), diff --git a/tests/fast/udf/test_remove_function.py b/tests/fast/udf/test_remove_function.py index d03fd7e6..c909c61d 100644 --- a/tests/fast/udf/test_remove_function.py +++ b/tests/fast/udf/test_remove_function.py @@ -1,20 +1,15 @@ -import duckdb -import os + import pytest +import duckdb + pd = pytest.importorskip("pandas") pa = pytest.importorskip("pyarrow") -from typing import Union -import pyarrow.compute as pc -import uuid -import datetime -import numpy as np -import cmath from duckdb.typing import * -class TestRemoveFunction(object): +class TestRemoveFunction: def test_not_created(self): con = duckdb.connect() with pytest.raises( diff --git a/tests/fast/udf/test_scalar.py b/tests/fast/udf/test_scalar.py index c156f94b..e8b1e6d9 100644 --- a/tests/fast/udf/test_scalar.py +++ b/tests/fast/udf/test_scalar.py @@ -1,15 +1,16 @@ -import duckdb -import os + import pytest +import duckdb + pd = pytest.importorskip("pandas") pa = pytest.importorskip("pyarrow", "18.0.0") -from typing import Union, Any -import pyarrow.compute as pc -import uuid +import cmath import datetime +import uuid +from typing import Any + import numpy as np -import cmath from duckdb.typing import * @@ -29,7 +30,7 @@ def test_base(x): return test_function -class TestScalarUDF(object): +class TestScalarUDF: @pytest.mark.parametrize("function_type", ["native", "arrow"]) @pytest.mark.parametrize( "test_type", @@ -69,16 +70,16 @@ def test_type_coverage(self, test_type, function_type): con = duckdb.connect() con.create_function("test", test_function, type=function_type) # Single value - res = con.execute(f"select test(?::{str(type)})", [value]).fetchall() + res = con.execute(f"select test(?::{type!s})", [value]).fetchall() assert res[0][0] == value # NULLs - res = con.execute(f"select res from (select ?, test(NULL::{str(type)}) as res)", [value]).fetchall() + res = con.execute(f"select res from (select ?, test(NULL::{type!s}) as res)", [value]).fetchall() assert res[0][0] == None # Multiple chunks size = duckdb.__standard_vector_size__ * 3 - res = con.execute(f"select test(x) from repeat(?::{str(type)}, {size}) as tbl(x)", [value]).fetchall() + res = con.execute(f"select test(x) from repeat(?::{type!s}, {size}) as tbl(x)", [value]).fetchall() assert len(res) == size # Mixed NULL/NON-NULL @@ -88,7 +89,7 @@ def test_type_coverage(self, test_type, function_type): f""" select test( case when (x > 0.5) then - ?::{str(type)} + ?::{type!s} else NULL end @@ -102,7 +103,7 @@ def test_type_coverage(self, test_type, function_type): f""" select case when (x > 0.5) then - ?::{str(type)} + ?::{type!s} else NULL end @@ -113,7 +114,7 @@ def test_type_coverage(self, test_type, function_type): assert expected == actual # Using 'relation.project' - con.execute(f"create table tbl as select ?::{str(type)} as x", [value]) + con.execute(f"create table tbl as select ?::{type!s} as x", [value]) table_rel = con.table("tbl") res = table_rel.project("test(x)").fetchall() assert res[0][0] == value @@ -221,7 +222,6 @@ def return_np_nan(): @pytest.mark.parametrize("duckdb_type", [FLOAT, DOUBLE]) def test_math_nan(self, duckdb_type, udf_type): def return_math_nan(): - import cmath if udf_type == "native": return cmath.nan diff --git a/tests/fast/udf/test_scalar_arrow.py b/tests/fast/udf/test_scalar_arrow.py index 794ebc35..856c760d 100644 --- a/tests/fast/udf/test_scalar_arrow.py +++ b/tests/fast/udf/test_scalar_arrow.py @@ -1,18 +1,15 @@ -import duckdb -import os + import pytest +import duckdb + pd = pytest.importorskip("pandas") pa = pytest.importorskip("pyarrow") -from typing import Union -import pyarrow.compute as pc -import uuid -import datetime from duckdb.typing import * -class TestPyArrowUDF(object): +class TestPyArrowUDF: def test_basic_use(self): def plus_one(x): table = pa.lib.Table.from_arrays([x], names=["c0"]) @@ -24,7 +21,7 @@ def plus_one(x): con = duckdb.connect() con.create_function("plus_one", plus_one, [BIGINT], BIGINT, type="arrow") - assert [(6,)] == con.sql("select plus_one(5)").fetchall() + assert con.sql("select plus_one(5)").fetchall() == [(6,)] range_table = con.table_function("range", [5000]) res = con.sql("select plus_one(i) from range_table tbl(i)").fetchall() @@ -125,7 +122,6 @@ def return_too_many(col): res = con.sql("""select too_many_tuples(5)""").fetchall() def test_arrow_side_effects(self, duckdb_cursor): - import random as r def random_arrow(x): if not hasattr(random_arrow, "data"): diff --git a/tests/fast/udf/test_scalar_native.py b/tests/fast/udf/test_scalar_native.py index 0c5cf927..94b2949e 100644 --- a/tests/fast/udf/test_scalar_native.py +++ b/tests/fast/udf/test_scalar_native.py @@ -1,12 +1,11 @@ -import duckdb -import os -import pandas as pd + import pytest +import duckdb from duckdb.typing import * -class TestNativeUDF(object): +class TestNativeUDF: def test_default_conn(self): def passthrough(x): return x @@ -23,7 +22,7 @@ def plus_one(x): con = duckdb.connect() con.create_function("plus_one", plus_one, [BIGINT], BIGINT) - assert [(6,)] == con.sql("select plus_one(5)").fetchall() + assert con.sql("select plus_one(5)").fetchall() == [(6,)] range_table = con.table_function("range", [5000]) res = con.sql("select plus_one(i) from range_table tbl(i)").fetchall() diff --git a/tests/fast/udf/test_transactionality.py b/tests/fast/udf/test_transactionality.py index 134df663..acad21ef 100644 --- a/tests/fast/udf/test_transactionality.py +++ b/tests/fast/udf/test_transactionality.py @@ -1,8 +1,9 @@ -import duckdb import pytest +import duckdb + -class TestUDFTransactionality(object): +class TestUDFTransactionality: @pytest.mark.xfail(reason="fetchone() does not realize the stream result was closed before completion") def test_type_coverage(self, duckdb_cursor): rel = duckdb_cursor.sql("select * from range(4096)") diff --git a/tests/slow/test_h2oai_arrow.py b/tests/slow/test_h2oai_arrow.py index b0901ab8..d0dbc2fe 100644 --- a/tests/slow/test_h2oai_arrow.py +++ b/tests/slow/test_h2oai_arrow.py @@ -1,7 +1,9 @@ -import duckdb -import os import math -from pytest import mark, fixture, importorskip +import os + +from pytest import fixture, importorskip, mark + +import duckdb read_csv = importorskip("pyarrow.csv").read_csv requests = importorskip("requests") @@ -153,7 +155,7 @@ def join_by_q5(con): con.execute("DROP TABLE ans") -class TestH2OAIArrow(object): +class TestH2OAIArrow: @mark.parametrize( "function", [ From c499e404228b6c7e52ecd9b1cfb893f466a6a0c3 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Tue, 9 Sep 2025 20:29:34 +0200 Subject: [PATCH 013/135] Ruff format fixes --- adbc_driver_duckdb/dbapi.py | 3 +- duckdb/experimental/spark/errors/__init__.py | 3 +- .../spark/errors/exceptions/base.py | 72 +++++++------------ duckdb/experimental/spark/errors/utils.py | 6 +- duckdb/experimental/spark/sql/dataframe.py | 3 +- duckdb/experimental/spark/sql/streaming.py | 1 - duckdb/experimental/spark/sql/types.py | 21 ++---- duckdb/polars_io.py | 3 +- duckdb/udf.py | 3 +- duckdb_packaging/pypi_cleanup.py | 3 - sqllogic/conftest.py | 3 +- tests/conftest.py | 1 - tests/fast/api/test_attribute_getter.py | 1 - tests/fast/api/test_dbapi12.py | 1 - tests/fast/arrow/test_2426.py | 2 - tests/fast/arrow/test_arrow_fetch.py | 2 - tests/fast/arrow/test_arrow_pycapsule.py | 1 - tests/fast/arrow/test_binary_type.py | 1 - tests/fast/arrow/test_date.py | 2 - tests/fast/arrow/test_dictionary_arrow.py | 1 - tests/fast/arrow/test_interval.py | 1 - tests/fast/arrow/test_large_string.py | 1 - tests/fast/arrow/test_time.py | 2 - tests/fast/pandas/test_df_analyze.py | 1 - tests/fast/pandas/test_df_recursive_nested.py | 1 - tests/fast/pandas/test_fetch_nested.py | 1 - tests/fast/pandas/test_pandas_limit.py | 1 - .../pandas/test_partitioned_pandas_scan.py | 1 - tests/fast/pandas/test_progress_bar.py | 1 - .../test_pyarrow_projection_pushdown.py | 1 - .../relational_api/test_rapi_aggregations.py | 1 - tests/fast/spark/test_spark_functions_hex.py | 1 - tests/fast/test_all_types.py | 1 - tests/fast/test_ambiguous_prepare.py | 1 - tests/fast/test_case_alias.py | 1 - tests/fast/test_insert.py | 1 - tests/fast/test_multi_statement.py | 1 - tests/fast/test_string_annotation.py | 1 - tests/fast/test_transaction.py | 1 - tests/fast/test_type.py | 1 - tests/fast/test_union.py | 1 - tests/fast/test_versioning.py | 3 +- tests/fast/types/test_boolean.py | 2 - tests/fast/types/test_nested.py | 2 - tests/fast/types/test_null.py | 2 - tests/fast/types/test_numeric.py | 2 - tests/fast/udf/test_remove_function.py | 1 - tests/fast/udf/test_scalar.py | 2 - tests/fast/udf/test_scalar_arrow.py | 2 - tests/fast/udf/test_scalar_native.py | 1 - 50 files changed, 40 insertions(+), 132 deletions(-) diff --git a/adbc_driver_duckdb/dbapi.py b/adbc_driver_duckdb/dbapi.py index 7d703713..5d0a8702 100644 --- a/adbc_driver_duckdb/dbapi.py +++ b/adbc_driver_duckdb/dbapi.py @@ -15,8 +15,7 @@ # specific language governing permissions and limitations # under the License. -"""DBAPI 2.0-compatible facade for the ADBC DuckDB driver. -""" +"""DBAPI 2.0-compatible facade for the ADBC DuckDB driver.""" import typing diff --git a/duckdb/experimental/spark/errors/__init__.py b/duckdb/experimental/spark/errors/__init__.py index 2f265d97..ee7688ea 100644 --- a/duckdb/experimental/spark/errors/__init__.py +++ b/duckdb/experimental/spark/errors/__init__.py @@ -15,8 +15,7 @@ # limitations under the License. # -"""PySpark exceptions. -""" +"""PySpark exceptions.""" from .exceptions.base import ( AnalysisException, diff --git a/duckdb/experimental/spark/errors/exceptions/base.py b/duckdb/experimental/spark/errors/exceptions/base.py index a6f1f940..0b2c6a43 100644 --- a/duckdb/experimental/spark/errors/exceptions/base.py +++ b/duckdb/experimental/spark/errors/exceptions/base.py @@ -4,8 +4,7 @@ class PySparkException(Exception): - """Base Exception for handling errors generated from PySpark. - """ + """Base Exception for handling errors generated from PySpark.""" def __init__( self, @@ -78,115 +77,92 @@ def __str__(self) -> str: class AnalysisException(PySparkException): - """Failed to analyze a SQL query plan. - """ + """Failed to analyze a SQL query plan.""" class SessionNotSameException(PySparkException): - """Performed the same operation on different SparkSession. - """ + """Performed the same operation on different SparkSession.""" class TempTableAlreadyExistsException(AnalysisException): - """Failed to create temp view since it is already exists. - """ + """Failed to create temp view since it is already exists.""" class ParseException(AnalysisException): - """Failed to parse a SQL command. - """ + """Failed to parse a SQL command.""" class IllegalArgumentException(PySparkException): - """Passed an illegal or inappropriate argument. - """ + """Passed an illegal or inappropriate argument.""" class ArithmeticException(PySparkException): - """Arithmetic exception thrown from Spark with an error class. - """ + """Arithmetic exception thrown from Spark with an error class.""" class UnsupportedOperationException(PySparkException): - """Unsupported operation exception thrown from Spark with an error class. - """ + """Unsupported operation exception thrown from Spark with an error class.""" class ArrayIndexOutOfBoundsException(PySparkException): - """Array index out of bounds exception thrown from Spark with an error class. - """ + """Array index out of bounds exception thrown from Spark with an error class.""" class DateTimeException(PySparkException): - """Datetime exception thrown from Spark with an error class. - """ + """Datetime exception thrown from Spark with an error class.""" class NumberFormatException(IllegalArgumentException): - """Number format exception thrown from Spark with an error class. - """ + """Number format exception thrown from Spark with an error class.""" class StreamingQueryException(PySparkException): - """Exception that stopped a :class:`StreamingQuery`. - """ + """Exception that stopped a :class:`StreamingQuery`.""" class QueryExecutionException(PySparkException): - """Failed to execute a query. - """ + """Failed to execute a query.""" class PythonException(PySparkException): - """Exceptions thrown from Python workers. - """ + """Exceptions thrown from Python workers.""" class SparkRuntimeException(PySparkException): - """Runtime exception thrown from Spark with an error class. - """ + """Runtime exception thrown from Spark with an error class.""" class SparkUpgradeException(PySparkException): - """Exception thrown because of Spark upgrade. - """ + """Exception thrown because of Spark upgrade.""" class UnknownException(PySparkException): - """None of the above exceptions. - """ + """None of the above exceptions.""" class PySparkValueError(PySparkException, ValueError): - """Wrapper class for ValueError to support error classes. - """ + """Wrapper class for ValueError to support error classes.""" class PySparkIndexError(PySparkException, IndexError): - """Wrapper class for IndexError to support error classes. - """ + """Wrapper class for IndexError to support error classes.""" class PySparkTypeError(PySparkException, TypeError): - """Wrapper class for TypeError to support error classes. - """ + """Wrapper class for TypeError to support error classes.""" class PySparkAttributeError(PySparkException, AttributeError): - """Wrapper class for AttributeError to support error classes. - """ + """Wrapper class for AttributeError to support error classes.""" class PySparkRuntimeError(PySparkException, RuntimeError): - """Wrapper class for RuntimeError to support error classes. - """ + """Wrapper class for RuntimeError to support error classes.""" class PySparkAssertionError(PySparkException, AssertionError): - """Wrapper class for AssertionError to support error classes. - """ + """Wrapper class for AssertionError to support error classes.""" class PySparkNotImplementedError(PySparkException, NotImplementedError): - """Wrapper class for NotImplementedError to support error classes. - """ + """Wrapper class for NotImplementedError to support error classes.""" diff --git a/duckdb/experimental/spark/errors/utils.py b/duckdb/experimental/spark/errors/utils.py index c8c66896..8b737dde 100644 --- a/duckdb/experimental/spark/errors/utils.py +++ b/duckdb/experimental/spark/errors/utils.py @@ -21,15 +21,13 @@ class ErrorClassesReader: - """A reader to load error information from error_classes.py. - """ + """A reader to load error information from error_classes.py.""" def __init__(self) -> None: self.error_info_map = ERROR_CLASSES_MAP def get_error_message(self, error_class: str, message_parameters: dict[str, str]) -> str: - """Returns the completed error message by applying message parameters to the message template. - """ + """Returns the completed error message by applying message parameters to the message template.""" message_template = self.get_message_template(error_class) # Verify message parameters. message_parameters_from_template = re.findall("<([a-zA-Z0-9_-]+)>", message_template) diff --git a/duckdb/experimental/spark/sql/dataframe.py b/duckdb/experimental/spark/sql/dataframe.py index 3f32aa32..d0d4835d 100644 --- a/duckdb/experimental/spark/sql/dataframe.py +++ b/duckdb/experimental/spark/sql/dataframe.py @@ -843,8 +843,7 @@ def limit(self, num: int) -> "DataFrame": return DataFrame(rel, self.session) def __contains__(self, item: str) -> bool: - """Check if the :class:`DataFrame` contains a column by the name of `item` - """ + """Check if the :class:`DataFrame` contains a column by the name of `item`""" return item in self.relation @property diff --git a/duckdb/experimental/spark/sql/streaming.py b/duckdb/experimental/spark/sql/streaming.py index ba54db60..201b889b 100644 --- a/duckdb/experimental/spark/sql/streaming.py +++ b/duckdb/experimental/spark/sql/streaming.py @@ -30,7 +30,6 @@ def load( schema: Union[StructType, str, None] = None, **options: OptionalPrimitiveType, ) -> "DataFrame": - raise NotImplementedError diff --git a/duckdb/experimental/spark/sql/types.py b/duckdb/experimental/spark/sql/types.py index d8a04b8e..9d2b4b7d 100644 --- a/duckdb/experimental/spark/sql/types.py +++ b/duckdb/experimental/spark/sql/types.py @@ -102,13 +102,11 @@ def needConversion(self) -> bool: return False def toInternal(self, obj: Any) -> Any: - """Converts a Python object into an internal SQL object. - """ + """Converts a Python object into an internal SQL object.""" return obj def fromInternal(self, obj: Any) -> Any: - """Converts an internal SQL object into a native Python object. - """ + """Converts an internal SQL object into a native Python object.""" return obj @@ -979,14 +977,12 @@ def typeName(cls) -> str: @classmethod def sqlType(cls) -> DataType: - """Underlying SQL storage type for this UDT. - """ + """Underlying SQL storage type for this UDT.""" raise NotImplementedError("UDT must implement sqlType().") @classmethod def module(cls) -> str: - """The Python module of the UDT. - """ + """The Python module of the UDT.""" raise NotImplementedError("UDT must implement module().") @classmethod @@ -1001,8 +997,7 @@ def needConversion(self) -> bool: @classmethod def _cachedSqlType(cls) -> DataType: - """Cache the sqlType() into class, because it's heavily used in `toInternal`. - """ + """Cache the sqlType() into class, because it's heavily used in `toInternal`.""" if not hasattr(cls, "_cached_sql_type"): cls._cached_sql_type = cls.sqlType() # type: ignore[attr-defined] return cls._cached_sql_type # type: ignore[attr-defined] @@ -1017,13 +1012,11 @@ def fromInternal(self, obj: Any) -> Any: return self.deserialize(v) def serialize(self, obj: Any) -> Any: - """Converts a user-type object into a SQL datum. - """ + """Converts a user-type object into a SQL datum.""" raise NotImplementedError("UDT must implement toInternal().") def deserialize(self, datum: Any) -> Any: - """Converts a SQL datum into a user-type object. - """ + """Converts a SQL datum into a user-type object.""" raise NotImplementedError("UDT must implement fromInternal().") def simpleString(self) -> str: diff --git a/duckdb/polars_io.py b/duckdb/polars_io.py index b1fc244c..59758f19 100644 --- a/duckdb/polars_io.py +++ b/duckdb/polars_io.py @@ -206,8 +206,7 @@ def _pl_tree_to_sql(tree: dict) -> str: def duckdb_source(relation: duckdb.DuckDBPyRelation, schema: pl.schema.Schema) -> pl.LazyFrame: - """A polars IO plugin for DuckDB. - """ + """A polars IO plugin for DuckDB.""" def source_generator( with_columns: Optional[list[str]], diff --git a/duckdb/udf.py b/duckdb/udf.py index 0eb59ba9..21d6d53f 100644 --- a/duckdb/udf.py +++ b/duckdb/udf.py @@ -1,6 +1,5 @@ def vectorized(func): - """Decorate a function with annotated function parameters, so DuckDB can infer that the function should be provided with pyarrow arrays and should expect pyarrow array(s) as output - """ + """Decorate a function with annotated function parameters, so DuckDB can infer that the function should be provided with pyarrow arrays and should expect pyarrow array(s) as output""" import types from inspect import signature diff --git a/duckdb_packaging/pypi_cleanup.py b/duckdb_packaging/pypi_cleanup.py index 8e91b34f..b45cf1a1 100644 --- a/duckdb_packaging/pypi_cleanup.py +++ b/duckdb_packaging/pypi_cleanup.py @@ -78,17 +78,14 @@ class PyPICleanupError(Exception): """Base exception for PyPI cleanup operations.""" - class AuthenticationError(PyPICleanupError): """Raised when authentication fails.""" - class ValidationError(PyPICleanupError): """Raised when input validation fails.""" - def setup_logging(verbose: bool = False) -> None: """Configure logging with appropriate level and format.""" level = logging.DEBUG if verbose else logging.INFO diff --git a/sqllogic/conftest.py b/sqllogic/conftest.py index 8d772111..77281d54 100644 --- a/sqllogic/conftest.py +++ b/sqllogic/conftest.py @@ -268,8 +268,7 @@ def pytest_collection_modifyitems(session: pytest.Session, config: pytest.Config def pytest_runtest_setup(item: pytest.Item): - """Show the test index after the test name - """ + """Show the test index after the test name""" def get_from_tuple_list(tuples, key): for t in tuples: diff --git a/tests/conftest.py b/tests/conftest.py index 83c10f3a..cc385c31 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -228,7 +228,6 @@ def _require(extension_name, db_name=""): # By making the scope 'function' we ensure that a new connection gets created for every function that uses the fixture @pytest.fixture(scope="function") def spark(): - if not hasattr(spark, "session"): # Cache the import from spark_namespace.sql import SparkSession as session diff --git a/tests/fast/api/test_attribute_getter.py b/tests/fast/api/test_attribute_getter.py index 3b1513d1..208ccc40 100644 --- a/tests/fast/api/test_attribute_getter.py +++ b/tests/fast/api/test_attribute_getter.py @@ -1,4 +1,3 @@ - import pytest import duckdb diff --git a/tests/fast/api/test_dbapi12.py b/tests/fast/api/test_dbapi12.py index 96b1deac..f8dcdbe6 100644 --- a/tests/fast/api/test_dbapi12.py +++ b/tests/fast/api/test_dbapi12.py @@ -1,4 +1,3 @@ - import pandas as pd import duckdb diff --git a/tests/fast/arrow/test_2426.py b/tests/fast/arrow/test_2426.py index 5e6d42ef..a4bdeff7 100644 --- a/tests/fast/arrow/test_2426.py +++ b/tests/fast/arrow/test_2426.py @@ -1,8 +1,6 @@ - import duckdb try: - can_run = True except: can_run = False diff --git a/tests/fast/arrow/test_arrow_fetch.py b/tests/fast/arrow/test_arrow_fetch.py index 62460912..11deab23 100644 --- a/tests/fast/arrow/test_arrow_fetch.py +++ b/tests/fast/arrow/test_arrow_fetch.py @@ -1,8 +1,6 @@ - import duckdb try: - can_run = True except: can_run = False diff --git a/tests/fast/arrow/test_arrow_pycapsule.py b/tests/fast/arrow/test_arrow_pycapsule.py index 295f0292..0799c206 100644 --- a/tests/fast/arrow/test_arrow_pycapsule.py +++ b/tests/fast/arrow/test_arrow_pycapsule.py @@ -1,4 +1,3 @@ - import pytest import duckdb diff --git a/tests/fast/arrow/test_binary_type.py b/tests/fast/arrow/test_binary_type.py index 5932fba8..0a0062f5 100644 --- a/tests/fast/arrow/test_binary_type.py +++ b/tests/fast/arrow/test_binary_type.py @@ -1,4 +1,3 @@ - import duckdb try: diff --git a/tests/fast/arrow/test_date.py b/tests/fast/arrow/test_date.py index 83c14932..bebb55a0 100644 --- a/tests/fast/arrow/test_date.py +++ b/tests/fast/arrow/test_date.py @@ -1,5 +1,3 @@ - - import duckdb try: diff --git a/tests/fast/arrow/test_dictionary_arrow.py b/tests/fast/arrow/test_dictionary_arrow.py index 5cb2d38d..1b24c2b9 100644 --- a/tests/fast/arrow/test_dictionary_arrow.py +++ b/tests/fast/arrow/test_dictionary_arrow.py @@ -1,4 +1,3 @@ - import pytest pa = pytest.importorskip("pyarrow") diff --git a/tests/fast/arrow/test_interval.py b/tests/fast/arrow/test_interval.py index 5cdb04bd..7d3ec128 100644 --- a/tests/fast/arrow/test_interval.py +++ b/tests/fast/arrow/test_interval.py @@ -1,4 +1,3 @@ - import pytest import duckdb diff --git a/tests/fast/arrow/test_large_string.py b/tests/fast/arrow/test_large_string.py index bb9d1b5b..d6a4c76a 100644 --- a/tests/fast/arrow/test_large_string.py +++ b/tests/fast/arrow/test_large_string.py @@ -1,4 +1,3 @@ - import duckdb try: diff --git a/tests/fast/arrow/test_time.py b/tests/fast/arrow/test_time.py index b3bab360..ff16002c 100644 --- a/tests/fast/arrow/test_time.py +++ b/tests/fast/arrow/test_time.py @@ -1,5 +1,3 @@ - - import duckdb try: diff --git a/tests/fast/pandas/test_df_analyze.py b/tests/fast/pandas/test_df_analyze.py index 92318085..e1e0a2a7 100644 --- a/tests/fast/pandas/test_df_analyze.py +++ b/tests/fast/pandas/test_df_analyze.py @@ -1,4 +1,3 @@ - import numpy as np import pytest from conftest import ArrowPandas, NumpyPandas diff --git a/tests/fast/pandas/test_df_recursive_nested.py b/tests/fast/pandas/test_df_recursive_nested.py index 4eacf777..4ef84c84 100644 --- a/tests/fast/pandas/test_df_recursive_nested.py +++ b/tests/fast/pandas/test_df_recursive_nested.py @@ -1,4 +1,3 @@ - import pytest from conftest import ArrowPandas, NumpyPandas diff --git a/tests/fast/pandas/test_fetch_nested.py b/tests/fast/pandas/test_fetch_nested.py index 6e878643..5b8cfe50 100644 --- a/tests/fast/pandas/test_fetch_nested.py +++ b/tests/fast/pandas/test_fetch_nested.py @@ -1,4 +1,3 @@ - import pytest import duckdb diff --git a/tests/fast/pandas/test_pandas_limit.py b/tests/fast/pandas/test_pandas_limit.py index 89fe1583..9c63cfdc 100644 --- a/tests/fast/pandas/test_pandas_limit.py +++ b/tests/fast/pandas/test_pandas_limit.py @@ -1,4 +1,3 @@ - import duckdb diff --git a/tests/fast/pandas/test_partitioned_pandas_scan.py b/tests/fast/pandas/test_partitioned_pandas_scan.py index 9f580659..c1ab7b34 100644 --- a/tests/fast/pandas/test_partitioned_pandas_scan.py +++ b/tests/fast/pandas/test_partitioned_pandas_scan.py @@ -1,4 +1,3 @@ - import numpy import pandas as pd diff --git a/tests/fast/pandas/test_progress_bar.py b/tests/fast/pandas/test_progress_bar.py index c8cfb2e0..5635edae 100644 --- a/tests/fast/pandas/test_progress_bar.py +++ b/tests/fast/pandas/test_progress_bar.py @@ -1,4 +1,3 @@ - import numpy import pandas as pd diff --git a/tests/fast/pandas/test_pyarrow_projection_pushdown.py b/tests/fast/pandas/test_pyarrow_projection_pushdown.py index 4191a96e..87f49f04 100644 --- a/tests/fast/pandas/test_pyarrow_projection_pushdown.py +++ b/tests/fast/pandas/test_pyarrow_projection_pushdown.py @@ -1,4 +1,3 @@ - import pytest from conftest import pandas_supports_arrow_backend diff --git a/tests/fast/relational_api/test_rapi_aggregations.py b/tests/fast/relational_api/test_rapi_aggregations.py index 31cb21c9..9cc0492b 100644 --- a/tests/fast/relational_api/test_rapi_aggregations.py +++ b/tests/fast/relational_api/test_rapi_aggregations.py @@ -1,4 +1,3 @@ - import pytest import duckdb diff --git a/tests/fast/spark/test_spark_functions_hex.py b/tests/fast/spark/test_spark_functions_hex.py index c58c6d90..54caaf28 100644 --- a/tests/fast/spark/test_spark_functions_hex.py +++ b/tests/fast/spark/test_spark_functions_hex.py @@ -1,4 +1,3 @@ - import pytest _ = pytest.importorskip("duckdb.experimental.spark") diff --git a/tests/fast/test_all_types.py b/tests/fast/test_all_types.py index e74cca30..be920cf8 100644 --- a/tests/fast/test_all_types.py +++ b/tests/fast/test_all_types.py @@ -27,7 +27,6 @@ def replace_with_ndarray(obj): # we need to write our own equality function that considers nan==nan for testing purposes def recursive_equality(o1, o2): - if type(o1) != type(o2): return False if type(o1) == float and math.isnan(o1) and math.isnan(o2): diff --git a/tests/fast/test_ambiguous_prepare.py b/tests/fast/test_ambiguous_prepare.py index 0865b007..48f217cd 100644 --- a/tests/fast/test_ambiguous_prepare.py +++ b/tests/fast/test_ambiguous_prepare.py @@ -1,4 +1,3 @@ - import duckdb diff --git a/tests/fast/test_case_alias.py b/tests/fast/test_case_alias.py index 5092f099..d1afb4d8 100644 --- a/tests/fast/test_case_alias.py +++ b/tests/fast/test_case_alias.py @@ -7,7 +7,6 @@ class TestCaseAlias: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_case_alias(self, duckdb_cursor, pandas): - con = duckdb.connect(":memory:") df = pandas.DataFrame([{"COL1": "val1", "CoL2": 1.05}, {"COL1": "val3", "CoL2": 17}]) diff --git a/tests/fast/test_insert.py b/tests/fast/test_insert.py index 34489b44..030d255a 100644 --- a/tests/fast/test_insert.py +++ b/tests/fast/test_insert.py @@ -1,4 +1,3 @@ - import pytest from conftest import ArrowPandas, NumpyPandas diff --git a/tests/fast/test_multi_statement.py b/tests/fast/test_multi_statement.py index cd3111e6..2b255375 100644 --- a/tests/fast/test_multi_statement.py +++ b/tests/fast/test_multi_statement.py @@ -6,7 +6,6 @@ class TestMultiStatement: def test_multi_statement(self, duckdb_cursor): - con = duckdb.connect(":memory:") # test empty statement diff --git a/tests/fast/test_string_annotation.py b/tests/fast/test_string_annotation.py index 17c22844..b8014740 100644 --- a/tests/fast/test_string_annotation.py +++ b/tests/fast/test_string_annotation.py @@ -18,7 +18,6 @@ def test_base(): def python_version_lower_than_3_10(): - if sys.version_info[0] < 3: return True if sys.version_info[1] < 10: diff --git a/tests/fast/test_transaction.py b/tests/fast/test_transaction.py index 4a06c9e7..0dfabafa 100644 --- a/tests/fast/test_transaction.py +++ b/tests/fast/test_transaction.py @@ -1,4 +1,3 @@ - import duckdb diff --git a/tests/fast/test_type.py b/tests/fast/test_type.py index 768b7782..4824ce7c 100644 --- a/tests/fast/test_type.py +++ b/tests/fast/test_type.py @@ -117,7 +117,6 @@ def test_union_type(self): type = duckdb.union_type({"a": BIGINT, "b": VARCHAR, "c": TINYINT}) assert str(type) == "UNION(a BIGINT, b VARCHAR, c TINYINT)" - @pytest.mark.skipif(sys.version_info < (3, 9), reason="requires >= python3.9") def test_implicit_convert_from_builtin_type(self): type = duckdb.list_type(list[str]) diff --git a/tests/fast/test_union.py b/tests/fast/test_union.py index d47a8192..8df17238 100644 --- a/tests/fast/test_union.py +++ b/tests/fast/test_union.py @@ -1,4 +1,3 @@ - import duckdb diff --git a/tests/fast/test_versioning.py b/tests/fast/test_versioning.py index 207b24fe..5f48c3cb 100644 --- a/tests/fast/test_versioning.py +++ b/tests/fast/test_versioning.py @@ -1,5 +1,4 @@ -"""Tests for duckdb_pytooling versioning functionality. -""" +"""Tests for duckdb_pytooling versioning functionality.""" import os import subprocess diff --git a/tests/fast/types/test_boolean.py b/tests/fast/types/test_boolean.py index 5a519e51..dfa67aaa 100644 --- a/tests/fast/types/test_boolean.py +++ b/tests/fast/types/test_boolean.py @@ -1,5 +1,3 @@ - - class TestBoolean: def test_bool(self, duckdb_cursor): duckdb_cursor.execute("SELECT TRUE") diff --git a/tests/fast/types/test_nested.py b/tests/fast/types/test_nested.py index e82673c7..824b2825 100644 --- a/tests/fast/types/test_nested.py +++ b/tests/fast/types/test_nested.py @@ -1,5 +1,3 @@ - - class TestNested: def test_lists(self, duckdb_cursor): result = duckdb_cursor.execute("SELECT LIST_VALUE(1, 2, 3, 4) ").fetchall() diff --git a/tests/fast/types/test_null.py b/tests/fast/types/test_null.py index e5fe2e3d..27f287c8 100644 --- a/tests/fast/types/test_null.py +++ b/tests/fast/types/test_null.py @@ -1,5 +1,3 @@ - - class TestNull: def test_fetchone_null(self, duckdb_cursor): duckdb_cursor.execute("CREATE TABLE atable (Value int)") diff --git a/tests/fast/types/test_numeric.py b/tests/fast/types/test_numeric.py index 174700aa..6540735d 100644 --- a/tests/fast/types/test_numeric.py +++ b/tests/fast/types/test_numeric.py @@ -1,5 +1,3 @@ - - def check_result(duckdb_cursor, value, type): duckdb_cursor.execute("SELECT " + str(value) + "::" + type) results = duckdb_cursor.fetchall() diff --git a/tests/fast/udf/test_remove_function.py b/tests/fast/udf/test_remove_function.py index c909c61d..2e7cc670 100644 --- a/tests/fast/udf/test_remove_function.py +++ b/tests/fast/udf/test_remove_function.py @@ -1,4 +1,3 @@ - import pytest import duckdb diff --git a/tests/fast/udf/test_scalar.py b/tests/fast/udf/test_scalar.py index e8b1e6d9..b7f4e343 100644 --- a/tests/fast/udf/test_scalar.py +++ b/tests/fast/udf/test_scalar.py @@ -1,4 +1,3 @@ - import pytest import duckdb @@ -222,7 +221,6 @@ def return_np_nan(): @pytest.mark.parametrize("duckdb_type", [FLOAT, DOUBLE]) def test_math_nan(self, duckdb_type, udf_type): def return_math_nan(): - if udf_type == "native": return cmath.nan else: diff --git a/tests/fast/udf/test_scalar_arrow.py b/tests/fast/udf/test_scalar_arrow.py index 856c760d..984a1f8c 100644 --- a/tests/fast/udf/test_scalar_arrow.py +++ b/tests/fast/udf/test_scalar_arrow.py @@ -1,4 +1,3 @@ - import pytest import duckdb @@ -122,7 +121,6 @@ def return_too_many(col): res = con.sql("""select too_many_tuples(5)""").fetchall() def test_arrow_side_effects(self, duckdb_cursor): - def random_arrow(x): if not hasattr(random_arrow, "data"): random_arrow.data = 0 diff --git a/tests/fast/udf/test_scalar_native.py b/tests/fast/udf/test_scalar_native.py index 94b2949e..76295060 100644 --- a/tests/fast/udf/test_scalar_native.py +++ b/tests/fast/udf/test_scalar_native.py @@ -1,4 +1,3 @@ - import pytest import duckdb From cde722e24a71faf0bf205ff090db55d8104e5998 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Tue, 9 Sep 2025 20:32:18 +0200 Subject: [PATCH 014/135] Ruff config: dont add future annotations --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a53f9eb5..03570028 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -329,7 +329,6 @@ select = [ "E", # pycodestyle "EM", # flake8-errmsg "F", # pyflakes - "FA", # flake8-future-annotations "FBT001", # flake8-boolean-trap "I", # isort "ICN", # flake8-import-conventions From 89c6387b68b6409c959f3aaa6f0fe0000d95488e Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Tue, 9 Sep 2025 20:37:19 +0200 Subject: [PATCH 015/135] Ruff config: temporarily skip import checks --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 03570028..811d9c1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -331,7 +331,7 @@ select = [ "F", # pyflakes "FBT001", # flake8-boolean-trap "I", # isort - "ICN", # flake8-import-conventions + #"ICN", # flake8-import-conventions "INT", # flake8-gettext "PERF", # perflint "PIE", # flake8-pie @@ -342,7 +342,7 @@ select = [ "SIM", # flake8-simplify "TCH", # flake8-type-checking "TD", # flake8-todos - "TID", # flake8-tidy-imports + #"TID", # flake8-tidy-imports "TRY", # tryceratops "UP", # pyupgrade "W", # pycodestyle From 6370653a70748094df961c5c35c8209290482758 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Tue, 9 Sep 2025 20:45:55 +0200 Subject: [PATCH 016/135] Ruff EM: Error messages assigned to a var first --- duckdb/experimental/spark/_globals.py | 3 +- duckdb/experimental/spark/errors/utils.py | 6 +- duckdb/experimental/spark/sql/column.py | 6 +- duckdb/experimental/spark/sql/dataframe.py | 9 ++- duckdb/experimental/spark/sql/functions.py | 30 ++++--- duckdb/experimental/spark/sql/readwriter.py | 78 ++++++++++++------- duckdb/experimental/spark/sql/types.py | 36 ++++++--- duckdb/filesystem.py | 3 +- duckdb/polars_io.py | 12 ++- duckdb_packaging/_versioning.py | 9 ++- duckdb_packaging/build_backend.py | 26 ++++--- duckdb_packaging/pypi_cleanup.py | 65 ++++++++++------ duckdb_packaging/setuptools_scm_version.py | 12 ++- scripts/generate_connection_methods.py | 12 ++- scripts/generate_connection_stubs.py | 12 ++- .../generate_connection_wrapper_methods.py | 9 ++- scripts/generate_connection_wrapper_stubs.py | 12 ++- scripts/generate_import_cache_json.py | 3 +- sqllogic/conftest.py | 17 ++-- tests/fast/adbc/test_statement_bind.py | 3 +- tests/fast/udf/test_scalar.py | 3 +- tests/fast/udf/test_scalar_arrow.py | 3 +- 22 files changed, 242 insertions(+), 127 deletions(-) diff --git a/duckdb/experimental/spark/_globals.py b/duckdb/experimental/spark/_globals.py index 4bc325f7..771daceb 100644 --- a/duckdb/experimental/spark/_globals.py +++ b/duckdb/experimental/spark/_globals.py @@ -39,7 +39,8 @@ def foo(arg=pyducdkb.spark._NoValue): # Disallow reloading this module so as to preserve the identities of the # classes defined here. if "_is_loaded" in globals(): - raise RuntimeError("Reloading duckdb.experimental.spark._globals is not allowed") + msg = "Reloading duckdb.experimental.spark._globals is not allowed" + raise RuntimeError(msg) _is_loaded = True diff --git a/duckdb/experimental/spark/errors/utils.py b/duckdb/experimental/spark/errors/utils.py index 8b737dde..984504a4 100644 --- a/duckdb/experimental/spark/errors/utils.py +++ b/duckdb/experimental/spark/errors/utils.py @@ -86,7 +86,8 @@ def get_message_template(self, error_class: str) -> str: if main_error_class in self.error_info_map: main_error_class_info_map = self.error_info_map[main_error_class] else: - raise ValueError(f"Cannot find main error class '{main_error_class}'") + msg = f"Cannot find main error class '{main_error_class}'" + raise ValueError(msg) main_message_template = "\n".join(main_error_class_info_map["message"]) @@ -101,7 +102,8 @@ def get_message_template(self, error_class: str) -> str: if sub_error_class in main_error_class_subclass_info_map: sub_error_class_info_map = main_error_class_subclass_info_map[sub_error_class] else: - raise ValueError(f"Cannot find sub error class '{sub_error_class}'") + msg = f"Cannot find sub error class '{sub_error_class}'" + raise ValueError(msg) sub_message_template = "\n".join(sub_error_class_info_map["message"]) message_template = main_message_template + " " + sub_message_template diff --git a/duckdb/experimental/spark/sql/column.py b/duckdb/experimental/spark/sql/column.py index 3a6f6cea..6cc92523 100644 --- a/duckdb/experimental/spark/sql/column.py +++ b/duckdb/experimental/spark/sql/column.py @@ -201,7 +201,8 @@ def __getattr__(self, item: Any) -> "Column": +------+ """ if item.startswith("__"): - raise AttributeError("Can not access __ (dunder) method") + msg = "Can not access __ (dunder) method" + raise AttributeError(msg) return self[item] def alias(self, alias: str): @@ -209,7 +210,8 @@ def alias(self, alias: str): def when(self, condition: "Column", value: Any): if not isinstance(condition, Column): - raise TypeError("condition should be a Column") + msg = "condition should be a Column" + raise TypeError(msg) v = _get_expr(value) expr = self.expr.when(condition.expr, v) return Column(expr) diff --git a/duckdb/experimental/spark/sql/dataframe.py b/duckdb/experimental/spark/sql/dataframe.py index d0d4835d..57c8cd03 100644 --- a/duckdb/experimental/spark/sql/dataframe.py +++ b/duckdb/experimental/spark/sql/dataframe.py @@ -108,7 +108,8 @@ def createGlobalTempView(self, name: str) -> None: def withColumnRenamed(self, columnName: str, newName: str) -> "DataFrame": if columnName not in self.relation: - raise ValueError(f"DataFrame does not contain a column named {columnName}") + msg = f"DataFrame does not contain a column named {columnName}" + raise ValueError(msg) cols = [] for x in self.relation.columns: col = ColumnExpression(x) @@ -258,7 +259,8 @@ def withColumnsRenamed(self, colsMap: dict[str, str]) -> "DataFrame": unknown_columns = set(colsMap.keys()) - set(self.relation.columns) if unknown_columns: - raise ValueError(f"DataFrame does not contain column(s): {', '.join(unknown_columns)}") + msg = f"DataFrame does not contain column(s): {', '.join(unknown_columns)}" + raise ValueError(msg) # Compute this only once old_column_names = list(colsMap.keys()) @@ -887,7 +889,8 @@ def __getitem__(self, item: Union[int, str, Column, list, tuple]) -> Union[Colum elif isinstance(item, int): return col(self._schema[item].name) else: - raise TypeError(f"Unexpected item type: {type(item)}") + msg = f"Unexpected item type: {type(item)}" + raise TypeError(msg) def __getattr__(self, name: str) -> Column: """Returns the :class:`Column` denoted by ``name``. diff --git a/duckdb/experimental/spark/sql/functions.py b/duckdb/experimental/spark/sql/functions.py index 501c9503..fddcd4c5 100644 --- a/duckdb/experimental/spark/sql/functions.py +++ b/duckdb/experimental/spark/sql/functions.py @@ -92,7 +92,8 @@ def ucase(str: "ColumnOrName") -> Column: def when(condition: "Column", value: Any) -> Column: if not isinstance(condition, Column): - raise TypeError("condition should be a Column") + msg = "condition should be a Column" + raise TypeError(msg) v = _get_expr(value) expr = CaseExpression(condition.expr, v) return Column(expr) @@ -1480,7 +1481,8 @@ def approx_count_distinct(col: "ColumnOrName", rsd: Optional[float] = None) -> C +---------------+ """ if rsd is not None: - raise ValueError("rsd is not supported by DuckDB") + msg = "rsd is not supported by DuckDB" + raise ValueError(msg) return _invoke_function_over_columns("approx_count_distinct", col) @@ -2365,7 +2367,8 @@ def rand(seed: Optional[int] = None) -> Column: """ if seed is not None: # Maybe call setseed just before but how do we know when it is executed? - raise ContributionsAcceptedError("Seed is not yet implemented") + msg = "Seed is not yet implemented" + raise ContributionsAcceptedError(msg) return _invoke_function("random") @@ -2842,7 +2845,8 @@ def encode(col: "ColumnOrName", charset: str) -> Column: +----------------+ """ if charset != "UTF-8": - raise ContributionsAcceptedError("Only UTF-8 charset is supported right now") + msg = "Only UTF-8 charset is supported right now" + raise ContributionsAcceptedError(msg) return _invoke_function("encode", _to_column_expr(col)) @@ -3017,7 +3021,8 @@ def greatest(*cols: "ColumnOrName") -> Column: [Row(greatest=4)] """ if len(cols) < 2: - raise ValueError("greatest should take at least 2 columns") + msg = "greatest should take at least 2 columns" + raise ValueError(msg) cols = [_to_column_expr(expr) for expr in cols] return _invoke_function("greatest", *cols) @@ -3049,7 +3054,8 @@ def least(*cols: "ColumnOrName") -> Column: [Row(least=1)] """ if len(cols) < 2: - raise ValueError("least should take at least 2 columns") + msg = "least should take at least 2 columns" + raise ValueError(msg) cols = [_to_column_expr(expr) for expr in cols] return _invoke_function("least", *cols) @@ -3550,12 +3556,14 @@ def sha2(col: "ColumnOrName", numBits: int) -> Column: +-----+----------------------------------------------------------------+ """ if numBits not in {224, 256, 384, 512, 0}: - raise ValueError("numBits should be one of {224, 256, 384, 512, 0}") + msg = "numBits should be one of {224, 256, 384, 512, 0}" + raise ValueError(msg) if numBits == 256: return _invoke_function_over_columns("sha256", col) - raise ContributionsAcceptedError("SHA-224, SHA-384, and SHA-512 are not supported yet.") + msg = "SHA-224, SHA-384, and SHA-512 are not supported yet." + raise ContributionsAcceptedError(msg) def curdate() -> Column: @@ -5241,7 +5249,8 @@ def array_sort(col: "ColumnOrName", comparator: Optional[Callable[[Column, Colum [Row(r=['foobar', 'foo', None, 'bar']), Row(r=['foo']), Row(r=[])] """ if comparator is not None: - raise ContributionsAcceptedError("comparator is not yet supported") + msg = "comparator is not yet supported" + raise ContributionsAcceptedError(msg) else: return _invoke_function_over_columns("list_sort", col, lit("ASC"), lit("NULLS LAST")) @@ -5335,7 +5344,8 @@ def split(str: "ColumnOrName", pattern: str, limit: int = -1) -> Column: if limit > 0: # Unclear how to implement this in DuckDB as we'd need to map back from the split array # to the original array which is tricky with regular expressions. - raise ContributionsAcceptedError("limit is not yet supported") + msg = "limit is not yet supported" + raise ContributionsAcceptedError(msg) return _invoke_function_over_columns("regexp_split_to_array", str, lit(pattern)) diff --git a/duckdb/experimental/spark/sql/readwriter.py b/duckdb/experimental/spark/sql/readwriter.py index 607e9d36..714ed797 100644 --- a/duckdb/experimental/spark/sql/readwriter.py +++ b/duckdb/experimental/spark/sql/readwriter.py @@ -248,10 +248,12 @@ def csv( def parquet(self, *paths: str, **options: "OptionalPrimitiveType") -> "DataFrame": input = list(paths) if len(input) != 1: - raise NotImplementedError("Only single paths are supported for now") + msg = "Only single paths are supported for now" + raise NotImplementedError(msg) option_amount = len(options.keys()) if option_amount != 0: - raise ContributionsAcceptedError("Options are not supported") + msg = "Options are not supported" + raise ContributionsAcceptedError(msg) path = input[0] rel = self.session.conn.read_parquet(path) from ..sql.dataframe import DataFrame @@ -338,53 +340,77 @@ def json( +---+------------+ """ if schema is not None: - raise ContributionsAcceptedError("The 'schema' option is not supported") + msg = "The 'schema' option is not supported" + raise ContributionsAcceptedError(msg) if primitivesAsString is not None: - raise ContributionsAcceptedError("The 'primitivesAsString' option is not supported") + msg = "The 'primitivesAsString' option is not supported" + raise ContributionsAcceptedError(msg) if prefersDecimal is not None: - raise ContributionsAcceptedError("The 'prefersDecimal' option is not supported") + msg = "The 'prefersDecimal' option is not supported" + raise ContributionsAcceptedError(msg) if allowComments is not None: - raise ContributionsAcceptedError("The 'allowComments' option is not supported") + msg = "The 'allowComments' option is not supported" + raise ContributionsAcceptedError(msg) if allowUnquotedFieldNames is not None: - raise ContributionsAcceptedError("The 'allowUnquotedFieldNames' option is not supported") + msg = "The 'allowUnquotedFieldNames' option is not supported" + raise ContributionsAcceptedError(msg) if allowSingleQuotes is not None: - raise ContributionsAcceptedError("The 'allowSingleQuotes' option is not supported") + msg = "The 'allowSingleQuotes' option is not supported" + raise ContributionsAcceptedError(msg) if allowNumericLeadingZero is not None: - raise ContributionsAcceptedError("The 'allowNumericLeadingZero' option is not supported") + msg = "The 'allowNumericLeadingZero' option is not supported" + raise ContributionsAcceptedError(msg) if allowBackslashEscapingAnyCharacter is not None: - raise ContributionsAcceptedError("The 'allowBackslashEscapingAnyCharacter' option is not supported") + msg = "The 'allowBackslashEscapingAnyCharacter' option is not supported" + raise ContributionsAcceptedError(msg) if mode is not None: - raise ContributionsAcceptedError("The 'mode' option is not supported") + msg = "The 'mode' option is not supported" + raise ContributionsAcceptedError(msg) if columnNameOfCorruptRecord is not None: - raise ContributionsAcceptedError("The 'columnNameOfCorruptRecord' option is not supported") + msg = "The 'columnNameOfCorruptRecord' option is not supported" + raise ContributionsAcceptedError(msg) if dateFormat is not None: - raise ContributionsAcceptedError("The 'dateFormat' option is not supported") + msg = "The 'dateFormat' option is not supported" + raise ContributionsAcceptedError(msg) if timestampFormat is not None: - raise ContributionsAcceptedError("The 'timestampFormat' option is not supported") + msg = "The 'timestampFormat' option is not supported" + raise ContributionsAcceptedError(msg) if multiLine is not None: - raise ContributionsAcceptedError("The 'multiLine' option is not supported") + msg = "The 'multiLine' option is not supported" + raise ContributionsAcceptedError(msg) if allowUnquotedControlChars is not None: - raise ContributionsAcceptedError("The 'allowUnquotedControlChars' option is not supported") + msg = "The 'allowUnquotedControlChars' option is not supported" + raise ContributionsAcceptedError(msg) if lineSep is not None: - raise ContributionsAcceptedError("The 'lineSep' option is not supported") + msg = "The 'lineSep' option is not supported" + raise ContributionsAcceptedError(msg) if samplingRatio is not None: - raise ContributionsAcceptedError("The 'samplingRatio' option is not supported") + msg = "The 'samplingRatio' option is not supported" + raise ContributionsAcceptedError(msg) if dropFieldIfAllNull is not None: - raise ContributionsAcceptedError("The 'dropFieldIfAllNull' option is not supported") + msg = "The 'dropFieldIfAllNull' option is not supported" + raise ContributionsAcceptedError(msg) if encoding is not None: - raise ContributionsAcceptedError("The 'encoding' option is not supported") + msg = "The 'encoding' option is not supported" + raise ContributionsAcceptedError(msg) if locale is not None: - raise ContributionsAcceptedError("The 'locale' option is not supported") + msg = "The 'locale' option is not supported" + raise ContributionsAcceptedError(msg) if pathGlobFilter is not None: - raise ContributionsAcceptedError("The 'pathGlobFilter' option is not supported") + msg = "The 'pathGlobFilter' option is not supported" + raise ContributionsAcceptedError(msg) if recursiveFileLookup is not None: - raise ContributionsAcceptedError("The 'recursiveFileLookup' option is not supported") + msg = "The 'recursiveFileLookup' option is not supported" + raise ContributionsAcceptedError(msg) if modifiedBefore is not None: - raise ContributionsAcceptedError("The 'modifiedBefore' option is not supported") + msg = "The 'modifiedBefore' option is not supported" + raise ContributionsAcceptedError(msg) if modifiedAfter is not None: - raise ContributionsAcceptedError("The 'modifiedAfter' option is not supported") + msg = "The 'modifiedAfter' option is not supported" + raise ContributionsAcceptedError(msg) if allowNonNumericNumbers is not None: - raise ContributionsAcceptedError("The 'allowNonNumericNumbers' option is not supported") + msg = "The 'allowNonNumericNumbers' option is not supported" + raise ContributionsAcceptedError(msg) if isinstance(path, str): path = [path] diff --git a/duckdb/experimental/spark/sql/types.py b/duckdb/experimental/spark/sql/types.py index 9d2b4b7d..55eb9855 100644 --- a/duckdb/experimental/spark/sql/types.py +++ b/duckdb/experimental/spark/sql/types.py @@ -731,7 +731,8 @@ def fromInternal(self, obj: T) -> T: return self.dataType.fromInternal(obj) def typeName(self) -> str: # type: ignore[override] - raise TypeError("StructField does not have typeName. Use typeName on its type explicitly instead.") + msg = "StructField does not have typeName. Use typeName on its type explicitly instead." + raise TypeError(msg) class StructType(DataType): @@ -841,7 +842,8 @@ def add( self.names.append(field.name) else: if isinstance(field, str) and data_type is None: - raise ValueError("Must specify DataType if passing name of struct_field to create.") + msg = "Must specify DataType if passing name of struct_field to create." + raise ValueError(msg) else: data_type_f = data_type self.fields.append(StructField(field, data_type_f, nullable, metadata)) @@ -866,16 +868,19 @@ def __getitem__(self, key: Union[str, int]) -> StructField: for field in self: if field.name == key: return field - raise KeyError(f"No StructField named {key}") + msg = f"No StructField named {key}" + raise KeyError(msg) elif isinstance(key, int): try: return self.fields[key] except IndexError: - raise IndexError("StructType index out of range") + msg = "StructType index out of range" + raise IndexError(msg) elif isinstance(key, slice): return StructType(self.fields[key]) else: - raise TypeError("StructType keys should be strings, integers or slices") + msg = "StructType keys should be strings, integers or slices" + raise TypeError(msg) def simpleString(self) -> str: return "struct<%s>" % (",".join(f.simpleString() for f in self)) @@ -978,12 +983,14 @@ def typeName(cls) -> str: @classmethod def sqlType(cls) -> DataType: """Underlying SQL storage type for this UDT.""" - raise NotImplementedError("UDT must implement sqlType().") + msg = "UDT must implement sqlType()." + raise NotImplementedError(msg) @classmethod def module(cls) -> str: """The Python module of the UDT.""" - raise NotImplementedError("UDT must implement module().") + msg = "UDT must implement module()." + raise NotImplementedError(msg) @classmethod def scalaUDT(cls) -> str: @@ -1013,11 +1020,13 @@ def fromInternal(self, obj: Any) -> Any: def serialize(self, obj: Any) -> Any: """Converts a user-type object into a SQL datum.""" - raise NotImplementedError("UDT must implement toInternal().") + msg = "UDT must implement toInternal()." + raise NotImplementedError(msg) def deserialize(self, datum: Any) -> Any: """Converts a SQL datum into a user-type object.""" - raise NotImplementedError("UDT must implement fromInternal().") + msg = "UDT must implement fromInternal()." + raise NotImplementedError(msg) def simpleString(self) -> str: return "udt" @@ -1126,7 +1135,8 @@ def __new__(cls, **kwargs: Any) -> "Row": ... def __new__(cls, *args: Optional[str], **kwargs: Optional[Any]) -> "Row": if args and kwargs: - raise ValueError("Can not use both args and kwargs to create Row") + msg = "Can not use both args and kwargs to create Row" + raise ValueError(msg) if kwargs: # create row objects row = tuple.__new__(cls, list(kwargs.values())) @@ -1163,7 +1173,8 @@ def asDict(self, recursive: bool = False) -> dict[str, Any]: True """ if not hasattr(self, "__fields__"): - raise TypeError("Cannot convert a Row class into dict") + msg = "Cannot convert a Row class into dict" + raise TypeError(msg) if recursive: @@ -1224,7 +1235,8 @@ def __getattr__(self, item: str) -> Any: def __setattr__(self, key: Any, value: Any) -> None: if key != "__fields__": - raise RuntimeError("Row is read-only") + msg = "Row is read-only" + raise RuntimeError(msg) self.__dict__[key] = value def __reduce__( diff --git a/duckdb/filesystem.py b/duckdb/filesystem.py index 885c797f..77838103 100644 --- a/duckdb/filesystem.py +++ b/duckdb/filesystem.py @@ -18,7 +18,8 @@ class ModifiedMemoryFileSystem(MemoryFileSystem): def add_file(self, object, path): if not is_file_like(object): - raise ValueError("Can not read from a non file-like object") + msg = "Can not read from a non file-like object" + raise ValueError(msg) path = self._strip_protocol(path) if isinstance(object, TextIOBase): # Wrap this so that we can return a bytes object from 'read' diff --git a/duckdb/polars_io.py b/duckdb/polars_io.py index 59758f19..69e1e7ea 100644 --- a/duckdb/polars_io.py +++ b/duckdb/polars_io.py @@ -132,9 +132,11 @@ def _pl_tree_to_sql(tree: dict) -> str: return f"({arg_sql} IS NULL)" if func == "IsNotNull": return f"({arg_sql} IS NOT NULL)" - raise NotImplementedError(f"Boolean function not supported: {func}") + msg = f"Boolean function not supported: {func}" + raise NotImplementedError(msg) - raise NotImplementedError(f"Unsupported function type: {func_dict}") + msg = f"Unsupported function type: {func_dict}" + raise NotImplementedError(msg) if node_type == "Scalar": # Detect format: old style (dtype/value) or new style (direct type key) @@ -200,9 +202,11 @@ def _pl_tree_to_sql(tree: dict) -> str: string_val = value.get("StringOwned", value.get("String", None)) return f"'{string_val}'" - raise NotImplementedError(f"Unsupported scalar type {dtype!s}, with value {value}") + msg = f"Unsupported scalar type {dtype!s}, with value {value}" + raise NotImplementedError(msg) - raise NotImplementedError(f"Node type: {node_type} is not implemented. {subtree}") + msg = f"Node type: {node_type} is not implemented. {subtree}" + raise NotImplementedError(msg) def duckdb_source(relation: duckdb.DuckDBPyRelation, schema: pl.schema.Schema) -> pl.LazyFrame: diff --git a/duckdb_packaging/_versioning.py b/duckdb_packaging/_versioning.py index 57008fa3..b338ef6b 100644 --- a/duckdb_packaging/_versioning.py +++ b/duckdb_packaging/_versioning.py @@ -30,7 +30,8 @@ def parse_version(version: str) -> tuple[int, int, int, int, int]: """ match = VERSION_RE.match(version) if not match: - raise ValueError(f"Invalid version format: {version} (expected X.Y.Z, X.Y.Z.rcM or X.Y.Z.postN)") + msg = f"Invalid version format: {version} (expected X.Y.Z, X.Y.Z.rcM or X.Y.Z.postN)" + raise ValueError(msg) major, minor, patch, rc, post = match.groups() return int(major), int(minor), int(patch), int(post or 0), int(rc or 0) @@ -51,7 +52,8 @@ def format_version(major: int, minor: int, patch: int, post: int = 0, rc: int = """ version = f"{major}.{minor}.{patch}" if post != 0 and rc != 0: - raise ValueError("post and rc are mutually exclusive") + msg = "post and rc are mutually exclusive" + raise ValueError(msg) if post != 0: version += f".post{post}" if rc != 0: @@ -168,4 +170,5 @@ def get_git_describe(repo_path: Optional[pathlib.Path] = None, since_major=False result.check_returncode() return result.stdout.strip() except FileNotFoundError: - raise RuntimeError("git executable can't be found") + msg = "git executable can't be found" + raise RuntimeError(msg) diff --git a/duckdb_packaging/build_backend.py b/duckdb_packaging/build_backend.py index dc94eeaa..aa5e4515 100644 --- a/duckdb_packaging/build_backend.py +++ b/duckdb_packaging/build_backend.py @@ -75,7 +75,8 @@ def _in_sdist() -> bool: def _duckdb_submodule_path() -> Path: """Verify that the duckdb submodule is checked out and usable and return its path.""" if not _in_git_repository(): - raise RuntimeError("Not in a git repository, no duckdb submodule present") + msg = "Not in a git repository, no duckdb submodule present" + raise RuntimeError(msg) # search the duckdb submodule gitmodules_path = Path(".gitmodules") modules = dict() @@ -97,7 +98,8 @@ def _duckdb_submodule_path() -> Path: modules[cur_module_reponame] = cur_module_path if "duckdb" not in modules: - raise RuntimeError("DuckDB submodule missing") + msg = "DuckDB submodule missing" + raise RuntimeError(msg) duckdb_path = modules["duckdb"] # now check that the submodule is usable @@ -106,9 +108,11 @@ def _duckdb_submodule_path() -> Path: status = status.decode("ascii", "replace") for line in status.splitlines(): if line.startswith("-"): - raise RuntimeError(f"Duckdb submodule not initialized: {line}") + msg = f"Duckdb submodule not initialized: {line}" + raise RuntimeError(msg) if line.startswith("U"): - raise RuntimeError(f"Duckdb submodule has merge conflicts: {line}") + msg = f"Duckdb submodule has merge conflicts: {line}" + raise RuntimeError(msg) if line.startswith("+"): _log(f"WARNING: Duckdb submodule not clean: {line}") # all good @@ -169,7 +173,8 @@ def _skbuild_config_add( if not key_exists: config_settings[store_key] = value elif fail_if_exists: - raise RuntimeError(f"{key} already present in config and may not be overridden") + msg = f"{key} already present in config and may not be overridden" + raise RuntimeError(msg) elif key_exists_as_list and val_is_list: config_settings[store_key].extend(value) elif key_exists_as_list and val_is_str: @@ -178,9 +183,8 @@ def _skbuild_config_add( _log(f"WARNING: overriding existing value in {store_key}") config_settings[store_key] = value else: - raise RuntimeError( - f"Type mismatch: cannot set {store_key} ({type(config_settings[store_key])}) to `{value}` ({type(value)})" - ) + msg = f"Type mismatch: cannot set {store_key} ({type(config_settings[store_key])}) to `{value}` ({type(value)})" + raise RuntimeError(msg) def build_sdist(sdist_directory: str, config_settings: Optional[dict[str, Union[list[str], str]]] = None) -> str: @@ -201,7 +205,8 @@ def build_sdist(sdist_directory: str, config_settings: Optional[dict[str, Union[ RuntimeError: If not in a git repository or DuckDB submodule issues. """ if not _in_git_repository(): - raise RuntimeError("Not in a git repository, can't create an sdist") + msg = "Not in a git repository, can't create an sdist" + raise RuntimeError(msg) submodule_path = _duckdb_submodule_path() if _FORCED_PEP440_VERSION is not None: duckdb_version = pep440_to_git_tag(strip_post_from_version(_FORCED_PEP440_VERSION)) @@ -237,7 +242,8 @@ def build_wheel( duckdb_version = None if not _in_git_repository(): if not _in_sdist(): - raise RuntimeError("Not in a git repository nor in an sdist, can't build a wheel") + msg = "Not in a git repository nor in an sdist, can't build a wheel" + raise RuntimeError(msg) _log("Building duckdb wheel from sdist. Reading duckdb version from file.") config_settings = config_settings or {} duckdb_version = _read_duckdb_long_version() diff --git a/duckdb_packaging/pypi_cleanup.py b/duckdb_packaging/pypi_cleanup.py index b45cf1a1..428e07dd 100644 --- a/duckdb_packaging/pypi_cleanup.py +++ b/duckdb_packaging/pypi_cleanup.py @@ -95,15 +95,18 @@ def setup_logging(verbose: bool = False) -> None: def validate_username(value: str) -> str: """Validate and sanitize username input.""" if not value or not value.strip(): - raise argparse.ArgumentTypeError("Username cannot be empty") + msg = "Username cannot be empty" + raise argparse.ArgumentTypeError(msg) username = value.strip() if len(username) > 100: # Reasonable limit - raise argparse.ArgumentTypeError("Username too long (max 100 characters)") + msg = "Username too long (max 100 characters)" + raise argparse.ArgumentTypeError(msg) # Basic validation - PyPI usernames are alphanumeric with limited special chars if not re.match(r"^[a-zA-Z0-9][a-zA-Z0-9._-]*[a-zA-Z0-9]$|^[a-zA-Z0-9]$", username): - raise argparse.ArgumentTypeError("Invalid username format") + msg = "Invalid username format" + raise argparse.ArgumentTypeError(msg) return username @@ -140,9 +143,11 @@ def load_credentials(dry_run: bool) -> tuple[Optional[str], Optional[str]]: otp = os.getenv("PYPI_CLEANUP_OTP") if not password: - raise ValidationError("PYPI_CLEANUP_PASSWORD environment variable is required when not in dry-run mode") + msg = "PYPI_CLEANUP_PASSWORD environment variable is required when not in dry-run mode" + raise ValidationError(msg) if not otp: - raise ValidationError("PYPI_CLEANUP_OTP environment variable is required when not in dry-run mode") + msg = "PYPI_CLEANUP_OTP environment variable is required when not in dry-run mode" + raise ValidationError(msg) return password, otp @@ -150,10 +155,12 @@ def load_credentials(dry_run: bool) -> tuple[Optional[str], Optional[str]]: def validate_arguments(args: argparse.Namespace) -> None: """Validate parsed arguments.""" if not args.dry_run and not args.username: - raise ValidationError("--username is required when not in dry-run mode") + msg = "--username is required when not in dry-run mode" + raise ValidationError(msg) if args.max_nightlies < 0: - raise ValidationError("--max-nightlies must be non-negative") + msg = "--max-nightlies must be non-negative" + raise ValidationError(msg) class CsrfParser(HTMLParser): @@ -287,7 +294,8 @@ def _fetch_released_versions(self, http_session: Session) -> set[str]: logging.debug(f"Found {len(versions)} releases with files") return versions except RequestException as e: - raise PyPICleanupError(f"Failed to fetch package information for '{self._package}': {e}") from e + msg = f"Failed to fetch package information for '{self._package}': {e}" + raise PyPICleanupError(msg) from e def _is_stable_release_version(self, version: str) -> bool: """Determine whether a version string denotes a stable release.""" @@ -305,14 +313,16 @@ def _parse_rc_version(self, version: str) -> str: """Parse a rc version string to determine the base version.""" match = self._rc_version_pattern.match(version) if not match: - raise PyPICleanupError(f"Invalid rc version '{version}'") + msg = f"Invalid rc version '{version}'" + raise PyPICleanupError(msg) return match.group("version") if match else None def _parse_dev_version(self, version: str) -> tuple[str, int]: """Parse a dev version string to determine the base version and dev version id.""" match = self._dev_version_pattern.match(version) if not match: - raise PyPICleanupError(f"Invalid dev version '{version}'") + msg = f"Invalid dev version '{version}'" + raise PyPICleanupError(msg) return match.group("version"), int(match.group("dev_id")) def _determine_versions_to_delete(self, versions: set[str]) -> set[str]: @@ -363,15 +373,17 @@ def _determine_versions_to_delete(self, versions: set[str]) -> set[str]: # Final safety checks if versions_to_delete == versions: - raise PyPICleanupError( + msg = ( f"Safety check failed: cleanup would delete ALL versions of '{self._package}'. " "This would make the package permanently inaccessible. Aborting." ) + raise PyPICleanupError(msg) if len(versions_to_delete.intersection(stable_versions)) > 0: - raise PyPICleanupError( + msg = ( f"Safety check failed: cleanup would delete one or more stable versions of '{self._package}'. " f"A regexp might be broken? (would delete {versions_to_delete.intersection(stable_versions)})" ) + raise PyPICleanupError(msg) unknown_versions = versions.difference(stable_versions).difference(rc_versions).difference(dev_versions) if unknown_versions: logging.warning(f"Found version string(s) in an unsupported format: {unknown_versions}") @@ -381,7 +393,8 @@ def _determine_versions_to_delete(self, versions: set[str]) -> set[str]: def _authenticate(self, http_session: Session) -> None: """Authenticate with PyPI.""" if not self._username or not self._password: - raise AuthenticationError("Username and password are required for authentication") + msg = "Username and password are required for authentication" + raise AuthenticationError(msg) logging.info(f"Authenticating user '{self._username}' with PyPI") @@ -397,7 +410,8 @@ def _authenticate(self, http_session: Session) -> None: logging.info("Authentication successful") except RequestException as e: - raise AuthenticationError(f"Network error during authentication: {e}") from e + msg = f"Network error during authentication: {e}" + raise AuthenticationError(msg) from e def _get_csrf_token(self, http_session: Session, form_action: str) -> str: """Extract CSRF token from a form page.""" @@ -406,7 +420,8 @@ def _get_csrf_token(self, http_session: Session, form_action: str) -> str: parser = CsrfParser(form_action) parser.feed(resp.text) if not parser.csrf: - raise AuthenticationError(f"No CSRF token found in {form_action}") + msg = f"No CSRF token found in {form_action}" + raise AuthenticationError(msg) return parser.csrf def _perform_login(self, http_session: Session) -> requests.Response: @@ -425,14 +440,16 @@ def _perform_login(self, http_session: Session) -> requests.Response: # Check if login failed (redirected back to login page) if response.url == f"{self._index_url}/account/login/": - raise AuthenticationError(f"Login failed for user '{self._username}' - check credentials") + msg = f"Login failed for user '{self._username}' - check credentials" + raise AuthenticationError(msg) return response def _handle_two_factor_auth(self, http_session: Session, response: requests.Response) -> None: """Handle two-factor authentication.""" if not self._otp: - raise AuthenticationError("Two-factor authentication required but no OTP secret provided") + msg = "Two-factor authentication required but no OTP secret provided" + raise AuthenticationError(msg) two_factor_url = response.url form_action = two_factor_url[len(self._index_url) :] @@ -462,11 +479,13 @@ def _handle_two_factor_auth(self, http_session: Session, response: requests.Resp except RequestException as e: if attempt == _LOGIN_RETRY_ATTEMPTS - 1: - raise AuthenticationError(f"Network error during 2FA: {e}") from e + msg = f"Network error during 2FA: {e}" + raise AuthenticationError(msg) from e logging.debug(f"Network error during 2FA attempt {attempt + 1}, retrying...") time.sleep(_LOGIN_RETRY_DELAY) - raise AuthenticationError("Two-factor authentication failed after all attempts") + msg = "Two-factor authentication failed after all attempts" + raise AuthenticationError(msg) def _delete_versions(self, http_session: Session, versions_to_delete: set[str]) -> None: """Delete the specified package versions.""" @@ -483,15 +502,15 @@ def _delete_versions(self, http_session: Session, versions_to_delete: set[str]) failed_deletions.append(version) if failed_deletions: - raise PyPICleanupError( - f"Failed to delete {len(failed_deletions)}/{len(versions_to_delete)} versions: {failed_deletions}" - ) + msg = f"Failed to delete {len(failed_deletions)}/{len(versions_to_delete)} versions: {failed_deletions}" + raise PyPICleanupError(msg) def _delete_single_version(self, http_session: Session, version: str) -> None: """Delete a single package version.""" # Safety check if not self._is_dev_version(version) or self._is_rc_version(version): - raise PyPICleanupError(f"Refusing to delete non-[dev|rc] version: {version}") + msg = f"Refusing to delete non-[dev|rc] version: {version}" + raise PyPICleanupError(msg) logging.debug(f"Deleting {self._package} version {version}") diff --git a/duckdb_packaging/setuptools_scm_version.py b/duckdb_packaging/setuptools_scm_version.py index 2ff79f80..5b0c5383 100644 --- a/duckdb_packaging/setuptools_scm_version.py +++ b/duckdb_packaging/setuptools_scm_version.py @@ -40,12 +40,14 @@ def version_scheme(version: Any) -> str: # Handle case where tag is None if version.tag is None: - raise ValueError("Need a valid version. Did you set a fallback_version in pyproject.toml?") + msg = "Need a valid version. Did you set a fallback_version in pyproject.toml?" + raise ValueError(msg) try: return _bump_version(str(version.tag), version.distance, version.dirty) except Exception as e: - raise RuntimeError(f"Failed to bump version: {e}") + msg = f"Failed to bump version: {e}" + raise RuntimeError(msg) def _bump_version(base_version: str, distance: int, dirty: bool = False) -> str: @@ -54,7 +56,8 @@ def _bump_version(base_version: str, distance: int, dirty: bool = False) -> str: try: major, minor, patch, post, rc = parse_version(base_version) except ValueError: - raise ValueError(f"Incorrect version format: {base_version} (expected X.Y.Z or X.Y.Z.postN)") + msg = f"Incorrect version format: {base_version} (expected X.Y.Z or X.Y.Z.postN)" + raise ValueError(msg) # If we're exactly on a tag (distance = 0, dirty=False) distance = int(distance or 0) @@ -110,7 +113,8 @@ def _git_describe_override_to_pep_440(override_value: str) -> str: match = describe_pattern.match(override_value) if not match: - raise ValueError(f"Invalid git describe override: {override_value}") + msg = f"Invalid git describe override: {override_value}" + raise ValueError(msg) version, distance, commit_hash = match.groups() diff --git a/scripts/generate_connection_methods.py b/scripts/generate_connection_methods.py index 51f667f6..a3bf36ad 100644 --- a/scripts/generate_connection_methods.py +++ b/scripts/generate_connection_methods.py @@ -37,15 +37,18 @@ def generate(): for i, line in enumerate(source_code): if line.startswith(INITIALIZE_METHOD): if start_index != -1: - raise ValueError("Encountered the INITIALIZE_METHOD a second time, quitting!") + msg = "Encountered the INITIALIZE_METHOD a second time, quitting!" + raise ValueError(msg) start_index = i elif line.startswith(END_MARKER): if end_index != -1: - raise ValueError("Encountered the END_MARKER a second time, quitting!") + msg = "Encountered the END_MARKER a second time, quitting!" + raise ValueError(msg) end_index = i if start_index == -1 or end_index == -1: - raise ValueError("Couldn't find start or end marker in source file") + msg = "Couldn't find start or end marker in source file" + raise ValueError(msg) start_section = source_code[: start_index + 1] end_section = source_code[end_index:] @@ -128,5 +131,6 @@ def create_definition(name, method) -> str: if __name__ == "__main__": - raise ValueError("Please use 'generate_connection_code.py' instead of running the individual script(s)") + msg = "Please use 'generate_connection_code.py' instead of running the individual script(s)" + raise ValueError(msg) # generate() diff --git a/scripts/generate_connection_stubs.py b/scripts/generate_connection_stubs.py index 9b1be9aa..910e657a 100644 --- a/scripts/generate_connection_stubs.py +++ b/scripts/generate_connection_stubs.py @@ -20,15 +20,18 @@ def generate(): for i, line in enumerate(source_code): if line.startswith(START_MARKER): if start_index != -1: - raise ValueError("Encountered the START_MARKER a second time, quitting!") + msg = "Encountered the START_MARKER a second time, quitting!" + raise ValueError(msg) start_index = i elif line.startswith(END_MARKER): if end_index != -1: - raise ValueError("Encountered the END_MARKER a second time, quitting!") + msg = "Encountered the END_MARKER a second time, quitting!" + raise ValueError(msg) end_index = i if start_index == -1 or end_index == -1: - raise ValueError("Couldn't find start or end marker in source file") + msg = "Couldn't find start or end marker in source file" + raise ValueError(msg) start_section = source_code[: start_index + 1] end_section = source_code[end_index:] @@ -94,5 +97,6 @@ def create_definition(name, method, overloaded: bool) -> str: if __name__ == "__main__": - raise ValueError("Please use 'generate_connection_code.py' instead of running the individual script(s)") + msg = "Please use 'generate_connection_code.py' instead of running the individual script(s)" + raise ValueError(msg) # generate() diff --git a/scripts/generate_connection_wrapper_methods.py b/scripts/generate_connection_wrapper_methods.py index d2ef0bba..743d0224 100644 --- a/scripts/generate_connection_wrapper_methods.py +++ b/scripts/generate_connection_wrapper_methods.py @@ -75,15 +75,18 @@ def remove_section(content, start_marker, end_marker) -> tuple[list[str], list[s for i, line in enumerate(content): if line.startswith(start_marker): if start_index != -1: - raise ValueError("Encountered the START_MARKER a second time, quitting!") + msg = "Encountered the START_MARKER a second time, quitting!" + raise ValueError(msg) start_index = i elif line.startswith(end_marker): if end_index != -1: - raise ValueError("Encountered the END_MARKER a second time, quitting!") + msg = "Encountered the END_MARKER a second time, quitting!" + raise ValueError(msg) end_index = i if start_index == -1 or end_index == -1: - raise ValueError("Couldn't find start or end marker in source file") + msg = "Couldn't find start or end marker in source file" + raise ValueError(msg) start_section = content[: start_index + 1] end_section = content[end_index:] diff --git a/scripts/generate_connection_wrapper_stubs.py b/scripts/generate_connection_wrapper_stubs.py index 3b3b8c93..4066d0ea 100644 --- a/scripts/generate_connection_wrapper_stubs.py +++ b/scripts/generate_connection_wrapper_stubs.py @@ -21,15 +21,18 @@ def generate(): for i, line in enumerate(source_code): if line.startswith(START_MARKER): if start_index != -1: - raise ValueError("Encountered the START_MARKER a second time, quitting!") + msg = "Encountered the START_MARKER a second time, quitting!" + raise ValueError(msg) start_index = i elif line.startswith(END_MARKER): if end_index != -1: - raise ValueError("Encountered the END_MARKER a second time, quitting!") + msg = "Encountered the END_MARKER a second time, quitting!" + raise ValueError(msg) end_index = i if start_index == -1 or end_index == -1: - raise ValueError("Couldn't find start or end marker in source file") + msg = "Couldn't find start or end marker in source file" + raise ValueError(msg) start_section = source_code[: start_index + 1] end_section = source_code[end_index:] @@ -118,5 +121,6 @@ def create_definition(name, method, overloaded: bool) -> str: if __name__ == "__main__": - raise ValueError("Please use 'generate_connection_code.py' instead of running the individual script(s)") + msg = "Please use 'generate_connection_code.py' instead of running the individual script(s)" + raise ValueError(msg) # generate() diff --git a/scripts/generate_import_cache_json.py b/scripts/generate_import_cache_json.py index 34cd84b6..7b43d175 100644 --- a/scripts/generate_import_cache_json.py +++ b/scripts/generate_import_cache_json.py @@ -107,7 +107,8 @@ def add_or_get_module(self, module_name: str) -> ImportCacheModule: def get_module(self, module_name: str) -> ImportCacheModule: if module_name not in self.modules: - raise ValueError("Import the module before registering its attributes!") + msg = "Import the module before registering its attributes!" + raise ValueError(msg) return self.modules[module_name] def get_item(self, item_name: str) -> Union[ImportCacheModule, ImportCacheAttribute]: diff --git a/sqllogic/conftest.py b/sqllogic/conftest.py index 77281d54..db875566 100644 --- a/sqllogic/conftest.py +++ b/sqllogic/conftest.py @@ -179,16 +179,20 @@ def determine_test_offsets(config: pytest.Config, num_tests: int) -> tuple[int, percentage_specified = start_offset_percentage is not None or end_offset_percentage is not None if index_specified and percentage_specified: - raise ValueError("You can only specify either start/end offsets or start/end offset percentages, not both") + msg = "You can only specify either start/end offsets or start/end offset percentages, not both" + raise ValueError(msg) if start_offset is not None and start_offset < 0: - raise ValueError("--start-offset must be a non-negative integer") + msg = "--start-offset must be a non-negative integer" + raise ValueError(msg) if start_offset_percentage is not None and (start_offset_percentage < 0 or start_offset_percentage > 100): - raise ValueError("--start-offset-percentage must be between 0 and 100") + msg = "--start-offset-percentage must be between 0 and 100" + raise ValueError(msg) if end_offset_percentage is not None and (end_offset_percentage < 0 or end_offset_percentage > 100): - raise ValueError("--end-offset-percentage must be between 0 and 100") + msg = "--end-offset-percentage must be between 0 and 100" + raise ValueError(msg) if start_offset is None: if start_offset_percentage is not None: @@ -197,9 +201,8 @@ def determine_test_offsets(config: pytest.Config, num_tests: int) -> tuple[int, start_offset = 0 if end_offset is not None and end_offset < start_offset: - raise ValueError( - f"--end-offset ({end_offset}) must be greater than or equal to the start offset ({start_offset})" - ) + msg = f"--end-offset ({end_offset}) must be greater than or equal to the start offset ({start_offset})" + raise ValueError(msg) if end_offset is None: if end_offset_percentage is not None: diff --git a/tests/fast/adbc/test_statement_bind.py b/tests/fast/adbc/test_statement_bind.py index dc5d1f59..c8b935cb 100644 --- a/tests/fast/adbc/test_statement_bind.py +++ b/tests/fast/adbc/test_statement_bind.py @@ -21,7 +21,8 @@ def _import(handle): return pa.RecordBatchReader._import_from_c(handle.address) elif isinstance(handle, adbc_driver_manager.ArrowSchemaHandle): return pa.Schema._import_from_c(handle.address) - raise NotImplementedError(f"Importing {handle!r}") + msg = f"Importing {handle!r}" + raise NotImplementedError(msg) def _bind(stmt, batch): diff --git a/tests/fast/udf/test_scalar.py b/tests/fast/udf/test_scalar.py index b7f4e343..57160d75 100644 --- a/tests/fast/udf/test_scalar.py +++ b/tests/fast/udf/test_scalar.py @@ -133,7 +133,8 @@ def no_op(x): @pytest.mark.parametrize("udf_type", ["arrow", "native"]) def test_exceptions(self, udf_type): def raises_exception(x): - raise AttributeError("error") + msg = "error" + raise AttributeError(msg) con = duckdb.connect() con.create_function("raises", raises_exception, [BIGINT], BIGINT, type=udf_type) diff --git a/tests/fast/udf/test_scalar_arrow.py b/tests/fast/udf/test_scalar_arrow.py index 984a1f8c..28d86455 100644 --- a/tests/fast/udf/test_scalar_arrow.py +++ b/tests/fast/udf/test_scalar_arrow.py @@ -47,7 +47,8 @@ def test_varargs(self): def variable_args(*args): # We return a chunked array here, but internally we convert this into a Table if len(args) == 0: - raise ValueError("Expected at least one argument") + msg = "Expected at least one argument" + raise ValueError(msg) for item in args: return item From 6445bdef32566385cadfd45ff5972e3dfcc5a5e1 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Tue, 9 Sep 2025 20:53:02 +0200 Subject: [PATCH 017/135] Ruff D: Docstring fixes --- duckdb/experimental/spark/exception.py | 2 +- duckdb/experimental/spark/sql/column.py | 10 ++++---- duckdb/experimental/spark/sql/dataframe.py | 2 +- duckdb/experimental/spark/sql/functions.py | 28 +++++++++++----------- duckdb/experimental/spark/sql/types.py | 8 +++---- duckdb/udf.py | 2 +- sqllogic/conftest.py | 2 +- tests/fast/spark/test_spark_dataframe.py | 2 +- tests/fast/test_json_logging.py | 2 +- tests/fast/test_pypi_cleanup.py | 2 +- tests/slow/test_h2oai_arrow.py | 2 +- 11 files changed, 31 insertions(+), 31 deletions(-) diff --git a/duckdb/experimental/spark/exception.py b/duckdb/experimental/spark/exception.py index 791f7090..1c2ad9a6 100644 --- a/duckdb/experimental/spark/exception.py +++ b/duckdb/experimental/spark/exception.py @@ -1,7 +1,7 @@ class ContributionsAcceptedError(NotImplementedError): """This method is not planned to be implemented, if you would like to implement this method or show your interest in this method to other members of the community, - feel free to open up a PR or a Discussion over on https://github.com/duckdb/duckdb + feel free to open up a PR or a Discussion over on https://github.com/duckdb/duckdb. """ def __init__(self, message=None) -> None: diff --git a/duckdb/experimental/spark/sql/column.py b/duckdb/experimental/spark/sql/column.py index 6cc92523..dd676846 100644 --- a/duckdb/experimental/spark/sql/column.py +++ b/duckdb/experimental/spark/sql/column.py @@ -29,7 +29,7 @@ def _unary_op( name: str, doc: str = "unary operator", ) -> Callable[["Column"], "Column"]: - """Create a method for given unary operator""" + """Create a method for given unary operator.""" def _(self: "Column") -> "Column": # Call the function identified by 'name' on the internal Expression object @@ -44,7 +44,7 @@ def _bin_op( name: str, doc: str = "binary operator", ) -> Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"]], "Column"]: - """Create a method for given binary operator""" + """Create a method for given binary operator.""" def _( self: "Column", @@ -62,7 +62,7 @@ def _bin_func( name: str, doc: str = "binary function", ) -> Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"]], "Column"]: - """Create a function expression for the given binary function""" + """Create a function expression for the given binary function.""" def _( self: "Column", @@ -245,14 +245,14 @@ def __eq__( # type: ignore[override] self, other: Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"], ) -> "Column": - """Binary function""" + """Binary function.""" return Column(self.expr == (_get_expr(other))) def __ne__( # type: ignore[override] self, other: object, ) -> "Column": - """Binary function""" + """Binary function.""" return Column(self.expr != (_get_expr(other))) __lt__ = _bin_op("__lt__") diff --git a/duckdb/experimental/spark/sql/dataframe.py b/duckdb/experimental/spark/sql/dataframe.py index 57c8cd03..16d54f0b 100644 --- a/duckdb/experimental/spark/sql/dataframe.py +++ b/duckdb/experimental/spark/sql/dataframe.py @@ -845,7 +845,7 @@ def limit(self, num: int) -> "DataFrame": return DataFrame(rel, self.session) def __contains__(self, item: str) -> bool: - """Check if the :class:`DataFrame` contains a column by the name of `item`""" + """Check if the :class:`DataFrame` contains a column by the name of `item`.""" return item in self.relation @property diff --git a/duckdb/experimental/spark/sql/functions.py b/duckdb/experimental/spark/sql/functions.py index fddcd4c5..a319ec13 100644 --- a/duckdb/experimental/spark/sql/functions.py +++ b/duckdb/experimental/spark/sql/functions.py @@ -164,7 +164,7 @@ def _to_column_expr(col: ColumnOrName) -> Expression: def regexp_replace(str: "ColumnOrName", pattern: str, replacement: str) -> Column: - r"""Replace all substrings of the specified string value that match regexp with rep. + """Replace all substrings of the specified string value that match regexp with rep. .. versionadded:: 1.5.0 @@ -1487,7 +1487,7 @@ def approx_count_distinct(col: "ColumnOrName", rsd: Optional[float] = None) -> C def approxCountDistinct(col: "ColumnOrName", rsd: Optional[float] = None) -> Column: - """.. versionadded:: 1.3.0 + """.. versionadded:: 1.3.0. .. versionchanged:: 3.4.0 Supports Spark Connect. @@ -2059,7 +2059,7 @@ def cbrt(col: "ColumnOrName") -> Column: def char(col: "ColumnOrName") -> Column: """Returns the ASCII character having the binary equivalent to `col`. If col is larger than 256 the - result is equivalent to char(col % 256) + result is equivalent to char(col % 256). .. versionadded:: 3.5.0 @@ -2373,7 +2373,7 @@ def rand(seed: Optional[int] = None) -> Column: def regexp(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: - r"""Returns true if `str` matches the Java regex `regexp`, or false otherwise. + """Returns true if `str` matches the Java regex `regexp`, or false otherwise. .. versionadded:: 3.5.0 @@ -2425,7 +2425,7 @@ def regexp(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: def regexp_count(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: - r"""Returns a count of the number of times that the Java regex pattern `regexp` is matched + """Returns a count of the number of times that the Java regex pattern `regexp` is matched in the string `str`. .. versionadded:: 3.5.0 @@ -2456,7 +2456,7 @@ def regexp_count(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: def regexp_extract(str: "ColumnOrName", pattern: str, idx: int) -> Column: - r"""Extract a specific group matched by the Java regex `regexp`, from the specified string column. + """Extract a specific group matched by the Java regex `regexp`, from the specified string column. If the regex did not match, or the specified group did not match, an empty string is returned. .. versionadded:: 1.5.0 @@ -2496,7 +2496,7 @@ def regexp_extract(str: "ColumnOrName", pattern: str, idx: int) -> Column: def regexp_extract_all(str: "ColumnOrName", regexp: "ColumnOrName", idx: Optional[Union[int, Column]] = None) -> Column: - r"""Extract all strings in the `str` that match the Java regex `regexp` + """Extract all strings in the `str` that match the Java regex `regexp` and corresponding to the regex group index. .. versionadded:: 3.5.0 @@ -2535,7 +2535,7 @@ def regexp_extract_all(str: "ColumnOrName", regexp: "ColumnOrName", idx: Optiona def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: - r"""Returns true if `str` matches the Java regex `regexp`, or false otherwise. + """Returns true if `str` matches the Java regex `regexp`, or false otherwise. .. versionadded:: 3.5.0 @@ -2587,7 +2587,7 @@ def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: def regexp_substr(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: - r"""Returns the substring that matches the Java regex `regexp` within the string `str`. + """Returns the substring that matches the Java regex `regexp` within the string `str`. If the regular expression is not found, the result is null. .. versionadded:: 3.5.0 @@ -3996,7 +3996,7 @@ def month(col: "ColumnOrName") -> Column: def dayofweek(col: "ColumnOrName") -> Column: """Extract the day of the week of a given date/timestamp as integer. - Ranges from 1 for a Sunday through to 7 for a Saturday + Ranges from 1 for a Sunday through to 7 for a Saturday. .. versionadded:: 2.3.0 @@ -4187,7 +4187,7 @@ def second(col: "ColumnOrName") -> Column: def weekofyear(col: "ColumnOrName") -> Column: """Extract the week number of a given date as integer. A week is considered to start on a Monday and week 1 is the first week with more than 3 days, - as defined by ISO 8601 + as defined by ISO 8601. .. versionadded:: 1.5.0 @@ -4609,7 +4609,7 @@ def atan(col: "ColumnOrName") -> Column: def atan2(col1: Union["ColumnOrName", float], col2: Union["ColumnOrName", float]) -> Column: - """.. versionadded:: 1.4.0 + """.. versionadded:: 1.4.0. .. versionchanged:: 3.4.0 Supports Spark Connect. @@ -5577,7 +5577,7 @@ def var_samp(col: "ColumnOrName") -> Column: def variance(col: "ColumnOrName") -> Column: - """Aggregate function: alias for var_samp + """Aggregate function: alias for var_samp. .. versionadded:: 1.6.0 @@ -6242,7 +6242,7 @@ def instr(str: "ColumnOrName", substr: str) -> Column: def expr(str: str) -> Column: - """Parses the expression string into the column that it represents + """Parses the expression string into the column that it represents. .. versionadded:: 1.5.0 diff --git a/duckdb/experimental/spark/sql/types.py b/duckdb/experimental/spark/sql/types.py index 55eb9855..4418f495 100644 --- a/duckdb/experimental/spark/sql/types.py +++ b/duckdb/experimental/spark/sql/types.py @@ -113,7 +113,7 @@ def fromInternal(self, obj: Any) -> Any: # This singleton pattern does not work with pickle, you will get # another object after pickle and unpickle class DataTypeSingleton(type): - """Metaclass for DataType""" + """Metaclass for DataType.""" _instances: ClassVar[dict[type["DataTypeSingleton"], "DataTypeSingleton"]] = {} @@ -855,7 +855,7 @@ def add( return self def __iter__(self) -> Iterator[StructField]: - """Iterate the fields""" + """Iterate the fields.""" return iter(self.fields) def __len__(self) -> int: @@ -1147,7 +1147,7 @@ def __new__(cls, *args: Optional[str], **kwargs: Optional[Any]) -> "Row": return tuple.__new__(cls, args) def asDict(self, recursive: bool = False) -> dict[str, Any]: - """Return as a dict + """Return as a dict. Parameters ---------- @@ -1200,7 +1200,7 @@ def __contains__(self, item: Any) -> bool: # let object acts like class def __call__(self, *args: Any) -> "Row": - """Create new Row object""" + """Create new Row object.""" if len(args) > len(self): raise ValueError( "Can not create Row with fields %s, expected %d values but got %s" % (self, len(self), args) diff --git a/duckdb/udf.py b/duckdb/udf.py index 21d6d53f..1357dee5 100644 --- a/duckdb/udf.py +++ b/duckdb/udf.py @@ -1,5 +1,5 @@ def vectorized(func): - """Decorate a function with annotated function parameters, so DuckDB can infer that the function should be provided with pyarrow arrays and should expect pyarrow array(s) as output""" + """Decorate a function with annotated function parameters, so DuckDB can infer that the function should be provided with pyarrow arrays and should expect pyarrow array(s) as output.""" import types from inspect import signature diff --git a/sqllogic/conftest.py b/sqllogic/conftest.py index db875566..40759e9c 100644 --- a/sqllogic/conftest.py +++ b/sqllogic/conftest.py @@ -271,7 +271,7 @@ def pytest_collection_modifyitems(session: pytest.Session, config: pytest.Config def pytest_runtest_setup(item: pytest.Item): - """Show the test index after the test name""" + """Show the test index after the test name.""" def get_from_tuple_list(tuples, key): for t in tuples: diff --git a/tests/fast/spark/test_spark_dataframe.py b/tests/fast/spark/test_spark_dataframe.py index 26006952..3fd78090 100644 --- a/tests/fast/spark/test_spark_dataframe.py +++ b/tests/fast/spark/test_spark_dataframe.py @@ -339,7 +339,7 @@ def test_df_columns(self, spark): assert "OtherInfo" in updatedDF.columns def test_array_and_map_type(self, spark): - """Array & Map""" + """Array & Map.""" arrayStructureSchema = StructType( [ StructField( diff --git a/tests/fast/test_json_logging.py b/tests/fast/test_json_logging.py index b29ea7bf..9e9908ea 100644 --- a/tests/fast/test_json_logging.py +++ b/tests/fast/test_json_logging.py @@ -6,7 +6,7 @@ def _parse_json_func(error_prefix: str): - """Helper to check that the error message is indeed parsable json""" + """Helper to check that the error message is indeed parsable json.""" def parse_func(exception): msg = exception.args[0] diff --git a/tests/fast/test_pypi_cleanup.py b/tests/fast/test_pypi_cleanup.py index 0e0439ce..74b1266f 100644 --- a/tests/fast/test_pypi_cleanup.py +++ b/tests/fast/test_pypi_cleanup.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -"""Unit tests for pypi_cleanup.py +"""Unit tests for pypi_cleanup.py. Run with: python -m pytest test_pypi_cleanup.py -v """ diff --git a/tests/slow/test_h2oai_arrow.py b/tests/slow/test_h2oai_arrow.py index d0dbc2fe..35d8b1c7 100644 --- a/tests/slow/test_h2oai_arrow.py +++ b/tests/slow/test_h2oai_arrow.py @@ -197,7 +197,7 @@ def test_join(self, threads, function, large_data): @fixture(scope="module") def arrow_dataset_register(): - """Single fixture to download files and register them on the given connection""" + """Single fixture to download files and register them on the given connection.""" session = requests.Session() retries = urllib3_util.Retry( allowed_methods={"GET"}, # only retry on GETs (all we do) From 2119fb08ac8c5e55a12eb94b2d544d8a21eb69d2 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Tue, 9 Sep 2025 20:55:56 +0200 Subject: [PATCH 018/135] Ruff D301: Make docstring raw if they contain backslashes --- duckdb/experimental/spark/sql/functions.py | 26 +++++++++++----------- duckdb/experimental/spark/sql/types.py | 4 ++-- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/duckdb/experimental/spark/sql/functions.py b/duckdb/experimental/spark/sql/functions.py index a319ec13..92631ee8 100644 --- a/duckdb/experimental/spark/sql/functions.py +++ b/duckdb/experimental/spark/sql/functions.py @@ -108,7 +108,7 @@ def struct(*cols: Column) -> Column: def array(*cols: Union["ColumnOrName", Union[list["ColumnOrName"], tuple["ColumnOrName", ...]]]) -> Column: - """Creates a new array column. + r"""Creates a new array column. .. versionadded:: 1.4.0 @@ -164,7 +164,7 @@ def _to_column_expr(col: ColumnOrName) -> Expression: def regexp_replace(str: "ColumnOrName", pattern: str, replacement: str) -> Column: - """Replace all substrings of the specified string value that match regexp with rep. + r"""Replace all substrings of the specified string value that match regexp with rep. .. versionadded:: 1.5.0 @@ -713,7 +713,7 @@ def asin(col: "ColumnOrName") -> Column: def like(str: "ColumnOrName", pattern: "ColumnOrName", escapeChar: Optional["Column"] = None) -> Column: - """Returns true if str matches `pattern` with `escape`, + r"""Returns true if str matches `pattern` with `escape`, null if any arguments are null, false otherwise. The default escape character is the '\'. @@ -750,7 +750,7 @@ def like(str: "ColumnOrName", pattern: "ColumnOrName", escapeChar: Optional["Col def ilike(str: "ColumnOrName", pattern: "ColumnOrName", escapeChar: Optional["Column"] = None) -> Column: - """Returns true if str matches `pattern` with `escape` case-insensitively, + r"""Returns true if str matches `pattern` with `escape` case-insensitively, null if any arguments are null, false otherwise. The default escape character is the '\'. @@ -2264,7 +2264,7 @@ def pow(col1: Union["ColumnOrName", float], col2: Union["ColumnOrName", float]) def printf(format: "ColumnOrName", *cols: "ColumnOrName") -> Column: - """Formats the arguments in printf-style and returns the result as a string column. + r"""Formats the arguments in printf-style and returns the result as a string column. .. versionadded:: 3.5.0 @@ -2373,7 +2373,7 @@ def rand(seed: Optional[int] = None) -> Column: def regexp(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: - """Returns true if `str` matches the Java regex `regexp`, or false otherwise. + r"""Returns true if `str` matches the Java regex `regexp`, or false otherwise. .. versionadded:: 3.5.0 @@ -2425,7 +2425,7 @@ def regexp(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: def regexp_count(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: - """Returns a count of the number of times that the Java regex pattern `regexp` is matched + r"""Returns a count of the number of times that the Java regex pattern `regexp` is matched in the string `str`. .. versionadded:: 3.5.0 @@ -2456,7 +2456,7 @@ def regexp_count(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: def regexp_extract(str: "ColumnOrName", pattern: str, idx: int) -> Column: - """Extract a specific group matched by the Java regex `regexp`, from the specified string column. + r"""Extract a specific group matched by the Java regex `regexp`, from the specified string column. If the regex did not match, or the specified group did not match, an empty string is returned. .. versionadded:: 1.5.0 @@ -2496,7 +2496,7 @@ def regexp_extract(str: "ColumnOrName", pattern: str, idx: int) -> Column: def regexp_extract_all(str: "ColumnOrName", regexp: "ColumnOrName", idx: Optional[Union[int, Column]] = None) -> Column: - """Extract all strings in the `str` that match the Java regex `regexp` + r"""Extract all strings in the `str` that match the Java regex `regexp` and corresponding to the regex group index. .. versionadded:: 3.5.0 @@ -2535,7 +2535,7 @@ def regexp_extract_all(str: "ColumnOrName", regexp: "ColumnOrName", idx: Optiona def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: - """Returns true if `str` matches the Java regex `regexp`, or false otherwise. + r"""Returns true if `str` matches the Java regex `regexp`, or false otherwise. .. versionadded:: 3.5.0 @@ -2587,7 +2587,7 @@ def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: def regexp_substr(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: - """Returns the substring that matches the Java regex `regexp` within the string `str`. + r"""Returns the substring that matches the Java regex `regexp` within the string `str`. If the regular expression is not found, the result is null. .. versionadded:: 3.5.0 @@ -4274,7 +4274,7 @@ def acos(col: "ColumnOrName") -> Column: def call_function(funcName: str, *cols: "ColumnOrName") -> Column: - """Call a SQL function. + r"""Call a SQL function. .. versionadded:: 3.5.0 @@ -4851,7 +4851,7 @@ def initcap(col: "ColumnOrName") -> Column: def octet_length(col: "ColumnOrName") -> Column: - """Calculates the byte length for the specified string column. + r"""Calculates the byte length for the specified string column. .. versionadded:: 3.3.0 diff --git a/duckdb/experimental/spark/sql/types.py b/duckdb/experimental/spark/sql/types.py index 4418f495..fa961eb1 100644 --- a/duckdb/experimental/spark/sql/types.py +++ b/duckdb/experimental/spark/sql/types.py @@ -736,7 +736,7 @@ def typeName(self) -> str: # type: ignore[override] class StructType(DataType): - """Struct type, consisting of a list of :class:`StructField`. + r"""Struct type, consisting of a list of :class:`StructField`. This is the data type representing a :class:`Row`. @@ -798,7 +798,7 @@ def add( nullable: bool = True, metadata: Optional[dict[str, Any]] = None, ) -> "StructType": - """Construct a :class:`StructType` by adding new elements to it, to define the schema. + r"""Construct a :class:`StructType` by adding new elements to it, to define the schema. The method accepts either: a) A single parameter which is a :class:`StructField` object. From b6fb52a69417647042d48e441812e9781b187885 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Wed, 10 Sep 2025 15:08:47 +0200 Subject: [PATCH 019/135] Fix testfixture yield --- tests/fast/test_expression.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/fast/test_expression.py b/tests/fast/test_expression.py index 049a2a5c..c7207338 100644 --- a/tests/fast/test_expression.py +++ b/tests/fast/test_expression.py @@ -36,7 +36,8 @@ def filter_rel(): ) tbl(a, b) """ ) - return rel + yield rel + con.close() class TestExpression: From 4ff4147b39c94aba93358f95c3c68da07ac58358 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Wed, 10 Sep 2025 15:27:07 +0200 Subject: [PATCH 020/135] Ruff config: disable docstring checks for tests --- pyproject.toml | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 811d9c1b..02e85f5e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -349,9 +349,6 @@ select = [ ] ignore = [] -[tool.ruff.lint.pycodestyle] -max-doc-length = 88 - [tool.ruff.lint.pydocstyle] convention = "google" @@ -361,6 +358,12 @@ ban-relative-imports = "all" [tool.ruff.lint.flake8-type-checking] strict = true +[tool.ruff.lint.per-file-ignores] +"tests/**.py" = [ + # No need for package, module, class, function, init etc docstrings in tests + 'D100', 'D101', 'D102', 'D103', 'D104', 'D105', 'D107' +] + [tool.ruff.format] docstring-code-format = true docstring-code-line-length = 88 From f983b69bc61f75805ba059ac21258c0f511067fd Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Wed, 10 Sep 2025 15:28:29 +0200 Subject: [PATCH 021/135] Ruff noqa D205 on tests: no need to check newline between summary and description in tests --- tests/fast/numpy/test_numpy_new_path.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/fast/numpy/test_numpy_new_path.py b/tests/fast/numpy/test_numpy_new_path.py index d95c93d1..272d5e45 100644 --- a/tests/fast/numpy/test_numpy_new_path.py +++ b/tests/fast/numpy/test_numpy_new_path.py @@ -1,6 +1,6 @@ """The support for scaning over numpy arrays reuses many codes for pandas. Therefore, we only test the new codes and exec paths. -""" +""" # noqa: D205 from datetime import timedelta From 12c0f483691f4f11d6c944b844f7358fe5440023 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Wed, 10 Sep 2025 15:30:34 +0200 Subject: [PATCH 022/135] Ruff noqa D205: dont need the newline in docstrings for existing code --- duckdb/experimental/spark/exception.py | 2 +- duckdb/experimental/spark/sql/column.py | 4 +- duckdb/experimental/spark/sql/dataframe.py | 18 +-- duckdb/experimental/spark/sql/functions.py | 154 ++++++++++----------- duckdb/experimental/spark/sql/group.py | 2 +- duckdb/experimental/spark/sql/types.py | 8 +- sqllogic/conftest.py | 4 +- 7 files changed, 96 insertions(+), 96 deletions(-) diff --git a/duckdb/experimental/spark/exception.py b/duckdb/experimental/spark/exception.py index 1c2ad9a6..24c4f291 100644 --- a/duckdb/experimental/spark/exception.py +++ b/duckdb/experimental/spark/exception.py @@ -2,7 +2,7 @@ class ContributionsAcceptedError(NotImplementedError): """This method is not planned to be implemented, if you would like to implement this method or show your interest in this method to other members of the community, feel free to open up a PR or a Discussion over on https://github.com/duckdb/duckdb. - """ + """ # noqa: D205 def __init__(self, message=None) -> None: doc = self.__class__.__doc__ diff --git a/duckdb/experimental/spark/sql/column.py b/duckdb/experimental/spark/sql/column.py index dd676846..ea6cd8a8 100644 --- a/duckdb/experimental/spark/sql/column.py +++ b/duckdb/experimental/spark/sql/column.py @@ -165,7 +165,7 @@ def __getitem__(self, k: Any) -> "Column": +------------------+------+ | abc| value| +------------------+------+ - """ + """ # noqa: D205 if isinstance(k, slice): raise ContributionsAcceptedError # if k.step is not None: @@ -199,7 +199,7 @@ def __getattr__(self, item: Any) -> "Column": +------+ | value| +------+ - """ + """ # noqa: D205 if item.startswith("__"): msg = "Can not access __ (dunder) method" raise AttributeError(msg) diff --git a/duckdb/experimental/spark/sql/dataframe.py b/duckdb/experimental/spark/sql/dataframe.py index 16d54f0b..99da92ec 100644 --- a/duckdb/experimental/spark/sql/dataframe.py +++ b/duckdb/experimental/spark/sql/dataframe.py @@ -172,7 +172,7 @@ def withColumns(self, *colsMap: dict[str, Column]) -> "DataFrame": | 2|Alice| 4| 5| | 5| Bob| 7| 8| +---+-----+----+----+ - """ + """ # noqa: D205 # Below code is to help enable kwargs in future. assert len(colsMap) == 1 colsMap = colsMap[0] # type: ignore[assignment] @@ -250,7 +250,7 @@ def withColumnsRenamed(self, colsMap: dict[str, str]) -> "DataFrame": | 2|Alice| 4| 5| | 5| Bob| 7| 8| +---+-----+----+----+ - """ + """ # noqa: D205 if not isinstance(colsMap, dict): raise PySparkTypeError( error_class="NOT_DICT", @@ -974,7 +974,7 @@ def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc] | Bob| 2| 2| | Bob| 5| 1| +-----+---+-----+ - """ + """ # noqa: D205 from .group import GroupedData, Grouping if len(cols) == 1 and isinstance(cols[0], list): @@ -1034,7 +1034,7 @@ def union(self, other: "DataFrame") -> "DataFrame": | 1| 2| 3| | 1| 2| 3| +----+----+----+ - """ + """ # noqa: D205 return DataFrame(self.relation.union(other.relation), self.session) unionAll = union @@ -1094,7 +1094,7 @@ def unionByName(self, other: "DataFrame", allowMissingColumns: bool = False) -> | 1| 2| 3|NULL| |NULL| 4| 5| 6| +----+----+----+----+ - """ + """ # noqa: D205 if allowMissingColumns: cols = [] for col in self.relation.columns: @@ -1144,7 +1144,7 @@ def intersect(self, other: "DataFrame") -> "DataFrame": | b| 3| | a| 1| +---+---+ - """ + """ # noqa: D205 return self.intersectAll(other).drop_duplicates() def intersectAll(self, other: "DataFrame") -> "DataFrame": @@ -1181,7 +1181,7 @@ def intersectAll(self, other: "DataFrame") -> "DataFrame": | a| 1| | b| 3| +---+---+ - """ + """ # noqa: D205 return DataFrame(self.relation.intersect(other.relation), self.session) def exceptAll(self, other: "DataFrame") -> "DataFrame": @@ -1221,7 +1221,7 @@ def exceptAll(self, other: "DataFrame") -> "DataFrame": | c| 4| +---+---+ - """ + """ # noqa: D205 return DataFrame(self.relation.except_(other.relation), self.session) def dropDuplicates(self, subset: Optional[list[str]] = None) -> "DataFrame": @@ -1275,7 +1275,7 @@ def dropDuplicates(self, subset: Optional[list[str]] = None) -> "DataFrame": +-----+---+------+ |Alice| 5| 80| +-----+---+------+ - """ + """ # noqa: D205 if subset: rn_col = f"tmp_col_{uuid.uuid1().hex}" subset_str = ", ".join([f'"{c}"' for c in subset]) diff --git a/duckdb/experimental/spark/sql/functions.py b/duckdb/experimental/spark/sql/functions.py index 92631ee8..7ae923f4 100644 --- a/duckdb/experimental/spark/sql/functions.py +++ b/duckdb/experimental/spark/sql/functions.py @@ -25,7 +25,7 @@ def _invoke_function_over_columns(name: str, *cols: "ColumnOrName") -> Column: """Invokes n-ary JVM function identified by name and wraps the result with :class:`~pyspark.sql.Column`. - """ + """ # noqa: D205 cols = [_to_column_expr(expr) for expr in cols] return _invoke_function(name, *cols) @@ -211,7 +211,7 @@ def slice(x: "ColumnOrName", start: Union["ColumnOrName", int], length: Union["C >>> df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ["x"]) >>> df.select(slice(df.x, 2, 2).alias("sliced")).collect() [Row(sliced=[2, 3]), Row(sliced=[5])] - """ + """ # noqa: D205 start = ConstantExpression(start) if isinstance(start, int) else _to_column_expr(start) length = ConstantExpression(length) if isinstance(length, int) else _to_column_expr(length) @@ -302,7 +302,7 @@ def asc_nulls_first(col: "ColumnOrName") -> Column: | 1| Bob| +---+-----+ - """ + """ # noqa: D205 return asc(col).nulls_first() @@ -337,7 +337,7 @@ def asc_nulls_last(col: "ColumnOrName") -> Column: | 0| NULL| +---+-----+ - """ + """ # noqa: D205 return asc(col).nulls_last() @@ -408,7 +408,7 @@ def desc_nulls_first(col: "ColumnOrName") -> Column: | 2|Alice| +---+-----+ - """ + """ # noqa: D205 return desc(col).nulls_first() @@ -443,7 +443,7 @@ def desc_nulls_last(col: "ColumnOrName") -> Column: | 0| NULL| +---+-----+ - """ + """ # noqa: D205 return desc(col).nulls_last() @@ -473,7 +473,7 @@ def left(str: "ColumnOrName", len: "ColumnOrName") -> Column: ... ) >>> df.select(left(df.a, df.b).alias("r")).collect() [Row(r='Spa')] - """ + """ # noqa: D205 len = _to_column_expr(len) return Column( CaseExpression(len <= ConstantExpression(0), ConstantExpression("")).otherwise( @@ -508,7 +508,7 @@ def right(str: "ColumnOrName", len: "ColumnOrName") -> Column: ... ) >>> df.select(right(df.a, df.b).alias("r")).collect() [Row(r='SQL')] - """ + """ # noqa: D205 len = _to_column_expr(len) return Column( CaseExpression(len <= ConstantExpression(0), ConstantExpression("")).otherwise( @@ -741,7 +741,7 @@ def like(str: "ColumnOrName", pattern: "ColumnOrName", escapeChar: Optional["Col ... ) >>> df.select(like(df.a, df.b, lit("/")).alias("r")).collect() [Row(r=True)] - """ + """ # noqa: D205 if escapeChar is None: escapeChar = ConstantExpression("\\") else: @@ -778,7 +778,7 @@ def ilike(str: "ColumnOrName", pattern: "ColumnOrName", escapeChar: Optional["Co ... ) >>> df.select(ilike(df.a, df.b, lit("/")).alias("r")).collect() [Row(r=True)] - """ + """ # noqa: D205 if escapeChar is None: escapeChar = ConstantExpression("\\") else: @@ -872,7 +872,7 @@ def array_append(col: "ColumnOrName", value: Any) -> Column: [Row(array_append(c1, c2)=['b', 'a', 'c', 'c'])] >>> df.select(array_append(df.c1, "x")).collect() [Row(array_append(c1, x)=['b', 'a', 'c', 'x'])] - """ + """ # noqa: D205 return _invoke_function("list_append", _to_column_expr(col), _get_expr(value)) @@ -912,7 +912,7 @@ def array_insert(arr: "ColumnOrName", pos: Union["ColumnOrName", int], value: An [Row(data=['a', 'd', 'b', 'c']), Row(data=['c', 'b', 'd', 'a'])] >>> df.select(array_insert(df.data, 5, "hello").alias("data")).collect() [Row(data=['a', 'b', 'c', None, 'hello']), Row(data=['c', 'b', 'a', None, 'hello'])] - """ + """ # noqa: D205 pos = _get_expr(pos) arr = _to_column_expr(arr) # Depending on if the position is positive or not, we need to interpret it differently. @@ -992,7 +992,7 @@ def array_contains(col: "ColumnOrName", value: Any) -> Column: [Row(array_contains(data, a)=True), Row(array_contains(data, a)=False)] >>> df.select(array_contains(df.data, lit("a"))).collect() [Row(array_contains(data, a)=True), Row(array_contains(data, a)=False)] - """ + """ # noqa: D205 value = _get_expr(value) return _invoke_function("array_contains", _to_column_expr(col), value) @@ -1051,7 +1051,7 @@ def array_intersect(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) >>> df.select(array_intersect(df.c1, df.c2)).collect() [Row(array_intersect(c1, c2)=['a', 'c'])] - """ + """ # noqa: D205 return _invoke_function_over_columns("array_intersect", col1, col2) @@ -1082,7 +1082,7 @@ def array_union(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) >>> df.select(array_union(df.c1, df.c2)).collect() [Row(array_union(c1, c2)=['b', 'a', 'c', 'd', 'f'])] - """ + """ # noqa: D205 return _invoke_function_over_columns("array_distinct", _invoke_function_over_columns("array_concat", col1, col2)) @@ -1265,7 +1265,7 @@ def mean(col: "ColumnOrName") -> Column: +-------+ | 4.5| +-------+ - """ + """ # noqa: D205 return _invoke_function_over_columns("mean", col) @@ -1479,7 +1479,7 @@ def approx_count_distinct(col: "ColumnOrName", rsd: Optional[float] = None) -> C +---------------+ | 3| +---------------+ - """ + """ # noqa: D205 if rsd is not None: msg = "rsd is not supported by DuckDB" raise ValueError(msg) @@ -1588,7 +1588,7 @@ def concat_ws(sep: str, *cols: "ColumnOrName") -> "Column": >>> df = spark.createDataFrame([("abcd", "123")], ["s", "d"]) >>> df.select(concat_ws("-", df.s, df.d).alias("s")).collect() [Row(s='abcd-123')] - """ + """ # noqa: D205 cols = [_to_column_expr(expr) for expr in cols] return _invoke_function("concat_ws", ConstantExpression(sep), *cols) @@ -1854,7 +1854,7 @@ def equal_null(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: ... ) >>> df.select(equal_null(df.a, df.b).alias("r")).collect() [Row(r=True), Row(r=False)] - """ + """ # noqa: D205 if isinstance(col1, str): col1 = col(col1) @@ -1901,7 +1901,7 @@ def flatten(col: "ColumnOrName") -> Column: |[1, 2, 3, 4, 5, 6]| | NULL| +------------------+ - """ + """ # noqa: D205 col = _to_column_expr(col) contains_null = _list_contains_null(col) return Column(CaseExpression(contains_null, None).otherwise(FunctionExpression("flatten", col))) @@ -2077,7 +2077,7 @@ def char(col: "ColumnOrName") -> Column: +--------+ | A| +--------+ - """ + """ # noqa: D205 col = _to_column_expr(col) return Column(FunctionExpression("chr", CaseExpression(col > 256, col % 256).otherwise(col))) @@ -2110,7 +2110,7 @@ def corr(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: >>> df = spark.createDataFrame(zip(a, b), ["a", "b"]) >>> df.agg(corr("a", "b").alias("c")).collect() [Row(c=1.0)] - """ + """ # noqa: D205 return _invoke_function_over_columns("corr", col1, col2) @@ -2183,7 +2183,7 @@ def negative(col: "ColumnOrName") -> Column: | -1| | -2| +------------+ - """ + """ # noqa: D205 return abs(col) * -1 @@ -2364,7 +2364,7 @@ def rand(seed: Optional[int] = None) -> Column: | 0|1.8575681106759028| | 1|1.5288056527339444| +---+------------------+ - """ + """ # noqa: D205 if seed is not None: # Maybe call setseed just before but how do we know when it is executed? msg = "Seed is not yet implemented" @@ -2451,7 +2451,7 @@ def regexp_count(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: [Row(d=0)] >>> df.select(regexp_count("str", col("regexp")).alias("d")).collect() [Row(d=3)] - """ + """ # noqa: D205 return _invoke_function_over_columns("len", _invoke_function_over_columns("regexp_extract_all", str, regexp)) @@ -2489,7 +2489,7 @@ def regexp_extract(str: "ColumnOrName", pattern: str, idx: int) -> Column: >>> df = spark.createDataFrame([("aaaac",)], ["str"]) >>> df.select(regexp_extract("str", "(a+)(b)?(c)", 2).alias("d")).collect() [Row(d='')] - """ + """ # noqa: D205 return _invoke_function( "regexp_extract", _to_column_expr(str), ConstantExpression(pattern), ConstantExpression(idx) ) @@ -2526,7 +2526,7 @@ def regexp_extract_all(str: "ColumnOrName", regexp: "ColumnOrName", idx: Optiona [Row(d=['200', '400'])] >>> df.select(regexp_extract_all("str", col("regexp")).alias("d")).collect() [Row(d=['100', '300'])] - """ + """ # noqa: D205 if idx is None: idx = 1 return _invoke_function( @@ -2613,7 +2613,7 @@ def regexp_substr(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: [Row(d=None)] >>> df.select(regexp_substr("str", col("regexp")).alias("d")).collect() [Row(d='1')] - """ + """ # noqa: D205 return Column( FunctionExpression( "nullif", @@ -2689,7 +2689,7 @@ def sequence(start: "ColumnOrName", stop: "ColumnOrName", step: Optional["Column >>> df2 = spark.createDataFrame([(4, -4, -2)], ("C1", "C2", "C3")) >>> df2.select(sequence("C1", "C2", "C3").alias("r")).collect() [Row(r=[4, 2, 0, -2, -4])] - """ + """ # noqa: D205 if step is None: return _invoke_function_over_columns("generate_series", start, stop) else: @@ -2843,7 +2843,7 @@ def encode(col: "ColumnOrName", charset: str) -> Column: +----------------+ | [61 62 63 64]| +----------------+ - """ + """ # noqa: D205 if charset != "UTF-8": msg = "Only UTF-8 charset is supported right now" raise ContributionsAcceptedError(msg) @@ -2869,7 +2869,7 @@ def find_in_set(str: "ColumnOrName", str_array: "ColumnOrName") -> Column: >>> df = spark.createDataFrame([("ab", "abc,b,ab,c,def")], ["a", "b"]) >>> df.select(find_in_set(df.a, df.b).alias("r")).collect() [Row(r=3)] - """ + """ # noqa: D205 str_array = _to_column_expr(str_array) str = _to_column_expr(str) return Column( @@ -3019,7 +3019,7 @@ def greatest(*cols: "ColumnOrName") -> Column: >>> df = spark.createDataFrame([(1, 4, 3)], ["a", "b", "c"]) >>> df.select(greatest(df.a, df.b, df.c).alias("greatest")).collect() [Row(greatest=4)] - """ + """ # noqa: D205 if len(cols) < 2: msg = "greatest should take at least 2 columns" raise ValueError(msg) @@ -3052,7 +3052,7 @@ def least(*cols: "ColumnOrName") -> Column: >>> df = spark.createDataFrame([(1, 4, 3)], ["a", "b", "c"]) >>> df.select(least(df.a, df.b, df.c).alias("least")).collect() [Row(least=1)] - """ + """ # noqa: D205 if len(cols) < 2: msg = "least should take at least 2 columns" raise ValueError(msg) @@ -3244,7 +3244,7 @@ def endswith(str: "ColumnOrName", suffix: "ColumnOrName") -> Column: +--------------+--------------+ | true| false| +--------------+--------------+ - """ + """ # noqa: D205 return _invoke_function_over_columns("suffix", str, suffix) @@ -3296,7 +3296,7 @@ def startswith(str: "ColumnOrName", prefix: "ColumnOrName") -> Column: +----------------+----------------+ | true| false| +----------------+----------------+ - """ + """ # noqa: D205 return _invoke_function_over_columns("starts_with", str, prefix) @@ -3324,7 +3324,7 @@ def length(col: "ColumnOrName") -> Column: -------- >>> spark.createDataFrame([("ABC ",)], ["a"]).select(length("a").alias("length")).collect() [Row(length=4)] - """ + """ # noqa: D205 return _invoke_function_over_columns("length", col) @@ -3370,7 +3370,7 @@ def coalesce(*cols: "ColumnOrName") -> Column: | 1|NULL| 1.0| |NULL| 2| 0.0| +----+----+----------------+ - """ + """ # noqa: D205 cols = [_to_column_expr(expr) for expr in cols] return Column(CoalesceOperator(*cols)) @@ -3400,7 +3400,7 @@ def nvl(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: ... ) >>> df.select(nvl(df.a, df.b).alias("r")).collect() [Row(r=8), Row(r=1)] - """ + """ # noqa: D205 return coalesce(col1, col2) @@ -3460,7 +3460,7 @@ def ifnull(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: | 8| | 1| +------------+ - """ + """ # noqa: D205 return coalesce(col1, col2) @@ -3554,7 +3554,7 @@ def sha2(col: "ColumnOrName", numBits: int) -> Column: |Alice|3bc51062973c458d5a6f2d8d64a023246354ad7e064b1e4e009ec8a0699a3043| |Bob |cd9fb1e148ccd8442e5aa74904cc73bf6fb54d1d54d333bd596aa9bb4bb4e961| +-----+----------------------------------------------------------------+ - """ + """ # noqa: D205 if numBits not in {224, 256, 384, 512, 0}: msg = "numBits should be one of {224, 256, 384, 512, 0}" raise ValueError(msg) @@ -3586,7 +3586,7 @@ def curdate() -> Column: +--------------+ | 2022-08-26| +--------------+ - """ + """ # noqa: D205 return _invoke_function("today") @@ -3613,7 +3613,7 @@ def current_date() -> Column: +--------------+ | 2022-08-26| +--------------+ - """ + """ # noqa: D205 return curdate() @@ -4018,7 +4018,7 @@ def dayofweek(col: "ColumnOrName") -> Column: >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) >>> df.select(dayofweek("dt").alias("day")).collect() [Row(day=4)] - """ + """ # noqa: D205 return _invoke_function_over_columns("dayofweek", col) + lit(1) @@ -4209,7 +4209,7 @@ def weekofyear(col: "ColumnOrName") -> Column: >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) >>> df.select(weekofyear(df.dt).alias("week")).collect() [Row(week=15)] - """ + """ # noqa: D205 return _invoke_function_over_columns("weekofyear", col) @@ -4367,7 +4367,7 @@ def covar_pop(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: >>> df = spark.createDataFrame(zip(a, b), ["a", "b"]) >>> df.agg(covar_pop("a", "b").alias("c")).collect() [Row(c=0.0)] - """ + """ # noqa: D205 return _invoke_function_over_columns("covar_pop", col1, col2) @@ -4399,7 +4399,7 @@ def covar_samp(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: >>> df = spark.createDataFrame(zip(a, b), ["a", "b"]) >>> df.agg(covar_samp("a", "b").alias("c")).collect() [Row(c=0.0)] - """ + """ # noqa: D205 return _invoke_function_over_columns("covar_samp", col1, col2) @@ -4545,7 +4545,7 @@ def degrees(col: "ColumnOrName") -> Column: >>> df = spark.range(1) >>> df.select(degrees(lit(math.pi))).first() Row(DEGREES(3.14159...)=180.0) - """ + """ # noqa: D205 return _invoke_function_over_columns("degrees", col) @@ -4573,7 +4573,7 @@ def radians(col: "ColumnOrName") -> Column: >>> df = spark.range(1) >>> df.select(radians(lit(180))).first() Row(RADIANS(180)=3.14159...) - """ + """ # noqa: D205 return _invoke_function_over_columns("radians", col) @@ -4698,7 +4698,7 @@ def round(col: "ColumnOrName", scale: int = 0) -> Column: -------- >>> spark.createDataFrame([(2.5,)], ["a"]).select(round("a", 0).alias("r")).collect() [Row(r=3.0)] - """ + """ # noqa: D205 return _invoke_function_over_columns("round", col, lit(scale)) @@ -4727,7 +4727,7 @@ def bround(col: "ColumnOrName", scale: int = 0) -> Column: -------- >>> spark.createDataFrame([(2.5,)], ["a"]).select(bround("a", 0).alias("r")).collect() [Row(r=2.0)] - """ + """ # noqa: D205 return _invoke_function_over_columns("round_even", col, lit(scale)) @@ -4796,7 +4796,7 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: +----------------------+ | a| +----------------------+ - """ + """ # noqa: D205 index = ConstantExpression(index) if isinstance(index, int) else _to_column_expr(index) # Spark uses 0-indexing, DuckDB 1-indexing index = index + 1 @@ -5029,7 +5029,7 @@ def add_months(start: "ColumnOrName", months: Union["ColumnOrName", int]) -> Col [Row(next_month=datetime.date(2015, 6, 8))] >>> df.select(add_months("dt", -2).alias("prev_month")).collect() [Row(prev_month=datetime.date(2015, 2, 8))] - """ + """ # noqa: D205 months = ConstantExpression(months) if isinstance(months, int) else _to_column_expr(months) return _invoke_function("date_add", _to_column_expr(start), FunctionExpression("to_months", months)).cast("date") @@ -5064,7 +5064,7 @@ def array_join(col: "ColumnOrName", delimiter: str, null_replacement: Optional[s [Row(joined='a,b,c'), Row(joined='a')] >>> df.select(array_join(df.data, ",", "NULL").alias("joined")).collect() [Row(joined='a,b,c'), Row(joined='a,NULL')] - """ + """ # noqa: D205 col = _to_column_expr(col) if null_replacement is not None: col = FunctionExpression( @@ -5111,7 +5111,7 @@ def array_position(col: "ColumnOrName", value: Any) -> Column: >>> df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ["data"]) >>> df.select(array_position(df.data, "a")).collect() [Row(array_position(data, a)=3), Row(array_position(data, a)=0)] - """ + """ # noqa: D205 return Column( CoalesceOperator( _to_column_expr(_invoke_function_over_columns("list_position", col, lit(value))), ConstantExpression(0) @@ -5143,7 +5143,7 @@ def array_prepend(col: "ColumnOrName", value: Any) -> Column: >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ["data"]) >>> df.select(array_prepend(df.data, 1)).collect() [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] - """ + """ # noqa: D205 return _invoke_function_over_columns("list_prepend", lit(value), col) @@ -5247,7 +5247,7 @@ def array_sort(col: "ColumnOrName", comparator: Optional[Callable[[Column, Colum ... ).alias("r") ... ).collect() [Row(r=['foobar', 'foo', None, 'bar']), Row(r=['foo']), Row(r=[])] - """ + """ # noqa: D205 if comparator is not None: msg = "comparator is not yet supported" raise ContributionsAcceptedError(msg) @@ -5286,7 +5286,7 @@ def sort_array(col: "ColumnOrName", asc: bool = True) -> Column: [Row(r=[None, 1, 2, 3]), Row(r=[1]), Row(r=[])] >>> df.select(sort_array(df.data, asc=False).alias("r")).collect() [Row(r=[3, 2, 1, None]), Row(r=[1]), Row(r=[])] - """ + """ # noqa: D205 if asc: order = "ASC" null_order = "NULLS FIRST" @@ -5381,7 +5381,7 @@ def split_part(src: "ColumnOrName", delimiter: "ColumnOrName", partNum: "ColumnO ... ) >>> df.select(split_part(df.a, df.b, df.c).alias("r")).collect() [Row(r='13')] - """ + """ # noqa: D205 src = _to_column_expr(src) delimiter = _to_column_expr(delimiter) partNum = _to_column_expr(partNum) @@ -5422,7 +5422,7 @@ def stddev_samp(col: "ColumnOrName") -> Column: +------------------+ |1.8708286933869...| +------------------+ - """ + """ # noqa: D205 return _invoke_function_over_columns("stddev_samp", col) @@ -5513,7 +5513,7 @@ def stddev_pop(col: "ColumnOrName") -> Column: +-----------------+ |1.707825127659...| +-----------------+ - """ + """ # noqa: D205 return _invoke_function_over_columns("stddev_pop", col) @@ -5572,7 +5572,7 @@ def var_samp(col: "ColumnOrName") -> Column: +------------+ | 3.5| +------------+ - """ + """ # noqa: D205 return _invoke_function_over_columns("var_samp", col) @@ -5701,7 +5701,7 @@ def to_date(col: "ColumnOrName", format: Optional[str] = None) -> Column: >>> df = spark.createDataFrame([("1997-02-28 10:30:00",)], ["t"]) >>> df.select(to_date(df.t, "yyyy-MM-dd HH:mm:ss").alias("date")).collect() [Row(date=datetime.date(1997, 2, 28))] - """ + """ # noqa: D205 return _to_date_or_timestamp(col, _types.DateType(), format) @@ -5739,7 +5739,7 @@ def to_timestamp(col: "ColumnOrName", format: Optional[str] = None) -> Column: >>> df = spark.createDataFrame([("1997-02-28 10:30:00",)], ["t"]) >>> df.select(to_timestamp(df.t, "yyyy-MM-dd HH:mm:ss").alias("dt")).collect() [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))] - """ + """ # noqa: D205 return _to_date_or_timestamp(col, _types.TimestampNTZType(), format) @@ -5770,7 +5770,7 @@ def to_timestamp_ltz( >>> df.select(to_timestamp_ltz(df.e).alias("r")).collect() ... # doctest: +SKIP [Row(r=datetime.datetime(2016, 12, 31, 0, 0))] - """ + """ # noqa: D205 return _to_date_or_timestamp(timestamp, _types.TimestampNTZType(), format) @@ -5801,7 +5801,7 @@ def to_timestamp_ntz( >>> df.select(to_timestamp_ntz(df.e).alias("r")).collect() ... # doctest: +SKIP [Row(r=datetime.datetime(2016, 4, 8, 0, 0))] - """ + """ # noqa: D205 return _to_date_or_timestamp(timestamp, _types.TimestampNTZType(), format) @@ -5824,7 +5824,7 @@ def try_to_timestamp(col: "ColumnOrName", format: Optional["ColumnOrName"] = Non [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))] >>> df.select(try_to_timestamp(df.t, lit("yyyy-MM-dd HH:mm:ss")).alias("dt")).collect() [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))] - """ + """ # noqa: D205 if format is None: format = lit(["%Y-%m-%d", "%Y-%m-%d %H:%M:%S"]) @@ -5881,7 +5881,7 @@ def substr(str: "ColumnOrName", pos: "ColumnOrName", len: Optional["ColumnOrName +------------------------+ | k SQL| +------------------------+ - """ + """ # noqa: D205 if len is not None: return _invoke_function_over_columns("substring", str, pos, len) else: @@ -5939,7 +5939,7 @@ def unix_millis(col: "ColumnOrName") -> Column: >>> df.select(unix_millis(to_timestamp(df.t)).alias("n")).collect() [Row(n=1437584400000)] >>> spark.conf.unset("spark.sql.session.timeZone") - """ + """ # noqa: D205 return _unix_diff(col, "milliseconds") @@ -5956,7 +5956,7 @@ def unix_seconds(col: "ColumnOrName") -> Column: >>> df.select(unix_seconds(to_timestamp(df.t)).alias("n")).collect() [Row(n=1437584400)] >>> spark.conf.unset("spark.sql.session.timeZone") - """ + """ # noqa: D205 return _unix_diff(col, "seconds") @@ -5980,7 +5980,7 @@ def arrays_overlap(a1: "ColumnOrName", a2: "ColumnOrName") -> Column: >>> df = spark.createDataFrame([(["a", "b"], ["b", "c"]), (["a"], ["b", "c"])], ["x", "y"]) >>> df.select(arrays_overlap(df.x, df.y).alias("overlap")).collect() [Row(overlap=True), Row(overlap=False)] - """ + """ # noqa: D205 a1 = _to_column_expr(a1) a2 = _to_column_expr(a2) @@ -6045,7 +6045,7 @@ def arrays_zip(*cols: "ColumnOrName") -> Column: | | |-- vals1: long (nullable = true) | | |-- vals2: long (nullable = true) | | |-- vals3: long (nullable = true) - """ + """ # noqa: D205 return _invoke_function_over_columns("list_zip", *cols) @@ -6084,7 +6084,7 @@ def substring(str: "ColumnOrName", pos: int, len: int) -> Column: ... ) >>> df.select(substring(df.s, 1, 2).alias("s")).collect() [Row(s='ab')] - """ + """ # noqa: D205 return _invoke_function( "substring", _to_column_expr(str), @@ -6130,7 +6130,7 @@ def contains(left: "ColumnOrName", right: "ColumnOrName") -> Column: +--------------+--------------+ | true| false| +--------------+--------------+ - """ + """ # noqa: D205 return _invoke_function_over_columns("contains", left, right) @@ -6157,7 +6157,7 @@ def reverse(col: "ColumnOrName") -> Column: >>> df = spark.createDataFrame([([2, 1, 3],), ([1],), ([],)], ["data"]) >>> df.select(reverse(df.data).alias("r")).collect() [Row(r=[3, 1, 2]), Row(r=[1]), Row(r=[])] - """ + """ # noqa: D205 return _invoke_function("reverse", _to_column_expr(col)) @@ -6197,7 +6197,7 @@ def concat(*cols: "ColumnOrName") -> Column: [Row(arr=[1, 2, 3, 4, 5]), Row(arr=None)] >>> df DataFrame[arr: array] - """ + """ # noqa: D205 return _invoke_function_over_columns("concat", *cols) @@ -6237,7 +6237,7 @@ def instr(str: "ColumnOrName", substr: str) -> Column: ... ) >>> df.select(instr(df.s, "b").alias("s")).collect() [Row(s=2)] - """ + """ # noqa: D205 return _invoke_function("instr", _to_column_expr(str), ConstantExpression(substr)) @@ -6278,5 +6278,5 @@ def broadcast(df: "DataFrame") -> "DataFrame": dataset to all the worker nodes. However, DuckDB operates on a single-node architecture . As a result, the function simply returns the input DataFrame without applying any modifications or optimizations, since broadcasting is not applicable in the DuckDB context. - """ + """ # noqa: D205 return df diff --git a/duckdb/experimental/spark/sql/group.py b/duckdb/experimental/spark/sql/group.py index 7aa9eb11..ab8e89cf 100644 --- a/duckdb/experimental/spark/sql/group.py +++ b/duckdb/experimental/spark/sql/group.py @@ -76,7 +76,7 @@ class GroupedData: """A set of methods for aggregations on a :class:`DataFrame`, created by :func:`DataFrame.groupBy`. - """ + """ # noqa: D205 def __init__(self, grouping: Grouping, df: DataFrame) -> None: self._grouping = grouping diff --git a/duckdb/experimental/spark/sql/types.py b/duckdb/experimental/spark/sql/types.py index fa961eb1..606f792c 100644 --- a/duckdb/experimental/spark/sql/types.py +++ b/duckdb/experimental/spark/sql/types.py @@ -140,7 +140,7 @@ def typeName(cls) -> str: class AtomicType(DataType): """An internal type used to represent everything that is not null, UDTs, arrays, structs, and maps. - """ + """ # noqa: D205 class NumericType(AtomicType): @@ -836,7 +836,7 @@ def add( >>> struct2 = StructType([StructField("f1", StringType(), True)]) >>> struct1 == struct2 True - """ + """ # noqa: D205 if isinstance(field, StructField): self.fields.append(field) self.names.append(field.name) @@ -996,7 +996,7 @@ def module(cls) -> str: def scalaUDT(cls) -> str: """The class name of the paired Scala UDT (could be '', if there is no corresponding one). - """ + """ # noqa: D205 return "" def needConversion(self) -> bool: @@ -1125,7 +1125,7 @@ class Row(tuple): >>> row2 = Row(name="Alice", age=11) >>> row1 == row2 True - """ + """ # noqa: D205 @overload def __new__(cls, *args: str) -> "Row": ... diff --git a/sqllogic/conftest.py b/sqllogic/conftest.py index 40759e9c..48315109 100644 --- a/sqllogic/conftest.py +++ b/sqllogic/conftest.py @@ -129,7 +129,7 @@ def create_parameters_from_paths(paths, root_dir: pathlib.Path, config: pytest.C def scan_for_test_scripts(root_dir: pathlib.Path, config: pytest.Config) -> typing.Iterator[typing.Any]: """Scans for .test files in the given directory and its subdirectories. Returns an iterator of pytest parameters (argument, id and marks). - """ + """ # noqa: D205 # TODO: Add tests from extensions test_script_extensions = [".test", ".test_slow", ".test_coverage"] it = itertools.chain.from_iterable(root_dir.rglob(f"*{ext}") for ext in test_script_extensions) @@ -169,7 +169,7 @@ def determine_test_offsets(config: pytest.Config, num_tests: int) -> tuple[int, start_offset defaults to 0. end_offset defaults to and is capped to the last test index. start_offset_percentage and end_offset_percentage are used to calculate the start and end offsets based on the total number of tests. This is done in a way that a test run to 25% and another test run starting at 25% do not overlap by excluding the 25th percent test. - """ + """ # noqa: D205 start_offset = config.getoption("start_offset") end_offset = config.getoption("end_offset") start_offset_percentage = config.getoption("start_offset_percentage") From c0a94d25cc9a52f324b19812fe49ec7e4c756e7a Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Wed, 10 Sep 2025 15:31:52 +0200 Subject: [PATCH 023/135] Ruff config: will not check sqlogic --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 02e85f5e..b596bf60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -316,7 +316,7 @@ line-length = 120 indent-width = 4 target-version = "py39" fix = true -exclude = ['external/duckdb'] +exclude = ['external/duckdb', 'sqllogic'] [tool.ruff.lint] fixable = ["ALL"] From 5cea343428ba33a7ad307ab83a92cff123c2cbf1 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Wed, 10 Sep 2025 15:33:29 +0200 Subject: [PATCH 024/135] Ruff config: will not check scripts for D --- pyproject.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index b596bf60..525cfd2a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -363,6 +363,10 @@ strict = true # No need for package, module, class, function, init etc docstrings in tests 'D100', 'D101', 'D102', 'D103', 'D104', 'D105', 'D107' ] +"scripts/**.py" = [ + # No need for package, module, class, function, init etc docstrings in scripts + 'D100', 'D101', 'D102', 'D103', 'D104', 'D105', 'D107', 'D205' +] [tool.ruff.format] docstring-code-format = true From a1675c86e22382bbf8ec9f0cd8ea3ea68e7c4c6f Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Wed, 10 Sep 2025 15:34:42 +0200 Subject: [PATCH 025/135] Ruff noqa D10x: ignore docstring issues in existing code --- duckdb/__init__.py | 2 +- duckdb/bytes_io_wrapper.py | 10 +- duckdb/experimental/__init__.py | 2 +- duckdb/experimental/spark/__init__.py | 2 +- duckdb/experimental/spark/conf.py | 28 +-- duckdb/experimental/spark/context.py | 60 ++--- .../spark/errors/error_classes.py | 2 +- .../spark/errors/exceptions/__init__.py | 2 +- .../spark/errors/exceptions/base.py | 6 +- duckdb/experimental/spark/errors/utils.py | 4 +- duckdb/experimental/spark/exception.py | 4 +- duckdb/experimental/spark/sql/__init__.py | 2 +- duckdb/experimental/spark/sql/catalog.py | 24 +- duckdb/experimental/spark/sql/column.py | 28 +-- duckdb/experimental/spark/sql/conf.py | 14 +- duckdb/experimental/spark/sql/dataframe.py | 36 +-- duckdb/experimental/spark/sql/functions.py | 30 +-- duckdb/experimental/spark/sql/group.py | 14 +- duckdb/experimental/spark/sql/readwriter.py | 22 +- duckdb/experimental/spark/sql/session.py | 52 ++-- duckdb/experimental/spark/sql/streaming.py | 14 +- duckdb/experimental/spark/sql/type_utils.py | 8 +- duckdb/experimental/spark/sql/types.py | 228 +++++++++--------- duckdb/experimental/spark/sql/udf.py | 12 +- duckdb/filesystem.py | 8 +- duckdb/functional/__init__.py | 2 +- duckdb/polars_io.py | 2 +- duckdb/query_graph/__main__.py | 48 ++-- duckdb/typing/__init__.py | 2 +- duckdb/udf.py | 2 +- duckdb/value/__init__.py | 1 + duckdb/value/constant/__init__.py | 120 ++++----- duckdb_packaging/pypi_cleanup.py | 8 +- 33 files changed, 400 insertions(+), 399 deletions(-) diff --git a/duckdb/__init__.py b/duckdb/__init__.py index 73fcbbd2..8d6d68aa 100644 --- a/duckdb/__init__.py +++ b/duckdb/__init__.py @@ -1,4 +1,4 @@ -# Modules +# Modules # noqa: D104 from importlib.metadata import version from _duckdb import __version__ as duckdb_version diff --git a/duckdb/bytes_io_wrapper.py b/duckdb/bytes_io_wrapper.py index 763fd8b7..9851ad65 100644 --- a/duckdb/bytes_io_wrapper.py +++ b/duckdb/bytes_io_wrapper.py @@ -1,4 +1,4 @@ -from io import StringIO, TextIOBase +from io import StringIO, TextIOBase # noqa: D100 from typing import Any, Union """ @@ -36,10 +36,10 @@ """ -class BytesIOWrapper: +class BytesIOWrapper: # noqa: D101 # Wrapper that wraps a StringIO buffer and reads bytes from it # Created for compat with pyarrow read_csv - def __init__(self, buffer: Union[StringIO, TextIOBase], encoding: str = "utf-8") -> None: + def __init__(self, buffer: Union[StringIO, TextIOBase], encoding: str = "utf-8") -> None: # noqa: D107 self.buffer = buffer self.encoding = encoding # Because a character can be represented by more than 1 byte, @@ -48,10 +48,10 @@ def __init__(self, buffer: Union[StringIO, TextIOBase], encoding: str = "utf-8") # overflow to the front of the bytestring the next time reading is performed self.overflow = b"" - def __getattr__(self, attr: str) -> Any: + def __getattr__(self, attr: str) -> Any: # noqa: D105 return getattr(self.buffer, attr) - def read(self, n: Union[int, None] = -1) -> bytes: + def read(self, n: Union[int, None] = -1) -> bytes: # noqa: D102 assert self.buffer is not None bytestring = self.buffer.read(n).encode(self.encoding) # When n=-1/n greater than remaining bytes: Read entire file/rest of file diff --git a/duckdb/experimental/__init__.py b/duckdb/experimental/__init__.py index a88a6170..1b5ee51b 100644 --- a/duckdb/experimental/__init__.py +++ b/duckdb/experimental/__init__.py @@ -1,3 +1,3 @@ -from . import spark +from . import spark # noqa: D104 __all__ = spark.__all__ diff --git a/duckdb/experimental/spark/__init__.py b/duckdb/experimental/spark/__init__.py index bdde2ef8..7e56d4b1 100644 --- a/duckdb/experimental/spark/__init__.py +++ b/duckdb/experimental/spark/__init__.py @@ -1,4 +1,4 @@ -from ._globals import _NoValue +from ._globals import _NoValue # noqa: D104 from .conf import SparkConf from .context import SparkContext from .exception import ContributionsAcceptedError diff --git a/duckdb/experimental/spark/conf.py b/duckdb/experimental/spark/conf.py index ea1153b4..974115d6 100644 --- a/duckdb/experimental/spark/conf.py +++ b/duckdb/experimental/spark/conf.py @@ -1,45 +1,45 @@ -from typing import Optional +from typing import Optional # noqa: D100 from duckdb.experimental.spark.exception import ContributionsAcceptedError -class SparkConf: - def __init__(self) -> None: +class SparkConf: # noqa: D101 + def __init__(self) -> None: # noqa: D107 raise NotImplementedError - def contains(self, key: str) -> bool: + def contains(self, key: str) -> bool: # noqa: D102 raise ContributionsAcceptedError - def get(self, key: str, defaultValue: Optional[str] = None) -> Optional[str]: + def get(self, key: str, defaultValue: Optional[str] = None) -> Optional[str]: # noqa: D102 raise ContributionsAcceptedError - def getAll(self) -> list[tuple[str, str]]: + def getAll(self) -> list[tuple[str, str]]: # noqa: D102 raise ContributionsAcceptedError - def set(self, key: str, value: str) -> "SparkConf": + def set(self, key: str, value: str) -> "SparkConf": # noqa: D102 raise ContributionsAcceptedError - def setAll(self, pairs: list[tuple[str, str]]) -> "SparkConf": + def setAll(self, pairs: list[tuple[str, str]]) -> "SparkConf": # noqa: D102 raise ContributionsAcceptedError - def setAppName(self, value: str) -> "SparkConf": + def setAppName(self, value: str) -> "SparkConf": # noqa: D102 raise ContributionsAcceptedError - def setExecutorEnv( + def setExecutorEnv( # noqa: D102 self, key: Optional[str] = None, value: Optional[str] = None, pairs: Optional[list[tuple[str, str]]] = None ) -> "SparkConf": raise ContributionsAcceptedError - def setIfMissing(self, key: str, value: str) -> "SparkConf": + def setIfMissing(self, key: str, value: str) -> "SparkConf": # noqa: D102 raise ContributionsAcceptedError - def setMaster(self, value: str) -> "SparkConf": + def setMaster(self, value: str) -> "SparkConf": # noqa: D102 raise ContributionsAcceptedError - def setSparkHome(self, value: str) -> "SparkConf": + def setSparkHome(self, value: str) -> "SparkConf": # noqa: D102 raise ContributionsAcceptedError - def toDebugString(self) -> str: + def toDebugString(self) -> str: # noqa: D102 raise ContributionsAcceptedError diff --git a/duckdb/experimental/spark/context.py b/duckdb/experimental/spark/context.py index 9f1b4155..9835fcea 100644 --- a/duckdb/experimental/spark/context.py +++ b/duckdb/experimental/spark/context.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional # noqa: D100 import duckdb from duckdb import DuckDBPyConnection @@ -6,37 +6,37 @@ from duckdb.experimental.spark.exception import ContributionsAcceptedError -class SparkContext: - def __init__(self, master: str) -> None: +class SparkContext: # noqa: D101 + def __init__(self, master: str) -> None: # noqa: D107 self._connection = duckdb.connect(":memory:") # This aligns the null ordering with Spark. self._connection.execute("set default_null_order='nulls_first_on_asc_last_on_desc'") @property - def connection(self) -> DuckDBPyConnection: + def connection(self) -> DuckDBPyConnection: # noqa: D102 return self._connection - def stop(self) -> None: + def stop(self) -> None: # noqa: D102 self._connection.close() @classmethod - def getOrCreate(cls, conf: Optional[SparkConf] = None) -> "SparkContext": + def getOrCreate(cls, conf: Optional[SparkConf] = None) -> "SparkContext": # noqa: D102 raise ContributionsAcceptedError @classmethod - def setSystemProperty(cls, key: str, value: str) -> None: + def setSystemProperty(cls, key: str, value: str) -> None: # noqa: D102 raise ContributionsAcceptedError @property - def applicationId(self) -> str: + def applicationId(self) -> str: # noqa: D102 raise ContributionsAcceptedError @property - def defaultMinPartitions(self) -> int: + def defaultMinPartitions(self) -> int: # noqa: D102 raise ContributionsAcceptedError @property - def defaultParallelism(self) -> int: + def defaultParallelism(self) -> int: # noqa: D102 raise ContributionsAcceptedError # @property @@ -44,30 +44,30 @@ def defaultParallelism(self) -> int: # raise ContributionsAcceptedError @property - def startTime(self) -> str: + def startTime(self) -> str: # noqa: D102 raise ContributionsAcceptedError @property - def uiWebUrl(self) -> str: + def uiWebUrl(self) -> str: # noqa: D102 raise ContributionsAcceptedError @property - def version(self) -> str: + def version(self) -> str: # noqa: D102 raise ContributionsAcceptedError - def __repr__(self) -> str: + def __repr__(self) -> str: # noqa: D105 raise ContributionsAcceptedError # def accumulator(self, value: ~T, accum_param: Optional[ForwardRef('AccumulatorParam[T]')] = None) -> 'Accumulator[T]': # pass - def addArchive(self, path: str) -> None: + def addArchive(self, path: str) -> None: # noqa: D102 raise ContributionsAcceptedError - def addFile(self, path: str, recursive: bool = False) -> None: + def addFile(self, path: str, recursive: bool = False) -> None: # noqa: D102 raise ContributionsAcceptedError - def addPyFile(self, path: str) -> None: + def addPyFile(self, path: str) -> None: # noqa: D102 raise ContributionsAcceptedError # def binaryFiles(self, path: str, minPartitions: Optional[int] = None) -> duckdb.experimental.spark.rdd.RDD[typing.Tuple[str, bytes]]: @@ -79,25 +79,25 @@ def addPyFile(self, path: str) -> None: # def broadcast(self, value: ~T) -> 'Broadcast[T]': # pass - def cancelAllJobs(self) -> None: + def cancelAllJobs(self) -> None: # noqa: D102 raise ContributionsAcceptedError - def cancelJobGroup(self, groupId: str) -> None: + def cancelJobGroup(self, groupId: str) -> None: # noqa: D102 raise ContributionsAcceptedError - def dump_profiles(self, path: str) -> None: + def dump_profiles(self, path: str) -> None: # noqa: D102 raise ContributionsAcceptedError # def emptyRDD(self) -> duckdb.experimental.spark.rdd.RDD[typing.Any]: # pass - def getCheckpointDir(self) -> Optional[str]: + def getCheckpointDir(self) -> Optional[str]: # noqa: D102 raise ContributionsAcceptedError - def getConf(self) -> SparkConf: + def getConf(self) -> SparkConf: # noqa: D102 raise ContributionsAcceptedError - def getLocalProperty(self, key: str) -> Optional[str]: + def getLocalProperty(self, key: str) -> Optional[str]: # noqa: D102 raise ContributionsAcceptedError # def hadoopFile(self, path: str, inputFormatClass: str, keyClass: str, valueClass: str, keyConverter: Optional[str] = None, valueConverter: Optional[str] = None, conf: Optional[Dict[str, str]] = None, batchSize: int = 0) -> pyspark.rdd.RDD[typing.Tuple[~T, ~U]]: @@ -127,25 +127,25 @@ def getLocalProperty(self, key: str) -> Optional[str]: # def sequenceFile(self, path: str, keyClass: Optional[str] = None, valueClass: Optional[str] = None, keyConverter: Optional[str] = None, valueConverter: Optional[str] = None, minSplits: Optional[int] = None, batchSize: int = 0) -> pyspark.rdd.RDD[typing.Tuple[~T, ~U]]: # pass - def setCheckpointDir(self, dirName: str) -> None: + def setCheckpointDir(self, dirName: str) -> None: # noqa: D102 raise ContributionsAcceptedError - def setJobDescription(self, value: str) -> None: + def setJobDescription(self, value: str) -> None: # noqa: D102 raise ContributionsAcceptedError - def setJobGroup(self, groupId: str, description: str, interruptOnCancel: bool = False) -> None: + def setJobGroup(self, groupId: str, description: str, interruptOnCancel: bool = False) -> None: # noqa: D102 raise ContributionsAcceptedError - def setLocalProperty(self, key: str, value: str) -> None: + def setLocalProperty(self, key: str, value: str) -> None: # noqa: D102 raise ContributionsAcceptedError - def setLogLevel(self, logLevel: str) -> None: + def setLogLevel(self, logLevel: str) -> None: # noqa: D102 raise ContributionsAcceptedError - def show_profiles(self) -> None: + def show_profiles(self) -> None: # noqa: D102 raise ContributionsAcceptedError - def sparkUser(self) -> str: + def sparkUser(self) -> str: # noqa: D102 raise ContributionsAcceptedError # def statusTracker(self) -> duckdb.experimental.spark.status.StatusTracker: diff --git a/duckdb/experimental/spark/errors/error_classes.py b/duckdb/experimental/spark/errors/error_classes.py index 256fb644..55cea14d 100644 --- a/duckdb/experimental/spark/errors/error_classes.py +++ b/duckdb/experimental/spark/errors/error_classes.py @@ -1,4 +1,4 @@ -# +# # noqa: D100 # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. diff --git a/duckdb/experimental/spark/errors/exceptions/__init__.py b/duckdb/experimental/spark/errors/exceptions/__init__.py index cce3acad..edd0e7e1 100644 --- a/duckdb/experimental/spark/errors/exceptions/__init__.py +++ b/duckdb/experimental/spark/errors/exceptions/__init__.py @@ -1,4 +1,4 @@ -# +# # noqa: D104 # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. diff --git a/duckdb/experimental/spark/errors/exceptions/base.py b/duckdb/experimental/spark/errors/exceptions/base.py index 0b2c6a43..2eae2a19 100644 --- a/duckdb/experimental/spark/errors/exceptions/base.py +++ b/duckdb/experimental/spark/errors/exceptions/base.py @@ -1,4 +1,4 @@ -from typing import Optional, cast +from typing import Optional, cast # noqa: D100 from ..utils import ErrorClassesReader @@ -6,7 +6,7 @@ class PySparkException(Exception): """Base Exception for handling errors generated from PySpark.""" - def __init__( + def __init__( # noqa: D107 self, message: Optional[str] = None, # The error class, decides the message format, must be one of the valid options listed in 'error_classes.py' @@ -69,7 +69,7 @@ def getSqlState(self) -> None: """ return None - def __str__(self) -> str: + def __str__(self) -> str: # noqa: D105 if self.getErrorClass() is not None: return f"[{self.getErrorClass()}] {self.message}" else: diff --git a/duckdb/experimental/spark/errors/utils.py b/duckdb/experimental/spark/errors/utils.py index 984504a4..8a71f3b0 100644 --- a/duckdb/experimental/spark/errors/utils.py +++ b/duckdb/experimental/spark/errors/utils.py @@ -1,4 +1,4 @@ -# +# # noqa: D100 # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. @@ -23,7 +23,7 @@ class ErrorClassesReader: """A reader to load error information from error_classes.py.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 self.error_info_map = ERROR_CLASSES_MAP def get_error_message(self, error_class: str, message_parameters: dict[str, str]) -> str: diff --git a/duckdb/experimental/spark/exception.py b/duckdb/experimental/spark/exception.py index 24c4f291..3973d9c4 100644 --- a/duckdb/experimental/spark/exception.py +++ b/duckdb/experimental/spark/exception.py @@ -1,10 +1,10 @@ -class ContributionsAcceptedError(NotImplementedError): +class ContributionsAcceptedError(NotImplementedError): # noqa: D100 """This method is not planned to be implemented, if you would like to implement this method or show your interest in this method to other members of the community, feel free to open up a PR or a Discussion over on https://github.com/duckdb/duckdb. """ # noqa: D205 - def __init__(self, message=None) -> None: + def __init__(self, message=None) -> None: # noqa: D107 doc = self.__class__.__doc__ if message: doc = message + "\n" + doc diff --git a/duckdb/experimental/spark/sql/__init__.py b/duckdb/experimental/spark/sql/__init__.py index 9ae09308..418273f0 100644 --- a/duckdb/experimental/spark/sql/__init__.py +++ b/duckdb/experimental/spark/sql/__init__.py @@ -1,4 +1,4 @@ -from .catalog import Catalog +from .catalog import Catalog # noqa: D104 from .conf import RuntimeConfig from .dataframe import DataFrame from .readwriter import DataFrameWriter diff --git a/duckdb/experimental/spark/sql/catalog.py b/duckdb/experimental/spark/sql/catalog.py index 8e510fdf..27e6fbb0 100644 --- a/duckdb/experimental/spark/sql/catalog.py +++ b/duckdb/experimental/spark/sql/catalog.py @@ -1,15 +1,15 @@ -from typing import NamedTuple, Optional +from typing import NamedTuple, Optional # noqa: D100 from .session import SparkSession -class Database(NamedTuple): +class Database(NamedTuple): # noqa: D101 name: str description: Optional[str] locationUri: str -class Table(NamedTuple): +class Table(NamedTuple): # noqa: D101 name: str database: Optional[str] description: Optional[str] @@ -17,7 +17,7 @@ class Table(NamedTuple): isTemporary: bool -class Column(NamedTuple): +class Column(NamedTuple): # noqa: D101 name: str description: Optional[str] dataType: str @@ -26,18 +26,18 @@ class Column(NamedTuple): isBucket: bool -class Function(NamedTuple): +class Function(NamedTuple): # noqa: D101 name: str description: Optional[str] className: str isTemporary: bool -class Catalog: - def __init__(self, session: SparkSession) -> None: +class Catalog: # noqa: D101 + def __init__(self, session: SparkSession) -> None: # noqa: D107 self._session = session - def listDatabases(self) -> list[Database]: + def listDatabases(self) -> list[Database]: # noqa: D102 res = self._session.conn.sql("select database_name from duckdb_databases()").fetchall() def transform_to_database(x) -> Database: @@ -46,7 +46,7 @@ def transform_to_database(x) -> Database: databases = [transform_to_database(x) for x in res] return databases - def listTables(self) -> list[Table]: + def listTables(self) -> list[Table]: # noqa: D102 res = self._session.conn.sql("select table_name, database_name, sql, temporary from duckdb_tables()").fetchall() def transform_to_table(x) -> Table: @@ -55,7 +55,7 @@ def transform_to_table(x) -> Table: tables = [transform_to_table(x) for x in res] return tables - def listColumns(self, tableName: str, dbName: Optional[str] = None) -> list[Column]: + def listColumns(self, tableName: str, dbName: Optional[str] = None) -> list[Column]: # noqa: D102 query = f""" select column_name, data_type, is_nullable from duckdb_columns() where table_name = '{tableName}' """ @@ -69,10 +69,10 @@ def transform_to_column(x) -> Column: columns = [transform_to_column(x) for x in res] return columns - def listFunctions(self, dbName: Optional[str] = None) -> list[Function]: + def listFunctions(self, dbName: Optional[str] = None) -> list[Function]: # noqa: D102 raise NotImplementedError - def setCurrentDatabase(self, dbName: str) -> None: + def setCurrentDatabase(self, dbName: str) -> None: # noqa: D102 raise NotImplementedError diff --git a/duckdb/experimental/spark/sql/column.py b/duckdb/experimental/spark/sql/column.py index ea6cd8a8..bc84365a 100644 --- a/duckdb/experimental/spark/sql/column.py +++ b/duckdb/experimental/spark/sql/column.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Callable, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Union, cast # noqa: D100 from ..exception import ContributionsAcceptedError from .types import DataType @@ -93,11 +93,11 @@ class Column: .. versionadded:: 1.3.0 """ - def __init__(self, expr: Expression) -> None: + def __init__(self, expr: Expression) -> None: # noqa: D107 self.expr = expr # arithmetic operators - def __neg__(self) -> "Column": + def __neg__(self) -> "Column": # noqa: D105 return Column(-self.expr) # `and`, `or`, `not` cannot be overloaded in Python, @@ -205,10 +205,10 @@ def __getattr__(self, item: Any) -> "Column": raise AttributeError(msg) return self[item] - def alias(self, alias: str): + def alias(self, alias: str): # noqa: D102 return Column(self.expr.alias(alias)) - def when(self, condition: "Column", value: Any): + def when(self, condition: "Column", value: Any): # noqa: D102 if not isinstance(condition, Column): msg = "condition should be a Column" raise TypeError(msg) @@ -216,12 +216,12 @@ def when(self, condition: "Column", value: Any): expr = self.expr.when(condition.expr, v) return Column(expr) - def otherwise(self, value: Any): + def otherwise(self, value: Any): # noqa: D102 v = _get_expr(value) expr = self.expr.otherwise(v) return Column(expr) - def cast(self, dataType: Union[DataType, str]) -> "Column": + def cast(self, dataType: Union[DataType, str]) -> "Column": # noqa: D102 if isinstance(dataType, str): # Try to construct a default DuckDBPyType from it internal_type = DuckDBPyType(dataType) @@ -229,7 +229,7 @@ def cast(self, dataType: Union[DataType, str]) -> "Column": internal_type = dataType.duckdb_type return Column(self.expr.cast(internal_type)) - def isin(self, *cols: Any) -> "Column": + def isin(self, *cols: Any) -> "Column": # noqa: D102 if len(cols) == 1 and isinstance(cols[0], (list, set)): # Only one argument supplied, it's a list cols = cast("tuple", cols[0]) @@ -345,20 +345,20 @@ def __ne__( # type: ignore[override] nulls_first = _unary_op("nulls_first") nulls_last = _unary_op("nulls_last") - def asc_nulls_first(self) -> "Column": + def asc_nulls_first(self) -> "Column": # noqa: D102 return self.asc().nulls_first() - def asc_nulls_last(self) -> "Column": + def asc_nulls_last(self) -> "Column": # noqa: D102 return self.asc().nulls_last() - def desc_nulls_first(self) -> "Column": + def desc_nulls_first(self) -> "Column": # noqa: D102 return self.desc().nulls_first() - def desc_nulls_last(self) -> "Column": + def desc_nulls_last(self) -> "Column": # noqa: D102 return self.desc().nulls_last() - def isNull(self) -> "Column": + def isNull(self) -> "Column": # noqa: D102 return Column(self.expr.isnull()) - def isNotNull(self) -> "Column": + def isNotNull(self) -> "Column": # noqa: D102 return Column(self.expr.isnotnull()) diff --git a/duckdb/experimental/spark/sql/conf.py b/duckdb/experimental/spark/sql/conf.py index 8ab9fa38..e44f2566 100644 --- a/duckdb/experimental/spark/sql/conf.py +++ b/duckdb/experimental/spark/sql/conf.py @@ -1,23 +1,23 @@ -from typing import Optional, Union +from typing import Optional, Union # noqa: D100 from duckdb import DuckDBPyConnection from duckdb.experimental.spark._globals import _NoValue, _NoValueType -class RuntimeConfig: - def __init__(self, connection: DuckDBPyConnection) -> None: +class RuntimeConfig: # noqa: D101 + def __init__(self, connection: DuckDBPyConnection) -> None: # noqa: D107 self._connection = connection - def set(self, key: str, value: str) -> None: + def set(self, key: str, value: str) -> None: # noqa: D102 raise NotImplementedError - def isModifiable(self, key: str) -> bool: + def isModifiable(self, key: str) -> bool: # noqa: D102 raise NotImplementedError - def unset(self, key: str) -> None: + def unset(self, key: str) -> None: # noqa: D102 raise NotImplementedError - def get(self, key: str, default: Union[Optional[str], _NoValueType] = _NoValue) -> str: + def get(self, key: str, default: Union[Optional[str], _NoValueType] = _NoValue) -> str: # noqa: D102 raise NotImplementedError diff --git a/duckdb/experimental/spark/sql/dataframe.py b/duckdb/experimental/spark/sql/dataframe.py index 99da92ec..8e83822b 100644 --- a/duckdb/experimental/spark/sql/dataframe.py +++ b/duckdb/experimental/spark/sql/dataframe.py @@ -1,4 +1,4 @@ -import uuid +import uuid # noqa: D100 from functools import reduce from keyword import iskeyword from typing import ( @@ -32,18 +32,18 @@ from .functions import _to_column_expr, col, lit -class DataFrame: - def __init__(self, relation: duckdb.DuckDBPyRelation, session: "SparkSession") -> None: +class DataFrame: # noqa: D101 + def __init__(self, relation: duckdb.DuckDBPyRelation, session: "SparkSession") -> None: # noqa: D107 self.relation = relation self.session = session self._schema = None if self.relation is not None: self._schema = duckdb_to_spark_schema(self.relation.columns, self.relation.types) - def show(self, **kwargs) -> None: + def show(self, **kwargs) -> None: # noqa: D102 self.relation.show() - def toPandas(self) -> "PandasDataFrame": + def toPandas(self) -> "PandasDataFrame": # noqa: D102 return self.relation.df() def toArrow(self) -> "pa.Table": @@ -103,10 +103,10 @@ def createOrReplaceTempView(self, name: str) -> None: """ self.relation.create_view(name, True) - def createGlobalTempView(self, name: str) -> None: + def createGlobalTempView(self, name: str) -> None: # noqa: D102 raise NotImplementedError - def withColumnRenamed(self, columnName: str, newName: str) -> "DataFrame": + def withColumnRenamed(self, columnName: str, newName: str) -> "DataFrame": # noqa: D102 if columnName not in self.relation: msg = f"DataFrame does not contain a column named {columnName}" raise ValueError(msg) @@ -119,7 +119,7 @@ def withColumnRenamed(self, columnName: str, newName: str) -> "DataFrame": rel = self.relation.select(*cols) return DataFrame(rel, self.session) - def withColumn(self, columnName: str, col: Column) -> "DataFrame": + def withColumn(self, columnName: str, col: Column) -> "DataFrame": # noqa: D102 if not isinstance(col, Column): raise PySparkTypeError( error_class="NOT_COLUMN", @@ -472,7 +472,7 @@ def sort(self, *cols: Union[str, Column, list[Union[str, Column]]], **kwargs: An orderBy = sort - def head(self, n: Optional[int] = None) -> Union[Optional[Row], list[Row]]: + def head(self, n: Optional[int] = None) -> Union[Optional[Row], list[Row]]: # noqa: D102 if n is None: rs = self.head(1) return rs[0] if rs else None @@ -480,7 +480,7 @@ def head(self, n: Optional[int] = None) -> Union[Optional[Row], list[Row]]: first = head - def take(self, num: int) -> list[Row]: + def take(self, num: int) -> list[Row]: # noqa: D102 return self.limit(num).collect() def filter(self, condition: "ColumnOrName") -> "DataFrame": @@ -547,7 +547,7 @@ def filter(self, condition: "ColumnOrName") -> "DataFrame": where = filter - def select(self, *cols) -> "DataFrame": + def select(self, *cols) -> "DataFrame": # noqa: D102 cols = list(cols) if len(cols) == 1: cols = cols[0] @@ -574,7 +574,7 @@ def _ipython_key_completions_(self) -> list[str]: # when accessed in bracket notation, e.g. df['] return self.columns - def __dir__(self) -> list[str]: + def __dir__(self) -> list[str]: # noqa: D105 out = set(super().__dir__()) out.update(c for c in self.columns if c.isidentifier() and not iskeyword(c)) return sorted(out) @@ -792,7 +792,7 @@ def alias(self, alias: str) -> "DataFrame": assert isinstance(alias, str), "alias should be a string" return DataFrame(self.relation.set_alias(alias), self.session) - def drop(self, *cols: "ColumnOrName") -> "DataFrame": # type: ignore[misc] + def drop(self, *cols: "ColumnOrName") -> "DataFrame": # type: ignore[misc] # noqa: D102 exclude = [] for col in cols: if isinstance(col, str): @@ -809,7 +809,7 @@ def drop(self, *cols: "ColumnOrName") -> "DataFrame": # type: ignore[misc] expr = StarExpression(exclude=exclude) return DataFrame(self.relation.select(expr), self.session) - def __repr__(self) -> str: + def __repr__(self) -> str: # noqa: D105 return str(self.relation) def limit(self, num: int) -> "DataFrame": @@ -986,10 +986,10 @@ def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc] groupby = groupBy @property - def write(self) -> DataFrameWriter: + def write(self) -> DataFrameWriter: # noqa: D102 return DataFrameWriter(self) - def printSchema(self): + def printSchema(self): # noqa: D102 raise ContributionsAcceptedError def union(self, other: "DataFrame") -> "DataFrame": @@ -1339,7 +1339,7 @@ def _cast_types(self, *types) -> "DataFrame": new_rel = self.relation.project(cast_expressions) return DataFrame(new_rel, self.session) - def toDF(self, *cols) -> "DataFrame": + def toDF(self, *cols) -> "DataFrame": # noqa: D102 existing_columns = self.relation.columns column_count = len(cols) if column_count != len(existing_columns): @@ -1350,7 +1350,7 @@ def toDF(self, *cols) -> "DataFrame": new_rel = self.relation.project(*projections) return DataFrame(new_rel, self.session) - def collect(self) -> list[Row]: + def collect(self) -> list[Row]: # noqa: D102 columns = self.relation.columns result = self.relation.fetchall() diff --git a/duckdb/experimental/spark/sql/functions.py b/duckdb/experimental/spark/sql/functions.py index 7ae923f4..30764fe1 100644 --- a/duckdb/experimental/spark/sql/functions.py +++ b/duckdb/experimental/spark/sql/functions.py @@ -1,4 +1,4 @@ -import warnings +import warnings # noqa: D100 from typing import TYPE_CHECKING, Any, Callable, Optional, Union, overload from duckdb import ( @@ -30,7 +30,7 @@ def _invoke_function_over_columns(name: str, *cols: "ColumnOrName") -> Column: return _invoke_function(name, *cols) -def col(column: str): +def col(column: str): # noqa: D103 return Column(ColumnExpression(column)) @@ -90,7 +90,7 @@ def ucase(str: "ColumnOrName") -> Column: return upper(str) -def when(condition: "Column", value: Any) -> Column: +def when(condition: "Column", value: Any) -> Column: # noqa: D103 if not isinstance(condition, Column): msg = "condition should be a Column" raise TypeError(msg) @@ -103,7 +103,7 @@ def _inner_expr_or_val(val): return val.expr if isinstance(val, Column) else val -def struct(*cols: Column) -> Column: +def struct(*cols: Column) -> Column: # noqa: D103 return Column(FunctionExpression("struct_pack", *[_inner_expr_or_val(x) for x in cols])) @@ -143,7 +143,7 @@ def array(*cols: Union["ColumnOrName", Union[list["ColumnOrName"], tuple["Column return _invoke_function_over_columns("list_value", *cols) -def lit(col: Any) -> Column: +def lit(col: Any) -> Column: # noqa: D103 return col if isinstance(col, Column) else Column(ConstantExpression(col)) @@ -1680,7 +1680,7 @@ def ceil(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("ceil", col) -def ceiling(col: "ColumnOrName") -> Column: +def ceiling(col: "ColumnOrName") -> Column: # noqa: D103 return ceil(col) @@ -1854,7 +1854,7 @@ def equal_null(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: ... ) >>> df.select(equal_null(df.a, df.b).alias("r")).collect() [Row(r=True), Row(r=False)] - """ # noqa: D205 + """ # noqa: D205, D415 if isinstance(col1, str): col1 = col(col1) @@ -2183,7 +2183,7 @@ def negative(col: "ColumnOrName") -> Column: | -1| | -2| +------------+ - """ # noqa: D205 + """ # noqa: D205, D415 return abs(col) * -1 @@ -3370,7 +3370,7 @@ def coalesce(*cols: "ColumnOrName") -> Column: | 1|NULL| 1.0| |NULL| 2| 0.0| +----+----+----------------+ - """ # noqa: D205 + """ # noqa: D205, D415 cols = [_to_column_expr(expr) for expr in cols] return Column(CoalesceOperator(*cols)) @@ -3400,7 +3400,7 @@ def nvl(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: ... ) >>> df.select(nvl(df.a, df.b).alias("r")).collect() [Row(r=8), Row(r=1)] - """ # noqa: D205 + """ # noqa: D205, D415 return coalesce(col1, col2) @@ -3460,7 +3460,7 @@ def ifnull(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: | 8| | 1| +------------+ - """ # noqa: D205 + """ # noqa: D205, D415 return coalesce(col1, col2) @@ -5824,7 +5824,7 @@ def try_to_timestamp(col: "ColumnOrName", format: Optional["ColumnOrName"] = Non [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))] >>> df.select(try_to_timestamp(df.t, lit("yyyy-MM-dd HH:mm:ss")).alias("dt")).collect() [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))] - """ # noqa: D205 + """ # noqa: D205, D415 if format is None: format = lit(["%Y-%m-%d", "%Y-%m-%d %H:%M:%S"]) @@ -6130,7 +6130,7 @@ def contains(left: "ColumnOrName", right: "ColumnOrName") -> Column: +--------------+--------------+ | true| false| +--------------+--------------+ - """ # noqa: D205 + """ # noqa: D205, D415 return _invoke_function_over_columns("contains", left, right) @@ -6157,7 +6157,7 @@ def reverse(col: "ColumnOrName") -> Column: >>> df = spark.createDataFrame([([2, 1, 3],), ([1],), ([],)], ["data"]) >>> df.select(reverse(df.data).alias("r")).collect() [Row(r=[3, 1, 2]), Row(r=[1]), Row(r=[])] - """ # noqa: D205 + """ # noqa: D205, D415 return _invoke_function("reverse", _to_column_expr(col)) @@ -6197,7 +6197,7 @@ def concat(*cols: "ColumnOrName") -> Column: [Row(arr=[1, 2, 3, 4, 5]), Row(arr=None)] >>> df DataFrame[arr: array] - """ # noqa: D205 + """ # noqa: D205, D415 return _invoke_function_over_columns("concat", *cols) diff --git a/duckdb/experimental/spark/sql/group.py b/duckdb/experimental/spark/sql/group.py index ab8e89cf..c4222749 100644 --- a/duckdb/experimental/spark/sql/group.py +++ b/duckdb/experimental/spark/sql/group.py @@ -1,4 +1,4 @@ -# +# # noqa: D100 # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. @@ -51,8 +51,8 @@ def _api(self: "GroupedData", *cols: str) -> DataFrame: return _api -class Grouping: - def __init__(self, *cols: "ColumnOrName", **kwargs) -> None: +class Grouping: # noqa: D101 + def __init__(self, *cols: "ColumnOrName", **kwargs) -> None: # noqa: D107 self._type = "" self._cols = [_to_column_expr(x) for x in cols] if "special" in kwargs: @@ -61,11 +61,11 @@ def __init__(self, *cols: "ColumnOrName", **kwargs) -> None: assert special in accepted_special self._type = special - def get_columns(self) -> str: + def get_columns(self) -> str: # noqa: D102 columns = ",".join([str(x) for x in self._cols]) return columns - def __str__(self) -> str: + def __str__(self) -> str: # noqa: D105 columns = self.get_columns() if self._type: return self._type + "(" + columns + ")" @@ -78,12 +78,12 @@ class GroupedData: """ # noqa: D205 - def __init__(self, grouping: Grouping, df: DataFrame) -> None: + def __init__(self, grouping: Grouping, df: DataFrame) -> None: # noqa: D107 self._grouping = grouping self._df = df self.session: SparkSession = df.session - def __repr__(self) -> str: + def __repr__(self) -> str: # noqa: D105 return str(self._df) def count(self) -> DataFrame: diff --git a/duckdb/experimental/spark/sql/readwriter.py b/duckdb/experimental/spark/sql/readwriter.py index 714ed797..eb714833 100644 --- a/duckdb/experimental/spark/sql/readwriter.py +++ b/duckdb/experimental/spark/sql/readwriter.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional, Union, cast +from typing import TYPE_CHECKING, Optional, Union, cast # noqa: D100 from ..errors import PySparkNotImplementedError, PySparkTypeError from ..exception import ContributionsAcceptedError @@ -12,15 +12,15 @@ from duckdb.experimental.spark.sql.session import SparkSession -class DataFrameWriter: - def __init__(self, dataframe: "DataFrame") -> None: +class DataFrameWriter: # noqa: D101 + def __init__(self, dataframe: "DataFrame") -> None: # noqa: D107 self.dataframe = dataframe - def saveAsTable(self, table_name: str) -> None: + def saveAsTable(self, table_name: str) -> None: # noqa: D102 relation = self.dataframe.relation relation.create(table_name) - def parquet( + def parquet( # noqa: D102 self, path: str, mode: Optional[str] = None, @@ -35,7 +35,7 @@ def parquet( relation.write_parquet(path, compression=compression) - def csv( + def csv( # noqa: D102 self, path: str, mode: Optional[str] = None, @@ -86,11 +86,11 @@ def csv( ) -class DataFrameReader: - def __init__(self, session: "SparkSession") -> None: +class DataFrameReader: # noqa: D101 + def __init__(self, session: "SparkSession") -> None: # noqa: D107 self.session = session - def load( + def load( # noqa: D102 self, path: Optional[Union[str, list[str]]] = None, format: Optional[str] = None, @@ -127,7 +127,7 @@ def load( df = df.toDF(names) raise NotImplementedError - def csv( + def csv( # noqa: D102 self, path: Union[str, list[str]], schema: Optional[Union[StructType, str]] = None, @@ -245,7 +245,7 @@ def csv( df = df.toDF(*names) return df - def parquet(self, *paths: str, **options: "OptionalPrimitiveType") -> "DataFrame": + def parquet(self, *paths: str, **options: "OptionalPrimitiveType") -> "DataFrame": # noqa: D102 input = list(paths) if len(input) != 1: msg = "Only single paths are supported for now" diff --git a/duckdb/experimental/spark/sql/session.py b/duckdb/experimental/spark/sql/session.py index 4b919446..8bb6e910 100644 --- a/duckdb/experimental/spark/sql/session.py +++ b/duckdb/experimental/spark/sql/session.py @@ -1,4 +1,4 @@ -import uuid +import uuid # noqa: D100 from collections.abc import Iterable from typing import TYPE_CHECKING, Any, Optional, Union @@ -41,8 +41,8 @@ def _combine_data_and_schema(data: Iterable[Any], schema: StructType): return new_data -class SparkSession: - def __init__(self, context: SparkContext) -> None: +class SparkSession: # noqa: D101 + def __init__(self, context: SparkContext) -> None: # noqa: D107 self.conn = context.connection self._context = context self._conf = RuntimeConfig(self.conn) @@ -121,7 +121,7 @@ def _createDataFrameFromPandas(self, data: "PandasDataFrame", types, names) -> D df = df.toDF(*names) return df - def createDataFrame( + def createDataFrame( # noqa: D102 self, data: Union["PandasDataFrame", Iterable[Any]], schema: Optional[Union[StructType, list[str]]] = None, @@ -184,10 +184,10 @@ def createDataFrame( df = df.toDF(*names) return df - def newSession(self) -> "SparkSession": + def newSession(self) -> "SparkSession": # noqa: D102 return SparkSession(self._context) - def range( + def range( # noqa: D102 self, start: int, end: Optional[int] = None, @@ -203,24 +203,24 @@ def range( return DataFrame(self.conn.table_function("range", parameters=[start, end, step]), self) - def sql(self, sqlQuery: str, **kwargs: Any) -> DataFrame: + def sql(self, sqlQuery: str, **kwargs: Any) -> DataFrame: # noqa: D102 if kwargs: raise NotImplementedError relation = self.conn.sql(sqlQuery) return DataFrame(relation, self) - def stop(self) -> None: + def stop(self) -> None: # noqa: D102 self._context.stop() - def table(self, tableName: str) -> DataFrame: + def table(self, tableName: str) -> DataFrame: # noqa: D102 relation = self.conn.table(tableName) return DataFrame(relation, self) - def getActiveSession(self) -> "SparkSession": + def getActiveSession(self) -> "SparkSession": # noqa: D102 return self @property - def catalog(self) -> "Catalog": + def catalog(self) -> "Catalog": # noqa: D102 if not hasattr(self, "_catalog"): from duckdb.experimental.spark.sql.catalog import Catalog @@ -228,59 +228,59 @@ def catalog(self) -> "Catalog": return self._catalog @property - def conf(self) -> RuntimeConfig: + def conf(self) -> RuntimeConfig: # noqa: D102 return self._conf @property - def read(self) -> DataFrameReader: + def read(self) -> DataFrameReader: # noqa: D102 return DataFrameReader(self) @property - def readStream(self) -> DataStreamReader: + def readStream(self) -> DataStreamReader: # noqa: D102 return DataStreamReader(self) @property - def sparkContext(self) -> SparkContext: + def sparkContext(self) -> SparkContext: # noqa: D102 return self._context @property - def streams(self) -> Any: + def streams(self) -> Any: # noqa: D102 raise ContributionsAcceptedError @property - def udf(self) -> UDFRegistration: + def udf(self) -> UDFRegistration: # noqa: D102 return UDFRegistration(self) @property - def version(self) -> str: + def version(self) -> str: # noqa: D102 return "1.0.0" - class Builder: - def __init__(self) -> None: + class Builder: # noqa: D106 + def __init__(self) -> None: # noqa: D107 pass - def master(self, name: str) -> "SparkSession.Builder": + def master(self, name: str) -> "SparkSession.Builder": # noqa: D102 # no-op return self - def appName(self, name: str) -> "SparkSession.Builder": + def appName(self, name: str) -> "SparkSession.Builder": # noqa: D102 # no-op return self - def remote(self, url: str) -> "SparkSession.Builder": + def remote(self, url: str) -> "SparkSession.Builder": # noqa: D102 # no-op return self - def getOrCreate(self) -> "SparkSession": + def getOrCreate(self) -> "SparkSession": # noqa: D102 context = SparkContext("__ignored__") return SparkSession(context) - def config( + def config( # noqa: D102 self, key: Optional[str] = None, value: Optional[Any] = None, conf: Optional[SparkConf] = None ) -> "SparkSession.Builder": return self - def enableHiveSupport(self) -> "SparkSession.Builder": + def enableHiveSupport(self) -> "SparkSession.Builder": # noqa: D102 # no-op return self diff --git a/duckdb/experimental/spark/sql/streaming.py b/duckdb/experimental/spark/sql/streaming.py index 201b889b..08b7cc30 100644 --- a/duckdb/experimental/spark/sql/streaming.py +++ b/duckdb/experimental/spark/sql/streaming.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Optional, Union # noqa: D100 from .types import StructType @@ -10,20 +10,20 @@ OptionalPrimitiveType = Optional[PrimitiveType] -class DataStreamWriter: - def __init__(self, dataframe: "DataFrame") -> None: +class DataStreamWriter: # noqa: D101 + def __init__(self, dataframe: "DataFrame") -> None: # noqa: D107 self.dataframe = dataframe - def toTable(self, table_name: str) -> None: + def toTable(self, table_name: str) -> None: # noqa: D102 # Should we register the dataframe or create a table from the contents? raise NotImplementedError -class DataStreamReader: - def __init__(self, session: "SparkSession") -> None: +class DataStreamReader: # noqa: D101 + def __init__(self, session: "SparkSession") -> None: # noqa: D107 self.session = session - def load( + def load( # noqa: D102 self, path: Optional[str] = None, format: Optional[str] = None, diff --git a/duckdb/experimental/spark/sql/type_utils.py b/duckdb/experimental/spark/sql/type_utils.py index 446eac97..1773eb9e 100644 --- a/duckdb/experimental/spark/sql/type_utils.py +++ b/duckdb/experimental/spark/sql/type_utils.py @@ -1,4 +1,4 @@ -from typing import cast +from typing import cast # noqa: D100 from duckdb.typing import DuckDBPyType @@ -74,7 +74,7 @@ } -def convert_nested_type(dtype: DuckDBPyType) -> DataType: +def convert_nested_type(dtype: DuckDBPyType) -> DataType: # noqa: D103 id = dtype.id if id == "list" or id == "array": children = dtype.children @@ -89,7 +89,7 @@ def convert_nested_type(dtype: DuckDBPyType) -> DataType: raise NotImplementedError -def convert_type(dtype: DuckDBPyType) -> DataType: +def convert_type(dtype: DuckDBPyType) -> DataType: # noqa: D103 id = dtype.id if id in ["list", "struct", "map", "array"]: return convert_nested_type(dtype) @@ -102,6 +102,6 @@ def convert_type(dtype: DuckDBPyType) -> DataType: return spark_type() -def duckdb_to_spark_schema(names: list[str], types: list[DuckDBPyType]) -> StructType: +def duckdb_to_spark_schema(names: list[str], types: list[DuckDBPyType]) -> StructType: # noqa: D103 fields = [StructField(name, dtype) for name, dtype in zip(names, [convert_type(x) for x in types])] return StructType(fields) diff --git a/duckdb/experimental/spark/sql/types.py b/duckdb/experimental/spark/sql/types.py index 606f792c..ad74cd98 100644 --- a/duckdb/experimental/spark/sql/types.py +++ b/duckdb/experimental/spark/sql/types.py @@ -1,4 +1,4 @@ -# This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'spark' folder. +# This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'spark' folder. # noqa: D100 import calendar import datetime @@ -66,32 +66,32 @@ class DataType: """Base class for data types.""" - def __init__(self, duckdb_type) -> None: + def __init__(self, duckdb_type) -> None: # noqa: D107 self.duckdb_type = duckdb_type - def __repr__(self) -> str: + def __repr__(self) -> str: # noqa: D105 return self.__class__.__name__ + "()" - def __hash__(self) -> int: + def __hash__(self) -> int: # noqa: D105 return hash(str(self)) - def __eq__(self, other: object) -> bool: + def __eq__(self, other: object) -> bool: # noqa: D105 return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - def __ne__(self, other: object) -> bool: + def __ne__(self, other: object) -> bool: # noqa: D105 return not self.__eq__(other) @classmethod - def typeName(cls) -> str: + def typeName(cls) -> str: # noqa: D102 return cls.__name__[:-4].lower() - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return self.typeName() - def jsonValue(self) -> Union[str, dict[str, Any]]: + def jsonValue(self) -> Union[str, dict[str, Any]]: # noqa: D102 raise ContributionsAcceptedError - def json(self) -> str: + def json(self) -> str: # noqa: D102 raise ContributionsAcceptedError def needConversion(self) -> bool: @@ -129,11 +129,11 @@ class NullType(DataType, metaclass=DataTypeSingleton): The data type representing None, used for the types that cannot be inferred. """ - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("NULL")) @classmethod - def typeName(cls) -> str: + def typeName(cls) -> str: # noqa: D102 return "void" @@ -158,54 +158,54 @@ class FractionalType(NumericType): class StringType(AtomicType, metaclass=DataTypeSingleton): """String data type.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("VARCHAR")) class BitstringType(AtomicType, metaclass=DataTypeSingleton): """Bitstring data type.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("BIT")) class UUIDType(AtomicType, metaclass=DataTypeSingleton): """UUID data type.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("UUID")) class BinaryType(AtomicType, metaclass=DataTypeSingleton): """Binary (byte array) data type.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("BLOB")) class BooleanType(AtomicType, metaclass=DataTypeSingleton): """Boolean data type.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("BOOLEAN")) class DateType(AtomicType, metaclass=DataTypeSingleton): """Date (datetime.date) data type.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("DATE")) EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal() - def needConversion(self) -> bool: + def needConversion(self) -> bool: # noqa: D102 return True - def toInternal(self, d: datetime.date) -> int: + def toInternal(self, d: datetime.date) -> int: # noqa: D102 if d is not None: return d.toordinal() - self.EPOCH_ORDINAL - def fromInternal(self, v: int) -> datetime.date: + def fromInternal(self, v: int) -> datetime.date: # noqa: D102 if v is not None: return datetime.date.fromordinal(v + self.EPOCH_ORDINAL) @@ -213,22 +213,22 @@ def fromInternal(self, v: int) -> datetime.date: class TimestampType(AtomicType, metaclass=DataTypeSingleton): """Timestamp (datetime.datetime) data type.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("TIMESTAMPTZ")) @classmethod - def typeName(cls) -> str: + def typeName(cls) -> str: # noqa: D102 return "timestamptz" - def needConversion(self) -> bool: + def needConversion(self) -> bool: # noqa: D102 return True - def toInternal(self, dt: datetime.datetime) -> int: + def toInternal(self, dt: datetime.datetime) -> int: # noqa: D102 if dt is not None: seconds = calendar.timegm(dt.utctimetuple()) if dt.tzinfo else time.mktime(dt.timetuple()) return int(seconds) * 1000000 + dt.microsecond - def fromInternal(self, ts: int) -> datetime.datetime: + def fromInternal(self, ts: int) -> datetime.datetime: # noqa: D102 if ts is not None: # using int to avoid precision loss in float return datetime.datetime.fromtimestamp(ts // 1000000).replace(microsecond=ts % 1000000) @@ -237,22 +237,22 @@ def fromInternal(self, ts: int) -> datetime.datetime: class TimestampNTZType(AtomicType, metaclass=DataTypeSingleton): """Timestamp (datetime.datetime) data type without timezone information with microsecond precision.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("TIMESTAMP")) - def needConversion(self) -> bool: + def needConversion(self) -> bool: # noqa: D102 return True @classmethod - def typeName(cls) -> str: + def typeName(cls) -> str: # noqa: D102 return "timestamp" - def toInternal(self, dt: datetime.datetime) -> int: + def toInternal(self, dt: datetime.datetime) -> int: # noqa: D102 if dt is not None: seconds = calendar.timegm(dt.timetuple()) return int(seconds) * 1000000 + dt.microsecond - def fromInternal(self, ts: int) -> datetime.datetime: + def fromInternal(self, ts: int) -> datetime.datetime: # noqa: D102 if ts is not None: # using int to avoid precision loss in float return datetime.datetime.utcfromtimestamp(ts // 1000000).replace(microsecond=ts % 1000000) @@ -261,60 +261,60 @@ def fromInternal(self, ts: int) -> datetime.datetime: class TimestampSecondNTZType(AtomicType, metaclass=DataTypeSingleton): """Timestamp (datetime.datetime) data type without timezone information with second precision.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("TIMESTAMP_S")) - def needConversion(self) -> bool: + def needConversion(self) -> bool: # noqa: D102 return True @classmethod - def typeName(cls) -> str: + def typeName(cls) -> str: # noqa: D102 return "timestamp_s" - def toInternal(self, dt: datetime.datetime) -> int: + def toInternal(self, dt: datetime.datetime) -> int: # noqa: D102 raise ContributionsAcceptedError - def fromInternal(self, ts: int) -> datetime.datetime: + def fromInternal(self, ts: int) -> datetime.datetime: # noqa: D102 raise ContributionsAcceptedError class TimestampMilisecondNTZType(AtomicType, metaclass=DataTypeSingleton): """Timestamp (datetime.datetime) data type without timezone information with milisecond precision.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("TIMESTAMP_MS")) - def needConversion(self) -> bool: + def needConversion(self) -> bool: # noqa: D102 return True @classmethod - def typeName(cls) -> str: + def typeName(cls) -> str: # noqa: D102 return "timestamp_ms" - def toInternal(self, dt: datetime.datetime) -> int: + def toInternal(self, dt: datetime.datetime) -> int: # noqa: D102 raise ContributionsAcceptedError - def fromInternal(self, ts: int) -> datetime.datetime: + def fromInternal(self, ts: int) -> datetime.datetime: # noqa: D102 raise ContributionsAcceptedError class TimestampNanosecondNTZType(AtomicType, metaclass=DataTypeSingleton): """Timestamp (datetime.datetime) data type without timezone information with nanosecond precision.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("TIMESTAMP_NS")) - def needConversion(self) -> bool: + def needConversion(self) -> bool: # noqa: D102 return True @classmethod - def typeName(cls) -> str: + def typeName(cls) -> str: # noqa: D102 return "timestamp_ns" - def toInternal(self, dt: datetime.datetime) -> int: + def toInternal(self, dt: datetime.datetime) -> int: # noqa: D102 raise ContributionsAcceptedError - def fromInternal(self, ts: int) -> datetime.datetime: + def fromInternal(self, ts: int) -> datetime.datetime: # noqa: D102 raise ContributionsAcceptedError @@ -338,90 +338,90 @@ class DecimalType(FractionalType): the number of digits on right side of dot. (default: 0) """ - def __init__(self, precision: int = 10, scale: int = 0) -> None: + def __init__(self, precision: int = 10, scale: int = 0) -> None: # noqa: D107 super().__init__(duckdb.decimal_type(precision, scale)) self.precision = precision self.scale = scale self.hasPrecisionInfo = True # this is a public API - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "decimal(%d,%d)" % (self.precision, self.scale) - def __repr__(self) -> str: + def __repr__(self) -> str: # noqa: D105 return "DecimalType(%d,%d)" % (self.precision, self.scale) class DoubleType(FractionalType, metaclass=DataTypeSingleton): """Double data type, representing double precision floats.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("DOUBLE")) class FloatType(FractionalType, metaclass=DataTypeSingleton): """Float data type, representing single precision floats.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("FLOAT")) class ByteType(IntegralType): """Byte data type, i.e. a signed integer in a single byte.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("TINYINT")) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "tinyint" class UnsignedByteType(IntegralType): """Unsigned byte data type, i.e. a unsigned integer in a single byte.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("UTINYINT")) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "utinyint" class ShortType(IntegralType): """Short data type, i.e. a signed 16-bit integer.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("SMALLINT")) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "smallint" class UnsignedShortType(IntegralType): """Unsigned short data type, i.e. a unsigned 16-bit integer.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("USMALLINT")) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "usmallint" class IntegerType(IntegralType): """Int data type, i.e. a signed 32-bit integer.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("INTEGER")) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "integer" class UnsignedIntegerType(IntegralType): """Unsigned int data type, i.e. a unsigned 32-bit integer.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("UINTEGER")) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "uinteger" @@ -432,10 +432,10 @@ class LongType(IntegralType): please use :class:`DecimalType`. """ - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("BIGINT")) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "bigint" @@ -446,10 +446,10 @@ class UnsignedLongType(IntegralType): please use :class:`HugeIntegerType`. """ - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("UBIGINT")) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "ubigint" @@ -460,10 +460,10 @@ class HugeIntegerType(IntegralType): please use :class:`DecimalType`. """ - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("HUGEINT")) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "hugeint" @@ -474,30 +474,30 @@ class UnsignedHugeIntegerType(IntegralType): please use :class:`DecimalType`. """ - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("UHUGEINT")) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "uhugeint" class TimeType(IntegralType): """Time (datetime.time) data type.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("TIMETZ")) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "timetz" class TimeNTZType(IntegralType): """Time (datetime.time) data type without timezone information.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("TIME")) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "time" @@ -518,7 +518,7 @@ class DayTimeIntervalType(AtomicType): _inverted_fields = dict(zip(_fields.values(), _fields.keys())) - def __init__(self, startField: Optional[int] = None, endField: Optional[int] = None) -> None: + def __init__(self, startField: Optional[int] = None, endField: Optional[int] = None) -> None: # noqa: D107 super().__init__(DuckDBPyType("INTERVAL")) if startField is None and endField is None: # Default matched to scala side. @@ -544,17 +544,17 @@ def _str_repr(self) -> str: simpleString = _str_repr - def __repr__(self) -> str: + def __repr__(self) -> str: # noqa: D105 return "%s(%d, %d)" % (type(self).__name__, self.startField, self.endField) - def needConversion(self) -> bool: + def needConversion(self) -> bool: # noqa: D102 return True - def toInternal(self, dt: datetime.timedelta) -> Optional[int]: + def toInternal(self, dt: datetime.timedelta) -> Optional[int]: # noqa: D102 if dt is not None: return (math.floor(dt.total_seconds()) * 1000000) + dt.microseconds - def fromInternal(self, micros: int) -> Optional[datetime.timedelta]: + def fromInternal(self, micros: int) -> Optional[datetime.timedelta]: # noqa: D102 if micros is not None: return datetime.timedelta(microseconds=micros) @@ -577,7 +577,7 @@ class ArrayType(DataType): False """ - def __init__(self, elementType: DataType, containsNull: bool = True) -> None: + def __init__(self, elementType: DataType, containsNull: bool = True) -> None: # noqa: D107 super().__init__(duckdb.list_type(elementType.duckdb_type)) assert isinstance(elementType, DataType), "elementType %s should be an instance of %s" % ( elementType, @@ -586,21 +586,21 @@ def __init__(self, elementType: DataType, containsNull: bool = True) -> None: self.elementType = elementType self.containsNull = containsNull - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "array<%s>" % self.elementType.simpleString() - def __repr__(self) -> str: + def __repr__(self) -> str: # noqa: D105 return "ArrayType(%s, %s)" % (self.elementType, str(self.containsNull)) - def needConversion(self) -> bool: + def needConversion(self) -> bool: # noqa: D102 return self.elementType.needConversion() - def toInternal(self, obj: list[Optional[T]]) -> list[Optional[T]]: + def toInternal(self, obj: list[Optional[T]]) -> list[Optional[T]]: # noqa: D102 if not self.needConversion(): return obj return obj and [self.elementType.toInternal(v) for v in obj] - def fromInternal(self, obj: list[Optional[T]]) -> list[Optional[T]]: + def fromInternal(self, obj: list[Optional[T]]) -> list[Optional[T]]: # noqa: D102 if not self.needConversion(): return obj return obj and [self.elementType.fromInternal(v) for v in obj] @@ -630,7 +630,7 @@ class MapType(DataType): False """ - def __init__(self, keyType: DataType, valueType: DataType, valueContainsNull: bool = True) -> None: + def __init__(self, keyType: DataType, valueType: DataType, valueContainsNull: bool = True) -> None: # noqa: D107 super().__init__(duckdb.map_type(keyType.duckdb_type, valueType.duckdb_type)) assert isinstance(keyType, DataType), "keyType %s should be an instance of %s" % ( keyType, @@ -644,28 +644,28 @@ def __init__(self, keyType: DataType, valueType: DataType, valueContainsNull: bo self.valueType = valueType self.valueContainsNull = valueContainsNull - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "map<%s,%s>" % ( self.keyType.simpleString(), self.valueType.simpleString(), ) - def __repr__(self) -> str: + def __repr__(self) -> str: # noqa: D105 return "MapType(%s, %s, %s)" % ( self.keyType, self.valueType, str(self.valueContainsNull), ) - def needConversion(self) -> bool: + def needConversion(self) -> bool: # noqa: D102 return self.keyType.needConversion() or self.valueType.needConversion() - def toInternal(self, obj: dict[T, Optional[U]]) -> dict[T, Optional[U]]: + def toInternal(self, obj: dict[T, Optional[U]]) -> dict[T, Optional[U]]: # noqa: D102 if not self.needConversion(): return obj return obj and dict((self.keyType.toInternal(k), self.valueType.toInternal(v)) for k, v in obj.items()) - def fromInternal(self, obj: dict[T, Optional[U]]) -> dict[T, Optional[U]]: + def fromInternal(self, obj: dict[T, Optional[U]]) -> dict[T, Optional[U]]: # noqa: D102 if not self.needConversion(): return obj return obj and dict((self.keyType.fromInternal(k), self.valueType.fromInternal(v)) for k, v in obj.items()) @@ -693,7 +693,7 @@ class StructField(DataType): False """ - def __init__( + def __init__( # noqa: D107 self, name: str, dataType: DataType, @@ -711,26 +711,26 @@ def __init__( self.nullable = nullable self.metadata = metadata or {} - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "%s:%s" % (self.name, self.dataType.simpleString()) - def __repr__(self) -> str: + def __repr__(self) -> str: # noqa: D105 return "StructField('%s', %s, %s)" % ( self.name, self.dataType, str(self.nullable), ) - def needConversion(self) -> bool: + def needConversion(self) -> bool: # noqa: D102 return self.dataType.needConversion() - def toInternal(self, obj: T) -> T: + def toInternal(self, obj: T) -> T: # noqa: D102 return self.dataType.toInternal(obj) - def fromInternal(self, obj: T) -> T: + def fromInternal(self, obj: T) -> T: # noqa: D102 return self.dataType.fromInternal(obj) - def typeName(self) -> str: # type: ignore[override] + def typeName(self) -> str: # type: ignore[override] # noqa: D102 msg = "StructField does not have typeName. Use typeName on its type explicitly instead." raise TypeError(msg) @@ -766,7 +766,7 @@ class StructType(DataType): def _update_internal_duckdb_type(self): self.duckdb_type = duckdb.struct_type(dict(zip(self.names, [x.duckdb_type for x in self.fields]))) - def __init__(self, fields: Optional[list[StructField]] = None) -> None: + def __init__(self, fields: Optional[list[StructField]] = None) -> None: # noqa: D107 if not fields: self.fields = [] self.names = [] @@ -836,7 +836,7 @@ def add( >>> struct2 = StructType([StructField("f1", StringType(), True)]) >>> struct1 == struct2 True - """ # noqa: D205 + """ # noqa: D205, D415 if isinstance(field, StructField): self.fields.append(field) self.names.append(field.name) @@ -882,16 +882,16 @@ def __getitem__(self, key: Union[str, int]) -> StructField: msg = "StructType keys should be strings, integers or slices" raise TypeError(msg) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "struct<%s>" % (",".join(f.simpleString() for f in self)) - def __repr__(self) -> str: + def __repr__(self) -> str: # noqa: D105 return "StructType([%s])" % ", ".join(str(field) for field in self) - def __contains__(self, item: Any) -> bool: + def __contains__(self, item: Any) -> bool: # noqa: D105 return item in self.names - def extract_types_and_names(self) -> tuple[list[str], list[str]]: + def extract_types_and_names(self) -> tuple[list[str], list[str]]: # noqa: D102 names = [] types = [] for f in self.fields: @@ -910,11 +910,11 @@ def fieldNames(self) -> list[str]: """ return list(self.names) - def needConversion(self) -> bool: + def needConversion(self) -> bool: # noqa: D102 # We need convert Row()/namedtuple into tuple() return True - def toInternal(self, obj: tuple) -> tuple: + def toInternal(self, obj: tuple) -> tuple: # noqa: D102 if obj is None: return @@ -946,7 +946,7 @@ def toInternal(self, obj: tuple) -> tuple: else: raise ValueError("Unexpected tuple %r with StructType" % obj) - def fromInternal(self, obj: tuple) -> "Row": + def fromInternal(self, obj: tuple) -> "Row": # noqa: D102 if obj is None: return if isinstance(obj, Row): @@ -1125,7 +1125,7 @@ class Row(tuple): >>> row2 = Row(name="Alice", age=11) >>> row1 == row2 True - """ # noqa: D205 + """ # noqa: D205, D415 @overload def __new__(cls, *args: str) -> "Row": ... @@ -1133,7 +1133,7 @@ def __new__(cls, *args: str) -> "Row": ... @overload def __new__(cls, **kwargs: Any) -> "Row": ... - def __new__(cls, *args: Optional[str], **kwargs: Optional[Any]) -> "Row": + def __new__(cls, *args: Optional[str], **kwargs: Optional[Any]) -> "Row": # noqa: D102 if args and kwargs: msg = "Can not use both args and kwargs to create Row" raise ValueError(msg) @@ -1192,7 +1192,7 @@ def conv(obj: Any) -> Any: else: return dict(zip(self.__fields__, self)) - def __contains__(self, item: Any) -> bool: + def __contains__(self, item: Any) -> bool: # noqa: D105 if hasattr(self, "__fields__"): return item in self.__fields__ else: @@ -1207,7 +1207,7 @@ def __call__(self, *args: Any) -> "Row": ) return _create_row(self, args) - def __getitem__(self, item: Any) -> Any: + def __getitem__(self, item: Any) -> Any: # noqa: D105 if isinstance(item, (int, slice)): return super(Row, self).__getitem__(item) try: @@ -1220,7 +1220,7 @@ def __getitem__(self, item: Any) -> Any: except ValueError: raise ValueError(item) - def __getattr__(self, item: str) -> Any: + def __getattr__(self, item: str) -> Any: # noqa: D105 if item.startswith("__"): raise AttributeError(item) try: @@ -1233,7 +1233,7 @@ def __getattr__(self, item: str) -> Any: except ValueError: raise AttributeError(item) - def __setattr__(self, key: Any, value: Any) -> None: + def __setattr__(self, key: Any, value: Any) -> None: # noqa: D105 if key != "__fields__": msg = "Row is read-only" raise RuntimeError(msg) diff --git a/duckdb/experimental/spark/sql/udf.py b/duckdb/experimental/spark/sql/udf.py index 389d43ab..7437ed6b 100644 --- a/duckdb/experimental/spark/sql/udf.py +++ b/duckdb/experimental/spark/sql/udf.py @@ -1,4 +1,4 @@ -# https://sparkbyexamples.com/pyspark/pyspark-udf-user-defined-function/ +# https://sparkbyexamples.com/pyspark/pyspark-udf-user-defined-function/ # noqa: D100 from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union from .types import DataType @@ -10,11 +10,11 @@ UserDefinedFunctionLike = TypeVar("UserDefinedFunctionLike") -class UDFRegistration: - def __init__(self, sparkSession: "SparkSession") -> None: +class UDFRegistration: # noqa: D101 + def __init__(self, sparkSession: "SparkSession") -> None: # noqa: D107 self.sparkSession = sparkSession - def register( + def register( # noqa: D102 self, name: str, f: Union[Callable[..., Any], "UserDefinedFunctionLike"], @@ -22,7 +22,7 @@ def register( ) -> "UserDefinedFunctionLike": self.sparkSession.conn.create_function(name, f, return_type=returnType) - def registerJavaFunction( + def registerJavaFunction( # noqa: D102 self, name: str, javaClassName: str, @@ -30,7 +30,7 @@ def registerJavaFunction( ) -> None: raise NotImplementedError - def registerJavaUDAF(self, name: str, javaClassName: str) -> None: + def registerJavaUDAF(self, name: str, javaClassName: str) -> None: # noqa: D102 raise NotImplementedError diff --git a/duckdb/filesystem.py b/duckdb/filesystem.py index 77838103..1775a9cf 100644 --- a/duckdb/filesystem.py +++ b/duckdb/filesystem.py @@ -1,4 +1,4 @@ -from io import TextIOBase +from io import TextIOBase # noqa: D100 from fsspec import AbstractFileSystem from fsspec.implementations.memory import MemoryFile, MemoryFileSystem @@ -6,17 +6,17 @@ from .bytes_io_wrapper import BytesIOWrapper -def is_file_like(obj): +def is_file_like(obj): # noqa: D103 # We only care that we can read from the file return hasattr(obj, "read") and hasattr(obj, "seek") -class ModifiedMemoryFileSystem(MemoryFileSystem): +class ModifiedMemoryFileSystem(MemoryFileSystem): # noqa: D101 protocol = ("DUCKDB_INTERNAL_OBJECTSTORE",) # defer to the original implementation that doesn't hardcode the protocol _strip_protocol = classmethod(AbstractFileSystem._strip_protocol.__func__) - def add_file(self, object, path): + def add_file(self, object, path): # noqa: D102 if not is_file_like(object): msg = "Can not read from a non file-like object" raise ValueError(msg) diff --git a/duckdb/functional/__init__.py b/duckdb/functional/__init__.py index b1ddab19..a1d69d39 100644 --- a/duckdb/functional/__init__.py +++ b/duckdb/functional/__init__.py @@ -1,3 +1,3 @@ -from _duckdb.functional import ARROW, DEFAULT, NATIVE, SPECIAL, FunctionNullHandling, PythonUDFType +from _duckdb.functional import ARROW, DEFAULT, NATIVE, SPECIAL, FunctionNullHandling, PythonUDFType # noqa: D104 __all__ = ["ARROW", "DEFAULT", "NATIVE", "SPECIAL", "FunctionNullHandling", "PythonUDFType"] diff --git a/duckdb/polars_io.py b/duckdb/polars_io.py index 69e1e7ea..a11339bb 100644 --- a/duckdb/polars_io.py +++ b/duckdb/polars_io.py @@ -1,4 +1,4 @@ -import datetime +import datetime # noqa: D100 import json from collections.abc import Iterator from decimal import Decimal diff --git a/duckdb/query_graph/__main__.py b/duckdb/query_graph/__main__.py index 88d96350..dedb30a3 100644 --- a/duckdb/query_graph/__main__.py +++ b/duckdb/query_graph/__main__.py @@ -1,4 +1,4 @@ -import argparse +import argparse # noqa: D100 import json import os import re @@ -76,63 +76,63 @@ """ -class NodeTiming: - def __init__(self, phase: str, time: float) -> object: +class NodeTiming: # noqa: D101 + def __init__(self, phase: str, time: float) -> object: # noqa: D107 self.phase = phase self.time = time # percentage is determined later. self.percentage = 0 - def calculate_percentage(self, total_time: float) -> None: + def calculate_percentage(self, total_time: float) -> None: # noqa: D102 self.percentage = self.time / total_time - def combine_timing(l: object, r: object) -> object: + def combine_timing(l: object, r: object) -> object: # noqa: D102 # TODO: can only add timings for same-phase nodes total_time = l.time + r.time return NodeTiming(l.phase, total_time) -class AllTimings: - def __init__(self) -> None: +class AllTimings: # noqa: D101 + def __init__(self) -> None: # noqa: D107 self.phase_to_timings = {} - def add_node_timing(self, node_timing: NodeTiming): + def add_node_timing(self, node_timing: NodeTiming): # noqa: D102 if node_timing.phase in self.phase_to_timings: self.phase_to_timings[node_timing.phase].append(node_timing) return self.phase_to_timings[node_timing.phase] = [node_timing] - def get_phase_timings(self, phase: str): + def get_phase_timings(self, phase: str): # noqa: D102 return self.phase_to_timings[phase] - def get_summary_phase_timings(self, phase: str): + def get_summary_phase_timings(self, phase: str): # noqa: D102 return reduce(NodeTiming.combine_timing, self.phase_to_timings[phase]) - def get_phases(self): + def get_phases(self): # noqa: D102 phases = list(self.phase_to_timings.keys()) phases.sort(key=lambda x: (self.get_summary_phase_timings(x)).time) phases.reverse() return phases - def get_sum_of_all_timings(self): + def get_sum_of_all_timings(self): # noqa: D102 total_timing_sum = 0 for phase in self.phase_to_timings.keys(): total_timing_sum += self.get_summary_phase_timings(phase).time return total_timing_sum -def open_utf8(fpath: str, flags: str) -> object: +def open_utf8(fpath: str, flags: str) -> object: # noqa: D103 return open(fpath, flags, encoding="utf8") -def get_child_timings(top_node: object, query_timings: object) -> str: +def get_child_timings(top_node: object, query_timings: object) -> str: # noqa: D103 node_timing = NodeTiming(top_node["operator_type"], float(top_node["operator_timing"])) query_timings.add_node_timing(node_timing) for child in top_node["children"]: get_child_timings(child, query_timings) -def get_pink_shade_hex(fraction: float): +def get_pink_shade_hex(fraction: float): # noqa: D103 fraction = max(0, min(1, fraction)) # Define the RGB values for very light pink (almost white) and dark pink @@ -148,7 +148,7 @@ def get_pink_shade_hex(fraction: float): return f"#{r:02x}{g:02x}{b:02x}" -def get_node_body(name: str, result: str, cpu_time: float, card: int, est: int, width: int, extra_info: str) -> str: +def get_node_body(name: str, result: str, cpu_time: float, card: int, est: int, width: int, extra_info: str) -> str: # noqa: D103 node_style = f"background-color: {get_pink_shade_hex(float(result) / cpu_time)};" body = f'' @@ -167,7 +167,7 @@ def get_node_body(name: str, result: str, cpu_time: float, card: int, est: int, return body -def generate_tree_recursive(json_graph: object, cpu_time: float) -> str: +def generate_tree_recursive(json_graph: object, cpu_time: float) -> str: # noqa: D103 node_prefix_html = "
    • " node_suffix_html = "
    • " @@ -206,7 +206,7 @@ def generate_tree_recursive(json_graph: object, cpu_time: float) -> str: # For generating the table in the top left. -def generate_timing_html(graph_json: object, query_timings: object) -> object: +def generate_timing_html(graph_json: object, query_timings: object) -> object: # noqa: D103 json_graph = json.loads(graph_json) gather_timing_information(json_graph, query_timings) total_time = float(json_graph.get("operator_timing") or json_graph.get("latency")) @@ -244,7 +244,7 @@ def generate_timing_html(graph_json: object, query_timings: object) -> object: return table_head + table_body -def generate_tree_html(graph_json: object) -> str: +def generate_tree_html(graph_json: object) -> str: # noqa: D103 json_graph = json.loads(graph_json) cpu_time = float(json_graph["cpu_time"]) tree_prefix = '
      \n
        ' @@ -255,7 +255,7 @@ def generate_tree_html(graph_json: object) -> str: return tree_prefix + tree_body + tree_suffix -def generate_ipython(json_input: str) -> str: +def generate_ipython(json_input: str) -> str: # noqa: D103 from IPython.core.display import HTML html_output = generate_html(json_input, False) @@ -268,7 +268,7 @@ def generate_ipython(json_input: str) -> str: ) -def generate_style_html(graph_json: str, include_meta_info: bool) -> None: +def generate_style_html(graph_json: str, include_meta_info: bool) -> None: # noqa: D103 treeflex_css = '\n' css = "