1+ use std:: collections:: HashMap ;
2+
13use internal_baml_diagnostics:: { DatamodelError , Diagnostics } ;
24
35use super :: {
@@ -765,34 +767,56 @@ pub fn parse_expr_block(token: Pair<'_>, diagnostics: &mut Diagnostics) -> Optio
765767 let mut expr = None ;
766768 let _open_bracket = tokens. next ( ) ?;
767769
768- // Collect all items first to process headers together
770+ // Collect all items first so we can gather every header before we bind them
771+ // to statements. We need two passes: the first pass collects and normalizes
772+ // the headers (including establishing their relative levels), the second
773+ // pass walks the statements in source order and attaches those normalized
774+ // headers. If we tried to attach while parsing in a single pass, headers
775+ // appearing inside comment blocks would be seen after their statements and
776+ // could not participate in markdown hierarchy normalization.
769777 let mut items: Vec < Pair < ' _ > > = Vec :: new ( ) ;
770778 for item in tokens {
771779 items. push ( item) ;
772780 }
773781
774782 // Track headers with their hierarchy
783+ // NB(sam): I don't entirely understand why we need to wrap Headers in Arc<>,
784+ // but here are the notes from codex:
785+ // <codex>
786+ // Most AST nodes are owned outright—each node sits in exactly one place in
787+ // the tree—so ordinary struct fields work fine. Header annotations are the
788+ // odd case: the parser needs to attach the same logical header instance to
789+ // multiple spots (statements, trailing expressions, top‑level block etc.)
790+ // while also normalizing them later. To avoid copying or moving those
791+ // structs repeatedly, the parser promotes headers into shared references
792+ // (Arc<Header>). That lets the first pass create and normalize a header
793+ // once, stash it in the lookup map, and then hand out clones of the pointer
794+ // wherever the header appears, without duplication or life‑time juggling.
795+ // Functionally, Arc is central here because headers get reused across many
796+ // nodes, not because other AST structures require special thread‑safety
797+ // treatment.
798+ // </codex>
775799 let mut all_headers_in_block: Vec < std:: sync:: Arc < Header > > = Vec :: new ( ) ;
776800
777801 // First pass: collect all headers
778802 for item in & items {
779- if item. as_rule ( ) == Rule :: mdx_header {
780- let header = parse_header ( item. clone ( ) , diagnostics) ;
781- if let Some ( header) = header {
782- let header_arc = std:: sync:: Arc :: new ( header) ;
783- all_headers_in_block. push ( header_arc. clone ( ) ) ;
803+ if item. as_rule ( ) == Rule :: comment_block {
804+ let headers = headers_from_comment_block ( item. clone ( ) , diagnostics) ;
805+ if !headers. is_empty ( ) {
806+ all_headers_in_block. extend ( headers) ;
784807 }
785808 }
786809 }
787810
788- // Normalize all headers in the block together
811+ // normalize_headers adjusts header levels so the shallowest header in the
812+ // scope becomes an h1
789813 normalize_headers ( & mut all_headers_in_block) ;
790814
791- // Debug: Print normalized headers (disabled)
792- // println!("PARSER: Normalized headers in block:" );
793- // for (i, header) in all_headers_in_block.iter().enumerate() {
794- // println!(" [{}] '{}' (Level: {})", i, header.title , header.level );
795- // }
815+ // Lookup by span so we can reuse normalized headers later.
816+ let mut header_lookup : HashMap < ( usize , usize ) , std :: sync :: Arc < Header > > = HashMap :: new ( ) ;
817+ for header in & all_headers_in_block {
818+ header_lookup . insert ( ( header . span . start , header. span . end ) , header. clone ( ) ) ;
819+ }
796820
797821 // Second pass: process statements and expressions with normalized headers
798822 let mut current_headers: Vec < std:: sync:: Arc < Header > > = Vec :: new ( ) ;
@@ -826,26 +850,6 @@ pub fn parse_expr_block(token: Pair<'_>, diagnostics: &mut Diagnostics) -> Optio
826850 continue ;
827851 }
828852 }
829- Rule :: mdx_header => {
830- // Headers are already processed, just update current headers
831- let header = parse_header ( item, diagnostics) ;
832- if let Some ( header) = header {
833- let header_arc = std:: sync:: Arc :: new ( header) ;
834-
835- // Find the corresponding normalized header
836- if let Some ( normalized_header) = all_headers_in_block
837- . iter ( )
838- . find ( |h| h. title == header_arc. title )
839- {
840- // Implement header hierarchy logic
841- filter_headers_by_hierarchy ( & mut current_headers, normalized_header) ;
842-
843- // Add to current headers and headers since last statement
844- current_headers. push ( normalized_header. clone ( ) ) ;
845- headers_since_last_stmt. push ( normalized_header. clone ( ) ) ;
846- }
847- }
848- }
849853 Rule :: BLOCK_CLOSE => {
850854 // Commentend out because we can't have blocks without return
851855 // expressions otherwise. Plus we need functions with no return
@@ -863,8 +867,18 @@ pub fn parse_expr_block(token: Pair<'_>, diagnostics: &mut Diagnostics) -> Optio
863867 continue ;
864868 }
865869 Rule :: comment_block => {
866- // Skip comments in function bodies
867- continue ;
870+ let headers = headers_from_comment_block ( item, diagnostics) ;
871+ if headers. is_empty ( ) {
872+ continue ;
873+ }
874+ for header in headers {
875+ attach_header_if_known (
876+ & header,
877+ & header_lookup,
878+ & mut current_headers,
879+ & mut headers_since_last_stmt,
880+ ) ;
881+ }
868882 }
869883 Rule :: empty_lines => {
870884 // Skip empty lines in function bodies
@@ -942,44 +956,75 @@ pub fn parse_expr_block(token: Pair<'_>, diagnostics: &mut Diagnostics) -> Optio
942956 } )
943957}
944958
945- /// Parse a single header from an MDX header token
946- pub fn parse_header ( token : Pair < ' _ > , diagnostics : & mut Diagnostics ) -> Option < Header > {
947- let full_text = token. as_str ( ) ;
948- let header_span = diagnostics. span ( token. as_span ( ) ) ;
949-
950- // Find the start of the hash sequence
951- let hash_start = full_text. find ( '#' ) ?;
952- let after_whitespace = full_text[ hash_start..] . trim_start ( ) ;
953-
954- // Count consecutive hash characters
955- let hash_count = after_whitespace. chars ( ) . take_while ( |& c| c == '#' ) . count ( ) ;
959+ fn headers_from_comment_block (
960+ token : Pair < ' _ > ,
961+ diagnostics : & mut Diagnostics ,
962+ ) -> Vec < std:: sync:: Arc < Header > > {
963+ if token. as_rule ( ) != Rule :: comment_block {
964+ return Vec :: new ( ) ;
965+ }
956966
957- // Extract the title after the hash sequence and whitespace
958- let after_hashes = & after_whitespace[ hash_count..] ;
959- let title_text = after_hashes. trim ( ) . to_string ( ) ;
967+ let mut headers = Vec :: new ( ) ;
968+ for current in token. into_inner ( ) {
969+ if current. as_rule ( ) == Rule :: comment {
970+ if let Some ( header) = parse_comment_header_pair ( & current, diagnostics) {
971+ headers. push ( std:: sync:: Arc :: new ( header) ) ;
972+ }
973+ }
974+ }
975+ headers
976+ }
960977
961- // Remove trailing newline if present
962- let title_text = title_text
963- . trim_end_matches ( '\n' )
964- . trim_end_matches ( '\r' )
965- . to_string ( ) ;
978+ pub ( crate ) fn parse_comment_header_pair (
979+ comment : & Pair < ' _ > ,
980+ diagnostics : & mut Diagnostics ,
981+ ) -> Option < Header > {
982+ let span = diagnostics. span ( comment. as_span ( ) ) ;
983+ let mut text = comment. as_str ( ) . trim_start ( ) ;
984+ if !text. starts_with ( "//" ) {
985+ return None ;
986+ }
987+ text = & text[ 2 ..] ;
988+ let text = text. trim_start ( ) ;
989+ if !text. starts_with ( '#' ) {
990+ return None ;
991+ }
966992
967- let level = hash_count as u8 ;
993+ let mut level = 0usize ;
994+ for ch in text. chars ( ) {
995+ if ch == '#' {
996+ level += 1 ;
997+ } else {
998+ break ;
999+ }
1000+ }
1001+ if level == 0 {
1002+ return None ;
1003+ }
9681004
969- // Print debug information about the header (disabled)
970- // let indent = " ".repeat(level as usize);
971- // println!(
972- // "{}└ HEADER Level {}: '{}' (hash count: {})",
973- // indent, level, title_text, level
974- // );
1005+ let rest = text[ level..] . trim ( ) . to_string ( ) ;
9751006
9761007 Some ( Header {
977- level,
978- title : title_text ,
979- span : header_span ,
1008+ level : level as u8 ,
1009+ title : rest ,
1010+ span,
9801011 } )
9811012}
9821013
1014+ fn attach_header_if_known (
1015+ header : & std:: sync:: Arc < Header > ,
1016+ lookup : & HashMap < ( usize , usize ) , std:: sync:: Arc < Header > > ,
1017+ current_headers : & mut Vec < std:: sync:: Arc < Header > > ,
1018+ headers_since_last_stmt : & mut Vec < std:: sync:: Arc < Header > > ,
1019+ ) {
1020+ let key = ( header. span . start , header. span . end ) ;
1021+ if let Some ( normalized_header) = lookup. get ( & key) {
1022+ filter_headers_by_hierarchy ( current_headers, normalized_header) ;
1023+ current_headers. push ( normalized_header. clone ( ) ) ;
1024+ headers_since_last_stmt. push ( normalized_header. clone ( ) ) ;
1025+ }
1026+ }
1027+
9831028/// Filter headers based on hierarchy rules (markdown-style nesting)
9841029fn filter_headers_by_hierarchy (
9851030 pending_headers : & mut Vec < std:: sync:: Arc < Header > > ,
0 commit comments