Skip to content

Commit 716b7b1

Browse files
authored
fix literal return values (#2663)
Fixes several typechecker subsumption bugs
1 parent 1a4bcc5 commit 716b7b1

File tree

4 files changed

+30
-21
lines changed

4 files changed

+30
-21
lines changed

engine/baml-compiler/src/thir/typecheck.rs

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ pub fn typecheck_returning_context<'a>(
316316
.as_ref()
317317
.and_then(|e| Some((e, e.meta().1.as_ref()?)))
318318
{
319-
if !types_compatible(expr_return_type, &func.return_type) {
319+
if !expr_return_type.is_subtype(&func.return_type) {
320320
diagnostics.push_error(DatamodelError::new_validation_error(
321321
&format!(
322322
"Return type mismatch: function return type is {} but got {}",
@@ -1681,7 +1681,7 @@ pub fn typecheck_expression(
16811681

16821682
// Check if argument type matches expected type
16831683
if let Some(arg_type) = typed_arg.meta().1.as_ref() {
1684-
if !types_compatible(arg_type, expected_type) {
1684+
if !arg_type.is_subtype(expected_type) {
16851685
diagnostics.push_error(DatamodelError::new_validation_error(
16861686
"Type mismatch in argument",
16871687
arg.span(),
@@ -2238,7 +2238,7 @@ pub fn typecheck_expression(
22382238
}
22392239
}
22402240
_ => {
2241-
if !types_compatible(arg_type, expected_type) {
2241+
if !arg_type.is_subtype(expected_type) {
22422242
diagnostics.push_error(
22432243
DatamodelError::new_validation_error(
22442244
&format!(
@@ -3031,22 +3031,6 @@ fn typecheck_emit(
30313031
}
30323032
}
30333033

3034-
/// Check if two types are compatible (for now, just equality)
3035-
fn types_compatible(actual: &TypeIR, expected: &TypeIR) -> bool {
3036-
match (actual, expected) {
3037-
(TypeIR::Top(_), _) | (_, TypeIR::Top(_)) => true,
3038-
(TypeIR::Primitive(a, _), TypeIR::Primitive(b, _)) => a == b,
3039-
(TypeIR::List(a, _), TypeIR::List(b, _)) => types_compatible(a, b),
3040-
(TypeIR::Map(k1, v1, _), TypeIR::Map(k2, v2, _)) => {
3041-
types_compatible(k1, k2) && types_compatible(v1, v2)
3042-
}
3043-
(TypeIR::Class { name: a, .. }, TypeIR::Class { name: b, .. }) => a == b,
3044-
(TypeIR::Enum { name: a, .. }, TypeIR::Enum { name: b, .. }) => a == b,
3045-
// TODO: Handle union types, subtyping, etc.
3046-
_ => false,
3047-
}
3048-
}
3049-
30503034
pub trait TypeCompatibility {
30513035
fn is_optional(&self) -> bool;
30523036
fn is_subtype(&self, expected: &TypeIR) -> bool;
@@ -3069,6 +3053,8 @@ impl TypeCompatibility for TypeIR {
30693053
}
30703054

30713055
/// Return true if `self` is a subtype of `expected`.
3056+
/// TODO: Remove wildcard match
3057+
/// TODO: This needs to account for type aliases.
30723058
fn is_subtype(&self, expected: &TypeIR) -> bool {
30733059
// Semantics similar to IR's `IntermediateRepr::is_subtype`:
30743060
// - Unions on the right: self <: (e1 | e2 | ...) if exists ei s.t. self <: ei
@@ -3097,6 +3083,10 @@ impl TypeCompatibility for TypeIR {
30973083
TypeIR::Primitive(baml_types::TypeValue::Null, _),
30983084
TypeIR::Primitive(baml_types::TypeValue::Null, _),
30993085
) => true,
3086+
(
3087+
TypeIR::Primitive(baml_types::TypeValue::Media(x), _),
3088+
TypeIR::Primitive(baml_types::TypeValue::Media(y), _),
3089+
) => x == y,
31003090

31013091
// Arrays: covariant element
31023092
(TypeIR::List(a_item, _), TypeIR::List(e_item, _)) => a_item.is_subtype(e_item),

engine/baml-compiler/tests/executor.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ fn run_test(test: &InterpreterTest) -> Result<()> {
242242
/// Main test function that discovers and runs all interpreter tests
243243
#[test]
244244
fn executor_tests() -> Result<()> {
245-
let test_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/interpreter_tests");
245+
let test_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/executor_tests");
246246

247247
if !test_dir.exists() {
248248
// No tests yet, skip

engine/baml-runtime/tests/executor.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ fn update_test_expectation(test: &InterpreterTest, new_output: &str) -> Result<(
340340
/// Main test function that discovers and runs all interpreter tests
341341
#[tokio::test]
342342
async fn executor_tests() -> Result<()> {
343-
let test_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/interpreter_tests");
343+
let test_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/executor_tests");
344344

345345
if !test_dir.exists() {
346346
// No tests yet, skip
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
class Bar {
2+
bar_inner int
3+
}
4+
5+
class Baz {
6+
baz_inner string
7+
}
8+
9+
function Foo() -> (Bar | Baz)[] {
10+
[Bar{ bar_inner: 1}]
11+
}
12+
13+
//>Foo()
14+
//
15+
// [
16+
// {
17+
// "bar_inner": 1
18+
// }
19+
// ]

0 commit comments

Comments
 (0)