@@ -2481,6 +2481,13 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
2481
2481
case GGML_OP_SCALE:
2482
2482
return op->src [0 ]->type == GGML_TYPE_F32 && ggml_is_contiguous (op->src [0 ]);
2483
2483
case GGML_OP_ADD:
2484
+ if (op->type == GGML_TYPE_F16) {
2485
+ const bool src0_ok = op->src [0 ]->type == GGML_TYPE_F16 || op->src [0 ]->type == GGML_TYPE_F32;
2486
+ const bool src1_ok = op->src [1 ]->type == GGML_TYPE_F16 || op->src [1 ]->type == GGML_TYPE_F32;
2487
+ if (src0_ok && src1_ok) {
2488
+ return true ;
2489
+ }
2490
+ }
2484
2491
case GGML_OP_MUL:
2485
2492
case GGML_OP_DIV:
2486
2493
case GGML_OP_SUB:
@@ -3717,34 +3724,30 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
3717
3724
GGML_ASSERT (dst);
3718
3725
GGML_ASSERT (dst->extra );
3719
3726
3720
- GGML_ASSERT (src0->type == src1->type );
3721
- GGML_ASSERT (src0->type == dst->type );
3722
- GGML_ASSERT (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
3723
-
3724
- const int ne00 = src0->ne [0 ];
3725
- const int ne01 = src0->ne [1 ];
3726
- const int ne02 = src0->ne [2 ];
3727
- const int ne03 = src0->ne [3 ];
3727
+ const int ne00 = src0->ne [0 ];
3728
+ const int ne01 = src0->ne [1 ];
3729
+ const int ne02 = src0->ne [2 ];
3730
+ const int ne03 = src0->ne [3 ];
3728
3731
3729
3732
const cl_ulong nb00 = src0->nb [0 ];
3730
3733
const cl_ulong nb01 = src0->nb [1 ];
3731
3734
const cl_ulong nb02 = src0->nb [2 ];
3732
3735
const cl_ulong nb03 = src0->nb [3 ];
3733
3736
3734
- const int ne10 = src1->ne [0 ];
3735
- const int ne11 = src1->ne [1 ];
3736
- const int ne12 = src1->ne [2 ];
3737
- const int ne13 = src1->ne [3 ]; UNUSED (ne13) ;
3737
+ const int ne10 = src1->ne [0 ];
3738
+ const int ne11 = src1->ne [1 ];
3739
+ const int ne12 = src1->ne [2 ];
3740
+ const int ne13 = src1->ne [3 ];
3738
3741
3739
3742
const cl_ulong nb10 = src1->nb [0 ];
3740
3743
const cl_ulong nb11 = src1->nb [1 ];
3741
3744
const cl_ulong nb12 = src1->nb [2 ];
3742
- const cl_ulong nb13 = src1->nb [3 ]; UNUSED (nb13);
3745
+ const cl_ulong nb13 = src1->nb [3 ];
3743
3746
3744
- const int ne0 = dst->ne [0 ];
3745
- const int ne1 = dst->ne [1 ];
3746
- const int ne2 = dst->ne [2 ];
3747
- const int ne3 = dst->ne [3 ];
3747
+ const int ne0 = dst->ne [0 ];
3748
+ const int ne1 = dst->ne [1 ];
3749
+ const int ne2 = dst->ne [2 ];
3750
+ const int ne3 = dst->ne [3 ];
3748
3751
3749
3752
const cl_ulong nb0 = dst->nb [0 ];
3750
3753
const cl_ulong nb1 = dst->nb [1 ];
@@ -3761,68 +3764,114 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
3761
3764
cl_ulong offset1 = extra1->offset + src1->view_offs ;
3762
3765
cl_ulong offsetd = extrad->offset + dst->view_offs ;
3763
3766
3764
- bool bcast_row = false ;
3765
3767
cl_kernel kernel;
3766
3768
3767
- if (ggml_nelements (src1) == ne10 && ggml_is_contiguous (src1) && ne00 % 4 == 0 && ne10 % 4 == 0 ) {
3768
- GGML_ASSERT (ggml_is_contiguous (src0));
3769
+ const bool bcast_row = ggml_nelements (src1) == ne10 && ggml_is_contiguous (src1) && ne00 % 4 == 0 && ne10 % 4 == 0 ;
3769
3770
3770
- // src1 is a row
3771
+ if (bcast_row) {
3772
+ GGML_ASSERT (ggml_is_contiguous (src0));
3771
3773
GGML_ASSERT (ne11 == 1 );
3774
+ }
3772
3775
3773
- bcast_row = true ;
3774
- int ne = ne00 / 4 ;
3775
-
3776
- if (src0->type == GGML_TYPE_F32) {
3776
+ if (dst->type == GGML_TYPE_F32) {
3777
+ GGML_ASSERT (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32);
3778
+ if (bcast_row) {
3777
3779
kernel = backend_ctx->kernel_add_row ;
3780
+ const int ne = ne00 / 4 ;
3781
+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
3782
+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
3783
+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra1->data_device ));
3784
+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offset1));
3785
+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extrad->data_device ));
3786
+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offsetd));
3787
+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne));
3778
3788
} else {
3779
- kernel = backend_ctx->kernel_add_row_f16 ;
3780
- }
3781
-
3782
- CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
3783
- CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
3784
- CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra1->data_device ));
3785
- CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offset1));
3786
- CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extrad->data_device ));
3787
- CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offsetd));
3788
- CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne));
3789
- } else {
3790
- if (src0->type == GGML_TYPE_F32) {
3791
3789
kernel = backend_ctx->kernel_add ;
3790
+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
3791
+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
3792
+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra1->data_device ));
3793
+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offset1));
3794
+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extrad->data_device ));
3795
+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offsetd));
3796
+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne00));
3797
+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (int ), &ne01));
3798
+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (int ), &ne02));
3799
+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (int ), &ne03));
3800
+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (cl_ulong), &nb00));
3801
+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (cl_ulong), &nb01));
3802
+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (cl_ulong), &nb02));
3803
+ CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (cl_ulong), &nb03));
3804
+ CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (int ), &ne10));
3805
+ CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (int ), &ne11));
3806
+ CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (int ), &ne12));
3807
+ CL_CHECK (clSetKernelArg (kernel, 17 , sizeof (int ), &ne13));
3808
+ CL_CHECK (clSetKernelArg (kernel, 18 , sizeof (cl_ulong), &nb10));
3809
+ CL_CHECK (clSetKernelArg (kernel, 19 , sizeof (cl_ulong), &nb11));
3810
+ CL_CHECK (clSetKernelArg (kernel, 20 , sizeof (cl_ulong), &nb12));
3811
+ CL_CHECK (clSetKernelArg (kernel, 21 , sizeof (cl_ulong), &nb13));
3812
+ CL_CHECK (clSetKernelArg (kernel, 22 , sizeof (int ), &ne0));
3813
+ CL_CHECK (clSetKernelArg (kernel, 23 , sizeof (int ), &ne1));
3814
+ CL_CHECK (clSetKernelArg (kernel, 24 , sizeof (int ), &ne2));
3815
+ CL_CHECK (clSetKernelArg (kernel, 25 , sizeof (int ), &ne3));
3816
+ CL_CHECK (clSetKernelArg (kernel, 26 , sizeof (cl_ulong), &nb0));
3817
+ CL_CHECK (clSetKernelArg (kernel, 27 , sizeof (cl_ulong), &nb1));
3818
+ CL_CHECK (clSetKernelArg (kernel, 28 , sizeof (cl_ulong), &nb2));
3819
+ CL_CHECK (clSetKernelArg (kernel, 29 , sizeof (cl_ulong), &nb3));
3820
+ }
3821
+ } else if (dst->type == GGML_TYPE_F16) {
3822
+ GGML_ASSERT (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32);
3823
+ GGML_ASSERT (src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
3824
+ const int type_src0 = (src0->type == GGML_TYPE_F32);
3825
+ const int type_src1 = (src1->type == GGML_TYPE_F32);
3826
+ if (bcast_row) {
3827
+ kernel = backend_ctx->kernel_add_row_f16 ;
3828
+ const int ne = ne00 / 4 ;
3829
+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
3830
+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
3831
+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra1->data_device ));
3832
+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offset1));
3833
+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extrad->data_device ));
3834
+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offsetd));
3835
+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne));
3836
+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (int ), &type_src0));
3837
+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (int ), &type_src1));
3792
3838
} else {
3793
3839
kernel = backend_ctx->kernel_add_f16 ;
3840
+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
3841
+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
3842
+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra1->data_device ));
3843
+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offset1));
3844
+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extrad->data_device ));
3845
+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offsetd));
3846
+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne00));
3847
+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (int ), &ne01));
3848
+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (int ), &ne02));
3849
+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (int ), &ne03));
3850
+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (cl_ulong), &nb00));
3851
+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (cl_ulong), &nb01));
3852
+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (cl_ulong), &nb02));
3853
+ CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (cl_ulong), &nb03));
3854
+ CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (int ), &ne10));
3855
+ CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (int ), &ne11));
3856
+ CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (int ), &ne12));
3857
+ CL_CHECK (clSetKernelArg (kernel, 17 , sizeof (int ), &ne13));
3858
+ CL_CHECK (clSetKernelArg (kernel, 18 , sizeof (cl_ulong), &nb10));
3859
+ CL_CHECK (clSetKernelArg (kernel, 19 , sizeof (cl_ulong), &nb11));
3860
+ CL_CHECK (clSetKernelArg (kernel, 20 , sizeof (cl_ulong), &nb12));
3861
+ CL_CHECK (clSetKernelArg (kernel, 21 , sizeof (cl_ulong), &nb13));
3862
+ CL_CHECK (clSetKernelArg (kernel, 22 , sizeof (int ), &ne0));
3863
+ CL_CHECK (clSetKernelArg (kernel, 23 , sizeof (int ), &ne1));
3864
+ CL_CHECK (clSetKernelArg (kernel, 24 , sizeof (int ), &ne2));
3865
+ CL_CHECK (clSetKernelArg (kernel, 25 , sizeof (int ), &ne3));
3866
+ CL_CHECK (clSetKernelArg (kernel, 26 , sizeof (cl_ulong), &nb0));
3867
+ CL_CHECK (clSetKernelArg (kernel, 27 , sizeof (cl_ulong), &nb1));
3868
+ CL_CHECK (clSetKernelArg (kernel, 28 , sizeof (cl_ulong), &nb2));
3869
+ CL_CHECK (clSetKernelArg (kernel, 29 , sizeof (cl_ulong), &nb3));
3870
+ CL_CHECK (clSetKernelArg (kernel, 30 , sizeof (int ), &type_src0));
3871
+ CL_CHECK (clSetKernelArg (kernel, 31 , sizeof (int ), &type_src1));
3794
3872
}
3795
-
3796
- CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
3797
- CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
3798
- CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra1->data_device ));
3799
- CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offset1));
3800
- CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extrad->data_device ));
3801
- CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offsetd));
3802
- CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne00));
3803
- CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (int ), &ne01));
3804
- CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (int ), &ne02));
3805
- CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (int ), &ne03));
3806
- CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (cl_ulong), &nb00));
3807
- CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (cl_ulong), &nb01));
3808
- CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (cl_ulong), &nb02));
3809
- CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (cl_ulong), &nb03));
3810
- CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (int ), &ne10));
3811
- CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (int ), &ne11));
3812
- CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (int ), &ne12));
3813
- CL_CHECK (clSetKernelArg (kernel, 17 , sizeof (int ), &ne13));
3814
- CL_CHECK (clSetKernelArg (kernel, 18 , sizeof (cl_ulong), &nb10));
3815
- CL_CHECK (clSetKernelArg (kernel, 19 , sizeof (cl_ulong), &nb11));
3816
- CL_CHECK (clSetKernelArg (kernel, 20 , sizeof (cl_ulong), &nb12));
3817
- CL_CHECK (clSetKernelArg (kernel, 21 , sizeof (cl_ulong), &nb13));
3818
- CL_CHECK (clSetKernelArg (kernel, 22 , sizeof (int ), &ne0));
3819
- CL_CHECK (clSetKernelArg (kernel, 23 , sizeof (int ), &ne1));
3820
- CL_CHECK (clSetKernelArg (kernel, 24 , sizeof (int ), &ne2));
3821
- CL_CHECK (clSetKernelArg (kernel, 25 , sizeof (int ), &ne3));
3822
- CL_CHECK (clSetKernelArg (kernel, 26 , sizeof (cl_ulong), &nb0));
3823
- CL_CHECK (clSetKernelArg (kernel, 27 , sizeof (cl_ulong), &nb1));
3824
- CL_CHECK (clSetKernelArg (kernel, 28 , sizeof (cl_ulong), &nb2));
3825
- CL_CHECK (clSetKernelArg (kernel, 29 , sizeof (cl_ulong), &nb3));
3873
+ } else {
3874
+ GGML_ASSERT (false && " unsupported data types for add" );
3826
3875
}
3827
3876
3828
3877
if (bcast_row) {
@@ -3832,13 +3881,13 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
3832
3881
3833
3882
size_t * local_work_size_ptr = local_work_size;
3834
3883
if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups ) {
3835
- local_work_size_ptr = nullptr ; // Let driver choose the work-group sizes.
3884
+ local_work_size_ptr = nullptr ;
3836
3885
}
3837
3886
3838
- backend_ctx->enqueue_ndrange_kernel (kernel, 3 , global_work_size, local_work_size_ptr, dst);
3887
+ backend_ctx->enqueue_ndrange_kernel (kernel, 1 , global_work_size, local_work_size_ptr, dst);
3839
3888
} else {
3840
3889
unsigned int nth = MIN (64 , ne0);
3841
- size_t global_work_size[] = {ne01*nth, (size_t )ne02, (size_t )ne03};
3890
+ size_t global_work_size[] = {( size_t ) ne01*nth, (size_t )ne02, (size_t )ne03};
3842
3891
size_t local_work_size[] = {nth, 1 , 1 };
3843
3892
3844
3893
backend_ctx->enqueue_ndrange_kernel (kernel, 3 , global_work_size, local_work_size, dst);
0 commit comments