Skip to content

Commit ebd3aed

Browse files
jessicapriebemboehm7
authored andcommitted
[SYSTEMDS-3915] Out-of-core ctable operations
Closes #2342.
1 parent 9e0a481 commit ebd3aed

File tree

8 files changed

+472
-2
lines changed

8 files changed

+472
-2
lines changed

src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.apache.sysds.runtime.instructions.ooc.AggregateUnaryOOCInstruction;
2727
import org.apache.sysds.runtime.instructions.ooc.BinaryOOCInstruction;
2828
import org.apache.sysds.runtime.instructions.ooc.CentralMomentOOCInstruction;
29+
import org.apache.sysds.runtime.instructions.ooc.CtableOOCInstruction;
2930
import org.apache.sysds.runtime.instructions.ooc.OOCInstruction;
3031
import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction;
3132
import org.apache.sysds.runtime.instructions.ooc.TSMMOOCInstruction;
@@ -72,7 +73,9 @@ public static OOCInstruction parseSingleInstruction(InstructionType ooctype, Str
7273
return TeeOOCInstruction.parseInstruction(str);
7374
case CentralMoment:
7475
return CentralMomentOOCInstruction.parseInstruction(str);
75-
76+
case Ctable:
77+
return CtableOOCInstruction.parseInstruction(str);
78+
7679
default:
7780
throw new DMLRuntimeException("Invalid OOC Instruction Type: " + ooctype);
7881
}
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.runtime.instructions.ooc;
21+
22+
import java.util.HashMap;
23+
24+
import org.apache.sysds.common.Types;
25+
import org.apache.sysds.lops.Ctable;
26+
import org.apache.sysds.runtime.DMLRuntimeException;
27+
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
28+
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
29+
import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
30+
import org.apache.sysds.runtime.instructions.Instruction;
31+
import org.apache.sysds.runtime.instructions.InstructionUtils;
32+
import org.apache.sysds.runtime.instructions.cp.CPOperand;
33+
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
34+
import org.apache.sysds.runtime.matrix.data.CTableMap;
35+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
36+
import org.apache.sysds.runtime.matrix.operators.Operator;
37+
import org.apache.sysds.runtime.util.DataConverter;
38+
import org.apache.sysds.runtime.util.LongLongDoubleHashMap;
39+
40+
public class CtableOOCInstruction extends ComputationOOCInstruction {
41+
private final CPOperand _outDim1;
42+
private final CPOperand _outDim2;
43+
private final boolean _ignoreZeros;
44+
45+
protected CtableOOCInstruction(OOCType type, Operator op, CPOperand in1,
46+
CPOperand in2, CPOperand in3, CPOperand out, CPOperand outDim1, CPOperand outDim2,
47+
boolean ignoreZeros, String opcode, String istr)
48+
{
49+
super(type, op, in1, in2, in3, out, opcode, istr);
50+
_ignoreZeros = ignoreZeros;
51+
_outDim1 = outDim1;
52+
_outDim2 = outDim2;
53+
}
54+
55+
public static CtableOOCInstruction parseInstruction(String str) {
56+
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
57+
InstructionUtils.checkNumFields(parts, 8);
58+
59+
String opcode = parts[0];
60+
CPOperand in1 = new CPOperand(parts[1]);
61+
CPOperand in2 = new CPOperand(parts[2]);
62+
CPOperand in3 = new CPOperand(parts[3]);
63+
CPOperand out = new CPOperand(parts[6]);
64+
65+
String[] dim1Fields = parts[4].split(Instruction.LITERAL_PREFIX);
66+
String[] dim2Fields = parts[5].split(Instruction.LITERAL_PREFIX);
67+
CPOperand outDim1 = new CPOperand(dim1Fields[0], Types.ValueType.FP64,
68+
Types.DataType.SCALAR, Boolean.parseBoolean(dim1Fields[1]));
69+
CPOperand outDim2 = new CPOperand(dim2Fields[0], Types.ValueType.FP64,
70+
Types.DataType.SCALAR, Boolean.parseBoolean(dim2Fields[1]));
71+
72+
boolean ignoreZeros = Boolean.parseBoolean(parts[7]);
73+
74+
// does not require any op
75+
return new CtableOOCInstruction(OOCType.Ctable, null, in1, in2, in3, out, outDim1, outDim2, ignoreZeros, opcode, str);
76+
}
77+
78+
@Override
79+
public void processInstruction( ExecutionContext ec ) {
80+
81+
MatrixObject in1 = ec.getMatrixObject(input1); // stream
82+
LocalTaskQueue<IndexedMatrixValue> qIn1 = in1.getStreamHandle();
83+
IndexedMatrixValue tmp1 = null;
84+
85+
long outputDim1 = ec.getScalarInput(_outDim1).getLongValue();
86+
long outputDim2 = ec.getScalarInput(_outDim2).getLongValue();
87+
88+
long cols = in1.getDataCharacteristics().getNumColBlocks();
89+
CTableMap map = new CTableMap(LongLongDoubleHashMap.EntryType.INT);
90+
91+
Ctable.OperationTypes ctableOp = findCtableOperation();
92+
MatrixObject in2 = null, in3 = null;
93+
LocalTaskQueue<IndexedMatrixValue> qIn2 = null, qIn3 = null;
94+
double cst2 = 0, cst3 = 0;
95+
96+
// init vars based on ctableOp
97+
if (ctableOp.hasSecondInput()){
98+
in2 = ec.getMatrixObject(input2); // stream
99+
qIn2 = in2.getStreamHandle();
100+
} else
101+
cst2 = ec.getScalarInput(input2).getDoubleValue();
102+
103+
if (ctableOp.hasThirdInput()){
104+
in3 = ec.getMatrixObject(input3); // stream
105+
qIn3 = in3.getStreamHandle();
106+
} else
107+
cst3 = ec.getScalarInput(input3).getDoubleValue();
108+
109+
HashMap<Long, MatrixBlock> blocksIn2 = new HashMap<>(), blocksIn3 = new HashMap<>();
110+
MatrixBlock block2, block3;
111+
112+
// only init result block if output dims known and dense
113+
MatrixBlock result = null;
114+
boolean outputDimsKnown = (outputDim1 != -1 && outputDim2 != -1);
115+
if (outputDimsKnown){
116+
long totalRows = in1.getDataCharacteristics().getRows();
117+
long totalCols = in1.getDataCharacteristics().getCols();
118+
boolean sparse = MatrixBlock.evalSparseFormatInMemory(outputDim1, outputDim2, totalRows*totalCols);
119+
if(!sparse)
120+
result = new MatrixBlock((int)outputDim1, (int)outputDim2, false);
121+
}
122+
123+
try {
124+
while((tmp1 = qIn1.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) {
125+
126+
MatrixBlock block1 = (MatrixBlock) tmp1.getValue();
127+
long r = tmp1.getIndexes().getRowIndex();
128+
long c = tmp1.getIndexes().getColumnIndex();
129+
long key = (r-1) * cols + (c-1);
130+
131+
switch(ctableOp) {
132+
case CTABLE_TRANSFORM:
133+
// ctable(A,B,W)
134+
block2 = getOrDequeueBlock(key, cols, blocksIn2, qIn2);
135+
block3 = getOrDequeueBlock(key, cols, blocksIn3, qIn3);
136+
block1.ctableOperations(_optr, block2, block3, map, result);
137+
break;
138+
case CTABLE_TRANSFORM_SCALAR_WEIGHT:
139+
// ctable(A,B) or ctable(A,B,1)
140+
block2 = getOrDequeueBlock(key, cols, blocksIn2, qIn2);
141+
block1.ctableOperations(_optr, block2, cst3, _ignoreZeros, map, result);
142+
break;
143+
case CTABLE_TRANSFORM_HISTOGRAM:
144+
// ctable(A,1) or ctable(A,1,1)
145+
block1.ctableOperations(_optr, cst2, cst3, map, result);
146+
break;
147+
case CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM:
148+
// ctable(A,1,W)
149+
block3 = getOrDequeueBlock(key, cols, blocksIn3, qIn3);
150+
block1.ctableOperations(_optr, cst2, block3, map, result);
151+
break;
152+
153+
default:
154+
throw new DMLRuntimeException("Encountered an invalid OOC ctable operation "
155+
+ "("+ctableOp+") while executing instruction: " + this);
156+
}
157+
}
158+
if (result == null){
159+
if(outputDimsKnown)
160+
result = DataConverter.convertToMatrixBlock(map, (int)outputDim1, (int)outputDim2);
161+
else
162+
result = DataConverter.convertToMatrixBlock(map);
163+
}
164+
else
165+
result.examSparsity();
166+
167+
ec.setMatrixOutput(output.getName(), result);
168+
}
169+
catch(Exception ex) {
170+
throw new DMLRuntimeException(ex);
171+
}
172+
}
173+
174+
private MatrixBlock getOrDequeueBlock(long key, long cols, HashMap<Long, MatrixBlock> blocks,
175+
LocalTaskQueue<IndexedMatrixValue> queue) throws InterruptedException
176+
{
177+
MatrixBlock block = blocks.get(key);
178+
if (block == null) {
179+
IndexedMatrixValue tmp;
180+
// corresponding block still in queue, dequeue until found
181+
while ((tmp = queue.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) {
182+
block = (MatrixBlock) tmp.getValue();
183+
long r = tmp.getIndexes().getRowIndex();
184+
long c = tmp.getIndexes().getColumnIndex();
185+
long tmpKey = (r-1) * cols + (c-1);
186+
// found corresponding block
187+
if (tmpKey == key) break;
188+
// store all dequeued blocks in cache that we don't need yet
189+
blocks.put(tmpKey, block);
190+
}
191+
}
192+
else
193+
blocks.remove(key); // needed only once
194+
195+
return block;
196+
}
197+
198+
private Ctable.OperationTypes findCtableOperation() {
199+
Types.DataType dt1 = input1.getDataType();
200+
Types.DataType dt2 = input2.getDataType();
201+
Types.DataType dt3 = input3.getDataType();
202+
return Ctable.findCtableOperationByInputDataTypes(dt1, dt2, dt3);
203+
}
204+
}

src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public abstract class OOCInstruction extends Instruction {
3333
protected static final Log LOG = LogFactory.getLog(OOCInstruction.class.getName());
3434

3535
public enum OOCType {
36-
Reblock, Tee, Binary, Unary, AggregateUnary, AggregateBinary, MAPMM, MMTSJ, Reorg, CM
36+
Reblock, Tee, Binary, Unary, AggregateUnary, AggregateBinary, MAPMM, MMTSJ, Reorg, CM, Ctable
3737
}
3838

3939
protected final OOCInstruction.OOCType _ooctype;
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.functions.ooc;
21+
22+
import org.apache.sysds.common.Opcodes;
23+
import org.apache.sysds.common.Types;
24+
import org.apache.sysds.runtime.instructions.Instruction;
25+
import org.apache.sysds.runtime.io.MatrixWriter;
26+
import org.apache.sysds.runtime.io.MatrixWriterFactory;
27+
import org.apache.sysds.runtime.matrix.data.MatrixValue;
28+
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
29+
import org.apache.sysds.runtime.util.DataConverter;
30+
import org.apache.sysds.runtime.util.HDFSTool;
31+
import org.apache.sysds.test.AutomatedTestBase;
32+
import org.apache.sysds.test.TestConfiguration;
33+
import org.apache.sysds.test.TestUtils;
34+
import org.junit.Assert;
35+
import org.junit.Test;
36+
37+
import java.util.HashMap;
38+
39+
public class CTableTest extends AutomatedTestBase{
40+
private static final String TEST_NAME1 = "CTableTest";
41+
private static final String TEST_NAME2 = "WeightedCTableTest";
42+
private static final String TEST_DIR = "functions/ooc/";
43+
private static final String TEST_CLASS_DIR = TEST_DIR + CTableTest.class.getSimpleName() + "/";
44+
45+
private static final String INPUT_NAME1 = "v";
46+
private static final String INPUT_NAME2 = "w";
47+
private static final String INPUT_NAME3 = "weights";
48+
private static final String OUTPUT_NAME = "res";
49+
50+
private final static double eps = 1e-10;
51+
52+
@Override
53+
public void setUp() {
54+
TestUtils.clearAssertionInformation();
55+
addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1));
56+
addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2));
57+
}
58+
59+
@Test
60+
public void testCTableSimple(){ testCTable(1372, 1012, 5, 5, false);}
61+
62+
@Test
63+
public void testCTableValueSetDifferencesNonEmpty(){ testCTable(2000, 37, 4995, 5, false);}
64+
65+
@Test
66+
public void testWeightedCTableSimple(){ testCTable(1372, 1012, 5, 5, true);}
67+
68+
@Test
69+
public void testWeightedCTableValueSetDifferencesNonEmpty(){ testCTable(2000, 37, 4995, 5, true);}
70+
71+
72+
public void testCTable(int rows, int cols, int maxValV, int maxValW, boolean isWeighted)
73+
{
74+
Types.ExecMode platformOld = rtplatform;
75+
rtplatform = Types.ExecMode.SINGLE_NODE;
76+
77+
try {
78+
String TEST_NAME = isWeighted? TEST_NAME2:TEST_NAME1;
79+
80+
getAndLoadTestConfiguration(TEST_NAME);
81+
String HOME = SCRIPT_DIR + TEST_DIR;
82+
fullDMLScriptName = HOME + TEST_NAME + ".dml";
83+
if (isWeighted)
84+
programArgs = new String[] {"-explain", "-stats", "-ooc", "-args", input(INPUT_NAME1), input(INPUT_NAME2), input(INPUT_NAME3), output(OUTPUT_NAME)};
85+
else
86+
programArgs = new String[] {"-explain", "-stats", "-ooc", "-args", input(INPUT_NAME1), input(INPUT_NAME2), output(OUTPUT_NAME)};
87+
88+
fullRScriptName = HOME + TEST_NAME + ".R";
89+
rCmd = "Rscript" + " " + fullRScriptName + " " + inputDir() + " " + expectedDir();
90+
91+
// values <=0 invalid
92+
double[][] v = TestUtils.floor(getRandomMatrix(rows, cols, 1, maxValV, 1.0, 7));
93+
double[][] w = TestUtils.floor(getRandomMatrix(rows, cols, 1, maxValW, 1.0, 13));
94+
double[][] weights = null;
95+
96+
MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY);
97+
writer.writeMatrixToHDFS(DataConverter.convertToMatrixBlock(v), input(INPUT_NAME1), rows, cols, 1000, rows*cols);
98+
writer.writeMatrixToHDFS(DataConverter.convertToMatrixBlock(w), input(INPUT_NAME2), rows, cols, 1000, rows*cols);
99+
100+
HDFSTool.writeMetaDataFile(input(INPUT_NAME1+".mtd"), Types.ValueType.FP64, new MatrixCharacteristics(rows,cols,1000,rows*cols), Types.FileFormat.BINARY);
101+
HDFSTool.writeMetaDataFile(input(INPUT_NAME2+".mtd"), Types.ValueType.FP64, new MatrixCharacteristics(rows,cols,1000,rows*cols), Types.FileFormat.BINARY);
102+
103+
// for RScript
104+
writeInputMatrixWithMTD("vR", v, true);
105+
writeInputMatrixWithMTD("wR", w, true);
106+
107+
if (isWeighted) {
108+
weights = TestUtils.floor(getRandomMatrix(rows, cols, 1, maxValW, 1.0, 17));
109+
writer.writeMatrixToHDFS(DataConverter.convertToMatrixBlock(weights),
110+
input(INPUT_NAME3), rows, cols, 1000, rows * cols);
111+
HDFSTool.writeMetaDataFile(input(INPUT_NAME3 + ".mtd"), Types.ValueType.FP64,
112+
new MatrixCharacteristics(rows, cols, 1000, rows * cols), Types.FileFormat.BINARY);
113+
writeInputMatrixWithMTD("weightsR", weights, true);
114+
}
115+
116+
runTest(true, false, null, -1);
117+
runRScript(true);
118+
119+
// compare matrices
120+
HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromExpectedDir("resR");
121+
122+
double[][] rRes = TestUtils.convertHashMapToDoubleArray(rfile);
123+
double[][] dmlRes = DataConverter.convertToDoubleMatrix(
124+
DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME),
125+
Types.FileFormat.BINARY, rRes.length, rRes[0].length, 1000, 1000));
126+
TestUtils.compareMatrices(rRes, dmlRes, eps);
127+
128+
String prefix = Instruction.OOC_INST_PREFIX;
129+
Assert.assertTrue("OOC wasn't used for RBLK",
130+
heavyHittersContainsString(prefix + Opcodes.RBLK));
131+
}
132+
catch(Exception ex) {
133+
Assert.fail(ex.getMessage());
134+
}
135+
finally {
136+
resetExecMode(platformOld);
137+
}
138+
}
139+
}

0 commit comments

Comments
 (0)