Skip to content

Commit 3af964c

Browse files
draft search by similarity feature vec
1 parent 52a3c57 commit 3af964c

File tree

1 file changed

+74
-2
lines changed

1 file changed

+74
-2
lines changed

mp_api/client/routes/materials/similarity.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,31 @@
11
from __future__ import annotations
22

3-
from emmet.core.similarity import SimilarityDoc
3+
from typing import TYPE_CHECKING
44

5-
from mp_api.client.core import BaseRester
5+
from emmet.core.mpid import MPID, AlphaID
6+
from emmet.core.similarity import CrystalNNSimilarity, SimilarityDoc, SimilarityEntry
7+
8+
from mp_api.client.core import BaseRester, MPRestError
69
from mp_api.client.core.utils import validate_ids
710

11+
if TYPE_CHECKING:
12+
from emmet.core.similarity import SimilarityScorer
13+
from pymatgen.core import Structure
14+
815

916
class SimilarityRester(BaseRester):
1017
suffix = "materials/similarity"
1118
document_model = SimilarityDoc # type: ignore
1219
primary_key = "material_id"
1320

21+
_fingerprinter: SimilarityScorer | None = None
22+
23+
@property
24+
def fingerprinter(self, structure: Structure) -> list[float]:
25+
if self._fingerprinter is None:
26+
self._fingerprinter = CrystalNNSimilarity()
27+
return self._fingerprinter()._featurize_structure(structure).tolist()
28+
1429
def search(
1530
self,
1631
material_ids: str | list[str] | None = None,
@@ -53,3 +68,60 @@ def search(
5368
fields=fields,
5469
**query_params,
5570
)
71+
72+
def find_similar(
73+
self,
74+
structure_or_mpid: Structure | str | MPID | AlphaID,
75+
num_chunks: int | None = None,
76+
chunk_size: int | None = 1000,
77+
) -> list[SimilarityEntry] | list[dict]:
78+
"""Find structures similar to a user-submitted structure.
79+
80+
Arguments:
81+
structure_or_mpid : pymatgen .Structure, or str, MPID, AlphaID
82+
If a .Structure, the feature vector is computed on the fly
83+
If a str, MPID, or AlphaID, attempts to retrieve a pre-computed
84+
feature vector using the input as a material ID
85+
num_chunks (int or None): Maximum number of chunks of data to yield. None will yield all possible.
86+
chunk_size (int or None): Number of data entries per chunk.
87+
88+
Returns:
89+
([SimilarityEntry] | [dict]) List of SimilarityEntry documents
90+
(if `use_document_model`) or dict (otherwise) listing
91+
structures most similar to the input structure.
92+
"""
93+
if isinstance(structure_or_mpid, str | MPID | AlphaID):
94+
fmt_idx = AlphaID(structure_or_mpid).string
95+
96+
docs = self.search(material_ids=[fmt_idx], fields=["feature_vector"])
97+
if not docs:
98+
raise MPRestError(f"No similarity data available for {fmt_idx}")
99+
feature_vector = docs[0]["feature_vector"]
100+
else:
101+
feature_vector = self.fingerprinter(structure_or_mpid)
102+
103+
result = self._query_resource(
104+
criteria={"feature_vector": feature_vector, "_limit": chunk_size},
105+
suburl="match",
106+
use_document_model=False, # Return type is not exactly a SimilarityDoc, closer to SimilarityEntry
107+
chunk_size=chunk_size,
108+
num_chunks=num_chunks,
109+
).get("data", None)
110+
111+
if result is None:
112+
raise MPRestError(
113+
"Could not find any structures similar to the input structure."
114+
)
115+
116+
sim_docs = [
117+
{
118+
"formula": entry["formula_pretty"],
119+
"task_id": entry["material_id"],
120+
"dissimilarity": 100 * (1.0 - entry["score"]),
121+
}
122+
for entry in result
123+
]
124+
125+
if self.use_document_model:
126+
return [SimilarityEntry(**doc) for doc in sim_docs]
127+
return sim_docs

0 commit comments

Comments
 (0)