@@ -891,24 +891,23 @@ void RegTree::ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split
891891 this ->param_ .num_nodes = this ->p_mt_tree_ ->Size ();
892892}
893893
894- void RegTree::ExpandCategorical (bst_node_t nid, bst_feature_t split_index,
895- common::Span<const uint32_t > split_cat, bool default_left,
896- bst_float base_weight, bst_float left_leaf_weight,
897- bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
898- float left_sum, float right_sum) {
894+ void RegTree::ExpandCategorical (bst_node_t nidx, bst_feature_t split_index,
895+ common::Span<common::KCatBitField::value_type> split_cat,
896+ bool default_left, bst_float base_weight,
897+ bst_float left_leaf_weight, bst_float right_leaf_weight,
898+ bst_float loss_change, float sum_hess, float left_sum,
899+ float right_sum) {
899900 CHECK (!IsMultiTarget ());
900- this ->ExpandNode (nid, split_index, std::numeric_limits<float >::quiet_NaN (),
901- default_left, base_weight,
902- left_leaf_weight, right_leaf_weight, loss_change, sum_hess,
903- left_sum, right_sum);
901+ this ->ExpandNode (nidx, split_index, DftBadValue (), default_left, base_weight, left_leaf_weight,
902+ right_leaf_weight, loss_change, sum_hess, left_sum, right_sum);
904903
905904 size_t orig_size = split_categories_.size ();
906905 this ->split_categories_ .resize (orig_size + split_cat.size ());
907906 std::copy (split_cat.data (), split_cat.data () + split_cat.size (),
908907 split_categories_.begin () + orig_size);
909- this ->split_types_ .at (nid ) = FeatureType::kCategorical ;
910- this ->split_categories_segments_ .at (nid ).beg = orig_size;
911- this ->split_categories_segments_ .at (nid ).size = split_cat.size ();
908+ this ->split_types_ .at (nidx ) = FeatureType::kCategorical ;
909+ this ->split_categories_segments_ .at (nidx ).beg = orig_size;
910+ this ->split_categories_segments_ .at (nidx ).size = split_cat.size ();
912911}
913912
914913void RegTree::Load (dmlc::Stream* fi) {
0 commit comments