Skip to content

Commit 4ed3585

Browse files
committed
Fix code style
1 parent 810a6a2 commit 4ed3585

File tree

5 files changed

+27
-23
lines changed

5 files changed

+27
-23
lines changed

examples/tutorial.ipynb

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"%matplotlib inline\n",
1414
"\n",
1515
"import sys\n",
16+
"\n",
1617
"sys.path.append(\"../\")\n",
1718
"import tensorflow_riemopt as riemopt"
1819
]
@@ -119,7 +120,7 @@
119120
"\n",
120121
"opt = riemopt.optimizers.RiemannianAdam(learning_rate=0.2)\n",
121122
"\n",
122-
"npole = tf.constant([0., 1.])\n",
123+
"npole = tf.constant([0.0, 1.0])\n",
123124
"phi = np.linspace(-np.pi, np.pi, 100)\n",
124125
"\n",
125126
"for _ in range(STEPS):\n",
@@ -136,8 +137,14 @@
136137
" plt.plot(np.cos(phi), np.sin(phi))\n",
137138
" plt.plot(var_np[:, 0], var_np[:, 1], '+', color='black')\n",
138139
" for i in range(len(egrad_np)):\n",
139-
" plt.arrow(var_np[i][0], var_np[i][1], -egrad_np[i][0], -egrad_np[i][1],\n",
140-
" width=0.01, color='green')\n",
140+
" plt.arrow(\n",
141+
" var_np[i][0],\n",
142+
" var_np[i][1],\n",
143+
" -egrad_np[i][0],\n",
144+
" -egrad_np[i][1],\n",
145+
" width=0.01,\n",
146+
" color='green',\n",
147+
" )\n",
141148
" plt.plot(var_t_np[:, 0], var_t_np[:, 1], '+', color='red')\n",
142149
" plt.show()"
143150
]
@@ -152,7 +159,7 @@
152159
],
153160
"metadata": {
154161
"kernelspec": {
155-
"display_name": "Python 3",
162+
"display_name": "Python 3 (ipykernel)",
156163
"language": "python",
157164
"name": "python3"
158165
},
@@ -166,7 +173,7 @@
166173
"name": "python",
167174
"nbconvert_exporter": "python",
168175
"pygments_lexer": "ipython3",
169-
"version": "3.7.3"
176+
"version": "3.10.6"
170177
}
171178
},
172179
"nbformat": 4,

examples/usage.ipynb

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"from mpl_toolkits.mplot3d import Axes3D\n",
1515
"\n",
1616
"import sys\n",
17+
"\n",
1718
"sys.path.append(\"../\")\n",
1819
"import tensorflow_riemopt as riemopt"
1920
]
@@ -32,15 +33,17 @@
3233
" s_z = np.outer(np.abs(np.cos(phi)), np.ones_like(psi))\n",
3334
" return ax.plot_wireframe(s_x, s_y, s_z, color=color, alpha=0.3)\n",
3435
"\n",
36+
"\n",
3537
"def plot_vector(ax, x, u, color=\"darkorange\"):\n",
3638
" return ax.quiver(*x, *u, length=0.6, normalize=True, color=color)\n",
3739
"\n",
40+
"\n",
3841
"def plot_hyperplane(ax, p, u, v, color=\"limegreen\"):\n",
39-
" xx = np.linspace(-0.05, 1., 10)\n",
40-
" yy = np.linspace(-1., 0.1, 10)\n",
42+
" xx = np.linspace(-0.05, 1.0, 10)\n",
43+
" yy = np.linspace(-1.0, 0.1, 10)\n",
4144
" x, y = np.meshgrid(xx, yy)\n",
4245
" n = np.cross(u, v)\n",
43-
" z = (- n[0] * x - n[1] * y + p.dot(n)) / n[2]\n",
46+
" z = (-n[0] * x - n[1] * y + p.dot(n)) / n[2]\n",
4447
" return ax.plot_wireframe(x, y, z, color=color, alpha=0.4)"
4548
]
4649
},
@@ -53,7 +56,7 @@
5356
"S = riemopt.manifolds.Sphere()\n",
5457
"\n",
5558
"x = S.projx(tf.constant([0.1, -0.1, 0.1]))\n",
56-
"u = S.proju(x, tf.constant([1., 1., 1.]))\n",
59+
"u = S.proju(x, tf.constant([1.0, 1.0, 1.0]))\n",
5760
"v = S.proju(x, tf.constant([-0.7, -1.4, 1.4]))\n",
5861
"y = S.exp(x, v)\n",
5962
"u_ = S.transp(x, y, u)\n",
@@ -866,9 +869,9 @@
866869
"fig = plt.figure()\n",
867870
"ax = fig.gca(projection=\"3d\")\n",
868871
"ax.axis(\"off\")\n",
869-
"ax.set_zlim3d(-1.5, 1.5) \n",
870-
"ax.set_ylim3d(-1.5, 1.5) \n",
871-
"ax.set_xlim3d(-1.5, 1.5) \n",
872+
"ax.set_zlim3d(-1.5, 1.5)\n",
873+
"ax.set_ylim3d(-1.5, 1.5)\n",
874+
"ax.set_xlim3d(-1.5, 1.5)\n",
872875
"\n",
873876
"plot_halfsphere(ax)\n",
874877
"plot_hyperplane(ax, x, u, v)\n",
@@ -885,7 +888,7 @@
885888
],
886889
"metadata": {
887890
"kernelspec": {
888-
"display_name": "Python 3",
891+
"display_name": "Python 3 (ipykernel)",
889892
"language": "python",
890893
"name": "python3"
891894
},
@@ -899,7 +902,7 @@
899902
"name": "python",
900903
"nbconvert_exporter": "python",
901904
"pygments_lexer": "ipython3",
902-
"version": "3.7.3"
905+
"version": "3.10.6"
903906
}
904907
},
905908
"nbformat": 4,

tensorflow_riemopt/optimizers/constrained_rmsprop.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,7 @@ def _create_slots(self, var_list):
8787
self.add_slot(var, "mg")
8888

8989
def _prepare_local(self, var_device, var_dtype, apply_state):
90-
super()._prepare_local(
91-
var_device, var_dtype, apply_state
92-
)
90+
super()._prepare_local(var_device, var_dtype, apply_state)
9391

9492
rho = array_ops.identity(self._get_hyper("rho", var_dtype))
9593
apply_state[(var_device, var_dtype)].update(

tensorflow_riemopt/optimizers/riemannian_adam.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,7 @@ def _create_slots(self, var_list):
9292
self.add_slot(var, "vhat")
9393

9494
def _prepare_local(self, var_device, var_dtype, apply_state):
95-
super()._prepare_local(
96-
var_device, var_dtype, apply_state
97-
)
95+
super()._prepare_local(var_device, var_dtype, apply_state)
9896

9997
local_step = math_ops.cast(self.iterations + 1, var_dtype)
10098
beta_1_t = array_ops.identity(self._get_hyper("beta_1", var_dtype))

tensorflow_riemopt/optimizers/riemannian_gradient_descent.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,7 @@ def _create_slots(self, var_list):
8585
self.add_slot(var, "momentum")
8686

8787
def _prepare_local(self, var_device, var_dtype, apply_state):
88-
super()._prepare_local(
89-
var_device, var_dtype, apply_state
90-
)
88+
super()._prepare_local(var_device, var_dtype, apply_state)
9189
apply_state[(var_device, var_dtype)]["momentum"] = array_ops.identity(
9290
self._get_hyper("momentum", var_dtype)
9391
)

0 commit comments

Comments
 (0)