Skip to content

Commit ca8d209

Browse files
committed
[SYSTEMDS-3894] New out-of-core binary scalar-matrix operations
This patch completes the selected example operations for the new out-of-core backend and related test.
1 parent a6faf44 commit ca8d209

File tree

4 files changed

+121
-12
lines changed

4 files changed

+121
-12
lines changed

src/main/java/org/apache/sysds/hops/BinaryOp.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -854,9 +854,6 @@ else if( (op == OpOp2.CBIND && getDataType().isList())
854854
_etype = ExecType.CP;
855855
}
856856

857-
if( _etype == ExecType.OOC ) //TODO
858-
setExecType(ExecType.CP);
859-
860857
//mark for recompile (forever)
861858
setRequiresRecompileIfNecessary();
862859

src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.apache.sysds.common.InstructionType;
2525
import org.apache.sysds.runtime.DMLRuntimeException;
2626
import org.apache.sysds.runtime.instructions.ooc.AggregateUnaryOOCInstruction;
27+
import org.apache.sysds.runtime.instructions.ooc.BinaryOOCInstruction;
2728
import org.apache.sysds.runtime.instructions.ooc.OOCInstruction;
2829
import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction;
2930

@@ -50,10 +51,9 @@ public static OOCInstruction parseSingleInstruction(InstructionType ooctype, Str
5051
return ReblockOOCInstruction.parseInstruction(str);
5152
case AggregateUnary:
5253
return AggregateUnaryOOCInstruction.parseInstruction(str);
53-
54-
// TODO:
5554
case Binary:
56-
55+
return BinaryOOCInstruction.parseInstruction(str);
56+
5757
default:
5858
throw new DMLRuntimeException("Invalid OOC Instruction Type: " + ooctype);
5959
}
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.runtime.instructions.ooc;
21+
22+
import java.util.concurrent.ExecutorService;
23+
24+
import org.apache.sysds.common.Types.DataType;
25+
import org.apache.sysds.runtime.DMLRuntimeException;
26+
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
27+
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
28+
import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
29+
import org.apache.sysds.runtime.instructions.InstructionUtils;
30+
import org.apache.sysds.runtime.instructions.cp.CPOperand;
31+
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
32+
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
33+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
34+
import org.apache.sysds.runtime.matrix.operators.Operator;
35+
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
36+
import org.apache.sysds.runtime.util.CommonThreadPool;
37+
38+
public class BinaryOOCInstruction extends ComputationOOCInstruction {
39+
40+
protected BinaryOOCInstruction(OOCType type, Operator bop,
41+
CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
42+
super(type, bop, in1, in2, out, opcode, istr);
43+
}
44+
45+
public static BinaryOOCInstruction parseInstruction(String str) {
46+
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
47+
InstructionUtils.checkNumFields(parts, 3);
48+
String opcode = parts[0];
49+
CPOperand in1 = new CPOperand(parts[1]);
50+
CPOperand in2 = new CPOperand(parts[2]);
51+
CPOperand out = new CPOperand(parts[3]);
52+
Operator bop = InstructionUtils.parseExtendedBinaryOrBuiltinOperator(opcode, in1, in2);
53+
54+
return new BinaryOOCInstruction(
55+
OOCType.Binary, bop, in1, in2, out, opcode, str);
56+
}
57+
58+
@Override
59+
public void processInstruction( ExecutionContext ec ) {
60+
//TODO support all types, currently only binary matrix-scalar
61+
62+
//get operator and scalar
63+
CPOperand scalar = ( input1.getDataType() == DataType.MATRIX ) ? input2 : input1;
64+
ScalarObject constant = ec.getScalarInput(scalar);
65+
ScalarOperator sc_op = ((ScalarOperator)_optr).setConstant(constant.getDoubleValue());
66+
67+
//create thread and process binary operation
68+
MatrixObject min = ec.getMatrixObject(input1);
69+
LocalTaskQueue<IndexedMatrixValue> qIn = min.getStreamHandle();
70+
LocalTaskQueue<IndexedMatrixValue> qOut = new LocalTaskQueue<>();
71+
ec.getMatrixObject(output).setStreamHandle(qOut);
72+
73+
ExecutorService pool = CommonThreadPool.get();
74+
try {
75+
pool.submit(() -> {
76+
IndexedMatrixValue tmp = null;
77+
try {
78+
while((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) {
79+
IndexedMatrixValue tmpOut = new IndexedMatrixValue();
80+
tmpOut.set(tmp.getIndexes(),
81+
tmp.getValue().scalarOperations(sc_op, new MatrixBlock()));
82+
qOut.enqueueTask(tmpOut);
83+
}
84+
qOut.closeInput();
85+
}
86+
catch(Exception ex) {
87+
throw new DMLRuntimeException(ex);
88+
}
89+
});
90+
}
91+
finally {
92+
pool.shutdown();
93+
}
94+
}
95+
}

src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.apache.sysds.common.Types;
2424
import org.apache.sysds.common.Types.FileFormat;
2525
import org.apache.sysds.common.Types.ValueType;
26+
import org.apache.sysds.hops.OptimizerUtils;
2627
import org.apache.sysds.runtime.instructions.Instruction;
2728
import org.apache.sysds.runtime.io.MatrixWriter;
2829
import org.apache.sysds.runtime.io.MatrixWriterFactory;
@@ -57,11 +58,26 @@ public void setUp() {
5758
* Test the sum of scalar multiplication, "sum(X*7)", with OOC backend.
5859
*/
5960
@Test
60-
public void testSumScalarMult() {
61-
61+
public void testSumScalarMultNoRewrite() {
62+
testSumScalarMult(false);
63+
}
64+
65+
/**
66+
* Test the sum of scalar multiplication, "sum(X)*7", with OOC backend.
67+
*/
68+
@Test
69+
public void testSumScalarMultRewrite() {
70+
testSumScalarMult(true);
71+
}
72+
73+
74+
public void testSumScalarMult(boolean rewrite)
75+
{
6276
Types.ExecMode platformOld = rtplatform;
6377
rtplatform = Types.ExecMode.SINGLE_NODE;
64-
78+
boolean oldRewrite = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
79+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrite;
80+
6581
try {
6682
getAndLoadTestConfiguration(TEST_NAME);
6783
String HOME = SCRIPT_DIR + TEST_DIR;
@@ -92,16 +108,17 @@ public void testSumScalarMult() {
92108
String prefix = Instruction.OOC_INST_PREFIX;
93109
Assert.assertTrue("OOC wasn't used for RBLK",
94110
heavyHittersContainsString(prefix + Opcodes.RBLK));
111+
if(!rewrite)
112+
Assert.assertTrue("OOC wasn't used for SUM",
113+
heavyHittersContainsString(prefix + Opcodes.MULT));
95114
Assert.assertTrue("OOC wasn't used for SUM",
96115
heavyHittersContainsString(prefix + Opcodes.UAKP));
97-
98-
// boolean usedOOCMult = Statistics.getCPHeavyHitterOpCodes().contains(prefix + Opcodes.MULT);
99-
// Assert.assertTrue("OOC wasn't used for MULT", usedOOCMult);
100116
}
101117
catch(Exception ex) {
102118
Assert.fail(ex.getMessage());
103119
}
104120
finally {
121+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldRewrite;
105122
resetExecMode(platformOld);
106123
}
107124
}

0 commit comments

Comments
 (0)