diff --git a/.github/workflows/packaging_wheels.yml b/.github/workflows/packaging_wheels.yml index ea13b674..e151ccf1 100644 --- a/.github/workflows/packaging_wheels.yml +++ b/.github/workflows/packaging_wheels.yml @@ -94,12 +94,9 @@ jobs: env: CIBW_ARCHS: ${{ matrix.platform.arch == 'amd64' && 'AMD64' || matrix.platform.arch }} CIBW_BUILD: ${{ matrix.python }}-${{ matrix.platform.cibw_system }}_${{ matrix.platform.arch }} - # PYTHON_GIL=1: Suppresses the RuntimeWarning that the GIL is enabled on free-threaded builds. - # TODO: Remove PYTHON_GIL=1 when free-threaded is supported. - CIBW_ENVIRONMENT: PYTHON_GIL=1 - name: Upload wheel uses: actions/upload-artifact@v4 with: - name: wheel-${{ matrix.python }}-${{ matrix.platform.cibw_system }}_${{ matrix.platform.arch }} + name: wheel-${{ matrix.python }}-${{ matrix.platform.os }}_${{ matrix.platform.arch }} path: wheelhouse/*.whl compression-level: 0 diff --git a/pyproject.toml b/pyproject.toml index bcbb24f6..477e936f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -227,6 +227,7 @@ test = [ # dependencies used for running tests "pytest", "pytest-reraise", "pytest-timeout", + "pytest-run-parallel", "mypy", "coverage", "gcovr; python_version < '3.14'", diff --git a/src/duckdb_py/CMakeLists.txt b/src/duckdb_py/CMakeLists.txt index 2252ba29..7673d17f 100644 --- a/src/duckdb_py/CMakeLists.txt +++ b/src/duckdb_py/CMakeLists.txt @@ -17,6 +17,7 @@ add_library(python_src OBJECT duckdb_python.cpp importer.cpp map.cpp + module_state.cpp path_like.cpp pyconnection.cpp pyexpression.cpp diff --git a/src/duckdb_py/duckdb_python.cpp b/src/duckdb_py/duckdb_python.cpp index 939fa41a..f84377a7 100644 --- a/src/duckdb_py/duckdb_python.cpp +++ b/src/duckdb_py/duckdb_python.cpp @@ -20,6 +20,7 @@ #include "duckdb_python/pybind11/conversions/python_udf_type_enum.hpp" #include "duckdb_python/pybind11/conversions/python_csv_line_terminator_enum.hpp" #include "duckdb/common/enums/statement_type.hpp" +#include "duckdb_python/module_state.hpp" #include "duckdb.hpp" @@ -31,6 +32,16 @@ namespace py = pybind11; namespace duckdb { +// Private function to initialize module state +void InitializeModuleState(py::module_ &m) { + auto state_ptr = new DuckDBPyModuleState(); + SetModuleState(state_ptr); + + // https://pybind11.readthedocs.io/en/stable/advanced/misc.html#module-destructors + auto capsule = py::capsule(state_ptr, [](void *p) { delete static_cast(p); }); + m.attr("__duckdb_state") = capsule; +} + enum PySQLTokenType : uint8_t { PY_SQL_TOKEN_IDENTIFIER = 0, PY_SQL_TOKEN_NUMERIC_CONSTANT, @@ -1007,7 +1018,22 @@ static void RegisterExpectedResultType(py::handle &m) { expected_return_type.export_values(); } -PYBIND11_MODULE(DUCKDB_PYTHON_LIB_NAME, m) { // NOLINT +// Only mark mod_gil_not_used for 3.14t or later +// This is to not add support for 3.13t +// Py_GIL_DISABLED check is not strictly necessary +#if defined(Py_GIL_DISABLED) && PY_VERSION_HEX >= 0x030e0000 +PYBIND11_MODULE(DUCKDB_PYTHON_LIB_NAME, m, py::mod_gil_not_used(), + py::multiple_interpreters::not_supported()) { // NOLINT +#else +PYBIND11_MODULE(DUCKDB_PYTHON_LIB_NAME, m, + py::multiple_interpreters::not_supported()) { // NOLINT +#endif + + // Initialize module state completely during initialization + // PEP 489 wants calls for state to be module local, but currently + // static via g_module_state. + InitializeModuleState(m); + py::enum_(m, "ExplainType") .value("STANDARD", duckdb::ExplainType::EXPLAIN_STANDARD) .value("ANALYZE", duckdb::ExplainType::EXPLAIN_ANALYZE) @@ -1046,9 +1072,10 @@ PYBIND11_MODULE(DUCKDB_PYTHON_LIB_NAME, m) { // NOLINT m.attr("__version__") = std::string(DuckDB::LibraryVersion()).substr(1); m.attr("__standard_vector_size__") = DuckDB::StandardVectorSize(); m.attr("__git_revision__") = DuckDB::SourceID(); - m.attr("__interactive__") = DuckDBPyConnection::DetectAndGetEnvironment(); - m.attr("__jupyter__") = DuckDBPyConnection::IsJupyter(); - m.attr("__formatted_python_version__") = DuckDBPyConnection::FormattedPythonVersion(); + auto &module_state = GetModuleState(); + m.attr("__interactive__") = module_state.environment != PythonEnvironmentType::NORMAL; + m.attr("__jupyter__") = module_state.environment == PythonEnvironmentType::JUPYTER; + m.attr("__formatted_python_version__") = module_state.formatted_python_version; m.def("default_connection", &DuckDBPyConnection::DefaultConnection, "Retrieve the connection currently registered as the default to be used by the module"); m.def("set_default_connection", &DuckDBPyConnection::SetDefaultConnection, diff --git a/src/duckdb_py/include/duckdb_python/module_state.hpp b/src/duckdb_py/include/duckdb_python/module_state.hpp new file mode 100644 index 00000000..d7a4e377 --- /dev/null +++ b/src/duckdb_py/include/duckdb_python/module_state.hpp @@ -0,0 +1,63 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb_python/module_state.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb_python/pybind11/pybind_wrapper.hpp" +#include "duckdb/common/shared_ptr.hpp" +#include "duckdb/main/db_instance_cache.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb_python/import_cache/python_import_cache.hpp" +#include "duckdb_python/pyconnection/pyconnection.hpp" +#include + +namespace duckdb { + +// Module state structure to hold per-interpreter state +struct DuckDBPyModuleState { + // Python environment tracking + PythonEnvironmentType environment = PythonEnvironmentType::NORMAL; + string formatted_python_version; + + DuckDBPyModuleState(); + + shared_ptr GetDefaultConnection(); + void SetDefaultConnection(shared_ptr connection); + void ClearDefaultConnection(); + + PythonImportCache *GetImportCache(); + void ClearImportCache(); + + DBInstanceCache *GetInstanceCache(); + + static DuckDBPyModuleState &GetGlobalModuleState(); + static void SetGlobalModuleState(DuckDBPyModuleState *state); + +private: + shared_ptr default_connection_ptr; + PythonImportCache import_cache; + DBInstanceCache instance_cache; +#ifdef Py_GIL_DISABLED + py::object default_con_lock; +#endif + + // Implemented as static as a first step towards PEP 489 / multi-phase init + // Intent is to move to per-module object, but frequent calls to import_cache + // need to be considered carefully. + // TODO: Replace with non-static per-interpreter state for multi-interpreter support + static DuckDBPyModuleState *g_module_state; + + // Non-copyable + DuckDBPyModuleState(const DuckDBPyModuleState &) = delete; + DuckDBPyModuleState &operator=(const DuckDBPyModuleState &) = delete; +}; + +DuckDBPyModuleState &GetModuleState(); +void SetModuleState(DuckDBPyModuleState *state); + +} // namespace duckdb \ No newline at end of file diff --git a/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp b/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp index 48ee055e..7998c14e 100644 --- a/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp +++ b/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp @@ -28,6 +28,7 @@ namespace duckdb { struct BoundParameterData; +struct DuckDBPyModuleState; enum class PythonEnvironmentType { NORMAL, INTERACTIVE, JUPYTER }; @@ -172,8 +173,7 @@ struct DuckDBPyConnection : public enable_shared_from_this { case_insensitive_set_t registered_objects; public: - explicit DuckDBPyConnection() { - } + DuckDBPyConnection(); ~DuckDBPyConnection(); public: @@ -190,9 +190,17 @@ struct DuckDBPyConnection : public enable_shared_from_this { static std::string FormattedPythonVersion(); static shared_ptr DefaultConnection(); static void SetDefaultConnection(shared_ptr conn); + static shared_ptr GetDefaultConnection(); + static void ClearDefaultConnection(); + static void ClearImportCache(); static PythonImportCache *ImportCache(); static bool IsInteractive(); + // Instance methods for optimized module state access + bool IsJupyterInstance() const; + bool IsInteractiveInstance() const; + std::string FormattedPythonVersionInstance() const; + unique_ptr ReadCSV(const py::object &name, py::kwargs &kwargs); py::list ExtractStatements(const string &query); @@ -337,11 +345,6 @@ struct DuckDBPyConnection : public enable_shared_from_this { py::list ListFilesystems(); bool FileSystemIsRegistered(const string &name); - //! Default connection to an in-memory database - static DefaultConnectionHolder default_connection; - //! Caches and provides an interface to get frequently used modules+subtypes - static shared_ptr import_cache; - static bool IsPandasDataframe(const py::object &object); static PyArrowObjectType GetArrowType(const py::handle &obj); static bool IsAcceptedArrowObject(const py::object &object); @@ -357,10 +360,6 @@ struct DuckDBPyConnection : public enable_shared_from_this { bool side_effects); void RegisterArrowObject(const py::object &arrow_object, const string &name); vector> GetStatements(const py::object &query); - - static PythonEnvironmentType environment; - static std::string formatted_python_version; - static void DetectEnvironment(); }; template diff --git a/src/duckdb_py/module_state.cpp b/src/duckdb_py/module_state.cpp new file mode 100644 index 00000000..1e0b6897 --- /dev/null +++ b/src/duckdb_py/module_state.cpp @@ -0,0 +1,128 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb_python/module_state.cpp +// +// +//===----------------------------------------------------------------------===// + +#include "duckdb_python/module_state.hpp" +#include +#include +#include + +#define DEBUG_MODULE_STATE 0 + +namespace duckdb { + +// Forward declaration from pyconnection.cpp +void InstantiateNewInstance(DuckDB &db); + +// Static member initialization - required for all static class members in C++ +DuckDBPyModuleState *DuckDBPyModuleState::g_module_state = nullptr; + +DuckDBPyModuleState::DuckDBPyModuleState() { + // Caches are constructed as direct objects - no heap allocation needed + +#ifdef Py_GIL_DISABLED + // Initialize lock object for critical sections + // TODO: Consider moving to finer-grained locks + default_con_lock = py::none(); +#endif + + // Detects Python environment and version during intialization + // Moved from DuckDBPyConnection::DetectEnvironment() + py::module_ sys = py::module_::import("sys"); + py::object version_info = sys.attr("version_info"); + int major = py::cast(version_info.attr("major")); + int minor = py::cast(version_info.attr("minor")); + formatted_python_version = std::to_string(major) + "." + std::to_string(minor); + + // If __main__ does not have a __file__ attribute, we are in interactive mode + auto main_module = py::module_::import("__main__"); + if (!py::hasattr(main_module, "__file__")) { + environment = PythonEnvironmentType::INTERACTIVE; + + if (ModuleIsLoaded()) { + // Check to see if we are in a Jupyter Notebook + auto get_ipython = import_cache.IPython.get_ipython(); + if (get_ipython.ptr() != nullptr) { + auto ipython = get_ipython(); + if (py::hasattr(ipython, "config")) { + py::dict ipython_config = ipython.attr("config"); + if (ipython_config.contains("IPKernelApp")) { + environment = PythonEnvironmentType::JUPYTER; + } + } + } + } + } +} + +DuckDBPyModuleState &DuckDBPyModuleState::GetGlobalModuleState() { + // TODO: Externalize this static cache when adding multi-interpreter support + // For now, single interpreter assumption allows simple static caching + if (!g_module_state) { + throw InternalException("Module state not initialized - call SetGlobalModuleState() during module init"); + } + return *g_module_state; +} + +void DuckDBPyModuleState::SetGlobalModuleState(DuckDBPyModuleState *state) { +#if DEBUG_MODULE_STATE + printf("DEBUG: SetGlobalModuleState() called - initializing static cache (built: %s %s)\n", __DATE__, __TIME__); +#endif + g_module_state = state; +} + +DuckDBPyModuleState &GetModuleState() { +#if DEBUG_MODULE_STATE + printf("DEBUG: GetModuleState() called\n"); +#endif + return DuckDBPyModuleState::GetGlobalModuleState(); +} + +void SetModuleState(DuckDBPyModuleState *state) { + DuckDBPyModuleState::SetGlobalModuleState(state); +} + +shared_ptr DuckDBPyModuleState::GetDefaultConnection() { +#if defined(Py_GIL_DISABLED) + // TODO: Consider whether a mutex vs a scoped_critical_section + py::scoped_critical_section guard(default_con_lock); +#endif + // Reproduce exact logic from original DefaultConnectionHolder::Get() + if (!default_connection_ptr || default_connection_ptr->con.ConnectionIsClosed()) { + py::dict config_dict; + default_connection_ptr = DuckDBPyConnection::Connect(py::str(":memory:"), false, config_dict); + } + return default_connection_ptr; +} + +void DuckDBPyModuleState::SetDefaultConnection(shared_ptr connection) { +#if defined(Py_GIL_DISABLED) + py::scoped_critical_section guard(default_con_lock); +#endif + default_connection_ptr = std::move(connection); +} + +void DuckDBPyModuleState::ClearDefaultConnection() { +#if defined(Py_GIL_DISABLED) + py::scoped_critical_section guard(default_con_lock); +#endif + default_connection_ptr = nullptr; +} + +PythonImportCache *DuckDBPyModuleState::GetImportCache() { + return &import_cache; +} + +void DuckDBPyModuleState::ClearImportCache() { + import_cache = PythonImportCache(); +} + +DBInstanceCache *DuckDBPyModuleState::GetInstanceCache() { + return &instance_cache; +} + +} // namespace duckdb \ No newline at end of file diff --git a/src/duckdb_py/pyconnection.cpp b/src/duckdb_py/pyconnection.cpp index b88b88ed..ba394990 100644 --- a/src/duckdb_py/pyconnection.cpp +++ b/src/duckdb_py/pyconnection.cpp @@ -1,4 +1,5 @@ #include "duckdb_python/pyconnection/pyconnection.hpp" +#include "duckdb_python/module_state.hpp" #include "duckdb/catalog/default/default_types.hpp" #include "duckdb/common/arrow/arrow.hpp" @@ -66,11 +67,8 @@ namespace duckdb { -DefaultConnectionHolder DuckDBPyConnection::default_connection; // NOLINT: allow global -DBInstanceCache instance_cache; // NOLINT: allow global -shared_ptr DuckDBPyConnection::import_cache = nullptr; // NOLINT: allow global -PythonEnvironmentType DuckDBPyConnection::environment = PythonEnvironmentType::NORMAL; // NOLINT: allow global -std::string DuckDBPyConnection::formatted_python_version = ""; +DuckDBPyConnection::DuckDBPyConnection() { +} DuckDBPyConnection::~DuckDBPyConnection() { try { @@ -82,53 +80,17 @@ DuckDBPyConnection::~DuckDBPyConnection() { } } -void DuckDBPyConnection::DetectEnvironment() { - // Get the formatted Python version - py::module_ sys = py::module_::import("sys"); - py::object version_info = sys.attr("version_info"); - int major = py::cast(version_info.attr("major")); - int minor = py::cast(version_info.attr("minor")); - DuckDBPyConnection::formatted_python_version = std::to_string(major) + "." + std::to_string(minor); - - // If __main__ does not have a __file__ attribute, we are in interactive mode - auto main_module = py::module_::import("__main__"); - if (py::hasattr(main_module, "__file__")) { - return; - } - DuckDBPyConnection::environment = PythonEnvironmentType::INTERACTIVE; - if (!ModuleIsLoaded()) { - return; - } - - // Check to see if we are in a Jupyter Notebook - auto &import_cache_py = *DuckDBPyConnection::ImportCache(); - auto get_ipython = import_cache_py.IPython.get_ipython(); - if (get_ipython.ptr() == nullptr) { - // Could either not load the IPython module, or it has no 'get_ipython' attribute - return; - } - auto ipython = get_ipython(); - if (!py::hasattr(ipython, "config")) { - return; - } - py::dict ipython_config = ipython.attr("config"); - if (ipython_config.contains("IPKernelApp")) { - DuckDBPyConnection::environment = PythonEnvironmentType::JUPYTER; - } - return; -} - bool DuckDBPyConnection::DetectAndGetEnvironment() { - DuckDBPyConnection::DetectEnvironment(); + // Environment detection now happens during module state construction return DuckDBPyConnection::IsInteractive(); } bool DuckDBPyConnection::IsJupyter() { - return DuckDBPyConnection::environment == PythonEnvironmentType::JUPYTER; + return GetModuleState().environment == PythonEnvironmentType::JUPYTER; } std::string DuckDBPyConnection::FormattedPythonVersion() { - return DuckDBPyConnection::formatted_python_version; + return GetModuleState().formatted_python_version; } // NOTE: this function is generated by tools/pythonpkg/scripts/generate_connection_methods.py. @@ -2113,8 +2075,8 @@ static shared_ptr FetchOrCreateInstance(const string &databa D_ASSERT(py::gil_check()); py::gil_scoped_release release; unique_lock lock(res->py_connection_lock); - auto database = - instance_cache.GetOrCreateInstance(database_path, config, cache_instance, InstantiateNewInstance); + auto database = GetModuleState().GetInstanceCache()->GetOrCreateInstance(database_path, config, cache_instance, + InstantiateNewInstance); res->con.SetDatabase(std::move(database)); res->con.SetConnection(make_uniq(res->con.GetDatabase())); } @@ -2162,6 +2124,7 @@ shared_ptr DuckDBPyConnection::Connect(const py::object &dat "python_scan_all_frames", "If set, restores the old behavior of scanning all preceding frames to locate the referenced variable.", LogicalType::BOOLEAN, Value::BOOLEAN(false)); + // Use static methods here since we don't have connection instance yet if (!DuckDBPyConnection::IsJupyter()) { config_dict["duckdb_api"] = Value("python/" + DuckDBPyConnection::FormattedPythonVersion()); } else { @@ -2197,18 +2160,27 @@ case_insensitive_map_t DuckDBPyConnection::TransformPythonPa } shared_ptr DuckDBPyConnection::DefaultConnection() { - return default_connection.Get(); + return GetModuleState().GetDefaultConnection(); } void DuckDBPyConnection::SetDefaultConnection(shared_ptr connection) { - return default_connection.Set(std::move(connection)); + GetModuleState().SetDefaultConnection(std::move(connection)); } PythonImportCache *DuckDBPyConnection::ImportCache() { - if (!import_cache) { - import_cache = make_shared_ptr(); - } - return import_cache.get(); + return GetModuleState().GetImportCache(); +} + +bool DuckDBPyConnection::IsJupyterInstance() const { + return GetModuleState().environment == PythonEnvironmentType::JUPYTER; +} + +bool DuckDBPyConnection::IsInteractiveInstance() const { + return GetModuleState().environment != PythonEnvironmentType::NORMAL; +} + +std::string DuckDBPyConnection::FormattedPythonVersionInstance() const { + return GetModuleState().formatted_python_version; } ModifiedMemoryFileSystem &DuckDBPyConnection::GetObjectFileSystem() { @@ -2228,7 +2200,7 @@ ModifiedMemoryFileSystem &DuckDBPyConnection::GetObjectFileSystem() { } bool DuckDBPyConnection::IsInteractive() { - return DuckDBPyConnection::environment != PythonEnvironmentType::NORMAL; + return GetModuleState().environment != PythonEnvironmentType::NORMAL; } shared_ptr DuckDBPyConnection::Enter() { @@ -2246,8 +2218,25 @@ void DuckDBPyConnection::Exit(DuckDBPyConnection &self, const py::object &exc_ty } void DuckDBPyConnection::Cleanup() { - default_connection.Set(nullptr); - import_cache.reset(); + try { + GetModuleState().ClearDefaultConnection(); + GetModuleState().ClearImportCache(); + } catch (...) { // NOLINT + // TODO: Can we detect shutdown? Py_IsFinalizing might be appropriate, although renamed from + // _Py_IsFinalizing + } +} + +shared_ptr DuckDBPyConnection::GetDefaultConnection() { + return GetModuleState().GetDefaultConnection(); +} + +void DuckDBPyConnection::ClearDefaultConnection() { + GetModuleState().ClearDefaultConnection(); +} + +void DuckDBPyConnection::ClearImportCache() { + GetModuleState().ClearImportCache(); } bool DuckDBPyConnection::IsPandasDataframe(const py::object &object) { diff --git a/tests/conftest.py b/tests/conftest.py index 5e297aee..fe8f2a20 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -336,3 +336,13 @@ def finalizer(): duckdb.connect(test_dbfarm) return test_dbfarm + + +@pytest.fixture(scope="function") +def num_threads_testing(): + """Get thread count: enough to load the system, but still as fast test.""" + import multiprocessing + + cpu_count = multiprocessing.cpu_count() + # Use 1.5x CPU count, max 12 for CI compatibility + return min(12, max(4, int(cpu_count * 1.5))) diff --git a/tests/fast/api/test_connection_interrupt.py b/tests/fast/api/test_connection_interrupt.py index 4efd68b5..40a7b618 100644 --- a/tests/fast/api/test_connection_interrupt.py +++ b/tests/fast/api/test_connection_interrupt.py @@ -16,13 +16,14 @@ def test_connection_interrupt(self): def interrupt(): # Wait for query to start running before interrupting - time.sleep(0.1) + time.sleep(1) conn.interrupt() thread = threading.Thread(target=interrupt) thread.start() with pytest.raises(duckdb.InterruptException): - conn.execute("select count(*) from range(100000000000)").fetchall() + conn.execute('select * from range(100000) t1,range(100000) t2').fetchall() + thread.join() def test_interrupt_closed_connection(self): diff --git a/tests/fast/api/test_query_interrupt.py b/tests/fast/api/test_query_interrupt.py index 6334e475..312414a6 100644 --- a/tests/fast/api/test_query_interrupt.py +++ b/tests/fast/api/test_query_interrupt.py @@ -9,7 +9,7 @@ def send_keyboard_interrupt(): # Wait a little, so we're sure the 'execute' has started - time.sleep(0.1) + time.sleep(1) # Send an interrupt to the main thread thread.interrupt_main() @@ -25,7 +25,7 @@ def test_query_interruption(self): # Start the thread thread.start() try: - res = con.execute('select count(*) from range(100000000000)').fetchall() + con.execute('select * from range(100000) t1,range(100000) t2').fetchall() except RuntimeError: # If this is not reached, we could not cancel the query before it completed # indicating that the query interruption functionality is broken diff --git a/tests/fast/threading/README.md b/tests/fast/threading/README.md new file mode 100644 index 00000000..be5b8f53 --- /dev/null +++ b/tests/fast/threading/README.md @@ -0,0 +1,10 @@ +Tests in this directory are intended to be run with [pytest-run-parallel](https://github.com/Quansight-Labs/pytest-run-parallel) to exercise thread safety. + +Example usage: `pytest --parallel-threads=10 --iterations=5 --verbose tests/fast/threading -n 4 --durations=5` + +#### Thread Safety and DuckDB + +Not all duckdb operations are thread safe - cursors are not thread safe, so some care must be considered to avoid running tests that concurrently hit the same tests. + +Tests can be marked as single threaded with: +- `pytest.mark.thread_unsafe` or the equivalent `pytest.mark.parallel_threads(1)` diff --git a/tests/fast/threading/test_basic_operations.py b/tests/fast/threading/test_basic_operations.py new file mode 100644 index 00000000..266fd295 --- /dev/null +++ b/tests/fast/threading/test_basic_operations.py @@ -0,0 +1,117 @@ +import gc +import random +import time +import weakref +from threading import get_ident + +import uuid + +import pytest + +import duckdb + + +def test_basic(): + with duckdb.connect(":memory:") as conn: + result = conn.execute("SELECT 1").fetchone() + assert result[0] == 1 + int_type = duckdb.type("INTEGER") + assert int_type is not None, "type creation failed" + + +def test_connection_instance_cache(tmp_path): + thread_id = get_ident() + for i in range(10): + with duckdb.connect(tmp_path / f"{thread_id}_{uuid.uuid4()}.db") as conn: + conn.execute( + f"CREATE TABLE IF NOT EXISTS thread_{thread_id}_data_{i} (x BIGINT)" + ) + conn.execute(f"INSERT INTO thread_{thread_id}_data_{i} VALUES (100), (100)") + + time.sleep(random.uniform(0.0001, 0.001)) + + result = conn.execute( + f"SELECT COUNT(*) FROM thread_{thread_id}_data_{i}" + ).fetchone()[0] + assert result == 2, f"Iteration {i}: expected 2 rows, got {result}" + + +def test_cleanup(): + weak_refs = [] + + for i in range(5): + conn = duckdb.connect(":memory:") + weak_refs.append(weakref.ref(conn)) + try: + conn.execute("CREATE TABLE test (x INTEGER)") + conn.execute("INSERT INTO test VALUES (1), (2), (3)") + result = conn.execute("SELECT COUNT(*) FROM test").fetchone() + assert result[0] == 3 + finally: + conn.close() + conn = None + + if i % 3 == 0: + with duckdb.connect(":memory:") as new_conn: + result = new_conn.execute("SELECT 1").fetchone() + assert result[0] == 1 + + if i % 10 == 0: + gc.collect() + time.sleep(random.uniform(0.0001, 0.0005)) + + gc.collect() + time.sleep(0.1) + gc.collect() + + alive_refs = [ref for ref in weak_refs if ref() is not None] + assert len(alive_refs) <= 10, ( + f"{len(alive_refs)} connections still alive (expected <= 10)" + ) + + +def test_default_connection(): + with duckdb.connect() as conn1: + r1 = conn1.execute("SELECT 1").fetchone()[0] + assert r1 == 1, f"expected 1, got {r1}" + + with duckdb.connect(":memory:") as conn2: + r2 = conn2.execute("SELECT 2").fetchone()[0] + assert r2 == 2, f"expected 2, got {r2}" + + +def test_type_system(): + for i in range(20): + types = [ + duckdb.type("INTEGER"), + duckdb.type("VARCHAR"), + duckdb.type("DOUBLE"), + duckdb.type("BOOLEAN"), + duckdb.list_type(duckdb.type("INTEGER")), + duckdb.struct_type( + {"a": duckdb.type("INTEGER"), "b": duckdb.type("VARCHAR")} + ), + ] + + for t in types: + assert t is not None, "type creation failed" + + if i % 5 == 0: + with duckdb.connect(":memory:") as conn: + conn.execute( + "CREATE TABLE test (a INTEGER, b VARCHAR, c DOUBLE, d BOOLEAN)" + ) + result = conn.execute("SELECT COUNT(*) FROM test").fetchone() + assert result[0] == 0 + + +def test_import_cache(): + with duckdb.connect(":memory:") as conn: + conn.execute("CREATE TABLE test AS SELECT range as x FROM range(10)") + result = conn.fetchdf() + assert len(result) > 0, "fetchdf failed" + + result = conn.execute("SELECT range as x FROM range(5)").fetchnumpy() + assert len(result["x"]) == 5, "fetchnumpy failed" + + conn.execute("DROP TABLE test") diff --git a/tests/fast/threading/test_concurrent_access.py b/tests/fast/threading/test_concurrent_access.py new file mode 100644 index 00000000..6cc8ea8a --- /dev/null +++ b/tests/fast/threading/test_concurrent_access.py @@ -0,0 +1,111 @@ +""" +Concurrent access tests for DuckDB Python bindings with free threading support. + +These tests verify that the DuckDB Python module can handle concurrent access +from multiple threads safely, testing module state isolation, memory management, +and connection handling under various stress conditions. +""" + +import gc +import random +import time +import concurrent.futures + +import pytest + +import duckdb + + +def test_concurrent_connections(): + with duckdb.connect() as conn: + result = conn.execute("SELECT random() as id, random()*2 as doubled").fetchone() + assert result is not None + + +@pytest.mark.parallel_threads(1) +def test_shared_connection_stress(num_threads_testing): + """Test concurrent operations on shared connection using cursors.""" + iterations = 10 + + with duckdb.connect(":memory:") as connection: + connection.execute( + "CREATE TABLE stress_test (id INTEGER, thread_id INTEGER, value TEXT)" + ) + + def worker_thread(thread_id: int) -> None: + cursor = connection.cursor() + for i in range(iterations): + cursor.execute( + "INSERT INTO stress_test VALUES (?, ?, ?)", + [i, thread_id, f"thread_{thread_id}_value_{i}"], + ) + cursor.execute( + "SELECT COUNT(*) FROM stress_test WHERE thread_id = ?", [thread_id] + ).fetchone() + time.sleep(random.uniform(0.0001, 0.001)) + + with concurrent.futures.ThreadPoolExecutor( + max_workers=num_threads_testing + ) as executor: + futures = [ + executor.submit(worker_thread, i) for i in range(num_threads_testing) + ] + # Wait for all to complete, will raise if any fail + for future in concurrent.futures.as_completed(futures): + future.result() + + total_rows = connection.execute("SELECT COUNT(*) FROM stress_test").fetchone()[ + 0 + ] + expected_rows = num_threads_testing * iterations + assert total_rows == expected_rows + + +@pytest.mark.parallel_threads(1) +def test_module_state_isolation(): + """Test that module state is properly accessible.""" + with duckdb.connect(":memory:"): + assert hasattr(duckdb, "__version__") + + with duckdb.connect() as default_conn: + result = default_conn.execute("SELECT 'default' as type").fetchone() + assert result[0] == "default" + + int_type = duckdb.type("INTEGER") + string_type = duckdb.type("VARCHAR") + assert int_type is not None + assert string_type is not None + + +def test_rapid_connect_disconnect(): + connections_count = 10 + """Test rapid connection creation and destruction.""" + for i in range(connections_count): + conn = duckdb.connect(":memory:") + try: + result = conn.execute("SELECT 1").fetchone()[0] + assert result == 1 + finally: + conn.close() + + # Sometimes force GC to increase pressure + if i % 3 == 0: + gc.collect() + + +def test_exception_handling(): + """Test exception handling doesn't affect module state.""" + conn = duckdb.connect(":memory:") + try: + conn.execute("CREATE TABLE test (x INTEGER)") + conn.execute("INSERT INTO test VALUES (1), (2), (3)") + + for i in range(10): + if i % 3 == 0: + with pytest.raises(duckdb.CatalogException): + conn.execute("SELECT * FROM nonexistent_table") + else: + result = conn.execute("SELECT COUNT(*) FROM test").fetchone()[0] + assert result == 3 + finally: + conn.close() diff --git a/tests/fast/threading/test_connection_lifecycle_races.py b/tests/fast/threading/test_connection_lifecycle_races.py new file mode 100644 index 00000000..4e5922fc --- /dev/null +++ b/tests/fast/threading/test_connection_lifecycle_races.py @@ -0,0 +1,105 @@ +""" +Test connection lifecycle races. + +Focused on DuckDBPyConnection constructor and Close +""" + +import gc +import concurrent.futures + +import pytest + +import duckdb + + +def test_concurrent_connection_creation_destruction(): + conn = duckdb.connect() + try: + result = conn.execute("SELECT 1").fetchone() + assert result[0] == 1 + finally: + conn.close() + + +def test_connection_destructor_race(): + conn = duckdb.connect() + result = conn.execute("SELECT COUNT(*) FROM range(1)").fetchone() + assert result[0] == 1 + + del conn + gc.collect() + + +@pytest.mark.parallel_threads(1) +def test_concurrent_close_operations(num_threads_testing): + with duckdb.connect(":memory:") as conn: + conn.execute("CREATE TABLE shared_table (id INTEGER, data VARCHAR)") + conn.execute("INSERT INTO shared_table VALUES (1, 'test')") + + def attempt_close_connection(cursor, thread_id): + _result = cursor.execute("SELECT COUNT(*) FROM shared_table").fetchone() + + cursor.close() + + return True + + with concurrent.futures.ThreadPoolExecutor( + max_workers=num_threads_testing + ) as executor: + futures = [ + executor.submit(attempt_close_connection, conn.cursor(), i) + for i in range(num_threads_testing) + ] + results = [ + future.result() for future in concurrent.futures.as_completed(futures) + ] + + assert all(results) + + +@pytest.mark.parallel_threads(1) +def test_cursor_operations_race(num_threads_testing): + conn = duckdb.connect(":memory:") + try: + conn.execute("CREATE TABLE cursor_test (id INTEGER, name VARCHAR)") + conn.execute( + "INSERT INTO cursor_test SELECT i, 'name_' || i FROM range(100) t(i)" + ) + + def cursor_operations(thread_id): + """Perform cursor operations concurrently.""" + # Get a cursor + cursor = conn.cursor() + cursor.execute( + f"SELECT * FROM cursor_test WHERE id % {num_threads_testing} = {thread_id}" + ) + results = cursor.fetchall() + + return True + + with concurrent.futures.ThreadPoolExecutor( + max_workers=num_threads_testing + ) as executor: + futures = [ + executor.submit(cursor_operations, i) + for i in range(num_threads_testing) + ] + results = [ + future.result() for future in concurrent.futures.as_completed(futures) + ] + + assert all(results) + finally: + conn.close() + + +def test_rapid_connection_cycling(): + """Test rapid connection creation and destruction cycles.""" + num_cycles = 5 + for cycle in range(num_cycles): + conn = duckdb.connect(":memory:") + try: + result = conn.execute(f"SELECT 1 + {cycle}").fetchone() + assert result[0] == 1 + cycle + finally: + conn.close() diff --git a/tests/fast/threading/test_fetching.py b/tests/fast/threading/test_fetching.py new file mode 100644 index 00000000..dc7024b6 --- /dev/null +++ b/tests/fast/threading/test_fetching.py @@ -0,0 +1,46 @@ +""" +Test fetching operations. +""" + +from threading import get_ident + +import pytest + +import duckdb + + +def test_fetching(): + """Test different fetching methods.""" + iterations = 10 + thread_id = get_ident() + + conn = duckdb.connect() + try: + batch_data = [ + (thread_id * 100 + i, f"name_{thread_id}_{i}") for i in range(iterations) + ] + conn.execute("CREATE TABLE batch_data (id BIGINT, name VARCHAR)") + conn.executemany("INSERT INTO batch_data VALUES (?, ?)", batch_data) + + # Test different fetch methods + result1 = conn.execute( + f"SELECT COUNT(*) FROM batch_data WHERE name LIKE 'name_{thread_id}_%'" + ).fetchone() + assert result1[0] == iterations + + result2 = conn.execute( + f"SELECT COUNT(*) FROM batch_data WHERE name LIKE 'name_{thread_id}_%'" + ).fetchall() + assert result2[0][0] == iterations + + result3 = conn.execute( + f"SELECT COUNT(*) FROM batch_data WHERE name LIKE 'name_{thread_id}_%'" + ).fetchdf() + assert len(result3) == 1 + + result4 = conn.execute( + f"SELECT COUNT(*) FROM batch_data WHERE name LIKE 'name_{thread_id}_%'" + ).fetch_arrow_table() + assert result4.num_rows == 1 + finally: + conn.close() diff --git a/tests/fast/threading/test_module_lifecycle.py b/tests/fast/threading/test_module_lifecycle.py new file mode 100644 index 00000000..0b265108 --- /dev/null +++ b/tests/fast/threading/test_module_lifecycle.py @@ -0,0 +1,148 @@ +""" +Test module lifecycle + +Reloading and unload are not expected nor required behaviors - +these tests are to document current behavior so that changes +are visible. +""" + +import importlib +import sys +from threading import get_ident + +import pytest + + +@pytest.mark.parallel_threads(1) +def test_module_reload_safety(): + """Test module reloading scenarios to detect use-after-free issues.""" + import duckdb + + with duckdb.connect(":memory:") as conn1: + conn1.execute("CREATE TABLE test (id INTEGER)") + conn1.execute("INSERT INTO test VALUES (1)") + result1 = conn1.execute("SELECT * FROM test").fetchone()[0] + assert result1 == 1 + + initial_module_id = id(sys.modules["duckdb"]) + + # Test importlib.reload() - + # does NOT create new module in Python + importlib.reload(duckdb) + + # Verify module instance is the same (expected Python behavior) + reload_module_id = id(sys.modules["duckdb"]) + assert initial_module_id == reload_module_id, ( + "importlib.reload() should reuse same module instance" + ) + + # Test if old connection still works after importlib.reload() + result2 = conn1.execute("SELECT * FROM test").fetchone()[0] + assert result2 == 1 + + # Test new connection after importlib.reload() + with duckdb.connect(":memory:") as conn2: + conn2.execute("CREATE TABLE test2 (id INTEGER)") + conn2.execute("INSERT INTO test2 VALUES (2)") + result3 = conn2.execute("SELECT * FROM test2").fetchone()[0] + assert result3 == 2 + + +@pytest.mark.parallel_threads(1) +def test_dynamic_module_loading(): + import duckdb + + with duckdb.connect(":memory:") as conn: + conn.execute("SELECT 1").fetchone() + + module_id_1 = id(sys.modules["duckdb"]) + + # "Unload" module (not really, just to try it) + if "duckdb" in sys.modules: + del sys.modules["duckdb"] + + # Remove from local namespace + if "duckdb" in locals(): + del duckdb + + # Verify module is unloaded + assert "duckdb" not in sys.modules, "Module not properly unloaded" + + # import (load) module + import duckdb + + module_id_2 = id(sys.modules["duckdb"]) + + # Verify we have a new module instance + assert module_id_1 != module_id_2, "Module not actually reloaded" + + # Test functionality after reload + with duckdb.connect(":memory:") as conn: + conn.execute("CREATE TABLE test (id INTEGER)") + conn.execute("INSERT INTO test VALUES (42)") + result = conn.execute("SELECT * FROM test").fetchone()[0] + assert result == 42 + + +def test_import_cache_consistency(): + """Test that import cache remains consistent across module operations.""" + + import duckdb + import pandas as pd + + conn = duckdb.connect(":memory:") + + df = pd.DataFrame({"a": [1, 2, 3]}) + + conn.register("test_df", df) + result = conn.execute("SELECT COUNT(*) FROM test_df").fetchone()[0] + assert result == 3 + + conn.close() + + +def test_module_state_memory_safety(): + """Test memory safety of module state access patterns.""" + + import duckdb + + connections = [] + for i in range(10): + conn = duckdb.connect(":memory:") + conn.execute(f"CREATE TABLE test_{i} (id INTEGER)") + conn.execute(f"INSERT INTO test_{i} VALUES ({i})") + connections.append(conn) + + import gc + + gc.collect() + + for i, conn in enumerate(connections): + result = conn.execute(f"SELECT * FROM test_{i}").fetchone()[0] + assert result == i + + for conn in connections: + conn.close() + + +def test_static_cache_stress(): + """Test rapid module state access.""" + import duckdb + + iterations = 5 + for i in range(iterations): + conn = duckdb.connect(":memory:") + result = conn.execute("SELECT 1").fetchone() + assert result[0] == 1 + conn.close() + + +def test_concurrent_module_access(): + import duckdb + + thread_id = get_ident() + with duckdb.connect(":memory:") as conn: + conn.execute(f"CREATE TABLE test_{thread_id} (id BIGINT)") + conn.execute(f"INSERT INTO test_{thread_id} VALUES ({thread_id})") + result = conn.execute(f"SELECT * FROM test_{thread_id}").fetchone()[0] + assert result == thread_id diff --git a/tests/fast/threading/test_module_state.py b/tests/fast/threading/test_module_state.py new file mode 100644 index 00000000..7a1ad231 --- /dev/null +++ b/tests/fast/threading/test_module_state.py @@ -0,0 +1,38 @@ +from threading import get_ident + +import pytest + +import duckdb + + +def test_concurrent_connection_creation(): + thread_id = get_ident() + for i in range(5): + with duckdb.connect(":memory:") as conn: + conn.execute(f"CREATE TABLE test_{i} (x BIGINT)") + conn.execute(f"INSERT INTO test_{i} VALUES ({thread_id})") + result = conn.execute(f"SELECT * FROM test_{i}").fetchall() + assert result == [(thread_id,)], f"Table {i} failed" + + +def test_concurrent_instance_cache_access(tmp_path): + thread_id = get_ident() + for i in range(10): + db_path = str(tmp_path / f"test_{thread_id}_{i}.db") + with duckdb.connect(db_path) as conn: + conn.execute("CREATE TABLE IF NOT EXISTS test (x BIGINT, thread_id BIGINT)") + conn.execute(f"INSERT INTO test VALUES ({i}, {thread_id})") + result = conn.execute("SELECT COUNT(*) FROM test").fetchone() + assert result[0] >= 1 + + +def test_environment_detection(): + version = duckdb.__formatted_python_version__ + interactive = duckdb.__interactive__ + + assert isinstance(version, str), "version should be string" + assert isinstance(interactive, bool), "interactive should be boolean" + + with duckdb.connect(":memory:") as conn: + result = conn.execute("SELECT 1").fetchone() + assert result[0] == 1 diff --git a/tests/fast/threading/test_query_execution_races.py b/tests/fast/threading/test_query_execution_races.py new file mode 100644 index 00000000..e3128219 --- /dev/null +++ b/tests/fast/threading/test_query_execution_races.py @@ -0,0 +1,194 @@ +""" +Test concurrent query execution races. + +This tests race conditions in query execution paths where GIL is released +during query processing, as identified in pyconnection.cpp. +""" + +import concurrent.futures +import threading +from threading import get_ident + +import pytest + +import duckdb + + +class QueryRaceTester: + """Increases contention by aligning tests w a barrier""" + + def setup_barrier(self, num_threads): + self.barrier = threading.Barrier(num_threads) + + def synchronized_execute(self, db, query, description="query"): + with db.cursor() as conn: + self.barrier.wait() + result = conn.execute(query).fetchall() + return True + + +@pytest.mark.parallel_threads(1) +def test_concurrent_prepare_execute(): + num_threads = 5 + conn = duckdb.connect(":memory:") + try: + conn.execute("CREATE TABLE test_data (id INTEGER, value VARCHAR)") + conn.execute( + "INSERT INTO test_data SELECT i, 'value_' || i FROM range(1000) t(i)" + ) + + tester = QueryRaceTester() + tester.setup_barrier(num_threads) + + def prepare_and_execute(thread_id, conn): + queries = [ + f"SELECT COUNT(*) FROM test_data WHERE id > {thread_id * 10}", + f"SELECT value FROM test_data WHERE id = {thread_id + 1}", + f"SELECT id, value FROM test_data WHERE id BETWEEN {thread_id} AND {thread_id + 10}", + f"INSERT INTO test_data VALUES ({1000 + thread_id}, 'thread_{thread_id}')", + f"UPDATE test_data SET value = 'updated_{thread_id}' WHERE id = {thread_id + 500}", + ] + + query = queries[thread_id % len(queries)] + return tester.synchronized_execute( + conn, query, f"Prepared query {thread_id}" + ) + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [ + executor.submit(prepare_and_execute, i, conn) + for i in range(num_threads) + ] + results = [ + future.result() for future in concurrent.futures.as_completed(futures) + ] + + assert len(results) == num_threads and all(results) + finally: + conn.close() + + +@pytest.mark.parallel_threads(1) +def test_concurrent_pending_query_execution(): + conn = duckdb.connect(":memory:") + try: + conn.execute( + "CREATE TABLE large_data AS SELECT i, i*2 as double_val, 'row_' || i as str_val FROM range(10000) t(i)" + ) + + num_threads = 8 + tester = QueryRaceTester() + tester.setup_barrier(num_threads) + + def execute_long_query(thread_id): + queries = [ + "SELECT COUNT(*), AVG(double_val) FROM large_data", + "SELECT str_val, double_val FROM large_data WHERE i % 100 = 0 ORDER BY double_val", + f"SELECT * FROM large_data WHERE i BETWEEN {thread_id * 1000} AND {(thread_id + 1) * 1000}", + "SELECT i, double_val, str_val FROM large_data WHERE double_val > 5000 ORDER BY i DESC", + f"SELECT COUNT(*) as cnt FROM large_data WHERE str_val LIKE '%{thread_id}%'", + ] + + query = queries[thread_id % len(queries)] + return tester.synchronized_execute(conn, query, f"Long query {thread_id}") + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [ + executor.submit(execute_long_query, i) for i in range(num_threads) + ] + results = [ + future.result() for future in concurrent.futures.as_completed(futures) + ] + + assert all(results) and len(results) == num_threads + finally: + conn.close() + + +def test_execute_many_race(): + """Test executemany operations.""" + iterations = 10 + thread_id = get_ident() + + conn = duckdb.connect() + try: + batch_data = [ + (thread_id * 100 + i, f"name_{thread_id}_{i}") for i in range(iterations) + ] + conn.execute("CREATE TABLE batch_data (id BIGINT, name VARCHAR)") + conn.executemany("INSERT INTO batch_data VALUES (?, ?)", batch_data) + result = conn.execute( + f"SELECT COUNT(*) FROM batch_data WHERE name LIKE 'name_{thread_id}_%'" + ).fetchone() + assert result[0] == iterations + finally: + conn.close() + + +@pytest.mark.parallel_threads(1) +def test_query_interruption_race(): + conn = duckdb.connect(":memory:") + try: + conn.execute("CREATE TABLE interrupt_test AS SELECT i FROM range(100000) t(i)") + + num_threads = 6 + + def run_interruptible_query(thread_id): + with conn.cursor() as conn2: + if thread_id % 2 == 0: + # Fast query + result = conn2.execute( + "SELECT COUNT(*) FROM interrupt_test" + ).fetchall() + return True + else: + # Potentially slower query + result = conn2.execute( + "SELECT i, i*i FROM interrupt_test WHERE i % 1000 = 0 ORDER BY i" + ).fetchall() + return True + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [ + executor.submit(run_interruptible_query, i) for i in range(num_threads) + ] + results = [ + future.result() + for future in concurrent.futures.as_completed(futures, timeout=30) + ] + + assert all(results) + finally: + conn.close() + + +def test_mixed_query_operations(): + """Test mixed query operations.""" + thread_id = get_ident() + + with duckdb.connect(":memory:") as conn: + conn.execute( + "CREATE TABLE mixed_ops (id BIGINT PRIMARY KEY, data VARCHAR, num_val DOUBLE)" + ) + conn.execute( + "INSERT INTO mixed_ops SELECT i, 'initial_' || i, i * 1.5 FROM range(1000) t(i)" + ) + + queries = [ + f"SELECT COUNT(*) FROM mixed_ops WHERE id > {thread_id * 50}", + f"INSERT INTO mixed_ops VALUES ({10000 + thread_id}, 'thread_{thread_id}', {thread_id * 2.5})", + f"UPDATE mixed_ops SET data = 'updated_{thread_id}' WHERE id = {thread_id + 100}", + "SELECT AVG(num_val), MAX(id) FROM mixed_ops WHERE data LIKE 'initial_%'", + """ + SELECT m1.id, m1.data, m2.num_val + FROM mixed_ops m1 + JOIN mixed_ops m2 ON m1.id = m2.id - 1 + LIMIT 10 + """, + ] + + for query in queries: + result = conn.execute(query) + if "SELECT" in query.upper(): + rows = result.fetchall() + assert len(rows) >= 0 diff --git a/tests/fast/threading/test_threading.py b/tests/fast/threading/test_threading.py new file mode 100644 index 00000000..db164b9c --- /dev/null +++ b/tests/fast/threading/test_threading.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +""" +Tests designed to expose specific threading bugs in the DuckDB implementation. +""" + +import sys +from threading import get_ident + +import pytest + +import duckdb + + +def test_gil_enabled(): + # Safeguard to ensure GIL is disabled if this is a free-threading build to ensure test validity + # this would fail if tests were run with PYTHON_GIL=1, as one example + if "free-threading" in sys.version: + import sysconfig + + print(f"Free-threading Python detected: {sys.version}") + print(f"Py_GIL_DISABLED = {sysconfig.get_config_var('Py_GIL_DISABLED')}") + + assert sysconfig.get_config_var("Py_GIL_DISABLED") == 1, ( + f"Py_GIL_DISABLED must be 1 in free-threading build, got: {sysconfig.get_config_var('Py_GIL_DISABLED')}" + ) + + +def test_instance_cache_race(tmp_path): + """Test opening connections to different files.""" + + tid = get_ident() + with duckdb.connect(tmp_path / f"{tid}_testing.db") as conn: + conn.execute("CREATE TABLE IF NOT EXISTS test (x INTEGER, y INTEGER)") + conn.execute(f"INSERT INTO test VALUES (123, 456)") + result = conn.execute("SELECT COUNT(*) FROM test").fetchone() + assert result[0] >= 1 diff --git a/tests/fast/threading/test_udf_threaded.py b/tests/fast/threading/test_udf_threaded.py new file mode 100644 index 00000000..7f84d763 --- /dev/null +++ b/tests/fast/threading/test_udf_threaded.py @@ -0,0 +1,87 @@ +""" +Test User Defined Function (UDF). +""" + +import concurrent.futures +import threading + +import pytest + +import duckdb + + +def test_concurrent_udf_registration(): + """Test UDF registration.""" + with duckdb.connect(":memory:") as conn: + + def my_add(x: int, y: int) -> int: + return x + y + + udf_name = "my_add_1" + conn.create_function(udf_name, my_add) + + result = conn.execute(f"SELECT {udf_name}(1, 2)").fetchone() + assert result[0] == 3 + + +def test_mixed_udf_operations(): + conn = duckdb.connect(":memory:") + try: + # Register and use UDF + def thread_func(x: int) -> int: + return x * 2 + + udf_name = "thread_func_1" + conn.create_function(udf_name, thread_func) + result1 = conn.execute(f"SELECT {udf_name}(5)").fetchone() + assert result1[0] == 10 + + # Simple query + result2 = conn.execute("SELECT 42").fetchone() + assert result2[0] == 42 + + # Create table and use built-in functions + conn.execute("CREATE TABLE test_table (x INTEGER)") + conn.execute("INSERT INTO test_table VALUES (1), (2), (3)") + result3 = conn.execute("SELECT COUNT(*) FROM test_table").fetchone() + assert result3[0] == 3 + finally: + conn.close() + + +@pytest.mark.parallel_threads(1) +def test_scalar_udf_concurrent(): + num_threads = 5 + conn = duckdb.connect(":memory:") + + # Create test data + conn.execute("CREATE TABLE numbers (x INTEGER)") + conn.execute("INSERT INTO numbers SELECT * FROM range(100)") + + # Create a simple scalar UDF instead of vectorized (simpler for testing) + def simple_square(x: int) -> int: + """Square a single value.""" + return x * x + + conn.create_function("simple_square", simple_square) + + def execute_scalar_udf(thread_id): + start = thread_id * 10 + end = start + 10 + query = ( + f"SELECT simple_square(x) FROM numbers WHERE x BETWEEN {start} AND {end}" + ) + with conn.cursor() as c: + assert c.execute(query).fetchone()[0] == (start**2) + + return True + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(execute_scalar_udf, i) for i in range(num_threads)] + results = [ + future.result() for future in concurrent.futures.as_completed(futures) + ] + + conn.close() + + assert all(results)