66import numpy as np
77from emmet .core .mpid import MPID , AlphaID
88from emmet .core .similarity import CrystalNNSimilarity , SimilarityDoc , SimilarityEntry
9- from pymatgen .core import Composition
9+ from pymatgen .core import Composition , Structure
1010
1111from mp_api .client .core import BaseRester , MPRestError
1212from mp_api .client .core .utils import validate_ids
1313
1414if TYPE_CHECKING :
1515 from emmet .core .similarity import SimilarityScorer
16- from pymatgen .core import Structure
16+
17+ # This limit seems to be associated with MongoDB vector search
18+ MAX_VECTOR_SEARCH_RESULTS = 10_000
1719
1820
1921class SimilarityRester (BaseRester ):
@@ -23,11 +25,15 @@ class SimilarityRester(BaseRester):
2325
2426 _fingerprinter : SimilarityScorer | None = None
2527
26- @ property
27- def fingerprinter ( self , structure : Structure ) -> list [ float ]:
28+ def fingerprint_structure ( self , structure : Structure ) -> np . ndarray :
29+ """Get the fingerprint of a user-submitted structures."""
2830 if self ._fingerprinter is None :
2931 self ._fingerprinter = CrystalNNSimilarity ()
30- return self ._fingerprinter ()._featurize_structure (structure ).tolist ()
32+ return self ._fingerprinter ._featurize_structure (structure )
33+
34+ def _get_hex_fingerprint (self , feature_vetor : np .ndarray ) -> str :
35+ """Convert feature vector fingerprint to compressed hex str."""
36+ return zlib .compress (feature_vetor .tobytes ()).hex ()
3137
3238 def search (
3339 self ,
@@ -75,18 +81,24 @@ def search(
7581 def find_similar (
7682 self ,
7783 structure_or_mpid : Structure | str | MPID | AlphaID ,
84+ top : int | None = 50 ,
7885 num_chunks : int | None = None ,
7986 chunk_size : int | None = 1000 ,
8087 ) -> list [SimilarityEntry ] | list [dict ]:
81- """Find structures similar to a user-submitted structure.
88+ """Find structures most similar to a user-submitted structure.
8289
8390 Arguments:
8491 structure_or_mpid : pymatgen .Structure, or str, MPID, AlphaID
8592 If a .Structure, the feature vector is computed on the fly
8693 If a str, MPID, or AlphaID, attempts to retrieve a pre-computed
8794 feature vector using the input as a material ID
95+ top : int
96+ The number of most similar materials to return, defaults to 50.
97+ Setting to None will return the maximum possible number of
98+ most similar materials..
8899 num_chunks (int or None): Maximum number of chunks of data to yield. None will yield all possible.
89100 chunk_size (int or None): Number of data entries per chunk.
101+ The chunk_size is also used to limit the number of responses returned.
90102
91103 Returns:
92104 ([SimilarityEntry] | [dict]) List of SimilarityEntry documents
@@ -100,15 +112,24 @@ def find_similar(
100112 if not docs :
101113 raise MPRestError (f"No similarity data available for { fmt_idx } " )
102114 feature_vector = docs [0 ]["feature_vector" ]
115+
116+ elif isinstance (structure_or_mpid , Structure ):
117+ feature_vector = self .fingerprint_structure (structure_or_mpid )
118+
103119 else :
104- feature_vector = self .fingerprinter (structure_or_mpid )
120+ raise ValueError ("Please submit a pymatgen Structure or MP ID." )
121+
122+ top = top or MAX_VECTOR_SEARCH_RESULTS
123+ if not isinstance (top , int ) or top < 1 :
124+ raise ValueError (
125+ f"Invalid number of possible top matches specified = { top } ."
126+ "Please specify a positive integer or `None` to return all results."
127+ )
105128
106129 result = self ._query_resource (
107130 criteria = {
108- "feature_vector_hex" : zlib .compress (
109- np .array (feature_vector ).tobytes ()
110- ).hex (),
111- "_limit" : chunk_size ,
131+ "feature_vector_hex" : self ._get_hex_fingerprint (feature_vector ),
132+ "_limit" : top ,
112133 },
113134 suburl = "match" ,
114135 use_document_model = False , # Return type is not exactly a SimilarityDoc, closer to SimilarityEntry
0 commit comments