Skip to content

Commit 3779d50

Browse files
janniklindemboehm7
authored andcommitted
[SYSTEMDS-3891] New out-of-core seq instruction
Closes #2357.
1 parent b0ef875 commit 3779d50

File tree

5 files changed

+257
-1
lines changed

5 files changed

+257
-1
lines changed

src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import org.apache.sysds.runtime.instructions.ooc.CentralMomentOOCInstruction;
3030
import org.apache.sysds.runtime.instructions.ooc.CtableOOCInstruction;
3131
import org.apache.sysds.runtime.instructions.ooc.IndexingOOCInstruction;
32+
import org.apache.sysds.runtime.instructions.ooc.DataGenOOCInstruction;
3233
import org.apache.sysds.runtime.instructions.ooc.OOCInstruction;
3334
import org.apache.sysds.runtime.instructions.ooc.ParameterizedBuiltinOOCInstruction;
3435
import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction;
@@ -84,6 +85,8 @@ public static OOCInstruction parseSingleInstruction(InstructionType ooctype, Str
8485
return ParameterizedBuiltinOOCInstruction.parseInstruction(str);
8586
case MatrixIndexing:
8687
return IndexingOOCInstruction.parseInstruction(str);
88+
case Rand:
89+
return DataGenOOCInstruction.parseInstruction(str);
8790

8891
default:
8992
throw new DMLRuntimeException("Invalid OOC Instruction Type: " + ooctype);
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.runtime.instructions.ooc;
21+
22+
import org.apache.commons.lang3.NotImplementedException;
23+
import org.apache.sysds.common.Opcodes;
24+
import org.apache.sysds.common.Types;
25+
import org.apache.sysds.runtime.DMLRuntimeException;
26+
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
27+
import org.apache.sysds.runtime.instructions.InstructionUtils;
28+
import org.apache.sysds.runtime.instructions.cp.CPOperand;
29+
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
30+
import org.apache.sysds.runtime.matrix.data.LibMatrixDatagen;
31+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
32+
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
33+
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
34+
import org.apache.sysds.runtime.util.UtilFunctions;
35+
36+
public class DataGenOOCInstruction extends UnaryOOCInstruction {
37+
38+
private final int blen;
39+
private Types.OpOpDG method;
40+
41+
// sequence specific attributes
42+
private final CPOperand seq_from, seq_to, seq_incr;
43+
44+
public DataGenOOCInstruction(UnaryOperator op, Types.OpOpDG mthd, CPOperand in, CPOperand out, int blen, CPOperand seqFrom,
45+
CPOperand seqTo, CPOperand seqIncr, String opcode, String istr) {
46+
super(OOCType.Rand, op, in, out, opcode, istr);
47+
this.blen = blen;
48+
this.method = mthd;
49+
this.seq_from = seqFrom;
50+
this.seq_to = seqTo;
51+
this.seq_incr = seqIncr;
52+
}
53+
54+
public static DataGenOOCInstruction parseInstruction(String str) {
55+
Types.OpOpDG method = null;
56+
String[] s = InstructionUtils.getInstructionPartsWithValueType(str);
57+
String opcode = s[0];
58+
59+
if(opcode.equalsIgnoreCase(Opcodes.SEQUENCE.toString())) {
60+
method = Types.OpOpDG.SEQ;
61+
// 8 operands: rows, cols, blen, from, to, incr, outvar
62+
InstructionUtils.checkNumFields(s, 7);
63+
}
64+
else
65+
throw new NotImplementedException(); // TODO
66+
67+
CPOperand out = new CPOperand(s[s.length - 1]);
68+
UnaryOperator op = null;
69+
70+
if(method == Types.OpOpDG.SEQ) {
71+
int blen = Integer.parseInt(s[3]);
72+
CPOperand from = new CPOperand(s[4]);
73+
CPOperand to = new CPOperand(s[5]);
74+
CPOperand incr = new CPOperand(s[6]);
75+
76+
return new DataGenOOCInstruction(op, method, null, out, blen, from, to, incr, opcode, str);
77+
}
78+
else
79+
throw new NotImplementedException();
80+
}
81+
82+
@Override
83+
public void processInstruction(ExecutionContext ec) {
84+
final OOCStream<IndexedMatrixValue> qOut = createWritableStream();
85+
86+
// process specific datagen operator
87+
if(method == Types.OpOpDG.SEQ) {
88+
double lfrom = ec.getScalarInput(seq_from).getDoubleValue();
89+
double lto = ec.getScalarInput(seq_to).getDoubleValue();
90+
double lincr = ec.getScalarInput(seq_incr).getDoubleValue();
91+
92+
// handle default 1 to -1 for special case of from>to
93+
lincr = LibMatrixDatagen.updateSeqIncr(lfrom, lto, lincr);
94+
95+
if(LOG.isTraceEnabled())
96+
LOG.trace(
97+
"Process DataGenOOCInstruction seq with seqFrom=" + lfrom + ", seqTo=" + lto + ", seqIncr" + lincr);
98+
99+
final int maxK = (int) UtilFunctions.getSeqLength(lfrom, lto, lincr);
100+
final double finalLincr = lincr;
101+
102+
103+
submitOOCTask(() -> {
104+
int k = 0;
105+
double curFrom = lfrom;
106+
double curTo;
107+
MatrixBlock mb;
108+
109+
while (k < maxK) {
110+
long desiredLen = Math.min(blen, maxK - k);
111+
curTo = curFrom + (desiredLen - 1) * finalLincr;
112+
long actualLen = UtilFunctions.getSeqLength(curFrom, curTo, finalLincr);
113+
114+
if (actualLen != desiredLen) {
115+
// Then we add / subtract a small correction term
116+
curTo += (actualLen < desiredLen) ? finalLincr / 2 : -finalLincr / 2;
117+
118+
if (UtilFunctions.getSeqLength(curFrom, curTo, finalLincr) != desiredLen)
119+
throw new DMLRuntimeException("OOC seq could not construct the right number of elements.");
120+
}
121+
122+
mb = MatrixBlock.seqOperations(curFrom, curTo, finalLincr);
123+
qOut.enqueue(new IndexedMatrixValue(new MatrixIndexes(1 + k / blen, 1), mb));
124+
curFrom = mb.get(mb.getNumRows() - 1, 0) + finalLincr;
125+
k += blen;
126+
}
127+
128+
qOut.closeInput();
129+
}, qOut);
130+
}
131+
else
132+
throw new NotImplementedException();
133+
134+
ec.getMatrixObject(output).setStreamHandle(qOut);
135+
}
136+
}

src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ public abstract class OOCInstruction extends Instruction {
5454
private static final AtomicInteger nextStreamId = new AtomicInteger(0);
5555

5656
public enum OOCType {
57-
Reblock, Tee, Binary, Unary, AggregateUnary, AggregateBinary, MAPMM, MMTSJ, Reorg, CM, Ctable, MatrixIndexing, ParameterizedBuiltin
57+
Reblock, Tee, Binary, Unary, AggregateUnary, AggregateBinary, MAPMM, MMTSJ,
58+
Reorg, CM, Ctable, MatrixIndexing, ParameterizedBuiltin, Rand
5859
}
5960

6061
protected final OOCInstruction.OOCType _ooctype;
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.functions.ooc;
21+
22+
import org.apache.sysds.common.Opcodes;
23+
import org.apache.sysds.common.Types;
24+
import org.apache.sysds.runtime.instructions.Instruction;
25+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
26+
import org.apache.sysds.test.AutomatedTestBase;
27+
import org.apache.sysds.test.TestConfiguration;
28+
import org.apache.sysds.test.TestUtils;
29+
import org.junit.Assert;
30+
import org.junit.Test;
31+
32+
public class SeqTest extends AutomatedTestBase {
33+
private final static String TEST_NAME1 = "Seq";
34+
private final static String TEST_DIR = "functions/ooc/";
35+
private final static String TEST_CLASS_DIR = TEST_DIR + SeqTest.class.getSimpleName() + "/";
36+
private final static double eps = 1e-8;
37+
private static final String OUTPUT_NAME = "res";
38+
39+
@Override
40+
public void setUp() {
41+
TestUtils.clearAssertionInformation();
42+
TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1);
43+
addTestConfiguration(TEST_NAME1, config);
44+
}
45+
46+
@Test
47+
public void testSeq1() {
48+
runSeqTest(0, 10, 0.1);
49+
}
50+
51+
@Test
52+
public void testSeq2() {
53+
runSeqTest(0, 15.9, 0.01);
54+
}
55+
56+
private void runSeqTest(double from, double to, double incr) {
57+
Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE);
58+
59+
try {
60+
getAndLoadTestConfiguration(TEST_NAME1);
61+
62+
String HOME = SCRIPT_DIR + TEST_DIR;
63+
fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
64+
programArgs = new String[] {"-explain", "-stats", "-ooc", "-args", Double.toString(from), Double.toString(to), Double.toString(incr), output(OUTPUT_NAME)};
65+
66+
runTest(true, false, null, -1);
67+
68+
//check seq OOC
69+
Assert.assertTrue("OOC wasn't used for seq",
70+
heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.SEQUENCE));
71+
//compare results
72+
73+
// rerun without ooc flag
74+
programArgs = new String[] {"-explain", "-stats", "-args", Double.toString(from), Double.toString(to), Double.toString(incr), output(OUTPUT_NAME + "_target")};
75+
runTest(true, false, null, -1);
76+
77+
// compare matrices
78+
MatrixBlock ret1 = TestUtils.readBinary(output(OUTPUT_NAME));
79+
MatrixBlock ret2 = TestUtils.readBinary(output(OUTPUT_NAME + "_target"));
80+
81+
TestUtils.compareMatrices(ret1, ret2, eps);
82+
}
83+
finally {
84+
resetExecMode(platformOld);
85+
}
86+
}
87+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
# Read the input matrix as a stream
23+
from = $1;
24+
to = $2;
25+
incr = $3;
26+
27+
res = seq(from, to, incr);
28+
29+
write(res, $4, format="binary");

0 commit comments

Comments
 (0)