Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions .actions/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,21 @@ def convert_version2nightly(ver_file: str = "src/version.info") -> None:


if __name__ == "__main__":
import sys

import jsonargparse
from jsonargparse import ArgumentParser

def patch_jsonargparse_python_3_12_8():
if sys.version_info < (3, 12, 8):
return

def _parse_known_args_patch(self: ArgumentParser, args: Any = None, namespace: Any = None) -> tuple[Any, Any]:
namespace, args = super(ArgumentParser, self)._parse_known_args(args, namespace, intermixed=False) # type: ignore
return namespace, args

setattr(ArgumentParser, "_parse_known_args", _parse_known_args_patch)

patch_jsonargparse_python_3_12_8() # Required until fix https://github.com/omni-us/jsonargparse/issues/641

jsonargparse.CLI(AssistantCLI, as_positional=False)
18 changes: 9 additions & 9 deletions .github/workflows/ci-tests-fabric.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,16 @@ jobs:
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
# only run PyTorch latest with Python latest, use Fabric scope to limit dependency issues
- { os: "macOS-14", pkg-name: "fabric", python-version: "3.12.7", pytorch-version: "2.5.1" }
- { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.12.7", pytorch-version: "2.5.1" }
- { os: "windows-2022", pkg-name: "fabric", python-version: "3.12.7", pytorch-version: "2.5.1" }
- { os: "macOS-14", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.5.1" }
- { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.5.1" }
- { os: "windows-2022", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.5.1" }
# "oldest" versions tests, only on minimum Python
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.9", pytorch-version: "2.1", requires: "oldest" }
- {
Expand Down
18 changes: 9 additions & 9 deletions .github/workflows/ci-tests-pytorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,16 @@ jobs:
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
# only run PyTorch latest with Python latest, use PyTorch scope to limit dependency issues
- { os: "macOS-14", pkg-name: "pytorch", python-version: "3.12.7", pytorch-version: "2.5.1" }
- { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.12.7", pytorch-version: "2.5.1" }
- { os: "windows-2022", pkg-name: "pytorch", python-version: "3.12.7", pytorch-version: "2.5.1" }
- { os: "macOS-14", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.5.1" }
- { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.5.1" }
- { os: "windows-2022", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.5.1" }
# "oldest" versions tests, only on minimum Python
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.9", pytorch-version: "2.1", requires: "oldest" }
- {
Expand Down
3 changes: 2 additions & 1 deletion examples/fabric/tensor_parallel/train.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import lightning as L
import torch
import torch.nn.functional as F
from data import RandomTokenDataset
from lightning.fabric.strategies import ModelParallelStrategy
from model import ModelArgs, Transformer
from parallelism import parallelize
from torch.distributed.tensor.parallel import loss_parallel
from torch.utils.data import DataLoader

from data import RandomTokenDataset


def train():
strategy = ModelParallelStrategy(
Expand Down
3 changes: 2 additions & 1 deletion examples/pytorch/tensor_parallel/train.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import lightning as L
import torch
import torch.nn.functional as F
from data import RandomTokenDataset
from lightning.pytorch.strategies import ModelParallelStrategy
from model import ModelArgs, Transformer
from parallelism import parallelize
from torch.distributed.tensor.parallel import loss_parallel
from torch.utils.data import DataLoader

from data import RandomTokenDataset


class Llama3(L.LightningModule):
def __init__(self):
Expand Down
14 changes: 14 additions & 0 deletions src/lightning/pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,18 @@

_JSONARGPARSE_SIGNATURES_AVAILABLE = RequirementCache("jsonargparse[signatures]>=4.27.7")


def patch_jsonargparse_python_3_12_8() -> None:
if sys.version_info < (3, 12, 8):
return

def _parse_known_args_patch(self: ArgumentParser, args: Any = None, namespace: Any = None) -> tuple[Any, Any]:
namespace, args = super(ArgumentParser, self)._parse_known_args(args, namespace, intermixed=False) # type: ignore
return namespace, args

setattr(ArgumentParser, "_parse_known_args", _parse_known_args_patch)


if _JSONARGPARSE_SIGNATURES_AVAILABLE:
import docstring_parser
from jsonargparse import (
Expand All @@ -48,6 +60,8 @@
set_config_read_mode,
)

patch_jsonargparse_python_3_12_8() # Required until fix https://github.com/omni-us/jsonargparse/issues/641

register_unresolvable_import_paths(torch) # Required until fix https://github.com/pytorch/pytorch/issues/74483
set_config_read_mode(fsspec_enabled=True)
else:
Expand Down
3 changes: 3 additions & 0 deletions tests/parity_fabric/test_parity_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,5 +162,8 @@ def run_parity_test(accelerator: str = "cpu", devices: int = 2, tolerance: float

if __name__ == "__main__":
from jsonargparse import CLI
from lightning.pytorch.cli import patch_jsonargparse_python_3_12_8

patch_jsonargparse_python_3_12_8() # Required until fix https://github.com/omni-us/jsonargparse/issues/641

CLI(run_parity_test)
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@
import pytest
import torch
import yaml
from jsonargparse import ArgumentParser
from lightning.fabric.utilities.cloud_io import _load as pl_load
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.cli import LightningArgumentParser as ArgumentParser
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
from lightning.pytorch.utilities.exceptions import MisconfigurationException
Expand Down
Loading