Skip to content

Commit f05ccdc

Browse files
committed
[SYSTEMDS-3253] Add combined rewrite and lop instruction for union
This patch refines the current union operation to an internal LOP operation. Currently, two subsequent operations -- rbind() and unique(), are used to perform the union operation. We rewrite the operation with an internal LOP that uses a HashSet to compute the unique entries and returns them in a matrix. This improves the efficiency of the operation, as it avoids unique(). The order of the input entries is preserved in the output.
1 parent bd278ae commit f05ccdc

File tree

12 files changed

+281
-47
lines changed

12 files changed

+281
-47
lines changed

NAR

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
%%MatrixMarket matrix coordinate integer general
2+
10 1 10
3+
1 1 1
4+
2 1 4
5+
3 1 2
6+
4 1 5
7+
5 1 3
8+
6 1 6
9+
7 1 7
10+
8 1 8
11+
9 1 9
12+
10 1 10

src/main/java/org/apache/sysds/common/InstructionType.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ public enum InstructionType {
6161
MMTSJ,
6262
PMMJ,
6363
MMChain,
64+
Union,
6465

6566
//SP Types
6667
MAPMM,

src/main/java/org/apache/sysds/common/Opcodes.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ public enum Opcodes {
9292
MULT2("*2", InstructionType.Binary), //special * case
9393
MINUS_NZ("-nz", InstructionType.Binary), //special - case
9494

95+
UNION_DISTINCT("union_distinct", InstructionType.Union),
96+
9597
// Boolean Instruction Opcodes
9698
AND("&&", InstructionType.Binary),
9799
OR("||", InstructionType.Binary),

src/main/java/org/apache/sysds/common/Types.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -635,7 +635,8 @@ public enum OpOp2 {
635635
MINUS_NZ(false), //sparse-safe minus: X-(mean*ppred(X,0,!=))
636636
LOG_NZ(false), //sparse-safe log; ppred(X,0,"!=")*log(X,0.5)
637637
MINUS1_MULT(false), //1-X*Y
638-
QUANTIZE_COMPRESS(false); //quantization-fused compression
638+
QUANTIZE_COMPRESS(false), //quantization-fused compression
639+
UNION_DISTINCT(true);
639640

640641
private final boolean _validOuter;
641642

src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
182182
hi = simplifyListIndexing(hi); //e.g., L[i:i, 1:ncol(L)] -> L[i:i, 1:1]
183183
hi = simplifyScalarIndexing(hop, hi, i); //e.g., as.scalar(X[i,1])->X[i,1] w/ scalar output
184184
hi = simplifyConstantSort(hop, hi, i); //e.g., order(matrix())->matrix/seq;
185+
hi = simplifyUnionDistinct(hop, hi, i); //e.g., unique(rbind(A, B)) -> union_distinct(A, B);
185186
hi = simplifyOrderedSort(hop, hi, i); //e.g., order(matrix())->seq;
186187
hi = fuseOrderOperationChain(hi); //e.g., order(order(X,2),1) -> order(X,(12))
187188
hi = removeUnnecessaryReorgOperation(hop, hi, i); //e.g., t(t(X))->X; rev(rev(X))->X potentially introduced by other rewrites
@@ -1802,7 +1803,26 @@ private static Hop simplifyConstantSort(Hop parent, Hop hi, int pos)
18021803
}
18031804
}
18041805
}
1806+
1807+
return hi;
1808+
}
1809+
18051810

1811+
private static Hop simplifyUnionDistinct(Hop parent, Hop hi, int pos) {
1812+
// pattern: unique(rbind(A, B)) -> union_distinct(A, B)
1813+
if(HopRewriteUtils.isAggUnaryOp(hi, AggOp.UNIQUE) && HopRewriteUtils.isBinary(hi.getInput(0), OpOp2.RBIND)) {
1814+
Hop rbindAB = hi.getInput(0);
1815+
List<Hop> rbindABParents = rbindAB.getParent();
1816+
if(rbindABParents.size() == 1) {
1817+
// make sure that rbind is only used here
1818+
Hop A = rbindAB.getInput(0);
1819+
Hop B = rbindAB.getInput(1);
1820+
Hop unionDistinct = HopRewriteUtils.createBinary(A, B, OpOp2.UNION_DISTINCT);
1821+
HopRewriteUtils.replaceChildReference(parent, hi, unionDistinct, pos);
1822+
HopRewriteUtils.cleanupUnreferenced(hi, rbindAB);
1823+
LOG.debug("Applied simplifyUnionDistinct rewrite");
1824+
}
1825+
}
18061826
return hi;
18071827
}
18081828

src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java

Lines changed: 4 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -25,45 +25,7 @@
2525
import org.apache.sysds.common.Opcodes;
2626
import org.apache.sysds.common.Types.ExecType;
2727
import org.apache.sysds.runtime.DMLRuntimeException;
28-
import org.apache.sysds.runtime.instructions.cp.AggregateBinaryCPInstruction;
29-
import org.apache.sysds.runtime.instructions.cp.AggregateTernaryCPInstruction;
30-
import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction;
31-
import org.apache.sysds.runtime.instructions.cp.AppendCPInstruction;
32-
import org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction;
33-
import org.apache.sysds.runtime.instructions.cp.BroadcastCPInstruction;
34-
import org.apache.sysds.runtime.instructions.cp.BuiltinNaryCPInstruction;
35-
import org.apache.sysds.runtime.instructions.cp.CPInstruction;
36-
import org.apache.sysds.runtime.instructions.cp.CentralMomentCPInstruction;
37-
import org.apache.sysds.runtime.instructions.cp.CompressionCPInstruction;
38-
import org.apache.sysds.runtime.instructions.cp.CovarianceCPInstruction;
39-
import org.apache.sysds.runtime.instructions.cp.CtableCPInstruction;
40-
import org.apache.sysds.runtime.instructions.cp.DataGenCPInstruction;
41-
import org.apache.sysds.runtime.instructions.cp.DeCompressionCPInstruction;
42-
import org.apache.sysds.runtime.instructions.cp.DnnCPInstruction;
43-
import org.apache.sysds.runtime.instructions.cp.EvictCPInstruction;
44-
import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
45-
import org.apache.sysds.runtime.instructions.cp.IndexingCPInstruction;
46-
import org.apache.sysds.runtime.instructions.cp.LocalCPInstruction;
47-
import org.apache.sysds.runtime.instructions.cp.MMChainCPInstruction;
48-
import org.apache.sysds.runtime.instructions.cp.MMTSJCPInstruction;
49-
import org.apache.sysds.runtime.instructions.cp.MultiReturnBuiltinCPInstruction;
50-
import org.apache.sysds.runtime.instructions.cp.MultiReturnComplexMatrixBuiltinCPInstruction;
51-
import org.apache.sysds.runtime.instructions.cp.MultiReturnParameterizedBuiltinCPInstruction;
52-
import org.apache.sysds.runtime.instructions.cp.PMMJCPInstruction;
53-
import org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
54-
import org.apache.sysds.runtime.instructions.cp.PrefetchCPInstruction;
55-
import org.apache.sysds.runtime.instructions.cp.QuantilePickCPInstruction;
56-
import org.apache.sysds.runtime.instructions.cp.QuantileSortCPInstruction;
57-
import org.apache.sysds.runtime.instructions.cp.QuaternaryCPInstruction;
58-
import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction;
59-
import org.apache.sysds.runtime.instructions.cp.ReshapeCPInstruction;
60-
import org.apache.sysds.runtime.instructions.cp.SpoofCPInstruction;
61-
import org.apache.sysds.runtime.instructions.cp.SqlCPInstruction;
62-
import org.apache.sysds.runtime.instructions.cp.StringInitCPInstruction;
63-
import org.apache.sysds.runtime.instructions.cp.TernaryCPInstruction;
64-
import org.apache.sysds.runtime.instructions.cp.UaggOuterChainCPInstruction;
65-
import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction;
66-
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
28+
import org.apache.sysds.runtime.instructions.cp.*;
6729
import org.apache.sysds.runtime.instructions.cpfile.MatrixIndexingCPFileInstruction;
6830

6931
public class CPInstructionParser extends InstructionParser {
@@ -218,6 +180,9 @@ public static CPInstruction parseSingleInstruction ( InstructionType cptype, Str
218180

219181
case EvictLineageCache:
220182
return EvictCPInstruction.parseInstruction(str);
183+
184+
case Union:
185+
return UnionCPInstruction.parseInstruction(str);
221186

222187
default:
223188
throw new DMLRuntimeException("Invalid CP Instruction Type: " + cptype );

src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ public enum CPType {
4646
StringInit, CentralMoment, Covariance, UaggOuterChain, Dnn, Sql, Prefetch, Broadcast, TrigRemote,
4747
EvictLineageCache,
4848
NoOp,
49+
Union,
4950
QuantizeCompression
5051
}
5152

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package org.apache.sysds.runtime.instructions.cp;
2+
3+
import org.apache.sysds.common.Opcodes;
4+
import org.apache.sysds.runtime.DMLRuntimeException;
5+
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
6+
import org.apache.sysds.runtime.instructions.InstructionUtils;
7+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
8+
import org.apache.sysds.runtime.matrix.operators.MultiThreadedOperator;
9+
import org.apache.sysds.runtime.matrix.operators.Operator;
10+
11+
public class UnionCPInstruction extends BinaryCPInstruction {
12+
13+
private UnionCPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
14+
super(CPType.Union, op, in1, in2, out, opcode, istr);
15+
}
16+
17+
public static UnionCPInstruction parseInstruction(String str) {
18+
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
19+
String opcode = parts[0];
20+
21+
if(!opcode.equalsIgnoreCase(Opcodes.UNION_DISTINCT.toString()))
22+
throw new DMLRuntimeException("Invalid opcode for UNION_DISTINCT: " + opcode);
23+
24+
CPOperand in1 = new CPOperand(parts[1]);
25+
CPOperand in2 = new CPOperand(parts[2]);
26+
CPOperand out = new CPOperand(parts[parts.length - 2]);
27+
MultiThreadedOperator operator = new MultiThreadedOperator();
28+
return new UnionCPInstruction(operator, in1, in2, out, opcode, str);
29+
}
30+
31+
@Override
32+
public void processInstruction(ExecutionContext ec) {
33+
MatrixBlock matBlock1 = ec.getMatrixInput(input1.getName());
34+
MatrixBlock matBlock2 = ec.getMatrixInput(input2.getName());
35+
MatrixBlock out = matBlock1.unionOperations(matBlock1, matBlock2);
36+
ec.releaseMatrixInput(input1.getName());
37+
ec.releaseMatrixInput(input2.getName());
38+
ec.setMatrixOutput(output.getName(), out);
39+
System.out.println("Processing UNION CP instruction");
40+
}
41+
42+
}

src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,7 @@
2727
import java.io.ObjectInputStream;
2828
import java.io.ObjectOutput;
2929
import java.io.ObjectOutputStream;
30-
import java.util.ArrayList;
31-
import java.util.Arrays;
32-
import java.util.Collections;
33-
import java.util.HashMap;
34-
import java.util.Iterator;
35-
import java.util.List;
30+
import java.util.*;
3631
import java.util.concurrent.ExecutorService;
3732
import java.util.concurrent.Future;
3833
import java.util.stream.Collectors;
@@ -4925,7 +4920,61 @@ public MatrixBlock uaggouterchainOperations(MatrixBlock mbLeft, MatrixBlock mbR
49254920
LibMatrixOuterAgg.aggregateMatrix(mbLeft, mbOut, bv, bvi, bOp, uaggOp);
49264921
} else
49274922
throw new DMLRuntimeException("Unsupported operator for unary aggregate operations.");
4928-
4923+
4924+
return mbOut;
4925+
}
4926+
4927+
public MatrixBlock unionOperations(MatrixBlock m1, MatrixBlock m2) {
4928+
HashSet<List<Double>> set = new HashSet<>();
4929+
boolean[] toAddArr = new boolean[m1.getNumRows() + m2.getNumRows()];
4930+
int id = 0;
4931+
for(int i = 0; i < m1.getNumRows(); i++) {
4932+
List<Double> row = new ArrayList<>();
4933+
for(int j = 0; j < m1.getNumColumns(); j++) {
4934+
row.add(m1.get(i, j));
4935+
}
4936+
if(!set.contains(row)) {
4937+
set.add(row);
4938+
toAddArr[id] = true;
4939+
}
4940+
id++;
4941+
}
4942+
4943+
for(int i = 0; i < m2.getNumRows(); i++) {
4944+
List<Double> row = new ArrayList<>();
4945+
for(int j = 0; j < m2.getNumColumns(); j++) {
4946+
row.add(m2.get(i, j));
4947+
}
4948+
if(!set.contains(row)) {
4949+
set.add(row);
4950+
toAddArr[id] = true;
4951+
}
4952+
id++;
4953+
}
4954+
4955+
MatrixBlock mbOut = new MatrixBlock(set.size(), m1.getNumColumns(), false);
4956+
int rowOut = 0;
4957+
int rowId = 0;
4958+
for(boolean toAdd : toAddArr) {
4959+
if(toAdd) {
4960+
if(rowId < m1.getNumRows()) {
4961+
// is first matrix
4962+
for(int i = 0; i < m1.getNumColumns(); i++) {
4963+
mbOut.set(rowOut, i, m1.get(rowId, i));
4964+
}
4965+
}
4966+
else {
4967+
// is second matrix
4968+
int tempRowId = rowId - m1.getNumRows();
4969+
for(int i = 0; i < m2.getNumColumns(); i++) {
4970+
mbOut.set(rowOut, i, m2.get(tempRowId, i));
4971+
}
4972+
}
4973+
rowOut++;
4974+
}
4975+
rowId++;
4976+
}
4977+
49294978
return mbOut;
49304979
}
49314980

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package org.apache.sysds.test.functions.rewrite;
2+
3+
import org.apache.sysds.hops.OptimizerUtils;
4+
import org.apache.sysds.runtime.matrix.data.MatrixValue;
5+
import org.apache.sysds.test.AutomatedTestBase;
6+
import org.apache.sysds.test.TestConfiguration;
7+
import org.apache.sysds.test.TestUtils;
8+
import org.junit.Test;
9+
10+
import java.util.HashMap;
11+
12+
public class RewriteSimplifyUnionDistinctTest extends AutomatedTestBase {
13+
private static final String TEST_NAME = "rewriteSimplifyUnionDistinct";
14+
private static final String TEST_DIR = "functions/rewrite/";
15+
private static final String TEST_CLASS_DIR = TEST_DIR + RewriteSimplifyUnionDistinctTest.class.getSimpleName()
16+
+ "/";
17+
private static final double eps = Math.pow(10, -10);
18+
19+
@Override
20+
public void setUp() {
21+
TestUtils.clearAssertionInformation();
22+
addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"}));
23+
}
24+
25+
@Test
26+
public void testPrint() {
27+
System.out.println("Test rewriteSimplifyUnionDistinct");
28+
}
29+
30+
@Test
31+
public void testUnionDistinctNoRewrite() {
32+
testRewriteSimplifyUnionDistinct(2, false);
33+
}
34+
35+
@Test
36+
public void testUnionDistinctRewrite() {
37+
testRewriteSimplifyUnionDistinct(2, true);
38+
}
39+
40+
private void testRewriteSimplifyUnionDistinct(int ID, boolean rewrites) {
41+
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
42+
try {
43+
TestConfiguration config = getTestConfiguration(TEST_NAME);
44+
loadTestConfiguration(config);
45+
46+
String HOME = SCRIPT_DIR + TEST_DIR;
47+
fullDMLScriptName = HOME + TEST_NAME + ".dml";
48+
programArgs = new String[] {"-explain", "-stats", "-args", String.valueOf(ID), output("R")};
49+
fullRScriptName = HOME + TEST_NAME + ".R";
50+
rCmd = getRCmd(String.valueOf(ID), expectedDir());
51+
52+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
53+
54+
runTest(true, false, null, -1);
55+
runRScript(true);
56+
57+
// compare matrices
58+
HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R");
59+
HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromExpectedDir("R");
60+
TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
61+
}
62+
finally {
63+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
64+
}
65+
}
66+
}

0 commit comments

Comments
 (0)