Skip to content

Commit 7037810

Browse files
[Fix] Fix from_structural_tag when the parameter is a StructuralTag. (#423)
This PR fixes the behavior of `from_structural_tag` method when the parameter is a `StructuralTag`. --------- Signed-off-by: Yuchuan <[email protected]>
1 parent d2e27d9 commit 7037810

File tree

2 files changed

+180
-5
lines changed

2 files changed

+180
-5
lines changed

python/xgrammar/grammar.py

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,43 @@
1010
from .structural_tag import StructuralTag, StructuralTagItem
1111

1212

13+
def _convert_instance_to_str(instance: Union[str, Dict[str, Any], StructuralTag]) -> str:
14+
"""Convert a instance to a string representation. It returns the schema in string format because
15+
it's faster to send to C++.
16+
17+
This function handles different instance input types and converts them to a JSON string:
18+
- StructuralTag.
19+
- String inputs are returned as-is (assumed to be valid JSON)
20+
- Dictionary inputs are converted to JSON strings
21+
22+
Parameters
23+
----------
24+
instance : Union[str, StructuralTag, Dict[str, Any]]
25+
The instance to convert, which can be a StructuralTag,
26+
a JSON schema string, or a dictionary representing a JSON schema.
27+
28+
Returns
29+
-------
30+
str
31+
The JSON schema as a string.
32+
33+
Raises
34+
------
35+
ValueError
36+
When the instance type is not supported.
37+
TypeError
38+
When he dictionary is not serializable.
39+
"""
40+
if isinstance(instance, dict):
41+
return json.dumps(instance)
42+
elif isinstance(instance, str):
43+
return instance
44+
elif isinstance(instance, StructuralTag):
45+
return instance.model_dump_json()
46+
else:
47+
raise ValueError("Invalid instance type")
48+
49+
1350
def _convert_schema_to_str(schema: Union[str, Type[BaseModel], Dict[str, Any]]) -> str:
1451
"""Convert a schema to a string representation. It returns the schema in string format because
1552
it's faster to send to C++.
@@ -32,8 +69,10 @@ def _convert_schema_to_str(schema: Union[str, Type[BaseModel], Dict[str, Any]])
3269
3370
Raises
3471
------
35-
ValueError, TypeError
36-
If the schema type is not supported, or the dictionary is not serializable.
72+
ValueError
73+
When the schema type is not supported.
74+
TypeError
75+
When the dictionary is not serializable.
3776
"""
3877
if isinstance(schema, type) and issubclass(schema, BaseModel):
3978
if hasattr(schema, "model_json_schema"):
@@ -71,14 +110,17 @@ def _get_structural_tag_str_from_args(args: List[Any], kwargs: Dict[str, Any]) -
71110
TypeError
72111
When the arguments are invalid.
73112
"""
74-
if len(args) == 1 and isinstance(args[0], (StructuralTag, str, dict)):
75-
return _convert_schema_to_str(args[0])
113+
if len(args) == 1:
114+
if isinstance(args[0], (str, dict, StructuralTag)):
115+
return _convert_instance_to_str(args[0])
116+
else:
117+
raise TypeError("Invalid argument type for from_structural_tag")
76118
elif len(args) == 2 and isinstance(args[0], list) and isinstance(args[1], list):
77119
return StructuralTag.from_legacy_structural_tag(args[0], args[1]).model_dump_json(
78120
indent=None
79121
)
80122
elif "structural_tag" in kwargs:
81-
return _convert_schema_to_str(kwargs["structural_tag"])
123+
return _convert_instance_to_str(kwargs["structural_tag"])
82124
elif "tags" in kwargs and "triggers" in kwargs:
83125
return StructuralTag.from_legacy_structural_tag(
84126
kwargs["tags"], kwargs["triggers"]

tests/python/test_structural_tag_converter.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1655,5 +1655,138 @@ def test_basic_structural_tag_utf8(stag_format: Dict[str, Any], instance: str, i
16551655
check_stag_with_instance(stag_format, instance, is_accepted)
16561656

16571657

1658+
basic_structural_tags_instance_is_accepted = [
1659+
# ConstStringFormat
1660+
(xgr.structural_tag.ConstStringFormat(value="hello"), "hello", True),
1661+
(xgr.structural_tag.ConstStringFormat(value="hello"), "hello world", False),
1662+
# JSONSchemaFormat
1663+
(xgr.structural_tag.JSONSchemaFormat(json_schema={"type": "object"}), '{"key": "value"}', True),
1664+
(xgr.structural_tag.JSONSchemaFormat(json_schema={"type": "string"}), '"abc"', True),
1665+
(xgr.structural_tag.JSONSchemaFormat(json_schema={"type": "integer"}), "123", True),
1666+
(xgr.structural_tag.JSONSchemaFormat(json_schema={"type": "integer"}), "abc", False),
1667+
# AnyTextFormat
1668+
(xgr.structural_tag.AnyTextFormat(), "", True),
1669+
(xgr.structural_tag.AnyTextFormat(), "any text here", True),
1670+
# SequenceFormat
1671+
(
1672+
xgr.structural_tag.SequenceFormat(
1673+
elements=[
1674+
xgr.structural_tag.ConstStringFormat(value="A"),
1675+
xgr.structural_tag.ConstStringFormat(value="B"),
1676+
]
1677+
),
1678+
"AB",
1679+
True,
1680+
),
1681+
(
1682+
xgr.structural_tag.SequenceFormat(
1683+
elements=[
1684+
xgr.structural_tag.ConstStringFormat(value="A"),
1685+
xgr.structural_tag.ConstStringFormat(value="B"),
1686+
]
1687+
),
1688+
"A",
1689+
False,
1690+
),
1691+
# OrFormat
1692+
(
1693+
xgr.structural_tag.OrFormat(
1694+
elements=[
1695+
xgr.structural_tag.ConstStringFormat(value="A"),
1696+
xgr.structural_tag.ConstStringFormat(value="B"),
1697+
]
1698+
),
1699+
"A",
1700+
True,
1701+
),
1702+
(
1703+
xgr.structural_tag.OrFormat(
1704+
elements=[
1705+
xgr.structural_tag.ConstStringFormat(value="A"),
1706+
xgr.structural_tag.ConstStringFormat(value="B"),
1707+
]
1708+
),
1709+
"B",
1710+
True,
1711+
),
1712+
(
1713+
xgr.structural_tag.OrFormat(
1714+
elements=[
1715+
xgr.structural_tag.ConstStringFormat(value="A"),
1716+
xgr.structural_tag.ConstStringFormat(value="B"),
1717+
]
1718+
),
1719+
"C",
1720+
False,
1721+
),
1722+
# TagFormat
1723+
(
1724+
xgr.structural_tag.TagFormat(
1725+
begin="<b>", content=xgr.structural_tag.AnyTextFormat(), end="</b>"
1726+
),
1727+
"<b>text</b>",
1728+
True,
1729+
),
1730+
(
1731+
xgr.structural_tag.TagFormat(
1732+
begin="<b>", content=xgr.structural_tag.AnyTextFormat(), end="</b>"
1733+
),
1734+
"<b>text</b",
1735+
False,
1736+
),
1737+
# TagsWithSeparatorFormat
1738+
(
1739+
xgr.structural_tag.TagsWithSeparatorFormat(
1740+
tags=[
1741+
xgr.structural_tag.TagFormat(
1742+
begin="<b>", content=xgr.structural_tag.AnyTextFormat(), end="</b>"
1743+
)
1744+
],
1745+
separator=",",
1746+
),
1747+
'<b>"1"</b>,<b>"2"</b>',
1748+
True,
1749+
),
1750+
(
1751+
xgr.structural_tag.TagsWithSeparatorFormat(
1752+
tags=[
1753+
xgr.structural_tag.TagFormat(
1754+
begin="<b>", content=xgr.structural_tag.AnyTextFormat(), end="</b>"
1755+
)
1756+
],
1757+
separator=",",
1758+
),
1759+
'<b>"1"</b><b>"2"</b>',
1760+
False,
1761+
),
1762+
# QwenXMLParameterFormat
1763+
(
1764+
xgr.structural_tag.QwenXMLParameterFormat(
1765+
json_schema={"type": "object", "properties": {"name": {"type": "string"}}}
1766+
),
1767+
"<parameter=name>value</parameter>",
1768+
True,
1769+
),
1770+
(
1771+
xgr.structural_tag.QwenXMLParameterFormat(
1772+
json_schema={"type": "object", "properties": {"name": {"type": "string"}}}
1773+
),
1774+
"<parameter=name>value</param>",
1775+
False,
1776+
),
1777+
]
1778+
1779+
1780+
@pytest.mark.parametrize(
1781+
"stag_format, instance, is_accepted", basic_structural_tags_instance_is_accepted
1782+
)
1783+
def test_from_structural_tag_with_structural_tag_instance(
1784+
stag_format: xgr.structural_tag.Format, instance: str, is_accepted: bool
1785+
):
1786+
stag = xgr.structural_tag.StructuralTag(format=stag_format)
1787+
grammar = xgr.Grammar.from_structural_tag(stag)
1788+
assert _is_grammar_accept_string(grammar, instance) == is_accepted
1789+
1790+
16581791
if __name__ == "__main__":
16591792
pytest.main(sys.argv)

0 commit comments

Comments
 (0)