Skip to content

Commit e9dece6

Browse files
feat: add logged_assert and logged_err functions
- add functions to stubs (util namespace) - add log_error field to AssertExpression AWST node - add arc65_error function builders - add the proper ebs to the type registry - backend changes (IR) to support the logged error compilation path - ToCodeVisitor changes to show logged errors
1 parent ff99e92 commit e9dece6

File tree

6 files changed

+251
-9
lines changed

6 files changed

+251
-9
lines changed

src/puya/awst/nodes.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,7 @@ class AssertExpression(Expression):
338338
"""An error message to be associated with the assertion failure"""
339339
wtype: WType = attrs.field(default=wtypes.void_wtype, init=False)
340340
explicit: bool = True
341+
log_error: bool = False
341342

342343
def accept(self, visitor: ExpressionVisitor[T]) -> T:
343344
return visitor.visit_assert_expression(self)
@@ -1211,13 +1212,13 @@ def _validate_wtype(self, _attr: object, wtype: WType) -> None:
12111212
match self.base.wtype, wtype:
12121213
case wtypes.BytesWType(), wtypes.BytesWType(length=1 | None):
12131214
pass
1214-
case wtypes.ARC4Array(
1215-
element_type=array_element_type
1216-
), _ if array_element_type == wtype:
1215+
case wtypes.ARC4Array(element_type=array_element_type), _ if (
1216+
array_element_type == wtype
1217+
):
12171218
pass
1218-
case wtypes.ReferenceArray(
1219-
element_type=array_element_type
1220-
), _ if array_element_type == wtype:
1219+
case wtypes.ReferenceArray(element_type=array_element_type), _ if (
1220+
array_element_type == wtype
1221+
):
12211222
pass
12221223
case _:
12231224
raise InternalError(

src/puya/awst/to_code_visitor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -654,12 +654,13 @@ def visit_return_statement(self, statement: nodes.ReturnStatement) -> list[str]:
654654
def visit_assert_expression(self, statement: nodes.AssertExpression) -> str:
655655
error_message = "" if statement.error_message is None else f'"{statement.error_message}"'
656656
if not statement.condition:
657-
result = "err("
657+
result = "err(" if not statement.log_error else "logged_err("
658658
if error_message:
659659
result += error_message
660660
result += ")"
661661
else:
662-
result = f"assert({statement.condition.accept(self)}"
662+
result = "assert(" if not statement.log_error else "logged_assert("
663+
result += f"{statement.condition.accept(self)}"
663664
if error_message:
664665
result += f", comment={error_message}"
665666
result += ")"

src/puya/ir/builder/main.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
method_signature_to_abi_signature,
4141
)
4242
from puya.ir.builder.encoding_validation import validate_encoding
43-
from puya.ir.context import IRBuildContext
43+
from puya.ir.context import IRBuildContext, IRFunctionBuildContext
4444
from puya.ir.encodings import wtype_to_encoding
4545
from puya.ir.op_utils import OpFactory, assert_value, assign_intrinsic_op, assign_targets, mktemp
4646
from puya.ir.types_ import wtype_to_encoded_ir_type, wtype_to_ir_type
@@ -1120,6 +1120,13 @@ def visit_assert_expression(self, expr: awst_nodes.AssertExpression) -> TStateme
11201120
else: # false constant, treat as fail/err
11211121
condition_value = None
11221122

1123+
if expr.log_error:
1124+
if expr.error_message:
1125+
_build_logged_error(self.context, condition_value, expr.error_message, loc)
1126+
else:
1127+
raise InternalError("a logged error should have some kind of string to log")
1128+
return None
1129+
11231130
if condition_value is None:
11241131
self.context.block_builder.terminate(
11251132
ir.Fail(
@@ -1590,6 +1597,50 @@ def materialise_value_provider(
15901597
return multi_value_to_values(value_or_tuple)
15911598

15921599

1600+
def _build_logged_error(
1601+
context: IRFunctionBuildContext,
1602+
condition: ir.Value | None,
1603+
msg: str,
1604+
loc: SourceLocation,
1605+
):
1606+
# model assert/err behavior as a conditional jump into a pushbytes X; log; err; pattern
1607+
if condition:
1608+
log_and_fail, after_assert = context.block_builder.mkblocks(
1609+
"logged_error_handling", "after_assert", source_location=loc
1610+
)
1611+
context.block_builder.terminate(
1612+
ir.ConditionalBranch(
1613+
condition=condition,
1614+
non_zero=after_assert,
1615+
zero=log_and_fail,
1616+
source_location=loc,
1617+
)
1618+
)
1619+
context.block_builder.activate_block(log_and_fail)
1620+
context.block_builder.add(
1621+
ir.Intrinsic(
1622+
op=AVMOp("log"),
1623+
args=[
1624+
ir.BytesConstant(
1625+
value=msg.encode("utf8"),
1626+
encoding=types.AVMBytesEncoding.utf8,
1627+
source_location=loc,
1628+
)
1629+
],
1630+
source_location=loc,
1631+
)
1632+
)
1633+
context.block_builder.terminate(
1634+
ir.Fail(
1635+
error_message=msg,
1636+
explicit=True, # logged errors are always explicit
1637+
source_location=loc,
1638+
)
1639+
)
1640+
if condition:
1641+
context.block_builder.try_activate_block(after_assert)
1642+
1643+
15931644
def create_uint64_binary_op(
15941645
op: UInt64BinaryOperator,
15951646
left: ir.Value,

src/puyapy/awst_build/eb/_type_registry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from puyapy.awst_build import constants, intrinsic_data, pytypes
1010
from puyapy.awst_build.eb import (
1111
arc4,
12+
arc65_error,
1213
biguint,
1314
bool as bool_,
1415
bytes as bytes_,
@@ -58,6 +59,8 @@
5859
template_variables.GenericTemplateVariableExpressionBuilder
5960
),
6061
"algopy.op.err": intrinsics.ErrFunctionBuilder,
62+
"algopy._util.logged_assert": arc65_error.LoggedAssertFunctionBuilder,
63+
"algopy._util.logged_err": arc65_error.LoggedErrFunctionBuilder,
6164
**{
6265
(fullname := "".join((constants.ALGOPY_OP_PREFIX, name))): functools.partial(
6366
intrinsics.IntrinsicFunctionExpressionBuilder, fullname, mappings
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import typing
2+
from collections.abc import Sequence
3+
4+
from puya import log
5+
from puya.awst.nodes import AssertExpression
6+
from puya.parse import SourceLocation
7+
from puyapy import models
8+
from puyapy.awst_build import pytypes
9+
from puyapy.awst_build.eb import _expect as expect
10+
from puyapy.awst_build.eb._base import FunctionBuilder
11+
from puyapy.awst_build.eb.factories import builder_for_instance
12+
from puyapy.awst_build.eb.interface import InstanceBuilder, NodeBuilder
13+
from puyapy.awst_build.utils import get_arg_mapping
14+
15+
logger = log.get_logger(__name__)
16+
17+
_VALID_PREFIXES = frozenset(("AER", "ERR"))
18+
_LONG_ERROR_MESSAGE = 64 # long error message (in bytes)
19+
20+
21+
class LoggedAssertFunctionBuilder(FunctionBuilder):
22+
@typing.override
23+
def call(
24+
self,
25+
args: Sequence[NodeBuilder],
26+
arg_kinds: list[models.ArgKind],
27+
arg_names: list[str | None],
28+
location: SourceLocation,
29+
) -> InstanceBuilder:
30+
arg_mapping, _ = get_arg_mapping(
31+
required_positional_names=["condition", "error_code"],
32+
optional_positional_names=["error_message", "prefix"],
33+
args=args,
34+
arg_names=arg_names,
35+
call_location=location,
36+
raise_on_missing=False,
37+
)
38+
39+
condition_arg = arg_mapping.get("condition")
40+
if condition_arg is not None:
41+
condition_eb = expect.instance_builder(condition_arg, default=expect.default_none)
42+
condition_expr = condition_eb.bool_eval(location).resolve() if condition_eb else None
43+
else:
44+
condition_expr = None
45+
46+
error_message = _resolve_error_message(arg_mapping, location)
47+
return builder_for_instance(
48+
pytypes.NoneType,
49+
AssertExpression(
50+
condition=condition_expr,
51+
error_message=error_message,
52+
source_location=location,
53+
log_error=True,
54+
),
55+
)
56+
57+
58+
class LoggedErrFunctionBuilder(FunctionBuilder):
59+
@typing.override
60+
def call(
61+
self,
62+
args: Sequence[NodeBuilder],
63+
arg_kinds: list[models.ArgKind],
64+
arg_names: list[str | None],
65+
location: SourceLocation,
66+
) -> InstanceBuilder:
67+
arg_mapping, _ = get_arg_mapping(
68+
required_positional_names=["error_code"],
69+
optional_positional_names=["error_message", "prefix"],
70+
args=args,
71+
arg_names=arg_names,
72+
call_location=location,
73+
raise_on_missing=False,
74+
)
75+
76+
error_message = _resolve_error_message(arg_mapping, location)
77+
return builder_for_instance(
78+
pytypes.NoneType,
79+
AssertExpression(
80+
condition=None,
81+
error_message=error_message,
82+
source_location=location,
83+
log_error=True,
84+
),
85+
)
86+
87+
88+
def _resolve_error_message(
89+
arg_mapping: dict[str, NodeBuilder], location: SourceLocation
90+
) -> str | None:
91+
code = (
92+
expect.simple_string_literal(arg_mapping["error_code"], default=expect.default_none)
93+
if "error_code" in arg_mapping
94+
else None
95+
)
96+
message = (
97+
expect.simple_string_literal(arg_mapping["error_message"], default=expect.default_none)
98+
if "error_message" in arg_mapping
99+
else None
100+
)
101+
prefix = (
102+
expect.simple_string_literal(arg_mapping["prefix"], default=expect.default_none)
103+
if "prefix" in arg_mapping
104+
else "ERR"
105+
)
106+
107+
# code validation
108+
if code is None:
109+
logger.error("error code is mandatory in logged errors", location=location)
110+
return None
111+
elif ":" in code:
112+
logger.error("error code must not contain domain separator ':'", location=location)
113+
elif not code.isalnum():
114+
logger.warning("error code should be alphanumeric", location=location)
115+
116+
# message validation
117+
if message is not None and ":" in message:
118+
logger.error("error message must not contain domain separator ':'", location=location)
119+
120+
# prefix validation (note: prefix should already be validated by mypy typing check)
121+
if prefix not in _VALID_PREFIXES:
122+
logger.error(
123+
"error prefix must be one of AER, ERR",
124+
location=location,
125+
)
126+
127+
arc65_msg = f"{prefix}:{code}:{message}" if message else f"{prefix}:{code}"
128+
129+
# arc65 recommendations
130+
msglen = len(arc65_msg)
131+
if msglen >= _LONG_ERROR_MESSAGE:
132+
logger.warning(
133+
f"error message is {msglen} bytes long, consider making it shorter",
134+
location=location,
135+
)
136+
elif msglen in (8, 32):
137+
logger.warning(
138+
f"your final error message is {msglen} bytes long. "
139+
"Error messages exactly 8 or 32 bytes long are discouraged",
140+
location=location,
141+
)
142+
143+
return arc65_msg

stubs/algopy-stubs/_util.pyi

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,46 @@ def size_of(type_or_expression: type | object, /) -> UInt64:
3030
Returns the number of bytes required to store a statically sized type,
3131
given as a type object or an expression of that type.
3232
"""
33+
34+
def logged_assert(
35+
condition: bool,
36+
/,
37+
error_code: typing.LiteralString,
38+
error_message: typing.LiteralString | None = None,
39+
prefix: typing.Literal["AER", "ERR"] = "ERR",
40+
) -> None:
41+
"""
42+
Asserts that a condition is true, logging a formatted error message before failing
43+
if the condition is false.
44+
45+
The logged output follows the format ``{prefix}:{error_code}`` or
46+
``{prefix}:{error_code}:{error_message}`` and is compatible with ARC-56 and ARC-32 clients.
47+
48+
Note this increases the generated bytecode, so keeping ``error_code`` and ``error_message``
49+
short is recommended.
50+
51+
:arg condition: The condition to assert; if false, logs an error and fails.
52+
:arg error_code: An alphanumeric error code. Must not contain ``:``.
53+
:arg error_message: Optional message appended after the code. Must not contain ``:``.
54+
:arg prefix: Error prefix, either ``"AER"`` or ``"ERR"``.
55+
"""
56+
57+
def logged_err(
58+
error_code: typing.LiteralString,
59+
error_message: typing.LiteralString | None = None,
60+
prefix: typing.Literal["AER", "ERR"] = "ERR",
61+
) -> None:
62+
"""
63+
Logs a formatted error message and immediately fails the transaction.
64+
65+
Note this is equivalent to ``logged_assert(False, error_code, error_message, prefix)``.
66+
This function increases the generated bytecode, so keeping ``error_code`` and ``error_message``
67+
short is recommended.
68+
69+
The logged output follows the ARC-65 format ``{prefix}:{error_code}`` or
70+
``{prefix}:{error_code}:{error_message}`` and is compatible with ARC-56 and ARC-32 clients.
71+
72+
:arg error_code: An alphanumeric error code. Must not contain ``:``.
73+
:arg error_message: Optional message appended after the code. Must not contain ``:``.
74+
:arg prefix: Error prefix, either ``"AER"`` or ``"ERR"``.
75+
"""

0 commit comments

Comments
 (0)