|
4 | 4 |
|
5 | 5 | import abc |
6 | 6 | from collections.abc import Sequence |
| 7 | +from inspect import signature |
7 | 8 | from numbers import Number |
8 | 9 | from typing import TYPE_CHECKING, Any, Callable, Union |
9 | 10 |
|
@@ -458,69 +459,53 @@ def dim(self) -> int: |
458 | 459 | return len(self.categories) |
459 | 460 |
|
460 | 461 |
|
461 | | -class WrappedKernel(kernels.Kernel): |
462 | | - """Wrap a kernel with a parameter transformation. |
463 | | -
|
464 | | - The transform function is applied to the input before passing it to the base kernel. |
| 462 | +def wrap_kernel(kernel: kernels.Kernel, transform: Callable[[Any], Any]) -> kernels.Kernel: |
| 463 | + """Wrap a kernel to transform input data before passing it to the kernel. |
465 | 464 |
|
466 | 465 | Parameters |
467 | 466 | ---------- |
468 | | - base_kernel : kernels.Kernel |
| 467 | + kernel : kernels.Kernel |
| 468 | + The kernel to wrap. |
469 | 469 |
|
470 | | - transform : Callable[[Any], Any] |
471 | | - """ |
| 470 | + transform : Callable |
| 471 | + The transformation function to apply to the input data. |
472 | 472 |
|
473 | | - def __init__(self, base_kernel: kernels.Kernel, transform: Callable[[Any], Any]) -> None: |
474 | | - super().__init__() |
475 | | - self.base_kernel = base_kernel |
476 | | - self.transform = transform |
| 473 | + Returns |
| 474 | + ------- |
| 475 | + kernels.Kernel |
| 476 | + The wrapped kernel. |
477 | 477 |
|
478 | | - def __call__(self, X: NDArray[Float], Y: NDArray[Float] = None, eval_gradient: bool = False) -> Any: |
479 | | - """Return the kernel k(X, Y) and optionally its gradient after applying the transform. |
| 478 | + Notes |
| 479 | + ----- |
| 480 | + See https://arxiv.org/abs/1805.03463 for more information. |
| 481 | + """ |
| 482 | + kernel_type = type(kernel) |
480 | 483 |
|
481 | | - For details, see the documentation of the base kernel. |
| 484 | + class WrappedKernel(kernel_type): |
| 485 | + @copy_signature(getattr(kernel_type.__init__, "deprecated_original", kernel_type.__init__)) |
| 486 | + def __init__(self, **kwargs: Any) -> None: |
| 487 | + super().__init__(**kwargs) |
482 | 488 |
|
483 | | - Parameters |
484 | | - ---------- |
485 | | - X : ndarray of shape (n_samples_X, n_features) |
486 | | - Left argument of the returned kernel k(X, Y). |
| 489 | + def __call__(self, X: Any, Y: Any = None, eval_gradient: bool = False) -> Any: |
| 490 | + X = transform(X) |
| 491 | + Y = transform(Y) if Y is not None else None |
| 492 | + return super().__call__(X, Y, eval_gradient) |
487 | 493 |
|
488 | | - Y : ndarray of shape (n_samples_Y, n_features), default=None |
489 | | - Right argument of the returned kernel k(X, Y). If None, k(X, X) is evaluated. |
| 494 | + def __reduce__(self) -> str | tuple[Any, ...]: |
| 495 | + return (wrap_kernel, (kernel, transform)) |
490 | 496 |
|
491 | | - eval_gradient : bool, default=False |
492 | | - Determines whether the gradient with respect to the kernel hyperparameter is calculated. |
| 497 | + return WrappedKernel(**kernel.get_params()) |
493 | 498 |
|
494 | | - Returns |
495 | | - ------- |
496 | | - K : ndarray of shape (n_samples_X, n_samples_Y) |
497 | | -
|
498 | | - K_gradient : ndarray of shape (n_samples_X, n_samples_X, n_dims) |
499 | | - """ |
500 | | - X = self.transform(X) |
501 | | - Y = self.transform(Y) if Y is not None else None |
502 | | - return self.base_kernel(X, Y, eval_gradient) |
503 | 499 |
|
504 | | - def is_stationary(self): |
505 | | - """Return whether the kernel is stationary.""" |
506 | | - return self.base_kernel.is_stationary() |
507 | | - |
508 | | - def diag(self, X: NDArray[Float]) -> NDArray[Float]: |
509 | | - """Return the diagonal of k(X, X). |
510 | | -
|
511 | | - This method allows for more efficient calculations than calling |
512 | | - np.diag(self(X)). |
| 500 | +def copy_signature(source_fct: Callable[..., Any]) -> Callable[[Callable[..., Any]], Callable[..., Any]]: |
| 501 | + """Clones a signature from a source function to a target function. |
513 | 502 |
|
| 503 | + via |
| 504 | + https://stackoverflow.com/a/58989918/ |
| 505 | + """ |
514 | 506 |
|
515 | | - Parameters |
516 | | - ---------- |
517 | | - X : array-like of shape (n_samples,) |
518 | | - Left argument of the returned kernel k(X, Y) |
| 507 | + def copy(target_fct: Callable[..., Any]) -> Callable[..., Any]: |
| 508 | + target_fct.__signature__ = signature(source_fct) |
| 509 | + return target_fct |
519 | 510 |
|
520 | | - Returns |
521 | | - ------- |
522 | | - K_diag : ndarray of shape (n_samples_X,) |
523 | | - Diagonal of kernel k(X, X) |
524 | | - """ |
525 | | - X = self.transform(X) |
526 | | - return self.base_kernel.diag(X) |
| 511 | + return copy |
0 commit comments