Part 1 will require you to implement memory-saving techniques such as offloading and gradient checkpointing / accumulation. To implement offloading, you may either write your own low-level code, or use the recommended trick: write your own [autograd.Function](https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function) (similar to gradient checkpoint function) that moves the requisite modules on device just in time for computation. Our practice video ('25) contains some tips on extending autograd functions, but those are optional.
0 commit comments