Skip to content

Commit 1d3a81f

Browse files
authored
Fix Arrows.bounding_box() (#321)
* fix Arrows.bounding_box() - return [[min(x), min(y), min(z)], [min(x), min(y), min(z)]] + return [[min(x), min(y), min(z)], [max(x), max(y), max(z)]] * add doc strings
1 parent 0954935 commit 1d3a81f

File tree

6 files changed

+34
-20
lines changed

6 files changed

+34
-20
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ exclude: ^(docs/.+|.*lock.*|jupyterlab-extension/.+|.*\.(svg|js|css))$
1212

1313
repos:
1414
- repo: https://github.com/charliermarsh/ruff-pre-commit
15-
rev: v0.0.252
15+
rev: v0.0.255
1616
hooks:
1717
- id: ruff
1818
args: [--fix, --ignore, D]

crystal_toolkit/core/legend.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -208,12 +208,13 @@ def generate_accessible_color_scheme_on_the_fly(
208208

209209
@staticmethod
210210
def generate_categorical_color_scheme_on_the_fly(
211-
site_collection: SiteCollection, site_prop_types
211+
site_collection: SiteCollection, site_prop_types: dict[str, list[str]]
212212
) -> dict[str, dict[str, tuple[int, int, int]]]:
213-
"""e.g. for Wyckoff.
213+
"""E.g. for Wyckoff.
214214
215215
Args:
216-
site_collection: SiteCollection
216+
site_collection (SiteCollection): The sites to generate a color scheme for.
217+
site_prop_types (dict[str, list[str]]): The categorical site property types.
217218
218219
Returns: A dictionary in similar format to EL_COLORS
219220
"""
@@ -225,9 +226,9 @@ def generate_categorical_color_scheme_on_the_fly(
225226
props = np.array(site_collection.site_properties[site_prop_name])
226227
props[props is None] = "None"
227228

228-
le = LabelEncoder()
229-
le.fit(props)
230-
transformed_props = le.transform(props)
229+
label_enc = LabelEncoder()
230+
label_enc.fit(props)
231+
transformed_props = label_enc.transform(props)
231232

232233
# if we have more categories than available colors,
233234
# arbitrarily group some categories together

crystal_toolkit/core/mpcomponent.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,8 @@ def id(
224224
self,
225225
name: str = "default",
226226
is_kwarg: bool = False,
227-
idx=False,
228-
hint=None,
227+
idx: bool | int = False,
228+
hint: str = None,
229229
is_store: bool = False,
230230
) -> str | dict[str, str]:
231231
"""Generate an id from a name combined with the base id of the MPComponent itself, useful
@@ -241,9 +241,13 @@ def id(
241241
to parse a boolean value. In future iterations, we may be able to replace this with native
242242
Python type hints. The problem here is being able to specify array shape where appropriate.
243243
244-
245244
Args:
246-
name: e.g. "default"
245+
name (str): The name of the element, e.g. "graph", "structure". Defaults to "default".
246+
is_kwarg (bool): If True, return a dict with information necessary to reconstruct
247+
the keyword argument for a specific class.
248+
idx (bool | int): The index to return if is_kwarg is True. Defaults to False.
249+
hint (str): The type hint to return if is_kwarg is True. Defaults to None.
250+
is_store (bool): If True, return the id of the store, otherwise return a dict
247251
248252
Returns: e.g. "MPComponent_default"
249253
"""

crystal_toolkit/core/scene.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ def bounding_box(self) -> list[list[float]]:
133133
"""Returns the bounding box coordinates."""
134134
if len(self.contents) > 0:
135135
min_list, max_list = zip(*[p.bounding_box for p in self.contents])
136-
min_x, min_y, min_z = map(min, list(zip(*min_list)))
137-
max_x, max_y, max_z = map(max, list(zip(*max_list)))
136+
min_x, min_y, min_z = map(min, zip(*min_list))
137+
max_x, max_y, max_z = map(max, zip(*max_list))
138138

139139
return [[min_x, min_y, min_z], [max_x, max_y, max_z]]
140140
else:
@@ -511,7 +511,15 @@ def key(self):
511511
return f"arrow_{self.color}_{self.radius}_{self.headLength}_{self.headWidth}_{self.reference}"
512512

513513
@classmethod
514-
def merge(cls, arrow_list):
514+
def merge(cls, arrow_list: list[Arrows]) -> Arrows:
515+
"""Merge a list of arrows into a new Arrows instance.
516+
517+
Args:
518+
arrow_list (list[Arrows]): Arrows to merge
519+
520+
Returns:
521+
Arrows: Merged arrows
522+
"""
515523
new_positionPairs = list(
516524
chain.from_iterable([arrow.positionPairs for arrow in arrow_list])
517525
)
@@ -527,7 +535,7 @@ def merge(cls, arrow_list):
527535
@property
528536
def bounding_box(self) -> list[list[float]]:
529537
x, y, z = zip(*chain.from_iterable(self.positionPairs))
530-
return [[min(x), min(y), min(z)], [min(x), min(y), min(z)]]
538+
return [[min(x), min(y), min(z)], [max(x), max(y), max(z)]]
531539

532540

533541
@dataclass

crystal_toolkit/helpers/utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,7 @@ def pretty_frac_format(x: float) -> str:
515515
return x_str
516516

517517

518-
def hook_up_fig_with_ctk_struct_viewer(
518+
def hook_up_fig_with_struct_viewer(
519519
fig: go.Figure,
520520
df: pd.DataFrame,
521521
struct_col: str = "structure",
@@ -534,16 +534,17 @@ def hook_up_fig_with_ctk_struct_viewer(
534534
from pymatgen.ext.matproj import MPRester
535535
536536
# Get random structures from the Materials Project
537-
mp_ids = [f"mp-{random.randint(1, 10000)}" for _ in range(100)]
538-
structures = MPRester(use_document_model=False).summary.search(material_ids=mp_ids)
537+
mp_ids = [f"mp-{random.randint(1, 10_000)}" for _ in range(100)]
538+
docs = MPRester(use_document_model=False).summary.search(material_ids=mp_ids)
539539
540-
df = pd.DataFrame(structures)
540+
df = pd.DataFrame(docs)
541541
id_col = "material_id"
542542
543543
fig = px.scatter(df, x="nsites", y="volume", hover_name=id_col, template="plotly_white")
544544
app = hook_up_fig_with_ctk_struct_viewer(fig, df.set_index(id_col))
545545
app.run_server(port=8000)
546546
547+
547548
Args:
548549
fig (Figure): Plotly figure to be hooked up with the structure component. The
549550
figure must have hover_name set to the index of the data frame to identify

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ build-backend = "setuptools.build_meta"
6767
target-version = "py38"
6868
select = [
6969
"B", # flake8-bugbear
70-
"C4", # flake8-comprehensions
70+
"C40", # flake8-comprehensions
7171
"D", # pydocstyle
7272
"E", # pycodestyle
7373
"F", # pyflakes

0 commit comments

Comments
 (0)