Skip to content

Commit 999c5dd

Browse files
committed
infer by assert(x)
1 parent 3cadbfc commit 999c5dd

File tree

3 files changed

+71
-5
lines changed

3 files changed

+71
-5
lines changed

script/parser/newparser.lua

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ local Specials = {
117117
['xpcall'] = true,
118118
['pairs'] = true,
119119
['ipairs'] = true,
120+
['assert'] = true,
120121
}
121122

122123
local UnarySymbol = {

script/vm/runner.lua

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,31 +81,31 @@ function mt:_compileNarrowByFilter(filter, pos)
8181
if not loc or not exp then
8282
return
8383
end
84-
if exp.type == 'nil' then
84+
if guide.isLiteral(exp) then
8585
if filter.op.type == '==' then
8686
self.steps[#self.steps+1] = {
8787
type = 'remove',
88-
name = 'nil',
88+
name = exp.type,
8989
pos = pos,
9090
order = 2,
9191
}
9292
self.steps[#self.steps+1] = {
9393
type = 'as',
94-
name = 'nil',
94+
name = exp.type,
9595
pos = pos,
9696
order = 4,
9797
}
9898
end
9999
if filter.op.type == '~=' then
100100
self.steps[#self.steps+1] = {
101101
type = 'as',
102-
name = 'nil',
102+
name = exp.type,
103103
pos = pos,
104104
order = 2,
105105
}
106106
self.steps[#self.steps+1] = {
107107
type = 'remove',
108-
name = 'nil',
108+
name = exp.type,
109109
pos = pos,
110110
order = 4,
111111
}
@@ -248,6 +248,42 @@ function mt:_preCompile()
248248
end)
249249
end
250250

251+
---@param loc parser.object
252+
---@param node vm.node
253+
---@return vm.node
254+
local function checkAssert(loc, node)
255+
local parent = loc.parent
256+
if parent.type == 'binary' then
257+
if parent.op and (parent.op.type == '~=' or parent.op.type == '==') then
258+
local exp
259+
for i = 1, 2 do
260+
if parent[i] == loc then
261+
exp = parent[i % 2 + 1]
262+
end
263+
end
264+
if exp and guide.isLiteral(exp) then
265+
local callargs = parent.parent
266+
if callargs.type == 'callargs'
267+
and callargs.parent.node.special == 'assert'
268+
and callargs[1] == parent then
269+
if parent.op.type == '~=' then
270+
node:remove(exp.type)
271+
end
272+
if parent.op.type == '==' then
273+
node = vm.compileNode(exp)
274+
end
275+
end
276+
end
277+
end
278+
end
279+
if parent.type == 'callargs'
280+
and parent.parent.node.special == 'assert'
281+
and parent[1] == loc then
282+
node:setTruly()
283+
end
284+
return node
285+
end
286+
251287
---@param callback fun(src: parser.object, node: vm.node)
252288
function mt:launch(callback)
253289
local node = vm.getNode(self.loc):copy()
@@ -267,6 +303,9 @@ function mt:launch(callback)
267303
node:remove(step.name)
268304
elseif step.type == 'object' then
269305
node = callback(step.object, node) or node
306+
if step.object.type == 'getlocal' then
307+
node = checkAssert(step.object, node)
308+
end
270309
elseif step.type == 'save' then
271310
-- nothing to do
272311
elseif step.type == 'load' then

test/type_inference/init.lua

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1919,3 +1919,29 @@ local <?x?> = t[1]
19191919
TEST 'unknown' [[
19201920
local <?x?> = y and z
19211921
]]
1922+
1923+
TEST 'integer' [[
1924+
---@type integer?
1925+
local x
1926+
1927+
assert(x)
1928+
1929+
print(<?x?>)
1930+
]]
1931+
1932+
TEST 'integer' [[
1933+
---@type integer?
1934+
local x
1935+
1936+
assert(x ~= nil)
1937+
1938+
print(<?x?>)
1939+
]]
1940+
1941+
TEST 'integer' [[
1942+
local x
1943+
1944+
assert(x == 1)
1945+
1946+
print(<?x?>)
1947+
]]

0 commit comments

Comments
 (0)