Skip to content

Commit 942c324

Browse files
committed
Fix #318
1 parent 0c22c20 commit 942c324

File tree

4 files changed

+148
-23
lines changed

4 files changed

+148
-23
lines changed
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#[cfg(test)]
2+
mod test {
3+
use smol_str::SmolStr;
4+
5+
use crate::{LuaType, LuaUnionType, VirtualWorkspace};
6+
7+
#[test]
8+
fn test_issue_318() {
9+
let mut ws = VirtualWorkspace::new_with_init_std_lib();
10+
11+
ws.def(
12+
r#"
13+
local map = {
14+
a = 'a',
15+
b = 'b',
16+
c = 'c',
17+
}
18+
local key --- @type string
19+
c = map[key] -- type should be ('a'|'b'|'c'|nil)
20+
21+
"#,
22+
);
23+
24+
let c_ty = ws.expr_ty("c");
25+
26+
let union_type = LuaType::Union(
27+
LuaUnionType::new(vec![
28+
LuaType::StringConst(SmolStr::new("a").into()),
29+
LuaType::StringConst(SmolStr::new("b").into()),
30+
LuaType::StringConst(SmolStr::new("c").into()),
31+
LuaType::Nil,
32+
])
33+
.into(),
34+
);
35+
36+
assert_eq!(c_ty, union_type);
37+
}
38+
}

crates/emmylua_code_analysis/src/compilation/test/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ mod for_range_var_infer_test;
99
mod infer_str_tpl_test;
1010
mod inherit_type;
1111
mod mathlib_test;
12+
mod memebr_infer_test;
1213
mod metatable_test;
1314
mod module_annotation;
1415
mod multi_return;

crates/emmylua_code_analysis/src/db_index/type/types.rs

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
use std::{collections::HashMap, hash::Hash, sync::Arc};
1+
use std::{
2+
collections::HashMap,
3+
hash::{Hash, Hasher},
4+
sync::Arc,
5+
};
26

37
use internment::ArcIntern;
48
use rowan::TextRange;
@@ -631,8 +635,7 @@ impl From<LuaObjectType> for LuaType {
631635
LuaType::Object(t.into())
632636
}
633637
}
634-
635-
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
638+
#[derive(Debug, Clone)]
636639
pub struct LuaUnionType {
637640
types: Vec<LuaType>,
638641
}
@@ -655,6 +658,50 @@ impl LuaUnionType {
655658
}
656659
}
657660

661+
impl PartialEq for LuaUnionType {
662+
fn eq(&self, other: &Self) -> bool {
663+
if self.types.len() != other.types.len() {
664+
return false;
665+
}
666+
let mut counts = HashMap::new();
667+
// Count occurrences in self.types
668+
for t in &self.types {
669+
*counts.entry(t).or_insert(0) += 1;
670+
}
671+
// Decrease counts for other.types
672+
for t in &other.types {
673+
match counts.get_mut(t) {
674+
Some(count) if *count > 0 => *count -= 1,
675+
_ => return false,
676+
}
677+
}
678+
true
679+
}
680+
}
681+
682+
impl Eq for LuaUnionType {}
683+
684+
impl std::hash::Hash for LuaUnionType {
685+
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
686+
// To get an order-insensitive hash, combine:
687+
// - the number of elements
688+
// - the sum and product of the hashes of individual elements.
689+
// This is a simple and fast commutative hash.
690+
let mut sum: u64 = 0;
691+
let mut prod: u64 = 1;
692+
for t in &self.types {
693+
let mut hasher = std::collections::hash_map::DefaultHasher::new();
694+
t.hash(&mut hasher);
695+
let h = hasher.finish();
696+
sum = sum.wrapping_add(h);
697+
prod = prod.wrapping_mul(h.wrapping_add(1));
698+
}
699+
self.types.len().hash(state);
700+
sum.hash(state);
701+
prod.hash(state);
702+
}
703+
}
704+
658705
impl From<LuaUnionType> for LuaType {
659706
fn from(t: LuaUnionType) -> Self {
660707
LuaType::Union(t.into())

crates/emmylua_code_analysis/src/semantic/infer/infer_index.rs

Lines changed: 59 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -453,28 +453,67 @@ fn infer_member_by_index_table(
453453
table_range: &InFiled<TextRange>,
454454
index_expr: LuaIndexMemberExpr,
455455
) -> InferResult {
456-
let metatable = db
457-
.get_metatable_index()
458-
.get(table_range)
459-
.ok_or(InferFailReason::FieldDotFound)?;
460-
461-
let meta_owner = LuaOperatorOwner::Table(metatable.clone());
462-
let operator_ids = db
463-
.get_operator_index()
464-
.get_operators(&meta_owner, LuaOperatorMetaMethod::Index)
465-
.ok_or(InferFailReason::FieldDotFound)?;
456+
let metatable = db.get_metatable_index().get(table_range);
457+
match metatable {
458+
Some(metatable) => {
459+
let meta_owner = LuaOperatorOwner::Table(metatable.clone());
460+
let operator_ids = db
461+
.get_operator_index()
462+
.get_operators(&meta_owner, LuaOperatorMetaMethod::Index)
463+
.ok_or(InferFailReason::FieldDotFound)?;
464+
465+
let index_key = index_expr.get_index_key().ok_or(InferFailReason::None)?;
466+
467+
for operator_id in operator_ids {
468+
let operator = db
469+
.get_operator_index()
470+
.get_operator(operator_id)
471+
.ok_or(InferFailReason::None)?;
472+
let operand = operator.get_operand(db);
473+
let return_type = operator.get_result(db)?;
474+
let typ = infer_index_metamethod(db, cache, &index_key, &operand, &return_type)?;
475+
return Ok(typ);
476+
}
477+
}
478+
None => {
479+
let index_key = index_expr.get_index_key().ok_or(InferFailReason::None)?;
480+
if let LuaIndexKey::Expr(expr) = index_key {
481+
let key_type = infer_expr(db, cache, expr.clone())?;
482+
let members = db
483+
.get_member_index()
484+
.get_members(&LuaMemberOwner::Element(table_range.clone()));
485+
if let Some(members) = members {
486+
let mut result_type = LuaType::Unknown;
487+
for member in members {
488+
let member_key_type = match member.get_key() {
489+
LuaMemberKey::Name(s) => LuaType::StringConst(s.clone().into()),
490+
LuaMemberKey::Integer(i) => LuaType::IntegerConst(*i),
491+
_ => continue,
492+
};
493+
if check_type_compact(db, &key_type, &member_key_type).is_ok() {
494+
let member_type = db
495+
.get_type_index()
496+
.get_type_cache(&member.get_id().into())
497+
.map(|it| it.as_type())
498+
.unwrap_or(&LuaType::Unknown);
499+
500+
result_type = TypeOps::Union.apply(&result_type, member_type);
501+
}
502+
}
466503

467-
let index_key = index_expr.get_index_key().ok_or(InferFailReason::None)?;
504+
if !result_type.is_unknown() {
505+
if matches!(
506+
key_type,
507+
LuaType::String | LuaType::Number | LuaType::Integer
508+
) {
509+
result_type = TypeOps::Union.apply(&result_type, &LuaType::Nil);
510+
}
468511

469-
for operator_id in operator_ids {
470-
let operator = db
471-
.get_operator_index()
472-
.get_operator(operator_id)
473-
.ok_or(InferFailReason::None)?;
474-
let operand = operator.get_operand(db);
475-
let return_type = operator.get_result(db)?;
476-
let typ = infer_index_metamethod(db, cache, &index_key, &operand, &return_type)?;
477-
return Ok(typ);
512+
return Ok(result_type);
513+
}
514+
}
515+
}
516+
}
478517
}
479518

480519
Err(InferFailReason::FieldDotFound)

0 commit comments

Comments
 (0)