@@ -48,7 +48,7 @@ def test_validate_unsupported_task_type(self, forecasting_enabled):
4848 task_type = 'regre' ,
4949 forecasting_enabled = forecasting_enabled )
5050
51- def test_missing_data_warnings (self ):
51+ def test_missing_test_data (self ):
5252 train_data = {
5353 'Column1' : [10 , 20 , 90 , 40 , 50 ],
5454 'Column2' : [10 , 20 , 90 , 40 , 50 ],
@@ -57,7 +57,39 @@ def test_missing_data_warnings(self):
5757 train = pd .DataFrame (train_data )
5858
5959 test_data = {
60- 'Column1' : [10 , 20 , np .nan , 40 , 50 ],
60+ 'Column1' : [10 , 20 , 90 , 40 , 50 ],
61+ 'Column2' : [10 , 20 , 90 , 40 , 50 ],
62+ 'Target' : [10 , 20 , np .nan , 40 , 50 ]
63+ }
64+ test = pd .DataFrame (test_data )
65+
66+ X_train = train .drop (columns = ['Target' ])
67+ y_train = train ['Target' ].values
68+ model = create_complex_classification_pipeline (
69+ X_train , y_train , ['Column1' , 'Column2' ], [])
70+
71+ with pytest .raises (
72+ UserConfigValidationException ,
73+ match = "['Column1']" ) as ucve :
74+ RAIInsights (
75+ model = model ,
76+ train = train ,
77+ test = test ,
78+ target_column = 'Target' ,
79+ task_type = 'classification' )
80+ assert "Features ['Target'] have missing values in " + \
81+ "test data" in str (ucve .value )
82+
83+ def test_missing_train_data (self ):
84+ train_data = {
85+ 'Column1' : [10 , 20 , 90 , 40 , 50 ],
86+ 'Column2' : [10 , 20 , np .nan , 40 , 50 ],
87+ 'Target' : [10 , 20 , 90 , 40 , 50 ]
88+ }
89+ train = pd .DataFrame (train_data )
90+
91+ test_data = {
92+ 'Column1' : [10 , 20 , 90 , 40 , 50 ],
6193 'Column2' : [10 , 20 , 90 , 40 , 50 ],
6294 'Target' : [10 , 20 , 90 , 40 , 50 ]
6395 }
@@ -68,15 +100,17 @@ def test_missing_data_warnings(self):
68100 model = create_complex_classification_pipeline (
69101 X_train , y_train , ['Column1' , 'Column2' ], [])
70102
71- with pytest .warns (
72- UserWarning ,
73- match = "['Column1 ']" ):
103+ with pytest .raises (
104+ UserConfigValidationException ,
105+ match = "['Column2 ']" ) as ucve :
74106 RAIInsights (
75107 model = model ,
76108 train = train ,
77109 test = test ,
78110 target_column = 'Target' ,
79111 task_type = 'classification' )
112+ assert "Features ['Column2'] have missing values in " + \
113+ "train data" in str (ucve .value )
80114
81115 def test_validate_test_data_size (self ):
82116 X_train , X_test , y_train , y_test , _ , _ = \
0 commit comments