@@ -21,6 +21,7 @@ limitations under the License.
2121#include < math.h>
2222#include < cuda_runtime.h>
2323#include < cublas_v2.h>
24+ #include < cub/cub.cuh>
2425
2526#define SCALING_EPSILON 1e-12
2627
@@ -58,11 +59,11 @@ __global__ void clamp_sqrt_and_accum_kernel(double *__restrict__ scaling_factors
5859 double *__restrict__ inverse_scaling_factors,
5960 double *__restrict__ cumulative_rescaling,
6061 int num_variables);
61- __global__ void reduce_bound_norm_sq_kernel (
62+ __global__ void compute_bound_contrib_kernel (
6263 const double *__restrict__ constraint_lower_bound,
6364 const double *__restrict__ constraint_upper_bound,
6465 int num_constraints,
65- double *__restrict__ block_sums );
66+ double *__restrict__ contrib );
6667__global__ void scale_bounds_kernel (
6768 double *__restrict__ constraint_lower_bound,
6869 double *__restrict__ constraint_upper_bound,
@@ -210,14 +211,23 @@ static void bound_objective_rescaling(
210211 const int num_constraints = state->num_constraints ;
211212 const int num_variables = state->num_variables ;
212213
213- double *bnd_norm_sq_d = NULL ;
214- CUDA_CHECK (cudaMalloc (&bnd_norm_sq_d, sizeof (double )));
215- CUDA_CHECK (cudaMemset (bnd_norm_sq_d, 0 , sizeof (double )));
216- reduce_bound_norm_sq_kernel<<<1 , THREADS_PER_BLOCK, THREADS_PER_BLOCK * sizeof (double )>>> (
214+ double *contrib_d = nullptr ;
215+ CUDA_CHECK (cudaMalloc (&contrib_d, num_constraints * sizeof (double )));
216+ compute_bound_contrib_kernel<<<state->num_blocks_dual, THREADS_PER_BLOCK>>> (
217217 state->constraint_lower_bound ,
218218 state->constraint_upper_bound ,
219219 num_constraints,
220- bnd_norm_sq_d);
220+ contrib_d);
221+
222+ double *bnd_norm_sq_d = nullptr ;
223+ CUDA_CHECK (cudaMalloc (&bnd_norm_sq_d, sizeof (double )));;
224+ void *temp_storage = nullptr ;
225+ size_t temp_bytes = 0 ;
226+ CUDA_CHECK (cub::DeviceReduce::Sum (temp_storage, temp_bytes, contrib_d, bnd_norm_sq_d, num_constraints));
227+ CUDA_CHECK (cudaMalloc (&temp_storage, temp_bytes));
228+ CUDA_CHECK (cub::DeviceReduce::Sum (temp_storage, temp_bytes, contrib_d, bnd_norm_sq_d, num_constraints));
229+ CUDA_CHECK (cudaFree (contrib_d));
230+ CUDA_CHECK (cudaFree (temp_storage));
221231
222232 double bnd_norm_sq_h = 0.0 ;
223233 CUDA_CHECK (cudaMemcpy (&bnd_norm_sq_h, bnd_norm_sq_d, sizeof (double ), cudaMemcpyDeviceToHost));
@@ -432,49 +442,29 @@ __global__ void clamp_sqrt_and_accum_kernel(double *__restrict__ scaling_factors
432442 }
433443}
434444
435- __global__ void reduce_bound_norm_sq_kernel (
445+ __global__ void compute_bound_contrib_kernel (
436446 const double *__restrict__ constraint_lower_bound,
437447 const double *__restrict__ constraint_upper_bound,
438448 int num_constraints,
439- double *__restrict__ block_sums )
449+ double *__restrict__ contrib )
440450{
441- extern __shared__ double sdata[];
442- int tid = threadIdx .x ;
443- int global_tid = blockIdx .x * blockDim .x + tid;
444- int stride = blockDim .x * gridDim .x ;
451+ int i = blockIdx .x * blockDim .x + threadIdx .x ;
452+ if (i >= num_constraints) return ;
453+
454+ double Li = constraint_lower_bound[i];
455+ double Ui = constraint_upper_bound[i];
456+ bool fL = isfinite (Li);
457+ bool fU = isfinite (Ui);
445458
446459 double acc = 0.0 ;
447- for (int i = global_tid; i < num_constraints; i += stride)
448- {
449- double Li = constraint_lower_bound[i], Ui = constraint_upper_bound[i];
450- bool fL = isfinite (Li), fU = isfinite (Ui);
451-
452- if (fL && (!fU || fabs (Li - Ui) > SCALING_EPSILON))
453- {
454- acc += Li * Li;
455- }
456- if (fU )
457- {
458- acc += Ui * Ui;
459- }
460- }
461460
462- sdata[tid] = acc;
463- __syncthreads ();
461+ // follow the existing semantics
462+ if (fL && (!fU || fabs (Li - Ui) > SCALING_EPSILON))
463+ acc += Li * Li;
464+ if (fU )
465+ acc += Ui * Ui;
464466
465- for (int s = blockDim .x / 2 ; s > 0 ; s >>= 1 )
466- {
467- if (tid < s)
468- {
469- sdata[tid] += sdata[tid + s];
470- }
471- __syncthreads ();
472- }
473-
474- if (tid == 0 )
475- {
476- block_sums[blockIdx .x ] = sdata[0 ];
477- }
467+ contrib[i] = acc;
478468}
479469
480470__global__ void scale_bounds_kernel (
0 commit comments