@@ -802,15 +802,6 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
802802 if (op->src [0 ]->ne [0 ] == 256 ) {
803803 return false ;
804804 }
805- {
806- float logit_softcap;
807-
808- memcpy (&logit_softcap, ((const float *) op->op_params ) + 2 , sizeof (logit_softcap));
809-
810- if (logit_softcap != 0 .0f ) {
811- return false ;
812- }
813- }
814805 return ctx->support_simdgroup_mm ; // TODO: over-restricted for vec-kernels
815806 case GGML_OP_MUL_MAT:
816807 case GGML_OP_MUL_MAT_ID:
@@ -2633,9 +2624,14 @@ static enum ggml_status ggml_metal_graph_compute(
26332624
26342625 float scale;
26352626 float max_bias;
2627+ float logit_softcap;
2628+ memcpy (&scale, ((int32_t *) dst->op_params ) + 0 , sizeof (scale));
2629+ memcpy (&max_bias, ((int32_t *) dst->op_params ) + 1 , sizeof (max_bias));
2630+ memcpy (&logit_softcap, ((int32_t *) dst->op_params ) + 2 , sizeof (logit_softcap));
26362631
2637- memcpy (&scale, ((int32_t *) dst->op_params ) + 0 , sizeof (scale));
2638- memcpy (&max_bias, ((int32_t *) dst->op_params ) + 1 , sizeof (max_bias));
2632+ if (logit_softcap != 0 .0f ) {
2633+ scale /= logit_softcap;
2634+ }
26392635
26402636 const uint32_t n_head = src0->ne [2 ];
26412637 const uint32_t n_head_log2 = 1u << (uint32_t ) floorf (log2f ((float ) n_head));
@@ -2686,30 +2682,31 @@ static enum ggml_status ggml_metal_graph_compute(
26862682 } else {
26872683 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 3 ];
26882684 }
2689- [encoder setBuffer: id_dst offset: offs_dst atIndex: 4 ];
2690- [encoder setBytes: &ne01 length: sizeof ( int64_t ) atIndex: 5 ];
2691- [encoder setBytes: &ne02 length: sizeof ( int64_t ) atIndex: 6 ];
2692- [encoder setBytes: &ne03 length: sizeof ( int64_t ) atIndex: 7 ];
2693- [encoder setBytes: &nb01 length: sizeof (uint64_t ) atIndex: 8 ];
2694- [encoder setBytes: &nb02 length: sizeof (uint64_t ) atIndex: 9 ];
2695- [encoder setBytes: &nb03 length: sizeof (uint64_t ) atIndex: 10 ];
2696- [encoder setBytes: &ne11 length: sizeof ( int64_t ) atIndex: 11 ];
2697- [encoder setBytes: &ne12 length: sizeof ( int64_t ) atIndex: 12 ];
2698- [encoder setBytes: &ne13 length: sizeof ( int64_t ) atIndex: 13 ];
2699- [encoder setBytes: &nb11 length: sizeof (uint64_t ) atIndex: 14 ];
2700- [encoder setBytes: &nb12 length: sizeof (uint64_t ) atIndex: 15 ];
2701- [encoder setBytes: &nb13 length: sizeof (uint64_t ) atIndex: 16 ];
2702- [encoder setBytes: &nb21 length: sizeof (uint64_t ) atIndex: 17 ];
2703- [encoder setBytes: &nb22 length: sizeof (uint64_t ) atIndex: 18 ];
2704- [encoder setBytes: &nb23 length: sizeof (uint64_t ) atIndex: 19 ];
2705- [encoder setBytes: &nb31 length: sizeof (uint64_t ) atIndex: 20 ];
2706- [encoder setBytes: &ne1 length: sizeof ( int64_t ) atIndex: 21 ];
2707- [encoder setBytes: &ne2 length: sizeof ( int64_t ) atIndex: 22 ];
2708- [encoder setBytes: &scale length: sizeof ( float ) atIndex: 23 ];
2709- [encoder setBytes: &max_bias length: sizeof ( float ) atIndex: 24 ];
2710- [encoder setBytes: &m0 length: sizeof (m0) atIndex: 25 ];
2711- [encoder setBytes: &m1 length: sizeof (m1) atIndex: 26 ];
2712- [encoder setBytes: &n_head_log2 length: sizeof (n_head_log2) atIndex: 27 ];
2685+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 4 ];
2686+ [encoder setBytes: &ne01 length: sizeof ( int64_t ) atIndex: 5 ];
2687+ [encoder setBytes: &ne02 length: sizeof ( int64_t ) atIndex: 6 ];
2688+ [encoder setBytes: &ne03 length: sizeof ( int64_t ) atIndex: 7 ];
2689+ [encoder setBytes: &nb01 length: sizeof (uint64_t ) atIndex: 8 ];
2690+ [encoder setBytes: &nb02 length: sizeof (uint64_t ) atIndex: 9 ];
2691+ [encoder setBytes: &nb03 length: sizeof (uint64_t ) atIndex: 10 ];
2692+ [encoder setBytes: &ne11 length: sizeof ( int64_t ) atIndex: 11 ];
2693+ [encoder setBytes: &ne12 length: sizeof ( int64_t ) atIndex: 12 ];
2694+ [encoder setBytes: &ne13 length: sizeof ( int64_t ) atIndex: 13 ];
2695+ [encoder setBytes: &nb11 length: sizeof (uint64_t ) atIndex: 14 ];
2696+ [encoder setBytes: &nb12 length: sizeof (uint64_t ) atIndex: 15 ];
2697+ [encoder setBytes: &nb13 length: sizeof (uint64_t ) atIndex: 16 ];
2698+ [encoder setBytes: &nb21 length: sizeof (uint64_t ) atIndex: 17 ];
2699+ [encoder setBytes: &nb22 length: sizeof (uint64_t ) atIndex: 18 ];
2700+ [encoder setBytes: &nb23 length: sizeof (uint64_t ) atIndex: 19 ];
2701+ [encoder setBytes: &nb31 length: sizeof (uint64_t ) atIndex: 20 ];
2702+ [encoder setBytes: &ne1 length: sizeof ( int64_t ) atIndex: 21 ];
2703+ [encoder setBytes: &ne2 length: sizeof ( int64_t ) atIndex: 22 ];
2704+ [encoder setBytes: &scale length: sizeof ( float ) atIndex: 23 ];
2705+ [encoder setBytes: &max_bias length: sizeof ( float ) atIndex: 24 ];
2706+ [encoder setBytes: &m0 length: sizeof (m0) atIndex: 25 ];
2707+ [encoder setBytes: &m1 length: sizeof (m1) atIndex: 26 ];
2708+ [encoder setBytes: &n_head_log2 length: sizeof (n_head_log2) atIndex: 27 ];
2709+ [encoder setBytes: &logit_softcap length: sizeof (logit_softcap) atIndex: 28 ];
27132710
27142711 if (!use_vec_kernel) {
27152712 // half8x8 kernel
0 commit comments