Skip to content

Commit 8d69c93

Browse files
authored
Add set data structure (#2122)
1 parent 47bd6d5 commit 8d69c93

21 files changed

+1333
-58
lines changed

integration_tests/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,9 @@ RUN(NAME test_dict_12 LABELS cpython llvm c)
484484
RUN(NAME test_dict_13 LABELS cpython llvm c)
485485
RUN(NAME test_dict_bool LABELS cpython llvm)
486486
RUN(NAME test_dict_increment LABELS cpython llvm)
487+
RUN(NAME test_set_len LABELS cpython llvm)
488+
RUN(NAME test_set_add LABELS cpython llvm)
489+
RUN(NAME test_set_remove LABELS cpython llvm)
487490
RUN(NAME test_for_loop LABELS cpython llvm c)
488491
RUN(NAME modules_01 LABELS cpython llvm c wasm wasm_x86 wasm_x64)
489492
RUN(NAME modules_02 LABELS cpython llvm c wasm wasm_x86 wasm_x64)

integration_tests/test_set_add.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from lpython import i32
2+
3+
def test_set_add():
4+
s1: set[i32]
5+
s2: set[tuple[i32, tuple[i32, i32], str]]
6+
s3: set[str]
7+
st1: str
8+
i: i32
9+
j: i32
10+
11+
s1 = {0}
12+
s2 = {(0, (1, 2), 'a')}
13+
for i in range(20):
14+
j = i % 10
15+
s1.add(j)
16+
s2.add((j, (j + 1, j + 2), 'a'))
17+
assert len(s1) == len(s2)
18+
if i < 10:
19+
assert len(s1) == i + 1
20+
else:
21+
assert len(s1) == 10
22+
23+
st1 = 'a'
24+
s3 = {st1}
25+
for i in range(20):
26+
s3.add(st1)
27+
if i < 10:
28+
if i > 0:
29+
assert len(s3) == i
30+
st1 += 'a'
31+
else:
32+
assert len(s3) == 10
33+
34+
test_set_add()

integration_tests/test_set_len.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from lpython import i32
2+
3+
def test_set():
4+
s: set[i32]
5+
s = {1, 2, 22, 2, -1, 1}
6+
assert len(s) == 4
7+
8+
test_set()
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from lpython import i32
2+
3+
def test_set_add():
4+
s1: set[i32]
5+
s2: set[tuple[i32, tuple[i32, i32], str]]
6+
s3: set[str]
7+
st1: str
8+
i: i32
9+
j: i32
10+
k: i32
11+
12+
for k in range(2):
13+
s1 = {0}
14+
s2 = {(0, (1, 2), 'a')}
15+
for i in range(20):
16+
j = i % 10
17+
s1.add(j)
18+
s2.add((j, (j + 1, j + 2), 'a'))
19+
20+
for i in range(10):
21+
s1.remove(i)
22+
s2.remove((i, (i + 1, i + 2), 'a'))
23+
# assert len(s1) == 10 - 1 - i
24+
# assert len(s1) == len(s2)
25+
26+
st1 = 'a'
27+
s3 = {st1}
28+
for i in range(20):
29+
s3.add(st1)
30+
if i < 10:
31+
if i > 0:
32+
st1 += 'a'
33+
34+
st1 = 'a'
35+
for i in range(10):
36+
s3.remove(st1)
37+
assert len(s3) == 10 - 1 - i
38+
if i < 10:
39+
st1 += 'a'
40+
41+
for i in range(20):
42+
s1.add(i)
43+
if i % 2 == 0:
44+
s1.remove(i)
45+
assert len(s1) == (i + 1) // 2
46+
47+
test_set_add()

src/libasr/ASR.asdl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,6 @@ stmt
221221
| SelectType(expr selector, type_stmt* body, stmt* default)
222222
| CPtrToPointer(expr cptr, expr ptr, expr? shape, expr? lower_bounds)
223223
| BlockCall(int label, symbol m)
224-
| SetInsert(expr a, expr ele)
225-
| SetRemove(expr a, expr ele)
226224
| ListInsert(expr a, expr pos, expr ele)
227225
| ListRemove(expr a, expr ele)
228226
| ListClear(expr a)

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
175175
std::unique_ptr<LLVMTuple> tuple_api;
176176
std::unique_ptr<LLVMDictInterface> dict_api_lp;
177177
std::unique_ptr<LLVMDictInterface> dict_api_sc;
178+
std::unique_ptr<LLVMSetInterface> set_api; // linear probing
178179
std::unique_ptr<LLVMArrUtils::Descriptor> arr_descr;
179180

180181
ASRToLLVMVisitor(Allocator &al, llvm::LLVMContext &context, std::string infile,
@@ -199,13 +200,15 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
199200
tuple_api(std::make_unique<LLVMTuple>(context, llvm_utils.get(), builder.get())),
200201
dict_api_lp(std::make_unique<LLVMDictOptimizedLinearProbing>(context, llvm_utils.get(), builder.get())),
201202
dict_api_sc(std::make_unique<LLVMDictSeparateChaining>(context, llvm_utils.get(), builder.get())),
203+
set_api(std::make_unique<LLVMSetLinearProbing>(context, llvm_utils.get(), builder.get())),
202204
arr_descr(LLVMArrUtils::Descriptor::get_descriptor(context,
203205
builder.get(), llvm_utils.get(),
204206
LLVMArrUtils::DESCR_TYPE::_SimpleCMODescriptor))
205207
{
206208
llvm_utils->tuple_api = tuple_api.get();
207209
llvm_utils->list_api = list_api.get();
208210
llvm_utils->dict_api = nullptr;
211+
llvm_utils->set_api = set_api.get();
209212
llvm_utils->arr_api = arr_descr.get();
210213
llvm_utils->dict_api_lp = dict_api_lp.get();
211214
llvm_utils->dict_api_sc = dict_api_sc.get();
@@ -1149,6 +1152,25 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
11491152
tmp = const_dict;
11501153
}
11511154

1155+
void visit_SetConstant(const ASR::SetConstant_t& x) {
1156+
llvm::Type* const_set_type = llvm_utils->get_set_type(x.m_type, module.get());
1157+
llvm::Value* const_set = builder->CreateAlloca(const_set_type, nullptr, "const_set");
1158+
ASR::Set_t* x_set = ASR::down_cast<ASR::Set_t>(x.m_type);
1159+
std::string el_type_code = ASRUtils::get_type_code(x_set->m_type);
1160+
llvm_utils->set_api->set_init(el_type_code, const_set, module.get(), x.n_elements);
1161+
int64_t ptr_loads_el = !LLVM::is_llvm_struct(x_set->m_type);
1162+
int64_t ptr_loads_copy = ptr_loads;
1163+
for( size_t i = 0; i < x.n_elements; i++ ) {
1164+
ptr_loads = ptr_loads_el;
1165+
visit_expr_wrapper(x.m_elements[i], true);
1166+
llvm::Value* element = tmp;
1167+
llvm_utils->set_api->write_item(const_set, element, module.get(),
1168+
x_set->m_type, name2memidx);
1169+
}
1170+
ptr_loads = ptr_loads_copy;
1171+
tmp = const_set;
1172+
}
1173+
11521174
void visit_TupleConstant(const ASR::TupleConstant_t& x) {
11531175
ASR::Tuple_t* tuple_type = ASR::down_cast<ASR::Tuple_t>(x.m_type);
11541176
std::string type_code = ASRUtils::get_type_code(tuple_type->m_type,
@@ -1487,6 +1509,20 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
14871509
tmp = llvm_utils->dict_api->len(pdict);
14881510
}
14891511

1512+
void visit_SetLen(const ASR::SetLen_t& x) {
1513+
if (x.m_value) {
1514+
this->visit_expr(*x.m_value);
1515+
return ;
1516+
}
1517+
1518+
int64_t ptr_loads_copy = ptr_loads;
1519+
ptr_loads = 0;
1520+
this->visit_expr(*x.m_arg);
1521+
ptr_loads = ptr_loads_copy;
1522+
llvm::Value* pset = tmp;
1523+
tmp = llvm_utils->set_api->len(pset);
1524+
}
1525+
14901526
void visit_ListInsert(const ASR::ListInsert_t& x) {
14911527
ASR::List_t* asr_list = ASR::down_cast<ASR::List_t>(
14921528
ASRUtils::expr_type(x.m_a));
@@ -1648,6 +1684,34 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
16481684
tmp = list_api->pop_position(plist, pos, asr_el_type, module.get(), name2memidx);
16491685
}
16501686

1687+
void generate_SetAdd(ASR::expr_t* m_arg, ASR::expr_t* m_ele) {
1688+
ASR::ttype_t* asr_el_type = ASRUtils::get_contained_type(ASRUtils::expr_type(m_arg));
1689+
int64_t ptr_loads_copy = ptr_loads;
1690+
ptr_loads = 0;
1691+
this->visit_expr(*m_arg);
1692+
llvm::Value* pset = tmp;
1693+
1694+
ptr_loads = 2;
1695+
this->visit_expr_wrapper(m_ele, true);
1696+
ptr_loads = ptr_loads_copy;
1697+
llvm::Value *el = tmp;
1698+
set_api->write_item(pset, el, module.get(), asr_el_type, name2memidx);
1699+
}
1700+
1701+
void generate_SetRemove(ASR::expr_t* m_arg, ASR::expr_t* m_ele) {
1702+
ASR::ttype_t* asr_el_type = ASRUtils::get_contained_type(ASRUtils::expr_type(m_arg));
1703+
int64_t ptr_loads_copy = ptr_loads;
1704+
ptr_loads = 0;
1705+
this->visit_expr(*m_arg);
1706+
llvm::Value* pset = tmp;
1707+
1708+
ptr_loads = 2;
1709+
this->visit_expr_wrapper(m_ele, true);
1710+
ptr_loads = ptr_loads_copy;
1711+
llvm::Value *el = tmp;
1712+
set_api->remove_item(pset, el, *module, asr_el_type);
1713+
}
1714+
16511715
void visit_IntrinsicFunction(const ASR::IntrinsicFunction_t& x) {
16521716
switch (static_cast<ASRUtils::IntrinsicFunctions>(x.m_intrinsic_id)) {
16531717
case ASRUtils::IntrinsicFunctions::ListIndex: {
@@ -1691,6 +1755,14 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
16911755
}
16921756
break;
16931757
}
1758+
case ASRUtils::IntrinsicFunctions::SetAdd: {
1759+
generate_SetAdd(x.m_args[0], x.m_args[1]);
1760+
break;
1761+
}
1762+
case ASRUtils::IntrinsicFunctions::SetRemove: {
1763+
generate_SetRemove(x.m_args[0], x.m_args[1]);
1764+
break;
1765+
}
16941766
case ASRUtils::IntrinsicFunctions::Exp: {
16951767
switch (x.m_overload_id) {
16961768
case 0: {
@@ -3945,6 +4017,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
39454017
bool is_value_tuple = ASR::is_a<ASR::Tuple_t>(*asr_value_type);
39464018
bool is_target_dict = ASR::is_a<ASR::Dict_t>(*asr_target_type);
39474019
bool is_value_dict = ASR::is_a<ASR::Dict_t>(*asr_value_type);
4020+
bool is_target_set = ASR::is_a<ASR::Set_t>(*asr_target_type);
4021+
bool is_value_set = ASR::is_a<ASR::Set_t>(*asr_value_type);
39484022
bool is_target_struct = ASR::is_a<ASR::Struct_t>(*asr_target_type);
39494023
bool is_value_struct = ASR::is_a<ASR::Struct_t>(*asr_value_type);
39504024
if (ASR::is_a<ASR::StringSection_t>(*x.m_target)) {
@@ -4034,6 +4108,18 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
40344108
llvm_utils->dict_api->dict_deepcopy(value_dict, target_dict,
40354109
value_dict_type, module.get(), name2memidx);
40364110
return ;
4111+
} else if( is_target_set && is_value_set ) {
4112+
int64_t ptr_loads_copy = ptr_loads;
4113+
ptr_loads = 0;
4114+
this->visit_expr(*x.m_value);
4115+
llvm::Value* value_set = tmp;
4116+
this->visit_expr(*x.m_target);
4117+
llvm::Value* target_set = tmp;
4118+
ptr_loads = ptr_loads_copy;
4119+
ASR::Set_t* value_set_type = ASR::down_cast<ASR::Set_t>(asr_value_type);
4120+
llvm_utils->set_api->set_deepcopy(value_set, target_set,
4121+
value_set_type, module.get(), name2memidx);
4122+
return ;
40374123
} else if( is_target_struct && is_value_struct ) {
40384124
int64_t ptr_loads_copy = ptr_loads;
40394125
ptr_loads = 0;

0 commit comments

Comments
 (0)