@@ -3238,15 +3238,13 @@ namespace dlib
32383238 const long d_model = feature_dim / input_data.k ();
32393239 const long num_channels = input_data.k ();
32403240
3241- // Compute logits in parallel
32423241 #pragma omp parallel for
32433242 for (long pos = 0 ; pos < batch_size * seq_len; ++pos) {
32443243 const long n = pos / seq_len;
32453244 const long s = pos % seq_len;
32463245
32473246 float logit = b_halt;
32483247
3249- // Dot product across all channels and model dimensions
32503248 for (long c = 0 ; c < num_channels; ++c) {
32513249 for (long d = 0 ; d < d_model; ++d) {
32523250 const long in_idx = ((n * num_channels + c) * seq_len + s) * d_model + d;
@@ -3257,7 +3255,6 @@ namespace dlib
32573255
32583256 logits_ptr[pos] = logit;
32593257
3260- // Apply sigmoid: p = 1 / (1 + exp(-logit))
32613258 halt_probs_ptr[pos] = 1 .0f / (1 .0f + std::exp (-logit));
32623259 }
32633260 }
@@ -3294,12 +3291,10 @@ namespace dlib
32943291 float r = remain[pos];
32953292 float effective = std::min (p * r, halt_threshold - cum_halt[pos]);
32963293
3297- // Update ACT state
32983294 cum_halt[pos] += effective;
32993295 remain[pos] -= effective;
33003296 steps[pos] = static_cast <float >(current_step + 1 );
33013297
3302- // Accumulate weighted output
33033298 for (long c = 0 ; c < num_channels; ++c) {
33043299 for (long d = 0 ; d < d_model; ++d) {
33053300 const long idx = ((n * num_channels + c) * seq_len + s) * d_model + d;
@@ -3341,63 +3336,6 @@ namespace dlib
33413336 }
33423337 }
33433338
3344- void compute_act_gradients (
3345- tensor& params_grad,
3346- resizable_tensor& gradient_logits,
3347- const tensor& input_cache,
3348- const tensor& halt_probs,
3349- const tensor& n_steps,
3350- long batch_size,
3351- long seq_len,
3352- long feature_dim,
3353- float ponder_penalty,
3354- float max_steps
3355- )
3356- {
3357- const float * p_halt = halt_probs.host ();
3358- const float * steps = n_steps.host ();
3359- const float * in_ptr = input_cache.host ();
3360- float * p_grad = params_grad.host ();
3361- float * g_logits = gradient_logits.host ();
3362-
3363- const long total = batch_size * seq_len;
3364- const long d_model = feature_dim / input_cache.k ();
3365- const long num_channels = input_cache.k ();
3366-
3367- // Compute gradient w.r.t. logits
3368- #pragma omp parallel for
3369- for (long i = 0 ; i < total; ++i) {
3370- float p = p_halt[i];
3371- float sigmoid_grad = p * (1 .0f - p);
3372- float ponder_grad = ponder_penalty * steps[i] / max_steps;
3373- g_logits[i] = sigmoid_grad * ponder_grad;
3374- }
3375-
3376- // Compute gradient w.r.t. weights
3377- #pragma omp parallel for
3378- for (long f = 0 ; f < feature_dim; ++f) {
3379- const long c = f / d_model;
3380- const long d = f % d_model;
3381- float grad_w = 0 ;
3382-
3383- for (long pos = 0 ; pos < total; ++pos) {
3384- const long n = pos / seq_len;
3385- const long s = pos % seq_len;
3386- const long in_idx = ((n * num_channels + c) * seq_len + s) * d_model + d;
3387- grad_w += in_ptr[in_idx] * g_logits[pos];
3388- }
3389-
3390- p_grad[f] += grad_w / total + 0 .0001f * params_grad.host ()[f]; // L2 reg
3391- }
3392-
3393- // Compute gradient w.r.t. bias
3394- float grad_b = 0 ;
3395- for (long i = 0 ; i < total; ++i) {
3396- grad_b += g_logits[i];
3397- }
3398- p_grad[feature_dim] += grad_b / total;
3399- }
3400-
34013339 void apply_act_depth_scaling (
34023340 tensor& gradients,
34033341 const tensor& n_steps,
@@ -3413,13 +3351,16 @@ namespace dlib
34133351 float * grad_ptr = gradients.host ();
34143352
34153353 #pragma omp parallel for
3416- for (long pos = 0 ; pos < batch_size * seq_len; ++pos) {
3417- float scale = 1 .0f + scale_factor * (steps[pos] / max_steps);
3354+ for (long pos = 0 ; pos < batch_size * seq_len; ++pos)
3355+ {
3356+ const float scale = 1 .0f + scale_factor * (steps[pos] / max_steps);
34183357 const long n = pos / seq_len;
34193358 const long s = pos % seq_len;
34203359
3421- for (long c = 0 ; c < num_channels; ++c) {
3422- for (long d = 0 ; d < d_model; ++d) {
3360+ for (long c = 0 ; c < num_channels; ++c)
3361+ {
3362+ for (long d = 0 ; d < d_model; ++d)
3363+ {
34233364 const long idx = ((n * num_channels + c) * seq_len + s) * d_model + d;
34243365 grad_ptr[idx] *= scale;
34253366 }
0 commit comments