@@ -642,7 +642,7 @@ def plotk(self, kvals, output_dir):
642642 """
643643
644644 # Get number of plots (confirm the kvals are valid before)
645- nplots = len ([x for x in kvals if int (x ) in self .kvals if x >= 1 ])
645+ nplots = len ([x for x in kvals if int (x ) in self .kvals ])
646646
647647 # If no valid plots, issue an error
648648 if not nplots :
@@ -681,8 +681,15 @@ def plotk(self, kvals, output_dir):
681681 # Fetch PlotK object that will be plotted
682682 kobj = self .kvals [k ]
683683
684+ # Transforms the qvals matrix when K = 1. If K > 2, use the
685+ # original matrix
686+ if len (kobj .qvals .shape ) == 1 :
687+ qvals = [kobj .qvals .T ]
688+ else :
689+ qvals = kobj .qvals .T
690+
684691 # Iterate over each meanQ column (corresponding to each cluster)
685- for p , i in enumerate (kobj . qvals . T ):
692+ for p , i in enumerate (qvals ):
686693
687694 # Create Bar trace for each cluster
688695 current_bar = go .Bar (
@@ -819,8 +826,6 @@ def plotk_static(self, kval, output_dir, bw=False, use_ind=False):
819826 with the --ind option, use those labels instead of population labels
820827 """
821828
822- qvalues = self .kvals [kval ].qvals
823-
824829 plt .style .use ("ggplot" )
825830
826831 numinds = self .number_indv
@@ -833,7 +838,16 @@ def plotk_static(self, kval, output_dir, bw=False, use_ind=False):
833838 fig = plt .figure ()
834839 axe = fig .add_subplot (111 , xlim = (- .5 , numinds - .5 ), ylim = (0 , 1 ))
835840
836- for i in range (qvalues .shape [1 ]):
841+ # Transforms the qvals matrix when K = 1. If K > 2, use the
842+ # original matrix
843+ if len (self .kvals [kval ].qvals .shape ) == 1 :
844+ # This list comprehension ensures that the shape of the array
845+ # is (i, 1), where i is the number of samples
846+ qvalues = np .array ([[x ] for x in self .kvals [kval ].qvals ])
847+ else :
848+ qvalues = self .kvals [kval ].qvals
849+
850+ for i in range (kval ):
837851
838852 # Determine color/pattern arguments
839853 kwargs = {}
@@ -894,10 +908,6 @@ def plotk_static(self, kval, output_dir, bw=False, use_ind=False):
894908 plt .yticks ([])
895909 plt .xticks ([])
896910
897- # Add k legend
898- legend = plt .legend (bbox_to_anchor = (1.2 , .5 ), loc = 7 , borderaxespad = 0. )
899- legend .get_frame ().set_facecolor ("white" )
900-
901911 kfile = self .kvals [kval ].file_path
902912 filename = splitext (basename (kfile ))[0 ]
903913 filepath = join (output_dir , filename )
@@ -930,12 +940,11 @@ def main(result_files, fmt, outdir, bestk=None, popfile=None, indfile=None,
930940 # Plot all K files individually
931941 for k , kobj in klist :
932942
933- if k >= 1 and k in filter_k :
943+ if k in filter_k :
934944 klist .plotk ([k ], outdir )
935945 klist .plotk_static (k , outdir , bw = bw , use_ind = use_ind )
936946
937947 # If a sequence of multiple bestk is provided, plot all files in a single
938948 # plot
939949 if bestk :
940- bestk = [x for x in bestk if x >= 1 ]
941950 klist .plotk (bestk , outdir )
0 commit comments