@@ -5757,19 +5757,32 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
57575757
57585758 cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
57595759
5760- const int ne00 = src0 ? src0->ne [0 ] : 0 ;
5761- const int ne01 = src0 ? src0->ne [1 ] : 0 ;
5762- const int ne02 = src0 ? src0->ne [2 ] : 0 ;
5763- const int ne03 = src0 ? src0->ne [3 ] : 0 ;
5760+ const int ne00 = src0->ne [0 ];
5761+ const int ne01 = src0->ne [1 ];
5762+ const int ne02 = src0->ne [2 ];
5763+ const int ne03 = src0->ne [3 ];
5764+
5765+ const cl_long nb01 = src0->nb [1 ];
5766+ const cl_long nb02 = src0->nb [2 ];
5767+ const cl_long nb03 = src0->nb [3 ];
5768+
5769+ const int ne11 = src1 ? src1->ne [1 ] : 0 ;
5770+ const int ne12 = src1 ? src1->ne [2 ] : 0 ;
5771+ const int ne13 = src1 ? src1->ne [3 ] : 0 ;
5772+
5773+ const cl_long nb11 = src1 ? src1->nb [1 ] : 0 ;
5774+ const cl_long nb12 = src1 ? src1->nb [2 ] : 0 ;
5775+ const cl_long nb13 = src1 ? src1->nb [3 ] : 0 ;
5776+
5777+ const cl_long nb1 = dst->nb [1 ];
5778+ const cl_long nb2 = dst->nb [2 ];
5779+ const cl_long nb3 = dst->nb [3 ];
57645780
57655781 float scale, max_bias;
57665782 memcpy (&scale, dst->op_params + 0 , sizeof (float ));
57675783 memcpy (&max_bias, dst->op_params + 1 , sizeof (float ));
57685784
5769- const int nrows_x = ggml_nrows (src0);
5770- const int nrows_y = src0->ne [1 ];
5771-
5772- const int n_head = nrows_x/nrows_y;
5785+ const int n_head = src0->ne [2 ];
57735786 const int n_head_log2 = 1u << (uint32_t ) floorf (log2f ((float ) n_head));
57745787
57755788 const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
@@ -5816,11 +5829,23 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
58165829 CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne00));
58175830 CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (int ), &ne01));
58185831 CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (int ), &ne02));
5819- CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (float ), &scale));
5820- CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (float ), &max_bias));
5821- CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (float ), &m0));
5822- CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (float ), &m1));
5823- CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (int ), &n_head_log2));
5832+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (cl_ulong), &nb01));
5833+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (cl_ulong), &nb02));
5834+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (cl_ulong), &nb03));
5835+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (int ), &ne11));
5836+ CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (int ), &ne12));
5837+ CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (int ), &ne13));
5838+ CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (cl_ulong), &nb11));
5839+ CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (cl_ulong), &nb12));
5840+ CL_CHECK (clSetKernelArg (kernel, 17 , sizeof (cl_ulong), &nb13));
5841+ CL_CHECK (clSetKernelArg (kernel, 18 , sizeof (cl_ulong), &nb1));
5842+ CL_CHECK (clSetKernelArg (kernel, 19 , sizeof (cl_ulong), &nb2));
5843+ CL_CHECK (clSetKernelArg (kernel, 20 , sizeof (cl_ulong), &nb3));
5844+ CL_CHECK (clSetKernelArg (kernel, 21 , sizeof (float ), &scale));
5845+ CL_CHECK (clSetKernelArg (kernel, 22 , sizeof (float ), &max_bias));
5846+ CL_CHECK (clSetKernelArg (kernel, 23 , sizeof (float ), &m0));
5847+ CL_CHECK (clSetKernelArg (kernel, 24 , sizeof (float ), &m1));
5848+ CL_CHECK (clSetKernelArg (kernel, 25 , sizeof (int ), &n_head_log2));
58245849
58255850 size_t global_work_size[] = {(size_t )ne01*nth, (size_t )ne02, (size_t )ne03};
58265851 size_t local_work_size[] = {(size_t )nth, 1 , 1 };
0 commit comments