Skip to content

Commit 68677d3

Browse files
committed
Patch argparse _parse_known_args
1 parent c09fc66 commit 68677d3

File tree

7 files changed

+40
-21
lines changed

7 files changed

+40
-21
lines changed

.actions/assistant.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,16 @@ def convert_version2nightly(ver_file: str = "src/version.info") -> None:
483483

484484

485485
if __name__ == "__main__":
486+
import sys
487+
486488
import jsonargparse
489+
from jsonargparse import ArgumentParser
490+
491+
def _parse_known_args_patch(self: ArgumentParser, args: Any = None, namespace: Any = None) -> tuple[Any, Any]:
492+
namespace, args = super(ArgumentParser, self)._parse_known_args(args, namespace, intermixed=False) # type: ignore
493+
return namespace, args
494+
495+
if sys.version_info >= (3, 12, 8):
496+
setattr(ArgumentParser, "_parse_known_args", _parse_known_args_patch)
487497

488498
jsonargparse.CLI(AssistantCLI, as_positional=False)

.github/workflows/ci-tests-fabric.yml

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,16 @@ jobs:
4949
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
5050
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
5151
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
52-
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
53-
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
54-
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
55-
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
56-
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
57-
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
52+
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
53+
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
54+
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
55+
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
56+
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
57+
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
5858
# only run PyTorch latest with Python latest, use Fabric scope to limit dependency issues
59-
- { os: "macOS-14", pkg-name: "fabric", python-version: "3.12.7", pytorch-version: "2.5.1" }
60-
- { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.12.7", pytorch-version: "2.5.1" }
61-
- { os: "windows-2022", pkg-name: "fabric", python-version: "3.12.7", pytorch-version: "2.5.1" }
59+
- { os: "macOS-14", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.5.1" }
60+
- { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.5.1" }
61+
- { os: "windows-2022", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.5.1" }
6262
# "oldest" versions tests, only on minimum Python
6363
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.9", pytorch-version: "2.1", requires: "oldest" }
6464
- {

.github/workflows/ci-tests-pytorch.yml

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,16 @@ jobs:
5353
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
5454
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
5555
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
56-
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
57-
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
58-
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
59-
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
60-
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
61-
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
56+
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
57+
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
58+
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
59+
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
60+
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
61+
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
6262
# only run PyTorch latest with Python latest, use PyTorch scope to limit dependency issues
63-
- { os: "macOS-14", pkg-name: "pytorch", python-version: "3.12.7", pytorch-version: "2.5.1" }
64-
- { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.12.7", pytorch-version: "2.5.1" }
65-
- { os: "windows-2022", pkg-name: "pytorch", python-version: "3.12.7", pytorch-version: "2.5.1" }
63+
- { os: "macOS-14", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.5.1" }
64+
- { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.5.1" }
65+
- { os: "windows-2022", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.5.1" }
6666
# "oldest" versions tests, only on minimum Python
6767
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.9", pytorch-version: "2.1", requires: "oldest" }
6868
- {

examples/fabric/tensor_parallel/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import lightning as L
22
import torch
33
import torch.nn.functional as F
4-
from data import RandomTokenDataset
54
from lightning.fabric.strategies import ModelParallelStrategy
65
from model import ModelArgs, Transformer
76
from parallelism import parallelize
87
from torch.distributed.tensor.parallel import loss_parallel
98
from torch.utils.data import DataLoader
109

10+
from data import RandomTokenDataset
11+
1112

1213
def train():
1314
strategy = ModelParallelStrategy(

examples/pytorch/tensor_parallel/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import lightning as L
22
import torch
33
import torch.nn.functional as F
4-
from data import RandomTokenDataset
54
from lightning.pytorch.strategies import ModelParallelStrategy
65
from model import ModelArgs, Transformer
76
from parallelism import parallelize
87
from torch.distributed.tensor.parallel import loss_parallel
98
from torch.utils.data import DataLoader
109

10+
from data import RandomTokenDataset
11+
1112

1213
class Llama3(L.LightningModule):
1314
def __init__(self):

src/lightning/pytorch/cli.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,13 @@
4848
set_config_read_mode,
4949
)
5050

51+
def _parse_known_args_patch(self: ArgumentParser, args: Any = None, namespace: Any = None) -> tuple[Any, Any]:
52+
namespace, args = super(ArgumentParser, self)._parse_known_args(args, namespace, intermixed=False) # type: ignore
53+
return namespace, args
54+
55+
if sys.version_info >= (3, 12, 8):
56+
setattr(ArgumentParser, "_parse_known_args", _parse_known_args_patch)
57+
5158
register_unresolvable_import_paths(torch) # Required until fix https://github.com/pytorch/pytorch/issues/74483
5259
set_config_read_mode(fsspec_enabled=True)
5360
else:

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@
2929
import pytest
3030
import torch
3131
import yaml
32-
from jsonargparse import ArgumentParser
3332
from lightning.fabric.utilities.cloud_io import _load as pl_load
3433
from lightning.pytorch import Trainer, seed_everything
3534
from lightning.pytorch.callbacks import ModelCheckpoint
35+
from lightning.pytorch.cli import LightningArgumentParser as ArgumentParser
3636
from lightning.pytorch.demos.boring_classes import BoringModel
3737
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
3838
from lightning.pytorch.utilities.exceptions import MisconfigurationException

0 commit comments

Comments
 (0)