Skip to content

Commit 2e8d9f8

Browse files
revise similarity search to reduce to hex and norm
1 parent 78355e1 commit 2e8d9f8

File tree

6 files changed

+16
-14
lines changed

6 files changed

+16
-14
lines changed

mp_api/client/routes/materials/similarity.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
from __future__ import annotations
22

3-
import zlib
43
from typing import TYPE_CHECKING
54

6-
import numpy as np
75
from emmet.core.mpid import MPID, AlphaID
8-
from emmet.core.similarity import CrystalNNSimilarity, SimilarityDoc, SimilarityEntry
6+
from emmet.core.similarity import (
7+
CrystalNNSimilarity,
8+
SimilarityDoc,
9+
SimilarityEntry,
10+
_vector_to_hex_and_norm,
11+
)
912
from pymatgen.core import Composition, Structure
1013

1114
from mp_api.client.core import BaseRester, MPRestError
1215
from mp_api.client.core.utils import validate_ids
1316

1417
if TYPE_CHECKING:
18+
import numpy as np
1519
from emmet.core.similarity import SimilarityScorer
1620

1721
# This limit seems to be associated with MongoDB vector search
@@ -31,10 +35,6 @@ def fingerprint_structure(self, structure: Structure) -> np.ndarray:
3135
self._fingerprinter = CrystalNNSimilarity()
3236
return self._fingerprinter._featurize_structure(structure)
3337

34-
def _get_hex_fingerprint(self, feature_vetor: np.ndarray) -> str:
35-
"""Convert feature vector fingerprint to compressed hex str."""
36-
return zlib.compress(feature_vetor.tobytes()).hex()
37-
3838
def search(
3939
self,
4040
material_ids: str | list[str] | None = None,
@@ -126,9 +126,11 @@ def find_similar(
126126
"Please specify a positive integer or `None` to return all results."
127127
)
128128

129+
vector_hex, vector_norm = _vector_to_hex_and_norm(feature_vector)
129130
result = self._query_resource(
130131
criteria={
131-
"feature_vector_hex": self._get_hex_fingerprint(feature_vector),
132+
"feature_vector_hex": vector_hex,
133+
"feature_vector_norm": vector_norm,
132134
"_limit": top,
133135
},
134136
suburl="match",

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,15 @@ dependencies = [
2525
"typing-extensions>=3.7.4.1",
2626
"requests>=2.23.0",
2727
"monty>=2024.12.10",
28-
"emmet-core>=0.85.1rc0",
28+
"emmet-core>=0.86.2rc1",
2929
"smart_open",
3030
"boto3",
3131
"orjson >= 3.10,<4",
3232
]
3333
dynamic = ["version"]
3434

3535
[project.optional-dependencies]
36-
all = ["emmet-core[all]>=0.85.1rc0", "custodian", "mpcontribs-client>=5.10"]
36+
all = ["emmet-core[all]>=0.86.2rc1", "custodian", "mpcontribs-client>=5.10"]
3737
test = [
3838
"pre-commit",
3939
"pytest",

requirements/requirements-ubuntu-latest_py3.11.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ contourpy==1.3.3
2424
# via matplotlib
2525
cycler==0.12.1
2626
# via matplotlib
27-
emmet-core==0.86.2rc0
27+
emmet-core==0.86.2rc1
2828
# via mp-api (pyproject.toml)
2929
fonttools==4.60.1
3030
# via matplotlib

requirements/requirements-ubuntu-latest_py3.11_extras.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ dnspython==2.8.0
6262
# pymongo
6363
docutils==0.21.2
6464
# via sphinx
65-
emmet-core[all]==0.86.2rc0
65+
emmet-core[all]==0.86.2rc1
6666
# via mp-api (pyproject.toml)
6767
execnet==2.1.1
6868
# via pytest-xdist

requirements/requirements-ubuntu-latest_py3.12.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ contourpy==1.3.3
2424
# via matplotlib
2525
cycler==0.12.1
2626
# via matplotlib
27-
emmet-core==0.86.2rc0
27+
emmet-core==0.86.2rc1
2828
# via mp-api (pyproject.toml)
2929
fonttools==4.60.1
3030
# via matplotlib

requirements/requirements-ubuntu-latest_py3.12_extras.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ dnspython==2.8.0
6262
# pymongo
6363
docutils==0.21.2
6464
# via sphinx
65-
emmet-core[all]==0.86.2rc0
65+
emmet-core[all]==0.86.2rc1
6666
# via mp-api (pyproject.toml)
6767
execnet==2.1.1
6868
# via pytest-xdist

0 commit comments

Comments
 (0)