Skip to content

Commit f265845

Browse files
janniklindemboehm7
authored andcommitted
[SYSTEMDS-3891] New OOC replace and contains instructions
Closes #2356.
1 parent 6f3cdb3 commit f265845

File tree

7 files changed

+452
-2
lines changed

7 files changed

+452
-2
lines changed

src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import org.apache.sysds.runtime.instructions.ooc.CentralMomentOOCInstruction;
3030
import org.apache.sysds.runtime.instructions.ooc.CtableOOCInstruction;
3131
import org.apache.sysds.runtime.instructions.ooc.OOCInstruction;
32+
import org.apache.sysds.runtime.instructions.ooc.ParameterizedBuiltinOOCInstruction;
3233
import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction;
3334
import org.apache.sysds.runtime.instructions.ooc.TSMMOOCInstruction;
3435
import org.apache.sysds.runtime.instructions.ooc.UnaryOOCInstruction;
@@ -78,6 +79,8 @@ public static OOCInstruction parseSingleInstruction(InstructionType ooctype, Str
7879
return CentralMomentOOCInstruction.parseInstruction(str);
7980
case Ctable:
8081
return CtableOOCInstruction.parseInstruction(str);
82+
case ParameterizedBuiltin:
83+
return ParameterizedBuiltinOOCInstruction.parseInstruction(str);
8184

8285
default:
8386
throw new DMLRuntimeException("Invalid OOC Instruction Type: " + ooctype);

src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.apache.sysds.runtime.util.OOCJoin;
3535

3636
import java.util.ArrayList;
37+
import java.util.Collections;
3738
import java.util.HashMap;
3839
import java.util.HashSet;
3940
import java.util.List;
@@ -53,7 +54,7 @@ public abstract class OOCInstruction extends Instruction {
5354
private static final AtomicInteger nextStreamId = new AtomicInteger(0);
5455

5556
public enum OOCType {
56-
Reblock, Tee, Binary, Unary, AggregateUnary, AggregateBinary, MAPMM, MMTSJ, Reorg, CM, Ctable, MatrixIndexing
57+
Reblock, Tee, Binary, Unary, AggregateUnary, AggregateBinary, MAPMM, MMTSJ, Reorg, CM, Ctable, MatrixIndexing, ParameterizedBuiltin
5758
}
5859

5960
protected final OOCInstruction.OOCType _ooctype;
@@ -208,6 +209,8 @@ protected <T> CompletableFuture<Void> submitOOCTasks(final List<OOCStream<T>> qu
208209

209210
final AtomicInteger globalTaskCtr = new AtomicInteger(0);
210211
final CompletableFuture<Void> globalFuture = CompletableFuture.allOf(futures.toArray(CompletableFuture[]::new));
212+
if (_outQueues == null)
213+
_outQueues = Collections.emptySet();
211214
final Runnable oocFinalizer = oocTask(finalizer, null, Stream.concat(_outQueues.stream(), _inQueues.stream()).toArray(OOCStream[]::new));
212215
final Object globalLock = new Object();
213216

@@ -278,7 +281,14 @@ protected <T> CompletableFuture<Void> submitOOCTasks(final List<OOCStream<T>> qu
278281

279282
globalFuture.whenComplete((res, e) -> {
280283
if (globalFuture.isCancelled() || globalFuture.isCompletedExceptionally())
281-
futures.forEach(f -> f.cancel(true));
284+
futures.forEach(f -> {
285+
if (!f.isDone()) {
286+
if (globalFuture.isCancelled() || globalFuture.isCompletedExceptionally())
287+
f.cancel(true);
288+
else
289+
f.complete(null);
290+
}
291+
});
282292

283293
boolean runFinalizer;
284294

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.runtime.instructions.ooc;
21+
22+
import org.apache.commons.lang3.NotImplementedException;
23+
import org.apache.commons.lang3.mutable.MutableObject;
24+
import org.apache.sysds.common.Opcodes;
25+
import org.apache.sysds.common.Types;
26+
import org.apache.sysds.runtime.DMLRuntimeException;
27+
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
28+
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
29+
import org.apache.sysds.runtime.functionobjects.ParameterizedBuiltin;
30+
import org.apache.sysds.runtime.functionobjects.ValueFunction;
31+
import org.apache.sysds.runtime.instructions.InstructionUtils;
32+
import org.apache.sysds.runtime.instructions.cp.BooleanObject;
33+
import org.apache.sysds.runtime.instructions.cp.CPOperand;
34+
import org.apache.sysds.runtime.instructions.cp.Data;
35+
import org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
36+
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
37+
import org.apache.sysds.runtime.instructions.cp.ScalarObjectFactory;
38+
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
39+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
40+
import org.apache.sysds.runtime.matrix.operators.Operator;
41+
import org.apache.sysds.runtime.matrix.operators.SimpleOperator;
42+
43+
import java.util.LinkedHashMap;
44+
import java.util.concurrent.CompletableFuture;
45+
import java.util.concurrent.ExecutionException;
46+
import java.util.concurrent.atomic.AtomicBoolean;
47+
48+
public class ParameterizedBuiltinOOCInstruction extends ComputationOOCInstruction {
49+
50+
protected final LinkedHashMap<String, String> params;
51+
52+
protected ParameterizedBuiltinOOCInstruction(Operator op, LinkedHashMap<String, String> paramsMap, CPOperand out,
53+
String opcode, String istr) {
54+
super(OOCInstruction.OOCType.ParameterizedBuiltin, op, null, null, out, opcode, istr);
55+
params = paramsMap;
56+
}
57+
58+
public static ParameterizedBuiltinOOCInstruction parseInstruction(String str) {
59+
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
60+
// first part is always the opcode
61+
String opcode = parts[0];
62+
// last part is always the output
63+
CPOperand out = new CPOperand(parts[parts.length - 1]);
64+
65+
// process remaining parts and build a hash map
66+
LinkedHashMap<String, String> paramsMap = ParameterizedBuiltinCPInstruction.constructParameterMap(parts);
67+
68+
// determine the appropriate value function
69+
ValueFunction func = null;
70+
71+
if(opcode.equalsIgnoreCase(Opcodes.REPLACE.toString())) {
72+
func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
73+
return new ParameterizedBuiltinOOCInstruction(new SimpleOperator(func), paramsMap, out, opcode, str);
74+
}
75+
else if(opcode.equalsIgnoreCase(Opcodes.CONTAINS.toString())) {
76+
return new ParameterizedBuiltinOOCInstruction(null, paramsMap, out, opcode, str);
77+
}
78+
else
79+
throw new NotImplementedException(); // TODO
80+
}
81+
82+
@Override
83+
public void processInstruction(ExecutionContext ec) {
84+
if(instOpcode.equalsIgnoreCase(Opcodes.REPLACE.toString())) {
85+
if(ec.isFrameObject(params.get("target"))){
86+
throw new NotImplementedException();
87+
} else{
88+
MatrixObject targetObj = ec.getMatrixObject(params.get("target"));
89+
OOCStream<IndexedMatrixValue> qIn = targetObj.getStreamHandle();
90+
OOCStream<IndexedMatrixValue> qOut = createWritableStream();
91+
92+
double pattern = Double.parseDouble(params.get("pattern"));
93+
double replacement = Double.parseDouble(params.get("replacement"));
94+
95+
mapOOC(qIn, qOut, tmp -> new IndexedMatrixValue(tmp.getIndexes(), tmp.getValue().replaceOperations(new MatrixBlock(), pattern, replacement)));
96+
97+
ec.getMatrixObject(output).setStreamHandle(qOut);
98+
}
99+
}
100+
else if(instOpcode.equalsIgnoreCase(Opcodes.CONTAINS.toString())) {
101+
MatrixObject targetObj = ec.getMatrixObject(params.get("target"));
102+
OOCStream<IndexedMatrixValue> qIn = targetObj.getStreamHandle();
103+
Data pattern = ec.getVariable(params.get("pattern"));
104+
105+
if( pattern == null ) //literal
106+
pattern = ScalarObjectFactory.createScalarObject(Types.ValueType.FP64, params.get("pattern"));
107+
108+
if (!pattern.getDataType().isScalar())
109+
throw new NotImplementedException();
110+
111+
Data finalPattern = pattern;
112+
113+
AtomicBoolean found = new AtomicBoolean(false);
114+
115+
MutableObject<CompletableFuture<Void>> futureRef = new MutableObject<>();
116+
CompletableFuture<Void> future = submitOOCTasks(qIn, tmp -> {
117+
boolean contains = ((MatrixBlock)tmp.getValue()).containsValue(((ScalarObject)finalPattern).getDoubleValue());
118+
119+
if (contains) {
120+
found.set(true);
121+
122+
// Now we may complete the future
123+
if (futureRef.getValue() != null)
124+
futureRef.getValue().complete(null);
125+
}
126+
}, () -> {});
127+
futureRef.setValue(future);
128+
129+
try {
130+
futureRef.getValue().get();
131+
} catch (InterruptedException | ExecutionException e) {
132+
throw new DMLRuntimeException(e);
133+
}
134+
135+
boolean ret = found.get();
136+
ec.setScalarOutput(output.getName(), new BooleanObject(ret));
137+
}
138+
}
139+
}
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.functions.ooc;
21+
22+
import org.apache.sysds.common.Opcodes;
23+
import org.apache.sysds.common.Types;
24+
import org.apache.sysds.runtime.instructions.Instruction;
25+
import org.apache.sysds.runtime.io.MatrixWriter;
26+
import org.apache.sysds.runtime.io.MatrixWriterFactory;
27+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
28+
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
29+
import org.apache.sysds.runtime.util.DataConverter;
30+
import org.apache.sysds.runtime.util.HDFSTool;
31+
import org.apache.sysds.test.AutomatedTestBase;
32+
import org.apache.sysds.test.TestConfiguration;
33+
import org.apache.sysds.test.TestUtils;
34+
import org.junit.Assert;
35+
import org.junit.Test;
36+
37+
import java.io.IOException;
38+
39+
public class ContainsTest extends AutomatedTestBase {
40+
private final static String TEST_NAME1 = "Contains";
41+
private final static String TEST_DIR = "functions/ooc/";
42+
private final static String TEST_CLASS_DIR = TEST_DIR + ContainsTest.class.getSimpleName() + "/";
43+
private final static double eps = 1e-8;
44+
private static final String INPUT_NAME_1 = "X";
45+
private static final String OUTPUT_NAME = "res";
46+
47+
private final static int rows = 1500;
48+
private final static int cols = 1200;
49+
private final static int maxVal = 2;
50+
private final static double sparsity1 = 1;
51+
private final static double sparsity2 = 0.05;
52+
53+
@Override
54+
public void setUp() {
55+
TestUtils.clearAssertionInformation();
56+
TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1);
57+
addTestConfiguration(TEST_NAME1, config);
58+
}
59+
60+
@Test
61+
public void testContainsDense() {
62+
runContainsTest(false);
63+
}
64+
65+
@Test
66+
public void testContainsSparse() {
67+
runContainsTest(true);
68+
}
69+
70+
private void runContainsTest(boolean sparse) {
71+
Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE);
72+
73+
try {
74+
getAndLoadTestConfiguration(TEST_NAME1);
75+
76+
String HOME = SCRIPT_DIR + TEST_DIR;
77+
fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
78+
programArgs = new String[] {"-explain", "-stats", "-ooc", "-args", input(INPUT_NAME_1), output(OUTPUT_NAME)};
79+
80+
// 1. Generate the data in-memory as MatrixBlock objects
81+
double[][] X_data = getRandomMatrix(rows, cols, 0, maxVal, sparse ? sparsity2 : sparsity1, 7);
82+
X_data[rows-1][cols-1] = -1;
83+
84+
// 2. Convert the double arrays to MatrixBlock objects
85+
MatrixBlock X_mb = DataConverter.convertToMatrixBlock(X_data);
86+
87+
// 3. Create a binary matrix writer
88+
MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY);
89+
90+
// 4. Write matrix A to a binary SequenceFile
91+
writer.writeMatrixToHDFS(X_mb, input(INPUT_NAME_1), rows, cols, 1000, X_mb.getNonZeros());
92+
HDFSTool.writeMetaDataFile(input(INPUT_NAME_1 + ".mtd"), Types.ValueType.FP64,
93+
new MatrixCharacteristics(rows, cols, 1000, X_mb.getNonZeros()), Types.FileFormat.BINARY);
94+
95+
runTest(true, false, null, -1);
96+
97+
//check replace OOC op
98+
Assert.assertTrue("OOC wasn't used for contains",
99+
heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.CONTAINS));
100+
101+
//compare results
102+
103+
// rerun without ooc flag
104+
programArgs = new String[] {"-explain", "-stats", "-args", input(INPUT_NAME_1), output(OUTPUT_NAME + "_target")};
105+
runTest(true, false, null, -1);
106+
107+
// compare matrices
108+
MatrixBlock ret1 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME),
109+
Types.FileFormat.BINARY, 1, 1, 1000);
110+
MatrixBlock ret2 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME + "_target"),
111+
Types.FileFormat.BINARY, 1, 1, 1000);
112+
TestUtils.compareMatrices(ret1, ret2, eps);
113+
}
114+
catch(IOException e) {
115+
throw new RuntimeException(e);
116+
}
117+
finally {
118+
resetExecMode(platformOld);
119+
}
120+
}
121+
}

0 commit comments

Comments
 (0)