|
| 1 | +local lang = require 'language' |
| 2 | +local parser = require 'parser' |
| 3 | +local guide = require 'parser.guide' |
| 4 | + |
| 5 | +local function nodeId(node) |
| 6 | + return node.type .. ':' .. node.start .. ':' .. node.finish |
| 7 | +end |
| 8 | + |
| 9 | +local function shorten(str) |
| 10 | + if type(str) ~= 'string' then |
| 11 | + return str |
| 12 | + end |
| 13 | + str = str:gsub('\n', '\\\\n') |
| 14 | + if #str <= 20 then |
| 15 | + return str |
| 16 | + else |
| 17 | + return str:sub(1, 17) .. '...' |
| 18 | + end |
| 19 | +end |
| 20 | + |
| 21 | +local function getTooltipLine(k, v) |
| 22 | + if type(v) == 'table' then |
| 23 | + if v.type then |
| 24 | + v = '<node ' .. v.type .. '>' |
| 25 | + else |
| 26 | + v = '<table>' |
| 27 | + end |
| 28 | + end |
| 29 | + v = tostring(v) |
| 30 | + v = v:gsub('"', '\\"') |
| 31 | + return k .. ': ' .. shorten(v) .. '\\n' |
| 32 | +end |
| 33 | + |
| 34 | +local function getTooltip(node) |
| 35 | + local str = '' |
| 36 | + local skipNodes = {parent = true, start = true, finish = true, type = true} |
| 37 | + str = str .. getTooltipLine('start', node.start) |
| 38 | + str = str .. getTooltipLine('finish', node.finish) |
| 39 | + for k, v in pairs(node) do |
| 40 | + if type(k) ~= 'number' and not skipNodes[k] then |
| 41 | + str = str .. getTooltipLine(k, v) |
| 42 | + end |
| 43 | + end |
| 44 | + for i = 1, math.min(#node, 15) do |
| 45 | + str = str .. getTooltipLine(i, node[i]) |
| 46 | + end |
| 47 | + if #node > 15 then |
| 48 | + str = str .. getTooltipLine('15..' .. #node, '(...)') |
| 49 | + end |
| 50 | + return str |
| 51 | +end |
| 52 | + |
| 53 | +local nodeEntry = '\t"%s" [\n\t\tlabel="%s\\l%s\\l"\n\t\ttooltip="%s"\n\t]' |
| 54 | +local function getNodeLabel(node) |
| 55 | + local keyName = guide.getKeyName(node) |
| 56 | + if node.type == 'binary' or node.type == 'unary' then |
| 57 | + keyName = node.op.type |
| 58 | + elseif node.type == 'label' or node.type == 'goto' then |
| 59 | + keyName = node[1] |
| 60 | + end |
| 61 | + return nodeEntry:format(nodeId(node), node.type, shorten(keyName) or '', getTooltip(node)) |
| 62 | +end |
| 63 | + |
| 64 | +local function getVisualizeVisitor(writer) |
| 65 | + local function visitNode(node, parent) |
| 66 | + if node == nil then return end |
| 67 | + writer:write(getNodeLabel(node)) |
| 68 | + writer:write('\n') |
| 69 | + if parent then |
| 70 | + writer:write(('\t"%s" -> "%s"'):format(nodeId(parent), nodeId(node))) |
| 71 | + writer:write('\n') |
| 72 | + end |
| 73 | + guide.eachChild(node, function(child) |
| 74 | + visitNode(child, node) |
| 75 | + end) |
| 76 | + end |
| 77 | + return visitNode |
| 78 | +end |
| 79 | + |
| 80 | + |
| 81 | +local export = {} |
| 82 | + |
| 83 | +function export.visualizeAst(code, writer) |
| 84 | + local state = parser.compile(code, 'Lua', _G['LUA_VER'] or 'Lua 5.4') |
| 85 | + writer:write('digraph AST {\n') |
| 86 | + writer:write('\tnode [shape = rect]\n') |
| 87 | + getVisualizeVisitor(writer)(state.ast) |
| 88 | + writer:write('}\n') |
| 89 | +end |
| 90 | + |
| 91 | +function export.runCLI() |
| 92 | + lang(LOCALE) |
| 93 | + local file = _G['VISUALIZE'] |
| 94 | + local code, err = io.open(file) |
| 95 | + if not code then |
| 96 | + io.stderr:write('failed to open ' .. file .. ': ' .. err) |
| 97 | + return 1 |
| 98 | + end |
| 99 | + code = code:read('a') |
| 100 | + return export.visualizeAst(code, io.stdout) |
| 101 | +end |
| 102 | + |
| 103 | +return export |
0 commit comments