@@ -18,25 +18,24 @@ def update_input_schema(input_schema, columns):
1818 new_input_schema .append (column )
1919 return new_input_schema
2020
21+ def _add_column (name , type , schema , columns = None ):
22+ schema .append ({'type' : type , 'name' : name })
23+ if columns is not None :
24+ columns .append (name )
25+
2126def get_scored_df_schema (tree , schema , columns , output_probabilities , is_evaluation = False , check_prediction = False ):
2227 check_input_schema (tree , set (column ["name" ] for column in schema ), is_evaluation )
2328 if columns is not None :
2429 schema = update_input_schema (schema , columns )
2530 if output_probabilities :
2631 for value in tree .target_values :
27- schema .append ({'type' : 'double' , 'name' : "proba_" + safe_str (value )})
28- if columns is not None :
29- columns .append ("proba_" + safe_str (value ))
30- schema .append ({'type' : 'string' , 'name' : 'prediction' })
31- if columns is not None :
32- columns .append ("prediction" )
32+ _add_column ('proba_' + safe_str (value ), 'double' , schema , columns )
33+ _add_column ('prediction' , 'string' , schema , columns )
3334 if check_prediction :
34- schema .append ({'type' : 'boolean' , 'name' : 'prediction_correct' })
35- if columns is not None :
36- columns .append ("prediction_correct" )
37- schema .append ({'type' : 'string' , 'name' : 'label' })
38- if columns is not None :
39- columns .append ("label" )
35+ _add_column ('prediction_correct' , 'boolean' , schema , columns )
36+ _add_column ('decision_rule' , 'array' , schema , columns )
37+ _add_column ('node_id' , 'int' , schema , columns )
38+ _add_column ('label' , 'string' , schema , columns )
4039 return schema
4140
4241def get_metric_df_schema (metrics_dict , metrics , recipe_config ):
@@ -78,6 +77,7 @@ def add_scoring_columns(tree, df, output_probabilities, is_evaluation=False, che
7877 df .loc [filtered_df_indices , "prediction_correct" ] = filtered_df [tree .target ] == leaf .prediction
7978 df .loc [label_indices , "label" ] = leaf .label
8079
81- elif leaf .label is not None :
82- filtered_df = tree .get_filtered_df (leaf , df )
83- df .loc [filtered_df .index , "label" ] = leaf .label
80+ filtered_df = tree .get_filtered_df (leaf , df )
81+ df .loc [filtered_df .index , "decision_rule" ] = safe_str (tree .get_decision_rule (leaf_id ))
82+ df .loc [filtered_df .index , "node_id" ] = leaf_id
83+ df .loc [filtered_df .index , "label" ] = leaf .label
0 commit comments