Skip to content

Commit 1e2a62d

Browse files
committed
Support VECTOR(N) type
1 parent 34913e3 commit 1e2a62d

File tree

2 files changed

+64
-2
lines changed

2 files changed

+64
-2
lines changed

python/databend_udf/udf.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,7 +1179,7 @@ def _input_process_func(field: pa.Field) -> Callable:
11791179
- Json=pa.large_binary(): bytes -> Any
11801180
- Map=pa.map_(): list[tuple(k,v)] -> dict
11811181
"""
1182-
if pa.types.is_list(field.type):
1182+
if pa.types.is_list(field.type) or pa.types.is_fixed_size_list(field.type):
11831183
func = _input_process_func(field.type.value_field)
11841184
return (
11851185
lambda array: [func(v) if v is not None else None for v in array]
@@ -1225,7 +1225,7 @@ def _output_process_func(field: pa.Field) -> Callable:
12251225
- Json=pa.large_binary(): Any -> str
12261226
- Map=pa.map_(): dict -> list[tuple(k,v)]
12271227
"""
1228-
if pa.types.is_list(field.type):
1228+
if pa.types.is_list(field.type) or pa.types.is_fixed_size_list(field.type):
12291229
func = _output_process_func(field.type.value_field)
12301230
return (
12311231
lambda array: [func(v) if v is not None else None for v in array]
@@ -1402,6 +1402,10 @@ def _type_str_to_arrow_field_inner(type_str: str) -> pa.Field:
14021402
type_str = type_str.strip()
14031403
fields.append(_type_str_to_arrow_field_inner(type_str))
14041404
return pa.field("", pa.struct(fields), False)
1405+
elif type_str.startswith("VECTOR"):
1406+
# VECTOR(1024)
1407+
dim = int(type_str[6:].strip("()").strip())
1408+
return pa.field("", pa.list_(pa.float32(), dim), False)
14051409
else:
14061410
raise ValueError(f"Unsupported type: {type_str}")
14071411

@@ -1460,6 +1464,8 @@ def _field_type_to_string(field: pa.Field) -> str:
14601464
return "VARIANT"
14611465
else:
14621466
return "BINARY"
1467+
elif pa.types.is_fixed_size_list(t):
1468+
return f"VECTOR({t.list_size})"
14631469
elif pa.types.is_list(t):
14641470
return f"ARRAY({_inner_field_to_string(t.value_field)})"
14651471
elif pa.types.is_map(t):

python/tests/test_vector_type.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import pyarrow as pa
2+
import pytest
3+
from databend_udf.udf import (
4+
_type_str_to_arrow_field,
5+
_field_type_to_string,
6+
_input_process_func,
7+
_output_process_func,
8+
_arrow_field_to_string,
9+
)
10+
11+
def test_vector_sql_generation():
12+
# Test nullable VECTOR (default)
13+
field = _type_str_to_arrow_field("VECTOR(1024)")
14+
sql_type = _arrow_field_to_string(field)
15+
assert sql_type == "VECTOR(1024)"
16+
17+
# Test NOT NULL VECTOR
18+
field_not_null = _type_str_to_arrow_field("VECTOR(1024) NOT NULL")
19+
sql_type_not_null = _arrow_field_to_string(field_not_null)
20+
assert sql_type_not_null == "VECTOR(1024) NOT NULL"
21+
22+
def test_vector_type_parsing():
23+
field = _type_str_to_arrow_field("VECTOR(1024)")
24+
assert pa.types.is_fixed_size_list(field.type)
25+
assert field.type.list_size == 1024
26+
assert pa.types.is_float32(field.type.value_type)
27+
assert field.nullable is True
28+
29+
def test_vector_type_formatting():
30+
field = pa.field("", pa.list_(pa.float32(), 1024), nullable=True)
31+
type_str = _field_type_to_string(field)
32+
assert type_str == "VECTOR(1024)"
33+
34+
def test_vector_input_processing():
35+
field = pa.field("", pa.list_(pa.float32(), 3), nullable=True)
36+
func = _input_process_func(field)
37+
38+
# Input is a list of floats
39+
input_data = [1.0, 2.0, 3.0]
40+
result = func(input_data)
41+
assert result == [1.0, 2.0, 3.0]
42+
43+
# Input is None
44+
assert func(None) is None
45+
46+
def test_vector_output_processing():
47+
field = pa.field("", pa.list_(pa.float32(), 3), nullable=True)
48+
func = _output_process_func(field)
49+
50+
# Output is a list of floats
51+
output_data = [1.0, 2.0, 3.0]
52+
result = func(output_data)
53+
assert result == [1.0, 2.0, 3.0]
54+
55+
# Output is None
56+
assert func(None) is None

0 commit comments

Comments
 (0)