Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 85 additions & 1 deletion baml_language/crates/baml_compiler2_ast/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,98 @@ pub struct FunctionTypeParam {
pub ty: TypeExpr,
}

/// Recursive spanned type expression kind — mirrors `TypeExpr` but children
/// are `SpannedTypeExpr` so every sub-expression carries its own source span.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SpannedTypeExprKind {
Path(Vec<Name>),
Int,
Float,
String,
Bool,
Null,
Never,
Media(baml_base::MediaKind),
Optional(Box<SpannedTypeExpr>),
List(Box<SpannedTypeExpr>),
Map {
key: Box<SpannedTypeExpr>,
value: Box<SpannedTypeExpr>,
},
Union(Vec<SpannedTypeExpr>),
Literal(baml_base::Literal),
Function {
params: Vec<SpannedFunctionTypeParam>,
ret: Box<SpannedTypeExpr>,
},
BuiltinUnknown,
Type,
Rust,
Error,
Unknown,
}

/// A parameter in a spanned function type expression.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SpannedFunctionTypeParam {
pub name: Option<Name>,
pub ty: SpannedTypeExpr,
}

/// A type expression with its source span — used in item definitions
/// where we need both the type data and the source location.
///
/// Recursive: children are also `SpannedTypeExpr`, so every sub-expression
/// (e.g. each union member) carries its own span for precise diagnostics.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SpannedTypeExpr {
pub expr: TypeExpr,
pub kind: SpannedTypeExprKind,
pub span: TextRange,
}

impl SpannedTypeExpr {
/// Recursively strip spans to produce a span-free `TypeExpr` for Salsa queries.
pub fn to_type_expr(&self) -> TypeExpr {
match &self.kind {
SpannedTypeExprKind::Path(segments) => TypeExpr::Path(segments.clone()),
SpannedTypeExprKind::Int => TypeExpr::Int,
SpannedTypeExprKind::Float => TypeExpr::Float,
SpannedTypeExprKind::String => TypeExpr::String,
SpannedTypeExprKind::Bool => TypeExpr::Bool,
SpannedTypeExprKind::Null => TypeExpr::Null,
SpannedTypeExprKind::Never => TypeExpr::Never,
SpannedTypeExprKind::Media(kind) => TypeExpr::Media(*kind),
SpannedTypeExprKind::Optional(inner) => {
TypeExpr::Optional(Box::new(inner.to_type_expr()))
}
SpannedTypeExprKind::List(inner) => TypeExpr::List(Box::new(inner.to_type_expr())),
SpannedTypeExprKind::Map { key, value } => TypeExpr::Map {
key: Box::new(key.to_type_expr()),
value: Box::new(value.to_type_expr()),
},
SpannedTypeExprKind::Union(members) => {
TypeExpr::Union(members.iter().map(SpannedTypeExpr::to_type_expr).collect())
}
SpannedTypeExprKind::Literal(lit) => TypeExpr::Literal(lit.clone()),
SpannedTypeExprKind::Function { params, ret } => TypeExpr::Function {
params: params
.iter()
.map(|p| FunctionTypeParam {
name: p.name.clone(),
ty: p.ty.to_type_expr(),
})
.collect(),
ret: Box::new(ret.to_type_expr()),
},
SpannedTypeExprKind::BuiltinUnknown => TypeExpr::BuiltinUnknown,
SpannedTypeExprKind::Type => TypeExpr::Type,
SpannedTypeExprKind::Rust => TypeExpr::Rust,
SpannedTypeExprKind::Error => TypeExpr::Error,
SpannedTypeExprKind::Unknown => TypeExpr::Unknown,
}
}
}

// ── Expression Bodies ───────────────────────────────────────────
//
// Full expression/statement arena — modeled after the existing
Expand Down
14 changes: 9 additions & 5 deletions baml_language/crates/baml_compiler2_ast/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ class Media {
.expect("expected _data field");

match &field.type_expr {
Some(spanned) => match &spanned.expr {
Some(spanned) => match spanned.to_type_expr() {
TypeExpr::Rust => {}
other => panic!("expected TypeExpr::Rust, got {other:?}"),
},
Expand Down Expand Up @@ -446,7 +446,11 @@ class Media {
assert!(data_field.is_some(), "expected _data field");
assert!(
matches!(
data_field.unwrap().type_expr.as_ref().map(|te| &te.expr),
data_field
.unwrap()
.type_expr
.as_ref()
.map(super::ast::SpannedTypeExpr::to_type_expr),
Some(TypeExpr::Rust)
),
"_data field should have TypeExpr::Rust"
Expand All @@ -467,10 +471,10 @@ function f() -> int throws never {
let throws = func
.throws
.expect("expected throws clause to be lowered into FunctionDef.throws");
let throws_te = throws.to_type_expr();
assert!(
matches!(throws.expr, TypeExpr::Never),
"expected throws type to lower as TypeExpr::Never, got {:?}",
throws.expr
matches!(throws_te, TypeExpr::Never),
"expected throws type to lower as TypeExpr::Never, got {throws_te:?}"
);
}

Expand Down
34 changes: 12 additions & 22 deletions baml_language/crates/baml_compiler2_ast/src/lower_cst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ use crate::{
ast::{
BuiltinKind, ClientDef, ConfigItemDef, EnumDef, FieldDef, FunctionBodyDef, FunctionDef,
GeneratorDef, Interpolation, Item, LlmBodyDef, Param, RawAttribute, RawAttributeArg,
RawPrompt, RetryPolicyDef, SpannedTypeExpr, TemplateStringDef, TestDef, TypeAliasDef,
VariantDef,
RawPrompt, RetryPolicyDef, TemplateStringDef, TestDef, TypeAliasDef, VariantDef,
},
lower_expr_body, lower_type_expr,
};
Expand Down Expand Up @@ -102,18 +101,14 @@ fn lower_function(node: &SyntaxNode) -> Option<FunctionDef> {
.map(|pl| lower_params(&pl))
.unwrap_or_default();

let return_type = func.return_type().map(|te| SpannedTypeExpr {
expr: lower_type_expr::lower_type_expr_node(&te),
span: te.syntax().text_range(),
});
let return_type = func
.return_type()
.map(|te| lower_type_expr::lower_type_expr_node(&te));

let throws = func
.throws_clause()
.and_then(|tc| tc.type_expr())
.map(|te| SpannedTypeExpr {
expr: lower_type_expr::lower_type_expr_node(&te),
span: te.syntax().text_range(),
});
.map(|te| lower_type_expr::lower_type_expr_node(&te));

let body = if let Some(llm) = func.llm_body() {
Some(FunctionBodyDef::Llm(lower_llm_body(&llm)))
Expand Down Expand Up @@ -182,10 +177,9 @@ fn lower_param(param: &ast::Parameter) -> Option<Param> {
let name_token = param.name()?;
Some(Param {
name: Name::new(name_token.text()),
type_expr: param.ty().map(|te| SpannedTypeExpr {
expr: lower_type_expr::lower_type_expr_node(&te),
span: te.syntax().text_range(),
}),
type_expr: param
.ty()
.map(|te| lower_type_expr::lower_type_expr_node(&te)),
span: param.syntax().text_range(),
name_span: name_token.text_range(),
})
Expand Down Expand Up @@ -268,10 +262,7 @@ fn lower_class(node: &SyntaxNode) -> Option<crate::ast::ClassDef> {
let fname = f.name()?;
Some(FieldDef {
name: Name::new(fname.text()),
type_expr: f.ty().map(|te| SpannedTypeExpr {
expr: lower_type_expr::lower_type_expr_node(&te),
span: te.syntax().text_range(),
}),
type_expr: f.ty().map(|te| lower_type_expr::lower_type_expr_node(&te)),
attributes: lower_field_attributes(&f),
span: f.syntax().text_range(),
name_span: fname.text_range(),
Expand Down Expand Up @@ -357,10 +348,9 @@ fn lower_type_alias(node: &SyntaxNode) -> Option<TypeAliasDef> {

Some(TypeAliasDef {
name: Name::new(name_token.text()),
type_expr: alias.ty().map(|te| SpannedTypeExpr {
expr: lower_type_expr::lower_type_expr_node(&te),
span: te.syntax().text_range(),
}),
type_expr: alias
.ty()
.map(|te| lower_type_expr::lower_type_expr_node(&te)),
span: node.text_range(),
name_span: name_token.text_range(),
})
Expand Down
18 changes: 12 additions & 6 deletions baml_language/crates/baml_compiler2_ast/src/lower_expr_body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -586,8 +586,9 @@ impl LoweringContext {
baml_compiler_syntax::ast::TypeExpr::cast(child.clone())
{
let span = child.text_range();
let ty = crate::lower_type_expr::lower_type_expr_node(&type_expr);
scrutinee_type = Some(self.alloc_type_annot(ty, span));
let spanned = crate::lower_type_expr::lower_type_expr_node(&type_expr);
scrutinee_type =
Some(self.alloc_type_annot(spanned.to_type_expr(), span));
}
}
_ => {
Expand Down Expand Up @@ -892,9 +893,12 @@ impl LoweringContext {
if let Some(type_expr) =
baml_compiler_syntax::ast::TypeExpr::cast(child.clone())
{
let ty =
let spanned =
crate::lower_type_expr::lower_type_expr_node(&type_expr);
let pat = Pattern::TypedBinding { name, ty };
let pat = Pattern::TypedBinding {
name,
ty: spanned.to_type_expr(),
};
elements.push(self.alloc_pattern(pat, child.text_range()));
}
}
Expand Down Expand Up @@ -1748,8 +1752,10 @@ impl LoweringContext {
baml_compiler_syntax::ast::TypeExpr::cast(child.clone())
{
let span = child.text_range();
let ty = crate::lower_type_expr::lower_type_expr_node(&type_expr);
type_annotation = Some(self.alloc_type_annot(ty, span));
let spanned =
crate::lower_type_expr::lower_type_expr_node(&type_expr);
type_annotation =
Some(self.alloc_type_annot(spanned.to_type_expr(), span));
seen_colon = false;
}
} else if pattern_id.is_none() {
Expand Down
Loading
Loading