@@ -891,24 +891,23 @@ void RegTree::ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split
891
891
this ->param_ .num_nodes = this ->p_mt_tree_ ->Size ();
892
892
}
893
893
894
- void RegTree::ExpandCategorical (bst_node_t nid, bst_feature_t split_index,
895
- common::Span<const uint32_t > split_cat, bool default_left,
896
- bst_float base_weight, bst_float left_leaf_weight,
897
- bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
898
- float left_sum, float right_sum) {
894
+ void RegTree::ExpandCategorical (bst_node_t nidx, bst_feature_t split_index,
895
+ common::Span<common::KCatBitField::value_type> split_cat,
896
+ bool default_left, bst_float base_weight,
897
+ bst_float left_leaf_weight, bst_float right_leaf_weight,
898
+ bst_float loss_change, float sum_hess, float left_sum,
899
+ float right_sum) {
899
900
CHECK (!IsMultiTarget ());
900
- this ->ExpandNode (nid, split_index, std::numeric_limits<float >::quiet_NaN (),
901
- default_left, base_weight,
902
- left_leaf_weight, right_leaf_weight, loss_change, sum_hess,
903
- left_sum, right_sum);
901
+ this ->ExpandNode (nidx, split_index, DftBadValue (), default_left, base_weight, left_leaf_weight,
902
+ right_leaf_weight, loss_change, sum_hess, left_sum, right_sum);
904
903
905
904
size_t orig_size = split_categories_.size ();
906
905
this ->split_categories_ .resize (orig_size + split_cat.size ());
907
906
std::copy (split_cat.data (), split_cat.data () + split_cat.size (),
908
907
split_categories_.begin () + orig_size);
909
- this ->split_types_ .at (nid ) = FeatureType::kCategorical ;
910
- this ->split_categories_segments_ .at (nid ).beg = orig_size;
911
- this ->split_categories_segments_ .at (nid ).size = split_cat.size ();
908
+ this ->split_types_ .at (nidx ) = FeatureType::kCategorical ;
909
+ this ->split_categories_segments_ .at (nidx ).beg = orig_size;
910
+ this ->split_categories_segments_ .at (nidx ).size = split_cat.size ();
912
911
}
913
912
914
913
void RegTree::Load (dmlc::Stream* fi) {
0 commit comments