Skip to content

Commit 5545c49

Browse files
authored
Require keyword args for data iterator. (dmlc#8327)
1 parent e1f9f80 commit 5545c49

File tree

2 files changed

+55
-38
lines changed

2 files changed

+55
-38
lines changed

python-package/xgboost/core.py

Lines changed: 51 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -502,8 +502,8 @@ def _next_wrapper(self, this: None) -> int: # pylint: disable=unused-argument
502502
pointer.
503503
504504
"""
505-
@_deprecate_positional_args
506-
def data_handle(
505+
@require_pos_args(True)
506+
def input_data(
507507
data: Any,
508508
*,
509509
feature_names: Optional[FeatureNames] = None,
@@ -528,7 +528,7 @@ def data_handle(
528528
**kwargs,
529529
)
530530
# pylint: disable=not-callable
531-
return self._handle_exception(lambda: self.next(data_handle), 0)
531+
return self._handle_exception(lambda: self.next(input_data), 0)
532532

533533
@abstractmethod
534534
def reset(self) -> None:
@@ -554,7 +554,7 @@ def next(self, input_data: Callable) -> int:
554554
raise NotImplementedError()
555555

556556

557-
# Notice for `_deprecate_positional_args`
557+
# Notice for `require_pos_args`
558558
# Authors: Olivier Grisel
559559
# Gael Varoquaux
560560
# Andreas Mueller
@@ -563,50 +563,63 @@ def next(self, input_data: Callable) -> int:
563563
# Nicolas Tresegnie
564564
# Sylvain Marie
565565
# License: BSD 3 clause
566-
def _deprecate_positional_args(f: Callable[..., _T]) -> Callable[..., _T]:
566+
def require_pos_args(error: bool) -> Callable[[Callable[..., _T]], Callable[..., _T]]:
567567
"""Decorator for methods that issues warnings for positional arguments
568568
569569
Using the keyword-only argument syntax in pep 3102, arguments after the
570-
* will issue a warning when passed as a positional argument.
570+
* will issue a warning or error when passed as a positional argument.
571571
572572
Modified from sklearn utils.validation.
573573
574574
Parameters
575575
----------
576-
f : function
577-
function to check arguments on
576+
error :
577+
Whether to throw an error or raise a warning.
578578
"""
579-
sig = signature(f)
580-
kwonly_args = []
581-
all_args = []
582-
583-
for name, param in sig.parameters.items():
584-
if param.kind == Parameter.POSITIONAL_OR_KEYWORD:
585-
all_args.append(name)
586-
elif param.kind == Parameter.KEYWORD_ONLY:
587-
kwonly_args.append(name)
588-
589-
@wraps(f)
590-
def inner_f(*args: Any, **kwargs: Any) -> _T:
591-
extra_args = len(args) - len(all_args)
592-
if extra_args > 0:
593-
# ignore first 'self' argument for instance methods
594-
args_msg = [
595-
f"{name}" for name, _ in zip(
596-
kwonly_args[:extra_args], args[-extra_args:]
597-
)
598-
]
599-
# pylint: disable=consider-using-f-string
600-
warnings.warn(
601-
"Pass `{}` as keyword args. Passing these as positional "
602-
"arguments will be considered as error in future releases.".
603-
format(", ".join(args_msg)), FutureWarning
604-
)
605-
for k, arg in zip(sig.parameters, args):
606-
kwargs[k] = arg
607-
return f(**kwargs)
608579

609-
return inner_f
580+
def throw_if(func: Callable[..., _T]) -> Callable[..., _T]:
581+
"""Throw error/warning if there are positional arguments after the asterisk.
582+
583+
Parameters
584+
----------
585+
f :
586+
function to check arguments on.
587+
588+
"""
589+
sig = signature(func)
590+
kwonly_args = []
591+
all_args = []
592+
593+
for name, param in sig.parameters.items():
594+
if param.kind == Parameter.POSITIONAL_OR_KEYWORD:
595+
all_args.append(name)
596+
elif param.kind == Parameter.KEYWORD_ONLY:
597+
kwonly_args.append(name)
598+
599+
@wraps(func)
600+
def inner_f(*args: Any, **kwargs: Any) -> _T:
601+
extra_args = len(args) - len(all_args)
602+
if extra_args > 0:
603+
# ignore first 'self' argument for instance methods
604+
args_msg = [
605+
f"{name}"
606+
for name, _ in zip(kwonly_args[:extra_args], args[-extra_args:])
607+
]
608+
# pylint: disable=consider-using-f-string
609+
msg = "Pass `{}` as keyword args.".format(", ".join(args_msg))
610+
if error:
611+
raise TypeError(msg)
612+
warnings.warn(msg, FutureWarning)
613+
for k, arg in zip(sig.parameters, args):
614+
kwargs[k] = arg
615+
return func(**kwargs)
616+
617+
return inner_f
618+
619+
return throw_if
620+
621+
622+
_deprecate_positional_args = require_pos_args(False)
610623

611624

612625
class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-methods

tests/python/testing.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,10 @@ def __init__(
198198
def next(self, input_data: Callable) -> int:
199199
if self.it == len(self.X):
200200
return 0
201+
202+
with pytest.raises(TypeError, match="keyword args"):
203+
input_data(self.X[self.it], self.y[self.it], None)
204+
201205
# Use copy to make sure the iterator doesn't hold a reference to the data.
202206
input_data(
203207
data=self.X[self.it].copy(),

0 commit comments

Comments
 (0)