77 */
88
99#include < executorch/extension/training/optimizer/sgd.h>
10- #include < executorch/kernels/test/FunctionHeaderWrapper .h> // Declares the operator
10+ #include < executorch/kernels/portable/NativeFunctions .h>
1111
1212#include < executorch/runtime/core/error.h>
1313#include < executorch/runtime/kernel/kernel_runtime_context.h>
1414
1515using exec_aten::Tensor;
1616using exec_aten::TensorImpl;
1717using ::executorch::runtime::Error;
18- using ::executorch::runtime::KernelRuntimeContext;
1918
2019namespace executorch {
2120namespace extension {
@@ -73,10 +72,7 @@ Error SGD::step(const std::map<exec_aten::string_view, exec_aten::Tensor>&
7372 auto p = param_iter->second ;
7473 if (weight_decay != 0 ) {
7574 // uses weight_decay specified and adds it to the gradient
76- torch::executor::aten::add_outf (context, d_p, p, weight_decay, d_p);
77- if (context.failure_state () != Error::Ok) {
78- return context.failure_state ();
79- }
75+ torch::executor::native::add_out (d_p, p, weight_decay, d_p);
8076 }
8177 if (momentum != 0 ) {
8278 Tensor buf (nullptr );
@@ -100,11 +96,8 @@ Error SGD::step(const std::map<exec_aten::string_view, exec_aten::Tensor>&
10096 const_cast <TensorImpl::DimOrderType*>(d_p.dim_order ().data ()));
10197 buf = Tensor (buf_impl);
10298#endif
103- torch::executor::aten::clone_outf (
104- context, d_p, exec_aten::MemoryFormat::Contiguous, buf);
105- if (context.failure_state () != Error::Ok) {
106- return context.failure_state ();
107- }
99+ torch::executor::native::clone_out (
100+ d_p, exec_aten::MemoryFormat::Contiguous, buf);
108101
109102 // save the state of the momentum buffer to be reused in later
110103 // epochs
@@ -115,31 +108,18 @@ Error SGD::step(const std::map<exec_aten::string_view, exec_aten::Tensor>&
115108 .momentum_buffer ();
116109
117110 // update the momentum buffer and apply dampening
118- torch::executor::aten::mul_outf (context, buf, momentum, buf);
119- if (context.failure_state () != Error::Ok) {
120- return context.failure_state ();
121- }
122- torch::executor::aten::add_outf (
123- context, buf, d_p, 1 - dampening, buf);
124- if (context.failure_state () != Error::Ok) {
125- return context.failure_state ();
126- }
111+ torch::executor::native::mul_out (context, buf, momentum, buf);
112+ torch::executor::native::add_out (buf, d_p, 1 - dampening, buf);
127113 }
128114 if (nesterov) {
129115 // apply nesterov momentum
130- torch::executor::aten::add_outf (context, d_p, buf, momentum, d_p);
131- if (context.failure_state () != Error::Ok) {
132- return context.failure_state ();
133- }
116+ torch::executor::native::add_out (d_p, buf, momentum, d_p);
134117 } else {
135118 d_p = buf;
136119 }
137120 }
138121 // update the parameter using the gradient and learning rate
139- torch::executor::aten::add_outf (context, p, d_p, -1 * options.lr (), p);
140- if (context.failure_state () != Error::Ok) {
141- return context.failure_state ();
142- }
122+ torch::executor::native::add_out (p, d_p, -1 * options.lr (), p);
143123 }
144124 }
145125 }
0 commit comments