File tree Expand file tree Collapse file tree 3 files changed +7
-6
lines changed Expand file tree Collapse file tree 3 files changed +7
-6
lines changed Original file line number Diff line number Diff line change @@ -454,6 +454,8 @@ typedef struct {
454454 int64_t ne00 ;
455455 int64_t ne01 ;
456456 int64_t ne02 ;
457+ uint64_t nb11 ;
458+ uint64_t nb12 ;
457459 float scale ;
458460 float max_bias ;
459461 float m0 ;
Original file line number Diff line number Diff line change @@ -2562,10 +2562,7 @@ static bool ggml_metal_encode_node(
25622562 memcpy (&scale, ((const int32_t *) dst->op_params ) + 0 , sizeof (scale));
25632563 memcpy (&max_bias, ((const int32_t *) dst->op_params ) + 1 , sizeof (max_bias));
25642564
2565- const int64_t nrows_x = ggml_nrows (src0);
2566- const int64_t nrows_y = src0->ne [1 ];
2567-
2568- const uint32_t n_head = nrows_x/nrows_y;
2565+ const uint32_t n_head = src0->ne [2 ];
25692566 const uint32_t n_head_log2 = 1u << (uint32_t ) floorf (log2f ((float ) n_head));
25702567
25712568 const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
@@ -2625,6 +2622,8 @@ static bool ggml_metal_encode_node(
26252622 /* .ne00 =*/ ne00,
26262623 /* .ne01 =*/ ne01,
26272624 /* .ne02 =*/ ne02,
2625+ /* .nb11 =*/ nb11,
2626+ /* .nb12 =*/ nb12,
26282627 /* .scale =*/ scale,
26292628 /* .max_bias =*/ max_bias,
26302629 /* .m0 =*/ m0,
Original file line number Diff line number Diff line change @@ -1263,7 +1263,7 @@ kernel void kernel_soft_max(
12631263 const int64_t i01 = (tgpig - i03*args.ne02 *args.ne01 - i02*args.ne01 );
12641264
12651265 device const float * psrc0 = (device const float *) src0 + (i03*args.ne02 *args.ne01 *args.ne00 + i02*args.ne01 *args.ne00 + i01*args.ne00 );
1266- device const T * pmask = src1 != src0 ? (device const T *) src1 + i01 *args.ne00 : nullptr ;
1266+ device const T * pmask = src1 != src0 ? (device const T *) ( src1 + i01*args. nb11 + i03 *args.nb12 ) : nullptr ;
12671267 device float * pdst = (device float *) dst + (i03*args.ne02 *args.ne01 *args.ne00 + i02*args.ne01 *args.ne00 + i01*args.ne00 );
12681268
12691269 float slope = 1 .0f ;
@@ -1359,7 +1359,7 @@ kernel void kernel_soft_max_4(
13591359 const int64_t i01 = (tgpig - i03*args.ne02 *args.ne01 - i02*args.ne01 );
13601360
13611361 device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02 *args.ne01 *args.ne00 + i02*args.ne01 *args.ne00 + i01*args.ne00 )/4 ;
1362- device const T * pmask = src1 != src0 ? (device const T *) src1 + i01 *args.ne00 / 4 : nullptr ;
1362+ device const T * pmask = src1 != src0 ? (device const T *) ( src1 + i01*args. nb11 + i03 *args.nb12 ) : nullptr ;
13631363 device float4 * pdst4 = (device float4 *) dst + (i03*args.ne02 *args.ne01 *args.ne00 + i02*args.ne01 *args.ne00 + i01*args.ne00 )/4 ;
13641364
13651365 float slope = 1 .0f ;
You can’t perform that action at this time.
0 commit comments