Skip to content

Commit 4b5158e

Browse files
authored
feat: Add augmented assignment feature (#208)
1 parent e384edc commit 4b5158e

File tree

6 files changed

+104
-3
lines changed

6 files changed

+104
-3
lines changed

docs/contributing.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ In order to be able to contribute, it is important that you understand the
44
project layout. This project uses the _src layout_, which means that the package
55
code is located at `./src/astx`.
66

7-
For my information, check the official documentation:
7+
For more information, check the official documentation:
88
https://packaging.python.org/en/latest/discussions/src-layout-vs-flat-layout/
99

1010
In addition, you should know that to build our package we use

src/astx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@
115115
)
116116
from astx.operators import (
117117
AssignmentExpr,
118+
AugAssign,
118119
VariableAssignment,
119120
WalrusOp,
120121
)
@@ -204,6 +205,7 @@ def get_version() -> str:
204205
"AssignmentExpr",
205206
"AsyncForRangeLoopExpr",
206207
"AsyncForRangeLoopStmt",
208+
"AugAssign",
207209
"AwaitExpr",
208210
"BinaryOp",
209211
"Block",

src/astx/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ class ASTKind(Enum):
103103
BinaryOpKind = -301
104104
WalrusOpKind = -302
105105
AssignmentExprKind = -303
106+
AugmentedAssignKind = -304
106107

107108
# functions
108109
PrototypeKind = -400

src/astx/operators.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
from __future__ import annotations
44

5-
from typing import Iterable, Optional, cast
5+
from typing import Iterable, Literal, Optional, cast
66

77
from public import public
8+
from typing_extensions import TypeAlias
89

910
from astx.base import (
1011
NO_SOURCE_LOCATION,
@@ -13,6 +14,7 @@
1314
DataType,
1415
DictDataTypesStruct,
1516
Expr,
17+
Identifier,
1618
ReprStruct,
1719
SourceLocation,
1820
StatementType,
@@ -128,3 +130,55 @@ def get_struct(self, simplified: bool = False) -> ReprStruct:
128130
key = str(self)
129131
value = self.value.get_struct(simplified)
130132
return self._prepare_struct(key, value, simplified)
133+
134+
135+
OpCodeAugAssign: TypeAlias = Literal[
136+
"+=",
137+
"-=",
138+
"*=",
139+
"/=",
140+
"//=",
141+
"%=",
142+
"**=",
143+
"&=",
144+
"|=",
145+
"^=",
146+
"<<=",
147+
">>=",
148+
]
149+
150+
151+
@public
152+
@typechecked
153+
class AugAssign(DataType):
154+
"""AST class for augmented assignment."""
155+
156+
target: Identifier
157+
op_code: OpCodeAugAssign
158+
value: DataType
159+
160+
def __init__(
161+
self,
162+
target: Identifier,
163+
op_code: OpCodeAugAssign,
164+
value: DataType,
165+
loc: SourceLocation = NO_SOURCE_LOCATION,
166+
) -> None:
167+
super().__init__(loc=loc)
168+
self.target = target
169+
self.op_code = op_code
170+
self.value = value
171+
self.kind = ASTKind.AugmentedAssignKind
172+
173+
def __str__(self) -> str:
174+
"""Return a string that represents the augmented assignment object."""
175+
return f"AugAssign[{self.op_code}]"
176+
177+
def get_struct(self, simplified: bool = False) -> ReprStruct:
178+
"""Return the AST structure of the object."""
179+
key = str(self)
180+
value: ReprStruct = {
181+
"target": self.target.get_struct(simplified),
182+
"value": self.value.get_struct(simplified),
183+
}
184+
return self._prepare_struct(key, value, simplified)

src/astx/tools/transpilers/python.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from plum import dispatch
66

77
import astx
8+
import astx.operators
89

910
from astx.tools.typing import typechecked
1011

@@ -520,6 +521,13 @@ def visit(self, node: astx.WalrusOp) -> str:
520521
"""Handle Walrus operator."""
521522
return f"({self.visit(node.lhs)} := {self.visit(node.rhs)})"
522523

524+
@dispatch # type: ignore[no-redef]
525+
def visit(self, node: astx.AugAssign) -> str:
526+
"""Handle Augmented assign operator."""
527+
target = self.visit(node.target)
528+
value = self.visit(node.value)
529+
return f"{target} {node.op_code} {value}"
530+
523531
@dispatch # type: ignore[no-redef]
524532
def visit(self, node: astx.WhileExpr) -> str:
525533
"""Handle WhileExpr nodes."""

tests/test_operators.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
11
"""Tests for operators."""
22

3+
from typing import cast
4+
35
import astx
6+
import pytest
47

58
from astx.literals.numeric import LiteralInt32
6-
from astx.operators import AssignmentExpr, VariableAssignment
9+
from astx.operators import (
10+
AssignmentExpr,
11+
AugAssign,
12+
OpCodeAugAssign,
13+
VariableAssignment,
14+
)
715
from astx.variables import Variable
816
from astx.viz import visualize
917

@@ -116,3 +124,31 @@ def test_not_op() -> None:
116124
assert op.get_struct()
117125
assert op.get_struct(simplified=True)
118126
visualize(op.get_struct())
127+
128+
129+
@pytest.mark.parametrize(
130+
"operator, value",
131+
[
132+
(cast(OpCodeAugAssign, "+="), 10),
133+
(cast(OpCodeAugAssign, "-="), 5),
134+
(cast(OpCodeAugAssign, "*="), 3),
135+
(cast(OpCodeAugAssign, "/="), 2),
136+
(cast(OpCodeAugAssign, "//="), 2),
137+
(cast(OpCodeAugAssign, "%="), 4),
138+
(cast(OpCodeAugAssign, "**="), 2),
139+
(cast(OpCodeAugAssign, "&="), 6),
140+
(cast(OpCodeAugAssign, "|="), 3),
141+
(cast(OpCodeAugAssign, "^="), 1),
142+
(cast(OpCodeAugAssign, "<<="), 1),
143+
(cast(OpCodeAugAssign, ">>="), 2),
144+
],
145+
)
146+
def test_aug_assign_operations(operator: OpCodeAugAssign, value: int) -> None:
147+
"""Test all augmented assignment operators using parametrize."""
148+
var_x = astx.Identifier(value="x")
149+
literal_value = LiteralInt32(value)
150+
aug_assign = AugAssign(var_x, operator, literal_value)
151+
152+
assert str(aug_assign) == f"AugAssign[{operator}]"
153+
assert aug_assign.get_struct()
154+
assert aug_assign.get_struct(simplified=True)

0 commit comments

Comments
 (0)