Skip to content

Commit 7440ef7

Browse files
jessicapriebemboehm7
authored andcommitted
[SYSTEMDS-3892] Initial out-of-core base instruction and parser
Closes #2289.
1 parent 12a2094 commit 7440ef7

File tree

11 files changed

+298
-2
lines changed

11 files changed

+298
-2
lines changed

src/main/java/org/apache/sysds/api/DMLOptions.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ public class DMLOptions {
6565
public ExecMode execMode = OptimizerUtils.getDefaultExecutionMode(); // Execution mode standalone, MR, Spark or a hybrid
6666
public boolean gpu = false; // Whether to use the GPU
6767
public boolean forceGPU = false; // Whether to ignore memory & estimates and always use the GPU
68+
public boolean ooc = false; // Whether to use the OOC backend
6869
public boolean debug = false; // to go into debug mode to be able to step through a program
6970
public String filePath = null; // path to script
7071
public String script = null; // the script itself
@@ -109,6 +110,7 @@ public String toString() {
109110
", execMode=" + execMode +
110111
", gpu=" + gpu +
111112
", forceGPU=" + forceGPU +
113+
", ooc=" + ooc +
112114
", debug=" + debug +
113115
", filePath='" + filePath + '\'' +
114116
", script='" + script + '\'' +
@@ -182,6 +184,7 @@ else if (lineageType.equalsIgnoreCase("debugger"))
182184
}
183185
}
184186
}
187+
dmlOptions.ooc = line.hasOption("ooc");
185188
if (line.hasOption("exec")){
186189
String execMode = line.getOptionValue("exec");
187190
if (execMode.equalsIgnoreCase("singlenode")) dmlOptions.execMode = ExecMode.SINGLE_NODE;
@@ -388,6 +391,8 @@ private static Options createCLIOptions() {
388391
Option gpuOpt = OptionBuilder.withArgName("force")
389392
.withDescription("uses CUDA instructions when reasonable; set <force> option to skip conservative memory estimates and use GPU wherever possible; default off")
390393
.hasOptionalArg().create("gpu");
394+
Option oocOpt = OptionBuilder.withDescription("uses OOC backend")
395+
.create("ooc");
391396
Option debugOpt = OptionBuilder.withDescription("runs in debug mode; default off")
392397
.create("debug");
393398
Option pythonOpt = OptionBuilder
@@ -441,6 +446,7 @@ private static Options createCLIOptions() {
441446
options.addOption(explainOpt);
442447
options.addOption(execOpt);
443448
options.addOption(gpuOpt);
449+
options.addOption(oocOpt);
444450
options.addOption(debugOpt);
445451
options.addOption(lineageOpt);
446452
options.addOption(fedOpt);

src/main/java/org/apache/sysds/api/DMLScript.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ public class DMLScript
147147
public static boolean FORCE_ACCELERATOR = DMLOptions.defaultOptions.forceGPU;
148148
// Enable synchronizing GPU after every instruction
149149
public static boolean SYNCHRONIZE_GPU = true;
150+
// Set OOC backend
151+
public static boolean USE_OOC = DMLOptions.defaultOptions.ooc;
150152
// Enable eager CUDA free on rmvar
151153
public static boolean EAGER_CUDA_FREE = false;
152154

@@ -266,6 +268,7 @@ public static boolean executeScript( String[] args )
266268
JMLC_MEM_STATISTICS = dmlOptions.memStats;
267269
USE_ACCELERATOR = dmlOptions.gpu;
268270
FORCE_ACCELERATOR = dmlOptions.forceGPU;
271+
USE_OOC = dmlOptions.ooc;
269272
EXPLAIN = dmlOptions.explainType;
270273
EXEC_MODE = dmlOptions.execMode;
271274
LINEAGE = dmlOptions.lineage;

src/main/java/org/apache/sysds/common/Types.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ public interface Types {
3232
* Execution mode for entire script. This setting specify which {@link ExecType}s are allowed.
3333
*/
3434
public enum ExecMode {
35-
/** Execute all operations in {@link ExecType#CP} and if available {@link ExecType#GPU} */
35+
/** Execute all operations in {@link ExecType#CP}, {@link ExecType#OOC} and if available {@link ExecType#GPU} */
3636
SINGLE_NODE,
3737
/**
3838
* The default and encouraged ExecMode. Execute operations while leveraging all available options:
@@ -58,6 +58,8 @@ public enum ExecType {
5858
GPU,
5959
/** FED: indicate that the instruction should be executed as a Federated instruction */
6060
FED,
61+
/** Out of Core: indicate that the operation should be executed out of core. */
62+
OOC,
6163
/** invalid is used for debugging or if it is undecided where the current instruction should be executed */
6264
INVALID
6365
}

src/main/java/org/apache/sysds/hops/Hop.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,9 @@ else if ( DMLScript.getGlobalExecMode() == ExecMode.SINGLE_NODE && _etypeForced
263263
if(_etypeForced != ExecType.CP && _etypeForced != ExecType.GPU)
264264
_etypeForced = ExecType.CP;
265265
}
266+
else if (DMLScript.USE_OOC){
267+
_etypeForced = ExecType.OOC;
268+
}
266269
else {
267270
// enabled with -exec singlenode option
268271
_etypeForced = ExecType.CP;

src/main/java/org/apache/sysds/runtime/instructions/Instruction.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ public enum IType {
3535
BREAKPOINT,
3636
SPARK,
3737
GPU,
38-
FEDERATED
38+
FEDERATED,
39+
OUT_OF_CORE
3940
}
4041

4142
protected static final Log LOG = LogFactory.getLog(Instruction.class.getName());
@@ -53,6 +54,7 @@ protected Instruction(Operator _optr){
5354
public static final String SP_INST_PREFIX = "sp_";
5455
public static final String GPU_INST_PREFIX = "gpu_";
5556
public static final String FEDERATED_INST_PREFIX = "fed_";
57+
public static final String OOC_INST_PREFIX = "ooc_";
5658

5759
//basic instruction meta data
5860
protected String instString = null;
@@ -184,6 +186,8 @@ else if( getType() == IType.GPU )
184186
extendedOpcode = GPU_INST_PREFIX + getOpcode();
185187
else if( getType() == IType.FEDERATED)
186188
extendedOpcode = FEDERATED_INST_PREFIX + getOpcode();
189+
else if( getType() == IType.OUT_OF_CORE)
190+
extendedOpcode = OOC_INST_PREFIX + getOpcode();
187191
else
188192
extendedOpcode = getOpcode();
189193
}

src/main/java/org/apache/sysds/runtime/instructions/InstructionParser.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ public static Instruction parseSingleInstruction ( String str ) {
5353
if( fedtype == null )
5454
throw new DMLRuntimeException("Unknown FEDERATED instruction: " + str);
5555
return FEDInstructionParser.parseSingleInstruction (fedtype, str);
56+
case OOC:
57+
InstructionType ooctype = InstructionUtils.getOOCType(str);
58+
if( ooctype == null )
59+
throw new DMLRuntimeException("Unknown OOC instruction: " + str);
60+
return OOCInstructionParser.parseSingleInstruction (ooctype, str);
5661
default:
5762
throw new DMLRuntimeException("Unknown execution type in instruction: " + str);
5863
}

src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,11 @@ public static InstructionType getFEDType(String str) {
281281
return Opcodes.getTypeByOpcode(op, Types.ExecType.FED);
282282
}
283283

284+
public static InstructionType getOOCType(String str) {
285+
String op = getOpCode(str);
286+
return Opcodes.getTypeByOpcode(op, Types.ExecType.OOC);
287+
}
288+
284289
public static boolean isBuiltinFunction( String opcode ) {
285290
Builtin.BuiltinCode bfc = Builtin.String2BuiltinCode.get(opcode);
286291
return (bfc != null);
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.runtime.instructions;
21+
22+
import org.apache.commons.logging.Log;
23+
import org.apache.commons.logging.LogFactory;
24+
import org.apache.sysds.common.InstructionType;
25+
import org.apache.sysds.runtime.DMLRuntimeException;
26+
import org.apache.sysds.runtime.instructions.ooc.OOCInstruction;
27+
28+
public class OOCInstructionParser extends InstructionParser {
29+
protected static final Log LOG = LogFactory.getLog(OOCInstructionParser.class.getName());
30+
31+
public static OOCInstruction parseSingleInstruction(String str) {
32+
if(str == null || str.isEmpty())
33+
return null;
34+
InstructionType ooctype = InstructionUtils.getOOCType(str);
35+
if(ooctype == null)
36+
throw new DMLRuntimeException("Unable derive ooctype for instruction: " + str);
37+
OOCInstruction oocinst = parseSingleInstruction(ooctype, str);
38+
if(oocinst == null)
39+
throw new DMLRuntimeException("Unable to parse instruction: " + str);
40+
return oocinst;
41+
}
42+
43+
public static OOCInstruction parseSingleInstruction(InstructionType ooctype, String str) {
44+
if(str == null || str.isEmpty())
45+
return null;
46+
switch(ooctype) {
47+
48+
// TODO:
49+
case AggregateUnary:
50+
case Binary:
51+
52+
default:
53+
throw new DMLRuntimeException("Invalid OOC Instruction Type: " + ooctype);
54+
}
55+
}
56+
}
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.runtime.instructions.ooc;
21+
22+
import org.apache.commons.logging.Log;
23+
import org.apache.commons.logging.LogFactory;
24+
import org.apache.sysds.api.DMLScript;
25+
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
26+
import org.apache.sysds.runtime.instructions.Instruction;
27+
import org.apache.sysds.runtime.matrix.operators.Operator;
28+
29+
public abstract class OOCInstruction extends Instruction {
30+
protected static final Log LOG = LogFactory.getLog(OOCInstruction.class.getName());
31+
32+
public enum OOCType {
33+
AggregateUnary, Binary
34+
}
35+
36+
protected final OOCInstruction.OOCType _ooctype;
37+
protected final boolean _requiresLabelUpdate;
38+
39+
protected OOCInstruction(OOCInstruction.OOCType type, String opcode, String istr) {
40+
this(type, null, opcode, istr);
41+
}
42+
43+
protected OOCInstruction(OOCInstruction.OOCType type, Operator op, String opcode, String istr) {
44+
super(op);
45+
_ooctype = type;
46+
instString = istr;
47+
instOpcode = opcode;
48+
49+
_requiresLabelUpdate = super.requiresLabelUpdate();
50+
}
51+
52+
@Override
53+
public IType getType() {
54+
return IType.OUT_OF_CORE;
55+
}
56+
57+
public OOCInstruction.OOCType getOOCInstructionType() {
58+
return _ooctype;
59+
}
60+
61+
@Override
62+
public boolean requiresLabelUpdate() {
63+
return _requiresLabelUpdate;
64+
}
65+
66+
@Override
67+
public String getGraphString() {
68+
return getOpcode();
69+
}
70+
71+
@Override
72+
public Instruction preprocessInstruction(ExecutionContext ec) {
73+
// TODO
74+
return super.preprocessInstruction(ec);
75+
}
76+
77+
@Override
78+
public abstract void processInstruction(ExecutionContext ec);
79+
80+
@Override
81+
public void postprocessInstruction(ExecutionContext ec) {
82+
if(DMLScript.LINEAGE_DEBUGGER)
83+
ec.maintainLineageDebuggerInfo(this);
84+
}
85+
}
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.functions.ooc;
21+
22+
import org.apache.sysds.common.Opcodes;
23+
import org.apache.sysds.common.Types;
24+
import org.apache.sysds.runtime.instructions.Instruction;
25+
import org.apache.sysds.runtime.matrix.data.MatrixValue;
26+
import org.apache.sysds.test.AutomatedTestBase;
27+
import org.apache.sysds.test.TestConfiguration;
28+
import org.apache.sysds.test.TestUtils;
29+
import org.apache.sysds.utils.Statistics;
30+
import org.junit.Assert;
31+
import org.junit.Ignore;
32+
import org.junit.Test;
33+
34+
import java.util.HashMap;
35+
36+
public class SumScalarMultiplicationTest extends AutomatedTestBase {
37+
38+
private static final String TEST_NAME = "SumScalarMultiplication";
39+
private static final String TEST_DIR = "functions/ooc/";
40+
private static final String TEST_CLASS_DIR = TEST_DIR + SumScalarMultiplicationTest.class.getSimpleName() + "/";
41+
private static final String INPUT_NAME = "X";
42+
private static final String OUTPUT_NAME = "res";
43+
44+
@Override
45+
public void setUp() {
46+
TestUtils.clearAssertionInformation();
47+
TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME);
48+
addTestConfiguration(TEST_NAME, config);
49+
}
50+
51+
/**
52+
* Test the sum of scalar multiplication, "sum(X*7)", with OOC backend.
53+
*/
54+
@Test
55+
@Ignore
56+
public void testSumScalarMult() {
57+
58+
Types.ExecMode platformOld = rtplatform;
59+
rtplatform = Types.ExecMode.SINGLE_NODE;
60+
61+
try {
62+
getAndLoadTestConfiguration(TEST_NAME);
63+
String HOME = SCRIPT_DIR + TEST_DIR;
64+
fullDMLScriptName = HOME + TEST_NAME + ".dml";
65+
programArgs = new String[] {"-explain", "-stats", "-ooc", "-args", input(INPUT_NAME), output(OUTPUT_NAME)};
66+
67+
int rows = 3;
68+
int cols = 4;
69+
double sparsity = 0.8;
70+
71+
double[][] X = getRandomMatrix(rows, cols, -1, 1, sparsity, 7);
72+
writeInputMatrixWithMTD(INPUT_NAME, X, true);
73+
74+
runTest(true, false, null, -1);
75+
76+
HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir(OUTPUT_NAME);
77+
// only one entry
78+
Double result = dmlfile.get(new MatrixValue.CellIndex(1, 1));
79+
80+
double expected = 0.0;
81+
for(int i = 0; i < rows; i++) {
82+
for(int j = 0; j < cols; j++) {
83+
expected += X[i][j] * 7;
84+
}
85+
}
86+
87+
Assert.assertEquals(expected, result, 1e-10);
88+
89+
String prefix = Instruction.OOC_INST_PREFIX;
90+
91+
boolean usedOOCMult = Statistics.getCPHeavyHitterOpCodes().contains(prefix + Opcodes.MULT);
92+
Assert.assertTrue("OOC wasn't used for MULT", usedOOCMult);
93+
94+
boolean usedOOCSum = Statistics.getCPHeavyHitterOpCodes().contains(prefix + Opcodes.UAKP);
95+
Assert.assertTrue("OOC wasn't used for SUM", usedOOCSum);
96+
97+
}
98+
finally {
99+
// reset
100+
rtplatform = platformOld;
101+
}
102+
}
103+
}

0 commit comments

Comments
 (0)