Skip to content

Commit 7db35a7

Browse files
authored
CUDA: add FLOOR, CEIL, ROUND, TRUNC unary ops (ggml-org#16917)
1 parent a864132 commit 7db35a7

File tree

3 files changed

+56
-0
lines changed

3 files changed

+56
-0
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2499,6 +2499,18 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
24992499
case GGML_UNARY_OP_XIELU:
25002500
ggml_cuda_op_xielu(ctx, dst);
25012501
break;
2502+
case GGML_UNARY_OP_FLOOR:
2503+
ggml_cuda_op_floor(ctx, dst);
2504+
break;
2505+
case GGML_UNARY_OP_CEIL:
2506+
ggml_cuda_op_ceil(ctx, dst);
2507+
break;
2508+
case GGML_UNARY_OP_ROUND:
2509+
ggml_cuda_op_round(ctx, dst);
2510+
break;
2511+
case GGML_UNARY_OP_TRUNC:
2512+
ggml_cuda_op_trunc(ctx, dst);
2513+
break;
25022514
default:
25032515
return false;
25042516
}
@@ -3769,6 +3781,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
37693781
case GGML_UNARY_OP_TANH:
37703782
case GGML_UNARY_OP_EXP:
37713783
case GGML_UNARY_OP_ELU:
3784+
case GGML_UNARY_OP_FLOOR:
3785+
case GGML_UNARY_OP_CEIL:
3786+
case GGML_UNARY_OP_ROUND:
3787+
case GGML_UNARY_OP_TRUNC:
37723788
return ggml_is_contiguous(op->src[0]);
37733789
default:
37743790
return false;

ggml/src/ggml-cuda/unary.cu

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,22 @@ static __device__ __forceinline__ float op_elu(float x) {
8585
return (x > 0.f) ? x : expm1f(x);
8686
}
8787

88+
static __device__ __forceinline__ float op_floor(float x) {
89+
return floorf(x);
90+
}
91+
92+
static __device__ __forceinline__ float op_ceil(float x) {
93+
return ceilf(x);
94+
}
95+
96+
static __device__ __forceinline__ float op_round(float x) {
97+
return round(x);
98+
}
99+
100+
static __device__ __forceinline__ float op_trunc(float x) {
101+
return trunc(x);
102+
}
103+
88104
template <float (*op)(float), typename T>
89105
static __global__ void unary_op_kernel(const T * x, T * dst, const int k) {
90106
const int i = blockDim.x*blockIdx.x + threadIdx.x;
@@ -201,6 +217,22 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
201217
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
202218
ggml_cuda_op_unary<op_elu>(ctx, dst);
203219
}
220+
221+
void ggml_cuda_op_floor(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
222+
ggml_cuda_op_unary<op_floor>(ctx, dst);
223+
}
224+
225+
void ggml_cuda_op_ceil(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
226+
ggml_cuda_op_unary<op_ceil>(ctx, dst);
227+
}
228+
229+
void ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
230+
ggml_cuda_op_unary<op_round>(ctx, dst);
231+
}
232+
233+
void ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
234+
ggml_cuda_op_unary<op_trunc>(ctx, dst);
235+
}
204236
/* gated ops */
205237

206238
template <float (*op)(float), typename T>

ggml/src/ggml-cuda/unary.cuh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,14 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
6363

6464
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
6565

66+
void ggml_cuda_op_floor(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
67+
68+
void ggml_cuda_op_ceil(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
69+
70+
void ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
71+
72+
void ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
73+
6674
void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
6775

6876
void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

0 commit comments

Comments
 (0)