55# LICENSE file in the root directory of this source tree.
66
77import itertools
8+ import warnings
89from copy import deepcopy
910
1011import torch
12+ from botorch import settings
1113from botorch .exceptions .errors import BotorchTensorDimensionError
1214from botorch .models .transforms .input import (
1315 AffineInputTransform ,
2931from torch import Tensor
3032from torch .distributions import Kumaraswamy
3133from torch .nn import Module
34+ from torch .nn .functional import one_hot
3235
3336
3437def get_test_warp (indices , ** kwargs ):
@@ -534,19 +537,45 @@ def test_chained_input_transform(self):
534537 def test_round_transform (self ):
535538 for dtype in (torch .float , torch .double ):
536539 # basic init
537- int_idcs = [0 , 2 ]
538- round_tf = Round (indices = [0 , 2 ])
539- self .assertEqual (round_tf .indices .tolist (), int_idcs )
540+ int_idcs = [0 , 4 ]
541+ categorical_feats = {2 : 2 , 5 : 3 }
542+ # test deprecation warning
543+ with warnings .catch_warnings (record = True ) as ws , settings .debug (True ):
544+ Round (indices = int_idcs )
545+ self .assertTrue (
546+ any (issubclass (w .category , DeprecationWarning ) for w in ws )
547+ )
548+ round_tf = Round (
549+ integer_indices = int_idcs , categorical_features = categorical_feats
550+ )
551+ self .assertEqual (round_tf .integer_indices .tolist (), int_idcs )
552+ self .assertEqual (round_tf .categorical_features , categorical_feats )
540553 self .assertTrue (round_tf .training )
541- self .assertTrue (round_tf .approximate )
554+ self .assertFalse (round_tf .approximate )
542555 self .assertEqual (round_tf .tau , 1e-3 )
543556
544557 # basic usage
545- for batch_shape , approx in itertools .product (
546- (torch .Size (), torch .Size ([3 ])), (False , True )
558+ for batch_shape , approx , categorical_features in itertools .product (
559+ (torch .Size (), torch .Size ([3 ])),
560+ (False , True ),
561+ (None , categorical_feats ),
547562 ):
548- X = 5 * torch .rand (* batch_shape , 4 , 3 , device = self .device , dtype = dtype )
549- round_tf = Round (indices = [0 , 2 ], approximate = approx )
563+ X = torch .rand (* batch_shape , 4 , 8 , device = self .device , dtype = dtype )
564+ X [..., int_idcs ] *= 5
565+ if categorical_features is not None and approx :
566+ with self .assertRaises (NotImplementedError ):
567+ Round (
568+ integer_indices = int_idcs ,
569+ categorical_features = categorical_features ,
570+ approximate = approx ,
571+ )
572+ continue
573+ round_tf = Round (
574+ integer_indices = int_idcs ,
575+ categorical_features = categorical_features ,
576+ approximate = approx ,
577+ tau = 1e-1 ,
578+ )
550579 X_rounded = round_tf (X )
551580 exact_rounded_X_ints = X [..., int_idcs ].round ()
552581 # check non-integers parameters are unchanged
@@ -560,17 +589,39 @@ def test_round_transform(self):
560589 <= (X [..., int_idcs ] - exact_rounded_X_ints ).abs ()
561590 ).all ()
562591 )
592+ self .assertFalse (
593+ torch .equal (X_rounded [..., int_idcs ], exact_rounded_X_ints )
594+ )
563595 else :
564- # check that exact rounding behaves as expected
596+ # check that exact rounding behaves as expected for integers
565597 self .assertTrue (
566598 torch .equal (X_rounded [..., int_idcs ], exact_rounded_X_ints )
567599 )
600+ if categorical_features is not None :
601+ # test that discretization works as expected for categoricals
602+ for start , card in categorical_features .items ():
603+ end = start + card
604+ expected_categorical = one_hot (
605+ X [..., start :end ].argmax (dim = - 1 ), num_classes = card
606+ ).to (X )
607+ self .assertTrue (
608+ torch .equal (
609+ X_rounded [..., start :end ], expected_categorical
610+ )
611+ )
612+ # test that gradient information is passed via STE
613+ X2 = X .clone ().requires_grad_ (True )
614+ round_tf (X2 ).sum ().backward ()
615+ self .assertTrue (torch .equal (X2 .grad , torch .ones_like (X2 )))
568616 with self .assertRaises (NotImplementedError ):
569617 round_tf .untransform (X_rounded )
570618
571619 # test no transform on eval
572620 round_tf = Round (
573- indices = int_idcs , approximate = approx , transform_on_eval = False
621+ integer_indices = int_idcs ,
622+ categorical_features = categorical_features ,
623+ approximate = approx ,
624+ transform_on_eval = False ,
574625 )
575626 X_rounded = round_tf (X )
576627 self .assertFalse (torch .equal (X , X_rounded ))
@@ -580,7 +631,10 @@ def test_round_transform(self):
580631
581632 # test no transform on train
582633 round_tf = Round (
583- indices = int_idcs , approximate = approx , transform_on_train = False
634+ integer_indices = int_idcs ,
635+ categorical_features = categorical_features ,
636+ approximate = approx ,
637+ transform_on_train = False ,
584638 )
585639 X_rounded = round_tf (X )
586640 self .assertTrue (torch .equal (X , X_rounded ))
@@ -590,27 +644,48 @@ def test_round_transform(self):
590644
591645 # test equals
592646 round_tf2 = Round (
593- indices = int_idcs , approximate = approx , transform_on_train = False
647+ integer_indices = int_idcs ,
648+ categorical_features = categorical_features ,
649+ approximate = approx ,
650+ transform_on_train = False ,
594651 )
595652 self .assertTrue (round_tf .equals (round_tf2 ))
596653 # test different transform_on_train
597- round_tf2 = Round (indices = int_idcs , approximate = approx )
654+ round_tf2 = Round (
655+ integer_indices = int_idcs ,
656+ categorical_features = categorical_features ,
657+ approximate = approx ,
658+ )
598659 self .assertFalse (round_tf .equals (round_tf2 ))
599660 # test different approx
661+ round_tf = Round (
662+ integer_indices = int_idcs ,
663+ )
600664 round_tf2 = Round (
601- indices = int_idcs , approximate = not approx , transform_on_train = False
665+ integer_indices = int_idcs ,
666+ approximate = not approx ,
667+ transform_on_train = False ,
602668 )
603669 self .assertFalse (round_tf .equals (round_tf2 ))
604670 # test different indices
671+ round_tf = Round (
672+ integer_indices = int_idcs ,
673+ categorical_features = categorical_features ,
674+ transform_on_train = False ,
675+ )
605676 round_tf2 = Round (
606- indices = [0 , 1 ], approximate = approx , transform_on_train = False
677+ integer_indices = [0 , 1 ],
678+ categorical_features = categorical_features ,
679+ approximate = approx ,
680+ transform_on_train = False ,
607681 )
608682 self .assertFalse (round_tf .equals (round_tf2 ))
609683
610684 # test preprocess_transform
611685 round_tf .transform_on_train = False
612686 self .assertTrue (torch .equal (round_tf .preprocess_transform (X ), X ))
613687 round_tf .transform_on_train = True
688+ X_rounded = round_tf (X )
614689 self .assertTrue (
615690 torch .equal (round_tf .preprocess_transform (X ), X_rounded )
616691 )
0 commit comments