@@ -225,12 +225,16 @@ def _compute_op(op: str, d: np.ndarray):
225225 f .write (f"{ class_labels [i ]} { deli } { deli .join ([f'{ _compute_op (k , c ):.4f} ' for k in ops ])} \n " )
226226
227227
228- def from_engine (keys : KeysCollection ):
228+ def from_engine (keys : KeysCollection , first : bool = False ):
229229 """
230230 Utility function to simplify the `batch_transform` or `output_transform` args of ignite components
231231 when handling dictionary or list of dictionaries(for example: `engine.state.batch` or `engine.state.output`).
232232 Users only need to set the expected keys, then it will return a callable function to extract data from
233233 dictionary and construct a tuple respectively.
234+
235+ If data is a list of dictionaries after decollating, extract expected keys and construct lists respectively,
236+ for example, if data is `[{"A": 1, "B": 2}, {"A": 3, "B": 4}]`, from_engine(["A", "B"]): `([1, 3], [2, 4])`.
237+
234238 It can help avoid a complicated `lambda` function and make the arg of metrics more straight-forward.
235239 For example, set the first key as the prediction and the second key as label to get the expected data
236240 from `engine.state.output` for a metric::
@@ -242,15 +246,23 @@ def from_engine(keys: KeysCollection):
242246 output_transform=from_engine(["pred", "label"])
243247 )
244248
249+ Args:
250+ keys: specified keys to extract data from dictionary or decollated list of dictionaries.
251+ first: whether only extract sepcified keys from the first item if input data is a list of dictionaries,
252+ it's used to extract the scalar data which doesn't have batch dim and was replicated into every
253+ dictionary when decollating, like `loss`, etc.
254+
255+
245256 """
246257 keys = ensure_tuple (keys )
247258
248259 def _wrapper (data ):
249260 if isinstance (data , dict ):
250261 return tuple (data [k ] for k in keys )
251262 elif isinstance (data , list ) and isinstance (data [0 ], dict ):
252- # if data is a list of dictionaries, extract expected keys and construct lists
253- ret = [[i [k ] for i in data ] for k in keys ]
263+ # if data is a list of dictionaries, extract expected keys and construct lists,
264+ # if `first=True`, only extract keys from the first item of the list
265+ ret = [data [0 ][k ] if first else [i [k ] for i in data ] for k in keys ]
254266 return tuple (ret ) if len (ret ) > 1 else ret [0 ]
255267
256268 return _wrapper
0 commit comments