File tree Expand file tree Collapse file tree 2 files changed +12
-2
lines changed
Expand file tree Collapse file tree 2 files changed +12
-2
lines changed Original file line number Diff line number Diff line change @@ -86,6 +86,7 @@ def forward(self,
8686 self ._validate_fit ()
8787 if self .params ["sd" ] == 0 :
8888 # In the edge case where `X` is degenerate, avoid 0 divided by 0
89+ warnings .warn ('Transform constant target values by mean subtraction' , UserWarning )
8990 return zeros_like (X )
9091 else :
9192 return (X - self .params ['mu' ])/ self .params ['sd' ]
Original file line number Diff line number Diff line change @@ -255,7 +255,16 @@ def test_target_validation():
255255 with pytest .warns (UserWarning ):
256256 transform_func = Logit_Scaler (range_response = 100 )
257257 transform_func .forward (test_neg_response , fit = False )
258-
259-
258+
259+ # Transform constant target values
260+ test_constant_response = torch .zeros (10 ) + 9.0
261+ with pytest .warns (UserWarning ):
262+ Target ('Response1' , f_transform = 'Standard' ).transform_f (test_constant_response , fit = True )
263+
264+ # Corner case for Logit_Scaler
265+ transform_func = Logit_Scaler (standardize = False )
266+ transform_func .forward (test_response , fit = True )
267+
268+
260269if __name__ == '__main__' :
261270 pytest .main ([__file__ , '-m' , 'fast' ])
You can’t perform that action at this time.
0 commit comments