|
1 | 1 | /* |
2 | | - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. |
| 2 | + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. |
3 | 3 | * SPDX-License-Identifier: Apache-2.0 |
4 | 4 | */ |
5 | 5 | #pragma once |
|
17 | 17 | #include <utils.h> |
18 | 18 |
|
19 | 19 | #include <chrono> |
| 20 | +#include <cmath> |
| 21 | +#include <cstdio> |
20 | 22 | #include <memory> |
21 | 23 | #include <vector> |
22 | 24 |
|
@@ -172,12 +174,14 @@ class diskann_ssd : public algo<T> { |
172 | 174 | struct build_param { |
173 | 175 | uint32_t R; |
174 | 176 | uint32_t L_build; |
175 | | - uint32_t build_pq_bytes = 0; |
176 | | - float alpha = 1.2; |
177 | | - int num_threads = omp_get_max_threads(); |
178 | | - uint32_t QD = 192; |
179 | | - std::string dataset_base_file = ""; |
180 | | - std::string index_file = ""; |
| 177 | + uint32_t build_pq_bytes = 0; |
| 178 | + float alpha = 1.2; |
| 179 | + int num_threads = omp_get_max_threads(); |
| 180 | + uint32_t QD = 192; |
| 181 | + std::string dataset_base_file = ""; |
| 182 | + std::string index_file = ""; |
| 183 | + uint32_t build_dram_budget_megabytes = std::numeric_limits<uint32_t>::max(); |
| 184 | + uint32_t search_dram_budget_megabytes = std::numeric_limits<uint32_t>::max(); |
181 | 185 | }; |
182 | 186 | using search_param_base = typename algo<T>::search_param; |
183 | 187 |
|
@@ -232,12 +236,17 @@ template <typename T> |
232 | 236 | diskann_ssd<T>::diskann_ssd(Metric metric, int dim, const build_param& param) : algo<T>(metric, dim) |
233 | 237 | { |
234 | 238 | // Currently set the indexing RAM budget and the search RAM budget to max value to avoid sharding |
235 | | - uint32_t build_dram_budget = std::numeric_limits<uint32_t>::max(); |
236 | | - uint32_t search_dram_budget = std::numeric_limits<uint32_t>::max(); |
| 239 | + float build_dram_budget = static_cast<float>(param.build_dram_budget_megabytes) / 1024.0f; |
| 240 | + float search_dram_budget = static_cast<float>(param.search_dram_budget_megabytes) / 1024.0f; |
| 241 | + char search_buf[16]; |
| 242 | + char build_buf[16]; |
| 243 | + std::snprintf(search_buf, sizeof(search_buf), "%.2f", search_dram_budget); |
| 244 | + std::snprintf(build_buf, sizeof(build_buf), "%.2f", build_dram_budget); |
| 245 | + const std::string search_dram_budget_gb(search_buf); |
| 246 | + const std::string build_dram_budget_gb(build_buf); |
237 | 247 | index_build_params_str = |
238 | 248 | std::string(std::to_string(param.R)) + " " + std::string(std::to_string(param.L_build)) + " " + |
239 | | - std::string(std::to_string(search_dram_budget)) + " " + |
240 | | - std::string(std::to_string(build_dram_budget)) + " " + |
| 249 | + search_dram_budget_gb + " " + build_dram_budget_gb + " " + |
241 | 250 | std::string(std::to_string(param.num_threads)) + " " + std::string(std::to_string(false)) + |
242 | 251 | " " + std::string(std::to_string(false)) + " " + std::string(std::to_string(0)) + " " + |
243 | 252 | std::string(std::to_string(param.QD)); |
|
0 commit comments