Skip to content

Commit beca494

Browse files
authored
Merge pull request #276 from rdguha1995/main
First pass at bond order visualizations for molecules [WIP]
2 parents ea6a5a6 + 5a005b1 commit beca494

File tree

5 files changed

+98
-15
lines changed

5 files changed

+98
-15
lines changed

crystal_toolkit/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

3-
from pathlib import Path
43
from importlib.metadata import PackageNotFoundError, version
4+
from pathlib import Path
55

66
from monty.json import MSONable
77

crystal_toolkit/core/scene.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ class Cylinders(Primitive):
316316

317317
@property
318318
def key(self):
319-
return f"cylinder_{self.color}_{self.radius}_{self.reference}"
319+
return f"cylinder_{self.color}_{self.radius}_{self.reference}_{self.clickable}_{self.tooltip}"
320320

321321
@classmethod
322322
def merge(cls, cylinder_list):
@@ -330,6 +330,8 @@ def merge(cls, cylinder_list):
330330
color=cylinder_list[0].color,
331331
radius=cylinder_list[0].radius,
332332
visible=cylinder_list[0].visible,
333+
clickable=cylinder_list[0].clickable,
334+
tooltip=cylinder_list[0].tooltip,
333335
)
334336

335337
@property

crystal_toolkit/renderables/moleculegraph.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from collections import defaultdict
44

55
from pymatgen.analysis.graphs import MoleculeGraph
6+
from pymatgen.analysis.local_env import OpenBabelNN
67

78
from crystal_toolkit.core.legend import Legend
89
from crystal_toolkit.core.scene import Scene
@@ -16,21 +17,47 @@ def get_molecule_graph_scene(
1617
explicitly_calculate_polyhedra_hull=False,
1718
legend=None,
1819
draw_polyhedra=False,
20+
show_atom_idx=True,
21+
show_atom_coord=True,
22+
show_bond_order=True,
23+
show_bond_length=False,
24+
visualize_bond_orders=False,
1925
) -> Scene:
2026

27+
"""
28+
Args:
29+
show_atom_idx: Defaults to True, shows the site index of each atom in the molecule
30+
show_atom_coord: Defaults to True, shows the 3D coordinates of each atom in the molecule
31+
show_bond_order: Defaults to True, shows the calculated bond order in the chosen local environment strategy
32+
show_bond_length: Defaults to False, shows the calculated length between two connected atoms
33+
visualize_bpnd_orders: Defaults False, will show the 'integral' number of bonds calculated from the OpenBabelNN strategy in the Molecule Graph
34+
Returns:
35+
A Molecule Graph scene
36+
"""
37+
38+
vis_mol_graph = MoleculeGraph.with_local_env_strategy(self.molecule, OpenBabelNN())
2139
legend = legend or Legend(self.molecule)
2240

2341
primitives: dict[str, list] = defaultdict(list)
2442

2543
for idx, site in enumerate(self.molecule):
2644

27-
connected_sites = self.get_connected_sites(idx)
45+
if visualize_bond_orders:
46+
connected_sites = vis_mol_graph.get_connected_sites(idx)
47+
else:
48+
connected_sites = self.get_connected_sites(idx)
2849

2950
site_scene = site.get_scene(
51+
site_idx=idx,
3052
connected_sites=connected_sites,
3153
origin=origin,
3254
explicitly_calculate_polyhedra_hull=explicitly_calculate_polyhedra_hull,
3355
legend=legend,
56+
show_atom_idx=show_atom_idx,
57+
show_atom_coord=show_atom_coord,
58+
show_bond_order=show_bond_order,
59+
show_bond_length=show_bond_length,
60+
visualize_bond_orders=visualize_bond_orders,
3461
draw_polyhedra=draw_polyhedra,
3562
)
3663
for scene in site_scene.contents:

crystal_toolkit/renderables/site.py

Lines changed: 65 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def get_site_scene(
2828
# connected_sites_to_draw,
2929
connected_sites_not_drawn: list[ConnectedSite] = None,
3030
hide_incomplete_edges: bool = False,
31+
site_idx: int | None = 0,
3132
incomplete_edge_length_scale: float | None = 1.0,
3233
connected_sites_colors: list[str] | None = None,
3334
connected_sites_not_drawn_colors: list[str] | None = None,
@@ -36,24 +37,27 @@ def get_site_scene(
3637
explicitly_calculate_polyhedra_hull: bool = False,
3738
bond_radius: float = 0.1,
3839
draw_magmoms: bool = True,
40+
show_atom_idx: bool = False,
41+
show_atom_coord: bool = True,
42+
show_bond_order: bool = True,
43+
show_bond_length: bool = False,
44+
visualize_bond_orders: bool = False,
3945
magmom_scale: float = 1.0,
4046
legend: Legend | None = None,
4147
) -> Scene:
4248
"""
43-
4449
Args:
4550
connected_sites:
4651
connected_sites_not_drawn:
4752
hide_incomplete_edges:
53+
site_idx:
4854
incomplete_edge_length_scale:
4955
connected_sites_colors:
5056
connected_sites_not_drawn_colors:
5157
origin:
5258
explicitly_calculate_polyhedra_hull:
5359
legend:
54-
5560
Returns:
56-
5761
"""
5862

5963
atoms = []
@@ -74,7 +78,6 @@ def get_site_scene(
7478
max_radius = float(min(radii))
7579

7680
for sp, occu in self.species.items():
77-
7881
if isinstance(sp, DummySpecie):
7982

8083
cube = Cubes(
@@ -101,7 +104,12 @@ def get_site_scene(
101104
name = str(sp)
102105
if occu != 1.0:
103106
name += f" ({occu}% occupancy)"
104-
name += f" ({position[0]:.3f}, {position[1]:.3f}, {position[2]:.3f})"
107+
108+
if show_atom_coord:
109+
name += f" ({position[0]:.3f}, {position[1]:.3f}, {position[2]:.3f})"
110+
111+
if show_atom_idx:
112+
name += "\n" + "index:" + str(site_idx)
105113

106114
if self.properties:
107115
for k, v in self.properties.items():
@@ -157,9 +165,20 @@ def get_site_scene(
157165
# TODO: can cause a bug if all vertices almost co-planar
158166
# necessary to include center site in case it's outside polyhedra
159167
all_positions = [self.coords]
168+
name_cyl = " "
160169

161170
for idx, connected_site in enumerate(connected_sites):
162171

172+
if show_bond_order:
173+
if connected_site.weight is not None:
174+
name_cyl = "bond order:" + str(f"{connected_site.weight:.2f}")
175+
176+
if show_bond_length:
177+
if connected_site.dist is not None:
178+
name_cyl += (
179+
"\n" + "bond length:" + str(f"{connected_site.dist:.3f}")
180+
)
181+
163182
connected_position = connected_site.site.coords
164183
bond_midpoint = np.add(position, connected_position) / 2
165184

@@ -168,12 +187,47 @@ def get_site_scene(
168187
else:
169188
color = site_color
170189

171-
cylinder = Cylinders(
172-
positionPairs=[[position, bond_midpoint.tolist()]],
173-
color=color,
174-
radius=bond_radius,
175-
)
176-
bonds.append(cylinder)
190+
if visualize_bond_orders:
191+
cylinders = []
192+
193+
if connected_site.weight is not None:
194+
195+
if connected_site.weight > 1:
196+
trans_vector = 0.0
197+
for _bond in range(connected_site.weight):
198+
pos_r_1 = [i + trans_vector for i in position]
199+
pos_r_2 = [i + trans_vector for i in bond_midpoint.tolist()]
200+
cylinders.append(
201+
Cylinders(
202+
positionPairs=[[pos_r_1, pos_r_2]],
203+
color=color,
204+
radius=bond_radius / 2,
205+
clickable=True,
206+
tooltip=name_cyl,
207+
)
208+
)
209+
trans_vector = trans_vector + 0.25 * max_radius
210+
for cylinder in cylinders:
211+
bonds.append(cylinder)
212+
else:
213+
cylinder = Cylinders(
214+
positionPairs=[[position, bond_midpoint.tolist()]],
215+
color=color,
216+
radius=bond_radius,
217+
clickable=True,
218+
tooltip=name_cyl,
219+
)
220+
bonds.append(cylinder)
221+
222+
else:
223+
cylinder = Cylinders(
224+
positionPairs=[[position, bond_midpoint.tolist()]],
225+
color=color,
226+
radius=bond_radius,
227+
clickable=True,
228+
tooltip=name_cyl,
229+
)
230+
bonds.append(cylinder)
177231
all_positions.append(connected_position.tolist())
178232

179233
if connected_sites_not_drawn and not hide_incomplete_edges:

jupyterlab-extension/tsconfig.tsbuildinfo

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1190,4 +1190,4 @@
11901190
]
11911191
},
11921192
"version": "4.1.3"
1193-
}
1193+
}

0 commit comments

Comments
 (0)