@@ -797,9 +797,13 @@ def plotk(self, kvals, output_dir):
797797 with open (filepath , "w" ) as fh :
798798 fh .write (ploty_html (pdiv ))
799799
800- def plotk_static (self , kval , output_dir ):
800+ def plotk_static (self , kval , output_dir , bw = False ):
801801 """
802-
802+ Generates a structure plot in svg format.
803+ :param kval: (int) Must match the K value from self.kvals
804+ :param output_dir: (string) Path of the plot file
805+ :param bw: If True, plots will be generated with patterns instead of
806+ colors to distinguish k groups.
803807 """
804808
805809 qvalues = self .kvals [kval ].qvals
@@ -817,19 +821,36 @@ def plotk_static(self, kval, output_dir):
817821 axe = fig .add_subplot (111 , xlim = (- .5 , numinds - .5 ), ylim = (0 , 1 ))
818822
819823 for i in range (qvalues .shape [1 ]):
820- # Get bar color. If K exceeds the 12 colors, generate random color
821- try :
822- clr = clist [i ]
823- except IndexError :
824- clr = np .random .rand (3 , 1 )
824+
825+ # Determine color/pattern arguments
826+ kwargs = {}
827+
828+ # Use colors todistinguish k groups
829+ if not bw :
830+ # Get bar color. If K exceeds the 12 colors, generate random
831+ # color
832+ try :
833+ clr = clist [i ]
834+ except IndexError :
835+ clr = np .random .rand (3 , 1 )
836+
837+ kwargs ["facecolor" ] = clr
838+ kwargs ["edgecolor" ] = "grey"
839+
840+ else :
841+ grey_rgb = [(float ((i + 1 )) / (float (qvalues .shape [1 ]) + 1 ))
842+ for _ in range (3 )]
843+
844+ kwargs ["facecolor" ] = grey_rgb
845+ kwargs ["edgecolor" ] = "white"
825846
826847 if i == 0 :
827- axe .bar (range (numinds ), qvalues [:, i ], facecolor = clr ,
828- edgecolor = "grey" , width = 1 )
848+ axe .bar (range (numinds ), qvalues [:, i ],
849+ width = 1 , label = "K {}" . format ( i + 1 ), ** kwargs )
829850 former_q = qvalues [:, i ]
830851 else :
831852 axe .bar (range (numinds ), qvalues [:, i ], bottom = former_q ,
832- facecolor = clr , edgecolor = "grey" , width = 1 )
853+ width = 1 , label = "K {}" . format ( i + 1 ), ** kwargs )
833854 former_q = former_q + qvalues [:, i ]
834855
835856 # Annotate population info
@@ -860,15 +881,23 @@ def plotk_static(self, kval, output_dir):
860881 plt .yticks ([])
861882 plt .xticks ([])
862883
884+ # Add k legend
885+ legend = plt .legend (bbox_to_anchor = (1.2 , .5 ), loc = 7 , borderaxespad = 0. )
886+ legend .get_frame ().set_facecolor ("white" )
887+
863888 kfile = self .kvals [kval ].file_path
864889 filename = splitext (basename (kfile ))[0 ]
865890 filepath = join (output_dir , filename )
866891
867892 plt .savefig ("{}.svg" .format (filepath ), bbox_inches = "tight" )
868893
894+ # Clear plot object
895+ plt .clf ()
896+ plt .close ()
897+
869898
870899def main (result_files , fmt , outdir , bestk = None , popfile = None , indfile = None ,
871- filter_k = None ):
900+ filter_k = None , bw = False ):
872901 """
873902 Wrapper function that generates one plot for each K value.
874903 :return:
@@ -890,7 +919,7 @@ def main(result_files, fmt, outdir, bestk=None, popfile=None, indfile=None,
890919
891920 if k >= 1 and k in filter_k :
892921 klist .plotk ([k ], outdir )
893- klist .plotk_static (k , outdir )
922+ klist .plotk_static (k , outdir , bw = bw )
894923
895924 # If a sequence of multiple bestk is provided, plot all files in a single
896925 # plot
0 commit comments