Skip to content
This repository was archived by the owner on Jul 26, 2025. It is now read-only.

Commit 67f9167

Browse files
authored
Improve precondition checks (#62)
* Refactor argument checking; Introduce new unsafeSetup method * Fix tests * Fix bug * Add test with verbose true
1 parent ed6b406 commit 67f9167

File tree

2 files changed

+131
-70
lines changed

2 files changed

+131
-70
lines changed

src/main/java/com/ustermetrics/ecos4j/Model.java

Lines changed: 109 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
import java.lang.foreign.MemorySegment;
1111
import java.util.Arrays;
1212
import java.util.Optional;
13+
import java.util.stream.IntStream;
1314

1415
import static com.google.common.base.Preconditions.checkArgument;
1516
import static com.google.common.base.Preconditions.checkState;
1617
import static com.google.common.base.Verify.verify;
1718
import static com.ustermetrics.ecos4j.bindings.ecos_h.*;
19+
import static java.lang.Math.toIntExact;
1820
import static java.lang.foreign.MemorySegment.NULL;
1921

2022
/**
@@ -74,70 +76,60 @@ public static String version() {
7476
public void setup(long l, long @NonNull [] q, long nExC, double @NonNull [] gpr, long @NonNull [] gjc,
7577
long @NonNull [] gir, double @NonNull [] c, double @NonNull [] h, double[] apr, long[] ajc,
7678
long[] air, double[] b) {
77-
checkState(stage == Stage.NEW, "Model must be in stage new");
78-
79-
checkArgument(l >= 0, "dimension of the positive orthant l must be non-negative");
80-
val nCones = q.length;
81-
checkArgument(nCones == 0 || Arrays.stream(q).allMatch(d -> d > 0),
82-
"second-order cone dimensions q must be empty or each dimension q[i] must be positive");
83-
checkArgument(nExC >= 0, "number of exponential cones nExC must be non-negative");
84-
val nnzG = gpr.length;
85-
checkArgument(nnzG > 0, "number of non-zero elements in G (gpr.length) must be positive");
86-
checkArgument(nnzG == gir.length,
87-
"number of non-zero elements in G (gpr.length) must be equal to the number of elements in the row " +
88-
"index of G (gir.length)");
89-
val nColsG = gjc.length - 1;
90-
checkArgument(nColsG > 0, "number of columns of G (gjc.length - 1) must be positive");
91-
n = c.length;
92-
checkArgument(n > 0, "number of variables x (c.length) must be positive");
93-
m = h.length;
94-
checkArgument(m > 0, "dimension of all cones (h.length) must be positive");
95-
checkArgument(m == l + Arrays.stream(q).sum() + 3 * nExC,
96-
"dimension of all cones (h.length) must be equal to the sum of the positive orthant dimension l, the " +
97-
"second-order cone dimensions q[i], and three times the number of exponential cones 3 * nExC");
98-
checkArgument(nColsG == n, "number of columns of G (gjc.length - 1) must be equal to the number of variables " +
99-
"x (c.length)");
79+
checkArguments(l, q, nExC, gpr, gjc, gir, c, h, apr, ajc, air, b);
80+
unsafeSetup(l, q, nExC, gpr, gjc, gir, c, h, apr, ajc, air, b);
81+
}
10082

83+
private static void checkArguments(long l, long[] q, long nExC, double[] gpr, long[] gjc, long[] gir, double[] c,
84+
double[] h, double[] apr, long[] ajc, long[] air, double[] b) {
85+
checkCone(l, q, nExC, h);
86+
checkMatrix(gpr, gjc, gir, h.length, c.length, "G");
10187
checkArgument(apr != null && ajc != null && air != null && b != null || apr == null && ajc == null && air == null && b == null,
102-
"A (apr, ajc, air) and b must be supplied (all non-null) or omitted (all null) together");
88+
"all arguments of the equalities must be supplied together or must be null together");
10389
if (apr != null) {
104-
val nnzA = apr.length;
105-
checkArgument(nnzA > 0, "number of non-zero elements in A (apr.length) must be positive");
106-
checkArgument(nnzA == air.length,
107-
"number of non-zero elements in A (apr.length) must be equal to the number of elements in the row" +
108-
" index of A (air.length)");
109-
val nColsA = ajc.length - 1;
110-
checkArgument(nColsA > 0, "number of columns of A (ajc.length - 1) must be positive");
111-
p = b.length;
112-
checkArgument(p > 0, "number of equalities (b.length) must be positive");
113-
checkArgument(nColsA == n, "number of columns of A (ajc.length - 1) must be equal to the number of " +
114-
"variables x (c.length)");
115-
} else {
116-
p = 0;
90+
checkMatrix(apr, ajc, air, b.length, c.length, "A");
11791
}
92+
}
11893

119-
val qSeg = arena.allocateFrom(C_LONG_LONG, q);
120-
val gprSeg = arena.allocateFrom(C_DOUBLE, gpr);
121-
val gjcSeg = arena.allocateFrom(C_LONG_LONG, gjc);
122-
val girSeg = arena.allocateFrom(C_LONG_LONG, gir);
123-
val aprSeg = apr != null ? arena.allocateFrom(C_DOUBLE, apr) : NULL;
124-
val ajcSeg = ajc != null ? arena.allocateFrom(C_LONG_LONG, ajc) : NULL;
125-
val airSeg = air != null ? arena.allocateFrom(C_LONG_LONG, air) : NULL;
126-
val cSeg = arena.allocateFrom(C_DOUBLE, c);
127-
val hSeg = arena.allocateFrom(C_DOUBLE, h);
128-
val bSeg = b != null ? arena.allocateFrom(C_DOUBLE, b) : NULL;
129-
130-
workSeg = ECOS_setup(n, m, p, l, nCones, qSeg, nExC, gprSeg, gjcSeg, girSeg, aprSeg, ajcSeg, airSeg, cSeg,
131-
hSeg, bSeg).reinterpret(pwork.sizeof(), arena, null);
132-
verify(!NULL.equals(workSeg));
133-
134-
stgsSeg = pwork.stgs(workSeg).reinterpret(settings.sizeof(), arena, null);
135-
infoSeg = pwork.info(workSeg).reinterpret(stats.sizeof(), arena, null);
94+
private static void checkCone(long l, long[] q, long nExC, double[] h) {
95+
checkArgument(l >= 0, "dimension of the positive orthant must be non-negative");
96+
checkArgument(q.length == 0 || Arrays.stream(q).allMatch(i -> i > 0),
97+
"second-order cone dimensions must have zero length or each dimension must be positive");
98+
checkArgument(nExC >= 0, "number of exponential cones must be non-negative");
99+
checkArgument(h.length == l + Arrays.stream(q).sum() + 3 * nExC,
100+
"dimension of the convex cone K must be equal to the sum of the positive orthant dimension, the " +
101+
"second-order cone dimensions, and three times the number of exponential cones");
102+
}
136103

137-
stage = Stage.SETUP;
104+
private static void checkMatrix(double[] mpr, long[] mjc, long[] mir, int nRows, int nCols, String mName) {
105+
checkArgument(nRows > 0, "matrix %s: number of rows must be positive", mName);
106+
checkArgument(nCols > 0, "matrix %s: number of columns must be positive", mName);
107+
val nnz = mpr.length;
108+
checkArgument(0 < nnz && nnz <= nRows * nCols,
109+
"matrix %s: number of non-zero entries must be greater than zero and less equal than the number of " +
110+
"rows times the number of columns", mName);
111+
checkArgument(mir.length == nnz,
112+
"matrix %s: number of entries in the row index must be equal to the number of non-zero entries",
113+
mName);
114+
checkArgument(mjc.length == nCols + 1,
115+
"length of the column index must be equal to the number of columns plus one", mName);
116+
checkArgument(Arrays.stream(mir).allMatch(i -> 0 <= i && i < nRows),
117+
"matrix %s: entries of the row index must be greater equal zero and less than the number of rows",
118+
mName);
119+
checkArgument(mjc[0] == 0 && mjc[mjc.length - 1] == nnz,
120+
"matrix %s: the first entry of the column index must be equal to zero and the last entry must be " +
121+
"equal to the number of non-zero entries", mName);
122+
checkArgument(IntStream.range(0, mjc.length - 1).allMatch(i ->
123+
0 <= mjc[i] && mjc[i] <= nnz && mjc[i] <= mjc[i + 1]
124+
&& IntStream.range(toIntExact(mjc[i]), toIntExact(mjc[i + 1]) - 1).allMatch(j -> mir[j] < mir[j + 1])),
125+
"matrix %s: entries of the column index must be greater equal zero, less equal than the number of " +
126+
"non-zero entries, and must be ordered, entries of the row index within each column must be " +
127+
"strictly ordered", mName);
138128
}
139129

140130
/**
131+
* Set up the {@link Model} data.
132+
* <p>
141133
* Same as
142134
* {@link Model#setup(long l, long[] q, long nExC, double[] gpr, long[] gjc, long[] gir, double[] c, double[] h, double[] apr, long[] ajc, long[] air, double[] b)}
143135
* without equality constraint, i.e. {@code apr}, {@code ajc}, {@code air}, and {@code b} are empty arrays.
@@ -156,6 +148,60 @@ public void setup(long l, long @NonNull [] q, long nExC, double @NonNull [] gpr,
156148
setup(l, q, nExC, gpr, gjc, gir, c, h, null, null, null, null);
157149
}
158150

151+
/**
152+
* Unsafe set up the {@link Model} data.
153+
* <p>
154+
* Same as
155+
* {@link Model#setup(long l, long[] q, long nExC, double[] gpr, long[] gjc, long[] gir, double[] c, double[] h, double[] apr, long[] ajc, long[] air, double[] b)}
156+
* without any precondition checks on its arguments.
157+
* <p>
158+
* <b>Warning: Setting the arguments incorrectly may lead to incorrect results in the best case. In the worst
159+
* case, it can crash the JVM and may silently result in memory corruption.</b>
160+
*
161+
* @param l the dimension of the positive orthant.
162+
* @param q the dimensions of the second-order cones.
163+
* @param nExC the number of exponential cones.
164+
* @param gpr the sparse G matrix data (Column Compressed Storage CCS).
165+
* @param gjc the sparse G matrix column index (CCS).
166+
* @param gir the sparse G matrix row index (CCS).
167+
* @param c the cost function weights.
168+
* @param h the right-hand-side of the cone constraints.
169+
* @param apr the (optional) sparse A matrix data (CCS).
170+
* @param ajc the (optional) sparse A matrix column index (CCS).
171+
* @param air the (optional) sparse A matrix row index (CCS).
172+
* @param b the (optional) right-hand-side of the equalities.
173+
*/
174+
public void unsafeSetup(long l, long @NonNull [] q, long nExC, double @NonNull [] gpr, long @NonNull [] gjc,
175+
long @NonNull [] gir, double @NonNull [] c, double @NonNull [] h, double[] apr, long[] ajc,
176+
long[] air, double[] b) {
177+
checkState(stage == Stage.NEW, "model must be in stage new");
178+
179+
n = c.length;
180+
m = h.length;
181+
p = apr != null ? b.length : 0;
182+
val nCones = q.length;
183+
184+
val qSeg = arena.allocateFrom(C_LONG_LONG, q);
185+
val gprSeg = arena.allocateFrom(C_DOUBLE, gpr);
186+
val gjcSeg = arena.allocateFrom(C_LONG_LONG, gjc);
187+
val girSeg = arena.allocateFrom(C_LONG_LONG, gir);
188+
val aprSeg = apr != null ? arena.allocateFrom(C_DOUBLE, apr) : NULL;
189+
val ajcSeg = ajc != null ? arena.allocateFrom(C_LONG_LONG, ajc) : NULL;
190+
val airSeg = air != null ? arena.allocateFrom(C_LONG_LONG, air) : NULL;
191+
val cSeg = arena.allocateFrom(C_DOUBLE, c);
192+
val hSeg = arena.allocateFrom(C_DOUBLE, h);
193+
val bSeg = b != null ? arena.allocateFrom(C_DOUBLE, b) : NULL;
194+
195+
workSeg = ECOS_setup(n, m, p, l, nCones, qSeg, nExC, gprSeg, gjcSeg, girSeg, aprSeg, ajcSeg, airSeg, cSeg,
196+
hSeg, bSeg).reinterpret(pwork.sizeof(), arena, null);
197+
verify(!NULL.equals(workSeg));
198+
199+
stgsSeg = pwork.stgs(workSeg).reinterpret(settings.sizeof(), arena, null);
200+
infoSeg = pwork.info(workSeg).reinterpret(stats.sizeof(), arena, null);
201+
202+
stage = Stage.SETUP;
203+
}
204+
159205
/**
160206
* Sets the <a href="https://github.com/embotech/ecos">ECOS</a> solver options.
161207
* <p>
@@ -164,7 +210,7 @@ public void setup(long l, long @NonNull [] q, long nExC, double @NonNull [] gpr,
164210
* @param parameters the parameter object for the solver options.
165211
*/
166212
public void setParameters(@NonNull Parameters parameters) {
167-
checkState(stage != Stage.NEW, "Model must not be in stage new");
213+
checkState(stage != Stage.NEW, "model must not be in stage new");
168214

169215
Optional.ofNullable(parameters.feasTol())
170216
.ifPresent(feasTol -> settings.feastol(stgsSeg, feasTol));
@@ -192,7 +238,7 @@ public void setParameters(@NonNull Parameters parameters) {
192238
* @return the solver status.
193239
*/
194240
public Status optimize() {
195-
checkState(stage != Stage.NEW, "Model must not be in stage new");
241+
checkState(stage != Stage.NEW, "model must not be in stage new");
196242

197243
val status = ECOS_solve(workSeg);
198244
if (settings.verbose(stgsSeg) == 1) {
@@ -209,7 +255,7 @@ public Status optimize() {
209255
* Cleanup: free this {@link Model} native memory.
210256
*/
211257
public void cleanup() {
212-
checkState(stage != Stage.NEW, "Model must not be in stage new");
258+
checkState(stage != Stage.NEW, "model must not be in stage new");
213259
ECOS_cleanup(workSeg, 0);
214260
stage = Stage.NEW;
215261
}
@@ -349,6 +395,11 @@ public long iter() {
349395
return pwork.s(workSeg).reinterpret(C_DOUBLE.byteSize() * m, arena, null).toArray(C_DOUBLE);
350396
}
351397

398+
private void checkStageIsOptimizedAndStatusIsNotFatal() {
399+
checkState(stage == Stage.OPTIMIZED, "model must be in stage optimized");
400+
checkState(status != Status.FATAL, "solver status must not be fatal");
401+
}
402+
352403
@Override
353404
public void close() {
354405
if (stage != Stage.NEW) {
@@ -357,9 +408,4 @@ public void close() {
357408
arena.close();
358409
}
359410

360-
private void checkStageIsOptimizedAndStatusIsNotFatal() {
361-
checkState(stage == Stage.OPTIMIZED, "Model must be in stage optimized");
362-
checkState(status != Status.FATAL, "Solver status must not be fatal");
363-
}
364-
365411
}

src/test/java/com/ustermetrics/ecos4j/ModelTest.java

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,21 @@ void solvePortfolioOptimizationProblemReturnsExpectedSolution() {
7777
}
7878
}
7979

80+
@Test
81+
void solvePortfolioOptimizationProblemWithVerboseParameterTrueReturnsOptimal() {
82+
try (val model = new Model()) {
83+
model.setup(l, q, nExC, gpr, gjc, gir, c, h, apr, ajc, air, b);
84+
val parameters = Parameters.builder()
85+
.verbose(true)
86+
.build();
87+
model.setParameters(parameters);
88+
89+
val status = model.optimize();
90+
91+
assertEquals(Status.OPTIMAL, status);
92+
}
93+
}
94+
8095
@Test
8196
void solveModifiedPortfolioOptimizationProblemWithMaxitLimitReturnsMaxitStatus() {
8297
try (val model = new Model()) {
@@ -102,7 +117,7 @@ void optimizeWithoutSetupThrowsException() {
102117
}
103118
});
104119

105-
assertEquals("Model must not be in stage new", exception.getMessage());
120+
assertEquals("model must not be in stage new", exception.getMessage());
106121
}
107122

108123
@Test
@@ -113,7 +128,7 @@ void getPrimalVariablesWithoutSetupThrowsException() {
113128
}
114129
});
115130

116-
assertEquals("Model must be in stage optimized", exception.getMessage());
131+
assertEquals("model must be in stage optimized", exception.getMessage());
117132
}
118133

119134
@Test
@@ -124,7 +139,7 @@ void cleanupWithoutSetupThrowsException() {
124139
}
125140
});
126141

127-
assertEquals("Model must not be in stage new", exception.getMessage());
142+
assertEquals("model must not be in stage new", exception.getMessage());
128143
}
129144

130145
@Test
@@ -135,7 +150,7 @@ void setParametersWithoutSetupThrowsException() {
135150
}
136151
});
137152

138-
assertEquals("Model must not be in stage new", exception.getMessage());
153+
assertEquals("model must not be in stage new", exception.getMessage());
139154
}
140155

141156
@Test
@@ -149,7 +164,7 @@ void setupAfterOptimizeThrowsException() {
149164
}
150165
});
151166

152-
assertEquals("Model must be in stage new", exception.getMessage());
167+
assertEquals("model must be in stage new", exception.getMessage());
153168
}
154169

155170
@Test
@@ -160,7 +175,7 @@ void setupWithInvalidPositiveOrthantDimensionThrowsException() {
160175
}
161176
});
162177

163-
assertEquals("dimension of the positive orthant l must be non-negative", exception.getMessage());
178+
assertEquals("dimension of the positive orthant must be non-negative", exception.getMessage());
164179
}
165180

166181
@Test
@@ -171,7 +186,7 @@ void setupWithInvalidNumberOfExponentialConesThrowsException() {
171186
}
172187
});
173188

174-
assertEquals("number of exponential cones nExC must be non-negative", exception.getMessage());
189+
assertEquals("number of exponential cones must be non-negative", exception.getMessage());
175190
}
176191

177192
@Test

0 commit comments

Comments
 (0)