Skip to content

Commit a6faf44

Browse files
committed
[SYSTEMDS-3895] New out-of-core unary aggregate operations
This patch introduces the out-of-core unary aggregate operations as an example of how to implement operations against the input stream of blocks.
1 parent c168aa1 commit a6faf44

File tree

4 files changed

+99
-6
lines changed

4 files changed

+99
-6
lines changed

src/main/java/org/apache/sysds/hops/AggUnaryOp.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -409,9 +409,6 @@ else if(getInput().get(0).areDimsBelowThreshold() || getInput().get(0).isVector(
409409
else
410410
setRequiresRecompileIfNecessary();
411411

412-
if( _etype == ExecType.OOC ) //TODO
413-
setExecType(ExecType.CP);
414-
415412
return _etype;
416413
}
417414

src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.apache.commons.logging.LogFactory;
2424
import org.apache.sysds.common.InstructionType;
2525
import org.apache.sysds.runtime.DMLRuntimeException;
26+
import org.apache.sysds.runtime.instructions.ooc.AggregateUnaryOOCInstruction;
2627
import org.apache.sysds.runtime.instructions.ooc.OOCInstruction;
2728
import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction;
2829

@@ -47,9 +48,10 @@ public static OOCInstruction parseSingleInstruction(InstructionType ooctype, Str
4748
switch(ooctype) {
4849
case Reblock:
4950
return ReblockOOCInstruction.parseInstruction(str);
51+
case AggregateUnary:
52+
return AggregateUnaryOOCInstruction.parseInstruction(str);
5053

5154
// TODO:
52-
case AggregateUnary:
5355
case Binary:
5456

5557
default:
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.runtime.instructions.ooc;
21+
22+
import org.apache.sysds.common.Types.CorrectionLocationType;
23+
import org.apache.sysds.conf.ConfigurationManager;
24+
import org.apache.sysds.runtime.DMLRuntimeException;
25+
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
26+
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
27+
import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
28+
import org.apache.sysds.runtime.instructions.InstructionUtils;
29+
import org.apache.sysds.runtime.instructions.cp.CPOperand;
30+
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
31+
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
32+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
33+
import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
34+
import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
35+
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
36+
37+
38+
public class AggregateUnaryOOCInstruction extends ComputationOOCInstruction {
39+
private AggregateOperator _aop = null;
40+
41+
protected AggregateUnaryOOCInstruction(OOCType type, AggregateUnaryOperator auop, AggregateOperator aop,
42+
CPOperand in, CPOperand out, String opcode, String istr) {
43+
super(type, auop, in, out, opcode, istr);
44+
_aop = aop;
45+
}
46+
47+
public static AggregateUnaryOOCInstruction parseInstruction(String str) {
48+
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
49+
InstructionUtils.checkNumFields(parts, 2);
50+
String opcode = parts[0];
51+
CPOperand in1 = new CPOperand(parts[1]);
52+
CPOperand out = new CPOperand(parts[2]);
53+
54+
String aopcode = InstructionUtils.deriveAggregateOperatorOpcode(opcode);
55+
CorrectionLocationType corrLoc = InstructionUtils.deriveAggregateOperatorCorrectionLocation(opcode);
56+
AggregateUnaryOperator aggun = InstructionUtils.parseBasicAggregateUnaryOperator(opcode);
57+
AggregateOperator aop = InstructionUtils.parseAggregateOperator(aopcode, corrLoc.toString());
58+
return new AggregateUnaryOOCInstruction(
59+
OOCType.AggregateUnary, aggun, aop, in1, out, opcode, str);
60+
}
61+
62+
@Override
63+
public void processInstruction( ExecutionContext ec ) {
64+
//TODO support all types of aggregations, currently only full aggregation
65+
66+
//setup operators and input queue
67+
AggregateUnaryOperator aggun = (AggregateUnaryOperator) getOperator();
68+
MatrixObject min = ec.getMatrixObject(input1);
69+
LocalTaskQueue<IndexedMatrixValue> q = min.getStreamHandle();
70+
IndexedMatrixValue tmp = null;
71+
int blen = ConfigurationManager.getBlocksize();
72+
73+
//read blocks and aggregate immediately into result
74+
int extra = _aop.correction.getNumRemovedRowsColumns();
75+
MatrixBlock ret = new MatrixBlock(1,1+extra,false);
76+
MatrixBlock corr = new MatrixBlock(1,1+extra,false);
77+
try {
78+
while((tmp = q.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) {
79+
//block aggregation
80+
MatrixBlock ltmp = (MatrixBlock) ((MatrixBlock) tmp.getValue())
81+
.aggregateUnaryOperations(aggun, new MatrixBlock(), blen, tmp.getIndexes());
82+
//accumulation into final result
83+
OperationsOnMatrixValues.incrementalAggregation(
84+
ret, _aop.existsCorrection() ? corr : null, ltmp, _aop, true);
85+
}
86+
}
87+
catch(Exception ex) {
88+
throw new DMLRuntimeException(ex);
89+
}
90+
91+
//create scalar output
92+
ec.setScalarOutput(output.getName(), new DoubleObject(ret.get(0, 0)));
93+
}
94+
}

src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,11 @@ public void testSumScalarMult() {
9292
String prefix = Instruction.OOC_INST_PREFIX;
9393
Assert.assertTrue("OOC wasn't used for RBLK",
9494
heavyHittersContainsString(prefix + Opcodes.RBLK));
95+
Assert.assertTrue("OOC wasn't used for SUM",
96+
heavyHittersContainsString(prefix + Opcodes.UAKP));
9597

9698
// boolean usedOOCMult = Statistics.getCPHeavyHitterOpCodes().contains(prefix + Opcodes.MULT);
9799
// Assert.assertTrue("OOC wasn't used for MULT", usedOOCMult);
98-
// boolean usedOOCSum = Statistics.getCPHeavyHitterOpCodes().contains(prefix + Opcodes.UAKP);
99-
// Assert.assertTrue("OOC wasn't used for SUM", usedOOCSum);
100100
}
101101
catch(Exception ex) {
102102
Assert.fail(ex.getMessage());

0 commit comments

Comments
 (0)