Skip to content

Commit 88bc2b9

Browse files
committed
Scope failures locally
1 parent 0324985 commit 88bc2b9

File tree

3 files changed

+13
-29
lines changed

3 files changed

+13
-29
lines changed

src/main/java/org/apache/sysds/api/DMLScript.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -444,9 +444,6 @@ private static void execute(String dmlScriptStr, String fnameOptConfig, Map<Stri
444444

445445
// optionally register for monitoring
446446
registerForMonitoring();
447-
448-
// reset any errors from the LocalTaskQueue
449-
LocalTaskQueue.resetFailures();
450447

451448
//Step 1: parse configuration files & write any configuration specific global variables
452449
loadConfiguration(fnameOptConfig);

src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ public class LocalTaskQueue<T>
4242

4343
public static final int MAX_SIZE = 100000; //main memory constraint
4444
public static final Object NO_MORE_TASKS = null; //object to signal NO_MORE_TASKS
45-
private static volatile DMLRuntimeException FAILURE = null;
4645

4746
private LinkedList<T> _data = null;
4847
private boolean _closedInput = false;
48+
private DMLRuntimeException _failure = null;
4949
private static final Log LOG = LogFactory.getLog(LocalTaskQueue.class.getName());
5050

5151
public LocalTaskQueue()
@@ -63,14 +63,14 @@ public LocalTaskQueue()
6363
public synchronized void enqueueTask( T t )
6464
throws InterruptedException
6565
{
66-
while( _data.size() + 1 > MAX_SIZE && FAILURE == null )
66+
while( _data.size() + 1 > MAX_SIZE && _failure == null )
6767
{
6868
LOG.warn("MAX_SIZE of task queue reached.");
6969
wait(); //max constraint reached, wait for read
7070
}
7171

72-
if ( FAILURE != null )
73-
throw FAILURE;
72+
if ( _failure != null )
73+
throw _failure;
7474

7575
_data.addLast( t );
7676

@@ -87,16 +87,16 @@ public synchronized void enqueueTask( T t )
8787
public synchronized T dequeueTask()
8888
throws InterruptedException
8989
{
90-
while( _data.isEmpty() && FAILURE == null )
90+
while( _data.isEmpty() && _failure == null )
9191
{
9292
if( !_closedInput )
9393
wait(); // wait for writers
9494
else
9595
return (T)NO_MORE_TASKS;
9696
}
9797

98-
if ( FAILURE != null )
99-
throw FAILURE;
98+
if ( _failure != null )
99+
throw _failure;
100100

101101
T t = _data.removeFirst();
102102

@@ -119,8 +119,11 @@ public synchronized boolean isProcessed() {
119119
return _closedInput && _data.isEmpty();
120120
}
121121

122-
public synchronized void notifyFailure() {
123-
notifyAll();
122+
public synchronized void propagateFailure(DMLRuntimeException failure) {
123+
if (_failure == null) {
124+
_failure = failure;
125+
notifyAll();
126+
}
124127
}
125128

126129
@Override
@@ -147,18 +150,4 @@ public synchronized String toString()
147150

148151
return sb.toString();
149152
}
150-
151-
public static boolean failGlobally(DMLRuntimeException ex) {
152-
// Only register the first failure
153-
if (FAILURE == null) {
154-
FAILURE = ex;
155-
return true;
156-
}
157-
158-
return false;
159-
}
160-
161-
public static void resetFailures() {
162-
FAILURE = null;
163-
}
164153
}

src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,8 @@ private Runnable oocTask(Runnable r, LocalTaskQueue<?>... queues) {
111111
catch (Exception ex) {
112112
DMLRuntimeException re = ex instanceof DMLRuntimeException ? (DMLRuntimeException) ex : new DMLRuntimeException(ex);
113113

114-
LocalTaskQueue.failGlobally(re);
115-
116114
for (LocalTaskQueue<?> q : queues) {
117-
q.notifyFailure();
115+
q.propagateFailure(re);
118116
}
119117
}
120118
};

0 commit comments

Comments
 (0)