Skip to content

Commit 0263279

Browse files
jessicapriebemboehm7
authored andcommitted
[SYSTEMDS-3895] Add OOC row and column aggregations with tests
Closes #2309.
1 parent c9a54fe commit 0263279

File tree

5 files changed

+384
-20
lines changed

5 files changed

+384
-20
lines changed

src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java

Lines changed: 110 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,15 @@
3030
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
3131
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
3232
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
33+
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
3334
import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
3435
import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
3536
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
37+
import org.apache.sysds.runtime.meta.DataCharacteristics;
38+
import org.apache.sysds.runtime.util.CommonThreadPool;
3639

40+
import java.util.HashMap;
41+
import java.util.concurrent.ExecutorService;
3742

3843
public class AggregateUnaryOOCInstruction extends ComputationOOCInstruction {
3944
private AggregateOperator _aop = null;
@@ -61,34 +66,119 @@ public static AggregateUnaryOOCInstruction parseInstruction(String str) {
6166

6267
@Override
6368
public void processInstruction( ExecutionContext ec ) {
64-
//TODO support all types of aggregations, currently only full aggregation
69+
//TODO support all types of aggregations, currently only full aggregation, row aggregation and column aggregation
6570

6671
//setup operators and input queue
6772
AggregateUnaryOperator aggun = (AggregateUnaryOperator) getOperator();
6873
MatrixObject min = ec.getMatrixObject(input1);
6974
LocalTaskQueue<IndexedMatrixValue> q = min.getStreamHandle();
70-
IndexedMatrixValue tmp = null;
7175
int blen = ConfigurationManager.getBlocksize();
72-
73-
//read blocks and aggregate immediately into result
74-
int extra = _aop.correction.getNumRemovedRowsColumns();
75-
MatrixBlock ret = new MatrixBlock(1,1+extra,false);
76-
MatrixBlock corr = new MatrixBlock(1,1+extra,false);
77-
try {
78-
while((tmp = q.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) {
79-
//block aggregation
80-
MatrixBlock ltmp = (MatrixBlock) ((MatrixBlock) tmp.getValue())
81-
.aggregateUnaryOperations(aggun, new MatrixBlock(), blen, tmp.getIndexes());
82-
//accumulation into final result
83-
OperationsOnMatrixValues.incrementalAggregation(
84-
ret, _aop.existsCorrection() ? corr : null, ltmp, _aop, true);
76+
77+
if (aggun.isRowAggregate() || aggun.isColAggregate()) {
78+
// intermediate state per aggregation index
79+
HashMap<Long, MatrixBlock> aggs = new HashMap<>(); // partial aggregates
80+
HashMap<Long, MatrixBlock> corrs = new HashMap<>(); // correction blocks
81+
HashMap<Long, Integer> cnt = new HashMap<>(); // processed block count per agg idx
82+
83+
DataCharacteristics chars = ec.getDataCharacteristics(input1.getName());
84+
// number of blocks to process per aggregation idx (row or column dim)
85+
long nBlocks = aggun.isRowAggregate()? chars.getNumColBlocks() : chars.getNumRowBlocks();
86+
87+
LocalTaskQueue<IndexedMatrixValue> qOut = new LocalTaskQueue<>();
88+
ec.getMatrixObject(output).setStreamHandle(qOut);
89+
ExecutorService pool = CommonThreadPool.get();
90+
try {
91+
pool.submit(() -> {
92+
IndexedMatrixValue tmp = null;
93+
try {
94+
while((tmp = q.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) {
95+
long idx = aggun.isRowAggregate() ?
96+
tmp.getIndexes().getRowIndex() : tmp.getIndexes().getColumnIndex();
97+
if(aggs.containsKey(idx)) {
98+
// update existing partial aggregate for this idx
99+
MatrixBlock ret = aggs.get(idx);
100+
MatrixBlock corr = corrs.get(idx);
101+
102+
// aggregation
103+
MatrixBlock ltmp = (MatrixBlock) ((MatrixBlock) tmp.getValue())
104+
.aggregateUnaryOperations(aggun, new MatrixBlock(), blen, tmp.getIndexes());
105+
OperationsOnMatrixValues.incrementalAggregation(ret,
106+
_aop.existsCorrection() ? corr : null, ltmp, _aop, true);
107+
108+
aggs.replace(idx, ret);
109+
corrs.replace(idx, corr);
110+
cnt.replace(idx, cnt.get(idx) + 1);
111+
}
112+
else {
113+
// first block for this idx - init aggregate and correction
114+
// TODO avoid corr block for inplace incremental aggregation
115+
int rows = tmp.getValue().getNumRows();
116+
int cols = tmp.getValue().getNumColumns();
117+
int extra = _aop.correction.getNumRemovedRowsColumns();
118+
MatrixBlock ret = aggun.isRowAggregate()? new MatrixBlock(rows, 1 + extra, false) : new MatrixBlock(1 + extra, cols, false);
119+
MatrixBlock corr = aggun.isRowAggregate()? new MatrixBlock(rows, 1 + extra, false) : new MatrixBlock(1 + extra, cols, false);
120+
121+
// aggregation
122+
MatrixBlock ltmp = (MatrixBlock) ((MatrixBlock) tmp.getValue()).aggregateUnaryOperations(
123+
aggun, new MatrixBlock(), blen, tmp.getIndexes());
124+
OperationsOnMatrixValues.incrementalAggregation(ret,
125+
_aop.existsCorrection() ? corr : null, ltmp, _aop, true);
126+
127+
aggs.put(idx, ret);
128+
corrs.put(idx, corr);
129+
cnt.put(idx, 1);
130+
}
131+
132+
if(cnt.get(idx) == nBlocks) {
133+
// all input blocks for this idx processed - emit aggregated block
134+
MatrixBlock ret = aggs.get(idx);
135+
// drop correction row/col
136+
ret.dropLastRowsOrColumns(_aop.correction);
137+
MatrixIndexes midx = aggun.isRowAggregate()? new MatrixIndexes(tmp.getIndexes().getRowIndex(), 1) : new MatrixIndexes(1, tmp.getIndexes().getColumnIndex());
138+
IndexedMatrixValue tmpOut = new IndexedMatrixValue(midx, ret);
139+
140+
qOut.enqueueTask(tmpOut);
141+
// drop intermediate states
142+
aggs.remove(idx);
143+
corrs.remove(idx);
144+
cnt.remove(idx);
145+
}
146+
}
147+
qOut.closeInput();
148+
}
149+
catch(Exception ex) {
150+
throw new DMLRuntimeException(ex);
151+
}
152+
});
153+
} catch (Exception ex) {
154+
throw new DMLRuntimeException(ex);
155+
} finally {
156+
pool.shutdown();
85157
}
86158
}
87-
catch(Exception ex) {
88-
throw new DMLRuntimeException(ex);
159+
// full aggregation
160+
else {
161+
IndexedMatrixValue tmp = null;
162+
//read blocks and aggregate immediately into result
163+
int extra = _aop.correction.getNumRemovedRowsColumns();
164+
MatrixBlock ret = new MatrixBlock(1,1+extra,false);
165+
MatrixBlock corr = new MatrixBlock(1,1+extra,false);
166+
try {
167+
while((tmp = q.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) {
168+
//block aggregation
169+
MatrixBlock ltmp = (MatrixBlock) ((MatrixBlock) tmp.getValue())
170+
.aggregateUnaryOperations(aggun, new MatrixBlock(), blen, tmp.getIndexes());
171+
//accumulation into final result
172+
OperationsOnMatrixValues.incrementalAggregation(
173+
ret, _aop.existsCorrection() ? corr : null, ltmp, _aop, true);
174+
}
175+
}
176+
catch(Exception ex) {
177+
throw new DMLRuntimeException(ex);
178+
}
179+
180+
//create scalar output
181+
ec.setScalarOutput(output.getName(), new DoubleObject(ret.get(0, 0)));
89182
}
90-
91-
//create scalar output
92-
ec.setScalarOutput(output.getName(), new DoubleObject(ret.get(0, 0)));
93183
}
94184
}
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.functions.ooc;
21+
22+
import org.apache.sysds.common.Opcodes;
23+
import org.apache.sysds.common.Types;
24+
import org.apache.sysds.hops.OptimizerUtils;
25+
import org.apache.sysds.runtime.instructions.Instruction;
26+
import org.apache.sysds.runtime.io.MatrixWriter;
27+
import org.apache.sysds.runtime.io.MatrixWriterFactory;
28+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
29+
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
30+
import org.apache.sysds.runtime.util.DataConverter;
31+
import org.apache.sysds.runtime.util.HDFSTool;
32+
import org.apache.sysds.test.AutomatedTestBase;
33+
import org.apache.sysds.test.TestConfiguration;
34+
import org.apache.sysds.test.TestUtils;
35+
import org.junit.Assert;
36+
import org.junit.Test;
37+
38+
public class ColAggregationTest extends AutomatedTestBase{
39+
private static final String TEST_NAME = "ColAggregationTest";
40+
private static final String TEST_DIR = "functions/ooc/";
41+
private static final String TEST_CLASS_DIR = TEST_DIR + ColAggregationTest.class.getSimpleName() + "/";
42+
private static final String INPUT_NAME = "X";
43+
private static final String OUTPUT_NAME = "res";
44+
45+
@Override
46+
public void setUp() {
47+
TestUtils.clearAssertionInformation();
48+
TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME);
49+
addTestConfiguration(TEST_NAME, config);
50+
}
51+
52+
@Test
53+
public void testColAggregationNoRewrite() {
54+
testColAggregation(false);
55+
}
56+
57+
/**
58+
* Test the col aggregation, "colSums(X)", with OOC backend.
59+
*/
60+
@Test
61+
public void testColAggregationRewrite() {
62+
testColAggregation(true);
63+
}
64+
65+
public void testColAggregation(boolean rewrite)
66+
{
67+
Types.ExecMode platformOld = rtplatform;
68+
rtplatform = Types.ExecMode.SINGLE_NODE;
69+
boolean oldRewrite = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
70+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrite;
71+
72+
try {
73+
getAndLoadTestConfiguration(TEST_NAME);
74+
String HOME = SCRIPT_DIR + TEST_DIR;
75+
fullDMLScriptName = HOME + TEST_NAME + ".dml";
76+
programArgs = new String[] {"-explain", "-stats", "-ooc",
77+
"-args", input(INPUT_NAME), output(OUTPUT_NAME)};
78+
79+
int rows = 4200, cols = 2700;
80+
MatrixBlock mb = MatrixBlock.randOperations(rows, cols, 1.0, -1, 1, "uniform", 7);
81+
MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY);
82+
writer.writeMatrixToHDFS(mb, input(INPUT_NAME), rows, cols, 1000, rows*cols);
83+
HDFSTool.writeMetaDataFile(input(INPUT_NAME+".mtd"), Types.ValueType.FP64,
84+
new MatrixCharacteristics(rows,cols,1000,rows*cols), Types.FileFormat.BINARY);
85+
86+
runTest(true, false, null, -1);
87+
88+
double[][] res = DataConverter.convertToDoubleMatrix(DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME), Types.FileFormat.BINARY, 1, cols, 1000, 1000));
89+
for(int j = 0; j < cols; j++) {
90+
double expected = 0.0;
91+
for(int i = 0; i < rows; i++) {
92+
expected += mb.get(i, j);
93+
}
94+
Assert.assertEquals(expected, res[0][j], 1e-10);
95+
}
96+
97+
String prefix = Instruction.OOC_INST_PREFIX;
98+
Assert.assertTrue("OOC wasn't used for RBLK",
99+
heavyHittersContainsString(prefix + Opcodes.RBLK));
100+
// uack+
101+
Assert.assertTrue("OOC wasn't used for COLSUMS",
102+
heavyHittersContainsString(prefix + Opcodes.UACKP));
103+
}
104+
catch(Exception ex) {
105+
Assert.fail(ex.getMessage());
106+
}
107+
finally {
108+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldRewrite;
109+
resetExecMode(platformOld);
110+
}
111+
}
112+
113+
}
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.functions.ooc;
21+
22+
import org.apache.sysds.common.Opcodes;
23+
import org.apache.sysds.common.Types;
24+
import org.apache.sysds.hops.OptimizerUtils;
25+
import org.apache.sysds.runtime.instructions.Instruction;
26+
import org.apache.sysds.runtime.io.MatrixWriter;
27+
import org.apache.sysds.runtime.io.MatrixWriterFactory;
28+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
29+
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
30+
import org.apache.sysds.runtime.util.DataConverter;
31+
import org.apache.sysds.runtime.util.HDFSTool;
32+
import org.apache.sysds.test.AutomatedTestBase;
33+
import org.apache.sysds.test.TestConfiguration;
34+
import org.apache.sysds.test.TestUtils;
35+
import org.junit.Assert;
36+
import org.junit.Test;
37+
38+
public class RowAggregationTest extends AutomatedTestBase{
39+
private static final String TEST_NAME = "RowAggregationTest";
40+
private static final String TEST_DIR = "functions/ooc/";
41+
private static final String TEST_CLASS_DIR = TEST_DIR + RowAggregationTest.class.getSimpleName() + "/";
42+
private static final String INPUT_NAME = "X";
43+
private static final String OUTPUT_NAME = "res";
44+
45+
@Override
46+
public void setUp() {
47+
TestUtils.clearAssertionInformation();
48+
TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME);
49+
addTestConfiguration(TEST_NAME, config);
50+
}
51+
52+
@Test
53+
public void testRowAggregationNoRewrite() {
54+
testRowAggregation(false);
55+
}
56+
57+
/**
58+
* Test the row aggregation, "rowSums(X)", with OOC backend.
59+
*/
60+
@Test
61+
public void testRowAggregationRewrite() {
62+
testRowAggregation(true);
63+
}
64+
65+
public void testRowAggregation(boolean rewrite)
66+
{
67+
Types.ExecMode platformOld = rtplatform;
68+
rtplatform = Types.ExecMode.SINGLE_NODE;
69+
boolean oldRewrite = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
70+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrite;
71+
72+
try {
73+
getAndLoadTestConfiguration(TEST_NAME);
74+
String HOME = SCRIPT_DIR + TEST_DIR;
75+
fullDMLScriptName = HOME + TEST_NAME + ".dml";
76+
programArgs = new String[] {"-explain", "-stats", "-ooc",
77+
"-args", input(INPUT_NAME), output(OUTPUT_NAME)};
78+
79+
int rows = 3900, cols = 1700;
80+
MatrixBlock mb = MatrixBlock.randOperations(rows, cols, 1.0, -1, 1, "uniform", 7);
81+
MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY);
82+
writer.writeMatrixToHDFS(mb, input(INPUT_NAME), rows, cols, 1000, rows*cols);
83+
HDFSTool.writeMetaDataFile(input(INPUT_NAME+".mtd"), Types.ValueType.FP64,
84+
new MatrixCharacteristics(rows,cols,1000,rows*cols), Types.FileFormat.BINARY);
85+
86+
runTest(true, false, null, -1);
87+
88+
double[][] res = DataConverter.convertToDoubleMatrix(DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME), Types.FileFormat.BINARY, rows, 1, 1000, 1000));
89+
for(int i = 0; i < rows; i++) {
90+
double expected = 0.0;
91+
for(int j = 0; j < cols; j++) {
92+
expected += mb.get(i, j);
93+
}
94+
Assert.assertEquals(expected, res[i][0], 1e-10);
95+
}
96+
97+
String prefix = Instruction.OOC_INST_PREFIX;
98+
Assert.assertTrue("OOC wasn't used for RBLK",
99+
heavyHittersContainsString(prefix + Opcodes.RBLK));
100+
// uark+
101+
Assert.assertTrue("OOC wasn't used for ROWSUMS",
102+
heavyHittersContainsString(prefix + Opcodes.UARKP));
103+
}
104+
catch(Exception ex) {
105+
Assert.fail(ex.getMessage());
106+
}
107+
finally {
108+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldRewrite;
109+
resetExecMode(platformOld);
110+
}
111+
}
112+
113+
}

0 commit comments

Comments
 (0)