Skip to content

Commit b8c3dbf

Browse files
author
jiyuang
committed
add an example for plot-tools
1 parent 36dbead commit b8c3dbf

File tree

10 files changed

+64649
-75
lines changed

10 files changed

+64649
-75
lines changed

tools/plot-tools/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ python setup.py install
2121

2222
## Usage
2323
There are two ways to use this tool:
24-
1. Specify parameters in `band.py` or `dos.py` directly, and then `python band.py` or `python dos.py`. And you can also import module in your own script e.g. `from abacus_plot.band import Band`
24+
1. Specify parameters in `band.py` or `dos.py` directly, and then `python band.py` or `python dos.py`. And you can also import module in your own script e.g. `from abacus_plot.band import Band`. (Recommend)
2525
2. Command-line tools are also supported in this tool. In this way, you need prepare an input file and execute some commands (see below). You can use `abacus-plot -h` to check command-line information
2626

2727
### Band Structure

tools/plot-tools/abacus_plot/band.py

Lines changed: 60 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ def __init__(self, bandfile: Union[PathLike, Sequence[PathLike]] = None, kptfile
3737
if self.kptfile:
3838
self.kpt = read_kpt(kptfile)
3939
self.k_index = list(map(int, self.k_index))
40+
if self.kpt:
41+
self._kzip = self.kpt.label_special_k
42+
else:
43+
self._kzip = self.k_index
4044

4145
@classmethod
4246
def read(cls, filename: PathLike):
@@ -168,14 +172,14 @@ def plot_data(cls,
168172
ax: axes.Axes,
169173
x: Sequence,
170174
y: Sequence,
171-
index: Sequence,
175+
kzip: Sequence,
172176
efermi: float = 0,
173177
energy_range: Sequence[float] = [],
174178
**kwargs):
175179
"""Plot band structure
176180
177181
:params x, y: x-axis and y-axis coordinates
178-
:params index: special k-points label and its index in data file
182+
:params kzip: special k-points label and its k-index in data file
179183
:params efermi: Fermi level in unit eV
180184
:params energy_range: range of energy to plot, its length equals to two
181185
"""
@@ -189,7 +193,7 @@ def plot_data(cls,
189193

190194
bandplot.ax.plot(kpoints, energy, lw=bandplot._lw, color=bandplot._color,
191195
label=bandplot._label, linestyle=bandplot._linestyle)
192-
bandplot._set_figure(index, energy_range)
196+
bandplot._set_figure(kzip, energy_range)
193197

194198
def plot(self,
195199
fig: Figure,
@@ -231,11 +235,7 @@ def plot(self,
231235
bandplot.ax.plot(self.k_index, band,
232236
lw=bandplot._lw, color=bandplot._color, label=bandplot._label, linestyle=bandplot._linestyle)
233237

234-
if self.kpt:
235-
index = self.kpt.label_special_k
236-
else:
237-
index = self.k_index
238-
bandplot._set_figure(index, energy_range)
238+
bandplot._set_figure(self._kzip, energy_range)
239239

240240
return bandplot
241241

@@ -253,16 +253,16 @@ def __init__(self, fig: Figure, ax: axes.Axes, **kwargs) -> None:
253253
self._linestyle = kwargs.pop('linestyle', 'solid')
254254
self.plot_params = kwargs
255255

256-
def _set_figure(self, index, range: Sequence):
256+
def _set_figure(self, kzip, range: Sequence):
257257
"""set figure and axes for plotting
258258
259-
:params index: dict of label of points of x-axis and its index in data file. Range of x-axis based on index.value()
259+
:params kzip: dict of label of points of x-axis and its index in data file. Range of x-axis based on kzip.value()
260260
:params range: range of y-axis
261261
"""
262262

263263
keys = []
264264
values = []
265-
for t in index:
265+
for t in kzip:
266266
if isinstance(t, tuple):
267267
keys.append(t[0])
268268
values.append(t[1])
@@ -355,6 +355,10 @@ def __init__(self, bandfile: Union[PathLike, Sequence[PathLike]] = None, kptfile
355355
if self.kptfile:
356356
self.kpt = read_kpt(kptfile)
357357
self.k_index = list(map(int, self.k_index))
358+
if self.kpt:
359+
self._kzip = self.kpt.label_special_k
360+
else:
361+
self._kzip = self.k_index
358362

359363
def _check_energy(self, energy):
360364
assert energy.shape[0] == self.nkpoints, "The dimension of band structure dismatches with the number of k-points."
@@ -421,7 +425,7 @@ def read(cls, filename: PathLike):
421425

422426
return nspin, norbitals, eunit, nbands, nkpoints, k_index, energy, orbitals
423427

424-
def _write(self, species: Union[Sequence[Any], Dict[Any, List[int]], Dict[Any, Dict[int, List[int]]]], keyname='', file_dir:PathLike=''):
428+
def _write(self, species: Union[Sequence[Any], Dict[Any, List[int]], Dict[Any, Dict[int, List[int]]]], keyname='', file_dir: PathLike = ''):
425429
"""Write parsed projected bands data to files
426430
427431
Args:
@@ -581,10 +585,11 @@ def _plot(self,
581585
BandPlot object: for manually plotting picture with bandplot.ax
582586
"""
583587

584-
def _seg_plot(bandplot, lc, index, file_dir, name):
588+
def _seg_plot(bandplot, lc, file_dir, name):
585589
cbar = bandplot.fig.colorbar(lc, ax=bandplot.ax)
586-
bandplot._set_figure(index, energy_range)
587-
bandplot.fig.savefig(file_dir/f'{keyname}-{bandplot._label}.pdf', dpi=400)
590+
bandplot._set_figure(self._kzip, energy_range)
591+
bandplot.fig.savefig(
592+
file_dir/f'{keyname}-{bandplot._label}.pdf', dpi=400)
588593
cbar.remove()
589594
plt.cla()
590595

@@ -595,15 +600,10 @@ def _seg_plot(bandplot, lc, index, file_dir, name):
595600
file_dir = Path(f"{outdir}", f"PBAND{out_index}_FIG")
596601
file_dir.mkdir(exist_ok=True)
597602

598-
if self.kpt:
599-
index = self.kpt.label_special_k
600-
else:
601-
index = self.k_index
602-
603603
if not species:
604604
bandplot = BandPlot(fig, ax, **kwargs)
605605
bandplot = super().plot(fig, ax, efermi, energy_range, shift, **kwargs)
606-
bandplot._set_figure(index, energy_range)
606+
bandplot._set_figure(self._kzip, energy_range)
607607

608608
return bandplot
609609

@@ -613,15 +613,19 @@ def _seg_plot(bandplot, lc, index, file_dir, name):
613613
bandplot = BandPlot(fig, ax, **kwargs)
614614
bandplot._label = elem
615615
for ib in range(self.nbands):
616-
points = np.array((self.k_index, energy[0:, ib])).T.reshape(-1, 1, 2)
617-
segments = np.concatenate([points[:-1], points[1:]], axis=1)
618-
norm = Normalize(vmin=wei[elem][0:, ib].min(), vmax=wei[elem][0:, ib].max())
619-
lc = LineCollection(segments, cmap=plt.get_cmap(cmap), norm=norm)
616+
points = np.array(
617+
(self.k_index, energy[0:, ib])).T.reshape(-1, 1, 2)
618+
segments = np.concatenate(
619+
[points[:-1], points[1:]], axis=1)
620+
norm = Normalize(
621+
vmin=wei[elem][0:, ib].min(), vmax=wei[elem][0:, ib].max())
622+
lc = LineCollection(
623+
segments, cmap=plt.get_cmap(cmap), norm=norm)
620624
lc.set_array(wei[elem][0:, ib])
621625
lc.set_label(bandplot._label)
622626
bandplot.ax.add_collection(lc)
623-
624-
_seg_plot(bandplot, lc, index, file_dir, name=f'{elem}')
627+
628+
_seg_plot(bandplot, lc, file_dir, name=f'{elem}')
625629
bandplots.append(bandplot)
626630
return bandplots
627631

@@ -638,30 +642,40 @@ def _seg_plot(bandplot, lc, index, file_dir, name):
638642
m_index = int(mag)
639643
bandplot._label = f"{elem}-{get_angular_momentum_name(l_index, m_index)}"
640644
for ib in range(self.nbands):
641-
points = np.array((self.k_index, energy[0:, ib])).T.reshape(-1, 1, 2)
642-
segments = np.concatenate([points[:-1], points[1:]], axis=1)
643-
norm = Normalize(vmin=wei[elem][ang][mag][0:, ib].min(), vmax=wei[elem][ang][mag][0:, ib].max())
644-
lc = LineCollection(segments, cmap=plt.get_cmap(cmap), norm=norm)
645+
points = np.array(
646+
(self.k_index, energy[0:, ib])).T.reshape(-1, 1, 2)
647+
segments = np.concatenate(
648+
[points[:-1], points[1:]], axis=1)
649+
norm = Normalize(vmin=wei[elem][ang][mag][0:, ib].min(
650+
), vmax=wei[elem][ang][mag][0:, ib].max())
651+
lc = LineCollection(
652+
segments, cmap=plt.get_cmap(cmap), norm=norm)
645653
lc.set_array(wei[elem][ang][mag][0:, ib])
646654
lc.set_label(bandplot._label)
647655
bandplot.ax.add_collection(lc)
648-
649-
_seg_plot(bandplot, lc, index, elem_file_dir, name=f'{elem}_{ang}_{mag}')
656+
657+
_seg_plot(bandplot, lc, elem_file_dir,
658+
name=f'{elem}_{ang}_{mag}')
650659
bandplots.append(bandplot)
651660

652661
else:
653662
bandplot = BandPlot(fig, ax, **kwargs)
654663
bandplot._label = f"{elem}-{get_angular_momentum_label(l_index)}"
655664
for ib in range(self.nbands):
656-
points = np.array((self.k_index, energy[0:, ib])).T.reshape(-1, 1, 2)
657-
segments = np.concatenate([points[:-1], points[1:]], axis=1)
658-
norm = Normalize(vmin=wei[elem][ang][0:, ib].min(), vmax=wei[elem][ang][0:, ib].max())
659-
lc = LineCollection(segments, cmap=plt.get_cmap(cmap), norm=norm)
665+
points = np.array(
666+
(self.k_index, energy[0:, ib])).T.reshape(-1, 1, 2)
667+
segments = np.concatenate(
668+
[points[:-1], points[1:]], axis=1)
669+
norm = Normalize(vmin=wei[elem][ang][0:, ib].min(
670+
), vmax=wei[elem][ang][0:, ib].max())
671+
lc = LineCollection(
672+
segments, cmap=plt.get_cmap(cmap), norm=norm)
660673
lc.set_array(wei[elem][ang][0:, ib])
661674
lc.set_label(bandplot._label)
662675
bandplot.ax.add_collection(lc)
663-
664-
_seg_plot(bandplot, lc, index, elem_file_dir, name=f'{elem}_{ang}')
676+
677+
_seg_plot(bandplot, lc, elem_file_dir,
678+
name=f'{elem}_{ang}')
665679
bandplots.append(bandplot)
666680

667681
return bandplots
@@ -742,19 +756,19 @@ def plot(self,
742756
if __name__ == "__main__":
743757
import matplotlib.pyplot as plt
744758
from pathlib import Path
745-
parent = Path(r"C:\Users\YY.Ji\Desktop")
759+
parent = Path(r"C:\Users\YY.Ji\Desktop\Si")
746760
name = "PBANDS_1"
747761
path = parent/name
748762
fig, ax = plt.subplots(figsize=(12, 6))
749-
energy_range = [-5, 6]
750-
efermi = 4.417301755850272
763+
energy_range = [-5, 7]
764+
efermi = 6.585653952007503
751765
shift = False
752766
#species = {"Ag": [2], "Cl": [1], "In": [0]}
753-
atom_index = {8: {2: [1, 2]}, 4: {2: [1, 2]}, 10: [1, 2]}
767+
atom_index = {1: {1: [0, 1]}}
754768
pband = PBand(str(path))
755769

756-
# if you want to specify `species` or `index`, you need to
757-
# set `species=species` or `index=index` in the following two functions
770+
# if you want to specify `species` or `index`, you need to
771+
# set `species=species` or `index=index` in the following two functions
758772
pband.plot(fig, ax, atom_index=atom_index, efermi=efermi,
759-
energy_range=energy_range, shift=shift)
773+
energy_range=energy_range, shift=shift)
760774
pband.write(atom_index=atom_index)

tools/plot-tools/abacus_plot/dos.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def _set_figure(self, energy_range: Sequence = [], dos_range: Sequence = [], not
154154
class TDOS(DOS):
155155
"""Parse total DOS data"""
156156

157-
def __init__(self, tdosfile: PathLike=None) -> None:
157+
def __init__(self, tdosfile: PathLike = None) -> None:
158158
super().__init__()
159159
self.tdosfile = tdosfile
160160
self._read()
@@ -202,7 +202,7 @@ def plot(self, fig: Figure, ax: Union[axes.Axes, Sequence[axes.Axes]], efermi: f
202202
class PDOS(DOS):
203203
"""Parse partial DOS data"""
204204

205-
def __init__(self, pdosfile: PathLike=None) -> None:
205+
def __init__(self, pdosfile: PathLike = None) -> None:
206206
super().__init__()
207207
self.pdosfile = pdosfile
208208
self._read()
@@ -391,6 +391,8 @@ def _parial_plot(self,
391391
Returns:
392392
DOSPlot object: for manually plotting picture with dosplot.ax
393393
"""
394+
if not isinstance(ax, list):
395+
ax = [ax]
394396

395397
dos, totnum = parse_projected_data(self.orbitals, species, keyname)
396398
energy_f, tdos = self._shift_energy(efermi, shift, prec)
@@ -403,7 +405,7 @@ def _parial_plot(self,
403405
notes=dosplot.plot_params["notes"])
404406
else:
405407
dosplot._set_figure(energy_range, dos_range)
406-
408+
407409
return dosplot
408410

409411
if isinstance(species, (list, tuple)):
@@ -515,17 +517,18 @@ def plot(self,
515517
# energy_range=energy_range, dos_range=dos_range, notes={'s': '(a)'})
516518
# fig.savefig("tdos.png")
517519

518-
pdosfile = r"C:\Users\YY.Ji\Desktop\PDOS"
520+
pdosfile = r"C:\Users\YY.Ji\Desktop\Si\PDOS"
519521
pdos = PDOS(pdosfile)
520522
#species = {"Ag": [2], "Cl": [1], "In": [0]}
521-
atom_index = {8: {2: [1, 2]}, 4: {2: [1, 2]}, 10: [1, 2]}
522-
fig, ax = plt.subplots(3, 1, sharex=True)
523-
energy_range = [-1.5, 6]
523+
atom_index = {1: {1: [0, 1]}}
524+
fig, ax = plt.subplots(1, 1, sharex=True)
525+
energy_range = [-5, 7]
526+
efermi = 6.585653952007503
524527
dos_range = [0, 5]
525528

526-
# if you want to specify `species` or `index`, you need to
527-
# set `species=species` or `index=index` in the following two functions
528-
dosplots = pdos.plot(fig, ax, atom_index=atom_index, efermi=5, shift=True,
529-
energy_range=energy_range, dos_range=dos_range, notes=[{'s': '(a)'}, {'s': '(b)'}, {'s': '(c)'}])
529+
# if you want to specify `species` or `index`, you need to
530+
# set `species=species` or `index=index` in the following two functions
531+
dosplots = pdos.plot(fig, ax, atom_index=atom_index, efermi=efermi, shift=True,
532+
energy_range=energy_range, dos_range=dos_range, notes=[{'s': '(a)'}])
530533
fig.savefig("pdos.png")
531-
pdos.write(atom_index=atom_index)
534+
pdos.write(atom_index=atom_index)

tools/plot-tools/abacus_plot/main.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from abacus_plot.band import Band, PBand
1212
from abacus_plot.dos import TDOS, PDOS
13-
from abacus_plot.utils import read_json
13+
from abacus_plot.utils import read_json, key2int
1414

1515

1616
class Show:
@@ -41,9 +41,9 @@ def show_cmdline(cls, args):
4141
shift = text.pop("shift", False)
4242
figsize = text.pop("figsize", (12, 10))
4343
fig, ax = plt.subplots(figsize=figsize)
44-
index = text.pop("index", [])
45-
atom_index = text.pop("atom_index", [])
46-
species = text.pop("species", [])
44+
index = key2int(text.pop("index", None)) # keys of json must be dict, I convert them to int for `index` and `atom_index`
45+
atom_index = key2int(text.pop("atom_index", None))
46+
species = text.pop("species", None)
4747
outdir = text.pop("outdir", './')
4848
cmapname = text.pop("cmapname", 'jet')
4949
pband = PBand(bandfile, kptfile)
@@ -53,9 +53,9 @@ def show_cmdline(cls, args):
5353
text = read_json(args.file)
5454
bandfile = text["bandfile"]
5555
kptfile = text["kptfile"]
56-
index = text.pop("index", [])
57-
atom_index = text.pop("atom_index", [])
58-
species = text.pop("species", [])
56+
index = key2int(text.pop("index", None))
57+
atom_index = key2int(text.pop("atom_index", None))
58+
species = text.pop("species", None)
5959
outdir = text.pop("outdir", './')
6060
pdos = PBand(bandfile, kptfile)
6161
pdos.write(index=index, atom_index=atom_index, species=species, outdir=outdir)
@@ -83,12 +83,19 @@ def show_cmdline(cls, args):
8383
energy_range = text.pop("energy_range", [])
8484
dos_range = text.pop("dos_range", [])
8585
shift = text.pop("shift", False)
86-
index = text.pop("index", [])
87-
atom_index = text.pop("atom_index", [])
88-
species = text.pop("species", [])
86+
index = key2int(text.pop("index", None))
87+
atom_index = key2int(text.pop("atom_index", None))
88+
species = text.pop("species", None)
8989
figsize = text.pop("figsize", (12, 10))
90-
fig, ax = plt.subplots(
91-
len(species), 1, sharex=True, figsize=figsize)
90+
if index:
91+
fig, ax = plt.subplots(
92+
len(index), 1, sharex=True, figsize=figsize)
93+
if atom_index:
94+
fig, ax = plt.subplots(
95+
len(atom_index), 1, sharex=True, figsize=figsize)
96+
if species:
97+
fig, ax = plt.subplots(
98+
len(species), 1, sharex=True, figsize=figsize)
9299
prec = text.pop("prec", 0.01)
93100
pdos = PDOS(pdosfile)
94101
pdos.plot(fig=fig, ax=ax, index=index, atom_index=atom_index, species=species, efermi=efermi, shift=shift,
@@ -99,9 +106,9 @@ def show_cmdline(cls, args):
99106
if args.dos and args.out:
100107
text = read_json(args.file)
101108
pdosfile = text.pop("pdosfile", '')
102-
index = text.pop("index", [])
103-
atom_index = text.pop("atom_index", [])
104-
species = text.pop("species", [])
109+
index = key2int(text.pop("index", None))
110+
atom_index = key2int(text.pop("atom_index", None))
111+
species = text.pop("species", None)
105112
outdir = text.pop("outdir", './')
106113
pdos = PDOS(pdosfile)
107114
pdos.write(index=index, atom_index=atom_index, species=species, outdir=outdir)

0 commit comments

Comments
 (0)