Skip to content

Commit 4284762

Browse files
awaelchlilexierule
authored andcommitted
Fix import error when running examples in fresh environment (#19431)
(cherry picked from commit 5aea3b1)
1 parent 0de52d1 commit 4284762

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

requirements/pytorch/examples.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
22
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
33

4+
requests <2.32.0
45
torchvision >=0.14.0, <0.17.0
56
gym[classic_control] >=0.17.0, <0.27.0
67
ipython[all] <8.15.0

src/lightning/pytorch/demos/transformer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,19 @@
99
from pathlib import Path
1010
from typing import Dict, List, Optional, Tuple
1111

12-
import requests
1312
import torch
1413
import torch.nn as nn
1514
import torch.nn.functional as F
15+
from lightning_utilities.core.imports import RequirementCache
1616
from torch import Tensor
1717
from torch.nn.modules import MultiheadAttention
1818
from torch.utils.data import DataLoader, Dataset
1919

2020
from lightning.pytorch import LightningModule
2121

22+
_REQUESTS_AVAILABLE = RequirementCache("requests")
23+
24+
2225
if hasattr(MultiheadAttention, "_reset_parameters") and not hasattr(MultiheadAttention, "reset_parameters"):
2326
# See https://github.com/pytorch/pytorch/issues/107909
2427
MultiheadAttention.reset_parameters = MultiheadAttention._reset_parameters
@@ -125,6 +128,11 @@ def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
125128

126129
@staticmethod
127130
def download(destination: Path) -> None:
131+
if not _REQUESTS_AVAILABLE:
132+
raise ModuleNotFoundError(str(_REQUESTS_AVAILABLE))
133+
134+
import requests
135+
128136
os.makedirs(destination.parent, exist_ok=True)
129137
url = "https://raw.githubusercontent.com/pytorch/examples/main/word_language_model/data/wikitext-2/train.txt"
130138
if os.path.exists(destination):

0 commit comments

Comments
 (0)