Skip to content

Commit b2a3520

Browse files
authored
Merge pull request #2061 from Shaikh-Ubaid/fix_not
Fix unsigned integer bitnot
2 parents e642072 + 116f344 commit b2a3520

File tree

9 files changed

+238
-13
lines changed

9 files changed

+238
-13
lines changed

integration_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,7 @@ RUN(NAME test_max_min LABELS cpython llvm c)
551551
RUN(NAME test_global LABELS cpython llvm c)
552552
RUN(NAME test_global_decl LABELS cpython llvm c)
553553
RUN(NAME test_integer_bitnot LABELS cpython llvm c wasm)
554+
RUN(NAME test_unsign_int_bitnot LABELS cpython llvm c)
554555
RUN(NAME test_ifexp LABELS cpython llvm c)
555556
RUN(NAME test_unary_minus LABELS cpython llvm c)
556557
RUN(NAME test_unary_plus LABELS cpython llvm c)

integration_tests/bindpy_01.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
11
from lpython import i32, i64, u32, u64, f32, f64, pythoncall
22

33
@pythoncall(module = "bindpy_01_module")
4-
def add_ints(a: i32, b: i64, c: u32, d: u64) -> i64:
4+
def add_ints(a: i32, b: i32, c: i32, d: i32) -> i64:
55
pass
66

77
@pythoncall(module = "bindpy_01_module")
8-
def multiply_ints(a: i32, b: i64, c: u32, d: u64) -> i64:
8+
def multiply_ints(a: i32, b: i32, c: i32, d: i32) -> i64:
9+
pass
10+
11+
@pythoncall(module = "bindpy_01_module")
12+
def add_unsigned_ints(a: u32, b: u32, c: u32, d: u32) -> u64:
13+
pass
14+
15+
@pythoncall(module = "bindpy_01_module")
16+
def multiply_unsigned_ints(a: u32, b: u32, c: u32, d: u32) -> u64:
917
pass
1018

1119
@pythoncall(module = "bindpy_01_module")
@@ -31,17 +39,31 @@ def get_cpython_version() -> str:
3139
# Integers:
3240
def test_ints():
3341
i: i32
34-
j: i64
35-
k: u32
36-
l: u64
42+
j: i32
43+
k: i32
44+
l: i32
3745
i = -5
38-
j = i64(24)
39-
k = u32(20)
40-
l = u64(92)
46+
j = 24
47+
k = 20
48+
l = 92
4149

4250
assert add_ints(i, j, k, l) == i64(131)
4351
assert multiply_ints(i, j, k, l) == i64(-220800)
4452

53+
# Unsigned Integers:
54+
def test_unsigned_ints():
55+
i: u32
56+
j: u32
57+
k: u32
58+
l: u32
59+
i = u32(5)
60+
j = u32(24)
61+
k = u32(20)
62+
l = u32(92)
63+
64+
assert add_unsigned_ints(i, j, k, l) == u64(141)
65+
assert multiply_unsigned_ints(i, j, k, l) == u64(220800)
66+
4567
# Floats
4668
def test_floats():
4769
a: f32

integration_tests/bindpy_01_module.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@ def multiply_ints(a, b, c, d):
1010
e = a * b * c * d
1111
return e
1212

13+
def add_unsigned_ints(a, b, c, d):
14+
e = a + b + c + d
15+
return e
16+
17+
def multiply_unsigned_ints(a, b, c, d):
18+
e = a * b * c * d
19+
return e
20+
1321
def add_floats(a, b):
1422
return a + b
1523

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from lpython import u16
2+
3+
def foo(grp: u16) -> u16:
4+
i: u16 = ~(u16(grp))
5+
6+
return i
7+
8+
9+
def foo2() -> u16:
10+
i: u16 = ~(u16(0xffff))
11+
12+
return i
13+
14+
def foo3() -> u16:
15+
i: u16 = ~(u16(0xffff))
16+
17+
return ~i
18+
19+
assert foo(u16(0)) == u16(0xffff)
20+
assert foo(u16(0xffff)) == u16(0)
21+
assert foo2() == u16(0)
22+
assert foo3() == u16(0xffff)

src/libasr/ASR.asdl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ expr
250250
| IntegerBinOp(expr left, binop op, expr right, ttype type, expr? value)
251251
| UnsignedIntegerConstant(int n, ttype type)
252252
| UnsignedIntegerUnaryMinus(expr arg, ttype type, expr? value)
253+
| UnsignedIntegerBitNot(expr arg, ttype type, expr? value)
253254
| UnsignedIntegerCompare(expr left, cmpop op, expr right, ttype type, expr? value)
254255
| UnsignedIntegerBinOp(expr left, binop op, expr right, ttype type, expr? value)
255256
| RealConstant(float r, ttype type)

src/libasr/codegen/asr_to_c_cpp.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2036,7 +2036,8 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
20362036
}
20372037
}
20382038

2039-
void visit_IntegerBitNot(const ASR::IntegerBitNot_t& x) {
2039+
template<typename T>
2040+
void handle_SU_IntegerBitNot(const T& x) {
20402041
CHECK_FAST_C_CPP(compiler_options, x)
20412042
self().visit_expr(*x.m_arg);
20422043
int expr_precedence = last_expr_precedence;
@@ -2048,6 +2049,14 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
20482049
}
20492050
}
20502051

2052+
void visit_IntegerBitNot(const ASR::IntegerBitNot_t& x) {
2053+
handle_SU_IntegerBitNot(x);
2054+
}
2055+
2056+
void visit_UnsignedIntegerBitNot(const ASR::UnsignedIntegerBitNot_t& x) {
2057+
handle_SU_IntegerBitNot(x);
2058+
}
2059+
20512060
void visit_IntegerUnaryMinus(const ASR::IntegerUnaryMinus_t &x) {
20522061
handle_UnaryMinus(x);
20532062
}

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6193,6 +6193,15 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
61936193
tmp = builder->CreateNot(tmp);
61946194
}
61956195

6196+
void visit_UnsignedIntegerBitNot(const ASR::UnsignedIntegerBitNot_t &x) {
6197+
if (x.m_value) {
6198+
this->visit_expr_wrapper(x.m_value, true);
6199+
return;
6200+
}
6201+
this->visit_expr_wrapper(x.m_arg, true);
6202+
tmp = builder->CreateNot(tmp);
6203+
}
6204+
61966205
void visit_IntegerUnaryMinus(const ASR::IntegerUnaryMinus_t &x) {
61976206
if (x.m_value) {
61986207
this->visit_expr_wrapper(x.m_value, true);

src/lpython/semantics/python_ast_to_asr.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3410,6 +3410,17 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
34103410
tmp = ASR::make_IntegerBitNot_t(al, x.base.base.loc, operand, dest_type, value);
34113411
return;
34123412
}
3413+
else if (ASRUtils::is_unsigned_integer(*operand_type)) {
3414+
if (ASRUtils::expr_value(operand) != nullptr) {
3415+
int64_t op_value = ASR::down_cast<ASR::UnsignedIntegerConstant_t>(
3416+
ASRUtils::expr_value(operand))->m_n;
3417+
uint64_t val = ~uint64_t(op_value);
3418+
value = ASR::down_cast<ASR::expr_t>(ASR::make_UnsignedIntegerConstant_t(
3419+
al, x.base.base.loc, val, operand_type));
3420+
}
3421+
tmp = ASR::make_UnsignedIntegerBitNot_t(al, x.base.base.loc, operand, dest_type, value);
3422+
return;
3423+
}
34133424
else if (ASRUtils::is_real(*operand_type)) {
34143425
throw SemanticError("Unary operator '~' not supported for floats",
34153426
x.base.base.loc);

src/runtime/lpython/lpython.py

Lines changed: 146 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,158 @@
1414

1515
# data-types
1616

17+
class UnsignedInteger:
18+
def __init__(self, bit_width, value):
19+
if isinstance(value, UnsignedInteger):
20+
if bit_width != value.bit_width:
21+
raise ValueError(f"Bit width mismatch: {bit_width} vs {value.bit_width}")
22+
value = value.value
23+
24+
if not (0 <= value < 2**bit_width):
25+
raise ValueError(f"Value should be in range 0 to {2**bit_width-1} for a {bit_width}-bit unsigned integer.")
26+
self.bit_width = bit_width
27+
self.value = value
28+
29+
def __add__(self, other):
30+
if isinstance(other, self.__class__):
31+
return UnsignedInteger(self.bit_width, (self.value + other.value) % (2**self.bit_width))
32+
else:
33+
raise TypeError("Unsupported operand type")
34+
35+
def __sub__(self, other):
36+
if isinstance(other, self.__class__):
37+
# if self.value < other.value:
38+
# raise ValueError("Result of subtraction cannot be negative")
39+
return UnsignedInteger(self.bit_width, (self.value - other.value) % (2**self.bit_width))
40+
else:
41+
raise TypeError("Unsupported operand type")
42+
43+
def __mul__(self, other):
44+
if isinstance(other, self.__class__):
45+
return UnsignedInteger(self.bit_width, (self.value * other.value) % (2**self.bit_width))
46+
else:
47+
raise TypeError("Unsupported operand type")
48+
49+
def __div__(self, other):
50+
if isinstance(other, self.__class__):
51+
if other.value == 0:
52+
raise ValueError("Division by zero")
53+
return UnsignedInteger(self.bit_width, self.value / other.value)
54+
else:
55+
raise TypeError("Unsupported operand type")
56+
57+
def __floordiv__(self, other):
58+
if isinstance(other, self.__class__):
59+
if other.value == 0:
60+
raise ValueError("Division by zero")
61+
return UnsignedInteger(self.bit_width, self.value // other.value)
62+
else:
63+
raise TypeError("Unsupported operand type")
64+
65+
def __mod__(self, other):
66+
if isinstance(other, self.__class__):
67+
if other.value == 0:
68+
raise ValueError("Modulo by zero")
69+
return UnsignedInteger(self.bit_width, self.value % other.value)
70+
else:
71+
raise TypeError("Unsupported operand type")
72+
73+
def __pow__(self, other):
74+
if isinstance(other, self.__class__):
75+
return UnsignedInteger(self.bit_width, (self.value ** other.value) % (2**self.bit_width))
76+
else:
77+
raise TypeError("Unsupported operand type")
78+
79+
def __and__(self, other):
80+
if isinstance(other, self.__class__):
81+
return UnsignedInteger(self.bit_width, self.value & other.value)
82+
else:
83+
raise TypeError("Unsupported operand type")
84+
85+
def __or__(self, other):
86+
if isinstance(other, self.__class__):
87+
return UnsignedInteger(self.bit_width, self.value | other.value)
88+
else:
89+
raise TypeError("Unsupported operand type")
90+
91+
# unary operators
92+
def __neg__(self):
93+
return UnsignedInteger(self.bit_width, -self.value % (2**self.bit_width))
94+
95+
def __pos__(self):
96+
return UnsignedInteger(self.bit_width, self.value)
97+
98+
def __abs__(self):
99+
return UnsignedInteger(self.bit_width, abs(self.value))
100+
101+
def __invert__(self):
102+
return UnsignedInteger(self.bit_width, ~self.value % (2**self.bit_width))
103+
104+
# comparator operators
105+
def __eq__(self, other):
106+
if isinstance(other, self.__class__):
107+
return self.value == other.value
108+
else:
109+
try:
110+
return self.value == other
111+
except:
112+
raise TypeError("Unsupported operand type")
113+
114+
def __ne__(self, other):
115+
if isinstance(other, self.__class__):
116+
return self.value != other.value
117+
else:
118+
raise TypeError("Unsupported operand type")
119+
120+
def __lt__(self, other):
121+
if isinstance(other, self.__class__):
122+
return self.value < other.value
123+
else:
124+
raise TypeError("Unsupported operand type")
125+
126+
def __le__(self, other):
127+
if isinstance(other, self.__class__):
128+
return self.value <= other.value
129+
else:
130+
raise TypeError("Unsupported operand type")
131+
132+
def __gt__(self, other):
133+
if isinstance(other, self.__class__):
134+
return self.value > other.value
135+
else:
136+
raise TypeError("Unsupported operand type")
137+
138+
def __ge__(self, other):
139+
if isinstance(other, self.__class__):
140+
return self.value >= other.value
141+
else:
142+
raise TypeError("Unsupported operand type")
143+
144+
# conversion to integer
145+
def __int__(self):
146+
return self.value
147+
148+
def __str__(self):
149+
return str(self.value)
150+
151+
def __repr__(self):
152+
return f'UnsignedInteger({self.bit_width}, {str(self)})'
153+
154+
def __index__(self):
155+
return self.value
156+
157+
158+
17159
type_to_convert_func = {
18160
"i1": bool,
19161
"i8": int,
20162
"i16": int,
21163
"i32": int,
22164
"i64": int,
23-
"u8": lambda x: x,
24-
"u16": lambda x: x,
25-
"u32": lambda x: x,
26-
"u64": lambda x: x,
165+
"u8": lambda x: UnsignedInteger(8, x),
166+
"u16": lambda x: UnsignedInteger(16, x),
167+
"u32": lambda x: UnsignedInteger(32, x),
168+
"u64": lambda x: UnsignedInteger(64, x),
27169
"f32": float,
28170
"f64": float,
29171
"c32": complex,

0 commit comments

Comments
 (0)