Skip to content

Commit 2ce3fb7

Browse files
committed
Improve clarity of usage examples
1 parent 8ee2ed4 commit 2ce3fb7

File tree

5 files changed

+14
-10
lines changed

5 files changed

+14
-10
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,9 @@ Jacobian descent using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgr
149149

150150
optimizer.zero_grad()
151151
- loss.backward()
152-
+ gramian = engine.compute_gramian(losses)
153-
+ losses.backward(weighting(gramian))
152+
+ gramian = engine.compute_gramian(losses) # shape: [16, 16]
153+
+ weights = weighting(gramian) # shape: [16]
154+
+ losses.backward(weights)
154155
optimizer.step()
155156
```
156157

docs/source/examples/iwrm.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,8 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
135135
y_hat = model(x).squeeze(dim=1) # shape: [16]
136136
losses = loss_fn(y_hat, y) # shape: [16]
137137
optimizer.zero_grad()
138-
gramian = engine.compute_gramian(losses)
139-
weights = weighting(gramian)
138+
gramian = engine.compute_gramian(losses) # shape: [16, 16]
139+
weights = weighting(gramian) # shape: [16]
140140
losses.backward(weights)
141141
optimizer.step()
142142

src/torchjd/autogram/_engine.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,9 @@ class Engine:
8686
>>> losses = criterion(output, target) # shape: [16]
8787
>>>
8888
>>> optimizer.zero_grad()
89-
>>> gramian = engine.compute_gramian(losses)
90-
>>> losses.backward(weighting(gramian))
89+
>>> gramian = engine.compute_gramian(losses) # shape: [16, 16]
90+
>>> weights = weighting(gramian) # shape: [16]
91+
>>> losses.backward(weights)
9192
>>> optimizer.step()
9293
9394
.. warning::

tests/doc/test_autogram.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def test_engine():
2525
losses = criterion(output, target) # shape: [16]
2626

2727
optimizer.zero_grad()
28-
gramian = engine.compute_gramian(losses)
29-
losses.backward(weighting(gramian))
28+
gramian = engine.compute_gramian(losses) # shape: [16, 16]
29+
weights = weighting(gramian) # shape: [16]
30+
losses.backward(weights)
3031
optimizer.step()

tests/doc/test_rst.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,9 @@ def test_autogram():
101101
y_hat = model(x).squeeze(dim=1) # shape: [16]
102102
losses = loss_fn(y_hat, y) # shape: [16]
103103
optimizer.zero_grad()
104-
gramian = engine.compute_gramian(losses)
105-
losses.backward(weighting(gramian))
104+
gramian = engine.compute_gramian(losses) # shape: [16, 16]
105+
weights = weighting(gramian) # shape: [16]
106+
losses.backward(weights)
106107
optimizer.step()
107108

108109
test_autograd()

0 commit comments

Comments
 (0)