Skip to content

Commit 0369058

Browse files
committed
[SYSTEMDS-2854] Followup to bugfix in SUM_SQ CUDA codegen
in commit 197a14b Tested with org.apache.sysds.test.functions.codegen.CellwiseTmplTest.testCodegenCellwiseRewrite9()
1 parent 7858979 commit 0369058

File tree

3 files changed

+6
-18
lines changed

3 files changed

+6
-18
lines changed

src/main/cuda/headers/agg_ops.cuh

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,6 @@ struct IdentityOp {
9292
template<typename T>
9393
struct SumOp {
9494
__device__ __forceinline__ T operator()(T a, T b) const {
95-
// if(blockIdx.x==0 && threadIdx.x ==0)
96-
// printf("a=%f + b=%f => %f\n", a, b, a+b);
9795
return a + b;
9896
}
9997

@@ -124,16 +122,6 @@ struct MinusOp {
124122
}
125123
};
126124

127-
/**
128-
* Functor op for sum of squares operation (returns a + b * b)
129-
*/
130-
template<typename T>
131-
struct SumSqOp {
132-
__device__ __forceinline__ T operator()(T a, T b) const {
133-
return a + b * b;
134-
}
135-
};
136-
137125
/**
138126
* Functor op for min operation
139127
*/

src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,9 @@ public String codegen(boolean sparse, GeneratorAPI _api) {
162162

163163
tmp = tmp.replace("%BODY_dense%", tmpDense);
164164

165-
//return last TMP
166-
tmp = tmp.replaceAll("%OUT%", _output.getVarname());
165+
//Return last TMP. Square it for CUDA+SUM_SQ
166+
tmp = (api.isJava() || _aggOp != AggOp.SUM_SQ) ? tmp.replaceAll("%OUT%", _output.getVarname()) :
167+
tmp.replaceAll("%OUT%", _output.getVarname() + " * " + _output.getVarname());
167168

168169
//replace meta data information
169170
tmp = tmp.replaceAll("%TYPE%", getCellType().name());
@@ -181,11 +182,8 @@ public String codegen(boolean sparse, GeneratorAPI _api) {
181182
if(_aggOp != null)
182183
switch(_aggOp) {
183184
case SUM:
184-
agg_op = "SumOp";
185-
initial_value = "(T)0.0";
186-
break;
187185
case SUM_SQ:
188-
agg_op = "SumSqOp";
186+
agg_op = "SumOp";
189187
initial_value = "(T)0.0";
190188
break;
191189
case MIN:

src/test/scripts/functions/codegen/SystemDS-config-codegen6.xml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,6 @@
2727

2828
<!-- The number of theads for the spark instance artificially selected-->
2929
<sysds.local.spark.number.threads>16</sysds.local.spark.number.threads>
30+
31+
<sysds.codegen.api>auto</sysds.codegen.api>
3032
</root>

0 commit comments

Comments
 (0)