Skip to content

Commit 9e23ad8

Browse files
j143mboehm7
authored andcommitted
[SYSTEMDS-3899] Fix incorrect barrier in unary OOC operations
Closes #2306.
1 parent a80e3dc commit 9e23ad8

File tree

1 file changed

+44
-47
lines changed

1 file changed

+44
-47
lines changed

src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java

Lines changed: 44 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -30,61 +30,58 @@
3030
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
3131
import org.apache.sysds.runtime.util.CommonThreadPool;
3232

33-
import java.util.concurrent.ExecutionException;
3433
import java.util.concurrent.ExecutorService;
35-
import java.util.concurrent.Future;
3634

3735
public class UnaryOOCInstruction extends ComputationOOCInstruction {
38-
private UnaryOperator _uop = null;
36+
private UnaryOperator _uop = null;
3937

40-
protected UnaryOOCInstruction(OOCType type, UnaryOperator op, CPOperand in1, CPOperand out, String opcode, String istr) {
41-
super(type, op, in1, out, opcode, istr);
38+
protected UnaryOOCInstruction(OOCType type, UnaryOperator op, CPOperand in1, CPOperand out, String opcode, String istr) {
39+
super(type, op, in1, out, opcode, istr);
4240

43-
_uop = op;
44-
}
41+
_uop = op;
42+
}
4543

46-
public static UnaryOOCInstruction parseInstruction(String str) {
47-
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
48-
InstructionUtils.checkNumFields(parts, 2);
49-
String opcode = parts[0];
50-
CPOperand in1 = new CPOperand(parts[1]);
51-
CPOperand out = new CPOperand(parts[2]);
44+
public static UnaryOOCInstruction parseInstruction(String str) {
45+
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
46+
InstructionUtils.checkNumFields(parts, 2);
47+
String opcode = parts[0];
48+
CPOperand in1 = new CPOperand(parts[1]);
49+
CPOperand out = new CPOperand(parts[2]);
5250

53-
UnaryOperator uopcode = InstructionUtils.parseUnaryOperator(opcode);
54-
return new UnaryOOCInstruction(OOCType.Unary, uopcode, in1, out, opcode, str);
55-
}
51+
UnaryOperator uopcode = InstructionUtils.parseUnaryOperator(opcode);
52+
return new UnaryOOCInstruction(OOCType.Unary, uopcode, in1, out, opcode, str);
53+
}
5654

57-
public void processInstruction( ExecutionContext ec ) {
58-
UnaryOperator uop = (UnaryOperator) _uop;
59-
// Create thread and process the unary operation
60-
MatrixObject min = ec.getMatrixObject(input1);
61-
LocalTaskQueue<IndexedMatrixValue> qIn = min.getStreamHandle();
62-
LocalTaskQueue<IndexedMatrixValue> qOut = new LocalTaskQueue<>();
63-
ec.getMatrixObject(output).setStreamHandle(qOut);
55+
public void processInstruction( ExecutionContext ec ) {
56+
UnaryOperator uop = (UnaryOperator) _uop;
57+
// Create thread and process the unary operation
58+
MatrixObject min = ec.getMatrixObject(input1);
59+
LocalTaskQueue<IndexedMatrixValue> qIn = min.getStreamHandle();
60+
LocalTaskQueue<IndexedMatrixValue> qOut = new LocalTaskQueue<>();
61+
ec.getMatrixObject(output).setStreamHandle(qOut);
6462

6563

66-
ExecutorService pool = CommonThreadPool.get();
67-
try {
68-
Future<?> task =pool.submit(() -> {
69-
IndexedMatrixValue tmp = null;
70-
try {
71-
while ((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) {
72-
IndexedMatrixValue tmpOut = new IndexedMatrixValue();
73-
tmpOut.set(tmp.getIndexes(),
74-
tmp.getValue().unaryOperations(uop, new MatrixBlock()));
75-
qOut.enqueueTask(tmpOut);
76-
}
77-
qOut.closeInput();
78-
}
79-
catch(Exception ex) {
80-
throw new DMLRuntimeException(ex);
81-
}
82-
});
83-
task.get();
84-
} catch (ExecutionException | InterruptedException e) {
85-
throw new RuntimeException(e);
86-
} finally {
87-
pool.shutdown();
88-
}
89-
}
64+
ExecutorService pool = CommonThreadPool.get();
65+
try {
66+
pool.submit(() -> {
67+
IndexedMatrixValue tmp = null;
68+
try {
69+
while ((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) {
70+
IndexedMatrixValue tmpOut = new IndexedMatrixValue();
71+
tmpOut.set(tmp.getIndexes(),
72+
tmp.getValue().unaryOperations(uop, new MatrixBlock()));
73+
qOut.enqueueTask(tmpOut);
74+
}
75+
qOut.closeInput();
76+
}
77+
catch(Exception ex) {
78+
throw new DMLRuntimeException(ex);
79+
}
80+
});
81+
} catch (Exception ex) {
82+
throw new DMLRuntimeException(ex);
83+
} finally {
84+
pool.shutdown();
85+
}
86+
}
9087
}

0 commit comments

Comments
 (0)