Skip to content

Commit b7514db

Browse files
committed
Use more accurate spans for errors.
Now it looks inside groups too, instead of treating groups as a whole.
1 parent 5541ce0 commit b7514db

File tree

1 file changed

+24
-14
lines changed

1 file changed

+24
-14
lines changed

macros/src/error.rs

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use proc_macro::Span;
1+
use proc_macro::{TokenTree, Span, TokenStream as TokenStream1};
22
use proc_macro2::TokenStream;
33
use pyo3::{prelude::*, types::PyTraceback, Bound, IntoPyObject, PyErr, PyResult, PyTypeInfo, Python};
44
use quote::{quote, quote_spanned};
@@ -16,7 +16,7 @@ pub fn compile_error_msg(py: Python, error: PyErr, tokens: TokenStream) -> Token
1616
let line: Option<usize> = value.getattr("lineno").ok().and_then(|x| x.extract().ok());
1717
let msg: Option<String> = value.getattr("msg").ok().and_then(|x| x.extract().ok());
1818
if let (Some(line), Some(msg)) = (line, msg) {
19-
if let Some(spans) = spans_for_line(tokens.clone(), line) {
19+
if let Some(spans) = spans_for_line(tokens.clone().into(), line) {
2020
return compile_error(spans, format!("python: {msg}"));
2121
}
2222
}
@@ -26,7 +26,7 @@ pub fn compile_error_msg(py: Python, error: PyErr, tokens: TokenStream) -> Token
2626
if let Ok((file, line)) = get_traceback_info(tb) {
2727
if file == Span::call_site().file() {
2828
if let Ok(msg) = value.str() {
29-
if let Some(spans) = spans_for_line(tokens, line) {
29+
if let Some(spans) = spans_for_line(tokens.into(), line) {
3030
return compile_error(spans, format!("python: {msg}"));
3131
}
3232
}
@@ -46,18 +46,28 @@ fn get_traceback_info(tb: &Bound<'_, PyTraceback>) -> PyResult<(String, usize)>
4646
Ok((file, line))
4747
}
4848

49-
/// Get the first and last span for a specific line of input from a TokenStream.
50-
fn spans_for_line(input: TokenStream, line: usize) -> Option<(Span, Span)> {
51-
let mut spans = input
52-
.into_iter()
53-
.map(|x| x.span().unwrap())
54-
.skip_while(|span| span.start().line() < line)
55-
.take_while(|span| span.start().line() == line);
56-
57-
let first = spans.next()?;
58-
let last = spans.last().unwrap_or(first);
49+
fn for_all_spans(input: TokenStream1, f: &mut impl FnMut(Span)) {
50+
for token in input {
51+
match token {
52+
TokenTree::Group(group) => {
53+
f(group.span_open());
54+
for_all_spans(group.stream(), f);
55+
f(group.span_close());
56+
}
57+
_ => f(token.span()),
58+
}
59+
}
60+
}
5961

60-
Some((first, last))
62+
/// Get the first and last span for a specific line of input from a TokenStream.
63+
fn spans_for_line(input: TokenStream1, line: usize) -> Option<(Span, Span)> {
64+
let mut spans = None;
65+
for_all_spans(input, &mut |span| {
66+
if span.start().line() == line {
67+
spans.get_or_insert((span, span)).1 = span;
68+
}
69+
});
70+
spans
6171
}
6272

6373
/// Create a compile_error!{} using two spans that mark the start and end of the error.

0 commit comments

Comments
 (0)