Skip to content

Commit f3b57da

Browse files
committed
Save some WIP
1 parent 48fd705 commit f3b57da

File tree

3 files changed

+78
-12
lines changed

3 files changed

+78
-12
lines changed

c/experimental/stf/include/cccl/c/experimental/stf/stf.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,12 @@ void stf_task_add_dep(stf_task_handle t, stf_logical_data_handle ld, stf_access_
4141
void stf_task_start(stf_task_handle t);
4242
void stf_task_end(stf_task_handle t);
4343
cudaStream_t stf_task_get_stream(stf_task_handle t);
44-
void *stf_task_get(stf_task_handle t, size_t submitted_index);
44+
void* stf_task_get(stf_task_handle t, size_t submitted_index);
4545
void stf_task_destroy(stf_task_handle t);
4646

47-
typedef struct stf_kernel_desc_handle_t *stf_kernel_desc_handle;
47+
typedef struct stf_kernel_desc_handle_t* stf_kernel_desc_handle;
4848

49-
void stf_kernel_create(stf_kernel_desc_handle *d);
49+
void stf_kernel_create(stf_kernel_desc_handle* d);
5050
void stf_kernel_destroy(stf_kernel_desc_handle d);
5151
// TODO stf_cuda_kernel_desc : symbol, deps, args... ?
5252
// void stf_kernel_set_symbol((stf_kernel_handle k, const char* symbol)

c/experimental/stf/src/stf.cu

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ void stf_ctx_finalize(stf_ctx_handle ctx)
3838

3939
cudaStream_t stf_fence(stf_ctx_handle ctx)
4040
{
41-
assert(ctx);
42-
return ctx->ctx.fence();
41+
assert(ctx);
42+
return ctx->ctx.fence();
4343
}
4444

4545
void stf_logical_data(stf_ctx_handle ctx, stf_logical_data_handle* ld, void* addr, size_t sz)
@@ -107,4 +107,70 @@ void stf_task_destroy(stf_task_handle t)
107107
assert(t);
108108
delete t;
109109
}
110+
111+
/**
112+
* Low level example of cuda_kernel(_chain)
113+
* auto t = ctx.cuda_kernel_chain();
114+
t.add_deps(lX.read());
115+
t.add_deps(lY.rw());
116+
t->*[&]() {
117+
auto dX = t.template get<slice<double>>(0);
118+
auto dY = t.template get<slice<double>>(1);
119+
return std::vector<cuda_kernel_desc> {
120+
{ axpy, 16, 128, 0, alpha, dX, dY },
121+
{ axpy, 16, 128, 0, beta, dX, dY },
122+
{ axpy, 16, 128, 0, gamma, dX, dY }
123+
};
124+
};
125+
126+
*
127+
*/
128+
struct stf_cuda_kernel_handle_t
129+
{
130+
// return type of ctx.cuda_kernel()
131+
using kernel_type = decltype(::std::declval<context>().cuda_kernel());
132+
kernel_type k;
133+
};
134+
135+
void stf_cuda_kernel_create(stf_ctx_handle ctx, stf_cuda_kernel_handle* k)
136+
{
137+
assert(k);
138+
assert(ctx);
139+
140+
*k = new stf_cuda_kernel_handle_t{ctx->ctx.cuda_kernel()};
141+
}
142+
143+
void stf_cuda_kernel_set_symbol(stf_cuda_kernel_handle k, const char* symbol)
144+
{
145+
assert(k);
146+
assert(symbol);
147+
148+
k->k.set_symbol(symbol);
149+
}
150+
151+
void stf_cuda_kernel_add_dep(stf_cuda_kernel_handle k, stf_logical_data_handle ld, stf_access_mode m)
152+
{
153+
assert(k);
154+
assert(ld);
155+
156+
k->k.add_deps(cuda_kernel_dep_untyped(ld->ld, access_mode(m)));
157+
}
158+
159+
// void stf_cuda_kernel_start(stf_cuda_kernel_handle k)
160+
// {
161+
// assert(k);
162+
// k->k.start();
163+
// }
164+
//
165+
// void stf_cuda_kernel_end(stf_cuda_kernel_handle k)
166+
// {
167+
// assert(k);
168+
// k->k.end();
169+
// }
170+
171+
void stf_cuda_kernel_destroy(stf_cuda_kernel_handle t)
172+
{
173+
assert(t);
174+
delete t;
175+
}
110176
}

c/experimental/stf/test/test_task.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
//===----------------------------------------------------------------------===//
1010

1111
#include <cuda_runtime.h>
12-
#include <c2h/catch2_test_helper.h>
1312

13+
#include <c2h/catch2_test_helper.h>
1414
#include <cccl/c/experimental/stf/stf.h>
1515

1616
C2H_TEST("empty stf tasks", "[task]")
@@ -23,13 +23,13 @@ C2H_TEST("empty stf tasks", "[task]")
2323
stf_logical_data_handle lX, lY, lZ;
2424

2525
float *X, *Y, *Z;
26-
X = (float *)malloc(N*sizeof(float));
27-
Y = (float *)malloc(N*sizeof(float));
28-
Z = (float *)malloc(N*sizeof(float));
26+
X = (float*) malloc(N * sizeof(float));
27+
Y = (float*) malloc(N * sizeof(float));
28+
Z = (float*) malloc(N * sizeof(float));
2929

30-
stf_logical_data(ctx, &lX, X, N*sizeof(float));
31-
stf_logical_data(ctx, &lY, Y, N*sizeof(float));
32-
stf_logical_data(ctx, &lZ, Z, N*sizeof(float));
30+
stf_logical_data(ctx, &lX, X, N * sizeof(float));
31+
stf_logical_data(ctx, &lY, Y, N * sizeof(float));
32+
stf_logical_data(ctx, &lZ, Z, N * sizeof(float));
3333

3434
stf_logical_data_set_symbol(lX, "X");
3535
stf_logical_data_set_symbol(lY, "Y");

0 commit comments

Comments
 (0)