Skip to content

Commit 406b01e

Browse files
committed
More intelligent initial point selection
1 parent e2f2148 commit 406b01e

File tree

3 files changed

+203
-11
lines changed

3 files changed

+203
-11
lines changed

src/scs.c

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -174,21 +174,46 @@ static inline scs_int _is_nan(scs_float x) {
174174
}
175175

176176
/* given x,y,s warm start, set v = [x; s / R + y; 1]
177-
* check for nans and set to zero if present
177+
* check for nans:
178+
* x NaN -> w->g
179+
* y NaN -> -w->g
180+
* s NaN -> b - Ax (primal feasibility)
181+
* g is set by update_work_cache
178182
*/
179183
static void warm_start_vars(ScsWork *w, ScsSolution *sol) {
180184
scs_int n = w->d->n, m = w->d->m, i;
181185
scs_float *v = w->v;
186+
scs_float *g = w->g;
187+
scs_int s_has_nan = 0;
182188
/* normalize the warm-start */
183189
if (w->stgs->normalize) {
184190
SCS(normalize_sol)(w->scal, sol);
185191
}
186192
for (i = 0; i < n; ++i) {
187-
v[i] = _is_nan(sol->x[i]) ? 0. : sol->x[i];
193+
v[i] = _is_nan(sol->x[i]) ? g[i] : sol->x[i];
188194
}
195+
/* check if s has any NaN entries */
189196
for (i = 0; i < m; ++i) {
190-
v[i + n] = sol->y[i] + sol->s[i] / w->diag_r[i + n];
191-
v[i + n] = _is_nan(v[i + n]) ? 0. : v[i + n];
197+
if (_is_nan(sol->s[i])) {
198+
s_has_nan = 1;
199+
break;
200+
}
201+
}
202+
if (s_has_nan) {
203+
/* compute Ax into v[n:n+m] as scratch space */
204+
memset(v + n, 0, m * sizeof(scs_float));
205+
SCS(accum_by_a)(w->d->A, v, v + n); /* v[n:n+m] = A * x */
206+
/* for NaN entries of s, use b - Ax (primal feasibility) */
207+
for (i = 0; i < m; ++i) {
208+
scs_float si = _is_nan(sol->s[i]) ? (w->d->b[i] - v[i + n]) : sol->s[i];
209+
scs_float yi = _is_nan(sol->y[i]) ? -g[i + n] : sol->y[i];
210+
v[i + n] = yi + si / w->diag_r[i + n];
211+
}
212+
} else {
213+
for (i = 0; i < m; ++i) {
214+
scs_float yi = _is_nan(sol->y[i]) ? -g[i + n] : sol->y[i];
215+
v[i + n] = yi + sol->s[i] / w->diag_r[i + n];
216+
}
192217
}
193218
v[n + m] = 1.0; /* tau = 1 */
194219
/* un-normalize so sol unchanged */
@@ -339,9 +364,23 @@ static void populate_residual_struct(ScsWork *w, scs_int iter) {
339364
}
340365

341366
static void cold_start_vars(ScsWork *w) {
342-
scs_int l = w->d->n + w->d->m + 1;
343-
memset(w->v, 0, l * sizeof(scs_float));
344-
w->v[l - 1] = 1.;
367+
scs_int n = w->d->n, m = w->d->m, i;
368+
scs_float *v = w->v;
369+
scs_float *g = w->g;
370+
371+
for (i = 0; i < n; ++i) {
372+
v[i] = g[i];
373+
}
374+
375+
memset(v + n, 0, m * sizeof(scs_float));
376+
SCS(accum_by_a)(w->d->A, v, v + n); /* v[n:n+m] = A * x */
377+
/* for NaN entries of s, use b - Ax (primal feasibility) */
378+
for (i = 0; i < m; ++i) {
379+
scs_float si = (w->d->b[i] - v[i + n]);
380+
scs_float yi = -g[i + n];
381+
v[i + n] = yi + si / w->diag_r[i + n];
382+
}
383+
v[n + m] = 1.;
345384
}
346385

347386
/* utility function that computes x'Ry */
@@ -944,16 +983,17 @@ static void reset_tracking(ScsWork *w) {
944983
static scs_int update_work(ScsWork *w, ScsSolution *sol) {
945984
reset_tracking(w);
946985

986+
/* h = [c;b] */
987+
memcpy(w->h, w->d->c, w->d->n * sizeof(scs_float));
988+
memcpy(&(w->h[w->d->n]), w->d->b, w->d->m * sizeof(scs_float));
989+
update_work_cache(w);
990+
947991
if (w->stgs->warm_start) {
948992
warm_start_vars(w, sol);
949993
} else {
950994
cold_start_vars(w);
951995
}
952996

953-
/* h = [c;b] */
954-
memcpy(w->h, w->d->c, w->d->n * sizeof(scs_float));
955-
memcpy(&(w->h[w->d->n]), w->d->b, w->d->m * sizeof(scs_float));
956-
update_work_cache(w);
957997
return 0;
958998
}
959999

test/problems/partial_warm_start.h

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
#include "glbopts.h"
2+
#include "linalg.h"
3+
#include "minunit.h"
4+
#include "problem_utils.h"
5+
#include "scs.h"
6+
#include "scs_matrix.h"
7+
#include "util.h"
8+
9+
static const char *partial_warm_start(void) {
10+
ScsCone *k = (ScsCone *)scs_calloc(1, sizeof(ScsCone));
11+
ScsData *d = (ScsData *)scs_calloc(1, sizeof(ScsData));
12+
ScsSettings *stgs = (ScsSettings *)scs_calloc(1, sizeof(ScsSettings));
13+
ScsSolution *sol = (ScsSolution *)scs_calloc(1, sizeof(ScsSolution));
14+
ScsInfo info = {0};
15+
scs_int exitflag;
16+
scs_int cold_iters;
17+
scs_int i;
18+
scs_float perr, derr;
19+
scs_int success;
20+
const char *fail;
21+
22+
/* data (same as hs21_tiny_qp) */
23+
scs_float Ax[] = {-10., -1., 1., -1.};
24+
scs_int Ai[] = {1, 2, 1, 3};
25+
scs_int Ap[] = {0, 2, 4};
26+
27+
scs_float Px[] = {0.02, 2.};
28+
scs_int Pi[] = {0, 1};
29+
scs_int Pp[] = {0, 1, 2};
30+
31+
scs_float b[] = {1., 0., 0., 0.};
32+
scs_float c[] = {0., 0.};
33+
34+
scs_int m = 4;
35+
scs_int n = 2;
36+
37+
scs_float bl[] = {10.0, 2.0, -50.0};
38+
scs_float bu[] = {1e+20, 50.0, 50.0};
39+
scs_int bsize = 4;
40+
41+
scs_float opt = 0.04000000000000625;
42+
/* end data */
43+
44+
d->m = m;
45+
d->n = n;
46+
d->b = b;
47+
d->c = c;
48+
49+
d->A = (ScsMatrix *)scs_calloc(1, sizeof(ScsMatrix));
50+
d->P = (ScsMatrix *)scs_calloc(1, sizeof(ScsMatrix));
51+
52+
d->A->m = m;
53+
d->A->n = n;
54+
55+
d->A->x = Ax;
56+
d->A->i = Ai;
57+
d->A->p = Ap;
58+
59+
d->P->m = n;
60+
d->P->n = n;
61+
62+
d->P->x = Px;
63+
d->P->i = Pi;
64+
d->P->p = Pp;
65+
66+
k->bsize = bsize;
67+
k->bl = bl;
68+
k->bu = bu;
69+
70+
scs_set_default_settings(stgs);
71+
stgs->eps_abs = 1e-9;
72+
stgs->eps_rel = 1e-9;
73+
stgs->eps_infeas = 0.;
74+
stgs->acceleration_lookback = 0; /* disable acceleration for consistent iters */
75+
76+
/* Step 1: Cold solve to get reference solution and iteration count */
77+
exitflag = scs(d, k, stgs, sol, &info);
78+
79+
perr = info.pobj - opt;
80+
derr = info.dobj - opt;
81+
82+
success = ABS(perr) < 1e-3 && ABS(derr) < 1e-3 && exitflag == SCS_SOLVED;
83+
mu_assert("partial_warm_start: cold solve failed", success);
84+
fail = verify_solution_correct(d, k, stgs, &info, sol, exitflag);
85+
if (fail) {
86+
SCS(free_sol)(sol);
87+
scs_free(d->A);
88+
scs_free(d->P);
89+
scs_free(k);
90+
scs_free(stgs);
91+
scs_free(d);
92+
return fail;
93+
}
94+
95+
cold_iters = info.iter;
96+
scs_printf("partial_warm_start: cold solve took %li iters\n",
97+
(long)cold_iters);
98+
99+
/* Step 2: Partial warm start - keep x, set s and y to NaN */
100+
for (i = 0; i < m; ++i) {
101+
sol->s[i] = NAN;
102+
sol->y[i] = NAN;
103+
}
104+
105+
stgs->warm_start = 1;
106+
exitflag = scs(d, k, stgs, sol, &info);
107+
108+
perr = info.pobj - opt;
109+
derr = info.dobj - opt;
110+
111+
success = ABS(perr) < 1e-3 && ABS(derr) < 1e-3 && exitflag == SCS_SOLVED;
112+
mu_assert("partial_warm_start: partial warm start (x only) failed to solve",
113+
success);
114+
115+
scs_printf("partial_warm_start: partial warm start (x only) took %li iters\n",
116+
(long)info.iter);
117+
mu_assert(
118+
"partial_warm_start: partial warm start should take fewer iters than "
119+
"cold start",
120+
info.iter < cold_iters);
121+
122+
/* Step 3: All-NaN warm start - should still converge */
123+
for (i = 0; i < n; ++i) {
124+
sol->x[i] = NAN;
125+
}
126+
for (i = 0; i < m; ++i) {
127+
sol->s[i] = NAN;
128+
sol->y[i] = NAN;
129+
}
130+
131+
stgs->warm_start = 1;
132+
exitflag = scs(d, k, stgs, sol, &info);
133+
134+
perr = info.pobj - opt;
135+
derr = info.dobj - opt;
136+
137+
success = ABS(perr) < 1e-3 && ABS(derr) < 1e-3 && exitflag == SCS_SOLVED;
138+
mu_assert("partial_warm_start: all-NaN warm start failed to solve", success);
139+
140+
scs_printf("partial_warm_start: all-NaN warm start took %li iters\n",
141+
(long)info.iter);
142+
143+
SCS(free_sol)(sol);
144+
scs_free(d->A);
145+
scs_free(d->P);
146+
scs_free(k);
147+
scs_free(stgs);
148+
scs_free(d);
149+
return fail;
150+
}

test/run_tests.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "problems/small_lp.h"
1414
#include "problems/small_qp.h"
1515
#include "problems/test_exp_cone.h"
16+
#include "problems/partial_warm_start.h"
1617
#include "problems/unbounded_tiny_qp.h"
1718

1819
int tests_run = 0;
@@ -86,6 +87,7 @@ static const char *all_tests(void) {
8687
mu_run_test(complex_PSD);
8788
mu_run_test(sd_and_complex_sd);
8889
mu_run_test(hs21_tiny_qp);
90+
mu_run_test(partial_warm_start);
8991
mu_run_test(hs21_tiny_qp_rw);
9092
mu_run_test(qafiro_tiny_qp);
9193
mu_run_test(infeasible_tiny_qp);

0 commit comments

Comments
 (0)