1+ use std:: ops:: Deref ;
2+
13use emmylua_parser:: {
24 LuaAst , LuaAstNode , LuaCallArgList , LuaCallExpr , LuaClosureExpr , LuaFuncStat , LuaVarExpr ,
35} ;
@@ -215,7 +217,7 @@ pub fn analyze_return_point(
215217 match point {
216218 LuaReturnPoint :: Expr ( expr) => {
217219 let expr_type = infer_expr ( db, cache, expr. clone ( ) ) ?;
218- return_type = TypeOps :: Union . apply ( db, & return_type, & expr_type) ;
220+ return_type = union_return_expr ( db, return_type, expr_type) ;
219221 }
220222 LuaReturnPoint :: MuliExpr ( exprs) => {
221223 let mut multi_return = vec ! [ ] ;
@@ -224,10 +226,10 @@ pub fn analyze_return_point(
224226 multi_return. push ( expr_type) ;
225227 }
226228 let typ = LuaType :: Variadic ( VariadicType :: Multi ( multi_return) . into ( ) ) ;
227- return_type = TypeOps :: Union . apply ( db, & return_type, & typ) ;
229+ return_type = union_return_expr ( db, return_type, typ) ;
228230 }
229231 LuaReturnPoint :: Nil => {
230- return_type = TypeOps :: Union . apply ( db, & return_type, & LuaType :: Nil ) ;
232+ return_type = union_return_expr ( db, return_type, LuaType :: Nil ) ;
231233 }
232234 _ => { }
233235 }
@@ -239,3 +241,92 @@ pub fn analyze_return_point(
239241 name: None ,
240242 } ] )
241243}
244+
245+ fn union_return_expr ( db : & DbIndex , left : LuaType , right : LuaType ) -> LuaType {
246+ if left == LuaType :: Unknown {
247+ return right;
248+ }
249+
250+ match ( & left, & right) {
251+ ( LuaType :: Variadic ( left_variadic) , LuaType :: Variadic ( right_variadic) ) => {
252+ match ( & left_variadic. deref ( ) , & right_variadic. deref ( ) ) {
253+ ( VariadicType :: Base ( left_base) , VariadicType :: Base ( right_base) ) => {
254+ let union_base = TypeOps :: Union . apply ( db, left_base, right_base) ;
255+ LuaType :: Variadic ( VariadicType :: Base ( union_base) . into ( ) )
256+ }
257+ ( VariadicType :: Multi ( left_multi) , VariadicType :: Multi ( right_multi) ) => {
258+ let mut new_multi = vec ! [ ] ;
259+ let max_len = left_multi. len ( ) . max ( right_multi. len ( ) ) ;
260+ for i in 0 ..max_len {
261+ let left_type = left_multi. get ( i) . cloned ( ) . unwrap_or ( LuaType :: Nil ) ;
262+ let right_type = right_multi. get ( i) . cloned ( ) . unwrap_or ( LuaType :: Nil ) ;
263+ new_multi. push ( TypeOps :: Union . apply ( db, & left_type, & right_type) ) ;
264+ }
265+ LuaType :: Variadic ( VariadicType :: Multi ( new_multi) . into ( ) )
266+ }
267+ // difficult to merge the type, use let
268+ _ => left. clone ( ) ,
269+ }
270+ }
271+ ( LuaType :: Variadic ( variadic) , _) => {
272+ let first_type = variadic. get_type ( 0 ) . cloned ( ) . unwrap_or ( LuaType :: Unknown ) ;
273+ let first_union_type = TypeOps :: Union . apply ( db, & first_type, & right) ;
274+
275+ match variadic. deref ( ) {
276+ VariadicType :: Base ( base) => {
277+ let union_base = TypeOps :: Union . apply ( db, base, & LuaType :: Nil ) ;
278+ LuaType :: Variadic (
279+ VariadicType :: Multi ( vec ! [
280+ first_union_type,
281+ LuaType :: Variadic ( VariadicType :: Base ( union_base) . into( ) ) ,
282+ ] )
283+ . into ( ) ,
284+ )
285+ }
286+ VariadicType :: Multi ( multi) => {
287+ let mut new_multi = multi. clone ( ) ;
288+ if new_multi. len ( ) > 0 {
289+ new_multi[ 0 ] = first_union_type;
290+ for i in 1 ..new_multi. len ( ) {
291+ new_multi[ i] = TypeOps :: Union . apply ( db, & new_multi[ i] , & LuaType :: Nil ) ;
292+ }
293+ } else {
294+ new_multi. push ( first_union_type) ;
295+ }
296+
297+ LuaType :: Variadic ( VariadicType :: Multi ( new_multi) . into ( ) )
298+ }
299+ }
300+ }
301+ ( _, LuaType :: Variadic ( variadic) ) => {
302+ let first_type = variadic. get_type ( 0 ) . cloned ( ) . unwrap_or ( LuaType :: Unknown ) ;
303+ let first_union_type = TypeOps :: Union . apply ( db, & left, & first_type) ;
304+ match variadic. deref ( ) {
305+ VariadicType :: Base ( base) => {
306+ let union_base = TypeOps :: Union . apply ( db, base, & LuaType :: Nil ) ;
307+ LuaType :: Variadic (
308+ VariadicType :: Multi ( vec ! [
309+ first_union_type,
310+ LuaType :: Variadic ( VariadicType :: Base ( union_base) . into( ) ) ,
311+ ] )
312+ . into ( ) ,
313+ )
314+ }
315+ VariadicType :: Multi ( multi) => {
316+ let mut new_multi = multi. clone ( ) ;
317+ if new_multi. len ( ) > 0 {
318+ new_multi[ 0 ] = first_union_type;
319+ for i in 1 ..new_multi. len ( ) {
320+ new_multi[ i] = TypeOps :: Union . apply ( db, & new_multi[ i] , & LuaType :: Nil ) ;
321+ }
322+ } else {
323+ new_multi. push ( first_union_type) ;
324+ }
325+
326+ LuaType :: Variadic ( VariadicType :: Multi ( new_multi) . into ( ) )
327+ }
328+ }
329+ }
330+ _ => TypeOps :: Union . apply ( db, & left, & right) ,
331+ }
332+ }
0 commit comments