diff --git a/pyproject.toml b/pyproject.toml index 546460cb..cac3e92e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,7 +106,7 @@ cmake.define.CMAKE_C_FLAGS = "--coverage -O0" cmake.define.CMAKE_SHARED_LINKER_FLAGS = "--coverage" # Override: if we're in editable mode then make sure a build dir is set. Note that COVERAGE runs have their own -# build-dir, and we don't want to interfere with that. We're also disabling unity builds to help with debugging. +# build-dir, and we don't want to interfere with that. [[tool.scikit-build.overrides]] if.state = "editable" if.env.COVERAGE = false @@ -121,7 +121,8 @@ cmake.build-type = "Debug" if.state = "editable" if.env.COVERAGE = false inherit.cmake.define = "append" -cmake.define.DISABLE_UNITY = "1" +cmake.define.DISABLE_UNITY = "0" + [tool.scikit-build.sdist] include = [ @@ -227,6 +228,7 @@ test = [ # dependencies used for running tests "pytest", "pytest-reraise", "pytest-timeout", + "pytest-xdist", "mypy", "coverage", "gcovr; python_version < '3.14'", @@ -379,4 +381,25 @@ manylinux-x86_64-image = "manylinux_2_28" manylinux-pypy_x86_64-image = "manylinux_2_28" manylinux-aarch64-image = "manylinux_2_28" manylinux-pypy_aarch64-image = "manylinux_2_28" -enable = ["cpython-freethreading", "cpython-prerelease"] + + +[tool.cibuildwheel.linux] +environment-pass = ["SCCACHE_GHA_ENABLED", "ACTIONS_RUNTIME_TOKEN", "ACTIONS_RESULTS_URL", "ACTIONS_CACHE_SERVICE_V2", "SCCACHE_C_CUSTOM_CACHE_BUSTER", "PYTHON_GIL"] +before-build = [ + "if [ \"$(uname -m)\" = \"aarch64\" ]; then ARCH=aarch64; else ARCH=x86_64; fi", + "curl -L https://github.com/mozilla/sccache/releases/download/v0.10.0/sccache-v0.10.0-${ARCH}-unknown-linux-musl.tar.gz | tar xz", + "cp sccache-v0.10.0-${ARCH}-unknown-linux-musl/sccache /usr/bin", + "sccache --show-stats"] +before-test = ["sccache --show-stats"] + +[tool.cibuildwheel.macos] +environment-pass = ["SCCACHE_GHA_ENABLED", "ACTIONS_RUNTIME_TOKEN", "ACTIONS_RESULTS_URL", "ACTIONS_CACHE_SERVICE_V2", "SCCACHE_C_CUSTOM_CACHE_BUSTER", "PYTHON_GIL"] +before-build = ["brew install sccache"] + +[tool.cibuildwheel.windows] +before-build = ["choco install ccache"] + +# Override for free threading builds (cp314t) - set PYTHON_GIL=0 +[[tool.cibuildwheel.overrides]] +select = "cp314t-*" +environment = { PYTHON_GIL = "0" } diff --git a/src/duckdb_py/CMakeLists.txt b/src/duckdb_py/CMakeLists.txt index 2252ba29..7673d17f 100644 --- a/src/duckdb_py/CMakeLists.txt +++ b/src/duckdb_py/CMakeLists.txt @@ -17,6 +17,7 @@ add_library(python_src OBJECT duckdb_python.cpp importer.cpp map.cpp + module_state.cpp path_like.cpp pyconnection.cpp pyexpression.cpp diff --git a/src/duckdb_py/duckdb_python.cpp b/src/duckdb_py/duckdb_python.cpp index 939fa41a..f84377a7 100644 --- a/src/duckdb_py/duckdb_python.cpp +++ b/src/duckdb_py/duckdb_python.cpp @@ -20,6 +20,7 @@ #include "duckdb_python/pybind11/conversions/python_udf_type_enum.hpp" #include "duckdb_python/pybind11/conversions/python_csv_line_terminator_enum.hpp" #include "duckdb/common/enums/statement_type.hpp" +#include "duckdb_python/module_state.hpp" #include "duckdb.hpp" @@ -31,6 +32,16 @@ namespace py = pybind11; namespace duckdb { +// Private function to initialize module state +void InitializeModuleState(py::module_ &m) { + auto state_ptr = new DuckDBPyModuleState(); + SetModuleState(state_ptr); + + // https://pybind11.readthedocs.io/en/stable/advanced/misc.html#module-destructors + auto capsule = py::capsule(state_ptr, [](void *p) { delete static_cast(p); }); + m.attr("__duckdb_state") = capsule; +} + enum PySQLTokenType : uint8_t { PY_SQL_TOKEN_IDENTIFIER = 0, PY_SQL_TOKEN_NUMERIC_CONSTANT, @@ -1007,7 +1018,22 @@ static void RegisterExpectedResultType(py::handle &m) { expected_return_type.export_values(); } -PYBIND11_MODULE(DUCKDB_PYTHON_LIB_NAME, m) { // NOLINT +// Only mark mod_gil_not_used for 3.14t or later +// This is to not add support for 3.13t +// Py_GIL_DISABLED check is not strictly necessary +#if defined(Py_GIL_DISABLED) && PY_VERSION_HEX >= 0x030e0000 +PYBIND11_MODULE(DUCKDB_PYTHON_LIB_NAME, m, py::mod_gil_not_used(), + py::multiple_interpreters::not_supported()) { // NOLINT +#else +PYBIND11_MODULE(DUCKDB_PYTHON_LIB_NAME, m, + py::multiple_interpreters::not_supported()) { // NOLINT +#endif + + // Initialize module state completely during initialization + // PEP 489 wants calls for state to be module local, but currently + // static via g_module_state. + InitializeModuleState(m); + py::enum_(m, "ExplainType") .value("STANDARD", duckdb::ExplainType::EXPLAIN_STANDARD) .value("ANALYZE", duckdb::ExplainType::EXPLAIN_ANALYZE) @@ -1046,9 +1072,10 @@ PYBIND11_MODULE(DUCKDB_PYTHON_LIB_NAME, m) { // NOLINT m.attr("__version__") = std::string(DuckDB::LibraryVersion()).substr(1); m.attr("__standard_vector_size__") = DuckDB::StandardVectorSize(); m.attr("__git_revision__") = DuckDB::SourceID(); - m.attr("__interactive__") = DuckDBPyConnection::DetectAndGetEnvironment(); - m.attr("__jupyter__") = DuckDBPyConnection::IsJupyter(); - m.attr("__formatted_python_version__") = DuckDBPyConnection::FormattedPythonVersion(); + auto &module_state = GetModuleState(); + m.attr("__interactive__") = module_state.environment != PythonEnvironmentType::NORMAL; + m.attr("__jupyter__") = module_state.environment == PythonEnvironmentType::JUPYTER; + m.attr("__formatted_python_version__") = module_state.formatted_python_version; m.def("default_connection", &DuckDBPyConnection::DefaultConnection, "Retrieve the connection currently registered as the default to be used by the module"); m.def("set_default_connection", &DuckDBPyConnection::SetDefaultConnection, diff --git a/src/duckdb_py/include/duckdb_python/module_state.hpp b/src/duckdb_py/include/duckdb_python/module_state.hpp new file mode 100644 index 00000000..d7a4e377 --- /dev/null +++ b/src/duckdb_py/include/duckdb_python/module_state.hpp @@ -0,0 +1,63 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb_python/module_state.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb_python/pybind11/pybind_wrapper.hpp" +#include "duckdb/common/shared_ptr.hpp" +#include "duckdb/main/db_instance_cache.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb_python/import_cache/python_import_cache.hpp" +#include "duckdb_python/pyconnection/pyconnection.hpp" +#include + +namespace duckdb { + +// Module state structure to hold per-interpreter state +struct DuckDBPyModuleState { + // Python environment tracking + PythonEnvironmentType environment = PythonEnvironmentType::NORMAL; + string formatted_python_version; + + DuckDBPyModuleState(); + + shared_ptr GetDefaultConnection(); + void SetDefaultConnection(shared_ptr connection); + void ClearDefaultConnection(); + + PythonImportCache *GetImportCache(); + void ClearImportCache(); + + DBInstanceCache *GetInstanceCache(); + + static DuckDBPyModuleState &GetGlobalModuleState(); + static void SetGlobalModuleState(DuckDBPyModuleState *state); + +private: + shared_ptr default_connection_ptr; + PythonImportCache import_cache; + DBInstanceCache instance_cache; +#ifdef Py_GIL_DISABLED + py::object default_con_lock; +#endif + + // Implemented as static as a first step towards PEP 489 / multi-phase init + // Intent is to move to per-module object, but frequent calls to import_cache + // need to be considered carefully. + // TODO: Replace with non-static per-interpreter state for multi-interpreter support + static DuckDBPyModuleState *g_module_state; + + // Non-copyable + DuckDBPyModuleState(const DuckDBPyModuleState &) = delete; + DuckDBPyModuleState &operator=(const DuckDBPyModuleState &) = delete; +}; + +DuckDBPyModuleState &GetModuleState(); +void SetModuleState(DuckDBPyModuleState *state); + +} // namespace duckdb \ No newline at end of file diff --git a/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp b/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp index 48ee055e..7998c14e 100644 --- a/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp +++ b/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp @@ -28,6 +28,7 @@ namespace duckdb { struct BoundParameterData; +struct DuckDBPyModuleState; enum class PythonEnvironmentType { NORMAL, INTERACTIVE, JUPYTER }; @@ -172,8 +173,7 @@ struct DuckDBPyConnection : public enable_shared_from_this { case_insensitive_set_t registered_objects; public: - explicit DuckDBPyConnection() { - } + DuckDBPyConnection(); ~DuckDBPyConnection(); public: @@ -190,9 +190,17 @@ struct DuckDBPyConnection : public enable_shared_from_this { static std::string FormattedPythonVersion(); static shared_ptr DefaultConnection(); static void SetDefaultConnection(shared_ptr conn); + static shared_ptr GetDefaultConnection(); + static void ClearDefaultConnection(); + static void ClearImportCache(); static PythonImportCache *ImportCache(); static bool IsInteractive(); + // Instance methods for optimized module state access + bool IsJupyterInstance() const; + bool IsInteractiveInstance() const; + std::string FormattedPythonVersionInstance() const; + unique_ptr ReadCSV(const py::object &name, py::kwargs &kwargs); py::list ExtractStatements(const string &query); @@ -337,11 +345,6 @@ struct DuckDBPyConnection : public enable_shared_from_this { py::list ListFilesystems(); bool FileSystemIsRegistered(const string &name); - //! Default connection to an in-memory database - static DefaultConnectionHolder default_connection; - //! Caches and provides an interface to get frequently used modules+subtypes - static shared_ptr import_cache; - static bool IsPandasDataframe(const py::object &object); static PyArrowObjectType GetArrowType(const py::handle &obj); static bool IsAcceptedArrowObject(const py::object &object); @@ -357,10 +360,6 @@ struct DuckDBPyConnection : public enable_shared_from_this { bool side_effects); void RegisterArrowObject(const py::object &arrow_object, const string &name); vector> GetStatements(const py::object &query); - - static PythonEnvironmentType environment; - static std::string formatted_python_version; - static void DetectEnvironment(); }; template diff --git a/src/duckdb_py/module_state.cpp b/src/duckdb_py/module_state.cpp new file mode 100644 index 00000000..1e0b6897 --- /dev/null +++ b/src/duckdb_py/module_state.cpp @@ -0,0 +1,128 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb_python/module_state.cpp +// +// +//===----------------------------------------------------------------------===// + +#include "duckdb_python/module_state.hpp" +#include +#include +#include + +#define DEBUG_MODULE_STATE 0 + +namespace duckdb { + +// Forward declaration from pyconnection.cpp +void InstantiateNewInstance(DuckDB &db); + +// Static member initialization - required for all static class members in C++ +DuckDBPyModuleState *DuckDBPyModuleState::g_module_state = nullptr; + +DuckDBPyModuleState::DuckDBPyModuleState() { + // Caches are constructed as direct objects - no heap allocation needed + +#ifdef Py_GIL_DISABLED + // Initialize lock object for critical sections + // TODO: Consider moving to finer-grained locks + default_con_lock = py::none(); +#endif + + // Detects Python environment and version during intialization + // Moved from DuckDBPyConnection::DetectEnvironment() + py::module_ sys = py::module_::import("sys"); + py::object version_info = sys.attr("version_info"); + int major = py::cast(version_info.attr("major")); + int minor = py::cast(version_info.attr("minor")); + formatted_python_version = std::to_string(major) + "." + std::to_string(minor); + + // If __main__ does not have a __file__ attribute, we are in interactive mode + auto main_module = py::module_::import("__main__"); + if (!py::hasattr(main_module, "__file__")) { + environment = PythonEnvironmentType::INTERACTIVE; + + if (ModuleIsLoaded()) { + // Check to see if we are in a Jupyter Notebook + auto get_ipython = import_cache.IPython.get_ipython(); + if (get_ipython.ptr() != nullptr) { + auto ipython = get_ipython(); + if (py::hasattr(ipython, "config")) { + py::dict ipython_config = ipython.attr("config"); + if (ipython_config.contains("IPKernelApp")) { + environment = PythonEnvironmentType::JUPYTER; + } + } + } + } + } +} + +DuckDBPyModuleState &DuckDBPyModuleState::GetGlobalModuleState() { + // TODO: Externalize this static cache when adding multi-interpreter support + // For now, single interpreter assumption allows simple static caching + if (!g_module_state) { + throw InternalException("Module state not initialized - call SetGlobalModuleState() during module init"); + } + return *g_module_state; +} + +void DuckDBPyModuleState::SetGlobalModuleState(DuckDBPyModuleState *state) { +#if DEBUG_MODULE_STATE + printf("DEBUG: SetGlobalModuleState() called - initializing static cache (built: %s %s)\n", __DATE__, __TIME__); +#endif + g_module_state = state; +} + +DuckDBPyModuleState &GetModuleState() { +#if DEBUG_MODULE_STATE + printf("DEBUG: GetModuleState() called\n"); +#endif + return DuckDBPyModuleState::GetGlobalModuleState(); +} + +void SetModuleState(DuckDBPyModuleState *state) { + DuckDBPyModuleState::SetGlobalModuleState(state); +} + +shared_ptr DuckDBPyModuleState::GetDefaultConnection() { +#if defined(Py_GIL_DISABLED) + // TODO: Consider whether a mutex vs a scoped_critical_section + py::scoped_critical_section guard(default_con_lock); +#endif + // Reproduce exact logic from original DefaultConnectionHolder::Get() + if (!default_connection_ptr || default_connection_ptr->con.ConnectionIsClosed()) { + py::dict config_dict; + default_connection_ptr = DuckDBPyConnection::Connect(py::str(":memory:"), false, config_dict); + } + return default_connection_ptr; +} + +void DuckDBPyModuleState::SetDefaultConnection(shared_ptr connection) { +#if defined(Py_GIL_DISABLED) + py::scoped_critical_section guard(default_con_lock); +#endif + default_connection_ptr = std::move(connection); +} + +void DuckDBPyModuleState::ClearDefaultConnection() { +#if defined(Py_GIL_DISABLED) + py::scoped_critical_section guard(default_con_lock); +#endif + default_connection_ptr = nullptr; +} + +PythonImportCache *DuckDBPyModuleState::GetImportCache() { + return &import_cache; +} + +void DuckDBPyModuleState::ClearImportCache() { + import_cache = PythonImportCache(); +} + +DBInstanceCache *DuckDBPyModuleState::GetInstanceCache() { + return &instance_cache; +} + +} // namespace duckdb \ No newline at end of file diff --git a/src/duckdb_py/pyconnection.cpp b/src/duckdb_py/pyconnection.cpp index b88b88ed..8c129a09 100644 --- a/src/duckdb_py/pyconnection.cpp +++ b/src/duckdb_py/pyconnection.cpp @@ -1,4 +1,6 @@ #include "duckdb_python/pyconnection/pyconnection.hpp" +#include "duckdb_python/module_state.hpp" +#include #include "duckdb/catalog/default/default_types.hpp" #include "duckdb/common/arrow/arrow.hpp" @@ -66,11 +68,8 @@ namespace duckdb { -DefaultConnectionHolder DuckDBPyConnection::default_connection; // NOLINT: allow global -DBInstanceCache instance_cache; // NOLINT: allow global -shared_ptr DuckDBPyConnection::import_cache = nullptr; // NOLINT: allow global -PythonEnvironmentType DuckDBPyConnection::environment = PythonEnvironmentType::NORMAL; // NOLINT: allow global -std::string DuckDBPyConnection::formatted_python_version = ""; +DuckDBPyConnection::DuckDBPyConnection() { +} DuckDBPyConnection::~DuckDBPyConnection() { try { @@ -82,53 +81,17 @@ DuckDBPyConnection::~DuckDBPyConnection() { } } -void DuckDBPyConnection::DetectEnvironment() { - // Get the formatted Python version - py::module_ sys = py::module_::import("sys"); - py::object version_info = sys.attr("version_info"); - int major = py::cast(version_info.attr("major")); - int minor = py::cast(version_info.attr("minor")); - DuckDBPyConnection::formatted_python_version = std::to_string(major) + "." + std::to_string(minor); - - // If __main__ does not have a __file__ attribute, we are in interactive mode - auto main_module = py::module_::import("__main__"); - if (py::hasattr(main_module, "__file__")) { - return; - } - DuckDBPyConnection::environment = PythonEnvironmentType::INTERACTIVE; - if (!ModuleIsLoaded()) { - return; - } - - // Check to see if we are in a Jupyter Notebook - auto &import_cache_py = *DuckDBPyConnection::ImportCache(); - auto get_ipython = import_cache_py.IPython.get_ipython(); - if (get_ipython.ptr() == nullptr) { - // Could either not load the IPython module, or it has no 'get_ipython' attribute - return; - } - auto ipython = get_ipython(); - if (!py::hasattr(ipython, "config")) { - return; - } - py::dict ipython_config = ipython.attr("config"); - if (ipython_config.contains("IPKernelApp")) { - DuckDBPyConnection::environment = PythonEnvironmentType::JUPYTER; - } - return; -} - bool DuckDBPyConnection::DetectAndGetEnvironment() { - DuckDBPyConnection::DetectEnvironment(); + // Environment detection now happens during module state construction return DuckDBPyConnection::IsInteractive(); } bool DuckDBPyConnection::IsJupyter() { - return DuckDBPyConnection::environment == PythonEnvironmentType::JUPYTER; + return GetModuleState().environment == PythonEnvironmentType::JUPYTER; } std::string DuckDBPyConnection::FormattedPythonVersion() { - return DuckDBPyConnection::formatted_python_version; + return GetModuleState().formatted_python_version; } // NOTE: this function is generated by tools/pythonpkg/scripts/generate_connection_methods.py. @@ -1820,6 +1783,7 @@ int DuckDBPyConnection::GetRowcount() { void DuckDBPyConnection::Close() { con.SetResult(nullptr); D_ASSERT(py::gil_check()); + py::gil_scoped_release release; con.SetConnection(nullptr); con.SetDatabase(nullptr); @@ -2111,10 +2075,11 @@ static shared_ptr FetchOrCreateInstance(const string &databa config.replacement_scans.emplace_back(PythonReplacementScan::Replace); { D_ASSERT(py::gil_check()); + py::gil_scoped_release release; unique_lock lock(res->py_connection_lock); - auto database = - instance_cache.GetOrCreateInstance(database_path, config, cache_instance, InstantiateNewInstance); + auto database = GetModuleState().GetInstanceCache()->GetOrCreateInstance(database_path, config, cache_instance, + InstantiateNewInstance); res->con.SetDatabase(std::move(database)); res->con.SetConnection(make_uniq(res->con.GetDatabase())); } @@ -2162,6 +2127,7 @@ shared_ptr DuckDBPyConnection::Connect(const py::object &dat "python_scan_all_frames", "If set, restores the old behavior of scanning all preceding frames to locate the referenced variable.", LogicalType::BOOLEAN, Value::BOOLEAN(false)); + // Use static methods here since we don't have connection instance yet if (!DuckDBPyConnection::IsJupyter()) { config_dict["duckdb_api"] = Value("python/" + DuckDBPyConnection::FormattedPythonVersion()); } else { @@ -2197,18 +2163,27 @@ case_insensitive_map_t DuckDBPyConnection::TransformPythonPa } shared_ptr DuckDBPyConnection::DefaultConnection() { - return default_connection.Get(); + return GetModuleState().GetDefaultConnection(); } void DuckDBPyConnection::SetDefaultConnection(shared_ptr connection) { - return default_connection.Set(std::move(connection)); + GetModuleState().SetDefaultConnection(std::move(connection)); } PythonImportCache *DuckDBPyConnection::ImportCache() { - if (!import_cache) { - import_cache = make_shared_ptr(); - } - return import_cache.get(); + return GetModuleState().GetImportCache(); +} + +bool DuckDBPyConnection::IsJupyterInstance() const { + return GetModuleState().environment == PythonEnvironmentType::JUPYTER; +} + +bool DuckDBPyConnection::IsInteractiveInstance() const { + return GetModuleState().environment != PythonEnvironmentType::NORMAL; +} + +std::string DuckDBPyConnection::FormattedPythonVersionInstance() const { + return GetModuleState().formatted_python_version; } ModifiedMemoryFileSystem &DuckDBPyConnection::GetObjectFileSystem() { @@ -2228,7 +2203,7 @@ ModifiedMemoryFileSystem &DuckDBPyConnection::GetObjectFileSystem() { } bool DuckDBPyConnection::IsInteractive() { - return DuckDBPyConnection::environment != PythonEnvironmentType::NORMAL; + return GetModuleState().environment != PythonEnvironmentType::NORMAL; } shared_ptr DuckDBPyConnection::Enter() { @@ -2246,8 +2221,25 @@ void DuckDBPyConnection::Exit(DuckDBPyConnection &self, const py::object &exc_ty } void DuckDBPyConnection::Cleanup() { - default_connection.Set(nullptr); - import_cache.reset(); + try { + GetModuleState().ClearDefaultConnection(); + GetModuleState().ClearImportCache(); + } catch (...) { // NOLINT + // TODO: Can we detect shutdown? Py_IsFinalizing might be appropriate, although renamed from + // _Py_IsFinalizing + } +} + +shared_ptr DuckDBPyConnection::GetDefaultConnection() { + return GetModuleState().GetDefaultConnection(); +} + +void DuckDBPyConnection::ClearDefaultConnection() { + GetModuleState().ClearDefaultConnection(); +} + +void DuckDBPyConnection::ClearImportCache() { + GetModuleState().ClearImportCache(); } bool DuckDBPyConnection::IsPandasDataframe(const py::object &object) { diff --git a/tests/conftest.py b/tests/conftest.py index 5e297aee..33a15839 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,13 @@ from importlib import import_module import sys +# Safeguard to ensure GIL is disabled if this is a free-threading build to ensure test validity +if 'free-threading' in sys.version: + import sysconfig + assert sysconfig.get_config_var('Py_GIL_DISABLED') == 1, f"Py_GIL_DISABLED must be 1 in free-threading build, got: {sysconfig.get_config_var('Py_GIL_DISABLED')}" + print(f"Free-threading Python detected: {sys.version}") + print(f"Py_GIL_DISABLED = {sysconfig.get_config_var('Py_GIL_DISABLED')}") + try: # need to ignore warnings that might be thrown deep inside pandas's import tree (from dateutil in this case) warnings.simplefilter(action="ignore", category=DeprecationWarning) @@ -336,3 +343,12 @@ def finalizer(): duckdb.connect(test_dbfarm) return test_dbfarm + +@pytest.fixture(scope="function") +def num_threads_testing(): + """Get thread count: enough to load the system, but still as fast test.""" + import multiprocessing + + cpu_count = multiprocessing.cpu_count() + # Use 1.5x CPU count, max 12 for CI compatibility + return min(12, max(4, int(cpu_count * 1.5))) diff --git a/tests/fast/numpy/test_numpy_new_path.py b/tests/fast/numpy/test_numpy_new_path.py index 3735ff6e..c1122797 100644 --- a/tests/fast/numpy/test_numpy_new_path.py +++ b/tests/fast/numpy/test_numpy_new_path.py @@ -10,6 +10,7 @@ class TestScanNumpy(object): + @pytest.mark.xfail(sys.version_info[:2] == (3, 14), reason="Fails when testing without pandas https://github.com/duckdb/duckdb-python/issues/48") def test_scan_numpy(self, duckdb_cursor): z = np.array([1, 2, 3]) res = duckdb_cursor.sql("select * from z").fetchall() diff --git a/tests/fast/threading/test_concurrent_access.py b/tests/fast/threading/test_concurrent_access.py new file mode 100644 index 00000000..8b1b528f --- /dev/null +++ b/tests/fast/threading/test_concurrent_access.py @@ -0,0 +1,207 @@ +""" +Concurrent access tests for DuckDB Python bindings with free threading support. + +These tests verify that the DuckDB Python module can handle concurrent access +from multiple threads safely, testing module state isolation, memory management, +and connection handling under various stress conditions. +""" + +import gc +import random +import time +import concurrent.futures +from typing import Tuple + +import pytest + +import duckdb + +@pytest.mark.parametrize("num_threads", [10, 25, 50]) +def test_concurrent_connections(num_threads): + """Test creating many connections concurrently from multiple threads.""" + + def create_connection_and_query(thread_id: int) -> Tuple[int, Tuple[int, int]]: + conn = duckdb.connect(':memory:') + try: + result = conn.execute(f"SELECT {thread_id} as thread_id, {thread_id * 2} as doubled").fetchone() + return (thread_id, result) + finally: + conn.close() + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(create_connection_and_query, i) for i in range(num_threads)] + results = [future.result() for future in concurrent.futures.as_completed(futures)] + + # Verify results are correct + assert len(results) == num_threads + for thread_id, result in results: + expected = (thread_id, thread_id * 2) + assert result == expected + + +@pytest.mark.parametrize("num_threads,iterations", [(10, 5), (20, 10)]) +def test_shared_connection_stress(num_threads, iterations): + """Test concurrent operations on shared connection using cursors.""" + + with duckdb.connect(':memory:') as connection: + connection.execute("CREATE TABLE stress_test (id INTEGER, thread_id INTEGER, value TEXT)") + + def worker_thread(thread_id: int) -> None: + cursor = connection.cursor() + for i in range(iterations): + cursor.execute( + "INSERT INTO stress_test VALUES (?, ?, ?)", + [i, thread_id, f"thread_{thread_id}_value_{i}"] + ) + cursor.execute( + "SELECT COUNT(*) FROM stress_test WHERE thread_id = ?", + [thread_id] + ).fetchone() + time.sleep(random.uniform(0.0001, 0.001)) + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(worker_thread, i) for i in range(num_threads)] + # Wait for all to complete, will raise if any fail + for future in concurrent.futures.as_completed(futures): + future.result() + + total_rows = connection.execute("SELECT COUNT(*) FROM stress_test").fetchone()[0] + expected_rows = num_threads * iterations + assert total_rows == expected_rows + + +def test_module_state_isolation(): + """Test that module state is properly isolated and accessible from all threads.""" + + def check_module_state(_thread_id: int) -> dict: + with duckdb.connect(':memory:'): + env_info = [ + hasattr(duckdb, '__version__'), + hasattr(duckdb, '__free_threading__'), + ] + + # Test default connection functionality (if available) + try: + with duckdb.connect() as default_conn: + default_conn.execute("SELECT 'default' as type").fetchone() + has_default = True + except Exception: + has_default = False + + int_type = duckdb.type('INTEGER') + string_type = duckdb.type('VARCHAR') + + return { + 'env_info': env_info, + 'has_default': has_default, + 'types_work': bool(int_type and string_type), + } + + num_threads = 30 + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(check_module_state, i) for i in range(num_threads)] + results = [future.result() for future in concurrent.futures.as_completed(futures)] + + # All threads should see the same module state + assert len(results) == num_threads + first_result = results[0] + for result in results[1:]: + assert result == first_result, f"Inconsistent module state: {result} != {first_result}" + +def test_memory_pressure(): + """Test memory management under high pressure with many concurrent operations.""" + + def memory_intensive_work(thread_id: int) -> None: + connections = [] + + # Create multiple connections + for i in range(5): + conn = duckdb.connect(':memory:') + connections.append(conn) + + # Create some data + conn.execute(f""" + CREATE TABLE data_{i} AS + SELECT range as id, + 'thread_{thread_id}_conn_{i}_row_' || range as value + FROM range(100) + """) + + # Do some queries + result = conn.execute(f"SELECT COUNT(*) FROM data_{i}").fetchone()[0] + assert result == 100 + + # Force some GC pressure + large_data = [] + for _ in range(10): + large_data.append([random.random() for _ in range(1000)]) + + # Clean up connections + for conn in connections: + conn.close() + + # Force garbage collection + del large_data + gc.collect() + + # Run memory-intensive work across many threads + num_threads = 25 + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(memory_intensive_work, i) for i in range(num_threads)] + for future in concurrent.futures.as_completed(futures): + future.result() # Will raise if any thread failed + +@pytest.mark.parametrize("num_threads,connections_per_thread", [(10, 25), (15, 50)]) +def test_rapid_connect_disconnect(num_threads, connections_per_thread): + """Test rapid connection creation and destruction to stress module state.""" + + def rapid_connections(_thread_id: int) -> None: + for i in range(connections_per_thread): + conn = duckdb.connect(':memory:') + try: + result = conn.execute("SELECT 1").fetchone()[0] + assert result == 1 + finally: + conn.close() + + # Sometimes force GC to increase pressure + if i % 10 == 0: + gc.collect() + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(rapid_connections, i) for i in range(num_threads)] + for future in concurrent.futures.as_completed(futures): + future.result() # Will raise if any thread failed + +def test_exception_handling(): + """Test that exceptions in one thread don't affect module state for others.""" + + def worker_with_exceptions(_thread_id: int) -> None: + conn = duckdb.connect(':memory:') + try: + # Do some successful operations + conn.execute("CREATE TABLE test (x INTEGER)") + conn.execute("INSERT INTO test VALUES (1), (2), (3)") + + # Intentionally cause errors every few operations + for i in range(10): + try: + if i % 3 == 0: + # This should fail + conn.execute("SELECT * FROM nonexistent_table") + else: + # This should succeed + result = conn.execute("SELECT COUNT(*) FROM test").fetchone()[0] + assert result == 3 + except Exception: + # Expected for every 3rd operation when querying nonexistent table + pass + finally: + conn.close() + + num_threads = 20 + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(worker_with_exceptions, i) for i in range(num_threads)] + for future in concurrent.futures.as_completed(futures): + future.result() # Will raise if any thread failed + diff --git a/tests/fast/threading/test_connection_lifecycle_races.py b/tests/fast/threading/test_connection_lifecycle_races.py new file mode 100644 index 00000000..b8036e8b --- /dev/null +++ b/tests/fast/threading/test_connection_lifecycle_races.py @@ -0,0 +1,217 @@ +""" +Test connection lifecycle races. + +Focused on DuckDBPyConnection constructor and Close +""" + +import gc +import threading +import concurrent.futures + +import pytest + +import duckdb + + +class ConnectionRaceTester: + def setup_barrier(self, num_threads): + self.barrier = threading.Barrier(num_threads) + + def synchronized_action(self, action_func, description="action"): + """Ensures all threads start at same time""" + self.barrier.wait() + result = action_func() + return True + + + +@pytest.mark.parametrize("num_threads", [15, 20]) +def test_concurrent_connection_creation_destruction(num_threads): + """Test creating and destroying connections concurrently.""" + + tester = ConnectionRaceTester() + tester.setup_barrier(num_threads) + + def create_and_destroy_connection(thread_id): + """Create, use, and destroy a connection.""" + + def action(): + conn = duckdb.connect() + try: + conn.execute("SELECT 1").fetchone() + finally: + conn.close() + + return True + + return tester.synchronized_action(action, f"Create/Destroy {thread_id}") + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [ + executor.submit(create_and_destroy_connection, i) + for i in range(num_threads) + ] + results = [ + future.result() for future in concurrent.futures.as_completed(futures) + ] + + assert all(results) + + + +def test_connection_destructor_race(): + num_threads = 15 + + tester = ConnectionRaceTester() + tester.setup_barrier(num_threads) + + def destroy_connection(thread_id): + """Destroy a connection (testing destructor race).""" + + def action(): + conn = duckdb.connect() + + conn.execute(f"SELECT COUNT(*) FROM range(1)").fetchone() + + del conn + gc.collect() + + return True + + return tester.synchronized_action(action, f"Destructor {thread_id}") + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(destroy_connection, i) for i in range(num_threads)] + results = [ + future.result() for future in concurrent.futures.as_completed(futures) + ] + + assert all(results) + +def test_concurrent_close_operations(): + conn = duckdb.connect(":memory:") + conn.execute("CREATE TABLE shared_table (id INTEGER, data VARCHAR)") + conn.execute("INSERT INTO shared_table VALUES (1, 'test')") + + num_threads = 10 + tester = ConnectionRaceTester() + tester.setup_barrier(num_threads) + + def attempt_close_connection(thread_id): + cursor = conn.cursor() + def action(): + try: + _result = cursor.execute( + "SELECT COUNT(*) FROM shared_table" + ).fetchone() + + # Try to close / only first thread should succeed + cursor.close() + + return f"close_succeeded_{thread_id}" + except Exception as e: + return f"close_failed_{thread_id}_{str(e)}" + + return tester.synchronized_action(action, f"Close attempt {thread_id}") + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [ + executor.submit(attempt_close_connection, i) for i in range(num_threads) + ] + results = [ + future.result() for future in concurrent.futures.as_completed(futures) + ] + + assert any(results), "No close attempts succeeded" + + +def test_connection_state_races(): + """Test race conditions in connection state management.""" + + num_threads = 12 + + def connection_state_operations(thread_id): + conn = duckdb.connect(":memory:") + operations = [ + lambda: conn.execute("SELECT 1").fetchone(), + lambda: conn.begin(), + lambda: conn.execute("CREATE TABLE test (x INTEGER)"), + lambda: conn.execute("INSERT INTO test VALUES (1)"), + lambda: conn.commit(), + lambda: conn.execute("SELECT * FROM test").fetchall(), + ] + + results = [] + for i, op in enumerate(operations): + try: + _result = op() + results.append(f"op_{i}_success") + except Exception as e: + results.append(f"op_{i}_failed_{type(e).__name__}") + + return True + + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [ + executor.submit(connection_state_operations, i) for i in range(num_threads) + ] + results = [ + future.result() for future in concurrent.futures.as_completed(futures) + ] + + assert all(results) + + +def test_cursor_operations_race(): + conn = duckdb.connect(":memory:") + try: + conn.execute("CREATE TABLE cursor_test (id INTEGER, name VARCHAR)") + conn.execute("INSERT INTO cursor_test SELECT i, 'name_' || i FROM range(100) t(i)") + + num_threads = 8 + + def cursor_operations(thread_id): + """Perform cursor operations concurrently.""" + # Get a cursor + cursor = conn.cursor() + cursor.execute( + f"SELECT * FROM cursor_test WHERE id % {num_threads} = {thread_id}" + ) + results = cursor.fetchall() + + return True + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(cursor_operations, i) for i in range(num_threads)] + results = [ + future.result() for future in concurrent.futures.as_completed(futures) + ] + + assert all(results) + finally: + conn.close() + + +@pytest.mark.parametrize("num_cycles,num_threads", [(25, 4), (50, 6)]) +def test_rapid_connection_cycling(num_cycles, num_threads): + """Test rapid connection creation and destruction cycles.""" + + def rapid_cycling(thread_id): + for cycle in range(num_cycles): + conn = duckdb.connect(":memory:") + try: + conn.execute(f"SELECT {thread_id} + {cycle}").fetchone() + finally: + conn.close() + + return True + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(rapid_cycling, i) for i in range(num_threads)] + results = [ + future.result() for future in concurrent.futures.as_completed(futures) + ] + + assert all(results) + \ No newline at end of file diff --git a/tests/fast/threading/test_module_lifecycle.py b/tests/fast/threading/test_module_lifecycle.py new file mode 100644 index 00000000..e485b8be --- /dev/null +++ b/tests/fast/threading/test_module_lifecycle.py @@ -0,0 +1,205 @@ +""" +Test module lifecycle + +Reloading and unload are not expected nor required behaviors - +these tests are to document current behavior so that changes +are visible. +""" +import sys +import importlib +import pytest +from concurrent.futures import ThreadPoolExecutor + + +def test_module_reload_safety(): + """Test module reloading scenarios to detect use-after-free issues.""" + import duckdb + + with duckdb.connect(':memory:') as conn1: + conn1.execute("CREATE TABLE test (id INTEGER)") + conn1.execute("INSERT INTO test VALUES (1)") + result1 = conn1.execute("SELECT * FROM test").fetchone()[0] + assert result1 == 1 + + initial_module_id = id(sys.modules['duckdb']) + + # Test importlib.reload() - this does NOT create new module instance in Python + importlib.reload(duckdb) + + # Verify module instance is the same (expected Python behavior) + reload_module_id = id(sys.modules['duckdb']) + assert initial_module_id == reload_module_id, "importlib.reload() should reuse same module instance" + + # Test if old connection still works after importlib.reload() + try: + result2 = conn1.execute("SELECT * FROM test").fetchone()[0] + assert result2 == 1 + except Exception as e: + pytest.fail(f"Old connection failed after importlib.reload(): {e}") + + # Test new connection after importlib.reload() + with duckdb.connect(':memory:') as conn2: + conn2.execute("CREATE TABLE test2 (id INTEGER)") + conn2.execute("INSERT INTO test2 VALUES (2)") + result3 = conn2.execute("SELECT * FROM test2").fetchone()[0] + assert result3 == 2 + + +def test_dynamic_module_loading(): + """Test module loading/unloading cycles.""" + import duckdb + + with duckdb.connect(':memory:') as conn: + conn.execute("SELECT 1").fetchone() + + module_id_1 = id(sys.modules['duckdb']) + + # "Unload" module (not really, just to try it) + if 'duckdb' in sys.modules: + del sys.modules['duckdb'] + + # Remove from local namespace + if 'duckdb' in locals(): + del duckdb + + # Verify module is unloaded + assert 'duckdb' not in sys.modules, "Module not properly unloaded" + + # import (load) module + import duckdb + module_id_2 = id(sys.modules['duckdb']) + + # Verify we have a new module instance + assert module_id_1 != module_id_2, "Module not actually reloaded" + + # Test functionality after reload + with duckdb.connect(':memory:') as conn: + conn.execute("CREATE TABLE test (id INTEGER)") + conn.execute("INSERT INTO test VALUES (42)") + result = conn.execute("SELECT * FROM test").fetchone()[0] + assert result == 42 + + +def test_complete_module_unload_with_live_connections(): + """Test the dangerous scenario: complete module unload with live connections.""" + + import duckdb + conn1 = duckdb.connect(':memory:') + conn1.execute("CREATE TABLE danger_test (id INTEGER)") + conn1.execute("INSERT INTO danger_test VALUES (123)") + + module_id_1 = id(sys.modules['duckdb']) + + if 'duckdb' in sys.modules: + del sys.modules['duckdb'] + del duckdb + + # TODO: Rethink this behavior - the module is unloaded, but we + # didn't invalidate all the connections and state... so even after + # unload, conn1 works. + + result = conn1.execute("SELECT * FROM danger_test").fetchone()[0] + assert result == 123 + + # Reimport creates new module state, but static cache should be reset + import duckdb + module_id_2 = id(sys.modules['duckdb']) + assert module_id_1 != module_id_2, "Should have different module instances" + + conn2 = duckdb.connect(':memory:') + conn2.execute("CREATE TABLE safe_test (id INTEGER)") + conn2.execute("INSERT INTO safe_test VALUES (456)") + result2 = conn2.execute("SELECT * FROM safe_test").fetchone()[0] + assert result2 == 456 + + conn2.close() + try: + conn1.close() + except: + pass + + +def test_concurrent_module_access(): + + import duckdb + + def worker(thread_id): + with duckdb.connect(':memory:') as conn: + conn.execute(f"CREATE TABLE test_{thread_id} (id INTEGER)") + conn.execute(f"INSERT INTO test_{thread_id} VALUES ({thread_id})") + result = conn.execute(f"SELECT * FROM test_{thread_id}").fetchone()[0] + conn.close() + return True + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(worker, i) for i in range(5)] + results = [f.result() for f in futures] + + assert all(results) + + +def test_import_cache_consistency(): + """Test that import cache remains consistent across module operations.""" + + import duckdb + import pandas as pd + + conn = duckdb.connect(':memory:') + + df = pd.DataFrame({'a': [1, 2, 3]}) + + conn.register('test_df', df) + result = conn.execute("SELECT COUNT(*) FROM test_df").fetchone()[0] + assert result == 3 + + conn.close() + + +def test_module_state_memory_safety(): + """Test memory safety of module state access patterns.""" + + import duckdb + + connections = [] + for i in range(10): + conn = duckdb.connect(':memory:') + conn.execute(f"CREATE TABLE test_{i} (id INTEGER)") + conn.execute(f"INSERT INTO test_{i} VALUES ({i})") + connections.append(conn) + + import gc + gc.collect() + + for i, conn in enumerate(connections): + try: + result = conn.execute(f"SELECT * FROM test_{i}").fetchone()[0] + assert result == i + except Exception as e: + pytest.fail(f"Connection {i} failed after GC: {e}") + + for conn in connections: + conn.close() + + +def test_static_cache_stress(): + """Stress test static cache with rapid module state access.""" + + import duckdb + + def rapid_access_worker(iterations): + """Rapidly access module state.""" + results = [] + for i in range(iterations): + conn = duckdb.connect(':memory:') + conn.execute("SELECT 1").fetchone() + conn.close() + results.append(True) + + assert len(results) == iterations + return True + + with ThreadPoolExecutor(max_workers=3) as executor: + futures = [executor.submit(rapid_access_worker, 50) for _ in range(3)] + all_results = [f.result() for f in futures] + + assert all(all_results) \ No newline at end of file diff --git a/tests/fast/threading/test_module_state.py b/tests/fast/threading/test_module_state.py new file mode 100644 index 00000000..7971c02f --- /dev/null +++ b/tests/fast/threading/test_module_state.py @@ -0,0 +1,197 @@ +import concurrent.futures +import os +import tempfile +import threading +import time + +import pytest + +import duckdb + + +def test_module_state_isolation(): + with duckdb.connect(":memory:") as conn1, duckdb.connect(":memory:") as conn2: + conn1.execute("CREATE TABLE test1 (x INTEGER)") + conn1.execute("INSERT INTO test1 VALUES (1)") + + conn2.execute("CREATE TABLE test2 (x INTEGER)") + conn2.execute("INSERT INTO test2 VALUES (2)") + + result1 = conn1.execute("SELECT * FROM test1").fetchall() + result2 = conn2.execute("SELECT * FROM test2").fetchall() + + assert result1 == [(1,)], "Connection 1 isolation failed" + assert result2 == [(2,)], "Connection 2 isolation failed" + + +def test_default_connection_access(): + with duckdb.connect() as conn1: + conn1.execute("CREATE TABLE test1 (x INTEGER)") + conn1.execute("INSERT INTO test1 VALUES (42)") + + # Verify data exists in this connection + result1 = conn1.execute("SELECT * FROM test1").fetchall() + assert result1 == [(42,)], "Connection 1 data missing" + + # New default connection should be isolated (table won't exist) + with duckdb.connect() as conn2: + # This should fail because tables are not shared between connections + try: + conn2.execute("SELECT * FROM test1").fetchall() + assert False, "Table should not exist in new connection" + except duckdb.CatalogException: + pass # Expected behavior - tables are isolated between connections + + +def test_import_cache_access(): + with duckdb.connect(":memory:") as conn: + try: + conn.execute("CREATE TABLE test AS SELECT range as x FROM range(10)") + df = conn.fetchdf() + assert len(df) == 10, "Pandas integration failed" + except Exception: + pass + + try: + result = conn.execute("SELECT range as x FROM range(5)").fetchnumpy() + assert "x" in result, "Numpy integration failed" + except Exception: + pass + + +def test_instance_cache_functionality(): + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + + with duckdb.connect(db_path) as conn1: + conn1.execute("CREATE TABLE test (x INTEGER)") + conn1.execute("INSERT INTO test VALUES (1)") + + with duckdb.connect(db_path) as conn2: + result = conn2.execute("SELECT * FROM test").fetchall() + assert result == [(1,)], "Instance cache failed" + + +def test_environment_detection(): + version = duckdb.__formatted_python_version__ + assert isinstance(version, str) + assert len(version) > 0 + + interactive = duckdb.__interactive__ + assert isinstance(interactive, bool) + + +@pytest.mark.parametrize("num_threads", [15, 20, 25]) +def test_concurrent_connection_creation(num_threads): + barrier = threading.Barrier(num_threads) + + def worker(thread_id): + barrier.wait() + + for i in range(5): + with duckdb.connect(":memory:") as conn: + conn.execute(f"CREATE TABLE test_{i} (x INTEGER)") + conn.execute(f"INSERT INTO test_{i} VALUES ({thread_id})") + result = conn.execute(f"SELECT * FROM test_{i}").fetchall() + assert result == [(thread_id,)], f"Thread {thread_id}, table {i} failed" + + return True + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(worker, i) for i in range(num_threads)] + results = [future.result() for future in futures] + + assert all(results) + +@pytest.mark.parametrize("num_threads", [10, 15, 20]) +def test_concurrent_instance_cache_access(tmp_path, num_threads): + barrier = threading.Barrier(num_threads) + + def worker(thread_id): + barrier.wait() + + for i in range(10): + db_path = str(tmp_path / f"test_{thread_id}_{i}.db") + with duckdb.connect(db_path) as conn: + conn.execute( + "CREATE TABLE IF NOT EXISTS test (x INTEGER, thread_id INTEGER)" + ) + conn.execute(f"INSERT INTO test VALUES ({i}, {thread_id})") + + return True + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(worker, i) for i in range(num_threads)] + results = [future.result() for future in futures] + + assert all(results) + + +@pytest.mark.parametrize("num_threads", [10, 15, 20]) +def test_concurrent_import_cache_access(num_threads): + barrier = threading.Barrier(num_threads) + + def worker(thread_id): + barrier.wait() + + for _i in range(20): + with duckdb.connect(":memory:") as conn: + try: + conn.execute("CREATE TABLE test AS SELECT range as x FROM range(5)") + df = conn.fetchdf() + assert len(df) == 5, ( + f"Thread {thread_id}: pandas integration failed" + ) + except Exception: + pass + + try: + result = conn.execute( + "SELECT range as x FROM range(3)" + ).fetchnumpy() + assert "x" in result, ( + f"Thread {thread_id}: numpy integration failed" + ) + except Exception: + pass + + time.sleep(0.0001) + + return True + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(worker, i) for i in range(num_threads)] + results = [future.result() for future in futures] + + assert all(results), "Some threads failed" + + +@pytest.mark.parametrize("num_threads", [10, 15, 20]) +def test_concurrent_environment_detection(num_threads): + """Test concurrent access to environment detection.""" + barrier = threading.Barrier(num_threads) + + def worker(thread_id): + barrier.wait() + + for _i in range(30): + version = duckdb.__formatted_python_version__ + interactive = duckdb.__interactive__ + + assert isinstance(version, str), ( + f"Thread {thread_id}: version should be string" + ) + assert isinstance(interactive, bool), ( + f"Thread {thread_id}: interactive should be boolean" + ) + + with duckdb.connect(":memory:") as conn: + conn.execute("SELECT 1") + + return True + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(worker, i) for i in range(num_threads)] + results = [future.result() for future in futures] + + assert all(results) diff --git a/tests/fast/threading/test_query_execution_races.py b/tests/fast/threading/test_query_execution_races.py new file mode 100644 index 00000000..afcf9cf5 --- /dev/null +++ b/tests/fast/threading/test_query_execution_races.py @@ -0,0 +1,190 @@ +""" +Test concurrent query execution races. + +This tests race conditions in query execution paths where GIL is released +during query processing, as identified in pyconnection.cpp. +""" + +import random +import threading +import time +import concurrent.futures + +import pytest + +import duckdb + + +class QueryRaceTester: + """Helper class to coordinate query execution race condition tests.""" + + def setup_barrier(self, num_threads): + self.barrier = threading.Barrier(num_threads) + + def synchronized_execute(self, db, query, description="query"): + """Wait for all threads to be ready, then execute query.""" + try: + with db.cursor() as conn: + self.barrier.wait() # Synchronize thread starts for maximum contention + result = conn.execute(query).fetchall() + return {"success": True, "result": result, "description": description} + except Exception as e: + return {"success": False, "error": str(e), "description": description} + + +@pytest.mark.parametrize("num_threads", [8, 12]) +def test_concurrent_prepare_execute(num_threads): + """Test concurrent PrepareQuery and ExecuteInternal paths.""" + + conn = duckdb.connect(':memory:') + try: + conn.execute("CREATE TABLE test_data (id INTEGER, value VARCHAR)") + conn.execute("INSERT INTO test_data SELECT i, 'value_' || i FROM range(1000) t(i)") + + tester = QueryRaceTester() + tester.setup_barrier(num_threads) + + def prepare_and_execute(thread_id): + queries = [ + f"SELECT COUNT(*) FROM test_data WHERE id > {thread_id * 10}", + f"SELECT value FROM test_data WHERE id = {thread_id + 1}", + f"SELECT id, value FROM test_data WHERE id BETWEEN {thread_id} AND {thread_id + 10}", + f"INSERT INTO test_data VALUES ({1000 + thread_id}, 'thread_{thread_id}')", + f"UPDATE test_data SET value = 'updated_{thread_id}' WHERE id = {thread_id + 500}" + ] + + query = queries[thread_id % len(queries)] + return tester.synchronized_execute(conn, query, f"Prepared query {thread_id}") + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(prepare_and_execute, i) for i in range(num_threads)] + results = [future.result() for future in concurrent.futures.as_completed(futures)] + + successful = [r for r in results if r["success"]] + assert len(successful) >= num_threads * 0.9, f"Only {len(successful)}/{num_threads} operations succeeded" + finally: + conn.close() + + +def test_concurrent_pending_query_execution(): + + conn = duckdb.connect(':memory:') + try: + conn.execute("CREATE TABLE large_data AS SELECT i, i*2 as double_val, 'row_' || i as str_val FROM range(10000) t(i)") + + num_threads = 8 + tester = QueryRaceTester() + tester.setup_barrier(num_threads) + + def execute_long_query(thread_id): + queries = [ + "SELECT COUNT(*), AVG(double_val) FROM large_data", + "SELECT str_val, double_val FROM large_data WHERE i % 100 = 0 ORDER BY double_val", + f"SELECT * FROM large_data WHERE i BETWEEN {thread_id * 1000} AND {(thread_id + 1) * 1000}", + "SELECT i, double_val, str_val FROM large_data WHERE double_val > 5000 ORDER BY i DESC", + f"SELECT COUNT(*) as cnt FROM large_data WHERE str_val LIKE '%{thread_id}%'" + ] + + query = queries[thread_id % len(queries)] + return tester.synchronized_execute(conn, query, f"Long query {thread_id}") + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(execute_long_query, i) for i in range(num_threads)] + results = [future.result() for future in concurrent.futures.as_completed(futures)] + + successful = [r for r in results if r["success"]] + assert len(successful) == num_threads, f"Only {len(successful)}/{num_threads} long queries succeeded" + finally: + conn.close() + + +def test_execute_many_race(): + + with duckdb.connect() as conn: + conn.execute("CREATE TABLE batch_data (id INTEGER, name VARCHAR)") + + num_threads = 10 + iterations = 10 + tester = QueryRaceTester() + tester.setup_barrier(num_threads) + + def execute_many_batch(thread_id): + with conn.cursor() as conn2: + batch_data = [(thread_id * 100 + i, f'name_{thread_id}_{i}') for i in range(iterations)] + tester.barrier.wait() + conn2.executemany("INSERT INTO batch_data VALUES (?, ?)", batch_data) + result = conn2.execute(f"SELECT COUNT(*) FROM batch_data WHERE name LIKE 'name_{thread_id}_%'").fetchone() + + return True + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(execute_many_batch, i) for i in range(num_threads)] + results = [future.result() for future in concurrent.futures.as_completed(futures)] + + total_rows = conn.execute("SELECT COUNT(*) FROM batch_data").fetchone()[0] + assert total_rows == num_threads * iterations + assert all(results) + + + +def test_query_interruption_race(): + + conn = duckdb.connect(':memory:') + try: + conn.execute("CREATE TABLE interrupt_test AS SELECT i FROM range(100000) t(i)") + + num_threads = 6 + + def run_interruptible_query(thread_id): + + with conn.cursor() as conn2: + if thread_id % 2 == 0: + # Fast query + result = conn2.execute("SELECT COUNT(*) FROM interrupt_test").fetchall() + return True + else: + # Potentially slower query + result = conn2.execute("SELECT i, i*i FROM interrupt_test WHERE i % 1000 = 0 ORDER BY i").fetchall() + return True + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(run_interruptible_query, i) for i in range(num_threads)] + results = [future.result() for future in concurrent.futures.as_completed(futures, timeout=30)] + + assert all(results) + finally: + conn.close() + + +@pytest.mark.parametrize("num_threads", [10, 15, 40]) +def test_mixed_query_operations(num_threads): + + def mixed_query_operations(thread_id, db): + + queries = [ + f"SELECT COUNT(*) FROM mixed_ops WHERE id > {thread_id * 50}", + f"INSERT INTO mixed_ops VALUES ({10000 + thread_id}, 'thread_{thread_id}', {thread_id * 2.5})", + f"UPDATE mixed_ops SET data = 'updated_{thread_id}' WHERE id = {thread_id + 100}", + "SELECT AVG(num_val), MAX(id) FROM mixed_ops WHERE data LIKE 'initial_%'", + """ + SELECT m1.id, m1.data, m2.num_val + FROM mixed_ops m1 + JOIN mixed_ops m2 ON m1.id = m2.id - 1 + LIMIT 10 + """ + ] + + with duckdb.connect(db) as conn2: + conn2.execute("CREATE TABLE mixed_ops (id INTEGER PRIMARY KEY, data VARCHAR, num_val DOUBLE)") + conn2.execute("INSERT INTO mixed_ops SELECT i, 'initial_' || i, i * 1.5 FROM range(1000) t(i)") + + query = queries[thread_id % len(queries)] + + conn2.execute(query) + return True + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(mixed_query_operations, i, f":memory:{i}") for i in range(num_threads)] + results = [future.result() for future in concurrent.futures.as_completed(futures)] + + assert all(results) \ No newline at end of file diff --git a/tests/fast/threading/test_race_conditions.py b/tests/fast/threading/test_race_conditions.py new file mode 100644 index 00000000..3784e197 --- /dev/null +++ b/tests/fast/threading/test_race_conditions.py @@ -0,0 +1,221 @@ +import concurrent.futures +import gc +import random +import threading +import time +import weakref + +import pytest + +import duckdb + + + +def test_module_state_race(num_threads_testing): + barrier = threading.Barrier(num_threads_testing) + + def worker(thread_id): + barrier.wait() + + for _i in range(30): + with duckdb.connect(":memory:") as conn: + conn.execute("SELECT 1") + int_type = duckdb.type("INTEGER") + assert int_type is not None, f"Thread {thread_id}: type creation failed" + + if _i % 10 == 0: + time.sleep(0.0001) + + return True + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads_testing) as executor: + futures = [executor.submit(worker, i) for i in range(num_threads_testing)] + results = [future.result() for future in futures] + + assert all(results) + + +def test_connection_instance_cache_race(tmp_path, num_threads_testing): + num_threads = num_threads_testing + barrier = threading.Barrier(num_threads) + + def worker(thread_id): + barrier.wait() + + for i in range(10): + db_path = tmp_path / f"race_test_t{thread_id}_i{i}.db" + with duckdb.connect(str(db_path)) as conn: + conn.execute( + f"CREATE TABLE IF NOT EXISTS thread_{thread_id}_data_{i} (x INTEGER)" + ) + conn.execute( + f"INSERT INTO thread_{thread_id}_data_{i} VALUES ({thread_id}), ({i})" + ) + + time.sleep(random.uniform(0.0001, 0.001)) + + result = conn.execute( + f"SELECT COUNT(*) FROM thread_{thread_id}_data_{i}" + ).fetchone()[0] + assert result == 2, ( + f"Thread {thread_id}, iteration {i}: expected 2 rows, got {result}" + ) + + return True + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(worker, i) for i in range(num_threads)] + results = [future.result() for future in futures] + + assert all(results) + + +@pytest.mark.parametrize("num_threads", [15, 20, 25]) +def test_cleanup_race(num_threads): + + def worker(thread_id): + weak_refs = [] + + for i in range(50): + conn = duckdb.connect(":memory:") + weak_refs.append(weakref.ref(conn)) + try: + conn.execute("CREATE TABLE test (x INTEGER)") + conn.execute("INSERT INTO test VALUES (1), (2), (3)") + finally: + conn.close() + conn = None + + if i % 3 == 0: + with duckdb.connect(":memory:") as new_conn: + new_conn.execute("SELECT 1") + + if i % 10 == 0: + gc.collect() + time.sleep(random.uniform(0.0001, 0.0005)) + + gc.collect() + time.sleep(0.1) + gc.collect() + + alive_refs = [ref for ref in weak_refs if ref() is not None] + if len(alive_refs) > 10: + assert False, ( + f"Thread {thread_id}: {len(alive_refs)} connections still alive (expected < 10)" + ) + return True + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(worker, i) for i in range(num_threads)] + results = [future.result() for future in futures] + + assert all(results), "Some threads failed" + + +@pytest.mark.parametrize("num_threads", [20, 25, 30]) +def test_default_connection_race(num_threads): + barrier = threading.Barrier(num_threads) + + def worker(thread_id): + barrier.wait() + + for _i in range(30): + with duckdb.connect() as conn1: + r1 = conn1.execute("SELECT 1").fetchone()[0] + assert r1 == 1, f"Thread {thread_id}: expected 1, got {r1}" + + with duckdb.connect(":memory:") as conn2: + r2 = conn2.execute("SELECT 2").fetchone()[0] + assert r2 == 2, f"Thread {thread_id}: expected 2, got {r2}" + + with duckdb.connect("") as conn3: + r3 = conn3.execute("SELECT 3").fetchone()[0] + assert r3 == 3, f"Thread {thread_id}: expected 3, got {r3}" + + time.sleep(0.0001) + + return True + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(worker, i) for i in range(num_threads)] + results = [future.result() for future in futures] + + assert all(results) + + +@pytest.mark.parametrize("num_threads", [15, 20, 25]) +def test_type_system_race(num_threads): + barrier = threading.Barrier(num_threads) + + def worker(thread_id): + barrier.wait() + + for i in range(100): + types = [ + duckdb.type("INTEGER"), + duckdb.type("VARCHAR"), + duckdb.type("DOUBLE"), + duckdb.type("BOOLEAN"), + duckdb.list_type(duckdb.type("INTEGER")), + duckdb.struct_type( + {"a": duckdb.type("INTEGER"), "b": duckdb.type("VARCHAR")} + ), + ] + + for t in types: + assert t is not None, f"Thread {thread_id}: type creation failed" + + if i % 5 == 0: + with duckdb.connect(":memory:") as conn: + conn.execute( + "CREATE TABLE test (a INTEGER, b VARCHAR, c DOUBLE, d BOOLEAN)" + ) + + return True + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(worker, i) for i in range(num_threads)] + results = [future.result() for future in futures] + + assert all(results) + + +@pytest.mark.parametrize("num_threads", [10, 15, 20]) +def test_import_cache_race(num_threads): + + def worker(thread_id): + for _i in range(50): + with duckdb.connect(":memory:") as conn: + try: + conn.execute( + "CREATE TABLE test AS SELECT range as x FROM range(10)" + ) + result = conn.fetchdf() + assert len(result) > 0, f"Thread {thread_id}: fetchdf failed" + except: + pass + + try: + result = conn.execute( + "SELECT range as x FROM range(5)" + ).fetchnumpy() + assert len(result["x"]) == 5, ( + f"Thread {thread_id}: fetchnumpy failed" + ) + except: + pass + + try: + conn.execute("DROP TABLE test") + except: + pass + + time.sleep(0.0001) + + return True + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(worker, i) for i in range(num_threads)] + results = [future.result() for future in futures] + + assert all(results) diff --git a/tests/fast/threading/test_threading.py b/tests/fast/threading/test_threading.py new file mode 100644 index 00000000..6a419b8d --- /dev/null +++ b/tests/fast/threading/test_threading.py @@ -0,0 +1,236 @@ +#!/usr/bin/env python3 +""" +Tests designed to expose specific threading bugs in the DuckDB implementation. +""" + +import concurrent.futures +import os +import tempfile +import threading +import time + +import pytest + +import duckdb + + +def get_optimal_thread_count(): + """Calculate thread count based on number of cores""" + import multiprocessing + + cpu_count = multiprocessing.cpu_count() + return min(12, max(4, int(cpu_count * 1.5))) + + +@pytest.fixture +def temp_db_files(): + """Provide temporary database files and clean them up.""" + temp_files = [] + with tempfile.TemporaryDirectory() as tmpdir: + for i in range(3): + temp_files.append(os.path.join(tmpdir, f"race_test_{i}.db")) + yield temp_files + + +def test_instance_cache_race(temp_db_files): + """Test the specific race condition in instance cache initialization.""" + # This test tries to trigger the race condition where multiple threads + # see state->instance_cache as null and try to create it simultaneously + + num_threads = get_optimal_thread_count() + barrier = threading.Barrier(num_threads) + results = [] + lock = threading.Lock() + + def trigger_instance_cache_race(thread_id): + try: + # Wait for all threads to be ready + barrier.wait() + + # All threads try to create file-based connections simultaneously + # This should trigger the instance cache initialization race + connections = [] + for i in range(5): # Reduced from 10 + # Use unique database file per thread and iteration to avoid conflicts + db_file = temp_db_files[0] + f"_t{thread_id}_i{i}.db" + conn = duckdb.connect(db_file) + connections.append(conn) + + # Do some work to keep the connection alive + conn.execute("CREATE TABLE IF NOT EXISTS test (x INTEGER, y INTEGER)") + conn.execute(f"INSERT INTO test VALUES ({thread_id}, {i})") + + # Close connections + for conn in connections: + conn.close() + + with lock: + results.append((thread_id, "success")) + + except Exception as e: + with lock: + results.append((thread_id, f"error: {e}")) + + # Start all threads + threads = [] + for i in range(num_threads): + t = threading.Thread(target=trigger_instance_cache_race, args=(i,)) + threads.append(t) + t.start() + + # Wait for completion + for t in threads: + t.join() + + # Analyze results + errors = [r for r in results if not r[1] == "success"] + assert not errors, f"Errors detected in instance cache race test: {errors}" + assert len(results) == num_threads, ( + f"Expected {num_threads} results, got {len(results)}" + ) + + +def test_import_cache_reset_race(): + """Test race condition when import cache is reset while in use.""" + + def worker_thread(thread_id): + try: + for i in range(20): # Reduced from 50 + conn = duckdb.connect(":memory:") + + # These operations might use the import cache + try: + # Try pandas operations (if available) + conn.execute( + "CREATE TABLE test AS SELECT range as x FROM range(10)" + ) + df = conn.fetchdf() # Might use import cache for pandas + + # Try numpy operations (if available) + result = conn.execute("SELECT * FROM test").fetchnumpy() + + except Exception: + # pandas/numpy might not be available, that's fine + pass + + conn.close() + + # Add tiny delay to increase race chance + time.sleep(0.0001) + + return (thread_id, "success") + + except Exception as e: + return (thread_id, f"error: {e}") + + # Run many threads that use import cache + num_threads = get_optimal_thread_count() + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(worker_thread, i) for i in range(num_threads)] + results = [ + future.result() for future in concurrent.futures.as_completed(futures) + ] + + errors = [r for r in results if not r[1] == "success"] + assert not errors, f"Errors detected in import cache race test: {errors}" + assert len(results) == num_threads, ( + f"Expected {num_threads} results, got {len(results)}" + ) + + +def test_module_state_corruption(): + """Test for module state corruption under heavy concurrent access.""" + + def stress_module_state(thread_id): + try: + # Rapidly access different parts of module state + for i in range(30): # Reduced from 100 + # Create connection (accesses instance cache) + conn = duckdb.connect(":memory:") + + # Access type system (might use module state for caching) + int_type = duckdb.type("INTEGER") + + # Access default connection logic + default_conn = duckdb.connect() + + # Do some operations + conn.execute("SELECT 1") + default_conn.execute("SELECT 2") + + # Check type system consistency + if not int_type: + return (thread_id, "type system corruption") + + conn.close() + default_conn.close() + + # Vary timing to increase race chances + if i % 10 == 0: + time.sleep(0.0001) + + return (thread_id, "success") + + except Exception as e: + return (thread_id, f"error: {e}") + + # Heavy concurrent load + num_threads = get_optimal_thread_count() + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(stress_module_state, i) for i in range(num_threads)] + results = [ + future.result() for future in concurrent.futures.as_completed(futures) + ] + + errors = [r for r in results if not r[1] == "success"] + assert not errors, f"Errors detected in module state corruption test: {errors}" + assert len(results) == num_threads, ( + f"Expected {num_threads} results, got {len(results)}" + ) + + +def test_formatted_python_version_race(): + """Test race condition in formatted_python_version string.""" + + num_threads = get_optimal_thread_count() + barrier = threading.Barrier(num_threads) + results = [] + lock = threading.Lock() + + def access_python_version(thread_id): + try: + # Wait for all threads + barrier.wait() + + # All threads try to trigger DetectEnvironment simultaneously + # This might race on the formatted_python_version string + for i in range(15): # Reduced from 30 + conn = duckdb.connect(":memory:") + + # This might trigger environment detection + conn.execute("SELECT 'test' as value") + + conn.close() + + with lock: + results.append((thread_id, "success")) + + except Exception as e: + with lock: + results.append((thread_id, f"error: {e}")) + + # Start threads + threads = [] + for i in range(num_threads): + t = threading.Thread(target=access_python_version, args=(i,)) + threads.append(t) + t.start() + + for t in threads: + t.join() + + errors = [r for r in results if not r[1] == "success"] + assert not errors, f"Errors detected in python version race test: {errors}" + assert len(results) == num_threads, ( + f"Expected {num_threads} results, got {len(results)}" + ) diff --git a/tests/fast/threading/test_udf_races.py b/tests/fast/threading/test_udf_races.py new file mode 100644 index 00000000..0c6d6d00 --- /dev/null +++ b/tests/fast/threading/test_udf_races.py @@ -0,0 +1,169 @@ +""" +Test concurrent User Defined Function (UDF). +""" + +import concurrent.futures +import threading + +import pytest + +import duckdb + + +class UDFRaceTester: + def setup_barrier(self, num_threads): + self.barrier = threading.Barrier(num_threads) + + def wait_and_execute(self, db, query, description="query"): + with db.cursor() as conn: + self.barrier.wait() # Synchronize thread starts for maximum contention + result = conn.execute(query).fetchall() + return True + + + +@pytest.mark.parametrize("num_threads", [8, 10, 12]) +def test_concurrent_udf_registration(num_threads): + """Test concurrent registration of UDFs.""" + tester = UDFRaceTester() + tester.setup_barrier(num_threads) + + def register_udf(thread_id): + with duckdb.connect(":memory:") as conn: + + def my_add(x: int, y: int) -> int: + return x + y + + udf_name = f"my_add_{thread_id}" + conn.create_function(udf_name, my_add) + + tester.wait_and_execute( + conn, f"SELECT {udf_name}(1, 2)", f"UDF test {thread_id}" + ) + + return True + + + # Run concurrent UDF registrations + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(register_udf, i) for i in range(num_threads)] + results = [ + future.result() for future in concurrent.futures.as_completed(futures) + ] + + assert all(results), "Some UDF registrations failed" + + +@pytest.mark.parametrize("num_threads", [10, 15, 20]) +def test_concurrent_udf_execution(num_threads): + """Test concurrent execution of the same UDF.""" + conn = duckdb.connect(":memory:") + + def slow_multiply(x: int, y: int) -> int: + result = 1 + for _i in range(10): + result = result * 1.0 + (x * y * 0.1) + return int(result) + + conn.create_function("slow_multiply", slow_multiply) + + tester = UDFRaceTester() + tester.setup_barrier(num_threads) + + def execute_udf(thread_id): + query = f"SELECT slow_multiply({thread_id}, 2) as result" + return tester.wait_and_execute(conn, query, f"UDF execution {thread_id}") + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(execute_udf, i) for i in range(num_threads)] + results = [ + future.result() for future in concurrent.futures.as_completed(futures) + ] + + conn.close() + + assert all(results) + + +@pytest.mark.parametrize("num_threads", [8, 12, 16]) +def test_mixed_udf_operations(num_threads): + """Test mixing UDF registration, execution, and unregistration concurrently.""" + tester = UDFRaceTester() + tester.setup_barrier(num_threads) + + def mixed_operations(thread_id): + conn = duckdb.connect(":memory:") + + if thread_id % 3 == 0: + # Register and use UDF + def thread_func(x: int) -> int: + return x * thread_id + + udf_name = f"thread_func_{thread_id}" + conn.create_function(udf_name, thread_func) + result = tester.wait_and_execute( + conn, f"SELECT {udf_name}(5)", f"Register+Execute {thread_id}" + ) + elif thread_id % 3 == 1: + # Use a common UDF that might be registered by other threads + result = tester.wait_and_execute( + conn, "SELECT 42", f"Simple query {thread_id}" + ) + else: + # Create table and use built-in functions + conn.execute("CREATE TABLE test_table (x INTEGER)") + conn.execute("INSERT INTO test_table VALUES (1), (2), (3)") + result = tester.wait_and_execute( + conn, "SELECT COUNT(*) FROM test_table", f"Table ops {thread_id}" + ) + + conn.close() + return result + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(mixed_operations, i) for i in range(num_threads)] + results = [ + future.result() for future in concurrent.futures.as_completed(futures) + ] + + assert all(results) + + +@pytest.mark.parametrize("num_threads", [6, 8, 10]) +def test_scalar_udf_races(num_threads): + """Test concurrent execution of scalar UDFs.""" + conn = duckdb.connect(":memory:") + + # Create test data + conn.execute("CREATE TABLE numbers (x INTEGER)") + conn.execute("INSERT INTO numbers SELECT * FROM range(100)") + + # Create a simple scalar UDF instead of vectorized (simpler for testing) + def simple_square(x: int) -> int: + """Square a single value.""" + return x * x + + conn.create_function("simple_square", simple_square) + + tester = UDFRaceTester() + tester.setup_barrier(num_threads) + + def execute_scalar_udf(thread_id): + start = thread_id * 10 + end = start + 10 + query = ( + f"SELECT simple_square(x) FROM numbers WHERE x BETWEEN {start} AND {end}" + ) + return tester.wait_and_execute(conn, query, f"Scalar UDF {thread_id}") + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(execute_scalar_udf, i) for i in range(num_threads)] + results = [ + future.result() for future in concurrent.futures.as_completed(futures) + ] + + conn.close() + + assert all(results) + + diff --git a/tests/slow/test_h2oai_arrow.py b/tests/slow/test_h2oai_arrow.py index 40bde07b..eddfc7d1 100644 --- a/tests/slow/test_h2oai_arrow.py +++ b/tests/slow/test_h2oai_arrow.py @@ -194,8 +194,10 @@ def test_join(self, threads, function, large_data): @fixture(scope="module") -def arrow_dataset_register(): +def arrow_dataset_register(tmp_path_factory): """Single fixture to download files and register them on the given connection""" + temp_dir = tmp_path_factory.mktemp("h2oai_data") + session = requests.Session() retries = urllib3_util.Retry( allowed_methods={'GET'}, # only retry on GETs (all we do) @@ -212,19 +214,15 @@ def arrow_dataset_register(): respect_retry_after_header=True, # respect Retry-After headers ) session.mount('https://', requests_adapters.HTTPAdapter(max_retries=retries)) - saved_filenames = set() def _register(url, filename, con, tablename): + file_path = temp_dir / filename r = session.get(url) - with open(filename, 'wb') as f: - f.write(r.content) - con.register(tablename, read_csv(filename)) - saved_filenames.add(filename) + file_path.write_bytes(r.content) + con.register(tablename, read_csv(str(file_path))) yield _register - for filename in saved_filenames: - os.remove(filename) session.close()