Skip to content

Commit f5cda97

Browse files
authored
feat: Add import path and tree refactor (#48)
Rename some confusing tree functions and add a tree method to get all import paths of a file.
1 parent 7709f3c commit f5cda97

18 files changed

+123
-86
lines changed

src/lsp.rs

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -51,25 +51,6 @@ impl LanguageServer for ProtoLanguageServer {
5151
},
5252
}];
5353

54-
let worktoken = params.work_done_progress_params.work_done_token;
55-
let (tx, rx) = mpsc::channel();
56-
let mut socket = self.client.clone();
57-
58-
thread::spawn(move || {
59-
let Some(token) = worktoken else {
60-
return;
61-
};
62-
63-
while let Ok(value) = rx.recv() {
64-
if let Err(e) = socket.progress(ProgressParams {
65-
token: token.clone(),
66-
value,
67-
}) {
68-
error!(error=%e, "failed to report parse progress");
69-
}
70-
}
71-
});
72-
7354
let file_registration_option = FileOperationRegistrationOptions {
7455
filters: file_operation_filers.clone(),
7556
};
@@ -80,7 +61,6 @@ impl LanguageServer for ProtoLanguageServer {
8061
for workspace in folders {
8162
info!("Workspace folder: {workspace:?}");
8263
self.configs.add_workspace(&workspace);
83-
self.state.add_workspace_folder_async(workspace, tx.clone());
8464
}
8565
workspace_capabilities = Some(WorkspaceServerCapabilities {
8666
workspace_folders: Some(WorkspaceFoldersServerCapabilities {

src/nodekind.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ pub enum NodeKind {
1111
ServiceName,
1212
RpcName,
1313
PackageName,
14+
PackageImport,
1415
}
1516

1617
#[allow(unused)]
@@ -26,6 +27,7 @@ impl NodeKind {
2627
NodeKind::ServiceName => "service_name",
2728
NodeKind::RpcName => "rpc_name",
2829
NodeKind::PackageName => "full_ident",
30+
NodeKind::PackageImport => "import",
2931
}
3032
}
3133

@@ -37,6 +39,10 @@ impl NodeKind {
3739
n.kind() == Self::Error.as_str()
3840
}
3941

42+
pub fn is_import_path(n: &Node) -> bool {
43+
n.kind() == Self::PackageImport.as_str()
44+
}
45+
4046
pub fn is_package_name(n: &Node) -> bool {
4147
n.kind() == Self::PackageName.as_str()
4248
}

src/parser/definition.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ impl ParsedTree {
2525
match identifier.split_once('.') {
2626
Some((parent_identifier, remaining)) => {
2727
let child_node = self
28-
.filter_nodes_from(n, NodeKind::is_userdefined)
28+
.find_all_nodes_from(n, NodeKind::is_userdefined)
2929
.into_iter()
3030
.find(|n| {
3131
n.utf8_text(content.as_ref()).expect("utf8-parse error")
@@ -39,7 +39,7 @@ impl ParsedTree {
3939
}
4040
None => {
4141
let locations: Vec<Location> = self
42-
.filter_nodes_from(n, NodeKind::is_userdefined)
42+
.find_all_nodes_from(n, NodeKind::is_userdefined)
4343
.into_iter()
4444
.filter(|n| {
4545
n.utf8_text(content.as_ref()).expect("utf-8 parse error") == identifier

src/parser/diagnostics.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use super::ParsedTree;
77
impl ParsedTree {
88
pub fn collect_parse_errors(&self) -> PublishDiagnosticsParams {
99
let diagnostics = self
10-
.filter_nodes(NodeKind::is_error)
10+
.find_all_nodes(NodeKind::is_error)
1111
.into_iter()
1212
.map(|n| Diagnostic {
1313
range: Range {

src/parser/hover.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ impl ParsedTree {
6666
match identifier.split_once('.') {
6767
Some((parent, child)) => {
6868
let child_node = self
69-
.filter_nodes_from(n, NodeKind::is_userdefined)
69+
.find_all_nodes_from(n, NodeKind::is_userdefined)
7070
.into_iter()
7171
.find(|n| n.utf8_text(content.as_ref()).expect("utf8-parse error") == parent)
7272
.and_then(|n| n.parent());
@@ -77,7 +77,7 @@ impl ParsedTree {
7777
}
7878
None => {
7979
let comments: Vec<String> = self
80-
.filter_nodes_from(n, NodeKind::is_userdefined)
80+
.find_all_nodes_from(n, NodeKind::is_userdefined)
8181
.into_iter()
8282
.filter(|n| {
8383
n.utf8_text(content.as_ref()).expect("utf-8 parse error") == identifier

src/parser/input/test_filter.proto

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ syntax = "proto3";
22

33
package com.parser;
44

5+
import "foo/bar.proto";
6+
import "baz/bar.proto";
7+
58
message Book {
69

710
message Author {

src/parser/rename.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ impl ParsedTree {
2828
content: impl AsRef<[u8]>,
2929
) -> Option<Vec<Node<'a>>> {
3030
n.parent().map(|p| {
31-
self.filter_nodes_from(p, NodeKind::is_field_name)
31+
self.find_all_nodes_from(p, NodeKind::is_field_name)
3232
.into_iter()
3333
.filter(|i| i.utf8_text(content.as_ref()).expect("utf-8 parse error") == identifier)
3434
.collect()
@@ -114,7 +114,7 @@ impl ParsedTree {
114114
new_identifier: &str,
115115
content: impl AsRef<[u8]>,
116116
) -> Vec<TextEdit> {
117-
self.filter_nodes(NodeKind::is_field_name)
117+
self.find_all_nodes(NodeKind::is_field_name)
118118
.into_iter()
119119
.filter(|n| {
120120
let ntext = n.utf8_text(content.as_ref()).expect("utf-8 parse error");
@@ -135,7 +135,7 @@ impl ParsedTree {
135135
}
136136

137137
pub fn reference_field(&self, id: &str, content: impl AsRef<[u8]>) -> Vec<Location> {
138-
self.filter_nodes(NodeKind::is_field_name)
138+
self.find_all_nodes(NodeKind::is_field_name)
139139
.into_iter()
140140
.filter(|n| n.utf8_text(content.as_ref()).expect("utf-8 parse error") == id)
141141
.map(|n| Location {
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
source: src/parser/tree.rs
3+
expression: imports
4+
---
5+
- foo/bar.proto
6+
- baz/bar.proto

src/parser/tree.rs

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use crate::{nodekind::NodeKind, utils::lsp_to_ts_point};
66
use super::ParsedTree;
77

88
impl ParsedTree {
9-
pub(super) fn walk_and_collect_filter<'a>(
9+
pub(super) fn walk_and_filter<'a>(
1010
cursor: &mut TreeCursor<'a>,
1111
f: fn(&Node) -> bool,
1212
early: bool,
@@ -24,7 +24,7 @@ impl ParsedTree {
2424
}
2525

2626
if cursor.goto_first_child() {
27-
v.extend(Self::walk_and_collect_filter(cursor, f, early));
27+
v.extend(Self::walk_and_filter(cursor, f, early));
2828
cursor.goto_parent();
2929
}
3030

@@ -110,29 +110,41 @@ impl ParsedTree {
110110
self.tree.root_node().descendant_for_point_range(pos, pos)
111111
}
112112

113-
pub fn filter_nodes(&self, f: fn(&Node) -> bool) -> Vec<Node> {
114-
self.filter_nodes_from(self.tree.root_node(), f)
113+
pub fn find_all_nodes(&self, f: fn(&Node) -> bool) -> Vec<Node> {
114+
self.find_all_nodes_from(self.tree.root_node(), f)
115115
}
116116

117-
pub fn filter_nodes_from<'a>(&self, n: Node<'a>, f: fn(&Node) -> bool) -> Vec<Node<'a>> {
117+
pub fn find_all_nodes_from<'a>(&self, n: Node<'a>, f: fn(&Node) -> bool) -> Vec<Node<'a>> {
118118
let mut cursor = n.walk();
119-
Self::walk_and_collect_filter(&mut cursor, f, false)
119+
Self::walk_and_filter(&mut cursor, f, false)
120120
}
121121

122-
pub fn find_node(&self, f: fn(&Node) -> bool) -> Vec<Node> {
122+
pub fn find_first_node(&self, f: fn(&Node) -> bool) -> Vec<Node> {
123123
self.find_node_from(self.tree.root_node(), f)
124124
}
125125

126126
pub fn find_node_from<'a>(&self, n: Node<'a>, f: fn(&Node) -> bool) -> Vec<Node<'a>> {
127127
let mut cursor = n.walk();
128-
Self::walk_and_collect_filter(&mut cursor, f, true)
128+
Self::walk_and_filter(&mut cursor, f, true)
129129
}
130130

131131
pub fn get_package_name<'a>(&self, content: &'a [u8]) -> Option<&'a str> {
132-
self.find_node(NodeKind::is_package_name)
132+
self.find_first_node(NodeKind::is_package_name)
133133
.first()
134134
.map(|n| n.utf8_text(content).expect("utf-8 parse error"))
135135
}
136+
pub fn get_import_path<'a>(&self, content: &'a [u8]) -> Vec<&'a str> {
137+
self.find_all_nodes(NodeKind::is_import_path)
138+
.into_iter()
139+
.filter_map(|n| {
140+
n.child_by_field_name("path").map(|c| {
141+
c.utf8_text(content)
142+
.expect("utf-8 parse error")
143+
.trim_matches('"')
144+
})
145+
})
146+
.collect()
147+
}
136148
}
137149

138150
#[cfg(test)]
@@ -150,7 +162,7 @@ mod test {
150162

151163
assert!(parsed.is_some());
152164
let tree = parsed.unwrap();
153-
let nodes = tree.filter_nodes(NodeKind::is_message_name);
165+
let nodes = tree.find_all_nodes(NodeKind::is_message_name);
154166

155167
assert_eq!(nodes.len(), 2);
156168

@@ -163,5 +175,7 @@ mod test {
163175

164176
let package_name = tree.get_package_name(contents.as_ref());
165177
assert_yaml_snapshot!(package_name);
178+
let imports = tree.get_import_path(contents.as_ref());
179+
assert_yaml_snapshot!(imports);
166180
}
167181
}

src/server.rs

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
1-
use async_lsp::{router::Router, ClientSocket};
2-
use std::ops::ControlFlow;
1+
use async_lsp::{
2+
lsp_types::{NumberOrString, ProgressParams, ProgressParamsValue},
3+
router::Router,
4+
ClientSocket, LanguageClient,
5+
};
6+
use std::{
7+
ops::ControlFlow,
8+
sync::{mpsc, mpsc::Sender},
9+
thread,
10+
};
311

412
use crate::{config::workspace::WorkspaceProtoConfigs, state::ProtoLanguageState};
513

@@ -27,4 +35,22 @@ impl ProtoLanguageServer {
2735
self.counter += 1;
2836
ControlFlow::Continue(())
2937
}
38+
39+
fn with_report_progress(&self, token: NumberOrString) -> Sender<ProgressParamsValue> {
40+
let (tx, rx) = mpsc::channel();
41+
let mut socket = self.client.clone();
42+
43+
thread::spawn(move || {
44+
while let Ok(value) = rx.recv() {
45+
if let Err(e) = socket.progress(ProgressParams {
46+
token: token.clone(),
47+
value,
48+
}) {
49+
tracing::error!(error=%e, "failed to report parse progress");
50+
}
51+
}
52+
});
53+
54+
tx
55+
}
3056
}

0 commit comments

Comments
 (0)