Skip to content

Commit fbc4330

Browse files
fix
1 parent ec3e611 commit fbc4330

File tree

4 files changed

+96
-28
lines changed

4 files changed

+96
-28
lines changed

pyrefly/lib/commands/infer.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ use crate::commands::files::get_project_config_for_current_dir;
2525
use crate::commands::util::CommandExitStatus;
2626
use crate::config::error_kind::ErrorKind;
2727
use crate::state::ide::insert_import_edit_with_forced_import_format;
28+
use crate::state::ide::ImportEdit;
2829
use crate::state::lsp::AnnotationKind;
2930
use crate::state::lsp::ParameterAnnotation;
3031
use crate::state::require::Require;
@@ -289,7 +290,7 @@ impl InferArgs {
289290
if let Some(ast) = transaction.get_ast(&handle) {
290291
let error_range = error.range();
291292
let unknown_name = module_info.code_at(error_range);
292-
let imports: Vec<(TextSize, String)> = transaction
293+
let imports: Vec<ImportEdit> = transaction
293294
.search_exports_exact(unknown_name)
294295
.into_iter()
295296
.map(|handle_to_import_from| {
@@ -336,14 +337,17 @@ impl InferArgs {
336337

337338
fn add_imports_to_file(
338339
file_path: &Path,
339-
imports: Vec<(TextSize, String)>,
340+
imports: Vec<ImportEdit>,
340341
) -> anyhow::Result<()> {
341342
let file_content = fs_anyhow::read_to_string(file_path)?;
342343
let mut result = file_content;
343-
for (position, import) in imports {
344-
let offset = (position).into();
345-
if !result.contains(&import) {
346-
result.insert_str(offset, &import);
344+
for import_edit in imports {
345+
if import_edit.insert_text.is_empty() {
346+
continue;
347+
}
348+
let offset = (import_edit.position).into();
349+
if offset <= result.len() && !result.contains(&import_edit.insert_text) {
350+
result.insert_str(offset, &import_edit.insert_text);
347351
}
348352
}
349353
fs_anyhow::write(file_path, result)

pyrefly/lib/state/ide.rs

Lines changed: 65 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ use ruff_python_ast::Expr;
1616
use ruff_python_ast::ModModule;
1717
use ruff_python_ast::helpers::is_docstring_stmt;
1818
use ruff_python_ast::name::Name;
19+
use ruff_python_ast::Stmt;
20+
use ruff_python_ast::StmtImportFrom;
1921
use ruff_text_size::Ranged;
2022
use ruff_text_size::TextRange;
2123
use ruff_text_size::TextSize;
@@ -32,6 +34,13 @@ use crate::state::lsp::ImportFormat;
3234

3335
const KEY_TO_DEFINITION_INITIAL_GAS: Gas = Gas::new(100);
3436

37+
#[derive(Clone, Debug, PartialEq, Eq)]
38+
pub struct ImportEdit {
39+
pub position: TextSize,
40+
pub insert_text: String,
41+
pub display_text: String,
42+
}
43+
3544
pub enum IntermediateDefinition {
3645
Local(Export),
3746
NamedImport(TextRange, ModuleName, Name, Option<TextRange>),
@@ -189,7 +198,7 @@ pub fn insert_import_edit(
189198
handle_to_import_from: Handle,
190199
export_name: &str,
191200
import_format: ImportFormat,
192-
) -> (TextSize, String) {
201+
) -> ImportEdit {
193202
let use_absolute_import = match import_format {
194203
ImportFormat::Absolute => true,
195204
ImportFormat::Relative => {
@@ -226,12 +235,7 @@ pub fn insert_import_edit_with_forced_import_format(
226235
handle_to_import_from: Handle,
227236
export_name: &str,
228237
use_absolute_import: bool,
229-
) -> (TextSize, String) {
230-
let position = if let Some(first_stmt) = ast.body.iter().find(|stmt| !is_docstring_stmt(stmt)) {
231-
first_stmt.range().start()
232-
} else {
233-
ast.range.end()
234-
};
238+
) -> ImportEdit {
235239
let module_name_to_import = if use_absolute_import {
236240
handle_to_import_from.module()
237241
} else if let Some(relative_module) = ModuleName::relative_module_name_between(
@@ -242,12 +246,65 @@ pub fn insert_import_edit_with_forced_import_format(
242246
} else {
243247
handle_to_import_from.module()
244248
};
249+
let display_text = format!(
250+
"from {} import {}",
251+
module_name_to_import.as_str(),
252+
export_name
253+
);
254+
if let Some((position, insert_text)) =
255+
try_extend_existing_from_import(ast, module_name_to_import.as_str(), export_name)
256+
{
257+
return ImportEdit {
258+
position,
259+
insert_text,
260+
display_text,
261+
};
262+
}
263+
let position = if let Some(first_stmt) = ast.body.iter().find(|stmt| !is_docstring_stmt(stmt)) {
264+
first_stmt.range().start()
265+
} else {
266+
ast.range.end()
267+
};
245268
let insert_text = format!(
246269
"from {} import {}\n",
247270
module_name_to_import.as_str(),
248271
export_name
249272
);
250-
(position, insert_text)
273+
ImportEdit {
274+
position,
275+
insert_text,
276+
display_text,
277+
}
278+
}
279+
280+
fn try_extend_existing_from_import(
281+
ast: &ModModule,
282+
target_module_name: &str,
283+
export_name: &str,
284+
) -> Option<(TextSize, String)> {
285+
for stmt in &ast.body {
286+
if let Stmt::ImportFrom(import_from) = stmt {
287+
if import_from_module_name(import_from) == target_module_name {
288+
if let Some(last_alias) = import_from.names.last() {
289+
let position = last_alias.range.end();
290+
let insert_text = format!(", {}", export_name);
291+
return Some((position, insert_text));
292+
}
293+
}
294+
}
295+
}
296+
None
297+
}
298+
299+
fn import_from_module_name(import_from: &StmtImportFrom) -> String {
300+
let mut module_name = String::new();
301+
if import_from.level > 0 {
302+
module_name.push_str(&".".repeat(import_from.level as usize));
303+
}
304+
if let Some(module) = &import_from.module {
305+
module_name.push_str(module.as_str());
306+
}
307+
module_name
251308
}
252309

253310
/// Some handles must be imported in absolute style,

pyrefly/lib/state/lsp.rs

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1859,17 +1859,23 @@ impl<'a> Transaction<'a> {
18591859
if error_range.contains_range(range) {
18601860
let unknown_name = module_info.code_at(error_range);
18611861
for handle_to_import_from in self.search_exports_exact(unknown_name) {
1862-
let (position, insert_text) = insert_import_edit(
1862+
let import_edit = insert_import_edit(
18631863
&ast,
18641864
self.config_finder(),
18651865
handle.dupe(),
18661866
handle_to_import_from,
18671867
unknown_name,
18681868
import_format,
18691869
);
1870-
let range = TextRange::at(position, TextSize::new(0));
1871-
let title = format!("Insert import: `{}`", insert_text.trim());
1872-
code_actions.push((title, module_info.dupe(), range, insert_text));
1870+
let range = TextRange::at(import_edit.position, TextSize::new(0));
1871+
let title =
1872+
format!("Insert import: `{}`", import_edit.display_text);
1873+
code_actions.push((
1874+
title,
1875+
module_info.dupe(),
1876+
range,
1877+
import_edit.insert_text,
1878+
));
18731879
}
18741880

18751881
for module_name in self.search_modules_fuzzy(unknown_name) {
@@ -2391,8 +2397,8 @@ impl<'a> Transaction<'a> {
23912397
{
23922398
continue;
23932399
}
2394-
let (insert_text, additional_text_edits) = {
2395-
let (position, insert_text) = insert_import_edit(
2400+
let (detail_text, additional_text_edits) = {
2401+
let import_edit = insert_import_edit(
23962402
&ast,
23972403
self.config_finder(),
23982404
handle.dupe(),
@@ -2401,14 +2407,17 @@ impl<'a> Transaction<'a> {
24012407
import_format,
24022408
);
24032409
let import_text_edit = TextEdit {
2404-
range: module_info.to_lsp_range(TextRange::at(position, TextSize::new(0))),
2405-
new_text: insert_text.clone(),
2410+
range: module_info.to_lsp_range(TextRange::at(
2411+
import_edit.position,
2412+
TextSize::new(0),
2413+
)),
2414+
new_text: import_edit.insert_text.clone(),
24062415
};
2407-
(Some(insert_text), Some(vec![import_text_edit]))
2416+
(Some(import_edit.display_text), Some(vec![import_text_edit]))
24082417
};
24092418
completions.push(CompletionItem {
24102419
label: name,
2411-
detail: insert_text,
2420+
detail: detail_text,
24122421
kind: export
24132422
.symbol_kind
24142423
.map_or(Some(CompletionItemKind::VARIABLE), |k| {

pyrefly/lib/test/lsp/code_actions.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,7 @@ fn insertion_test_duplicate_imports() {
210210
],
211211
get_test_report,
212212
);
213-
// The insertion won't attempt to merge imports from the same module.
214-
// It's not illegal, but it would be nice if we do merge.
213+
// When another import from the same module already exists, we should append to it.
215214
assert_eq!(
216215
r#"
217216
# a.py
@@ -227,8 +226,7 @@ from a import another_thing
227226
my_export
228227
# ^
229228
## After:
230-
from a import my_export
231-
from a import another_thing
229+
from a import another_thing, my_export
232230
my_export
233231
# ^
234232
"#

0 commit comments

Comments
 (0)