| 
21 | 21 | #include "ggml-cuda/mmq.cuh"  | 
22 | 22 | #include "ggml-cuda/mmvq.cuh"  | 
23 | 23 | #include "ggml-cuda/norm.cuh"  | 
 | 24 | +#include "ggml-cuda/opt-step-adamw.cuh"  | 
 | 25 | +#include "ggml-cuda/out-prod.cuh"  | 
24 | 26 | #include "ggml-cuda/pad.cuh"  | 
25 | 27 | #include "ggml-cuda/pool2d.cuh"  | 
26 | 28 | #include "ggml-cuda/quantize.cuh"  | 
@@ -493,6 +495,14 @@ GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t  | 
493 | 495 |     }  | 
494 | 496 | }  | 
495 | 497 | 
 
  | 
 | 498 | +GGML_CALL static void ggml_backend_cuda_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {  | 
 | 499 | +    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;  | 
 | 500 | + | 
 | 501 | +    ggml_cuda_set_device(ctx->device);  | 
 | 502 | +    CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + offset, value, size, cudaStreamPerThread));  | 
 | 503 | +    CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));  | 
 | 504 | +}  | 
 | 505 | + | 
496 | 506 | GGML_CALL static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {  | 
497 | 507 |     ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;  | 
498 | 508 | 
 
  | 
@@ -544,6 +554,7 @@ static ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = {  | 
544 | 554 |     /* .free_buffer     = */ ggml_backend_cuda_buffer_free_buffer,  | 
545 | 555 |     /* .get_base        = */ ggml_backend_cuda_buffer_get_base,  | 
546 | 556 |     /* .init_tensor     = */ ggml_backend_cuda_buffer_init_tensor,  | 
 | 557 | +    /* .memset_tensor   = */ ggml_backend_cuda_buffer_memset_tensor,  | 
547 | 558 |     /* .set_tensor      = */ ggml_backend_cuda_buffer_set_tensor,  | 
548 | 559 |     /* .get_tensor      = */ ggml_backend_cuda_buffer_get_tensor,  | 
549 | 560 |     /* .cpy_tensor      = */ ggml_backend_cuda_buffer_cpy_tensor,  | 
@@ -860,6 +871,7 @@ static struct ggml_backend_buffer_i ggml_backend_cuda_split_buffer_interface = {  | 
860 | 871 |     /* .free_buffer     = */ ggml_backend_cuda_split_buffer_free_buffer,  | 
861 | 872 |     /* .get_base        = */ ggml_backend_cuda_split_buffer_get_base,  | 
862 | 873 |     /* .init_tensor     = */ ggml_backend_cuda_split_buffer_init_tensor,  | 
 | 874 | +    /* .memset_tensor   = */ NULL,  | 
863 | 875 |     /* .set_tensor      = */ ggml_backend_cuda_split_buffer_set_tensor,  | 
864 | 876 |     /* .get_tensor      = */ ggml_backend_cuda_split_buffer_get_tensor,  | 
865 | 877 |     /* .cpy_tensor      = */ NULL,  | 
@@ -2168,6 +2180,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg  | 
2168 | 2180 |         case GGML_OP_REPEAT:  | 
2169 | 2181 |             ggml_cuda_op_repeat(ctx, dst);  | 
2170 | 2182 |             break;  | 
 | 2183 | +        case GGML_OP_REPEAT_BACK:  | 
 | 2184 | +            ggml_cuda_op_repeat_back(ctx, dst);  | 
 | 2185 | +            break;  | 
2171 | 2186 |         case GGML_OP_GET_ROWS:  | 
2172 | 2187 |             ggml_cuda_op_get_rows(ctx, dst);  | 
2173 | 2188 |             break;  | 
@@ -2201,6 +2216,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg  | 
2201 | 2216 |                 case GGML_UNARY_OP_NEG:  | 
2202 | 2217 |                     ggml_cuda_op_neg(ctx, dst);  | 
2203 | 2218 |                     break;  | 
 | 2219 | +                case GGML_UNARY_OP_STEP:  | 
 | 2220 | +                    ggml_cuda_op_step(ctx, dst);  | 
 | 2221 | +                    break;  | 
2204 | 2222 |                 case GGML_UNARY_OP_GELU:  | 
2205 | 2223 |                     ggml_cuda_op_gelu(ctx, dst);  | 
2206 | 2224 |                     break;  | 
@@ -2267,6 +2285,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg  | 
2267 | 2285 |         case GGML_OP_MUL_MAT_ID:  | 
2268 | 2286 |             ggml_cuda_mul_mat_id(ctx, dst);  | 
2269 | 2287 |             break;  | 
 | 2288 | +        case GGML_OP_OUT_PROD:  | 
 | 2289 | +            ggml_cuda_out_prod(ctx, dst);  | 
 | 2290 | +            break;  | 
2270 | 2291 |         case GGML_OP_SCALE:  | 
2271 | 2292 |             ggml_cuda_op_scale(ctx, dst);  | 
2272 | 2293 |             break;  | 
@@ -2324,6 +2345,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg  | 
2324 | 2345 |         case GGML_OP_CROSS_ENTROPY_LOSS:  | 
2325 | 2346 |             ggml_cuda_cross_entropy_loss(ctx, dst);  | 
2326 | 2347 |             break;  | 
 | 2348 | +        case GGML_OP_CROSS_ENTROPY_LOSS_BACK:  | 
 | 2349 | +            ggml_cuda_cross_entropy_loss_back(ctx, dst);  | 
 | 2350 | +            break;  | 
 | 2351 | +        case GGML_OP_OPT_STEP_ADAMW:  | 
 | 2352 | +            ggml_cuda_opt_step_adamw(ctx, dst);  | 
 | 2353 | +            break;  | 
2327 | 2354 |         default:  | 
2328 | 2355 |             return false;  | 
2329 | 2356 |     }  | 
@@ -2761,6 +2788,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons  | 
2761 | 2788 |         case GGML_OP_UNARY:  | 
2762 | 2789 |             switch (ggml_get_unary_op(op)) {  | 
2763 | 2790 |                 case GGML_UNARY_OP_NEG:  | 
 | 2791 | +                case GGML_UNARY_OP_STEP:  | 
2764 | 2792 |                 case GGML_UNARY_OP_GELU:  | 
2765 | 2793 |                 case GGML_UNARY_OP_SILU:  | 
2766 | 2794 |                 case GGML_UNARY_OP_RELU:  | 
@@ -2813,6 +2841,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons  | 
2813 | 2841 |                         return false;  | 
2814 | 2842 |                 }  | 
2815 | 2843 |             } break;  | 
 | 2844 | +        case GGML_OP_OUT_PROD:  | 
 | 2845 | +            return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;  | 
2816 | 2846 |         case GGML_OP_GET_ROWS:  | 
2817 | 2847 |             {  | 
2818 | 2848 |                 switch (op->src[0]->type) {  | 
@@ -2869,6 +2899,12 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons  | 
2869 | 2899 |             } break;  | 
2870 | 2900 |         case GGML_OP_DUP:  | 
2871 | 2901 |         case GGML_OP_REPEAT:  | 
 | 2902 | +            {  | 
 | 2903 | +                ggml_type src0_type = op->src[0]->type;  | 
 | 2904 | +                return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;  | 
 | 2905 | +            } break;  | 
 | 2906 | +        case GGML_OP_REPEAT_BACK:  | 
 | 2907 | +                return op->type == GGML_TYPE_F32 && op->src[0]->ne[3] == 1;  | 
2872 | 2908 |         case GGML_OP_CONCAT:  | 
2873 | 2909 |             {  | 
2874 | 2910 |                 ggml_type src0_type = op->src[0]->type;  | 
@@ -2935,9 +2971,11 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons  | 
2935 | 2971 |             }  | 
2936 | 2972 |             return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&  | 
2937 | 2973 |                 op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;  | 
 | 2974 | +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)  | 
2938 | 2975 |         case GGML_OP_CROSS_ENTROPY_LOSS:  | 
 | 2976 | +        case GGML_OP_CROSS_ENTROPY_LOSS_BACK:  | 
 | 2977 | +        case GGML_OP_OPT_STEP_ADAMW:  | 
2939 | 2978 |             return true;  | 
2940 |  | -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)  | 
2941 | 2979 |         default:  | 
2942 | 2980 |             return false;  | 
2943 | 2981 |     }  | 
 | 
0 commit comments