@@ -29,9 +29,9 @@ namespace py = pybind11;
2929
3030namespace adam_op {
3131
32- TensorArray<3 > adamForwardInplace (const torch::Tensor & updates,
33- const torch::Tensor & mu,
34- const torch::Tensor & nu,
32+ TensorArray<3 > adamForwardInplace (const torch::Tensor& updates,
33+ const torch::Tensor& mu,
34+ const torch::Tensor& nu,
3535 const pyfloat_t b1,
3636 const pyfloat_t b2,
3737 const pyfloat_t eps,
@@ -49,8 +49,8 @@ TensorArray<3> adamForwardInplace(const torch::Tensor &updates,
4949 }
5050}
5151
52- torch::Tensor adamForwardMu (const torch::Tensor & updates,
53- const torch::Tensor & mu,
52+ torch::Tensor adamForwardMu (const torch::Tensor& updates,
53+ const torch::Tensor& mu,
5454 const pyfloat_t b1) {
5555#if defined(__USE_CUDA__)
5656 if (updates.device ().is_cuda ()) {
@@ -64,8 +64,8 @@ torch::Tensor adamForwardMu(const torch::Tensor &updates,
6464 }
6565}
6666
67- torch::Tensor adamForwardNu (const torch::Tensor & updates,
68- const torch::Tensor & nu,
67+ torch::Tensor adamForwardNu (const torch::Tensor& updates,
68+ const torch::Tensor& nu,
6969 const pyfloat_t b2) {
7070#if defined(__USE_CUDA__)
7171 if (updates.device ().is_cuda ()) {
@@ -79,8 +79,8 @@ torch::Tensor adamForwardNu(const torch::Tensor &updates,
7979 }
8080}
8181
82- torch::Tensor adamForwardUpdates (const torch::Tensor & new_mu,
83- const torch::Tensor & new_nu,
82+ torch::Tensor adamForwardUpdates (const torch::Tensor& new_mu,
83+ const torch::Tensor& new_nu,
8484 const pyfloat_t b1,
8585 const pyfloat_t b2,
8686 const pyfloat_t eps,
@@ -98,9 +98,9 @@ torch::Tensor adamForwardUpdates(const torch::Tensor &new_mu,
9898 }
9999}
100100
101- TensorArray<2 > adamBackwardMu (const torch::Tensor & dmu,
102- const torch::Tensor & updates,
103- const torch::Tensor & mu,
101+ TensorArray<2 > adamBackwardMu (const torch::Tensor& dmu,
102+ const torch::Tensor& updates,
103+ const torch::Tensor& mu,
104104 const pyfloat_t b1) {
105105#if defined(__USE_CUDA__)
106106 if (dmu.device ().is_cuda ()) {
@@ -114,9 +114,9 @@ TensorArray<2> adamBackwardMu(const torch::Tensor &dmu,
114114 }
115115}
116116
117- TensorArray<2 > adamBackwardNu (const torch::Tensor & dnu,
118- const torch::Tensor & updates,
119- const torch::Tensor & nu,
117+ TensorArray<2 > adamBackwardNu (const torch::Tensor& dnu,
118+ const torch::Tensor& updates,
119+ const torch::Tensor& nu,
120120 const pyfloat_t b2) {
121121#if defined(__USE_CUDA__)
122122 if (dnu.device ().is_cuda ()) {
@@ -130,10 +130,10 @@ TensorArray<2> adamBackwardNu(const torch::Tensor &dnu,
130130 }
131131}
132132
133- TensorArray<2 > adamBackwardUpdates (const torch::Tensor & dupdates,
134- const torch::Tensor & updates,
135- const torch::Tensor & new_mu,
136- const torch::Tensor & new_nu,
133+ TensorArray<2 > adamBackwardUpdates (const torch::Tensor& dupdates,
134+ const torch::Tensor& updates,
135+ const torch::Tensor& new_mu,
136+ const torch::Tensor& new_nu,
137137 const pyfloat_t b1,
138138 const pyfloat_t b2,
139139 const pyfloat_t eps_root,
@@ -152,7 +152,7 @@ TensorArray<2> adamBackwardUpdates(const torch::Tensor &dupdates,
152152 }
153153}
154154
155- void buildSubmodule (py::module & mod) { // NOLINT[runtime/references]
155+ void buildSubmodule (py::module & mod) { // NOLINT[runtime/references]
156156 py::module m = mod.def_submodule (" adam_op" , " Adam Ops" );
157157 m.def (" forward_" ,
158158 &adamForwardInplace,
0 commit comments