@@ -2379,110 +2379,65 @@ std::map<double, double> APLRRegressor::get_coefficient_shape_function(size_t pr
23792379        return  coefficient_shape_function;
23802380
23812381    std::vector<double > split_points;
2382-     split_points.reserve (relevant_term_indexes.size ());
2383-     double  linear_term_combined_effect{0.0 };
2382+     split_points.reserve (relevant_term_indexes.size () * 4 );
23842383    for  (auto  &relevant_term_index : relevant_term_indexes)
23852384    {
23862385        bool  split_point_exits{std::isfinite (terms[relevant_term_index].split_point )};
23872386        if  (split_point_exits)
23882387        {
23892388            split_points.push_back (terms[relevant_term_index].split_point );
23902389        }
2391-         else 
2390+         for  ( auto  &given_term : terms[relevant_term_index]. given_terms ) 
23922391        {
2393-             linear_term_combined_effect += terms[relevant_term_index].coefficient ;
2392+             bool  split_point_exits{std::isfinite (given_term.split_point )};
2393+             if  (split_point_exits)
2394+             {
2395+                 split_points.push_back (given_term.split_point );
2396+             }
23942397        }
23952398    }
2396-     split_points = remove_duplicate_elements_from_vector (split_points);
2397-     split_points.shrink_to_fit ();
2398- 
23992399    bool  no_split_points{split_points.size () == 0 };
24002400    if  (no_split_points)
24012401    {
2402-         coefficient_shape_function[ 0.0 ] = linear_term_combined_effect ;
2403-         return  coefficient_shape_function ;
2402+         split_points. push_back ( 0 ) ;
2403+         split_points. push_back ( 1 ) ;
24042404    }
2405-     double  increment_around_split_points ;
2405+     split_points =  remove_duplicate_elements_from_vector (split_points) ;
24062406    bool  one_split_point{split_points.size () == 1 };
24072407    if  (one_split_point)
24082408    {
2409-         increment_around_split_points = split_points[0 ] / DIVISOR_IN_GET_COEFFICIENT_SHAPE_FUNCTION;
2409+         split_points.push_back (split_points[0 ] - 1 );
2410+         split_points = remove_duplicate_elements_from_vector (split_points);
24102411    }
2411-     else 
2412+ 
2413+     VectorXd split_point_increments{VectorXd (split_points.size () - 1 )};
2414+     for  (Eigen::Index i = 0 ; i < split_point_increments.size (); ++i)
24122415    {
2413-         std::sort (split_points.begin (), split_points.end ());
2414-         VectorXd split_point_increments{VectorXd (split_points.size () - 1 )};
2415-         for  (Eigen::Index i = 0 ; i < split_point_increments.size (); ++i)
2416-         {
2417-             split_point_increments[i] = split_points[i + 1 ] - split_points[i];
2418-         }
2419-         double  minimum_split_point_increment{split_point_increments.minCoeff ()};
2420-         increment_around_split_points = minimum_split_point_increment / DIVISOR_IN_GET_COEFFICIENT_SHAPE_FUNCTION;
2416+         split_point_increments[i] = split_points[i + 1 ] - split_points[i];
24212417    }
2418+     double  minimum_split_point_increment{split_point_increments.minCoeff ()};
2419+     double  increment_around_split_points{minimum_split_point_increment / DIVISOR_IN_GET_COEFFICIENT_SHAPE_FUNCTION};
24222420
2423-     for  (size_t  i = 0 ; i < relevant_term_indexes.size (); ++i)
2421+     size_t  num_split_points{split_points.size ()};
2422+     for  (size_t  i = 0 ; i < num_split_points; ++i)
24242423    {
2425-         bool  split_point_exits{std::isfinite (terms[relevant_term_indexes[i]].split_point )};
2426-         if  (split_point_exits)
2427-         {
2428-             coefficient_shape_function[terms[relevant_term_indexes[i]].split_point  - increment_around_split_points] = linear_term_combined_effect;
2429-             coefficient_shape_function[terms[relevant_term_indexes[i]].split_point ] = linear_term_combined_effect;
2430-             coefficient_shape_function[terms[relevant_term_indexes[i]].split_point  + increment_around_split_points] = linear_term_combined_effect;
2431-         }
2424+         split_points.push_back (split_points[i] - increment_around_split_points);
2425+         split_points.push_back (split_points[i] + increment_around_split_points);
24322426    }
2427+     split_points.push_back (split_points[split_points.size () - 1 ] + increment_around_split_points);
2428+     split_points = remove_duplicate_elements_from_vector (split_points);
2429+     split_points.shrink_to_fit ();
24332430
2434-     for  (size_t  i = 0 ; i < relevant_term_indexes.size (); ++i)
2431+     MatrixXd X{MatrixXd::Constant (split_points.size (), number_of_base_terms, 0 )};
2432+     for  (size_t  i = 0 ; i < split_points.size (); ++i)
24352433    {
2436-         bool  split_point_exits{std::isfinite (terms[relevant_term_indexes[i]].split_point )};
2437-         if  (split_point_exits)
2438-         {
2439-             if  (terms[relevant_term_indexes[i]].direction_right )
2440-             {
2441-                 for  (auto  &key : coefficient_shape_function)
2442-                 {
2443-                     bool  key_split_point_is_higher{std::isgreater (key.first , terms[relevant_term_indexes[i]].split_point )};
2444-                     bool  key_split_point_is_not_too_high{true };
2445-                     for  (auto  &given_term : terms[relevant_term_indexes[i]].given_terms )
2446-                     {
2447-                         if  (given_term.direction_right  != terms[relevant_term_indexes[i]].direction_right )
2448-                         {
2449-                             if  (std::isgreater (key.first , given_term.split_point ))
2450-                             {
2451-                                 key_split_point_is_not_too_high = false ;
2452-                                 break ;
2453-                             }
2454-                         }
2455-                     }
2456-                     if  (key_split_point_is_higher && key_split_point_is_not_too_high)
2457-                     {
2458-                         key.second  += terms[relevant_term_indexes[i]].coefficient ;
2459-                     }
2460-                 }
2461-             }
2462-             else 
2463-             {
2464-                 for  (auto  &key : coefficient_shape_function)
2465-                 {
2466-                     bool  key_split_point_is_lower{std::isless (key.first , terms[relevant_term_indexes[i]].split_point )};
2467-                     bool  key_split_point_is_not_too_low{true };
2468-                     for  (auto  &given_term : terms[relevant_term_indexes[i]].given_terms )
2469-                     {
2470-                         if  (given_term.direction_right  != terms[relevant_term_indexes[i]].direction_right )
2471-                         {
2472-                             if  (std::isless (key.first , given_term.split_point ))
2473-                             {
2474-                                 key_split_point_is_not_too_low = false ;
2475-                                 break ;
2476-                             }
2477-                         }
2478-                     }
2479-                     if  (!key_split_point_is_lower)
2480-                         break ;
2481-                     else  if  (key_split_point_is_not_too_low)
2482-                         key.second  += terms[relevant_term_indexes[i]].coefficient ;
2483-                 }
2484-             }
2485-         }
2434+         X.col (predictor_index)[i] = split_points[i];
2435+     }
2436+ 
2437+     VectorXd contribution_to_linear_predictor{calculate_local_contribution_from_selected_terms (X, {predictor_index})};
2438+     for  (size_t  i = 0 ; i < split_points.size () - 1 ; ++i)
2439+     {
2440+         coefficient_shape_function[split_points[i]] = (contribution_to_linear_predictor[i + 1 ] - contribution_to_linear_predictor[i]) / (split_points[i + 1 ] - split_points[i]);
24862441    }
24872442
24882443    return  coefficient_shape_function;
@@ -2494,21 +2449,8 @@ std::vector<size_t> APLRRegressor::compute_relevant_term_indexes(size_t predicto
24942449    relevant_term_indexes.reserve (terms.size ());
24952450    for  (size_t  i = 0 ; i < terms.size (); ++i)
24962451    {
2497-         bool  predictor_index_is_base_term{terms[i].base_term  == predictor_index};
2498-         if  (predictor_index_is_base_term)
2499-         {
2500-             bool  no_interactions_with_other_base_terms{true };
2501-             for  (auto  &given_term : terms[i].given_terms )
2502-             {
2503-                 if  (given_term.base_term  != predictor_index)
2504-                 {
2505-                     no_interactions_with_other_base_terms = false ;
2506-                     break ;
2507-                 }
2508-             }
2509-             if  (no_interactions_with_other_base_terms)
2510-                 relevant_term_indexes.push_back (i);
2511-         }
2452+         if  (terms[i].term_uses_just_these_predictors ({predictor_index}))
2453+             relevant_term_indexes.push_back (i);
25122454    }
25132455    relevant_term_indexes.shrink_to_fit ();
25142456    return  relevant_term_indexes;
0 commit comments