Skip to content

Commit 661b186

Browse files
committed
combine uncompressed error
1 parent 8f4c716 commit 661b186

File tree

5 files changed

+77
-22
lines changed

5 files changed

+77
-22
lines changed

src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,12 @@ public class ColGroupUncompressed extends AColGroup {
8080
*/
8181
private final MatrixBlock _data;
8282

83-
private ColGroupUncompressed(MatrixBlock mb, IColIndex colIndexes) {
83+
/**
84+
* Do not use this constructor of column group uncompressed, instead uce the create constructor.
85+
* @param mb The contained data.
86+
* @param colIndexes Column indexes for this Columngroup
87+
*/
88+
protected ColGroupUncompressed(MatrixBlock mb, IColIndex colIndexes) {
8489
super(colIndexes);
8590
_data = mb;
8691
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.runtime.compress.colgroup;
21+
22+
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
23+
import org.apache.sysds.runtime.frame.data.columns.Array;
24+
25+
/**
26+
* Special sideways Compressed column group not supposed to be used outside of the compressed transform encode.
27+
*/
28+
public class ColGroupUncompressedArray extends ColGroupUncompressed {
29+
30+
public final Array<?> array;
31+
public final int id; // columnID
32+
33+
public ColGroupUncompressedArray(Array<?> data, int id, IColIndex colIndexes){
34+
super(null, colIndexes);
35+
this.array = data;
36+
this.id = id;
37+
}
38+
39+
40+
}

src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCombineGroups.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,6 @@ public static AColGroup combineN(List<AColGroup> groups, int nRows, ExecutorServ
149149
else {
150150
return combineNSingleAtATime(groups, nRows);
151151
}
152-
153152
}
154153

155154
@SuppressWarnings("unchecked")

src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828

2929
import org.apache.commons.logging.Log;
3030
import org.apache.commons.logging.LogFactory;
31-
import org.apache.sysds.api.DMLScript;
3231
import org.apache.sysds.conf.ConfigurationManager;
3332
import org.apache.sysds.conf.DMLConfig;
3433
import org.apache.sysds.runtime.DMLRuntimeException;
@@ -43,8 +42,6 @@
4342
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
4443
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
4544
import org.apache.sysds.runtime.util.CommonThreadPool;
46-
import org.apache.sysds.utils.DMLCompressionStatistics;
47-
import org.apache.sysds.utils.stats.Timing;
4845

4946
public final class CLALibRightMultBy {
5047
private static final Log LOG = LogFactory.getLog(CLALibRightMultBy.class.getName());

src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
3838
import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC;
3939
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
40-
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
40+
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressedArray;
4141
import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
4242
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
4343
import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary;
@@ -104,7 +104,9 @@ private MatrixBlock apply() throws Exception {
104104
final List<ColumnEncoderComposite> encoders = enc.getColumnEncoders();
105105
final List<AColGroup> groups = isParallel() ? multiThread(encoders) : singleThread(encoders);
106106
final int cols = shiftGroups(groups);
107-
final MatrixBlock mb = new CompressedMatrixBlock(in.getNumRows(), cols, -1, false, groups);
107+
final CompressedMatrixBlock mb = new CompressedMatrixBlock(in.getNumRows(), cols, -1, false, groups);
108+
109+
combineUncompressed(mb);
108110
mb.setNonZeros(nnz.get());
109111
logging(mb);
110112
return mb;
@@ -193,7 +195,7 @@ private <T> AColGroup recodeToDummy(ColumnEncoderComposite c) throws Exception {
193195
if(containsNull && domain == 0)
194196
return new ColGroupEmpty(ColIndexFactory.create(1));
195197
IColIndex colIndexes = ColIndexFactory.create(0, domain);
196-
if(domain == 1 && !containsNull){
198+
if(domain == 1 && !containsNull) {
197199
nnz.addAndGet(in.getNumRows());
198200
return ColGroupConst.create(colIndexes, new double[] {1});
199201
}
@@ -347,10 +349,10 @@ private <T> AColGroup recode(ColumnEncoderComposite c) throws Exception {
347349

348350
// int domain = c.getDomainSize();
349351
IColIndex colIndexes = ColIndexFactory.create(1);
350-
if(domain == 0 && containsNull){
352+
if(domain == 0 && containsNull) {
351353
return new ColGroupEmpty(colIndexes);
352354
}
353-
if(domain == 1 && !containsNull){
355+
if(domain == 1 && !containsNull) {
354356
nnz.addAndGet(in.getNumRows());
355357
return ColGroupConst.create(colIndexes, new double[] {1});
356358
}
@@ -397,14 +399,7 @@ private <T> AColGroup passThroughNormal(ColumnEncoderComposite c, final IColInde
397399

398400
if(a.getValueType() != ValueType.BOOLEAN // if not booleans
399401
&& (stats == null || !stats.shouldCompress || stats.valueType != a.getValueType())) {
400-
// stats.valueType;
401-
double[] vals = (double[]) a.changeType(ValueType.FP64).get();
402-
403-
MatrixBlock col = new MatrixBlock(a.size(), 1, vals);
404-
long nz = col.recomputeNonZeros(1);
405-
406-
nnz.addAndGet(nz);
407-
return ColGroupUncompressed.create(colIndexes, col, false);
402+
return new ColGroupUncompressedArray(a, c._colID - 1,colIndexes);
408403
}
409404
else {
410405
boolean containsNull = a.containsNull();
@@ -532,10 +527,10 @@ private AColGroup hash(ColumnEncoderComposite c) {
532527
int domain = (int) CEHash.getK();
533528
boolean nulls = a.containsNull();
534529
IColIndex colIndexes = ColIndexFactory.create(0, 1);
535-
if(domain == 0 && nulls){
530+
if(domain == 0 && nulls) {
536531
return new ColGroupEmpty(colIndexes);
537532
}
538-
if(domain == 1 && !nulls){
533+
if(domain == 1 && !nulls) {
539534
nnz.addAndGet(in.getNumRows());
540535
return ColGroupConst.create(colIndexes, new double[] {1});
541536
}
@@ -561,10 +556,10 @@ private AColGroup hashToDummy(ColumnEncoderComposite c) {
561556
int domain = (int) CEHash.getK();
562557
boolean nulls = a.containsNull();
563558
IColIndex colIndexes = ColIndexFactory.create(0, domain);
564-
if(domain == 0 && nulls){
559+
if(domain == 0 && nulls) {
565560
return new ColGroupEmpty(ColIndexFactory.create(1));
566561
}
567-
if(domain == 1 && !nulls){
562+
if(domain == 1 && !nulls) {
568563
nnz.addAndGet(in.getNumRows());
569564
return ColGroupConst.create(colIndexes, new double[] {1});
570565
}
@@ -609,6 +604,25 @@ private <T> void estimateRCDMapSize(ColumnEncoderComposite c) {
609604
c._estNumDistincts = estDistCount;
610605
}
611606

607+
private void combineUncompressed(CompressedMatrixBlock mb) {
608+
609+
List<ColGroupUncompressedArray> ucg = new ArrayList<>();
610+
List<AColGroup> ret = new ArrayList<>();
611+
for(AColGroup g : mb.getColGroups()) {
612+
if(g instanceof ColGroupUncompressedArray)
613+
ucg.add((ColGroupUncompressedArray) g);
614+
else
615+
ret.add(g);
616+
}
617+
ret.add(combine(ucg));
618+
nnz.addAndGet(ret.get(ret.size()-1).getNumberNonZeros(in.getNumRows()));
619+
mb.allocateColGroupList(ret);
620+
}
621+
622+
private AColGroup combine(List<ColGroupUncompressedArray> ucg) {
623+
throw new NotImplementedException("Should combine " + ucg.size());
624+
}
625+
612626
private void logging(MatrixBlock mb) {
613627
if(LOG.isDebugEnabled()) {
614628
LOG.debug(String.format("Uncompressed transform encode Dense size: %16d", mb.estimateSizeDenseInMemory()));

0 commit comments

Comments
 (0)