Skip to content

Commit c16d288

Browse files
committed
Optimize insertion in dynamic_bitset
1 parent deae754 commit c16d288

File tree

1 file changed

+252
-47
lines changed

1 file changed

+252
-47
lines changed

include/sparrow/buffer/dynamic_bitset/dynamic_bitset_base.hpp

Lines changed: 252 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717

1818
#include <algorithm>
1919
#include <bit>
20+
#include <iterator>
2021
#include <stdexcept>
2122
#include <string>
2223
#include <type_traits>
24+
#include <vector>
2325

2426
#include "sparrow/buffer/dynamic_bitset/bitset_iterator.hpp"
2527
#include "sparrow/buffer/dynamic_bitset/bitset_reference.hpp"
@@ -537,6 +539,13 @@ namespace sparrow
537539
*/
538540
constexpr void update_null_count(bool old_value, bool new_value);
539541

542+
// Efficient bit manipulation helpers for insert operations
543+
constexpr void shift_bits_right(size_type start_pos, size_type bit_count, size_type shift_amount);
544+
constexpr void fill_bits_range(size_type start_pos, size_type bit_count, value_type value);
545+
template <std::random_access_iterator InputIt>
546+
constexpr iterator
547+
insert_range_random_access(const_iterator pos, InputIt first, size_type count, size_type index);
548+
540549
storage_type m_buffer; ///< The underlying storage for bit data
541550
size_type m_size; ///< The number of bits in the bitset
542551
size_type m_null_count; ///< The number of bits set to false
@@ -968,30 +977,31 @@ namespace sparrow
968977
SPARROW_ASSERT_TRUE(cbegin() <= pos);
969978
SPARROW_ASSERT_TRUE(pos <= cend());
970979
const auto index = static_cast<size_type>(std::distance(cbegin(), pos));
980+
971981
if (data() == nullptr && value)
972982
{
973983
m_size += count;
984+
return iterator(this, index);
974985
}
975-
else
976-
{
977-
const size_type old_size = size();
978-
const size_type new_size = old_size + count;
979986

980-
// TODO: The current implementation is not efficient. It can be improved.
987+
if (count == 0)
988+
{
989+
return iterator(this, index);
990+
}
981991

982-
resize(new_size);
992+
const size_type old_size = size();
993+
const size_type new_size = old_size + count;
994+
const size_type bits_to_move = old_size - index;
983995

984-
for (size_type i = old_size + count - 1; i >= index + count; --i)
985-
{
986-
set(i, test(i - count));
987-
}
996+
resize(new_size);
988997

989-
for (size_type i = 0; i < count; ++i)
990-
{
991-
set(index + i, value);
992-
}
998+
if (bits_to_move > 0)
999+
{
1000+
shift_bits_right(index, bits_to_move, count);
9931001
}
9941002

1003+
fill_bits_range(index, count, value);
1004+
9951005
return iterator(this, index);
9961006
}
9971007

@@ -1002,43 +1012,19 @@ namespace sparrow
10021012
dynamic_bitset_base<B>::insert(const_iterator pos, InputIt first, InputIt last)
10031013
{
10041014
const auto index = static_cast<size_type>(std::distance(cbegin(), pos));
1005-
const auto count = static_cast<size_type>(std::distance(first, last));
1006-
if (data() == nullptr)
1007-
{
1008-
if (std::all_of(
1009-
first,
1010-
last,
1011-
[](auto v)
1012-
{
1013-
return bool(v);
1014-
}
1015-
))
1016-
{
1017-
m_size += count;
1018-
}
1019-
return iterator(this, index);
1020-
}
1021-
SPARROW_ASSERT_TRUE(cbegin() <= pos);
1022-
SPARROW_ASSERT_TRUE(pos <= cend());
1023-
1024-
const size_type old_size = size();
1025-
const size_type new_size = old_size + count;
10261015

1027-
resize(new_size);
1028-
1029-
// TODO: The current implementation is not efficient. It can be improved.
1030-
1031-
for (size_type i = old_size + count - 1; i >= index + count; --i)
1016+
if constexpr (std::random_access_iterator<InputIt>)
10321017
{
1033-
set(i, test(i - count));
1018+
// Fast path for random access iterators
1019+
const auto count = static_cast<size_type>(std::distance(first, last));
1020+
return insert_range_random_access(pos, first, count, index);
10341021
}
1035-
1036-
for (size_type i = 0; i < count; ++i)
1022+
else
10371023
{
1038-
set(index + i, *first++);
1024+
// Slower path for input iterators - collect values first
1025+
std::vector<value_type> values(first, last);
1026+
return insert_range_random_access(pos, values.begin(), values.size(), index);
10391027
}
1040-
1041-
return iterator(this, index);
10421028
}
10431029

10441030
template <typename B>
@@ -1087,7 +1073,7 @@ namespace sparrow
10871073
// TODO: The current implementation is not efficient. It can be improved.
10881074

10891075
const size_type bit_to_move = size() - last_index;
1090-
for (size_type i = 0; i < bit_to_move; ++i)
1076+
for (size_t i = 0; i < bit_to_move; ++i)
10911077
{
10921078
set(first_index + i, test(last_index + i));
10931079
}
@@ -1114,4 +1100,223 @@ namespace sparrow
11141100
}
11151101
resize(size() - 1);
11161102
}
1103+
1104+
// Efficient helper functions for insert operations
1105+
1106+
template <typename B>
1107+
requires std::ranges::random_access_range<std::remove_pointer_t<B>>
1108+
constexpr void
1109+
dynamic_bitset_base<B>::shift_bits_right(size_type start_pos, size_type bit_count, size_type shift_amount)
1110+
{
1111+
if (bit_count == 0 || shift_amount == 0 || data() == nullptr)
1112+
{
1113+
return;
1114+
}
1115+
1116+
const size_type end_pos = start_pos + bit_count;
1117+
1118+
// Calculate block boundaries
1119+
const size_type start_block = block_index(start_pos);
1120+
const size_type end_block = block_index(end_pos - 1);
1121+
const size_type target_start_block = block_index(start_pos + shift_amount);
1122+
const size_type target_end_block = block_index(end_pos + shift_amount - 1);
1123+
1124+
// If the shift spans multiple blocks, use block-level operations
1125+
if (shift_amount >= s_bits_per_block && start_block != end_block)
1126+
{
1127+
const size_type block_shift = shift_amount / s_bits_per_block;
1128+
const size_type bit_shift = shift_amount % s_bits_per_block;
1129+
1130+
// Move whole blocks first
1131+
for (size_type i = end_block; i >= start_block && i != SIZE_MAX; --i)
1132+
{
1133+
const size_type target_block = i + block_shift;
1134+
if (target_block < buffer().size())
1135+
{
1136+
buffer().data()[target_block] = buffer().data()[i];
1137+
}
1138+
}
1139+
1140+
// Handle remaining bit shift within blocks
1141+
if (bit_shift > 0)
1142+
{
1143+
for (size_type i = target_end_block; i > target_start_block && i != SIZE_MAX; --i)
1144+
{
1145+
const block_type current = buffer().data()[i];
1146+
const block_type previous = (i > 0) ? buffer().data()[i - 1] : block_type(0);
1147+
buffer().data()[i] = static_cast<block_type>(
1148+
(current << bit_shift) | (previous >> (s_bits_per_block - bit_shift))
1149+
);
1150+
}
1151+
if (target_start_block < buffer().size())
1152+
{
1153+
buffer().data()[target_start_block] = static_cast<block_type>(
1154+
buffer().data()[target_start_block] << bit_shift
1155+
);
1156+
}
1157+
}
1158+
}
1159+
else
1160+
{
1161+
// For smaller shifts, use bit-level operations optimized for the shift amount
1162+
for (size_type i = bit_count; i > 0; --i)
1163+
{
1164+
const size_t src_pos = start_pos + i - 1;
1165+
const size_t dst_pos = src_pos + shift_amount;
1166+
set(dst_pos, test(src_pos));
1167+
}
1168+
}
1169+
}
1170+
1171+
template <typename B>
1172+
requires std::ranges::random_access_range<std::remove_pointer_t<B>>
1173+
constexpr void
1174+
dynamic_bitset_base<B>::fill_bits_range(size_type start_pos, size_type bit_count, value_type value)
1175+
{
1176+
if (bit_count == 0 || data() == nullptr)
1177+
{
1178+
return;
1179+
}
1180+
1181+
const size_type end_pos = start_pos + bit_count;
1182+
const size_type start_block = block_index(start_pos);
1183+
const size_type end_block = block_index(end_pos - 1);
1184+
1185+
const block_type fill_value = value ? block_type(~block_type(0)) : block_type(0);
1186+
1187+
if (start_block == end_block)
1188+
{
1189+
// All bits are in the same block - use efficient bit masking
1190+
const size_type start_bit = bit_index(start_pos);
1191+
const size_type end_bit = bit_index(end_pos - 1);
1192+
const size_type mask_width = end_bit - start_bit + 1;
1193+
const block_type mask = static_cast<block_type>(((block_type(1) << mask_width) - 1) << start_bit);
1194+
1195+
if (value)
1196+
{
1197+
buffer().data()[start_block] |= mask;
1198+
}
1199+
else
1200+
{
1201+
buffer().data()[start_block] &= ~mask;
1202+
}
1203+
}
1204+
else
1205+
{
1206+
// Handle first partial block
1207+
const size_type start_bit = bit_index(start_pos);
1208+
if (start_bit != 0)
1209+
{
1210+
const block_type mask = static_cast<block_type>(~block_type(0) << start_bit);
1211+
if (value)
1212+
{
1213+
buffer().data()[start_block] |= mask;
1214+
}
1215+
else
1216+
{
1217+
buffer().data()[start_block] &= ~mask;
1218+
}
1219+
}
1220+
else
1221+
{
1222+
buffer().data()[start_block] = fill_value;
1223+
}
1224+
1225+
// Handle full blocks in between
1226+
for (size_type block = start_block + 1; block < end_block; ++block)
1227+
{
1228+
buffer().data()[block] = fill_value;
1229+
}
1230+
1231+
// Handle last partial block
1232+
const size_type end_bit = bit_index(end_pos - 1);
1233+
const block_type mask = static_cast<block_type>((block_type(1) << (end_bit + 1)) - 1);
1234+
if (value)
1235+
{
1236+
buffer().data()[end_block] |= mask;
1237+
}
1238+
else
1239+
{
1240+
buffer().data()[end_block] &= ~mask;
1241+
}
1242+
}
1243+
1244+
m_null_count = m_size - count_non_null();
1245+
}
1246+
1247+
template <typename B>
1248+
requires std::ranges::random_access_range<std::remove_pointer_t<B>>
1249+
template <std::random_access_iterator InputIt>
1250+
constexpr auto dynamic_bitset_base<B>::insert_range_random_access(
1251+
const_iterator /* pos */,
1252+
InputIt first,
1253+
size_type count,
1254+
size_type index
1255+
) -> iterator
1256+
{
1257+
if (data() == nullptr)
1258+
{
1259+
if (std::all_of(
1260+
first,
1261+
std::next(first, static_cast<std::ptrdiff_t>(count)),
1262+
[](auto v)
1263+
{
1264+
return bool(v);
1265+
}
1266+
))
1267+
{
1268+
m_size += count;
1269+
}
1270+
return iterator(this, index);
1271+
}
1272+
1273+
if (count == 0)
1274+
{
1275+
return iterator(this, index);
1276+
}
1277+
1278+
const size_type old_size = size();
1279+
const size_type new_size = old_size + count;
1280+
const size_type bits_to_move = old_size - index;
1281+
1282+
resize(new_size);
1283+
1284+
if (bits_to_move > 0)
1285+
{
1286+
shift_bits_right(index, bits_to_move, count);
1287+
}
1288+
1289+
// Set bits efficiently in batches
1290+
constexpr size_type batch_size = s_bits_per_block;
1291+
for (size_type i = 0; i < count; i += batch_size)
1292+
{
1293+
const size_type current_batch_size = std::min(batch_size, count - i);
1294+
const size_type batch_start = index + i;
1295+
1296+
// Process bits in the current batch
1297+
if (current_batch_size == s_bits_per_block && bit_index(batch_start) == 0)
1298+
{
1299+
// Optimized path: entire block can be set at once
1300+
block_type block_value = 0;
1301+
for (size_type j = 0; j < s_bits_per_block; ++j)
1302+
{
1303+
if (bool(*(std::next(first, static_cast<std::ptrdiff_t>(i + j)))))
1304+
{
1305+
block_value |= static_cast<block_type>(block_type(1) << j);
1306+
}
1307+
}
1308+
buffer().data()[block_index(batch_start)] = block_value;
1309+
}
1310+
else
1311+
{
1312+
// Fallback to bit-by-bit setting for partial blocks
1313+
for (size_type j = 0; j < current_batch_size; ++j)
1314+
{
1315+
set(batch_start + j, bool(*(std::next(first, static_cast<std::ptrdiff_t>(i + j)))));
1316+
}
1317+
}
1318+
}
1319+
1320+
return iterator(this, index);
1321+
}
11171322
}

0 commit comments

Comments
 (0)