@@ -2749,23 +2749,35 @@ namespace dlib
27492749 size_t feature_dim
27502750 )
27512751 {
2752- for (auto pos : grid_stride_range (0 , batch_size * seq_len))
2752+ const long total_positions = batch_size * seq_len;
2753+
2754+ for (auto pos : grid_stride_range_y (0 , total_positions))
2755+ for (auto i : grid_stride_range (0 , 1 ))
2756+ logits[pos] = b_halt;
2757+ __syncthreads ();
2758+
2759+ for (auto pos : grid_stride_range_y (0 , total_positions))
27532760 {
2754- const size_t n = pos / seq_len;
2755- const size_t s = pos % seq_len;
2761+ const long n = pos / seq_len;
2762+ const long s = pos % seq_len;
27562763
2757- float logit = b_halt;
2764+ float temp = 0 ;
2765+ for (auto feat_idx : grid_stride_range (0 , feature_dim))
2766+ {
2767+ const long c = feat_idx / d_model;
2768+ const long d = feat_idx % d_model;
27582769
2759- for (size_t c = 0 ; c < num_channels; ++c) {
2760- for (size_t d = 0 ; d < d_model; ++d) {
2761- const size_t in_idx = ((n * num_channels + c) * seq_len + s) * d_model + d;
2762- const size_t weight_idx = c * d_model + d;
2763- logit += input_data[in_idx] * W_halt[weight_idx];
2764- }
2770+ const long in_idx = ((n * num_channels + c) * seq_len + s) * d_model + d;
2771+ temp += input_data[in_idx] * W_halt[feat_idx];
27652772 }
27662773
2767- logits[pos] = logit;
2768- halt_probs[pos] = 1 .0f / (1 .0f + ::expf (-logit));
2774+ warp_reduce_atomic_add (logits[pos], temp);
2775+ }
2776+ __syncthreads ();
2777+
2778+ for (auto pos : grid_stride_range (0 , total_positions))
2779+ {
2780+ halt_probs[pos] = 1 .0f / (1 .0f + expf (-logits[pos]));
27692781 }
27702782 }
27712783
@@ -2783,8 +2795,11 @@ namespace dlib
27832795 const long d_model = feature_dim / input_data.k ();
27842796 const long num_channels = input_data.k ();
27852797
2798+ halt_probs.set_size (total_positions, 1 , 1 , 1 );
2799+ logits.set_size (total_positions, 1 , 1 , 1 );
2800+
27862801 launch_kernel (_cuda_compute_act_halt_probabilities,
2787- max_jobs (total_positions),
2802+ max_jobs (feature_dim, total_positions),
27882803 halt_probs.device (),
27892804 logits.device (),
27902805 input_data.device (),
@@ -2814,7 +2829,8 @@ namespace dlib
28142829 {
28152830 for (auto pos : grid_stride_range (0 , batch_size * seq_len))
28162831 {
2817- if (cumulative_halting[pos] < halt_threshold) {
2832+ if (cumulative_halting[pos] < halt_threshold)
2833+ {
28182834 const size_t n = pos / seq_len;
28192835 const size_t s = pos % seq_len;
28202836
@@ -2930,17 +2946,21 @@ namespace dlib
29302946 float scale_factor
29312947 )
29322948 {
2933- for (auto pos : grid_stride_range (0 , batch_size * seq_len))
2949+ const long total_positions = batch_size * seq_len;
2950+ const long feature_dim = num_channels * d_model;
2951+
2952+ for (auto pos : grid_stride_range_y (0 , total_positions))
29342953 {
2954+ const long n = pos / seq_len;
2955+ const long s = pos % seq_len;
29352956 const float scale = 1 .0f + scale_factor * (n_steps[pos] / max_steps);
2936- const size_t n = pos / seq_len;
2937- const size_t s = pos % seq_len;
29382957
2939- for (size_t c = 0 ; c < num_channels; ++c) {
2940- for (size_t d = 0 ; d < d_model; ++d) {
2941- const size_t idx = ((n * num_channels + c) * seq_len + s) * d_model + d;
2942- gradients[idx] *= scale;
2943- }
2958+ for (auto feat_idx : grid_stride_range (0 , feature_dim))
2959+ {
2960+ const long c = feat_idx / d_model;
2961+ const long d = feat_idx % d_model;
2962+ const long idx = ((n * num_channels + c) * seq_len + s) * d_model + d;
2963+ gradients[idx] *= scale;
29442964 }
29452965 }
29462966 }
@@ -2956,8 +2976,11 @@ namespace dlib
29562976 float scale_factor
29572977 )
29582978 {
2979+ const long total_positions = batch_size * seq_len;
2980+ const long feature_dim = num_channels * d_model;
2981+
29592982 launch_kernel (_cuda_apply_act_depth_scaling,
2960- max_jobs (batch_size * seq_len ),
2983+ max_jobs (feature_dim, total_positions ),
29612984 gradients.device (),
29622985 n_steps.device (),
29632986 batch_size,
0 commit comments