We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 832b242 commit 0d4ed2aCopy full SHA for 0d4ed2a
README.md
@@ -137,7 +137,7 @@ Jacobian descent using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgr
137
optimizer = SGD(model.parameters(), lr=0.1)
138
139
+ weighting = UPGradWeighting()
140
-+ engine = Engine(model.modules())
++ engine = Engine(model, batch_dim=0)
141
142
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
143
targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task
0 commit comments