Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions torchft/checkpointing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Checkpointing
==============

This module implements methods for checkpointing and resuming training from a checkpoint.
"""

from torchft.checkpointing.http_transport import HTTPTransport
from torchft.checkpointing.transport import CheckpointTransport

__all__ = [
"HTTPTransport",
"CheckpointTransport",
]
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Checkpointing
==============

This module implements methods for checkpointing and resuming training from a checkpoint.
"""

import io
import logging
import socket
Expand All @@ -24,70 +17,14 @@

import torch

from torchft.checkpointing.transport import CheckpointTransport
from torchft.http import _IPv6HTTPServer

logger: logging.Logger = logging.getLogger(__name__)

T = TypeVar("T")


class CheckpointTransport(Generic[T], ABC):
@abstractmethod
def metadata(self) -> str:
"""
Returns a string that will be used by the remote CheckpointTransport to fetch the checkpoint.
"""
...

@abstractmethod
def send_checkpoint(
self, dst_ranks: List[int], step: int, state_dict: T, timeout: timedelta
) -> None:
"""
Sends the checkpoint, only called when there is a rank that is behind.

This may be async.

Args:
dst_ranks: the ranks to send to
step: the step number to send
state_dict: the state dict to send
timeout: the timeout to wait for the checkpoint to be sent
"""
...

def disallow_checkpoint(self) -> None:
"""
Called after send_checkpoint to wait for the checkpoint to be sent.

Once this returns, the state_dict may be mutated so no further data should be sent.
"""
...

@abstractmethod
def recv_checkpoint(
self, src_rank: int, metadata: str, step: int, timeout: timedelta
) -> T:
"""
Receives the checkpoint from the given rank.

Args:
src_rank: the rank to receive the checkpoint from
metadata: the metadata returned by the remote CheckpointTransport
step: the step number to receive
timeout: the timeout to wait for the checkpoint
"""
...

def shutdown(self, wait: bool = True) -> None:
"""
Called to shutdown the checkpoint transport.

Args:
wait: whether to wait for the transport to shutdown
"""


@contextmanager
def _timed_acquire(
lock: threading.Lock, timeout: timedelta
Expand All @@ -107,7 +44,7 @@ def _timed_acquire(
lock.release()


class CheckpointServer(CheckpointTransport[T]):
class HTTPTransport(CheckpointTransport[T]):
"""
This is an HTTP server that can be used to transfer checkpoints
between workers.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
from unittest import TestCase
from unittest.mock import MagicMock

from torchft.checkpointing import CheckpointServer, _timed_acquire
from torchft.checkpointing.http_transport import HTTPTransport, _timed_acquire


class TestCheckpointing(TestCase):
def test_checkpoint_server(self) -> None:
expected = {"state": "dict"}
state_dict_fn = MagicMock()
state_dict_fn.return_value = expected
server = CheckpointServer(
server = HTTPTransport(
timeout=timedelta(seconds=10),
)

Expand Down Expand Up @@ -58,7 +58,7 @@ def test_checkpoint_server(self) -> None:
server.shutdown()

def test_checkpoint_server_locking(self) -> None:
server = CheckpointServer(
server = HTTPTransport(
timeout=timedelta(seconds=10),
)

Expand Down
68 changes: 68 additions & 0 deletions torchft/checkpointing/transport.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod
from datetime import timedelta
from typing import Generic, List, TypeVar

T = TypeVar("T")


class CheckpointTransport(Generic[T], ABC):
@abstractmethod
def metadata(self) -> str:
"""
Returns a string that will be used by the remote CheckpointTransport to fetch the checkpoint.
"""
...

@abstractmethod
def send_checkpoint(
self, dst_ranks: List[int], step: int, state_dict: T, timeout: timedelta
) -> None:
"""
Sends the checkpoint, only called when there is a rank that is behind.

This may be async.

Args:
dst_ranks: the ranks to send to
step: the step number to send
state_dict: the state dict to send
timeout: the timeout to wait for the checkpoint to be sent
"""
...

def disallow_checkpoint(self) -> None:
"""
Called after send_checkpoint to wait for the checkpoint to be sent.

Once this returns, the state_dict may be mutated so no further data should be sent.
"""
...

@abstractmethod
def recv_checkpoint(
self, src_rank: int, metadata: str, step: int, timeout: timedelta
) -> T:
"""
Receives the checkpoint from the given rank.

Args:
src_rank: the rank to receive the checkpoint from
metadata: the metadata returned by the remote CheckpointTransport
step: the step number to receive
timeout: the timeout to wait for the checkpoint
"""
...

def shutdown(self, wait: bool = True) -> None:
"""
Called to shutdown the checkpoint transport.

Args:
wait: whether to wait for the transport to shutdown
"""
4 changes: 2 additions & 2 deletions torchft/fsdp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def _test_fsdp(world_size: int, rank: int) -> None:
# pyre-ignore[56]: Pyre was not able to infer the type of argument
@unittest.skipIf(torch.cuda.device_count() < 4, "Not enough GPUs")
def test_fsdp(self) -> None:
multiprocessing.set_start_method("spawn")
with ProcessPoolExecutor(max_workers=4) as executor:
context = multiprocessing.get_context("spawn")
with ProcessPoolExecutor(max_workers=4, mp_context=context) as executor:
futures = []
for i in range(4):
future = executor.submit(self._test_fsdp, 4, i)
Expand Down
6 changes: 3 additions & 3 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import torch
from torch.distributed import ReduceOp, TCPStore

from torchft.checkpointing import CheckpointServer, CheckpointTransport
from torchft.checkpointing import CheckpointTransport, HTTPTransport
from torchft.futures import future_timeout
from torchft.torchft import Manager as _Manager, ManagerClient

Expand Down Expand Up @@ -141,7 +141,7 @@ def __init__(
replica_id: if rank==0, the replica_id for this group
hostname: if rank==0, the hostname to advertise to the lighthouse server
checkpoint_transport: the checkpoint transport to use for
transfering checkpoints to recovering replicas
transfering checkpoints to recovering replicas, defaults to HTTPTransport
"""
self._load_state_dict = load_state_dict
self._user_state_dict = state_dict
Expand All @@ -160,7 +160,7 @@ def __init__(
self._min_replica_size = min_replica_size

if checkpoint_transport is None:
checkpoint_transport = CheckpointServer[Dict[str, T]](
checkpoint_transport = HTTPTransport[Dict[str, T]](
timeout=timeout,
)

Expand Down