-
Notifications
You must be signed in to change notification settings - Fork 32
Python Table UDFs #99
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
a9bb62d
013081f
262ccfd
8384dc3
e2f465e
88198b7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,23 @@ | ||
from _duckdb.functional import ( | ||
FunctionNullHandling, | ||
PythonUDFType, | ||
PythonTVFType, | ||
SPECIAL, | ||
DEFAULT, | ||
NATIVE, | ||
ARROW | ||
ARROW, | ||
TUPLES, | ||
ARROW_TABLE | ||
) | ||
|
||
__all__ = [ | ||
"FunctionNullHandling", | ||
"PythonUDFType", | ||
"PythonTVFType", | ||
"SPECIAL", | ||
"DEFAULT", | ||
"NATIVE", | ||
"ARROW" | ||
"ARROW", | ||
"TUPLES", | ||
"ARROW_TABLE" | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
#pragma once | ||
|
||
#include "duckdb/common/common.hpp" | ||
#include "duckdb/common/exception.hpp" | ||
#include "duckdb/common/string_util.hpp" | ||
|
||
using duckdb::InvalidInputException; | ||
using duckdb::string; | ||
using duckdb::StringUtil; | ||
|
||
namespace duckdb { | ||
|
||
enum class PythonTVFType : uint8_t { TUPLES, ARROW_TABLE }; | ||
|
||
} // namespace duckdb | ||
|
||
using duckdb::PythonTVFType; | ||
|
||
namespace py = pybind11; | ||
|
||
static PythonTVFType PythonTVFTypeFromString(const string &type) { | ||
auto ltype = StringUtil::Lower(type); | ||
if (ltype.empty() || ltype == "tuples") { | ||
return PythonTVFType::TUPLES; | ||
} else if (ltype == "arrow_table") { | ||
return PythonTVFType::ARROW_TABLE; | ||
} else { | ||
throw InvalidInputException("'%s' is not a recognized type for 'tvf_type'", type); | ||
} | ||
} | ||
|
||
static PythonTVFType PythonTVFTypeFromInteger(int64_t value) { | ||
if (value == 0) { | ||
return PythonTVFType::TUPLES; | ||
} else if (value == 1) { | ||
return PythonTVFType::ARROW_TABLE; | ||
} else { | ||
throw InvalidInputException("'%d' is not a recognized type for 'tvf_type'", value); | ||
} | ||
} | ||
|
||
namespace PYBIND11_NAMESPACE { | ||
namespace detail { | ||
|
||
template <> | ||
struct type_caster<PythonTVFType> : public type_caster_base<PythonTVFType> { | ||
using base = type_caster_base<PythonTVFType>; | ||
PythonTVFType tmp; | ||
|
||
public: | ||
bool load(handle src, bool convert) { | ||
if (base::load(src, convert)) { | ||
return true; | ||
} else if (py::isinstance<py::str>(src)) { | ||
tmp = PythonTVFTypeFromString(py::str(src)); | ||
value = &tmp; | ||
return true; | ||
} else if (py::isinstance<py::int_>(src)) { | ||
tmp = PythonTVFTypeFromInteger(src.cast<int64_t>()); | ||
value = &tmp; | ||
return true; | ||
} | ||
return false; | ||
} | ||
|
||
static handle cast(PythonTVFType src, return_value_policy policy, handle parent) { | ||
return base::cast(src, policy, parent); | ||
} | ||
}; | ||
|
||
} // namespace detail | ||
} // namespace PYBIND11_NAMESPACE |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -393,7 +393,6 @@ DuckDBPyConnection::RegisterScalarUDF(const string &name, const py::function &ud | |
auto scalar_function = CreateScalarUDF(name, udf, parameters_p, return_type_p, type == PythonUDFType::ARROW, | ||
null_handling, exception_handling, side_effects); | ||
CreateScalarFunctionInfo info(scalar_function); | ||
|
||
context.RegisterFunction(info); | ||
|
||
auto dependency = make_uniq<ExternalDependency>(); | ||
|
@@ -403,6 +402,57 @@ DuckDBPyConnection::RegisterScalarUDF(const string &name, const py::function &ud | |
return shared_from_this(); | ||
} | ||
|
||
shared_ptr<DuckDBPyConnection> DuckDBPyConnection::RegisterTableFunction(const string &name, | ||
const py::function &function, | ||
const py::object ¶meters, | ||
const py::object &schema, | ||
PythonTVFType type) { | ||
|
||
auto &connection = con.GetConnection(); | ||
auto &context = *connection.context; | ||
|
||
if (context.transaction.HasActiveTransaction()) { | ||
context.CancelTransaction(); | ||
} | ||
|
||
if (registered_table_functions.find(name) != registered_table_functions.end()) { | ||
throw NotImplementedException("A table function by the name of '%s' is already registered, " | ||
"please unregister it first", | ||
name); | ||
} | ||
|
||
auto table_function = CreateTableFunctionFromCallable(name, function, parameters, schema, type); | ||
CreateTableFunctionInfo info(table_function); | ||
|
||
// re-registration: changing the callable to another | ||
info.on_conflict = OnCreateConflict::REPLACE_ON_CONFLICT; | ||
|
||
context.RegisterFunction(info); | ||
|
||
auto dependency = make_uniq<ExternalDependency>(); | ||
dependency->AddDependency("function", PythonDependencyItem::Create(function)); | ||
registered_table_functions[name] = std::move(dependency); | ||
|
||
return shared_from_this(); | ||
} | ||
|
||
shared_ptr<DuckDBPyConnection> DuckDBPyConnection::UnregisterTableFunction(const string &name) { | ||
auto entry = registered_table_functions.find(name); | ||
if (entry == registered_table_functions.end()) { | ||
throw InvalidInputException( | ||
"No table function by the name of '%s' was found in the list of registered table functions", name); | ||
} | ||
|
||
auto &connection = con.GetConnection(); | ||
auto &context = *connection.context; | ||
|
||
// Remove from our registry. | ||
// TODO: Callable still exists in the function catalog, since duckdb doesn't (yet?) support removal | ||
registered_table_functions.erase(entry); | ||
|
||
return shared_from_this(); | ||
} | ||
|
||
void DuckDBPyConnection::Initialize(py::handle &m) { | ||
auto connection_module = | ||
py::class_<DuckDBPyConnection, shared_ptr<DuckDBPyConnection>>(m, "DuckDBPyConnection", py::module_local()); | ||
|
@@ -411,6 +461,14 @@ void DuckDBPyConnection::Initialize(py::handle &m) { | |
.def("__exit__", &DuckDBPyConnection::Exit, py::arg("exc_type"), py::arg("exc"), py::arg("traceback")); | ||
connection_module.def("__del__", &DuckDBPyConnection::Close); | ||
|
||
connection_module.def("create_table_function", &DuckDBPyConnection::RegisterTableFunction, | ||
"Register a table valued function via Callable", py::arg("name"), py::arg("callable"), | ||
py::arg("parameters") = py::none(), py::arg("schema") = py::none(), | ||
|
||
py::arg("type") = PythonTVFType::TUPLES); | ||
|
||
connection_module.def("unregister_table_function", &DuckDBPyConnection::UnregisterTableFunction, | ||
"Unregister a table valued function", py::arg("name")); | ||
|
||
InitializeConnectionMethods(connection_module); | ||
connection_module.def_property_readonly("description", &DuckDBPyConnection::GetDescription, | ||
"Get result set attributes, mainly column names"); | ||
|
@@ -1575,7 +1633,12 @@ unique_ptr<DuckDBPyRelation> DuckDBPyConnection::RunQuery(const py::object &quer | |
} | ||
if (res->type == QueryResultType::STREAM_RESULT) { | ||
auto &stream_result = res->Cast<StreamQueryResult>(); | ||
res = stream_result.Materialize(); | ||
{ | ||
// Release the GIL, as Materialize *may* need the GIL (TVFs, for instance) | ||
D_ASSERT(py::gil_check()); | ||
py::gil_scoped_release release; | ||
res = stream_result.Materialize(); | ||
} | ||
} | ||
auto &materialized_result = res->Cast<MaterializedQueryResult>(); | ||
relation = make_shared_ptr<MaterializedRelation>(connection.context, materialized_result.TakeCollection(), | ||
|
@@ -1826,6 +1889,7 @@ void DuckDBPyConnection::Close() { | |
// https://peps.python.org/pep-0249/#Connection.close | ||
cursors.ClearCursors(); | ||
registered_functions.clear(); | ||
registered_table_functions.clear(); | ||
} | ||
|
||
void DuckDBPyConnection::Interrupt() { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
python_table_udf.cpp has my preference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For types like PythonTVFType, would you prefer PythonTUDF, or PythonUDTF, or PythonTableUDF?
For what it's worth, SnowFlake and DataBrix have gone with UDTF: https://docs.snowflake.com/en/developer-guide/udf/sql/udf-sql-tabular-functions and https://docs.databricks.com/aws/en/udf/python-udtf
I don't like TUDF as an abbreviation, but UDTF or TableUDF both sound good to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My preference is table udf
So a search for "udf" finds both versions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done: renamed files to "table_udf" and in code to TableUDF