44
55import abc
66from collections .abc import Sequence
7- from inspect import signature
87from numbers import Number
98from typing import TYPE_CHECKING , Any , Callable , Union
109
@@ -375,8 +374,6 @@ def to_float(self, value: Any) -> NDArray[Float]:
375374 """
376375 res = np .zeros (len (self .categories ))
377376 one_hot_index = [i for i , val in enumerate (self .categories ) if val == value ]
378- if len (one_hot_index ) != 1 :
379- raise ValueError
380377 res [one_hot_index ] = 1
381378 return res .astype (float )
382379
@@ -432,7 +429,7 @@ def kernel_transform(self, value: NDArray[Float]) -> NDArray[Float]:
432429 """
433430 value = np .atleast_2d (value )
434431 res = np .zeros (value .shape )
435- res [np .argmax (value , axis = 0 )] = 1
432+ res [:, np .argmax (value , axis = 1 )] = 1
436433 return res
437434
438435 @property
@@ -441,52 +438,68 @@ def dim(self) -> int:
441438 return len (self .categories )
442439
443440
444- def wrap_kernel (kernel : kernels .Kernel , transform : Callable [[Any ], Any ]) -> kernels .Kernel :
445- """Wrap a kernel to transform input data before passing it to the kernel.
441+ class WrappedKernel (kernels .Kernel ):
442+ """Wrap a kernel with a parameter transformation.
443+
444+ The transform function is applied to the input before passing it to the base kernel.
446445
447446 Parameters
448447 ----------
449- kernel : kernels.Kernel
450- The kernel to wrap.
448+ base_kernel : kernels.Kernel
451449
452- transform : Callable
453- The transformation function to apply to the input data.
450+ transform : Callable[[Any], Any]
451+ """
454452
455- Returns
456- -------
457- kernels.Kernel
458- The wrapped kernel.
453+ def __init__ ( self , base_kernel : kernels . Kernel , transform : Callable [[ Any ], Any ]) -> None :
454+ super (). __init__ ()
455+ self . base_kernel = base_kernel
456+ self . transform = transform
459457
460- Notes
461- -----
462- See https://arxiv.org/abs/1805.03463 for more information.
463- """
464- kernel_type = type (kernel )
458+ def __call__ (self , X : NDArray [Float ], Y : NDArray [Float ] = None , eval_gradient : bool = False ) -> Any :
459+ """Return the kernel k(X, Y) and optionally its gradient after applying the transform.
465460
466- class WrappedKernel (kernel_type ):
467- @copy_signature (getattr (kernel_type .__init__ , "deprecated_original" , kernel_type .__init__ ))
468- def __init__ (self , ** kwargs : Any ) -> None :
469- super ().__init__ (** kwargs )
461+ For details, see the documentation of the base kernel.
470462
471- def __call__ (self , X : Any , Y : Any = None , eval_gradient : bool = False ) -> Any :
472- X = transform (X )
473- return super ().__call__ (X , Y , eval_gradient )
463+ Parameters
464+ ----------
465+ X : ndarray of shape (n_samples_X, n_features)
466+ Left argument of the returned kernel k(X, Y).
474467
475- def __reduce__ ( self ) -> str | tuple [ Any , ...]:
476- return ( wrap_kernel , ( kernel , transform ))
468+ Y : ndarray of shape (n_samples_Y, n_features), default=None
469+ Right argument of the returned kernel k(X, Y). If None, k(X, X) is evaluated.
477470
478- return WrappedKernel (** kernel .get_params ())
471+ eval_gradient : bool, default=False
472+ Determines whether the gradient with respect to the kernel hyperparameter is calculated.
479473
474+ Returns
475+ -------
476+ K : ndarray of shape (n_samples_X, n_samples_Y)
480477
481- def copy_signature (source_fct : Callable [..., Any ]) -> Callable [[Callable [..., Any ]], Callable [..., Any ]]:
482- """Clones a signature from a source function to a target function.
478+ K_gradient : ndarray of shape (n_samples_X, n_samples_X, n_dims)
479+ """
480+ X = self .transform (X )
481+ return self .base_kernel (X , Y , eval_gradient )
483482
484- via
485- https://stackoverflow.com/a/58989918/
486- """
483+ def is_stationary ( self ):
484+ """Return whether the kernel is stationary."""
485+ return self . base_kernel . is_stationary ()
487486
488- def copy (target_fct : Callable [..., Any ]) -> Callable [..., Any ]:
489- target_fct .__signature__ = signature (source_fct )
490- return target_fct
487+ def diag (self , X : NDArray [Float ]) -> NDArray [Float ]:
488+ """Return the diagonal of k(X, X).
491489
492- return copy
490+ This method allows for more efficient calculations than calling
491+ np.diag(self(X)).
492+
493+
494+ Parameters
495+ ----------
496+ X : array-like of shape (n_samples,)
497+ Left argument of the returned kernel k(X, Y)
498+
499+ Returns
500+ -------
501+ K_diag : ndarray of shape (n_samples_X,)
502+ Diagonal of kernel k(X, X)
503+ """
504+ X = self .transform (X )
505+ return self .base_kernel .diag (X )
0 commit comments