Skip to content

Commit f84739c

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
chore: Make MatchingEngineIndexConfig's algorithmConfig optional
PiperOrigin-RevId: 822655431
1 parent 0fc74de commit f84739c

File tree

3 files changed

+54
-7
lines changed

3 files changed

+54
-7
lines changed

google/cloud/aiplatform/matching_engine/matching_engine_index.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -596,10 +596,15 @@ def create_tree_ah_index(
596596
597597
"""
598598

599-
algorithm_config = matching_engine_index_config.TreeAhConfig(
600-
leaf_node_embedding_count=leaf_node_embedding_count,
601-
leaf_nodes_to_search_percent=leaf_nodes_to_search_percent,
602-
)
599+
algorithm_config = None
600+
if (
601+
leaf_node_embedding_count is not None
602+
or leaf_nodes_to_search_percent is not None
603+
):
604+
algorithm_config = matching_engine_index_config.TreeAhConfig(
605+
leaf_node_embedding_count=leaf_node_embedding_count,
606+
leaf_nodes_to_search_percent=leaf_nodes_to_search_percent,
607+
)
603608

604609
config = matching_engine_index_config.MatchingEngineIndexConfig(
605610
dimensions=dimensions,

google/cloud/aiplatform/matching_engine/matching_engine_index_config.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ class MatchingEngineIndexConfig:
120120
dimensions (int):
121121
Required. The number of dimensions of the input vectors.
122122
algorithm_config (AlgorithmConfig):
123-
Required. The configuration with regard to the algorithms used for efficient search.
123+
Optional. The configuration with regard to the algorithms used for efficient search.
124124
approximate_neighbors_count (int):
125125
Optional. The default number of neighbors to find via approximate search before exact reordering is
126126
performed. Exact reordering is a procedure where results returned by an
@@ -139,7 +139,7 @@ class MatchingEngineIndexConfig:
139139
"""
140140

141141
dimensions: int
142-
algorithm_config: AlgorithmConfig
142+
algorithm_config: Optional[AlgorithmConfig] = None
143143
approximate_neighbors_count: Optional[int] = None
144144
distance_measure_type: Optional[DistanceMeasureType] = None
145145
feature_norm_type: Optional[FeatureNormType] = None
@@ -153,10 +153,13 @@ def as_dict(self) -> Dict[str, Any]:
153153
"""
154154
res = {
155155
"dimensions": self.dimensions,
156-
"algorithmConfig": self.algorithm_config.as_dict(),
157156
"approximateNeighborsCount": self.approximate_neighbors_count,
158157
"distanceMeasureType": self.distance_measure_type,
159158
"featureNormType": self.feature_norm_type,
160159
"shardSize": self.shard_size,
161160
}
161+
if self.algorithm_config:
162+
res["algorithmConfig"] = self.algorithm_config.as_dict()
163+
else:
164+
res["algorithmConfig"] = None
162165
return res

tests/unit/aiplatform/test_matching_engine_index.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,45 @@ def test_create_tree_ah_index_backward_compatibility(self, create_index_mock):
618618
timeout=None,
619619
)
620620

621+
@pytest.mark.usefixtures("get_index_mock")
622+
def test_create_tree_ah_index_empty_algorithm_config(self, create_index_mock):
623+
aiplatform.init(project=_TEST_PROJECT)
624+
625+
aiplatform.MatchingEngineIndex.create_tree_ah_index(
626+
display_name=_TEST_INDEX_DISPLAY_NAME,
627+
contents_delta_uri=_TEST_CONTENTS_DELTA_URI,
628+
dimensions=_TEST_INDEX_CONFIG_DIMENSIONS,
629+
approximate_neighbors_count=_TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT,
630+
distance_measure_type=_TEST_INDEX_DISTANCE_MEASURE_TYPE.value,
631+
feature_norm_type=_TEST_INDEX_FEATURE_NORM_TYPE.value,
632+
description=_TEST_INDEX_DESCRIPTION,
633+
labels=_TEST_LABELS,
634+
)
635+
636+
expected = gca_index.Index(
637+
display_name=_TEST_INDEX_DISPLAY_NAME,
638+
metadata={
639+
"config": {
640+
"algorithmConfig": None,
641+
"dimensions": _TEST_INDEX_CONFIG_DIMENSIONS,
642+
"approximateNeighborsCount": _TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT,
643+
"distanceMeasureType": _TEST_INDEX_DISTANCE_MEASURE_TYPE,
644+
"featureNormType": _TEST_INDEX_FEATURE_NORM_TYPE,
645+
"shardSize": None,
646+
},
647+
"contentsDeltaUri": _TEST_CONTENTS_DELTA_URI,
648+
},
649+
description=_TEST_INDEX_DESCRIPTION,
650+
labels=_TEST_LABELS,
651+
)
652+
653+
create_index_mock.assert_called_once_with(
654+
parent=_TEST_PARENT,
655+
index=expected,
656+
metadata=_TEST_REQUEST_METADATA,
657+
timeout=None,
658+
)
659+
621660
@pytest.mark.usefixtures("get_index_mock")
622661
@pytest.mark.parametrize("sync", [True, False])
623662
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)