|
25 | 25 |
|
26 | 26 | import pymatviz as pmv |
27 | 27 | import pymatviz.cluster.composition as pcc |
| 28 | +from pymatviz.cluster.composition import EmbeddingMethod as Embed |
| 29 | +from pymatviz.cluster.composition import ProjectionMethod as Project |
28 | 30 | from pymatviz.enums import Key |
29 | 31 |
|
30 | 32 |
|
31 | 33 | if TYPE_CHECKING: |
32 | 34 | import plotly.graph_objects as go |
33 | 35 |
|
34 | | - from pymatviz.cluster.composition import ProjectionMethod |
35 | | - |
36 | 36 |
|
37 | 37 | pmv.set_plotly_template("pymatviz_white") |
38 | 38 | module_dir = os.path.dirname(__file__) |
@@ -61,7 +61,7 @@ def process_dataset( |
61 | 61 | target_label: str, |
62 | 62 | target_symbol: str, |
63 | 63 | embed_method: pcc.EmbeddingMethod, |
64 | | - projection: ProjectionMethod, |
| 64 | + projection: pcc.ProjectionMethod, |
65 | 65 | n_components: int, |
66 | 66 | **kwargs: Any, |
67 | 67 | ) -> go.Figure: |
@@ -110,7 +110,7 @@ def process_dataset( |
110 | 110 | # Create embeddings |
111 | 111 | if embed_method == "one-hot": |
112 | 112 | embeddings = pcc.one_hot_encode(compositions) |
113 | | - elif embed_method in ["magpie", "matscholar_el"]: |
| 113 | + elif embed_method in (Embed.magpie, Embed.matscholar_el): |
114 | 114 | embeddings = pcc.matminer_featurize(compositions, preset=embed_method) |
115 | 115 | else: |
116 | 116 | raise ValueError(f"Unknown {embed_method=}") |
@@ -144,14 +144,15 @@ def annotate_top_points(row: pd.Series) -> dict[str, Any] | None: |
144 | 144 | text = f"{comp_str}<br>{prop_val}" |
145 | 145 | return dict(text=text, font_size=11, bgcolor="rgba(240, 240, 240, 0.5)") |
146 | 146 |
|
147 | | - if "embeddings" not in df_plot: |
148 | | - df_plot["embeddings"] = [embeddings_dict.get(comp) for comp in compositions] |
| 147 | + embed_col = "embeddings" |
| 148 | + if embed_col not in df_plot: |
| 149 | + df_plot[embed_col] = [embeddings_dict.get(comp) for comp in compositions] |
149 | 150 |
|
150 | 151 | fig = pmv.cluster_compositions( |
151 | 152 | df_in=df_plot, |
152 | 153 | composition_col="composition", |
153 | 154 | prop_name=target_label, |
154 | | - embedding_method="embeddings", |
| 155 | + embedding_method=embed_col, |
155 | 156 | projection=projection, |
156 | 157 | n_components=n_components, |
157 | 158 | marker_size=8, |
@@ -185,68 +186,72 @@ def annotate_top_points(row: pd.Series) -> dict[str, Any] | None: |
185 | 186 | return fig |
186 | 187 |
|
187 | 188 |
|
188 | | -mb_jdft2d = ( |
| 189 | +mb_jdft2d: tuple[str, str, str, str] = ( |
189 | 190 | "matbench_jdft2d", |
190 | 191 | "exfoliation_en", |
191 | 192 | "Exfoliation Energy (meV/atom)", |
192 | 193 | "E<sub>ex</sub>", |
193 | 194 | ) |
194 | | -mb_steels = ( |
| 195 | +mb_steels: tuple[str, str, str, str] = ( |
195 | 196 | "matbench_steels", |
196 | 197 | "yield strength", |
197 | 198 | "Yield Strength (MPa)", |
198 | 199 | "σ", |
199 | 200 | ) |
200 | | -mb_dielectric = ( |
| 201 | +mb_dielectric: tuple[str, str, str, str] = ( |
201 | 202 | "matbench_dielectric", |
202 | 203 | "n", |
203 | 204 | "Refractive index", |
204 | 205 | "n", |
205 | 206 | ) |
206 | | -mb_perovskites = ( |
| 207 | +mb_perovskites: tuple[str, str, str, str] = ( |
207 | 208 | "matbench_perovskites", |
208 | 209 | "e_form", |
209 | 210 | "Formation energy (eV/atom)", |
210 | 211 | "E<sub>f</sub>", |
211 | 212 | ) |
212 | | -mb_phonons = ( |
| 213 | +mb_phonons: tuple[str, str, str, str] = ( |
213 | 214 | "matbench_phonons", |
214 | 215 | "last phdos peak", |
215 | 216 | "Max Phonon Peak (cm⁻¹)", |
216 | 217 | "ν<sub>max</sub>", |
217 | 218 | ) |
218 | | -mb_bulk_modulus = ( |
| 219 | +mb_bulk_modulus: tuple[str, str, str, str] = ( |
219 | 220 | "matbench_log_kvrh", |
220 | 221 | "log10(K_VRH)", |
221 | 222 | "Bulk Modulus (GPa)", |
222 | 223 | "K<sub>VRH</sub>", |
223 | 224 | ) |
224 | 225 | plot_combinations: list[ # type: ignore[invalid-assignment] |
225 | | - tuple[ |
226 | | - str, str, str, str, pcc.EmbeddingMethod, ProjectionMethod, int, dict[str, Any] |
227 | | - ] |
| 226 | + tuple[str, str, str, str, Embed, Project, int, dict[str, Any]] |
228 | 227 | ] = [ |
229 | 228 | # 1. Steels with PCA (2D) - shows clear linear trends |
230 | | - (*mb_steels, "magpie", "pca", 2, dict(x=0.01, xanchor="left")), |
| 229 | + (*mb_steels, Embed.magpie, Project.pca, 2, dict(x=0.01, xanchor="left")), |
231 | 230 | # 2. Steels with t-SNE (2D) - shows non-linear clustering |
232 | | - (*mb_steels, "magpie", "tsne", 2, dict(x=0.01, xanchor="left")), |
| 231 | + (*mb_steels, Embed.magpie, Project.tsne, 2, dict(x=0.01, xanchor="left")), |
233 | 232 | # TODO umap-learn seemingly not installed by uv run in CI, fix later |
234 | 233 | # 3. JDFT2D with UMAP (2D) - shows modern non-linear projection |
235 | | - # (*mb_jdft2d, "magpie", "umap", 2, dict(x=0.01, xanchor="left")), |
| 234 | + # (*mb_jdft2d, Embed.magpie, Project.umap, 2, dict(x=0.01, xanchor="left")), |
236 | 235 | # 4. JDFT2D with one-hot encoding and PCA (3D) - shows raw element relationships |
237 | | - (*mb_jdft2d, "one-hot", "pca", 3, dict()), |
| 236 | + (*mb_jdft2d, Embed.one_hot, Project.pca, 3, dict()), |
238 | 237 | # 5. Steels with Matscholar embedding and t-SNE (3D) - shows advanced embedding |
239 | | - (*mb_steels, "matscholar_el", "tsne", 3, dict(x=0.5, y=0.8)), |
| 238 | + (*mb_steels, Embed.matscholar_el, Project.tsne, 3, dict(x=0.5, y=0.8)), |
240 | 239 | # 6. Dielectric with PCA (2D) - shows clear linear trends |
241 | | - (*mb_dielectric, "magpie", "pca", 2, dict(x=0.01, xanchor="left")), |
| 240 | + (*mb_dielectric, Embed.magpie, Project.pca, 2, dict(x=0.01, xanchor="left")), |
242 | 241 | # 7. Perovskites with PCA (2D) - shows clear linear trends |
243 | | - (*mb_perovskites, "magpie", "pca", 2, dict(x=0.01, xanchor="left")), |
| 242 | + (*mb_perovskites, Embed.magpie, Project.pca, 2, dict(x=0.01, xanchor="left")), |
244 | 243 | # 8. Phonons with PCA (2D) - shows clear linear trends |
245 | | - (*mb_phonons, "magpie", "pca", 2, dict(x=0.01, xanchor="left")), |
| 244 | + (*mb_phonons, Embed.magpie, Project.pca, 2, dict(x=0.01, xanchor="left")), |
246 | 245 | # 9. Bulk Modulus with PCA (2D) - shows clear linear trends |
247 | | - (*mb_bulk_modulus, "magpie", "pca", 2, dict(x=0.99, y=0.96, yanchor="top")), |
| 246 | + ( |
| 247 | + *mb_bulk_modulus, |
| 248 | + Embed.magpie, |
| 249 | + Project.pca, |
| 250 | + 2, |
| 251 | + dict(x=0.99, y=0.96, yanchor="top"), |
| 252 | + ), |
248 | 253 | # 10. Perovskites with t-SNE (3D) - shows raw element relationships |
249 | | - (*mb_perovskites, "magpie", "tsne", 3, dict()), |
| 254 | + (*mb_perovskites, Embed.magpie, Project.tsne, 3, dict()), |
250 | 255 | ] |
251 | 256 |
|
252 | 257 | for ( |
|
0 commit comments