Skip to content

Commit 0a962fc

Browse files
authored
Merge pull request #2486 from fesily/plugin-OnNodeCompileFunctionParam
Plugin on node compile function param
2 parents ea3aed4 + 155f831 commit 0a962fc

File tree

6 files changed

+245
-50
lines changed

6 files changed

+245
-50
lines changed

script/plugin.lua

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,15 @@ local scope = require 'workspace.scope'
77
local ws = require 'workspace'
88
local fs = require 'bee.filesystem'
99

10+
---@class pluginInterfaces
11+
local pluginConfigs = {
12+
-- create plugin for vm module
13+
VM = {
14+
OnCompileFunctionParam = function (next, func, source)
15+
end
16+
}
17+
}
18+
1019
---@class plugin
1120
local m = {}
1221

@@ -51,6 +60,15 @@ function m.dispatch(event, uri, ...)
5160
return failed == 0, res1, res2
5261
end
5362

63+
function m.getVmPlugin(uri)
64+
local scp = scope.getScope(uri)
65+
local interfaces = scp:get('pluginInterfaces')
66+
if not interfaces then
67+
return
68+
end
69+
return interfaces.VM
70+
end
71+
5472
---@async
5573
---@param scp scope
5674
local function checkTrustLoad(scp)
@@ -78,6 +96,40 @@ local function checkTrustLoad(scp)
7896
return true
7997
end
8098

99+
local function createMethodGroup(interfaces, key, methods)
100+
local methodGroup = {}
101+
102+
for method in pairs(methods) do
103+
local funcs = setmetatable({}, {
104+
__call = function (t, next, ...)
105+
if #t == 0 then
106+
return next(...)
107+
else
108+
local result
109+
for _, fn in ipairs(t) do
110+
result = fn(next, ...)
111+
end
112+
return result
113+
end
114+
end
115+
})
116+
for _, interface in ipairs(interfaces) do
117+
local func = interface[method]
118+
if not func then
119+
local namespace = interface[key]
120+
if namespace then
121+
func = namespace[method]
122+
end
123+
end
124+
if func then
125+
funcs[#funcs+1] = func
126+
end
127+
end
128+
methodGroup[method] = funcs
129+
end
130+
return #methodGroup>0 and methodGroup or nil
131+
end
132+
81133
---@param uri uri
82134
local function initPlugin(uri)
83135
await.call(function () ---@async
@@ -148,6 +200,11 @@ local function initPlugin(uri)
148200
end
149201
interfaces[#interfaces+1] = interface
150202
end
203+
204+
for key, config in pairs(pluginConfigs) do
205+
interfaces[key] = createMethodGroup(interfaces, key, config)
206+
end
207+
151208
ws.resetFiles(scp)
152209
end)
153210
end

script/plugins/nodeHelper.lua

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
local vm = require 'vm'
2+
local guide = require 'parser.guide'
3+
4+
local _M = {}
5+
6+
---@class node.match.pattern
7+
---@field next node.match.pattern?
8+
9+
local function deepCompare(source, pattern)
10+
local type1, type2 = type(source), type(pattern)
11+
if type1 ~= type2 then
12+
return false
13+
end
14+
15+
if type1 ~= "table" then
16+
return source == pattern
17+
end
18+
19+
for key2, value2 in pairs(pattern) do
20+
local value1 = source[key2]
21+
if value1 == nil or not deepCompare(value1, value2) then
22+
return false
23+
end
24+
end
25+
26+
return true
27+
end
28+
29+
---@param source parser.object
30+
---@param pattern node.match.pattern
31+
---@return boolean
32+
function _M.matchPattern(source, pattern)
33+
if source.type == 'local' then
34+
if source.parent.type == 'funcargs' and source.parent.parent.type == 'function' then
35+
for i, ref in ipairs(source.ref) do
36+
if deepCompare(ref, pattern) then
37+
return true
38+
end
39+
end
40+
end
41+
end
42+
return false
43+
end
44+
45+
local vaildVarRegex = "()([a-zA-Z][a-zA-Z0-9_]*)()"
46+
---创建类型 *.field.field形式的 pattern
47+
---@param pattern string
48+
---@return node.match.pattern?, string?
49+
function _M.createFieldPattern(pattern)
50+
local ret = { next = nil }
51+
local next = ret
52+
local init = 1
53+
while true do
54+
local startpos, matched, endpos
55+
if pattern:sub(1, 1) == "*" then
56+
startpos, matched, endpos = init, "*", init + 1
57+
else
58+
startpos, matched, endpos = vaildVarRegex:match(pattern, init)
59+
end
60+
if not startpos then
61+
break
62+
end
63+
if startpos ~= init then
64+
return nil, "invalid pattern"
65+
end
66+
local field = matched == "*" and { next = nil }
67+
or { field = { type = 'field', matched }, type = 'getfield', next = nil }
68+
next.next = field
69+
next = field
70+
pattern = pattern:sub(endpos)
71+
end
72+
return ret
73+
end
74+
75+
return _M

script/vm/compiler.lua

Lines changed: 53 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ local rpath = require 'workspace.require-path'
55
local files = require 'files'
66
---@class vm
77
local vm = require 'vm.vm'
8+
local plugin = require 'plugin'
89

910
---@class parser.object
1011
---@field _compiledNodes boolean
@@ -1030,6 +1031,55 @@ local function compileForVars(source, target)
10301031
return false
10311032
end
10321033

1034+
---@param source parser.object
1035+
local function compileFunctionParam(func, source)
1036+
-- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number
1037+
local funcNode = vm.compileNode(func)
1038+
for n in funcNode:eachObject() do
1039+
if n.type == 'doc.type.function' then
1040+
for index, arg in ipairs(n.args) do
1041+
if func.args[index] == source then
1042+
local argNode = vm.compileNode(arg)
1043+
for an in argNode:eachObject() do
1044+
if an.type ~= 'doc.generic.name' then
1045+
vm.setNode(source, an)
1046+
end
1047+
end
1048+
return true
1049+
end
1050+
end
1051+
end
1052+
end
1053+
if func.parent.type == 'local' then
1054+
local refs = func.parent.ref
1055+
local findCall
1056+
if refs then
1057+
for i, ref in ipairs(refs) do
1058+
if ref.parent.type == 'call' then
1059+
findCall = ref.parent
1060+
break
1061+
end
1062+
end
1063+
end
1064+
if findCall and findCall.args then
1065+
local index
1066+
for i, arg in ipairs(source.parent) do
1067+
if arg == source then
1068+
index = i
1069+
break
1070+
end
1071+
end
1072+
if index then
1073+
local callerArg = findCall.args[index]
1074+
if callerArg then
1075+
vm.setNode(source, vm.compileNode(callerArg))
1076+
return true
1077+
end
1078+
end
1079+
end
1080+
end
1081+
end
1082+
10331083
---@param source parser.object
10341084
local function compileLocal(source)
10351085
local myNode = vm.setNode(source, source)
@@ -1069,56 +1119,11 @@ local function compileLocal(source)
10691119
vm.setNode(source, vm.compileNode(setfield.node))
10701120
end
10711121
end
1072-
10731122
if source.parent.type == 'funcargs' and not hasMarkDoc and not hasMarkParam then
10741123
local func = source.parent.parent
1075-
-- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number
1076-
local funcNode = vm.compileNode(func)
1077-
local hasDocArg
1078-
for n in funcNode:eachObject() do
1079-
if n.type == 'doc.type.function' then
1080-
for index, arg in ipairs(n.args) do
1081-
if func.args[index] == source then
1082-
local argNode = vm.compileNode(arg)
1083-
for an in argNode:eachObject() do
1084-
if an.type ~= 'doc.generic.name' then
1085-
vm.setNode(source, an)
1086-
end
1087-
end
1088-
hasDocArg = true
1089-
end
1090-
end
1091-
end
1092-
end
1093-
if not hasDocArg
1094-
and func.parent.type == 'local' then
1095-
local refs = func.parent.ref
1096-
local findCall
1097-
if refs then
1098-
for i, ref in ipairs(refs) do
1099-
if ref.parent.type == 'call' then
1100-
findCall = ref.parent
1101-
break
1102-
end
1103-
end
1104-
end
1105-
if findCall and findCall.args then
1106-
local index
1107-
for i, arg in ipairs(source.parent) do
1108-
if arg == source then
1109-
index = i
1110-
break
1111-
end
1112-
end
1113-
if index then
1114-
local callerArg = findCall.args[index]
1115-
if callerArg then
1116-
hasDocArg = true
1117-
vm.setNode(source, vm.compileNode(callerArg))
1118-
end
1119-
end
1120-
end
1121-
end
1124+
local vmPlugin = plugin.getVmPlugin(guide.getUri(source))
1125+
local hasDocArg = vmPlugin and vmPlugin.OnCompileFunctionParam(compileFunctionParam, func, source)
1126+
or compileFunctionParam(func, source)
11221127
if not hasDocArg then
11231128
vm.setNode(source, vm.declareGlobal('type', 'any'))
11241129
end

script/workspace/scope.lua

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,9 @@ function mt:set(k, v)
125125
return v
126126
end
127127

128-
---@param k string
129-
---@return any
128+
---@generic T
129+
---@param k `T`
130+
---@return T
130131
function mt:get(k)
131132
return self._data[k]
132133
end

test/plugins/node/test.lua

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
local files = require 'files'
2+
local scope = require 'workspace.scope'
3+
local nodeHelper = require 'plugins.nodeHelper'
4+
local vm = require 'vm'
5+
local guide = require 'parser.guide'
6+
7+
8+
local pattern, msg = nodeHelper.createFieldPattern("*.components")
9+
assert(pattern, msg)
10+
11+
---@param source parser.object
12+
function OnCompileFunctionParam(next, func, source)
13+
if next(func, source) then
14+
return true
15+
end
16+
--从该参数的使用模式来推导该类型
17+
if nodeHelper.matchPattern(source, pattern) then
18+
local type = vm.declareGlobal('type', 'TestClass', TESTURI)
19+
vm.setNode(source, vm.createNode(type, source))
20+
return true
21+
end
22+
end
23+
24+
local myplugin = { OnCompileFunctionParam = OnCompileFunctionParam }
25+
26+
---@diagnostic disable: await-in-sync
27+
local function TestPlugin(script)
28+
local prefix = [[
29+
---@class TestClass
30+
---@field b string
31+
]]
32+
---@param checker fun(state:parser.state)
33+
return function (plugin, checker)
34+
files.open(TESTURI)
35+
files.setText(TESTURI, prefix .. script, true)
36+
scope.getScope(TESTURI):set('pluginInterfaces', plugin)
37+
local state = files.getState(TESTURI)
38+
assert(state)
39+
checker(state)
40+
files.remove(TESTURI)
41+
end
42+
end
43+
44+
TestPlugin [[
45+
local function t(a)
46+
a.components:test()
47+
end
48+
]](myplugin, function (state)
49+
guide.eachSourceType(state.ast, 'local', function (src)
50+
if guide.getKeyName(src) == 'a' then
51+
local node = vm.compileNode(src)
52+
assert(node)
53+
assert(not vm.isUnknown(node))
54+
end
55+
end)
56+
end)

test/plugins/test.lua

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
require 'plugins.ast.test'
22
require 'plugins.ffi.test'
3+
require 'plugins.node.test'

0 commit comments

Comments
 (0)