2222import java .util .Arrays ;
2323
2424import org .apache .commons .lang3 .StringUtils ;
25+ import org .apache .sysds .api .DMLScript ;
2526import org .apache .sysds .common .Opcodes ;
2627import org .apache .sysds .hops .codegen .template .TemplateUtils ;
2728import org .apache .sysds .common .Types .DataType ;
@@ -126,7 +127,8 @@ public boolean isNotSupportedBySpoofCUDA() {
126127 }
127128
128129 private final BinType _type ;
129-
130+ private boolean sparseTemplate ;
131+
130132 public CNodeBinary ( CNode in1 , CNode in2 , BinType type ) {
131133 //canonicalize commutative matrix-scalar operations
132134 //to increase reuse potential
@@ -143,6 +145,23 @@ public CNodeBinary( CNode in1, CNode in2, BinType type ) {
143145 setOutputDims ();
144146 }
145147
148+ public CNodeBinary ( CNode in1 , CNode in2 , BinType type , double sparsityEst , double scalarVal ) {
149+ //canonicalize commutative matrix-scalar operations
150+ //to increase reuse potential
151+ if ( type .isCommutative () && in1 instanceof CNodeData
152+ && in1 .getDataType ()==DataType .SCALAR ) {
153+ CNode tmp = in1 ;
154+ in1 = in2 ;
155+ in2 = tmp ;
156+ }
157+
158+ _inputs .add (in1 );
159+ _inputs .add (in2 );
160+ _type = type ;
161+ setOutputDims ();
162+ sparseTemplate = getTemplateType (sparsityEst , scalarVal );
163+ }
164+
146165 public BinType getType () {
147166 return _type ;
148167 }
@@ -157,60 +176,63 @@ public String codegen(boolean sparse, GeneratorAPI api) {
157176 //generate children
158177 sb .append (_inputs .get (0 ).codegen (sparse , api ));
159178 sb .append (_inputs .get (1 ).codegen (sparse , api ));
160-
179+
161180 //generate binary operation (use sparse template, if data input)
162- boolean lsparseLhs = sparse && _inputs .get (0 ) instanceof CNodeData
163- && _inputs .get (0 ).getVarname ().startsWith ("a" );
164- boolean lsparseRhs = sparse && _inputs .get (1 ) instanceof CNodeData
165- && _inputs .get (1 ).getVarname ().startsWith ("a" );
181+ boolean lsparseLhs = sparse ? _inputs .get (0 ) instanceof CNodeData
182+ && _inputs .get (0 ).getVarname ().startsWith ("a" ) ||
183+ _inputs .get (0 ).getVarname ().startsWith ("STMP" ) : false ;
184+ boolean lsparseRhs = sparse ? _inputs .get (1 ) instanceof CNodeData
185+ && _inputs .get (1 ).getVarname ().startsWith ("a" ) ||
186+ _inputs .get (1 ).getVarname ().startsWith ("STMP" ) : false ;
166187 boolean scalarInput = _inputs .get (0 ).getDataType ().isScalar ();
167188 boolean scalarVector = (_inputs .get (0 ).getDataType ().isScalar ()
168189 && _inputs .get (1 ).getDataType ().isMatrix ());
169190 boolean vectorVector = _inputs .get (0 ).getDataType ().isMatrix ()
170191 && _inputs .get (1 ).getDataType ().isMatrix ();
171- String var = createVarname ();
192+ String var = createVarname (sparse && sparseTemplate && getOutputType ( scalarVector , lsparseLhs , lsparseRhs ) );
172193 String tmp = getLanguageTemplateClass (this , api )
173- .getTemplate (_type , lsparseLhs , lsparseRhs , scalarVector , scalarInput , vectorVector );
194+ .getTemplate (_type , lsparseLhs , lsparseRhs , scalarVector , scalarInput , vectorVector , sparseTemplate );
174195
175196 tmp = tmp .replace ("%TMP%" , var );
176-
197+
177198 //replace input references and start indexes
178199 for ( int j =0 ; j <2 ; j ++ ) {
179200 String varj = _inputs .get (j ).getVarname (api );
180-
181201 //replace sparse and dense inputs
182- tmp = tmp .replace ("%IN" +(j +1 )+"v%" , varj +"vals" );
183- tmp = tmp .replace ("%IN" +(j +1 )+"i%" , varj +"ix" );
202+ tmp = tmp .replace ("%IN" +(j +1 )+"v%" , varj . startsWith ( "STMP" ) ? varj + ".values()" : varj +"vals" );
203+ tmp = tmp .replace ("%IN" +(j +1 )+"i%" , varj . startsWith ( "STMP" ) ? varj + ".indexes()" : varj +"ix" );
184204 tmp = tmp .replace ("%IN" +(j +1 )+"%" ,
185- varj .startsWith ("a" ) ? (api == GeneratorAPI .JAVA ? varj :
186- (_inputs .get (j ).getDataType () == DataType .MATRIX ? varj + ".vals(0)" : varj )) :
187- varj .startsWith ("b" ) ? (api == GeneratorAPI .JAVA ? varj + ".values(rix)" :
188- (_type == BinType .VECT_MATRIXMULT ? varj : varj + ".vals(0)" )) :
189- _inputs .get (j ).getDataType () == DataType .MATRIX ? (api == GeneratorAPI .JAVA ? varj : varj + ".vals(0)" ) : varj );
190-
205+ varj .startsWith ("a" ) ? (api == GeneratorAPI .JAVA ? varj :
206+ (_inputs .get (j ).getDataType () == DataType .MATRIX ? varj + ".vals(0)" : varj )) :
207+ varj .startsWith ("b" ) ? (api == GeneratorAPI .JAVA ? varj + ".values(rix)" :
208+ (_type == BinType .VECT_MATRIXMULT ? varj : varj + ".vals(0)" )) :
209+ _inputs .get (j ).getDataType () == DataType .MATRIX ? (api == GeneratorAPI .JAVA ? varj : varj + ".vals(0)" ) : varj );
210+
211+ tmp = tmp .replace ("%SLEN" +(j +1 )+"%" , varj .startsWith ("STMP" ) ? varj +".size()" : varj .startsWith ("a" ) ? "alen" : "blen" );
212+
191213 //replace start position of main input
192- tmp = tmp .replace ("%POS" +(j +1 )+"%" , (_inputs .get (j ) instanceof CNodeData
193- && _inputs .get (j ).getDataType ().isMatrix ()) ? (!varj .startsWith ("b" )) ? varj +"i" :
194- ((TemplateUtils .isMatrix (_inputs .get (j )) || (_type .isElementwise ()
195- && TemplateUtils .isColVector (_inputs .get (j )))) && _type !=BinType .VECT_MATRIXMULT ) ?
214+ tmp = tmp .replace ("%POS" +(j +1 )+"%" , (_inputs .get (j ) instanceof CNodeData
215+ && _inputs .get (j ).getDataType ().isMatrix ()) ? (!varj .startsWith ("b" )) ? varj +"i" :
216+ ((TemplateUtils .isMatrix (_inputs .get (j )) || (_type .isElementwise ()
217+ && TemplateUtils .isColVector (_inputs .get (j )))) && _type !=BinType .VECT_MATRIXMULT ) ?
196218 varj + ".pos(rix)" : "0" : "0" );
197219 }
198220 //replace length information (e.g., after matrix mult)
199- if ( _type == BinType .VECT_OUTERMULT_ADD || (_type == BinType .VECT_CBIND && vectorVector ) ) {
221+ if ( _type == BinType .VECT_OUTERMULT_ADD || (_type == BinType .VECT_CBIND && vectorVector )) {
200222 for ( int j =0 ; j <2 ; j ++ )
201223 tmp = tmp .replace ("%LEN" +(j +1 )+"%" , _inputs .get (j ).getVectorLength (api ));
202224 }
203- else { //general case
225+ else { //general case
204226 CNode mInput = getIntermediateInputVector ();
205227 if ( mInput != null )
206228 tmp = tmp .replace ("%LEN%" , mInput .getVectorLength (api ));
207229 }
208-
230+
209231 sb .append (tmp );
210-
232+
211233 //mark as generated
212234 _generated = true ;
213-
235+
214236 return sb .toString ();
215237 }
216238
@@ -219,7 +241,126 @@ private CNode getIntermediateInputVector() {
219241 if ( getInput ().get (i ).getDataType ().isMatrix () )
220242 return getInput ().get (i );
221243 return null ;
222- }
244+ }
245+
246+ private boolean getTemplateType (double sparsityEst , double scalarVal ) {
247+ if (!DMLScript .SPARSE_INTERMEDIATE )
248+ return false ;
249+ else {
250+ switch (_type ) {
251+ case VECT_MULT :
252+ case VECT_DIV :
253+ case VECT_LESS :
254+ case VECT_MINUS :
255+ case VECT_PLUS :
256+ case VECT_XOR :
257+ case VECT_BITWAND :
258+ case VECT_BIASADD :
259+ case VECT_BIASMULT :
260+ case VECT_MIN :
261+ case VECT_MAX :
262+ case VECT_NOTEQUAL :
263+ case VECT_GREATER :
264+ case VECT_EQUAL :
265+ case VECT_LESSEQUAL :
266+ case VECT_GREATEREQUAL : return sparsityEst < 0.1 ;
267+ case VECT_MULT_SCALAR :
268+ case VECT_DIV_SCALAR :
269+ case VECT_XOR_SCALAR :
270+ case VECT_BITWAND_SCALAR : return sparsityEst < 0.3 ;
271+ case VECT_GREATER_SCALAR : {
272+ if (scalarVal != Double .NaN ) {
273+ return _inputs .get (1 ).getDataType ().isScalar () ? scalarVal >= 0 && sparsityEst < 0.2
274+ : _inputs .get (0 ).getDataType ().isScalar () && scalarVal < 0 && sparsityEst < 0.2 ;
275+ } else
276+ return false ;
277+ }
278+ case VECT_GREATEREQUAL_SCALAR : {
279+ if (scalarVal != Double .NaN ) {
280+ return _inputs .get (1 ).getDataType ().isScalar () ? scalarVal > 0 && sparsityEst < 0.2
281+ : _inputs .get (0 ).getDataType ().isScalar () && scalarVal <= 0 && sparsityEst < 0.2 ;
282+ } else
283+ return false ;
284+ }
285+ case VECT_MIN_SCALAR : {
286+ if (scalarVal != Double .NaN ) {
287+ return _inputs .get (1 ).getDataType ().isScalar () ? scalarVal >= 0 && sparsityEst < 0.2
288+ : _inputs .get (0 ).getDataType ().isScalar () && scalarVal >= 0 && sparsityEst < 0.2 ;
289+ } else
290+ return false ;
291+ }
292+ case VECT_LESS_SCALAR : {
293+ if (scalarVal != Double .NaN ) {
294+ return _inputs .get (1 ).getDataType ().isScalar () ? scalarVal <= 0 && sparsityEst < 0.2
295+ : _inputs .get (0 ).getDataType ().isScalar () && scalarVal > 0 && sparsityEst < 0.2 ;
296+ } else
297+ return false ;
298+ }
299+ case VECT_LESSEQUAL_SCALAR : {
300+ if (scalarVal != Double .NaN ) {
301+ return _inputs .get (1 ).getDataType ().isScalar () ? scalarVal < 0 && sparsityEst < 0.2
302+ : _inputs .get (0 ).getDataType ().isScalar () && scalarVal >= 0 && sparsityEst < 0.2 ;
303+ } else
304+ return false ;
305+ }
306+ case VECT_MAX_SCALAR : {
307+ if (scalarVal != Double .NaN ) {
308+ return _inputs .get (1 ).getDataType ().isScalar () ? scalarVal <= 0 && sparsityEst < 0.2
309+ : _inputs .get (0 ).getDataType ().isScalar () && scalarVal <= 0 && sparsityEst < 0.2 ;
310+ } else
311+ return false ;
312+ }
313+ case VECT_POW_SCALAR :
314+ case VECT_EQUAL_SCALAR :{
315+ if (scalarVal != Double .NaN ) {
316+ return _inputs .get (1 ).getDataType ().isScalar () ? scalarVal != 0 && sparsityEst < 0.2
317+ : _inputs .get (0 ).getDataType ().isScalar () && scalarVal != 0 && sparsityEst < 0.2 ;
318+ } else
319+ return false ;
320+ }
321+ case VECT_NOTEQUAL_SCALAR :{
322+ if (scalarVal != Double .NaN ) {
323+ return _inputs .get (1 ).getDataType ().isScalar () ? scalarVal == 0 && sparsityEst < 0.2
324+ : _inputs .get (0 ).getDataType ().isScalar () && scalarVal == 0 && sparsityEst < 0.2 ;
325+ } else
326+ return false ;
327+ }
328+ default : return sparsityEst < 0.3 ;
329+ }
330+ }
331+ }
332+
333+ public boolean getOutputType (boolean scalarVector , boolean lsparseLhs , boolean lsparseRhs ) {
334+ switch (_type ) {
335+ case VECT_POW_SCALAR : return !scalarVector && lsparseLhs ;
336+ case VECT_MULT_SCALAR :
337+ case VECT_DIV_SCALAR :
338+ case VECT_XOR_SCALAR :
339+ case VECT_MIN_SCALAR :
340+ case VECT_MAX_SCALAR :
341+ case VECT_EQUAL_SCALAR :
342+ case VECT_NOTEQUAL_SCALAR :
343+ case VECT_LESS_SCALAR :
344+ case VECT_LESSEQUAL_SCALAR :
345+ case VECT_GREATER_SCALAR :
346+ case VECT_GREATEREQUAL_SCALAR :
347+ case VECT_BITWAND_SCALAR : return lsparseLhs || lsparseRhs ;
348+ case VECT_MULT :
349+ case VECT_DIV :
350+ case VECT_MINUS :
351+ case VECT_PLUS :
352+ case VECT_XOR :
353+ case VECT_BITWAND :
354+ case VECT_BIASADD :
355+ case VECT_BIASMULT :
356+ case VECT_MIN :
357+ case VECT_MAX :
358+ case VECT_NOTEQUAL :
359+ case VECT_LESS :
360+ case VECT_GREATER : return lsparseLhs && lsparseRhs ;
361+ default : return false ;
362+ }
363+ }
223364
224365 @ Override
225366 public String toString () {
0 commit comments