|
| 1 | +/** |
| 2 | + * @file |
| 3 | + * @author Michal Sedlak <[email protected]> |
| 4 | + * @brief Threshold algorithm implementation |
| 5 | + * |
| 6 | + * Copyright: (C) 2024 CESNET, z.s.p.o. |
| 7 | + * SPDX-License-Identifier: BSD-3-Clause |
| 8 | + */ |
| 9 | + |
| 10 | +#include <aggregator/thresholdAlgorithm.hpp> |
| 11 | +#include <aggregator/aggregatedField.hpp> |
| 12 | +#include <common/logger.hpp> |
| 13 | + |
| 14 | +namespace fdsdump { |
| 15 | +namespace aggregator { |
| 16 | + |
| 17 | + |
| 18 | +static void min_aggregate_value(DataType data_type, Value& value, const Value &other) |
| 19 | +{ |
| 20 | + switch (data_type) { |
| 21 | + case DataType::Unsigned8: |
| 22 | + value.u8 = std::min<uint8_t>(other.u8, value.u8); |
| 23 | + break; |
| 24 | + case DataType::Unsigned16: |
| 25 | + value.u16 = std::min<uint16_t>(other.u16, value.u16); |
| 26 | + break; |
| 27 | + case DataType::Unsigned32: |
| 28 | + value.u32 = std::min<uint32_t>(other.u32, value.u32); |
| 29 | + break; |
| 30 | + case DataType::Unsigned64: |
| 31 | + value.u64 = std::min<uint64_t>(other.u64, value.u64); |
| 32 | + break; |
| 33 | + case DataType::Signed8: |
| 34 | + value.i8 = std::min<int8_t>(other.i8, value.i8); |
| 35 | + break; |
| 36 | + case DataType::Signed16: |
| 37 | + value.i16 = std::min<int16_t>(other.i16, value.i16); |
| 38 | + break; |
| 39 | + case DataType::Signed32: |
| 40 | + value.i32 = std::min<int32_t>(other.i32, value.i32); |
| 41 | + break; |
| 42 | + case DataType::Signed64: |
| 43 | + value.i64 = std::min<int64_t>(other.i64, value.i64); |
| 44 | + break; |
| 45 | + case DataType::DateTime: |
| 46 | + value.ts_millisecs = std::min<uint64_t>(other.ts_millisecs, value.ts_millisecs); |
| 47 | + break; |
| 48 | + default: |
| 49 | + assert(0); |
| 50 | + } |
| 51 | +} |
| 52 | + |
| 53 | +void max_aggregate_value(DataType data_type, Value &value, Value &other) |
| 54 | +{ |
| 55 | + switch (data_type) { |
| 56 | + case DataType::Unsigned8: |
| 57 | + value.u8 = std::max<uint8_t>(other.u8, value.u8); |
| 58 | + break; |
| 59 | + case DataType::Unsigned16: |
| 60 | + value.u16 = std::max<uint16_t>(other.u16, value.u16); |
| 61 | + break; |
| 62 | + case DataType::Unsigned32: |
| 63 | + value.u32 = std::max<uint32_t>(other.u32, value.u32); |
| 64 | + break; |
| 65 | + case DataType::Unsigned64: |
| 66 | + value.u64 = std::max<uint64_t>(other.u64, value.u64); |
| 67 | + break; |
| 68 | + case DataType::Signed8: |
| 69 | + value.i8 = std::max<int8_t>(other.i8, value.i8); |
| 70 | + break; |
| 71 | + case DataType::Signed16: |
| 72 | + value.i16 = std::max<int16_t>(other.i16, value.i16); |
| 73 | + break; |
| 74 | + case DataType::Signed32: |
| 75 | + value.i32 = std::max<int32_t>(other.i32, value.i32); |
| 76 | + break; |
| 77 | + case DataType::Signed64: |
| 78 | + value.i64 = std::max<int64_t>(other.i64, value.i64); |
| 79 | + break; |
| 80 | + case DataType::DateTime: |
| 81 | + value.ts_millisecs = std::max<uint64_t>(other.ts_millisecs, value.ts_millisecs); |
| 82 | + break; |
| 83 | + default: |
| 84 | + assert(0); |
| 85 | + } |
| 86 | +} |
| 87 | + |
| 88 | +void sum_aggregate_value(DataType data_type, Value &value, const Value &other) |
| 89 | +{ |
| 90 | + switch (data_type) { |
| 91 | + case DataType::Unsigned64: |
| 92 | + value.u64 += other.u64; |
| 93 | + break; |
| 94 | + case DataType::Signed64: |
| 95 | + value.i64 += other.i64; |
| 96 | + break; |
| 97 | + default: assert(0); |
| 98 | + } |
| 99 | +} |
| 100 | + |
| 101 | +static std::vector<uint8_t> estabilish_threshold(std::vector<HashTable *> &tables, View &view, unsigned int row) |
| 102 | +{ |
| 103 | + std::vector<uint8_t> buffer; |
| 104 | + |
| 105 | + if (view.is_fixed_size()) { |
| 106 | + for (auto *table : tables) { |
| 107 | + if (table->items().size() <= row) { |
| 108 | + continue; |
| 109 | + } |
| 110 | + |
| 111 | + uint8_t *rec = table->items()[row]; |
| 112 | + if (buffer.empty()) { |
| 113 | + std::size_t size = view.record_size(rec); |
| 114 | + buffer.resize(size); |
| 115 | + std::memcpy(&buffer[0], rec, size); |
| 116 | + continue; |
| 117 | + } |
| 118 | + |
| 119 | + for (const auto &order_field : view.order_fields()) { |
| 120 | + Value &a = view.access_field(*order_field.field, buffer.data()); |
| 121 | + Value &b = view.access_field(*order_field.field, rec); |
| 122 | + |
| 123 | + switch (order_field.dir) { |
| 124 | + case View::OrderDirection::Ascending: { |
| 125 | + if (order_field.field->is_of_type<MinAggregatedField>()) { |
| 126 | + min_aggregate_value(order_field.field->data_type(), a, b); |
| 127 | + } else if (order_field.field->is_of_type<MaxAggregatedField>()) { |
| 128 | + min_aggregate_value(order_field.field->data_type(), a, b); |
| 129 | + } else if (order_field.field->is_of_type<SumAggregatedField>()) { |
| 130 | + min_aggregate_value(order_field.field->data_type(), a, b); |
| 131 | + } else { |
| 132 | + if (order_field.field->compare(a, b) == CmpResult::Gt) { |
| 133 | + std::memcpy(&a, &b, order_field.field->size()); |
| 134 | + } |
| 135 | + } |
| 136 | + } break; |
| 137 | + |
| 138 | + case View::OrderDirection::Descending: { |
| 139 | + if (order_field.field->is_of_type<MinAggregatedField>()) { |
| 140 | + max_aggregate_value(order_field.field->data_type(), a, b); |
| 141 | + } else if (order_field.field->is_of_type<MaxAggregatedField>()) { |
| 142 | + max_aggregate_value(order_field.field->data_type(), a, b); |
| 143 | + } else if (order_field.field->is_of_type<SumAggregatedField>()) { |
| 144 | + sum_aggregate_value(order_field.field->data_type(), a, b); |
| 145 | + } else { |
| 146 | + if (order_field.field->compare(a, b) == CmpResult::Lt) { |
| 147 | + std::memcpy(&a, &b, order_field.field->size()); |
| 148 | + } |
| 149 | + } |
| 150 | + } break; |
| 151 | + } |
| 152 | + } |
| 153 | + } |
| 154 | + } else { |
| 155 | + //FIXME |
| 156 | + throw std::runtime_error("not implemented"); |
| 157 | + } |
| 158 | + |
| 159 | + return buffer; |
| 160 | +} |
| 161 | + |
| 162 | +ThresholdAlgorithm::ThresholdAlgorithm(std::vector<HashTable *> &tables, View &view, unsigned int top_count) : |
| 163 | + m_result_table(new HashTable(view)), |
| 164 | + m_tables(tables), |
| 165 | + m_view(view), |
| 166 | + m_top_count(top_count), |
| 167 | + m_min_queue(view.rec_orderer()) |
| 168 | +{} |
| 169 | + |
| 170 | +void ThresholdAlgorithm::process_row() |
| 171 | +{ |
| 172 | + for (auto *table : m_tables) { |
| 173 | + if (m_row >= table->items().size()) { |
| 174 | + continue; |
| 175 | + } |
| 176 | + uint8_t *rec = table->items()[m_row]; |
| 177 | + uint8_t *result_rec = nullptr; |
| 178 | + bool found = m_result_table->find_or_create(rec, result_rec); |
| 179 | + if (found) { |
| 180 | + continue; |
| 181 | + } |
| 182 | + // If not found - copy over |
| 183 | + // Key is already copied by find_or_create, so only value needs to be copied |
| 184 | + unsigned int key_size = m_view.key_size(rec); |
| 185 | + unsigned int value_size = m_view.value_size(); |
| 186 | + std::memcpy(result_rec + key_size, rec + key_size, value_size); |
| 187 | + |
| 188 | + for (auto *other_table : m_tables) { |
| 189 | + if (table == other_table) { |
| 190 | + continue; |
| 191 | + } |
| 192 | + if (m_row >= other_table->items().size()) { |
| 193 | + continue; |
| 194 | + } |
| 195 | + uint8_t *other_rec = nullptr; |
| 196 | + bool found = other_table->find(rec, other_rec); |
| 197 | + if (!found) { |
| 198 | + continue; |
| 199 | + } |
| 200 | + for (auto x : m_view.iter_values(result_rec, other_rec)) { |
| 201 | + x.field.merge(x.value1, x.value2); |
| 202 | + } |
| 203 | + } |
| 204 | + m_min_queue.push(result_rec); |
| 205 | + if (m_min_queue.size() > m_top_count) { |
| 206 | + m_min_queue.pop(); |
| 207 | + } |
| 208 | + } |
| 209 | + m_row++; |
| 210 | +} |
| 211 | + |
| 212 | +bool ThresholdAlgorithm::out_of_items() |
| 213 | +{ |
| 214 | + if (m_row >= m_max_row) { |
| 215 | + return true; |
| 216 | + } |
| 217 | + for (auto *table : m_tables) { |
| 218 | + if (table->items().size() > m_row) { |
| 219 | + return false; |
| 220 | + } |
| 221 | + } |
| 222 | + return true; |
| 223 | +} |
| 224 | + |
| 225 | +bool ThresholdAlgorithm::check_finish_condition() |
| 226 | +{ |
| 227 | + if (m_min_queue.size() < m_top_count) { |
| 228 | + return false; |
| 229 | + } |
| 230 | + m_threshold = estabilish_threshold(m_tables, m_view, m_row); |
| 231 | + |
| 232 | + if (m_view.compare(m_threshold.data(), m_min_queue.top()) == CmpResult::Lt) { |
| 233 | + return false; |
| 234 | + } |
| 235 | + return true; |
| 236 | +} |
| 237 | + |
| 238 | +} // aggregator |
| 239 | +} // fdsdump |
0 commit comments