Skip to content

Commit e7227d7

Browse files
committed
use fma
1 parent 2270209 commit e7227d7

File tree

1 file changed

+39
-20
lines changed

1 file changed

+39
-20
lines changed

highs/pdlp/hipdlp/pdhg.cu

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include <cuda_runtime.h>
22
#include <device_launch_parameters.h>
33
#include <cmath>
4+
#include <cstdio>
5+
#include <cublas_v2.h>
46

57
// Define Infinity for GPU
68
#define GPU_INF 1e20
@@ -11,6 +13,15 @@
1113
#define IDX_PRIMAL_OBJ 2
1214
#define IDX_DUAL_OBJ 3
1315

16+
// Add to pdhg.cu
17+
#define FULL_WARP_REDUCE(val) { \
18+
val += __shfl_down_sync(0xFFFFFFFF, val, 16); \
19+
val += __shfl_down_sync(0xFFFFFFFF, val, 8); \
20+
val += __shfl_down_sync(0xFFFFFFFF, val, 4); \
21+
val += __shfl_down_sync(0xFFFFFFFF, val, 2); \
22+
val += __shfl_down_sync(0xFFFFFFFF, val, 1); \
23+
}
24+
1425
// Utility for robust 1D kernel launches
1526
#define CUDA_GRID_STRIDE_LOOP(i, n) \
1627
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
@@ -22,21 +33,30 @@ static dim3 GetLaunchConfig(int n, int block_size = 256) {
2233
return dim3(num_blocks, 1, 1);
2334
}
2435

36+
// ============================================================================
37+
// FMA (Fused Multiply-Add) Helper
38+
// ============================================================================
39+
// Use __fma_rn for different precision:
40+
// __fma_rn - double precision, round to nearest
41+
// __fmaf_rn - single precision, round to nearest
42+
// FMA computes (a * b + c) in a single operation with only one rounding,
43+
// which is faster and more numerically accurate.
44+
45+
__device__ __forceinline__ double fma_rn(double a, double b, double c) {
46+
return __fma_rn(a, b, c); // a * b + c
47+
}
48+
2549
// === KERNEL 1: Update X (Primal Step) ===
2650
__global__ void kernelUpdateX(
2751
double* d_x_new, const double* d_x_old, const double* d_aty,
2852
const double* d_cost, const double* d_lower, const double* d_upper,
2953
double primal_step, int n_cols)
3054
{
3155
CUDA_GRID_STRIDE_LOOP(i, n_cols) {
32-
// 1. Compute gradient: gradient = c - A'y
33-
double gradient = d_cost[i] - d_aty[i];
34-
35-
// 2. Perform gradient step: x_updated = x_old - step * gradient
36-
double x_updated = d_x_old[i] - primal_step * gradient;
56+
double x_updated = fma_rn(primal_step, d_aty[i] - d_cost[i], d_x_old[i]);
3757

3858
// 3. Project to bounds [l, u]
39-
d_x_new[i] = fmax(d_lower[i], fmin(x_updated, d_upper[i]));
59+
d_x_new[i] = fmin(fmax(x_updated, d_lower[i]), d_upper[i]);
4060
}
4161
}
4262

@@ -48,30 +68,29 @@ __global__ void kernelUpdateY(
4868
double dual_step, int n_rows)
4969
{
5070
CUDA_GRID_STRIDE_LOOP(j, n_rows) {
51-
double extra_ax = 2.0 * d_ax_new[j] - d_ax_old[j];
52-
double dual_update = d_y_old[j] + dual_step * (d_rhs[j] - extra_ax);
53-
if (d_is_equality[j]){// to be optimized
54-
d_y_new[j] = dual_update; // No bounds for equality constr aints
55-
} else {
56-
d_y_new[j] = fmax(0.0, dual_update); // Project to non-negative orthant
57-
}
71+
double residual = fma_rn(-2.0, d_ax_new[j], d_rhs[j] + d_ax_old[j]);
72+
double dual_update = fma_rn(dual_step, residual, d_y_old[j]);
73+
d_y_new[j] = d_is_equality[j] ? dual_update : fmax(0.0, dual_update);
5874
}
5975
}
6076

6177
// === KERNEL 3: Update Averages ===
62-
// x_sum = x_sum + weight * x_next
78+
// x_sum = x_sum + weight * x_next.
6379
// y_sum = y_sum + weight * y_next
6480
__global__ void kernelUpdateAverages(
6581
double* d_x_sum, double* d_y_sum,
6682
const double* d_x_next, const double* d_y_next,
6783
double weight, int n_cols, int n_rows)
6884
{
69-
CUDA_GRID_STRIDE_LOOP(i, n_cols) {
70-
d_x_sum[i] += weight * d_x_next[i];
71-
}
72-
CUDA_GRID_STRIDE_LOOP(j, n_rows) {
73-
d_y_sum[j] += weight * d_y_next[j];
74-
}
85+
// Update x_sum
86+
CUDA_GRID_STRIDE_LOOP(i, n_cols) {
87+
d_x_sum[i] = fma_rn(weight, d_x_next[i], d_x_sum[i]);
88+
}
89+
90+
// Update y_sum
91+
CUDA_GRID_STRIDE_LOOP(j, n_rows) {
92+
d_y_sum[j] = fma_rn(weight, d_y_next[j], d_y_sum[j]);
93+
}
7594
}
7695

7796
__global__ void kernelScaleVector(

0 commit comments

Comments
 (0)