Skip to content

Commit 07a0bb6

Browse files
authored
feat/unreleased: annotate viz nodes using //# not # (#2610)
1 parent ac0ede8 commit 07a0bb6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+507
-279
lines changed

engine/baml-lib/ast/src/parser/datamodel.pest

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
schema = {
2-
SOI ~ (mdx_header | expr_fn | top_level_assignment | value_expression_block | type_expression_block | template_declaration | type_alias | comment_block | raw_string_literal | empty_lines | CATCH_ALL)* ~ EOI
2+
SOI ~ (expr_fn | top_level_assignment | value_expression_block | type_expression_block | template_declaration | type_alias | comment_block | raw_string_literal | empty_lines | CATCH_ALL)* ~ EOI
33
}
44

55
// ######################################
@@ -27,7 +27,7 @@ field_type_with_attr = { field_type ~ (NEWLINE? ~ (field_attribute | trailing_co
2727
value_expression_keyword = { FUNCTION_KEYWORD | TEST_KEYWORD | CLIENT_KEYWORD | RETRY_POLICY_KEYWORD | GENERATOR_KEYWORD }
2828
value_expression_block = { value_expression_keyword ~ identifier ~ named_argument_list? ~ ARROW? ~ field_type_chain? ~ SPACER_TEXT ~ BLOCK_OPEN ~ value_expression_contents ~ BLOCK_CLOSE }
2929
value_expression_contents = {
30-
(mdx_header | stmt | type_builder_block | value_expression | comment_block | block_attribute | empty_lines | BLOCK_LEVEL_CATCH_ALL)*
30+
(stmt | type_builder_block | value_expression | comment_block | block_attribute | empty_lines | BLOCK_LEVEL_CATCH_ALL)*
3131
}
3232
value_expression = { identifier ~ config_expression? ~ (NEWLINE? ~ field_attribute)* ~ trailing_comment? }
3333

@@ -339,7 +339,7 @@ top_level_stmt = {
339339
expr_fn = { "function" ~ identifier ~ named_argument_list ~ ARROW? ~ field_type_chain? ~ expr_block }
340340

341341
// Body of a function (including curly brackets).
342-
expr_block = { BLOCK_OPEN ~ NEWLINE? ~ (mdx_header | expr_body_stmt | stmt | comment_block | empty_lines)* ~ expression? ~ (comment_block | empty_lines)* ~ BLOCK_CLOSE }
342+
expr_block = { BLOCK_OPEN ~ NEWLINE? ~ (expr_body_stmt | stmt | comment_block | empty_lines)* ~ expression? ~ (comment_block | empty_lines)* ~ BLOCK_CLOSE }
343343

344344
// More forgiving statement rule for function bodies - only for statements that commonly miss semicolons
345345
expr_body_stmt = {
@@ -356,8 +356,6 @@ expr_body_stmt = {
356356
~ NEWLINE?
357357
}
358358

359-
// Headers can include any characters except a newline; do not use unquoted_string_literal restrictions
360-
mdx_header = { WHITESPACE* ~ "#"+ ~ WHITESPACE* ~ (!NEWLINE ~ ANY)* ~ NEWLINE? }
361359

362360
// Statement.
363361
stmt = {

engine/baml-lib/ast/src/parser/parse.rs

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use pest::Parser;
77

88
use super::{
99
parse_assignment::parse_assignment,
10-
parse_expr::{parse_expr_fn, parse_header, parse_top_level_assignment},
10+
parse_expr::{parse_comment_header_pair, parse_expr_fn, parse_top_level_assignment},
1111
parse_expression::parse_expression,
1212
parse_template_string::parse_template_string,
1313
parse_type_expression_block::parse_type_expression_block,
@@ -167,12 +167,13 @@ pub fn parse(root_path: &Path, source: &SourceFile) -> Result<(Ast, Diagnostics)
167167
));
168168
break;
169169
}
170-
Rule::mdx_header => {
171-
if let Some(header) = parse_header(current, &mut diagnostics) {
172-
pending_headers.push(header);
173-
}
174-
}
175170
Rule::comment_block => {
171+
let headers =
172+
headers_from_comment_block_top_level(current.clone(), &mut diagnostics);
173+
if !headers.is_empty() {
174+
pending_headers.extend(headers);
175+
continue;
176+
}
176177
match pairs.peek().map(|b| b.as_rule()) {
177178
Some(Rule::empty_lines) => {
178179
// free floating
@@ -226,6 +227,25 @@ pub fn parse(root_path: &Path, source: &SourceFile) -> Result<(Ast, Diagnostics)
226227
}
227228
}
228229

230+
fn headers_from_comment_block_top_level(
231+
token: pest::iterators::Pair<'_, Rule>,
232+
diagnostics: &mut Diagnostics,
233+
) -> Vec<Header> {
234+
if token.as_rule() != Rule::comment_block {
235+
return Vec::new();
236+
}
237+
238+
let mut headers = Vec::new();
239+
for current in token.into_inner() {
240+
if current.as_rule() == Rule::comment {
241+
if let Some(header) = parse_comment_header_pair(&current, diagnostics) {
242+
headers.push(header);
243+
}
244+
}
245+
}
246+
headers
247+
}
248+
229249
fn get_expected_from_error(positives: &[Rule]) -> String {
230250
use std::fmt::Write as _;
231251
let mut out = String::with_capacity(positives.len() * 6);

engine/baml-lib/ast/src/parser/parse_expr.rs

Lines changed: 108 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::collections::HashMap;
2+
13
use internal_baml_diagnostics::{DatamodelError, Diagnostics};
24

35
use super::{
@@ -765,34 +767,56 @@ pub fn parse_expr_block(token: Pair<'_>, diagnostics: &mut Diagnostics) -> Optio
765767
let mut expr = None;
766768
let _open_bracket = tokens.next()?;
767769

768-
// Collect all items first to process headers together
770+
// Collect all items first so we can gather every header before we bind them
771+
// to statements. We need two passes: the first pass collects and normalizes
772+
// the headers (including establishing their relative levels), the second
773+
// pass walks the statements in source order and attaches those normalized
774+
// headers. If we tried to attach while parsing in a single pass, headers
775+
// appearing inside comment blocks would be seen after their statements and
776+
// could not participate in markdown hierarchy normalization.
769777
let mut items: Vec<Pair<'_>> = Vec::new();
770778
for item in tokens {
771779
items.push(item);
772780
}
773781

774782
// Track headers with their hierarchy
783+
// NB(sam): I don't entirely understand why we need to wrap Headers in Arc<>,
784+
// but here are the notes from codex:
785+
// <codex>
786+
// Most AST nodes are owned outright—each node sits in exactly one place in
787+
// the tree—so ordinary struct fields work fine. Header annotations are the
788+
// odd case: the parser needs to attach the same logical header instance to
789+
// multiple spots (statements, trailing expressions, top‑level block etc.)
790+
// while also normalizing them later. To avoid copying or moving those
791+
// structs repeatedly, the parser promotes headers into shared references
792+
// (Arc<Header>). That lets the first pass create and normalize a header
793+
// once, stash it in the lookup map, and then hand out clones of the pointer
794+
// wherever the header appears, without duplication or life‑time juggling.
795+
// Functionally, Arc is central here because headers get reused across many
796+
// nodes, not because other AST structures require special thread‑safety
797+
// treatment.
798+
// </codex>
775799
let mut all_headers_in_block: Vec<std::sync::Arc<Header>> = Vec::new();
776800

777801
// First pass: collect all headers
778802
for item in &items {
779-
if item.as_rule() == Rule::mdx_header {
780-
let header = parse_header(item.clone(), diagnostics);
781-
if let Some(header) = header {
782-
let header_arc = std::sync::Arc::new(header);
783-
all_headers_in_block.push(header_arc.clone());
803+
if item.as_rule() == Rule::comment_block {
804+
let headers = headers_from_comment_block(item.clone(), diagnostics);
805+
if !headers.is_empty() {
806+
all_headers_in_block.extend(headers);
784807
}
785808
}
786809
}
787810

788-
// Normalize all headers in the block together
811+
// normalize_headers adjusts header levels so the shallowest header in the
812+
// scope becomes an h1
789813
normalize_headers(&mut all_headers_in_block);
790814

791-
// Debug: Print normalized headers (disabled)
792-
// println!("PARSER: Normalized headers in block:");
793-
// for (i, header) in all_headers_in_block.iter().enumerate() {
794-
// println!(" [{}] '{}' (Level: {})", i, header.title, header.level);
795-
// }
815+
// Lookup by span so we can reuse normalized headers later.
816+
let mut header_lookup: HashMap<(usize, usize), std::sync::Arc<Header>> = HashMap::new();
817+
for header in &all_headers_in_block {
818+
header_lookup.insert((header.span.start, header.span.end), header.clone());
819+
}
796820

797821
// Second pass: process statements and expressions with normalized headers
798822
let mut current_headers: Vec<std::sync::Arc<Header>> = Vec::new();
@@ -826,26 +850,6 @@ pub fn parse_expr_block(token: Pair<'_>, diagnostics: &mut Diagnostics) -> Optio
826850
continue;
827851
}
828852
}
829-
Rule::mdx_header => {
830-
// Headers are already processed, just update current headers
831-
let header = parse_header(item, diagnostics);
832-
if let Some(header) = header {
833-
let header_arc = std::sync::Arc::new(header);
834-
835-
// Find the corresponding normalized header
836-
if let Some(normalized_header) = all_headers_in_block
837-
.iter()
838-
.find(|h| h.title == header_arc.title)
839-
{
840-
// Implement header hierarchy logic
841-
filter_headers_by_hierarchy(&mut current_headers, normalized_header);
842-
843-
// Add to current headers and headers since last statement
844-
current_headers.push(normalized_header.clone());
845-
headers_since_last_stmt.push(normalized_header.clone());
846-
}
847-
}
848-
}
849853
Rule::BLOCK_CLOSE => {
850854
// Commentend out because we can't have blocks without return
851855
// expressions otherwise. Plus we need functions with no return
@@ -863,8 +867,18 @@ pub fn parse_expr_block(token: Pair<'_>, diagnostics: &mut Diagnostics) -> Optio
863867
continue;
864868
}
865869
Rule::comment_block => {
866-
// Skip comments in function bodies
867-
continue;
870+
let headers = headers_from_comment_block(item, diagnostics);
871+
if headers.is_empty() {
872+
continue;
873+
}
874+
for header in headers {
875+
attach_header_if_known(
876+
&header,
877+
&header_lookup,
878+
&mut current_headers,
879+
&mut headers_since_last_stmt,
880+
);
881+
}
868882
}
869883
Rule::empty_lines => {
870884
// Skip empty lines in function bodies
@@ -942,44 +956,75 @@ pub fn parse_expr_block(token: Pair<'_>, diagnostics: &mut Diagnostics) -> Optio
942956
})
943957
}
944958

945-
/// Parse a single header from an MDX header token
946-
pub fn parse_header(token: Pair<'_>, diagnostics: &mut Diagnostics) -> Option<Header> {
947-
let full_text = token.as_str();
948-
let header_span = diagnostics.span(token.as_span());
949-
950-
// Find the start of the hash sequence
951-
let hash_start = full_text.find('#')?;
952-
let after_whitespace = full_text[hash_start..].trim_start();
953-
954-
// Count consecutive hash characters
955-
let hash_count = after_whitespace.chars().take_while(|&c| c == '#').count();
959+
fn headers_from_comment_block(
960+
token: Pair<'_>,
961+
diagnostics: &mut Diagnostics,
962+
) -> Vec<std::sync::Arc<Header>> {
963+
if token.as_rule() != Rule::comment_block {
964+
return Vec::new();
965+
}
956966

957-
// Extract the title after the hash sequence and whitespace
958-
let after_hashes = &after_whitespace[hash_count..];
959-
let title_text = after_hashes.trim().to_string();
967+
let mut headers = Vec::new();
968+
for current in token.into_inner() {
969+
if current.as_rule() == Rule::comment {
970+
if let Some(header) = parse_comment_header_pair(&current, diagnostics) {
971+
headers.push(std::sync::Arc::new(header));
972+
}
973+
}
974+
}
975+
headers
976+
}
960977

961-
// Remove trailing newline if present
962-
let title_text = title_text
963-
.trim_end_matches('\n')
964-
.trim_end_matches('\r')
965-
.to_string();
978+
pub(crate) fn parse_comment_header_pair(
979+
comment: &Pair<'_>,
980+
diagnostics: &mut Diagnostics,
981+
) -> Option<Header> {
982+
let span = diagnostics.span(comment.as_span());
983+
let mut text = comment.as_str().trim_start();
984+
if !text.starts_with("//") {
985+
return None;
986+
}
987+
text = &text[2..];
988+
let text = text.trim_start();
989+
if !text.starts_with('#') {
990+
return None;
991+
}
966992

967-
let level = hash_count as u8;
993+
let mut level = 0usize;
994+
for ch in text.chars() {
995+
if ch == '#' {
996+
level += 1;
997+
} else {
998+
break;
999+
}
1000+
}
1001+
if level == 0 {
1002+
return None;
1003+
}
9681004

969-
// Print debug information about the header (disabled)
970-
// let indent = " ".repeat(level as usize);
971-
// println!(
972-
// "{}└ HEADER Level {}: '{}' (hash count: {})",
973-
// indent, level, title_text, level
974-
// );
1005+
let rest = text[level..].trim().to_string();
9751006

9761007
Some(Header {
977-
level,
978-
title: title_text,
979-
span: header_span,
1008+
level: level as u8,
1009+
title: rest,
1010+
span,
9801011
})
9811012
}
9821013

1014+
fn attach_header_if_known(
1015+
header: &std::sync::Arc<Header>,
1016+
lookup: &HashMap<(usize, usize), std::sync::Arc<Header>>,
1017+
current_headers: &mut Vec<std::sync::Arc<Header>>,
1018+
headers_since_last_stmt: &mut Vec<std::sync::Arc<Header>>,
1019+
) {
1020+
let key = (header.span.start, header.span.end);
1021+
if let Some(normalized_header) = lookup.get(&key) {
1022+
filter_headers_by_hierarchy(current_headers, normalized_header);
1023+
current_headers.push(normalized_header.clone());
1024+
headers_since_last_stmt.push(normalized_header.clone());
1025+
}
1026+
}
1027+
9831028
/// Filter headers based on hierarchy rules (markdown-style nesting)
9841029
fn filter_headers_by_hierarchy(
9851030
pending_headers: &mut Vec<std::sync::Arc<Header>>,

engine/baml-lib/ast/src/parser/parse_expression.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -711,19 +711,19 @@ mod tests {
711711
}
712712

713713
#[test]
714-
fn test_mdx_header_parsing() {
715-
println!("\n=== Testing MDX Header Parsing ===");
714+
fn test_comment_header_parsing() {
715+
println!("\n=== Testing Comment Header Parsing ===");
716716

717717
let input = r#"{
718-
# Level 1 Header
718+
//# Level 1 Header
719719
let x = "hello";
720720
721-
## Level 2 Header
721+
//## Level 2 Header
722722
let y = "world";
723723
724-
########### Level 11 Header
724+
//########### Level 11 Header
725725
726-
### Level 3 Headers
726+
//### Level 3 Headers
727727
x + y
728728
}"#;
729729

@@ -732,7 +732,7 @@ mod tests {
732732
let mut diagnostics = Diagnostics::new(root_path.into());
733733
diagnostics.set_source(&source);
734734

735-
println!("Parsing expression block with mdx headers...");
735+
println!("Parsing expression block with comment headers...");
736736

737737
let pair_result = BAMLParser::parse(Rule::expr_block, input);
738738
match pair_result {
@@ -762,21 +762,21 @@ mod tests {
762762
fn test_complex_header_hierarchy() {
763763
println!("\n=== Testing Complex Header Hierarchy ===");
764764

765-
let input = r#"# Loop Processing
765+
let input = r#"//# Loop Processing
766766
fn ForLoopWithHeaders() -> int {
767767
let items = [1, 2, 3, 4, 5];
768768
let result = 0;
769769
770-
## Main Loop
770+
//## Main Loop
771771
for (item in items) {
772-
### Item Processing
772+
//### Item Processing
773773
let processed = item * 2;
774774
775-
#### Accumulation
775+
//#### Accumulation
776776
result = result + processed;
777777
}
778778
779-
## Final Result
779+
//## Final Result
780780
result
781781
}"#;
782782

engine/baml-lib/ast/src/parser/parse_value_expression_block.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ pub(crate) fn parse_value_expression_block(
127127
}
128128
}
129129
Rule::empty_lines => {}
130-
Rule::BLOCK_LEVEL_CATCH_ALL | Rule::mdx_header => {
130+
Rule::BLOCK_LEVEL_CATCH_ALL => {
131131
diagnostics.push_error(DatamodelError::new_validation_error(
132132
"This line is not a valid field or attribute definition. A valid property may look like: 'myProperty \"some value\"' for example, with no colons.",
133133
diagnostics.span(item.as_span()),

0 commit comments

Comments
 (0)