Skip to content

Commit fee0445

Browse files
committed
Fix the in-memory kernel cache to not ignore function names.
1 parent e47d526 commit fee0445

File tree

1 file changed

+58
-28
lines changed

1 file changed

+58
-28
lines changed

src/gpuarray_buffer_cuda.c

Lines changed: 58 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,15 @@ typedef struct _disk_key {
5757
strb src;
5858
} disk_key;
5959

60+
typedef struct _kernel_key {
61+
const char *fname;
62+
strb src;
63+
} kernel_key;
64+
6065
/* Size of the disk_key that we can memcopy to duplicate */
6166
#define DISK_KEY_MM (sizeof(disk_key) - sizeof(strb))
6267

63-
static void key_free(cache_key_t _k) {
68+
static void disk_free(cache_key_t _k) {
6469
disk_key *k = (disk_key *)_k;
6570
strb_clear(&k->src);
6671
free(k);
@@ -71,30 +76,45 @@ static int strb_eq(strb *k1, strb *k2) {
7176
memcmp(k1->s, k2->s, k1->l) == 0);
7277
}
7378

74-
static uint32_t strb_hash(strb *k) {
75-
return XXH32(k->s, k->l, 42);
79+
static int kernel_eq(kernel_key *k1, kernel_key *k2) {
80+
return (strcmp(k1->fname, k2->fname) == 0 &&
81+
strb_eq(&k1->src, &k2->src));
82+
}
83+
84+
static uint32_t kernel_hash(kernel_key *k) {
85+
XXH32_state_t state;
86+
XXH32_reset(&state, 42);
87+
XXH32_update(&state, k->fname, strlen(k->fname));
88+
XXH32_update(&state, k->src.s, k->src.l);
89+
return XXH32_digest(&state);
7690
}
7791

78-
static int key_eq(disk_key *k1, disk_key *k2) {
92+
static void kernel_free(kernel_key *k) {
93+
free((void *)k->fname);
94+
strb_clear(&k->src);
95+
free(k);
96+
}
97+
98+
static int disk_eq(disk_key *k1, disk_key *k2) {
7999
return (memcmp(k1, k2, DISK_KEY_MM) == 0 &&
80100
strb_eq(&k1->src, &k2->src));
81101
}
82102

83-
static int key_hash(disk_key *k) {
103+
static int disk_hash(disk_key *k) {
84104
XXH32_state_t state;
85105
XXH32_reset(&state, 42);
86106
XXH32_update(&state, k, DISK_KEY_MM);
87107
XXH32_update(&state, k->src.s, k->src.l);
88108
return XXH32_digest(&state);
89109
}
90110

91-
static int key_write(strb *res, disk_key *k) {
111+
static int disk_write(strb *res, disk_key *k) {
92112
strb_appendn(res, (const char *)k, DISK_KEY_MM);
93113
strb_appendb(res, &k->src);
94114
return strb_error(res);
95115
}
96116

97-
static disk_key *key_read(const strb *b) {
117+
static disk_key *disk_read(const strb *b) {
98118
disk_key *k;
99119
if (b->l < DISK_KEY_MM) return NULL;
100120
k = calloc(1, sizeof(*k));
@@ -238,9 +258,9 @@ cuda_context *cuda_make_ctx(CUcontext ctx, int flags) {
238258
}
239259

240260
res->kernel_cache = cache_twoq(64, 128, 64, 8,
241-
(cache_eq_fn)strb_eq,
242-
(cache_hash_fn)strb_hash,
243-
(cache_freek_fn)strb_free,
261+
(cache_eq_fn)kernel_eq,
262+
(cache_hash_fn)kernel_hash,
263+
(cache_freek_fn)kernel_free,
244264
(cache_freev_fn)cuda_freekernel, global_err);
245265
if (res->kernel_cache == NULL) {
246266
error_cuda(global_err, "cuStreamCreate", err);
@@ -250,9 +270,9 @@ cuda_context *cuda_make_ctx(CUcontext ctx, int flags) {
250270
cache_path = getenv("GPUARRAY_CACHE_PATH");
251271
if (cache_path != NULL) {
252272
mem_cache = cache_lru(64, 8,
253-
(cache_eq_fn)key_eq,
254-
(cache_hash_fn)key_hash,
255-
(cache_freek_fn)key_free,
273+
(cache_eq_fn)disk_eq,
274+
(cache_hash_fn)disk_hash,
275+
(cache_freek_fn)disk_free,
256276
(cache_freev_fn)strb_free,
257277
global_err);
258278
if (mem_cache == NULL) {
@@ -261,11 +281,11 @@ cuda_context *cuda_make_ctx(CUcontext ctx, int flags) {
261281
goto fail_disk_cache;
262282
}
263283
res->disk_cache = cache_disk(cache_path, mem_cache,
264-
(kwrite_fn)key_write,
284+
(kwrite_fn)disk_write,
265285
(vwrite_fn)kernel_write,
266-
(kread_fn)key_read,
286+
(kread_fn)disk_read,
267287
(vread_fn)kernel_read,
268-
res->err);
288+
global_err);
269289
if (res->disk_cache == NULL) {
270290
fprintf(stderr, "Error initializing disk cache, disabling: %s\n",
271291
global_err->msg);
@@ -1230,23 +1250,23 @@ static int compile(cuda_context *ctx, strb *src, strb* bin, strb *log) {
12301250
error_sys(ctx->err, "strb_appendb");
12311251
fprintf(stderr, "Error adding kernel to disk cache %s\n",
12321252
ctx->err->msg);
1233-
key_free((cache_key_t)pk);
1253+
disk_free((cache_key_t)pk);
12341254
return GA_NO_ERROR;
12351255
}
12361256
cbin = strb_alloc(bin->l);
12371257
if (cbin == NULL) {
12381258
error_sys(ctx->err, "strb_alloc");
12391259
fprintf(stderr, "Error adding kernel to disk cache: %s\n",
12401260
ctx->err->msg);
1241-
key_free((cache_key_t)pk);
1261+
disk_free((cache_key_t)pk);
12421262
return GA_NO_ERROR;
12431263
}
12441264
strb_appendb(cbin, bin);
12451265
if (strb_error(cbin)) {
12461266
error_sys(ctx->err, "strb_appendb");
12471267
fprintf(stderr, "Error adding kernel to disk cache %s\n",
12481268
ctx->err->msg);
1249-
key_free((cache_key_t)pk);
1269+
disk_free((cache_key_t)pk);
12501270
strb_free(cbin);
12511271
return GA_NO_ERROR;
12521272
}
@@ -1284,8 +1304,9 @@ static int cuda_newkernel(gpukernel **k, gpucontext *c, unsigned int count,
12841304
strb src = STRB_STATIC_INIT;
12851305
strb bin = STRB_STATIC_INIT;
12861306
strb log = STRB_STATIC_INIT;
1287-
strb *psrc;
12881307
gpukernel *res;
1308+
kernel_key k_key;
1309+
kernel_key *p_key;
12891310
CUdevice dev;
12901311
CUresult err;
12911312
unsigned int i;
@@ -1350,7 +1371,10 @@ static int cuda_newkernel(gpukernel **k, gpucontext *c, unsigned int count,
13501371
return error_sys(ctx->err, "strb");
13511372
}
13521373

1353-
res = (gpukernel *)cache_get(ctx->kernel_cache, &src);
1374+
k_key.fname = fname;
1375+
k_key.src = src;
1376+
1377+
res = (gpukernel *)cache_get(ctx->kernel_cache, &k_key);
13541378
if (res != NULL) {
13551379
res->refcnt++;
13561380
strb_clear(&src);
@@ -1434,13 +1458,19 @@ static int cuda_newkernel(gpukernel **k, gpucontext *c, unsigned int count,
14341458
ctx->refcnt++;
14351459
cuda_exit(ctx);
14361460
TAG_KER(res);
1437-
psrc = memdup(&src, sizeof(strb));
1438-
if (psrc != NULL) {
1439-
/* One of the refs is for the cache */
1440-
res->refcnt++;
1441-
/* If this fails, it will free the key and remove a ref from the
1442-
kernel. */
1443-
cache_add(ctx->kernel_cache, psrc, res);
1461+
p_key = memdup(&k_key, sizeof(kernel_key));
1462+
if (p_key != NULL) {
1463+
p_key->fname = strdup(fname);
1464+
if (p_key->fname != NULL) {
1465+
/* One of the refs is for the cache */
1466+
res->refcnt++;
1467+
/* If this fails, it will free the key and remove a ref from the
1468+
kernel. */
1469+
cache_add(ctx->kernel_cache, p_key, res);
1470+
} else {
1471+
free(p_key);
1472+
strb_clear(&src);
1473+
}
14441474
} else {
14451475
strb_clear(&src);
14461476
}

0 commit comments

Comments
 (0)