Skip to content

Commit 49b34ef

Browse files
committed
Fixes #56. K=1 are now plotted.
1 parent e216d81 commit 49b34ef

File tree

2 files changed

+23
-16
lines changed

2 files changed

+23
-16
lines changed

structure_threader/plotter/structplot.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -642,7 +642,7 @@ def plotk(self, kvals, output_dir):
642642
"""
643643

644644
# Get number of plots (confirm the kvals are valid before)
645-
nplots = len([x for x in kvals if int(x) in self.kvals if x >= 1])
645+
nplots = len([x for x in kvals if int(x) in self.kvals])
646646

647647
# If no valid plots, issue an error
648648
if not nplots:
@@ -681,8 +681,15 @@ def plotk(self, kvals, output_dir):
681681
# Fetch PlotK object that will be plotted
682682
kobj = self.kvals[k]
683683

684+
# Transforms the qvals matrix when K = 1. If K > 2, use the
685+
# original matrix
686+
if len(kobj.qvals.shape) == 1:
687+
qvals = [kobj.qvals.T]
688+
else:
689+
qvals = kobj.qvals.T
690+
684691
# Iterate over each meanQ column (corresponding to each cluster)
685-
for p, i in enumerate(kobj.qvals.T):
692+
for p, i in enumerate(qvals):
686693

687694
# Create Bar trace for each cluster
688695
current_bar = go.Bar(
@@ -819,8 +826,6 @@ def plotk_static(self, kval, output_dir, bw=False, use_ind=False):
819826
with the --ind option, use those labels instead of population labels
820827
"""
821828

822-
qvalues = self.kvals[kval].qvals
823-
824829
plt.style.use("ggplot")
825830

826831
numinds = self.number_indv
@@ -833,7 +838,16 @@ def plotk_static(self, kval, output_dir, bw=False, use_ind=False):
833838
fig = plt.figure()
834839
axe = fig.add_subplot(111, xlim=(-.5, numinds - .5), ylim=(0, 1))
835840

836-
for i in range(qvalues.shape[1]):
841+
# Transforms the qvals matrix when K = 1. If K > 2, use the
842+
# original matrix
843+
if len(self.kvals[kval].qvals.shape) == 1:
844+
# This list comprehension ensures that the shape of the array
845+
# is (i, 1), where i is the number of samples
846+
qvalues = np.array([[x] for x in self.kvals[kval].qvals])
847+
else:
848+
qvalues = self.kvals[kval].qvals
849+
850+
for i in range(kval):
837851

838852
# Determine color/pattern arguments
839853
kwargs = {}
@@ -894,10 +908,6 @@ def plotk_static(self, kval, output_dir, bw=False, use_ind=False):
894908
plt.yticks([])
895909
plt.xticks([])
896910

897-
# Add k legend
898-
legend = plt.legend(bbox_to_anchor=(1.2, .5), loc=7, borderaxespad=0.)
899-
legend.get_frame().set_facecolor("white")
900-
901911
kfile = self.kvals[kval].file_path
902912
filename = splitext(basename(kfile))[0]
903913
filepath = join(output_dir, filename)
@@ -930,12 +940,11 @@ def main(result_files, fmt, outdir, bestk=None, popfile=None, indfile=None,
930940
# Plot all K files individually
931941
for k, kobj in klist:
932942

933-
if k >= 1 and k in filter_k:
943+
if k in filter_k:
934944
klist.plotk([k], outdir)
935945
klist.plotk_static(k, outdir, bw=bw, use_ind=use_ind)
936946

937947
# If a sequence of multiple bestk is provided, plot all files in a single
938948
# plot
939949
if bestk:
940-
bestk = [x for x in bestk if x >= 1]
941950
klist.plotk(bestk, outdir)

structure_threader/structure_threader.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,6 @@ def create_plts(resultsdir, wrapped_prog, Ks, bestk, arg):
228228
:param resultsdir: path to results directory
229229
"""
230230

231-
plt_list = [x for x in Ks if x != 1] # Don't plot K=1
232-
233231
outdir = os.path.join(resultsdir, "plots")
234232
if not os.path.exists(outdir):
235233
os.mkdir(outdir)
@@ -243,15 +241,15 @@ def create_plts(resultsdir, wrapped_prog, Ks, bestk, arg):
243241
file_to_plot = str(randrange(1, arg.replicates + 1))
244242
plt_files = [os.path.join(resultsdir, "str_K") + str(i) + "_rep" +
245243
file_to_plot + "_f"
246-
for i in plt_list]
244+
for i in Ks]
247245
elif wrapped_prog == "maverick":
248246
plt_files = [os.path.join(os.path.join(resultsdir, "mav_K" + str(i)),
249247
"outputQmatrix_ind_K" + str(i) + ".csv")
250-
for i in plt_list]
248+
for i in Ks]
251249

252250
else:
253251
plt_files = [os.path.join(resultsdir, "fS_run_K.") + str(i) + ".meanQ"
254-
for i in plt_list]
252+
for i in Ks]
255253

256254
sp.main(plt_files, wrapped_prog, outdir, bestk=bestk, popfile=arg.popfile,
257255
indfile=arg.indfile, bw=arg.blacknwhite, use_ind=arg.use_ind)

0 commit comments

Comments
 (0)