Skip to content

Commit 9e2aefc

Browse files
authored
Run pre-commit to lint files (#305)
1 parent 226b3ea commit 9e2aefc

File tree

17 files changed

+128
-111
lines changed

17 files changed

+128
-111
lines changed

.github/workflows/build-wheels.yml

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@ on:
55
push:
66
branches:
77
- release-*
8-
- '*wheel*' # must quote since "*" is a YAML reserved character; we want a string
8+
- "*wheel*" # must quote since "*" is a YAML reserved character; we want a string
99
tags:
10-
- '*'
10+
- "*"
1111
pull_request:
1212
branches:
13-
- '*wheel*' # must quote since "*" is a YAML reserved character; we want a string
13+
- "*wheel*" # must quote since "*" is a YAML reserved character; we want a string
1414

1515
jobs:
1616
generate_backwards_compatibility_data:
@@ -19,7 +19,7 @@ jobs:
1919
steps:
2020
- name: Checkout code
2121
uses: actions/checkout@v3
22-
22+
2323
# Based on https://github.com/TileDB-Inc/conda-forge-nightly-controller/blob/51519a0f8340b32cf737fcb59b76c6a91c42dc47/.github/workflows/activity.yml#L19C10-L19C10
2424
- name: Setup git
2525
run: |
@@ -92,19 +92,19 @@ jobs:
9292
strategy:
9393
matrix:
9494
buildplat:
95-
- [ ubuntu-22.04, manylinux_x86_64 ]
96-
- [ macos-13, macosx_x86_64 ]
97-
- [ macos-13, macosx_arm64 ]
98-
- [ windows-2022, win_amd64 ]
99-
python: [ "cp39", "cp310", "cp311", "cp312", "pp39" ]
95+
- [ubuntu-22.04, manylinux_x86_64]
96+
- [macos-13, macosx_x86_64]
97+
- [macos-13, macosx_arm64]
98+
- [windows-2022, win_amd64]
99+
python: ["cp39", "cp310", "cp311", "cp312", "pp39"]
100100
exclude:
101-
- buildplat: [ macos-13, macosx_arm64 ]
101+
- buildplat: [macos-13, macosx_arm64]
102102
python: "pp39"
103103

104104
steps:
105105
- uses: actions/checkout@v3
106106

107-
- name: 'Brew setup on macOS' # x-ref c8e49ba8f8b9ce
107+
- name: "Brew setup on macOS" # x-ref c8e49ba8f8b9ce
108108
if: ${{ startsWith(matrix.os, 'macos-') == true }}
109109
run: |
110110
set -e pipefail
@@ -128,7 +128,7 @@ jobs:
128128
- uses: actions/upload-artifact@v4
129129
with:
130130
name: cibw-wheels-${{ matrix.os }}-${{ strategy.job-index }}
131-
path: './wheelhouse/*.whl'
131+
path: "./wheelhouse/*.whl"
132132

133133
build_sdist:
134134
name: Build source distribution
@@ -143,7 +143,6 @@ jobs:
143143
with:
144144
path: dist/*.tar.gz
145145

146-
147146
upload_pypi:
148147
needs: [build_wheels, build_sdist]
149148
runs-on: ubuntu-latest

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,6 @@ jobs:
2424
- name: Fix Windows dll inclusion with ctest
2525
if: ${{ runner.os == 'Windows' }}
2626
run: |
27-
cp .\src\build\externals\install\bin\tiledb.dll D:/a/TileDB-Vector-Search/TileDB-Vector-Search/src/build/libtiledbvectorsearch/include/test
27+
cp .\src\build\externals\install\bin\tiledb.dll D:/a/TileDB-Vector-Search/TileDB-Vector-Search/src/build/libtiledbvectorsearch/include/test
2828
- name: Run Tests
2929
run: cmake --build ./src/build --target check-ci

apis/python/requirements-py.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
numpy==1.24.3
22
tiledb-cloud==0.10.24
33
tiledb==0.27.0
4-
scikit-learn==1.3.2
4+
scikit-learn==1.3.2

apis/python/src/tiledb/vector_search/__init__.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,6 @@
2020
from .storage_formats import STORAGE_VERSION
2121
from .storage_formats import storage_formats
2222

23-
from ._tiledbvspy import FeatureVector
24-
from ._tiledbvspy import FeatureVectorArray
25-
from ._tiledbvspy import IndexFlatL2
26-
from ._tiledbvspy import IndexIVFFlat
27-
from ._tiledbvspy import Ctx
28-
2923
try:
3024
from tiledb.vector_search.version import version as __version__
3125
except ImportError:

apis/python/src/tiledb/vector_search/embeddings/langchain_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
# class LangChainEmbedding(ObjectEmbedding):
9-
class LangChainEmbedding():
9+
class LangChainEmbedding:
1010
"""
1111
Embedding functions from `langchain.embeddings` package.
1212
"""

apis/python/src/tiledb/vector_search/index.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
import time
55
from typing import Any, Mapping, Optional
66

7+
from tiledb.vector_search import _tiledbvspy as vspy
78
from tiledb.vector_search.module import *
89
from tiledb.vector_search.storage_formats import storage_formats
9-
from tiledb.vector_search import _tiledbvspy as vspy
1010

1111
MAX_UINT64 = np.iinfo(np.dtype("uint64")).max
1212
MAX_INT32 = np.iinfo(np.dtype("int32")).max
@@ -286,7 +286,10 @@ def check_has_updates(self):
286286

287287
def set_has_updates(self, has_updates: bool = True):
288288
self.has_updates = has_updates
289-
if "has_updates" not in self.group.meta or self.group.meta["has_updates"] != has_updates:
289+
if (
290+
"has_updates" not in self.group.meta
291+
or self.group.meta["has_updates"] != has_updates
292+
):
290293
self.group.close()
291294
self.group = tiledb.Group(self.uri, "w", ctx=tiledb.Ctx(self.config))
292295
self.group.meta["has_updates"] = has_updates

apis/python/src/tiledb/vector_search/module.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
import numpy as np
55

66
import tiledb
7-
from tiledb.vector_search._tiledbvspy import *
87
from tiledb.vector_search import _tiledbvspy as vspy
8+
from tiledb.vector_search._tiledbvspy import *
9+
910

1011
def load_as_matrix(
1112
path: str,
@@ -35,7 +36,7 @@ def load_as_matrix(
3536

3637
a = tiledb.ArraySchema.load(path, ctx=tiledb.Ctx(config))
3738
dtype = a.attr(0).dtype
38-
# Read all rows from column 0 -> `size`. Set no upper_bound. Note that if `size` is None then
39+
# Read all rows from column 0 -> `size`. Set no upper_bound. Note that if `size` is None then
3940
# we'll read to the column domain length.
4041
if dtype == np.float32:
4142
m = tdbColMajorMatrix_f32(ctx, path, 0, None, 0, size, 0, timestamp)

apis/python/src/tiledb/vector_search/vamana_index.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,17 @@
1-
import json
2-
import multiprocessing
31
from typing import Any, Mapping
42

53
import numpy as np
6-
from tiledb.cloud.dag import Mode
74

5+
from tiledb.vector_search import _tiledbvspy as vspy
86
from tiledb.vector_search import index
97
from tiledb.vector_search.module import *
10-
from tiledb.vector_search.storage_formats import (STORAGE_VERSION,
11-
storage_formats,
12-
validate_storage_version)
13-
from tiledb.vector_search.utils import add_to_group
14-
from tiledb.vector_search import _tiledbvspy as vspy
8+
from tiledb.vector_search.storage_formats import STORAGE_VERSION
9+
from tiledb.vector_search.storage_formats import storage_formats
1510

1611
MAX_UINT64 = np.iinfo(np.dtype("uint64")).max
1712
INDEX_TYPE = "VAMANA"
1813

14+
1915
class VamanaIndex(index.Index):
2016
"""
2117
Open a Vamana index
@@ -38,12 +34,16 @@ def __init__(
3834
super().__init__(uri=uri, config=config, timestamp=timestamp)
3935
self.index_type = INDEX_TYPE
4036
self.index = vspy.IndexVamana(vspy.Ctx(config), uri)
41-
self.db_uri = self.group[storage_formats[self.storage_version]["PARTS_ARRAY_NAME"]].uri
42-
self.ids_uri = self.group[storage_formats[self.storage_version]["IDS_ARRAY_NAME"]].uri
43-
37+
self.db_uri = self.group[
38+
storage_formats[self.storage_version]["PARTS_ARRAY_NAME"]
39+
].uri
40+
self.ids_uri = self.group[
41+
storage_formats[self.storage_version]["IDS_ARRAY_NAME"]
42+
].uri
43+
4444
schema = tiledb.ArraySchema.load(self.db_uri, ctx=tiledb.Ctx(self.config))
4545
self.dimensions = self.index.dimension()
46-
46+
4747
self.dtype = np.dtype(self.group.meta.get("dtype", None))
4848
if self.dtype is None:
4949
self.dtype = np.dtype(schema.attr("values").dtype)
@@ -86,6 +86,7 @@ def query_internal(
8686
# TODO(paris): Actually run the query.
8787
return [], []
8888

89+
8990
# TODO(paris): Pass more arguments to C++, i.e. storage_version.
9091
def create(
9192
uri: str,
@@ -98,22 +99,19 @@ def create(
9899
storage_version: str = STORAGE_VERSION,
99100
**kwargs,
100101
) -> VamanaIndex:
101-
if not group_exists:
102+
if not group_exists:
102103
ctx = vspy.Ctx(config)
103104
index = vspy.IndexVamana(
104-
feature_type=np.dtype(vector_type).name,
105-
id_type=np.dtype(id_type).name,
106-
adjacency_row_index_type=np.dtype(adjacency_row_index_type).name,
105+
feature_type=np.dtype(vector_type).name,
106+
id_type=np.dtype(id_type).name,
107+
adjacency_row_index_type=np.dtype(adjacency_row_index_type).name,
107108
dimension=dimensions,
108109
)
109110
# TODO(paris): Run all of this with a single C++ call.
110111
empty_vector = vspy.FeatureVectorArray(
111-
dimensions,
112-
0,
113-
np.dtype(vector_type).name,
114-
np.dtype(id_type).name
115-
)
112+
dimensions, 0, np.dtype(vector_type).name, np.dtype(id_type).name
113+
)
116114
index.train(empty_vector)
117115
index.add(empty_vector)
118116
index.write_index(ctx, uri)
119-
return VamanaIndex(uri=uri, config=config, memory_budget=1000000)
117+
return VamanaIndex(uri=uri, config=config, memory_budget=1000000)

apis/python/test/test_api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def test_load_matrix(tmpdir):
2222
assert np.array_equal(m, data)
2323
assert np.array_equal(orig_matrix[0, 0], data[0, 0])
2424

25+
2526
def test_load_matrix_specify_size(tmpdir):
2627
p = str(tmpdir.mkdir("test").join("test.tdb"))
2728
data = np.random.rand(12).astype(np.float32).reshape(3, 4)
@@ -41,6 +42,7 @@ def test_load_matrix_specify_size(tmpdir):
4142
m = vs.load_as_array(p, size=0)
4243
assert m.shape == (3, 0)
4344

45+
4446
def test_vector(tmpdir):
4547
v = vspy._create_vector_u64()
4648
assert np.array_equal(np.array(v), np.arange(10))

apis/python/test/test_cloud.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import os
22
import unittest
33

4-
from common import *
54
from array_paths import *
5+
from common import *
66

77
import tiledb.vector_search as vs
88
from tiledb.cloud import groups

0 commit comments

Comments
 (0)