Skip to content

Commit bb9ab3b

Browse files
committed
Merge branch 'Sparse-Row-Primitives' into Sparse-Row-Optimizer
2 parents 7bca4e3 + e5943a3 commit bb9ab3b

File tree

2 files changed

+72
-77
lines changed

2 files changed

+72
-77
lines changed

src/main/java/org/apache/sysds/runtime/codegen/LibSpoofPrimitives.java

Lines changed: 46 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -2183,9 +2183,9 @@ public static SparseRowVector vectMultWrite(int len, double[] a, double bval, in
21832183
if( a == null ) return c;
21842184
int[] indexes = c.indexes();
21852185
double[] values = c.values();
2186-
for(int j = ai; j < ai+alen; j++) {
2187-
indexes[j] = aix[j];
2188-
values[j] = a[j]*bval;
2186+
for(int j = 0; j < alen; j++) {
2187+
indexes[j] = aix[ai+j];
2188+
values[j] = a[ai+j]*bval;
21892189
}
21902190
c.setSize(alen);
21912191
return c;
@@ -2195,7 +2195,7 @@ public static SparseRowVector vectMultWrite(int len, double bval, double[] a, in
21952195
return vectMultWrite(len, a, bval, aix, ai, alen);
21962196
}
21972197

2198-
//version with branching
2198+
//old version with branching (not used)
21992199
public static SparseRowVector vectMultWriteB(int len, double[] a, double[] b, int[] aix, int[] bix, int ai, int bi, int alen, int blen) {
22002200
SparseRowVector c = allocSparseVector(Math.min(alen, blen));
22012201
if( a == null || b == null ) return c;
@@ -2251,17 +2251,17 @@ public static void vectWrite(double[] a, int[] aix, double[] c, int ci, int len)
22512251

22522252
public static void vectWrite(double[] a, double[] c, int[] aix, int ai, int ci, int alen) {
22532253
if( a == null ) return;
2254-
for(int j = ai; j < ai+alen; j++)
2255-
c[ci+aix[j]] = a[j];
2254+
for(int j = 0; j < alen; j++)
2255+
c[ci+aix[ai+j]] = a[ai+j];
22562256
}
22572257

22582258
public static SparseRowVector vectDivWrite(int len, double[] a, double bval, int[] aix, int ai, int alen) {
22592259
SparseRowVector c = allocSparseVector(alen);
22602260
int[] indexes = c.indexes();
22612261
double[] values = c.values();
2262-
for( int j = ai; j < ai+alen; j++ ) {
2263-
indexes[j] = aix[j];
2264-
values[j] = a[j] / bval;
2262+
for( int j = 0; j < alen; j++ ) {
2263+
indexes[j] = aix[ai+j];
2264+
values[j] = a[ai+j] / bval;
22652265
}
22662266
c.setSize(alen);
22672267
return c;
@@ -2271,15 +2271,15 @@ public static SparseRowVector vectDivWrite(int len, double bval, double[] a, int
22712271
SparseRowVector c = allocSparseVector(alen);
22722272
int[] indexes = c.indexes();
22732273
double[] values = c.values();
2274-
for(int j = ai; j < ai+alen; j++) {
2275-
indexes[j] = aix[j];
2276-
values[j] = bval / a[j];
2274+
for(int j = 0; j < alen; j++) {
2275+
indexes[j] = aix[ai+j];
2276+
values[j] = bval / a[ai+j];
22772277
}
22782278
c.setSize(alen);
22792279
return c;
22802280
}
22812281

2282-
//version with branching
2282+
//old version with branching (not used)
22832283
public static SparseRowVector vectDivWriteB(int len, double[] a, double[] b, int[] aix, int[] bix, int ai, int bi, int alen, int blen) {
22842284
SparseRowVector c = allocSparseVector(alen);
22852285
int aItr = ai;
@@ -2471,9 +2471,9 @@ public static SparseRowVector vectXorWrite(int len, double[] a, double bval, int
24712471
SparseRowVector c = allocSparseVector(alen);
24722472
int[] indexes = c.indexes();
24732473
double[] values = c.values();
2474-
for(int j = ai; j < ai+alen; j++) {
2475-
indexes[j] = aix[j];
2476-
values[j] = (a[j] != 0) ? 1 : 0;
2474+
for(int j = 0; j < alen; j++) {
2475+
indexes[j] = aix[ai+j];
2476+
values[j] = (a[ai+j] != 0) ? 1 : 0;
24772477
}
24782478
c.setSize(alen);
24792479
return c;
@@ -2548,9 +2548,9 @@ public static SparseRowVector vectPowWrite(int len, double[] a, double bval, int
25482548
SparseRowVector c = allocSparseVector(alen);
25492549
int[] indexes = c.indexes();
25502550
double[] values = c.values();
2551-
for(int j = ai; j < ai+alen; j++) {
2552-
indexes[j] = aix[j];
2553-
values[j] = Math.pow(a[j], bval);
2551+
for(int j = 0; j < alen; j++) {
2552+
indexes[j] = aix[ai+j];
2553+
values[j] = Math.pow(a[ai+j], bval);
25542554
}
25552555
c.setSize(alen);
25562556
return c;
@@ -2584,9 +2584,9 @@ public static SparseRowVector vectMinWrite(int len, double[] a, double bval, int
25842584
SparseRowVector c = allocSparseVector(alen);
25852585
int[] indexes = c.indexes();
25862586
double[] values = c.values();
2587-
for(int j = ai; j < ai+alen; j++) {
2588-
indexes[j] = aix[j];
2589-
values[j] = Math.min(a[j], bval);
2587+
for(int j = 0; j < alen; j++) {
2588+
indexes[j] = aix[ai+j];
2589+
values[j] = Math.min(a[ai+j], bval);
25902590
}
25912591
c.setSize(alen);
25922592
return c;
@@ -2661,9 +2661,9 @@ public static SparseRowVector vectMaxWrite(int len, double[] a, double bval, int
26612661
SparseRowVector c = allocSparseVector(alen);
26622662
int[] indexes = c.indexes();
26632663
double[] values = c.values();
2664-
for(int j = ai; j < ai+alen; j++) {
2665-
indexes[j] = aix[j];
2666-
values[j] = Math.max(a[j], bval);
2664+
for(int j = 0; j < alen; j++) {
2665+
indexes[j] = aix[ai+j];
2666+
values[j] = Math.max(a[ai+j], bval);
26672667
}
26682668
c.setSize(alen);
26692669
return c;
@@ -2738,9 +2738,9 @@ public static SparseRowVector vectEqualWrite(int len, double[] a, double bval, i
27382738
SparseRowVector c = allocSparseVector(alen);
27392739
int[] indexes = c.indexes();
27402740
double[] values = c.values();
2741-
for(int j = ai; j < ai+alen; j++) {
2742-
indexes[j] = aix[j];
2743-
values[j] = a[j] == bval ? 1 : 0;
2741+
for(int j = 0; j < alen; j++) {
2742+
indexes[j] = aix[ai+j];
2743+
values[j] = a[ai+j] == bval ? 1 : 0;
27442744
}
27452745
c.setSize(alen);
27462746
return c;
@@ -2803,9 +2803,9 @@ public static SparseRowVector vectNotequalWrite(int len, double[] a, double bval
28032803
SparseRowVector c = allocSparseVector(alen);
28042804
int[] indexes = c.indexes();
28052805
double[] values = c.values();
2806-
for(int j = ai; j < ai+alen; j++) {
2807-
indexes[j] = aix[j];
2808-
values[j] = a[j] != bval ? 1 : 0;
2806+
for(int j = 0; j < alen; j++) {
2807+
indexes[j] = aix[ai+j];
2808+
values[j] = a[ai+j] != bval ? 1 : 0;
28092809
}
28102810
c.setSize(alen);
28112811
return c;
@@ -2880,9 +2880,9 @@ public static SparseRowVector vectLessWrite(int len, double[] a, double bval, in
28802880
SparseRowVector c = allocSparseVector(alen);
28812881
int[] indexes = c.indexes();
28822882
double[] values = c.values();
2883-
for(int j = ai; j < ai+alen; j++) {
2884-
indexes[j] = aix[j];
2885-
values[j] = a[j] < bval ? 1 : 0;
2883+
for(int j = 0; j < alen; j++) {
2884+
indexes[j] = aix[ai+j];
2885+
values[j] = a[ai+j] < bval ? 1 : 0;
28862886
}
28872887
c.setSize(alen);
28882888
return c;
@@ -2957,9 +2957,9 @@ public static SparseRowVector vectLessequalWrite(int len, double[] a, double bva
29572957
SparseRowVector c = allocSparseVector(alen);
29582958
int[] indexes = c.indexes();
29592959
double[] values = c.values();
2960-
for(int j = ai; j < ai+alen; j++) {
2961-
indexes[j] = aix[j];
2962-
values[j] = a[j] <= bval ? 1 : 0;
2960+
for(int j = 0; j < alen; j++) {
2961+
indexes[j] = aix[ai+j];
2962+
values[j] = a[ai+j] <= bval ? 1 : 0;
29632963
}
29642964
c.setSize(alen);
29652965
return c;
@@ -3022,9 +3022,9 @@ public static SparseRowVector vectGreaterWrite(int len, double[] a, double bval,
30223022
SparseRowVector c = allocSparseVector(alen);
30233023
int[] indexes = c.indexes();
30243024
double[] values = c.values();
3025-
for(int j = ai; j < ai+alen; j++) {
3026-
indexes[j] = aix[j];
3027-
values[j] = a[j] > bval ? 1 : 0;
3025+
for(int j = 0; j < alen; j++) {
3026+
indexes[j] = aix[ai+j];
3027+
values[j] = a[ai+j] > bval ? 1 : 0;
30283028
}
30293029
c.setSize(alen);
30303030
return c;
@@ -3099,9 +3099,9 @@ public static SparseRowVector vectGreaterequalWrite(int len, double[] a, double
30993099
SparseRowVector c = allocSparseVector(alen);
31003100
int[] indexes = c.indexes();
31013101
double[] values = c.values();
3102-
for(int j = ai; j < ai+alen; j++) {
3103-
indexes[j] = aix[j];
3104-
values[j] = a[j] >= bval ? 1 : 0;
3102+
for(int j = 0; j < alen; j++) {
3103+
indexes[j] = aix[ai+j];
3104+
values[j] = a[ai+j] >= bval ? 1 : 0;
31053105
}
31063106
c.setSize(alen);
31073107
return c;
@@ -3142,9 +3142,9 @@ public static SparseRowVector vectBitwandWrite(int len, double[] a, double bval,
31423142
int[] indexes = c.indexes();
31433143
double[] values = c.values();
31443144
int bval1 = (int) bval;
3145-
for( int j = ai; j < ai+alen; j++ ) {
3146-
indexes[j] = aix[j];
3147-
values[j] = bwAnd(a[j], bval1);
3145+
for( int j = 0; j < alen; j++ ) {
3146+
indexes[j] = aix[ai+j];
3147+
values[j] = bwAnd(a[ai+j], bval1);
31483148
}
31493149
c.setSize(alen);
31503150
return c;
@@ -3318,31 +3318,6 @@ public static SparseRowVector vectSignWrite(int len, double[] a, int[] aix, int
33183318
return c;
33193319
}
33203320

3321-
//todo MatrixMult, pow2 and mult2 drafts
3322-
// public static SparseRowVector vectMatrixMult(int len, double[] a, double[] b, int[] aix, int[] bix, int ai, int bi, int alen, int blen) {
3323-
// //note: assumption b is already transposed for efficient dot products
3324-
// int m2clen = b.length / len;
3325-
// SparseRowVector c = allocSparseVector(m2clen);
3326-
// for(int i = 0; i < m2clen; i++) {
3327-
// c.set(bix[i], LibMatrixMult.dotProduct(a, aix, ai, alen, b, bix, bi, blen));
3328-
// }
3329-
// return c;
3330-
// }
3331-
//
3332-
// public static SparseRowVector vectPow2Write(int len, double[] a, int[] aix, int ai, int alen) {
3333-
// SparseRowVector c = allocSparseVector(len);
3334-
// for(int j = 0; j < ai+alen; j++)
3335-
// c.set(aix[j], a[j] * a[j]);
3336-
// return c;
3337-
// }
3338-
//
3339-
// public static SparseRowVector vectMult2Write(int len, double[] a, int[] aix, int ai, int alen) {
3340-
// SparseRowVector c = allocSparseVector(len);
3341-
// for(int j = 0; j < ai+alen; j++)
3342-
// c.set(aix[j], a[j] + a[j]);
3343-
// return c;
3344-
// }
3345-
33463321
//complex builtin functions that are not directly generated
33473322
//(included here in order to reduce the number of imports)
33483323

src/main/java/org/apache/sysds/runtime/codegen/SpoofRowwise.java

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ public MatrixBlock execute(ArrayList<MatrixBlock> inputs, ArrayList<ScalarObject
189189
//setup thread-local memory if necessary
190190
if( allocTmp &&_reqVectMem > 0 )
191191
if(inputs.get(0).isInSparseFormat() && DMLScript.SPARSE_INTERMEDIATE) {
192-
LibSpoofPrimitives.setupSparseThreadLocalMemory(_reqVectMem, n/2, n2);
192+
LibSpoofPrimitives.setupSparseThreadLocalMemory(_reqVectMem, n, n2);
193193
LibSpoofPrimitives.setupThreadLocalMemory(_reqVectMem, n, n2);
194194
} else {
195195
LibSpoofPrimitives.setupThreadLocalMemory(_reqVectMem, n, n2);
@@ -442,7 +442,12 @@ public DenseBlock call() {
442442

443443
//allocate vector intermediates and partial output
444444
if( _reqVectMem > 0 )
445-
LibSpoofPrimitives.setupThreadLocalMemory(_reqVectMem, _clen, _clen2);
445+
if(_a.isInSparseFormat() && DMLScript.SPARSE_INTERMEDIATE) {
446+
LibSpoofPrimitives.setupSparseThreadLocalMemory(_reqVectMem, _clen, _clen2);
447+
LibSpoofPrimitives.setupThreadLocalMemory(_reqVectMem, _clen, _clen2);
448+
} else {
449+
LibSpoofPrimitives.setupThreadLocalMemory(_reqVectMem, _clen, _clen2);
450+
}
446451
DenseBlock c = DenseBlockFactory.createDenseBlock(1, _outLen);
447452

448453
if( !_a.isInSparseFormat() )
@@ -451,7 +456,12 @@ public DenseBlock call() {
451456
executeSparse(_a.getSparseBlock(), _b, _scalars, c, _clen, _rl, _ru, 0);
452457

453458
if( _reqVectMem > 0 )
454-
LibSpoofPrimitives.cleanupThreadLocalMemory();
459+
if(_a.isInSparseFormat() && DMLScript.SPARSE_INTERMEDIATE) {
460+
LibSpoofPrimitives.cleanupSparseThreadLocalMemory();
461+
LibSpoofPrimitives.cleanupThreadLocalMemory();
462+
} else {
463+
LibSpoofPrimitives.cleanupThreadLocalMemory();
464+
}
455465
return c;
456466
}
457467
}
@@ -485,15 +495,25 @@ protected ParExecTask( MatrixBlock a, SideInput[] b, MatrixBlock c, double[] sca
485495
public Long call() {
486496
//allocate vector intermediates
487497
if( _reqVectMem > 0 )
488-
LibSpoofPrimitives.setupThreadLocalMemory(_reqVectMem, _clen, _clen2);
498+
if(_a.isInSparseFormat() && DMLScript.SPARSE_INTERMEDIATE) {
499+
LibSpoofPrimitives.setupSparseThreadLocalMemory(_reqVectMem, _clen, _clen2);
500+
LibSpoofPrimitives.setupThreadLocalMemory(_reqVectMem, _clen, _clen2);
501+
} else {
502+
LibSpoofPrimitives.setupThreadLocalMemory(_reqVectMem, _clen, _clen2);
503+
}
489504

490505
if( !_a.isInSparseFormat() )
491506
executeDense(_a.getDenseBlock(), _b, _scalars, _c.getDenseBlock(), _clen, _rl, _ru, 0);
492507
else
493508
executeSparse(_a.getSparseBlock(), _b, _scalars, _c.getDenseBlock(), _clen, _rl, _ru, 0);
494-
509+
495510
if( _reqVectMem > 0 )
496-
LibSpoofPrimitives.cleanupThreadLocalMemory();
511+
if(_a.isInSparseFormat() && DMLScript.SPARSE_INTERMEDIATE) {
512+
LibSpoofPrimitives.cleanupSparseThreadLocalMemory();
513+
LibSpoofPrimitives.cleanupThreadLocalMemory();
514+
} else {
515+
LibSpoofPrimitives.cleanupThreadLocalMemory();
516+
}
497517

498518
//maintain nnz for row partition
499519
return _c.recomputeNonZeros(_rl, _ru-1, 0, _c.getNumColumns()-1);

0 commit comments

Comments
 (0)