diff --git a/docs/reference/model_configuration.md b/docs/reference/model_configuration.md index 9d040fe6db..f5dd0eadf0 100644 --- a/docs/reference/model_configuration.md +++ b/docs/reference/model_configuration.md @@ -178,6 +178,7 @@ The SQLMesh project-level `model_defaults` key supports the following options, d - kind - dialect - cron +- cron_tz - owner - start - table_format diff --git a/sqlmesh/core/config/model.py b/sqlmesh/core/config/model.py index aeefdf2557..41a22d128f 100644 --- a/sqlmesh/core/config/model.py +++ b/sqlmesh/core/config/model.py @@ -14,7 +14,7 @@ OnAdditiveChange, ) from sqlmesh.core.model.meta import FunctionCall -from sqlmesh.core.node import IntervalUnit +from sqlmesh.core.node import IntervalUnit, cron_tz_validator from sqlmesh.utils.date import TimeLike from sqlmesh.utils.pydantic import field_validator @@ -27,6 +27,7 @@ class ModelDefaultsConfig(BaseConfig): dialect: The SQL dialect that the model's query is written in. cron: A cron string specifying how often the model should be refreshed, leveraging the [croniter](https://github.com/kiorky/croniter) library. + cron_tz: The timezone for the cron expression, defaults to UTC. [IANA time zones](https://docs.python.org/3/library/zoneinfo.html). owner: The owner of the model. start: The earliest date that the model will be backfilled for. If this is None, then the date is inferred by taking the most recent start date of its ancestors. @@ -55,6 +56,7 @@ class ModelDefaultsConfig(BaseConfig): kind: t.Optional[ModelKind] = None dialect: t.Optional[str] = None cron: t.Optional[str] = None + cron_tz: t.Any = None owner: t.Optional[str] = None start: t.Optional[TimeLike] = None table_format: t.Optional[str] = None @@ -78,6 +80,7 @@ class ModelDefaultsConfig(BaseConfig): _model_kind_validator = model_kind_validator _on_destructive_change_validator = on_destructive_change_validator _on_additive_change_validator = on_additive_change_validator + _cron_tz_validator = cron_tz_validator @field_validator("audits", mode="before") def _audits_validator(cls, v: t.Any) -> t.Any: diff --git a/sqlmesh/core/node.py b/sqlmesh/core/node.py index 4a3bf2564b..c9dd087a10 100644 --- a/sqlmesh/core/node.py +++ b/sqlmesh/core/node.py @@ -260,6 +260,30 @@ def dbt_fqn(self) -> t.Optional[str]: } +def _cron_tz_validator(cls: t.Type, v: t.Any) -> t.Optional[zoneinfo.ZoneInfo]: + if not v or v == "UTC": + return None + + v = str_or_exp_to_str(v) + + try: + return zoneinfo.ZoneInfo(v) + except Exception as e: + available_timezones = zoneinfo.available_timezones() + + if available_timezones: + raise ConfigError(f"{e}. {v} must be in {available_timezones}.") + else: + raise ConfigError( + f"{e}. IANA time zone data is not available on your system. `pip install tzdata` to leverage cron time zones or remove this field which will default to UTC." + ) + + return None + + +cron_tz_validator = field_validator("cron_tz", mode="before")(_cron_tz_validator) + + class _Node(DbtInfoMixin, PydanticModel): """ Node is the core abstraction for entity that can be executed within the scheduler. @@ -302,6 +326,8 @@ class _Node(DbtInfoMixin, PydanticModel): _croniter: t.Optional[CroniterCache] = None __inferred_interval_unit: t.Optional[IntervalUnit] = None + _cron_tz_validator = cron_tz_validator + def __str__(self) -> str: path = f": {self._path.name}" if self._path else "" return f"{self.__class__.__name__}<{self.name}{path}>" @@ -328,27 +354,6 @@ def _name_validator(cls, v: t.Any) -> t.Optional[str]: return v.meta["sql"] return str(v) - @field_validator("cron_tz", mode="before") - def _cron_tz_validator(cls, v: t.Any) -> t.Optional[zoneinfo.ZoneInfo]: - if not v or v == "UTC": - return None - - v = str_or_exp_to_str(v) - - try: - return zoneinfo.ZoneInfo(v) - except Exception as e: - available_timezones = zoneinfo.available_timezones() - - if available_timezones: - raise ConfigError(f"{e}. {v} must be in {available_timezones}.") - else: - raise ConfigError( - f"{e}. IANA time zone data is not available on your system. `pip install tzdata` to leverage cron time zones or remove this field which will default to UTC." - ) - - return None - @field_validator("start", "end", mode="before") @classmethod def _date_validator(cls, v: t.Any) -> t.Optional[TimeLike]: diff --git a/tests/core/test_config.py b/tests/core/test_config.py index d0fad16e76..270210432b 100644 --- a/tests/core/test_config.py +++ b/tests/core/test_config.py @@ -930,6 +930,63 @@ def test_gateway_model_defaults(tmp_path): assert ctx.config.model_defaults == expected +def test_model_defaults_cron_tz(tmp_path): + """Test that cron_tz can be set in model_defaults.""" + import zoneinfo + + config_path = tmp_path / "config_model_defaults_cron_tz.yaml" + with open(config_path, "w", encoding="utf-8") as fd: + fd.write( + """ +model_defaults: + dialect: duckdb + cron: '@daily' + cron_tz: 'America/Los_Angeles' + """ + ) + + config = load_config_from_paths( + Config, + project_paths=[config_path], + ) + + assert config.model_defaults.cron == "@daily" + assert config.model_defaults.cron_tz == zoneinfo.ZoneInfo("America/Los_Angeles") + assert config.model_defaults.cron_tz.key == "America/Los_Angeles" + + +def test_gateway_model_defaults_cron_tz(tmp_path): + """Test that cron_tz can be set in gateway-specific model_defaults.""" + import zoneinfo + + global_defaults = ModelDefaultsConfig( + dialect="snowflake", owner="foo", cron="@daily", cron_tz="UTC" + ) + gateway_defaults = ModelDefaultsConfig(dialect="duckdb", cron_tz="America/New_York") + + config = Config( + gateways={ + "duckdb": GatewayConfig( + connection=DuckDBConnectionConfig(database="db.db"), + model_defaults=gateway_defaults, + ) + }, + model_defaults=global_defaults, + default_gateway="duckdb", + ) + + ctx = Context(paths=tmp_path, config=config, gateway="duckdb") + + expected = ModelDefaultsConfig( + dialect="duckdb", owner="foo", cron="@daily", cron_tz="America/New_York" + ) + + assert ctx.config.model_defaults == expected + # Also verify the cron_tz is a ZoneInfo object + assert isinstance(ctx.config.model_defaults.cron_tz, zoneinfo.ZoneInfo) + assert ctx.config.model_defaults.cron_tz.key == "America/New_York" + + def test_redshift_merge_flag(tmp_path, mocker: MockerFixture): config_path = tmp_path / "config_redshift_merge.yaml" with open(config_path, "w", encoding="utf-8") as fd: