Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion _duckdb-stubs/_func.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
import typing as pytyping

__all__: list[str] = ["ARROW", "DEFAULT", "NATIVE", "SPECIAL", "FunctionNullHandling", "PythonUDFType"]
__all__: list[str] = [
"ARROW",
"DEFAULT",
"NATIVE",
"SPECIAL",
"FunctionNullHandling",
"PythonTableUDFType",
"PythonUDFType",
]

class FunctionNullHandling:
DEFAULT: pytyping.ClassVar[FunctionNullHandling] # value = <FunctionNullHandling.DEFAULT: 0>
Expand All @@ -21,6 +29,25 @@ class FunctionNullHandling:
@property
def value(self) -> int: ...

class PythonTableUDFType:
ARROW_TABLE: pytyping.ClassVar[PythonTableUDFType] # value = <PythonTableUDFType.ARROW_TABLE: 1>
TUPLES: pytyping.ClassVar[PythonTableUDFType] # value = <PythonTableUDFType.TUPLES: 0>
__members__: pytyping.ClassVar[
dict[str, PythonTableUDFType]
] # value = {'TUPLES': <PythonTableUDFType.TUPLES: 0>, 'ARROW_TABLE': <PythonTableUDFType.ARROW_TABLE: 1>}
def __eq__(self, other: object) -> bool: ...
def __getstate__(self) -> int: ...
def __hash__(self) -> int: ...
def __index__(self) -> int: ...
def __init__(self, value: pytyping.SupportsInt) -> None: ...
def __int__(self) -> int: ...
def __ne__(self, other: object) -> bool: ...
def __setstate__(self, state: pytyping.SupportsInt) -> None: ...
@property
def name(self) -> str: ...
@property
def value(self) -> int: ...

class PythonUDFType:
ARROW: pytyping.ClassVar[PythonUDFType] # value = <PythonUDFType.ARROW: 1>
NATIVE: pytyping.ClassVar[PythonUDFType] # value = <PythonUDFType.NATIVE: 0>
Expand Down
12 changes: 10 additions & 2 deletions duckdb/func/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
from _duckdb._func import ARROW, DEFAULT, NATIVE, SPECIAL, FunctionNullHandling, PythonUDFType # noqa: D104
from _duckdb._func import ( # noqa: D104
ARROW,
DEFAULT,
NATIVE,
SPECIAL,
FunctionNullHandling,
PythonTableUDFType,
PythonUDFType,
)

__all__ = ["ARROW", "DEFAULT", "NATIVE", "SPECIAL", "FunctionNullHandling", "PythonUDFType"]
__all__ = ["ARROW", "DEFAULT", "NATIVE", "SPECIAL", "FunctionNullHandling", "PythonTableUDFType", "PythonUDFType"]
4 changes: 2 additions & 2 deletions duckdb/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import warnings

from duckdb.func import ARROW, DEFAULT, NATIVE, SPECIAL, FunctionNullHandling, PythonUDFType
from duckdb.func import ARROW, DEFAULT, NATIVE, SPECIAL, FunctionNullHandling, PythonTableUDFType, PythonUDFType

__all__ = ["ARROW", "DEFAULT", "NATIVE", "SPECIAL", "FunctionNullHandling", "PythonUDFType"]
__all__ = ["ARROW", "DEFAULT", "NATIVE", "SPECIAL", "FunctionNullHandling", "PythonTableUDFType", "PythonUDFType"]

warnings.warn(
"`duckdb.functional` is deprecated and will be removed in a future version. Please use `duckdb.func` instead.",
Expand Down
48 changes: 46 additions & 2 deletions scripts/connection_methods.json
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,51 @@
],
"return": "DuckDBPyConnection"
},
{
"name": "create_table_function",
"function": "RegisterTableFunction",
"docs": "Register a table valued function via Callable",
"args": [
{
"name": "name",
"type": "str"
},
{
"name": "callable",
"type": "Callable"
}
],
"kwargs": [
{
"name": "parameters",
"type": "Optional[Any]",
"default": "None"
},
{
"name": "schema",
"type": "Optional[Any]",
"default": "None"
},
{
"name": "type",
"type": "Optional[PythonTableUDFType]",
"default": "PythonTableUDFType.TUPLES"
}
],
"return": "DuckDBPyConnection"
},
{
"name": "unregister_table_function",
"function": "UnregisterTableFunction",
"docs": "Unregister a table valued function",
"args": [
{
"name": "name",
"type": "str"
}
],
"return": "DuckDBPyConnection"
},
{
"name": [
"sqltype",
Expand Down Expand Up @@ -412,7 +457,6 @@
"fetch_record_batch",
"arrow"
],

"function": "FetchRecordBatchReader",
"docs": "Fetch an Arrow RecordBatchReader following execute()",
"args": [
Expand Down Expand Up @@ -1094,4 +1138,4 @@
],
"return": "None"
}
]
]
1 change: 1 addition & 0 deletions src/duckdb_py/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ add_library(
python_dependency.cpp
python_import_cache.cpp
python_replacement_scan.cpp
python_table_udf.cpp
python_udf.cpp)

target_link_libraries(python_src PRIVATE _duckdb_dependencies)
5 changes: 5 additions & 0 deletions src/duckdb_py/functional/functional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ void DuckDBPyFunctional::Initialize(py::module_ &parent) {
.value("ARROW", duckdb::PythonUDFType::ARROW)
.export_values();

py::enum_<duckdb::PythonTableUDFType>(m, "PythonTableUDFType")
.value("TUPLES", duckdb::PythonTableUDFType::TUPLES)
.value("ARROW_TABLE", duckdb::PythonTableUDFType::ARROW_TABLE)
.export_values();

py::enum_<duckdb::FunctionNullHandling>(m, "FunctionNullHandling")
.value("DEFAULT", duckdb::FunctionNullHandling::DEFAULT_NULL_HANDLING)
.value("SPECIAL", duckdb::FunctionNullHandling::SPECIAL_HANDLING)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#pragma once

#include "duckdb/common/common.hpp"
#include "duckdb/common/exception.hpp"
#include "duckdb/common/string_util.hpp"

using duckdb::InvalidInputException;
using duckdb::string;
using duckdb::StringUtil;

namespace duckdb {

enum class PythonTableUDFType : uint8_t { TUPLES, ARROW_TABLE };

} // namespace duckdb

using duckdb::PythonTableUDFType;

namespace py = pybind11;

static PythonTableUDFType PythonTableUDFTypeFromString(const string &type) {
auto ltype = StringUtil::Lower(type);
if (ltype.empty() || ltype == "tuples") {
return PythonTableUDFType::TUPLES;
} else if (ltype == "arrow_table") {
return PythonTableUDFType::ARROW_TABLE;
} else {
throw InvalidInputException("'%s' is not a recognized type for 'tvf_type'", type);
}
}

static PythonTableUDFType PythonTableUDFTypeFromInteger(int64_t value) {
if (value == 0) {
return PythonTableUDFType::TUPLES;
} else if (value == 1) {
return PythonTableUDFType::ARROW_TABLE;
} else {
throw InvalidInputException("'%d' is not a recognized type for 'tvf_type'", value);
}
}

namespace PYBIND11_NAMESPACE {
namespace detail {

template <>
struct type_caster<PythonTableUDFType> : public type_caster_base<PythonTableUDFType> {
using base = type_caster_base<PythonTableUDFType>;
PythonTableUDFType tmp;

public:
bool load(handle src, bool convert) {
if (base::load(src, convert)) {
return true;
} else if (py::isinstance<py::str>(src)) {
tmp = PythonTableUDFTypeFromString(py::str(src));
value = &tmp;
return true;
} else if (py::isinstance<py::int_>(src)) {
tmp = PythonTableUDFTypeFromInteger(src.cast<int64_t>());
value = &tmp;
return true;
}
return false;
}

static handle cast(PythonTableUDFType src, return_value_policy policy, handle parent) {
return base::cast(src, policy, parent);
}
};

} // namespace detail
} // namespace PYBIND11_NAMESPACE
15 changes: 15 additions & 0 deletions src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "duckdb/function/scalar_function.hpp"
#include "duckdb_python/pybind11/conversions/exception_handling_enum.hpp"
#include "duckdb_python/pybind11/conversions/python_udf_type_enum.hpp"
#include "duckdb_python/pybind11/conversions/python_table_udf_type_enum.hpp"
#include "duckdb_python/pybind11/conversions/python_csv_line_terminator_enum.hpp"
#include "duckdb/common/shared_ptr.hpp"

Expand Down Expand Up @@ -169,6 +170,8 @@ struct DuckDBPyConnection : public enable_shared_from_this<DuckDBPyConnection> {
//! MemoryFileSystem used to temporarily store file-like objects for reading
shared_ptr<ModifiedMemoryFileSystem> internal_object_filesystem;
case_insensitive_map_t<unique_ptr<ExternalDependency>> registered_functions;
case_insensitive_map_t<unique_ptr<ExternalDependency>> registered_table_functions;

case_insensitive_set_t registered_objects;

public:
Expand Down Expand Up @@ -232,6 +235,13 @@ struct DuckDBPyConnection : public enable_shared_from_this<DuckDBPyConnection> {
PythonExceptionHandling exception_handling = PythonExceptionHandling::FORWARD_ERROR,
bool side_effects = false);

shared_ptr<DuckDBPyConnection> RegisterTableFunction(const string &name, const py::function &function,
const py::object &schema,
PythonTableUDFType type = PythonTableUDFType::TUPLES,
const py::object &parameters = py::none());

shared_ptr<DuckDBPyConnection> UnregisterTableFunction(const string &name);

shared_ptr<DuckDBPyConnection> UnregisterUDF(const string &name);

shared_ptr<DuckDBPyConnection> ExecuteMany(const py::object &query, py::object params = py::list());
Expand Down Expand Up @@ -355,6 +365,11 @@ struct DuckDBPyConnection : public enable_shared_from_this<DuckDBPyConnection> {
const shared_ptr<DuckDBPyType> &return_type, bool vectorized,
FunctionNullHandling null_handling, PythonExceptionHandling exception_handling,
bool side_effects);

duckdb::TableFunction CreateTableFunctionFromCallable(const std::string &name, const py::function &callable,
const py::object &parameters, const py::object &schema,
PythonTableUDFType type);

void RegisterArrowObject(const py::object &arrow_object, const string &name);
vector<unique_ptr<SQLStatement>> GetStatements(const py::object &query);

Expand Down
66 changes: 64 additions & 2 deletions src/duckdb_py/pyconnection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,6 @@ DuckDBPyConnection::RegisterScalarUDF(const string &name, const py::function &ud
auto scalar_function = CreateScalarUDF(name, udf, parameters_p, return_type_p, type == PythonUDFType::ARROW,
null_handling, exception_handling, side_effects);
CreateScalarFunctionInfo info(scalar_function);

context.RegisterFunction(info);

auto dependency = make_uniq<ExternalDependency>();
Expand All @@ -403,6 +402,55 @@ DuckDBPyConnection::RegisterScalarUDF(const string &name, const py::function &ud
return shared_from_this();
}

shared_ptr<DuckDBPyConnection>
DuckDBPyConnection::RegisterTableFunction(const string &name, const py::function &function, const py::object &schema,
PythonTableUDFType type, const py::object &parameters) {

auto &connection = con.GetConnection();
auto &context = *connection.context;

if (context.transaction.HasActiveTransaction()) {
context.CancelTransaction();
}

if (registered_table_functions.find(name) != registered_table_functions.end()) {
throw NotImplementedException("A table function by the name of '%s' is already registered, "
"please unregister it first",
name);
}

auto table_function = CreateTableFunctionFromCallable(name, function, parameters, schema, type);
CreateTableFunctionInfo info(table_function);

// re-registration: changing the callable to another
info.on_conflict = OnCreateConflict::REPLACE_ON_CONFLICT;

context.RegisterFunction(info);

auto dependency = make_uniq<ExternalDependency>();
dependency->AddDependency("function", PythonDependencyItem::Create(function));
registered_table_functions[name] = std::move(dependency);

return shared_from_this();
}

shared_ptr<DuckDBPyConnection> DuckDBPyConnection::UnregisterTableFunction(const string &name) {
auto entry = registered_table_functions.find(name);
if (entry == registered_table_functions.end()) {
throw InvalidInputException(
"No table function by the name of '%s' was found in the list of registered table functions", name);
}

auto &connection = con.GetConnection();
auto &context = *connection.context;

// Remove from our registry.
// TODO: Callable still exists in the function catalog, since duckdb doesn't (yet?) support removal
registered_table_functions.erase(entry);

return shared_from_this();
}

void DuckDBPyConnection::Initialize(py::handle &m) {
auto connection_module =
py::class_<DuckDBPyConnection, shared_ptr<DuckDBPyConnection>>(m, "DuckDBPyConnection", py::module_local());
Expand All @@ -411,6 +459,14 @@ void DuckDBPyConnection::Initialize(py::handle &m) {
.def("__exit__", &DuckDBPyConnection::Exit, py::arg("exc_type"), py::arg("exc"), py::arg("traceback"));
connection_module.def("__del__", &DuckDBPyConnection::Close);

connection_module.def("create_table_function", &DuckDBPyConnection::RegisterTableFunction,
"Register a table user defined function via Callable", py::arg("name"), py::arg("callable"),
py::arg("schema"), py::kw_only(), py::arg("type") = PythonTableUDFType::TUPLES,
py::arg("parameters") = py::none());

connection_module.def("unregister_table_function", &DuckDBPyConnection::UnregisterTableFunction,
"Unregister a table user defined function", py::arg("name"));

InitializeConnectionMethods(connection_module);
connection_module.def_property_readonly("description", &DuckDBPyConnection::GetDescription,
"Get result set attributes, mainly column names");
Expand Down Expand Up @@ -1575,7 +1631,12 @@ unique_ptr<DuckDBPyRelation> DuckDBPyConnection::RunQuery(const py::object &quer
}
if (res->type == QueryResultType::STREAM_RESULT) {
auto &stream_result = res->Cast<StreamQueryResult>();
res = stream_result.Materialize();
{
// Release the GIL, as Materialize *may* need the GIL (TVFs, for instance)
D_ASSERT(py::gil_check());
py::gil_scoped_release release;
res = stream_result.Materialize();
}
}
auto &materialized_result = res->Cast<MaterializedQueryResult>();
relation = make_shared_ptr<MaterializedRelation>(connection.context, materialized_result.TakeCollection(),
Expand Down Expand Up @@ -1826,6 +1887,7 @@ void DuckDBPyConnection::Close() {
// https://peps.python.org/pep-0249/#Connection.close
cursors.ClearCursors();
registered_functions.clear();
registered_table_functions.clear();
}

void DuckDBPyConnection::Interrupt() {
Expand Down
Loading
Loading