|
21 | 21 | #include "ggml-cuda/mmq.cuh" |
22 | 22 | #include "ggml-cuda/mmvq.cuh" |
23 | 23 | #include "ggml-cuda/norm.cuh" |
| 24 | +#include "ggml-cuda/opt-step-adamw.cuh" |
| 25 | +#include "ggml-cuda/out-prod.cuh" |
24 | 26 | #include "ggml-cuda/pad.cuh" |
25 | 27 | #include "ggml-cuda/pool2d.cuh" |
26 | 28 | #include "ggml-cuda/quantize.cuh" |
@@ -493,6 +495,14 @@ GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t |
493 | 495 | } |
494 | 496 | } |
495 | 497 |
|
| 498 | +GGML_CALL static void ggml_backend_cuda_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { |
| 499 | + ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; |
| 500 | + |
| 501 | + ggml_cuda_set_device(ctx->device); |
| 502 | + CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + offset, value, size, cudaStreamPerThread)); |
| 503 | + CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); |
| 504 | +} |
| 505 | + |
496 | 506 | GGML_CALL static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { |
497 | 507 | ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; |
498 | 508 |
|
@@ -544,6 +554,7 @@ static ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = { |
544 | 554 | /* .free_buffer = */ ggml_backend_cuda_buffer_free_buffer, |
545 | 555 | /* .get_base = */ ggml_backend_cuda_buffer_get_base, |
546 | 556 | /* .init_tensor = */ ggml_backend_cuda_buffer_init_tensor, |
| 557 | + /* .memset_tensor = */ ggml_backend_cuda_buffer_memset_tensor, |
547 | 558 | /* .set_tensor = */ ggml_backend_cuda_buffer_set_tensor, |
548 | 559 | /* .get_tensor = */ ggml_backend_cuda_buffer_get_tensor, |
549 | 560 | /* .cpy_tensor = */ ggml_backend_cuda_buffer_cpy_tensor, |
@@ -860,6 +871,7 @@ static struct ggml_backend_buffer_i ggml_backend_cuda_split_buffer_interface = { |
860 | 871 | /* .free_buffer = */ ggml_backend_cuda_split_buffer_free_buffer, |
861 | 872 | /* .get_base = */ ggml_backend_cuda_split_buffer_get_base, |
862 | 873 | /* .init_tensor = */ ggml_backend_cuda_split_buffer_init_tensor, |
| 874 | + /* .memset_tensor = */ NULL, |
863 | 875 | /* .set_tensor = */ ggml_backend_cuda_split_buffer_set_tensor, |
864 | 876 | /* .get_tensor = */ ggml_backend_cuda_split_buffer_get_tensor, |
865 | 877 | /* .cpy_tensor = */ NULL, |
@@ -2168,6 +2180,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg |
2168 | 2180 | case GGML_OP_REPEAT: |
2169 | 2181 | ggml_cuda_op_repeat(ctx, dst); |
2170 | 2182 | break; |
| 2183 | + case GGML_OP_REPEAT_BACK: |
| 2184 | + ggml_cuda_op_repeat_back(ctx, dst); |
| 2185 | + break; |
2171 | 2186 | case GGML_OP_GET_ROWS: |
2172 | 2187 | ggml_cuda_op_get_rows(ctx, dst); |
2173 | 2188 | break; |
@@ -2201,6 +2216,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg |
2201 | 2216 | case GGML_UNARY_OP_NEG: |
2202 | 2217 | ggml_cuda_op_neg(ctx, dst); |
2203 | 2218 | break; |
| 2219 | + case GGML_UNARY_OP_STEP: |
| 2220 | + ggml_cuda_op_step(ctx, dst); |
| 2221 | + break; |
2204 | 2222 | case GGML_UNARY_OP_GELU: |
2205 | 2223 | ggml_cuda_op_gelu(ctx, dst); |
2206 | 2224 | break; |
@@ -2267,6 +2285,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg |
2267 | 2285 | case GGML_OP_MUL_MAT_ID: |
2268 | 2286 | ggml_cuda_mul_mat_id(ctx, dst); |
2269 | 2287 | break; |
| 2288 | + case GGML_OP_OUT_PROD: |
| 2289 | + ggml_cuda_out_prod(ctx, dst); |
| 2290 | + break; |
2270 | 2291 | case GGML_OP_SCALE: |
2271 | 2292 | ggml_cuda_op_scale(ctx, dst); |
2272 | 2293 | break; |
@@ -2324,6 +2345,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg |
2324 | 2345 | case GGML_OP_CROSS_ENTROPY_LOSS: |
2325 | 2346 | ggml_cuda_cross_entropy_loss(ctx, dst); |
2326 | 2347 | break; |
| 2348 | + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: |
| 2349 | + ggml_cuda_cross_entropy_loss_back(ctx, dst); |
| 2350 | + break; |
| 2351 | + case GGML_OP_OPT_STEP_ADAMW: |
| 2352 | + ggml_cuda_opt_step_adamw(ctx, dst); |
| 2353 | + break; |
2327 | 2354 | default: |
2328 | 2355 | return false; |
2329 | 2356 | } |
@@ -2761,6 +2788,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons |
2761 | 2788 | case GGML_OP_UNARY: |
2762 | 2789 | switch (ggml_get_unary_op(op)) { |
2763 | 2790 | case GGML_UNARY_OP_NEG: |
| 2791 | + case GGML_UNARY_OP_STEP: |
2764 | 2792 | case GGML_UNARY_OP_GELU: |
2765 | 2793 | case GGML_UNARY_OP_SILU: |
2766 | 2794 | case GGML_UNARY_OP_RELU: |
@@ -2813,6 +2841,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons |
2813 | 2841 | return false; |
2814 | 2842 | } |
2815 | 2843 | } break; |
| 2844 | + case GGML_OP_OUT_PROD: |
| 2845 | + return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1; |
2816 | 2846 | case GGML_OP_GET_ROWS: |
2817 | 2847 | { |
2818 | 2848 | switch (op->src[0]->type) { |
@@ -2869,6 +2899,12 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons |
2869 | 2899 | } break; |
2870 | 2900 | case GGML_OP_DUP: |
2871 | 2901 | case GGML_OP_REPEAT: |
| 2902 | + { |
| 2903 | + ggml_type src0_type = op->src[0]->type; |
| 2904 | + return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; |
| 2905 | + } break; |
| 2906 | + case GGML_OP_REPEAT_BACK: |
| 2907 | + return op->type == GGML_TYPE_F32 && op->src[0]->ne[3] == 1; |
2872 | 2908 | case GGML_OP_CONCAT: |
2873 | 2909 | { |
2874 | 2910 | ggml_type src0_type = op->src[0]->type; |
@@ -2935,9 +2971,11 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons |
2935 | 2971 | } |
2936 | 2972 | return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA && |
2937 | 2973 | op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16; |
| 2974 | +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) |
2938 | 2975 | case GGML_OP_CROSS_ENTROPY_LOSS: |
| 2976 | + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: |
| 2977 | + case GGML_OP_OPT_STEP_ADAMW: |
2939 | 2978 | return true; |
2940 | | -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) |
2941 | 2979 | default: |
2942 | 2980 | return false; |
2943 | 2981 | } |
|
0 commit comments