Skip to content

Commit 0816a4a

Browse files
Merge pull request #47 from KevinMusgrave/dev
v0.0.60
2 parents 0b0fb63 + f273fc2 commit 0816a4a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+7009
-7013
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ for data in tqdm(dataloader):
4444
data = batch_to_device(data, device)
4545
# Optimization is done inside the hook.
4646
# The returned loss is for logging.
47-
loss, _ = hook({}, {**models, **data})
47+
_, loss = hook({**models, **data})
4848
```
4949

5050
### Build complex algorithms
@@ -62,7 +62,7 @@ misc = {"combined_model": torch.nn.Sequential(G, C)}
6262
hook = DANNHook(optimizers, post_g=[MCCHook(), VATHook()])
6363
for data in tqdm(dataloader):
6464
data = batch_to_device(data, device)
65-
loss, _ = hook({}, {**models, **data, **misc})
65+
_, loss = hook({**models, **data, **misc})
6666
```
6767

6868
### Wrap with your favorite PyTorch framework

examples/getting_started/CustomizingAlgorithms.ipynb

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@
108108
"metadata": {},
109109
"outputs": [],
110110
"source": [
111-
"def print_info(model_counts, losses, outputs, G, C, D=None):\n",
111+
"def print_info(model_counts, outputs, losses, G, C, D=None):\n",
112112
" def get_shape(v):\n",
113113
" if isinstance(v, torch.Tensor):\n",
114114
" return v.shape\n",
@@ -152,8 +152,8 @@
152152
"G.count, C.count = 0, 0\n",
153153
"hook = ClassifierHook(opts)\n",
154154
"model_counts = validate_hook(hook, list(data.keys()))\n",
155-
"losses, outputs = hook({}, {**models, **data})\n",
156-
"print_info(model_counts, losses, outputs, G, C)"
155+
"outputs, losses = hook({**models, **data})\n",
156+
"print_info(model_counts, outputs, losses, G, C)"
157157
]
158158
},
159159
{
@@ -200,8 +200,8 @@
200200
"weighter = MeanWeighter(weights={\"bsp_loss\": 1e-5})\n",
201201
"hook = ClassifierHook(opts, post=[BSPHook(), BNMHook()], weighter=weighter)\n",
202202
"model_counts = validate_hook(hook, list(data.keys()))\n",
203-
"losses, outputs = hook({}, {**models, **data})\n",
204-
"print_info(model_counts, losses, outputs, G, C)"
203+
"outputs, losses = hook({**models, **data})\n",
204+
"print_info(model_counts, outputs, losses, G, C)"
205205
]
206206
},
207207
{
@@ -237,8 +237,8 @@
237237
"G.count, C.count, D.count = 0, 0, 0\n",
238238
"hook = DANNHook(opts)\n",
239239
"model_counts = validate_hook(hook, list(data.keys()))\n",
240-
"losses, outputs = hook({}, {**models, **data})\n",
241-
"print_info(model_counts, losses, outputs, G, C, D)"
240+
"outputs, losses = hook({**models, **data})\n",
241+
"print_info(model_counts, outputs, losses, G, C, D)"
242242
]
243243
},
244244
{
@@ -271,8 +271,8 @@
271271
"\n",
272272
"hook = DANNHook(opts, post_g=[mcc, atdoc])\n",
273273
"model_counts = validate_hook(hook, list(data.keys()))\n",
274-
"losses, outputs = hook({}, {**models, **data})\n",
275-
"print_info(model_counts, losses, outputs, G, C, D)"
274+
"outputs, losses = hook({**models, **data})\n",
275+
"print_info(model_counts, outputs, losses, G, C, D)"
276276
]
277277
},
278278
{
@@ -310,8 +310,8 @@
310310
"\n",
311311
"hook = CDANHook(d_opts=d_opts, g_opts=g_opts)\n",
312312
"model_counts = validate_hook(hook, list(data.keys()))\n",
313-
"losses, outputs = hook({}, {**models, **misc, **data})\n",
314-
"print_info(model_counts, losses, outputs, G, C, D)"
313+
"outputs, losses = hook({**models, **misc, **data})\n",
314+
"print_info(model_counts, outputs, losses, G, C, D)"
315315
]
316316
},
317317
{
@@ -341,8 +341,8 @@
341341
"misc[\"combined_model\"] = torch.nn.Sequential(G, C)\n",
342342
"hook = CDANHook(d_opts=d_opts, g_opts=g_opts, post_g=[VATHook()])\n",
343343
"model_counts = validate_hook(hook, list(data.keys()))\n",
344-
"losses, outputs = hook({}, {**models, **misc, **data})\n",
345-
"print_info(model_counts, losses, outputs, G, C, D)"
344+
"outputs, losses = hook({**models, **misc, **data})\n",
345+
"print_info(model_counts, outputs, losses, G, C, D)"
346346
]
347347
},
348348
{
@@ -384,8 +384,8 @@
384384
"\n",
385385
"hook = MCDHook(g_opts=g_opts, c_opts=c_opts)\n",
386386
"model_counts = validate_hook(hook, list(data.keys()))\n",
387-
"losses, outputs = hook({}, {**models, **data})\n",
388-
"print_info(model_counts, losses, outputs, G, C_multiple)"
387+
"outputs, losses = hook({**models, **data})\n",
388+
"print_info(model_counts, outputs, losses, G, C_multiple)"
389389
]
390390
},
391391
{
@@ -415,8 +415,8 @@
415415
"G.count, C_multiple.count = 0, 0\n",
416416
"hook = MCDHook(g_opts=g_opts, c_opts=c_opts, post_x=[AFNHook()], post_z=[AlignerHook()])\n",
417417
"model_counts = validate_hook(hook, list(data.keys()))\n",
418-
"losses, outputs = hook({}, {**models, **data})\n",
419-
"print_info(model_counts, losses, outputs, G, C_multiple)"
418+
"outputs, losses = hook({**models, **data})\n",
419+
"print_info(model_counts, outputs, losses, G, C_multiple)"
420420
]
421421
},
422422
{
@@ -445,7 +445,7 @@
445445
"name": "python",
446446
"nbconvert_exporter": "python",
447447
"pygments_lexer": "ipython3",
448-
"version": "3.8.10"
448+
"version": "3.9.7"
449449
}
450450
},
451451
"nbformat": 4,

0 commit comments

Comments
 (0)