Skip to content

Commit 48af045

Browse files
Add modules support into resolve_imports (#47)
Co-authored-by: Alexandre Pasmantier <alex.pasmant@gmail.com>
1 parent cdffffe commit 48af045

File tree

8 files changed

+154
-20
lines changed

8 files changed

+154
-20
lines changed

lua/pymple/api.lua

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,20 +58,17 @@ M.resolve_import_under_cursor = function()
5858
)
5959
return
6060
elseif #candidates == 1 then
61-
local final_import = candidates[1]
61+
local final_import = utils.to_import_statement(candidates[1])
6262
utils.add_import_to_buffer(
63-
"from " .. final_import .. " import " .. symbol,
63+
final_import,
6464
0,
6565
config.user_config.add_import_to_buf.autosave
6666
)
67-
log.debug("Added import for " .. symbol .. ": " .. final_import)
6867
else
6968
local candidate_statements = {}
7069
for _, candidate in ipairs(candidates) do
71-
table.insert(
72-
candidate_statements,
73-
"from " .. candidate .. " import " .. symbol
74-
)
70+
local import_statement = utils.to_import_statement(candidate)
71+
table.insert(candidate_statements, import_statement)
7572
end
7673
local longest_candidate = utils.longest_string_in_list(candidate_statements)
7774
local telescope_opts = {}

lua/pymple/build.lua

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,32 @@ local function install_sed(os_name)
6464
end
6565
end
6666

67+
local function install_fd(os_name)
68+
local out, code = nil, nil
69+
if os_name == "Darwin" then
70+
log.debug("MacOS detected, installing fd...")
71+
local job = Job:new({
72+
command = "brew",
73+
args = { "install", "fd" },
74+
})
75+
out, code = job:sync(JOB_TIMEOUT)
76+
elseif os_name == "Linux" then
77+
log.debug("Linux detected, installing sed...")
78+
-- make this depend on the package manager
79+
local job = Job:new({
80+
command = "sudo",
81+
args = { "apt-get", "install", "fd-find" },
82+
})
83+
out, code = job:sync(JOB_TIMEOUT)
84+
end
85+
if is_outcode_success(code) then
86+
return true
87+
else
88+
log.warning(out)
89+
return false
90+
end
91+
end
92+
6793
function M.build()
6894
-- check if cargo is installed
6995
if not is_cargo_installed() then
@@ -98,6 +124,16 @@ function M.build()
98124
utils.print_err(message)
99125
return
100126
end
127+
128+
-- install fd depending on OS
129+
utils.print_info("Installing sed/gsed...")
130+
if install_fd(utils.OS_NAME) then
131+
utils.print_info("fd installed successfully")
132+
else
133+
local message = "Failed to install fd. Try installing it manually."
134+
utils.print_err(message)
135+
return
136+
end
101137
end
102138

103139
return M

lua/pymple/health.lua

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@ local required_binaries = {
2626
binaries = { "gg" },
2727
min_version = "0.2.23",
2828
},
29+
{
30+
name = "fd",
31+
url = "[sharkdp/fd](https://github.com/sharkdp/fd)",
32+
optional = false,
33+
binaries = { "fd" },
34+
},
2935
{
3036
name = "sed",
3137
url = "[https://www.gnu.org/software/sed](https://www.gnu.org/software/sed/manual/sed.html)",

lua/pymple/resolve_imports/init.lua

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,23 +88,43 @@ function M.resolve_python_import(symbol, current_file_path)
8888
.. current_file_path
8989
.. " "
9090
.. build_symbol_regexes(symbol, IMPORTABLE_SYMBOLS_PATTERNS)
91+
--- fd -H -I -p "/sqlalchemy/__init__.py$|/sqlalchemy.py$" ~.venv/lib/python3.11/site-package
92+
local fd_args = '-H -I -p "/'
93+
.. symbol
94+
.. "/__init__.py$|/"
95+
.. symbol
96+
.. ".py$"
97+
.. '"'
98+
.. " "
9199
local candidate_paths = jobs.find_import_candidates(gg_args, target_paths)
100+
local modules_paths =
101+
jobs.find_import_modules_candidates(fd_args, target_paths)
102+
log.debug("Modules: " .. vim.inspect(modules_paths))
92103
log.debug("Candidates: " .. vim.inspect(candidate_paths))
93104
candidate_paths = filter_candidates(candidate_paths)
105+
modules_paths = filter_candidates(modules_paths)
106+
local result = {}
94107

95-
local import_candidates = {}
96108
for _, path in ipairs(candidate_paths) do
97109
local _path = utils.make_relative_to(path, root)
98-
local import_path = utils.to_import_path(_path)
99-
table.insert(import_candidates, import_path)
110+
local reference_path = utils.to_python_reference_path(_path, symbol)
111+
table.insert(result, reference_path)
112+
end
113+
114+
for _, path in ipairs(modules_paths) do
115+
local _path = utils.make_relative_to(path, root)
116+
local reference_path = utils.to_python_reference_path(_path)
117+
table.insert(result, reference_path)
100118
end
101-
-- sort imports by length
102-
table.sort(import_candidates, function(a, b)
119+
120+
result = utils.deduplicate_list(result)
121+
122+
table.sort(result, function(a, b)
103123
-- TODO: add other sorting rules (e.g. alphabetical, private modules, etc.)
104124
return #a < #b
105125
end)
106126

107-
return import_candidates
127+
return result
108128
end
109129

110130
return M

lua/pymple/resolve_imports/jobs.lua

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,26 @@ function M.find_import_candidates(gg_args, sys_paths)
2626
return candidates
2727
end
2828

29+
function M.find_import_modules_candidates(fd_args, sys_paths)
30+
local candidates = {}
31+
for _, sys_path in ipairs(sys_paths) do
32+
local fd_command = "fd " .. fd_args .. sys_path
33+
log.debug("FD command: " .. fd_command)
34+
local job = vim.system({ utils.SHELL, "-c", fd_command }):wait()
35+
local result_lines = vim.split(job.stdout, "\n")
36+
for i, line in ipairs(result_lines) do
37+
if line == "" then
38+
table.remove(result_lines, i)
39+
end
40+
end
41+
if #result_lines ~= 0 then
42+
for _, line in ipairs(result_lines) do
43+
local path = utils.make_relative_to(line, sys_path)
44+
table.insert(candidates, path)
45+
end
46+
end
47+
end
48+
return candidates
49+
end
50+
2951
return M

lua/pymple/update_imports/init.lua

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ M.make_monolithic_imports_job = make_monolithic_imports_job
5555
function M.prepare_jobs(source, destination, filetypes, python_root)
5656
local s, d = unpack(
5757
utils.map(
58-
utils.to_import_path,
58+
utils.to_python_reference_path,
5959
utils.make_files_relative({ source, destination }, python_root),
6060
{}
6161
)

lua/pymple/utils.lua

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,37 @@ end
9797

9898
M.find_project_root = find_project_root
9999

100-
---Converts a path to an import path
100+
---Converts a path to a python reference path
101101
---@param module_path string: The path to a python module
102-
---@return string: The import path for the module
103-
function M.to_import_path(module_path)
104-
local result, _ = module_path:gsub("/", "."):gsub("%.py$", "")
102+
---@param python_symbol string | nil: The symbol to reference
103+
---@return string: The python reference path
104+
function M.to_python_reference_path(module_path, python_symbol)
105+
local result, _ =
106+
module_path:gsub("/", "."):gsub("%.py$", ""):gsub(".__init__", "")
107+
if python_symbol then
108+
result = result .. "." .. python_symbol
109+
end
105110
return result
106111
end
107112

113+
---Generate import string from reference path
114+
---@param reference_path string: The reference path
115+
---@return string: The import string
116+
function M.to_import_statement(reference_path)
117+
-- Split, but keep "__init__" as part of the path
118+
local parts = vim.split(reference_path, ".", { plain = true })
119+
-- If there's only one part, this is a simple import
120+
if #parts == 1 then
121+
return "import " .. parts[1]
122+
end
123+
local module_name = parts[#parts]
124+
table.remove(parts, #parts)
125+
126+
local module_path = table.concat(parts, ".")
127+
128+
return "from " .. module_path .. " import " .. module_name
129+
end
130+
108131
---Splits an import path on the last separator
109132
---@param import_path string: The import path to be split
110133
---@return string, string: The base path and the last part of the import path

lua/tests/utils_spec.lua

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,41 @@ describe("find_project_root", function()
2929
end)
3030
end)
3131

32-
describe("to_import_path", function()
32+
describe("to_python_reference_path", function()
3333
it("std", function()
34-
local result = utils.to_import_path("foo/bar/baz.py")
34+
local result = utils.to_python_reference_path("foo/bar/baz.py")
3535
assert.equals("foo.bar.baz", result)
3636
end)
37+
38+
it("__init__.py", function()
39+
local result = utils.to_python_reference_path("foo/bar/__init__.py")
40+
assert.equals("foo.bar", result)
41+
end)
42+
43+
it("python_symbol", function()
44+
local result = utils.to_python_reference_path("foo/bar/baz.py", "qux")
45+
assert.equals("foo.bar.baz.qux", result)
46+
end)
47+
48+
it("__init__ and python_symbol", function()
49+
local result = utils.to_python_reference_path("foo/bar/__init__.py", "qux")
50+
assert.equals("foo.bar.qux", result)
51+
end)
52+
end)
53+
54+
describe("to_import_statement", function()
55+
it("std", function()
56+
local result = utils.to_import_statement("foo.bar.baz")
57+
assert.equals("from foo.bar import baz", result)
58+
end)
59+
it("python_symbol", function()
60+
local result = utils.to_import_statement("foo.bar.baz.qux")
61+
assert.equals("from foo.bar.baz import qux", result)
62+
end)
63+
it("single_symbol", function()
64+
local result = utils.to_import_statement("foo")
65+
assert.equals("import foo", result)
66+
end)
3767
end)
3868

3969
describe("split_import_on_last_separator", function()

0 commit comments

Comments
 (0)