|
| 1 | +local files = require 'files' |
| 2 | +local furi = require 'file-uri' |
| 3 | +local vm = require 'vm' |
| 4 | +local guide = require 'parser.guide' |
| 5 | +local catch = require 'catch' |
| 6 | + |
| 7 | +rawset(_G, 'TEST', true) |
| 8 | + |
| 9 | +local function getSource(uri, pos) |
| 10 | + local state = files.getState(uri) |
| 11 | + if not state then |
| 12 | + return |
| 13 | + end |
| 14 | + local result |
| 15 | + guide.eachSourceContain(state.ast, pos, function (source) |
| 16 | + if source.type == 'local' |
| 17 | + or source.type == 'getlocal' |
| 18 | + or source.type == 'setlocal' |
| 19 | + or source.type == 'setglobal' |
| 20 | + or source.type == 'getglobal' |
| 21 | + or source.type == 'field' |
| 22 | + or source.type == 'method' |
| 23 | + or source.type == 'function' |
| 24 | + or source.type == 'table' |
| 25 | + or source.type == 'doc.type.name' then |
| 26 | + result = source |
| 27 | + end |
| 28 | + end) |
| 29 | + return result |
| 30 | +end |
| 31 | + |
| 32 | +local EXISTS = {} |
| 33 | + |
| 34 | +local function eq(a, b) |
| 35 | + if a == EXISTS and b ~= nil then |
| 36 | + return true |
| 37 | + end |
| 38 | + if b == EXISTS and a ~= nil then |
| 39 | + return true |
| 40 | + end |
| 41 | + local tp1, tp2 = type(a), type(b) |
| 42 | + if tp1 ~= tp2 then |
| 43 | + return false |
| 44 | + end |
| 45 | + if tp1 == 'table' then |
| 46 | + local mark = {} |
| 47 | + for k in pairs(a) do |
| 48 | + if not eq(a[k], b[k]) then |
| 49 | + return false |
| 50 | + end |
| 51 | + mark[k] = true |
| 52 | + end |
| 53 | + for k in pairs(b) do |
| 54 | + if not mark[k] then |
| 55 | + return false |
| 56 | + end |
| 57 | + end |
| 58 | + return true |
| 59 | + end |
| 60 | + return a == b |
| 61 | +end |
| 62 | + |
| 63 | +---@diagnostic disable: await-in-sync |
| 64 | +function TEST(expect) |
| 65 | + local sourcePos, sourceUri |
| 66 | + for _, file in ipairs(expect) do |
| 67 | + local script, list = catch(file.content, '?') |
| 68 | + local uri = furi.encode(file.path) |
| 69 | + files.setText(uri, script) |
| 70 | + files.compileState(uri) |
| 71 | + if #list['?'] > 0 then |
| 72 | + sourceUri = uri |
| 73 | + sourcePos = (list['?'][1][1] + list['?'][1][2]) // 2 |
| 74 | + end |
| 75 | + end |
| 76 | + |
| 77 | + local _ <close> = function () |
| 78 | + for _, info in ipairs(expect) do |
| 79 | + files.remove(furi.encode(info.path)) |
| 80 | + end |
| 81 | + end |
| 82 | + |
| 83 | + local source = getSource(sourceUri, sourcePos) |
| 84 | + assert(source) |
| 85 | + local view = vm.getInfer(source):view(sourceUri) |
| 86 | + assert(eq(view, expect.infer)) |
| 87 | +end |
| 88 | + |
| 89 | +TEST { |
| 90 | + { |
| 91 | + path = 'a.lua', |
| 92 | + content = [[ |
| 93 | +---@class T |
| 94 | +local x |
| 95 | +
|
| 96 | +---@class V |
| 97 | +x.y = 1 |
| 98 | +]], |
| 99 | + }, |
| 100 | + { |
| 101 | + path = 'b.lua', |
| 102 | + content = [[ |
| 103 | +---@type T |
| 104 | +local x |
| 105 | +
|
| 106 | +if x.y then |
| 107 | + print(x.<?y?>) |
| 108 | +end |
| 109 | + ]], |
| 110 | + }, |
| 111 | + infer = 'V', |
| 112 | +} |
0 commit comments