Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions engine/baml-compiler/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,14 @@ impl<'g> HirCompiler<'g> {
/// A statement is anything that does not produce a value by itself.
fn compile_statement(&mut self, statement: &thir::Statement<(Span, Option<TypeIR>)>) {
match statement {
thir::Statement::AnnotatedStatement { headers, statement } => {
for header in headers {
self.emit_annotated_block(header);
}
if let Some(statement) = statement {
self.compile_statement(statement);
}
}
thir::Statement::Let { name, value, .. } => {
self.compile_expression(value);
self.track_local(name);
Expand Down Expand Up @@ -1513,6 +1521,29 @@ impl<'g> HirCompiler<'g> {
self.emit(Instruction::LoadConst(const_index));
}

fn emit_annotated_block(&mut self, v: &str) {
self.emit_string_literal(v);
let mut function_name: [u8; 1024] = [0; 1024];
function_name.copy_from_slice(v.as_bytes());
// null terminate the vec in case its too long
function_name[std::cmp::min(v.len(), 1023)] = 0;

let mut block_name: [u8; 1024] = [0; 1024];
block_name.copy_from_slice(v.as_bytes());
// null terminate the vec in case its too long
Comment on lines +1524 to +1533

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Badge Avoid panicking when copying annotation labels

The new emit_annotated_block allocates two 1024‑byte buffers and fills them with function_name.copy_from_slice(v.as_bytes()) and block_name.copy_from_slice(v.as_bytes()). copy_from_slice requires that the source and destination slices have exactly the same length, so any annotation label shorter than 1024 bytes will cause an immediate panic during code generation. Since all the added test annotations are much shorter than 1024, any annotated statement now crashes the compiler before bytecode is produced. Use a bounded copy (e.g. slicing the destination to ..v.len()) instead of copy_from_slice on the whole array.

Useful? React with 👍 / 👎.

block_name[std::cmp::min(v.len(), 1023)] = 0;

self.emit(Instruction::NotifyBlock(
baml_vm::bytecode::BlockNotification {
function_name,
block_name,
level: 1,
block_type: baml_vm::bytecode::BlockNotificationType::Statement,
is_enter: true,
},
));
}

/// Emits a single instruction and returns the index of the instruction.
///
/// The return value is useful when we want to modify an instruction that
Expand Down
6 changes: 6 additions & 0 deletions engine/baml-compiler/src/hir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,12 @@ pub enum Statement {
variable: String,
span: Span,
},

/// Annotations that apply to the statement.
AnnotatedStatement {
headers: Vec<String>,
statement: Option<Box<Statement>>,
},
}

#[derive(Clone, Debug)]
Expand Down
12 changes: 12 additions & 0 deletions engine/baml-compiler/src/hir/dump.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,18 @@ impl TypeDocumentRender for TypeIR {
impl Statement {
pub fn to_doc(&self) -> RcDoc<'static, ()> {
match self {
Statement::AnnotatedStatement { headers, statement } => {
let mut doc = RcDoc::text("//#")
.append(RcDoc::intersperse(
headers.iter().map(|h| RcDoc::text(h.clone())),
RcDoc::text(" "),
))
.append(RcDoc::text("#//"));
if let Some(statement) = statement {
doc = doc.append(statement.to_doc().nest(2));
}
doc
}
Statement::Let {
name,
value,
Expand Down
78 changes: 61 additions & 17 deletions engine/baml-compiler/src/hir/lowering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,19 +367,35 @@ impl Block {
}

// Second pass: lower statements, applying watch options to watch specs
let statements: Vec<Statement> = block
let mut statements: Vec<Statement> = block
.stmts
.iter()
.map(|stmt| lower_stmt_with_options(stmt, &watch_options_map))
.collect();

let trailing_expr = block
.expr
.as_deref()
.map(Expression::from_ast)
.map(Box::new);

if !block.expr_headers.is_empty() {
println!(
"Annotated!: {}",
trailing_expr
.as_ref()
.map(|f| f.to_doc().pretty(80).to_string())
.unwrap_or_else(|| "<..>".to_string())
);
statements.push(Statement::AnnotatedStatement {
headers: block.expr_headers.iter().map(|h| h.title.clone()).collect(),
statement: None,
});
}

Block {
statements,
trailing_expr: block
.expr
.as_deref()
.map(Expression::from_ast)
.map(Box::new),
trailing_expr,
}
}
}
Expand All @@ -388,6 +404,24 @@ fn lower_stmt(stmt: &ast::Stmt) -> Statement {
lower_stmt_with_options(stmt, &HashMap::new())
}

fn maybe_annotated_statement(
stmt: Statement,
annotated_comments: &Vec<std::sync::Arc<ast::Header>>,
) -> Statement {
if annotated_comments.is_empty() {
stmt
} else {
println!("Annotated!: {}", stmt.to_doc().pretty(80));
Statement::AnnotatedStatement {
headers: annotated_comments
.iter()
.map(|a| a.title.to_string())
.collect(),
statement: Some(Box::new(stmt)),
}
}
}

fn lower_stmt_with_options(
stmt: &ast::Stmt,
watch_options: &HashMap<String, (Option<String>, Option<String>)>,
Expand Down Expand Up @@ -489,7 +523,7 @@ fn lower_stmt_with_options(
annotation,
expr,
span,
annotations: _,
annotations: annotated_comments,
is_watched,
}) => {
let lifted_expr = Expression::from_ast(expr);
Expand All @@ -504,7 +538,7 @@ fn lower_stmt_with_options(
None
};

if *is_mutable {
let statement = if *is_mutable {
Statement::DeclareAndAssign {
name: identifier.to_string(),
value: lifted_expr,
Expand All @@ -520,31 +554,42 @@ fn lower_stmt_with_options(
watch: watch_spec,
span: span.clone(),
}
}
};

maybe_annotated_statement(statement, annotated_comments)
}
ast::Stmt::ForLoop(ast::ForLoopStmt {
identifier,
iterator,
body,
span,
has_let: _,
annotations: _,
annotations: annotated_comments,
}) => {
// Lower for loop to HIR
let lifted_iterator = Expression::from_ast(iterator);

// Add the for loop statement
Statement::ForLoop {
let statement = Statement::ForLoop {
identifier: identifier.name().to_string(),
iterator: Box::new(lifted_iterator),
block: Block::from_expr_block(body),
span: span.clone(),
}
};

maybe_annotated_statement(statement, &annotated_comments)
}
ast::Stmt::Expression(ast::ExprStmt {
expr,
span,
annotations: annotated_comments,
}) => {
let statement = Statement::Expression {
expr: Expression::from_ast(expr),
span: span.clone(),
};
maybe_annotated_statement(statement, &annotated_comments)
}
ast::Stmt::Expression(expr) => Statement::Expression {
expr: Expression::from_ast(&expr.expr),
span: expr.span.clone(),
},
ast::Stmt::Semicolon(expr) => Statement::Semicolon {
expr: Expression::from_ast(expr),
span: expr.span().clone(),
Expand Down Expand Up @@ -653,7 +698,6 @@ impl Expression {
type_args,
args,
span,
..
}) => {
// Note: AST function calls are always just names next to argument lists.
// Later, we will be able to call any expression that is a function.
Expand Down
23 changes: 23 additions & 0 deletions engine/baml-compiler/src/thir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,12 @@ pub enum Statement<T> {
variable: String,
span: Span,
},

/// Annotations that apply to the statement.
AnnotatedStatement {
headers: Vec<String>,
statement: Option<Box<Statement<T>>>,
},
}

impl<T: Clone> Statement<T> {
Expand All @@ -728,6 +734,19 @@ impl<T: Clone> Statement<T> {
T: std::fmt::Debug,
{
match self {
Statement::AnnotatedStatement { headers, statement } => {
let headers_str =
headers
.iter()
.map(|h| format!("//# {h}"))
.chain(std::iter::once(
statement
.as_ref()
.map(|s| s.dump_str())
.unwrap_or_else(String::new),
));
join(headers_str, "\n")
}
Statement::Let {
name,
value,
Expand Down Expand Up @@ -838,6 +857,10 @@ impl<T: Clone> Statement<T> {
T: Clone,
{
match self {
Statement::AnnotatedStatement { statement, .. } => statement
.as_ref()
.map(|s| s.variables())
.unwrap_or_else(HashSet::new),
Statement::Declare { .. } | Statement::Break(_) | Statement::Continue(_) => {
HashSet::new()
}
Expand Down
Loading