Skip to content

Commit ed283a9

Browse files
committed
Merge branch 'master' into develop
2 parents 3605bb1 + 1058d97 commit ed283a9

File tree

7 files changed

+92
-61
lines changed

7 files changed

+92
-61
lines changed

CONTRIBUTING.rst

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -219,19 +219,19 @@ To achieve this, there are 5 points there are 5 points to follow:
219219
"""Defines Foo embedding technique."""
220220
221221
def fit(self, walks: List[List[SWalk]], is_updated: bool = False) -> Embedder:
222-
"""Fits the model based on provided walks.
222+
"""Fits the model based on provided walks.
223223
224-
Args:
225-
walks: The walks to create the corpus to to fit the model.
226-
is_update: True if the new walks should be added to old model's
227-
walks, False otherwise.
228-
Defaults to False.
224+
Args:
225+
walks: The walks to create the corpus to to fit the model.
226+
is_update: True if the new walks should be added to old model's
227+
walks, False otherwise.
228+
Defaults to False.
229229
230-
Returns:
231-
The fitted Foo model.
230+
Returns:
231+
The fitted Foo model.
232232
233-
"""
234-
# TODO: to be implemented
233+
"""
234+
# TODO: to be implemented
235235
236236
def transform(self, entities: Entities) -> Embeddings:
237237
"""The features vector of the provided entities.
@@ -244,14 +244,14 @@ To achieve this, there are 5 points there are 5 points to follow:
244244
Returns:
245245
The features vector of the provided entities.
246246
247-
"""
248-
# TODO: to be implemented
247+
"""
248+
# TODO: to be implemented
249249
250250
4. **Create unit tests of your embedding technique:**
251251

252252
Create a ``tests/embedders/foo.py`` file and see `how the tests are done for
253253
Word2Vec
254-
<https://github.com/IBCNServices/pyRDF2Vec/blob/master/tests/embedders/word2vec.py>`__
254+
<https://github.com/IBCNServices/pyRDF2Vec/blob/master/tests/embedders/test_word2vec.py>`__
255255
as an example.
256256

257257
Once this is done, run your tests:

README.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ embeddings and get literals from a given Knowledge Graph (KG) and entities:
100100
# ]
101101
102102
transformer = RDF2VecTransformer(
103-
Word2Vec(iter=10),
103+
Word2Vec(epochs=10),
104104
walkers=[RandomWalker(4, 10, with_reverse=False, n_jobs=2)],
105105
# verbose=1
106106
)

pyrdf2vec/connectors.py

Lines changed: 5 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
import requests
1212
from cachetools import Cache, TTLCache, cachedmethod
1313
from cachetools.keys import hashkey
14-
from requests.adapters import HTTPAdapter
15-
from requests.packages.urllib3.util import Retry
1614

1715
from pyrdf2vec.typings import Literal, Response
1816

@@ -24,8 +22,6 @@ class Connector(ABC):
2422
Attributes:
2523
_asession: The aiohttp session to use for asynchrone requests.
2624
Defaults to None.
27-
_session: The requests session to use for synchrone requests.
28-
Defaults to requests.Session.
2925
_headers: The HTTP headers to use.
3026
Defaults to {"Accept": "application/sparql-results+json"}.
3127
cache: The policy and size cache to use.
@@ -50,16 +46,10 @@ class Connector(ABC):
5046
init=False,
5147
type=Dict[str, str],
5248
repr=False,
53-
default={
54-
"Accept": "application/sparql-results+json",
55-
},
49+
default={"Accept": "application/sparql-results+json"},
5650
)
5751

5852
_asession = attr.ib(init=False, default=None)
59-
_session = attr.ib(
60-
init=False,
61-
factory=lambda: requests.Session(),
62-
)
6353

6454
async def close(self) -> None:
6555
"""Closes the aiohttp session."""
@@ -73,7 +63,7 @@ def fetch(self, query: str):
7363
query: The query to fetch the result
7464
7565
Returns:
76-
The generated dictionary from the ['results']['bindings'] json.
66+
The JSON response.
7767
7868
Raises:
7969
NotImplementedError: If this method is called, without having
@@ -90,8 +80,6 @@ class SPARQLConnector(Connector):
9080
Attributes:
9181
_asession: The aiohttp session to use for asynchrone requests.
9282
Defaults to None.
93-
_session: The requests session to use for synchrone requests.
94-
Defaults to requests.Session.
9583
_headers: The HTTP headers to use.
9684
Defaults to {"Accept": "application/sparql-results+json"}.
9785
cache: The policy and size cache to use.
@@ -100,17 +88,6 @@ class SPARQLConnector(Connector):
10088
10189
"""
10290

103-
def __attrs_post_init__(self):
104-
adapter = HTTPAdapter(
105-
Retry(
106-
total=3,
107-
status_forcelist=[429, 500, 502, 503, 504],
108-
method_whitelist=["HEAD", "GET", "OPTIONS"],
109-
)
110-
)
111-
self._session.mount("http", adapter)
112-
self._session.mount("https", adapter)
113-
11491
async def afetch(self, queries: List[str]) -> List[List[Response]]:
11592
"""Fetchs the result of SPARQL queries asynchronously.
11693
@@ -141,8 +118,7 @@ async def _fetch(self, query) -> Response:
141118
"""
142119
url = f"{self.endpoint}/query?query={parse.quote(query)}"
143120
async with self._asession.get(url, headers=self._headers) as res:
144-
res = await res.json()
145-
return res["results"]["bindings"]
121+
return await res.json()
146122

147123
@cachedmethod(operator.attrgetter("cache"), key=partial(hashkey, "fetch"))
148124
def fetch(self, query: str) -> Response:
@@ -156,8 +132,8 @@ def fetch(self, query: str) -> Response:
156132
157133
"""
158134
url = f"{self.endpoint}/query?query={parse.quote(query)}"
159-
res = self._session.get(url, headers=self._headers).json()
160-
return res["results"]["bindings"]
135+
with requests.get(url, headers=self._headers) as res:
136+
return res.json()
161137

162138
def get_query(self, entity: str, preds: Optional[List[str]] = None) -> str:
163139
"""Gets the SPARQL query for an entity.

pyrdf2vec/graphs/kg.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ class KG:
5353
Defaults to False.
5454
skip_predicates: The label predicates to skip from the KG.
5555
Defaults to set.
56+
skip_verify: To skip or not the verification of existing entities in a
57+
Knowledge Graph. Its deactivation can improve HTTP latency for KG
58+
remotes.
59+
Defaults to False.
5660
5761
"""
5862

@@ -95,6 +99,13 @@ class KG:
9599
validator=attr.validators.instance_of(bool),
96100
)
97101

102+
skip_verify = attr.ib(
103+
kw_only=True,
104+
type=bool,
105+
default=False,
106+
validator=attr.validators.instance_of(bool),
107+
)
108+
98109
cache = attr.ib(
99110
kw_only=True,
100111
type=Cache,
@@ -206,9 +217,6 @@ def add_walk(self, subj: Vertex, pred: Vertex, obj: Vertex) -> bool:
206217
return True
207218
return False
208219

209-
@cachedmethod(
210-
operator.attrgetter("cache"), key=partial(hashkey, "fetch_hops")
211-
)
212220
def fetch_hops(self, vertex: Vertex) -> List[Hop]:
213221
"""Fetchs the hops of the vertex from a SPARQL endpoint server and
214222
add the hops for this vertex in a cache dictionary.
@@ -229,7 +237,7 @@ def fetch_hops(self, vertex: Vertex) -> List[Hop]:
229237
"https://"
230238
):
231239
res = self.connector.fetch(self.connector.get_query(vertex.name))
232-
hops = self._res2hops(vertex, res)
240+
hops = self._res2hops(vertex, res["results"]["bindings"])
233241
return hops
234242

235243
def get_hops(self, vertex: Vertex, is_reverse: bool = False) -> List[Hop]:
@@ -283,7 +291,10 @@ def get_literals(self, entities: Entities, verbose: int = 0) -> Literals:
283291
responses = [self.connector.fetch(query) for query in queries]
284292

285293
literals_responses = [
286-
self.connector.res2literals(res) for res in responses
294+
self.connector.res2literals(
295+
res["results"]["bindings"] # type: ignore
296+
)
297+
for res in responses
287298
]
288299
return [
289300
literals_responses[
@@ -340,6 +351,31 @@ def get_pliterals(self, entity: str, preds: List[str]) -> List[str]:
340351
frontier = new_frontier
341352
return list(frontier)
342353

354+
def is_exist(self, entities: Entities) -> bool:
355+
"""Checks that all provided entities exists in the Knowledge Graph.
356+
357+
Args:
358+
entities: The entities to check the existence
359+
360+
Returns:
361+
True if all the entities exists, False otherwise.
362+
363+
"""
364+
if self._is_remote:
365+
queries = [
366+
f"ASK WHERE {{ <{entity}> ?p ?o . }}" for entity in entities
367+
]
368+
if self.mul_req:
369+
responses = [
370+
res["boolean"] # type: ignore
371+
for res in asyncio.run(self.connector.afetch(queries))
372+
]
373+
else:
374+
responses = [self.connector.fetch(query) for query in queries]
375+
responses = [res["boolean"] for res in responses]
376+
return False not in responses
377+
return all([Vertex(entity) in self._vertices for entity in entities])
378+
343379
def remove_edge(self, v1: Vertex, v2: Vertex) -> bool:
344380
"""Removes the edge (v1 -> v2) if present.
345381
@@ -403,7 +439,9 @@ def _fill_hops(self, entities: Entities) -> None:
403439
entities,
404440
asyncio.run(self.connector.afetch(queries)),
405441
):
406-
hops = self._res2hops(Vertex(entity), res)
442+
hops = self._res2hops(
443+
Vertex(entity), res["results"]["bindings"] # type: ignore
444+
)
407445
self._entity_hops.update({entity: hops})
408446

409447
@cachedmethod(
@@ -418,8 +456,8 @@ def _get_hops(self, vertex: Vertex, is_reverse: bool = False) -> List[Hop]:
418456
vertex. Otherwise, get the child nodes for this vertex.
419457
Defaults to False.
420458
421-
Returns:
422-
The hops of a vertex in a (predicate, object) form.
459+
Returns:
460+
The hops of a vertex in a (predicate, object) form.
423461
424462
"""
425463
matrix = self._transition_matrix
@@ -452,6 +490,6 @@ def _res2hops(self, vertex: Vertex, res) -> List[Hop]:
452490
vprev=vertex,
453491
vnext=obj,
454492
)
455-
if self.add_walk(vertex, pred, obj):
493+
if pred.name not in self.skip_predicates:
456494
hops.append((pred, obj))
457495
return hops

pyrdf2vec/rdf2vec.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import attr
99

1010
from pyrdf2vec.embedders import Embedder, Word2Vec
11-
from pyrdf2vec.graphs import KG, Vertex
11+
from pyrdf2vec.graphs import KG
1212
from pyrdf2vec.typings import Embeddings, Entities, Literals, SWalk
1313
from pyrdf2vec.walkers import RandomWalker, Walker
1414

@@ -160,15 +160,15 @@ def get_walks(self, kg: KG, entities: Entities) -> List[List[SWalk]]:
160160
ValueError: If the provided entities aren't in the Knowledge Graph.
161161
162162
"""
163-
if not kg._is_remote and not all(
164-
[Vertex(entity) in kg._vertices for entity in entities]
165-
):
166-
raise ValueError(
167-
"The provided entities must be in the Knowledge Graph."
168-
)
169-
170163
# Avoids duplicate entities for unnecessary walk extractions.
171164
entities = list(set(entities))
165+
if kg.skip_verify is False and not kg.is_exist(entities):
166+
if kg.mul_req:
167+
asyncio.run(kg.connector.close())
168+
raise ValueError(
169+
"At least one provided entity does not exist in the "
170+
+ "Knowledge Graph."
171+
)
172172

173173
if self.verbose == 2:
174174
print(kg)

pyrdf2vec/utils/validation.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ def is_valid_url(url: str) -> bool:
9191
9292
"""
9393
try:
94-
return requests.get(url).status_code == requests.codes.ok
94+
return (
95+
requests.head(url, headers={"Accept": "text/html"}).status_code
96+
== requests.codes.ok
97+
)
9598
except Exception:
9699
return False

tests/test_graph.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,20 @@ def test_invalid_url(self):
219219
with pytest.raises(ValueError):
220220
KG("http://foo")
221221

222+
def test_is_exist(self, setup):
223+
assert LOCAL_KG.is_exist([f"{URL}#Alice", "foo"]) is False
224+
assert (
225+
LOCAL_KG.is_exist(
226+
[
227+
f"{URL}#Alice",
228+
f"{URL}#Bob",
229+
f"{URL}#Casper",
230+
f"{URL}#Dean",
231+
]
232+
)
233+
is True
234+
)
235+
222236
def test_remove_edge(self, setup):
223237
vtx_alice = Vertex(f"{URL}#Alice")
224238

0 commit comments

Comments
 (0)