Skip to content

Commit 64c0c4f

Browse files
Vectorize mismatch for clang-cl for odd element sizes (#5591)
Co-authored-by: Stephan T. Lavavej <[email protected]>
1 parent c95e938 commit 64c0c4f

File tree

4 files changed

+110
-21
lines changed

4 files changed

+110
-21
lines changed

benchmarks/src/mismatch.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,24 @@ enum class op {
1919
lexi,
2020
};
2121

22-
template <class T, op Op>
22+
struct color {
23+
uint16_t h;
24+
uint16_t s;
25+
uint16_t l;
26+
27+
bool operator==(const color&) const = default;
28+
};
29+
30+
constexpr color c1{30000, 40000, 20000};
31+
constexpr color c2{30000, 40000, 30000};
32+
33+
template <class T, op Op, T MatchVal = T{'.'}, T MismatchVal = T{'x'}>
2334
void bm(benchmark::State& state) {
24-
vector<T, not_highly_aligned_allocator<T>> a(static_cast<size_t>(state.range(0)), T{'.'});
25-
vector<T, not_highly_aligned_allocator<T>> b(static_cast<size_t>(state.range(0)), T{'.'});
35+
vector<T, not_highly_aligned_allocator<T>> a(static_cast<size_t>(state.range(0)), MatchVal);
36+
vector<T, not_highly_aligned_allocator<T>> b(static_cast<size_t>(state.range(0)), MatchVal);
2637

2738
if (state.range(1) != no_pos) {
28-
b.at(static_cast<size_t>(state.range(1))) = 'x';
39+
b.at(static_cast<size_t>(state.range(1))) = MismatchVal;
2940
}
3041

3142
for (auto _ : state) {
@@ -45,6 +56,7 @@ BENCHMARK(bm<uint8_t, op::mismatch>)->Apply(common_args);
4556
BENCHMARK(bm<uint16_t, op::mismatch>)->Apply(common_args);
4657
BENCHMARK(bm<uint32_t, op::mismatch>)->Apply(common_args);
4758
BENCHMARK(bm<uint64_t, op::mismatch>)->Apply(common_args);
59+
BENCHMARK(bm<color, op::mismatch, c1, c2>)->Apply(common_args);
4860

4961
BENCHMARK(bm<uint8_t, op::lexi>)->Apply(common_args); // still optimized without vector algorithms using memcmp
5062
BENCHMARK(bm<int8_t, op::lexi>)->Apply(common_args); // optimized with vector algorithms only

stl/inc/algorithm

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -885,7 +885,7 @@ _NODISCARD _CONSTEXPR20 pair<_InIt1, _InIt2> mismatch(_InIt1 _First1, const _InI
885885
const auto _ULast1 = _STD _Get_unwrapped(_Last1);
886886
auto _UFirst2 = _STD _Get_unwrapped_n(_First2, _STD _Idl_distance<_InIt1>(_UFirst1, _ULast1));
887887
#if _USE_STD_VECTOR_ALGORITHMS
888-
if constexpr (_Vector_alg_in_search_is_safe<decltype(_UFirst1), decltype(_UFirst2), _Pr>) {
888+
if constexpr (_Equal_memcmp_is_safe<decltype(_UFirst1), decltype(_UFirst2), _Pr>) {
889889
if (!_STD _Is_constant_evaluated()) {
890890
constexpr size_t _Elem_size = sizeof(_Iter_value_t<_InIt1>);
891891

@@ -949,7 +949,7 @@ _NODISCARD _CONSTEXPR20 pair<_InIt1, _InIt2> mismatch(
949949
const auto _Count = static_cast<_Iter_diff_t<_InIt1>>((_STD min) (_Count1, _Count2));
950950
_ULast1 = _UFirst1 + _Count;
951951
#if _USE_STD_VECTOR_ALGORITHMS
952-
if constexpr (_Vector_alg_in_search_is_safe<decltype(_UFirst1), decltype(_UFirst2), _Pr>) {
952+
if constexpr (_Equal_memcmp_is_safe<decltype(_UFirst1), decltype(_UFirst2), _Pr>) {
953953
if (!_STD _Is_constant_evaluated()) {
954954
constexpr size_t _Elem_size = sizeof(_Iter_value_t<_InIt1>);
955955

stl/inc/xutility

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -447,16 +447,14 @@ auto _Max_vectorized(_Ty* const _First, _Ty* const _Last) noexcept {
447447

448448
template <size_t _Element_size>
449449
size_t _Mismatch_vectorized(const void* const _First1, const void* const _First2, const size_t _Count) noexcept {
450-
if constexpr (_Element_size == 1) {
451-
return __std_mismatch_1(_First1, _First2, _Count);
452-
} else if constexpr (_Element_size == 2) {
453-
return __std_mismatch_2(_First1, _First2, _Count);
454-
} else if constexpr (_Element_size == 4) {
455-
return __std_mismatch_4(_First1, _First2, _Count);
456-
} else if constexpr (_Element_size == 8) {
457-
return __std_mismatch_8(_First1, _First2, _Count);
450+
if constexpr (_Element_size % 8 == 0) {
451+
return __std_mismatch_8(_First1, _First2, _Count * (_Element_size / 8)) / (_Element_size / 8);
452+
} else if constexpr (_Element_size % 4 == 0) {
453+
return __std_mismatch_4(_First1, _First2, _Count * (_Element_size / 4)) / (_Element_size / 4);
454+
} else if constexpr (_Element_size % 2 == 0) {
455+
return __std_mismatch_2(_First1, _First2, _Count * (_Element_size / 2)) / (_Element_size / 2);
458456
} else {
459-
_STL_INTERNAL_STATIC_ASSERT(false); // unexpected size
457+
return __std_mismatch_1(_First1, _First2, _Count * _Element_size) / _Element_size;
460458
}
461459
}
462460
_STD_END
@@ -5772,7 +5770,7 @@ namespace ranges {
57725770
_It1 _First1, _It2 _First2, iter_difference_t<_It1> _Count, _Pr _Pred, _Pj1 _Proj1, _Pj2 _Proj2) {
57735771
_STL_INTERNAL_CHECK(_Count >= 0);
57745772
#if _USE_STD_VECTOR_ALGORITHMS
5775-
if constexpr (_Vector_alg_in_search_is_safe<_It1, _It2, _Pr> && is_same_v<_Pj1, identity>
5773+
if constexpr (_Equal_memcmp_is_safe<_It1, _It2, _Pr> && is_same_v<_Pj1, identity>
57765774
&& is_same_v<_Pj2, identity>) {
57775775
if (!_STD is_constant_evaluated()) {
57785776
constexpr size_t _Elem_size = sizeof(iter_value_t<_It1>);

tests/std/tests/VSO_0000000_vector_algorithms_mismatch_and_lex_compare/test.cpp

Lines changed: 84 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,20 +64,28 @@ auto last_known_good_lex_compare_3way(pair<FwdIt, FwdIt> expected_mismatch, FwdI
6464
#endif // _HAS_CXX20
6565

6666
template <class T>
67-
void test_case_mismatch_and_lex_compare_family(const vector<T>& a, const vector<T>& b) {
67+
auto test_case_mismatch_only(const vector<T>& a, const vector<T>& b) {
6868
auto expected_mismatch = last_known_good_mismatch(a.begin(), a.end(), b.begin(), b.end());
6969
auto actual_mismatch = mismatch(a.begin(), a.end(), b.begin(), b.end());
7070
assert(expected_mismatch == actual_mismatch);
7171

72-
auto expected_lex = last_known_good_lex_compare(expected_mismatch, a.end(), b.end());
73-
auto actual_lex = lexicographical_compare(a.begin(), a.end(), b.begin(), b.end());
74-
assert(expected_lex == actual_lex);
75-
7672
#if _HAS_CXX20
7773
auto ranges_actual_mismatch = ranges::mismatch(a, b);
7874
assert(get<0>(expected_mismatch) == ranges_actual_mismatch.in1);
7975
assert(get<1>(expected_mismatch) == ranges_actual_mismatch.in2);
76+
#endif // _HAS_CXX20
77+
return expected_mismatch;
78+
}
79+
80+
template <class T>
81+
void test_case_mismatch_and_lex_compare_family(const vector<T>& a, const vector<T>& b) {
82+
auto expected_mismatch = test_case_mismatch_only(a, b);
83+
84+
auto expected_lex = last_known_good_lex_compare(expected_mismatch, a.end(), b.end());
85+
auto actual_lex = lexicographical_compare(a.begin(), a.end(), b.begin(), b.end());
86+
assert(expected_lex == actual_lex);
8087

88+
#if _HAS_CXX20
8189
auto ranges_actual_lex = ranges::lexicographical_compare(a, b);
8290
assert(expected_lex == ranges_actual_lex);
8391

@@ -130,6 +138,65 @@ void test_mismatch_and_lex_compare_family(mt19937_64& gen) {
130138
}
131139
}
132140

141+
#if _HAS_CXX20
142+
template <class T>
143+
struct triplet {
144+
T x;
145+
T y;
146+
T z;
147+
148+
bool operator==(const triplet&) const = default;
149+
};
150+
151+
template <class T>
152+
void test_mismatch_only_triplets(mt19937_64& gen) {
153+
constexpr size_t shrinkCount = 4;
154+
constexpr size_t mismatchCount = 10;
155+
using TD = conditional_t<sizeof(T) == 1, int, T>;
156+
uniform_int_distribution<TD> dis('a', 'z');
157+
vector<triplet<T>> input_a;
158+
vector<triplet<T>> input_b;
159+
input_a.reserve(dataCount);
160+
input_b.reserve(dataCount);
161+
162+
for (;;) {
163+
// equal
164+
test_case_mismatch_only(input_a, input_b);
165+
166+
// different sizes
167+
for (size_t i = 0; i != shrinkCount && !input_b.empty(); ++i) {
168+
input_b.pop_back();
169+
test_case_mismatch_only(input_a, input_b);
170+
test_case_mismatch_only(input_b, input_a);
171+
}
172+
173+
// actual mismatch (or maybe not, depending on random)
174+
if (!input_b.empty()) {
175+
uniform_int_distribution<size_t> mismatch_dis(0, input_a.size() - 1);
176+
177+
for (size_t attempts = 0; attempts < mismatchCount; ++attempts) {
178+
const size_t possible_mismatch_pos = mismatch_dis(gen);
179+
input_a[possible_mismatch_pos].x = static_cast<T>(dis(gen));
180+
input_a[possible_mismatch_pos].y = static_cast<T>(dis(gen));
181+
input_a[possible_mismatch_pos].z = static_cast<T>(dis(gen));
182+
test_case_mismatch_only(input_a, input_b);
183+
test_case_mismatch_only(input_b, input_a);
184+
}
185+
}
186+
187+
if (input_a.size() == dataCount) {
188+
break;
189+
}
190+
191+
input_a.emplace_back();
192+
input_a.back().x = static_cast<T>(dis(gen));
193+
input_a.back().y = static_cast<T>(dis(gen));
194+
input_a.back().z = static_cast<T>(dis(gen));
195+
input_b = input_a;
196+
}
197+
}
198+
#endif // _HAS_CXX20
199+
133200
template <class C1, class C2>
134201
void test_mismatch_and_lex_compare_family_containers() {
135202
C1 a{'m', 'e', 'o', 'w', ' ', 'C', 'A', 'T', 'S'};
@@ -245,6 +312,18 @@ void test_vector_algorithms(mt19937_64& gen) {
245312
test_mismatch_and_lex_compare_family<long long>(gen);
246313
test_mismatch_and_lex_compare_family<unsigned long long>(gen);
247314

315+
#if _HAS_CXX20
316+
test_mismatch_only_triplets<char>(gen);
317+
test_mismatch_only_triplets<signed char>(gen);
318+
test_mismatch_only_triplets<unsigned char>(gen);
319+
test_mismatch_only_triplets<short>(gen);
320+
test_mismatch_only_triplets<unsigned short>(gen);
321+
test_mismatch_only_triplets<int>(gen);
322+
test_mismatch_only_triplets<unsigned int>(gen);
323+
test_mismatch_only_triplets<long long>(gen);
324+
test_mismatch_only_triplets<unsigned long long>(gen);
325+
#endif // _HAS_CXX20
326+
248327
test_mismatch_and_lex_compare_family_containers<vector<char>, vector<signed char>>();
249328
test_mismatch_and_lex_compare_family_containers<vector<char>, vector<unsigned char>>();
250329
test_mismatch_and_lex_compare_family_containers<vector<wchar_t>, vector<char>>();

0 commit comments

Comments
 (0)