Skip to content

Commit a2651be

Browse files
committed
Code refactor: CUB DeviceReduce for bound norm
1 parent dddf0df commit a2651be

File tree

1 file changed

+32
-42
lines changed

1 file changed

+32
-42
lines changed

src/preconditioner.cu

Lines changed: 32 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License.
2121
#include <math.h>
2222
#include <cuda_runtime.h>
2323
#include <cublas_v2.h>
24+
#include <cub/cub.cuh>
2425

2526
#define SCALING_EPSILON 1e-12
2627

@@ -58,11 +59,11 @@ __global__ void clamp_sqrt_and_accum_kernel(double *__restrict__ scaling_factors
5859
double *__restrict__ inverse_scaling_factors,
5960
double *__restrict__ cumulative_rescaling,
6061
int num_variables);
61-
__global__ void reduce_bound_norm_sq_kernel(
62+
__global__ void compute_bound_contrib_kernel(
6263
const double *__restrict__ constraint_lower_bound,
6364
const double *__restrict__ constraint_upper_bound,
6465
int num_constraints,
65-
double *__restrict__ block_sums);
66+
double *__restrict__ contrib);
6667
__global__ void scale_bounds_kernel(
6768
double *__restrict__ constraint_lower_bound,
6869
double *__restrict__ constraint_upper_bound,
@@ -210,14 +211,23 @@ static void bound_objective_rescaling(
210211
const int num_constraints = state->num_constraints;
211212
const int num_variables = state->num_variables;
212213

213-
double *bnd_norm_sq_d = NULL;
214-
CUDA_CHECK(cudaMalloc(&bnd_norm_sq_d, sizeof(double)));
215-
CUDA_CHECK(cudaMemset(bnd_norm_sq_d, 0, sizeof(double)));
216-
reduce_bound_norm_sq_kernel<<<1, THREADS_PER_BLOCK, THREADS_PER_BLOCK * sizeof(double)>>>(
214+
double *contrib_d = nullptr;
215+
CUDA_CHECK(cudaMalloc(&contrib_d, num_constraints * sizeof(double)));
216+
compute_bound_contrib_kernel<<<state->num_blocks_dual, THREADS_PER_BLOCK>>>(
217217
state->constraint_lower_bound,
218218
state->constraint_upper_bound,
219219
num_constraints,
220-
bnd_norm_sq_d);
220+
contrib_d);
221+
222+
double *bnd_norm_sq_d = nullptr;
223+
CUDA_CHECK(cudaMalloc(&bnd_norm_sq_d, sizeof(double)));;
224+
void *temp_storage = nullptr;
225+
size_t temp_bytes = 0;
226+
CUDA_CHECK(cub::DeviceReduce::Sum(temp_storage, temp_bytes, contrib_d, bnd_norm_sq_d, num_constraints));
227+
CUDA_CHECK(cudaMalloc(&temp_storage, temp_bytes));
228+
CUDA_CHECK(cub::DeviceReduce::Sum(temp_storage, temp_bytes, contrib_d, bnd_norm_sq_d, num_constraints));
229+
CUDA_CHECK(cudaFree(contrib_d));
230+
CUDA_CHECK(cudaFree(temp_storage));
221231

222232
double bnd_norm_sq_h = 0.0;
223233
CUDA_CHECK(cudaMemcpy(&bnd_norm_sq_h, bnd_norm_sq_d, sizeof(double), cudaMemcpyDeviceToHost));
@@ -432,49 +442,29 @@ __global__ void clamp_sqrt_and_accum_kernel(double *__restrict__ scaling_factors
432442
}
433443
}
434444

435-
__global__ void reduce_bound_norm_sq_kernel(
445+
__global__ void compute_bound_contrib_kernel(
436446
const double *__restrict__ constraint_lower_bound,
437447
const double *__restrict__ constraint_upper_bound,
438448
int num_constraints,
439-
double *__restrict__ block_sums)
449+
double *__restrict__ contrib)
440450
{
441-
extern __shared__ double sdata[];
442-
int tid = threadIdx.x;
443-
int global_tid = blockIdx.x * blockDim.x + tid;
444-
int stride = blockDim.x * gridDim.x;
451+
int i = blockIdx.x * blockDim.x + threadIdx.x;
452+
if (i >= num_constraints) return;
453+
454+
double Li = constraint_lower_bound[i];
455+
double Ui = constraint_upper_bound[i];
456+
bool fL = isfinite(Li);
457+
bool fU = isfinite(Ui);
445458

446459
double acc = 0.0;
447-
for (int i = global_tid; i < num_constraints; i += stride)
448-
{
449-
double Li = constraint_lower_bound[i], Ui = constraint_upper_bound[i];
450-
bool fL = isfinite(Li), fU = isfinite(Ui);
451-
452-
if (fL && (!fU || fabs(Li - Ui) > SCALING_EPSILON))
453-
{
454-
acc += Li * Li;
455-
}
456-
if (fU)
457-
{
458-
acc += Ui * Ui;
459-
}
460-
}
461460

462-
sdata[tid] = acc;
463-
__syncthreads();
461+
// follow the existing semantics
462+
if (fL && (!fU || fabs(Li - Ui) > SCALING_EPSILON))
463+
acc += Li * Li;
464+
if (fU)
465+
acc += Ui * Ui;
464466

465-
for (int s = blockDim.x / 2; s > 0; s >>= 1)
466-
{
467-
if (tid < s)
468-
{
469-
sdata[tid] += sdata[tid + s];
470-
}
471-
__syncthreads();
472-
}
473-
474-
if (tid == 0)
475-
{
476-
block_sums[blockIdx.x] = sdata[0];
477-
}
467+
contrib[i] = acc;
478468
}
479469

480470
__global__ void scale_bounds_kernel(

0 commit comments

Comments
 (0)