Skip to content

Commit 8d076c5

Browse files
committed
Include POSITIONAL_OR_KEYWORD params in MultiMetric kwarg validation
1 parent 6babc72 commit 8d076c5

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

flax/nnx/training/metrics.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -418,24 +418,27 @@ def __init__(self, **metrics):
418418
if self._expected_kwargs is None:
419419
continue
420420
sig = inspect.signature(metric.update)
421-
has_keyword_only = False
421+
has_named_params = False
422422
has_var_keyword = False
423-
keyword_only_names: set[str] = set()
423+
named_param_names: set[str] = set()
424424
for pname, param in sig.parameters.items():
425425
if pname == 'self':
426426
continue
427-
if param.kind == param.KEYWORD_ONLY:
428-
keyword_only_names.add(pname)
429-
has_keyword_only = True
427+
if param.kind in (
428+
param.POSITIONAL_OR_KEYWORD,
429+
param.KEYWORD_ONLY,
430+
):
431+
named_param_names.add(pname)
432+
has_named_params = True
430433
elif param.kind == param.VAR_KEYWORD:
431434
has_var_keyword = True
432-
if has_keyword_only and has_var_keyword:
435+
if has_named_params and has_var_keyword:
433436
# Metric declares specific params but also absorbs
434437
# extras (e.g. Accuracy's **_); can't validate
435438
# without false positives.
436439
self._expected_kwargs = None
437-
elif has_keyword_only:
438-
self._expected_kwargs.update(keyword_only_names)
440+
elif has_named_params:
441+
self._expected_kwargs.update(named_param_names)
439442
elif hasattr(metric, 'argname'):
440443
# Use argname convention (e.g. Average, Welford).
441444
self._expected_kwargs.add(metric.argname)

0 commit comments

Comments
 (0)