1
+ use std:: ops:: Deref ;
2
+
1
3
use emmylua_parser:: {
2
4
LuaAst , LuaAstNode , LuaCallArgList , LuaCallExpr , LuaClosureExpr , LuaFuncStat , LuaVarExpr ,
3
5
} ;
@@ -215,7 +217,7 @@ pub fn analyze_return_point(
215
217
match point {
216
218
LuaReturnPoint :: Expr ( expr) => {
217
219
let expr_type = infer_expr ( db, cache, expr. clone ( ) ) ?;
218
- return_type = TypeOps :: Union . apply ( db, & return_type, & expr_type) ;
220
+ return_type = union_return_expr ( db, return_type, expr_type) ;
219
221
}
220
222
LuaReturnPoint :: MuliExpr ( exprs) => {
221
223
let mut multi_return = vec ! [ ] ;
@@ -224,10 +226,10 @@ pub fn analyze_return_point(
224
226
multi_return. push ( expr_type) ;
225
227
}
226
228
let typ = LuaType :: Variadic ( VariadicType :: Multi ( multi_return) . into ( ) ) ;
227
- return_type = TypeOps :: Union . apply ( db, & return_type, & typ) ;
229
+ return_type = union_return_expr ( db, return_type, typ) ;
228
230
}
229
231
LuaReturnPoint :: Nil => {
230
- return_type = TypeOps :: Union . apply ( db, & return_type, & LuaType :: Nil ) ;
232
+ return_type = union_return_expr ( db, return_type, LuaType :: Nil ) ;
231
233
}
232
234
_ => { }
233
235
}
@@ -239,3 +241,92 @@ pub fn analyze_return_point(
239
241
name: None ,
240
242
} ] )
241
243
}
244
+
245
+ fn union_return_expr ( db : & DbIndex , left : LuaType , right : LuaType ) -> LuaType {
246
+ if left == LuaType :: Unknown {
247
+ return right;
248
+ }
249
+
250
+ match ( & left, & right) {
251
+ ( LuaType :: Variadic ( left_variadic) , LuaType :: Variadic ( right_variadic) ) => {
252
+ match ( & left_variadic. deref ( ) , & right_variadic. deref ( ) ) {
253
+ ( VariadicType :: Base ( left_base) , VariadicType :: Base ( right_base) ) => {
254
+ let union_base = TypeOps :: Union . apply ( db, left_base, right_base) ;
255
+ LuaType :: Variadic ( VariadicType :: Base ( union_base) . into ( ) )
256
+ }
257
+ ( VariadicType :: Multi ( left_multi) , VariadicType :: Multi ( right_multi) ) => {
258
+ let mut new_multi = vec ! [ ] ;
259
+ let max_len = left_multi. len ( ) . max ( right_multi. len ( ) ) ;
260
+ for i in 0 ..max_len {
261
+ let left_type = left_multi. get ( i) . cloned ( ) . unwrap_or ( LuaType :: Nil ) ;
262
+ let right_type = right_multi. get ( i) . cloned ( ) . unwrap_or ( LuaType :: Nil ) ;
263
+ new_multi. push ( TypeOps :: Union . apply ( db, & left_type, & right_type) ) ;
264
+ }
265
+ LuaType :: Variadic ( VariadicType :: Multi ( new_multi) . into ( ) )
266
+ }
267
+ // difficult to merge the type, use let
268
+ _ => left. clone ( ) ,
269
+ }
270
+ }
271
+ ( LuaType :: Variadic ( variadic) , _) => {
272
+ let first_type = variadic. get_type ( 0 ) . cloned ( ) . unwrap_or ( LuaType :: Unknown ) ;
273
+ let first_union_type = TypeOps :: Union . apply ( db, & first_type, & right) ;
274
+
275
+ match variadic. deref ( ) {
276
+ VariadicType :: Base ( base) => {
277
+ let union_base = TypeOps :: Union . apply ( db, base, & LuaType :: Nil ) ;
278
+ LuaType :: Variadic (
279
+ VariadicType :: Multi ( vec ! [
280
+ first_union_type,
281
+ LuaType :: Variadic ( VariadicType :: Base ( union_base) . into( ) ) ,
282
+ ] )
283
+ . into ( ) ,
284
+ )
285
+ }
286
+ VariadicType :: Multi ( multi) => {
287
+ let mut new_multi = multi. clone ( ) ;
288
+ if new_multi. len ( ) > 0 {
289
+ new_multi[ 0 ] = first_union_type;
290
+ for i in 1 ..new_multi. len ( ) {
291
+ new_multi[ i] = TypeOps :: Union . apply ( db, & new_multi[ i] , & LuaType :: Nil ) ;
292
+ }
293
+ } else {
294
+ new_multi. push ( first_union_type) ;
295
+ }
296
+
297
+ LuaType :: Variadic ( VariadicType :: Multi ( new_multi) . into ( ) )
298
+ }
299
+ }
300
+ }
301
+ ( _, LuaType :: Variadic ( variadic) ) => {
302
+ let first_type = variadic. get_type ( 0 ) . cloned ( ) . unwrap_or ( LuaType :: Unknown ) ;
303
+ let first_union_type = TypeOps :: Union . apply ( db, & left, & first_type) ;
304
+ match variadic. deref ( ) {
305
+ VariadicType :: Base ( base) => {
306
+ let union_base = TypeOps :: Union . apply ( db, base, & LuaType :: Nil ) ;
307
+ LuaType :: Variadic (
308
+ VariadicType :: Multi ( vec ! [
309
+ first_union_type,
310
+ LuaType :: Variadic ( VariadicType :: Base ( union_base) . into( ) ) ,
311
+ ] )
312
+ . into ( ) ,
313
+ )
314
+ }
315
+ VariadicType :: Multi ( multi) => {
316
+ let mut new_multi = multi. clone ( ) ;
317
+ if new_multi. len ( ) > 0 {
318
+ new_multi[ 0 ] = first_union_type;
319
+ for i in 1 ..new_multi. len ( ) {
320
+ new_multi[ i] = TypeOps :: Union . apply ( db, & new_multi[ i] , & LuaType :: Nil ) ;
321
+ }
322
+ } else {
323
+ new_multi. push ( first_union_type) ;
324
+ }
325
+
326
+ LuaType :: Variadic ( VariadicType :: Multi ( new_multi) . into ( ) )
327
+ }
328
+ }
329
+ }
330
+ _ => TypeOps :: Union . apply ( db, & left, & right) ,
331
+ }
332
+ }
0 commit comments