Skip to content

Commit d935c17

Browse files
authored
feat: add NumPy array support for vector representations (#586)
* chore: update example code to use NDArray for vectors * feat: add numpy array support for vector representations * feat: refactor vector decoding with DtypeRegistry for NumPy * test: add unit tests for numpy array support * Update docs to reflect NDArray support with NumPy dtypes * feat: add NDArray support for `Vector[T]` with `list[T]` fallback, optimize pgvector queries * feat: update engine value encoding to return ndarray directly
1 parent 48a0331 commit d935c17

File tree

11 files changed

+908
-39
lines changed

11 files changed

+908
-39
lines changed

docs/docs/core/data_types.mdx

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,17 @@ This is the list of all basic types supported by CocoIndex:
3636
| LocalDatetime | Date and time without timezone | `cocoindex.LocalDateTime` | `datetime.datetime` |
3737
| OffsetDatetime | Date and time with a timezone offset | `cocoindex.OffsetDateTime` | `datetime.datetime` |
3838
| TimeDelta | A duration of time | `datetime.timedelta` | `datetime.timedelta` |
39-
| Vector[*T*, *Dim*?] | *T* must be basic type. *Dim* is a positive integer and optional. |`cocoindex.Vector[T]` or `cocoindex.Vector[T, Dim]` | `list[T]` |
4039
| Json | | `cocoindex.Json` | Any data convertible to JSON by `json` package |
40+
| Vector[*T*, *Dim*?] | *T* can be a basic type or a numeric type. *Dim* is a positive integer and optional. | `cocoindex.Vector[T]` or `cocoindex.Vector[T, Dim]` | `numpy.typing.NDArray[T]` or `list[T]` |
4141

4242
Values of all data types can be represented by values in Python's native types (as described under the Native Python Type column).
4343
However, the underlying execution engine and some storage system (like Postgres) has finer distinctions for some types, specifically:
4444

4545
* *Float32* and *Float64* for `float`, with different precision.
4646
* *LocalDateTime* and *OffsetDateTime* for `datetime.datetime`, with different timezone awareness.
47-
* *Vector* has optional dimension information.
4847
* *Range* and *Json* provide a clear tag for the type, to clearly distinguish the type in CocoIndex.
48+
* *Vector* holds elements of type *T*. If *T* is numeric (e.g., `np.float32` or `np.float64`), it's represented as `NDArray[T]`; otherwise, as `list[T]`.
49+
* *Vector* also has optional dimension information.
4950

5051
The native Python type is always more permissive and can represent a superset of possible values.
5152
* Only when you annotate the return type of a custom function, you should use the specific type,

docs/docs/getting_started/quickstart.md

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -154,11 +154,11 @@ The goal of transforming your data is usually to query against it.
154154
Once you already have your index built, you can directly access the transformed data in the target database.
155155
CocoIndex also provides utilities for you to do this more seamlessly.
156156
157-
In this example, we'll use the [`psycopg` library](https://www.psycopg.org/) to connect to the database and run queries.
158-
Please make sure it's installed:
157+
In this example, we'll use the [`psycopg` library](https://www.psycopg.org/) along with pgvector to connect to the database and run queries on vector data.
158+
Please make sure the required packages are installed:
159159

160160
```bash
161-
pip install psycopg[binary,pool]
161+
pip install numpy psycopg[binary,pool] pgvector
162162
```
163163

164164
### Step 4.1: Extract common transformations
@@ -169,8 +169,11 @@ i.e. they should use exactly the same embedding model and parameters.
169169
Let's extract that into a function:
170170
171171
```python title="quickstart.py"
172+
from numpy.typing import NDArray
173+
import numpy as np
174+
172175
@cocoindex.transform_flow()
173-
def text_to_embedding(text: cocoindex.DataSlice[str]) -> cocoindex.DataSlice[list[float]]:
176+
def text_to_embedding(text: cocoindex.DataSlice[str]) -> cocoindex.DataSlice[NDArray[np.float32]]:
174177
return text.transform(
175178
cocoindex.functions.SentenceTransformerEmbed(
176179
model="sentence-transformers/all-MiniLM-L6-v2"))
@@ -207,6 +210,7 @@ Now we can create a function to query the index upon a given input query:
207210
208211
```python title="quickstart.py"
209212
from psycopg_pool import ConnectionPool
213+
from pgvector.psycopg import register_vector
210214
211215
def search(pool: ConnectionPool, query: str, top_k: int = 5):
212216
# Get the table name, for the export target in the text_embedding_flow above.
@@ -215,9 +219,10 @@ def search(pool: ConnectionPool, query: str, top_k: int = 5):
215219
query_vector = text_to_embedding.eval(query)
216220
# Run the query and get the results.
217221
with pool.connection() as conn:
222+
register_vector(conn)
218223
with conn.cursor() as cur:
219224
cur.execute(f"""
220-
SELECT filename, text, embedding <=> %s::vector AS distance
225+
SELECT filename, text, embedding <=> %s AS distance
221226
FROM {table_name} ORDER BY distance LIMIT %s
222227
""", (query_vector, top_k))
223228
return [
@@ -236,7 +241,7 @@ There're two CocoIndex-specific logic:
236241
237242
2. Evaluate the transform flow defined above with the input query, to get the embedding.
238243
It's done by the `eval()` method of the transform flow `text_to_embedding`.
239-
The return type of this method is `list[float]` as declared in the `text_to_embedding()` function (`cocoindex.DataSlice[list[float]]`).
244+
The return type of this method is `NDArray[np.float32]` as declared in the `text_to_embedding()` function (`cocoindex.DataSlice[NDArray[np.float32]]`).
240245
241246
### Step 4.3: Add the main script logic
242247

docs/docs/query.mdx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ The [quickstart](getting_started/quickstart#step-41-extract-common-transformatio
4141

4242
```python
4343
@cocoindex.transform_flow()
44-
def text_to_embedding(text: cocoindex.DataSlice[str]) -> cocoindex.DataSlice[list[float]]:
44+
def text_to_embedding(text: cocoindex.DataSlice[str]) -> cocoindex.DataSlice[NDArray[np.float32]]:
4545
return text.transform(
4646
cocoindex.functions.SentenceTransformerEmbed(
4747
model="sentence-transformers/all-MiniLM-L6-v2"))
@@ -61,7 +61,7 @@ with doc["chunks"].row() as chunk:
6161
chunk["embedding"] = chunk["text"].call(text_to_embedding)
6262
```
6363

64-
Any time, you can call the `eval()` method with specific string, which will return a `list[float]`:
64+
Any time, you can call the `eval()` method with specific string, which will return a `NDArray[np.float32]`:
6565

6666
```python
6767
print(text_to_embedding.eval("Hello, world!"))
@@ -93,7 +93,7 @@ For example:
9393

9494
```python
9595
table_name = cocoindex.utils.get_target_storage_default_name(text_embedding_flow, "doc_embeddings")
96-
query = f"SELECT filename, text FROM {table_name} ORDER BY embedding <=> %s::vector DESC LIMIT 5"
96+
query = f"SELECT filename, text FROM {table_name} ORDER BY embedding <=> %s DESC LIMIT 5"
9797
...
9898
```
9999

examples/text_embedding/Text_Embedding.ipynb

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
},
4646
"outputs": [],
4747
"source": [
48-
"%pip install cocoindex python-dotenv psycopg[binary,pool]"
48+
"%pip install cocoindex numpy python-dotenv psycopg[binary,pool] pgvector"
4949
]
5050
},
5151
{
@@ -164,7 +164,10 @@
164164
"from dotenv import load_dotenv\n",
165165
"import os\n",
166166
"from psycopg_pool import ConnectionPool\n",
167-
"import cocoindex\n"
167+
"from pgvector.psycopg import register_vector\n",
168+
"import cocoindex\n",
169+
"from numpy.typing import NDArray\n",
170+
"import numpy as np\n"
168171
]
169172
},
170173
{
@@ -187,7 +190,7 @@
187190
"%%writefile -a main.py\n",
188191
"\n",
189192
"@cocoindex.transform_flow()\n",
190-
"def text_to_embedding(text: cocoindex.DataSlice[str]) -> cocoindex.DataSlice[list[float]]:\n",
193+
"def text_to_embedding(text: cocoindex.DataSlice[str]) -> cocoindex.DataSlice[NDArray[np.float32]]:\n",
191194
" \"\"\"\n",
192195
" Embed the text using a SentenceTransformer model.\n",
193196
" This is shared logic between indexing and querying.\n",
@@ -274,9 +277,10 @@
274277
" query_vector = text_to_embedding.eval(query)\n",
275278
" # Run the query and get the results.\n",
276279
" with pool.connection() as conn:\n",
280+
" register_vector(conn)\n",
277281
" with conn.cursor() as cur:\n",
278282
" cur.execute(f\"\"\"\n",
279-
" SELECT filename, text, embedding <=> %s::vector AS distance\n",
283+
" SELECT filename, text, embedding <=> %s AS distance\n",
280284
" FROM {table_name} ORDER BY distance LIMIT %s\n",
281285
" \"\"\", (query_vector, top_k))\n",
282286
" return [\n",

examples/text_embedding/main.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
from dotenv import load_dotenv
22
from psycopg_pool import ConnectionPool
3+
from pgvector.psycopg import register_vector
34
import cocoindex
45
import os
6+
from numpy.typing import NDArray
7+
import numpy as np
58

69

710
@cocoindex.transform_flow()
811
def text_to_embedding(
912
text: cocoindex.DataSlice[str],
10-
) -> cocoindex.DataSlice[list[float]]:
13+
) -> cocoindex.DataSlice[NDArray[np.float32]]:
1114
"""
1215
Embed the text using a SentenceTransformer model.
1316
This is a shared logic between indexing and querying, so extract it as a function.
@@ -71,10 +74,11 @@ def search(pool: ConnectionPool, query: str, top_k: int = 5):
7174
query_vector = text_to_embedding.eval(query)
7275
# Run the query and get the results.
7376
with pool.connection() as conn:
77+
register_vector(conn)
7478
with conn.cursor() as cur:
7579
cur.execute(
7680
f"""
77-
SELECT filename, text, embedding <=> %s::vector AS distance
81+
SELECT filename, text, embedding <=> %s AS distance
7882
FROM {table_name} ORDER BY distance LIMIT %s
7983
""",
8084
(query_vector, top_k),

examples/text_embedding/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ requires-python = ">=3.10"
66
dependencies = [
77
"cocoindex>=0.1.42",
88
"python-dotenv>=1.0.1",
9+
"pgvector>=0.4.1",
910
"psycopg[binary,pool]",
11+
"numpy",
1012
]
1113

1214
[tool.setuptools]

python/cocoindex/convert.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import datetime
77
import inspect
88
import uuid
9+
import numpy as np
910

1011
from enum import Enum
1112
from typing import Any, Callable, get_origin, Mapping
@@ -15,6 +16,7 @@
1516
is_namedtuple_type,
1617
TABLE_TYPES,
1718
KEY_FIELD_NAME,
19+
DtypeRegistry,
1820
)
1921

2022

@@ -27,6 +29,8 @@ def encode_engine_value(value: Any) -> Any:
2729
]
2830
if is_namedtuple_type(type(value)):
2931
return [encode_engine_value(getattr(value, name)) for name in value._fields]
32+
if isinstance(value, np.ndarray):
33+
return value
3034
if isinstance(value, (list, tuple)):
3135
return [encode_engine_value(v) for v in value]
3236
if isinstance(value, dict):
@@ -122,6 +126,38 @@ def decode(value: Any) -> Any | None:
122126
if src_type_kind == "Uuid":
123127
return lambda value: uuid.UUID(bytes=value)
124128

129+
if src_type_kind == "Vector":
130+
elem_coco_type_info = analyze_type_info(dst_type_info.elem_type)
131+
dtype_info = DtypeRegistry.get_by_kind(elem_coco_type_info.kind)
132+
133+
def decode_vector(value: Any) -> Any | None:
134+
if value is None:
135+
if dst_type_info.nullable:
136+
return None
137+
raise ValueError(
138+
f"Received null for non-nullable vector `{''.join(field_path)}`"
139+
)
140+
141+
if not isinstance(value, (np.ndarray, list)):
142+
raise TypeError(
143+
f"Expected NDArray or list for vector `{''.join(field_path)}`, got {type(value)}"
144+
)
145+
expected_dim = (
146+
dst_type_info.vector_info.dim if dst_type_info.vector_info else None
147+
)
148+
if expected_dim is not None and len(value) != expected_dim:
149+
raise ValueError(
150+
f"Vector dimension mismatch for `{''.join(field_path)}`: "
151+
f"expected {expected_dim}, got {len(value)}"
152+
)
153+
154+
# Use NDArray for supported numeric dtypes, else return list
155+
if dtype_info is not None:
156+
return np.array(value, dtype=dtype_info.numpy_dtype)
157+
return value
158+
159+
return decode_vector
160+
125161
return lambda value: value
126162

127163

python/cocoindex/functions.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""All builtin functions."""
22

3-
from typing import Annotated, Any, TYPE_CHECKING
3+
from typing import Annotated, Any, TYPE_CHECKING, Literal
4+
import numpy as np
5+
from numpy.typing import NDArray
46
import dataclasses
57

68
from .typing import Float32, Vector, TypeAttr
@@ -66,11 +68,11 @@ def analyze(self, text: Any) -> type:
6668
self._model = sentence_transformers.SentenceTransformer(self.spec.model, **args)
6769
dim = self._model.get_sentence_embedding_dimension()
6870
result: type = Annotated[
69-
Vector[Float32, dim], # type: ignore
71+
Vector[np.float32, Literal[dim]], # type: ignore
7072
TypeAttr("cocoindex.io/vector_origin_text", text.analyzed_value),
7173
]
7274
return result
7375

74-
def __call__(self, text: str) -> list[Float32]:
75-
result: list[Float32] = self._model.encode(text).tolist()
76+
def __call__(self, text: str) -> NDArray[np.float32]:
77+
result: NDArray[np.float32] = self._model.encode(text, convert_to_numpy=True)
7678
return result

0 commit comments

Comments
 (0)