@@ -138,19 +138,21 @@ def get_all_split_indices(self):
138138 [self .cont_splits_train_indices_combined , self .cont_splits_test_indices_combined ]
139139 ]
140140
141-
142141 def plot_distributions (self ):
143142 fig , axs = plt .subplots (
144- nrows = 4 , ncols = self .n_cv ,
145- figsize = ((self .max_pos - self .min_pos ) * 0.1 * self .n_cv , 30 ),
143+ nrows = 4 , ncols = self .n_cv ,
146144 constrained_layout = True
147145 )
146+ fig .set_figwidth (30 )
147+ fig .set_figheight (10 )
148+
148149 poses , counts = self ._get_distribution (sorted (list (self .df .index )))
149150 for i in range (self .n_cv ):
150151 if i == self .n_cv // 2 :
151152 axs [0 , i ].set_title ("All data" )
152153 axs [0 , i ].plot (poses , counts , color = 'black' )
153154 axs [0 , i ].set_ylim (0 , 20 )
155+ axs [0 , i ].set_xlim (self .min_pos - 4 , self .max_pos + 4 )
154156 axs [0 , i ].set_ylabel (f"# Amino acids" )
155157 else :
156158 fig .delaxes (axs [0 , i ])
@@ -163,10 +165,14 @@ def plot_distributions(self):
163165 axs [i_category + 1 , i_split ].plot (pos_test , counts_test )
164166
165167 xticks = list (axs [i_category + 1 , i_split ].get_xticks ())
166- if self .min_pos != 1 and not self .min_pos in xticks :
167- xticks .append (self .min_pos )
168- xticks .append (self .max_pos )
168+ xticks = xticks [1 :- 1 ]
169+ if 0 in xticks :
170+ xticks .remove (0 )
171+ xticks .append (self .min_pos )
172+ xticks .append (self .max_pos )
169173 xticks = sorted (xticks )
174+ if (xticks [- 1 ] - xticks [- 2 ]) < 0.5 * (xticks [2 ] - xticks [1 ]):
175+ xticks .pop ()
170176 axs [i_category + 1 , i_split ].set_xticks (xticks )
171177 if i_category == 0 :
172178 axs [i_category + 1 , i_split ].set_title (f"Split { i_split + 1 } " )
@@ -177,11 +183,12 @@ def plot_distributions(self):
177183 if i_split == self .n_cv // 2 :
178184 axs [i_category + 1 , i_split ].set_title (category )
179185 axs [i_category + 1 , i_split ].set_ylim (0 , 20 )
186+ axs [i_category + 1 , i_split ].set_xlim (self .min_pos - 4 , self .max_pos + 4 )
180187 axs [0 , self .n_cv // 2 ].set_xticks (xticks )
181- #plt.tight_layout()
182188 fig_path = path .abspath (path .splitext (path .basename (self .csv_file ))[0 ] + '_pos_aa_distr.png' )
183189 plt .savefig (fig_path , dpi = 300 )
184190 print (f"Saved figure as { fig_path } " )
191+ plt .close (fig )
185192
186193
187194if __name__ == '__main__' :
0 commit comments