Skip to content

Commit 418b8d4

Browse files
authored
Merge branch 'master' into master
2 parents 965fc03 + 8ff43d4 commit 418b8d4

File tree

1,659 files changed

+25429
-97497
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

1,659 files changed

+25429
-97497
lines changed

.actions/assistant.py

Lines changed: 64 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,18 @@
1414
import glob
1515
import logging
1616
import os
17-
import pathlib
1817
import re
1918
import shutil
20-
import tarfile
2119
import tempfile
2220
import urllib.request
23-
from distutils.version import LooseVersion
21+
from collections.abc import Iterable, Iterator, Sequence
2422
from itertools import chain
2523
from os.path import dirname, isfile
2624
from pathlib import Path
27-
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple, Union
25+
from typing import Any, Optional
2826

29-
from pkg_resources import Requirement, parse_requirements, yield_lines
27+
from packaging.requirements import Requirement
28+
from packaging.version import Version
3029

3130
REQUIREMENT_FILES = {
3231
"pytorch": (
@@ -35,20 +34,11 @@
3534
"requirements/pytorch/strategies.txt",
3635
"requirements/pytorch/examples.txt",
3736
),
38-
"app": (
39-
"requirements/app/app.txt",
40-
"requirements/app/cloud.txt",
41-
"requirements/app/ui.txt",
42-
),
4337
"fabric": (
4438
"requirements/fabric/base.txt",
4539
"requirements/fabric/strategies.txt",
4640
),
47-
"data": (
48-
"requirements/data/data.txt",
49-
"requirements/data/cloud.txt",
50-
"requirements/data/examples.txt",
51-
),
41+
"data": ("requirements/data/data.txt",),
5242
}
5343
REQUIREMENT_FILES_ALL = list(chain(*REQUIREMENT_FILES.values()))
5444

@@ -91,14 +81,15 @@ def adjust(self, unfreeze: str) -> str:
9181
out = str(self)
9282
if self.strict:
9383
return f"{out} {self.strict_string}"
84+
specs = [(spec.operator, spec.version) for spec in self.specifier]
9485
if unfreeze == "major":
95-
for operator, version in self.specs:
86+
for operator, version in specs:
9687
if operator in ("<", "<="):
97-
major = LooseVersion(version).version[0]
88+
major = Version(version).major
9889
# replace upper bound with major version increased by one
9990
return out.replace(f"{operator}{version}", f"<{major + 1}.0")
10091
elif unfreeze == "all":
101-
for operator, version in self.specs:
92+
for operator, version in specs:
10293
if operator in ("<", "<="):
10394
# drop upper bound
10495
return out.replace(f"{operator}{version},", "")
@@ -107,33 +98,25 @@ def adjust(self, unfreeze: str) -> str:
10798
return out
10899

109100

110-
def _parse_requirements(strs: Union[str, Iterable[str]]) -> Iterator[_RequirementWithComment]:
101+
def _parse_requirements(lines: Iterable[str]) -> Iterator[_RequirementWithComment]:
111102
"""Adapted from `pkg_resources.parse_requirements` to include comments.
112103
113104
>>> txt = ['# ignored', '', 'this # is an', '--piparg', 'example', 'foo # strict', 'thing', '-r different/file.txt']
114105
>>> [r.adjust('none') for r in _parse_requirements(txt)]
115106
['this', 'example', 'foo # strict', 'thing']
116-
>>> txt = '\\n'.join(txt)
117-
>>> [r.adjust('none') for r in _parse_requirements(txt)]
118-
['this', 'example', 'foo # strict', 'thing']
119107
120108
"""
121-
lines = yield_lines(strs)
122109
pip_argument = None
123110
for line in lines:
111+
line = line.strip()
112+
if not line or line.startswith("#"):
113+
continue
124114
# Drop comments -- a hash without a space may be in a URL.
125115
if " #" in line:
126116
comment_pos = line.find(" #")
127117
line, comment = line[:comment_pos], line[comment_pos:]
128118
else:
129119
comment = ""
130-
# If there is a line continuation, drop it, and append the next line.
131-
if line.endswith("\\"):
132-
line = line[:-2].strip()
133-
try:
134-
line += next(lines)
135-
except StopIteration:
136-
return
137120
# If there's a pip argument, save it
138121
if line.startswith("--"):
139122
pip_argument = line
@@ -145,7 +128,7 @@ def _parse_requirements(strs: Union[str, Iterable[str]]) -> Iterator[_Requiremen
145128
pip_argument = None
146129

147130

148-
def load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: str = "all") -> List[str]:
131+
def load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: str = "all") -> list[str]:
149132
"""Loading requirements from a file.
150133
151134
>>> path_req = os.path.join(_PROJECT_ROOT, "requirements")
@@ -159,7 +142,7 @@ def load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: str
159142
logging.warning(f"Folder {path_dir} does not have any base requirements.")
160143
return []
161144
assert path.exists(), (path_dir, file_name, path)
162-
text = path.read_text()
145+
text = path.read_text().splitlines()
163146
return [req.adjust(unfreeze) for req in _parse_requirements(text)]
164147

165148

@@ -171,8 +154,8 @@ def load_readme_description(path_dir: str, homepage: str, version: str) -> str:
171154
172155
"""
173156
path_readme = os.path.join(path_dir, "README.md")
174-
with open(path_readme, encoding="utf-8") as fo:
175-
text = fo.read()
157+
with open(path_readme, encoding="utf-8") as fopen:
158+
text = fopen.read()
176159

177160
# drop images from readme
178161
text = text.replace(
@@ -220,30 +203,6 @@ def distribute_version(src_folder: str, ver_file: str = "version.info") -> None:
220203
shutil.copy2(ver_template, fpath)
221204

222205

223-
def _download_frontend(pkg_path: str, version: str = "v0.0.0"):
224-
"""Downloads an archive file for a specific release of the Lightning frontend and extracts it to the correct
225-
directory."""
226-
227-
try:
228-
frontend_dir = pathlib.Path(pkg_path, "ui")
229-
download_dir = tempfile.mkdtemp()
230-
231-
shutil.rmtree(frontend_dir, ignore_errors=True)
232-
# TODO: remove this once lightning-ui package is ready as a dependency
233-
frontend_release_url = f"https://lightning-packages.s3.amazonaws.com/ui/{version}.tar.gz"
234-
response = urllib.request.urlopen(frontend_release_url)
235-
236-
file = tarfile.open(fileobj=response, mode="r|gz")
237-
file.extractall(path=download_dir) # noqa: S202
238-
239-
shutil.move(download_dir, frontend_dir)
240-
print("The Lightning UI has successfully been downloaded!")
241-
242-
# If installing from source without internet connection, we don't want to break the installation
243-
except Exception:
244-
print("The Lightning UI downloading has failed!")
245-
246-
247206
def _load_aggregate_requirements(req_dir: str = "requirements", freeze_requirements: bool = False) -> None:
248207
"""Load all base requirements from all particular packages and prune duplicates.
249208
@@ -264,7 +223,7 @@ def _load_aggregate_requirements(req_dir: str = "requirements", freeze_requireme
264223
fp.writelines([ln + os.linesep for ln in requires] + [os.linesep])
265224

266225

267-
def _retrieve_files(directory: str, *ext: str) -> List[str]:
226+
def _retrieve_files(directory: str, *ext: str) -> list[str]:
268227
all_files = []
269228
for root, _, files in os.walk(directory):
270229
for fname in files:
@@ -274,7 +233,7 @@ def _retrieve_files(directory: str, *ext: str) -> List[str]:
274233
return all_files
275234

276235

277-
def _replace_imports(lines: List[str], mapping: List[Tuple[str, str]], lightning_by: str = "") -> List[str]:
236+
def _replace_imports(lines: list[str], mapping: list[tuple[str, str]], lightning_by: str = "") -> list[str]:
278237
"""Replace imports of standalone package to lightning.
279238
280239
>>> lns = [
@@ -349,31 +308,33 @@ def copy_replace_imports(
349308
if ext in (".pyc",):
350309
continue
351310
# Try to parse everything else
352-
with open(fp, encoding="utf-8") as fo:
311+
with open(fp, encoding="utf-8") as fopen:
353312
try:
354-
lines = fo.readlines()
313+
lines = fopen.readlines()
355314
except UnicodeDecodeError:
356315
# a binary file, skip
357316
print(f"Skipped replacing imports for {fp}")
358317
continue
359318
lines = _replace_imports(lines, list(zip(source_imports, target_imports)), lightning_by=lightning_by)
360319
os.makedirs(os.path.dirname(fp_new), exist_ok=True)
361-
with open(fp_new, "w", encoding="utf-8") as fo:
362-
fo.writelines(lines)
320+
with open(fp_new, "w", encoding="utf-8") as fopen:
321+
fopen.writelines(lines)
363322

364323

365-
def create_mirror_package(source_dir: str, package_mapping: Dict[str, str]) -> None:
324+
def create_mirror_package(source_dir: str, package_mapping: dict[str, str]) -> None:
325+
"""Create a mirror package with adjusted imports."""
366326
# replace imports and copy the code
367327
mapping = package_mapping.copy()
368328
mapping.pop("lightning", None) # pop this key to avoid replacing `lightning` to `lightning.lightning`
369329

370330
mapping = {f"lightning.{sp}": sl for sp, sl in mapping.items()}
371331
for pkg_from, pkg_to in mapping.items():
332+
source_imports, target_imports = zip(*mapping.items())
372333
copy_replace_imports(
373334
source_dir=os.path.join(source_dir, pkg_from.replace(".", os.sep)),
374335
# pytorch_lightning uses lightning_fabric, so we need to replace all imports for all directories
375-
source_imports=mapping.keys(),
376-
target_imports=mapping.values(),
336+
source_imports=source_imports,
337+
target_imports=target_imports,
377338
target_dir=os.path.join(source_dir, pkg_to.replace(".", os.sep)),
378339
lightning_by=pkg_from,
379340
)
@@ -401,26 +362,12 @@ def _prune_packages(req_file: str, packages: Sequence[str]) -> None:
401362
if not ln_ or ln_.startswith("#"):
402363
final.append(line)
403364
continue
404-
req = list(parse_requirements(ln_))[0]
365+
req = list(_parse_requirements([ln_]))[0]
405366
if req.name not in packages:
406367
final.append(line)
407368
print(final)
408369
path.write_text("\n".join(final) + "\n")
409370

410-
@staticmethod
411-
def _replace_min(fname: str) -> None:
412-
with open(fname, encoding="utf-8") as fo:
413-
req = fo.read().replace(">=", "==")
414-
with open(fname, "w", encoding="utf-8") as fw:
415-
fw.write(req)
416-
417-
@staticmethod
418-
def replace_oldest_ver(requirement_fnames: Sequence[str] = REQUIREMENT_FILES_ALL) -> None:
419-
"""Replace the min package version by fixed one."""
420-
for fname in requirement_fnames:
421-
print(fname)
422-
AssistantCLI._replace_min(fname)
423-
424371
@staticmethod
425372
def copy_replace_imports(
426373
source_dir: str,
@@ -468,7 +415,7 @@ def pull_docs_files(
468415
raise RuntimeError(f"Requesting file '{zip_url}' does not exist or it is just unavailable.")
469416

470417
with zipfile.ZipFile(zip_file, "r") as zip_ref:
471-
zip_ref.extractall(tmp) # noqa: S202
418+
zip_ref.extractall(tmp)
472419

473420
zip_dirs = [d for d in glob.glob(os.path.join(tmp, "*")) if os.path.isdir(d)]
474421
# check that the extracted archive has only repo folder
@@ -505,6 +452,40 @@ def _copy_rst(rst_in, rst_out, as_orphan: bool = False):
505452
with open(rst_out, "w", encoding="utf-8") as fopen:
506453
fopen.write(page)
507454

455+
@staticmethod
456+
def convert_version2nightly(ver_file: str = "src/version.info") -> None:
457+
"""Load the actual version and convert it to the nightly version."""
458+
from datetime import datetime
459+
460+
with open(ver_file) as fopen:
461+
version = fopen.read().strip()
462+
# parse X.Y.Z version and prune any suffix
463+
vers = re.match(r"(\d+)\.(\d+)\.(\d+).*", version)
464+
# create timestamp YYYYMMDD
465+
timestamp = datetime.now().strftime("%Y%m%d")
466+
version = f"{'.'.join(vers.groups())}.dev{timestamp}"
467+
with open(ver_file, "w") as fopen:
468+
fopen.write(version + os.linesep)
469+
470+
@staticmethod
471+
def generate_docker_tags(
472+
release_version: str,
473+
python_version: str,
474+
torch_version: str,
475+
cuda_version: str,
476+
docker_project: str = "pytorchlightning/pytorch_lightning",
477+
add_latest: bool = False,
478+
) -> None:
479+
"""Generate docker tags for the given versions."""
480+
tags = [f"latest-py{python_version}-torch{torch_version}-cuda{cuda_version}"]
481+
if release_version:
482+
tags += [f"{release_version}-py{python_version}-torch{torch_version}-cuda{cuda_version}"]
483+
if add_latest:
484+
tags += ["latest"]
485+
486+
tags = [f"{docker_project}:{tag}" for tag in tags]
487+
print(",".join(tags))
488+
508489

509490
if __name__ == "__main__":
510491
import jsonargparse

.actions/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1-
jsonargparse >=4.16.0, <4.28.0
1+
jsonargparse
22
requests
3+
packaging

0 commit comments

Comments
 (0)