diff --git a/crates/dbt-jinja/minijinja-contrib/src/modules/py_datetime/datetime.rs b/crates/dbt-jinja/minijinja-contrib/src/modules/py_datetime/datetime.rs index 0c6086f3..71e43976 100644 --- a/crates/dbt-jinja/minijinja-contrib/src/modules/py_datetime/datetime.rs +++ b/crates/dbt-jinja/minijinja-contrib/src/modules/py_datetime/datetime.rs @@ -1,10 +1,11 @@ use std::fmt; use std::str::FromStr; use std::sync::Arc; +use std::cmp::Ordering; use chrono::{DateTime, Datelike, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Timelike, Utc}; use chrono_tz::Tz; -use minijinja::{arg_utils::ArgParser, value::Object, Error, ErrorKind, Value}; +use minijinja::{arg_utils::ArgParser, value::{Object, ObjectRepr}, Error, ErrorKind, Value}; use crate::modules::py_datetime::date::PyDate; // your date use crate::modules::py_datetime::time::PyTime; @@ -824,10 +825,24 @@ impl PyDateTime { // Implement the `Object` trait for PyDateTime so Jinja can call methods // impl Object for PyDateTime { + fn repr(self: &Arc) -> ObjectRepr { + ObjectRepr::Plain + } + fn is_true(self: &Arc) -> bool { true } + fn custom_cmp(self: &Arc, other: &minijinja::value::DynObject) -> Option { + // try to downcast the other object to PyDateTime + if let Some(other_dt) = other.downcast_ref::() { + // compare using timestamps, wrap in Some() to satisfy the trait bounds + Some(self.timestamp().total_cmp(&other_dt.timestamp())) + } else { + None + } + } + fn call_method( self: &Arc, _state: &minijinja::State<'_, '_>, diff --git a/crates/dbt-jinja/minijinja-contrib/tests/datetime.rs b/crates/dbt-jinja/minijinja-contrib/tests/datetime.rs index da3e540e..96d19285 100644 --- a/crates/dbt-jinja/minijinja-contrib/tests/datetime.rs +++ b/crates/dbt-jinja/minijinja-contrib/tests/datetime.rs @@ -285,3 +285,199 @@ fn test_timeformat() { .unwrap(); assert_eq!(expr.eval((), &[]).unwrap().to_string(), "19:37"); } + +#[test] +fn test_datetime_direct_comparison() { + let mut env = minijinja::Environment::new(); + minijinja_contrib::add_to_environment(&mut env); + + // Test direct comparison operators now work! + let tmpl = env.template_from_str( + r#" + {%- set dt1 = modules.datetime.datetime(2025, 1, 1, 12, 0, 0) -%} + {%- set dt2 = modules.datetime.datetime(2025, 1, 2, 12, 0, 0) -%} + {%- set dt3 = modules.datetime.datetime(2025, 1, 1, 12, 0, 0) -%} + dt1 < dt2: {{ dt1 < dt2 }} + dt2 > dt1: {{ dt2 > dt1 }} + dt1 <= dt2: {{ dt1 <= dt2 }} + dt2 >= dt1: {{ dt2 >= dt1 }} + dt1 == dt3: {{ dt1 == dt3 }} + dt1 != dt2: {{ dt1 != dt2 }} + dt1 == dt1: {{ dt1 == dt1 }} + "#, + &[] + ).unwrap(); + + let output = tmpl.render(context!{}, &[]).unwrap(); + assert!(output.contains("dt1 < dt2: true")); + assert!(output.contains("dt2 > dt1: true")); + assert!(output.contains("dt1 <= dt2: true")); + assert!(output.contains("dt2 >= dt1: true")); + assert!(output.contains("dt1 == dt3: true")); + assert!(output.contains("dt1 != dt2: true")); + assert!(output.contains("dt1 == dt1: true")); +} + +#[test] +fn test_datetime_direct_comparison_timezone() { + let mut env = minijinja::Environment::new(); + minijinja_contrib::add_to_environment(&mut env); + + // Test timezone-aware comparisons with direct operators + let tmpl = env.template_from_str( + r#" + {%- set utc = modules.pytz.timezone('UTC') -%} + {%- set eastern = modules.pytz.timezone('US/Eastern') -%} + {%- set dt1 = modules.datetime.datetime(2025, 1, 1, 12, 0, 0, tzinfo=utc) -%} + {%- set dt2 = modules.datetime.datetime(2025, 1, 1, 7, 0, 0, tzinfo=eastern) -%} + {%- set dt3 = modules.datetime.datetime(2025, 1, 1, 8, 0, 0, tzinfo=eastern) -%} + same_instant: {{ dt1 == dt2 }} + dt1_before_dt3: {{ dt1 < dt3 }} + dt3_after_dt1: {{ dt3 > dt1 }} + "#, + &[] + ).unwrap(); + + let output = tmpl.render(context!{}, &[]).unwrap(); + assert!(output.contains("same_instant: true")); + assert!(output.contains("dt1_before_dt3: true")); + assert!(output.contains("dt3_after_dt1: true")); +} + +#[test] +fn test_datetime_comparison_with_incompatible_types() { + let mut env = minijinja::Environment::new(); + minijinja_contrib::add_to_environment(&mut env); + + // Test that comparing datetime with non-datetime types returns false + let tmpl = env.template_from_str( + r#" + {%- set dt1 = modules.datetime.datetime(2025, 1, 1, 12, 0, 0) -%} + {%- set num = 42 -%} + {%- set str = "2025-01-01" -%} + {%- set none_val = none -%} + {%- set date_obj = modules.datetime.date(2025, 1, 1) -%} + dt_vs_num: {{ dt1 == num }} + dt_vs_str: {{ dt1 == str }} + dt_vs_none: {{ dt1 == none_val }} + dt_vs_date: {{ dt1 == date_obj }} + "#, + &[] + ).unwrap(); + + let output = tmpl.render(context!{}, &[]).unwrap(); + assert!(output.contains("dt_vs_num: false")); + assert!(output.contains("dt_vs_str: false")); + assert!(output.contains("dt_vs_none: false")); + assert!(output.contains("dt_vs_date: false")); +} + +#[test] +fn test_datetime_invalid_construction_errors() { + let mut env = minijinja::Environment::new(); + minijinja_contrib::add_to_environment(&mut env); + + // Test invalid month + let tmpl = env.template_from_str( + r#" + {%- set dt = modules.datetime.datetime(2025, 13, 1) -%} + "#, + &[] + ).unwrap(); + + let result = tmpl.render(context!{}, &[]); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Invalid date components")); + + // Test invalid day + let tmpl = env.template_from_str( + r#" + {%- set dt = modules.datetime.datetime(2025, 2, 30) -%} + "#, + &[] + ).unwrap(); + + let result = tmpl.render(context!{}, &[]); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Invalid date components")); + + // Test invalid hour + let tmpl = env.template_from_str( + r#" + {%- set dt = modules.datetime.datetime(2025, 1, 1, 25) -%} + "#, + &[] + ).unwrap(); + + let result = tmpl.render(context!{}, &[]); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Invalid time components")); + + // Test invalid minute + let tmpl = env.template_from_str( + r#" + {%- set dt = modules.datetime.datetime(2025, 1, 1, 12, 60) -%} + "#, + &[] + ).unwrap(); + + let result = tmpl.render(context!{}, &[]); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Invalid time components")); +} + +#[test] +fn test_datetime_comparison_edge_cases() { + let mut env = minijinja::Environment::new(); + minijinja_contrib::add_to_environment(&mut env); + + // Test edge cases like same datetime, microsecond differences, etc. + let tmpl = env.template_from_str( + r#" + {%- set dt1 = modules.datetime.datetime(2025, 1, 1, 12, 0, 0, 0) -%} + {%- set dt2 = modules.datetime.datetime(2025, 1, 1, 12, 0, 0, 0) -%} + {%- set dt3 = modules.datetime.datetime(2025, 1, 1, 12, 0, 0, 1) -%} + same_datetime: {{ dt1 == dt2 }} + same_datetime_le: {{ dt1 <= dt2 }} + same_datetime_ge: {{ dt1 >= dt2 }} + microsecond_diff: {{ dt1 < dt3 }} + microsecond_ne: {{ dt1 != dt3 }} + "#, + &[] + ).unwrap(); + + let output = tmpl.render(context!{}, &[]).unwrap(); + assert!(output.contains("same_datetime: true")); + assert!(output.contains("same_datetime_le: true")); + assert!(output.contains("same_datetime_ge: true")); + assert!(output.contains("microsecond_diff: true")); + assert!(output.contains("microsecond_ne: true")); +} + +#[test] +fn test_datetime_sorting() { + let mut env = minijinja::Environment::new(); + minijinja_contrib::add_to_environment(&mut env); + + // Test that datetimes can be sorted correctly + let tmpl = env.template_from_str( + r#" + {%- set dates = [ + modules.datetime.datetime(2025, 3, 15), + modules.datetime.datetime(2025, 1, 1), + modules.datetime.datetime(2025, 12, 31), + modules.datetime.datetime(2025, 6, 15), + modules.datetime.datetime(2025, 1, 1) + ] -%} + {%- set sorted_dates = dates | sort %} + {%- for date in sorted_dates %} + {{ date.strftime("%Y-%m-%d") }} + {%- endfor %} + "#, + &[] + ).unwrap(); + + let output = tmpl.render(context!{}, &[]).unwrap(); + let dates: Vec<&str> = output.trim().split('\n').map(|s| s.trim()).collect(); + assert_eq!(dates, vec!["2025-01-01", "2025-01-01", "2025-03-15", "2025-06-15", "2025-12-31"]); +} diff --git a/crates/dbt-jinja/minijinja-py/tests/test_basic.py b/crates/dbt-jinja/minijinja-py/tests/test_basic.py index 9bd6163c..cd01b324 100644 --- a/crates/dbt-jinja/minijinja-py/tests/test_basic.py +++ b/crates/dbt-jinja/minijinja-py/tests/test_basic.py @@ -2,6 +2,7 @@ import pytest import posixpath import types +from functools import total_ordering from _pytest.unraisableexception import catch_unraisable_exception from minijinja import ( @@ -349,3 +350,27 @@ def test_custom_delimiters(): ) rv = env.render_str("<% if true %>${ value }<% endif %>", value=42) assert rv == "42" + +def test_pass_through_sort(): + @total_ordering + class X(object): + def __init__(self, value): + self.value = value + + def __eq__(self, other): + if type(self) is not type(other): + return NotImplemented + return self.value == other.value + + def __lt__(self, other): + if type(self) is not type(other): + return NotImplemented + return self.value < other.value + + def __str__(self): + return str(self.value) + + values = [X(4), X(23), X(42), X(-1)] + env = Environment() + rv = env.render_str("{{ values|sort|join(',') }}", values=values) + assert rv == "-1,4,23,42" \ No newline at end of file diff --git a/crates/dbt-jinja/minijinja/src/value/mod.rs b/crates/dbt-jinja/minijinja/src/value/mod.rs index 8d46716b..5e2faa2b 100644 --- a/crates/dbt-jinja/minijinja/src/value/mod.rs +++ b/crates/dbt-jinja/minijinja/src/value/mod.rs @@ -542,6 +542,10 @@ impl PartialEq for Value { if let (Some(a), Some(b)) = (self.as_object(), other.as_object()) { if a.is_same_object(b) { return true; + } else if a.is_same_object_type(b) { + if let Some(rv) = a.custom_cmp(b) { + return rv == Ordering::Equal; + } } match (a.repr(), b.repr()) { (ObjectRepr::Map, ObjectRepr::Map) => { @@ -629,17 +633,26 @@ impl Ord for Value { Some(ops::CoerceResult::I128(a, b)) => a.cmp(&b), Some(ops::CoerceResult::Str(a, b)) => a.cmp(b), None => { - if let (Some(a), Some(b)) = (self.as_object(), other.as_object()) { - if a.is_same_object(b) { - Ordering::Equal - } else { - match (a.repr(), b.repr()) { - (ObjectRepr::Map, ObjectRepr::Map) => { - // This is not really correct. Because the keys can be in arbitrary - // order this could just sort really weirdly as a result. However - // we don't want to pay the cost of actually sorting the keys for - // ordering so we just accept this for now. - match (a.try_iter_pairs(), b.try_iter_pairs()) { + // if coerce fails and kinds match, both must be objects + let a = self.as_object().unwrap(); + let b = other.as_object().unwrap(); + + if a.is_same_object(b) { + Ordering::Equal + } else { + // if there is a custom comparison, run it. + if a.is_same_object_type(b) { + if let Some(rv) = a.custom_cmp(b) { + return rv; + } + } + match (a.repr(), b.repr()) { + (ObjectRepr::Map, ObjectRepr::Map) => { + // This is not really correct. Because the keys can be in arbitrary + // order this could just sort really weirdly as a result. However + // we don't want to pay the cost of actually sorting the keys for + // ordering so we just accept this for now. + match (a.try_iter_pairs(), b.try_iter_pairs()) { (Some(a), Some(b)) => a.cmp(b), _ => unreachable!(), } @@ -651,11 +664,13 @@ impl Ord for Value { (Some(a), Some(b)) => a.cmp(b), _ => unreachable!(), }, + // terrible fallback for plain objects + (ObjectRepr::Plain, ObjectRepr::Plain) => { + a.to_string().cmp(&b.to_string()) + } + // should not happen (_, _) => unreachable!(), - } } - } else { - unreachable!() } } }, diff --git a/crates/dbt-jinja/minijinja/src/value/object.rs b/crates/dbt-jinja/minijinja/src/value/object.rs index d89cedeb..fbc316fd 100644 --- a/crates/dbt-jinja/minijinja/src/value/object.rs +++ b/crates/dbt-jinja/minijinja/src/value/object.rs @@ -1,9 +1,11 @@ +use std::any::Any; use std::borrow::Cow; use std::collections::BTreeMap; use std::fmt; use std::hash::Hash; use std::rc::Rc; use std::sync::Arc; +use std::cmp::Ordering; use crate::error::{Error, ErrorKind}; use crate::listener::RenderingEventListener; @@ -265,6 +267,35 @@ pub trait Object: fmt::Debug + Send + Sync { Err(Error::from(ErrorKind::UnknownMethod)) } + /// Custom comparison of this object against another object of the same type. + /// + /// This must return either `None` or `Some(Ordering)`. When implemented this + /// must guarantee a total ordering as otherwise sort functions will crash. + /// This will only compare against other objects of the same type, not + /// anything else. Objects of different types are given an absolute + /// ordering outside the scope of this method. + /// + /// The requirement is that an implementer downcasts the other [`DynObject`] + /// to itself, and it that cannot be accomplished `None` must be returned. + /// + /// ```rust + /// # use std::sync::Arc; + /// # use std::cmp::Ordering; + /// # use minijinja::value::{DynObject, Object}; + /// # #[derive(Debug)] + /// # struct Thing { num: u32 }; + /// impl Object for Thing { + /// fn custom_cmp(self: &Arc, other: &DynObject) -> Option { + /// let other = other.downcast_ref::()?; + /// Some(self.num.cmp(&other.num)) + /// } + /// } + /// ``` + fn custom_cmp(self: &Arc, other: &DynObject) -> Option { + let _ = other; + None + } + /// Formats the object for stringification. /// /// The default implementation is specific to the behavior of @@ -681,6 +712,8 @@ type_erase! { listeners: &[Rc] ) -> Result; + fn custom_cmp(&self, other: &DynObject) -> Option; + fn render(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result; impl fmt::Debug { @@ -699,6 +732,11 @@ impl DynObject { pub(crate) fn is_same_object(&self, other: &DynObject) -> bool { self.ptr == other.ptr && self.vtable == other.vtable } + + /// Checks if the two dyn objects are of the same type. + pub(crate) fn is_same_object_type(&self, other: &DynObject) -> bool { + self.type_id() == other.type_id() + } } impl Hash for DynObject { diff --git a/crates/dbt-jinja/minijinja/tests/test_value.rs b/crates/dbt-jinja/minijinja/tests/test_value.rs index e5f991d8..f34b70f0 100644 --- a/crates/dbt-jinja/minijinja/tests/test_value.rs +++ b/crates/dbt-jinja/minijinja/tests/test_value.rs @@ -1,5 +1,7 @@ +use std::cmp::Ordering; use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet, LinkedList, VecDeque}; use std::sync::Arc; +use std::fmt; use insta::{assert_debug_snapshot, assert_snapshot}; use similar_asserts::assert_eq; @@ -1458,3 +1460,37 @@ fn test_bytes() { "b'\\'foo\"'" ); } + +#[test] +fn test_custom_object_compare() { + #[derive(Debug)] + struct X(i32); + + impl Object for X { + fn repr(self: &Arc) -> ObjectRepr { + ObjectRepr::Plain + } + + fn custom_cmp(self: &Arc, other: &DynObject) -> Option { + let other = other.downcast_ref::()?; + Some(self.0.cmp(&other.0)) + } + + fn render(self: &Arc, f: &mut fmt::Formatter<'_>) -> fmt::Result + where + Self: Sized + 'static, + { + write!(f, "{}", self.0) + } + } + + let nums = (0..5) + .map(X) + .map(Value::from_object) + .rev() + .collect::>(); + let seq = Value::from_object(nums); + + let rv = render!("{{ seq|sort|join('|') }}", seq); + assert_eq!(rv, "0|1|2|3|4"); +}