Skip to content

Commit 5aa0307

Browse files
chihsinhmboehm7
authored andcommitted
[SYSTEMDS-3253] New native operator for union distinct
This patch refines the current union operation to an internal LOP operation. Currently, two subsequent operations -- rbind() and unique() are used to perform the union operation. We rewrite the operation with an internal LOP that uses a HashSet to compute the unique entries and returns them in a matrix. This improves the efficiency of the operation, as it avoids unique(). The order of the input entries is preserved in the output. Closes #2286.
1 parent 7440ef7 commit 5aa0307

File tree

11 files changed

+360
-3
lines changed

11 files changed

+360
-3
lines changed

src/main/java/org/apache/sysds/common/InstructionType.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ public enum InstructionType {
6161
MMTSJ,
6262
PMMJ,
6363
MMChain,
64+
Union,
6465

6566
//SP Types
6667
MAPMM,

src/main/java/org/apache/sysds/common/Opcodes.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ public enum Opcodes {
9090
MULT2("*2", InstructionType.Binary), //special * case
9191
MINUS_NZ("-nz", InstructionType.Binary), //special - case
9292

93+
UNION_DISTINCT("union_distinct", InstructionType.Union),
94+
9395
// Boolean Instruction Opcodes
9496
AND("&&", InstructionType.Binary),
9597
OR("||", InstructionType.Binary),

src/main/java/org/apache/sysds/common/Types.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,8 @@ public enum OpOp2 {
637637
MINUS_NZ(false), //sparse-safe minus: X-(mean*ppred(X,0,!=))
638638
LOG_NZ(false), //sparse-safe log; ppred(X,0,"!=")*log(X,0.5)
639639
MINUS1_MULT(false), //1-X*Y
640-
QUANTIZE_COMPRESS(false); //quantization-fused compression
640+
QUANTIZE_COMPRESS(false), //quantization-fused compression
641+
UNION_DISTINCT(false);
641642

642643
private final boolean _validOuter;
643644

src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
182182
hi = simplifyListIndexing(hi); //e.g., L[i:i, 1:ncol(L)] -> L[i:i, 1:1]
183183
hi = simplifyScalarIndexing(hop, hi, i); //e.g., as.scalar(X[i,1])->X[i,1] w/ scalar output
184184
hi = simplifyConstantSort(hop, hi, i); //e.g., order(matrix())->matrix/seq;
185+
hi = simplifyUnionDistinct(hop, hi, i); //e.g., unique(rbind(A, B)) -> union_distinct(A, B);
185186
hi = simplifyOrderedSort(hop, hi, i); //e.g., order(matrix())->seq;
186187
hi = fuseOrderOperationChain(hi); //e.g., order(order(X,2),1) -> order(X,(12))
187188
hi = removeUnnecessaryReorgOperation(hop, hi, i); //e.g., t(t(X))->X; rev(rev(X))->X potentially introduced by other rewrites
@@ -1837,7 +1838,26 @@ private static Hop simplifyConstantSort(Hop parent, Hop hi, int pos)
18371838
}
18381839
}
18391840
}
1841+
1842+
return hi;
1843+
}
1844+
18401845

1846+
private static Hop simplifyUnionDistinct(Hop parent, Hop hi, int pos) {
1847+
// pattern: unique(rbind(A, B)) -> union_distinct(A, B)
1848+
if(HopRewriteUtils.isAggUnaryOp(hi, AggOp.UNIQUE)
1849+
&& HopRewriteUtils.isBinary(hi.getInput(0), OpOp2.RBIND)) {
1850+
Hop rbindAB = hi.getInput(0);
1851+
if(rbindAB.getParent().size() == 1) {
1852+
// make sure that rbind is only used here
1853+
Hop A = rbindAB.getInput(0);
1854+
Hop B = rbindAB.getInput(1);
1855+
Hop unionDistinct = HopRewriteUtils.createBinary(A, B, OpOp2.UNION_DISTINCT);
1856+
HopRewriteUtils.replaceChildReference(parent, hi, unionDistinct, pos);
1857+
HopRewriteUtils.cleanupUnreferenced(hi, rbindAB);
1858+
LOG.debug("Applied simplifyUnionDistinct");
1859+
}
1860+
}
18411861
return hi;
18421862
}
18431863

src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
import org.apache.sysds.runtime.instructions.cp.UaggOuterChainCPInstruction;
6565
import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction;
6666
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
67+
import org.apache.sysds.runtime.instructions.cp.UnionCPInstruction;
6768
import org.apache.sysds.runtime.instructions.cpfile.MatrixIndexingCPFileInstruction;
6869

6970
public class CPInstructionParser extends InstructionParser {
@@ -218,6 +219,9 @@ public static CPInstruction parseSingleInstruction ( InstructionType cptype, Str
218219

219220
case EvictLineageCache:
220221
return EvictCPInstruction.parseInstruction(str);
222+
223+
case Union:
224+
return UnionCPInstruction.parseInstruction(str);
221225

222226
default:
223227
throw new DMLRuntimeException("Invalid CP Instruction Type: " + cptype );

src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ public enum CPType {
4646
StringInit, CentralMoment, Covariance, UaggOuterChain, Dnn, Sql, Prefetch, Broadcast, TrigRemote,
4747
EvictLineageCache,
4848
NoOp,
49+
Union,
4950
QuantizeCompression
5051
}
5152

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.runtime.instructions.cp;
21+
22+
import org.apache.sysds.common.Opcodes;
23+
import org.apache.sysds.runtime.DMLRuntimeException;
24+
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
25+
import org.apache.sysds.runtime.instructions.InstructionUtils;
26+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
27+
import org.apache.sysds.runtime.matrix.operators.MultiThreadedOperator;
28+
import org.apache.sysds.runtime.matrix.operators.Operator;
29+
30+
public class UnionCPInstruction extends BinaryCPInstruction {
31+
32+
private UnionCPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
33+
super(CPType.Union, op, in1, in2, out, opcode, istr);
34+
}
35+
36+
public static UnionCPInstruction parseInstruction(String str) {
37+
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
38+
String opcode = parts[0];
39+
40+
if(!opcode.equalsIgnoreCase(Opcodes.UNION_DISTINCT.toString()))
41+
throw new DMLRuntimeException("Invalid opcode for UNION_DISTINCT: " + opcode);
42+
43+
CPOperand in1 = new CPOperand(parts[1]);
44+
CPOperand in2 = new CPOperand(parts[2]);
45+
CPOperand out = new CPOperand(parts[parts.length - 2]);
46+
MultiThreadedOperator operator = new MultiThreadedOperator();
47+
return new UnionCPInstruction(operator, in1, in2, out, opcode, str);
48+
}
49+
50+
@Override
51+
public void processInstruction(ExecutionContext ec) {
52+
MatrixBlock matBlock1 = ec.getMatrixInput(input1.getName());
53+
MatrixBlock matBlock2 = ec.getMatrixInput(input2.getName());
54+
MatrixBlock out = matBlock1.unionOperations(matBlock1, matBlock2);
55+
ec.releaseMatrixInput(input1.getName());
56+
ec.releaseMatrixInput(input2.getName());
57+
ec.setMatrixOutput(output.getName(), out);
58+
}
59+
60+
}

src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java

Lines changed: 87 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@
3333
import java.util.HashMap;
3434
import java.util.Iterator;
3535
import java.util.List;
36+
import java.util.HashSet;
37+
import java.util.Set;
38+
import java.util.TreeSet;
3639
import java.util.concurrent.ExecutorService;
3740
import java.util.concurrent.Future;
3841
import java.util.stream.Collectors;
@@ -4925,11 +4928,93 @@ public MatrixBlock uaggouterchainOperations(MatrixBlock mbLeft, MatrixBlock mbR
49254928
LibMatrixOuterAgg.aggregateMatrix(mbLeft, mbOut, bv, bvi, bOp, uaggOp);
49264929
} else
49274930
throw new DMLRuntimeException("Unsupported operator for unary aggregate operations.");
4928-
4931+
49294932
return mbOut;
49304933
}
4934+
4935+
public MatrixBlock unionOperations(MatrixBlock m1, MatrixBlock m2) {
4936+
if(m1.getNumColumns() == 1) {
4937+
HashSet<Double> set = new HashSet<>();
4938+
boolean[] toAddArr = new boolean[m1.getNumRows() + m2.getNumRows()];
4939+
int id = 0;
4940+
for(MatrixBlock m : new MatrixBlock[] {m1,m2}) {
4941+
for(int i = 0; i < m.getNumRows(); i++) {
4942+
Double val = m.get(i, 0);
4943+
if(!set.contains(val)) {
4944+
set.add(val);
4945+
toAddArr[id] = true;
4946+
}
4947+
id++;
4948+
}
4949+
}
4950+
4951+
MatrixBlock mbOut = new MatrixBlock(set.size(), m1.getNumColumns(), false);
4952+
int rowOut = 0;
4953+
int rowId = 0;
4954+
for(boolean toAdd : toAddArr) {
4955+
if(toAdd) {
4956+
if(rowId < m1.getNumRows()) { // is first matrix
4957+
mbOut.set(rowOut, 0, m1.get(rowId, 0));
4958+
}
4959+
else { // is second matrix
4960+
int tempRowId = rowId - m1.getNumRows();
4961+
mbOut.set(rowOut, 0, m2.get(tempRowId, 0));
4962+
}
4963+
rowOut++;
4964+
}
4965+
rowId++;
4966+
}
4967+
4968+
return mbOut;
4969+
}
4970+
else {
4971+
Set<double[]> set = new TreeSet<>((o1, o2) -> {
4972+
return Arrays.compare(o1, o2);
4973+
});
4974+
boolean[] toAddArr = new boolean[m1.getNumRows() + m2.getNumRows()];
4975+
int id = 0;
4976+
4977+
//TODO perf dense zero-copy and sparse
4978+
for(MatrixBlock m : new MatrixBlock[] {m1,m2}) {
4979+
for(int i = 0; i < m.getNumRows(); i++) {
4980+
double[] row = new double[m.getNumColumns()];
4981+
for(int j = 0; j < m.getNumColumns(); j++)
4982+
row[j] = m.get(i, j);
4983+
if(!set.contains(row)) {
4984+
set.add(row);
4985+
toAddArr[id] = true;
4986+
}
4987+
id++;
4988+
}
4989+
}
4990+
4991+
MatrixBlock mbOut = new MatrixBlock(set.size(), m1.getNumColumns(), false);
4992+
int rowOut = 0;
4993+
int rowId = 0;
4994+
for(boolean toAdd : toAddArr) {
4995+
if(toAdd) {
4996+
if(rowId < m1.getNumRows()) {
4997+
// is first matrix
4998+
for(int i = 0; i < m1.getNumColumns(); i++) {
4999+
mbOut.set(rowOut, i, m1.get(rowId, i));
5000+
}
5001+
}
5002+
else {
5003+
// is second matrix
5004+
int tempRowId = rowId - m1.getNumRows();
5005+
for(int i = 0; i < m2.getNumColumns(); i++) {
5006+
mbOut.set(rowOut, i, m2.get(tempRowId, i));
5007+
}
5008+
}
5009+
rowOut++;
5010+
}
5011+
rowId++;
5012+
}
5013+
5014+
return mbOut;
5015+
}
5016+
}
49315017

4932-
49335018
/**
49345019
* Invocation from CP instructions. The aggregate is computed on the groups object
49355020
* against target and weights.
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.functions.rewrite;
21+
22+
import org.apache.sysds.hops.OptimizerUtils;
23+
import org.apache.sysds.runtime.matrix.data.MatrixValue;
24+
import org.apache.sysds.test.AutomatedTestBase;
25+
import org.apache.sysds.test.TestConfiguration;
26+
import org.apache.sysds.test.TestUtils;
27+
import org.junit.Test;
28+
29+
import java.util.HashMap;
30+
31+
public class RewriteSimplifyUnionDistinctTest extends AutomatedTestBase {
32+
private static final String TEST_NAME = "RewriteSimplifyUnion";
33+
private static final String TEST_DIR = "functions/rewrite/";
34+
private static final String TEST_CLASS_DIR = TEST_DIR + RewriteSimplifyUnionDistinctTest.class.getSimpleName()
35+
+ "/";
36+
private static final double eps = Math.pow(10, -10);
37+
38+
@Override
39+
public void setUp() {
40+
TestUtils.clearAssertionInformation();
41+
addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"}));
42+
}
43+
44+
@Test
45+
public void testUnionDistinctRewriteOne() {
46+
testRewriteSimplifyUnionDistinct(1, true);
47+
}
48+
49+
@Test
50+
public void testUnionDistinctRewriteFifty() {
51+
testRewriteSimplifyUnionDistinct(50, true);
52+
}
53+
54+
@Test
55+
public void testUnionDistinctRewriteOneThousand() {
56+
testRewriteSimplifyUnionDistinct(1000, true);
57+
}
58+
59+
@Test
60+
public void testUnionDistinctRewrite() {
61+
testRewriteSimplifyUnionDistinct(2, true);
62+
}
63+
64+
private void testRewriteSimplifyUnionDistinct(int ID, boolean rewrites) {
65+
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
66+
try {
67+
TestConfiguration config = getTestConfiguration(TEST_NAME);
68+
loadTestConfiguration(config);
69+
70+
String HOME = SCRIPT_DIR + TEST_DIR;
71+
fullDMLScriptName = HOME + TEST_NAME + ".dml";
72+
int rowNum = (int) (Math.random() * 1000);
73+
programArgs = new String[] {"-explain", "-stats", "-args", String.valueOf(ID), String.valueOf(rowNum),
74+
output("R")};
75+
fullRScriptName = HOME + TEST_NAME + ".R";
76+
rCmd = getRCmd(String.valueOf(ID), String.valueOf(rowNum), expectedDir());
77+
78+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
79+
80+
runTest(true, false, null, -1);
81+
runRScript(true);
82+
83+
// compare matrices
84+
HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R");
85+
HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromExpectedDir("R");
86+
TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
87+
}
88+
finally {
89+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
90+
}
91+
}
92+
}
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
23+
args <- commandArgs(TRUE)
24+
25+
# Set options for numeric precision
26+
options(digits=22)
27+
28+
# Load required libraries
29+
library("Matrix")
30+
library("matrixStats")
31+
32+
# Read matrices
33+
colNum = as.integer(args[1])
34+
rowNum = as.integer(args[2])
35+
X = matrix(rep(1, colNum), nrow=1, ncol=colNum)
36+
Y = matrix(rep(1 + floor(rowNum / 2), colNum), nrow=1, ncol=colNum)
37+
38+
if(rowNum != 1) {
39+
for(i in 2 : rowNum - 1) {
40+
X = rbind(X, rep(i + 1, colNum))
41+
Y = rbind(Y, rep(i + 1 + floor(rowNum / 2), colNum))
42+
}
43+
}
44+
45+
# Perform operations
46+
combined = rbind(X,Y);
47+
R = unique(combined);
48+
49+
#Write result matrix R
50+
writeMM(as(R, "CsparseMatrix"), paste(args[3], "R", sep=""))

0 commit comments

Comments
 (0)