Skip to content

Commit 1a904f2

Browse files
author
Aldric PIERRAIN
committed
Update comments for params
1 parent b089d58 commit 1a904f2

File tree

1 file changed

+60
-15
lines changed

1 file changed

+60
-15
lines changed

dlib/cuda/tensor_tools.h

Lines changed: 60 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2406,12 +2406,19 @@ namespace dlib { namespace tt
24062406
long feature_dim
24072407
);
24082408
/*!
2409-
requires
2410-
- halt_params.size() == feature_dim + 1 (weights + bias)
2411-
- input_data dimensions match batch_size x seq_len x ...
2412-
ensures
2413-
- halt_probs contains sigmoid(W_halt^T * input + b_halt) for each position
2414-
- logits contains the pre-sigmoid values
2409+
requires
2410+
- halt_params.size() == feature_dim + 1 (weights + bias)
2411+
- input_data.num_samples() == batch_size
2412+
- input_data.k() == num_channels where feature_dim = num_channels * d_model
2413+
- input_data.nr() == seq_len
2414+
- input_data.nc() == d_model
2415+
ensures
2416+
- Computes halting probabilities for Adaptive Computation Time:
2417+
- halt_probs contains sigmoid(W_halt^T * input + b_halt) for each position
2418+
- logits contains the pre-sigmoid values
2419+
- batch_size: number of samples in the batch
2420+
- seq_len: sequence length (number of positions to process)
2421+
- feature_dim: total feature dimension (num_channels × d_model)
24152422
!*/
24162423

24172424
void update_act_state(
@@ -2432,10 +2439,24 @@ namespace dlib { namespace tt
24322439
requires
24332440
- 0 < halt_threshold <= 1.0
24342441
- current_step >= 0
2435-
ensures
2436-
- Updates ACT state for all positions
2437-
- Accumulates weighted outputs: output += α_t^n · input_data
2438-
- Updates cumulative_halting, remainders, and n_steps
2442+
- input_data.num_samples() == batch_size
2443+
- input_data.k() == num_channels
2444+
- input_data.nr() == seq_len
2445+
- input_data.nc() == d_model
2446+
- output has the same dimensions as input_data
2447+
- halt_probs.size() == batch_size * seq_len
2448+
- cumulative_halting.size() == remainders.size() == n_steps.size() == batch_size * seq_len
2449+
ensures
2450+
- Core ACT update step that accumulates weighted outputs:
2451+
- Updates ACT state for all positions
2452+
- Accumulates weighted outputs: output += α_t^n * input_data
2453+
- Updates cumulative_halting, remainders, and n_steps
2454+
- batch_size: number of samples in the batch
2455+
- seq_len: sequence length (number of positions to process)
2456+
- d_model: model dimension per channel
2457+
- num_channels: number of feature channels
2458+
- halt_threshold: halting threshold (typically 0.99)
2459+
- current_step: current computation step index (0-based)
24392460
!*/
24402461

24412462
void finalize_act_output(
@@ -2448,9 +2469,21 @@ namespace dlib { namespace tt
24482469
long num_channels
24492470
);
24502471
/*!
2472+
requires
2473+
- input_data.num_samples() == batch_size
2474+
- input_data.k() == num_channels
2475+
- input_data.nr() == seq_len
2476+
- input_data.nc() == d_model
2477+
- output has the same dimensions as input_data
2478+
- remainders.size() == batch_size * seq_len
24512479
ensures
2452-
- Adds final remainder contributions: output += ρ_t · input_data
2453-
- Applied only to positions with significant remainder (> 1e-6)
2480+
- Finalizes ACT output by adding remainder contributions:
2481+
- Adds final remainder contributions: output += ρ_t * input_data
2482+
- Applied only to positions with significant remainder (> 1e-6)
2483+
- batch_size: number of samples in the batch
2484+
- seq_len: sequence length (number of positions to process)
2485+
- d_model: model dimension per channel
2486+
- num_channels: number of feature channels
24542487
!*/
24552488

24562489
void apply_act_depth_scaling(
@@ -2466,9 +2499,21 @@ namespace dlib { namespace tt
24662499
/*!
24672500
requires
24682501
- scale_factor >= 0
2469-
ensures
2470-
- Applies depth-dependent gradient scaling
2471-
- scale = 1 + scale_factor * (n_steps[pos] / max_steps)
2502+
- max_steps > 0
2503+
- gradients.num_samples() == batch_size
2504+
- gradients.k() == num_channels
2505+
- gradients.nr() == seq_len
2506+
- gradients.nc() == d_model
2507+
- n_steps.size() == batch_size * seq_len
2508+
ensures
2509+
- Applies gradient scaling based on computation depth:
2510+
- Applies depth-dependent gradient scaling
2511+
- scale = 1 + scale_factor * (n_steps[pos] / max_steps)
2512+
- seq_len: sequence length (number of positions to process)
2513+
- d_model: model dimension per channel
2514+
- num_channels: number of feature channels
2515+
- max_steps: maximum allowed computation steps
2516+
- scale_factor: scaling strength (0 = no scaling)
24722517
!*/
24732518

24742519
// ----------------------------------------------------------------------------------------

0 commit comments

Comments
 (0)