55from importlib .resources import files as import_resource_file
66from itertools import islice
77from pathlib import Path
8- from typing import TYPE_CHECKING , Any
8+ from typing import TYPE_CHECKING
99
1010import numpy as np
11- from matminer .utils .io import load_dataframe_from_json
11+ from matminer .utils .io import load_dataframe_from_json # type:ignore[import-untyped]
1212from pymatgen .analysis .prototypes import AflowPrototypeMatcher
13- from pymatgen .core .structure import IStructure
1413from robocrys .condense .fingerprint import (
1514 get_fingerprint_distance ,
1615 get_structure_fingerprint ,
1716)
1817
1918
2019if TYPE_CHECKING :
20+ from typing import Any
2121 import pandas as pd
2222
23- _mineral_db_file = import_resource_file ("robocrys.condense" ) / "mineral_db.json.gz"
23+ from pymatgen .core .structure import Structure
24+
25+
26+ _mineral_db_file = str (import_resource_file ("robocrys.condense" ) / "mineral_db.json.gz" )
2427
2528
2629class MineralMatcher :
@@ -60,9 +63,11 @@ def __init__(
6063 fingerprint_distance_cutoff : float = 0.4 ,
6164 mineral_db : str | Path | pd .DataFrame | None = None ,
6265 ):
63- self .mineral_db = mineral_db if mineral_db is not None else _mineral_db_file
64- if isinstance (self .mineral_db , (str , Path )):
65- self .mineral_db = load_dataframe_from_json (self .mineral_db , pbar = False )
66+ mineral_db = mineral_db or _mineral_db_file
67+ if isinstance (mineral_db , (str , Path )):
68+ self .mineral_db = load_dataframe_from_json (mineral_db , pbar = False )
69+ else :
70+ self .mineral_db = mineral_db
6671
6772 self .initial_ltol = initial_ltol
6873 self .initial_stol = initial_stol
@@ -72,7 +77,7 @@ def __init__(
7277 self ._structure = None
7378 self ._mineral_db = None
7479
75- def get_best_mineral_name (self , structure : IStructure ) -> dict [str , Any ]:
80+ def get_best_mineral_name (self , structure : Structure ) -> dict [str , Any ]:
7681 """Gets the "best" mineral name for a structure.
7782
7883 Uses a combination of AFLOW prototype matching and fingerprinting to
@@ -138,7 +143,7 @@ def get_best_mineral_name(self, structure: IStructure) -> dict[str, Any]:
138143
139144 def get_aflow_matches (
140145 self ,
141- structure : IStructure ,
146+ structure : Structure ,
142147 ) -> list [dict [str , Any ]] | None :
143148 """Gets minerals for a structure by matching to AFLOW prototypes.
144149
@@ -190,7 +195,7 @@ def _match_prototype(structure_matcher, s):
190195
191196 def get_fingerprint_matches (
192197 self ,
193- structure : IStructure ,
198+ structure : Structure ,
194199 max_n_matches : int | None = None ,
195200 match_n_sp : bool = True ,
196201 mineral_name_constraint : str | None = None ,
@@ -240,7 +245,7 @@ def get_fingerprint_matches(
240245
241246 return minerals if minerals else None
242247
243- def _set_distance_matrix (self , structure : IStructure ):
248+ def _set_distance_matrix (self , structure : Structure ):
244249 """Utility func to calculate distance between structure and minerals.
245250
246251 First checks to see if the distances have already been calculated for
0 commit comments