Skip to content

Commit d520c2a

Browse files
committed
Add Adaptive Computation Time (ACT) layer with CPU/CUDA support
1 parent 80a6e0e commit d520c2a

File tree

9 files changed

+1506
-75
lines changed

9 files changed

+1506
-75
lines changed

dlib/cuda/cpu_dlib.cpp

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3219,6 +3219,216 @@ namespace dlib
32193219

32203220
// ------------------------------------------------------------------------------------
32213221

3222+
void compute_act_halt_probabilities(
3223+
resizable_tensor& halt_probs,
3224+
resizable_tensor& logits,
3225+
const tensor& input_data,
3226+
const tensor& halt_params,
3227+
long batch_size,
3228+
long seq_len,
3229+
long feature_dim
3230+
)
3231+
{
3232+
const float* in_ptr = input_data.host();
3233+
const float* W_halt = halt_params.host();
3234+
const float b_halt = halt_params.host()[feature_dim];
3235+
float* logits_ptr = logits.host();
3236+
float* halt_probs_ptr = halt_probs.host();
3237+
3238+
const long d_model = feature_dim / input_data.k();
3239+
const long num_channels = input_data.k();
3240+
3241+
// Compute logits in parallel
3242+
#pragma omp parallel for
3243+
for (long pos = 0; pos < batch_size * seq_len; ++pos) {
3244+
const long n = pos / seq_len;
3245+
const long s = pos % seq_len;
3246+
3247+
float logit = b_halt;
3248+
3249+
// Dot product across all channels and model dimensions
3250+
for (long c = 0; c < num_channels; ++c) {
3251+
for (long d = 0; d < d_model; ++d) {
3252+
const long in_idx = ((n * num_channels + c) * seq_len + s) * d_model + d;
3253+
const long weight_idx = c * d_model + d;
3254+
logit += in_ptr[in_idx] * W_halt[weight_idx];
3255+
}
3256+
}
3257+
3258+
logits_ptr[pos] = logit;
3259+
3260+
// Apply sigmoid: p = 1 / (1 + exp(-logit))
3261+
halt_probs_ptr[pos] = 1.0f / (1.0f + std::exp(-logit));
3262+
}
3263+
}
3264+
3265+
void update_act_state(
3266+
resizable_tensor& output,
3267+
const tensor& input_data,
3268+
const tensor& halt_probs,
3269+
resizable_tensor& cumulative_halting,
3270+
resizable_tensor& remainders,
3271+
resizable_tensor& n_steps,
3272+
long batch_size,
3273+
long seq_len,
3274+
long d_model,
3275+
long num_channels,
3276+
float halt_threshold,
3277+
long current_step
3278+
)
3279+
{
3280+
const float* in_ptr = input_data.host();
3281+
const float* p_halt = halt_probs.host();
3282+
float* out_ptr = output.host();
3283+
float* cum_halt = cumulative_halting.host();
3284+
float* remain = remainders.host();
3285+
float* steps = n_steps.host();
3286+
3287+
#pragma omp parallel for
3288+
for (long pos = 0; pos < batch_size * seq_len; ++pos) {
3289+
if (cum_halt[pos] < halt_threshold) {
3290+
const long n = pos / seq_len;
3291+
const long s = pos % seq_len;
3292+
3293+
float p = p_halt[pos];
3294+
float r = remain[pos];
3295+
float effective = std::min(p * r, halt_threshold - cum_halt[pos]);
3296+
3297+
// Update ACT state
3298+
cum_halt[pos] += effective;
3299+
remain[pos] -= effective;
3300+
steps[pos] = static_cast<float>(current_step + 1);
3301+
3302+
// Accumulate weighted output
3303+
for (long c = 0; c < num_channels; ++c) {
3304+
for (long d = 0; d < d_model; ++d) {
3305+
const long idx = ((n * num_channels + c) * seq_len + s) * d_model + d;
3306+
out_ptr[idx] += effective * in_ptr[idx];
3307+
}
3308+
}
3309+
}
3310+
}
3311+
}
3312+
3313+
void finalize_act_output(
3314+
resizable_tensor& output,
3315+
const tensor& input_data,
3316+
const tensor& remainders,
3317+
long batch_size,
3318+
long seq_len,
3319+
long d_model,
3320+
long num_channels
3321+
)
3322+
{
3323+
const float* in_ptr = input_data.host();
3324+
const float* remain = remainders.host();
3325+
float* out_ptr = output.host();
3326+
3327+
#pragma omp parallel for
3328+
for (long pos = 0; pos < batch_size * seq_len; ++pos) {
3329+
float r = remain[pos];
3330+
if (r > 1e-6f) {
3331+
const long n = pos / seq_len;
3332+
const long s = pos % seq_len;
3333+
3334+
for (long c = 0; c < num_channels; ++c) {
3335+
for (long d = 0; d < d_model; ++d) {
3336+
const long idx = ((n * num_channels + c) * seq_len + s) * d_model + d;
3337+
out_ptr[idx] += r * in_ptr[idx];
3338+
}
3339+
}
3340+
}
3341+
}
3342+
}
3343+
3344+
void compute_act_gradients(
3345+
tensor& params_grad,
3346+
resizable_tensor& gradient_logits,
3347+
const tensor& input_cache,
3348+
const tensor& halt_probs,
3349+
const tensor& n_steps,
3350+
long batch_size,
3351+
long seq_len,
3352+
long feature_dim,
3353+
float ponder_penalty,
3354+
float max_steps
3355+
)
3356+
{
3357+
const float* p_halt = halt_probs.host();
3358+
const float* steps = n_steps.host();
3359+
const float* in_ptr = input_cache.host();
3360+
float* p_grad = params_grad.host();
3361+
float* g_logits = gradient_logits.host();
3362+
3363+
const long total = batch_size * seq_len;
3364+
const long d_model = feature_dim / input_cache.k();
3365+
const long num_channels = input_cache.k();
3366+
3367+
// Compute gradient w.r.t. logits
3368+
#pragma omp parallel for
3369+
for (long i = 0; i < total; ++i) {
3370+
float p = p_halt[i];
3371+
float sigmoid_grad = p * (1.0f - p);
3372+
float ponder_grad = ponder_penalty * steps[i] / max_steps;
3373+
g_logits[i] = sigmoid_grad * ponder_grad;
3374+
}
3375+
3376+
// Compute gradient w.r.t. weights
3377+
#pragma omp parallel for
3378+
for (long f = 0; f < feature_dim; ++f) {
3379+
const long c = f / d_model;
3380+
const long d = f % d_model;
3381+
float grad_w = 0;
3382+
3383+
for (long pos = 0; pos < total; ++pos) {
3384+
const long n = pos / seq_len;
3385+
const long s = pos % seq_len;
3386+
const long in_idx = ((n * num_channels + c) * seq_len + s) * d_model + d;
3387+
grad_w += in_ptr[in_idx] * g_logits[pos];
3388+
}
3389+
3390+
p_grad[f] += grad_w / total + 0.0001f * params_grad.host()[f]; // L2 reg
3391+
}
3392+
3393+
// Compute gradient w.r.t. bias
3394+
float grad_b = 0;
3395+
for (long i = 0; i < total; ++i) {
3396+
grad_b += g_logits[i];
3397+
}
3398+
p_grad[feature_dim] += grad_b / total;
3399+
}
3400+
3401+
void apply_act_depth_scaling(
3402+
tensor& gradients,
3403+
const tensor& n_steps,
3404+
long batch_size,
3405+
long seq_len,
3406+
long d_model,
3407+
long num_channels,
3408+
float max_steps,
3409+
float scale_factor
3410+
)
3411+
{
3412+
const float* steps = n_steps.host();
3413+
float* grad_ptr = gradients.host();
3414+
3415+
#pragma omp parallel for
3416+
for (long pos = 0; pos < batch_size * seq_len; ++pos) {
3417+
float scale = 1.0f + scale_factor * (steps[pos] / max_steps);
3418+
const long n = pos / seq_len;
3419+
const long s = pos % seq_len;
3420+
3421+
for (long c = 0; c < num_channels; ++c) {
3422+
for (long d = 0; d < d_model; ++d) {
3423+
const long idx = ((n * num_channels + c) * seq_len + s) * d_model + d;
3424+
grad_ptr[idx] *= scale;
3425+
}
3426+
}
3427+
}
3428+
}
3429+
3430+
// ------------------------------------------------------------------------------------
3431+
32223432
}
32233433
}
32243434

dlib/cuda/cpu_dlib.h

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,67 @@ namespace dlib
536536
bool scale
537537
);
538538

539+
// -----------------------------------------------------------------------------------
540+
541+
void compute_act_halt_probabilities(
542+
resizable_tensor& halt_probs,
543+
resizable_tensor& logits,
544+
const tensor& input_data,
545+
const tensor& halt_params,
546+
long batch_size,
547+
long seq_len,
548+
long feature_dim
549+
);
550+
551+
void update_act_state(
552+
resizable_tensor& output,
553+
const tensor& input_data,
554+
const tensor& halt_probs,
555+
resizable_tensor& cumulative_halting,
556+
resizable_tensor& remainders,
557+
resizable_tensor& n_steps,
558+
long batch_size,
559+
long seq_len,
560+
long d_model,
561+
long num_channels,
562+
float halt_threshold,
563+
long current_step
564+
);
565+
566+
void finalize_act_output(
567+
resizable_tensor& output,
568+
const tensor& input_data,
569+
const tensor& remainders,
570+
long batch_size,
571+
long seq_len,
572+
long d_model,
573+
long num_channels
574+
);
575+
576+
void compute_act_gradients(
577+
tensor& params_grad,
578+
resizable_tensor& gradient_logits,
579+
const tensor& input_cache,
580+
const tensor& halt_probs,
581+
const tensor& n_steps,
582+
long batch_size,
583+
long seq_len,
584+
long feature_dim,
585+
float ponder_penalty,
586+
float max_steps
587+
);
588+
589+
void apply_act_depth_scaling(
590+
tensor& gradients,
591+
const tensor& n_steps,
592+
long batch_size,
593+
long seq_len,
594+
long d_model,
595+
long num_channels,
596+
float max_steps,
597+
float scale_factor
598+
);
599+
539600
// -----------------------------------------------------------------------------------
540601

541602
class pooling

0 commit comments

Comments
 (0)