Skip to content

Commit e33c545

Browse files
giacomocavalierilpil
authored andcommitted
allow further matching on varaibles inside a pattern
fixes #4807
1 parent 12f8d70 commit e33c545

6 files changed

+236
-27
lines changed

compiler-core/src/language_server/code_action.rs

Lines changed: 107 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4453,6 +4453,7 @@ pub struct PatternMatchOnValue<'a, A> {
44534453
/// print and format the corresponding pattern matching code; that's why you'll
44544454
/// see `Range`s and `SrcSpan` besides the type of the thing being matched.
44554455
///
4456+
#[derive(Clone)]
44564457
pub enum PatternMatchedValue<'a> {
44574458
FunctionArgument {
44584459
/// The argument being pattern matched on.
@@ -4467,16 +4468,16 @@ pub enum PatternMatchedValue<'a> {
44674468
function_range: Range,
44684469
},
44694470
LetVariable {
4470-
variable_name: &'a EcoString,
4471-
variable_type: &'a Arc<Type>,
4471+
variable_name: EcoString,
4472+
variable_type: Arc<Type>,
44724473
/// The location of the entire let assignment the variable is part of,
44734474
/// so that we can add the pattern matching _after_ it.
44744475
///
44754476
assignment_location: SrcSpan,
44764477
},
44774478
UseVariable {
4478-
variable_name: &'a EcoString,
4479-
variable_type: &'a Arc<Type>,
4479+
variable_name: EcoString,
4480+
variable_type: Arc<Type>,
44804481
/// The location of the entire use expression the variable is part of,
44814482
/// so that we can add the pattern matching _after_ it.
44824483
///
@@ -4506,7 +4507,7 @@ where
45064507
pub fn code_actions(mut self) -> Vec<CodeAction> {
45074508
self.visit_typed_module(&self.module.ast);
45084509

4509-
let action_title = match self.selected_value {
4510+
let action_title = match self.selected_value.clone() {
45104511
Some(PatternMatchedValue::FunctionArgument {
45114512
arg,
45124513
first_statement: function_body,
@@ -4614,8 +4615,8 @@ where
46144615

46154616
fn match_on_let_variable(
46164617
&mut self,
4617-
variable_name: &EcoString,
4618-
variable_type: &Arc<Type>,
4618+
variable_name: EcoString,
4619+
variable_type: Arc<Type>,
46194620
assignment_location: SrcSpan,
46204621
) {
46214622
let Some(patterns) = self.type_to_destructure_patterns(variable_type.as_ref()) else {
@@ -4780,6 +4781,94 @@ where
47804781
pattern.push(')');
47814782
Some(pattern)
47824783
}
4784+
4785+
fn pattern_variable_under_cursor(
4786+
&self,
4787+
pattern: &'a TypedPattern,
4788+
) -> Option<(&'a EcoString, Arc<Type>)> {
4789+
match pattern {
4790+
Pattern::Int { .. }
4791+
| Pattern::Float { .. }
4792+
| Pattern::String { .. }
4793+
| Pattern::BitArraySize(_)
4794+
| Pattern::Discard { .. }
4795+
| Pattern::Invalid { .. } => None,
4796+
4797+
Pattern::Variable {
4798+
location,
4799+
name,
4800+
type_,
4801+
..
4802+
} => {
4803+
if within(
4804+
self.params.range,
4805+
self.edits.src_span_to_lsp_range(*location),
4806+
) {
4807+
Some((name, type_.clone()))
4808+
} else {
4809+
None
4810+
}
4811+
}
4812+
4813+
Pattern::Assign { pattern, .. } => self.pattern_variable_under_cursor(pattern),
4814+
4815+
Pattern::Constructor { arguments, .. } => arguments
4816+
.iter()
4817+
.find_map(|argument| self.pattern_variable_under_cursor(&argument.value)),
4818+
4819+
Pattern::Tuple { elements, .. } => elements
4820+
.iter()
4821+
.find_map(|element| self.pattern_variable_under_cursor(element)),
4822+
4823+
Pattern::List { elements, tail, .. } => elements
4824+
.iter()
4825+
.find_map(|element| self.pattern_variable_under_cursor(element))
4826+
.or_else(|| {
4827+
tail.as_ref()
4828+
.and_then(|tail| self.pattern_variable_under_cursor(tail))
4829+
}),
4830+
4831+
Pattern::BitArray { segments, .. } => segments
4832+
.iter()
4833+
.flat_map(|segment| &segment.options)
4834+
.find_map(|option| {
4835+
option
4836+
.value()
4837+
.and_then(|pattern| self.pattern_variable_under_cursor(pattern))
4838+
}),
4839+
4840+
Pattern::StringPrefix {
4841+
left_side_assignment: Some((name, location)),
4842+
..
4843+
} => {
4844+
if within(
4845+
self.params.range,
4846+
self.edits.src_span_to_lsp_range(*location),
4847+
) {
4848+
Some((name, type_::string()))
4849+
} else {
4850+
None
4851+
}
4852+
}
4853+
4854+
Pattern::StringPrefix {
4855+
right_side_assignment: AssignName::Variable(name),
4856+
right_location,
4857+
..
4858+
} => {
4859+
if within(
4860+
self.params.range,
4861+
self.edits.src_span_to_lsp_range(*right_location),
4862+
) {
4863+
Some((name, type_::string()))
4864+
} else {
4865+
None
4866+
}
4867+
}
4868+
4869+
Pattern::StringPrefix { .. } => None,
4870+
}
4871+
}
47834872
}
47844873

47854874
impl<'ast, IO> ast::visit::Visit<'ast> for PatternMatchOnValue<'ast, IO>
@@ -4863,24 +4952,15 @@ where
48634952
}
48644953

48654954
fn visit_typed_assignment(&mut self, assignment: &'ast TypedAssignment) {
4866-
if let Pattern::Variable {
4867-
name,
4868-
location,
4869-
type_,
4870-
..
4871-
} = &assignment.pattern
4872-
{
4873-
let variable_range = self.edits.src_span_to_lsp_range(*location);
4874-
if within(self.params.range, variable_range) {
4875-
self.selected_value = Some(PatternMatchedValue::LetVariable {
4876-
variable_name: name,
4877-
variable_type: type_,
4878-
assignment_location: assignment.location,
4879-
});
4880-
// If we've found the variable to pattern match on, there's no
4881-
// point in keeping traversing the AST.
4882-
return;
4883-
}
4955+
if let Some((name, type_)) = self.pattern_variable_under_cursor(&assignment.pattern) {
4956+
self.selected_value = Some(PatternMatchedValue::LetVariable {
4957+
variable_name: name.clone(),
4958+
variable_type: type_,
4959+
assignment_location: assignment.location,
4960+
});
4961+
// If we've found the variable to pattern match on, there's no
4962+
// point in keeping traversing the AST.
4963+
return;
48844964
}
48854965

48864966
ast::visit::visit_typed_assignment(self, assignment);
@@ -4909,8 +4989,8 @@ where
49094989
let variable_range = self.edits.src_span_to_lsp_range(*variable_location);
49104990
if within(self.params.range, variable_range) {
49114991
self.selected_value = Some(PatternMatchedValue::UseVariable {
4912-
variable_name: name,
4913-
variable_type: type_,
4992+
variable_name: name.clone(),
4993+
variable_type: type_.clone(),
49144994
use_location: use_.location,
49154995
});
49164996
// If we've found the variable to pattern match on, there's no

compiler-core/src/language_server/tests/action.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9163,6 +9163,57 @@ fn remove_opaque_from_private_type() {
91639163
);
91649164
}
91659165

9166+
#[test]
9167+
fn allow_further_pattern_matching_on_let_tuple_destructuring() {
9168+
assert_code_action!(
9169+
PATTERN_MATCH_ON_VARIABLE,
9170+
"pub fn main(x) {
9171+
let #(one, other) = #(Ok(1), Error(Nil))
9172+
}
9173+
",
9174+
find_position_of("one").to_selection()
9175+
);
9176+
}
9177+
9178+
#[test]
9179+
fn allow_further_pattern_matching_on_let_record_destructuring() {
9180+
assert_code_action!(
9181+
PATTERN_MATCH_ON_VARIABLE,
9182+
"pub fn main(x) {
9183+
let Wibble(field:) = Wibble(Ok(Nil))
9184+
}
9185+
9186+
pub type Wibble { Wibble(field: Result(Nil, String)) }
9187+
",
9188+
find_position_of("one").to_selection()
9189+
);
9190+
}
9191+
9192+
#[test]
9193+
fn allow_further_pattern_matching_on_asserted_result() {
9194+
assert_code_action!(
9195+
PATTERN_MATCH_ON_VARIABLE,
9196+
"pub fn main(x) {
9197+
let assert Ok(one) = Ok(Error(Nil))
9198+
}
9199+
",
9200+
find_position_of("one").to_selection()
9201+
);
9202+
}
9203+
9204+
#[test]
9205+
fn allow_further_pattern_matching_on_asserted_list() {
9206+
assert_code_action!(
9207+
PATTERN_MATCH_ON_VARIABLE,
9208+
"pub fn main(x) {
9209+
let assert [first, ..] = [Ok(Nil), ..todo]
9210+
todo
9211+
}
9212+
",
9213+
find_position_of("first").to_selection()
9214+
);
9215+
}
9216+
91669217
#[test]
91679218
fn pattern_match_on_list_variable() {
91689219
assert_code_action!(
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
---
2+
source: compiler-core/src/language_server/tests/action.rs
3+
expression: "pub fn main(x) {\n let assert Ok(one) = Ok(Error(Nil))\n}\n"
4+
---
5+
----- BEFORE ACTION
6+
pub fn main(x) {
7+
let assert Ok(one) = Ok(Error(Nil))
8+
9+
}
10+
11+
12+
----- AFTER ACTION
13+
pub fn main(x) {
14+
let assert Ok(one) = Ok(Error(Nil))
15+
case one {
16+
Ok(value) -> todo
17+
Error(value) -> todo
18+
}
19+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
---
2+
source: compiler-core/src/language_server/tests/action.rs
3+
expression: "pub fn main(x) {\n let assert [first, ..] = [Ok(Nil), ..todo]\n todo\n}\n"
4+
---
5+
----- BEFORE ACTION
6+
pub fn main(x) {
7+
let assert [first, ..] = [Ok(Nil), ..todo]
8+
9+
todo
10+
}
11+
12+
13+
----- AFTER ACTION
14+
pub fn main(x) {
15+
let assert [first, ..] = [Ok(Nil), ..todo]
16+
case first {
17+
Ok(value) -> todo
18+
Error(value) -> todo
19+
}
20+
todo
21+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
---
2+
source: compiler-core/src/language_server/tests/action.rs
3+
expression: "pub fn main(x) {\n let assert Ok(one) = Ok(Error(Nil))\n}\n"
4+
---
5+
----- BEFORE ACTION
6+
pub fn main(x) {
7+
let assert Ok(one) = Ok(Error(Nil))
8+
9+
}
10+
11+
12+
----- AFTER ACTION
13+
pub fn main(x) {
14+
let assert Ok(one) = Ok(Error(Nil))
15+
case one {
16+
Ok(value) -> todo
17+
Error(value) -> todo
18+
}
19+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
---
2+
source: compiler-core/src/language_server/tests/action.rs
3+
expression: "pub fn main(x) {\n let #(one, other) = #(Ok(1), Error(Nil))\n}\n"
4+
---
5+
----- BEFORE ACTION
6+
pub fn main(x) {
7+
let #(one, other) = #(Ok(1), Error(Nil))
8+
9+
}
10+
11+
12+
----- AFTER ACTION
13+
pub fn main(x) {
14+
let #(one, other) = #(Ok(1), Error(Nil))
15+
case one {
16+
Ok(value) -> todo
17+
Error(value) -> todo
18+
}
19+
}

0 commit comments

Comments
 (0)