Skip to content

Commit b356303

Browse files
ClodLingxiZeyi-Lin
andauthored
fix: implement "finalize" in SwanLabLogger (#1315)
* fix: implement "finalize" in SwanLabLogger (#1261) * change test script location --------- Co-authored-by: ZeYi Lin <944270057@qq.com>
1 parent 31169dd commit b356303

File tree

2 files changed

+63
-2
lines changed

2 files changed

+63
-2
lines changed

swanlab/integration/pytorch_lightning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
"or \n pip install lightning"
3838
)
3939
import swanlab
40-
from ..data.run import SwanLabRun
40+
from ..data.run import SwanLabRun, SwanLabRunState
4141
from lightning_fabric.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params
4242

4343

@@ -218,7 +218,7 @@ def name(self) -> Optional[str]:
218218
@rank_zero_only
219219
def finalize(self, status: str) -> None:
220220
if status != "success":
221-
return
221+
self.experiment.finish(SwanLabRunState.CRASHED, f"Closed by pytorch lightning. Status: {status}")
222222

223223
@property
224224
def version(self) -> Optional[str]:
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
try:
2+
import pytorch_lightning as pl
3+
from pytorch_lightning import demos # Need Import Again
4+
except ImportError:
5+
pl = None
6+
7+
import swanlab
8+
from swanlab.data.run import SwanLabRunState
9+
from swanlab.integration.pytorch_lightning import SwanLabLogger
10+
11+
12+
def pl_strategy_for_swanlab(raw_fn, swanlab_logger: SwanLabLogger):
13+
def wrapper(exception: BaseException):
14+
if isinstance(exception, KeyboardInterrupt):
15+
swanlab_logger.experiment.finish(SwanLabRunState.CRASHED, "KeyboardInterrupt by user", interrupt=True)
16+
raw_fn(exception)
17+
return wrapper
18+
19+
20+
def test_for_pytorch_lighting():
21+
# Also See SwanLab/test/integration/lightning
22+
if pl is None:
23+
Warning("Need PyTorch Lightning To Test")
24+
return
25+
26+
model = pl.demos.boring_classes.BoringModel()
27+
datamodule = pl.demos.boring_classes.BoringDataModule()
28+
logger = SwanLabLogger(project="test-interrupt")
29+
30+
# log_every_n_steps != None for enable Logger
31+
trainer = pl.Trainer(max_epochs=10, logger=logger, log_every_n_steps=1)
32+
33+
# 加上可使pl被KeyboardInterrupt中断时,Web上的状态为 "Ctrl+C" 而不是 "中断"
34+
# Add support so that when pl is interrupted by a KeyboardInterrupt, the status displayed on the web is "Ctrl+C" instead of "Crashed"
35+
# trainer.strategy.on_exception = pl_strategy_for_swanlab(trainer.strategy.on_exception, logger) # 在pl的策略补充即可
36+
37+
trainer.fit(model, datamodule)
38+
39+
def test_for_normal_interrupt():
40+
from tqdm import tqdm
41+
# import time
42+
43+
swanlab.init(project="test-interrupt")
44+
45+
for i in tqdm(range(10000)):
46+
# time.sleep(0.01)
47+
48+
for _ in range(100):
49+
for _ in range(100):
50+
for _ in range(100):
51+
pass
52+
53+
if i % 10 == 0:
54+
swanlab.log({"progress": i}, print_to_console=True)
55+
56+
57+
if __name__ == '__main__':
58+
test_for_pytorch_lighting()
59+
60+
# test_for_normal_interrupt()
61+

0 commit comments

Comments
 (0)