Skip to content

Commit cc5c61e

Browse files
committed
[MINOR] Mapping Add A Range Setting
This commit adds a range setting function for mappings, to enable subsequent parallel setting from integer arrays. Signed-off-by: Sebastian Baunsgaard <baunsgaard@apache.org>
1 parent 5f07f2b commit cc5c61e

File tree

8 files changed

+226
-20
lines changed

8 files changed

+226
-20
lines changed

src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java

Lines changed: 85 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@
2222
import java.io.DataOutput;
2323
import java.io.IOException;
2424
import java.io.Serializable;
25+
import java.util.ArrayList;
2526
import java.util.BitSet;
27+
import java.util.List;
28+
import java.util.concurrent.ExecutorService;
29+
import java.util.concurrent.Future;
2630

2731
import org.apache.commons.lang3.NotImplementedException;
2832
import org.apache.commons.logging.Log;
@@ -95,7 +99,6 @@ public final void setUnique(int nUnique) {
9599
*/
96100
public abstract int getIndex(int n);
97101

98-
99102
/**
100103
* Shortcut method to support Integer objects, not really efficient but for the purpose of reusing code.
101104
*
@@ -116,6 +119,18 @@ public void set(int n, Integer v) {
116119
*/
117120
public abstract void set(int n, int v);
118121

122+
/**
123+
* set a range of values from another map.
124+
*
125+
* The given tm must only contain supported values, and it is not verified.
126+
*
127+
* @param l lower bound
128+
* @param u upper bound (not inclusive)
129+
* @param off offset to take values from tm
130+
* @param tm the other map to copy values from
131+
*/
132+
public abstract void set(int l, int u, int off, AMapToData tm);
133+
119134
/**
120135
* Set the index to the value and get the contained value after.
121136
*
@@ -813,7 +828,11 @@ protected void copyInt(MapToInt d) {
813828
*
814829
* @param d The array to copy
815830
*/
816-
public abstract void copyInt(int[] d);
831+
public void copyInt(int[] d) {
832+
copyInt(d, 0, size());
833+
}
834+
835+
public abstract void copyInt(int[] d, int start, int end);
817836

818837
public abstract void copyBit(BitSet d);
819838

@@ -887,7 +906,8 @@ public int countRuns(AOffset off) {
887906

888907
@Override
889908
public boolean equals(Object e) {
890-
return e instanceof AMapToData && (this == e || this.equals((AMapToData) e));
909+
return this == e || // same object or
910+
(e instanceof AMapToData && this.equals((AMapToData) e));
891911
}
892912

893913
/**
@@ -903,7 +923,7 @@ public void verify() {
903923
if(CompressedMatrixBlock.debug) {
904924
for(int i = 0; i < size(); i++) {
905925
if(getIndex(i) >= nUnique) {
906-
throw new DMLCompressionException("invalid construction of Mapping data containing values above unique");
926+
throw new DMLCompressionException("Invalid construction of Mapping data containing values above unique");
907927
}
908928
}
909929
}
@@ -934,7 +954,7 @@ public void decompressToRange(double[] c, int rl, int ru, int offR, double[] val
934954
decompressToRangeOff(c, rl, ru, offR, values);
935955
}
936956

937-
public void decompressToRangeOff(double[] c, int rl, int ru, int offR, double[] values) {
957+
protected void decompressToRangeOff(double[] c, int rl, int ru, int offR, double[] values) {
938958
for(int i = rl, offT = rl + offR; i < ru; i++, offT++)
939959
c[offT] += values[getIndex(i)];
940960
}
@@ -950,14 +970,73 @@ protected void decompressToRangeNoOffBy8(double[] c, int r, double[] values) {
950970
c[r + 7] += values[getIndex(r + 7)];
951971
}
952972

953-
public void decompressToRangeNoOff(double[] c, int rl, int ru, double[] values) {
973+
protected void decompressToRangeNoOff(double[] c, int rl, int ru, double[] values) {
954974
final int h = (ru - rl) % 8;
955975
for(int rc = rl; rc < rl + h; rc++)
956976
c[rc] += values[getIndex(rc)];
957977
for(int rc = rl + h; rc < ru; rc += 8)
958978
decompressToRangeNoOffBy8(c, rc, values);
959979
}
960980

981+
/**
982+
* Split this mapping into x smaller mappings according to round robin.
983+
*
984+
* @param multiplier The number of smaller mappings to construct
985+
* @return The list of smaller mappings
986+
*/
987+
public AMapToData[] splitReshapeDDC(final int multiplier) {
988+
989+
final int s = size();
990+
final AMapToData[] ret = new AMapToData[multiplier];
991+
final int eachSize = s / multiplier;
992+
for(int i = 0; i < multiplier; i++)
993+
ret[i] = MapToFactory.create(eachSize, getUnique());
994+
995+
// for(int i = 0; i < s; i += multiplier)
996+
// splitReshapeDDCRow(ret, multiplier, i);
997+
998+
final int blkz = Math.max(eachSize / 8, 2048) * multiplier;
999+
for(int i = 0; i < s; i += blkz)
1000+
splitReshapeDDCBlock(ret, multiplier, i, Math.min(i + blkz, s));
1001+
1002+
return ret;
1003+
}
1004+
1005+
public AMapToData[] splitReshapeDDCPushDown(final int multiplier, final ExecutorService pool) throws Exception {
1006+
1007+
final int s = size();
1008+
final AMapToData[] ret = new AMapToData[multiplier];
1009+
final int eachSize = s / multiplier;
1010+
for(int i = 0; i < multiplier; i++)
1011+
ret[i] = MapToFactory.create(eachSize, getUnique());
1012+
1013+
final int blkz = Math.max(eachSize / 8, 2048) * multiplier;
1014+
List<Future<?>> tasks = new ArrayList<>();
1015+
for(int i = 0; i < s; i += blkz) {
1016+
final int start = i;
1017+
final int end = Math.min(i + blkz, s);
1018+
tasks.add(pool.submit(() -> splitReshapeDDCBlock(ret, multiplier, start, end)));
1019+
}
1020+
1021+
for(Future<?> t : tasks)
1022+
t.get();
1023+
1024+
return ret;
1025+
}
1026+
1027+
private void splitReshapeDDCBlock(final AMapToData[] ret, final int multiplier, final int start, final int end) {
1028+
1029+
for(int i = start; i < end; i += multiplier)
1030+
splitReshapeDDCRow(ret, multiplier, i);
1031+
}
1032+
1033+
private void splitReshapeDDCRow(final AMapToData[] ret, final int multiplier, final int i) {
1034+
final int off = i / multiplier;
1035+
final int end = i + multiplier;
1036+
for(int j = i; j < end; j++)
1037+
ret[j % multiplier].set(off, getIndex(j));
1038+
}
1039+
9611040
@Override
9621041
public String toString() {
9631042
final int sz = size();

src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToBit.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,13 @@ public void set(int n, int v) {
145145
_data[wIdx] &= ~(1L << n);
146146
}
147147

148+
@Override
149+
public void set(int l, int u, int off, AMapToData tm){
150+
for(int i = l; i < u; i++, off++) {
151+
set(i, tm.getIndex(off));
152+
}
153+
}
154+
148155
@Override
149156
public int setAndGet(int n, int v) {
150157
set(n, v);
@@ -267,8 +274,8 @@ public void copy(AMapToData d) {
267274
}
268275

269276
@Override
270-
public void copyInt(int[] d) {
271-
for(int i = 0; i < _size; i++)
277+
public void copyInt(int[] d, int start, int end) {
278+
for(int i = start; i < end; i++)
272279
set(i, d[i]);
273280
}
274281

src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToByte.java

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,23 @@ public void set(int n, int v) {
9595
_data[n] = (byte) v;
9696
}
9797

98+
@Override
99+
public void set(int l, int u, int off, AMapToData tm){
100+
if(tm instanceof MapToByte){
101+
MapToByte tbm = (MapToByte)tm;
102+
byte[] tbv = tbm._data;
103+
for(int i = l; i < u; i++, off++) {
104+
_data[i] = tbv[off];
105+
}
106+
}
107+
else{
108+
109+
for(int i = l; i < u; i++, off++) {
110+
_data[i] = (byte)tm.getIndex(off);
111+
}
112+
}
113+
}
114+
98115
@Override
99116
public int setAndGet(int n, int v) {
100117
_data[n] = (byte) v;
@@ -136,8 +153,8 @@ public void replace(int v, int r) {
136153
}
137154

138155
@Override
139-
public void copyInt(int[] d) {
140-
for(int i = 0; i < _data.length; i++)
156+
public void copyInt(int[] d, int start, int end) {
157+
for(int i = start; i < end; i++)
141158
_data[i] = (byte) d[i];
142159
}
143160

@@ -320,13 +337,13 @@ public void decompressToRange(double[] c, int rl, int ru, int offR, double[] val
320337
}
321338

322339
@Override
323-
public void decompressToRangeOff(double[] c, int rl, int ru, int offR, double[] values) {
340+
protected void decompressToRangeOff(double[] c, int rl, int ru, int offR, double[] values) {
324341
for(int i = rl, offT = rl + offR; i < ru; i++, offT++)
325342
c[offT] += values[getIndex(i)];
326343
}
327344

328345
@Override
329-
public void decompressToRangeNoOff(double[] c, int rl, int ru, double[] values) {
346+
protected void decompressToRangeNoOff(double[] c, int rl, int ru, double[] values) {
330347
// OVERWRITTEN FOR JIT COMPILE!
331348
final int h = (ru - rl) % 8;
332349
for(int rc = rl; rc < rl + h; rc++)

src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,12 @@
2222
import java.io.DataInput;
2323
import java.io.DataOutput;
2424
import java.io.IOException;
25+
import java.util.ArrayList;
2526
import java.util.Arrays;
2627
import java.util.BitSet;
28+
import java.util.List;
29+
import java.util.concurrent.ExecutorService;
30+
import java.util.concurrent.Future;
2731

2832
import org.apache.commons.lang3.NotImplementedException;
2933
import org.apache.sysds.runtime.compress.colgroup.IMapToDataGroup;
@@ -92,6 +96,26 @@ public void set(int n, int v) {
9296
_data[n] = (char) v;
9397
}
9498

99+
public void set(int n, char v) {
100+
_data[n] = v;
101+
}
102+
103+
@Override
104+
public void set(int l, int u, int off, AMapToData tm) {
105+
if(tm instanceof MapToChar) {
106+
MapToChar tbm = (MapToChar) tm;
107+
char[] tbv = tbm._data;
108+
for(int i = l; i < u; i++, off++) {
109+
_data[i] = tbv[off];
110+
}
111+
}
112+
else {
113+
for(int i = l; i < u; i++, off++) {
114+
set(i, tm.getIndex(off));
115+
}
116+
}
117+
}
118+
95119
@Override
96120
public int setAndGet(int n, int v) {
97121
return _data[n] = (char) v;
@@ -144,7 +168,7 @@ protected static MapToChar readFields(DataInput in) throws IOException {
144168
final int length = in.readInt();
145169
final char[] data = new char[length];
146170
for(int i = 0; i < length; i++)
147-
data[i] = in.readChar();
171+
data[i] = (char)in.readUnsignedShort();
148172
return new MapToChar(unique, data);
149173
}
150174

@@ -208,8 +232,8 @@ public int getUpperBoundValue() {
208232
}
209233

210234
@Override
211-
public void copyInt(int[] d) {
212-
for(int i = 0; i < _data.length; i++)
235+
public void copyInt(int[] d, int start, int end) {
236+
for(int i = start; i < end; i++)
213237
_data[i] = (char) d[i];
214238
}
215239

@@ -391,4 +415,38 @@ protected final void preAggregateDDC_DDCSingleCol_vecChar(MapToChar tm, double[]
391415
v[getIndex(r8)] += td[tm.getIndex(r8)];
392416
}
393417

418+
@Override
419+
public AMapToData[] splitReshapeDDCPushDown(final int multiplier, final ExecutorService pool) throws Exception {
420+
final int s = size();
421+
final MapToChar[] ret = new MapToChar[multiplier];
422+
final int eachSize = s / multiplier;
423+
for(int i = 0; i < multiplier; i++)
424+
ret[i] = new MapToChar(getUnique(), eachSize);
425+
426+
final int blkz = Math.max(eachSize / 8, 2048) * multiplier;
427+
List<Future<?>> tasks = new ArrayList<>();
428+
for(int i = 0; i < s; i += blkz) {
429+
final int start = i;
430+
final int end = Math.min(i + blkz, s);
431+
tasks.add(pool.submit(() -> splitReshapeDDCBlock(ret, multiplier, start, end)));
432+
}
433+
434+
for(Future<?> t : tasks)
435+
t.get();
436+
437+
return ret;
438+
}
439+
440+
private void splitReshapeDDCBlock(final MapToChar[] ret, final int multiplier, final int start, final int end) {
441+
for(int i = start; i < end; i += multiplier)
442+
splitReshapeDDCRow(ret, multiplier, i);
443+
}
444+
445+
private void splitReshapeDDCRow(final MapToChar[] ret, final int multiplier, final int i) {
446+
final int off = i / multiplier;
447+
final int end = i + multiplier;
448+
for(int j = i; j < end; j++)
449+
ret[j % multiplier]._data[off] = _data[j];
450+
}
451+
394452
}

src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToCharPByte.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,13 @@ public void set(int n, int v) {
101101
_data_b[n] = (byte) (m >> 16);
102102
}
103103

104+
@Override
105+
public void set(int l, int u, int off, AMapToData tm){
106+
for(int i = l; i < u; i++, off++) {
107+
set(i, tm.getIndex(off));
108+
}
109+
}
110+
104111
@Override
105112
public int setAndGet(int n, int v) {
106113
int m = v & 0xffffff;
@@ -167,8 +174,8 @@ public int getUpperBoundValue() {
167174
}
168175

169176
@Override
170-
public void copyInt(int[] d) {
171-
for(int i = 0; i < d.length; i++)
177+
public void copyInt(int[] d, int start, int end) {
178+
for(int i = start; i < end; i++)
172179
set(i, d[i]);
173180
}
174181

src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToFactory.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,15 @@
2121

2222
import java.io.DataInput;
2323
import java.io.IOException;
24+
import java.util.ArrayList;
25+
import java.util.List;
26+
import java.util.concurrent.ExecutorService;
27+
import java.util.concurrent.Future;
2428

2529
import org.apache.commons.logging.Log;
2630
import org.apache.commons.logging.LogFactory;
31+
import org.apache.sysds.runtime.compress.utils.IntArrayList;
32+
import org.apache.sysds.runtime.util.CommonThreadPool;
2733

2834
/** Interface for the factory design pattern for construction all AMapToData. */
2935
public interface MapToFactory {
@@ -63,6 +69,26 @@ public static AMapToData create(int size, int[] values, int nUnique) {
6369
return _data;
6470
}
6571

72+
public static AMapToData create(int unique, IntArrayList values) {
73+
AMapToData _data = create(values.size(), unique);
74+
_data.copyInt(values.extractValues());
75+
return _data;
76+
}
77+
78+
public static AMapToData create(int size, int[] values, int nUnique, int k) {
79+
AMapToData _data = create(size, nUnique);
80+
ExecutorService pool = CommonThreadPool.get(k);
81+
int blk = Math.max((values.length / k), 1024);
82+
blk -= blk % 64; // ensure long size
83+
List<Future<?>> tasks = new ArrayList<>();
84+
for(int i = 0; i < values.length; i += blk){
85+
int start = i;
86+
int end = Math.min(i + blk, values.length);
87+
tasks.add(pool.submit(() -> _data.copyInt(values, start, end)));
88+
}
89+
return _data;
90+
}
91+
6692
/**
6793
* Create and allocate a map with the given size and support for upto the num tuples argument of values
6894
*

0 commit comments

Comments
 (0)