Skip to content

Commit b1c5d64

Browse files
Frxmsmboehm7
authored andcommitted
[SYSTEMDS-3860] Extended sparsity exploitation in codegen row templates
Finalized runtime kernels, code generation, and optimization Closes #2297. Closes #2277. Closes #2276.
1 parent 0263279 commit b1c5d64

File tree

23 files changed

+2436
-82
lines changed

23 files changed

+2436
-82
lines changed

src/main/java/org/apache/sysds/api/DMLOptions.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ public class DMLOptions {
8686
public boolean federatedCompilation = false; // Compile federated instructions based on input federation state and privacy constraints.
8787
public boolean noFedRuntimeConversion = false; // If activated, no runtime conversion of CP instructions to FED instructions will be performed.
8888
public int seed = -1; // The general seed for the execution, if -1 random (system time).
89+
public boolean sparseIntermediate = false; // whether SparseRowIntermediates should be used for rowwise operations
8990

9091
public final static DMLOptions defaultOptions = new DMLOptions(null);
9192

@@ -119,7 +120,8 @@ public String toString() {
119120
", w=" + fedWorker +
120121
", federatedCompilation=" + federatedCompilation +
121122
", noFedRuntimeConversion=" + noFedRuntimeConversion +
122-
", seed=" + seed +
123+
", seed=" + seed +
124+
", sparseIntermediate=" + sparseIntermediate +
123125
'}';
124126
}
125127

@@ -353,6 +355,11 @@ else if (lineageType.equalsIgnoreCase("debugger"))
353355
dmlOptions.seed = Integer.parseInt(line.getOptionValue("seed"));
354356
}
355357

358+
//TODO move to systemds-config instead of command-line arg
359+
if(line.hasOption("sparseIntermediate")){
360+
dmlOptions.sparseIntermediate = true;
361+
}
362+
356363
return dmlOptions;
357364
}
358365

@@ -436,7 +443,10 @@ private static Options createCLIOptions() {
436443
Option commandlineSeed = OptionBuilder
437444
.withDescription("A general seed for the execution through the commandline")
438445
.hasArg().create("seed");
439-
446+
Option sparseRowIntermediates = OptionBuilder
447+
.withDescription("If activated, sparseRowVector intermediates will be used to calculate rowwise operations.")
448+
.create("sparseIntermediate");
449+
440450
options.addOption(configOpt);
441451
options.addOption(cleanOpt);
442452
options.addOption(statsOpt);
@@ -457,6 +467,7 @@ private static Options createCLIOptions() {
457467
options.addOption(federatedCompilation);
458468
options.addOption(noFedRuntimeConversion);
459469
options.addOption(commandlineSeed);
470+
options.addOption(sparseRowIntermediates);
460471

461472
// Either a clean(-clean), a file(-f), a script(-s) or help(-help) needs to be specified
462473
OptionGroup fileOrScriptOpt = new OptionGroup()

src/main/java/org/apache/sysds/api/DMLScript.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@ public class DMLScript
155155
// Global seed
156156
public static int SEED = -1;
157157

158+
// Sparse row flag
159+
public static boolean SPARSE_INTERMEDIATE = false;
160+
158161
public static String MONITORING_ADDRESS = null;
159162

160163
// flag that indicates whether or not to suppress any prints to stdout
@@ -278,6 +281,7 @@ public static boolean executeScript( String[] args )
278281
LINEAGE_ESTIMATE = dmlOptions.lineage_estimate;
279282
LINEAGE_DEBUGGER = dmlOptions.lineage_debugger;
280283
SEED = dmlOptions.seed;
284+
SPARSE_INTERMEDIATE = dmlOptions.sparseIntermediate;
281285

282286

283287
String fnameOptConfig = dmlOptions.configFile;

src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ public static Hop optimize( Hop root, boolean recompile ) {
470470
* @param recompile true if invoked during dynamic recompilation
471471
* @return dag root nodes of modified dag
472472
*/
473-
public static ArrayList<Hop> optimize(ArrayList<Hop> roots, boolean recompile)
473+
public static ArrayList<Hop> optimize(ArrayList<Hop> roots, boolean recompile)
474474
{
475475
if( roots == null || roots.isEmpty() )
476476
return roots;

src/main/java/org/apache/sysds/hops/codegen/cplan/CNode.java

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
package org.apache.sysds.hops.codegen.cplan;
2121

22+
import org.apache.sysds.api.DMLScript;
2223
import org.apache.sysds.common.Types.DataType;
2324
import org.apache.sysds.hops.codegen.SpoofCompiler.GeneratorAPI;
2425
import org.apache.sysds.hops.codegen.template.TemplateUtils;
@@ -77,6 +78,14 @@ public String createVarname() {
7778
_genVar = "TMP"+_seqVar.getNextID();
7879
return _genVar;
7980
}
81+
82+
public String createVarname(boolean sparse) {
83+
if(!sparse) {
84+
return createVarname();
85+
} else {
86+
return _genVar = "S" + createVarname();
87+
}
88+
}
8089

8190
public String getVarname() {
8291
return _genVar;
@@ -98,6 +107,8 @@ public String getVectorLength(GeneratorAPI api) {
98107
return "len";
99108
if(getVarname().startsWith("b"))
100109
return getVarname() + ".clen";
110+
else if(getVarname().startsWith("STMP"))
111+
return "len";
101112
else if(_dataType == DataType.MATRIX)
102113
return getVarname() + ".length";
103114
}
@@ -222,8 +233,13 @@ public boolean equals(Object that) {
222233

223234
protected String replaceUnaryPlaceholders(String tmp, String varj, boolean vectIn, GeneratorAPI api) {
224235
//replace sparse and dense inputs
225-
tmp = tmp.replace("%IN1v%", varj+"vals");
226-
tmp = tmp.replace("%IN1i%", varj+"ix");
236+
if(DMLScript.SPARSE_INTERMEDIATE) {
237+
tmp = tmp.replace("%IN1v%", varj.startsWith("STMP") ? varj+".values()" : varj+"vals");
238+
tmp = tmp.replace("%IN1i%", varj.startsWith("STMP") ? varj+".indexes()" :varj+"ix");
239+
} else {
240+
tmp = tmp.replace("%IN1v%", varj+"vals");
241+
tmp = tmp.replace("%IN1i%", varj+"ix");
242+
}
227243
tmp = tmp.replace("%IN1%",
228244
(vectIn && TemplateUtils.isMatrix(_inputs.get(0))) ?
229245
((api == GeneratorAPI.JAVA) ? varj + ".values(rix)" : varj + ".vals(0)" ) :

src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java

Lines changed: 169 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.util.Arrays;
2323

2424
import org.apache.commons.lang3.StringUtils;
25+
import org.apache.sysds.api.DMLScript;
2526
import org.apache.sysds.common.Opcodes;
2627
import org.apache.sysds.hops.codegen.template.TemplateUtils;
2728
import org.apache.sysds.common.Types.DataType;
@@ -126,7 +127,8 @@ public boolean isNotSupportedBySpoofCUDA() {
126127
}
127128

128129
private final BinType _type;
129-
130+
private boolean sparseTemplate;
131+
130132
public CNodeBinary( CNode in1, CNode in2, BinType type ) {
131133
//canonicalize commutative matrix-scalar operations
132134
//to increase reuse potential
@@ -143,6 +145,23 @@ public CNodeBinary( CNode in1, CNode in2, BinType type ) {
143145
setOutputDims();
144146
}
145147

148+
public CNodeBinary( CNode in1, CNode in2, BinType type, double sparsityEst, double scalarVal ) {
149+
//canonicalize commutative matrix-scalar operations
150+
//to increase reuse potential
151+
if( type.isCommutative() && in1 instanceof CNodeData
152+
&& in1.getDataType()==DataType.SCALAR ) {
153+
CNode tmp = in1;
154+
in1 = in2;
155+
in2 = tmp;
156+
}
157+
158+
_inputs.add(in1);
159+
_inputs.add(in2);
160+
_type = type;
161+
setOutputDims();
162+
sparseTemplate = getTemplateType(sparsityEst, scalarVal);
163+
}
164+
146165
public BinType getType() {
147166
return _type;
148167
}
@@ -157,60 +176,63 @@ public String codegen(boolean sparse, GeneratorAPI api) {
157176
//generate children
158177
sb.append(_inputs.get(0).codegen(sparse, api));
159178
sb.append(_inputs.get(1).codegen(sparse, api));
160-
179+
161180
//generate binary operation (use sparse template, if data input)
162-
boolean lsparseLhs = sparse && _inputs.get(0) instanceof CNodeData
163-
&& _inputs.get(0).getVarname().startsWith("a");
164-
boolean lsparseRhs = sparse && _inputs.get(1) instanceof CNodeData
165-
&& _inputs.get(1).getVarname().startsWith("a");
181+
boolean lsparseLhs = sparse ? _inputs.get(0) instanceof CNodeData
182+
&& _inputs.get(0).getVarname().startsWith("a") ||
183+
_inputs.get(0).getVarname().startsWith("STMP") : false;
184+
boolean lsparseRhs = sparse ? _inputs.get(1) instanceof CNodeData
185+
&& _inputs.get(1).getVarname().startsWith("a") ||
186+
_inputs.get(1).getVarname().startsWith("STMP") : false;
166187
boolean scalarInput = _inputs.get(0).getDataType().isScalar();
167188
boolean scalarVector = (_inputs.get(0).getDataType().isScalar()
168189
&& _inputs.get(1).getDataType().isMatrix());
169190
boolean vectorVector = _inputs.get(0).getDataType().isMatrix()
170191
&& _inputs.get(1).getDataType().isMatrix();
171-
String var = createVarname();
192+
String var = createVarname(sparse && sparseTemplate && getOutputType(scalarVector, lsparseLhs, lsparseRhs));
172193
String tmp = getLanguageTemplateClass(this, api)
173-
.getTemplate(_type, lsparseLhs, lsparseRhs, scalarVector, scalarInput, vectorVector);
194+
.getTemplate(_type, lsparseLhs, lsparseRhs, scalarVector, scalarInput, vectorVector, sparseTemplate);
174195

175196
tmp = tmp.replace("%TMP%", var);
176-
197+
177198
//replace input references and start indexes
178199
for( int j=0; j<2; j++ ) {
179200
String varj = _inputs.get(j).getVarname(api);
180-
181201
//replace sparse and dense inputs
182-
tmp = tmp.replace("%IN"+(j+1)+"v%", varj+"vals");
183-
tmp = tmp.replace("%IN"+(j+1)+"i%", varj+"ix");
202+
tmp = tmp.replace("%IN"+(j+1)+"v%", varj.startsWith("STMP") ? varj+".values()" : varj+"vals");
203+
tmp = tmp.replace("%IN"+(j+1)+"i%", varj.startsWith("STMP") ? varj+".indexes()" : varj+"ix");
184204
tmp = tmp.replace("%IN"+(j+1)+"%",
185-
varj.startsWith("a") ? (api == GeneratorAPI.JAVA ? varj :
186-
(_inputs.get(j).getDataType() == DataType.MATRIX ? varj + ".vals(0)" : varj)) :
187-
varj.startsWith("b") ? (api == GeneratorAPI.JAVA ? varj + ".values(rix)" :
188-
(_type == BinType.VECT_MATRIXMULT ? varj : varj + ".vals(0)")) :
189-
_inputs.get(j).getDataType() == DataType.MATRIX ? (api == GeneratorAPI.JAVA ? varj : varj + ".vals(0)") : varj);
190-
205+
varj.startsWith("a") ? (api == GeneratorAPI.JAVA ? varj :
206+
(_inputs.get(j).getDataType() == DataType.MATRIX ? varj + ".vals(0)" : varj)) :
207+
varj.startsWith("b") ? (api == GeneratorAPI.JAVA ? varj + ".values(rix)" :
208+
(_type == BinType.VECT_MATRIXMULT ? varj : varj + ".vals(0)")) :
209+
_inputs.get(j).getDataType() == DataType.MATRIX ? (api == GeneratorAPI.JAVA ? varj : varj + ".vals(0)") : varj);
210+
211+
tmp = tmp.replace("%SLEN"+(j+1)+"%", varj.startsWith("STMP") ? varj+".size()" : varj.startsWith("a") ? "alen" : "blen");
212+
191213
//replace start position of main input
192-
tmp = tmp.replace("%POS"+(j+1)+"%", (_inputs.get(j) instanceof CNodeData
193-
&& _inputs.get(j).getDataType().isMatrix()) ? (!varj.startsWith("b")) ? varj+"i" :
194-
((TemplateUtils.isMatrix(_inputs.get(j)) || (_type.isElementwise()
195-
&& TemplateUtils.isColVector(_inputs.get(j)))) && _type!=BinType.VECT_MATRIXMULT) ?
214+
tmp = tmp.replace("%POS"+(j+1)+"%", (_inputs.get(j) instanceof CNodeData
215+
&& _inputs.get(j).getDataType().isMatrix()) ? (!varj.startsWith("b")) ? varj+"i" :
216+
((TemplateUtils.isMatrix(_inputs.get(j)) || (_type.isElementwise()
217+
&& TemplateUtils.isColVector(_inputs.get(j)))) && _type!=BinType.VECT_MATRIXMULT) ?
196218
varj + ".pos(rix)" : "0" : "0");
197219
}
198220
//replace length information (e.g., after matrix mult)
199-
if( _type == BinType.VECT_OUTERMULT_ADD || (_type == BinType.VECT_CBIND && vectorVector) ) {
221+
if( _type == BinType.VECT_OUTERMULT_ADD || (_type == BinType.VECT_CBIND && vectorVector)) {
200222
for( int j=0; j<2; j++ )
201223
tmp = tmp.replace("%LEN"+(j+1)+"%", _inputs.get(j).getVectorLength(api));
202224
}
203-
else { //general case
225+
else { //general case
204226
CNode mInput = getIntermediateInputVector();
205227
if( mInput != null )
206228
tmp = tmp.replace("%LEN%", mInput.getVectorLength(api));
207229
}
208-
230+
209231
sb.append(tmp);
210-
232+
211233
//mark as generated
212234
_generated = true;
213-
235+
214236
return sb.toString();
215237
}
216238

@@ -219,7 +241,126 @@ private CNode getIntermediateInputVector() {
219241
if( getInput().get(i).getDataType().isMatrix() )
220242
return getInput().get(i);
221243
return null;
222-
}
244+
}
245+
246+
private boolean getTemplateType(double sparsityEst, double scalarVal) {
247+
if(!DMLScript.SPARSE_INTERMEDIATE)
248+
return false;
249+
else {
250+
switch(_type) {
251+
case VECT_MULT:
252+
case VECT_DIV:
253+
case VECT_LESS:
254+
case VECT_MINUS:
255+
case VECT_PLUS:
256+
case VECT_XOR:
257+
case VECT_BITWAND:
258+
case VECT_BIASADD:
259+
case VECT_BIASMULT:
260+
case VECT_MIN:
261+
case VECT_MAX:
262+
case VECT_NOTEQUAL:
263+
case VECT_GREATER:
264+
case VECT_EQUAL:
265+
case VECT_LESSEQUAL:
266+
case VECT_GREATEREQUAL: return sparsityEst < 0.1;
267+
case VECT_MULT_SCALAR:
268+
case VECT_DIV_SCALAR:
269+
case VECT_XOR_SCALAR:
270+
case VECT_BITWAND_SCALAR: return sparsityEst < 0.3;
271+
case VECT_GREATER_SCALAR: {
272+
if(scalarVal != Double.NaN) {
273+
return _inputs.get(1).getDataType().isScalar() ? scalarVal >= 0 && sparsityEst < 0.2
274+
: _inputs.get(0).getDataType().isScalar() && scalarVal < 0 && sparsityEst < 0.2;
275+
} else
276+
return false;
277+
}
278+
case VECT_GREATEREQUAL_SCALAR: {
279+
if(scalarVal != Double.NaN) {
280+
return _inputs.get(1).getDataType().isScalar() ? scalarVal > 0 && sparsityEst < 0.2
281+
: _inputs.get(0).getDataType().isScalar() && scalarVal <= 0 && sparsityEst < 0.2;
282+
} else
283+
return false;
284+
}
285+
case VECT_MIN_SCALAR: {
286+
if(scalarVal != Double.NaN) {
287+
return _inputs.get(1).getDataType().isScalar() ? scalarVal >= 0 && sparsityEst < 0.2
288+
: _inputs.get(0).getDataType().isScalar() && scalarVal >= 0 && sparsityEst < 0.2;
289+
} else
290+
return false;
291+
}
292+
case VECT_LESS_SCALAR: {
293+
if(scalarVal != Double.NaN) {
294+
return _inputs.get(1).getDataType().isScalar() ? scalarVal <= 0 && sparsityEst < 0.2
295+
: _inputs.get(0).getDataType().isScalar() && scalarVal > 0 && sparsityEst < 0.2;
296+
} else
297+
return false;
298+
}
299+
case VECT_LESSEQUAL_SCALAR: {
300+
if(scalarVal != Double.NaN) {
301+
return _inputs.get(1).getDataType().isScalar() ? scalarVal < 0 && sparsityEst < 0.2
302+
: _inputs.get(0).getDataType().isScalar() && scalarVal >= 0 && sparsityEst < 0.2;
303+
} else
304+
return false;
305+
}
306+
case VECT_MAX_SCALAR: {
307+
if(scalarVal != Double.NaN) {
308+
return _inputs.get(1).getDataType().isScalar() ? scalarVal <= 0 && sparsityEst < 0.2
309+
: _inputs.get(0).getDataType().isScalar() && scalarVal <= 0 && sparsityEst < 0.2;
310+
} else
311+
return false;
312+
}
313+
case VECT_POW_SCALAR:
314+
case VECT_EQUAL_SCALAR:{
315+
if(scalarVal != Double.NaN) {
316+
return _inputs.get(1).getDataType().isScalar() ? scalarVal != 0 && sparsityEst < 0.2
317+
: _inputs.get(0).getDataType().isScalar() && scalarVal != 0 && sparsityEst < 0.2;
318+
} else
319+
return false;
320+
}
321+
case VECT_NOTEQUAL_SCALAR:{
322+
if(scalarVal != Double.NaN) {
323+
return _inputs.get(1).getDataType().isScalar() ? scalarVal == 0 && sparsityEst < 0.2
324+
: _inputs.get(0).getDataType().isScalar() && scalarVal == 0 && sparsityEst < 0.2;
325+
} else
326+
return false;
327+
}
328+
default: return sparsityEst < 0.3;
329+
}
330+
}
331+
}
332+
333+
public boolean getOutputType(boolean scalarVector, boolean lsparseLhs, boolean lsparseRhs) {
334+
switch(_type) {
335+
case VECT_POW_SCALAR: return !scalarVector && lsparseLhs;
336+
case VECT_MULT_SCALAR:
337+
case VECT_DIV_SCALAR:
338+
case VECT_XOR_SCALAR:
339+
case VECT_MIN_SCALAR:
340+
case VECT_MAX_SCALAR:
341+
case VECT_EQUAL_SCALAR:
342+
case VECT_NOTEQUAL_SCALAR:
343+
case VECT_LESS_SCALAR:
344+
case VECT_LESSEQUAL_SCALAR:
345+
case VECT_GREATER_SCALAR:
346+
case VECT_GREATEREQUAL_SCALAR:
347+
case VECT_BITWAND_SCALAR: return lsparseLhs || lsparseRhs;
348+
case VECT_MULT:
349+
case VECT_DIV:
350+
case VECT_MINUS:
351+
case VECT_PLUS:
352+
case VECT_XOR:
353+
case VECT_BITWAND:
354+
case VECT_BIASADD:
355+
case VECT_BIASMULT:
356+
case VECT_MIN:
357+
case VECT_MAX:
358+
case VECT_NOTEQUAL:
359+
case VECT_LESS:
360+
case VECT_GREATER: return lsparseLhs && lsparseRhs;
361+
default: return false;
362+
}
363+
}
223364

224365
@Override
225366
public String toString() {

src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeNary.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,10 @@ public String getTemplate(boolean sparseGen, long len, ArrayList<CNode> inputs,
6060
sb.append( sparseInput ?
6161
" LibSpoofPrimitives.vectWrite("+varj+"vals, %TMP%, "
6262
+varj+"ix, "+pos+", "+off+", "+input._cols+");\n" :
63-
" LibSpoofPrimitives.vectWrite("+(varj.startsWith("b")?varj+".values(rix)":varj)
63+
varj.startsWith("STMP") ?
64+
" LibSpoofPrimitives.vectWrite("+varj+".values(), %TMP%, "
65+
+varj+".indexes(), "+pos+", "+off+", "+varj+".size());\n" :
66+
" LibSpoofPrimitives.vectWrite("+(varj.startsWith("b")?varj+".values(rix)":varj)
6467
+", %TMP%, "+pos+", "+off+", "+input._cols+");\n");
6568
off += input._cols;
6669
}

0 commit comments

Comments
 (0)