Skip to content

Commit 9b3dfe4

Browse files
committed
feat: refactor to match spans and fix bug with header doc comments
1 parent 4578cd0 commit 9b3dfe4

11 files changed

+191
-29
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pub mod text_span;
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
use std::collections::{HashMap, VecDeque};
2+
3+
use cairo_lang_macro::{quote, TextSpan, Token, TokenStream, TokenTree};
4+
use cairo_lang_parser::utils::SimpleParserDatabase;
5+
use cairo_lang_syntax::node::with_db::SyntaxNodeWithDb;
6+
7+
fn make_ident(text: &str, span: TextSpan) -> TokenTree {
8+
TokenTree::Ident(Token::new(text, span))
9+
}
10+
11+
fn tokenize_str(s: &str, db: &SimpleParserDatabase) -> TokenStream {
12+
let syntax_node = db.parse_virtual(s).unwrap();
13+
let syntax_node_with_db = SyntaxNodeWithDb::new(&syntax_node, db);
14+
quote! {#syntax_node_with_db}
15+
}
16+
17+
/// Merge spans from `initial_tokens` into a tokenized version of `final_output`.
18+
/// - If a token text in `final_output` matches one from `initial_tokens` (by exact text),
19+
/// it inherits that token’s span (first-come, first-served).
20+
/// - Otherwise, it gets `TextSpan::call_site()`.
21+
pub fn merge_spans_from_initial(
22+
initial_tokens: &[TokenTree],
23+
final_output: &str,
24+
db: &SimpleParserDatabase,
25+
) -> Vec<TokenTree> {
26+
// Build a multimap: text -> queue of spans (to support duplicates, left-to-right)
27+
let mut span_index: HashMap<String, VecDeque<TextSpan>> = HashMap::new();
28+
for tt in initial_tokens {
29+
let TokenTree::Ident(tok) = tt;
30+
let text = tok.content.to_string();
31+
span_index
32+
.entry(text)
33+
.or_default()
34+
.push_back(tok.span.clone());
35+
}
36+
37+
// Tokenize final output
38+
let final_tokens = tokenize_str(final_output, db)
39+
.tokens
40+
.iter()
41+
.map(|tt| {
42+
let TokenTree::Ident(tok) = tt;
43+
tok.content.to_string()
44+
})
45+
.collect::<Vec<_>>();
46+
47+
// Rebuild final tokens, reusing spans when the text matches; else call_site.
48+
let mut out = Vec::with_capacity(final_tokens.len());
49+
for text in final_tokens {
50+
if let Some(queue) = span_index.get_mut(&text) {
51+
if let Some(span) = queue.pop_front() {
52+
out.push(make_ident(&text, span));
53+
if queue.is_empty() {
54+
// keep the map clean
55+
span_index.remove(&text);
56+
}
57+
continue;
58+
}
59+
}
60+
out.push(make_ident(&text, TextSpan::call_site()));
61+
}
62+
63+
out
64+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
//! OpenZeppelin attribute macros.
22
3+
pub mod common;
34
pub mod type_hash;
45
pub mod with_components;

packages/macros/src/attribute/type_hash/definition.rs

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use convert_case::{Case, Casing};
1010
use indoc::formatdoc;
1111
use regex::Regex;
1212

13+
use crate::attribute::common::text_span::merge_spans_from_initial;
1314
use crate::type_hash::parser::TypeHashParser;
1415

1516
use super::diagnostics::errors;
@@ -29,11 +30,13 @@ use super::parser::parse_string_arg;
2930
/// ```
3031
#[attribute_macro]
3132
pub fn type_hash(attr_stream: TokenStream, item_stream: TokenStream) -> ProcMacroResult {
33+
let no_op_result = ProcMacroResult::new(item_stream.clone());
34+
3235
// 1. Parse the attribute stream
3336
let config = match parse_args(&attr_stream.to_string()) {
3437
Ok(config) => config,
3538
Err(err) => {
36-
return ProcMacroResult::new(TokenStream::empty()).with_diagnostics(err.into());
39+
return no_op_result.with_diagnostics(err.into());
3740
}
3841
};
3942

@@ -44,24 +47,30 @@ pub fn type_hash(attr_stream: TokenStream, item_stream: TokenStream) -> ProcMacr
4447
Ok(node) => handle_node(&db, node, &config),
4548
Err(err) => {
4649
let error = Diagnostic::error(err.format(&db));
47-
return ProcMacroResult::new(TokenStream::empty()).with_diagnostics(error.into());
50+
return no_op_result.with_diagnostics(error.into());
4851
}
4952
};
5053

5154
let generated = match content {
5255
Ok(generated) => generated,
5356
Err(err) => {
54-
return ProcMacroResult::new(TokenStream::empty()).with_diagnostics(err.into());
57+
return no_op_result.with_diagnostics(err.into());
5558
}
5659
};
5760

58-
// 3. Return the result
61+
// 3. Merge spans from the item stream into the content
62+
// TODO!: This should be refactored when scarb APIs get improved
5963
let syntax_node = db.parse_virtual(generated).unwrap();
6064
let content_node = SyntaxNodeWithDb::new(&syntax_node, &db);
6165

62-
let mut result = item_stream;
66+
let mut result = item_stream.clone();
6367
result.extend(quote! {#content_node});
64-
ProcMacroResult::new(result)
68+
69+
let syntax_node_with_spans =
70+
merge_spans_from_initial(&item_stream.tokens, &result.to_string(), &db);
71+
let token_stream =
72+
TokenStream::new(syntax_node_with_spans).with_metadata(item_stream.metadata().clone());
73+
ProcMacroResult::new(token_stream)
6574
}
6675

6776
/// This attribute macro is used to specify an override for the SNIP-12 type.

packages/macros/src/attribute/with_components/definition.rs

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
use cairo_lang_formatter::format_string;
2-
use cairo_lang_macro::{attribute_macro, quote, Diagnostic, ProcMacroResult, TokenStream};
1+
use cairo_lang_macro::{attribute_macro, Diagnostic, ProcMacroResult, TokenStream};
32
use cairo_lang_parser::utils::SimpleParserDatabase;
4-
use cairo_lang_syntax::node::with_db::SyntaxNodeWithDb;
53
use regex::Regex;
64

7-
use crate::with_components::{components::AllowedComponents, parser::WithComponentsParser};
5+
use crate::{
6+
attribute::common::text_span::merge_spans_from_initial,
7+
with_components::{components::AllowedComponents, parser::WithComponentsParser},
8+
};
89

910
/// Inserts multiple component dependencies into a modules codebase.
1011
#[attribute_macro]
@@ -13,15 +14,15 @@ pub fn with_components(attribute_stream: TokenStream, item_stream: TokenStream)
1314

1415
// 1. Get the components info (if valid)
1516
let mut components_info = vec![];
16-
let empty_result = ProcMacroResult::new(TokenStream::empty());
17+
let no_op_result = ProcMacroResult::new(item_stream.clone());
1718
for arg in args {
1819
let maybe_component = AllowedComponents::from_str(&arg);
1920
match maybe_component {
2021
Ok(component) => {
2122
components_info.push(component.get_info());
2223
}
2324
Err(err) => {
24-
return empty_result.with_diagnostics(err.into());
25+
return no_op_result.with_diagnostics(err.into());
2526
}
2627
}
2728
}
@@ -32,19 +33,16 @@ pub fn with_components(attribute_stream: TokenStream, item_stream: TokenStream)
3233
Ok(node) => WithComponentsParser::new(node, &components_info).parse(&db),
3334
Err(err) => {
3435
let error = Diagnostic::error(err.format(&db));
35-
return empty_result.with_diagnostics(error.into());
36+
return no_op_result.with_diagnostics(error.into());
3637
}
3738
};
3839

39-
let formatted_content = if !content.is_empty() {
40-
format_string(&db, content)
41-
} else {
42-
content
43-
};
44-
45-
let syntax_node = db.parse_virtual(formatted_content).unwrap();
46-
let formatted_content_node = SyntaxNodeWithDb::new(&syntax_node, &db);
47-
ProcMacroResult::new(quote! {#formatted_content_node}).with_diagnostics(diagnostics)
40+
// 3. Merge spans from the item stream into the content
41+
// TODO!: This should be refactored when scarb APIs get improved
42+
let syntax_node_with_spans = merge_spans_from_initial(&item_stream.tokens, &content, &db);
43+
let token_stream =
44+
TokenStream::new(syntax_node_with_spans).with_metadata(item_stream.metadata().clone());
45+
ProcMacroResult::new(token_stream).with_diagnostics(diagnostics)
4846
}
4947

5048
/// Parses the arguments from the attribute stream.

packages/macros/src/attribute/with_components/parser.rs

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@ use crate::{
99
};
1010
use cairo_lang_defs::patcher::{PatchBuilder, RewriteNode};
1111
use cairo_lang_macro::{Diagnostic, Diagnostics};
12-
use cairo_lang_syntax::node::helpers::QueryAttrs;
1312
use cairo_lang_syntax::node::{
1413
ast::{self, MaybeModuleBody},
1514
db::SyntaxGroup,
1615
SyntaxNode, Terminal, TypedSyntaxNode,
1716
};
17+
use cairo_lang_syntax::node::{helpers::QueryAttrs, kind::SyntaxKind};
1818
use indoc::indoc;
1919
use regex::Regex;
2020

@@ -47,9 +47,19 @@ impl<'a> WithComponentsParser<'a> {
4747

4848
let typed = ast::SyntaxFile::from_syntax_node(db, base_node);
4949
let mut base_rnode = RewriteNode::from_ast(&typed);
50-
let module_rnode = base_rnode
51-
.modify_child(db, ast::SyntaxFile::INDEX_ITEMS)
52-
.modify_child(db, 0);
50+
let module_rnode = base_rnode.modify_child(db, ast::SyntaxFile::INDEX_ITEMS);
51+
52+
// If the module has a header doc, skip it
53+
let module_rnode = if let RewriteNode::Copied(copied) = module_rnode {
54+
let children = copied.get_children(db);
55+
if !children.is_empty() && children[0].kind(db) == SyntaxKind::ItemHeaderDoc {
56+
module_rnode.modify_child(db, 1)
57+
} else {
58+
module_rnode.modify_child(db, 0)
59+
}
60+
} else {
61+
module_rnode.modify_child(db, 0)
62+
};
5363

5464
// Validate the contract module
5565
let (errors, mut warnings) =

packages/macros/src/tests/snapshots/openzeppelin_macros__tests__test_type_hash__invalid_type_hash_attribute.snap

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@ snapshot_kind: text
55
---
66
TokenStream:
77

8-
None
8+
pub struct MyType {
9+
pub member: felt252,
10+
}
11+
912

1013
Diagnostics:
1114

packages/macros/src/tests/snapshots/openzeppelin_macros__tests__test_type_hash__with_inner_custom_type.snap

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,14 @@ snapshot_kind: text
55
---
66
TokenStream:
77

8-
None
8+
pub struct MyType {
9+
pub name: felt252,
10+
pub version: felt252,
11+
pub chain_id: felt252,
12+
pub revision: felt252,
13+
pub member: InnerCustomType,
14+
}
15+
916

1017
Diagnostics:
1118

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
---
2+
source: src/tests/test_with_components.rs
3+
expression: result
4+
snapshot_kind: text
5+
---
6+
TokenStream:
7+
8+
#[starknet::contract]
9+
pub mod MyContract {
10+
use starknet::ContractAddress;
11+
#[storage]
12+
pub struct Storage {
13+
#[substorage(v0)]
14+
pub governor_votes: GovernorVotesComponent::Storage,
15+
}
16+
#[constructor]
17+
fn constructor(ref self: ContractState, votes_token: ContractAddress) {
18+
self.governor_votes.initializer(votes_token);
19+
}
20+
use openzeppelin_governance::governor::extensions::GovernorVotesComponent;
21+
22+
component!(path: GovernorVotesComponent, storage: governor_votes, event: GovernorVotesEvent);
23+
24+
impl GovernorVotesInternalImpl = GovernorVotesComponent::InternalImpl<ContractState>;
25+
impl GovernorVotesGovernorVotes = GovernorVotesComponent::GovernorVotes<ContractState>;
26+
27+
#[event]
28+
#[derive(Drop, starknet::Event)]
29+
enum Event {
30+
#[flat]
31+
GovernorVotesEvent: GovernorVotesComponent::Event,
32+
}
33+
}
34+
35+
36+
Diagnostics:
37+
38+
None
39+
40+
AuxData:
41+
42+
None

packages/macros/src/tests/snapshots/openzeppelin_macros__tests__test_with_components__with_invalid_component.snap

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@ snapshot_kind: text
55
---
66
TokenStream:
77

8-
None
8+
#[starknet::contract]
9+
pub mod MyContract {
10+
#[storage]
11+
pub struct Storage {}
12+
}
13+
914

1015
Diagnostics:
1116

0 commit comments

Comments
 (0)