Skip to content

Commit ffe7fce

Browse files
committed
Implement union operations with HashSet<Double> and TreeSet<double[]> after benchmarking
1 parent d23d386 commit ffe7fce

File tree

4 files changed

+144
-58
lines changed

4 files changed

+144
-58
lines changed

src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java

Lines changed: 99 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
import java.util.Iterator;
3535
import java.util.List;
3636
import java.util.HashSet;
37+
import java.util.Set;
38+
import java.util.TreeSet;
3739
import java.util.concurrent.ExecutorService;
3840
import java.util.concurrent.Future;
3941
import java.util.stream.Collectors;
@@ -4931,57 +4933,113 @@ public MatrixBlock uaggouterchainOperations(MatrixBlock mbLeft, MatrixBlock mbR
49314933
}
49324934

49334935
public MatrixBlock unionOperations(MatrixBlock m1, MatrixBlock m2) {
4934-
HashSet<List<Double>> set = new HashSet<>();
4935-
boolean[] toAddArr = new boolean[m1.getNumRows() + m2.getNumRows()];
4936-
int id = 0;
4937-
for(int i = 0; i < m1.getNumRows(); i++) {
4938-
List<Double> row = new ArrayList<>();
4939-
for(int j = 0; j < m1.getNumColumns(); j++) {
4940-
row.add(m1.get(i, j));
4941-
}
4942-
if(!set.contains(row)) {
4943-
set.add(row);
4944-
toAddArr[id] = true;
4936+
if(m1.getNumColumns() == 1) {
4937+
HashSet<Double> set = new HashSet<>();
4938+
boolean[] toAddArr = new boolean[m1.getNumRows() + m2.getNumRows()];
4939+
int id = 0;
4940+
for(int i = 0; i < m1.getNumRows(); i++) {
4941+
Double val = m1.get(i, 0);
4942+
if(!set.contains(val)) {
4943+
set.add(val);
4944+
toAddArr[id] = true;
4945+
}
4946+
id++;
49454947
}
4946-
id++;
4947-
}
49484948

4949-
for(int i = 0; i < m2.getNumRows(); i++) {
4950-
List<Double> row = new ArrayList<>();
4951-
for(int j = 0; j < m2.getNumColumns(); j++) {
4952-
row.add(m2.get(i, j));
4949+
for(int i = 0; i < m2.getNumRows(); i++) {
4950+
Double val = m2.get(i, 0);
4951+
if(!set.contains(val)) {
4952+
set.add(val);
4953+
toAddArr[id] = true;
4954+
}
4955+
id++;
49534956
}
4954-
if(!set.contains(row)) {
4955-
set.add(row);
4956-
toAddArr[id] = true;
4957+
4958+
MatrixBlock mbOut = new MatrixBlock(set.size(), m1.getNumColumns(), false);
4959+
int rowOut = 0;
4960+
int rowId = 0;
4961+
for(boolean toAdd : toAddArr) {
4962+
if(toAdd) {
4963+
if(rowId < m1.getNumRows()) {
4964+
// is first matrix
4965+
mbOut.set(rowOut, 0, m1.get(rowId, 0));
4966+
}
4967+
else {
4968+
// is second matrix
4969+
int tempRowId = rowId - m1.getNumRows();
4970+
mbOut.set(rowOut, 0, m2.get(tempRowId, 0));
4971+
}
4972+
rowOut++;
4973+
}
4974+
rowId++;
49574975
}
4958-
id++;
4959-
}
4960-
4961-
MatrixBlock mbOut = new MatrixBlock(set.size(), m1.getNumColumns(), false);
4962-
int rowOut = 0;
4963-
int rowId = 0;
4964-
for(boolean toAdd : toAddArr) {
4965-
if(toAdd) {
4966-
if(rowId < m1.getNumRows()) {
4967-
// is first matrix
4968-
for(int i = 0; i < m1.getNumColumns(); i++) {
4969-
mbOut.set(rowOut, i, m1.get(rowId, i));
4976+
4977+
return mbOut;
4978+
}
4979+
else {
4980+
Set<double[]> set = new TreeSet<>((o1, o2) -> {
4981+
for(int i = 0; i < o1.length; i++) {
4982+
if(o1[i] < o2[i]) {
4983+
return -1;
4984+
}
4985+
else if(o1[i] > o2[i]) {
4986+
return 1;
49704987
}
49714988
}
4972-
else {
4973-
// is second matrix
4974-
int tempRowId = rowId - m1.getNumRows();
4975-
for(int i = 0; i < m2.getNumColumns(); i++) {
4976-
mbOut.set(rowOut, i, m2.get(tempRowId, i));
4989+
return 0;
4990+
});
4991+
boolean[] toAddArr = new boolean[m1.getNumRows() + m2.getNumRows()];
4992+
int id = 0;
4993+
for(int i = 0; i < m1.getNumRows(); i++) {
4994+
double[] row = new double[m1.getNumColumns()];
4995+
for(int j = 0; j < m1.getNumColumns(); j++) {
4996+
// row.add(m1.get(i, j));
4997+
row[j] = m1.get(i, j);
4998+
}
4999+
if(!set.contains(row)) {
5000+
set.add(row);
5001+
toAddArr[id] = true;
5002+
}
5003+
id++;
5004+
}
5005+
5006+
for(int i = 0; i < m2.getNumRows(); i++) {
5007+
double[] row = new double[m2.getNumColumns()];
5008+
for(int j = 0; j < m2.getNumColumns(); j++) {
5009+
row[j] = m2.get(i, j);
5010+
}
5011+
if(!set.contains(row)) {
5012+
set.add(row);
5013+
toAddArr[id] = true;
5014+
}
5015+
id++;
5016+
}
5017+
5018+
MatrixBlock mbOut = new MatrixBlock(set.size(), m1.getNumColumns(), false);
5019+
int rowOut = 0;
5020+
int rowId = 0;
5021+
for(boolean toAdd : toAddArr) {
5022+
if(toAdd) {
5023+
if(rowId < m1.getNumRows()) {
5024+
// is first matrix
5025+
for(int i = 0; i < m1.getNumColumns(); i++) {
5026+
mbOut.set(rowOut, i, m1.get(rowId, i));
5027+
}
5028+
}
5029+
else {
5030+
// is second matrix
5031+
int tempRowId = rowId - m1.getNumRows();
5032+
for(int i = 0; i < m2.getNumColumns(); i++) {
5033+
mbOut.set(rowOut, i, m2.get(tempRowId, i));
5034+
}
49775035
}
5036+
rowOut++;
49785037
}
4979-
rowOut++;
5038+
rowId++;
49805039
}
4981-
rowId++;
4982-
}
49835040

4984-
return mbOut;
5041+
return mbOut;
5042+
}
49855043
}
49865044

49875045

src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyUnionDistinctTest.java

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,18 @@ public void setUp() {
4242
}
4343

4444
@Test
45-
public void testUnionDistinctNoRewrite() {
46-
testRewriteSimplifyUnionDistinct(2, false);
45+
public void testUnionDistinctRewriteOne() {
46+
testRewriteSimplifyUnionDistinct(1, true);
47+
}
48+
49+
@Test
50+
public void testUnionDistinctRewriteFifty() {
51+
testRewriteSimplifyUnionDistinct(50, true);
52+
}
53+
54+
@Test
55+
public void testUnionDistinctRewriteOneThousand() {
56+
testRewriteSimplifyUnionDistinct(1000, true);
4757
}
4858

4959
@Test
@@ -59,9 +69,11 @@ private void testRewriteSimplifyUnionDistinct(int ID, boolean rewrites) {
5969

6070
String HOME = SCRIPT_DIR + TEST_DIR;
6171
fullDMLScriptName = HOME + TEST_NAME + ".dml";
62-
programArgs = new String[] {"-explain", "-stats", "-args", String.valueOf(ID), output("R")};
72+
int rowNum = (int) (Math.random() * 1000);
73+
programArgs = new String[] {"-explain", "-stats", "-args", String.valueOf(ID), String.valueOf(rowNum),
74+
output("R")};
6375
fullRScriptName = HOME + TEST_NAME + ".R";
64-
rCmd = getRCmd(String.valueOf(ID), expectedDir());
76+
rCmd = getRCmd(String.valueOf(ID), String.valueOf(rowNum), expectedDir());
6577

6678
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
6779

src/test/scripts/functions/rewrite/RewriteSimplifyUnion.R

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,21 @@ library("Matrix")
3030
library("matrixStats")
3131

3232
# Read matrices
33-
X = seq(1, 7)
34-
Y = seq(4, 10)
33+
colNum = as.integer(args[1])
34+
rowNum = as.integer(args[2])
35+
X = matrix(rep(1, colNum), nrow=1, ncol=colNum)
36+
Y = matrix(rep(1 + floor(rowNum / 2), colNum), nrow=1, ncol=colNum)
37+
38+
if(rowNum != 1) {
39+
for(i in 2 : rowNum - 1) {
40+
X = rbind(X, rep(i + 1, colNum))
41+
Y = rbind(Y, rep(i + 1 + floor(rowNum / 2), colNum))
42+
}
43+
}
3544

3645
# Perform operations
37-
combined = c(rbind(X,Y));
46+
combined = rbind(X,Y);
3847
R = unique(combined);
39-
R = sort(R);
4048

4149
#Write result matrix R
42-
writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""))
50+
writeMM(as(R, "CsparseMatrix"), paste(args[3], "R", sep=""))

src/test/scripts/functions/rewrite/RewriteSimplifyUnion.dml

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,21 @@
2121

2222

2323

24-
# type = $1
25-
A = matrix("1 2 3 4 5 6 7", rows=7, cols=1)
26-
B = matrix("4 5 6 7 8 9 10", rows=7, cols=1)
27-
C = rbind(A,B)
28-
D = unique(C)
29-
D = order(target=D)
30-
# print(D)
24+
colNum = $1
25+
rowNum = $2
26+
27+
X = matrix(1, rows=rowNum, cols=colNum)
28+
Y = matrix(1, rows=rowNum, cols=colNum)
29+
for (i in 1 : rowNum) {
30+
for (j in 1 : colNum) {
31+
X[i, j] = i
32+
Y[i, j] = i + floor(rowNum/2)
33+
}
34+
}
35+
36+
C = rbind(X,Y)
37+
R = unique(C)
38+
R = order(target=R)
3139

3240
# Write the result matrix R
33-
write(D, $2)
41+
write(R, $3)

0 commit comments

Comments
 (0)