Skip to content

Commit 95f3d64

Browse files
authored
Fix bugs in Executor (#306)
* Fix: saved `meta_dict` contains excessive data - `meta_dict` contains training and evaluation metrics objects, but it's unnecessary to save objects -- only their values are needed. - A more serious problem is with the `LR` metric, which stores a reference to the optimizer. This will result in the entire optimizer being saved in the meta-info. - Solution is two parts: 1) save only the metric values; 2) store the optimizer as a weakref. * Fix: don't close files when calling test in train - `_open_files` and `_close_files` are called at the beginning and end of `train` and `test`, to prevent holding on to an open file object for an unnecessarily long amount of time. - However, it's possible that we call `test` within `train`. For instance, calling `test` in a action triggered by the validation event. In this case, the file will be closed before training ends. - Solution is to check whether we need to open files, and if we don't, then don't open nor close them. * Fix: missing call to tracker in `_validate_loop` This is so stupid: for some reason I forgot to call `_valid_tracker.add` in `_validate_loop`, so the status is never updated during validation. * Revert c89e0e4: fix `meta_dict` issue - It turns out we must store the metric objects -- otherwise we can't even compare two metric values. - So I just changed the pickle behavior for `LR` so that it doesn't save the optimizer. Seems like a hack, but let's just leave it at this. * Fix doc building issues
1 parent 27fe398 commit 95f3d64

File tree

4 files changed

+41
-18
lines changed

4 files changed

+41
-18
lines changed

docs/code/data.rst

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,6 @@ Data Loaders
132132
.. autoclass:: texar.torch.data.DatasetBase
133133
:members:
134134

135-
.. automethod:: process
136-
.. automethod:: collate
137-
138135
:hidden:`MonoTextData`
139136
~~~~~~~~~~~~~~~~~~~~~~~~
140137
.. autoclass:: texar.torch.data.MonoTextData

tests/run/executor_test.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def tearDown(self) -> None:
9999
shutil.rmtree(self.tbx_logging_dir)
100100

101101
def test_train_loop(self):
102+
optimizer = torch.optim.Adam(self.model.parameters())
102103
executor = Executor(
103104
model=self.model,
104105
train_data=self.datasets["train"],
@@ -110,8 +111,9 @@ def test_train_loop(self):
110111
save_every=[cond.time(seconds=10), cond.validation(better=True)],
111112
train_metrics=[("loss", metric.RunningAverage(20)),
112113
metric.F1(pred_name="preds", mode="macro"),
113-
metric.Accuracy(pred_name="preds")],
114-
optimizer={"type": torch.optim.Adam, "kwargs": {}},
114+
metric.Accuracy(pred_name="preds"),
115+
metric.LR(optimizer)],
116+
optimizer=optimizer,
115117
stop_training_on=cond.epoch(10),
116118
valid_metrics=[metric.F1(pred_name="preds", mode="micro"),
117119
("loss", metric.Average())],
@@ -129,6 +131,9 @@ def test_train_loop(self):
129131
executor.train()
130132
executor.test()
131133

134+
executor.save()
135+
executor.load()
136+
132137
def test_tbx_logging(self):
133138
executor = Executor(
134139
model=self.model,

texar/torch/run/executor.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1280,8 +1280,12 @@ def remove_action(self) -> None:
12801280
def train(self):
12811281
r"""Start the training loop.
12821282
"""
1283-
# open the log files
1284-
self._open_files()
1283+
# Check whether files have been opened, to avoid re-opening and closing.
1284+
# This could happen when, e.g., `test` is called in a registered hook
1285+
# during training.
1286+
should_open_file = (len(self._opened_files) == 0)
1287+
if should_open_file:
1288+
self._open_files()
12851289

12861290
if self._directory_exists:
12871291
self.write_log(
@@ -1351,8 +1355,9 @@ def _try_get_data_size(executor: 'Executor'):
13511355

13521356
self._fire_event(Event.Training, True)
13531357

1354-
# close the log files
1355-
self._close_files()
1358+
# Close the log files if we opened them here.
1359+
if should_open_file:
1360+
self._close_files()
13561361

13571362
def test(self, dataset: OptionalDict[DatasetBase] = None):
13581363
r"""Start the test loop.
@@ -1369,8 +1374,12 @@ def test(self, dataset: OptionalDict[DatasetBase] = None):
13691374
If `None`, :attr:`test_data` from the constructor arguments is
13701375
used. Defaults to `None`.
13711376
"""
1372-
# open the log files
1373-
self._open_files()
1377+
# Check whether files have been opened, to avoid re-opening and closing.
1378+
# This could happen when, e.g., `test` is called in a registered hook
1379+
# during training.
1380+
should_open_file = (len(self._opened_files) == 0)
1381+
if should_open_file:
1382+
self._open_files()
13741383

13751384
if dataset is None and self.test_data is None:
13761385
raise ValueError("No testing dataset is specified")
@@ -1417,8 +1426,9 @@ def test(self, dataset: OptionalDict[DatasetBase] = None):
14171426

14181427
self.model.train(model_mode)
14191428

1420-
# close the log files
1421-
self._close_files()
1429+
# Close the log files if we opened them here.
1430+
if should_open_file:
1431+
self._close_files()
14221432

14231433
def _register_logging_actions(self, show_live_progress: List[str]):
14241434
# Register logging actions.
@@ -1728,6 +1738,7 @@ def _open_files(self):
17281738
def _close_files(self):
17291739
for file in self._opened_files:
17301740
file.close()
1741+
self._opened_files = []
17311742

17321743
if hasattr(self, 'summary_writer'):
17331744
self.summary_writer.close()
@@ -1890,7 +1901,7 @@ def _validate_loop(self, iterator: DataIterator) -> None:
18901901
self._fire_event(Event.ValidationIteration, False)
18911902
return_dict = self._validate_step(batch)
18921903

1893-
# Update metrics.
1904+
self._valid_tracker.add(len(batch))
18941905
utils.update_metrics(return_dict, batch, self.valid_metrics)
18951906

18961907
self._fire_event(Event.ValidationIteration, True)
@@ -1906,8 +1917,7 @@ def _test_loop(self, iterator: DataIterator) -> None:
19061917
return_dict = self._test_step(batch)
19071918

19081919
self._test_tracker.add(len(batch))
1909-
utils.update_metrics(
1910-
return_dict, batch, self.test_metrics)
1920+
utils.update_metrics(return_dict, batch, self.test_metrics)
19111921

19121922
self._fire_event(Event.TestingIteration, True)
19131923

texar/torch/run/metric/summary.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from collections import deque
1919
from typing import Any, Deque, Optional, Sequence
20+
import weakref
2021

2122
import numpy as np
2223
from torch.optim.optimizer import Optimizer
@@ -152,15 +153,25 @@ class LR(StreamingMetric[Any, float]):
152153

153154
def __init__(self, optimizer: Optimizer, param_group: int = 0):
154155
super().__init__(pred_name=None)
155-
self.optimizer = optimizer
156+
self.optimizer = weakref.ref(optimizer)
156157
self.group = param_group
157158

158159
def add(self, _, __):
159160
pass
160161

161162
def value(self) -> float:
162-
return self.optimizer.param_groups[self.group]['lr'] # type: ignore
163+
return self.optimizer().param_groups[self.group]['lr'] # type: ignore
163164

164165
def better(self, cur: float, prev: float) -> Optional[bool]:
165166
# Always return `None` to indicate values are uncomparable.
166167
return None
168+
169+
def __getstate__(self):
170+
# There's no point in pickling an `LR` metric; just ignore it.
171+
return None
172+
173+
def __getnewargs__(self):
174+
# But when unpickling, we need to make sure we can construct something.
175+
# This requires passing a dummy `optimizer` to which a weakref can be
176+
# constructed. In this case, we use an arbitrary built-in class.
177+
return (int,)

0 commit comments

Comments
 (0)