diff --git a/.github/workflows/packaging_wheels.yml b/.github/workflows/packaging_wheels.yml index 858768e5..43dafa2a 100644 --- a/.github/workflows/packaging_wheels.yml +++ b/.github/workflows/packaging_wheels.yml @@ -90,7 +90,6 @@ jobs: env: CIBW_ARCHS: ${{ matrix.platform.arch == 'amd64' && 'AMD64' || matrix.platform.arch }} CIBW_BUILD: ${{ matrix.python }}-${{ matrix.platform.cibw_system }}_${{ matrix.platform.arch }} - - name: Upload wheel uses: actions/upload-artifact@v4 with: diff --git a/_duckdb-stubs/__init__.pyi b/_duckdb-stubs/__init__.pyi index 6c36d7be..25461326 100644 --- a/_duckdb-stubs/__init__.pyi +++ b/_duckdb-stubs/__init__.pyi @@ -1437,7 +1437,6 @@ __interactive__: bool __jupyter__: bool __standard_vector_size__: int __version__: str -_clean_default_connection: pytyping.Any # value = apilevel: str paramstyle: str threadsafety: int diff --git a/duckdb/__init__.py b/duckdb/__init__.py index e1a4aa9a..8dc87642 100644 --- a/duckdb/__init__.py +++ b/duckdb/__init__.py @@ -66,7 +66,6 @@ __interactive__, __jupyter__, __standard_vector_size__, - _clean_default_connection, aggregate, alias, apilevel, @@ -292,7 +291,6 @@ "__jupyter__", "__standard_vector_size__", "__version__", - "_clean_default_connection", "aggregate", "alias", "apilevel", diff --git a/pyproject.toml b/pyproject.toml index adb1dffe..f5456cf3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -233,6 +233,7 @@ test = [ # dependencies used for running tests "pytest", "pytest-reraise", "pytest-timeout", + "pytest-run-parallel", "mypy", "coverage", "gcovr; python_version < '3.14'", diff --git a/src/duckdb_py/CMakeLists.txt b/src/duckdb_py/CMakeLists.txt index 3d06b062..2c802769 100644 --- a/src/duckdb_py/CMakeLists.txt +++ b/src/duckdb_py/CMakeLists.txt @@ -18,6 +18,7 @@ add_library( duckdb_python.cpp importer.cpp map.cpp + module_state.cpp path_like.cpp pyconnection.cpp pyexpression.cpp diff --git a/src/duckdb_py/duckdb_python.cpp b/src/duckdb_py/duckdb_python.cpp index 1dd3ba17..5c3d6cb8 100644 --- a/src/duckdb_py/duckdb_python.cpp +++ b/src/duckdb_py/duckdb_python.cpp @@ -20,6 +20,7 @@ #include "duckdb_python/pybind11/conversions/python_udf_type_enum.hpp" #include "duckdb_python/pybind11/conversions/python_csv_line_terminator_enum.hpp" #include "duckdb/common/enums/statement_type.hpp" +#include "duckdb_python/module_state.hpp" #include "duckdb/common/adbc/adbc-init.hpp" #include "duckdb.hpp" @@ -32,6 +33,20 @@ namespace py = pybind11; namespace duckdb { +// Private function to initialize module state +void InitializeModuleState(py::module_ &m) { + auto state_ptr = new DuckDBPyModuleState(); + SetModuleState(state_ptr); + + // https://pybind11.readthedocs.io/en/stable/advanced/misc.html#module-destructors + auto capsule = py::capsule(state_ptr, [](void *p) { + auto state = static_cast(p); + DuckDBPyModuleState::SetGlobalModuleState(nullptr); + delete state; + }); + m.attr("__duckdb_state") = capsule; +} + enum PySQLTokenType : uint8_t { PY_SQL_TOKEN_IDENTIFIER = 0, PY_SQL_TOKEN_NUMERIC_CONSTANT, @@ -1031,7 +1046,21 @@ PYBIND11_EXPORT void *_force_symbol_inclusion() { } }; -PYBIND11_MODULE(DUCKDB_PYTHON_LIB_NAME, m) { // NOLINT +// Only mark mod_gil_not_used for 3.14t or later +// This is to not add support for 3.13t +// Py_GIL_DISABLED check is not strictly necessary +#if defined(Py_GIL_DISABLED) && PY_VERSION_HEX >= 0x030e0000 +PYBIND11_MODULE(DUCKDB_PYTHON_LIB_NAME, m, py::mod_gil_not_used(), + py::multiple_interpreters::not_supported()) { // NOLINT +#else +PYBIND11_MODULE(DUCKDB_PYTHON_LIB_NAME, m, + py::multiple_interpreters::not_supported()) { // NOLINT +#endif + // Initialize module state completely during initialization + // PEP 489 wants calls for state to be module local, but currently + // static via g_module_state. + InitializeModuleState(m); + // DO NOT REMOVE: the below forces that we include all symbols we want to export volatile auto *keep_alive = _force_symbol_inclusion(); (void)keep_alive; @@ -1075,9 +1104,10 @@ PYBIND11_MODULE(DUCKDB_PYTHON_LIB_NAME, m) { // NOLINT m.attr("__version__") = std::string(DuckDB::LibraryVersion()).substr(1); m.attr("__standard_vector_size__") = DuckDB::StandardVectorSize(); m.attr("__git_revision__") = DuckDB::SourceID(); - m.attr("__interactive__") = DuckDBPyConnection::DetectAndGetEnvironment(); - m.attr("__jupyter__") = DuckDBPyConnection::IsJupyter(); - m.attr("__formatted_python_version__") = DuckDBPyConnection::FormattedPythonVersion(); + auto &module_state = GetModuleState(); + m.attr("__interactive__") = module_state.environment != PythonEnvironmentType::NORMAL; + m.attr("__jupyter__") = module_state.environment == PythonEnvironmentType::JUPYTER; + m.attr("__formatted_python_version__") = module_state.formatted_python_version; m.def("default_connection", &DuckDBPyConnection::DefaultConnection, "Retrieve the connection currently registered as the default to be used by the module"); m.def("set_default_connection", &DuckDBPyConnection::SetDefaultConnection, @@ -1107,12 +1137,6 @@ PYBIND11_MODULE(DUCKDB_PYTHON_LIB_NAME, m) { // NOLINT .value("keyword", PySQLTokenType::PY_SQL_TOKEN_KEYWORD) .value("comment", PySQLTokenType::PY_SQL_TOKEN_COMMENT) .export_values(); - - // we need this because otherwise we try to remove registered_dfs on shutdown when python is already dead - auto clean_default_connection = []() { - DuckDBPyConnection::Cleanup(); - }; - m.add_object("_clean_default_connection", py::capsule(clean_default_connection)); } } // namespace duckdb diff --git a/src/duckdb_py/include/duckdb_python/module_state.hpp b/src/duckdb_py/include/duckdb_python/module_state.hpp new file mode 100644 index 00000000..d7a4e377 --- /dev/null +++ b/src/duckdb_py/include/duckdb_python/module_state.hpp @@ -0,0 +1,63 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb_python/module_state.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb_python/pybind11/pybind_wrapper.hpp" +#include "duckdb/common/shared_ptr.hpp" +#include "duckdb/main/db_instance_cache.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb_python/import_cache/python_import_cache.hpp" +#include "duckdb_python/pyconnection/pyconnection.hpp" +#include + +namespace duckdb { + +// Module state structure to hold per-interpreter state +struct DuckDBPyModuleState { + // Python environment tracking + PythonEnvironmentType environment = PythonEnvironmentType::NORMAL; + string formatted_python_version; + + DuckDBPyModuleState(); + + shared_ptr GetDefaultConnection(); + void SetDefaultConnection(shared_ptr connection); + void ClearDefaultConnection(); + + PythonImportCache *GetImportCache(); + void ClearImportCache(); + + DBInstanceCache *GetInstanceCache(); + + static DuckDBPyModuleState &GetGlobalModuleState(); + static void SetGlobalModuleState(DuckDBPyModuleState *state); + +private: + shared_ptr default_connection_ptr; + PythonImportCache import_cache; + DBInstanceCache instance_cache; +#ifdef Py_GIL_DISABLED + py::object default_con_lock; +#endif + + // Implemented as static as a first step towards PEP 489 / multi-phase init + // Intent is to move to per-module object, but frequent calls to import_cache + // need to be considered carefully. + // TODO: Replace with non-static per-interpreter state for multi-interpreter support + static DuckDBPyModuleState *g_module_state; + + // Non-copyable + DuckDBPyModuleState(const DuckDBPyModuleState &) = delete; + DuckDBPyModuleState &operator=(const DuckDBPyModuleState &) = delete; +}; + +DuckDBPyModuleState &GetModuleState(); +void SetModuleState(DuckDBPyModuleState *state); + +} // namespace duckdb \ No newline at end of file diff --git a/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp b/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp index 48ee055e..7998c14e 100644 --- a/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp +++ b/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp @@ -28,6 +28,7 @@ namespace duckdb { struct BoundParameterData; +struct DuckDBPyModuleState; enum class PythonEnvironmentType { NORMAL, INTERACTIVE, JUPYTER }; @@ -172,8 +173,7 @@ struct DuckDBPyConnection : public enable_shared_from_this { case_insensitive_set_t registered_objects; public: - explicit DuckDBPyConnection() { - } + DuckDBPyConnection(); ~DuckDBPyConnection(); public: @@ -190,9 +190,17 @@ struct DuckDBPyConnection : public enable_shared_from_this { static std::string FormattedPythonVersion(); static shared_ptr DefaultConnection(); static void SetDefaultConnection(shared_ptr conn); + static shared_ptr GetDefaultConnection(); + static void ClearDefaultConnection(); + static void ClearImportCache(); static PythonImportCache *ImportCache(); static bool IsInteractive(); + // Instance methods for optimized module state access + bool IsJupyterInstance() const; + bool IsInteractiveInstance() const; + std::string FormattedPythonVersionInstance() const; + unique_ptr ReadCSV(const py::object &name, py::kwargs &kwargs); py::list ExtractStatements(const string &query); @@ -337,11 +345,6 @@ struct DuckDBPyConnection : public enable_shared_from_this { py::list ListFilesystems(); bool FileSystemIsRegistered(const string &name); - //! Default connection to an in-memory database - static DefaultConnectionHolder default_connection; - //! Caches and provides an interface to get frequently used modules+subtypes - static shared_ptr import_cache; - static bool IsPandasDataframe(const py::object &object); static PyArrowObjectType GetArrowType(const py::handle &obj); static bool IsAcceptedArrowObject(const py::object &object); @@ -357,10 +360,6 @@ struct DuckDBPyConnection : public enable_shared_from_this { bool side_effects); void RegisterArrowObject(const py::object &arrow_object, const string &name); vector> GetStatements(const py::object &query); - - static PythonEnvironmentType environment; - static std::string formatted_python_version; - static void DetectEnvironment(); }; template diff --git a/src/duckdb_py/module_state.cpp b/src/duckdb_py/module_state.cpp new file mode 100644 index 00000000..1e0b6897 --- /dev/null +++ b/src/duckdb_py/module_state.cpp @@ -0,0 +1,128 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb_python/module_state.cpp +// +// +//===----------------------------------------------------------------------===// + +#include "duckdb_python/module_state.hpp" +#include +#include +#include + +#define DEBUG_MODULE_STATE 0 + +namespace duckdb { + +// Forward declaration from pyconnection.cpp +void InstantiateNewInstance(DuckDB &db); + +// Static member initialization - required for all static class members in C++ +DuckDBPyModuleState *DuckDBPyModuleState::g_module_state = nullptr; + +DuckDBPyModuleState::DuckDBPyModuleState() { + // Caches are constructed as direct objects - no heap allocation needed + +#ifdef Py_GIL_DISABLED + // Initialize lock object for critical sections + // TODO: Consider moving to finer-grained locks + default_con_lock = py::none(); +#endif + + // Detects Python environment and version during intialization + // Moved from DuckDBPyConnection::DetectEnvironment() + py::module_ sys = py::module_::import("sys"); + py::object version_info = sys.attr("version_info"); + int major = py::cast(version_info.attr("major")); + int minor = py::cast(version_info.attr("minor")); + formatted_python_version = std::to_string(major) + "." + std::to_string(minor); + + // If __main__ does not have a __file__ attribute, we are in interactive mode + auto main_module = py::module_::import("__main__"); + if (!py::hasattr(main_module, "__file__")) { + environment = PythonEnvironmentType::INTERACTIVE; + + if (ModuleIsLoaded()) { + // Check to see if we are in a Jupyter Notebook + auto get_ipython = import_cache.IPython.get_ipython(); + if (get_ipython.ptr() != nullptr) { + auto ipython = get_ipython(); + if (py::hasattr(ipython, "config")) { + py::dict ipython_config = ipython.attr("config"); + if (ipython_config.contains("IPKernelApp")) { + environment = PythonEnvironmentType::JUPYTER; + } + } + } + } + } +} + +DuckDBPyModuleState &DuckDBPyModuleState::GetGlobalModuleState() { + // TODO: Externalize this static cache when adding multi-interpreter support + // For now, single interpreter assumption allows simple static caching + if (!g_module_state) { + throw InternalException("Module state not initialized - call SetGlobalModuleState() during module init"); + } + return *g_module_state; +} + +void DuckDBPyModuleState::SetGlobalModuleState(DuckDBPyModuleState *state) { +#if DEBUG_MODULE_STATE + printf("DEBUG: SetGlobalModuleState() called - initializing static cache (built: %s %s)\n", __DATE__, __TIME__); +#endif + g_module_state = state; +} + +DuckDBPyModuleState &GetModuleState() { +#if DEBUG_MODULE_STATE + printf("DEBUG: GetModuleState() called\n"); +#endif + return DuckDBPyModuleState::GetGlobalModuleState(); +} + +void SetModuleState(DuckDBPyModuleState *state) { + DuckDBPyModuleState::SetGlobalModuleState(state); +} + +shared_ptr DuckDBPyModuleState::GetDefaultConnection() { +#if defined(Py_GIL_DISABLED) + // TODO: Consider whether a mutex vs a scoped_critical_section + py::scoped_critical_section guard(default_con_lock); +#endif + // Reproduce exact logic from original DefaultConnectionHolder::Get() + if (!default_connection_ptr || default_connection_ptr->con.ConnectionIsClosed()) { + py::dict config_dict; + default_connection_ptr = DuckDBPyConnection::Connect(py::str(":memory:"), false, config_dict); + } + return default_connection_ptr; +} + +void DuckDBPyModuleState::SetDefaultConnection(shared_ptr connection) { +#if defined(Py_GIL_DISABLED) + py::scoped_critical_section guard(default_con_lock); +#endif + default_connection_ptr = std::move(connection); +} + +void DuckDBPyModuleState::ClearDefaultConnection() { +#if defined(Py_GIL_DISABLED) + py::scoped_critical_section guard(default_con_lock); +#endif + default_connection_ptr = nullptr; +} + +PythonImportCache *DuckDBPyModuleState::GetImportCache() { + return &import_cache; +} + +void DuckDBPyModuleState::ClearImportCache() { + import_cache = PythonImportCache(); +} + +DBInstanceCache *DuckDBPyModuleState::GetInstanceCache() { + return &instance_cache; +} + +} // namespace duckdb \ No newline at end of file diff --git a/src/duckdb_py/pyconnection.cpp b/src/duckdb_py/pyconnection.cpp index b88b88ed..ba394990 100644 --- a/src/duckdb_py/pyconnection.cpp +++ b/src/duckdb_py/pyconnection.cpp @@ -1,4 +1,5 @@ #include "duckdb_python/pyconnection/pyconnection.hpp" +#include "duckdb_python/module_state.hpp" #include "duckdb/catalog/default/default_types.hpp" #include "duckdb/common/arrow/arrow.hpp" @@ -66,11 +67,8 @@ namespace duckdb { -DefaultConnectionHolder DuckDBPyConnection::default_connection; // NOLINT: allow global -DBInstanceCache instance_cache; // NOLINT: allow global -shared_ptr DuckDBPyConnection::import_cache = nullptr; // NOLINT: allow global -PythonEnvironmentType DuckDBPyConnection::environment = PythonEnvironmentType::NORMAL; // NOLINT: allow global -std::string DuckDBPyConnection::formatted_python_version = ""; +DuckDBPyConnection::DuckDBPyConnection() { +} DuckDBPyConnection::~DuckDBPyConnection() { try { @@ -82,53 +80,17 @@ DuckDBPyConnection::~DuckDBPyConnection() { } } -void DuckDBPyConnection::DetectEnvironment() { - // Get the formatted Python version - py::module_ sys = py::module_::import("sys"); - py::object version_info = sys.attr("version_info"); - int major = py::cast(version_info.attr("major")); - int minor = py::cast(version_info.attr("minor")); - DuckDBPyConnection::formatted_python_version = std::to_string(major) + "." + std::to_string(minor); - - // If __main__ does not have a __file__ attribute, we are in interactive mode - auto main_module = py::module_::import("__main__"); - if (py::hasattr(main_module, "__file__")) { - return; - } - DuckDBPyConnection::environment = PythonEnvironmentType::INTERACTIVE; - if (!ModuleIsLoaded()) { - return; - } - - // Check to see if we are in a Jupyter Notebook - auto &import_cache_py = *DuckDBPyConnection::ImportCache(); - auto get_ipython = import_cache_py.IPython.get_ipython(); - if (get_ipython.ptr() == nullptr) { - // Could either not load the IPython module, or it has no 'get_ipython' attribute - return; - } - auto ipython = get_ipython(); - if (!py::hasattr(ipython, "config")) { - return; - } - py::dict ipython_config = ipython.attr("config"); - if (ipython_config.contains("IPKernelApp")) { - DuckDBPyConnection::environment = PythonEnvironmentType::JUPYTER; - } - return; -} - bool DuckDBPyConnection::DetectAndGetEnvironment() { - DuckDBPyConnection::DetectEnvironment(); + // Environment detection now happens during module state construction return DuckDBPyConnection::IsInteractive(); } bool DuckDBPyConnection::IsJupyter() { - return DuckDBPyConnection::environment == PythonEnvironmentType::JUPYTER; + return GetModuleState().environment == PythonEnvironmentType::JUPYTER; } std::string DuckDBPyConnection::FormattedPythonVersion() { - return DuckDBPyConnection::formatted_python_version; + return GetModuleState().formatted_python_version; } // NOTE: this function is generated by tools/pythonpkg/scripts/generate_connection_methods.py. @@ -2113,8 +2075,8 @@ static shared_ptr FetchOrCreateInstance(const string &databa D_ASSERT(py::gil_check()); py::gil_scoped_release release; unique_lock lock(res->py_connection_lock); - auto database = - instance_cache.GetOrCreateInstance(database_path, config, cache_instance, InstantiateNewInstance); + auto database = GetModuleState().GetInstanceCache()->GetOrCreateInstance(database_path, config, cache_instance, + InstantiateNewInstance); res->con.SetDatabase(std::move(database)); res->con.SetConnection(make_uniq(res->con.GetDatabase())); } @@ -2162,6 +2124,7 @@ shared_ptr DuckDBPyConnection::Connect(const py::object &dat "python_scan_all_frames", "If set, restores the old behavior of scanning all preceding frames to locate the referenced variable.", LogicalType::BOOLEAN, Value::BOOLEAN(false)); + // Use static methods here since we don't have connection instance yet if (!DuckDBPyConnection::IsJupyter()) { config_dict["duckdb_api"] = Value("python/" + DuckDBPyConnection::FormattedPythonVersion()); } else { @@ -2197,18 +2160,27 @@ case_insensitive_map_t DuckDBPyConnection::TransformPythonPa } shared_ptr DuckDBPyConnection::DefaultConnection() { - return default_connection.Get(); + return GetModuleState().GetDefaultConnection(); } void DuckDBPyConnection::SetDefaultConnection(shared_ptr connection) { - return default_connection.Set(std::move(connection)); + GetModuleState().SetDefaultConnection(std::move(connection)); } PythonImportCache *DuckDBPyConnection::ImportCache() { - if (!import_cache) { - import_cache = make_shared_ptr(); - } - return import_cache.get(); + return GetModuleState().GetImportCache(); +} + +bool DuckDBPyConnection::IsJupyterInstance() const { + return GetModuleState().environment == PythonEnvironmentType::JUPYTER; +} + +bool DuckDBPyConnection::IsInteractiveInstance() const { + return GetModuleState().environment != PythonEnvironmentType::NORMAL; +} + +std::string DuckDBPyConnection::FormattedPythonVersionInstance() const { + return GetModuleState().formatted_python_version; } ModifiedMemoryFileSystem &DuckDBPyConnection::GetObjectFileSystem() { @@ -2228,7 +2200,7 @@ ModifiedMemoryFileSystem &DuckDBPyConnection::GetObjectFileSystem() { } bool DuckDBPyConnection::IsInteractive() { - return DuckDBPyConnection::environment != PythonEnvironmentType::NORMAL; + return GetModuleState().environment != PythonEnvironmentType::NORMAL; } shared_ptr DuckDBPyConnection::Enter() { @@ -2246,8 +2218,25 @@ void DuckDBPyConnection::Exit(DuckDBPyConnection &self, const py::object &exc_ty } void DuckDBPyConnection::Cleanup() { - default_connection.Set(nullptr); - import_cache.reset(); + try { + GetModuleState().ClearDefaultConnection(); + GetModuleState().ClearImportCache(); + } catch (...) { // NOLINT + // TODO: Can we detect shutdown? Py_IsFinalizing might be appropriate, although renamed from + // _Py_IsFinalizing + } +} + +shared_ptr DuckDBPyConnection::GetDefaultConnection() { + return GetModuleState().GetDefaultConnection(); +} + +void DuckDBPyConnection::ClearDefaultConnection() { + GetModuleState().ClearDefaultConnection(); +} + +void DuckDBPyConnection::ClearImportCache() { + GetModuleState().ClearImportCache(); } bool DuckDBPyConnection::IsPandasDataframe(const py::object &object) { diff --git a/tests/conftest.py b/tests/conftest.py index df64f86c..23070b34 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -320,3 +320,13 @@ def timestamps(duckdb_cursor): cursor.execute("INSERT INTO timestamps VALUES ('1992-10-03 18:34:45'), ('2010-01-01 00:00:01'), (NULL)") yield cursor.execute("drop table timestamps") + + +@pytest.fixture +def num_threads_testing(): + """Get thread count: enough to load the system, but still as fast test.""" + import multiprocessing + + cpu_count = multiprocessing.cpu_count() + # Use 1.5x CPU count, max 12 for CI compatibility + return min(12, max(4, int(cpu_count * 1.5))) diff --git a/tests/fast/adbc/test_adbc.py b/tests/fast/adbc/test_adbc.py index 80920a99..69085596 100644 --- a/tests/fast/adbc/test_adbc.py +++ b/tests/fast/adbc/test_adbc.py @@ -2,19 +2,22 @@ import sys from pathlib import Path -import adbc_driver_manager.dbapi import numpy as np import pyarrow import pytest -import adbc_driver_duckdb.dbapi +import adbc_driver_duckdb -xfail = pytest.mark.xfail driver_path = adbc_driver_duckdb.driver_path() +xfail = pytest.mark.xfail + + @pytest.fixture def duck_conn(): + adbc_driver_manager = pytest.importorskip("adbc_driver_manager") + with adbc_driver_manager.dbapi.connect(driver=driver_path, entrypoint="duckdb_adbc_init") as conn: yield conn @@ -95,6 +98,8 @@ def test_connection_get_objects_filters(duck_conn): def test_commit(tmp_path): + adbc_driver_manager = pytest.importorskip("adbc_driver_manager") + db = Path(tmp_path) / "tmp.db" if db.exists(): db.unlink() @@ -142,6 +147,8 @@ def test_commit(tmp_path): def test_connection_get_table_schema(duck_conn): + adbc_driver_manager = pytest.importorskip("adbc_driver_manager") + with duck_conn.cursor() as cursor: # Test Default Schema cursor.execute("CREATE TABLE tableschema (ints BIGINT)") @@ -209,6 +216,8 @@ def test_statement_query(duck_conn): @xfail(sys.platform == "win32", reason="adbc-driver-manager returns an invalid table schema on windows") def test_insertion(duck_conn): + adbc_driver_manager = pytest.importorskip("adbc_driver_manager") + table = example_table() reader = table.to_reader() @@ -281,6 +290,8 @@ def test_read(duck_conn): def test_large_chunk(tmp_path): + adbc_driver_manager = pytest.importorskip("adbc_driver_manager") + num_chunks = 3 chunk_size = 10_000 @@ -318,6 +329,8 @@ def test_large_chunk(tmp_path): def test_dictionary_data(tmp_path): + adbc_driver_manager = pytest.importorskip("adbc_driver_manager") + data = ["apple", "banana", "apple", "orange", "banana", "banana"] dict_type = pyarrow.dictionary(index_type=pyarrow.int32(), value_type=pyarrow.string()) @@ -346,6 +359,8 @@ def test_dictionary_data(tmp_path): def test_ree_data(tmp_path): + adbc_driver_manager = pytest.importorskip("adbc_driver_manager") + run_ends = pyarrow.array([3, 5, 6], type=pyarrow.int32()) # positions: [0-2], [3-4], [5] values = pyarrow.array(["apple", "banana", "orange"], type=pyarrow.string()) diff --git a/tests/fast/adbc/test_statement_bind.py b/tests/fast/adbc/test_statement_bind.py index d35693ff..dc7da7b2 100644 --- a/tests/fast/adbc/test_statement_bind.py +++ b/tests/fast/adbc/test_statement_bind.py @@ -1,6 +1,5 @@ import sys -import adbc_driver_manager import pyarrow as pa import pytest @@ -10,6 +9,8 @@ def _import(handle): + adbc_driver_manager = pytest.importorskip("adbc_driver_manager") + """Helper to import a C Data Interface handle.""" if isinstance(handle, adbc_driver_manager.ArrowArrayStreamHandle): return pa.RecordBatchReader._import_from_c(handle.address) @@ -20,6 +21,8 @@ def _import(handle): def _bind(stmt, batch) -> None: + adbc_driver_manager = pytest.importorskip("adbc_driver_manager") + array = adbc_driver_manager.ArrowArrayHandle() schema = adbc_driver_manager.ArrowSchemaHandle() batch._export_to_c(array.address, schema.address) @@ -28,6 +31,8 @@ def _bind(stmt, batch) -> None: class TestADBCStatementBind: def test_bind_multiple_rows(self): + adbc_driver_manager = pytest.importorskip("adbc_driver_manager") + data = pa.record_batch( [ [1, 2, 3, 4], @@ -141,6 +146,8 @@ def test_bind_composite_type(self): assert result == struct_array def test_too_many_parameters(self): + adbc_driver_manager = pytest.importorskip("adbc_driver_manager") + data = pa.record_batch( [[12423], ["not a short string"]], names=["ints", "strings"], @@ -170,6 +177,8 @@ def test_too_many_parameters(self): @xfail(sys.platform == "win32", reason="adbc-driver-manager returns an invalid table schema on windows") def test_not_enough_parameters(self): + adbc_driver_manager = pytest.importorskip("adbc_driver_manager") + data = pa.record_batch( [["not a short string"]], names=["strings"], diff --git a/tests/fast/api/test_connection_interrupt.py b/tests/fast/api/test_connection_interrupt.py index 4ea63176..128e6331 100644 --- a/tests/fast/api/test_connection_interrupt.py +++ b/tests/fast/api/test_connection_interrupt.py @@ -1,6 +1,7 @@ import platform import threading import time +from concurrent.futures import ThreadPoolExecutor import pytest @@ -12,19 +13,28 @@ class TestConnectionInterrupt: condition=platform.system() == "Emscripten", reason="threads not allowed on Emscripten", ) + @pytest.mark.timeout(10) def test_connection_interrupt(self): conn = duckdb.connect() + barrier = threading.Barrier(2) - def interrupt() -> None: - # Wait for query to start running before interrupting - time.sleep(0.1) + def execute_query(): + barrier.wait() + return conn.execute("select * from range(1000000) t1, range(1000000) t2").fetchall() + + def interrupt_query(): + barrier.wait() + time.sleep(2) conn.interrupt() - thread = threading.Thread(target=interrupt) - thread.start() - with pytest.raises(duckdb.InterruptException): - conn.execute("select count(*) from range(100000000000)").fetchall() - thread.join() + with ThreadPoolExecutor() as executor: + query_future = executor.submit(execute_query) + interrupt_future = executor.submit(interrupt_query) + + interrupt_future.result() + + with pytest.raises((duckdb.InterruptException, duckdb.InvalidInputException)): + query_future.result() def test_interrupt_closed_connection(self): conn = duckdb.connect() diff --git a/tests/fast/api/test_query_interrupt.py b/tests/fast/api/test_query_interrupt.py index 4a5a02e5..3145c616 100644 --- a/tests/fast/api/test_query_interrupt.py +++ b/tests/fast/api/test_query_interrupt.py @@ -8,29 +8,27 @@ import duckdb -def send_keyboard_interrupt(): - # Wait a little, so we're sure the 'execute' has started - time.sleep(0.1) - # Send an interrupt to the main thread - thread.interrupt_main() - - class TestQueryInterruption: @pytest.mark.xfail( condition=platform.system() == "Emscripten", reason="Emscripten builds cannot use threads", ) + @pytest.mark.timeout(10) def test_query_interruption(self): con = duckdb.connect() - thread = threading.Thread(target=send_keyboard_interrupt) - # Start the thread - thread.start() - try: - con.execute("select count(*) from range(100000000000)").fetchall() - except RuntimeError: - # If this is not reached, we could not cancel the query before it completed - # indicating that the query interruption functionality is broken - assert True - except KeyboardInterrupt: - pytest.fail("Interrupted by user") - thread.join() + barrier = threading.Barrier(2) + + def send_keyboard_interrupt(): + barrier.wait() + time.sleep(2) + thread.interrupt_main() + + interrupt_thread = threading.Thread(target=send_keyboard_interrupt) + interrupt_thread.start() + + barrier.wait() + + with pytest.raises((KeyboardInterrupt, RuntimeError)): + con.execute("select * from range(1000000) t1,range(1000000) t2").fetchall() + + interrupt_thread.join() diff --git a/tests/fast/threading/README.md b/tests/fast/threading/README.md new file mode 100644 index 00000000..be5b8f53 --- /dev/null +++ b/tests/fast/threading/README.md @@ -0,0 +1,10 @@ +Tests in this directory are intended to be run with [pytest-run-parallel](https://github.com/Quansight-Labs/pytest-run-parallel) to exercise thread safety. + +Example usage: `pytest --parallel-threads=10 --iterations=5 --verbose tests/fast/threading -n 4 --durations=5` + +#### Thread Safety and DuckDB + +Not all duckdb operations are thread safe - cursors are not thread safe, so some care must be considered to avoid running tests that concurrently hit the same tests. + +Tests can be marked as single threaded with: +- `pytest.mark.thread_unsafe` or the equivalent `pytest.mark.parallel_threads(1)` diff --git a/tests/fast/threading/test_basic_operations.py b/tests/fast/threading/test_basic_operations.py new file mode 100644 index 00000000..864a084f --- /dev/null +++ b/tests/fast/threading/test_basic_operations.py @@ -0,0 +1,104 @@ +import gc +import random +import time +import uuid +import weakref +from threading import get_ident + +import duckdb + + +def test_basic(): + with duckdb.connect(":memory:") as conn: + result = conn.execute("SELECT 1").fetchone() + assert result[0] == 1 + int_type = duckdb.type("INTEGER") + assert int_type is not None, "type creation failed" + + +def test_connection_instance_cache(tmp_path): + thread_id = get_ident() + for i in range(10): + with duckdb.connect(tmp_path / f"{thread_id}_{uuid.uuid4()}.db") as conn: + conn.execute(f"CREATE TABLE IF NOT EXISTS thread_{thread_id}_data_{i} (x BIGINT)") + conn.execute(f"INSERT INTO thread_{thread_id}_data_{i} VALUES (100), (100)") + + time.sleep(random.uniform(0.0001, 0.001)) + + result = conn.execute(f"SELECT COUNT(*) FROM thread_{thread_id}_data_{i}").fetchone()[0] + assert result == 2, f"Iteration {i}: expected 2 rows, got {result}" + + +def test_cleanup(): + weak_refs = [] + + for i in range(5): + conn = duckdb.connect(":memory:") + weak_refs.append(weakref.ref(conn)) + try: + conn.execute("CREATE TABLE test (x INTEGER)") + conn.execute("INSERT INTO test VALUES (1), (2), (3)") + result = conn.execute("SELECT COUNT(*) FROM test").fetchone() + assert result[0] == 3 + finally: + conn.close() + conn = None + + if i % 3 == 0: + with duckdb.connect(":memory:") as new_conn: + result = new_conn.execute("SELECT 1").fetchone() + assert result[0] == 1 + + if i % 10 == 0: + gc.collect() + time.sleep(random.uniform(0.0001, 0.0005)) + + gc.collect() + time.sleep(0.1) + gc.collect() + + alive_refs = [ref for ref in weak_refs if ref() is not None] + assert len(alive_refs) <= 10, f"{len(alive_refs)} connections still alive (expected <= 10)" + + +def test_default_connection(): + with duckdb.connect() as conn1: + r1 = conn1.execute("SELECT 1").fetchone()[0] + assert r1 == 1, f"expected 1, got {r1}" + + with duckdb.connect(":memory:") as conn2: + r2 = conn2.execute("SELECT 2").fetchone()[0] + assert r2 == 2, f"expected 2, got {r2}" + + +def test_type_system(): + for i in range(20): + types = [ + duckdb.type("INTEGER"), + duckdb.type("VARCHAR"), + duckdb.type("DOUBLE"), + duckdb.type("BOOLEAN"), + duckdb.list_type(duckdb.type("INTEGER")), + duckdb.struct_type({"a": duckdb.type("INTEGER"), "b": duckdb.type("VARCHAR")}), + ] + + for t in types: + assert t is not None, "type creation failed" + + if i % 5 == 0: + with duckdb.connect(":memory:") as conn: + conn.execute("CREATE TABLE test (a INTEGER, b VARCHAR, c DOUBLE, d BOOLEAN)") + result = conn.execute("SELECT COUNT(*) FROM test").fetchone() + assert result[0] == 0 + + +def test_import_cache(): + with duckdb.connect(":memory:") as conn: + conn.execute("CREATE TABLE test AS SELECT range as x FROM range(10)") + result = conn.fetchdf() + assert len(result) > 0, "fetchdf failed" + + result = conn.execute("SELECT range as x FROM range(5)").fetchnumpy() + assert len(result["x"]) == 5, "fetchnumpy failed" + + conn.execute("DROP TABLE test") diff --git a/tests/fast/threading/test_concurrent_access.py b/tests/fast/threading/test_concurrent_access.py new file mode 100644 index 00000000..111fc757 --- /dev/null +++ b/tests/fast/threading/test_concurrent_access.py @@ -0,0 +1,100 @@ +"""Concurrent access tests for DuckDB Python bindings with free threading support. + +These tests verify that the DuckDB Python module can handle concurrent access +from multiple threads safely, testing module state isolation, memory management, +and connection handling under various stress conditions. +""" + +import concurrent.futures +import gc +import random +import time + +import pytest + +import duckdb + + +def test_concurrent_connections(): + with duckdb.connect() as conn: + result = conn.execute("SELECT random() as id, random()*2 as doubled").fetchone() + assert result is not None + + +@pytest.mark.parallel_threads(1) +def test_shared_connection_stress(num_threads_testing): + """Test concurrent operations on shared connection using cursors.""" + iterations = 10 + + with duckdb.connect(":memory:") as connection: + connection.execute("CREATE TABLE stress_test (id INTEGER, thread_id INTEGER, value TEXT)") + + def worker_thread(thread_id: int) -> None: + cursor = connection.cursor() + for i in range(iterations): + cursor.execute( + "INSERT INTO stress_test VALUES (?, ?, ?)", + [i, thread_id, f"thread_{thread_id}_value_{i}"], + ) + cursor.execute("SELECT COUNT(*) FROM stress_test WHERE thread_id = ?", [thread_id]).fetchone() + time.sleep(random.uniform(0.0001, 0.001)) + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads_testing) as executor: + futures = [executor.submit(worker_thread, i) for i in range(num_threads_testing)] + # Wait for all to complete, will raise if any fail + for future in concurrent.futures.as_completed(futures): + future.result() + + total_rows = connection.execute("SELECT COUNT(*) FROM stress_test").fetchone()[0] + expected_rows = num_threads_testing * iterations + assert total_rows == expected_rows + + +@pytest.mark.parallel_threads(1) +def test_module_state_isolation(): + """Test that module state is properly accessible.""" + with duckdb.connect(":memory:"): + assert hasattr(duckdb, "__version__") + + with duckdb.connect() as default_conn: + result = default_conn.execute("SELECT 'default' as type").fetchone() + assert result[0] == "default" + + int_type = duckdb.type("INTEGER") + string_type = duckdb.type("VARCHAR") + assert int_type is not None + assert string_type is not None + + +def test_rapid_connect_disconnect(): + connections_count = 10 + """Test rapid connection creation and destruction.""" + for i in range(connections_count): + conn = duckdb.connect(":memory:") + try: + result = conn.execute("SELECT 1").fetchone()[0] + assert result == 1 + finally: + conn.close() + + # Sometimes force GC to increase pressure + if i % 3 == 0: + gc.collect() + + +def test_exception_handling(): + """Test exception handling doesn't affect module state.""" + conn = duckdb.connect(":memory:") + try: + conn.execute("CREATE TABLE test (x INTEGER)") + conn.execute("INSERT INTO test VALUES (1), (2), (3)") + + for i in range(10): + if i % 3 == 0: + with pytest.raises(duckdb.CatalogException): + conn.execute("SELECT * FROM nonexistent_table") + else: + result = conn.execute("SELECT COUNT(*) FROM test").fetchone()[0] + assert result == 3 + finally: + conn.close() diff --git a/tests/fast/threading/test_connection_lifecycle_races.py b/tests/fast/threading/test_connection_lifecycle_races.py new file mode 100644 index 00000000..10c6c151 --- /dev/null +++ b/tests/fast/threading/test_connection_lifecycle_races.py @@ -0,0 +1,86 @@ +"""Test connection lifecycle races. + +Focused on DuckDBPyConnection constructor and Close +""" + +import concurrent.futures +import gc + +import pytest + +import duckdb + + +def test_concurrent_connection_creation_destruction(): + conn = duckdb.connect() + try: + result = conn.execute("SELECT 1").fetchone() + assert result[0] == 1 + finally: + conn.close() + + +def test_connection_destructor_race(): + conn = duckdb.connect() + result = conn.execute("SELECT COUNT(*) FROM range(1)").fetchone() + assert result[0] == 1 + + del conn + gc.collect() + + +@pytest.mark.parallel_threads(1) +def test_concurrent_close_operations(num_threads_testing): + with duckdb.connect(":memory:") as conn: + conn.execute("CREATE TABLE shared_table (id INTEGER, data VARCHAR)") + conn.execute("INSERT INTO shared_table VALUES (1, 'test')") + + def attempt_close_connection(cursor, thread_id): + _result = cursor.execute("SELECT COUNT(*) FROM shared_table").fetchone() + + cursor.close() + + return True + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads_testing) as executor: + futures = [executor.submit(attempt_close_connection, conn.cursor(), i) for i in range(num_threads_testing)] + results = [future.result() for future in concurrent.futures.as_completed(futures)] + + assert all(results) + + +@pytest.mark.parallel_threads(1) +def test_cursor_operations_race(num_threads_testing): + conn = duckdb.connect(":memory:") + try: + conn.execute("CREATE TABLE cursor_test (id INTEGER, name VARCHAR)") + conn.execute("INSERT INTO cursor_test SELECT i, 'name_' || i FROM range(100) t(i)") + + def cursor_operations(thread_id): + """Perform cursor operations concurrently.""" + # Get a cursor + cursor = conn.cursor() + cursor.execute(f"SELECT * FROM cursor_test WHERE id % {num_threads_testing} = {thread_id}") + cursor.fetchall() + + return True + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads_testing) as executor: + futures = [executor.submit(cursor_operations, i) for i in range(num_threads_testing)] + results = [future.result() for future in concurrent.futures.as_completed(futures)] + + assert all(results) + finally: + conn.close() + + +def test_rapid_connection_cycling(): + """Test rapid connection creation and destruction cycles.""" + num_cycles = 5 + for cycle in range(num_cycles): + conn = duckdb.connect(":memory:") + try: + result = conn.execute(f"SELECT 1 + {cycle}").fetchone() + assert result[0] == 1 + cycle + finally: + conn.close() diff --git a/tests/fast/threading/test_fetching.py b/tests/fast/threading/test_fetching.py new file mode 100644 index 00000000..fc891df5 --- /dev/null +++ b/tests/fast/threading/test_fetching.py @@ -0,0 +1,34 @@ +"""Test fetching operations.""" + +from threading import get_ident + +import duckdb + + +def test_fetching(): + """Test different fetching methods.""" + iterations = 10 + thread_id = get_ident() + + conn = duckdb.connect() + try: + batch_data = [(thread_id * 100 + i, f"name_{thread_id}_{i}") for i in range(iterations)] + conn.execute("CREATE TABLE batch_data (id BIGINT, name VARCHAR)") + conn.executemany("INSERT INTO batch_data VALUES (?, ?)", batch_data) + + # Test different fetch methods + result1 = conn.execute(f"SELECT COUNT(*) FROM batch_data WHERE name LIKE 'name_{thread_id}_%'").fetchone() + assert result1[0] == iterations + + result2 = conn.execute(f"SELECT COUNT(*) FROM batch_data WHERE name LIKE 'name_{thread_id}_%'").fetchall() + assert result2[0][0] == iterations + + result3 = conn.execute(f"SELECT COUNT(*) FROM batch_data WHERE name LIKE 'name_{thread_id}_%'").fetchdf() + assert len(result3) == 1 + + result4 = conn.execute( + f"SELECT COUNT(*) FROM batch_data WHERE name LIKE 'name_{thread_id}_%'" + ).fetch_arrow_table() + assert result4.num_rows == 1 + finally: + conn.close() diff --git a/tests/fast/threading/test_module_lifecycle.py b/tests/fast/threading/test_module_lifecycle.py new file mode 100644 index 00000000..355740b1 --- /dev/null +++ b/tests/fast/threading/test_module_lifecycle.py @@ -0,0 +1,144 @@ +"""Test module lifecycle. + +Reloading and unload are not expected nor required behaviors - +these tests are to document current behavior so that changes +are visible. +""" + +import importlib +import sys +from threading import get_ident + +import pytest + + +@pytest.mark.parallel_threads(1) +def test_module_reload_safety(): + """Test module reloading scenarios to detect use-after-free issues.""" + import duckdb + + with duckdb.connect(":memory:") as conn1: + conn1.execute("CREATE TABLE test (id INTEGER)") + conn1.execute("INSERT INTO test VALUES (1)") + result1 = conn1.execute("SELECT * FROM test").fetchone()[0] + assert result1 == 1 + + initial_module_id = id(sys.modules["duckdb"]) + + # Test importlib.reload() - + # does NOT create new module in Python + importlib.reload(duckdb) + + # Verify module instance is the same (expected Python behavior) + reload_module_id = id(sys.modules["duckdb"]) + assert initial_module_id == reload_module_id, "importlib.reload() should reuse same module instance" + + # Test if old connection still works after importlib.reload() + result2 = conn1.execute("SELECT * FROM test").fetchone()[0] + assert result2 == 1 + + # Test new connection after importlib.reload() + with duckdb.connect(":memory:") as conn2: + conn2.execute("CREATE TABLE test2 (id INTEGER)") + conn2.execute("INSERT INTO test2 VALUES (2)") + result3 = conn2.execute("SELECT * FROM test2").fetchone()[0] + assert result3 == 2 + + +@pytest.mark.parallel_threads(1) +def test_dynamic_module_loading(): + import duckdb + + with duckdb.connect(":memory:") as conn: + conn.execute("SELECT 1").fetchone() + + module_id_1 = id(sys.modules["duckdb"]) + + # "Unload" module (not really, just to try it) + if "duckdb" in sys.modules: + del sys.modules["duckdb"] + + # Remove from local namespace + if "duckdb" in locals(): + del duckdb + + # Verify module is unloaded + assert "duckdb" not in sys.modules, "Module not properly unloaded" + + # import (load) module + import duckdb + + module_id_2 = id(sys.modules["duckdb"]) + + # Verify we have a new module instance + assert module_id_1 != module_id_2, "Module not actually reloaded" + + # Test functionality after reload + with duckdb.connect(":memory:") as conn: + conn.execute("CREATE TABLE test (id INTEGER)") + conn.execute("INSERT INTO test VALUES (42)") + result = conn.execute("SELECT * FROM test").fetchone()[0] + assert result == 42 + + +def test_import_cache_consistency(): + """Test that import cache remains consistent across module operations.""" + import pandas as pd + + import duckdb + + conn = duckdb.connect(":memory:") + + df = pd.DataFrame({"a": [1, 2, 3]}) + + conn.register("test_df", df) + result = conn.execute("SELECT COUNT(*) FROM test_df").fetchone()[0] + assert result == 3 + + conn.close() + + +def test_module_state_memory_safety(): + """Test memory safety of module state access patterns.""" + import duckdb + + connections = [] + for i in range(10): + conn = duckdb.connect(":memory:") + conn.execute(f"CREATE TABLE test_{i} (id INTEGER)") + conn.execute(f"INSERT INTO test_{i} VALUES ({i})") + connections.append(conn) + + import gc + + gc.collect() + + for i, conn in enumerate(connections): + result = conn.execute(f"SELECT * FROM test_{i}").fetchone()[0] + assert result == i + + for conn in connections: + conn.close() + + +def test_static_cache_stress(): + """Test rapid module state access.""" + import duckdb + + iterations = 5 + for _ in range(iterations): + conn = duckdb.connect(":memory:") + result = conn.execute("SELECT 1").fetchone() + assert result[0] == 1 + conn.close() + + +def test_concurrent_module_access(): + import duckdb + + thread_id = get_ident() + with duckdb.connect(":memory:") as conn: + conn.execute(f"CREATE TABLE test_{thread_id} (id BIGINT)") + conn.execute(f"INSERT INTO test_{thread_id} VALUES ({thread_id})") + result = conn.execute(f"SELECT * FROM test_{thread_id}").fetchone()[0] + assert result == thread_id diff --git a/tests/fast/threading/test_module_state.py b/tests/fast/threading/test_module_state.py new file mode 100644 index 00000000..17fd0b22 --- /dev/null +++ b/tests/fast/threading/test_module_state.py @@ -0,0 +1,36 @@ +from threading import get_ident + +import duckdb + + +def test_concurrent_connection_creation(): + thread_id = get_ident() + for i in range(5): + with duckdb.connect(":memory:") as conn: + conn.execute(f"CREATE TABLE test_{i} (x BIGINT)") + conn.execute(f"INSERT INTO test_{i} VALUES ({thread_id})") + result = conn.execute(f"SELECT * FROM test_{i}").fetchall() + assert result == [(thread_id,)], f"Table {i} failed" + + +def test_concurrent_instance_cache_access(tmp_path): + thread_id = get_ident() + for i in range(10): + db_path = str(tmp_path / f"test_{thread_id}_{i}.db") + with duckdb.connect(db_path) as conn: + conn.execute("CREATE TABLE IF NOT EXISTS test (x BIGINT, thread_id BIGINT)") + conn.execute(f"INSERT INTO test VALUES ({i}, {thread_id})") + result = conn.execute("SELECT COUNT(*) FROM test").fetchone() + assert result[0] >= 1 + + +def test_environment_detection(): + version = duckdb.__formatted_python_version__ + interactive = duckdb.__interactive__ + + assert isinstance(version, str), "version should be string" + assert isinstance(interactive, bool), "interactive should be boolean" + + with duckdb.connect(":memory:") as conn: + result = conn.execute("SELECT 1").fetchone() + assert result[0] == 1 diff --git a/tests/fast/threading/test_query_execution_races.py b/tests/fast/threading/test_query_execution_races.py new file mode 100644 index 00000000..c3ba19d9 --- /dev/null +++ b/tests/fast/threading/test_query_execution_races.py @@ -0,0 +1,165 @@ +"""Test concurrent query execution races. + +This tests race conditions in query execution paths where GIL is released +during query processing, as identified in pyconnection.cpp. +""" + +import concurrent.futures +import threading +from threading import get_ident + +import pytest + +import duckdb + + +class QueryRaceTester: + """Increases contention by aligning tests w a barrier.""" + + def setup_barrier(self, num_threads): + self.barrier = threading.Barrier(num_threads) + + def synchronized_execute(self, db, query, description="query"): + with db.cursor() as conn: + self.barrier.wait() + conn.execute(query).fetchall() + return True + + +@pytest.mark.parallel_threads(1) +def test_concurrent_prepare_execute(): + num_threads = 5 + conn = duckdb.connect(":memory:") + try: + conn.execute("CREATE TABLE test_data (id INTEGER, value VARCHAR)") + conn.execute("INSERT INTO test_data SELECT i, 'value_' || i FROM range(1000) t(i)") + + tester = QueryRaceTester() + tester.setup_barrier(num_threads) + + def prepare_and_execute(thread_id, conn): + queries = [ + f"SELECT COUNT(*) FROM test_data WHERE id > {thread_id * 10}", + f"SELECT value FROM test_data WHERE id = {thread_id + 1}", + f"SELECT id, value FROM test_data WHERE id BETWEEN {thread_id} AND {thread_id + 10}", + f"INSERT INTO test_data VALUES ({1000 + thread_id}, 'thread_{thread_id}')", + f"UPDATE test_data SET value = 'updated_{thread_id}' WHERE id = {thread_id + 500}", + ] + + query = queries[thread_id % len(queries)] + return tester.synchronized_execute(conn, query, f"Prepared query {thread_id}") + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(prepare_and_execute, i, conn) for i in range(num_threads)] + results = [future.result() for future in concurrent.futures.as_completed(futures)] + + assert len(results) == num_threads + assert all(results) + finally: + conn.close() + + +@pytest.mark.parallel_threads(1) +def test_concurrent_pending_query_execution(): + conn = duckdb.connect(":memory:") + try: + conn.execute( + "CREATE TABLE large_data AS SELECT i, i*2 as double_val, 'row_' || i as str_val FROM range(10000) t(i)" + ) + + num_threads = 8 + tester = QueryRaceTester() + tester.setup_barrier(num_threads) + + def execute_long_query(thread_id): + queries = [ + "SELECT COUNT(*), AVG(double_val) FROM large_data", + "SELECT str_val, double_val FROM large_data WHERE i % 100 = 0 ORDER BY double_val", + f"SELECT * FROM large_data WHERE i BETWEEN {thread_id * 1000} AND {(thread_id + 1) * 1000}", + "SELECT i, double_val, str_val FROM large_data WHERE double_val > 5000 ORDER BY i DESC", + f"SELECT COUNT(*) as cnt FROM large_data WHERE str_val LIKE '%{thread_id}%'", + ] + + query = queries[thread_id % len(queries)] + return tester.synchronized_execute(conn, query, f"Long query {thread_id}") + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(execute_long_query, i) for i in range(num_threads)] + results = [future.result() for future in concurrent.futures.as_completed(futures)] + + assert all(results) + assert len(results) == num_threads + finally: + conn.close() + + +def test_execute_many_race(): + """Test executemany operations.""" + iterations = 10 + thread_id = get_ident() + + conn = duckdb.connect() + try: + batch_data = [(thread_id * 100 + i, f"name_{thread_id}_{i}") for i in range(iterations)] + conn.execute("CREATE TABLE batch_data (id BIGINT, name VARCHAR)") + conn.executemany("INSERT INTO batch_data VALUES (?, ?)", batch_data) + result = conn.execute(f"SELECT COUNT(*) FROM batch_data WHERE name LIKE 'name_{thread_id}_%'").fetchone() + assert result[0] == iterations + finally: + conn.close() + + +@pytest.mark.parallel_threads(1) +def test_query_interruption_race(): + conn = duckdb.connect(":memory:") + try: + conn.execute("CREATE TABLE interrupt_test AS SELECT i FROM range(100000) t(i)") + + num_threads = 6 + + def run_interruptible_query(thread_id): + with conn.cursor() as conn2: + if thread_id % 2 == 0: + # Fast query + conn2.execute("SELECT COUNT(*) FROM interrupt_test").fetchall() + return True + else: + # Potentially slower query + conn2.execute("SELECT i, i*i FROM interrupt_test WHERE i % 1000 = 0 ORDER BY i").fetchall() + return True + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(run_interruptible_query, i) for i in range(num_threads)] + results = [future.result() for future in concurrent.futures.as_completed(futures, timeout=30)] + + assert all(results) + finally: + conn.close() + + +def test_mixed_query_operations(): + """Test mixed query operations.""" + thread_id = get_ident() + + with duckdb.connect(":memory:") as conn: + conn.execute("CREATE TABLE mixed_ops (id BIGINT PRIMARY KEY, data VARCHAR, num_val DOUBLE)") + conn.execute("INSERT INTO mixed_ops SELECT i, 'initial_' || i, i * 1.5 FROM range(1000) t(i)") + + queries = [ + f"SELECT COUNT(*) FROM mixed_ops WHERE id > {thread_id * 50}", + f"INSERT INTO mixed_ops VALUES ({10000 + thread_id}, 'thread_{thread_id}', {thread_id * 2.5})", + f"UPDATE mixed_ops SET data = 'updated_{thread_id}' WHERE id = {thread_id + 100}", + "SELECT AVG(num_val), MAX(id) FROM mixed_ops WHERE data LIKE 'initial_%'", + """ + SELECT m1.id, m1.data, m2.num_val + FROM mixed_ops m1 + JOIN mixed_ops m2 ON m1.id = m2.id - 1 + LIMIT 10 + """, + ] + + for query in queries: + result = conn.execute(query) + if "SELECT" in query.upper(): + rows = result.fetchall() + assert len(rows) >= 0 diff --git a/tests/fast/threading/test_threading.py b/tests/fast/threading/test_threading.py new file mode 100644 index 00000000..b80dedb2 --- /dev/null +++ b/tests/fast/threading/test_threading.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 +"""Tests designed to expose specific threading bugs in the DuckDB implementation.""" + +import sys +from threading import get_ident + +import duckdb + + +def test_gil_enabled(): + # Safeguard to ensure GIL is disabled if this is a free-threading build to ensure test validity + # this would fail if tests were run with PYTHON_GIL=1, as one example + if "free-threading" in sys.version: + import sysconfig + + print(f"Free-threading Python detected: {sys.version}") + print(f"Py_GIL_DISABLED = {sysconfig.get_config_var('Py_GIL_DISABLED')}") + + assert sysconfig.get_config_var("Py_GIL_DISABLED") == 1, ( + f"Py_GIL_DISABLED must be 1 in free-threading build, got: {sysconfig.get_config_var('Py_GIL_DISABLED')}" + ) + + +def test_instance_cache_race(tmp_path): + """Test opening connections to different files.""" + tid = get_ident() + with duckdb.connect(tmp_path / f"{tid}_testing.db") as conn: + conn.execute("CREATE TABLE IF NOT EXISTS test (x INTEGER, y INTEGER)") + conn.execute("INSERT INTO test VALUES (123, 456)") + result = conn.execute("SELECT COUNT(*) FROM test").fetchone() + assert result[0] >= 1 diff --git a/tests/fast/threading/test_udf_threaded.py b/tests/fast/threading/test_udf_threaded.py new file mode 100644 index 00000000..2a76cd29 --- /dev/null +++ b/tests/fast/threading/test_udf_threaded.py @@ -0,0 +1,80 @@ +"""Test User Defined Function (UDF).""" + +import concurrent.futures + +import pytest + +import duckdb + + +def test_concurrent_udf_registration(): + """Test UDF registration.""" + with duckdb.connect(":memory:") as conn: + + def my_add(x: int, y: int) -> int: + return x + y + + udf_name = "my_add_1" + conn.create_function(udf_name, my_add) + + result = conn.execute(f"SELECT {udf_name}(1, 2)").fetchone() + assert result[0] == 3 + + +def test_mixed_udf_operations(): + conn = duckdb.connect(":memory:") + try: + # Register and use UDF + def thread_func(x: int) -> int: + return x * 2 + + udf_name = "thread_func_1" + conn.create_function(udf_name, thread_func) + result1 = conn.execute(f"SELECT {udf_name}(5)").fetchone() + assert result1[0] == 10 + + # Simple query + result2 = conn.execute("SELECT 42").fetchone() + assert result2[0] == 42 + + # Create table and use built-in functions + conn.execute("CREATE TABLE test_table (x INTEGER)") + conn.execute("INSERT INTO test_table VALUES (1), (2), (3)") + result3 = conn.execute("SELECT COUNT(*) FROM test_table").fetchone() + assert result3[0] == 3 + finally: + conn.close() + + +@pytest.mark.parallel_threads(1) +def test_scalar_udf_concurrent(): + num_threads = 5 + conn = duckdb.connect(":memory:") + + # Create test data + conn.execute("CREATE TABLE numbers (x INTEGER)") + conn.execute("INSERT INTO numbers SELECT * FROM range(100)") + + # Create a simple scalar UDF instead of vectorized (simpler for testing) + def simple_square(x: int) -> int: + """Square a single value.""" + return x * x + + conn.create_function("simple_square", simple_square) + + def execute_scalar_udf(thread_id): + start = thread_id * 10 + end = start + 10 + query = f"SELECT simple_square(x) FROM numbers WHERE x BETWEEN {start} AND {end}" + with conn.cursor() as c: + assert c.execute(query).fetchone()[0] == (start**2) + + return True + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(execute_scalar_udf, i) for i in range(num_threads)] + results = [future.result() for future in concurrent.futures.as_completed(futures)] + + conn.close() + + assert all(results)