@@ -5757,19 +5757,32 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
5757
5757
5758
5758
cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
5759
5759
5760
- const int ne00 = src0 ? src0->ne [0 ] : 0 ;
5761
- const int ne01 = src0 ? src0->ne [1 ] : 0 ;
5762
- const int ne02 = src0 ? src0->ne [2 ] : 0 ;
5763
- const int ne03 = src0 ? src0->ne [3 ] : 0 ;
5760
+ const int ne00 = src0->ne [0 ];
5761
+ const int ne01 = src0->ne [1 ];
5762
+ const int ne02 = src0->ne [2 ];
5763
+ const int ne03 = src0->ne [3 ];
5764
+
5765
+ const cl_long nb01 = src0->nb [1 ];
5766
+ const cl_long nb02 = src0->nb [2 ];
5767
+ const cl_long nb03 = src0->nb [3 ];
5768
+
5769
+ const int ne11 = src1 ? src1->ne [1 ] : 0 ;
5770
+ const int ne12 = src1 ? src1->ne [2 ] : 0 ;
5771
+ const int ne13 = src1 ? src1->ne [3 ] : 0 ;
5772
+
5773
+ const cl_long nb11 = src1 ? src1->nb [1 ] : 0 ;
5774
+ const cl_long nb12 = src1 ? src1->nb [2 ] : 0 ;
5775
+ const cl_long nb13 = src1 ? src1->nb [3 ] : 0 ;
5776
+
5777
+ const cl_long nb1 = dst->nb [1 ];
5778
+ const cl_long nb2 = dst->nb [2 ];
5779
+ const cl_long nb3 = dst->nb [3 ];
5764
5780
5765
5781
float scale, max_bias;
5766
5782
memcpy (&scale, dst->op_params + 0 , sizeof (float ));
5767
5783
memcpy (&max_bias, dst->op_params + 1 , sizeof (float ));
5768
5784
5769
- const int nrows_x = ggml_nrows (src0);
5770
- const int nrows_y = src0->ne [1 ];
5771
-
5772
- const int n_head = nrows_x/nrows_y;
5785
+ const int n_head = src0->ne [2 ];
5773
5786
const int n_head_log2 = 1u << (uint32_t ) floorf (log2f ((float ) n_head));
5774
5787
5775
5788
const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
@@ -5816,11 +5829,23 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
5816
5829
CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne00));
5817
5830
CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (int ), &ne01));
5818
5831
CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (int ), &ne02));
5819
- CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (float ), &scale));
5820
- CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (float ), &max_bias));
5821
- CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (float ), &m0));
5822
- CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (float ), &m1));
5823
- CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (int ), &n_head_log2));
5832
+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (cl_ulong), &nb01));
5833
+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (cl_ulong), &nb02));
5834
+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (cl_ulong), &nb03));
5835
+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (int ), &ne11));
5836
+ CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (int ), &ne12));
5837
+ CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (int ), &ne13));
5838
+ CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (cl_ulong), &nb11));
5839
+ CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (cl_ulong), &nb12));
5840
+ CL_CHECK (clSetKernelArg (kernel, 17 , sizeof (cl_ulong), &nb13));
5841
+ CL_CHECK (clSetKernelArg (kernel, 18 , sizeof (cl_ulong), &nb1));
5842
+ CL_CHECK (clSetKernelArg (kernel, 19 , sizeof (cl_ulong), &nb2));
5843
+ CL_CHECK (clSetKernelArg (kernel, 20 , sizeof (cl_ulong), &nb3));
5844
+ CL_CHECK (clSetKernelArg (kernel, 21 , sizeof (float ), &scale));
5845
+ CL_CHECK (clSetKernelArg (kernel, 22 , sizeof (float ), &max_bias));
5846
+ CL_CHECK (clSetKernelArg (kernel, 23 , sizeof (float ), &m0));
5847
+ CL_CHECK (clSetKernelArg (kernel, 24 , sizeof (float ), &m1));
5848
+ CL_CHECK (clSetKernelArg (kernel, 25 , sizeof (int ), &n_head_log2));
5824
5849
5825
5850
size_t global_work_size[] = {(size_t )ne01*nth, (size_t )ne02, (size_t )ne03};
5826
5851
size_t local_work_size[] = {(size_t )nth, 1 , 1 };
0 commit comments