|
108 | 108 | "metadata": {}, |
109 | 109 | "outputs": [], |
110 | 110 | "source": [ |
111 | | - "def print_info(model_counts, losses, outputs, G, C, D=None):\n", |
| 111 | + "def print_info(model_counts, outputs, losses, G, C, D=None):\n", |
112 | 112 | " def get_shape(v):\n", |
113 | 113 | " if isinstance(v, torch.Tensor):\n", |
114 | 114 | " return v.shape\n", |
|
152 | 152 | "G.count, C.count = 0, 0\n", |
153 | 153 | "hook = ClassifierHook(opts)\n", |
154 | 154 | "model_counts = validate_hook(hook, list(data.keys()))\n", |
155 | | - "losses, outputs = hook({}, {**models, **data})\n", |
156 | | - "print_info(model_counts, losses, outputs, G, C)" |
| 155 | + "outputs, losses = hook({**models, **data})\n", |
| 156 | + "print_info(model_counts, outputs, losses, G, C)" |
157 | 157 | ] |
158 | 158 | }, |
159 | 159 | { |
|
200 | 200 | "weighter = MeanWeighter(weights={\"bsp_loss\": 1e-5})\n", |
201 | 201 | "hook = ClassifierHook(opts, post=[BSPHook(), BNMHook()], weighter=weighter)\n", |
202 | 202 | "model_counts = validate_hook(hook, list(data.keys()))\n", |
203 | | - "losses, outputs = hook({}, {**models, **data})\n", |
204 | | - "print_info(model_counts, losses, outputs, G, C)" |
| 203 | + "outputs, losses = hook({**models, **data})\n", |
| 204 | + "print_info(model_counts, outputs, losses, G, C)" |
205 | 205 | ] |
206 | 206 | }, |
207 | 207 | { |
|
237 | 237 | "G.count, C.count, D.count = 0, 0, 0\n", |
238 | 238 | "hook = DANNHook(opts)\n", |
239 | 239 | "model_counts = validate_hook(hook, list(data.keys()))\n", |
240 | | - "losses, outputs = hook({}, {**models, **data})\n", |
241 | | - "print_info(model_counts, losses, outputs, G, C, D)" |
| 240 | + "outputs, losses = hook({**models, **data})\n", |
| 241 | + "print_info(model_counts, outputs, losses, G, C, D)" |
242 | 242 | ] |
243 | 243 | }, |
244 | 244 | { |
|
271 | 271 | "\n", |
272 | 272 | "hook = DANNHook(opts, post_g=[mcc, atdoc])\n", |
273 | 273 | "model_counts = validate_hook(hook, list(data.keys()))\n", |
274 | | - "losses, outputs = hook({}, {**models, **data})\n", |
275 | | - "print_info(model_counts, losses, outputs, G, C, D)" |
| 274 | + "outputs, losses = hook({**models, **data})\n", |
| 275 | + "print_info(model_counts, outputs, losses, G, C, D)" |
276 | 276 | ] |
277 | 277 | }, |
278 | 278 | { |
|
310 | 310 | "\n", |
311 | 311 | "hook = CDANHook(d_opts=d_opts, g_opts=g_opts)\n", |
312 | 312 | "model_counts = validate_hook(hook, list(data.keys()))\n", |
313 | | - "losses, outputs = hook({}, {**models, **misc, **data})\n", |
314 | | - "print_info(model_counts, losses, outputs, G, C, D)" |
| 313 | + "outputs, losses = hook({**models, **misc, **data})\n", |
| 314 | + "print_info(model_counts, outputs, losses, G, C, D)" |
315 | 315 | ] |
316 | 316 | }, |
317 | 317 | { |
|
341 | 341 | "misc[\"combined_model\"] = torch.nn.Sequential(G, C)\n", |
342 | 342 | "hook = CDANHook(d_opts=d_opts, g_opts=g_opts, post_g=[VATHook()])\n", |
343 | 343 | "model_counts = validate_hook(hook, list(data.keys()))\n", |
344 | | - "losses, outputs = hook({}, {**models, **misc, **data})\n", |
345 | | - "print_info(model_counts, losses, outputs, G, C, D)" |
| 344 | + "outputs, losses = hook({**models, **misc, **data})\n", |
| 345 | + "print_info(model_counts, outputs, losses, G, C, D)" |
346 | 346 | ] |
347 | 347 | }, |
348 | 348 | { |
|
384 | 384 | "\n", |
385 | 385 | "hook = MCDHook(g_opts=g_opts, c_opts=c_opts)\n", |
386 | 386 | "model_counts = validate_hook(hook, list(data.keys()))\n", |
387 | | - "losses, outputs = hook({}, {**models, **data})\n", |
388 | | - "print_info(model_counts, losses, outputs, G, C_multiple)" |
| 387 | + "outputs, losses = hook({**models, **data})\n", |
| 388 | + "print_info(model_counts, outputs, losses, G, C_multiple)" |
389 | 389 | ] |
390 | 390 | }, |
391 | 391 | { |
|
415 | 415 | "G.count, C_multiple.count = 0, 0\n", |
416 | 416 | "hook = MCDHook(g_opts=g_opts, c_opts=c_opts, post_x=[AFNHook()], post_z=[AlignerHook()])\n", |
417 | 417 | "model_counts = validate_hook(hook, list(data.keys()))\n", |
418 | | - "losses, outputs = hook({}, {**models, **data})\n", |
419 | | - "print_info(model_counts, losses, outputs, G, C_multiple)" |
| 418 | + "outputs, losses = hook({**models, **data})\n", |
| 419 | + "print_info(model_counts, outputs, losses, G, C_multiple)" |
420 | 420 | ] |
421 | 421 | }, |
422 | 422 | { |
|
445 | 445 | "name": "python", |
446 | 446 | "nbconvert_exporter": "python", |
447 | 447 | "pygments_lexer": "ipython3", |
448 | | - "version": "3.8.10" |
| 448 | + "version": "3.9.7" |
449 | 449 | } |
450 | 450 | }, |
451 | 451 | "nbformat": 4, |
|
0 commit comments