Skip to content

Commit 634177d

Browse files
authored
Add dict.keys and dict.values (#2023)
1 parent a183feb commit 634177d

File tree

8 files changed

+451
-2
lines changed

8 files changed

+451
-2
lines changed

integration_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,7 @@ RUN(NAME test_dict_12 LABELS cpython llvm c)
508508
RUN(NAME test_dict_13 LABELS cpython llvm c)
509509
RUN(NAME test_dict_bool LABELS cpython llvm)
510510
RUN(NAME test_dict_increment LABELS cpython llvm)
511+
RUN(NAME test_dict_keys_values LABELS cpython llvm)
511512
RUN(NAME test_set_len LABELS cpython llvm)
512513
RUN(NAME test_set_add LABELS cpython llvm)
513514
RUN(NAME test_set_remove LABELS cpython llvm)
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from lpython import i32, f64
2+
3+
def test_dict_keys_values():
4+
d1: dict[i32, i32] = {}
5+
k1: list[i32]
6+
k1_copy: list[i32] = []
7+
v1: list[i32]
8+
v1_copy: list[i32] = []
9+
i: i32
10+
j: i32
11+
s: str
12+
key_count: i32
13+
14+
for i in range(105, 115):
15+
d1[i] = i + 1
16+
k1 = d1.keys()
17+
for i in k1:
18+
k1_copy.append(i)
19+
v1 = d1.values()
20+
for i in v1:
21+
v1_copy.append(i)
22+
assert len(k1) == 10
23+
for i in range(105, 115):
24+
key_count = 0
25+
for j in range(len(k1)):
26+
if k1_copy[j] == i:
27+
key_count += 1
28+
assert v1_copy[j] == d1[i]
29+
assert key_count == 1
30+
31+
d2: dict[str, str] = {}
32+
k2: list[str]
33+
k2_copy: list[str] = []
34+
v2: list[str]
35+
v2_copy: list[str] = []
36+
37+
for i in range(105, 115):
38+
d2[str(i)] = str(i + 1)
39+
k2 = d2.keys()
40+
for s in k2:
41+
k2_copy.append(s)
42+
v2 = d2.values()
43+
for s in v2:
44+
v2_copy.append(s)
45+
assert len(k2) == 10
46+
for i in range(105, 115):
47+
key_count = 0
48+
for j in range(len(k2)):
49+
if k2_copy[j] == str(i):
50+
key_count += 1
51+
assert v2_copy[j] == d2[str(i)]
52+
assert key_count == 1
53+
54+
test_dict_keys_values()

src/libasr/asr_utils.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,14 +238,24 @@ static inline ASR::abiType symbol_abi(const ASR::symbol_t *f)
238238
return ASR::abiType::Source;
239239
}
240240

241-
static inline ASR::ttype_t* get_contained_type(ASR::ttype_t* asr_type) {
241+
static inline ASR::ttype_t* get_contained_type(ASR::ttype_t* asr_type, int overload=0) {
242242
switch( asr_type->type ) {
243243
case ASR::ttypeType::List: {
244244
return ASR::down_cast<ASR::List_t>(asr_type)->m_type;
245245
}
246246
case ASR::ttypeType::Set: {
247247
return ASR::down_cast<ASR::Set_t>(asr_type)->m_type;
248248
}
249+
case ASR::ttypeType::Dict: {
250+
switch( overload ) {
251+
case 0:
252+
return ASR::down_cast<ASR::Dict_t>(asr_type)->m_key_type;
253+
case 1:
254+
return ASR::down_cast<ASR::Dict_t>(asr_type)->m_value_type;
255+
default:
256+
return asr_type;
257+
}
258+
}
249259
case ASR::ttypeType::Enum: {
250260
ASR::Enum_t* enum_asr = ASR::down_cast<ASR::Enum_t>(asr_type);
251261
ASR::EnumType_t* enum_type = ASR::down_cast<ASR::EnumType_t>(enum_asr->m_enum_type);

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1684,6 +1684,49 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
16841684
tmp = list_api->pop_position(plist, pos, asr_el_type, module.get(), name2memidx);
16851685
}
16861686

1687+
void generate_DictElems(ASR::expr_t* m_arg, bool key_or_value) {
1688+
ASR::Dict_t* dict_type = ASR::down_cast<ASR::Dict_t>(
1689+
ASRUtils::expr_type(m_arg));
1690+
ASR::ttype_t* el_type = key_or_value == 0 ?
1691+
dict_type->m_key_type : dict_type->m_value_type;
1692+
1693+
int64_t ptr_loads_copy = ptr_loads;
1694+
ptr_loads = 0;
1695+
this->visit_expr(*m_arg);
1696+
llvm::Value* pdict = tmp;
1697+
1698+
ptr_loads = ptr_loads_copy;
1699+
1700+
bool is_array_type_local = false, is_malloc_array_type_local = false;
1701+
bool is_list_local = false;
1702+
ASR::dimension_t* m_dims_local = nullptr;
1703+
int n_dims_local = -1, a_kind_local = -1;
1704+
llvm::Type* llvm_el_type = llvm_utils->get_type_from_ttype_t(el_type, nullptr,
1705+
ASR::storage_typeType::Default, is_array_type_local,
1706+
is_malloc_array_type_local, is_list_local, m_dims_local,
1707+
n_dims_local, a_kind_local, module.get());
1708+
std::string type_code = ASRUtils::get_type_code(el_type);
1709+
int32_t type_size = -1;
1710+
if( ASR::is_a<ASR::Character_t>(*el_type) ||
1711+
LLVM::is_llvm_struct(el_type) ||
1712+
ASR::is_a<ASR::Complex_t>(*el_type) ) {
1713+
llvm::DataLayout data_layout(module.get());
1714+
type_size = data_layout.getTypeAllocSize(llvm_el_type);
1715+
} else {
1716+
type_size = ASRUtils::extract_kind_from_ttype_t(el_type);
1717+
}
1718+
llvm::Type* el_list_type = list_api->get_list_type(llvm_el_type, type_code, type_size);
1719+
llvm::Value* el_list = builder->CreateAlloca(el_list_type, nullptr, key_or_value == 0 ?
1720+
"keys_list" : "values_list");
1721+
list_api->list_init(type_code, el_list, *module, 0, 0);
1722+
1723+
llvm_utils->set_dict_api(dict_type);
1724+
llvm_utils->dict_api->get_elements_list(pdict, el_list, dict_type->m_key_type,
1725+
dict_type->m_value_type, *module,
1726+
name2memidx, key_or_value);
1727+
tmp = el_list;
1728+
}
1729+
16871730
void generate_SetAdd(ASR::expr_t* m_arg, ASR::expr_t* m_ele) {
16881731
ASR::ttype_t* asr_el_type = ASRUtils::get_contained_type(ASRUtils::expr_type(m_arg));
16891732
int64_t ptr_loads_copy = ptr_loads;
@@ -1755,6 +1798,14 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
17551798
}
17561799
break;
17571800
}
1801+
case ASRUtils::IntrinsicFunctions::DictKeys: {
1802+
generate_DictElems(x.m_args[0], 0);
1803+
break;
1804+
}
1805+
case ASRUtils::IntrinsicFunctions::DictValues: {
1806+
generate_DictElems(x.m_args[0], 1);
1807+
break;
1808+
}
17581809
case ASRUtils::IntrinsicFunctions::SetAdd: {
17591810
generate_SetAdd(x.m_args[0], x.m_args[1]);
17601811
break;

src/libasr/codegen/llvm_utils.cpp

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1708,6 +1708,12 @@ namespace LCompilers {
17081708
list_api->list_deepcopy(src, dest, list_type, module, name2memidx);
17091709
break ;
17101710
}
1711+
case ASR::ttypeType::Dict: {
1712+
ASR::Dict_t* dict_type = ASR::down_cast<ASR::Dict_t>(asr_type);
1713+
// set dict api here?
1714+
dict_api->dict_deepcopy(src, dest, dict_type, module, name2memidx);
1715+
break ;
1716+
}
17111717
case ASR::ttypeType::Struct: {
17121718
ASR::Struct_t* struct_t = ASR::down_cast<ASR::Struct_t>(asr_type);
17131719
ASR::StructType_t* struct_type_t = ASR::down_cast<ASR::StructType_t>(
@@ -3865,6 +3871,178 @@ namespace LCompilers {
38653871
return LLVM::CreateLoad(*builder, value_ptr);
38663872
}
38673873

3874+
void LLVMDict::get_elements_list(llvm::Value* dict,
3875+
llvm::Value* elements_list, ASR::ttype_t* key_asr_type,
3876+
ASR::ttype_t* value_asr_type, llvm::Module& module,
3877+
std::map<std::string, std::map<std::string, int>>& name2memidx,
3878+
bool key_or_value) {
3879+
3880+
/**
3881+
* C++ equivalent:
3882+
*
3883+
* // key_or_value = 0 for keys, 1 for values
3884+
*
3885+
* idx = 0;
3886+
*
3887+
* while( capacity > idx ) {
3888+
* el = key_or_value_list[idx];
3889+
* key_mask_value = key_mask[idx];
3890+
*
3891+
* is_key_skip = key_mask_value == 3; // tombstone
3892+
* is_key_set = key_mask_value != 0;
3893+
* add_el = is_key_set && !is_key_skip;
3894+
* if( add_el ) {
3895+
* elements_list.append(el);
3896+
* }
3897+
*
3898+
* idx++;
3899+
* }
3900+
*
3901+
*/
3902+
3903+
llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict));
3904+
llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict));
3905+
llvm::Value* el_list = key_or_value == 0 ? get_key_list(dict) : get_value_list(dict);
3906+
ASR::ttype_t* el_asr_type = key_or_value == 0 ? key_asr_type : value_asr_type;
3907+
if( !are_iterators_set ) {
3908+
idx_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr);
3909+
}
3910+
LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context),
3911+
llvm::APInt(32, 0)), idx_ptr);
3912+
3913+
llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head");
3914+
llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body");
3915+
llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end");
3916+
3917+
// head
3918+
llvm_utils->start_new_block(loophead);
3919+
{
3920+
llvm::Value *cond = builder->CreateICmpSGT(capacity, LLVM::CreateLoad(*builder, idx_ptr));
3921+
builder->CreateCondBr(cond, loopbody, loopend);
3922+
}
3923+
3924+
// body
3925+
llvm_utils->start_new_block(loopbody);
3926+
{
3927+
llvm::Value* idx = LLVM::CreateLoad(*builder, idx_ptr);
3928+
llvm::Value* key_mask_value = LLVM::CreateLoad(*builder,
3929+
llvm_utils->create_ptr_gep(key_mask, idx));
3930+
llvm::Value* is_key_skip = builder->CreateICmpEQ(key_mask_value,
3931+
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 3)));
3932+
llvm::Value* is_key_set = builder->CreateICmpNE(key_mask_value,
3933+
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 0)));
3934+
3935+
llvm::Value* add_el = builder->CreateAnd(is_key_set,
3936+
builder->CreateNot(is_key_skip));
3937+
llvm_utils->create_if_else(add_el, [&]() {
3938+
llvm::Value* el = llvm_utils->list_api->read_item(el_list, idx,
3939+
false, module, LLVM::is_llvm_struct(el_asr_type));
3940+
llvm_utils->list_api->append(elements_list, el,
3941+
el_asr_type, &module, name2memidx);
3942+
}, [=]() {
3943+
});
3944+
3945+
idx = builder->CreateAdd(idx, llvm::ConstantInt::get(
3946+
llvm::Type::getInt32Ty(context), llvm::APInt(32, 1)));
3947+
LLVM::CreateStore(*builder, idx, idx_ptr);
3948+
}
3949+
3950+
builder->CreateBr(loophead);
3951+
3952+
// end
3953+
llvm_utils->start_new_block(loopend);
3954+
}
3955+
3956+
void LLVMDictSeparateChaining::get_elements_list(llvm::Value* dict,
3957+
llvm::Value* elements_list, ASR::ttype_t* key_asr_type,
3958+
ASR::ttype_t* value_asr_type, llvm::Module& module,
3959+
std::map<std::string, std::map<std::string, int>>& name2memidx,
3960+
bool key_or_value) {
3961+
if( !are_iterators_set ) {
3962+
idx_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr);
3963+
chain_itr = builder->CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr);
3964+
}
3965+
LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context),
3966+
llvm::APInt(32, 0)), idx_ptr);
3967+
3968+
llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict));
3969+
llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict));
3970+
llvm::Value* key_value_pairs = LLVM::CreateLoad(*builder, get_pointer_to_key_value_pairs(dict));
3971+
llvm::Type* kv_pair_type = get_key_value_pair_type(key_asr_type, value_asr_type);
3972+
ASR::ttype_t* el_asr_type = key_or_value == 0 ? key_asr_type : value_asr_type;
3973+
llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head");
3974+
llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body");
3975+
llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end");
3976+
3977+
// head
3978+
llvm_utils->start_new_block(loophead);
3979+
{
3980+
llvm::Value *cond = builder->CreateICmpSGT(
3981+
capacity,
3982+
LLVM::CreateLoad(*builder, idx_ptr));
3983+
builder->CreateCondBr(cond, loopbody, loopend);
3984+
}
3985+
3986+
// body
3987+
llvm_utils->start_new_block(loopbody);
3988+
{
3989+
llvm::Value* idx = LLVM::CreateLoad(*builder, idx_ptr);
3990+
llvm::Value* key_mask_value = LLVM::CreateLoad(*builder,
3991+
llvm_utils->create_ptr_gep(key_mask, idx));
3992+
llvm::Value* is_key_set = builder->CreateICmpEQ(key_mask_value,
3993+
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1)));
3994+
3995+
llvm_utils->create_if_else(is_key_set, [&]() {
3996+
llvm::Value* dict_i = llvm_utils->create_ptr_gep(key_value_pairs, idx);
3997+
llvm::Value* kv_ll_i8 = builder->CreateBitCast(dict_i, llvm::Type::getInt8PtrTy(context));
3998+
LLVM::CreateStore(*builder, kv_ll_i8, chain_itr);
3999+
4000+
llvm::BasicBlock *loop2head = llvm::BasicBlock::Create(context, "loop2.head");
4001+
llvm::BasicBlock *loop2body = llvm::BasicBlock::Create(context, "loop2.body");
4002+
llvm::BasicBlock *loop2end = llvm::BasicBlock::Create(context, "loop2.end");
4003+
4004+
// head
4005+
llvm_utils->start_new_block(loop2head);
4006+
{
4007+
llvm::Value *cond = builder->CreateICmpNE(
4008+
LLVM::CreateLoad(*builder, chain_itr),
4009+
llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))
4010+
);
4011+
builder->CreateCondBr(cond, loop2body, loop2end);
4012+
}
4013+
4014+
// body
4015+
llvm_utils->start_new_block(loop2body);
4016+
{
4017+
llvm::Value* kv_struct_i8 = LLVM::CreateLoad(*builder, chain_itr);
4018+
llvm::Value* kv_struct = builder->CreateBitCast(kv_struct_i8, kv_pair_type->getPointerTo());
4019+
llvm::Value* kv_el = llvm_utils->create_gep(kv_struct, key_or_value);
4020+
if( !LLVM::is_llvm_struct(el_asr_type) ) {
4021+
kv_el = LLVM::CreateLoad(*builder, kv_el);
4022+
}
4023+
llvm_utils->list_api->append(elements_list, kv_el,
4024+
el_asr_type, &module, name2memidx);
4025+
llvm::Value* next_kv_struct = LLVM::CreateLoad(*builder, llvm_utils->create_gep(kv_struct, 2));
4026+
LLVM::CreateStore(*builder, next_kv_struct, chain_itr);
4027+
}
4028+
4029+
builder->CreateBr(loop2head);
4030+
4031+
// end
4032+
llvm_utils->start_new_block(loop2end);
4033+
}, [=]() {
4034+
});
4035+
llvm::Value* tmp = builder->CreateAdd(idx,
4036+
llvm::ConstantInt::get(context, llvm::APInt(32, 1)));
4037+
LLVM::CreateStore(*builder, tmp, idx_ptr);
4038+
}
4039+
4040+
builder->CreateBr(loophead);
4041+
4042+
// end
4043+
llvm_utils->start_new_block(loopend);
4044+
}
4045+
38684046
llvm::Value* LLVMList::read_item(llvm::Value* list, llvm::Value* pos,
38694047
bool enable_bounds_checking,
38704048
llvm::Module& module, bool get_pointer) {

src/libasr/codegen/llvm_utils.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,13 @@ namespace LCompilers {
621621
virtual
622622
void set_is_dict_present(bool value);
623623

624+
virtual
625+
void get_elements_list(llvm::Value* dict,
626+
llvm::Value* elements_list, ASR::ttype_t* key_asr_type,
627+
ASR::ttype_t* value_asr_type, llvm::Module& module,
628+
std::map<std::string, std::map<std::string, int>>& name2memidx,
629+
bool key_or_value) = 0;
630+
624631
virtual ~LLVMDictInterface() = 0;
625632

626633
};
@@ -713,6 +720,12 @@ namespace LCompilers {
713720

714721
llvm::Value* len(llvm::Value* dict);
715722

723+
void get_elements_list(llvm::Value* dict,
724+
llvm::Value* elements_list, ASR::ttype_t* key_asr_type,
725+
ASR::ttype_t* value_asr_type, llvm::Module& module,
726+
std::map<std::string, std::map<std::string, int>>& name2memidx,
727+
bool key_or_value);
728+
716729
virtual ~LLVMDict();
717730
};
718731

@@ -860,6 +873,12 @@ namespace LCompilers {
860873

861874
llvm::Value* len(llvm::Value* dict);
862875

876+
void get_elements_list(llvm::Value* dict,
877+
llvm::Value* elements_list, ASR::ttype_t* key_asr_type,
878+
ASR::ttype_t* value_asr_type, llvm::Module& module,
879+
std::map<std::string, std::map<std::string, int>>& name2memidx,
880+
bool key_or_value);
881+
863882
virtual ~LLVMDictSeparateChaining();
864883

865884
};

0 commit comments

Comments
 (0)