|
1 | 1 | { |
2 | | - "cells": [ |
| 2 | + "cells": [ |
3 | 3 | { |
4 | 4 | "cell_type": "markdown", |
5 | 5 | "metadata": {}, |
|
200 | 200 | "outer_loss = F.mse_loss(net(x), y)\n", |
201 | 201 | "display(\n", |
202 | 202 | " torchopt.visual.make_dot(\n", |
203 | | - " outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]\n", |
204 | | - " )\n", |
| 203 | + " outer_loss,\n", |
| 204 | + " params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}],\n", |
| 205 | + " ),\n", |
205 | 206 | ")" |
206 | 207 | ] |
207 | 208 | }, |
|
247 | 248 | "outer_loss = F.mse_loss(net(x), y)\n", |
248 | 249 | "display(\n", |
249 | 250 | " torchopt.visual.make_dot(\n", |
250 | | - " outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]\n", |
251 | | - " )\n", |
| 251 | + " outer_loss,\n", |
| 252 | + " params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}],\n", |
| 253 | + " ),\n", |
252 | 254 | ")" |
253 | 255 | ] |
254 | 256 | }, |
|
513 | 515 | "source": [ |
514 | 516 | "functional_adam = torchopt.adam(\n", |
515 | 517 | " lr=torchopt.schedule.linear_schedule(\n", |
516 | | - " init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000\n", |
517 | | - " )\n", |
| 518 | + " init_value=1e-3,\n", |
| 519 | + " end_value=1e-4,\n", |
| 520 | + " transition_steps=10000,\n", |
| 521 | + " transition_begin=2000,\n", |
| 522 | + " ),\n", |
518 | 523 | ")\n", |
519 | 524 | "\n", |
520 | 525 | "adam = torchopt.Adam(\n", |
521 | 526 | " net.parameters(),\n", |
522 | 527 | " lr=torchopt.schedule.linear_schedule(\n", |
523 | | - " init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000\n", |
| 528 | + " init_value=1e-3,\n", |
| 529 | + " end_value=1e-4,\n", |
| 530 | + " transition_steps=10000,\n", |
| 531 | + " transition_begin=2000,\n", |
524 | 532 | " ),\n", |
525 | 533 | ")\n", |
526 | 534 | "\n", |
527 | 535 | "meta_adam = torchopt.MetaAdam(\n", |
528 | 536 | " net,\n", |
529 | 537 | " lr=torchopt.schedule.linear_schedule(\n", |
530 | | - " init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000\n", |
| 538 | + " init_value=1e-3,\n", |
| 539 | + " end_value=1e-4,\n", |
| 540 | + " transition_steps=10000,\n", |
| 541 | + " transition_begin=2000,\n", |
531 | 542 | " ),\n", |
532 | 543 | ")" |
533 | 544 | ] |
|
610 | 621 | "optim = torchopt.MetaAdam(net, lr=1.0, moment_requires_grad=True, use_accelerated_op=True)\n", |
611 | 622 | "\n", |
612 | 623 | "net_state_0 = torchopt.extract_state_dict(\n", |
613 | | - " net, by='reference', enable_visual=True, visual_prefix='step0.'\n", |
| 624 | + " net,\n", |
| 625 | + " by='reference',\n", |
| 626 | + " enable_visual=True,\n", |
| 627 | + " visual_prefix='step0.',\n", |
614 | 628 | ")\n", |
615 | 629 | "inner_loss = F.mse_loss(net(x), y)\n", |
616 | 630 | "optim.step(inner_loss)\n", |
617 | 631 | "net_state_1 = torchopt.extract_state_dict(\n", |
618 | | - " net, by='reference', enable_visual=True, visual_prefix='step1.'\n", |
| 632 | + " net,\n", |
| 633 | + " by='reference',\n", |
| 634 | + " enable_visual=True,\n", |
| 635 | + " visual_prefix='step1.',\n", |
619 | 636 | ")\n", |
620 | 637 | "\n", |
621 | 638 | "outer_loss = F.mse_loss(net(x), y)\n", |
622 | 639 | "display(\n", |
623 | 640 | " torchopt.visual.make_dot(\n", |
624 | | - " outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]\n", |
625 | | - " )\n", |
| 641 | + " outer_loss,\n", |
| 642 | + " params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}],\n", |
| 643 | + " ),\n", |
626 | 644 | ")" |
627 | 645 | ] |
628 | 646 | }, |
|
0 commit comments