Skip to content

Commit 2967234

Browse files
committed
cuda: emit is_first/broadcast_first
1 parent d1122d5 commit 2967234

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

src/shady/emit/c/cuda_builtins.cu

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,20 @@ __device__ void __shady_prepare_builtins() {
1212
GlobalInvocationId.arr[1] = threadIdx.y + blockDim.y * blockIdx.y;
1313
GlobalInvocationId.arr[2] = threadIdx.z + blockDim.z * blockIdx.z;
1414
}
15+
16+
__device__ bool __shady_elect_first() {
17+
unsigned int writemask = __activemask();
18+
// Find the lowest-numbered active lane
19+
int elected_lane = __ffs(writemask) - 1;
20+
return threadIdx.x == __shfl_sync(writemask, threadIdx.x, elected_lane)
21+
&& threadIdx.y == __shfl_sync(writemask, threadIdx.y, elected_lane)
22+
&& threadIdx.z == __shfl_sync(writemask, threadIdx.z, elected_lane);
23+
}
24+
25+
template<typename T>
26+
__device__ T __shady_broadcast_first(T t) {
27+
unsigned int writemask = __activemask();
28+
// Find the lowest-numbered active lane
29+
int elected_lane = __ffs(writemask) - 1;
30+
return __shfl_sync(writemask, t, elected_lane);
31+
}

src/shady/emit/c/emit_c_instructions.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -618,13 +618,13 @@ static void emit_primop(Emitter* emitter, Printer* p, const Node* node, Instruct
618618
case get_stack_base_op:
619619
case push_stack_op:
620620
case pop_stack_op:
621-
case get_stack_pointer_op:
622-
case set_stack_pointer_op: error("Stack operations need to be lowered.");
621+
case get_stack_size_op:
622+
case set_stack_size_op: error("Stack operations need to be lowered.");
623623
case default_join_point_op:
624624
case create_joint_point_op: error("lowered in lower_tailcalls.c");
625625
case subgroup_elect_first_op: {
626626
switch (emitter->config.dialect) {
627-
case CDialect_CUDA: error("TODO")
627+
case CDialect_CUDA: term = term_from_cvalue(format_string_arena(emitter->arena->arena, "__shady_elect_first()")); break;
628628
case CDialect_ISPC: term = term_from_cvalue(format_string_arena(emitter->arena->arena, "(programIndex == count_trailing_zeros(lanemask()))")); break;
629629
case CDialect_C11:
630630
case CDialect_GLSL: error("TODO")
@@ -641,7 +641,7 @@ static void emit_primop(Emitter* emitter, Printer* p, const Node* node, Instruct
641641
case subgroup_broadcast_first_op: {
642642
CValue value = to_cvalue(emitter, emit_value(emitter, p, first(prim_op->operands)));
643643
switch (emitter->config.dialect) {
644-
case CDialect_CUDA: error("TODO")
644+
case CDialect_CUDA: term = term_from_cvalue(format_string_arena(emitter->arena->arena, "__shady_broadcast_first(%s)", value)); break;
645645
case CDialect_ISPC: term = term_from_cvalue(format_string_arena(emitter->arena->arena, "extract(%s, count_trailing_zeros(lanemask()))", value)); break;
646646
case CDialect_C11:
647647
case CDialect_GLSL: error("TODO")

0 commit comments

Comments
 (0)