@@ -85,6 +85,35 @@ bool test_arg_min_non_default_stream() {
85
85
return out.key == 5 && out.value == 0 ;
86
86
}
87
87
88
+ __global__ void test_arg_min_max_op_in_device (int *Res) {
89
+ cub::KeyValuePair<int , int > LHS{1 , 3 }, RHS{2 , 4 };
90
+ cub::ArgMin MinOp;
91
+ cub::ArgMax MaxOp;
92
+ auto Min = MinOp (LHS, RHS);
93
+ auto Max = MaxOp (LHS, RHS);
94
+ *Res = Min.key == 1 && Min.value == 3 && Max.key == 2 && Max.value == 4 ;
95
+ }
96
+
97
+ bool test_arg_min_max_op_in_host () {
98
+ cub::KeyValuePair<int , int > LHS{1 , 3 }, RHS{2 , 4 };
99
+ cub::ArgMin MinOp;
100
+ cub::ArgMax MaxOp;
101
+ auto Min = MinOp (LHS, RHS);
102
+ auto Max = MaxOp (LHS, RHS);
103
+ return Min.key == 1 && Min.value == 3 && Max.key == 2 && Max.value == 4 ;
104
+ }
105
+
106
+ bool test_arg_min_max_op () {
107
+ int *Res;
108
+ cudaMallocManaged (&Res, sizeof (int ));
109
+ *Res = 0 ;
110
+ test_arg_min_max_op_in_device<<<1 , 1 >>> (Res);
111
+ cudaDeviceSynchronize ();
112
+ bool Val = *Res;
113
+ cudaFree (Res);
114
+ return Val && test_arg_min_max_op_in_host ();
115
+ }
116
+
88
117
int main () {
89
118
int res = 0 ;
90
119
if (!test_arg_max ()) {
@@ -107,5 +136,10 @@ int main() {
107
136
std::cout << " cub::DeviceReduce::ArgMin(Non default stream) test failed\n " ;
108
137
}
109
138
139
+ if (!test_arg_min_max_op ()) {
140
+ res = 1 ;
141
+ std::cout << " cub::{ArgMin, ArgMax} binary operator test failed\n " ;
142
+ }
143
+
110
144
return res;
111
145
}
0 commit comments