Skip to content

Commit 0202a11

Browse files
committed
feat: add quick select algorithm using median of medians
1 parent e5dad3f commit 0202a11

File tree

1 file changed

+285
-0
lines changed

1 file changed

+285
-0
lines changed

misc/quick_select.c

Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
/**
2+
* @file
3+
* @brief Kth largest element in linear time [Wikipedia: Selection Algorithm](https://en.wikipedia.org/wiki/Selection_algorithm)[Median of Medians](https://en.wikipedia.org/wiki/Median_of_medians)
4+
* @details
5+
* Quick Select is a linear-time algorithm for finding the kth largest element in an unsorted array.
6+
* It uses the median-of-medians algorithm to guarantee a good pivot selection, achieving O(n)
7+
* average time complexity and avoiding the O(n^2) worst-case of naive quickselect.
8+
* @author [Nothinormuch](https://github.com/Nothinormuch/)
9+
*/
10+
11+
#include<stdio.h>
12+
#include<stdlib.h>
13+
#include<time.h>
14+
#include<assert.h>
15+
16+
/**
17+
* @brief Prints a portion of an array
18+
*/
19+
void print_arr(int * arr, int start, int stop){
20+
printf("[%d",arr[start]);
21+
for(int i = start+1; i < stop+1; i ++){
22+
printf(",%d",arr[i]);
23+
}
24+
printf("]");
25+
}
26+
27+
/**
28+
* @brief Swaps two elements in an array
29+
* @param arr array pointer
30+
* @param i Index of first element
31+
* @param j Index of second element
32+
*/
33+
void swap(int * arr, int i, int j){
34+
int tmp = arr[i];
35+
arr[i] = arr[j];
36+
arr[j] = tmp;
37+
}
38+
39+
int partition(int * arr, int start, int stop, int pivot_value);
40+
int median_of_medians(int * arr, int start, int stop);
41+
int median_of_medians_helper(int * arr, int start, int stop);
42+
43+
/**
44+
* @brief Partitions array so elements greater than pivot are on the left
45+
* @details
46+
* Rearranges elements so that all elements greater than the pivot_value
47+
* are moved to the left side, smaller elements remain on the right.
48+
* The pivot is placed at the boundary between these two groups.
49+
* @param arr Pointer to the array
50+
* @param start Starting index of the partition range
51+
* @param stop Ending index of the partition range
52+
* @param pivot_value The value to partition around
53+
* @returns The final position of the pivot element
54+
*/
55+
int partition(int * arr, int start, int stop, int pivot_value){
56+
int i = start; // boundary pointer between larger and smaller elements
57+
58+
// Move all elements greater than pivot to the left
59+
for (int j = start; j <= stop; j++){
60+
if (arr[j] > pivot_value){
61+
swap(arr, i, j);
62+
i++;
63+
}
64+
}
65+
66+
// Find and place the pivot at position i
67+
for (int j = i; j <= stop; j++){
68+
if (arr[j] == pivot_value){
69+
swap(arr, i, j);
70+
break;
71+
}
72+
}
73+
return i;
74+
}
75+
76+
/**
77+
* @brief Finds a good pivot value using the median-of-medians algorithm
78+
* @details
79+
* Uses a divide-and-conquer strategy: divides the array into groups of 5,
80+
* finds the median of each group, then recursively finds the median of those medians.
81+
* This guarantees O(n) linear time complexity regardless of input distribution.
82+
* @param arr Pointer to the array
83+
* @param start Starting index of the range
84+
* @param stop Ending index of the range
85+
* @returns The median value (suitable for use as a pivot)
86+
*/
87+
int median_of_medians(int * arr, int start, int stop){
88+
int len = stop - start + 1;
89+
90+
// Base case: small arrays just get sorted and return middle element
91+
if (len <= 5){
92+
for (int i = start; i <= stop; i++){
93+
for (int j = start; j < stop - (i - start); j++){
94+
if (arr[j] > arr[j + 1]){
95+
swap(arr, j, j + 1);
96+
}
97+
}
98+
}
99+
return arr[start + len / 2];
100+
}
101+
102+
// Divide into groups of 5
103+
int num_groups = (len + 4) / 5; // ceiling division
104+
int * medians = (int *)malloc(sizeof(int) * num_groups);
105+
106+
for (int i = 0; i < num_groups; i++){
107+
int sub_start = start + i * 5;
108+
// Last group may have fewer than 5 elements
109+
int sub_stop = (sub_start + 4 > stop) ? stop : sub_start + 4;
110+
int sub_len = sub_stop - sub_start + 1;
111+
112+
// Sort this group
113+
for (int j = sub_start; j <= sub_stop; j++){
114+
for (int k = sub_start; k < sub_stop - (j - sub_start); k++){
115+
if (arr[k] > arr[k + 1]){
116+
swap(arr, k, k + 1);
117+
}
118+
}
119+
}
120+
medians[i] = arr[sub_start + sub_len / 2]; // store median of this group
121+
}
122+
123+
// Recursively find the median of all medians
124+
int result = median_of_medians_helper(medians, 0, num_groups - 1);
125+
free(medians);
126+
return result;
127+
}
128+
129+
/**
130+
* @brief Recursive helper for median-of-medians algorithm
131+
* @details
132+
* This function implements the same median-of-medians logic as the parent function.
133+
* It's separated as a helper to manage recursion properly without modifying the original array unexpectedly.
134+
* @param arr Pointer to the array
135+
* @param start Starting index of the range
136+
* @param stop Ending index of the range
137+
* @returns The median value
138+
*/
139+
int median_of_medians_helper(int * arr, int start, int stop){
140+
int len = stop - start + 1;
141+
if (len <= 5){
142+
for (int i = start; i <= stop; i++){
143+
for (int j = start; j < stop - (i - start); j++){
144+
if (arr[j] > arr[j + 1]){
145+
swap(arr, j, j + 1);
146+
}
147+
}
148+
}
149+
return arr[start + len / 2];
150+
}
151+
152+
int num_groups = (len + 4) / 5;
153+
int * medians = (int *)malloc(sizeof(int) * num_groups);
154+
155+
for (int i = 0; i < num_groups; i++){
156+
int sub_start = start + i * 5;
157+
int sub_stop = (sub_start + 4 > stop) ? stop : sub_start + 4;
158+
int sub_len = sub_stop - sub_start + 1;
159+
160+
for (int j = sub_start; j <= sub_stop; j++){
161+
for (int k = sub_start; k < sub_stop - (j - sub_start); k++){
162+
if (arr[k] > arr[k + 1]){
163+
swap(arr, k, k + 1);
164+
}
165+
}
166+
}
167+
medians[i] = arr[sub_start + sub_len / 2];
168+
}
169+
170+
int result = median_of_medians_helper(medians, 0, num_groups - 1);
171+
free(medians);
172+
return result;
173+
}
174+
175+
/**
176+
* @brief Finds the kth largest element in an array
177+
* @details
178+
* Uses the median-of-medians algorithm to find a good pivot, then partitions
179+
* the array and recursively searches the appropriate half. The pivot selection
180+
* guarantees O(n) time complexity in all cases (best, average, and worst).
181+
* k is 1-based: k=1 returns the largest, k=2 returns the 2nd largest, etc.
182+
* @param arr Pointer to the array
183+
* @param k The rank to find (1 = largest, 2 = 2nd largest, ..., n = smallest)
184+
* @param start Starting index of the search range
185+
* @param stop Ending index of the search range
186+
* @returns The kth largest element, or -1 if the range is invalid
187+
*/
188+
int kth_largest(int * arr, int k, int start, int stop){
189+
if (start > stop) return -1;
190+
191+
// Use median-of-medians to pick a good pivot
192+
int pivot_value = median_of_medians(arr, start, stop);
193+
194+
// Partition: larger elements go left, smaller go right of partition
195+
int pivot_index = partition(arr, start, stop, pivot_value);
196+
// Rank = how many elements are >= pivot_value
197+
int rank = pivot_index - start + 1;
198+
199+
// Check if we found the answer
200+
if (rank == k){
201+
return pivot_value;
202+
}
203+
// Kth largest is in left half (larger elements)
204+
else if (rank > k){
205+
return kth_largest(arr, k, start, pivot_index - 1);
206+
}
207+
// Kth largest is in right half (smaller elements), adjust k by how many are seen
208+
else{
209+
return kth_largest(arr, k - rank, pivot_index + 1, stop);
210+
}
211+
}
212+
213+
214+
/**
215+
* @brief Test cases
216+
*/
217+
static void test() {
218+
// Test 1: Simple unsorted array, find 3rd largest (17)
219+
int arr1[] = {7, 1, 15, 3, 19, 11, 5, 18, 2, 14, 9, 4, 16, 8, 12, 6, 17, 10, 13};
220+
int result1 = kth_largest(arr1, 3, 0, 18);
221+
assert(result1 == 17);
222+
printf("Test 1 passed: 3rd largest in unsorted array is 17\n");
223+
224+
// Test 2: Find the largest element (k=1)
225+
int arr2[] = {5, 2, 8, 1, 9, 3};
226+
int result2 = kth_largest(arr2, 1, 0, 5);
227+
assert(result2 == 9);
228+
printf("Test 2 passed: 1st largest (max) is 9\n");
229+
230+
// Test 3: Find the smallest element (k=n)
231+
int arr3[] = {5, 2, 8, 1, 9, 3};
232+
int result3 = kth_largest(arr3, 6, 0, 5);
233+
assert(result3 == 1);
234+
printf("Test 3 passed: 6th largest (min) in 6-element array is 1\n");
235+
236+
// Test 4: Single element array
237+
int arr4[] = {42};
238+
int result4 = kth_largest(arr4, 1, 0, 0);
239+
assert(result4 == 42);
240+
printf("Test 4 passed: 1st largest in single-element array is 42\n");
241+
242+
// Test 5: Two elements, find largest
243+
int arr5[] = {10, 20};
244+
int result5 = kth_largest(arr5, 1, 0, 1);
245+
assert(result5 == 20);
246+
printf("Test 5 passed: 1st largest in two-element array is 20\n");
247+
248+
// Test 6: Two elements, find smallest
249+
int arr6[] = {10, 20};
250+
int result6 = kth_largest(arr6, 2, 0, 1);
251+
assert(result6 == 10);
252+
printf("Test 6 passed: 2nd largest in two-element array is 10\n");
253+
254+
// Test 7: Array with duplicates, find 4th largest
255+
int arr7[] = {5, 3, 5, 2, 5, 1, 5};
256+
int result7 = kth_largest(arr7, 4, 0, 6);
257+
assert(result7 == 5); // sorted desc: [5,5,5,5,3,2,1], 4th is 5
258+
printf("Test 7 passed: 4th largest with duplicates is 5\n");
259+
260+
// Test 8: Already sorted (descending), find middle
261+
int arr8[] = {10, 9, 8, 7, 6, 5, 4, 3, 2, 1};
262+
int result8 = kth_largest(arr8, 5, 0, 9);
263+
assert(result8 == 6);
264+
printf("Test 8 passed: 5th largest in sorted descending array is 6\n");
265+
266+
// Test 9: Already sorted (ascending), find 3rd largest
267+
int arr9[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
268+
int result9 = kth_largest(arr9, 3, 0, 9);
269+
assert(result9 == 8);
270+
printf("Test 9 passed: 3rd largest in sorted ascending array is 8\n");
271+
272+
// Test 10: Larger array with random values
273+
int arr10[] = {45, 23, 78, 12, 89, 34, 56, 90, 67, 21, 98, 54, 32, 11, 88, 77, 42};
274+
int result10 = kth_largest(arr10, 5, 0, 16);
275+
assert(result10 == 78); // 5th largest: 98, 90, 89, 88, 78
276+
printf("Test 10 passed: 5th largest in random array is 78\n");
277+
278+
printf("\nAll tests have successfully passed!\n");
279+
}
280+
281+
// Main Function
282+
int main(){
283+
test();
284+
return 0;
285+
}

0 commit comments

Comments
 (0)