Skip to content

Commit 654615c

Browse files
author
Gerit Wagner
committed
update SearchFile
1 parent da55e69 commit 654615c

File tree

4 files changed

+80
-65
lines changed

4 files changed

+80
-65
lines changed

search_query/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,17 @@
66
from search_query.query import Query
77
from search_query.or_query import OrQuery
88
from search_query.and_query import AndQuery
9+
from search_query.search_file import SearchFile, load_search_file
910
from .__version__ import __version__
1011

11-
__all__ = ["__version__", "Query", "OrQuery", "AndQuery"]
12+
__all__ = [
13+
"__version__",
14+
"Query",
15+
"OrQuery",
16+
"AndQuery",
17+
"SearchFile",
18+
"load_search_file",
19+
]
1220

1321
# Instead of adding elements to __all__,
1422
# prefixing methods/variables with "__" is preferred.

search_query/linter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def pre_commit_hook() -> int:
2828
file_path = sys.argv[1]
2929

3030
try:
31-
search_file = SearchFile(file_path)
31+
search_file = SearchFile(file_path, platform="unknown")
3232
platform = search_query.parser.get_platform(search_file.platform)
3333
except Exception as e: # pylint: disable=broad-except
3434
print(e)

search_query/search_file.py

Lines changed: 68 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -4,72 +4,62 @@
44

55
import json
66
import re
7-
import typing
7+
from pathlib import Path
8+
from typing import Optional
89

910
# pylint: disable=too-few-public-methods
1011

1112

1213
class SearchFile:
13-
"""SearchFile model."""
14-
15-
record_info: typing.Dict[str, str]
16-
authors: typing.List[dict]
17-
date: dict
18-
platform: str
19-
database: typing.List[str]
20-
search_string: str
21-
22-
# Optionals
23-
parsed: typing.Optional[dict] = None
24-
string_name: typing.Optional[str] = None
25-
keywords: typing.Optional[str] = None
26-
related_records: typing.Optional[str] = None
27-
parent_record: typing.Optional[str] = None
28-
database_time_coverage: typing.Optional[str] = None
29-
search_language: typing.Optional[str] = None
30-
settings: typing.Optional[str] = None
31-
quality_assurance: typing.Optional[str] = None
32-
validation_report: typing.Optional[str] = None
33-
peer_review: typing.Optional[str] = None
34-
description: typing.Optional[str] = None
35-
review_question: typing.Optional[str] = None
36-
review_type: typing.Optional[str] = None
37-
linked_protocol: typing.Optional[str] = None
38-
linked_report: typing.Optional[str] = None
39-
40-
def __init__(self, filepath: str) -> None:
41-
with open(filepath, encoding="utf-8") as file:
42-
data = json.load(file)
43-
44-
self._validate(data)
45-
46-
self.record_info = data["record_info"]
47-
self.authors = data["authors"]
48-
self.date = data["date"]
49-
self.platform = data["platform"]
50-
self.database = data["database"]
51-
self.search_string = data["search_string"]
52-
if "parsed" in data:
53-
self.parsed = data["parsed"]
54-
55-
def _validate(self, data: dict) -> None:
56-
# Note: validate without pydantic to keep zero dependencies
57-
58-
if not isinstance(data, dict):
59-
raise TypeError("Data must be a dictionary.")
60-
61-
self._validate_authors(data)
62-
63-
if "record_info" not in data:
64-
raise ValueError("Data must have a 'record_info' key.")
65-
if "date" not in data:
66-
raise ValueError("Data must have a 'date' key.")
67-
if "platform" not in data:
68-
raise ValueError("Data must have a 'platform' key.")
69-
if "database" not in data:
70-
raise ValueError("Data must have a 'database' key.")
71-
if "search_string" not in data:
72-
raise ValueError("Data must have a 'search_string' key.")
14+
"""SearchFile class."""
15+
16+
# pylint: disable=too-many-arguments
17+
def __init__(
18+
self,
19+
search_string: str,
20+
platform: str,
21+
authors: Optional[list[dict]] = None,
22+
record_info: Optional[dict] = None,
23+
date: Optional[dict] = None,
24+
filepath: Optional[str | Path] = None,
25+
**kwargs: dict,
26+
) -> None:
27+
self.search_string = search_string
28+
self.platform = platform
29+
self.authors = authors or []
30+
self.record_info = record_info or {}
31+
self.date = date or {}
32+
self._filepath = Path(filepath) if filepath else None
33+
34+
for key, value in kwargs.items():
35+
setattr(self, key, value)
36+
37+
self._validate_authors(self.to_dict())
38+
39+
def save(self, filepath: Optional[str | Path] = None) -> None:
40+
"""Save the search file to a JSON file."""
41+
path = Path(filepath) if filepath else self._filepath
42+
if path is None:
43+
raise ValueError("No filepath provided and no previous filepath stored.")
44+
with open(path, "w", encoding="utf-8") as f:
45+
json.dump(self.to_dict(), f, indent=4, ensure_ascii=False)
46+
47+
def to_dict(self) -> dict:
48+
"""Convert the search file to a dictionary."""
49+
data = {
50+
"search_string": self.search_string,
51+
"platform": self.platform,
52+
"authors": self.authors,
53+
"record_info": self.record_info,
54+
"date": self.date,
55+
}
56+
extras = {
57+
k: v
58+
for k, v in self.__dict__.items()
59+
if k not in data and not k.startswith("_") and v is not None
60+
}
61+
data.update(extras)
62+
return data
7363

7464
def _validate_authors(self, data: dict) -> None:
7565
if "authors" not in data:
@@ -95,3 +85,20 @@ def _validate_authors(self, data: dict) -> None:
9585
raise TypeError("Email must be a string.")
9686
if not re.match(r"^\S+@\S+\.\S+$", author["email"]):
9787
raise ValueError("Invalid email.")
88+
89+
90+
def load_search_file(filepath: str | Path) -> SearchFile:
91+
"""Load a search file from a JSON file."""
92+
path = Path(filepath)
93+
with open(path, encoding="utf-8") as f:
94+
data = json.load(f)
95+
96+
if "search_string" not in data or "platform" not in data:
97+
raise ValueError("File must contain at least 'search_string' and 'platform'.")
98+
99+
return SearchFile(
100+
search_string=data.pop("search_string"),
101+
platform=data.pop("platform"),
102+
filepath=path,
103+
**data,
104+
)

test/test_search_file.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
"""Tests for SearchFile parser."""
33
from __future__ import annotations
44

5-
from search_query.search_file import SearchFile
5+
from search_query.search_file import load_search_file
66

77

88
def test_search_history_file_parser() -> None:
99
"""Test SearchFile parser."""
1010

1111
file_path = "test/search_history_file_1.json"
1212

13-
result = SearchFile(file_path)
13+
result = load_search_file(file_path)
1414

1515
assert hasattr(result, "parsed")

0 commit comments

Comments
 (0)