Skip to content

Commit 025dee7

Browse files
authored
[DLMED] add first to from_engine (#2524)
Signed-off-by: Nic Ma <[email protected]>
1 parent cfcce58 commit 025dee7

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

monai/handlers/utils.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,12 +225,16 @@ def _compute_op(op: str, d: np.ndarray):
225225
f.write(f"{class_labels[i]}{deli}{deli.join([f'{_compute_op(k, c):.4f}' for k in ops])}\n")
226226

227227

228-
def from_engine(keys: KeysCollection):
228+
def from_engine(keys: KeysCollection, first: bool = False):
229229
"""
230230
Utility function to simplify the `batch_transform` or `output_transform` args of ignite components
231231
when handling dictionary or list of dictionaries(for example: `engine.state.batch` or `engine.state.output`).
232232
Users only need to set the expected keys, then it will return a callable function to extract data from
233233
dictionary and construct a tuple respectively.
234+
235+
If data is a list of dictionaries after decollating, extract expected keys and construct lists respectively,
236+
for example, if data is `[{"A": 1, "B": 2}, {"A": 3, "B": 4}]`, from_engine(["A", "B"]): `([1, 3], [2, 4])`.
237+
234238
It can help avoid a complicated `lambda` function and make the arg of metrics more straight-forward.
235239
For example, set the first key as the prediction and the second key as label to get the expected data
236240
from `engine.state.output` for a metric::
@@ -242,15 +246,23 @@ def from_engine(keys: KeysCollection):
242246
output_transform=from_engine(["pred", "label"])
243247
)
244248
249+
Args:
250+
keys: specified keys to extract data from dictionary or decollated list of dictionaries.
251+
first: whether only extract sepcified keys from the first item if input data is a list of dictionaries,
252+
it's used to extract the scalar data which doesn't have batch dim and was replicated into every
253+
dictionary when decollating, like `loss`, etc.
254+
255+
245256
"""
246257
keys = ensure_tuple(keys)
247258

248259
def _wrapper(data):
249260
if isinstance(data, dict):
250261
return tuple(data[k] for k in keys)
251262
elif isinstance(data, list) and isinstance(data[0], dict):
252-
# if data is a list of dictionaries, extract expected keys and construct lists
253-
ret = [[i[k] for i in data] for k in keys]
263+
# if data is a list of dictionaries, extract expected keys and construct lists,
264+
# if `first=True`, only extract keys from the first item of the list
265+
ret = [data[0][k] if first else [i[k] for i in data] for k in keys]
254266
return tuple(ret) if len(ret) > 1 else ret[0]
255267

256268
return _wrapper

tests/test_integration_workflows.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,9 @@ def _model_completed(self, engine):
182182
train_handlers = [
183183
LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
184184
ValidationHandler(validator=evaluator, interval=2, epoch_level=True),
185-
StatsHandler(tag_name="train_loss", output_transform=lambda x: x[0]["loss"]),
185+
StatsHandler(tag_name="train_loss", output_transform=from_engine("loss", first=True)),
186186
TensorBoardStatsHandler(
187-
summary_writer=summary_writer, tag_name="train_loss", output_transform=lambda x: x[0]["loss"]
187+
summary_writer=summary_writer, tag_name="train_loss", output_transform=from_engine("loss", first=True)
188188
),
189189
CheckpointSaver(save_dir=root_dir, save_dict={"net": net, "opt": opt}, save_interval=2, epoch_level=True),
190190
_TestTrainIterEvents(),

0 commit comments

Comments
 (0)