Skip to content

Commit bc3216a

Browse files
janniklindemboehm7
authored andcommitted
[SYSTEMDS-3927] Out-of-core centralMoment operations
Closes #2339.
1 parent 3024817 commit bc3216a

File tree

10 files changed

+525
-2
lines changed

10 files changed

+525
-2
lines changed

src/main/java/org/apache/sysds/lops/CentralMoment.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ public String getInstructions(String input1, String input2, String input3, Strin
9797
getInputs().get(2).prepScalarInputOperand(getExecType()),
9898
prepOutputOperand(output));
9999
}
100-
if( getExecType() == ExecType.CP || getExecType() == ExecType.FED ) {
100+
if(getExecType() == ExecType.CP || getExecType() == ExecType.FED || getExecType() == ExecType.OOC) {
101101
sb.append(OPERAND_DELIMITOR);
102102
sb.append(_numThreads);
103103
if ( getExecType() == ExecType.FED ){

src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.apache.sysds.runtime.DMLRuntimeException;
2626
import org.apache.sysds.runtime.instructions.ooc.AggregateUnaryOOCInstruction;
2727
import org.apache.sysds.runtime.instructions.ooc.BinaryOOCInstruction;
28+
import org.apache.sysds.runtime.instructions.ooc.CentralMomentOOCInstruction;
2829
import org.apache.sysds.runtime.instructions.ooc.OOCInstruction;
2930
import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction;
3031
import org.apache.sysds.runtime.instructions.ooc.TSMMOOCInstruction;
@@ -69,6 +70,8 @@ public static OOCInstruction parseSingleInstruction(InstructionType ooctype, Str
6970
return TransposeOOCInstruction.parseInstruction(str);
7071
case Tee:
7172
return TeeOOCInstruction.parseInstruction(str);
73+
case CentralMoment:
74+
return CentralMomentOOCInstruction.parseInstruction(str);
7275

7376
default:
7477
throw new DMLRuntimeException("Invalid OOC Instruction Type: " + ooctype);

src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
3535
import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
3636
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
37+
import org.apache.sysds.runtime.matrix.operators.Operator;
3738
import org.apache.sysds.runtime.meta.DataCharacteristics;
3839
import org.apache.sysds.runtime.util.CommonThreadPool;
3940

@@ -49,6 +50,12 @@ protected AggregateUnaryOOCInstruction(OOCType type, AggregateUnaryOperator auop
4950
_aop = aop;
5051
}
5152

53+
protected AggregateUnaryOOCInstruction(OOCType type, Operator op, CPOperand in1, CPOperand in2, CPOperand in3,
54+
CPOperand out, String opcode, String istr) {
55+
super(type, op, in1, in2, in3, out, opcode, istr);
56+
_aop = null;
57+
}
58+
5259
public static AggregateUnaryOOCInstruction parseInstruction(String str) {
5360
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
5461
InstructionUtils.checkNumFields(parts, 2);
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.runtime.instructions.ooc;
21+
22+
import org.apache.sysds.runtime.DMLRuntimeException;
23+
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
24+
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
25+
import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
26+
import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
27+
import org.apache.sysds.runtime.instructions.cp.CPOperand;
28+
import org.apache.sysds.runtime.instructions.cp.CentralMomentCPInstruction;
29+
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
30+
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
31+
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
32+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
33+
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
34+
import org.apache.sysds.runtime.matrix.data.MatrixValue;
35+
import org.apache.sysds.runtime.matrix.operators.CMOperator;
36+
import org.apache.sysds.runtime.meta.DataCharacteristics;
37+
38+
import java.util.ArrayList;
39+
import java.util.HashMap;
40+
import java.util.List;
41+
import java.util.Map;
42+
import java.util.Optional;
43+
44+
public class CentralMomentOOCInstruction extends AggregateUnaryOOCInstruction {
45+
46+
private CentralMomentOOCInstruction(CMOperator cm, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out,
47+
String opcode, String str) {
48+
super(OOCType.CM, cm, in1, in2, in3, out, opcode, str);
49+
}
50+
51+
public static CentralMomentOOCInstruction parseInstruction(String str) {
52+
CentralMomentCPInstruction cpInst = CentralMomentCPInstruction.parseInstruction(str);
53+
return parseInstruction(cpInst);
54+
}
55+
56+
public static CentralMomentOOCInstruction parseInstruction(CentralMomentCPInstruction inst) {
57+
return new CentralMomentOOCInstruction((CMOperator) inst.getOperator(), inst.input1, inst.input2, inst.input3,
58+
inst.output, inst.getOpcode(), inst.getInstructionString());
59+
}
60+
61+
@Override
62+
public void processInstruction(ExecutionContext ec) {
63+
String output_name = output.getName();
64+
65+
/*
66+
* The "order" of the central moment in the instruction can
67+
* be set to INVALID when the exact value is unknown at
68+
* compilation time. We first need to determine the exact
69+
* order and update the CMOperator, if needed.
70+
*/
71+
72+
MatrixObject matObj = ec.getMatrixObject(input1.getName());
73+
LocalTaskQueue<IndexedMatrixValue> qIn = matObj.getStreamHandle();
74+
75+
CPOperand scalarInput = (input3 == null ? input2 : input3);
76+
ScalarObject order = ec.getScalarInput(scalarInput);
77+
78+
CMOperator cm_op = ((CMOperator) _optr);
79+
if(cm_op.getAggOpType() == CMOperator.AggregateOperationTypes.INVALID)
80+
cm_op = cm_op.setCMAggOp((int) order.getLongValue());
81+
82+
CMOperator finalCm_op = cm_op;
83+
84+
List<CM_COV_Object> cmObjs = new ArrayList<>();
85+
86+
if(input3 == null) {
87+
try {
88+
IndexedMatrixValue tmp;
89+
90+
while((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) {
91+
// We only handle MatrixBlock, other types of MatrixValue will fail here
92+
cmObjs.add(((MatrixBlock) tmp.getValue()).cmOperations(cm_op));
93+
}
94+
}
95+
catch(Exception ex) {
96+
throw new DMLRuntimeException(ex);
97+
}
98+
}
99+
else {
100+
// Here we use a hash join approach
101+
// Note that this may keep blocks in the cache for a while, depending on when a matching block arrives in the stream
102+
MatrixObject wtObj = ec.getMatrixObject(input2.getName());
103+
104+
DataCharacteristics dc = ec.getDataCharacteristics(input1.getName());
105+
DataCharacteristics dcW = ec.getDataCharacteristics(input2.getName());
106+
107+
if (dc.getBlocksize() != dcW.getBlocksize())
108+
throw new DMLRuntimeException("Different block sizes are not yet supported");
109+
110+
LocalTaskQueue<IndexedMatrixValue> wIn = wtObj.getStreamHandle();
111+
112+
try {
113+
IndexedMatrixValue tmp = qIn.dequeueTask();
114+
IndexedMatrixValue tmpW = wIn.dequeueTask();
115+
Map<MatrixIndexes, MatrixValue> left = new HashMap<>();
116+
Map<MatrixIndexes, MatrixValue> right = new HashMap<>();
117+
118+
boolean cont = tmp != LocalTaskQueue.NO_MORE_TASKS || tmpW != LocalTaskQueue.NO_MORE_TASKS;
119+
120+
while(cont) {
121+
cont = false;
122+
123+
if(tmp != LocalTaskQueue.NO_MORE_TASKS) {
124+
MatrixValue weights = right.remove(tmp.getIndexes());
125+
126+
if(weights != null)
127+
cmObjs.add(((MatrixBlock) tmp.getValue()).cmOperations(cm_op, (MatrixBlock) weights));
128+
else
129+
left.put(tmp.getIndexes(), tmp.getValue());
130+
131+
tmp = qIn.dequeueTask();
132+
cont = tmp != LocalTaskQueue.NO_MORE_TASKS;
133+
}
134+
135+
if(tmpW != LocalTaskQueue.NO_MORE_TASKS) {
136+
MatrixValue q = left.remove(tmpW.getIndexes());
137+
138+
if(q != null)
139+
cmObjs.add(((MatrixBlock) q).cmOperations(cm_op, (MatrixBlock) tmpW.getValue()));
140+
else
141+
right.put(tmpW.getIndexes(), tmpW.getValue());
142+
143+
tmpW = wIn.dequeueTask();
144+
cont |= tmpW != LocalTaskQueue.NO_MORE_TASKS;
145+
}
146+
}
147+
148+
if (!left.isEmpty() || !right.isEmpty())
149+
throw new DMLRuntimeException("Unmatched blocks: values=" + left.size() + ", weights=" + right.size());
150+
}
151+
catch(Exception ex) {
152+
throw new DMLRuntimeException(ex);
153+
}
154+
}
155+
156+
Optional<CM_COV_Object> res = cmObjs.stream()
157+
.reduce((arg0, arg1) -> (CM_COV_Object) finalCm_op.fn.execute(arg0, arg1));
158+
159+
try {
160+
ec.setScalarOutput(output_name, new DoubleObject(res.get().getRequiredResult(finalCm_op)));
161+
}
162+
catch(Exception ex) {
163+
throw new DMLRuntimeException(ex);
164+
}
165+
}
166+
}

src/main/java/org/apache/sysds/runtime/instructions/ooc/ComputationOOCInstruction.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ protected ComputationOOCInstruction(OOCType type, Operator op, CPOperand in1, CP
4242
output = out;
4343
}
4444

45+
protected ComputationOOCInstruction(OOCType type, Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr) {
46+
super(type, op, opcode, istr);
47+
input1 = in1;
48+
input2 = in2;
49+
input3 = in3;
50+
output = out;
51+
}
52+
4553
public String getOutputVariableName() {
4654
return output.getName();
4755
}

src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public abstract class OOCInstruction extends Instruction {
3333
protected static final Log LOG = LogFactory.getLog(OOCInstruction.class.getName());
3434

3535
public enum OOCType {
36-
Reblock, Tee, Binary, Unary, AggregateUnary, AggregateBinary, MAPMM, MMTSJ, Reorg,
36+
Reblock, Tee, Binary, Unary, AggregateUnary, AggregateBinary, MAPMM, MMTSJ, Reorg, CM
3737
}
3838

3939
protected final OOCInstruction.OOCType _ooctype;
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.functions.ooc;
21+
22+
import org.apache.sysds.common.Opcodes;
23+
import org.apache.sysds.common.Types;
24+
import org.apache.sysds.runtime.instructions.Instruction;
25+
import org.apache.sysds.runtime.io.MatrixWriter;
26+
import org.apache.sysds.runtime.io.MatrixWriterFactory;
27+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
28+
import org.apache.sysds.runtime.matrix.data.MatrixValue;
29+
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
30+
import org.apache.sysds.runtime.util.DataConverter;
31+
import org.apache.sysds.runtime.util.HDFSTool;
32+
import org.apache.sysds.test.AutomatedTestBase;
33+
import org.apache.sysds.test.TestConfiguration;
34+
import org.apache.sysds.test.TestUtils;
35+
import org.junit.Assert;
36+
import org.junit.Test;
37+
38+
import java.io.IOException;
39+
import java.util.HashMap;
40+
41+
public class CentralMomentTest extends AutomatedTestBase {
42+
private final static String TEST_NAME1 = "CentralMoment";
43+
private final static String TEST_DIR = "functions/ooc/";
44+
private final static String TEST_CLASS_DIR = TEST_DIR + CentralMomentTest.class.getSimpleName() + "/";
45+
private final static double eps = 1e-8;
46+
private static final String INPUT_NAME = "X";
47+
private static final String OUTPUT_NAME = "res";
48+
49+
private final static int rows = 1871;
50+
private final static int maxVal = 7;
51+
private final static double sparsity1 = 0.65;
52+
private final static double sparsity2 = 0.05;
53+
54+
@Override
55+
public void setUp() {
56+
TestUtils.clearAssertionInformation();
57+
TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1);
58+
addTestConfiguration(TEST_NAME1, config);
59+
}
60+
61+
@Test
62+
public void testCentralMoment2Dense() {
63+
runCentralMomentTest(2, false);
64+
}
65+
66+
@Test
67+
public void testCentralMoment3Dense() {
68+
runCentralMomentTest(3, false);
69+
}
70+
71+
@Test
72+
public void testCentralMoment4Dense() {
73+
runCentralMomentTest(4, false);
74+
}
75+
76+
@Test
77+
public void testCentralMoment2Sparse() {
78+
runCentralMomentTest(2, true);
79+
}
80+
81+
@Test
82+
public void testCentralMoment3Sparse() {
83+
runCentralMomentTest(3, true);
84+
}
85+
86+
@Test
87+
public void testCentralMoment4Sparse() {
88+
runCentralMomentTest(4, true);
89+
}
90+
91+
private void runCentralMomentTest(int order, boolean sparse) {
92+
Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE);
93+
94+
try {
95+
getAndLoadTestConfiguration(TEST_NAME1);
96+
97+
String HOME = SCRIPT_DIR + TEST_DIR;
98+
fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
99+
programArgs = new String[] {"-explain", "-stats", "-ooc", "-args", input(INPUT_NAME),
100+
Integer.toString(order), output(OUTPUT_NAME)};
101+
102+
// 1. Generate the data in-memory as MatrixBlock objects
103+
double[][] A_data = getRandomMatrix(rows, 1, 1, maxVal, sparse ? sparsity2 : sparsity1, 7);
104+
105+
// 2. Convert the double arrays to MatrixBlock objects
106+
MatrixBlock A_mb = DataConverter.convertToMatrixBlock(A_data);
107+
108+
// 3. Create a binary matrix writer
109+
MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY);
110+
111+
// 4. Write matrix A to a binary SequenceFile
112+
writer.writeMatrixToHDFS(A_mb, input(INPUT_NAME), rows, 1, 1000, A_mb.getNonZeros());
113+
HDFSTool.writeMetaDataFile(input(INPUT_NAME + ".mtd"), Types.ValueType.FP64,
114+
new MatrixCharacteristics(rows, 1, 1000, A_mb.getNonZeros()), Types.FileFormat.BINARY);
115+
116+
runTest(true, false, null, -1);
117+
118+
//check Central Moment OOC
119+
Assert.assertTrue("OOC wasn't used for CentralMoment",
120+
heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.CM));
121+
122+
//compare results
123+
124+
// rerun without ooc flag
125+
programArgs = new String[] {"-explain", "-stats", "-args", input(INPUT_NAME), Integer.toString(order),
126+
output(OUTPUT_NAME + "_target")};
127+
runTest(true, false, null, -1);
128+
129+
// compare matrices
130+
HashMap<MatrixValue.CellIndex, Double> ret1 = readDMLMatrixFromOutputDir(OUTPUT_NAME);
131+
HashMap<MatrixValue.CellIndex, Double> ret2 = readDMLMatrixFromOutputDir(OUTPUT_NAME + "_target");
132+
TestUtils.compareMatrices(ret1, ret2, eps, "Ret-1", "Ret-2");
133+
}
134+
catch(IOException e) {
135+
throw new RuntimeException(e);
136+
}
137+
finally {
138+
resetExecMode(platformOld);
139+
}
140+
}
141+
}

0 commit comments

Comments
 (0)