This landed - https://github.com/pytorch/torchtitan/pull/1646 - so we can cut our memory usage in half once it is in the nightly builds.