Skip to content

Commit d1ea0e0

Browse files
committed
Allow attr values to be HTML
1 parent c2cd993 commit d1ea0e0

File tree

2 files changed

+34
-19
lines changed

2 files changed

+34
-19
lines changed

htmltools/_core.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ class MetadataNode:
9797

9898
TagT = TypeVar("TagT", bound="Tag")
9999

100-
TagAttrValue = Union[str, float, bool, None]
100+
TagAttrValue = Union[str, float, bool, "HTML", None]
101101
"""
102102
Types that can be passed in as attributes to `Tag` functions. These values will be
103103
converted to strings before being stored as tag attributes.
@@ -225,7 +225,7 @@ class Tagifiable(Protocol):
225225
returns a `TagList`, the children of the `TagList` must also be tagified.
226226
"""
227227

228-
def tagify(self) -> "TagList | Tag | MetadataNode | str": ...
228+
def tagify(self) -> "TagList | Tag | MetadataNode | str | HTML": ...
229229

230230

231231
@runtime_checkable
@@ -486,7 +486,7 @@ def _repr_html_(self) -> str:
486486
# =============================================================================
487487
# TagAttrDict class
488488
# =============================================================================
489-
class TagAttrDict(Dict[str, str]):
489+
class TagAttrDict(Dict[str, "str | HTML"]):
490490
"""
491491
A dictionary-like object that can be used to store attributes for a tag. All
492492
attribute values will be stored as strings.
@@ -521,7 +521,7 @@ def update( # type: ignore[reportIncompatibleMethodOverride] # TODO-future: fix
521521
if kwargs:
522522
args = args + (kwargs,)
523523

524-
attrz: dict[str, str] = {}
524+
attrz: dict[str, str | HTML] = {}
525525
for arg in args:
526526
for k, v in arg.items():
527527
val = self._normalize_attr_value(v)
@@ -544,12 +544,14 @@ def _normalize_attr_name(x: str) -> str:
544544
return x.replace("_", "-")
545545

546546
@staticmethod
547-
def _normalize_attr_value(x: TagAttrValue) -> str | None:
547+
def _normalize_attr_value(x: TagAttrValue) -> str | HTML | None:
548548
if x is None or x is False:
549549
return None
550550
if x is True:
551551
return ""
552-
if isinstance(x, str):
552+
# Return both str and HTML objects as is.
553+
# HTML objects will handle value escaping when added to other values
554+
if isinstance(x, (str, HTML)):
553555
return x
554556
if isinstance(x, (int, float)): # pyright: ignore[reportUnnecessaryIsInstance]
555557
return str(x)
@@ -671,7 +673,7 @@ def __enter__(self) -> None:
671673
sys.displayhook = wrap_displayhook_handler(
672674
# self.append takes a TagChild, but the wrapper expects a function that
673675
# takes a object.
674-
self.append # pyright: ignore[reportArgumentType,reportGeneralTypeIssues]
676+
self.append # pyright: ignore[reportArgumentType]
675677
)
676678

677679
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
@@ -840,7 +842,8 @@ def get_html_string(self, indent: int = 0, eol: str = "\n") -> str:
840842

841843
# Write attributes
842844
for key, val in self.attrs.items():
843-
val = html_escape(val, attr=True)
845+
if not isinstance(val, HTML):
846+
val = html_escape(val, attr=True)
844847
html_ += f' {key}="{val}"'
845848

846849
# Dependencies are ignored in the HTML output
@@ -1383,7 +1386,6 @@ def __radd__(self, other: object) -> HTML:
13831386
# Case: `str + HTML()`
13841387
return HTML(html_escape(str(other)) + self.as_string())
13851388

1386-
13871389
def __repr__(self) -> str:
13881390
return self.as_string()
13891391

@@ -1923,7 +1925,7 @@ def _tag_show(
19231925
import IPython # pyright: ignore[reportUnknownVariableType]
19241926

19251927
ipy = ( # pyright: ignore[reportUnknownVariableType]
1926-
IPython.get_ipython() # pyright: ignore[reportUnknownMemberType, reportPrivateImportUsage]
1928+
IPython.get_ipython() # pyright: ignore[reportUnknownMemberType, reportPrivateImportUsage, reportAttributeAccessIssue]
19271929
)
19281930
renderer = "ipython" if ipy else "browser"
19291931
except ImportError:

tests/test_tags.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,25 @@ def test_basic_tag_api():
108108
x4.add_style("color: blue;", prepend=True)
109109
assert x4.attrs["style"] == "color: blue; color: red; color: green;"
110110

111+
x5 = div()
112+
x5.add_style("color: &purple;")
113+
assert isinstance(x5.attrs["style"], str)
114+
assert x5.attrs["style"] == "color: &purple;"
115+
x5.add_style(HTML("color: &green;"))
116+
assert isinstance(x5.attrs["style"], HTML)
117+
assert x5.attrs["style"] == HTML("color: &purple; color: &green;")
118+
119+
x6 = div()
120+
x6.add_style("color: &red;")
121+
assert isinstance(x6.attrs["style"], str)
122+
assert x6.attrs["style"] == "color: &red;"
123+
x6.add_style(HTML("color: &green;"), prepend=True)
124+
assert isinstance(x6.attrs["style"], HTML)
125+
assert x6.attrs["style"] == HTML("color: &green; color: &red;")
126+
assert isinstance(x6.attrs["style"], HTML)
127+
x6.add_style(HTML("color: &blue;"))
128+
assert x6.attrs["style"] == HTML("color: &green; color: &red; color: &blue;")
129+
111130

112131
def test_tag_list_dict():
113132
# Dictionaries allowed at top level
@@ -542,14 +561,8 @@ def test_tag_escaping():
542561
# Attributes are HTML escaped
543562
expect_html(div("text", class_="<a&b>"), '<div class="&lt;a&amp;b&gt;">text</div>')
544563
expect_html(div("text", class_="'ab'"), '<div class="&apos;ab&apos;">text</div>')
545-
# Sending in HTML causes an error
546-
with pytest.raises(TypeError):
547-
div(
548-
"text",
549-
class_=HTML(
550-
"<a&b>"
551-
), # pyright: ignore[reportArgumentType,reportGeneralTypeIssues]
552-
)
564+
# Attributes support `HTML()` values
565+
expect_html(div("text", class_=HTML("<a&b>")), '<div class="<a&b>">text</div>')
553566

554567
# script and style tags are not escaped
555568
assert str(tags.script("a && b > 3")) == "<script>a && b > 3</script>"
@@ -654,7 +667,7 @@ def _walk_mutate(x: TagNode, fn: Callable[[TagNode], TagNode]) -> TagNode:
654667
x.children[i] = _walk_mutate(child, fn)
655668
elif isinstance(x, list):
656669
for i, child in enumerate(x):
657-
x[i] = _walk_mutate(child, fn)
670+
x[i] = _walk_mutate(child, fn) # pyright: ignore[reportArgumentType]
658671
return x
659672

660673

0 commit comments

Comments
 (0)