Skip to content

Commit 5ea0698

Browse files
committed
Add: C++ argsort, intersect
1 parent 1ce830b commit 5ea0698

File tree

1 file changed

+133
-57
lines changed

1 file changed

+133
-57
lines changed

include/stringzilla/stringzilla.hpp

Lines changed: 133 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -4011,68 +4011,105 @@ void lookup(basic_string_slice<char_type_> string, basic_look_up_table<char_type
40114011
* @brief Internal data-structure used to wrap arbitrary sequential containers with a random-order lookup.
40124012
* @sa try_argsort, argsort, try_join, join
40134013
*/
4014-
template <typename objects_type_, typename string_extractor_>
4014+
template <typename container_type_, typename string_extractor_>
40154015
struct _sequence_args {
4016-
objects_type_ const *begin;
4017-
std::size_t count;
4018-
sorted_idx_t *order;
4019-
string_extractor_ extractor;
4016+
container_type_ const &container;
4017+
string_extractor_ const &extractor;
40204018
};
40214019

4022-
template <typename objects_type_, typename string_extractor_>
4023-
sz_cptr_t _call_sequence_member_start(void const *sequence, sz_size_t i) {
4024-
using handle_type = _sequence_args<objects_type_, string_extractor_>;
4025-
handle_type const *args = reinterpret_cast<handle_type const *>(sequence);
4026-
string_view member = args->extractor(args->begin[i]);
4020+
template <typename container_type_, typename string_extractor_>
4021+
sz_cptr_t _call_sequence_member_start(void const *sequence_args_ptr, sz_size_t i) {
4022+
using sequence_args_t = _sequence_args<container_type_, string_extractor_>;
4023+
sequence_args_t const *args = reinterpret_cast<sequence_args_t const *>(sequence_args_ptr);
4024+
string_view member = args->extractor(args->container[i]);
40274025
return member.data();
40284026
}
40294027

4030-
template <typename objects_type_, typename string_extractor_>
4031-
sz_size_t _call_sequence_member_length(void const *sequence, sz_size_t i) {
4032-
using handle_type = _sequence_args<objects_type_, string_extractor_>;
4033-
handle_type const *args = reinterpret_cast<handle_type const *>(sequence);
4034-
string_view member = args->extractor(args->begin[i]);
4028+
template <typename container_type_, typename string_extractor_>
4029+
sz_size_t _call_sequence_member_length(void const *sequence_args_ptr, sz_size_t i) {
4030+
using sequence_args_t = _sequence_args<container_type_, string_extractor_>;
4031+
sequence_args_t const *args = reinterpret_cast<sequence_args_t const *>(sequence_args_ptr);
4032+
string_view member = args->extractor(args->container[i]);
40354033
return static_cast<sz_size_t>(member.size());
40364034
}
40374035

40384036
/**
40394037
* @brief Computes the permutation of an array, that would lead to sorted order.
40404038
* The elements of the array must be convertible to a `string_view` with the given extractor.
40414039
* Unlike the `sz_sequence_argsort` C interface, overwrites the output array.
4040+
* @sa sz_sequence_argsort
40424041
*
4043-
* @param[in] begin The pointer to the first element of the array.
4044-
* @param[in] end The pointer to the element after the last element of the array.
4045-
* @param[out] order The pointer to the output array of indices, that will be populated with the permutation.
4046-
* @param[in] extractor The function object that extracts the string from the object.
4047-
*
4048-
* @see sz_sequence_argsort
4042+
* @param[in] begin The pointer to the first element of the array.
4043+
* @param[in] end The pointer to the element after the last element of the array.
4044+
* @param[in] extractor The function object that extracts the string from the object.
4045+
* @param[out] order The pointer to the output array of indices, that will be populated with the permutation.
40494046
*/
4050-
template <typename objects_type_, typename string_extractor_>
4051-
void argsort(objects_type_ const *begin, objects_type_ const *end, sorted_idx_t *order,
4052-
string_extractor_ &&extractor) noexcept {
4047+
template <typename container_type_, typename string_extractor_>
4048+
status_t try_argsort(container_type_ const &container, string_extractor_ const &extractor,
4049+
sorted_idx_t *order) noexcept {
40534050

40544051
// Pack the arguments into a single structure to reference it from the callback.
4055-
_sequence_args<objects_type_, string_extractor_> args = {begin, static_cast<std::size_t>(end - begin), order,
4056-
std::forward<string_extractor_>(extractor)};
4057-
// Populate the array with `iota`-style order.
4058-
for (std::size_t i = 0; i != args.count; ++i) order[i] = static_cast<sorted_idx_t>(i);
4052+
using args_t = _sequence_args<container_type_, string_extractor_>;
4053+
args_t args {container, extractor};
4054+
sz_sequence_t sequence;
4055+
sequence.handle = &args;
4056+
sequence.count = container.size();
4057+
sequence.get_start = _call_sequence_member_start<container_type_, string_extractor_>;
4058+
sequence.get_length = _call_sequence_member_length<container_type_, string_extractor_>;
40594059

4060-
sz_sequence_t array;
4061-
array.count = args.count;
4062-
array.handle = &args;
4063-
array.get_start = _call_sequence_member_start<objects_type_, string_extractor_>;
4064-
array.get_length = _call_sequence_member_length<objects_type_, string_extractor_>;
4060+
using sz_alloc_type = sz_memory_allocator_t;
4061+
return _with_alloc<std::allocator<sz_u8_t>>(
4062+
[&](sz_alloc_type &alloc) { return sz_sequence_argsort(&sequence, &alloc, order); });
4063+
}
4064+
4065+
/**
4066+
* @brief Locates the positions of the elements in 2 deduplicated string arrays that have identical values.
4067+
* @sa sz_sequence_join
4068+
*
4069+
* @param[in] first_begin The pointer to the first element of the first array.
4070+
* @param[in] first_end The pointer to the element after the last element of the first array.
4071+
* @param[in] second_begin The pointer to the first element of the second array.
4072+
* @param[in] second_end The pointer to the element after the last element of the second array.
4073+
* @param[out] first_positions The pointer to the output array of indices from the first array.
4074+
* @param[out] second_positions The pointer to the output array of indices from the second array.
4075+
* @param[in] first_extractor The function object that extracts the string from the object in the first array.
4076+
* @param[in] second_extractor The function object that extracts the string from the object in the second array.
4077+
*/
4078+
template <typename first_container_, typename second_container_, typename first_extractor_, typename second_extractor_>
4079+
status_t try_intersect( //
4080+
first_container_ const &first_container, first_extractor_ const &first_extractor, //
4081+
second_container_ const &second_container, second_extractor_ const &second_extractor, //
4082+
std::uint64_t seed, std::size_t *intersection_size_ptr, //
4083+
sorted_idx_t *first_positions, sorted_idx_t *second_positions) noexcept {
4084+
4085+
// Pack the arguments into a single structure to reference it from the callback.
4086+
using first_t = _sequence_args<first_container_, first_extractor_>;
4087+
using second_t = _sequence_args<second_container_, second_extractor_>;
4088+
first_t first_args {first_container, first_extractor};
4089+
second_t second_args {second_container, second_extractor};
4090+
4091+
sz_sequence_t first_sequence, second_sequence;
4092+
first_sequence.count = first_container.size(), second_sequence.count = second_container.size();
4093+
first_sequence.handle = &first_args, second_sequence.handle = &second_args;
4094+
first_sequence.get_start = _call_sequence_member_start<first_container_, first_extractor_>;
4095+
first_sequence.get_length = _call_sequence_member_length<first_container_, first_extractor_>;
4096+
second_sequence.get_start = _call_sequence_member_start<second_container_, second_extractor_>;
4097+
second_sequence.get_length = _call_sequence_member_length<second_container_, second_extractor_>;
40654098

40664099
using sz_alloc_type = sz_memory_allocator_t;
4067-
_with_alloc<std::allocator<sz_u8_t>>(
4068-
[&](sz_alloc_type &alloc) { return sz_sequence_argsort(&array, &alloc, order); });
4100+
return _with_alloc<std::allocator<sz_u8_t>>([&](sz_alloc_type &alloc) {
4101+
static_assert(sizeof(sz_size_t) == sizeof(std::size_t), "sz_size_t must be the same size as std::size_t.");
4102+
return sz_sequence_intersect(&first_sequence, &second_sequence, &alloc, static_cast<sz_u64_t>(seed),
4103+
reinterpret_cast<sz_size_t *>(intersection_size_ptr), first_positions,
4104+
second_positions);
4105+
});
40694106
}
40704107

40714108
#if !SZ_AVOID_STL
40724109
#if _SZ_DEPRECATED_FINGERPRINTS
40734110
/**
4074-
* @brief Computes the Rabin-Karp-like rolling binary fingerprint of a string.
4075-
* @see sz_hashes
4111+
* @brief Computes the Rabin-Karp-like rolling binary fingerprint of a string.
4112+
* @sa sz_hashes
40764113
*/
40774114
template <std::size_t bitset_bits_, typename char_type_>
40784115
void hashes_fingerprint( //
@@ -4105,41 +4142,80 @@ std::bitset<bitset_bits_> hashes_fingerprint(basic_string<char_type_> const &str
41054142
#endif
41064143

41074144
/**
4108-
* @brief Computes the permutation of an array, that would lead to sorted order.
4145+
* @brief Computes the permutation of an array, that would lead to sorted order.
41094146
* @return The array of indices, that will be populated with the permutation.
4110-
* @throw `std::bad_alloc` if the allocation fails.
4147+
* @throw `std::bad_alloc` if the allocation fails.
41114148
*/
4112-
template <typename objects_type_, typename string_extractor_>
4149+
template <typename container_type_, typename string_extractor_>
41134150
std::vector<sorted_idx_t> argsort( //
4114-
objects_type_ const *begin, objects_type_ const *end, string_extractor_ &&extractor) noexcept(false) {
4115-
std::vector<sorted_idx_t> order(end - begin);
4116-
argsort(begin, end, order.data(), std::forward<string_extractor_>(extractor));
4151+
container_type_ const &container, string_extractor_ const &extractor) noexcept(false) {
4152+
std::vector<sorted_idx_t> order(container.size());
4153+
status_t status = try_argsort(container, extractor, order.data());
4154+
raise(status);
41174155
return order;
41184156
}
41194157

41204158
/**
4121-
* @brief Computes the permutation of an array, that would lead to sorted order.
4159+
* @brief Computes the permutation of an array, that would lead to sorted order.
41224160
* @return The array of indices, that will be populated with the permutation.
4123-
* @throw `std::bad_alloc` if the allocation fails.
4161+
* @throw `std::bad_alloc` if the allocation fails.
41244162
*/
4125-
template <typename string_like_type_>
4126-
std::vector<sorted_idx_t> argsort(string_like_type_ const *begin, string_like_type_ const *end) noexcept(false) {
4163+
template <typename container_type_>
4164+
std::vector<sorted_idx_t> argsort(container_type_ const &container) noexcept(false) {
4165+
using string_like_type = typename container_type_::value_type;
41274166
static_assert( //
4128-
std::is_convertible<string_like_type_, string_view>::value, "The type must be convertible to string_view.");
4129-
return argsort(begin, end, [](string_like_type_ const &s) -> string_view { return s; });
4167+
std::is_convertible<string_like_type, string_view>::value, "The type must be convertible to string_view.");
4168+
return argsort(container, [](string_like_type const &s) -> string_view { return s; });
41304169
}
41314170

4171+
struct intersect_result_t {
4172+
std::vector<std::size_t> first_offsets;
4173+
std::vector<std::size_t> second_offsets;
4174+
};
4175+
41324176
/**
4133-
* @brief Computes the permutation of an array, that would lead to sorted order.
4134-
* @return The array of indices, that will be populated with the permutation.
4135-
* @throw `std::bad_alloc` if the allocation fails.
4177+
* @brief Locates identical elements in two arrays.
4178+
* @return Two arrays of indicies, mapping the elements of the first and the second array that have identical values.
4179+
* @throw `std::bad_alloc` if the allocation fails.
41364180
*/
4137-
template <typename string_like_type_>
4138-
std::vector<sorted_idx_t> argsort(std::vector<string_like_type_> const &array) noexcept(false) {
4181+
template <typename first_type_, typename second_type_, typename first_extractor_, typename second_extractor_>
4182+
intersect_result_t intersect(first_type_ const &first, second_type_ const &second,
4183+
first_extractor_ const &first_extractor, second_extractor_ const &second_extractor,
4184+
std::uint64_t seed = 0) noexcept(false) {
4185+
4186+
std::size_t const max_count = (std::min)(first.size(), second.size());
4187+
std::vector<sorted_idx_t> first_positions(max_count);
4188+
std::vector<sorted_idx_t> second_positions(max_count);
4189+
std::size_t count;
4190+
status_t status = try_intersect( //
4191+
first, first_extractor, //
4192+
second, second_extractor, //
4193+
seed, &count, first_positions.data(), second_positions.data());
4194+
raise(status);
4195+
first_positions.resize(count);
4196+
second_positions.resize(count);
4197+
return {std::move(first_positions), std::move(second_positions)};
4198+
}
4199+
4200+
/**
4201+
* @brief Locates identical elements in two arrays.
4202+
* @return Two arrays of indicies, mapping the elements of the first and the second array that have identical values.
4203+
* @throw `std::bad_alloc` if the allocation fails.
4204+
*/
4205+
template <typename first_type_, typename second_type_>
4206+
intersect_result_t intersect(first_type_ const &first, second_type_ const &second,
4207+
std::uint64_t seed = 0) noexcept(false) {
4208+
using first_string_type = typename first_type_::value_type;
4209+
using second_string_type = typename second_type_::value_type;
4210+
static_assert( //
4211+
std::is_convertible<first_string_type, string_view>::value, "The type must be convertible to string_view.");
41394212
static_assert( //
4140-
std::is_convertible<string_like_type_, string_view>::value, "The type must be convertible to string_view.");
4141-
return argsort(array.data(), array.data() + array.size(),
4142-
[](string_like_type_ const &s) -> string_view { return s; });
4213+
std::is_convertible<second_string_type, string_view>::value, "The type must be convertible to string_view.");
4214+
return intersect(
4215+
first, second, //
4216+
[](first_string_type const &s) -> string_view { return s; }, //
4217+
[](second_string_type const &s) -> string_view { return s; }, //
4218+
seed);
41434219
}
41444220

41454221
#endif

0 commit comments

Comments
 (0)