Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 96 additions & 2 deletions baml_language/crates/baml_compiler2_ast/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,98 @@ pub struct FunctionTypeParam {
pub ty: TypeExpr,
}

/// Recursive spanned type expression kind — mirrors `TypeExpr` but children
/// are `SpannedTypeExpr` so every sub-expression carries its own source span.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SpannedTypeExprKind {
Path(Vec<Name>),
Int,
Float,
String,
Bool,
Null,
Never,
Media(baml_base::MediaKind),
Optional(Box<SpannedTypeExpr>),
List(Box<SpannedTypeExpr>),
Map {
key: Box<SpannedTypeExpr>,
value: Box<SpannedTypeExpr>,
},
Union(Vec<SpannedTypeExpr>),
Literal(baml_base::Literal),
Function {
params: Vec<SpannedFunctionTypeParam>,
ret: Box<SpannedTypeExpr>,
},
BuiltinUnknown,
Type,
Rust,
Error,
Unknown,
}

/// A parameter in a spanned function type expression.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SpannedFunctionTypeParam {
pub name: Option<Name>,
pub ty: SpannedTypeExpr,
}

/// A type expression with its source span — used in item definitions
/// where we need both the type data and the source location.
///
/// Recursive: children are also `SpannedTypeExpr`, so every sub-expression
/// (e.g. each union member) carries its own span for precise diagnostics.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SpannedTypeExpr {
pub expr: TypeExpr,
pub kind: SpannedTypeExprKind,
pub span: TextRange,
}

impl SpannedTypeExpr {
/// Recursively strip spans to produce a span-free `TypeExpr` for Salsa queries.
pub fn to_type_expr(&self) -> TypeExpr {
match &self.kind {
SpannedTypeExprKind::Path(segments) => TypeExpr::Path(segments.clone()),
SpannedTypeExprKind::Int => TypeExpr::Int,
SpannedTypeExprKind::Float => TypeExpr::Float,
SpannedTypeExprKind::String => TypeExpr::String,
SpannedTypeExprKind::Bool => TypeExpr::Bool,
SpannedTypeExprKind::Null => TypeExpr::Null,
SpannedTypeExprKind::Never => TypeExpr::Never,
SpannedTypeExprKind::Media(kind) => TypeExpr::Media(*kind),
SpannedTypeExprKind::Optional(inner) => {
TypeExpr::Optional(Box::new(inner.to_type_expr()))
}
SpannedTypeExprKind::List(inner) => TypeExpr::List(Box::new(inner.to_type_expr())),
SpannedTypeExprKind::Map { key, value } => TypeExpr::Map {
key: Box::new(key.to_type_expr()),
value: Box::new(value.to_type_expr()),
},
SpannedTypeExprKind::Union(members) => {
TypeExpr::Union(members.iter().map(SpannedTypeExpr::to_type_expr).collect())
}
SpannedTypeExprKind::Literal(lit) => TypeExpr::Literal(lit.clone()),
SpannedTypeExprKind::Function { params, ret } => TypeExpr::Function {
params: params
.iter()
.map(|p| FunctionTypeParam {
name: p.name.clone(),
ty: p.ty.to_type_expr(),
})
.collect(),
ret: Box::new(ret.to_type_expr()),
},
SpannedTypeExprKind::BuiltinUnknown => TypeExpr::BuiltinUnknown,
SpannedTypeExprKind::Type => TypeExpr::Type,
SpannedTypeExprKind::Rust => TypeExpr::Rust,
SpannedTypeExprKind::Error => TypeExpr::Error,
SpannedTypeExprKind::Unknown => TypeExpr::Unknown,
}
}
}

// ── Expression Bodies ───────────────────────────────────────────
//
// Full expression/statement arena — modeled after the existing
Expand Down Expand Up @@ -130,6 +214,9 @@ pub struct AstSourceMap {
pub pattern_spans: Arena<TextRange>,
pub match_arm_spans: Arena<TextRange>,
pub type_annotation_spans: Arena<TextRange>,
/// Parallel to `type_annotation_spans` — full `SpannedTypeExpr` tree for
/// per-node diagnostics (e.g. underline only the bad union member).
pub type_annotation_spanned_exprs: Vec<SpannedTypeExpr>,
pub catch_arm_spans: Arena<TextRange>,
}

Expand All @@ -141,6 +228,7 @@ impl AstSourceMap {
pattern_spans: Arena::new(),
match_arm_spans: Arena::new(),
type_annotation_spans: Arena::new(),
type_annotation_spanned_exprs: Vec::new(),
catch_arm_spans: Arena::new(),
}
}
Expand Down Expand Up @@ -188,6 +276,12 @@ impl AstSourceMap {
.unwrap_or_default()
}

/// Look up the full `SpannedTypeExpr` for a type annotation (for per-node diagnostics).
pub fn type_annotation_spanned(&self, id: TypeAnnotId) -> Option<&SpannedTypeExpr> {
let raw: u32 = id.into_raw().into_u32();
self.type_annotation_spanned_exprs.get(raw as usize)
}

/// Look up the source span of a catch arm by its `CatchArmId`.
pub fn catch_arm_span(&self, id: CatchArmId) -> TextRange {
let raw: u32 = id.into_raw().into_u32();
Expand Down Expand Up @@ -314,7 +408,7 @@ pub enum Stmt {
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Pattern {
Binding(Name),
TypedBinding { name: Name, ty: TypeExpr },
TypedBinding { name: Name, ty: SpannedTypeExpr },
Literal(Literal),
Null,
EnumVariant { enum_name: Name, variant: Name },
Expand Down
40 changes: 33 additions & 7 deletions baml_language/crates/baml_compiler2_ast/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ mod tests {
use baml_compiler_syntax::{SyntaxKind, SyntaxNode};

use crate::{
ast::{BuiltinKind, Expr, FunctionBodyDef, Item, Stmt, TypeExpr},
ast::{BuiltinKind, Expr, FunctionBodyDef, Item, SpannedTypeExprKind, Stmt, TypeExpr},
lower_cst::lower_file,
};

Expand Down Expand Up @@ -336,7 +336,7 @@ class Media {
.expect("expected _data field");

match &field.type_expr {
Some(spanned) => match &spanned.expr {
Some(spanned) => match spanned.to_type_expr() {
TypeExpr::Rust => {}
other => panic!("expected TypeExpr::Rust, got {other:?}"),
},
Expand Down Expand Up @@ -446,7 +446,11 @@ class Media {
assert!(data_field.is_some(), "expected _data field");
assert!(
matches!(
data_field.unwrap().type_expr.as_ref().map(|te| &te.expr),
data_field
.unwrap()
.type_expr
.as_ref()
.map(super::ast::SpannedTypeExpr::to_type_expr),
Some(TypeExpr::Rust)
),
"_data field should have TypeExpr::Rust"
Expand All @@ -467,10 +471,32 @@ function f() -> int throws never {
let throws = func
.throws
.expect("expected throws clause to be lowered into FunctionDef.throws");
assert!(
matches!(throws.expr, TypeExpr::Never),
"expected throws type to lower as TypeExpr::Never, got {:?}",
throws.expr
match throws.to_type_expr() {
TypeExpr::Never => {}
other => panic!("expected throws type to lower as TypeExpr::Never, got {other:?}"),
}
}

#[test]
fn nested_container_type_has_distinct_spans_per_level() {
let source = "function f(x: int[][]) -> int { return 0 }";
let items = parse_and_lower(source);
let func = first_function(items);
let param = func.params.first().expect("one param");
let te = param.type_expr.as_ref().expect("param has type annotation");
// int[][] lowers to List(List(Int)); outer and inner List should have distinct spans.
let SpannedTypeExprKind::List(inner) = &te.kind else {
panic!("expected outer List, got {:?}", te.kind);
};
let SpannedTypeExprKind::List(inner_inner) = &inner.kind else {
panic!("expected inner List, got {:?}", inner.kind);
};
let SpannedTypeExprKind::Int = &inner_inner.kind else {
panic!("expected Int, got {:?}", inner_inner.kind);
};
assert_ne!(
te.span, inner.span,
"nested container levels should have distinct spans for precise diagnostics"
);
}

Expand Down
34 changes: 12 additions & 22 deletions baml_language/crates/baml_compiler2_ast/src/lower_cst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ use crate::{
ast::{
BuiltinKind, ClientDef, ConfigItemDef, EnumDef, FieldDef, FunctionBodyDef, FunctionDef,
GeneratorDef, Interpolation, Item, LlmBodyDef, Param, RawAttribute, RawAttributeArg,
RawPrompt, RetryPolicyDef, SpannedTypeExpr, TemplateStringDef, TestDef, TypeAliasDef,
VariantDef,
RawPrompt, RetryPolicyDef, TemplateStringDef, TestDef, TypeAliasDef, VariantDef,
},
lower_expr_body, lower_type_expr,
};
Expand Down Expand Up @@ -102,18 +101,14 @@ fn lower_function(node: &SyntaxNode) -> Option<FunctionDef> {
.map(|pl| lower_params(&pl))
.unwrap_or_default();

let return_type = func.return_type().map(|te| SpannedTypeExpr {
expr: lower_type_expr::lower_type_expr_node(&te),
span: te.syntax().text_range(),
});
let return_type = func
.return_type()
.map(|te| lower_type_expr::lower_type_expr_node(&te));

let throws = func
.throws_clause()
.and_then(|tc| tc.type_expr())
.map(|te| SpannedTypeExpr {
expr: lower_type_expr::lower_type_expr_node(&te),
span: te.syntax().text_range(),
});
.map(|te| lower_type_expr::lower_type_expr_node(&te));

let body = if let Some(llm) = func.llm_body() {
Some(FunctionBodyDef::Llm(lower_llm_body(&llm)))
Expand Down Expand Up @@ -182,10 +177,9 @@ fn lower_param(param: &ast::Parameter) -> Option<Param> {
let name_token = param.name()?;
Some(Param {
name: Name::new(name_token.text()),
type_expr: param.ty().map(|te| SpannedTypeExpr {
expr: lower_type_expr::lower_type_expr_node(&te),
span: te.syntax().text_range(),
}),
type_expr: param
.ty()
.map(|te| lower_type_expr::lower_type_expr_node(&te)),
span: param.syntax().text_range(),
name_span: name_token.text_range(),
})
Expand Down Expand Up @@ -268,10 +262,7 @@ fn lower_class(node: &SyntaxNode) -> Option<crate::ast::ClassDef> {
let fname = f.name()?;
Some(FieldDef {
name: Name::new(fname.text()),
type_expr: f.ty().map(|te| SpannedTypeExpr {
expr: lower_type_expr::lower_type_expr_node(&te),
span: te.syntax().text_range(),
}),
type_expr: f.ty().map(|te| lower_type_expr::lower_type_expr_node(&te)),
attributes: lower_field_attributes(&f),
span: f.syntax().text_range(),
name_span: fname.text_range(),
Expand Down Expand Up @@ -357,10 +348,9 @@ fn lower_type_alias(node: &SyntaxNode) -> Option<TypeAliasDef> {

Some(TypeAliasDef {
name: Name::new(name_token.text()),
type_expr: alias.ty().map(|te| SpannedTypeExpr {
expr: lower_type_expr::lower_type_expr_node(&te),
span: te.syntax().text_range(),
}),
type_expr: alias
.ty()
.map(|te| lower_type_expr::lower_type_expr_node(&te)),
span: node.text_range(),
name_span: name_token.text_range(),
})
Expand Down
34 changes: 19 additions & 15 deletions baml_language/crates/baml_compiler2_ast/src/lower_expr_body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,10 @@ impl LoweringContext {
id
}

fn alloc_type_annot(&mut self, ty: TypeExpr, range: TextRange) -> TypeAnnotId {
let id = self.type_annotations.alloc(ty);
self.source_map.type_annotation_spans.alloc(range);
fn alloc_type_annot(&mut self, spanned: crate::ast::SpannedTypeExpr) -> TypeAnnotId {
let id = self.type_annotations.alloc(spanned.to_type_expr());
self.source_map.type_annotation_spans.alloc(spanned.span);
self.source_map.type_annotation_spanned_exprs.push(spanned);
id
}

Expand Down Expand Up @@ -585,9 +586,8 @@ impl LoweringContext {
if let Some(type_expr) =
baml_compiler_syntax::ast::TypeExpr::cast(child.clone())
{
let span = child.text_range();
let ty = crate::lower_type_expr::lower_type_expr_node(&type_expr);
scrutinee_type = Some(self.alloc_type_annot(ty, span));
let spanned = crate::lower_type_expr::lower_type_expr_node(&type_expr);
scrutinee_type = Some(self.alloc_type_annot(spanned));
}
}
_ => {
Expand Down Expand Up @@ -780,11 +780,15 @@ impl LoweringContext {
// After `name:`, we expect the type to be a node child (TYPE_EXPR),
// but sometimes parser emits it as a WORD token directly.
// Treat it as a named type.
let pat = Pattern::TypedBinding {
name,
ty: crate::ast::TypeExpr::Path(vec![Name::new(&text)]),
let span = token.text_range();
let ty = crate::ast::SpannedTypeExpr {
kind: crate::ast::SpannedTypeExprKind::Path(vec![Name::new(
&text,
)]),
span,
};
elements.push(self.alloc_pattern(pat, token.text_range()));
let pat = Pattern::TypedBinding { name, ty };
elements.push(self.alloc_pattern(pat, span));
continue;
}

Expand Down Expand Up @@ -892,9 +896,9 @@ impl LoweringContext {
if let Some(type_expr) =
baml_compiler_syntax::ast::TypeExpr::cast(child.clone())
{
let ty =
let spanned =
crate::lower_type_expr::lower_type_expr_node(&type_expr);
let pat = Pattern::TypedBinding { name, ty };
let pat = Pattern::TypedBinding { name, ty: spanned };
elements.push(self.alloc_pattern(pat, child.text_range()));
}
}
Expand Down Expand Up @@ -1747,9 +1751,9 @@ impl LoweringContext {
if let Some(type_expr) =
baml_compiler_syntax::ast::TypeExpr::cast(child.clone())
{
let span = child.text_range();
let ty = crate::lower_type_expr::lower_type_expr_node(&type_expr);
type_annotation = Some(self.alloc_type_annot(ty, span));
let spanned =
crate::lower_type_expr::lower_type_expr_node(&type_expr);
type_annotation = Some(self.alloc_type_annot(spanned));
seen_colon = false;
}
} else if pattern_id.is_none() {
Expand Down
Loading
Loading