Skip to content

Commit a53b431

Browse files
committed
Add support for custom SQL functions with DuckDB, closes #1047
1 parent 2c937db commit a53b431

File tree

5 files changed

+46
-8
lines changed

5 files changed

+46
-8
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555

5656
extras["console"] = ["rich>=12.0.1"]
5757

58-
extras["database"] = ["duckdb>=0.7.1", "pillow>=7.1.2", "sqlalchemy>=2.0.20"]
58+
extras["database"] = ["duckdb>=0.8.0", "pillow>=7.1.2", "sqlalchemy>=2.0.20"]
5959

6060
extras["graph"] = ["grand-cypher>=0.6.0", "grand-graph>=0.6.0", "networkx>=2.7.1", "sqlalchemy>=2.0.20"]
6161

src/python/txtai/database/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -263,11 +263,11 @@ def registerfunctions(self, config):
263263
if inputs:
264264
functions = []
265265
for fn in inputs:
266-
name, argcount = None, -1
266+
name, argcount, deterministic = None, -1, None
267267

268268
# Optional function configuration
269269
if isinstance(fn, dict):
270-
name, argcount, fn = fn.get("name"), fn.get("argcount", -1), fn["function"]
270+
name, argcount, fn, deterministic = (fn.get("name"), fn.get("argcount", -1), fn["function"], fn.get("deterministic"))
271271

272272
# Determine if this is a callable object or a function
273273
if not isinstance(fn, types.FunctionType) and hasattr(fn, "__call__"):
@@ -277,7 +277,7 @@ def registerfunctions(self, config):
277277
name = name if name else fn.__name__.lower()
278278

279279
# Store function details
280-
functions.append((name, argcount, fn))
280+
functions.append((name, argcount, fn, deterministic))
281281

282282
# pylint: disable=W0201
283283
self.functions = functions

src/python/txtai/database/duckdb.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
import re
77

8+
from typing import get_type_hints
89
from tempfile import TemporaryDirectory
910

1011
# Conditional import
@@ -80,8 +81,15 @@ def rows(self):
8081
rows = self.cursor.fetchmany(batch)
8182

8283
def addfunctions(self):
83-
# DuckDB doesn't currently support scalar functions
84-
return
84+
if self.connection and self.functions:
85+
for name, _, fn, deterministic in self.functions:
86+
# Get function type hints
87+
hints = get_type_hints(fn)
88+
89+
# Create database functions
90+
self.connection.create_function(
91+
name, fn, return_type=hints.get("return", str), side_effects=not deterministic if deterministic is not None else False
92+
)
8593

8694
def copy(self, path):
8795
# Delete existing file, if necessary

src/python/txtai/database/sqlite.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,9 @@ def addfunctions(self):
3434
# Enable callback tracebacks to show user-defined function errors
3535
sqlite3.enable_callback_tracebacks(True)
3636

37-
for name, argcount, fn in self.functions:
38-
self.connection.create_function(name, argcount, fn)
37+
# Create database functions
38+
for name, argcount, fn, deterministic in self.functions:
39+
self.connection.create_function(name, argcount, fn, deterministic=deterministic if deterministic is not None else False)
3940

4041
def copy(self, path):
4142
# Delete existing file, if necessary

test/python/testdatabase/testduckdb.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,32 @@ def testArchive(self):
5353
"""
5454

5555
super().testArchive()
56+
57+
def testFunction(self):
58+
"""
59+
Test custom functions
60+
"""
61+
62+
embeddings = Embeddings(
63+
{
64+
"path": "sentence-transformers/nli-mpnet-base-v2",
65+
"content": self.backend,
66+
"functions": [{"name": "length", "function": "testdatabase.testduckdb.length"}],
67+
}
68+
)
69+
70+
# Create an index for the list of text
71+
embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
72+
73+
# Search for best match
74+
result = embeddings.search("select length(text) length from txtai where id = 0", 1)[0]
75+
76+
self.assertEqual(result["length"], 39)
77+
78+
79+
def length(text):
80+
"""
81+
Custom SQL function.
82+
"""
83+
84+
return len(text)

0 commit comments

Comments
 (0)