Skip to content

Commit 8054c9b

Browse files
committed
merge files from Java17V5
1 parent ee553b6 commit 8054c9b

File tree

11 files changed

+636
-42
lines changed

11 files changed

+636
-42
lines changed

src/main/java/org/apache/sysds/hops/TernaryOp.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import org.apache.sysds.api.DMLScript;
2323
import org.apache.sysds.common.Types.DataType;
24+
import org.apache.sysds.common.Types.ExecType;
2425
import org.apache.sysds.common.Types.OpOp2;
2526
import org.apache.sysds.common.Types.OpOp3;
2627
import org.apache.sysds.common.Types.OpOpDG;
@@ -33,8 +34,8 @@
3334
import org.apache.sysds.lops.CentralMoment;
3435
import org.apache.sysds.lops.CoVariance;
3536
import org.apache.sysds.lops.Ctable;
37+
import org.apache.sysds.lops.Data;
3638
import org.apache.sysds.lops.Lop;
37-
import org.apache.sysds.common.Types.ExecType;
3839
import org.apache.sysds.lops.LopsException;
3940
import org.apache.sysds.lops.PickByCount;
4041
import org.apache.sysds.lops.SortKeys;
@@ -273,14 +274,19 @@ private void constructLopsCtable() {
273274
// F=ctable(A,B,W)
274275

275276
DataType dt1 = getInput().get(0).getDataType();
277+
278+
276279
DataType dt2 = getInput().get(1).getDataType();
277280
DataType dt3 = getInput().get(2).getDataType();
278281
Ctable.OperationTypes ternaryOpOrig = Ctable.findCtableOperationByInputDataTypes(dt1, dt2, dt3);
279282

280283
// Compute lops for all inputs
281284
Lop[] inputLops = new Lop[getInput().size()];
282285
for(int i=0; i < getInput().size(); i++) {
283-
inputLops[i] = getInput().get(i).constructLops();
286+
if(i == 0 && HopRewriteUtils.isSequenceSizeOfA(getInput(0), getInput(1)))
287+
inputLops[i] = Data.createLiteralLop(ValueType.INT64, "" +getInput(1).getDim(0));
288+
else
289+
inputLops[i] = getInput().get(i).constructLops();
284290
}
285291

286292
ExecType et = optFindExecType();

src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1392,6 +1392,25 @@ public static boolean isBasicN1Sequence(Hop hop)
13921392
return ret;
13931393
}
13941394

1395+
public static boolean isSequenceSizeOfA(Hop hop, Hop A)
1396+
{
1397+
boolean ret = false;
1398+
1399+
if( hop instanceof DataGenOp )
1400+
{
1401+
DataGenOp dgop = (DataGenOp) hop;
1402+
if( dgop.getOp() == OpOpDG.SEQ ){
1403+
Hop from = dgop.getInput().get(dgop.getParamIndex(Statement.SEQ_FROM));
1404+
Hop to = dgop.getInput().get(dgop.getParamIndex(Statement.SEQ_TO));
1405+
Hop incr = dgop.getInput().get(dgop.getParamIndex(Statement.SEQ_INCR));
1406+
ret = (from instanceof LiteralOp && getIntValueSafe((LiteralOp) from) == 1) &&
1407+
(to instanceof LiteralOp && getIntValueSafe((LiteralOp) to) == A.getDim(0)) &&
1408+
(incr instanceof LiteralOp && getIntValueSafe((LiteralOp)incr)==1);
1409+
}
1410+
}
1411+
1412+
return ret;
1413+
}
13951414

13961415
public static Hop getBasic1NSequenceMax(Hop hop) {
13971416
if( isDataGenOp(hop, OpOpDG.SEQ) ) {

src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1974,8 +1974,8 @@ else if(this.getOpCode() == Builtins.MAX_POOL || this.getOpCode() == Builtins.AV
19741974
case DECOMPRESS:
19751975
if(OptimizerUtils.ALLOW_SCRIPT_LEVEL_COMPRESS_COMMAND){
19761976
checkNumParameters(1);
1977-
checkMatrixParam(getFirstExpr());
1978-
output.setDataType(DataType.MATRIX);
1977+
checkMatrixFrameParam(getFirstExpr());
1978+
output.setDataType(getFirstExpr().getOutput().getDataType());
19791979
output.setDimensions(id.getDim1(), id.getDim2());
19801980
output.setBlocksize (id.getBlocksize());
19811981
output.setValueType(id.getValueType());
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.runtime.compress.lib;
21+
22+
import java.util.ArrayList;
23+
import java.util.List;
24+
import java.util.concurrent.ExecutorService;
25+
import java.util.concurrent.Future;
26+
27+
import org.apache.commons.logging.Log;
28+
import org.apache.commons.logging.LogFactory;
29+
import org.apache.sysds.runtime.DMLRuntimeException;
30+
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
31+
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
32+
import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC;
33+
import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
34+
import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary;
35+
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
36+
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
37+
import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
38+
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
39+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
40+
import org.apache.sysds.runtime.util.CommonThreadPool;
41+
import org.apache.sysds.runtime.util.UtilFunctions;
42+
import org.apache.sysds.utils.stats.InfrastructureAnalyzer;
43+
44+
public class CLALibTable {
45+
46+
protected static final Log LOG = LogFactory.getLog(CLALibTable.class.getName());
47+
48+
private CLALibTable() {
49+
// empty constructor
50+
}
51+
52+
public static MatrixBlock tableSeqOperations(int seqHeight, MatrixBlock A, int nColOut){
53+
54+
int k = InfrastructureAnalyzer.getLocalParallelism();
55+
try{
56+
57+
final int[] map = new int[seqHeight];
58+
int maxCol = constructInitialMapping(map, A, k);
59+
boolean containsNull = maxCol < 0;
60+
maxCol = Math.abs(maxCol);
61+
62+
if(nColOut == -1)
63+
nColOut = maxCol;
64+
else if(nColOut < maxCol)
65+
throw new DMLRuntimeException("invalid nColOut, requested: " + nColOut + " but have to be : " + maxCol);
66+
67+
final int nNulls = containsNull ? correctNulls(map, nColOut) : 0;
68+
if(nColOut == 0) // edge case of empty zero dimension block.
69+
return new MatrixBlock(seqHeight, 0, 0.0);
70+
return createCompressedReturn(map, nColOut, seqHeight, nNulls, containsNull, k);
71+
}
72+
catch(Exception e){
73+
throw new RuntimeException("Failed table seq operator",e);
74+
}
75+
}
76+
77+
private static CompressedMatrixBlock createCompressedReturn(int[] map, int nColOut, int seqHeight, int nNulls,
78+
boolean containsNull, int k) throws Exception {
79+
// create a single DDC Column group.
80+
final IColIndex i = ColIndexFactory.create(0, nColOut);
81+
final ADictionary d = new IdentityDictionary(nColOut, containsNull);
82+
final AMapToData m = MapToFactory.create(seqHeight, map, nColOut + (containsNull ? 1 : 0), k);
83+
final AColGroup g = ColGroupDDC.create(i, d, m, null);
84+
85+
final CompressedMatrixBlock cmb = new CompressedMatrixBlock(seqHeight, nColOut);
86+
cmb.allocateColGroup(g);
87+
cmb.setNonZeros(seqHeight - nNulls);
88+
return cmb;
89+
}
90+
91+
private static int correctNulls(int[] map, int nColOut) {
92+
int nNulls = 0;
93+
for(int i = 0; i < map.length; i++) {
94+
if(map[i] == -1) {
95+
map[i] = nColOut;
96+
nNulls++;
97+
}
98+
}
99+
return nNulls;
100+
}
101+
102+
private static int constructInitialMapping(int[] map, MatrixBlock A, int k) {
103+
if(A.isEmpty() || A.isInSparseFormat())
104+
throw new DMLRuntimeException("not supported empty or sparse construction of seq table");
105+
106+
ExecutorService pool = CommonThreadPool.get(k);
107+
try {
108+
109+
int blkz = Math.max((map.length / k), 1000);
110+
List<Future<Integer>> tasks = new ArrayList<>();
111+
for(int i = 0; i < map.length; i+= blkz){
112+
final int start = i;
113+
final int end = Math.min(i + blkz, map.length);
114+
tasks.add(pool.submit(() -> partialMapping(map, A, start, end)));
115+
}
116+
117+
int maxCol = 0;
118+
for( Future<Integer> f : tasks){
119+
int tmp = f.get();
120+
if(Math.abs(tmp) >Math.abs(maxCol))
121+
maxCol = tmp;
122+
}
123+
return maxCol;
124+
}
125+
catch(Exception e) {
126+
throw new DMLRuntimeException(e);
127+
}
128+
finally {
129+
pool.shutdown();
130+
}
131+
132+
}
133+
134+
private static int partialMapping(int[] map, MatrixBlock A, int start, int end) {
135+
136+
int maxCol = 0;
137+
boolean containsNull = false;
138+
final double[] aVals = A.getDenseBlockValues();
139+
140+
for(int i = start; i < end; i++) {
141+
final double v2 = aVals[i];
142+
if(Double.isNaN(v2)) {
143+
map[i] = -1; // assign temporarily to -1
144+
containsNull = true;
145+
}
146+
else {
147+
// safe casts to long for consistent behavior with indexing
148+
int col = UtilFunctions.toInt(v2);
149+
if(col <= 0)
150+
throw new DMLRuntimeException(
151+
"Erroneous input while computing the contingency table (value <= zero): " + v2);
152+
153+
map[i] = col - 1;
154+
// maintain max seen col
155+
maxCol = Math.max(col, maxCol);
156+
}
157+
}
158+
159+
return containsNull ? maxCol * -1 : maxCol;
160+
}
161+
162+
}

src/main/java/org/apache/sysds/runtime/instructions/cp/CtableCPInstruction.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.apache.sysds.runtime.lineage.LineageItem;
3131
import org.apache.sysds.runtime.lineage.LineageItemUtils;
3232
import org.apache.sysds.runtime.matrix.data.CTableMap;
33+
import org.apache.sysds.runtime.matrix.data.LibMatrixTable;
3334
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
3435
import org.apache.sysds.runtime.util.DataConverter;
3536
import org.apache.sysds.runtime.util.LongLongDoubleHashMap.EntryType;
@@ -88,9 +89,11 @@ private Ctable.OperationTypes findCtableOperation() {
8889

8990
@Override
9091
public void processInstruction(ExecutionContext ec) {
91-
MatrixBlock matBlock1 = ec.getMatrixInput(input1.getName());
92-
MatrixBlock matBlock2=null, wtBlock=null;
92+
MatrixBlock matBlock1 = null;
93+
MatrixBlock matBlock2 = null, wtBlock=null;
9394
double cst1, cst2;
95+
if(!input1.isScalar())
96+
matBlock1 = ec.getMatrixInput(input1.getName());
9497

9598
CTableMap resultMap = new CTableMap(EntryType.INT);
9699
MatrixBlock resultBlock = null;
@@ -111,7 +114,8 @@ public void processInstruction(ExecutionContext ec) {
111114
resultBlock = new MatrixBlock((int)outputDim1, (int)outputDim2, false);
112115
}
113116
if( _isExpand ){
114-
resultBlock = new MatrixBlock( matBlock1.getNumRows(), Integer.MAX_VALUE, true );
117+
if(matBlock1 != null)
118+
resultBlock = new MatrixBlock( matBlock1.getNumRows(), Integer.MAX_VALUE, true );
115119
}
116120

117121
switch(ctableOp) {
@@ -132,7 +136,7 @@ public void processInstruction(ExecutionContext ec) {
132136
matBlock2 = ec.getMatrixInput(input2.getName());
133137
cst1 = ec.getScalarInput(input3).getDoubleValue();
134138
// only resultBlock.rlen known, resultBlock.clen set in operation
135-
matBlock1.ctableSeqOperations(matBlock2, cst1, resultBlock);
139+
resultBlock = LibMatrixTable.tableSeqOperations((int)input1.getLiteral().getLongValue(), matBlock2, cst1, resultBlock, true);
136140
break;
137141
case CTABLE_TRANSFORM_HISTOGRAM: //(VECTOR)
138142
// F=ctable(A,1) or F = ctable(A,1,1)

src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java

Lines changed: 8 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5351,46 +5351,20 @@ public void ctableOperations(Operator op, MatrixValue thatVal, double scalarThat
53515351
}
53525352

53535353
/**
5354+
* D = ctable(seq,A,w)
5355+
* <p>
5356+
* this = seq; thatMatrix = A; thatScalar = w; ret = D
5357+
*
53545358
* @param thatMatrix matrix value
53555359
* @param thatScalar scalar double
5356-
* @param ret result matrix block
5360+
* @param ret result matrix block that is the weight to multiply into the table output
53575361
* @param updateClen when this matrix already has the desired number of columns updateClen can be set to false
53585362
* @return result matrix block
53595363
*/
5360-
public MatrixBlock ctableSeqOperations(MatrixValue thatMatrix, double thatScalar, MatrixBlock ret, boolean updateClen) {
5364+
public MatrixBlock ctableSeqOperations(MatrixValue thatMatrix, double thatScalar, MatrixBlock ret,
5365+
boolean updateClen) {
53615366
MatrixBlock that = checkType(thatMatrix);
5362-
CTable ctable = CTable.getCTableFnObject();
5363-
double w = thatScalar;
5364-
5365-
//prepare allocation of CSR sparse block
5366-
int[] rptr = new int[rlen+1];
5367-
int[] indexes = new int[rlen];
5368-
double[] values = new double[rlen];
5369-
5370-
//sparse-unsafe ctable execution
5371-
//(because input values of 0 are invalid and have to result in errors)
5372-
//resultBlock guaranteed to be allocated for ctableexpand
5373-
//each row in resultBlock will be allocated and will contain exactly one value
5374-
int maxCol = 0;
5375-
for( int i=0; i<rlen; i++ ) {
5376-
double v2 = that.get(i, 0);
5377-
maxCol = ctable.execute(i+1, v2, w, maxCol, indexes, values);
5378-
rptr[i] = i;
5379-
}
5380-
rptr[rlen] = rlen;
5381-
5382-
//construct sparse CSR block from filled arrays
5383-
ret.sparseBlock = new SparseBlockCSR(rptr, indexes, values, rlen);
5384-
((SparseBlockCSR)ret.sparseBlock).compact();
5385-
ret.setNonZeros(ret.sparseBlock.size());
5386-
5387-
//update meta data (initially unknown number of columns)
5388-
//note: nnz maintained in ctable (via quickset)
5389-
if(updateClen) {
5390-
ret.clen = maxCol;
5391-
}
5392-
5393-
return ret;
5367+
return LibMatrixTable.tableSeqOperations(this.getNumRows(), that, thatScalar, ret, updateClen);
53945368
}
53955369

53965370
/**

0 commit comments

Comments
 (0)