@@ -2406,12 +2406,19 @@ namespace dlib { namespace tt
24062406 long feature_dim
24072407 );
24082408 /* !
2409- requires
2410- - halt_params.size() == feature_dim + 1 (weights + bias)
2411- - input_data dimensions match batch_size x seq_len x ...
2412- ensures
2413- - halt_probs contains sigmoid(W_halt^T * input + b_halt) for each position
2414- - logits contains the pre-sigmoid values
2409+ requires
2410+ - halt_params.size() == feature_dim + 1 (weights + bias)
2411+ - input_data.num_samples() == batch_size
2412+ - input_data.k() == num_channels where feature_dim = num_channels * d_model
2413+ - input_data.nr() == seq_len
2414+ - input_data.nc() == d_model
2415+ ensures
2416+ - Computes halting probabilities for Adaptive Computation Time:
2417+ - halt_probs contains sigmoid(W_halt^T * input + b_halt) for each position
2418+ - logits contains the pre-sigmoid values
2419+ - batch_size: number of samples in the batch
2420+ - seq_len: sequence length (number of positions to process)
2421+ - feature_dim: total feature dimension (num_channels × d_model)
24152422 !*/
24162423
24172424 void update_act_state (
@@ -2432,10 +2439,24 @@ namespace dlib { namespace tt
24322439 requires
24332440 - 0 < halt_threshold <= 1.0
24342441 - current_step >= 0
2435- ensures
2436- - Updates ACT state for all positions
2437- - Accumulates weighted outputs: output += α_t^n · input_data
2438- - Updates cumulative_halting, remainders, and n_steps
2442+ - input_data.num_samples() == batch_size
2443+ - input_data.k() == num_channels
2444+ - input_data.nr() == seq_len
2445+ - input_data.nc() == d_model
2446+ - output has the same dimensions as input_data
2447+ - halt_probs.size() == batch_size * seq_len
2448+ - cumulative_halting.size() == remainders.size() == n_steps.size() == batch_size * seq_len
2449+ ensures
2450+ - Core ACT update step that accumulates weighted outputs:
2451+ - Updates ACT state for all positions
2452+ - Accumulates weighted outputs: output += α_t^n * input_data
2453+ - Updates cumulative_halting, remainders, and n_steps
2454+ - batch_size: number of samples in the batch
2455+ - seq_len: sequence length (number of positions to process)
2456+ - d_model: model dimension per channel
2457+ - num_channels: number of feature channels
2458+ - halt_threshold: halting threshold (typically 0.99)
2459+ - current_step: current computation step index (0-based)
24392460 !*/
24402461
24412462 void finalize_act_output (
@@ -2448,9 +2469,21 @@ namespace dlib { namespace tt
24482469 long num_channels
24492470 );
24502471 /* !
2472+ requires
2473+ - input_data.num_samples() == batch_size
2474+ - input_data.k() == num_channels
2475+ - input_data.nr() == seq_len
2476+ - input_data.nc() == d_model
2477+ - output has the same dimensions as input_data
2478+ - remainders.size() == batch_size * seq_len
24512479 ensures
2452- - Adds final remainder contributions: output += ρ_t · input_data
2453- - Applied only to positions with significant remainder (> 1e-6)
2480+ - Finalizes ACT output by adding remainder contributions:
2481+ - Adds final remainder contributions: output += ρ_t * input_data
2482+ - Applied only to positions with significant remainder (> 1e-6)
2483+ - batch_size: number of samples in the batch
2484+ - seq_len: sequence length (number of positions to process)
2485+ - d_model: model dimension per channel
2486+ - num_channels: number of feature channels
24542487 !*/
24552488
24562489 void apply_act_depth_scaling (
@@ -2466,9 +2499,21 @@ namespace dlib { namespace tt
24662499 /* !
24672500 requires
24682501 - scale_factor >= 0
2469- ensures
2470- - Applies depth-dependent gradient scaling
2471- - scale = 1 + scale_factor * (n_steps[pos] / max_steps)
2502+ - max_steps > 0
2503+ - gradients.num_samples() == batch_size
2504+ - gradients.k() == num_channels
2505+ - gradients.nr() == seq_len
2506+ - gradients.nc() == d_model
2507+ - n_steps.size() == batch_size * seq_len
2508+ ensures
2509+ - Applies gradient scaling based on computation depth:
2510+ - Applies depth-dependent gradient scaling
2511+ - scale = 1 + scale_factor * (n_steps[pos] / max_steps)
2512+ - seq_len: sequence length (number of positions to process)
2513+ - d_model: model dimension per channel
2514+ - num_channels: number of feature channels
2515+ - max_steps: maximum allowed computation steps
2516+ - scale_factor: scaling strength (0 = no scaling)
24722517 !*/
24732518
24742519// ----------------------------------------------------------------------------------------
0 commit comments