@@ -3219,6 +3219,216 @@ namespace dlib
32193219
32203220 // ------------------------------------------------------------------------------------
32213221
3222+ void compute_act_halt_probabilities (
3223+ resizable_tensor& halt_probs,
3224+ resizable_tensor& logits,
3225+ const tensor& input_data,
3226+ const tensor& halt_params,
3227+ long batch_size,
3228+ long seq_len,
3229+ long feature_dim
3230+ )
3231+ {
3232+ const float * in_ptr = input_data.host ();
3233+ const float * W_halt = halt_params.host ();
3234+ const float b_halt = halt_params.host ()[feature_dim];
3235+ float * logits_ptr = logits.host ();
3236+ float * halt_probs_ptr = halt_probs.host ();
3237+
3238+ const long d_model = feature_dim / input_data.k ();
3239+ const long num_channels = input_data.k ();
3240+
3241+ // Compute logits in parallel
3242+ #pragma omp parallel for
3243+ for (long pos = 0 ; pos < batch_size * seq_len; ++pos) {
3244+ const long n = pos / seq_len;
3245+ const long s = pos % seq_len;
3246+
3247+ float logit = b_halt;
3248+
3249+ // Dot product across all channels and model dimensions
3250+ for (long c = 0 ; c < num_channels; ++c) {
3251+ for (long d = 0 ; d < d_model; ++d) {
3252+ const long in_idx = ((n * num_channels + c) * seq_len + s) * d_model + d;
3253+ const long weight_idx = c * d_model + d;
3254+ logit += in_ptr[in_idx] * W_halt[weight_idx];
3255+ }
3256+ }
3257+
3258+ logits_ptr[pos] = logit;
3259+
3260+ // Apply sigmoid: p = 1 / (1 + exp(-logit))
3261+ halt_probs_ptr[pos] = 1 .0f / (1 .0f + std::exp (-logit));
3262+ }
3263+ }
3264+
3265+ void update_act_state (
3266+ resizable_tensor& output,
3267+ const tensor& input_data,
3268+ const tensor& halt_probs,
3269+ resizable_tensor& cumulative_halting,
3270+ resizable_tensor& remainders,
3271+ resizable_tensor& n_steps,
3272+ long batch_size,
3273+ long seq_len,
3274+ long d_model,
3275+ long num_channels,
3276+ float halt_threshold,
3277+ long current_step
3278+ )
3279+ {
3280+ const float * in_ptr = input_data.host ();
3281+ const float * p_halt = halt_probs.host ();
3282+ float * out_ptr = output.host ();
3283+ float * cum_halt = cumulative_halting.host ();
3284+ float * remain = remainders.host ();
3285+ float * steps = n_steps.host ();
3286+
3287+ #pragma omp parallel for
3288+ for (long pos = 0 ; pos < batch_size * seq_len; ++pos) {
3289+ if (cum_halt[pos] < halt_threshold) {
3290+ const long n = pos / seq_len;
3291+ const long s = pos % seq_len;
3292+
3293+ float p = p_halt[pos];
3294+ float r = remain[pos];
3295+ float effective = std::min (p * r, halt_threshold - cum_halt[pos]);
3296+
3297+ // Update ACT state
3298+ cum_halt[pos] += effective;
3299+ remain[pos] -= effective;
3300+ steps[pos] = static_cast <float >(current_step + 1 );
3301+
3302+ // Accumulate weighted output
3303+ for (long c = 0 ; c < num_channels; ++c) {
3304+ for (long d = 0 ; d < d_model; ++d) {
3305+ const long idx = ((n * num_channels + c) * seq_len + s) * d_model + d;
3306+ out_ptr[idx] += effective * in_ptr[idx];
3307+ }
3308+ }
3309+ }
3310+ }
3311+ }
3312+
3313+ void finalize_act_output (
3314+ resizable_tensor& output,
3315+ const tensor& input_data,
3316+ const tensor& remainders,
3317+ long batch_size,
3318+ long seq_len,
3319+ long d_model,
3320+ long num_channels
3321+ )
3322+ {
3323+ const float * in_ptr = input_data.host ();
3324+ const float * remain = remainders.host ();
3325+ float * out_ptr = output.host ();
3326+
3327+ #pragma omp parallel for
3328+ for (long pos = 0 ; pos < batch_size * seq_len; ++pos) {
3329+ float r = remain[pos];
3330+ if (r > 1e-6f ) {
3331+ const long n = pos / seq_len;
3332+ const long s = pos % seq_len;
3333+
3334+ for (long c = 0 ; c < num_channels; ++c) {
3335+ for (long d = 0 ; d < d_model; ++d) {
3336+ const long idx = ((n * num_channels + c) * seq_len + s) * d_model + d;
3337+ out_ptr[idx] += r * in_ptr[idx];
3338+ }
3339+ }
3340+ }
3341+ }
3342+ }
3343+
3344+ void compute_act_gradients (
3345+ tensor& params_grad,
3346+ resizable_tensor& gradient_logits,
3347+ const tensor& input_cache,
3348+ const tensor& halt_probs,
3349+ const tensor& n_steps,
3350+ long batch_size,
3351+ long seq_len,
3352+ long feature_dim,
3353+ float ponder_penalty,
3354+ float max_steps
3355+ )
3356+ {
3357+ const float * p_halt = halt_probs.host ();
3358+ const float * steps = n_steps.host ();
3359+ const float * in_ptr = input_cache.host ();
3360+ float * p_grad = params_grad.host ();
3361+ float * g_logits = gradient_logits.host ();
3362+
3363+ const long total = batch_size * seq_len;
3364+ const long d_model = feature_dim / input_cache.k ();
3365+ const long num_channels = input_cache.k ();
3366+
3367+ // Compute gradient w.r.t. logits
3368+ #pragma omp parallel for
3369+ for (long i = 0 ; i < total; ++i) {
3370+ float p = p_halt[i];
3371+ float sigmoid_grad = p * (1 .0f - p);
3372+ float ponder_grad = ponder_penalty * steps[i] / max_steps;
3373+ g_logits[i] = sigmoid_grad * ponder_grad;
3374+ }
3375+
3376+ // Compute gradient w.r.t. weights
3377+ #pragma omp parallel for
3378+ for (long f = 0 ; f < feature_dim; ++f) {
3379+ const long c = f / d_model;
3380+ const long d = f % d_model;
3381+ float grad_w = 0 ;
3382+
3383+ for (long pos = 0 ; pos < total; ++pos) {
3384+ const long n = pos / seq_len;
3385+ const long s = pos % seq_len;
3386+ const long in_idx = ((n * num_channels + c) * seq_len + s) * d_model + d;
3387+ grad_w += in_ptr[in_idx] * g_logits[pos];
3388+ }
3389+
3390+ p_grad[f] += grad_w / total + 0 .0001f * params_grad.host ()[f]; // L2 reg
3391+ }
3392+
3393+ // Compute gradient w.r.t. bias
3394+ float grad_b = 0 ;
3395+ for (long i = 0 ; i < total; ++i) {
3396+ grad_b += g_logits[i];
3397+ }
3398+ p_grad[feature_dim] += grad_b / total;
3399+ }
3400+
3401+ void apply_act_depth_scaling (
3402+ tensor& gradients,
3403+ const tensor& n_steps,
3404+ long batch_size,
3405+ long seq_len,
3406+ long d_model,
3407+ long num_channels,
3408+ float max_steps,
3409+ float scale_factor
3410+ )
3411+ {
3412+ const float * steps = n_steps.host ();
3413+ float * grad_ptr = gradients.host ();
3414+
3415+ #pragma omp parallel for
3416+ for (long pos = 0 ; pos < batch_size * seq_len; ++pos) {
3417+ float scale = 1 .0f + scale_factor * (steps[pos] / max_steps);
3418+ const long n = pos / seq_len;
3419+ const long s = pos % seq_len;
3420+
3421+ for (long c = 0 ; c < num_channels; ++c) {
3422+ for (long d = 0 ; d < d_model; ++d) {
3423+ const long idx = ((n * num_channels + c) * seq_len + s) * d_model + d;
3424+ grad_ptr[idx] *= scale;
3425+ }
3426+ }
3427+ }
3428+ }
3429+
3430+ // ------------------------------------------------------------------------------------
3431+
32223432 }
32233433}
32243434
0 commit comments