|
30 | 30 | import org.apache.sysds.runtime.matrix.operators.UnaryOperator; |
31 | 31 | import org.apache.sysds.runtime.util.CommonThreadPool; |
32 | 32 |
|
33 | | -import java.util.concurrent.ExecutionException; |
34 | 33 | import java.util.concurrent.ExecutorService; |
35 | | -import java.util.concurrent.Future; |
36 | 34 |
|
37 | 35 | public class UnaryOOCInstruction extends ComputationOOCInstruction { |
38 | | - private UnaryOperator _uop = null; |
| 36 | + private UnaryOperator _uop = null; |
39 | 37 |
|
40 | | - protected UnaryOOCInstruction(OOCType type, UnaryOperator op, CPOperand in1, CPOperand out, String opcode, String istr) { |
41 | | - super(type, op, in1, out, opcode, istr); |
| 38 | + protected UnaryOOCInstruction(OOCType type, UnaryOperator op, CPOperand in1, CPOperand out, String opcode, String istr) { |
| 39 | + super(type, op, in1, out, opcode, istr); |
42 | 40 |
|
43 | | - _uop = op; |
44 | | - } |
| 41 | + _uop = op; |
| 42 | + } |
45 | 43 |
|
46 | | - public static UnaryOOCInstruction parseInstruction(String str) { |
47 | | - String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); |
48 | | - InstructionUtils.checkNumFields(parts, 2); |
49 | | - String opcode = parts[0]; |
50 | | - CPOperand in1 = new CPOperand(parts[1]); |
51 | | - CPOperand out = new CPOperand(parts[2]); |
| 44 | + public static UnaryOOCInstruction parseInstruction(String str) { |
| 45 | + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); |
| 46 | + InstructionUtils.checkNumFields(parts, 2); |
| 47 | + String opcode = parts[0]; |
| 48 | + CPOperand in1 = new CPOperand(parts[1]); |
| 49 | + CPOperand out = new CPOperand(parts[2]); |
52 | 50 |
|
53 | | - UnaryOperator uopcode = InstructionUtils.parseUnaryOperator(opcode); |
54 | | - return new UnaryOOCInstruction(OOCType.Unary, uopcode, in1, out, opcode, str); |
55 | | - } |
| 51 | + UnaryOperator uopcode = InstructionUtils.parseUnaryOperator(opcode); |
| 52 | + return new UnaryOOCInstruction(OOCType.Unary, uopcode, in1, out, opcode, str); |
| 53 | + } |
56 | 54 |
|
57 | | - public void processInstruction( ExecutionContext ec ) { |
58 | | - UnaryOperator uop = (UnaryOperator) _uop; |
59 | | - // Create thread and process the unary operation |
60 | | - MatrixObject min = ec.getMatrixObject(input1); |
61 | | - LocalTaskQueue<IndexedMatrixValue> qIn = min.getStreamHandle(); |
62 | | - LocalTaskQueue<IndexedMatrixValue> qOut = new LocalTaskQueue<>(); |
63 | | - ec.getMatrixObject(output).setStreamHandle(qOut); |
| 55 | + public void processInstruction( ExecutionContext ec ) { |
| 56 | + UnaryOperator uop = (UnaryOperator) _uop; |
| 57 | + // Create thread and process the unary operation |
| 58 | + MatrixObject min = ec.getMatrixObject(input1); |
| 59 | + LocalTaskQueue<IndexedMatrixValue> qIn = min.getStreamHandle(); |
| 60 | + LocalTaskQueue<IndexedMatrixValue> qOut = new LocalTaskQueue<>(); |
| 61 | + ec.getMatrixObject(output).setStreamHandle(qOut); |
64 | 62 |
|
65 | 63 |
|
66 | | - ExecutorService pool = CommonThreadPool.get(); |
67 | | - try { |
68 | | - Future<?> task =pool.submit(() -> { |
69 | | - IndexedMatrixValue tmp = null; |
70 | | - try { |
71 | | - while ((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { |
72 | | - IndexedMatrixValue tmpOut = new IndexedMatrixValue(); |
73 | | - tmpOut.set(tmp.getIndexes(), |
74 | | - tmp.getValue().unaryOperations(uop, new MatrixBlock())); |
75 | | - qOut.enqueueTask(tmpOut); |
76 | | - } |
77 | | - qOut.closeInput(); |
78 | | - } |
79 | | - catch(Exception ex) { |
80 | | - throw new DMLRuntimeException(ex); |
81 | | - } |
82 | | - }); |
83 | | - task.get(); |
84 | | - } catch (ExecutionException | InterruptedException e) { |
85 | | - throw new RuntimeException(e); |
86 | | - } finally { |
87 | | - pool.shutdown(); |
88 | | - } |
89 | | - } |
| 64 | + ExecutorService pool = CommonThreadPool.get(); |
| 65 | + try { |
| 66 | + pool.submit(() -> { |
| 67 | + IndexedMatrixValue tmp = null; |
| 68 | + try { |
| 69 | + while ((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { |
| 70 | + IndexedMatrixValue tmpOut = new IndexedMatrixValue(); |
| 71 | + tmpOut.set(tmp.getIndexes(), |
| 72 | + tmp.getValue().unaryOperations(uop, new MatrixBlock())); |
| 73 | + qOut.enqueueTask(tmpOut); |
| 74 | + } |
| 75 | + qOut.closeInput(); |
| 76 | + } |
| 77 | + catch(Exception ex) { |
| 78 | + throw new DMLRuntimeException(ex); |
| 79 | + } |
| 80 | + }); |
| 81 | + } catch (Exception ex) { |
| 82 | + throw new DMLRuntimeException(ex); |
| 83 | + } finally { |
| 84 | + pool.shutdown(); |
| 85 | + } |
| 86 | + } |
90 | 87 | } |
0 commit comments