|
10 | 10 | from unittest.mock import patch |
11 | 11 | import os |
12 | 12 |
|
| 13 | +from core.common.models.index_build_parameters import DataType |
| 14 | +from core.common.models.index_builder import CagraGraphBuildAlgo |
| 15 | +from core.common.models.index_builder.faiss import FaissGPUIndexCagraBuilder |
13 | 16 | from core.index_builder.faiss.faiss_index_build_service import FaissIndexBuildService |
| 17 | +from core.index_builder.index_builder_utils import calculate_ivf_pq_n_lists |
14 | 18 |
|
15 | 19 |
|
16 | 20 | class TestFaissIndexBuildService: |
@@ -53,12 +57,43 @@ def test_build_binary_index_success( |
53 | 57 | def _do_test_build_index_success( |
54 | 58 | self, service, vectors_dataset, index_build_parameters, tmp_path |
55 | 59 | ): |
56 | | - output_path = str(tmp_path / "output.index") |
57 | | - service.build_index(index_build_parameters, vectors_dataset, output_path) |
58 | | - |
59 | | - # Verify OMP threads were set correctly |
60 | | - assert faiss.omp_get_num_threads() == 2 # 8 CPUs/4 = 2 threads |
61 | | - assert os.path.exists(output_path) |
| 60 | + with patch( |
| 61 | + "core.common.models.index_builder.faiss.FaissGPUIndexCagraBuilder.from_dict" |
| 62 | + ) as mock_gpu_from_dict: |
| 63 | + |
| 64 | + output_path = str(tmp_path / "output.index") |
| 65 | + mock_gpu_from_dict.return_value = FaissGPUIndexCagraBuilder() |
| 66 | + |
| 67 | + service.build_index(index_build_parameters, vectors_dataset, output_path) |
| 68 | + |
| 69 | + # Ensuring that FaissGPUIndexCagraBuilder parameters are set correctly |
| 70 | + expected_params = self._get_expected_gpu_params( |
| 71 | + service, index_build_parameters |
| 72 | + ) |
| 73 | + mock_gpu_from_dict.assert_called_once_with(expected_params) |
| 74 | + |
| 75 | + assert faiss.omp_get_num_threads() == 2 # 8 CPUs/4 = 2 threads |
| 76 | + assert os.path.exists(output_path) |
| 77 | + |
| 78 | + def _get_expected_gpu_params(self, service, index_build_parameters): |
| 79 | + if index_build_parameters.data_type != DataType.BINARY: |
| 80 | + return { |
| 81 | + "ivf_pq_params": { |
| 82 | + "n_lists": calculate_ivf_pq_n_lists( |
| 83 | + index_build_parameters.doc_count |
| 84 | + ), |
| 85 | + "pq_dim": int( |
| 86 | + index_build_parameters.dimension |
| 87 | + / service.PQ_DIM_COMPRESSION_FACTOR |
| 88 | + ), |
| 89 | + }, |
| 90 | + "graph_degree": index_build_parameters.index_parameters.algorithm_parameters.m |
| 91 | + * 2, |
| 92 | + "intermediate_graph_degree": index_build_parameters.index_parameters.algorithm_parameters.m |
| 93 | + * 4, |
| 94 | + } |
| 95 | + else: |
| 96 | + return {"graph_build_algo": CagraGraphBuildAlgo.NN_DESCENT} |
62 | 97 |
|
63 | 98 | def test_build_index_gpu_creation_error( |
64 | 99 | self, service, vectors_dataset, index_build_parameters, tmp_path |
|
0 commit comments