Skip to content

Commit a814f4c

Browse files
authored
Fix logical comparison (#2598)
* Fix logical comparison * Handle logical comparisons on non-logical expressions * Handle logical comparison on strings * Fix type comparison * Fix nullptr error * Fix string comparison * Evaluate logical operations at compile time * Tests: Add error tests * Tests: Add tests * Tests: Update test references * Move `tmp` * Tests: Update test references * Add check for empty function call * Tests: Update tests and test references * Tests: Fix errors * Delete tests/reference/asr-test_logical_assignment-6b36ea2.json * Delete tests/reference/asr-test_logical_assignment-6b36ea2.stderr * Delete tests/reference/asr-test_logical_compare-d8a2a03.json * Delete tests/reference/asr-test_logical_compare_02-878af11.stderr * Delete tests/reference/asr-test_logical_compare_02-878af11.json * Delete tests/reference/asr-test_logical_compare-d8a2a03.stderr * Tests: Add error test to tests.toml
1 parent 031ced0 commit a814f4c

File tree

8 files changed

+286
-19
lines changed

8 files changed

+286
-19
lines changed

integration_tests/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,6 +755,8 @@ RUN(NAME test_platform LABELS cpython llvm c)
755755
RUN(NAME test_vars_01 LABELS cpython llvm)
756756
RUN(NAME test_version LABELS cpython llvm)
757757
RUN(NAME logical_binop1 LABELS cpython llvm)
758+
RUN(NAME test_logical_compare LABELS cpython llvm)
759+
RUN(NAME test_logical_assignment LABELS cpython llvm)
758760
RUN(NAME vec_01 LABELS cpython llvm c NOFAST)
759761
RUN(NAME test_str_comparison LABELS cpython llvm c wasm)
760762
RUN(NAME test_bit_length LABELS cpython llvm c)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from lpython import i32, f64
2+
3+
4+
def test_logical_assignment():
5+
# Can be uncommented after fixing the segfault
6+
# _LPYTHON: str = "LPython"
7+
# s_var: str = "" or _LPYTHON
8+
# assert s_var == "LPython"
9+
# print(s_var)
10+
11+
_MAX_VAL: i32 = 100
12+
i_var: i32 = 0 and 100
13+
assert i_var == 0
14+
print(i_var)
15+
16+
_PI: f64 = 3.14
17+
f_var: f64 = 2.0 * _PI or _PI**2.0
18+
assert f_var == 6.28
19+
print(f_var)
20+
21+
22+
test_logical_assignment()
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
from lpython import i32, f64
2+
3+
4+
def test_logical_compare_literal():
5+
# Integers
6+
print(1 or 3)
7+
assert (1 or 3) == 1
8+
9+
print(1 and 3)
10+
assert (1 and 3) == 3
11+
12+
print(2 or 3 or 5 or 6)
13+
assert (2 or 3 or 5 or 6) == 2
14+
15+
print(1 and 3 or 2 and 4)
16+
assert (1 and 3 or 2 and 4) == 3
17+
18+
print(1 or 3 and 0 or 4)
19+
assert (1 or 3 and 0 or 4) == 1
20+
21+
print(1 and 3 or 2 and 0)
22+
assert (1 and 3 or 2 and 0) == 3
23+
24+
print(1 and 0 or 3 and 4)
25+
assert (1 and 0 or 3 and 4) == 4
26+
27+
# Floating-point numbers
28+
print(1.33 or 6.67)
29+
assert (1.33 or 6.67) == 1.33
30+
31+
print(1.33 and 6.67)
32+
assert (1.33 and 6.67) == 6.67
33+
34+
print(1.33 or 6.67 and 3.33 or 0.0)
35+
assert (1.33 or 6.67 and 3.33 or 0.0) == 1.33
36+
37+
print(1.33 and 6.67 or 3.33 and 0.0)
38+
assert (1.33 and 6.67 or 3.33 and 0.0) == 6.67
39+
40+
print(1.33 and 0.0 and 3.33 and 6.67)
41+
assert (1.33 and 0.0 and 3.33 and 6.67) == 0.0
42+
43+
# Strings
44+
print("a" or "b")
45+
assert ("a" or "b") == "a"
46+
47+
print("abc" or "b")
48+
assert ("abc" or "b") == "abc"
49+
50+
print("a" and "b")
51+
assert ("a" and "b") == "b"
52+
53+
print("a" or "b" and "c" or "d")
54+
assert ("a" or "b" and "c" or "d") == "a"
55+
56+
print("" or " ")
57+
assert ("" or " ") == " "
58+
59+
print("" and " " or "a" and "b" and "c")
60+
assert ("" and " " or "a" and "b" and "c") == "c"
61+
62+
print("" and " " and "a" and "b" and "c")
63+
assert ("" and " " and "a" and "b" and "c") == ""
64+
65+
66+
def test_logical_compare_variable():
67+
# Integers
68+
i_a: i32 = 1
69+
i_b: i32 = 3
70+
71+
print(i_a and i_b)
72+
assert (i_a and i_b) == 3
73+
74+
print(i_a or i_b or 2 or 4)
75+
assert (i_a or i_b or 2 or 4) == 1
76+
77+
print(i_a and i_b or 2 and 4)
78+
assert (i_a and i_b or 2 and 4) == 3
79+
80+
print(i_a or i_b and 0 or 4)
81+
assert (i_a or i_b and 0 or 4) == i_a
82+
83+
print(i_a and i_b or 2 and 0)
84+
assert (i_a and i_b or 2 and 0) == i_b
85+
86+
print(i_a and 0 or i_b and 4)
87+
assert (i_a and 0 or i_b and 4) == 4
88+
89+
print(i_a + i_b or 0 - 4)
90+
assert (i_a + i_b or 0 - 4) == 4
91+
92+
# Floating-point numbers
93+
f_a: f64 = 1.67
94+
f_b: f64 = 3.33
95+
96+
print(f_a // f_b and f_a - f_b)
97+
assert (f_a // f_b and f_a - f_b) == 0.0
98+
99+
print(f_a**3.0 or 3.0**f_a)
100+
assert (f_a**3.0 or 3.0**f_a) == 4.657462999999999
101+
102+
print(f_a - 3.0 and f_a + 3.0 or f_b - 3.0 and f_b + 3.0)
103+
assert (f_a - 3.0 and f_a + 3.0 or f_b - 3.0 and f_b + 3.0) == 4.67
104+
105+
# Can be uncommented after fixing the segfault
106+
# Strings
107+
# s_a: str = "a"
108+
# s_b: str = "b"
109+
110+
# print(s_a or s_b)
111+
# assert (s_a or s_b) == s_a
112+
113+
# print(s_a and s_b)
114+
# assert (s_a and s_b) == s_b
115+
116+
# print(s_a + s_b or s_b + s_a)
117+
# assert (s_a + s_b or s_b + s_a) == "ab"
118+
119+
# print(s_a[0] or s_b[-1])
120+
# assert (s_a[0] or s_b[-1]) == "a"
121+
122+
# print(s_a[0] and s_b[-1])
123+
# assert (s_a[0] and s_b[-1]) == "b"
124+
125+
# print(s_a + s_b or s_b + s_a + s_a[0] and s_b[-1])
126+
# assert (s_a + s_b or s_b + s_a + s_a[0] and s_b[-1]) == "ab"
127+
128+
129+
test_logical_compare_literal()
130+
test_logical_compare_variable()

src/lpython/semantics/python_ast_to_asr.cpp

Lines changed: 105 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3326,29 +3326,116 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
33263326
x.base.base.loc);
33273327
}
33283328
}
3329-
LCOMPILERS_ASSERT(
3330-
ASRUtils::check_equal_type(ASRUtils::expr_type(lhs), ASRUtils::expr_type(rhs)));
3329+
ASR::ttype_t *left_operand_type = ASRUtils::expr_type(lhs);
3330+
ASR::ttype_t *right_operand_type = ASRUtils::expr_type(rhs);
3331+
33313332
ASR::expr_t *value = nullptr;
3332-
ASR::ttype_t *dest_type = ASRUtils::expr_type(lhs);
3333+
ASR::ttype_t *dest_type = left_operand_type;
33333334

3335+
if (!ASRUtils::check_equal_type(left_operand_type, right_operand_type)) {
3336+
throw SemanticError("Type mismatch: '" + ASRUtils::type_to_str_python(left_operand_type)
3337+
+ "' and '" + ASRUtils::type_to_str_python(right_operand_type)
3338+
+ "'. Both operands must be of the same type.", x.base.base.loc);
3339+
}
3340+
// Reference: https://docs.python.org/3/reference/expressions.html#boolean-operations
33343341
if (ASRUtils::expr_value(lhs) != nullptr && ASRUtils::expr_value(rhs) != nullptr) {
3335-
3336-
LCOMPILERS_ASSERT(ASR::is_a<ASR::Logical_t>(*dest_type));
3337-
bool left_value = ASR::down_cast<ASR::LogicalConstant_t>(
3338-
ASRUtils::expr_value(lhs))->m_value;
3339-
bool right_value = ASR::down_cast<ASR::LogicalConstant_t>(
3340-
ASRUtils::expr_value(rhs))->m_value;
3341-
bool result;
3342-
switch (op) {
3343-
case (ASR::logicalbinopType::And): { result = left_value && right_value; break; }
3344-
case (ASR::logicalbinopType::Or): { result = left_value || right_value; break; }
3345-
default : {
3346-
throw SemanticError("Boolean operator type not supported",
3347-
x.base.base.loc);
3342+
switch (dest_type->type) {
3343+
case ASR::ttypeType::Logical: {
3344+
bool left_value = ASR::down_cast<ASR::LogicalConstant_t>(
3345+
ASRUtils::expr_value(lhs))->m_value;
3346+
bool right_value = ASR::down_cast<ASR::LogicalConstant_t>(
3347+
ASRUtils::expr_value(rhs))->m_value;
3348+
bool result;
3349+
switch (op) {
3350+
case (ASR::logicalbinopType::And): { result = left_value && right_value; break; }
3351+
case (ASR::logicalbinopType::Or): { result = left_value || right_value; break; }
3352+
default : {
3353+
throw SemanticError("Boolean operator type not supported",
3354+
x.base.base.loc);
3355+
}
3356+
}
3357+
value = ASRUtils::EXPR(ASR::make_LogicalConstant_t(
3358+
al, x.base.base.loc, result, dest_type));
3359+
break;
3360+
}
3361+
case ASR::ttypeType::Integer: {
3362+
int64_t left_value = ASR::down_cast<ASR::IntegerConstant_t>(
3363+
ASRUtils::expr_value(lhs))->m_n;
3364+
int64_t right_value = ASR::down_cast<ASR::IntegerConstant_t>(
3365+
ASRUtils::expr_value(rhs))->m_n;
3366+
int64_t result;
3367+
switch (op) {
3368+
case (ASR::logicalbinopType::And): {
3369+
result = left_value == 0 ? left_value : right_value;
3370+
break;
3371+
}
3372+
case (ASR::logicalbinopType::Or): {
3373+
result = left_value != 0 ? left_value : right_value;
3374+
break;
3375+
}
3376+
default : {
3377+
throw SemanticError("Boolean operator type not supported",
3378+
x.base.base.loc);
3379+
}
3380+
}
3381+
value = ASRUtils::EXPR(ASR::make_IntegerConstant_t(
3382+
al, x.base.base.loc, result, dest_type));
3383+
break;
3384+
}
3385+
case ASR::ttypeType::Real: {
3386+
double left_value = ASR::down_cast<ASR::RealConstant_t>(
3387+
ASRUtils::expr_value(lhs))->m_r;
3388+
double right_value = ASR::down_cast<ASR::RealConstant_t>(
3389+
ASRUtils::expr_value(rhs))->m_r;
3390+
double result;
3391+
switch (op) {
3392+
case (ASR::logicalbinopType::And): {
3393+
result = left_value == 0 ? left_value : right_value;
3394+
break;
3395+
}
3396+
case (ASR::logicalbinopType::Or): {
3397+
result = left_value != 0 ? left_value : right_value;
3398+
break;
3399+
}
3400+
default : {
3401+
throw SemanticError("Boolean operator type not supported",
3402+
x.base.base.loc);
3403+
}
3404+
}
3405+
value = ASRUtils::EXPR(ASR::make_RealConstant_t(
3406+
al, x.base.base.loc, result, dest_type));
3407+
break;
33483408
}
3409+
case ASR::ttypeType::Character: {
3410+
char* left_value = ASR::down_cast<ASR::StringConstant_t>(
3411+
ASRUtils::expr_value(lhs))->m_s;
3412+
char* right_value = ASR::down_cast<ASR::StringConstant_t>(
3413+
ASRUtils::expr_value(rhs))->m_s;
3414+
char* result;
3415+
switch (op) {
3416+
case (ASR::logicalbinopType::And): {
3417+
result = std::strcmp(left_value, "") == 0 ? left_value : right_value;
3418+
break;
3419+
}
3420+
case (ASR::logicalbinopType::Or): {
3421+
result = std::strcmp(left_value, "") != 0 ? left_value : right_value;
3422+
break;
3423+
}
3424+
default : {
3425+
throw SemanticError("Boolean operator type not supported",
3426+
x.base.base.loc);
3427+
}
3428+
}
3429+
value = ASRUtils::EXPR(ASR::make_StringConstant_t(
3430+
al, x.base.base.loc, result, dest_type));
3431+
break;
3432+
}
3433+
3434+
default:
3435+
throw SemanticError("Boolean operation not supported on objects of type '"
3436+
+ ASRUtils::type_to_str_python(dest_type) + "'",
3437+
x.base.base.loc);
33493438
}
3350-
value = ASR::down_cast<ASR::expr_t>(ASR::make_LogicalConstant_t(
3351-
al, x.base.base.loc, result, dest_type));
33523439
}
33533440
tmp = ASR::make_LogicalBinOp_t(al, x.base.base.loc, lhs, op, rhs, dest_type, value);
33543441
}
@@ -7586,7 +7673,6 @@ we will have to use something else.
75867673
}
75877674
Vec<ASR::expr_t*> args_; args_.reserve(al, x.n_args);
75887675
visit_expr_list(x.m_args, x.n_args, args_);
7589-
75907676
if (x.n_args > 0 && ASRUtils::is_array(ASRUtils::expr_type(args_[0])) &&
75917677
imported_functions[call_name] == "math" ) {
75927678
throw SemanticError("Function '" + call_name + "' does not accept vector values",
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
def f():
2+
print("hello" or 10)
3+
4+
5+
f()
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
"basename": "asr-test_logical_compare_01-5db0e2e",
3+
"cmd": "lpython --show-asr --no-color {infile} -o {outfile}",
4+
"infile": "tests/errors/test_logical_compare_01.py",
5+
"infile_hash": "467dc216d8ce90f4b3a1ec06610cea226ae96152763cfa42d5ab8f33",
6+
"outfile": null,
7+
"outfile_hash": null,
8+
"stdout": null,
9+
"stdout_hash": null,
10+
"stderr": "asr-test_logical_compare_01-5db0e2e.stderr",
11+
"stderr_hash": "d10cac68687315b5d29828e0acb5170f44bd91dd30784f8bd4943bb0",
12+
"returncode": 2
13+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
semantic error: Type mismatch: 'str' and 'i32'. Both operands must be of the same type.
2+
--> tests/errors/test_logical_compare_01.py:2:11
3+
|
4+
2 | print("hello" or 10)
5+
| ^^^^^^^^^^^^^

tests/tests.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1259,6 +1259,10 @@ asr = true
12591259
filename = "errors/loop_03.py"
12601260
asr = true
12611261

1262+
[[test]]
1263+
filename = "errors/test_logical_compare_01.py"
1264+
asr = true
1265+
12621266
[[test]]
12631267
filename = "errors/bindc_01.py"
12641268
asr = true

0 commit comments

Comments
 (0)