11use emmylua_parser:: {
2- BinaryOperator , LuaBinaryExpr , LuaCallExpr , LuaChunk , LuaExpr , LuaLiteralToken ,
2+ BinaryOperator , LuaBinaryExpr , LuaCallExpr , LuaChunk , LuaExpr , LuaIndexMemberExpr ,
3+ LuaLiteralToken ,
34} ;
45
56use crate :: {
67 infer_expr,
78 semantic:: infer:: {
9+ infer_index:: infer_member_by_member_key,
810 narrow:: {
911 condition_flow:: { call_flow:: get_type_at_call_expr, InferConditionFlow } ,
1012 get_single_antecedent,
@@ -15,7 +17,8 @@ use crate::{
1517 } ,
1618 VarRefId ,
1719 } ,
18- DbIndex , FlowNode , FlowTree , InferFailReason , LuaInferCache , LuaType , TypeOps ,
20+ DbIndex , FlowNode , FlowTree , InferFailReason , InferGuard , LuaInferCache , LuaType , LuaUnionType ,
21+ TypeOps ,
1922} ;
2023
2124pub fn get_type_at_binary_expr (
@@ -36,73 +39,56 @@ pub fn get_type_at_binary_expr(
3639 return Ok ( ResultTypeOrContinue :: Continue ) ;
3740 } ;
3841
39- match op_token. get_op ( ) {
40- BinaryOperator :: OpLt
41- | BinaryOperator :: OpLe
42- | BinaryOperator :: OpGt
43- | BinaryOperator :: OpGe => {
44- // todo check number range
42+ let condition_flow = match op_token. get_op ( ) {
43+ BinaryOperator :: OpEq => condition_flow,
44+ BinaryOperator :: OpNe => condition_flow. get_negated ( ) ,
45+ _ => {
46+ return Ok ( ResultTypeOrContinue :: Continue ) ;
4547 }
46- BinaryOperator :: OpEq => {
47- let result_type = maybe_type_guard_binary (
48- db,
49- tree,
50- cache,
51- root,
52- var_ref_id,
53- flow_node,
54- left_expr. clone ( ) ,
55- right_expr. clone ( ) ,
56- condition_flow,
57- ) ?;
58- if let ResultTypeOrContinue :: Result ( result_type) = result_type {
59- return Ok ( ResultTypeOrContinue :: Result ( result_type) ) ;
60- }
48+ } ;
6149
62- return maybe_var_eq_narrow (
63- db,
64- tree,
65- cache,
66- root,
67- var_ref_id,
68- flow_node,
69- left_expr,
70- right_expr,
71- condition_flow,
72- ) ;
73- }
74- BinaryOperator :: OpNe => {
75- let result_type = maybe_type_guard_binary (
76- db,
77- tree,
78- cache,
79- root,
80- var_ref_id,
81- flow_node,
82- left_expr. clone ( ) ,
83- right_expr. clone ( ) ,
84- condition_flow. get_negated ( ) ,
85- ) ?;
86- if let ResultTypeOrContinue :: Result ( result_type) = result_type {
87- return Ok ( ResultTypeOrContinue :: Result ( result_type) ) ;
88- }
50+ let mut result_type = maybe_type_guard_binary (
51+ db,
52+ tree,
53+ cache,
54+ root,
55+ var_ref_id,
56+ flow_node,
57+ left_expr. clone ( ) ,
58+ right_expr. clone ( ) ,
59+ condition_flow,
60+ ) ?;
61+ if let ResultTypeOrContinue :: Result ( result_type) = result_type {
62+ return Ok ( ResultTypeOrContinue :: Result ( result_type) ) ;
63+ }
8964
90- return maybe_var_eq_narrow (
91- db,
92- tree,
93- cache,
94- root,
95- var_ref_id,
96- flow_node,
97- left_expr,
98- right_expr,
99- condition_flow. get_negated ( ) ,
100- ) ;
101- }
102- _ => { }
65+ result_type = maybe_field_literal_eq_narrow (
66+ db,
67+ tree,
68+ cache,
69+ root,
70+ var_ref_id,
71+ flow_node,
72+ left_expr. clone ( ) ,
73+ right_expr. clone ( ) ,
74+ condition_flow,
75+ ) ?;
76+
77+ if let ResultTypeOrContinue :: Result ( result_type) = result_type {
78+ return Ok ( ResultTypeOrContinue :: Result ( result_type) ) ;
10379 }
10480
105- Ok ( ResultTypeOrContinue :: Continue )
81+ return maybe_var_eq_narrow (
82+ db,
83+ tree,
84+ cache,
85+ root,
86+ var_ref_id,
87+ flow_node,
88+ left_expr,
89+ right_expr,
90+ condition_flow,
91+ ) ;
10692}
10793
10894fn maybe_type_guard_binary (
@@ -296,3 +282,107 @@ fn maybe_var_eq_narrow(
296282 }
297283 }
298284}
285+
286+ fn maybe_field_literal_eq_narrow (
287+ db : & DbIndex ,
288+ tree : & FlowTree ,
289+ cache : & mut LuaInferCache ,
290+ root : & LuaChunk ,
291+ var_ref_id : & VarRefId ,
292+ flow_node : & FlowNode ,
293+ left_expr : LuaExpr ,
294+ right_expr : LuaExpr ,
295+ condition_flow : InferConditionFlow ,
296+ ) -> Result < ResultTypeOrContinue , InferFailReason > {
297+ // only check left as need narrow
298+ let ( index_expr, literal_expr) = match ( left_expr, right_expr) {
299+ ( LuaExpr :: IndexExpr ( index_expr) , LuaExpr :: LiteralExpr ( literal_expr) ) => {
300+ ( index_expr, literal_expr)
301+ }
302+ ( LuaExpr :: LiteralExpr ( literal_expr) , LuaExpr :: IndexExpr ( index_expr) ) => {
303+ ( index_expr, literal_expr)
304+ }
305+ _ => return Ok ( ResultTypeOrContinue :: Continue ) ,
306+ } ;
307+
308+ let Some ( prefix_expr) = index_expr. get_prefix_expr ( ) else {
309+ return Ok ( ResultTypeOrContinue :: Continue ) ;
310+ } ;
311+
312+ let Some ( maybe_var_ref_id) = get_var_expr_var_ref_id ( db, cache, prefix_expr. clone ( ) ) else {
313+ // If we cannot find a reference declaration ID, we cannot narrow it
314+ return Ok ( ResultTypeOrContinue :: Continue ) ;
315+ } ;
316+
317+ if maybe_var_ref_id != * var_ref_id {
318+ return Ok ( ResultTypeOrContinue :: Continue ) ;
319+ }
320+
321+ let antecedent_flow_id = get_single_antecedent ( tree, flow_node) ?;
322+ let left_type = get_type_at_flow ( db, tree, cache, root, & var_ref_id, antecedent_flow_id) ?;
323+ let LuaType :: Union ( union_type) = left_type else {
324+ return Ok ( ResultTypeOrContinue :: Continue ) ;
325+ } ;
326+
327+ let right_type = infer_expr ( db, cache, LuaExpr :: LiteralExpr ( literal_expr) ) ?;
328+ let mut guard = InferGuard :: new ( ) ;
329+ let index = LuaIndexMemberExpr :: IndexExpr ( index_expr) ;
330+ let mut opt_result = None ;
331+ let mut union_types = union_type. get_types ( ) ;
332+ for ( i, sub_type) in union_types. iter ( ) . enumerate ( ) {
333+ let member_type =
334+ match infer_member_by_member_key ( db, cache, & sub_type, index. clone ( ) , & mut guard) {
335+ Ok ( member_type) => member_type,
336+ Err ( _) => continue , // If we cannot infer the member type, skip this type
337+ } ;
338+ if const_type_eq ( & member_type, & right_type) {
339+ // If the right type matches the member type, we can narrow it
340+ opt_result = Some ( i) ;
341+ }
342+ }
343+
344+ match condition_flow {
345+ InferConditionFlow :: TrueCondition => {
346+ if let Some ( i) = opt_result {
347+ return Ok ( ResultTypeOrContinue :: Result ( union_types[ i] . clone ( ) ) ) ;
348+ }
349+ }
350+ InferConditionFlow :: FalseCondition => {
351+ if let Some ( i) = opt_result {
352+ union_types. remove ( i) ;
353+ match union_types. len ( ) {
354+ 0 => return Ok ( ResultTypeOrContinue :: Result ( LuaType :: Unknown ) ) ,
355+ 1 => return Ok ( ResultTypeOrContinue :: Result ( union_types[ 0 ] . clone ( ) ) ) ,
356+ _ => {
357+ let union_type = LuaUnionType :: new ( union_types) ;
358+ return Ok ( ResultTypeOrContinue :: Result ( LuaType :: Union (
359+ union_type. into ( ) ,
360+ ) ) ) ;
361+ }
362+ }
363+ }
364+ }
365+ }
366+
367+ Ok ( ResultTypeOrContinue :: Continue )
368+ }
369+
370+ fn const_type_eq ( left_type : & LuaType , right_type : & LuaType ) -> bool {
371+ if left_type == right_type {
372+ return true ;
373+ }
374+
375+ match ( left_type, right_type) {
376+ (
377+ LuaType :: StringConst ( l) | LuaType :: DocStringConst ( l) ,
378+ LuaType :: StringConst ( r) | LuaType :: DocStringConst ( r) ,
379+ ) => l == r,
380+ ( LuaType :: FloatConst ( l) , LuaType :: FloatConst ( r) ) => l == r,
381+ ( LuaType :: BooleanConst ( l) , LuaType :: BooleanConst ( r) ) => l == r,
382+ (
383+ LuaType :: IntegerConst ( l) | LuaType :: DocIntegerConst ( l) ,
384+ LuaType :: IntegerConst ( r) | LuaType :: DocIntegerConst ( r) ,
385+ ) => l == r,
386+ _ => false ,
387+ }
388+ }
0 commit comments