Skip to content

Commit ee1c3a0

Browse files
committed
more to torch.
1 parent d648468 commit ee1c3a0

File tree

7 files changed

+23
-19
lines changed

7 files changed

+23
-19
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ Once more the problem is illustrated below:
4343

4444
$$ \min_{\mathbf{x}} \mathbf{x} \cdot \mathbf{x} + \cos(2 \pi x_0 ) + \sin(2 \pi x_1) + 0.5 \cdot \text{relu}(x_0) + 10 \cdot \tanh( \|\mathbf{x} \| ), \text{ with } \mathbf{x_0} = (2.9, -2.9) .$$
4545

46-
The function is already defined in `src/optimize_2d_momentum_bumpy_jax.py`. We dont have to find the gradient by hand!
47-
Use `jax.grad` [(jax-documentation)](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html) to compute the gradient automatically. Use the result to find the minimum using momentum.
46+
The function is already defined in `src/optimize_2d_momentum_bumpy_torch.py`. We dont have to find the gradient by hand!
47+
Use `torch.func.grad` [(torch-documentation)](https://pytorch.org/docs/stable/generated/torch.func.grad.html) to compute the gradient automatically. Use the result to find the minimum using momentum.
4848

4949
While coding use `nox -s test`, `nox -s lint`, and `nox -s typing` to check your code.
5050
Autoformatting help is available via `nox -s format`.

noxfile.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""This module implements our CI function calls."""
2+
23
import nox
34

45

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
nox
22
numpy
33
matplotlib
4-
jax
4+
torch

src/optimize_1d.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Implement gradient descent in 1d."""
2+
23
import matplotlib.pyplot as plt
34
import numpy as np
45

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
"""Implement 2d gradient descent using jax."""
22

3-
import jax
4-
import jax.numpy as np
5-
import numpy as nnp
3+
import numpy as np
4+
import torch as th
5+
from torch.func import grad
66

77
from util import write_movie
88

99

10-
def bumpy_function(pos: np.ndarray) -> np.ndarray:
10+
def bumpy_function(pos: th.Tensor) -> th.Tensor:
1111
"""Return values from an even bumpier function.
1212
1313
This even bumpier function is hard to optimize.
@@ -22,10 +22,10 @@ def bumpy_function(pos: np.ndarray) -> np.ndarray:
2222
return (
2323
pos[0] * pos[0]
2424
+ pos[1] * pos[1]
25-
+ np.cos(pos[0] * 2 * np.pi)
26-
+ np.sin(pos[1] * 2 * np.pi)
27-
+ (pos[0] > 0).astype(pos.dtype) * 0.5
28-
+ np.tanh(np.sqrt(pos[0] ** 2 + pos[1] ** 2)) * 10
25+
+ th.cos(pos[0] * 2 * th.pi)
26+
+ th.sin(pos[1] * 2 * th.pi)
27+
+ (pos[0] > 0).type(pos.dtype) * 0.5
28+
+ th.tanh(th.sqrt(pos[0] ** 2 + pos[1] ** 2)) * 10
2929
)
3030

3131

@@ -36,10 +36,10 @@ def bumpy_function(pos: np.ndarray) -> np.ndarray:
3636
# TODO: use jax to find the gradient.
3737

3838
nx, ny = (1001, 1001)
39-
x = np.linspace(-3, 3, nx)
40-
y = np.linspace(-3, 3, ny)
41-
mx, my = np.meshgrid(x, y)
42-
pos = np.stack((mx, my))
39+
x = th.linspace(-3, 3, nx)
40+
y = th.linspace(-3, 3, ny)
41+
mx, my = th.meshgrid(x, y)
42+
pos = th.stack((mx, my))
4343
mz = bumpy_function(pos)
4444

4545
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
@@ -51,7 +51,7 @@ def bumpy_function(pos: np.ndarray) -> np.ndarray:
5151
plt.contourf(mx, my, mz)
5252
plt.colorbar()
5353

54-
start_pos = np.array((2.9, -2.9))
54+
start_pos = th.tensor((2.9, -2.9))
5555
step_size = 0.0 # TODO: Choose your step size.
5656
alpha = 0.0 # TODO: Choose your momentum term.
5757
step_total = 100
@@ -65,9 +65,9 @@ def bumpy_function(pos: np.ndarray) -> np.ndarray:
6565
plt.show()
6666

6767
write_movie(
68-
nnp.array(mx),
69-
nnp.array(my),
70-
nnp.array(mz),
68+
np.array(mx),
69+
np.array(my),
70+
np.array(mz),
7171
pos_list,
7272
"writer_grad_bumpy_plot_jax",
7373
)

src/util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Code to export gradient descent sequences into a movie."""
2+
23
from typing import Optional
34

45
import matplotlib.animation as manimation

tests/test_function.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"Difference Qutiones may sometimes be useful, too"
66
Andreas Griewank, Andrea Walther - Evaluating Derivatives.
77
"""
8+
89
import sys
910
from typing import Callable, Optional, Union
1011

0 commit comments

Comments
 (0)