Skip to content

Commit d6b76fb

Browse files
committed
Import directly torchjd.autojac and use autojac.backward in basic usage example
1 parent 6daf5fa commit d6b76fb

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

docs/source/examples/basic_usage.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ Import several classes from ``torch`` and ``torchjd``:
1818
from torch.nn import Linear, MSELoss, ReLU, Sequential
1919
from torch.optim import SGD
2020
21-
import torchjd
21+
from torchjd import autojac
2222
from torchjd.aggregation import UPGrad
2323
2424
Define the model and the optimizer, as usual:
@@ -69,7 +69,7 @@ Perform the Jacobian descent backward pass:
6969

7070
.. code-block:: python
7171
72-
torchjd.autojac.backward([loss1, loss2], aggregator)
72+
autojac.backward([loss1, loss2], aggregator)
7373
7474
This will populate the ``.grad`` field of each model parameter with the corresponding aggregated
7575
Jacobian matrix.

tests/doc/test_rst.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def test_basic_usage():
1010
from torch.nn import Linear, MSELoss, ReLU, Sequential
1111
from torch.optim import SGD
1212

13-
import torchjd
13+
from torchjd import autojac
1414
from torchjd.aggregation import UPGrad
1515

1616
model = Sequential(Linear(10, 5), ReLU(), Linear(5, 2))
@@ -27,7 +27,7 @@ def test_basic_usage():
2727
loss2 = loss_fn(output[:, 1], target2)
2828

2929
optimizer.zero_grad()
30-
torchjd.autojac.backward([loss1, loss2], aggregator)
30+
autojac.backward([loss1, loss2], aggregator)
3131
optimizer.step()
3232

3333

0 commit comments

Comments
 (0)