33from spacy .vectors import Vectors
44from spacy .strings import StringStore
55from spacy .util import SimpleFrozenDict
6+ from thinc .api import NumpyOps
67import numpy
78import srsly
89
@@ -247,7 +248,11 @@ def get_other_senses(
247248 result = []
248249 key = key if isinstance (key , str ) else self .strings [key ]
249250 word , orig_sense = self .split_key (key )
250- versions = set ([word , word .lower (), word .upper (), word .title ()]) if ignore_case else [word ]
251+ versions = (
252+ set ([word , word .lower (), word .upper (), word .title ()])
253+ if ignore_case
254+ else [word ]
255+ )
251256 for text in versions :
252257 for sense in self .senses :
253258 new_key = self .make_key (text , sense )
@@ -270,7 +275,11 @@ def get_best_sense(
270275 sense_options = senses or self .senses
271276 if not sense_options :
272277 return None
273- versions = set ([word , word .lower (), word .upper (), word .title ()]) if ignore_case else [word ]
278+ versions = (
279+ set ([word , word .lower (), word .upper (), word .title ()])
280+ if ignore_case
281+ else [word ]
282+ )
274283 freqs = []
275284 for text in versions :
276285 for sense in sense_options :
@@ -304,6 +313,9 @@ def from_bytes(self, bytes_data: bytes, exclude: Sequence[str] = tuple()):
304313 """
305314 data = srsly .msgpack_loads (bytes_data )
306315 self .vectors = Vectors ().from_bytes (data ["vectors" ])
316+ # Pin vectors to the CPU so that we don't end up comparing
317+ # numpy and cupy arrays.
318+ self .vectors .to_ops (NumpyOps ())
307319 self .freqs = dict (data .get ("freqs" , []))
308320 self .cfg .update (data .get ("cfg" , {}))
309321 if "strings" not in exclude and "strings" in data :
@@ -340,6 +352,9 @@ def from_disk(self, path: Union[Path, str], exclude: Sequence[str] = tuple()):
340352 freqs_path = path / "freqs.json"
341353 cache_path = path / "cache"
342354 self .vectors = Vectors ().from_disk (path )
355+ # Pin vectors to the CPU so that we don't end up comparing
356+ # numpy and cupy arrays.
357+ self .vectors .to_ops (NumpyOps ())
343358 self .cfg .update (srsly .read_json (path / "cfg" ))
344359 if freqs_path .exists ():
345360 self .freqs = dict (srsly .read_json (freqs_path ))
0 commit comments