Skip to content

Commit d38e56c

Browse files
janniklindemboehm7
authored andcommitted
[SYSTEMDS-3923] Improve exception handling OOC instructions
Closes #2346.
1 parent 801d8e2 commit d38e56c

File tree

10 files changed

+199
-57
lines changed

10 files changed

+199
-57
lines changed

src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import org.apache.commons.logging.Log;
2525
import org.apache.commons.logging.LogFactory;
26+
import org.apache.sysds.runtime.DMLRuntimeException;
2627

2728
/**
2829
* This class provides a way of dynamic task distribution to multiple workers
@@ -43,7 +44,8 @@ public class LocalTaskQueue<T>
4344
public static final Object NO_MORE_TASKS = null; //object to signal NO_MORE_TASKS
4445

4546
private LinkedList<T> _data = null;
46-
private boolean _closedInput = false;
47+
private boolean _closedInput = false;
48+
private DMLRuntimeException _failure = null;
4749
private static final Log LOG = LogFactory.getLog(LocalTaskQueue.class.getName());
4850

4951
public LocalTaskQueue()
@@ -61,11 +63,14 @@ public LocalTaskQueue()
6163
public synchronized void enqueueTask( T t )
6264
throws InterruptedException
6365
{
64-
while( _data.size() + 1 > MAX_SIZE )
66+
while( _data.size() + 1 > MAX_SIZE && _failure == null )
6567
{
6668
LOG.warn("MAX_SIZE of task queue reached.");
6769
wait(); //max constraint reached, wait for read
6870
}
71+
72+
if ( _failure != null )
73+
throw _failure;
6974

7075
_data.addLast( t );
7176

@@ -82,13 +87,16 @@ public synchronized void enqueueTask( T t )
8287
public synchronized T dequeueTask()
8388
throws InterruptedException
8489
{
85-
while( _data.isEmpty() )
90+
while( _data.isEmpty() && _failure == null )
8691
{
8792
if( !_closedInput )
8893
wait(); // wait for writers
8994
else
9095
return (T)NO_MORE_TASKS;
9196
}
97+
98+
if ( _failure != null )
99+
throw _failure;
92100

93101
T t = _data.removeFirst();
94102

@@ -111,6 +119,13 @@ public synchronized boolean isProcessed() {
111119
return _closedInput && _data.isEmpty();
112120
}
113121

122+
public synchronized void propagateFailure(DMLRuntimeException failure) {
123+
if (_failure == null) {
124+
_failure = failure;
125+
notifyAll();
126+
}
127+
}
128+
114129
@Override
115130
public synchronized String toString()
116131
{

src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,8 @@ public void processInstruction( ExecutionContext ec ) {
9090

9191
LocalTaskQueue<IndexedMatrixValue> qOut = new LocalTaskQueue<>();
9292
ec.getMatrixObject(output).setStreamHandle(qOut);
93-
ExecutorService pool = CommonThreadPool.get();
94-
try {
95-
pool.submit(() -> {
93+
94+
submitOOCTask(() -> {
9695
IndexedMatrixValue tmp = null;
9796
try {
9897
while((tmp = q.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) {
@@ -152,12 +151,7 @@ public void processInstruction( ExecutionContext ec ) {
152151
catch(Exception ex) {
153152
throw new DMLRuntimeException(ex);
154153
}
155-
});
156-
} catch (Exception ex) {
157-
throw new DMLRuntimeException(ex);
158-
} finally {
159-
pool.shutdown();
160-
}
154+
}, q, qOut);
161155
}
162156
// full aggregation
163157
else {

src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,7 @@ public void processInstruction( ExecutionContext ec ) {
7070
LocalTaskQueue<IndexedMatrixValue> qOut = new LocalTaskQueue<>();
7171
ec.getMatrixObject(output).setStreamHandle(qOut);
7272

73-
ExecutorService pool = CommonThreadPool.get();
74-
try {
75-
pool.submit(() -> {
73+
submitOOCTask(() -> {
7674
IndexedMatrixValue tmp = null;
7775
try {
7876
while((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) {
@@ -86,10 +84,6 @@ public void processInstruction( ExecutionContext ec ) {
8684
catch(Exception ex) {
8785
throw new DMLRuntimeException(ex);
8886
}
89-
});
90-
}
91-
finally {
92-
pool.shutdown();
93-
}
87+
}, qIn, qOut);
9488
}
9589
}

src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,7 @@ public void processInstruction( ExecutionContext ec ) {
9090
BinaryOperator plus = InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString());
9191
ec.getMatrixObject(output).setStreamHandle(qOut);
9292

93-
ExecutorService pool = CommonThreadPool.get();
94-
try {
95-
// Core logic: background thread
96-
pool.submit(() -> {
93+
submitOOCTask(() -> {
9794
IndexedMatrixValue tmp = null;
9895
try {
9996
while((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) {
@@ -134,12 +131,6 @@ public void processInstruction( ExecutionContext ec ) {
134131
finally {
135132
qOut.closeInput();
136133
}
137-
});
138-
} catch (Exception e) {
139-
throw new DMLRuntimeException(e);
140-
}
141-
finally {
142-
pool.shutdown();
143-
}
134+
}, qIn, qOut);
144135
}
145136
}

src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,16 @@
2222
import org.apache.commons.logging.Log;
2323
import org.apache.commons.logging.LogFactory;
2424
import org.apache.sysds.api.DMLScript;
25+
import org.apache.sysds.runtime.DMLRuntimeException;
2526
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
27+
import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
2628
import org.apache.sysds.runtime.instructions.Instruction;
2729
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
2830
import org.apache.sysds.runtime.matrix.operators.Operator;
31+
import org.apache.sysds.runtime.util.CommonThreadPool;
2932

3033
import java.util.HashMap;
34+
import java.util.concurrent.ExecutorService;
3135

3236
public abstract class OOCInstruction extends Instruction {
3337
protected static final Log LOG = LogFactory.getLog(OOCInstruction.class.getName());
@@ -86,6 +90,37 @@ public void postprocessInstruction(ExecutionContext ec) {
8690
ec.maintainLineageDebuggerInfo(this);
8791
}
8892

93+
protected void submitOOCTask(Runnable r, LocalTaskQueue<?>... queues) {
94+
ExecutorService pool = CommonThreadPool.get();
95+
try {
96+
pool.submit(oocTask(r, queues));
97+
}
98+
catch (Exception ex) {
99+
throw new DMLRuntimeException(ex);
100+
}
101+
finally {
102+
pool.shutdown();
103+
}
104+
}
105+
106+
private Runnable oocTask(Runnable r, LocalTaskQueue<?>... queues) {
107+
return () -> {
108+
try {
109+
r.run();
110+
}
111+
catch (Exception ex) {
112+
DMLRuntimeException re = ex instanceof DMLRuntimeException ? (DMLRuntimeException) ex : new DMLRuntimeException(ex);
113+
114+
for (LocalTaskQueue<?> q : queues) {
115+
q.propagateFailure(re);
116+
}
117+
118+
// Rethrow to ensure proper future handling
119+
throw re;
120+
}
121+
};
122+
}
123+
89124
/**
90125
* Tracks blocks and their counts to enable early emission
91126
* once all blocks for a given index are processed.

src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,7 @@ public void processInstruction(ExecutionContext ec) {
7979

8080
//create queue, spawn thread for asynchronous reading, and return
8181
LocalTaskQueue<IndexedMatrixValue> q = new LocalTaskQueue<IndexedMatrixValue>();
82-
ExecutorService pool = CommonThreadPool.get();
83-
try {
84-
pool.submit(() -> readBinaryBlock(q, min.getFileName()));
85-
}
86-
finally {
87-
pool.shutdown();
88-
}
82+
submitOOCTask(() -> readBinaryBlock(q, min.getFileName()), q);
8983

9084
MatrixObject mout = ec.getMatrixObject(output);
9185
mout.setStreamHandle(q);

src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,7 @@ public void processInstruction( ExecutionContext ec ) {
6060
LocalTaskQueue<IndexedMatrixValue> qOut = new LocalTaskQueue<>();
6161
ec.getMatrixObject(output).setStreamHandle(qOut);
6262

63-
64-
ExecutorService pool = CommonThreadPool.get();
65-
try {
66-
pool.submit(() -> {
63+
submitOOCTask(() -> {
6764
IndexedMatrixValue tmp = null;
6865
try {
6966
while ((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) {
@@ -79,11 +76,6 @@ public void processInstruction( ExecutionContext ec ) {
7976
catch(Exception ex) {
8077
throw new DMLRuntimeException(ex);
8178
}
82-
});
83-
} catch (Exception ex) {
84-
throw new DMLRuntimeException(ex);
85-
} finally {
86-
pool.shutdown();
87-
}
79+
}, qIn, qOut);
8880
}
8981
}

src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,7 @@ public void processInstruction( ExecutionContext ec ) {
6161
ec.getMatrixObject(output).setStreamHandle(qOut);
6262

6363

64-
ExecutorService pool = CommonThreadPool.get();
65-
try {
66-
pool.submit(() -> {
64+
submitOOCTask(() -> {
6765
IndexedMatrixValue tmp = null;
6866
try {
6967
while ((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) {
@@ -77,11 +75,6 @@ public void processInstruction( ExecutionContext ec ) {
7775
catch(Exception ex) {
7876
throw new DMLRuntimeException(ex);
7977
}
80-
});
81-
} catch (Exception ex) {
82-
throw new DMLRuntimeException(ex);
83-
} finally {
84-
pool.shutdown();
85-
}
78+
}, qIn, qOut);
8679
}
8780
}
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.functions.ooc;
21+
22+
import org.apache.sysds.common.Types;
23+
import org.apache.sysds.runtime.io.MatrixWriter;
24+
import org.apache.sysds.runtime.io.MatrixWriterFactory;
25+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
26+
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
27+
import org.apache.sysds.runtime.util.DataConverter;
28+
import org.apache.sysds.runtime.util.HDFSTool;
29+
import org.apache.sysds.test.AutomatedTestBase;
30+
import org.apache.sysds.test.TestConfiguration;
31+
import org.apache.sysds.test.TestUtils;
32+
import org.junit.Test;
33+
34+
import java.io.IOException;
35+
36+
public class OOCExceptionHandlingTest extends AutomatedTestBase {
37+
private final static String TEST_NAME1 = "OOCExceptionHandling";
38+
private final static String TEST_DIR = "functions/ooc/";
39+
private final static String TEST_CLASS_DIR = TEST_DIR + OOCExceptionHandlingTest.class.getSimpleName() + "/";
40+
private static final String INPUT_NAME = "X";
41+
private static final String INPUT_NAME_2 = "Y";
42+
private static final String OUTPUT_NAME = "res";
43+
44+
private final static int rows = 1000;
45+
private final static int cols = 1000;
46+
47+
@Override
48+
public void setUp() {
49+
TestUtils.clearAssertionInformation();
50+
TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1);
51+
addTestConfiguration(TEST_NAME1, config);
52+
}
53+
54+
@Test
55+
public void runOOCExceptionHandlingTest1() {
56+
runOOCExceptionHandlingTest(500);
57+
}
58+
59+
@Test
60+
public void runOOCExceptionHandlingTest2() {
61+
runOOCExceptionHandlingTest(750);
62+
}
63+
64+
65+
private void runOOCExceptionHandlingTest(int misalignVals) {
66+
Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE);
67+
68+
try {
69+
getAndLoadTestConfiguration(TEST_NAME1);
70+
71+
String HOME = SCRIPT_DIR + TEST_DIR;
72+
fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
73+
programArgs = new String[] {"-explain", "-stats", "-ooc", "-args", input(INPUT_NAME), input(INPUT_NAME_2), output(OUTPUT_NAME)};
74+
75+
// 1. Generate the data in-memory as MatrixBlock objects
76+
double[][] A_data = getRandomMatrix(rows, cols, 1, 2, 1, 7);
77+
double[][] B_data = getRandomMatrix(rows, 1, 1, 2, 1, 7);
78+
79+
// 2. Convert the double arrays to MatrixBlock objects
80+
MatrixBlock A_mb = DataConverter.convertToMatrixBlock(A_data);
81+
MatrixBlock B_mb = DataConverter.convertToMatrixBlock(B_data);
82+
83+
// 3. Create a binary matrix writer
84+
MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY);
85+
86+
// 4. Write matrix A to a binary SequenceFile
87+
88+
// Here, we write two faulty matrices which will only be recognized at runtime
89+
writer.writeMatrixToHDFS(A_mb, input(INPUT_NAME), rows, cols, misalignVals, A_mb.getNonZeros());
90+
HDFSTool.writeMetaDataFile(input(INPUT_NAME + ".mtd"), Types.ValueType.FP64,
91+
new MatrixCharacteristics(rows, cols, 1000, A_mb.getNonZeros()), Types.FileFormat.BINARY);
92+
93+
writer.writeMatrixToHDFS(B_mb, input(INPUT_NAME_2), rows, 1, 1000, B_mb.getNonZeros());
94+
HDFSTool.writeMetaDataFile(input(INPUT_NAME_2 + ".mtd"), Types.ValueType.FP64,
95+
new MatrixCharacteristics(rows, 1, 1000, B_mb.getNonZeros()), Types.FileFormat.BINARY);
96+
97+
runTest(true, true, null, -1);
98+
}
99+
catch(IOException e) {
100+
throw new RuntimeException(e);
101+
}
102+
finally {
103+
resetExecMode(platformOld);
104+
}
105+
}
106+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
# Read the input matrix as a stream
23+
X = read($1);
24+
b = read($2);
25+
26+
OOC = colSums(X %*% b);
27+
28+
write(OOC, $3, format="binary");

0 commit comments

Comments
 (0)