@@ -5763,19 +5763,31 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
57635763
57645764 cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
57655765
5766- const int ne00 = src0 ? src0->ne [0 ] : 0 ;
5767- const int ne01 = src0 ? src0->ne [1 ] : 0 ;
5768- const int ne02 = src0 ? src0->ne [2 ] : 0 ;
5769- const int ne03 = src0 ? src0->ne [3 ] : 0 ;
5766+ const int ne00 = src0->ne [0 ];
5767+ const int ne01 = src0->ne [1 ];
5768+ const int ne02 = src0->ne [2 ];
5769+ const int ne03 = src0->ne [3 ];
5770+
5771+ const cl_long nb01 = src0->nb [1 ];
5772+ const cl_long nb02 = src0->nb [2 ];
5773+ const cl_long nb03 = src0->nb [3 ];
5774+
5775+ const int ne12 = src1 ? src1->ne [2 ] : 0 ;
5776+ const int ne13 = src1 ? src1->ne [3 ] : 0 ;
5777+
5778+ const cl_long nb11 = src1 ? src1->nb [1 ] : 0 ;
5779+ const cl_long nb12 = src1 ? src1->nb [2 ] : 0 ;
5780+ const cl_long nb13 = src1 ? src1->nb [3 ] : 0 ;
5781+
5782+ const cl_long nb1 = dst->nb [1 ];
5783+ const cl_long nb2 = dst->nb [2 ];
5784+ const cl_long nb3 = dst->nb [3 ];
57705785
57715786 float scale, max_bias;
57725787 memcpy (&scale, dst->op_params + 0 , sizeof (float ));
57735788 memcpy (&max_bias, dst->op_params + 1 , sizeof (float ));
57745789
5775- const int nrows_x = ggml_nrows (src0);
5776- const int nrows_y = src0->ne [1 ];
5777-
5778- const int n_head = nrows_x/nrows_y;
5790+ const int n_head = src0->ne [2 ];
57795791 const int n_head_log2 = 1u << (uint32_t ) floorf (log2f ((float ) n_head));
57805792
57815793 const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
@@ -5820,13 +5832,22 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
58205832 CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extrad->data_device ));
58215833 CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offsetd));
58225834 CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne00));
5823- CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (int ), &ne01));
5824- CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (int ), &ne02));
5825- CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (float ), &scale));
5826- CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (float ), &max_bias));
5827- CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (float ), &m0));
5828- CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (float ), &m1));
5829- CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (int ), &n_head_log2));
5835+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (cl_ulong), &nb01));
5836+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (cl_ulong), &nb02));
5837+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (cl_ulong), &nb03));
5838+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (int ), &ne12));
5839+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (int ), &ne13));
5840+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (cl_ulong), &nb11));
5841+ CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (cl_ulong), &nb12));
5842+ CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (cl_ulong), &nb13));
5843+ CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (cl_ulong), &nb1));
5844+ CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (cl_ulong), &nb2));
5845+ CL_CHECK (clSetKernelArg (kernel, 17 , sizeof (cl_ulong), &nb3));
5846+ CL_CHECK (clSetKernelArg (kernel, 18 , sizeof (float ), &scale));
5847+ CL_CHECK (clSetKernelArg (kernel, 19 , sizeof (float ), &max_bias));
5848+ CL_CHECK (clSetKernelArg (kernel, 20 , sizeof (float ), &m0));
5849+ CL_CHECK (clSetKernelArg (kernel, 21 , sizeof (float ), &m1));
5850+ CL_CHECK (clSetKernelArg (kernel, 22 , sizeof (int ), &n_head_log2));
58305851
58315852 size_t global_work_size[] = {(size_t )ne01*nth, (size_t )ne02, (size_t )ne03};
58325853 size_t local_work_size[] = {(size_t )nth, 1 , 1 };
0 commit comments