Skip to content

Commit c730416

Browse files
Merge pull request #33467 from danielsuo:host-offloading-fix
PiperOrigin-RevId: 835236780
2 parents 34a10de + 1940d78 commit c730416

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

docs/notebooks/host-offloading.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@
617617
"\n",
618618
"**Total Memory Savings**: 33.5 MB (20 MB + 10.75 MB + 1.75 MB)\n",
619619
"\n",
620-
"This hybrid approach demonstrates that parameter and activation offloading work synergistically to achieve significant memory reductions while maintaining computational correctness. \n",
620+
"This hybrid approach demonstrates that parameter and activation offloading work synergistically to achieve significant memory reductions while maintaining computational correctness.\n",
621621
"\n",
622622
"### Limitations of Parameter Offloading\n",
623623
"\n",
@@ -846,13 +846,13 @@
846846
" - Total memory size without offloading: 4.59 GB\n",
847847
" - Net memory saving: 1.72 GB\n",
848848
"\n",
849-
"while offloading increases temporary memory usage, the reduction in argument size more than compensates for this increase, resulting in an overall reduction in device memory usage. \n",
849+
"while offloading increases temporary memory usage, the reduction in argument size more than compensates for this increase, resulting in an overall reduction in device memory usage.\n",
850850
"\n",
851851
"Note: The optimizer states can be compared for numerical equivalence using `jax.tree_util.tree_map` and `jnp.allclose`, but this verification step is omitted here for brevity.\n",
852852
"\n",
853853
"## Tools for Host Offloading\n",
854854
"\n",
855-
"{func}`jax.stages.Compiled.memory_analysis` API is utilized above to get memory usage information. For device memory analysis, refer to :doc:`device_memory_profiling`. The profiling tools described in {ref}`profiling` can help measure memory savings and performance impact from host offloading."
855+
"{func}`jax.stages.Compiled.memory_analysis` API is utilized above to get memory usage information. For device memory analysis, refer to {doc}`../device_memory_profiling`. The profiling tools described in {doc}`../profiling` can help measure memory savings and performance impact from host offloading."
856856
]
857857
}
858858
],

docs/notebooks/host-offloading.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ This implementation demonstrates how offloading model parameters together with a
442442

443443
**Total Memory Savings**: 33.5 MB (20 MB + 10.75 MB + 1.75 MB)
444444

445-
This hybrid approach demonstrates that parameter and activation offloading work synergistically to achieve significant memory reductions while maintaining computational correctness.
445+
This hybrid approach demonstrates that parameter and activation offloading work synergistically to achieve significant memory reductions while maintaining computational correctness.
446446

447447
### Limitations of Parameter Offloading
448448

@@ -631,10 +631,10 @@ Memory Analysis:
631631
- Total memory size without offloading: 4.59 GB
632632
- Net memory saving: 1.72 GB
633633

634-
while offloading increases temporary memory usage, the reduction in argument size more than compensates for this increase, resulting in an overall reduction in device memory usage.
634+
while offloading increases temporary memory usage, the reduction in argument size more than compensates for this increase, resulting in an overall reduction in device memory usage.
635635

636636
Note: The optimizer states can be compared for numerical equivalence using `jax.tree_util.tree_map` and `jnp.allclose`, but this verification step is omitted here for brevity.
637637

638638
## Tools for Host Offloading
639639

640-
{func}`jax.stages.Compiled.memory_analysis` API is utilized above to get memory usage information. For device memory analysis, refer to :doc:`device_memory_profiling`. The profiling tools described in {ref}`profiling` can help measure memory savings and performance impact from host offloading.
640+
{func}`jax.stages.Compiled.memory_analysis` API is utilized above to get memory usage information. For device memory analysis, refer to {doc}`../device_memory_profiling`. The profiling tools described in {doc}`../profiling` can help measure memory savings and performance impact from host offloading.

0 commit comments

Comments
 (0)