|
28 | 28 |
|
29 | 29 | #include "curve_trees.h" |
30 | 30 |
|
| 31 | +#include "common/threadpool.h" |
31 | 32 | #include "crypto/crypto.h" |
32 | 33 | #include "fcmp_pp_crypto.h" |
33 | 34 | #include "fcmp_pp_types.h" |
34 | 35 | #include "profile_tools.h" |
35 | 36 |
|
| 37 | +#include <cstring> |
| 38 | + |
36 | 39 | namespace |
37 | 40 | { |
38 | 41 | // Struct composed of ec elems needed to get a full-fledged leaf tuple |
@@ -146,5 +149,198 @@ static PreLeafTuple output_to_pre_leaf_tuple(const OutputPair &output_pair) |
146 | 149 | } |
147 | 150 | //---------------------------------------------------------------------------------------------------------------------- |
148 | 151 | //---------------------------------------------------------------------------------------------------------------------- |
| 152 | +// CurveTrees private member functions |
| 153 | +//---------------------------------------------------------------------------------------------------------------------- |
| 154 | +template<typename C1, typename C2> |
| 155 | +void CurveTrees<C1, C2>::outputs_to_leaves(std::vector<UnifiedOutput> &&new_outputs, |
| 156 | + std::vector<typename C1::Scalar> &flattened_leaves_out, |
| 157 | + std::vector<UnifiedOutput> &valid_outputs_out) |
| 158 | +{ |
| 159 | + flattened_leaves_out.clear(); |
| 160 | + valid_outputs_out.clear(); |
| 161 | + |
| 162 | + TIME_MEASURE_START(set_valid_leaves); |
| 163 | + |
| 164 | + // Keep track of valid outputs to make sure we only use leaves from valid outputs. Can't use std::vector<bool> |
| 165 | + // because std::vector<bool> concurrent access is not thread safe. |
| 166 | + enum Boolean : uint8_t { |
| 167 | + False = 0, |
| 168 | + True = 1, |
| 169 | + }; |
| 170 | + std::vector<Boolean> valid_outputs(new_outputs.size(), False); |
| 171 | + |
| 172 | + tools::threadpool& tpool = tools::threadpool::getInstanceForCompute(); |
| 173 | + tools::threadpool::waiter waiter(tpool); |
| 174 | + const std::size_t n_threads = std::max<std::size_t>(1, tpool.get_max_concurrency()); |
| 175 | + |
| 176 | + TIME_MEASURE_START(convert_valid_leaves); |
| 177 | + // Step 1. Multithreaded convert valid outputs into Edwards derivatives needed to get Wei coordinates |
| 178 | + std::vector<PreLeafTuple> pre_leaves; |
| 179 | + pre_leaves.resize(new_outputs.size()); |
| 180 | + const std::size_t LEAF_CONVERT_BATCH_SIZE = std::max<std::size_t>(1, (new_outputs.size() / n_threads)); |
| 181 | + for (std::size_t i = 0; i < new_outputs.size(); i += LEAF_CONVERT_BATCH_SIZE) |
| 182 | + { |
| 183 | + const std::size_t end = std::min(i + LEAF_CONVERT_BATCH_SIZE, new_outputs.size()); |
| 184 | + tpool.submit(&waiter, |
| 185 | + [ |
| 186 | + &new_outputs, |
| 187 | + &valid_outputs, |
| 188 | + &pre_leaves, |
| 189 | + i, |
| 190 | + end |
| 191 | + ]() |
| 192 | + { |
| 193 | + for (std::size_t j = i; j < end; ++j) |
| 194 | + { |
| 195 | + CHECK_AND_ASSERT_THROW_MES(!valid_outputs.at(j), "unexpected already valid output"); |
| 196 | + |
| 197 | + const auto &output_pair = new_outputs.at(j).output_pair; |
| 198 | + try |
| 199 | + { |
| 200 | + pre_leaves.at(j) = output_to_pre_leaf_tuple(output_pair); |
| 201 | + } |
| 202 | + catch(...) |
| 203 | + { |
| 204 | + /* Invalid outputs can't be added to the tree */ |
| 205 | + LOG_PRINT_L2("Output " << new_outputs.at(j).unified_id << " is invalid (out pubkey " |
| 206 | + << output_pubkey_cref(output_pair) |
| 207 | + << " , commitment " << commitment_cref(output_pair) << ")"); |
| 208 | + continue; |
| 209 | + } |
| 210 | + |
| 211 | + valid_outputs.at(j) = True; |
| 212 | + } |
| 213 | + }, |
| 214 | + true |
| 215 | + ); |
| 216 | + } |
| 217 | + |
| 218 | + CHECK_AND_ASSERT_THROW_MES(waiter.wait(), "failed to convert outputs to ed derivatives"); |
| 219 | + TIME_MEASURE_FINISH(convert_valid_leaves); |
| 220 | + |
| 221 | + TIME_MEASURE_START(collect_derivatives); |
| 222 | + // Step 2. Collect valid Edwards y derivatives |
| 223 | + const std::size_t n_valid_outputs = std::count(valid_outputs.begin(), valid_outputs.end(), True); |
| 224 | + const std::size_t n_valid_leaf_points = n_valid_outputs * LEAF_TUPLE_POINTS; |
| 225 | + |
| 226 | + // Collecting [(1+y),(1-y),((1-y)*x)] for batch inversion |
| 227 | + std::unique_ptr<fe[]> one_plus_y_vec = std::make_unique<fe[]>(n_valid_leaf_points); |
| 228 | + std::unique_ptr<fe[]> fe_batch = std::make_unique<fe[]>(n_valid_leaf_points * 2); |
| 229 | + std::unique_ptr<fe[]> batch_inv_res = std::make_unique<fe[]>(n_valid_leaf_points * 2); |
| 230 | + |
| 231 | + std::size_t valid_i = 0, batch_i = 0; |
| 232 | + for (std::size_t i = 0; i < valid_outputs.size(); ++i) |
| 233 | + { |
| 234 | + if (!valid_outputs[i]) |
| 235 | + continue; |
| 236 | + |
| 237 | + CHECK_AND_ASSERT_THROW_MES(n_valid_leaf_points > valid_i, "unexpected valid_i"); |
| 238 | + |
| 239 | + auto &pl = pre_leaves.at(i); |
| 240 | + |
| 241 | + auto &O_derivatives = pl.O_derivatives; |
| 242 | + auto &I_derivatives = pl.I_derivatives; |
| 243 | + auto &C_derivatives = pl.C_derivatives; |
| 244 | + |
| 245 | + static_assert(LEAF_TUPLE_POINTS == 3, "unexpected n leaf tuple points"); |
| 246 | + |
| 247 | + // TODO: avoid copying underlying (tried using pointer to pointers, but wasn't clean) |
| 248 | + memcpy(&one_plus_y_vec[valid_i++], &O_derivatives.one_plus_y, sizeof(fe)); |
| 249 | + memcpy(&one_plus_y_vec[valid_i++], &I_derivatives.one_plus_y, sizeof(fe)); |
| 250 | + memcpy(&one_plus_y_vec[valid_i++], &C_derivatives.one_plus_y, sizeof(fe)); |
| 251 | + |
| 252 | + memcpy(&fe_batch[batch_i++], &O_derivatives.one_minus_y, sizeof(fe)); |
| 253 | + memcpy(&fe_batch[batch_i++], &O_derivatives.one_minus_y_mul_x, sizeof(fe)); |
| 254 | + |
| 255 | + memcpy(&fe_batch[batch_i++], &I_derivatives.one_minus_y, sizeof(fe)); |
| 256 | + memcpy(&fe_batch[batch_i++], &I_derivatives.one_minus_y_mul_x, sizeof(fe)); |
| 257 | + |
| 258 | + memcpy(&fe_batch[batch_i++], &C_derivatives.one_minus_y, sizeof(fe)); |
| 259 | + memcpy(&fe_batch[batch_i++], &C_derivatives.one_minus_y_mul_x, sizeof(fe)); |
| 260 | + } |
| 261 | + |
| 262 | + CHECK_AND_ASSERT_THROW_MES(n_valid_leaf_points == valid_i, "unexpected end valid_i"); |
| 263 | + CHECK_AND_ASSERT_THROW_MES((n_valid_leaf_points * 2) == batch_i, "unexpected end batch_i"); |
| 264 | + TIME_MEASURE_FINISH(collect_derivatives); |
| 265 | + |
| 266 | + TIME_MEASURE_START(batch_invert); |
| 267 | + // Step 3. Get batch inverse of all valid (1-y)'s and ((1-y)*x)'s |
| 268 | + // - Batch inversion is significantly faster than inverting 1 at a time |
| 269 | + fe_batch_invert(batch_inv_res.get(), fe_batch.get(), n_valid_leaf_points * 2); |
| 270 | + TIME_MEASURE_FINISH(batch_invert); |
| 271 | + |
| 272 | + TIME_MEASURE_START(get_selene_scalars); |
| 273 | + // Step 4. Multithreaded get Wei coordinates and convert to Selene scalars |
| 274 | + const std::size_t n_valid_leaf_elems = n_valid_outputs * LEAF_TUPLE_SIZE; |
| 275 | + flattened_leaves_out.resize(n_valid_leaf_elems); |
| 276 | + CHECK_AND_ASSERT_THROW_MES(flattened_leaves_out.size() == (2 * n_valid_leaf_points), |
| 277 | + "unexpected size of flattened leaves"); |
| 278 | + |
| 279 | + const std::size_t DERIVATION_BATCH_SIZE = std::max<std::size_t>(1, (n_valid_leaf_points / n_threads)); |
| 280 | + for (std::size_t i = 0; i < n_valid_leaf_points; i += DERIVATION_BATCH_SIZE) |
| 281 | + { |
| 282 | + const std::size_t end = std::min(n_valid_leaf_points, i + DERIVATION_BATCH_SIZE); |
| 283 | + tpool.submit(&waiter, |
| 284 | + [ |
| 285 | + &batch_inv_res, |
| 286 | + &one_plus_y_vec, |
| 287 | + &flattened_leaves_out, |
| 288 | + i, |
| 289 | + end |
| 290 | + ]() |
| 291 | + { |
| 292 | + std::size_t point_idx = i * 2; |
| 293 | + for (std::size_t j = i; j < end; ++j) |
| 294 | + { |
| 295 | + crypto::ec_coord wei_x; |
| 296 | + crypto::ec_coord wei_y; |
| 297 | + fe_ed_derivatives_to_wei_x_y( |
| 298 | + to_bytes(wei_x), |
| 299 | + to_bytes(wei_y), |
| 300 | + batch_inv_res[point_idx]/*inv_one_minus_y*/, |
| 301 | + one_plus_y_vec[j], |
| 302 | + batch_inv_res[point_idx+1]/*inv_one_minus_y_mul_x*/ |
| 303 | + ); |
| 304 | + |
| 305 | + flattened_leaves_out[point_idx++] = tower_cycle::selene_scalar_from_bytes(wei_x); |
| 306 | + flattened_leaves_out[point_idx++] = tower_cycle::selene_scalar_from_bytes(wei_y); |
| 307 | + } |
| 308 | + }, |
| 309 | + true |
| 310 | + ); |
| 311 | + } |
| 312 | + |
| 313 | + CHECK_AND_ASSERT_THROW_MES(waiter.wait(), "failed to convert outputs to wei coords"); |
| 314 | + TIME_MEASURE_FINISH(get_selene_scalars); |
| 315 | + |
| 316 | + // Step 5. Set valid tuples to be stored in the db |
| 317 | + valid_outputs_out.reserve(n_valid_outputs); |
| 318 | + for (std::size_t i = 0; i < valid_outputs.size(); ++i) |
| 319 | + { |
| 320 | + if (!valid_outputs[i]) |
| 321 | + continue; |
| 322 | + |
| 323 | + // We can derive leaf tuples from output pairs, so we store just the unified output in the db to save 32 bytes |
| 324 | + valid_outputs_out.emplace_back(std::move(new_outputs.at(i))); |
| 325 | + } |
| 326 | + CHECK_AND_ASSERT_THROW_MES(valid_outputs_out.size() == n_valid_outputs, "unexpected size of valid_outputs_out"); |
| 327 | + |
| 328 | + TIME_MEASURE_FINISH(set_valid_leaves); |
| 329 | + |
| 330 | + m_convert_valid_leaves_ms += convert_valid_leaves; |
| 331 | + m_collect_derivatives_ms += collect_derivatives; |
| 332 | + m_batch_invert_ms += batch_invert; |
| 333 | + m_get_selene_scalars_ms += get_selene_scalars; |
| 334 | + |
| 335 | + m_set_valid_leaves_ms += set_valid_leaves; |
| 336 | + |
| 337 | + LOG_PRINT_L2("Total time spent setting leaves: " << m_set_valid_leaves_ms / 1000 |
| 338 | + << " , converting valid leaves: " << m_convert_valid_leaves_ms / 1000 |
| 339 | + << " , collecting derivatives: " << m_collect_derivatives_ms / 1000 |
| 340 | + << " , batch invert: " << m_batch_invert_ms / 1000 |
| 341 | + << " , get selene scalars: " << m_get_selene_scalars_ms / 1000); |
| 342 | +} |
| 343 | +//---------------------------------------------------------------------------------------------------------------------- |
| 344 | +//---------------------------------------------------------------------------------------------------------------------- |
149 | 345 | } //namespace curve_trees |
150 | 346 | } //namespace fcmp_pp |
0 commit comments