Skip to content

Commit 7e778b3

Browse files
committed
new diag return-type-mismatch
1 parent 09cd988 commit 7e778b3

File tree

7 files changed

+161
-17
lines changed

7 files changed

+161
-17
lines changed

changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
* `missing-return-value`
1111
* `redundant-return-value`
1212
* `missing-return`
13+
* `return-type-mismatch`
1314
* `NEW` settings:
1415
* `diagnostics.groupSeverity`
1516
* `diagnostics.groupFileStatus`

locale/zh-cn/script.lua

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ DIAG_REDUNDANT_RETURN_VALUE_RANGE =
138138
'最多只有 {max} 个返回值,但此处返回了第 {rmin} 到第 {rmax} 个值。'
139139
DIAG_MISSING_RETURN =
140140
'此处需要返回值。'
141+
DIAG_RETURN_TYPE_MISMATCH =
142+
'第 {index} 个返回值的类型为 `{def}` ,但实际返回的是 `{ref}`。'
141143

142144
MWS_NOT_SUPPORT =
143145
'{} 目前还不支持多工作目录,我可能需要重启才能支持新的工作目录...'
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
local files = require 'files'
2+
local lang = require 'language'
3+
local guide = require 'parser.guide'
4+
local vm = require 'vm'
5+
local await = require 'await'
6+
7+
---@param func parser.object
8+
---@return vm.node[]?
9+
local function getDocReturns(func)
10+
if not func.bindDocs then
11+
return nil
12+
end
13+
local returns = {}
14+
for _, doc in ipairs(func.bindDocs) do
15+
if doc.type == 'doc.return' then
16+
for _, ret in ipairs(doc.returns) do
17+
returns[ret.returnIndex] = vm.compileNode(ret)
18+
end
19+
end
20+
end
21+
if #returns == 0 then
22+
return nil
23+
end
24+
return returns
25+
end
26+
---@async
27+
return function (uri, callback)
28+
local state = files.getState(uri)
29+
if not state then
30+
return
31+
end
32+
33+
---@param docReturns vm.node[]
34+
---@param rets parser.object
35+
local function checkReturn(docReturns, rets)
36+
for i, docRet in ipairs(docReturns) do
37+
local retNode, exp = vm.selectNode(rets, i)
38+
if not exp then
39+
break
40+
end
41+
if not vm.canCastType(uri, docRet, retNode) then
42+
callback {
43+
start = exp.start,
44+
finish = exp.finish,
45+
message = lang.script('DIAG_RETURN_TYPE_MISMATCH', {
46+
def = vm.getInfer(docRet):view(uri),
47+
ref = vm.getInfer(retNode):view(uri),
48+
index = i,
49+
}),
50+
}
51+
end
52+
end
53+
end
54+
55+
---@async
56+
guide.eachSourceType(state.ast, 'function', function (source)
57+
if not source.returns then
58+
return
59+
end
60+
local docReturns = getDocReturns(source)
61+
if not docReturns then
62+
return
63+
end
64+
await.delay()
65+
for _, ret in ipairs(source.returns) do
66+
checkReturn(docReturns, ret)
67+
await.delay()
68+
end
69+
end)
70+
end

script/parser/guide.lua

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,12 +1090,12 @@ end
10901090
---@param a table
10911091
---@param b table
10921092
---@return string|false mode
1093-
---@return table pathA?
1094-
---@return table pathB?
1093+
---@return table? pathA
1094+
---@return table? pathB
10951095
function m.getPath(a, b, sameFunction)
10961096
--- 首先测试双方在同一个函数内
10971097
if sameFunction and m.getParentFunction(a) ~= m.getParentFunction(b) then
1098-
return false, nil, nil
1098+
return false
10991099
end
11001100
local mode
11011101
local objA
@@ -1139,7 +1139,7 @@ function m.getPath(a, b, sameFunction)
11391139
end
11401140
end
11411141
if not start then
1142-
return false, nil, nil
1142+
return false
11431143
end
11441144
-- pathA: { 1, 2, 3}
11451145
-- pathB: {5, 6, 2, 3}

script/proto/diagnostic.lua

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ m.register {
7474
'assign-type-mismatch',
7575
'param-type-mismatch',
7676
'cast-type-mismatch',
77+
'return-type-mismatch',
7778
} {
7879
group = 'type-check',
7980
severity = 'Warning',

script/vm/compiler.lua

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -469,9 +469,9 @@ function vm.getReturnOfFunction(func, index)
469469
end
470470
if not func._returns[index] then
471471
func._returns[index] = {
472-
type = 'function.return',
473-
parent = func,
474-
index = index,
472+
type = 'function.return',
473+
parent = func,
474+
returnIndex = index,
475475
}
476476
end
477477
return func._returns[index]
@@ -736,14 +736,15 @@ function vm.compileByParentNode(source, key, ref, pushResult)
736736
end
737737
end
738738

739-
---@return vm.node?
740-
local function selectNode(source, list, index)
741-
if not list then
742-
return nil
743-
end
739+
---@param list parser.object[]
740+
---@param index integer
741+
---@return vm.node
742+
---@return parser.object?
743+
function vm.selectNode(list, index)
744744
local exp
745745
if list[index] then
746746
exp = list[index]
747+
index = 1
747748
else
748749
for i = index, 1, -1 do
749750
if list[i] then
@@ -758,16 +759,14 @@ local function selectNode(source, list, index)
758759
end
759760
end
760761
if not exp then
761-
vm.setNode(source, vm.declareGlobal('type', 'nil'))
762-
return vm.getNode(source)
762+
return vm.createNode(vm.declareGlobal('type', 'nil')), nil
763763
end
764764
---@type vm.node?
765765
local result
766766
if exp.type == 'call' then
767767
result = getReturn(exp.node, index, exp.args)
768768
if not result then
769-
vm.setNode(source, vm.declareGlobal('type', 'unknown'))
770-
return vm.getNode(source)
769+
return vm.createNode(vm.declareGlobal('type', 'unknown')), exp
771770
end
772771
else
773772
---@type vm.node
@@ -776,6 +775,15 @@ local function selectNode(source, list, index)
776775
result:merge(vm.declareGlobal('type', 'unknown'))
777776
end
778777
end
778+
return result, exp
779+
end
780+
781+
---@param source parser.object
782+
---@param list parser.object[]
783+
---@param index integer
784+
---@return vm.node
785+
local function selectNode(source, list, index)
786+
local result = vm.selectNode(list, index)
779787
if source.type == 'function.return' then
780788
-- remove any for returns
781789
local rtnNode = vm.createNode()
@@ -1513,9 +1521,10 @@ local compilerSwitch = util.switch()
15131521
vm.setNode(source, vm.compileNode(source.value))
15141522
end)
15151523
: case 'function.return'
1524+
---@param source parser.object
15161525
: call(function (source)
15171526
local func = source.parent
1518-
local index = source.index
1527+
local index = source.returnIndex
15191528
local hasMarkDoc
15201529
if func.bindDocs then
15211530
local sign = getObjectSign(func)

test/diagnostics/type-check.lua

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,5 +496,66 @@ TEST [[
496496
local <!x!> = 'aaa'
497497
]]
498498

499+
TEST [[
500+
---@return number
501+
function F()
502+
return <!true!>
503+
end
504+
]]
505+
506+
TEST [[
507+
---@return number?
508+
function F()
509+
return 1
510+
end
511+
]]
512+
513+
TEST [[
514+
---@return number?
515+
function F()
516+
return nil
517+
end
518+
]]
519+
520+
TEST [[
521+
---@return number, number
522+
local function f() end
523+
524+
---@return number, boolean
525+
function F()
526+
return <!f()!>
527+
end
528+
]]
529+
530+
TEST [[
531+
---@return boolean, number
532+
local function f() end
533+
534+
---@return number, boolean
535+
function F()
536+
return <!f()!>
537+
end
538+
]]
539+
540+
TEST [[
541+
---@return boolean, number?
542+
local function f() end
543+
544+
---@return number, boolean
545+
function F()
546+
return 1, f()
547+
end
548+
]]
549+
550+
TEST [[
551+
---@return number, number?
552+
local function f() end
553+
554+
---@return number, boolean, number
555+
function F()
556+
return 1, <!f()!>
557+
end
558+
]]
559+
499560
config.remove(nil, 'Lua.diagnostics.disable', 'unused-local')
500561
config.remove(nil, 'Lua.diagnostics.disable', 'undefined-global')

0 commit comments

Comments
 (0)