Skip to content

Commit 230b71f

Browse files
authored
Merge branch 'master' into feature/19743-tensorboard-histograms
2 parents 2add14d + 8ff43d4 commit 230b71f

File tree

1,555 files changed

+21520
-82319
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

1,555 files changed

+21520
-82319
lines changed

.actions/assistant.py

Lines changed: 48 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,18 @@
1414
import glob
1515
import logging
1616
import os
17-
import pathlib
1817
import re
1918
import shutil
20-
import tarfile
2119
import tempfile
2220
import urllib.request
23-
from distutils.version import LooseVersion
21+
from collections.abc import Iterable, Iterator, Sequence
2422
from itertools import chain
2523
from os.path import dirname, isfile
2624
from pathlib import Path
27-
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple, Union
25+
from typing import Any, Optional
2826

29-
from pkg_resources import Requirement, parse_requirements, yield_lines
27+
from packaging.requirements import Requirement
28+
from packaging.version import Version
3029

3130
REQUIREMENT_FILES = {
3231
"pytorch": (
@@ -35,11 +34,6 @@
3534
"requirements/pytorch/strategies.txt",
3635
"requirements/pytorch/examples.txt",
3736
),
38-
"app": (
39-
"requirements/app/app.txt",
40-
"requirements/app/cloud.txt",
41-
"requirements/app/ui.txt",
42-
),
4337
"fabric": (
4438
"requirements/fabric/base.txt",
4539
"requirements/fabric/strategies.txt",
@@ -87,14 +81,15 @@ def adjust(self, unfreeze: str) -> str:
8781
out = str(self)
8882
if self.strict:
8983
return f"{out} {self.strict_string}"
84+
specs = [(spec.operator, spec.version) for spec in self.specifier]
9085
if unfreeze == "major":
91-
for operator, version in self.specs:
86+
for operator, version in specs:
9287
if operator in ("<", "<="):
93-
major = LooseVersion(version).version[0]
88+
major = Version(version).major
9489
# replace upper bound with major version increased by one
9590
return out.replace(f"{operator}{version}", f"<{major + 1}.0")
9691
elif unfreeze == "all":
97-
for operator, version in self.specs:
92+
for operator, version in specs:
9893
if operator in ("<", "<="):
9994
# drop upper bound
10095
return out.replace(f"{operator}{version},", "")
@@ -103,33 +98,25 @@ def adjust(self, unfreeze: str) -> str:
10398
return out
10499

105100

106-
def _parse_requirements(strs: Union[str, Iterable[str]]) -> Iterator[_RequirementWithComment]:
101+
def _parse_requirements(lines: Iterable[str]) -> Iterator[_RequirementWithComment]:
107102
"""Adapted from `pkg_resources.parse_requirements` to include comments.
108103
109104
>>> txt = ['# ignored', '', 'this # is an', '--piparg', 'example', 'foo # strict', 'thing', '-r different/file.txt']
110105
>>> [r.adjust('none') for r in _parse_requirements(txt)]
111106
['this', 'example', 'foo # strict', 'thing']
112-
>>> txt = '\\n'.join(txt)
113-
>>> [r.adjust('none') for r in _parse_requirements(txt)]
114-
['this', 'example', 'foo # strict', 'thing']
115107
116108
"""
117-
lines = yield_lines(strs)
118109
pip_argument = None
119110
for line in lines:
111+
line = line.strip()
112+
if not line or line.startswith("#"):
113+
continue
120114
# Drop comments -- a hash without a space may be in a URL.
121115
if " #" in line:
122116
comment_pos = line.find(" #")
123117
line, comment = line[:comment_pos], line[comment_pos:]
124118
else:
125119
comment = ""
126-
# If there is a line continuation, drop it, and append the next line.
127-
if line.endswith("\\"):
128-
line = line[:-2].strip()
129-
try:
130-
line += next(lines)
131-
except StopIteration:
132-
return
133120
# If there's a pip argument, save it
134121
if line.startswith("--"):
135122
pip_argument = line
@@ -141,7 +128,7 @@ def _parse_requirements(strs: Union[str, Iterable[str]]) -> Iterator[_Requiremen
141128
pip_argument = None
142129

143130

144-
def load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: str = "all") -> List[str]:
131+
def load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: str = "all") -> list[str]:
145132
"""Loading requirements from a file.
146133
147134
>>> path_req = os.path.join(_PROJECT_ROOT, "requirements")
@@ -155,7 +142,7 @@ def load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: str
155142
logging.warning(f"Folder {path_dir} does not have any base requirements.")
156143
return []
157144
assert path.exists(), (path_dir, file_name, path)
158-
text = path.read_text()
145+
text = path.read_text().splitlines()
159146
return [req.adjust(unfreeze) for req in _parse_requirements(text)]
160147

161148

@@ -167,8 +154,8 @@ def load_readme_description(path_dir: str, homepage: str, version: str) -> str:
167154
168155
"""
169156
path_readme = os.path.join(path_dir, "README.md")
170-
with open(path_readme, encoding="utf-8") as fo:
171-
text = fo.read()
157+
with open(path_readme, encoding="utf-8") as fopen:
158+
text = fopen.read()
172159

173160
# drop images from readme
174161
text = text.replace(
@@ -216,30 +203,6 @@ def distribute_version(src_folder: str, ver_file: str = "version.info") -> None:
216203
shutil.copy2(ver_template, fpath)
217204

218205

219-
def _download_frontend(pkg_path: str, version: str = "v0.0.0"):
220-
"""Downloads an archive file for a specific release of the Lightning frontend and extracts it to the correct
221-
directory."""
222-
223-
try:
224-
frontend_dir = pathlib.Path(pkg_path, "ui")
225-
download_dir = tempfile.mkdtemp()
226-
227-
shutil.rmtree(frontend_dir, ignore_errors=True)
228-
# TODO: remove this once lightning-ui package is ready as a dependency
229-
frontend_release_url = f"https://lightning-packages.s3.amazonaws.com/ui/{version}.tar.gz"
230-
response = urllib.request.urlopen(frontend_release_url)
231-
232-
file = tarfile.open(fileobj=response, mode="r|gz")
233-
file.extractall(path=download_dir) # noqa: S202
234-
235-
shutil.move(download_dir, frontend_dir)
236-
print("The Lightning UI has successfully been downloaded!")
237-
238-
# If installing from source without internet connection, we don't want to break the installation
239-
except Exception:
240-
print("The Lightning UI downloading has failed!")
241-
242-
243206
def _load_aggregate_requirements(req_dir: str = "requirements", freeze_requirements: bool = False) -> None:
244207
"""Load all base requirements from all particular packages and prune duplicates.
245208
@@ -260,7 +223,7 @@ def _load_aggregate_requirements(req_dir: str = "requirements", freeze_requireme
260223
fp.writelines([ln + os.linesep for ln in requires] + [os.linesep])
261224

262225

263-
def _retrieve_files(directory: str, *ext: str) -> List[str]:
226+
def _retrieve_files(directory: str, *ext: str) -> list[str]:
264227
all_files = []
265228
for root, _, files in os.walk(directory):
266229
for fname in files:
@@ -270,7 +233,7 @@ def _retrieve_files(directory: str, *ext: str) -> List[str]:
270233
return all_files
271234

272235

273-
def _replace_imports(lines: List[str], mapping: List[Tuple[str, str]], lightning_by: str = "") -> List[str]:
236+
def _replace_imports(lines: list[str], mapping: list[tuple[str, str]], lightning_by: str = "") -> list[str]:
274237
"""Replace imports of standalone package to lightning.
275238
276239
>>> lns = [
@@ -345,20 +308,20 @@ def copy_replace_imports(
345308
if ext in (".pyc",):
346309
continue
347310
# Try to parse everything else
348-
with open(fp, encoding="utf-8") as fo:
311+
with open(fp, encoding="utf-8") as fopen:
349312
try:
350-
lines = fo.readlines()
313+
lines = fopen.readlines()
351314
except UnicodeDecodeError:
352315
# a binary file, skip
353316
print(f"Skipped replacing imports for {fp}")
354317
continue
355318
lines = _replace_imports(lines, list(zip(source_imports, target_imports)), lightning_by=lightning_by)
356319
os.makedirs(os.path.dirname(fp_new), exist_ok=True)
357-
with open(fp_new, "w", encoding="utf-8") as fo:
358-
fo.writelines(lines)
320+
with open(fp_new, "w", encoding="utf-8") as fopen:
321+
fopen.writelines(lines)
359322

360323

361-
def create_mirror_package(source_dir: str, package_mapping: Dict[str, str]) -> None:
324+
def create_mirror_package(source_dir: str, package_mapping: dict[str, str]) -> None:
362325
"""Create a mirror package with adjusted imports."""
363326
# replace imports and copy the code
364327
mapping = package_mapping.copy()
@@ -399,26 +362,12 @@ def _prune_packages(req_file: str, packages: Sequence[str]) -> None:
399362
if not ln_ or ln_.startswith("#"):
400363
final.append(line)
401364
continue
402-
req = list(parse_requirements(ln_))[0]
365+
req = list(_parse_requirements([ln_]))[0]
403366
if req.name not in packages:
404367
final.append(line)
405368
print(final)
406369
path.write_text("\n".join(final) + "\n")
407370

408-
@staticmethod
409-
def _replace_min(fname: str) -> None:
410-
with open(fname, encoding="utf-8") as fo:
411-
req = fo.read().replace(">=", "==")
412-
with open(fname, "w", encoding="utf-8") as fw:
413-
fw.write(req)
414-
415-
@staticmethod
416-
def replace_oldest_ver(requirement_fnames: Sequence[str] = REQUIREMENT_FILES_ALL) -> None:
417-
"""Replace the min package version by fixed one."""
418-
for fname in requirement_fnames:
419-
print(fname)
420-
AssistantCLI._replace_min(fname)
421-
422371
@staticmethod
423372
def copy_replace_imports(
424373
source_dir: str,
@@ -466,7 +415,7 @@ def pull_docs_files(
466415
raise RuntimeError(f"Requesting file '{zip_url}' does not exist or it is just unavailable.")
467416

468417
with zipfile.ZipFile(zip_file, "r") as zip_ref:
469-
zip_ref.extractall(tmp) # noqa: S202
418+
zip_ref.extractall(tmp)
470419

471420
zip_dirs = [d for d in glob.glob(os.path.join(tmp, "*")) if os.path.isdir(d)]
472421
# check that the extracted archive has only repo folder
@@ -508,15 +457,34 @@ def convert_version2nightly(ver_file: str = "src/version.info") -> None:
508457
"""Load the actual version and convert it to the nightly version."""
509458
from datetime import datetime
510459

511-
with open(ver_file) as fo:
512-
version = fo.read().strip()
460+
with open(ver_file) as fopen:
461+
version = fopen.read().strip()
513462
# parse X.Y.Z version and prune any suffix
514463
vers = re.match(r"(\d+)\.(\d+)\.(\d+).*", version)
515464
# create timestamp YYYYMMDD
516465
timestamp = datetime.now().strftime("%Y%m%d")
517466
version = f"{'.'.join(vers.groups())}.dev{timestamp}"
518-
with open(ver_file, "w") as fo:
519-
fo.write(version + os.linesep)
467+
with open(ver_file, "w") as fopen:
468+
fopen.write(version + os.linesep)
469+
470+
@staticmethod
471+
def generate_docker_tags(
472+
release_version: str,
473+
python_version: str,
474+
torch_version: str,
475+
cuda_version: str,
476+
docker_project: str = "pytorchlightning/pytorch_lightning",
477+
add_latest: bool = False,
478+
) -> None:
479+
"""Generate docker tags for the given versions."""
480+
tags = [f"latest-py{python_version}-torch{torch_version}-cuda{cuda_version}"]
481+
if release_version:
482+
tags += [f"{release_version}-py{python_version}-torch{torch_version}-cuda{cuda_version}"]
483+
if add_latest:
484+
tags += ["latest"]
485+
486+
tags = [f"{docker_project}:{tag}" for tag in tags]
487+
print(",".join(tags))
520488

521489

522490
if __name__ == "__main__":

.actions/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1-
jsonargparse >=4.16.0, <4.28.0
1+
jsonargparse
22
requests
3+
packaging

0 commit comments

Comments
 (0)