Skip to content

Commit ce2ae5e

Browse files
authored
Expose C++ TemporalPolicy in Python (#362)
1 parent 5742661 commit ce2ae5e

File tree

3 files changed

+90
-0
lines changed

3 files changed

+90
-0
lines changed

apis/python/src/tiledb/vector_search/type_erased_module.cc

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,37 @@ void init_type_erased_module(py::module_& m) {
120120
}))
121121
;
122122
#endif
123+
py::class_<TemporalPolicy>(m, "TemporalPolicy", py::buffer_protocol())
124+
// From 0 to UINT64_MAX.
125+
.def(py::init<>())
126+
// From 0 to timestamp_end.
127+
.def(
128+
"__init__",
129+
[](TemporalPolicy& instance,
130+
std::optional<uint64_t> timestamp_end_input) {
131+
uint64_t timestamp_end = timestamp_end_input.has_value() ?
132+
timestamp_end_input.value() :
133+
UINT64_MAX;
134+
new (&instance) TemporalPolicy(TimeTravel, timestamp_end);
135+
})
136+
// From timestamp_start to timestamp_end.
137+
.def(
138+
"__init__",
139+
[](TemporalPolicy& instance,
140+
std::optional<uint64_t> timestamp_start_input,
141+
std::optional<uint64_t> timestamp_end_input) {
142+
uint64_t timestamp_start = timestamp_start_input.has_value() ?
143+
timestamp_start_input.value() :
144+
0;
145+
uint64_t timestamp_end = timestamp_end_input.has_value() ?
146+
timestamp_end_input.value() :
147+
UINT64_MAX;
148+
new (&instance) TemporalPolicy(
149+
TimestampStartEnd, timestamp_start, timestamp_end);
150+
})
151+
.def("timestamp_start", &TemporalPolicy::timestamp_start)
152+
.def("timestamp_end", &TemporalPolicy::timestamp_end);
153+
123154
py::class_<FeatureVector>(m, "FeatureVector", py::buffer_protocol())
124155
.def(
125156
py::init<const tiledb::Context&, const std::string&>(),

apis/python/src/tiledb/vector_search/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import io
2+
from typing import Optional
23

34
import numpy as np
45

56
import tiledb
7+
from tiledb.vector_search import _tiledbvspy as vspy
68

79

810
def is_type_erased_index(index_type: str) -> bool:
@@ -21,6 +23,19 @@ def add_to_group(group, uri, name):
2123
group.add(name, name=name, relative=True)
2224

2325

26+
def to_temporal_policy(timestamp) -> Optional[vspy.TemporalPolicy]:
27+
temporal_policy = None
28+
if isinstance(timestamp, tuple):
29+
if len(timestamp) != 2:
30+
raise ValueError(
31+
"'timestamp' argument expects either int or tuple(start: int, end: int)"
32+
)
33+
temporal_policy = vspy.TemporalPolicy(timestamp[0], timestamp[1])
34+
elif timestamp is not None:
35+
temporal_policy = vspy.TemporalPolicy(timestamp)
36+
return temporal_policy
37+
38+
2439
def _load_vecs_t(uri, dtype, ctx_or_config=None):
2540
with tiledb.scope_ctx(ctx_or_config) as ctx:
2641
dtype = np.dtype(dtype)

apis/python/test/test_type_erased_module.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from tiledb.vector_search import _tiledbvspy as vspy
77
from tiledb.vector_search.utils import load_fvecs
8+
from tiledb.vector_search.utils import to_temporal_policy
89

910
ctx = vspy.Ctx({})
1011

@@ -152,6 +153,49 @@ def test_numpy_to_feature_vector_array():
152153
assert np.array_equal(a, np.transpose(np.array(b)))
153154

154155

156+
def test_TemporalPolicy():
157+
temporal_policy = vspy.TemporalPolicy()
158+
assert temporal_policy.timestamp_start() == 0
159+
assert temporal_policy.timestamp_end() == np.iinfo(np.uint64).max
160+
161+
temporal_policy = vspy.TemporalPolicy(99)
162+
assert temporal_policy.timestamp_start() == 0
163+
assert temporal_policy.timestamp_end() == 99
164+
165+
temporal_policy = vspy.TemporalPolicy(12, 99)
166+
assert temporal_policy.timestamp_start() == 12
167+
assert temporal_policy.timestamp_end() == 99
168+
169+
170+
def test_TemporalPolicy_from_timestamp():
171+
temporal_policy = to_temporal_policy(None)
172+
assert temporal_policy is None
173+
174+
temporal_policy = to_temporal_policy(3)
175+
assert temporal_policy.timestamp_start() == 0
176+
assert temporal_policy.timestamp_end() == 3
177+
178+
temporal_policy = to_temporal_policy((0, 33))
179+
assert temporal_policy.timestamp_start() == 0
180+
assert temporal_policy.timestamp_end() == 33
181+
182+
temporal_policy = to_temporal_policy((1, 33))
183+
assert temporal_policy.timestamp_start() == 1
184+
assert temporal_policy.timestamp_end() == 33
185+
186+
temporal_policy = to_temporal_policy((None, 333))
187+
assert temporal_policy.timestamp_start() == 0
188+
assert temporal_policy.timestamp_end() == 333
189+
190+
temporal_policy = to_temporal_policy((3333, None))
191+
assert temporal_policy.timestamp_start() == 3333
192+
assert temporal_policy.timestamp_end() == np.iinfo(np.uint64).max
193+
194+
temporal_policy = to_temporal_policy((None, None))
195+
assert temporal_policy.timestamp_start() == 0
196+
assert temporal_policy.timestamp_end() == np.iinfo(np.uint64).max
197+
198+
155199
def test_construct_IndexFlatL2():
156200
a = vspy.IndexFlatL2(ctx, siftsmall_inputs_uri)
157201
assert a.feature_type_string() == "float32"

0 commit comments

Comments
 (0)