Skip to content

Commit 76a7451

Browse files
committed
refactor(scripts): update main function signatures to improve argument handling
1 parent 7477964 commit 76a7451

File tree

4 files changed

+95
-58
lines changed

4 files changed

+95
-58
lines changed

poligrapher/scripts/html_crawler.py

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
from requests_cache import CachedSession
1818

1919
READABILITY_JS_COMMIT = "8e8ec27cd2013940bc6f3cc609de10e35a1d9d86"
20-
READABILITY_JS_URL = f"https://raw.githubusercontent.com/mozilla/readability/{READABILITY_JS_COMMIT}"
20+
READABILITY_JS_URL = (
21+
f"https://raw.githubusercontent.com/mozilla/readability/{READABILITY_JS_COMMIT}"
22+
)
2123
REQUESTS_TIMEOUT = 10
2224

2325

@@ -30,7 +32,9 @@ def get_readability_js():
3032
js_code.append(res.text)
3133
js_code.append(res.text)
3234

33-
res = session.get(f"{READABILITY_JS_URL}/Readability-readerable.js", timeout=REQUESTS_TIMEOUT)
35+
res = session.get(
36+
f"{READABILITY_JS_URL}/Readability-readerable.js", timeout=REQUESTS_TIMEOUT
37+
)
3438
res.raise_for_status()
3539
js_code.append(res.text)
3640

@@ -51,9 +55,15 @@ def url_arg_handler(url):
5155
return parsed_path.as_uri()
5256

5357
# Handle Google Docs URLs
54-
if (parsed_url.hostname == "docs.google.com"
55-
and not parsed_url.path.endswith("/pub")
56-
and (m := re.match(r"/document/d/(1[a-zA-Z0-9_-]{42}[AEIMQUYcgkosw048])", parsed_url.path))):
58+
if (
59+
parsed_url.hostname == "docs.google.com"
60+
and not parsed_url.path.endswith("/pub")
61+
and (
62+
m := re.match(
63+
r"/document/d/(1[a-zA-Z0-9_-]{42}[AEIMQUYcgkosw048])", parsed_url.path
64+
)
65+
)
66+
):
5767
logging.info("Exporting HTML from Google Docs URL...")
5868

5969
export_url = f"https://docs.google.com/feeds/download/documents/export/Export?id={m[1]}&exportFormat=html"
@@ -78,15 +88,12 @@ def url_arg_handler(url):
7888
return url
7989

8090

81-
def main():
82-
logging.basicConfig(format='%(asctime)s [%(levelname)s] %(message)s', level=logging.INFO)
83-
84-
parser = argparse.ArgumentParser()
85-
parser.add_argument("url", help="Input URL or path")
86-
parser.add_argument("output", help="Output dir")
87-
parser.add_argument("--no-readability-js", action="store_true", help="Disable readability.js")
88-
args = parser.parse_args()
91+
def main(url, output):
92+
logging.basicConfig(
93+
format="%(asctime)s [%(levelname)s] %(message)s", level=logging.INFO
94+
)
8995

96+
args = argparse.Namespace(url=url, output=output)
9097
access_url = url_arg_handler(args.url)
9198

9299
if access_url is None:
@@ -129,7 +136,10 @@ def error_cleanup(msg):
129136
url_status = dict()
130137
navigated_urls = []
131138
page.on("response", lambda r: url_status.update({r.url: r.status}))
132-
page.on("framenavigated", lambda f: f.parent_frame is None and navigated_urls.append(f.url))
139+
page.on(
140+
"framenavigated",
141+
lambda f: f.parent_frame is None and navigated_urls.append(f.url),
142+
)
133143

134144
page.goto(access_url)
135145

@@ -146,7 +156,8 @@ def error_cleanup(msg):
146156
# Apply readability.js
147157
page.evaluate("window.stop()")
148158
page.add_script_tag(content=get_readability_js())
149-
readability_info = page.evaluate(r"""(no_readability_js) => {
159+
readability_info = page.evaluate(
160+
r"""(no_readability_js) => {
150161
window.stop();
151162
152163
const documentClone = document.cloneNode(true);
@@ -168,11 +179,13 @@ def error_cleanup(msg):
168179
elem.remove();
169180
170181
return article;
171-
}""", [args.no_readability_js])
182+
}""",
183+
[args.no_readability_js],
184+
)
172185
cleaned_html = page.content()
173186

174187
# Check language
175-
soup = bs4.BeautifulSoup(cleaned_html, 'lxml')
188+
soup = bs4.BeautifulSoup(cleaned_html, "lxml")
176189
soup_text = soup.body.text if soup.body else ""
177190

178191
try:
@@ -192,7 +205,9 @@ def error_cleanup(msg):
192205
output_dir = Path(args.output)
193206
output_dir.mkdir(exist_ok=True)
194207

195-
with open(output_dir / "accessibility_tree.json", "w", encoding="utf-8") as fout:
208+
with open(
209+
output_dir / "accessibility_tree.json", "w", encoding="utf-8"
210+
) as fout:
196211
json.dump(snapshot, fout)
197212

198213
with open(output_dir / "cleaned.html", "w", encoding="utf-8") as fout:
@@ -207,4 +222,10 @@ def error_cleanup(msg):
207222

208223

209224
if __name__ == "__main__":
210-
main()
225+
# fallback to original CLI behavior
226+
import sys
227+
228+
if len(sys.argv) != 3:
229+
print("usage: html_crawler.py <url_or_path> <output_dir>")
230+
sys.exit(1)
231+
main(sys.argv[1], sys.argv[2])

poligrapher/scripts/init_document.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,27 +18,21 @@
1818
from poligrapher.utils import setup_nlp_pipeline
1919

2020

21-
def main():
22-
logging.basicConfig(format='%(asctime)s [%(levelname)s] %(message)s', level=logging.INFO)
23-
24-
parser = argparse.ArgumentParser()
25-
parser.add_argument("workdirs", nargs="+", help="Input directories")
26-
parser.add_argument("--nlp", default="", help="NLP model directory")
27-
parser.add_argument("--debug", action="store_true", help="Show NER results")
28-
parser.add_argument("--gpu-memory-threshold", default=0.9, type=float,
29-
help="Max GPU usage to trigger manual cache cleaning")
30-
args = parser.parse_args()
21+
def main(workdirs, nlp_model_dir="", debug=False, gpu_memory_threshold=0.9):
22+
logging.basicConfig(
23+
format="%(asctime)s [%(levelname)s] %(message)s", level=logging.INFO
24+
)
3125

3226
use_gpu = spacy.prefer_gpu()
33-
nlp = setup_nlp_pipeline(args.nlp)
27+
nlp = setup_nlp_pipeline(nlp_model_dir)
3428

35-
for d in args.workdirs:
29+
for d in workdirs:
3630
logging.info("Processing %s ...", d)
3731

3832
document = PolicyDocument.initialize(d, nlp=nlp)
3933
document.save()
4034

41-
if args.debug:
35+
if debug:
4236
with open(os.path.join(d, "document.txt"), "w", encoding="utf-8") as fout:
4337
fout.write(document.print_tree())
4438

@@ -47,10 +41,24 @@ def main():
4741
gmem_total = torch.cuda.get_device_properties(current_device).total_memory
4842
gmem_reserved = torch.cuda.memory_reserved(current_device)
4943

50-
if gmem_reserved / gmem_total > args.gpu_memory_threshold:
44+
if gmem_reserved / gmem_total > gpu_memory_threshold:
5145
logging.warning("Empty GPU cache...")
5246
torch.cuda.empty_cache()
5347

5448

5549
if __name__ == "__main__":
56-
main()
50+
import sys
51+
52+
if len(sys.argv) < 2:
53+
print(
54+
"usage: init_document.py <workdir1> [<workdir2> ...] [--nlp MODEL_DIR] [--debug] [--gpu-memory-threshold FLOAT]"
55+
)
56+
sys.exit(1)
57+
# parse sys.argv manually or via argparse then call:
58+
parser = argparse.ArgumentParser()
59+
parser.add_argument("workdirs", nargs="+")
60+
parser.add_argument("--nlp", default="")
61+
parser.add_argument("--debug", action="store_true")
62+
parser.add_argument("--gpu-memory-threshold", default=0.9, type=float)
63+
args = parser.parse_args()
64+
main(args.workdirs, args.nlp, args.debug, args.gpu_memory_threshold)

poligrapher/scripts/pdf_parser.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,8 @@ def url_arg_handler(url, args):
138138
return exported
139139

140140

141-
def main():
142-
parser = argparse.ArgumentParser()
143-
parser.add_argument("url", help="Input URL or path")
144-
parser.add_argument("output", help="Output dir")
145-
args = parser.parse_args()
141+
def main(url, output):
142+
args = argparse.Namespace(url=url, output=output)
146143

147144
pdf_source = url_arg_handler(args.url, args)
148145
if pdf_source is None:
@@ -165,4 +162,10 @@ def main():
165162

166163

167164
if __name__ == "__main__":
168-
main()
165+
# fallback to original CLI behavior
166+
import sys
167+
168+
if len(sys.argv) != 3:
169+
print("usage: pdf_parser.py <url_or_path> <output_dir>")
170+
sys.exit(1)
171+
main(sys.argv[1], sys.argv[2])

poligrapher/scripts/run_annotators.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,30 +15,28 @@
1515
from poligrapher.utils import setup_nlp_pipeline
1616

1717

18-
def main():
19-
logging.basicConfig(format='%(asctime)s [%(levelname)s] <%(name)s> %(message)s', level=logging.INFO)
18+
def main(workdirs, nlp_model_dir="", disable=""):
19+
logging.basicConfig(
20+
format="%(asctime)s [%(levelname)s] <%(name)s> %(message)s", level=logging.INFO
21+
)
2022

21-
parser = argparse.ArgumentParser()
22-
parser.add_argument("--nlp", default="", help="NLP model directory")
23-
parser.add_argument("--disable", default="", help="Disable annotators for ablation study")
24-
parser.add_argument("workdirs", nargs="+", help="Input directories")
25-
args = parser.parse_args()
26-
27-
nlp = setup_nlp_pipeline(args.nlp)
23+
nlp = setup_nlp_pipeline(nlp_model_dir)
2824

29-
disabled_annotators = frozenset(args.disable.split(","))
25+
disabled_annotators = frozenset(disable.split(",")) if disable else frozenset()
3026
annotators = []
3127

32-
for annotator_class in (SubsumptionAnnotator,
33-
CoreferenceAnnotator,
34-
CollectionAnnotator,
35-
PurposeAnnotator,
36-
ListAnnotator,
37-
SubjectAnnotator):
28+
for annotator_class in (
29+
SubsumptionAnnotator,
30+
CoreferenceAnnotator,
31+
CollectionAnnotator,
32+
PurposeAnnotator,
33+
ListAnnotator,
34+
SubjectAnnotator,
35+
):
3836
if annotator_class.__name__ not in disabled_annotators:
3937
annotators.append(annotator_class(nlp))
4038

41-
for d in args.workdirs:
39+
for d in workdirs:
4240
logging.info("Processing %s ...", d)
4341

4442
document = PolicyDocument.load(d, nlp)
@@ -51,4 +49,11 @@ def main():
5149

5250

5351
if __name__ == "__main__":
54-
main()
52+
parser = argparse.ArgumentParser()
53+
parser.add_argument("--nlp", default="", help="NLP model directory")
54+
parser.add_argument(
55+
"--disable", default="", help="Disable annotators for ablation study"
56+
)
57+
parser.add_argument("workdirs", nargs="+", help="Input directories")
58+
args = parser.parse_args()
59+
main(args.workdirs, args.nlp, args.disable)

0 commit comments

Comments
 (0)