@@ -22,8 +22,6 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
2222 if (reduce == " sum" ) {
2323 SWITCH_BITS (bits, DType, {
2424 SWITCH_OP (op, Op, {
25- DType *out_off = out.Ptr <DType>();
26- std::fill (out_off, out_off + csr.num_rows * dim, 0 );
2725 cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
2826 });
2927 });
@@ -33,8 +31,6 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
3331 DType *out_off = out.Ptr <DType>();
3432 IdType* argX = Op::use_lhs ? static_cast <IdType*>(out_aux[0 ]->data ) : nullptr ;
3533 IdType* argW = Op::use_rhs ? static_cast <IdType*>(out_aux[1 ]->data ) : nullptr ;
36- if (Op::use_lhs) std::fill (argX, argX + csr.num_rows * dim, 0 );
37- if (Op::use_rhs) std::fill (argW, argW + csr.num_rows * dim, 0 );
3834 if (reduce == " max" ) {
3935 std::fill (out_off, out_off + csr.num_rows * dim, cpu::op::Max<DType>::zero);
4036 cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Max<DType>>(
@@ -66,11 +62,6 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
6662 if (reduce == " sum" ) {
6763 SWITCH_BITS (bits, DType, {
6864 SWITCH_OP (op, Op, {
69- // TODO(Israt): Ideally the for loop should go over num_ntypes
70- for (dgl_type_t etype = 0 ; etype < ufeat_node_tids.size (); ++etype) {
71- DType *out_off = vec_out[out_node_tids[etype]].Ptr <DType>();
72- std::fill (out_off, out_off + vec_csr[etype].num_rows * dim, 0 );
73- }
7465 /* Call SpMM for each relation type */
7566 for (dgl_type_t etype = 0 ; etype < ufeat_node_tids.size (); ++etype) {
7667 const dgl_type_t src_id = ufeat_node_tids[etype];
@@ -86,13 +77,6 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
8677 } else if (reduce == " max" || reduce == " min" ) {
8778 SWITCH_BITS (bits, DType, {
8879 SWITCH_OP (op, Op, {
89- // TODO(Israt): Ideally the for loop should go over num_ntypes
90- for (dgl_type_t etype = 0 ; etype < ufeat_node_tids.size (); ++etype) {
91- IdType* argX = Op::use_lhs ? static_cast <IdType*>(out_aux[0 ]->data ) : nullptr ;
92- IdType* argW = Op::use_rhs ? static_cast <IdType*>(out_aux[1 ]->data ) : nullptr ;
93- if (Op::use_lhs) std::fill (argX, argX + vec_csr[etype].num_rows * dim, 0 );
94- if (Op::use_rhs) std::fill (argW, argW + vec_csr[etype].num_rows * dim, 0 );
95- }
9680 /* Call SpMM for each relation type */
9781 for (dgl_type_t etype = 0 ; etype < ufeat_node_tids.size (); ++etype) {
9882 const dgl_type_t src_id = ufeat_node_tids[etype];
0 commit comments