@@ -2691,6 +2691,109 @@ static void ggml_compute_forward_gelu(
26912691 }
26922692}
26932693
2694+ // ggml_compute_forward_gelu_erf
2695+
2696+ static void ggml_compute_forward_gelu_erf_f32 (
2697+ const ggml_compute_params * params,
2698+ ggml_tensor * dst) {
2699+
2700+ const ggml_tensor * src0 = dst->src [0 ];
2701+
2702+ assert (ggml_is_contiguous_1 (src0));
2703+ assert (ggml_is_contiguous_1 (dst));
2704+ assert (ggml_are_same_shape (src0, dst));
2705+
2706+ const int ith = params->ith ;
2707+ const int nth = params->nth ;
2708+
2709+ const int nc = src0->ne [0 ];
2710+ const int nr = ggml_nrows (src0);
2711+
2712+ // rows per thread
2713+ const int dr = (nr + nth - 1 )/nth;
2714+
2715+ // row range for this thread
2716+ const int ir0 = dr*ith;
2717+ const int ir1 = MIN (ir0 + dr, nr);
2718+
2719+ for (int i1 = ir0; i1 < ir1; i1++) {
2720+ ggml_vec_gelu_erf_f32 (nc,
2721+ (float *) ((char *) dst->data + i1*( dst->nb [1 ])),
2722+ (float *) ((char *) src0->data + i1*(src0->nb [1 ])));
2723+
2724+ #ifndef NDEBUG
2725+ for (int k = 0 ; k < nc; k++) {
2726+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb [1 ])))[k];
2727+ GGML_UNUSED (x);
2728+ assert (!isnan (x));
2729+ assert (!isinf (x));
2730+ }
2731+ #endif
2732+ }
2733+ }
2734+
2735+ static void ggml_compute_forward_gelu_erf_f16 (
2736+ const ggml_compute_params * params,
2737+ ggml_tensor * dst) {
2738+
2739+ const ggml_tensor * src0 = dst->src [0 ];
2740+
2741+ assert (ggml_is_contiguous_1 (src0));
2742+ assert (ggml_is_contiguous_1 (dst));
2743+ assert (ggml_are_same_shape (src0, dst));
2744+
2745+ const int ith = params->ith ;
2746+ const int nth = params->nth ;
2747+
2748+ const int nc = src0->ne [0 ];
2749+ const int nr = ggml_nrows (src0);
2750+
2751+ // rows per thread
2752+ const int dr = (nr + nth - 1 )/nth;
2753+
2754+ // row range for this thread
2755+ const int ir0 = dr*ith;
2756+ const int ir1 = MIN (ir0 + dr, nr);
2757+
2758+ for (int i1 = ir0; i1 < ir1; i1++) {
2759+ ggml_vec_gelu_erf_f16 (nc,
2760+ (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb [1 ])),
2761+ (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb [1 ])));
2762+
2763+ #ifndef NDEBUG
2764+ for (int k = 0 ; k < nc; k++) {
2765+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb [1 ])))[k];
2766+ const float v = GGML_FP16_TO_FP32 (x);
2767+ GGML_UNUSED (v);
2768+ assert (!isnan (v));
2769+ assert (!isinf (v));
2770+ }
2771+ #endif
2772+ }
2773+ }
2774+
2775+ static void ggml_compute_forward_gelu_erf (
2776+ const ggml_compute_params * params,
2777+ ggml_tensor * dst) {
2778+
2779+ const ggml_tensor * src0 = dst->src [0 ];
2780+
2781+ switch (src0->type ) {
2782+ case GGML_TYPE_F32:
2783+ {
2784+ ggml_compute_forward_gelu_erf_f32 (params, dst);
2785+ } break ;
2786+ case GGML_TYPE_F16:
2787+ {
2788+ ggml_compute_forward_gelu_erf_f16 (params, dst);
2789+ } break ;
2790+ default :
2791+ {
2792+ GGML_ABORT (" fatal error" );
2793+ }
2794+ }
2795+ }
2796+
26942797// ggml_compute_forward_gelu_quick
26952798
26962799static void ggml_compute_forward_gelu_quick_f32 (
@@ -7749,6 +7852,10 @@ void ggml_compute_forward_unary(
77497852 {
77507853 ggml_compute_forward_gelu (params, dst);
77517854 } break ;
7855+ case GGML_UNARY_OP_GELU_ERF:
7856+ {
7857+ ggml_compute_forward_gelu_erf (params, dst);
7858+ } break ;
77527859 case GGML_UNARY_OP_GELU_QUICK:
77537860 {
77547861 ggml_compute_forward_gelu_quick (params, dst);
0 commit comments