Skip to content

docs: add docstrings to utility and helper functions #97

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions gpt_oss/tools/simple_browser/page_contents.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class Tokens:


def get_domain(url: str) -> str:
"""Extracts the domain from a URL."""
if "http" not in url:
# If `get_domain` is called on a domain, add a scheme so that the
# original domain is returned instead of the empty string.
Expand All @@ -72,12 +73,14 @@ def get_domain(url: str) -> str:


def multiple_replace(text: str, replacements: dict[str, str]) -> str:
"""Performs multiple string replacements using regex pass."""
regex = re.compile("(%s)" % "|".join(map(re.escape, replacements.keys())))
return regex.sub(lambda mo: replacements[mo.group(1)], text)


@functools.lru_cache(maxsize=1024)
def mark_lines(text: str) -> str:
"""Adds line numbers (ex: 'L0:') to the beginning of each line in a string."""
# Split the string by newline characters
lines = text.split("\n")

Expand All @@ -88,16 +91,19 @@ def mark_lines(text: str) -> str:

@functools.cache
def _tiktoken_vocabulary_lengths(enc_name: str) -> list[int]:
"""Gets the character lengths of all tokens in the specified TikToken vocabulary."""
encoding = tiktoken.get_encoding(enc_name)
return [len(encoding.decode([i])) for i in range(encoding.n_vocab)]


def warmup_caches(enc_names: list[str]) -> None:
"""Warm up the cache by computing token length lists for the given TikToken encodings."""
for _ in map(_tiktoken_vocabulary_lengths, enc_names):
pass


def _replace_special_chars(text: str) -> str:
"""Replaces specific special characters with visually similar alternatives."""
replacements = {
"【": "〖",
"】": "〗",
Expand All @@ -110,16 +116,19 @@ def _replace_special_chars(text: str) -> str:


def merge_whitespace(text: str) -> str:
"""Replace newlines with spaces and merge consecutive whitespace into a single space."""
text = text.replace("\n", " ")
text = re.sub(r"\s+", " ", text)
return text


def arxiv_to_ar5iv(url: str) -> str:
"""Converts an arxiv.org URL to its ar5iv.org equivalent."""
return re.sub(r"arxiv.org", r"ar5iv.org", url)


def _clean_links(root: lxml.html.HtmlElement, cur_url: str) -> dict[str, str]:
"""Processes all anchor tags in the HTML, replaces them with a custom format and returns an ID-to-URL mapping."""
cur_domain = get_domain(cur_url)
urls: dict[str, str] = {}
urls_rev: dict[str, str] = {}
Expand Down Expand Up @@ -156,10 +165,12 @@ def _clean_links(root: lxml.html.HtmlElement, cur_url: str) -> dict[str, str]:


def _get_text(node: lxml.html.HtmlElement) -> str:
"""Extracts all text from an HTML element and merges it into a whitespace-normalized string."""
return merge_whitespace(" ".join(node.itertext()))


def _remove_node(node: lxml.html.HtmlElement) -> None:
"""Removes a node from its parent in the lxml tree."""
node.getparent().remove(node)


Expand All @@ -172,6 +183,7 @@ def _escape_md_section(text: str, snob: bool = False) -> str:


def html_to_text(html: str) -> str:
"""Converts an HTML string to clean plaintext."""
html = re.sub(HTML_SUP_RE, r"^{\2}", html)
html = re.sub(HTML_SUB_RE, r"_{\2}", html)
# add spaces between tags such as table cells
Expand All @@ -195,6 +207,7 @@ def html_to_text(html: str) -> str:


def _remove_math(root: lxml.html.HtmlElement) -> None:
"""Removes all <math> elements from the lxml tree."""
for node in root.findall(".//math"):
_remove_node(node)

Expand All @@ -209,6 +222,7 @@ def remove_unicode_smp(text: str) -> str:


def replace_node_with_text(node: lxml.html.HtmlElement, text: str) -> None:
"""Replaces an lxml node with a text string while preserving surrounding text."""
previous = node.getprevious()
parent = node.getparent()
tail = node.tail or ""
Expand All @@ -224,6 +238,7 @@ def replace_images(
base_url: str,
session: aiohttp.ClientSession | None,
) -> None:
"""Finds all image tags and replaces them with numbered placeholders (includes alt/title if available)."""
cnt = 0
for img_tag in root.findall(".//img"):
image_name = img_tag.get("alt", img_tag.get("title"))
Expand Down