Skip to content

Commit 82c004e

Browse files
committed
recode plugin interface
1 parent dd49a6d commit 82c004e

File tree

4 files changed

+120
-60
lines changed

4 files changed

+120
-60
lines changed

script/plugin.lua

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,15 @@ local scope = require 'workspace.scope'
77
local ws = require 'workspace'
88
local fs = require 'bee.filesystem'
99

10+
---@class pluginInterfaces
11+
local pluginConfigs = {
12+
-- create plugin for vm module
13+
VM = {
14+
OnCompileFunctionParam = function (next, func, source)
15+
end
16+
}
17+
}
18+
1019
---@class plugin
1120
local m = {}
1221

@@ -51,6 +60,15 @@ function m.dispatch(event, uri, ...)
5160
return failed == 0, res1, res2
5261
end
5362

63+
function m.getVmPlugin(uri)
64+
local scp = scope.getScope(uri)
65+
local interfaces = scp:get('pluginInterfaces')
66+
if not interfaces then
67+
return
68+
end
69+
return interfaces.VM
70+
end
71+
5472
---@async
5573
---@param scp scope
5674
local function checkTrustLoad(scp)
@@ -78,6 +96,40 @@ local function checkTrustLoad(scp)
7896
return true
7997
end
8098

99+
local function createMethodGroup(interfaces, key, methods)
100+
local methodGroup = {}
101+
102+
for method in pairs(methods) do
103+
local funcs = setmetatable({}, {
104+
__call = function (t, next, ...)
105+
if #t == 0 then
106+
return next(...)
107+
else
108+
local result
109+
for _, fn in ipairs(t) do
110+
result = fn(next, ...)
111+
end
112+
return result
113+
end
114+
end
115+
})
116+
for _, interface in ipairs(interfaces) do
117+
local func = interface[method]
118+
if not func then
119+
local namespace = interface[key]
120+
if namespace then
121+
func = namespace[method]
122+
end
123+
end
124+
if func then
125+
funcs[#funcs+1] = func
126+
end
127+
end
128+
methodGroup[method] = funcs
129+
end
130+
return #methodGroup>0 and methodGroup or nil
131+
end
132+
81133
---@param uri uri
82134
local function initPlugin(uri)
83135
await.call(function () ---@async
@@ -148,6 +200,11 @@ local function initPlugin(uri)
148200
end
149201
interfaces[#interfaces+1] = interface
150202
end
203+
204+
for key, config in pairs(pluginConfigs) do
205+
interfaces[key] = createMethodGroup(interfaces, key, config)
206+
end
207+
151208
ws.resetFiles(scp)
152209
end)
153210
end

script/vm/compiler.lua

Lines changed: 53 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,6 +1031,56 @@ local function compileForVars(source, target)
10311031
return false
10321032
end
10331033

1034+
---@param source parser.object
1035+
local function compileFunctionParam(func, source)
1036+
-- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number
1037+
local funcNode = vm.compileNode(func)
1038+
local hasDocArg
1039+
for n in funcNode:eachObject() do
1040+
if n.type == 'doc.type.function' then
1041+
for index, arg in ipairs(n.args) do
1042+
if func.args[index] == source then
1043+
local argNode = vm.compileNode(arg)
1044+
for an in argNode:eachObject() do
1045+
if an.type ~= 'doc.generic.name' then
1046+
vm.setNode(source, an)
1047+
end
1048+
end
1049+
return true
1050+
end
1051+
end
1052+
end
1053+
end
1054+
if func.parent.type == 'local' then
1055+
local refs = func.parent.ref
1056+
local findCall
1057+
if refs then
1058+
for i, ref in ipairs(refs) do
1059+
if ref.parent.type == 'call' then
1060+
findCall = ref.parent
1061+
break
1062+
end
1063+
end
1064+
end
1065+
if findCall and findCall.args then
1066+
local index
1067+
for i, arg in ipairs(source.parent) do
1068+
if arg == source then
1069+
index = i
1070+
break
1071+
end
1072+
end
1073+
if index then
1074+
local callerArg = findCall.args[index]
1075+
if callerArg then
1076+
vm.setNode(source, vm.compileNode(callerArg))
1077+
return true
1078+
end
1079+
end
1080+
end
1081+
end
1082+
end
1083+
10341084
---@param source parser.object
10351085
local function compileLocal(source)
10361086
local myNode = vm.setNode(source, source)
@@ -1070,63 +1120,11 @@ local function compileLocal(source)
10701120
vm.setNode(source, vm.compileNode(setfield.node))
10711121
end
10721122
end
1073-
10741123
if source.parent.type == 'funcargs' and not hasMarkDoc and not hasMarkParam then
10751124
local func = source.parent.parent
1076-
-- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number
1077-
local funcNode = vm.compileNode(func)
1078-
local hasDocArg
1079-
for n in funcNode:eachObject() do
1080-
if n.type == 'doc.type.function' then
1081-
for index, arg in ipairs(n.args) do
1082-
if func.args[index] == source then
1083-
local argNode = vm.compileNode(arg)
1084-
for an in argNode:eachObject() do
1085-
if an.type ~= 'doc.generic.name' then
1086-
vm.setNode(source, an)
1087-
end
1088-
end
1089-
hasDocArg = true
1090-
end
1091-
end
1092-
end
1093-
end
1094-
if not hasDocArg
1095-
and func.parent.type == 'local' then
1096-
local refs = func.parent.ref
1097-
local findCall
1098-
if refs then
1099-
for i, ref in ipairs(refs) do
1100-
if ref.parent.type == 'call' then
1101-
findCall = ref.parent
1102-
break
1103-
end
1104-
end
1105-
end
1106-
if findCall and findCall.args then
1107-
local index
1108-
for i, arg in ipairs(source.parent) do
1109-
if arg == source then
1110-
index = i
1111-
break
1112-
end
1113-
end
1114-
if index then
1115-
local callerArg = findCall.args[index]
1116-
if callerArg then
1117-
hasDocArg = true
1118-
vm.setNode(source, vm.compileNode(callerArg))
1119-
end
1120-
end
1121-
end
1122-
end
1123-
if not hasDocArg then
1124-
local suc, node = plugin.dispatch("OnNodeCompileFunctionParam", guide.getUri(source), source)
1125-
if suc and node then
1126-
hasDocArg = true
1127-
vm.setNode(source, node)
1128-
end
1129-
end
1125+
local vmPlugin = plugin.getVmPlugin(guide.getUri(source))
1126+
local hasDocArg = vmPlugin and vmPlugin.OnCompileFunctionParam(compileFunctionParam, func, source)
1127+
or compileFunctionParam(func, source)
11301128
if not hasDocArg then
11311129
vm.setNode(source, vm.declareGlobal('type', 'any'))
11321130
end

script/workspace/scope.lua

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,9 @@ function mt:set(k, v)
125125
return v
126126
end
127127

128-
---@param k string
129-
---@return any
128+
---@generic T
129+
---@param k `T`
130+
---@return T
130131
function mt:get(k)
131132
return self._data[k]
132133
end

test/plugins/node/test.lua

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,19 @@ local pattern, msg = nodeHelper.createFieldPattern("*.components")
99
assert(pattern, msg)
1010

1111
---@param source parser.object
12-
function OnNodeCompileFunctionParam(uri, source)
12+
function OnCompileFunctionParam(next, func, source)
13+
if next(func, source) then
14+
return true
15+
end
1316
--从该参数的使用模式来推导该类型
1417
if nodeHelper.matchPattern(source, pattern) then
1518
local type = vm.declareGlobal('type', 'TestClass', TESTURI)
16-
return vm.createNode(type, source)
19+
vm.setNode(source, vm.createNode(type, source))
20+
return true
1721
end
1822
end
1923

20-
local myplugin = { OnNodeCompileFunctionParam = OnNodeCompileFunctionParam }
24+
local myplugin = { OnCompileFunctionParam = OnCompileFunctionParam }
2125

2226
---@diagnostic disable: await-in-sync
2327
local function TestPlugin(script)

0 commit comments

Comments
 (0)