@@ -234,6 +234,8 @@ size_t TreeLikelihood_initialize_gradient(Model *self, int flags){
234234 int prepare_branch_model = tlk -> prepared_gradient & TREELIKELIHOOD_FLAG_BRANCH_MODEL ;
235235 int prepare_substitution_model_unconstrained = tlk -> prepared_gradient & TREELIKELIHOOD_FLAG_SUBSTITUTION_MODEL_UNCONSTRAINED ;
236236 int prepare_substitution_model = tlk -> prepared_gradient & TREELIKELIHOOD_FLAG_SUBSTITUTION_MODEL ;
237+ int prepare_substitution_model_rates = prepare_substitution_model | (tlk -> prepared_gradient & TREELIKELIHOOD_FLAG_SUBSTITUTION_MODEL_RATES );
238+ int prepare_substitution_model_frequencies = prepare_substitution_model | (tlk -> prepared_gradient & TREELIKELIHOOD_FLAG_SUBSTITUTION_MODEL_FREQUENCIES );
237239
238240 if (prepare_substitution_model_unconstrained && prepare_substitution_model ){
239241 fprintf (stderr , "Can only request unconstrained and constrained gradient at the same time\n" );
@@ -280,12 +282,18 @@ size_t TreeLikelihood_initialize_gradient(Model *self, int flags){
280282 gradient_length += tlk -> m -> simplex -> K - 1 ;
281283 tlk -> include_root_freqs = false;
282284 }
283- else if (prepare_substitution_model ) {
284- gradient_length += tlk -> m -> rates_simplex == NULL ? Parameters_count (tlk -> m -> rates ) : tlk -> m -> rates_simplex -> K ;
285- gradient_length += tlk -> m -> simplex -> K ;
286- tlk -> m -> grad_wrt_reparam = false;
287- tlk -> include_root_freqs = false;
288- }
285+ else {
286+ if (prepare_substitution_model_rates ) {
287+ gradient_length += tlk -> m -> rates_simplex == NULL ? Parameters_count (tlk -> m -> rates ) : tlk -> m -> rates_simplex -> K ;
288+ tlk -> m -> grad_wrt_reparam = false;
289+ tlk -> include_root_freqs = false;
290+ }
291+ if (prepare_substitution_model_frequencies ) {
292+ gradient_length += tlk -> m -> simplex -> K ;
293+ tlk -> m -> grad_wrt_reparam = false;
294+ tlk -> include_root_freqs = false;
295+ }
296+ }
289297
290298 if (tlk -> gradient == NULL ){
291299 tlk -> gradient = calloc (gradient_length , sizeof (double ));
@@ -3002,6 +3010,29 @@ void gradient_PMatrix(SingleTreeLikelihood* tlk, const double* pattern_likelihoo
30023010 }
30033011}
30043012
3013+ void gradient_PMatrix_rates (SingleTreeLikelihood * tlk , const double * pattern_likelihoods , double * gradient ){
3014+ size_t parameter_count = tlk -> m -> rates_simplex == NULL ? Parameters_count (tlk -> m -> rates ) : tlk -> m -> rates_simplex -> K ;
3015+ if (tlk -> m -> rates_simplex != NULL && tlk -> m -> grad_wrt_reparam ){
3016+ parameter_count -- ;
3017+ }
3018+ for (size_t i = 0 ; i < parameter_count ; i ++ ){
3019+ gradient [i ] = calculate_dlnl_dQ (tlk , i , pattern_likelihoods );
3020+ }
3021+ }
3022+
3023+ void gradient_PMatrix_frequencies (SingleTreeLikelihood * tlk , const double * pattern_likelihoods , double * gradient ){
3024+ size_t start = tlk -> m -> rates_simplex == NULL ? Parameters_count (tlk -> m -> rates ) : tlk -> m -> rates_simplex -> K ;
3025+ size_t parameter_count = tlk -> m -> simplex -> K ;
3026+ if (tlk -> m -> grad_wrt_reparam ) parameter_count -- ;
3027+
3028+ if (tlk -> m -> rates_simplex != NULL && tlk -> m -> grad_wrt_reparam ){
3029+ start -- ;
3030+ }
3031+ for (size_t i = 0 ; i < parameter_count ; i ++ ){
3032+ gradient [i ] = calculate_dlnl_dQ (tlk , i + start , pattern_likelihoods );
3033+ }
3034+ }
3035+
30053036void gradient_branch_length_from_cat (SingleTreeLikelihood * tlk , const double * cat_branch_gradient , double * gradient ){
30063037 size_t nodeCount = Tree_node_count (tlk -> tree );
30073038 size_t catCount = tlk -> sm -> cat_count ;
@@ -3073,7 +3104,9 @@ void SingleTreeLikelihood_gradient( SingleTreeLikelihood *tlk, double* grads ){
30733104 bool prepare_tree = tlk -> prepared_gradient & TREELIKELIHOOD_FLAG_TREE ;
30743105 bool prepare_site_model = tlk -> prepared_gradient & TREELIKELIHOOD_FLAG_SITE_MODEL ;
30753106 bool prepare_branch_model = tlk -> prepared_gradient & TREELIKELIHOOD_FLAG_BRANCH_MODEL ;
3076- bool prepare_subsitution_model = tlk -> prepared_gradient & TREELIKELIHOOD_FLAG_SUBSTITUTION_MODEL_UNCONSTRAINED || tlk -> prepared_gradient & TREELIKELIHOOD_FLAG_SUBSTITUTION_MODEL ;
3107+ bool prepare_substitution_model = tlk -> prepared_gradient & TREELIKELIHOOD_FLAG_SUBSTITUTION_MODEL_UNCONSTRAINED || tlk -> prepared_gradient & TREELIKELIHOOD_FLAG_SUBSTITUTION_MODEL ;
3108+ bool prepare_substitution_model_rates = prepare_substitution_model || (tlk -> prepared_gradient & TREELIKELIHOOD_FLAG_SUBSTITUTION_MODEL_RATES );
3109+ bool prepare_substitution_model_frequencies = prepare_substitution_model || (tlk -> prepared_gradient & TREELIKELIHOOD_FLAG_SUBSTITUTION_MODEL_FREQUENCIES );
30773110
30783111 size_t nodeCount = Tree_node_count (tlk -> tree );
30793112 size_t catCount = tlk -> sm -> cat_count ;
@@ -3167,9 +3200,19 @@ void SingleTreeLikelihood_gradient( SingleTreeLikelihood *tlk, double* grads ){
31673200 offset += Parameters_count (tlk -> bm -> rates );
31683201 }
31693202
3170- if (prepare_subsitution_model ){
3203+ if (prepare_substitution_model ){
31713204 gradient_PMatrix (tlk , pattern_likelihoods , grads + offset );
31723205 }
3206+ else {
3207+ if (prepare_substitution_model_rates ){
3208+ gradient_PMatrix_rates (tlk , pattern_likelihoods , grads + offset );
3209+ offset += tlk -> m -> rates_simplex == NULL ? Parameters_count (tlk -> m -> rates ) : tlk -> m -> rates_simplex -> K ;
3210+ if (tlk -> m -> grad_wrt_reparam ) offset -- ;
3211+ }
3212+ if (prepare_substitution_model_frequencies ){
3213+ gradient_PMatrix_frequencies (tlk , pattern_likelihoods , grads + offset );
3214+ }
3215+ }
31733216
31743217 free (branch_lengths );
31753218 free (cat_branch_gradient );
0 commit comments