Skip to content

Commit f06f72e

Browse files
committed
Start to experiment with bindings for cuda_kernel
1 parent 49064ee commit f06f72e

File tree

3 files changed

+87
-25
lines changed

3 files changed

+87
-25
lines changed

c/experimental/stf/include/cccl/c/experimental/stf/stf.h

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,14 @@ cudaStream_t stf_task_get_stream(stf_task_handle t);
4444
void* stf_task_get(stf_task_handle t, size_t submitted_index);
4545
void stf_task_destroy(stf_task_handle t);
4646

47-
typedef struct stf_kernel_desc_handle_t* stf_kernel_desc_handle;
48-
49-
void stf_kernel_create(stf_kernel_desc_handle* d);
50-
void stf_kernel_destroy(stf_kernel_desc_handle d);
51-
// TODO stf_cuda_kernel_desc : symbol, deps, args... ?
52-
// void stf_kernel_set_symbol((stf_kernel_handle k, const char* symbol)
53-
// void stf_kernel_add_dep(stf_kernel_handle k, stf_logical_data_handle ld, stf_access_mode m);
54-
// void stf_kernel_start(stf_kernel_handle k);
55-
// void stf_kernel_set_args(stf_kernel_handle k, size_t cnt, void **args);
56-
// void stf_kernel_end(stf_kernel_handle k);
57-
// void stf_kernel_destroy(stf_kernel_handle k);
47+
typedef struct stf_cuda_kernel_handle_t* stf_cuda_kernel_handle;
48+
49+
void stf_cuda_kernel_create(stf_ctx_handle ctx, stf_cuda_kernel_handle* k);
50+
void stf_cuda_kernel_set_symbol(stf_cuda_kernel_handle k, const char* symbol);
51+
void stf_cuda_kernel_add_dep(stf_cuda_kernel_handle k, stf_logical_data_handle ld, stf_access_mode m);
52+
void stf_cuda_kernel_start(stf_cuda_kernel_handle k);
53+
void stf_cuda_kernel_end(stf_cuda_kernel_handle k);
54+
void stf_cuda_kernel_destroy(stf_cuda_kernel_handle t);
5855

5956
#ifdef __cplusplus
6057
}

c/experimental/stf/src/stf.cu

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -153,20 +153,20 @@ void stf_cuda_kernel_add_dep(stf_cuda_kernel_handle k, stf_logical_data_handle l
153153
assert(k);
154154
assert(ld);
155155

156-
k->k.add_deps(cuda_kernel_dep_untyped(ld->ld, access_mode(m)));
157-
}
158-
159-
// void stf_cuda_kernel_start(stf_cuda_kernel_handle k)
160-
// {
161-
// assert(k);
162-
// k->k.start();
163-
// }
164-
//
165-
// void stf_cuda_kernel_end(stf_cuda_kernel_handle k)
166-
// {
167-
// assert(k);
168-
// k->k.end();
169-
// }
156+
k->k.add_deps(task_dep_untyped(ld->ld, access_mode(m)));
157+
}
158+
159+
void stf_cuda_kernel_start(stf_cuda_kernel_handle k)
160+
{
161+
assert(k);
162+
k->k.start();
163+
}
164+
165+
void stf_cuda_kernel_end(stf_cuda_kernel_handle k)
166+
{
167+
assert(k);
168+
k->k.end();
169+
}
170170

171171
void stf_cuda_kernel_destroy(stf_cuda_kernel_handle t)
172172
{
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of CUDA Experimental in CUDA C++ Core Libraries,
4+
// under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
8+
//
9+
//===----------------------------------------------------------------------===//
10+
11+
#include <cuda_runtime.h>
12+
13+
#include <c2h/catch2_test_helper.h>
14+
#include <cccl/c/experimental/stf/stf.h>
15+
16+
using namespace cuda::experimental::stf;
17+
18+
__global__ void axpy(int cnt, double a, const double *x, double *y)
19+
{
20+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
21+
int nthreads = gridDim.x * blockDim.x;
22+
23+
for (int i = tid; i < cnt; i += nthreads)
24+
{
25+
y[i] += a * x[i];
26+
}
27+
}
28+
29+
C2H_TEST("axpy with stf cuda_kernel", "[cuda_kernel]")
30+
{
31+
size_t N = 1000000;
32+
33+
stf_ctx_handle ctx;
34+
stf_ctx_create(&ctx);
35+
36+
stf_logical_data_handle lX, lY;
37+
38+
float *X, *Y;
39+
X = (float*) malloc(N * sizeof(float));
40+
Y = (float*) malloc(N * sizeof(float));
41+
42+
stf_logical_data(ctx, &lX, X, N * sizeof(float));
43+
stf_logical_data(ctx, &lY, Y, N * sizeof(float));
44+
45+
stf_logical_data_set_symbol(lX, "X");
46+
stf_logical_data_set_symbol(lY, "Y");
47+
48+
stf_cuda_kernel_handle k;
49+
stf_cuda_kernel_create(ctx, &k);
50+
stf_cuda_kernel_set_symbol(k, "axpy");
51+
stf_cuda_kernel_add_dep(k, lX, STF_READ);
52+
stf_cuda_kernel_add_dep(k, lY, STF_RW);
53+
stf_cuda_kernel_start(k);
54+
// TODO add descs
55+
stf_cuda_kernel_end(k);
56+
stf_cuda_kernel_destroy(k);
57+
58+
stf_logical_data_destroy(lX);
59+
stf_logical_data_destroy(lY);
60+
61+
stf_ctx_finalize(ctx);
62+
63+
free(X);
64+
free(Y);
65+
}

0 commit comments

Comments
 (0)