Skip to content

Commit 157d86a

Browse files
committed
Add deterministic setting to database functions closes #1048. Add ability to index database expressions #1049
1 parent 696fb5c commit 157d86a

File tree

8 files changed

+87
-20
lines changed

8 files changed

+87
-20
lines changed

docs/embeddings/configuration/general.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,13 @@ Sets the auto id generation method. When this is not set, an autogenerated numer
6666
columns:
6767
text: name of the text column
6868
object: name of the object column
69+
store: limit json data fields to this list of columns
6970
```
7071

7172
Sets the `text` and `object` column names. Defaults to `text` and `object` if not provided.
7273

74+
`store` sets a list of columns to store in the JSON data field. When this isn't provided, all columns are stored (default). When `store` is set to `None`, no JSON columns are stored. This is useful is a field is only needed at indexing time but not search time.
75+
7376
## format
7477
```yaml
7578
format: json|pickle

docs/embeddings/indexing.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ As mentioned above, computed vectors are stored in an ANN. There are various ind
2323

2424
Embeddings indexes can optionally [store content](../configuration/database#content). When this is enabled, the input content is saved in a database alongside the computed vectors. This enables filtering on additional fields and content retrieval.
2525

26+
The columns used for text, object and JSON data storage are set via [column configuration](../configuration/general#columns).
27+
2628
## Index vs Upsert
2729

2830
Data is loaded into an index with either an [index](../methods#txtai.embeddings.base.Embeddings.index) or [upsert](../methods#txtai.embeddings.base.Embeddings.upsert) call.

src/python/txtai/database/duckdb.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,7 @@ def rows(self):
8181
rows = self.cursor.fetchmany(batch)
8282

8383
def addfunctions(self):
84-
if self.connection and self.functions:
85-
for name, _, fn, deterministic in self.functions:
86-
# Get function type hints
87-
hints = get_type_hints(fn)
88-
89-
# Create database functions
90-
self.connection.create_function(
91-
name, fn, return_type=hints.get("return", str), side_effects=not deterministic if deterministic is not None else False
92-
)
84+
self.loadfunctions(self.connection)
9385

9486
def copy(self, path):
9587
# Delete existing file, if necessary
@@ -116,8 +108,14 @@ def copy(self, path):
116108
for table in tables:
117109
connection.execute(f"COPY {table} FROM '{directory}/{table}.parquet' (FORMAT parquet)")
118110

119-
# Create indexes and sync data to database file
120-
connection.execute(Statement.CREATE_SECTIONS_INDEX)
111+
# Copy functions
112+
self.loadfunctions(connection)
113+
114+
# Copy indexes
115+
for (sql,) in self.connection.execute("SELECT sql FROM duckdb_indexes()").fetchall():
116+
connection.execute(sql)
117+
118+
# Sync data to database file
121119
connection.execute("CHECKPOINT")
122120

123121
# Start transaction
@@ -156,3 +154,24 @@ def formatargs(self, args):
156154
args = (query, [value for _, value in sorted(params, key=lambda x: x[0])])
157155

158156
return args
157+
158+
def loadfunctions(self, connection):
159+
"""
160+
Load database functions.
161+
162+
Args:
163+
connection: connection to create functions
164+
"""
165+
166+
if self.functions and connection:
167+
for name, _, fn, deterministic in self.functions:
168+
# Create function if it doesn't already exist
169+
result = connection.execute("SELECT 1 FROM duckdb_functions() WHERE function_name = ?", [name]).fetchone()
170+
if not result:
171+
# Get function type hints
172+
hints = get_type_hints(fn)
173+
174+
# Create database functions
175+
connection.create_function(
176+
name, fn, return_type=hints.get("return", str), side_effects=not deterministic if deterministic is not None else False
177+
)

src/python/txtai/database/rdbms.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def resolve(self, name, alias=None):
145145

146146
# Resolve expression
147147
if self.expressions and name in self.expressions:
148-
return self.expressions[name]
148+
return self.expressions[name]["expression"]
149149

150150
# Name is already resolved, skip
151151
if name.startswith(self.jsonprefix()) or any(f"s.{s}" == name for s in sections):
@@ -251,6 +251,9 @@ def initialize(self):
251251
# Create initial table schema
252252
self.createtables()
253253

254+
# Create indexes
255+
self.createindexes()
256+
254257
def session(self, path=None, connection=None):
255258
"""
256259
Starts a new database session.
@@ -281,6 +284,23 @@ def createtables(self):
281284
self.cursor.execute(Statement.CREATE_SECTIONS % "sections")
282285
self.cursor.execute(Statement.CREATE_SECTIONS_INDEX)
283286

287+
def createindexes(self):
288+
"""
289+
Creates expression indexes
290+
"""
291+
292+
if self.expressions:
293+
for key, values in self.expressions.items():
294+
# Create index for expression, if enabled
295+
if values["index"]:
296+
# Get parameters
297+
name = f"expression_{key}".lower()
298+
expression = values["expression"]
299+
table = "documents" if expression.startswith(self.jsonprefix()) else "sections"
300+
301+
# Execute statement
302+
self.cursor.execute(Statement.CREATE_EXPRESSION_INDEX % (name, table, expression))
303+
284304
def finalize(self):
285305
"""
286306
Post processing logic run after inserting a batch of documents. Default method is no-op.
@@ -306,9 +326,13 @@ def loaddocument(self, uid, document, tags, entry):
306326
# Get and remove object field from document
307327
obj = document.pop(self.object) if self.object in document else None
308328

309-
# Insert document as JSON
310329
if document:
311-
self.insertdocument(uid, json.dumps(document, allow_nan=False), tags, entry)
330+
# Apply data filters, if necessary
331+
data = {key: value for key, value in document.items() if key in self.store} if self.store is not None else document
332+
333+
# Insert document as JSON
334+
if data:
335+
self.insertdocument(uid, json.dumps(data, allow_nan=False), tags, entry)
312336

313337
# If text and object are both available, load object as it won't otherwise be used
314338
if self.text in document and obj:

src/python/txtai/database/schema/statement.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,6 @@ class Statement:
9696
+ "LEFT JOIN scores sc ON s.indexid = sc.indexid"
9797
)
9898
IDS_CLAUSE = "s.indexid in (SELECT indexid from batch WHERE batch=%s)"
99+
100+
# Expression indexes
101+
CREATE_EXPRESSION_INDEX = "CREATE INDEX IF NOT EXISTS %s ON %s(%s)"

test/python/testdatabase/testduckdb.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,17 +63,17 @@ def testFunction(self):
6363
{
6464
"path": "sentence-transformers/nli-mpnet-base-v2",
6565
"content": self.backend,
66-
"functions": [{"name": "length", "function": "testdatabase.testduckdb.length"}],
66+
"functions": [{"name": "textlength", "function": "testdatabase.testduckdb.length"}],
6767
}
6868
)
6969

7070
# Create an index for the list of text
7171
embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
7272

7373
# Search for best match
74-
result = embeddings.search("select length(text) length from txtai where id = 0", 1)[0]
74+
result = embeddings.search("select textlength(text) length from txtai where id = 0", 1)[0]
7575

76-
self.assertEqual(result["length"], 39)
76+
self.assertEqual(int(result["length"]), 39)
7777

7878

7979
def length(text):

test/python/testdatabase/testrdbms.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,22 @@ def testExplainEmpty(self):
247247

248248
self.assertEqual(self.embeddings.explain("select * from txtai limit 1")[0]["id"], "0")
249249

250+
def testExpressions(self):
251+
"""
252+
Test expressions
253+
"""
254+
255+
# Test indexed expressions
256+
embeddings = Embeddings(
257+
path="sentence-transformers/nli-mpnet-base-v2",
258+
content=self.backend,
259+
expressions=[{"name": "textlength", "expression": "length(text)", "index": True}],
260+
)
261+
embeddings.index(self.data)
262+
263+
result = embeddings.search("SELECT textlength FROM txtai WHERE id = 0", 1)[0]
264+
self.assertEqual(result["textlength"], len(self.data[0]))
265+
250266
def testGenerator(self):
251267
"""
252268
Test index with a generator

test/python/testdatabase/testsqlite.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,17 +52,17 @@ def testFunction(self):
5252
{
5353
"path": "sentence-transformers/nli-mpnet-base-v2",
5454
"content": self.backend,
55-
"functions": [{"name": "length", "function": "testdatabase.testsqlite.length"}],
55+
"functions": [{"name": "textlength", "function": "testdatabase.testsqlite.length"}],
5656
}
5757
)
5858

5959
# Create an index for the list of text
6060
embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
6161

6262
# Search for best match
63-
result = embeddings.search("select length(text) length from txtai where id = 0", 1)[0]
63+
result = embeddings.search("select textlength(text) length from txtai where id = 0", 1)[0]
6464

65-
self.assertEqual(result["length"], 39)
65+
self.assertEqual(int(result["length"]), 39)
6666

6767

6868
def length(text):

0 commit comments

Comments
 (0)