Skip to content

Commit a5b298c

Browse files
j143mboehm7
authored andcommitted
[SYSTEMDS-3899] Unary out-of-core-operations
Closes #2298.
1 parent 0b11ef7 commit a5b298c

File tree

5 files changed

+237
-1
lines changed

5 files changed

+237
-1
lines changed

src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.apache.sysds.runtime.instructions.ooc.BinaryOOCInstruction;
2828
import org.apache.sysds.runtime.instructions.ooc.OOCInstruction;
2929
import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction;
30+
import org.apache.sysds.runtime.instructions.ooc.UnaryOOCInstruction;
3031

3132
public class OOCInstructionParser extends InstructionParser {
3233
protected static final Log LOG = LogFactory.getLog(OOCInstructionParser.class.getName());
@@ -51,6 +52,8 @@ public static OOCInstruction parseSingleInstruction(InstructionType ooctype, Str
5152
return ReblockOOCInstruction.parseInstruction(str);
5253
case AggregateUnary:
5354
return AggregateUnaryOOCInstruction.parseInstruction(str);
55+
case Unary:
56+
return UnaryOOCInstruction.parseInstruction(str);
5457
case Binary:
5558
return BinaryOOCInstruction.parseInstruction(str);
5659

src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public abstract class OOCInstruction extends Instruction {
3030
protected static final Log LOG = LogFactory.getLog(OOCInstruction.class.getName());
3131

3232
public enum OOCType {
33-
Reblock, AggregateUnary, Binary
33+
Reblock, AggregateUnary, Binary, Unary
3434
}
3535

3636
protected final OOCInstruction.OOCType _ooctype;
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.runtime.instructions.ooc;
21+
22+
import org.apache.sysds.runtime.DMLRuntimeException;
23+
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
24+
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
25+
import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
26+
import org.apache.sysds.runtime.instructions.InstructionUtils;
27+
import org.apache.sysds.runtime.instructions.cp.CPOperand;
28+
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
29+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
30+
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
31+
import org.apache.sysds.runtime.util.CommonThreadPool;
32+
33+
import java.util.concurrent.ExecutionException;
34+
import java.util.concurrent.ExecutorService;
35+
import java.util.concurrent.Future;
36+
37+
public class UnaryOOCInstruction extends ComputationOOCInstruction {
38+
private UnaryOperator _uop = null;
39+
40+
protected UnaryOOCInstruction(OOCType type, UnaryOperator op, CPOperand in1, CPOperand out, String opcode, String istr) {
41+
super(type, op, in1, out, opcode, istr);
42+
43+
_uop = op;
44+
}
45+
46+
public static UnaryOOCInstruction parseInstruction(String str) {
47+
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
48+
InstructionUtils.checkNumFields(parts, 2);
49+
String opcode = parts[0];
50+
CPOperand in1 = new CPOperand(parts[1]);
51+
CPOperand out = new CPOperand(parts[2]);
52+
53+
UnaryOperator uopcode = InstructionUtils.parseUnaryOperator(opcode);
54+
return new UnaryOOCInstruction(OOCType.Unary, uopcode, in1, out, opcode, str);
55+
}
56+
57+
public void processInstruction( ExecutionContext ec ) {
58+
UnaryOperator uop = (UnaryOperator) _uop;
59+
// Create thread and process the unary operation
60+
MatrixObject min = ec.getMatrixObject(input1);
61+
LocalTaskQueue<IndexedMatrixValue> qIn = min.getStreamHandle();
62+
LocalTaskQueue<IndexedMatrixValue> qOut = new LocalTaskQueue<>();
63+
ec.getMatrixObject(output).setStreamHandle(qOut);
64+
65+
66+
ExecutorService pool = CommonThreadPool.get();
67+
try {
68+
Future<?> task =pool.submit(() -> {
69+
IndexedMatrixValue tmp = null;
70+
try {
71+
while ((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) {
72+
IndexedMatrixValue tmpOut = new IndexedMatrixValue();
73+
tmpOut.set(tmp.getIndexes(),
74+
tmp.getValue().unaryOperations(uop, new MatrixBlock()));
75+
qOut.enqueueTask(tmpOut);
76+
}
77+
qOut.closeInput();
78+
}
79+
catch(Exception ex) {
80+
throw new DMLRuntimeException(ex);
81+
}
82+
});
83+
task.get();
84+
} catch (ExecutionException | InterruptedException e) {
85+
throw new RuntimeException(e);
86+
} finally {
87+
pool.shutdown();
88+
}
89+
}
90+
}
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.functions.ooc;
21+
22+
import org.apache.sysds.common.Opcodes;
23+
import org.apache.sysds.common.Types;
24+
import org.apache.sysds.common.Types.ExecMode;
25+
import org.apache.sysds.common.Types.FileFormat;
26+
import org.apache.sysds.common.Types.ValueType;
27+
import org.apache.sysds.hops.OptimizerUtils;
28+
import org.apache.sysds.runtime.instructions.Instruction;
29+
import org.apache.sysds.runtime.io.MatrixWriter;
30+
import org.apache.sysds.runtime.io.MatrixWriterFactory;
31+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
32+
import org.apache.sysds.runtime.matrix.data.MatrixValue;
33+
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
34+
import org.apache.sysds.runtime.util.HDFSTool;
35+
import org.apache.sysds.test.AutomatedTestBase;
36+
import org.apache.sysds.test.TestConfiguration;
37+
import org.apache.sysds.test.TestUtils;
38+
import org.junit.Assert;
39+
import org.junit.Test;
40+
41+
import java.util.HashMap;
42+
43+
public class UnaryTest extends AutomatedTestBase {
44+
45+
private static final String TEST_NAME = "Unary";
46+
private static final String TEST_DIR = "functions/ooc/";
47+
private static final String TEST_CLASS_DIR = TEST_DIR + UnaryTest.class.getSimpleName() + "/";
48+
private static final String INPUT_NAME = "X";
49+
private static final String OUTPUT_NAME = "res";
50+
51+
@Override
52+
public void setUp() {
53+
TestUtils.clearAssertionInformation();
54+
TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME);
55+
addTestConfiguration(TEST_NAME, config);
56+
}
57+
58+
/**
59+
* Test the sum of scalar multiplication, "sum(X*7)", with OOC backend.
60+
*/
61+
@Test
62+
public void testUnary() {
63+
testUnaryOperation(false);
64+
}
65+
66+
67+
public void testUnaryOperation(boolean rewrite)
68+
{
69+
Types.ExecMode platformOld = setExecMode(ExecMode.SINGLE_NODE);
70+
boolean oldRewrite = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
71+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrite;
72+
73+
try {
74+
getAndLoadTestConfiguration(TEST_NAME);
75+
String HOME = SCRIPT_DIR + TEST_DIR;
76+
fullDMLScriptName = HOME + TEST_NAME + ".dml";
77+
programArgs = new String[] {"-explain", "-stats", "-ooc",
78+
"-args", input(INPUT_NAME), output(OUTPUT_NAME)};
79+
80+
int rows = 1000, cols = 1000;
81+
MatrixBlock mb = MatrixBlock.randOperations(rows, cols, 1.0, -1, 1, "uniform", 7);
82+
MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(FileFormat.BINARY);
83+
writer.writeMatrixToHDFS(mb, input(INPUT_NAME), rows, cols, 1000, rows*cols);
84+
HDFSTool.writeMetaDataFile(input(INPUT_NAME+".mtd"), ValueType.FP64,
85+
new MatrixCharacteristics(rows,cols,1000,rows*cols), FileFormat.BINARY);
86+
87+
runTest(true, false, null, -1);
88+
89+
HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir(OUTPUT_NAME);
90+
Double result = dmlfile.get(new MatrixValue.CellIndex(1, 1));
91+
double expected = 0.0;
92+
for(int i = 0; i < rows; i++) {
93+
for(int j = 0; j < cols; j++) {
94+
expected += Math.ceil(mb.get(i, j));
95+
}
96+
}
97+
98+
Assert.assertEquals(expected, result, 1e-10);
99+
100+
String prefix = Instruction.OOC_INST_PREFIX;
101+
Assert.assertTrue("OOC wasn't used for RBLK",
102+
heavyHittersContainsString(prefix + Opcodes.RBLK));
103+
Assert.assertTrue("OOC wasn't used for CEIL",
104+
heavyHittersContainsString(prefix + Opcodes.CEIL));
105+
}
106+
catch(Exception ex) {
107+
Assert.fail(ex.getMessage());
108+
}
109+
finally {
110+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldRewrite;
111+
resetExecMode(platformOld);
112+
}
113+
}
114+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
# Read input matrix and operator from command line args
23+
X = read($1);
24+
#print(toString(X))
25+
Y = ceil(X);
26+
#print(toString(Y))
27+
res = as.matrix(sum(Y));
28+
# Write the final matrix result
29+
write(res, $2);

0 commit comments

Comments
 (0)