@@ -153,33 +153,6 @@ class DenseMomentumFunctor<T, NoNesterov> {
153
153
}
154
154
};
155
155
156
- // TODO(dzh): enhance speed use eigen
157
- // template<typename T>
158
- // class CPUSparseMomentumFunctor {
159
- // private:
160
- // const T* p_;
161
- // const T* g_;
162
- // const T* v_;
163
- // const T* lr_;
164
- // const T mu_;
165
- // const bool use_nesterov_;
166
- // const int64_t* rows_;
167
- // const int64_t row_numel_;
168
- // const int64_t row_height_;
169
- // T* p_out_;
170
- // T* v_out_;
171
-
172
- // public:
173
- // CPUSparseMomentumFunctor(const T* p, const T* g, const T* v, const T* lr,
174
- // const T mu, const bool use_nesterov, const int64_t* rows, const int64_t
175
- // row_numel, const int64_t row_height, T* p_out, T* v_out) :p_(p), g_(g),
176
- // v_(v), lr_(lr), mu_(mu), rows_(rows), row_numel_(row_numel),
177
- // row_height_(row_height), p_out_(p_out), v_out_(v_out) {}
178
- // inline void operator()() {
179
-
180
- // }
181
- // };
182
-
183
156
template <typename T, typename UpdateMethod>
184
157
class SparseMomentumFunctor ;
185
158
@@ -367,7 +340,10 @@ class MomentumOpKernel : public framework::OpKernel<T> {
367
340
for_range (functor);
368
341
}
369
342
} else {
370
- PADDLE_THROW (" Unsupported Variable Type of Grad" );
343
+ PADDLE_THROW (
344
+ string::Sprintf (" MomentumOp only supports LoDTensor or SelectedRows "
345
+ " gradient, but the received Variable Type is %s" ,
346
+ grad_var->Type ().name ()));
371
347
}
372
348
}
373
349
};
0 commit comments