Skip to content

Commit d239cf2

Browse files
committed
use binary search. test=develop
1 parent a9f5f82 commit d239cf2

File tree

1 file changed

+4
-28
lines changed

1 file changed

+4
-28
lines changed

paddle/fluid/operators/momentum_op.h

Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -153,33 +153,6 @@ class DenseMomentumFunctor<T, NoNesterov> {
153153
}
154154
};
155155

156-
// TODO(dzh): enhance speed use eigen
157-
// template<typename T>
158-
// class CPUSparseMomentumFunctor {
159-
// private:
160-
// const T* p_;
161-
// const T* g_;
162-
// const T* v_;
163-
// const T* lr_;
164-
// const T mu_;
165-
// const bool use_nesterov_;
166-
// const int64_t* rows_;
167-
// const int64_t row_numel_;
168-
// const int64_t row_height_;
169-
// T* p_out_;
170-
// T* v_out_;
171-
172-
// public:
173-
// CPUSparseMomentumFunctor(const T* p, const T* g, const T* v, const T* lr,
174-
// const T mu, const bool use_nesterov, const int64_t* rows, const int64_t
175-
// row_numel, const int64_t row_height, T* p_out, T* v_out) :p_(p), g_(g),
176-
// v_(v), lr_(lr), mu_(mu), rows_(rows), row_numel_(row_numel),
177-
// row_height_(row_height), p_out_(p_out), v_out_(v_out) {}
178-
// inline void operator()() {
179-
180-
// }
181-
// };
182-
183156
template <typename T, typename UpdateMethod>
184157
class SparseMomentumFunctor;
185158

@@ -367,7 +340,10 @@ class MomentumOpKernel : public framework::OpKernel<T> {
367340
for_range(functor);
368341
}
369342
} else {
370-
PADDLE_THROW("Unsupported Variable Type of Grad");
343+
PADDLE_THROW(
344+
string::Sprintf("MomentumOp only supports LoDTensor or SelectedRows "
345+
"gradient, but the received Variable Type is %s",
346+
grad_var->Type().name()));
371347
}
372348
}
373349
};

0 commit comments

Comments
 (0)