@@ -8831,7 +8831,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
88318831    GGML_ASSERT (ggml_are_same_shape (src0, src0_grad));
88328832    GGML_ASSERT (ggml_are_same_shape (src0, src0_grad_m));
88338833    GGML_ASSERT (ggml_are_same_shape (src0, src0_grad_v));
8834-     GGML_ASSERT (ggml_nelements (adamw_params) == 7 );
8834+     GGML_ASSERT (ggml_nelements (adamw_params) == 8 );
88358835
88368836    const  int  ith = params->ith ;
88378837    const  int  nth = params->nth ;
@@ -8849,14 +8849,14 @@ static void ggml_compute_forward_opt_step_adamw_f32(
88498849    const  int  ir1 = MIN (ir0 + dr, nr);
88508850
88518851    const  float  * adamw_params_ptr = ggml_get_data_f32 (adamw_params);
8852+ 
88528853    const  float  alpha  = adamw_params_ptr[0 ];
88538854    const  float  beta1  = adamw_params_ptr[1 ];
88548855    const  float  beta2  = adamw_params_ptr[2 ];
88558856    const  float  eps    = adamw_params_ptr[3 ];
8856-     const  float  wd     = adamw_params_ptr[4 ];
88578857    const  float  beta1h = adamw_params_ptr[5 ];
88588858    const  float  beta2h = adamw_params_ptr[6 ];
8859- 
8859+      const   float  keep   = adamw_params_ptr[ 7 ]; 
88608860    for  (int  ir = ir0; ir < ir1; ++ir) {
88618861        const  int64_t  i03 = ir/(ne02*ne01);
88628862        const  int64_t  i02 = (ir - i03*ne02*ne01)/ne01;
@@ -8879,7 +8879,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
88798879            //  The weight decay is applied independently of the Adam momenta m and v.
88808880            //  This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
88818881            //  See: https://arxiv.org/pdf/1711.05101v3.pdf
8882-             w[i00] = w[i00]*( 1 . 0f  - alpha*wd)  - alpha*mh/ vh;
8882+             w[i00] = w[i00] * keep  - alpha * mh /  vh;
88838883        }
88848884    }
88858885}
@@ -8901,3 +8901,63 @@ void ggml_compute_forward_opt_step_adamw(
89018901            }
89028902    }
89038903}
8904+ 
8905+ static  void  ggml_compute_forward_opt_step_sgd_f32 (const  ggml_compute_params * params, ggml_tensor * dst) {
8906+     const  ggml_tensor * src0         = dst->src [0 ];
8907+     const  ggml_tensor * src0_grad    = dst->src [1 ];
8908+     const  ggml_tensor * adamw_params = dst->src [2 ];
8909+ 
8910+     GGML_ASSERT (ggml_are_same_shape (src0, src0_grad));
8911+     GGML_ASSERT (ggml_nelements (adamw_params) == 8 );
8912+ 
8913+     const  int  ith = params->ith ;
8914+     const  int  nth = params->nth ;
8915+ 
8916+     const  int  nr = ggml_nrows (src0);
8917+ 
8918+     GGML_TENSOR_UNARY_OP_LOCALS
8919+     GGML_ASSERT (nb00 == sizeof (float ));
8920+ 
8921+     //  rows per thread
8922+     const  int  dr = (nr + nth - 1 ) / nth;
8923+ 
8924+     //  row range for this thread
8925+     const  int  ir0 = dr * ith;
8926+     const  int  ir1 = MIN (ir0 + dr, nr);
8927+ 
8928+     //  using adamw param subset we care about - alpha, wd - could have a separate struct
8929+     const  float  * adamw_params_ptr = ggml_get_data_f32 (adamw_params);
8930+     const  float    alpha            = adamw_params_ptr[0 ];
8931+     const  float    keep             = adamw_params_ptr[7 ];
8932+ 
8933+     for  (int  ir = ir0; ir < ir1; ++ir) {
8934+         const  int64_t  i03 = ir / (ne02 * ne01);
8935+         const  int64_t  i02 = (ir - i03 * ne02 * ne01) / ne01;
8936+         const  int64_t  i01 = (ir - i03 * ne02 * ne01 - i02 * ne01);
8937+ 
8938+         const  size_t  offset = i03 * nb03 + i02 * nb02 + i01 * nb01;
8939+ 
8940+         float  *       w = (float  *) ((char  *) src0->data  + offset);                   //  weight
8941+         const  float  * g = (const  float  *) ((const  char  *) src0_grad->data  + offset);  //  grad
8942+ 
8943+         for  (int  i00 = 0 ; i00 < ne00; ++i00) {
8944+             w[i00] = w[i00] * keep - alpha * g[i00];
8945+         }
8946+     }
8947+ }
8948+ 
8949+ void  ggml_compute_forward_opt_step_sgd (const  ggml_compute_params * params, ggml_tensor * dst) {
8950+     const  ggml_tensor * src0 = dst->src [0 ];
8951+ 
8952+     switch  (src0->type ) {
8953+         case  GGML_TYPE_F32:
8954+             {
8955+                 ggml_compute_forward_opt_step_sgd_f32 (params, dst);
8956+             }
8957+             break ;
8958+         default :
8959+             {
8960+                 GGML_ABORT (" fatal error - sgd is F32 only" 
8961+             }
8962+     }
8963+ }
0 commit comments