Skip to content

Commit ab29fc4

Browse files
committed
Fixes and improvements
1 parent 1a904f2 commit ab29fc4

File tree

3 files changed

+48
-34
lines changed

3 files changed

+48
-34
lines changed

dlib/cuda/cpu_dlib.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3238,7 +3238,6 @@ namespace dlib
32383238
const long d_model = feature_dim / input_data.k();
32393239
const long num_channels = input_data.k();
32403240

3241-
#pragma omp parallel for
32423241
for (long pos = 0; pos < batch_size * seq_len; ++pos) {
32433242
const long n = pos / seq_len;
32443243
const long s = pos % seq_len;
@@ -3281,7 +3280,6 @@ namespace dlib
32813280
float* remain = remainders.host();
32823281
float* steps = n_steps.host();
32833282

3284-
#pragma omp parallel for
32853283
for (long pos = 0; pos < batch_size * seq_len; ++pos) {
32863284
if (cum_halt[pos] < halt_threshold) {
32873285
const long n = pos / seq_len;
@@ -3319,7 +3317,6 @@ namespace dlib
33193317
const float* remain = remainders.host();
33203318
float* out_ptr = output.host();
33213319

3322-
#pragma omp parallel for
33233320
for (long pos = 0; pos < batch_size * seq_len; ++pos) {
33243321
float r = remain[pos];
33253322
if (r > 1e-6f) {
@@ -3350,7 +3347,6 @@ namespace dlib
33503347
const float* steps = n_steps.host();
33513348
float* grad_ptr = gradients.host();
33523349

3353-
#pragma omp parallel for
33543350
for (long pos = 0; pos < batch_size * seq_len; ++pos)
33553351
{
33563352
const float scale = 1.0f + scale_factor * (steps[pos] / max_steps);

dlib/cuda/cuda_dlib.cu

Lines changed: 46 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2749,23 +2749,35 @@ namespace dlib
27492749
size_t feature_dim
27502750
)
27512751
{
2752-
for (auto pos : grid_stride_range(0, batch_size * seq_len))
2752+
const long total_positions = batch_size * seq_len;
2753+
2754+
for (auto pos : grid_stride_range_y(0, total_positions))
2755+
for (auto i : grid_stride_range(0, 1))
2756+
logits[pos] = b_halt;
2757+
__syncthreads();
2758+
2759+
for (auto pos : grid_stride_range_y(0, total_positions))
27532760
{
2754-
const size_t n = pos / seq_len;
2755-
const size_t s = pos % seq_len;
2761+
const long n = pos / seq_len;
2762+
const long s = pos % seq_len;
27562763

2757-
float logit = b_halt;
2764+
float temp = 0;
2765+
for (auto feat_idx : grid_stride_range(0, feature_dim))
2766+
{
2767+
const long c = feat_idx / d_model;
2768+
const long d = feat_idx % d_model;
27582769

2759-
for (size_t c = 0; c < num_channels; ++c) {
2760-
for (size_t d = 0; d < d_model; ++d) {
2761-
const size_t in_idx = ((n * num_channels + c) * seq_len + s) * d_model + d;
2762-
const size_t weight_idx = c * d_model + d;
2763-
logit += input_data[in_idx] * W_halt[weight_idx];
2764-
}
2770+
const long in_idx = ((n * num_channels + c) * seq_len + s) * d_model + d;
2771+
temp += input_data[in_idx] * W_halt[feat_idx];
27652772
}
27662773

2767-
logits[pos] = logit;
2768-
halt_probs[pos] = 1.0f / (1.0f + ::expf(-logit));
2774+
warp_reduce_atomic_add(logits[pos], temp);
2775+
}
2776+
__syncthreads();
2777+
2778+
for (auto pos : grid_stride_range(0, total_positions))
2779+
{
2780+
halt_probs[pos] = 1.0f / (1.0f + expf(-logits[pos]));
27692781
}
27702782
}
27712783

@@ -2783,8 +2795,11 @@ namespace dlib
27832795
const long d_model = feature_dim / input_data.k();
27842796
const long num_channels = input_data.k();
27852797

2798+
halt_probs.set_size(total_positions, 1, 1, 1);
2799+
logits.set_size(total_positions, 1, 1, 1);
2800+
27862801
launch_kernel(_cuda_compute_act_halt_probabilities,
2787-
max_jobs(total_positions),
2802+
max_jobs(feature_dim, total_positions),
27882803
halt_probs.device(),
27892804
logits.device(),
27902805
input_data.device(),
@@ -2814,7 +2829,8 @@ namespace dlib
28142829
{
28152830
for (auto pos : grid_stride_range(0, batch_size * seq_len))
28162831
{
2817-
if (cumulative_halting[pos] < halt_threshold) {
2832+
if (cumulative_halting[pos] < halt_threshold)
2833+
{
28182834
const size_t n = pos / seq_len;
28192835
const size_t s = pos % seq_len;
28202836

@@ -2930,17 +2946,21 @@ namespace dlib
29302946
float scale_factor
29312947
)
29322948
{
2933-
for (auto pos : grid_stride_range(0, batch_size * seq_len))
2949+
const long total_positions = batch_size * seq_len;
2950+
const long feature_dim = num_channels * d_model;
2951+
2952+
for (auto pos : grid_stride_range_y(0, total_positions))
29342953
{
2954+
const long n = pos / seq_len;
2955+
const long s = pos % seq_len;
29352956
const float scale = 1.0f + scale_factor * (n_steps[pos] / max_steps);
2936-
const size_t n = pos / seq_len;
2937-
const size_t s = pos % seq_len;
29382957

2939-
for (size_t c = 0; c < num_channels; ++c) {
2940-
for (size_t d = 0; d < d_model; ++d) {
2941-
const size_t idx = ((n * num_channels + c) * seq_len + s) * d_model + d;
2942-
gradients[idx] *= scale;
2943-
}
2958+
for (auto feat_idx : grid_stride_range(0, feature_dim))
2959+
{
2960+
const long c = feat_idx / d_model;
2961+
const long d = feat_idx % d_model;
2962+
const long idx = ((n * num_channels + c) * seq_len + s) * d_model + d;
2963+
gradients[idx] *= scale;
29442964
}
29452965
}
29462966
}
@@ -2956,8 +2976,11 @@ namespace dlib
29562976
float scale_factor
29572977
)
29582978
{
2979+
const long total_positions = batch_size * seq_len;
2980+
const long feature_dim = num_channels * d_model;
2981+
29592982
launch_kernel(_cuda_apply_act_depth_scaling,
2960-
max_jobs(batch_size * seq_len),
2983+
max_jobs(feature_dim, total_positions),
29612984
gradients.device(),
29622985
n_steps.device(),
29632986
batch_size,

dlib/dnn/layers.h

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5720,7 +5720,7 @@ namespace dlib
57205720
max_steps_(max_steps),
57215721
halt_threshold_(0.99f), // theta in Graves' notation
57225722
ponder_penalty_(0.01f), // lambda (ponder cost weight)
5723-
enable_depth_scaling_(false),
5723+
enable_depth_scaling_(true),
57245724
batch_size_(0),
57255725
seq_len_(0),
57265726
d_model_(0),
@@ -5857,8 +5857,7 @@ namespace dlib
58575857
halting_probs_, logits_, input, params,
58585858
batch_size_, seq_len_, feature_dim_);
58595859

5860-
// CRITICAL: Capture effective weights before state update
5861-
// This ensures numerical precision in backward pass
5860+
// Capture effective weights before state update
58625861
const float* p_halt = halting_probs_.host();
58635862
const float* cum_halt = cum_halt_ptr;
58645863
const float* remainders = remainders_ptr;
@@ -5911,10 +5910,6 @@ namespace dlib
59115910
void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad) {
59125911
tensor& input_grad = sub.get_gradient_input();
59135912

5914-
// Propagate gradients to input using instrumented effective weights
5915-
// This approach ensures numerical precision by using the exact weights
5916-
// computed during the forward pass, avoiding reconstruction errors
5917-
59185913
const float* grad_in = gradient_input.host();
59195914
const float* eff_weights = true_effective_weights_.host();
59205915
float* grad_out = input_grad.host();

0 commit comments

Comments
 (0)