Skip to content

Commit be3e06e

Browse files
committed
linting
1 parent 6ddc66d commit be3e06e

File tree

1 file changed

+53
-110
lines changed

1 file changed

+53
-110
lines changed

crystal_toolkit/components/phonon.py

Lines changed: 53 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import itertools
4+
from copy import deepcopy
45
from typing import TYPE_CHECKING, Any
56

67
import numpy as np
@@ -28,9 +29,7 @@
2829
from pymatgen.electronic_structure.bandstructure import BandStructureSymmLine
2930
from pymatgen.electronic_structure.dos import CompleteDos
3031

31-
# Author: Jason Munro, Janosh Riebesell
32-
33-
32+
DISPLACE_COEF = [0, 1, 0, -1, 0]
3433

3534
# TODOs:
3635
# - look for additional projection methods in phonon DOS (currently only atom
@@ -214,131 +213,75 @@ def _get_eigendisplacement(
214213
precision: int = 15,
215214
magnitude: int = 15,
216215
) -> dict:
217-
if not ph_bs:
216+
if not ph_bs or not json_data:
218217
return {}
219218

220-
# get displacement
221-
min_bond_length = float("inf")
222-
for content_idx in range(len(json_data["contents"][1]["contents"])):
223-
for pair_idx in range(
224-
len(json_data["contents"][1]["contents"][content_idx]["_meta"])
225-
):
226-
u, v = json_data["contents"][1]["contents"][content_idx][
227-
"positionPairs"
228-
][pair_idx]
229-
# Convert to numpy arrays
230-
u = np.array(u)
231-
v = np.array(v)
232-
length = np.linalg.norm(v - u)
233-
min_bond_length = min(min_bond_length, length)
234-
235-
# atom animate
236219
assert json_data["contents"][0]["name"] == "atoms"
237-
for content_idx in range(len(json_data["contents"][0]["contents"])):
238-
atom_idx = json_data["contents"][0]["contents"][content_idx]["_meta"][0]
239-
240-
raw_displacement = ph_bs.eigendisplacements[band][qpoint][atom_idx]
220+
assert json_data["contents"][1]["name"] == "bonds"
221+
rdata = deepcopy(json_data)
241222

242-
displacement = [complex(vec).real * magnitude for vec in raw_displacement]
223+
def calc_displacement(idx: int) -> list:
224+
return [
225+
round(complex(vec).real * magnitude, precision)
226+
for vec in ph_bs.eigendisplacements[band][qpoint][idx]
227+
]
243228

244-
position_animation = []
245-
for displace_coef in [0, 1, 0, -1, 0]:
246-
displace = [
247-
round(displace_coef * magnitude * d, precision)
248-
for d in displacement
249-
]
250-
position_animation.append(displace)
229+
def calc_animation_step(displacement: list, coef: int) -> list:
230+
return [round(coef * magnitude * d, precision) for d in displacement]
251231

252-
json_data["contents"][0]["contents"][content_idx]["animate"] = (
253-
position_animation
254-
)
255-
json_data["contents"][0]["contents"][content_idx]["keyframes"] = [
256-
0,
257-
1,
258-
2,
259-
3,
260-
4,
232+
# atom animate
233+
contents0 = json_data["contents"][0]["contents"]
234+
for cidx, content in enumerate(contents0):
235+
displacement = calc_displacement(content["_meta"][0])
236+
rcontent = rdata["contents"][0]["contents"][cidx]
237+
rcontent["animate"] = [
238+
calc_animation_step(displacement, coef) for coef in DISPLACE_COEF
261239
]
262-
json_data["contents"][0]["contents"][content_idx]["animateType"] = (
263-
"displacement"
264-
)
240+
rcontent["keyframes"] = list(range(5))
241+
rcontent["animateType"] = "displacement"
265242

266-
# bond animate
267-
assert json_data["contents"][1]["name"] == "bonds"
268-
for content_idx in range(len(json_data["contents"][1]["contents"])):
243+
# get displacement and bond animate
244+
min_bond_length = float("inf")
245+
contents1 = json_data["contents"][1]["contents"]
246+
for cidx, content in enumerate(contents1):
269247
bond_animation = []
248+
assert len(content["_meta"]) == len(content["positionPairs"])
270249

271-
assert len(
272-
json_data["contents"][1]["contents"][content_idx]["_meta"]
273-
) == len(json_data["contents"][1]["contents"][content_idx]["positionPairs"])
274-
275-
for pair_idx in range(
276-
len(json_data["contents"][1]["contents"][content_idx]["_meta"])
277-
):
278-
u_idx, v_idx = json_data["contents"][1]["contents"][content_idx][
279-
"_meta"
280-
][pair_idx]
281-
282-
# u
283-
u_raw_displacement = ph_bs.eigendisplacements[band][qpoint][u_idx]
284-
u_displacement = [
285-
round(complex(vec).real * magnitude, precision)
286-
for vec in u_raw_displacement
287-
]
288-
289-
# v
290-
v_raw_displacement = ph_bs.eigendisplacements[band][qpoint][v_idx]
291-
v_displacement = [
292-
round(complex(vec).real * magnitude, precision)
293-
for vec in v_raw_displacement
294-
]
250+
for pair in enumerate(content["_meta"]):
251+
u, v = rdata["contents"][1]["contents"][cidx]["positionPairs"] = list(
252+
map(np.array, pair)
253+
)
254+
length = np.linalg.norm(v - u)
255+
min_bond_length = min(min_bond_length, length)
256+
displacements = list(map(calc_displacement, pair))
257+
u_to_middle_bond_animation = []
295258

296-
# only draw in unit cell
297-
u_to_middle_bond_animation = [] # u to middle
298-
# v_to_middle_bond_animation = [] # v to middle
299-
for displace_coef in [0, 1, 0, -1, 0]:
300-
u_end_displacement = [
301-
round(displace_coef * magnitude * d, precision)
302-
for d in u_displacement
303-
]
304-
v_end_displacement = [
305-
round(displace_coef * magnitude * d, precision)
306-
for d in v_displacement
307-
]
259+
for coef in DISPLACE_COEF:
308260
middle_end_displacement = (
309-
(np.array(u_end_displacement) + np.array(v_end_displacement))
261+
np.add(
262+
np.array(calc_animation_step(displacement, coef))
263+
for displacement in displacements
264+
)
310265
/ 2
311-
).tolist()
312-
middle_end_displacement = [
313-
round(dis, precision) for dis in middle_end_displacement
314-
]
315-
316-
u2middle_animation = [u_end_displacement, middle_end_displacement]
317-
# v2middle_animation = [v_end_displacement, middle_end_displacement]
318-
319-
u_to_middle_bond_animation.append(u2middle_animation)
320-
# v_to_middle_bond_animation.append(v2middle_animation)
266+
)
267+
u_to_middle_bond_animation.append(
268+
[
269+
displacements[0],
270+
[round(dis, precision) for dis in middle_end_displacement],
271+
]
272+
)
321273

322274
bond_animation.append(u_to_middle_bond_animation)
323-
json_data["contents"][1]["contents"][content_idx]["animate"] = (
324-
bond_animation
325-
)
326-
json_data["contents"][1]["contents"][content_idx]["keyframes"] = [
327-
0,
328-
1,
329-
2,
330-
3,
331-
4,
332-
]
333-
json_data["contents"][1]["contents"][content_idx]["animateType"] = (
334-
"displacement"
335-
)
275+
276+
rdata["contents"][1]["contents"][cidx]["animate"] = bond_animation
277+
rdata["contents"][1]["contents"][cidx]["keyframes"] = list(range(5))
278+
rdata["contents"][1]["contents"][cidx]["animateType"] = "displacement"
336279

337280
# remove polyhedra manually
338-
json_data["contents"][2]["visible"] = False
339-
json_data["contents"][3]["visible"] = False
281+
for i in range(2, 4):
282+
rdata["contents"][i]["visible"] = False
340283

341-
return json_data
284+
return rdata
342285

343286
@staticmethod
344287
def _get_ph_bs_dos(

0 commit comments

Comments
 (0)