Skip to content

Commit 1428f17

Browse files
Add _ignore_kwargs function to filter kwargs based on function signature
This new utility function filters keyword arguments to only those accepted by the specified function, logging any ignored keys. It is integrated into the model building process to ensure compatibility with model initialization.
1 parent 03f4f4d commit 1428f17

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

timm/models/_builder.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import dataclasses
22
import logging
33
import os
4+
import inspect
45
from copy import deepcopy
56
from pathlib import Path
67
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
@@ -303,6 +304,19 @@ def _filter_kwargs(kwargs: Dict[str, Any], names: List[str]) -> None:
303304
for n in names:
304305
kwargs.pop(n, None)
305306

307+
def _ignore_kwargs(func, kwargs):
308+
""" Filter kwargs to those that func accepts.
309+
"""
310+
sig = inspect.signature(func)
311+
if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()):
312+
return kwargs
313+
filter_keys = [p.name for p in sig.parameters.values() if p.kind in (p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY)]
314+
filtered_kwargs = {k: v for k, v in kwargs.items() if k in filter_keys}
315+
ignored_keys = set(kwargs.keys()) - set(filtered_kwargs.keys())
316+
if ignored_keys:
317+
_logger.warning(
318+
f'Ignored attempt to pass arguments ({", ".join(ignored_keys)}) to function {func}.')
319+
return filtered_kwargs
306320

307321
def _update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter) -> None:
308322
""" Update the default_cfg and kwargs before passing to model
@@ -441,6 +455,7 @@ def build_model_with_cfg(
441455
feature_cfg['feature_cls'] = kwargs.pop('feature_cls')
442456

443457
# Instantiate the model
458+
kwargs = _ignore_kwargs(model_cls.__init__, kwargs)
444459
if model_cfg is None:
445460
model = model_cls(**kwargs)
446461
else:

0 commit comments

Comments
 (0)