Skip to content

Commit 61ac1f2

Browse files
author
yimingx
committed
feat: bst km
1 parent f131edd commit 61ac1f2

File tree

3 files changed

+203
-9
lines changed

3 files changed

+203
-9
lines changed

test/evaluation/kmbst/Makefile

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
1+
obj-m += bstkm.o
12

3+
KDIR ?= /lib/modules/$(shell uname -r)/build
4+
PWD := $(shell pwd)
5+
6+
all:
7+
$(MAKE) -C $(KDIR) M=$(PWD) modules
8+
29
bst.o: bst_bpf.c
310
clang -O2 -I/usr/include/$(shell uname -m)-linux-gnu -target bpf -g -c $< -o bst.o
411

test/evaluation/kmbst/bstkm.c

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
#include <linux/module.h>
2+
#include <linux/kernel.h>
3+
#include <linux/init.h>
4+
#include <linux/slab.h>
5+
#include <linux/ktime.h>
6+
7+
MODULE_LICENSE("GPL");
8+
MODULE_DESCRIPTION("BST benchmark in kernel module");
9+
10+
// ---------------- BST Node ----------------
11+
struct bst_node {
12+
int key;
13+
struct bst_node *left, *right;
14+
};
15+
16+
// BST Insert
17+
struct bst_node* bst_insert(struct bst_node* root, int key) {
18+
if (!root) {
19+
struct bst_node* n = kmalloc(sizeof(*n), GFP_KERNEL);
20+
n->key = key;
21+
n->left = n->right = NULL;
22+
return n;
23+
}
24+
if (key < root->key)
25+
root->left = bst_insert(root->left, key);
26+
else if (key > root->key)
27+
root->right = bst_insert(root->right, key);
28+
return root; // ignore duplicates
29+
}
30+
31+
// BST Lookup
32+
struct bst_node* bst_lookup(struct bst_node* root, int key) {
33+
while (root) {
34+
if (key < root->key)
35+
root = root->left;
36+
else if (key > root->key)
37+
root = root->right;
38+
else
39+
return root;
40+
}
41+
return NULL;
42+
}
43+
44+
// BST Delete
45+
struct bst_node* bst_delete(struct bst_node* root, int key) {
46+
if (!root) return NULL;
47+
48+
if (key < root->key)
49+
root->left = bst_delete(root->left, key);
50+
else if (key > root->key)
51+
root->right = bst_delete(root->right, key);
52+
else {
53+
if (!root->left) {
54+
struct bst_node* tmp = root->right;
55+
kfree(root);
56+
return tmp;
57+
} else if (!root->right) {
58+
struct bst_node* tmp = root->left;
59+
kfree(root);
60+
return tmp;
61+
} else {
62+
// find min in right subtree
63+
struct bst_node* tmp = root->right;
64+
while (tmp->left) tmp = tmp->left;
65+
root->key = tmp->key;
66+
root->right = bst_delete(root->right, tmp->key);
67+
}
68+
}
69+
return root;
70+
}
71+
72+
// Free BST
73+
void bst_free(struct bst_node* root) {
74+
if (!root) return;
75+
bst_free(root->left);
76+
bst_free(root->right);
77+
kfree(root);
78+
}
79+
80+
// ---------------- LCG PRNG ----------------
81+
static inline unsigned int lcg_next(unsigned int *state) {
82+
*state = (*state * 1664525 + 1013904223);
83+
return *state;
84+
}
85+
86+
static inline unsigned int rand_range(unsigned int *state, unsigned int max) {
87+
return lcg_next(state) % max;
88+
}
89+
90+
// Fisher-Yates shuffle
91+
void shuffle(int *arr, int n, unsigned int seed) {
92+
unsigned int state = seed;
93+
int i;
94+
for (i = n - 1; i > 0; i--) {
95+
int j = rand_range(&state, i + 1);
96+
int tmp = arr[i];
97+
arr[i] = arr[j];
98+
arr[j] = tmp;
99+
}
100+
}
101+
102+
// ---------------- Benchmark ----------------
103+
#define N (64 * 1024)
104+
105+
static int __init bst_benchmark_init(void)
106+
{
107+
int i;
108+
unsigned long long start, end, total;
109+
struct bst_node* root = NULL;
110+
111+
int *data = kmalloc_array(N, sizeof(int), GFP_KERNEL);
112+
for (i = 0; i < N; i++) data[i] = i + 1;
113+
shuffle(data, N, 114514); // fixed seed
114+
115+
pr_info("bst Benchmark: N=%d\n", N);
116+
117+
// ----- Insert -----
118+
total = 0;
119+
for (i = 0; i < N; i++) {
120+
start = ktime_get_ns();
121+
root = bst_insert(root, data[i]);
122+
end = ktime_get_ns();
123+
total += (end - start);
124+
}
125+
pr_info("bst Insert avg latency: %llu ns\n", total / N);
126+
127+
// ----- Lookup -----
128+
total = 0;
129+
for (i = 0; i < N; i++) {
130+
start = ktime_get_ns();
131+
if (!bst_lookup(root, data[i])) pr_info("Lookup error!\n");
132+
end = ktime_get_ns();
133+
total += (end - start);
134+
}
135+
pr_info("bst Lookup avg latency: %llu ns\n", total / N);
136+
137+
// ----- Delete -----
138+
total = 0;
139+
for (i = 0; i < N; i++) {
140+
start = ktime_get_ns();
141+
root = bst_delete(root, data[i]);
142+
end = ktime_get_ns();
143+
total += (end - start);
144+
}
145+
pr_info("bst Delete avg latency: %llu ns\n", total / N);
146+
147+
bst_free(root);
148+
return 0;
149+
}
150+
151+
static void __exit bst_benchmark_exit(void)
152+
{
153+
pr_info("BST benchmark module exit\n");
154+
}
155+
156+
module_init(bst_benchmark_init);
157+
module_exit(bst_benchmark_exit);

test/evaluation/kmbst/syscall_user.c

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#include <stdio.h>
2+
#include <stdlib.h>
13
#define _GNU_SOURCE
24
// #include <errno.h>
35
// #include <stdio.h>
@@ -25,18 +27,46 @@ void delete(int key) {
2527
mount(source, target, fstype, generate_flags(2, key), data);
2628
}
2729

30+
static inline unsigned int lcg_next(unsigned int *state) {
31+
*state = (*state * 1664525 + 1013904223);
32+
return *state;
33+
}
34+
35+
static inline unsigned int rand_range(unsigned int *state, unsigned int max) {
36+
return lcg_next(state) % max;
37+
}
38+
39+
void shuffle(int *arr, int n, unsigned int seed) {
40+
unsigned int state = seed;
41+
for (int i = n - 1; i > 0; i--) {
42+
int j = rand_range(&state, i + 1);
43+
int tmp = arr[i];
44+
arr[i] = arr[j];
45+
arr[j] = tmp;
46+
}
47+
}
48+
2849
int main() {
50+
int num = 64 * 1024;
51+
int *arr = malloc(num * sizeof(int));
52+
for(int i = 0; i < num; i++) {
53+
arr[i] = i + 1;
54+
}
55+
shuffle(arr, num, 114514);
56+
57+
for(int i = 0; i < num; i++) {
58+
insert(arr[i]);
59+
}
60+
61+
62+
for(int i = 0; i < num; i++) {
63+
search(arr[i]);
64+
}
2965

30-
for (int i = 1; i <= 10; i++) {
31-
insert(i);
66+
for(int i = 0; i < num; i++) {
67+
delete(arr[i]);
3268
}
33-
// for (int i = 1; i <= 10; i++) {
34-
// search(6);
35-
// }
36-
37-
search(6);
38-
delete(6);
39-
search(6);
4069

70+
free(arr);
4171
return 0;
4272
}

0 commit comments

Comments
 (0)