Skip to content

Commit 3057063

Browse files
committed
infer by x == nil and x ~= nil
1 parent 5381bf4 commit 3057063

File tree

3 files changed

+119
-4
lines changed

3 files changed

+119
-4
lines changed

script/vm/node.lua

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ function mt:addOptional()
7171
end
7272

7373
function mt:removeOptional()
74-
self.optional = false
74+
self:remove 'nil'
7575
end
7676

7777
---@return boolean
@@ -187,6 +187,9 @@ end
187187

188188
---@param name string
189189
function mt:remove(name)
190+
if name == 'nil' and self.optional == true then
191+
self.optional = nil
192+
end
190193
local index = 0
191194
while true do
192195
index = index + 1

script/vm/runner.lua

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ mt.index = 1
1616
---@field _hasSorted boolean
1717

1818
---@class vm.runner.step
19-
---@field type 'truly' | 'falsy' | 'add' | 'remove' | 'object' | 'save' | 'load' | 'merge'
19+
---@field type 'truly' | 'falsy' | 'as' | 'add' | 'remove' | 'object' | 'save' | 'load' | 'merge'
2020
---@field pos integer
2121
---@field order? integer
2222
---@field node? vm.node
@@ -39,9 +39,13 @@ function mt:_compileNarrowByFilter(filter, pos)
3939
return
4040
end
4141
if filter.type == 'unary' then
42-
if filter.op and filter.op.type == 'not' then
42+
if not filter.op
43+
or not filter[1] then
44+
return
45+
end
46+
if filter.op.type == 'not' then
4347
local exp = filter[1]
44-
if exp and exp.type == 'getlocal' and exp.node == self.loc then
48+
if exp.type == 'getlocal' and exp.node == self.loc then
4549
self.steps[#self.steps+1] = {
4650
type = 'truly',
4751
pos = pos,
@@ -55,6 +59,59 @@ function mt:_compileNarrowByFilter(filter, pos)
5559
end
5660
end
5761
elseif filter.type == 'binary' then
62+
if not filter.op
63+
or not filter[1]
64+
or not filter[2] then
65+
return
66+
end
67+
if filter.op.type == 'and' then
68+
self:_compileNarrowByFilter(filter[1], pos)
69+
self:_compileNarrowByFilter(filter[2], pos)
70+
end
71+
if filter.op.type == '=='
72+
or filter.op.type == '~=' then
73+
local loc, exp
74+
for i = 1, 2 do
75+
loc = filter[i]
76+
if loc.type == 'getlocal' and loc.node == self.loc then
77+
exp = filter[i % 2 + 1]
78+
break
79+
end
80+
end
81+
if not loc then
82+
return
83+
end
84+
if exp.type == 'nil' then
85+
if filter.op.type == '==' then
86+
self.steps[#self.steps+1] = {
87+
type = 'remove',
88+
name = 'nil',
89+
pos = pos,
90+
order = 2,
91+
}
92+
self.steps[#self.steps+1] = {
93+
type = 'as',
94+
name = 'nil',
95+
pos = pos,
96+
order = 4,
97+
}
98+
end
99+
if filter.op.type == '~=' then
100+
self.steps[#self.steps+1] = {
101+
type = 'as',
102+
name = 'nil',
103+
pos = pos,
104+
order = 2,
105+
}
106+
self.steps[#self.steps+1] = {
107+
type = 'remove',
108+
name = 'nil',
109+
pos = pos,
110+
order = 4,
111+
}
112+
end
113+
end
114+
end
58115
else
59116
if filter.type == 'getlocal' and filter.node == self.loc then
60117
self.steps[#self.steps+1] = {
@@ -193,6 +250,8 @@ function mt:launch(callback)
193250
node:setTruly()
194251
elseif step.type == 'falsy' then
195252
node:setFalsy()
253+
elseif step.type == 'as' then
254+
node = vm.createNode(globalMgr.getGlobal('type', step.name))
196255
elseif step.type == 'add' then
197256
node:merge(globalMgr.getGlobal('type', step.name))
198257
elseif step.type == 'remove' then

test/type_inference/init.lua

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1780,6 +1780,59 @@ end
17801780
print(<?x?>)
17811781
]]
17821782

1783+
TEST 'integer' [[
1784+
---@type integer?
1785+
local x
1786+
1787+
if xxx and x then
1788+
print(<?x?>)
1789+
end
1790+
]]
1791+
1792+
TEST 'integer' [[
1793+
---@type integer?
1794+
local x
1795+
1796+
if x ~= nil then
1797+
print(<?x?>)
1798+
end
1799+
1800+
print(x)
1801+
]]
1802+
1803+
TEST 'integer?' [[
1804+
---@type integer?
1805+
local x
1806+
1807+
if x ~= nil then
1808+
print(x)
1809+
end
1810+
1811+
print(<?x?>)
1812+
]]
1813+
1814+
TEST 'nil' [[
1815+
---@type integer?
1816+
local x
1817+
1818+
if x == nil then
1819+
print(<?x?>)
1820+
end
1821+
1822+
print(x)
1823+
]]
1824+
1825+
TEST 'integer|nil' [[
1826+
---@type integer?
1827+
local x
1828+
1829+
if x == nil then
1830+
print(x)
1831+
end
1832+
1833+
print(<?x?>)
1834+
]]
1835+
17831836
TEST 'integer' [=[
17841837
local x
17851838

0 commit comments

Comments
 (0)