Skip to content

Commit b0b1037

Browse files
author
Raghuveer Devulapalli
authored
Merge pull request #92 from r-devulap/NAN
Add hasnan = false to all the sort methods
2 parents 64908e7 + 05400e0 commit b0b1037

14 files changed

+106
-148
lines changed

Makefile

Lines changed: 0 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,88 +1,10 @@
1-
# When unset, discover g++. Prioritise the latest version on the path.
2-
ifeq (, $(and $(strip $(CXX)), $(filter-out default undefined, $(origin CXX))))
3-
override CXX := $(shell which g++-13 g++-12 g++-11 g++-10 g++-9 g++-8 g++ 2>/dev/null | head -n 1)
4-
ifeq (, $(strip $(CXX)))
5-
$(error Could not locate the g++ compiler. Please manually specify its path using the CXX variable)
6-
endif
7-
endif
8-
9-
export CXX
10-
CXXFLAGS += $(OPTIMFLAG) $(MARCHFLAG)
11-
override CXXFLAGS += -I$(SRCDIR) -I$(UTILSDIR)
12-
GTESTCFLAGS := `pkg-config --cflags gtest_main`
13-
GTESTLDFLAGS := `pkg-config --static --libs gtest_main`
14-
GBENCHCFLAGS := `pkg-config --cflags benchmark`
15-
GBENCHLDFLAGS := `pkg-config --static --libs benchmark`
16-
OPTIMFLAG := -O3
17-
MARCHFLAG := -march=sapphirerapids
18-
19-
SRCDIR := ./src
20-
TESTDIR := ./tests
21-
BENCHDIR := ./benchmarks
22-
UTILSDIR := ./utils
23-
24-
SRCS := $(wildcard $(addprefix $(SRCDIR)/, *.hpp *.h))
25-
UTILSRCS := $(wildcard $(addprefix $(UTILSDIR)/, *.hpp *.h))
26-
TESTSRCS := $(wildcard $(addprefix $(TESTDIR)/, *.hpp *.h))
27-
BENCHSRCS := $(wildcard $(addprefix $(BENCHDIR)/, *.hpp *.h))
28-
UTILS := $(wildcard $(UTILSDIR)/*.cpp)
29-
TESTS := $(wildcard $(TESTDIR)/*.cpp)
30-
BENCHS := $(wildcard $(BENCHDIR)/*.cpp)
31-
32-
test_cxx_flag = $(shell 2>/dev/null $(CXX) -o /dev/null $(1) -c -x c++ /dev/null; echo $$?)
33-
34-
# Compiling AVX512-FP16 instructions wasn't possible until GCC 12
35-
ifeq ($(call test_cxx_flag,-mavx512fp16), 1)
36-
BENCHS_SKIP += bench-qsortfp16.cpp
37-
TESTS_SKIP += test-qsortfp16.cpp
38-
endif
39-
40-
# Sapphire Rapids was otherwise supported from GCC 11. Downgrade if required.
41-
ifeq ($(call test_cxx_flag,$(MARCHFLAG)), 1)
42-
MARCHFLAG := -march=icelake-client
43-
endif
44-
45-
BENCHOBJS := $(patsubst %.cpp, %.o, $(filter-out $(addprefix $(BENCHDIR)/, $(BENCHS_SKIP)), $(BENCHS)))
46-
TESTOBJS := $(patsubst %.cpp, %.o, $(filter-out $(addprefix $(TESTDIR)/, $(TESTS_SKIP)), $(TESTS)))
47-
UTILOBJS := $(UTILS:.cpp=.o)
48-
49-
# Stops make from wondering if it needs to generate the .hpp files (.cpp and .h have equivalent rules by default)
50-
%.hpp:
51-
52-
.PHONY: all
53-
.DEFAULT_GOAL := all
54-
all: test bench
55-
56-
.PHONY: test
57-
test: testexe
58-
59-
.PHONY: bench
60-
bench: benchexe
61-
62-
$(UTILOBJS): $(UTILSRCS)
63-
64-
$(TESTOBJS): $(TESTSRCS) $(UTILSRCS) $(SRCS)
65-
$(TESTDIR)/%.o: override CXXFLAGS += $(GTESTCFLAGS)
66-
67-
testexe: $(TESTOBJS) $(UTILOBJS)
68-
$(CXX) $(CXXFLAGS) $^ $(LDLIBS) $(LDFLAGS) -lgtest_main $(GTESTLDFLAGS) -o $@
69-
70-
$(BENCHOBJS): $(BENCHSRCS) $(UTILSRCS) $(SRCS)
71-
$(BENCHDIR)/%.o: override CXXFLAGS += $(GBENCHCFLAGS)
72-
73-
benchexe: $(BENCHOBJS) $(UTILOBJS)
74-
$(CXX) $(CXXFLAGS) $^ $(LDLIBS) $(LDFLAGS) -lbenchmark_main $(GBENCHLDFLAGS) -o $@
75-
76-
.PHONY: meson
771
meson:
782
meson setup --warnlevel 2 --werror --buildtype release builddir
793
cd builddir && ninja
804

81-
.PHONY: mesondebug
825
mesondebug:
836
meson setup --warnlevel 2 --werror --buildtype debug debug
847
cd debug && ninja
858

86-
.PHONY: clean
879
clean:
8810
$(RM) -rf $(TESTOBJS) $(BENCHOBJS) $(UTILOBJS) testexe benchexe builddir

lib/x86simdsort-avx2.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44

55
#define DEFINE_ALL_METHODS(type) \
66
template <> \
7-
void qsort(type *arr, size_t arrsize) \
7+
void qsort(type *arr, size_t arrsize, bool hasnan) \
88
{ \
9-
avx2_qsort(arr, arrsize); \
9+
avx2_qsort(arr, arrsize, hasnan); \
1010
} \
1111
template <> \
1212
void qselect(type *arr, size_t k, size_t arrsize, bool hasnan) \
@@ -24,5 +24,5 @@ namespace avx2 {
2424
DEFINE_ALL_METHODS(uint32_t)
2525
DEFINE_ALL_METHODS(int32_t)
2626
DEFINE_ALL_METHODS(float)
27-
} // namespace avx512
27+
} // namespace avx2
2828
} // namespace xss

lib/x86simdsort-icl.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
namespace xss {
66
namespace avx512 {
77
template <>
8-
void qsort(uint16_t *arr, size_t size)
8+
void qsort(uint16_t *arr, size_t size, bool hasnan)
99
{
10-
avx512_qsort(arr, size);
10+
avx512_qsort(arr, size, hasnan);
1111
}
1212
template <>
1313
void qselect(uint16_t *arr, size_t k, size_t arrsize, bool hasnan)
@@ -20,9 +20,9 @@ namespace avx512 {
2020
avx512_partial_qsort(arr, k, arrsize, hasnan);
2121
}
2222
template <>
23-
void qsort(int16_t *arr, size_t size)
23+
void qsort(int16_t *arr, size_t size, bool hasnan)
2424
{
25-
avx512_qsort(arr, size);
25+
avx512_qsort(arr, size, hasnan);
2626
}
2727
template <>
2828
void qselect(int16_t *arr, size_t k, size_t arrsize, bool hasnan)

lib/x86simdsort-internal.h

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace xss {
88
namespace avx512 {
99
// quicksort
1010
template <typename T>
11-
XSS_HIDE_SYMBOL void qsort(T *arr, size_t arrsize);
11+
XSS_HIDE_SYMBOL void qsort(T *arr, size_t arrsize, bool hasnan = false);
1212
// quickselect
1313
template <typename T>
1414
XSS_HIDE_SYMBOL void
@@ -19,16 +19,17 @@ namespace avx512 {
1919
partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan = false);
2020
// argsort
2121
template <typename T>
22-
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr, size_t arrsize);
22+
XSS_HIDE_SYMBOL std::vector<size_t>
23+
argsort(T *arr, size_t arrsize, bool hasnan = false);
2324
// argselect
2425
template <typename T>
2526
XSS_HIDE_SYMBOL std::vector<size_t>
26-
argselect(T *arr, size_t k, size_t arrsize);
27+
argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false);
2728
} // namespace avx512
2829
namespace avx2 {
2930
// quicksort
3031
template <typename T>
31-
XSS_HIDE_SYMBOL void qsort(T *arr, size_t arrsize);
32+
XSS_HIDE_SYMBOL void qsort(T *arr, size_t arrsize, bool hasnan = false);
3233
// quickselect
3334
template <typename T>
3435
XSS_HIDE_SYMBOL void
@@ -39,16 +40,17 @@ namespace avx2 {
3940
partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan = false);
4041
// argsort
4142
template <typename T>
42-
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr, size_t arrsize);
43+
XSS_HIDE_SYMBOL std::vector<size_t>
44+
argsort(T *arr, size_t arrsize, bool hasnan = false);
4345
// argselect
4446
template <typename T>
4547
XSS_HIDE_SYMBOL std::vector<size_t>
46-
argselect(T *arr, size_t k, size_t arrsize);
48+
argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false);
4749
} // namespace avx2
4850
namespace scalar {
4951
// quicksort
5052
template <typename T>
51-
XSS_HIDE_SYMBOL void qsort(T *arr, size_t arrsize);
53+
XSS_HIDE_SYMBOL void qsort(T *arr, size_t arrsize, bool hasnan = false);
5254
// quickselect
5355
template <typename T>
5456
XSS_HIDE_SYMBOL void
@@ -59,11 +61,12 @@ namespace scalar {
5961
partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan = false);
6062
// argsort
6163
template <typename T>
62-
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr, size_t arrsize);
64+
XSS_HIDE_SYMBOL std::vector<size_t>
65+
argsort(T *arr, size_t arrsize, bool hasnan = false);
6366
// argselect
6467
template <typename T>
6568
XSS_HIDE_SYMBOL std::vector<size_t>
66-
argselect(T *arr, size_t k, size_t arrsize);
69+
argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false);
6770
} // namespace scalar
6871
} // namespace xss
6972
#endif

lib/x86simdsort-scalar.h

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,14 @@
55
namespace xss {
66
namespace scalar {
77
template <typename T>
8-
void qsort(T *arr, size_t arrsize)
8+
void qsort(T *arr, size_t arrsize, bool hasnan)
99
{
10-
std::sort(arr, arr + arrsize, compare<T, std::less<T>>());
10+
if (hasnan) {
11+
std::sort(arr, arr + arrsize, compare<T, std::less<T>>());
12+
}
13+
else {
14+
std::sort(arr, arr + arrsize);
15+
}
1116
}
1217
template <typename T>
1318
void qselect(T *arr, size_t k, size_t arrsize, bool hasnan)
@@ -32,16 +37,18 @@ namespace scalar {
3237
}
3338
}
3439
template <typename T>
35-
std::vector<size_t> argsort(T *arr, size_t arrsize)
40+
std::vector<size_t> argsort(T *arr, size_t arrsize, bool hasnan)
3641
{
42+
UNUSED(hasnan);
3743
std::vector<size_t> arg(arrsize);
3844
std::iota(arg.begin(), arg.end(), 0);
3945
std::sort(arg.begin(), arg.end(), compare_arg<T, std::less<T>>(arr));
4046
return arg;
4147
}
4248
template <typename T>
43-
std::vector<size_t> argselect(T *arr, size_t k, size_t arrsize)
49+
std::vector<size_t> argselect(T *arr, size_t k, size_t arrsize, bool hasnan)
4450
{
51+
UNUSED(hasnan);
4552
std::vector<size_t> arg(arrsize);
4653
std::iota(arg.begin(), arg.end(), 0);
4754
std::nth_element(arg.begin(),

lib/x86simdsort-skx.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77
#define DEFINE_ALL_METHODS(type) \
88
template <> \
9-
void qsort(type *arr, size_t arrsize) \
9+
void qsort(type *arr, size_t arrsize, bool hasnan) \
1010
{ \
11-
avx512_qsort(arr, arrsize); \
11+
avx512_qsort(arr, arrsize, hasnan); \
1212
} \
1313
template <> \
1414
void qselect(type *arr, size_t k, size_t arrsize, bool hasnan) \
@@ -21,14 +21,15 @@
2121
avx512_partial_qsort(arr, k, arrsize, hasnan); \
2222
} \
2323
template <> \
24-
std::vector<size_t> argsort(type *arr, size_t arrsize) \
24+
std::vector<size_t> argsort(type *arr, size_t arrsize, bool hasnan) \
2525
{ \
26-
return avx512_argsort(arr, arrsize); \
26+
return avx512_argsort(arr, arrsize, hasnan); \
2727
} \
2828
template <> \
29-
std::vector<size_t> argselect(type *arr, size_t k, size_t arrsize) \
29+
std::vector<size_t> argselect( \
30+
type *arr, size_t k, size_t arrsize, bool hasnan) \
3031
{ \
31-
return avx512_argselect(arr, k, arrsize); \
32+
return avx512_argselect(arr, k, arrsize, hasnan); \
3233
}
3334

3435
namespace xss {

lib/x86simdsort-spr.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
namespace xss {
66
namespace avx512 {
77
template <>
8-
void qsort(_Float16 *arr, size_t size)
8+
void qsort(_Float16 *arr, size_t size, bool hasnan)
99
{
10-
avx512_qsort(arr, size);
10+
avx512_qsort(arr, size, hasnan);
1111
}
1212
template <>
1313
void qselect(_Float16 *arr, size_t k, size_t arrsize, bool hasnan)

lib/x86simdsort.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,11 @@ dispatch_requested(std::string_view cpurequested,
5555
#define CAT(a, b) CAT_(a, b)
5656

5757
#define DECLARE_INTERNAL_qsort(TYPE) \
58-
static void (*internal_qsort##TYPE)(TYPE *, size_t) = NULL; \
58+
static void (*internal_qsort##TYPE)(TYPE *, size_t, bool) = NULL; \
5959
template <> \
60-
void qsort(TYPE *arr, size_t arrsize) \
60+
void qsort(TYPE *arr, size_t arrsize, bool hasnan) \
6161
{ \
62-
(*internal_qsort##TYPE)(arr, arrsize); \
62+
(*internal_qsort##TYPE)(arr, arrsize, hasnan); \
6363
}
6464

6565
#define DECLARE_INTERNAL_qselect(TYPE) \
@@ -81,22 +81,23 @@ dispatch_requested(std::string_view cpurequested,
8181
}
8282

8383
#define DECLARE_INTERNAL_argsort(TYPE) \
84-
static std::vector<size_t> (*internal_argsort##TYPE)(TYPE *, size_t) \
84+
static std::vector<size_t> (*internal_argsort##TYPE)(TYPE *, size_t, bool) \
8585
= NULL; \
8686
template <> \
87-
std::vector<size_t> argsort(TYPE *arr, size_t arrsize) \
87+
std::vector<size_t> argsort(TYPE *arr, size_t arrsize, bool hasnan) \
8888
{ \
89-
return (*internal_argsort##TYPE)(arr, arrsize); \
89+
return (*internal_argsort##TYPE)(arr, arrsize, hasnan); \
9090
}
9191

9292
#define DECLARE_INTERNAL_argselect(TYPE) \
9393
static std::vector<size_t> (*internal_argselect##TYPE)( \
94-
TYPE *, size_t, size_t) \
94+
TYPE *, size_t, size_t, bool) \
9595
= NULL; \
9696
template <> \
97-
std::vector<size_t> argselect(TYPE *arr, size_t k, size_t arrsize) \
97+
std::vector<size_t> argselect( \
98+
TYPE *arr, size_t k, size_t arrsize, bool hasnan) \
9899
{ \
99-
return (*internal_argselect##TYPE)(arr, k, arrsize); \
100+
return (*internal_argselect##TYPE)(arr, k, arrsize, hasnan); \
100101
}
101102

102103
/* runtime dispatch mechanism */

lib/x86simdsort.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,33 @@
66

77
#define XSS_EXPORT_SYMBOL __attribute__((visibility("default")))
88
#define XSS_HIDE_SYMBOL __attribute__((visibility("hidden")))
9+
#define UNUSED(x) (void)(x)
910

1011
namespace x86simdsort {
12+
1113
// quicksort
1214
template <typename T>
13-
XSS_EXPORT_SYMBOL void qsort(T *arr, size_t arrsize);
15+
XSS_EXPORT_SYMBOL void qsort(T *arr, size_t arrsize, bool hasnan = false);
16+
1417
// quickselect
1518
template <typename T>
1619
XSS_EXPORT_SYMBOL void
1720
qselect(T *arr, size_t k, size_t arrsize, bool hasnan = false);
21+
1822
// partial sort
1923
template <typename T>
2024
XSS_EXPORT_SYMBOL void
2125
partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan = false);
26+
2227
// argsort
2328
template <typename T>
24-
XSS_EXPORT_SYMBOL std::vector<size_t> argsort(T *arr, size_t arrsize);
29+
XSS_EXPORT_SYMBOL std::vector<size_t>
30+
argsort(T *arr, size_t arrsize, bool hasnan = false);
31+
2532
// argselect
2633
template <typename T>
2734
XSS_EXPORT_SYMBOL std::vector<size_t>
28-
argselect(T *arr, size_t k, size_t arrsize);
35+
argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false);
36+
2937
} // namespace x86simdsort
3038
#endif

src/avx512-16bit-qsort.hpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -519,12 +519,14 @@ bool is_a_nan<uint16_t>(uint16_t elem)
519519
}
520520

521521
X86_SIMD_SORT_INLINE
522-
void avx512_qsort_fp16(uint16_t *arr, arrsize_t arrsize)
522+
void avx512_qsort_fp16(uint16_t *arr, arrsize_t arrsize, bool hasnan = false)
523523
{
524524
if (arrsize > 1) {
525-
arrsize_t nan_count
526-
= replace_nan_with_inf<zmm_vector<float16>, uint16_t>(arr,
527-
arrsize);
525+
arrsize_t nan_count = 0;
526+
if (UNLIKELY(hasnan)) {
527+
nan_count = replace_nan_with_inf<zmm_vector<float16>, uint16_t>(
528+
arr, arrsize);
529+
}
528530
qsort_<zmm_vector<float16>, uint16_t>(
529531
arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize));
530532
replace_inf_with_nan(arr, arrsize, nan_count);
@@ -535,7 +537,7 @@ X86_SIMD_SORT_INLINE
535537
void avx512_qselect_fp16(uint16_t *arr,
536538
arrsize_t k,
537539
arrsize_t arrsize,
538-
bool hasnan = true)
540+
bool hasnan = false)
539541
{
540542
arrsize_t indx_last_elem = arrsize - 1;
541543
if (UNLIKELY(hasnan)) {

0 commit comments

Comments
 (0)