Skip to content

Commit 8102a9c

Browse files
committed
Make usage examples match doc tests
1 parent 42b245c commit 8102a9c

File tree

6 files changed

+7
-7
lines changed

6 files changed

+7
-7
lines changed

docs/source/examples/amp.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@ following example shows the resulting code for a multi-task learning use-case.
1616
1717
import torch
1818
from torch.amp import GradScaler
19-
from torch.nn import Sequential, Linear, ReLU, MSELoss
19+
from torch.nn import Linear, MSELoss, ReLU, Sequential
2020
from torch.optim import SGD
2121
22-
from torchjd.autojac import mtl_backward
2322
from torchjd.aggregation import UPGrad
23+
from torchjd.autojac import mtl_backward
2424
2525
shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
2626
task1_module = Linear(3, 1)

docs/source/examples/lightning_integration.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ The following code example demonstrates a basic multi-task learning setup using
2121
from torch.optim import Adam
2222
from torch.utils.data import DataLoader, TensorDataset
2323
24-
from torchjd.autojac import mtl_backward
2524
from torchjd.aggregation import UPGrad
25+
from torchjd.autojac import mtl_backward
2626
2727
class Model(LightningModule):
2828
def __init__(self):

docs/source/examples/monitoring.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ they have a negative inner product).
2222
from torch.nn.functional import cosine_similarity
2323
from torch.optim import SGD
2424
25-
from torchjd.autojac import mtl_backward
2625
from torchjd.aggregation import UPGrad
26+
from torchjd.autojac import mtl_backward
2727
2828
def print_weights(_, __, weights: torch.Tensor) -> None:
2929
"""Prints the extracted weights."""

docs/source/examples/mtl.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks.
2525
from torch.nn import Linear, MSELoss, ReLU, Sequential
2626
from torch.optim import SGD
2727
28-
from torchjd.autojac import mtl_backward
2928
from torchjd.aggregation import UPGrad
29+
from torchjd.autojac import mtl_backward
3030
3131
shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
3232
task1_module = Linear(3, 1)

docs/source/examples/rnn.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ descent can be leveraged to enhance optimization.
1212
from torch.nn import RNN
1313
from torch.optim import SGD
1414
15-
from torchjd.autojac import backward
1615
from torchjd.aggregation import UPGrad
16+
from torchjd.autojac import backward
1717
1818
rnn = RNN(input_size=10, hidden_size=20, num_layers=2)
1919
optimizer = SGD(rnn.parameters(), lr=0.1)

src/torchjd/autojac/_backward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ def backward(
4141
4242
>>> import torch
4343
>>>
44-
>>> from torchjd.autojac import backward
4544
>>> from torchjd.aggregation import UPGrad
45+
>>> from torchjd.autojac import backward
4646
>>>
4747
>>> param = torch.tensor([1., 2.], requires_grad=True)
4848
>>> # Compute arbitrary quantities that are function of param

0 commit comments

Comments
 (0)