diff --git a/.changes/unreleased/Features-20250731-162638.yaml b/.changes/unreleased/Features-20250731-162638.yaml new file mode 100644 index 00000000000..28713045b91 --- /dev/null +++ b/.changes/unreleased/Features-20250731-162638.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Hook prioritization (on-run-start & on-run-end) +time: 2025-07-31T16:26:38.632262+03:00 +custom: + Author: ammarchalifah-bolt + Issue: "10592" diff --git a/.flake8 b/.flake8 index 084d3c0163a..063c1116ae5 100644 --- a/.flake8 +++ b/.flake8 @@ -4,12 +4,12 @@ select = W F ignore = - W503 # makes Flake8 work like black + W503 W504 - E203 # makes Flake8 work like black - E704 # makes Flake8 work like black + E203 + E704 E741 - E501 # long line checking is done in black + E501 exclude = test/ per-file-ignores = */__init__.py: F401 diff --git a/core/dbt/artifacts/resources/v1/config.py b/core/dbt/artifacts/resources/v1/config.py index 903fbcb53cd..7a06553409b 100644 --- a/core/dbt/artifacts/resources/v1/config.py +++ b/core/dbt/artifacts/resources/v1/config.py @@ -41,6 +41,7 @@ class Hook(dbtClassMixin): sql: str transaction: bool = True index: Optional[int] = None + priority: Optional[int] = None @dataclass diff --git a/core/dbt/cli/main.py b/core/dbt/cli/main.py index 5917aeb42f8..851ab5e9146 100644 --- a/core/dbt/cli/main.py +++ b/core/dbt/cli/main.py @@ -19,7 +19,6 @@ from dbt.mp_context import get_mp_context from dbt_common.events.base_types import EventMsg - @dataclass class dbtRunnerResult: """Contains the result of an invocation of the dbtRunner""" diff --git a/core/dbt/config/project.py b/core/dbt/config/project.py index fe21cdc9a80..92b067a8cfd 100644 --- a/core/dbt/config/project.py +++ b/core/dbt/config/project.py @@ -477,8 +477,8 @@ def create_project(self, rendered: RenderComponents) -> "Project": vars_value = VarProvider(vars_dict) # There will never be any project_env_vars when it's first created project_env_vars: Dict[str, Any] = {} - on_run_start: List[str] = value_or(cfg.on_run_start, []) - on_run_end: List[str] = value_or(cfg.on_run_end, []) + on_run_start: List[Union[str, Dict[str, int]]] = value_or(cfg.on_run_start, []) + on_run_end: List[Union[str, Dict[str, int]]] = value_or(cfg.on_run_end, []) query_comment = _query_comment_from_cfg(cfg.query_comment) packages: PackageConfig = package_config_from_data( @@ -632,8 +632,8 @@ class Project: packages_specified_path: str quoting: Dict[str, Any] models: Dict[str, Any] - on_run_start: List[str] - on_run_end: List[str] + on_run_start: List[Union[str, Dict[str, int]]] + on_run_end: List[Union[str, Dict[str, int]]] dispatch: List[Dict[str, Any]] seeds: Dict[str, Any] snapshots: Dict[str, Any] diff --git a/core/dbt/constants.py b/core/dbt/constants.py index 827c243a170..f8e7c41eb85 100644 --- a/core/dbt/constants.py +++ b/core/dbt/constants.py @@ -26,3 +26,7 @@ RUN_RESULTS_FILE_NAME = "run_results.json" CATALOG_FILENAME = "catalog.json" SOURCE_RESULT_FILE_NAME = "sources.json" + +# Hook priority constants +DEFAULT_HOOK_PRIORITY = 50 +PROJECT_HOOK_PRIORITY = 100 diff --git a/core/dbt/include/jsonschemas/project/0.0.110.json b/core/dbt/include/jsonschemas/project/0.0.110.json index adb59ba7958..4e2f1a0eadb 100644 --- a/core/dbt/include/jsonschemas/project/0.0.110.json +++ b/core/dbt/include/jsonschemas/project/0.0.110.json @@ -145,6 +145,36 @@ { "$ref": "#/definitions/StringOrArrayOfStrings" }, + { + "anyOf": [ + { + "type": "string" + }, + { + "type": "array", + "items": { + "type": "string" + } + }, + { + "type": "array", + "items": { + "type": "object", + "properties": { + "sql": { + "type": "string" + }, + "priority": { + "type": "integer" + } + }, + "required": [ + "sql" + ] + } + } + ] + }, { "type": "null" } @@ -155,6 +185,36 @@ { "$ref": "#/definitions/StringOrArrayOfStrings" }, + { + "anyOf": [ + { + "type": "string" + }, + { + "type": "array", + "items": { + "type": "string" + } + }, + { + "type": "array", + "items": { + "type": "object", + "properties": { + "sql": { + "type": "string" + }, + "priority": { + "type": "integer" + } + }, + "required": [ + "sql" + ] + } + } + ] + }, { "type": "null" } diff --git a/core/dbt/include/jsonschemas/project/0.0.85.json b/core/dbt/include/jsonschemas/project/0.0.85.json index 215ce89ec96..79d4ed97db2 100644 --- a/core/dbt/include/jsonschemas/project/0.0.85.json +++ b/core/dbt/include/jsonschemas/project/0.0.85.json @@ -124,6 +124,23 @@ }, { "type": "null" + }, + { + "type": "array", + "items": { + "type": "object", + "properties": { + "sql": { + "type": "string" + }, + "priority": { + "type": "integer" + } + }, + "required": [ + "sql" + ] + } } ] }, @@ -134,6 +151,23 @@ }, { "type": "null" + }, + { + "type": "array", + "items": { + "type": "object", + "properties": { + "sql": { + "type": "string" + }, + "priority": { + "type": "integer" + } + }, + "required": [ + "sql" + ] + } } ] }, diff --git a/core/dbt/include/jsonschemas/resources/0.0.110.json b/core/dbt/include/jsonschemas/resources/0.0.110.json index cfeb483c408..6da73ebcb66 100644 --- a/core/dbt/include/jsonschemas/resources/0.0.110.json +++ b/core/dbt/include/jsonschemas/resources/0.0.110.json @@ -550,8 +550,16 @@ "boolean", "null" ] + }, + "priority": { + "type": [ + "integer", + "null" + ] } - } + }, + "additionalProperties": false, + "required": ["sql", "transaction"] }, "Hooks": { "anyOf": [ diff --git a/core/dbt/include/jsonschemas/resources/0.0.85.json b/core/dbt/include/jsonschemas/resources/0.0.85.json index 047dd1d3d36..1078124e983 100644 --- a/core/dbt/include/jsonschemas/resources/0.0.85.json +++ b/core/dbt/include/jsonschemas/resources/0.0.85.json @@ -497,8 +497,16 @@ "boolean", "null" ] + }, + "priority": { + "type": [ + "integer", + "null" + ] } - } + }, + "additionalProperties": false, + "required": ["sql", "transaction"] }, "Hooks": { "anyOf": [ diff --git a/core/dbt/include/jsonschemas/resources/latest.json b/core/dbt/include/jsonschemas/resources/latest.json index 689cb84579a..355744da368 100644 --- a/core/dbt/include/jsonschemas/resources/latest.json +++ b/core/dbt/include/jsonschemas/resources/latest.json @@ -2076,9 +2076,16 @@ "boolean", "null" ] + }, + "priority": { + "type": [ + "integer", + "null" + ] } }, - "additionalProperties": false + "additionalProperties": false, + "required": ["sql", "transaction"] }, "Hooks": { "anyOf": [ diff --git a/core/dbt/parser/hooks.py b/core/dbt/parser/hooks.py index bcc25c0d937..d44c29c601b 100644 --- a/core/dbt/parser/hooks.py +++ b/core/dbt/parser/hooks.py @@ -1,6 +1,7 @@ from dataclasses import dataclass -from typing import Iterable, Iterator, List, Tuple, Union +from typing import Any, Dict, Iterable, Iterator, List, Tuple, Union +from dbt.constants import DEFAULT_HOOK_PRIORITY from dbt.context.context_config import ContextConfig from dbt.contracts.files import FilePath from dbt.contracts.graph.nodes import HookNode @@ -17,6 +18,7 @@ class HookBlock(FileBlock): value: str index: int hook_type: RunHookType + priority: int = DEFAULT_HOOK_PRIORITY @property def contents(self): @@ -33,14 +35,35 @@ def __init__(self, project, source_file, hook_type) -> None: self.source_file = source_file self.hook_type = hook_type - def _hook_list(self, hooks: Union[str, List[str], Tuple[str, ...]]) -> List[str]: + def _hook_list( + self, hooks: Union[str, List[Union[str, Dict]], Tuple[Union[str, Dict], ...]] + ) -> List[Dict[str, Any]]: + """Convert hook definitions to a standardized list format.""" if isinstance(hooks, tuple): hooks = list(hooks) elif not isinstance(hooks, list): hooks = [hooks] - return hooks - def get_hook_defs(self) -> List[str]: + # Standardize format - ensure all hooks have a consistent structure + result = [] + for hook in hooks: + if isinstance(hook, str): + # Convert string to dict format with default priority + result.append({"sql": hook, "priority": DEFAULT_HOOK_PRIORITY}) + elif isinstance(hook, dict) and "sql" in hook: + # Ensure required keys exist with defaults + hook_dict = { + "sql": hook["sql"], + "priority": hook.get("priority", DEFAULT_HOOK_PRIORITY), + } + result.append(hook_dict) + elif isinstance(hook, dict): + # Backward compatibility for any other dict format + result.append({"sql": str(hook), "priority": DEFAULT_HOOK_PRIORITY}) + + return result + + def get_hook_defs(self) -> List[Dict[str, Any]]: if self.hook_type == RunHookType.Start: hooks = self.project.on_run_start elif self.hook_type == RunHookType.End: @@ -56,12 +79,22 @@ def get_hook_defs(self) -> List[str]: def __iter__(self) -> Iterator[HookBlock]: hooks = self.get_hook_defs() for index, hook in enumerate(hooks): + # Extract SQL and priority + if isinstance(hook, dict): + sql = hook.get("sql", "") + priority = hook.get("priority", DEFAULT_HOOK_PRIORITY) + else: + # Fallback for any unexpected format + sql = str(hook) + priority = DEFAULT_HOOK_PRIORITY + yield HookBlock( file=self.source_file, project=self.project.project_name, - value=hook, + value=sql, index=index, hook_type=self.hook_type, + priority=priority, ) @@ -98,7 +131,8 @@ def _create_parsetime_node( **kwargs, ) -> HookNode: - return super()._create_parsetime_node( + # Create the node using the parent method + node = super()._create_parsetime_node( block=block, path=path, config=config, @@ -108,6 +142,13 @@ def _create_parsetime_node( tags=[str(block.hook_type)], ) + # Store the priority in the node's meta + if not node.config.meta: + node.config.meta = {} + node.config.meta["hook_priority"] = block.priority + + return node + @property def resource_type(self) -> NodeType: return NodeType.Operation diff --git a/core/dbt/task/run.py b/core/dbt/task/run.py index d1351ad4b3c..11a5cb9dedc 100644 --- a/core/dbt/task/run.py +++ b/core/dbt/task/run.py @@ -26,6 +26,7 @@ from dbt.cli.flags import Flags from dbt.clients.jinja import MacroGenerator from dbt.config import RuntimeConfig +from dbt.constants import DEFAULT_HOOK_PRIORITY, PROJECT_HOOK_PRIORITY from dbt.context.providers import generate_runtime_model_context from dbt.contracts.graph.manifest import Manifest from dbt.contracts.graph.nodes import BatchContext, HookNode, ModelNode, ResultNode @@ -929,11 +930,24 @@ def _submit_batch( return relation_exists - def _hook_keyfunc(self, hook: HookNode) -> Tuple[str, Optional[int]]: + def _hook_keyfunc(self, hook: HookNode) -> Tuple[int, str, Optional[int]]: + """Sort hooks by priority, then package name, then index""" package_name = hook.package_name - if package_name == self.config.project_name: + + # Default priority + priority = DEFAULT_HOOK_PRIORITY + + # Get priority from node meta if available + if hook.config and hook.config.meta and "hook_priority" in hook.config.meta: + priority = hook.config.meta["hook_priority"] + + # Special case for project hooks - if no explicit priority, + # make them run last (preserving backward compatibility) + if package_name == self.config.project_name and priority == DEFAULT_HOOK_PRIORITY: package_name = BiggestName("") - return package_name, hook.index + priority = PROJECT_HOOK_PRIORITY + + return priority, package_name, hook.index def get_hooks_by_type(self, hook_type: RunHookType) -> List[HookNode]: diff --git a/tests/unit/parser/test_parser.py b/tests/unit/parser/test_parser.py index 990d5f6a1fa..c972cc21016 100644 --- a/tests/unit/parser/test_parser.py +++ b/tests/unit/parser/test_parser.py @@ -13,6 +13,7 @@ ModelBuildAfter, ModelFreshnessUpdatesOnOptions, ) +from dbt.constants import DEFAULT_HOOK_PRIORITY from dbt.context.context_config import ContextConfig from dbt.contracts.files import FileHash, FilePath, SchemaSourceFile, SourceFile from dbt.contracts.graph.manifest import Manifest @@ -28,7 +29,7 @@ ) from dbt.exceptions import CompilationError, ParsingError, SchemaConfigError from dbt.flags import set_from_args -from dbt.node_types import NodeType +from dbt.node_types import NodeType, RunHookType from dbt.parser import ( AnalysisParser, GenericTestParser, @@ -39,6 +40,7 @@ SnapshotParser, ) from dbt.parser.common import YamlBlock +from dbt.parser.hooks import HookBlock, HookSearcher from dbt.parser.models import ( _get_config_call_dict, _get_exp_sample_result, @@ -2004,3 +2006,80 @@ def test_basic(self): self.assertEqual( self.parser.manifest.files[file_id].nodes, ["analysis.snowplow.analysis_1"] ) + + +class TestHookPriorityParsing(unittest.TestCase): + + def setUp(self): + # Mock project and source_file + self.mock_project = mock.MagicMock() + self.mock_project.project_name = "test_project" + self.mock_source_file = mock.MagicMock() + + def test_hook_string_to_dict_with_priority(self): + """Test that string hooks are parsed correctly with default priority""" + self.mock_project.on_run_start = ["SELECT 1"] + + searcher = HookSearcher(self.mock_project, self.mock_source_file, RunHookType.Start) + hooks = searcher.get_hook_defs() + + self.assertEqual(len(hooks), 1) + self.assertIsInstance(hooks[0], dict) + self.assertEqual(hooks[0]["sql"], "SELECT 1") + self.assertEqual(hooks[0]["priority"], DEFAULT_HOOK_PRIORITY) + + def test_hook_dict_with_custom_priority(self): + """Test that dict hooks with priority are parsed correctly""" + custom_priority = 10 + self.mock_project.on_run_end = [{"sql": "SELECT 2", "priority": custom_priority}] + + searcher = HookSearcher(self.mock_project, self.mock_source_file, RunHookType.End) + hooks = searcher.get_hook_defs() + + self.assertEqual(len(hooks), 1) + self.assertIsInstance(hooks[0], dict) + self.assertEqual(hooks[0]["sql"], "SELECT 2") + self.assertEqual(hooks[0]["priority"], custom_priority) + + def test_hook_mixed_formats(self): + """Test that mixed hook formats are all parsed correctly""" + custom_priority = 20 + self.mock_project.on_run_end = [ + "SELECT 4", + {"sql": "SELECT 5", "priority": custom_priority}, + {"sql": "SELECT 6"}, + ] + + searcher = HookSearcher(self.mock_project, self.mock_source_file, RunHookType.End) + hooks = searcher.get_hook_defs() + + self.assertEqual(len(hooks), 3) + self.assertEqual(hooks[0]["sql"], "SELECT 4") + self.assertEqual(hooks[0]["priority"], DEFAULT_HOOK_PRIORITY) + self.assertEqual(hooks[1]["sql"], "SELECT 5") + self.assertEqual(hooks[1]["priority"], custom_priority) + self.assertEqual(hooks[2]["sql"], "SELECT 6") + self.assertEqual(hooks[2]["priority"], DEFAULT_HOOK_PRIORITY) + + def test_hook_block_creation_with_priority(self): + """Test that HookBlock objects are created with correct priority values""" + custom_priority = 30 + self.mock_project.on_run_start = [ + "SELECT 1", + {"sql": "SELECT 2", "priority": custom_priority}, + ] + + searcher = HookSearcher(self.mock_project, self.mock_source_file, RunHookType.Start) + blocks = list(searcher) + + self.assertEqual(len(blocks), 2) + self.assertIsInstance(blocks[0], HookBlock) + self.assertEqual(blocks[0].value, "SELECT 1") + self.assertEqual(blocks[0].priority, DEFAULT_HOOK_PRIORITY) + + self.assertIsInstance(blocks[1], HookBlock) + self.assertEqual(blocks[1].value, "SELECT 2") + self.assertEqual(blocks[1].priority, custom_priority) + self.assertEqual(blocks[1].project, "test_project") + self.assertEqual(blocks[1].index, 1) + self.assertEqual(blocks[1].hook_type, RunHookType.Start) diff --git a/tests/unit/task/test_run.py b/tests/unit/task/test_run.py index 3378009fade..1e167ea5449 100644 --- a/tests/unit/task/test_run.py +++ b/tests/unit/task/test_run.py @@ -1,3 +1,4 @@ +import unittest from argparse import Namespace from dataclasses import dataclass from importlib import import_module @@ -20,6 +21,7 @@ from dbt.artifacts.schemas.results import RunStatus from dbt.artifacts.schemas.run import RunResult from dbt.config.runtime import RuntimeConfig +from dbt.constants import DEFAULT_HOOK_PRIORITY, PROJECT_HOOK_PRIORITY from dbt.contracts.graph.manifest import Manifest from dbt.contracts.graph.nodes import HookNode, ModelNode from dbt.events.types import LogModelResult @@ -385,3 +387,139 @@ def test_safe_run_hooks( assert not isinstance(expected_result, RunStatus) assert issubclass(expected_result, BaseException) assert type(e) == expected_result + + +class TestHookPrioritization(unittest.TestCase): + def test_hook_keyfunc_ordering(self): + """Test that _hook_keyfunc correctly sorts hooks based on priority""" + # Create a minimal RunTask instance + config = mock.MagicMock() + config.project_name = "test_project" + task = RunTask(args=mock.MagicMock(), config=config, manifest=mock.MagicMock()) + + # Create hook nodes with different priorities + # High priority package hook + high_priority_hook = mock.MagicMock() + high_priority_hook.name = "high_priority_hook" + high_priority_hook.package_name = "package_a" + high_priority_hook.config.meta = {"hook_priority": 10} + high_priority_hook.index = 0 + + # Medium priority project hook + medium_priority_hook = mock.MagicMock() + medium_priority_hook.name = "medium_priority_hook" + medium_priority_hook.package_name = "test_project" + medium_priority_hook.config.meta = {"hook_priority": 30} + medium_priority_hook.index = 0 + + # Default priority package hook + default_package_hook = mock.MagicMock() + default_package_hook.name = "default_package_hook" + default_package_hook.package_name = "package_b" + default_package_hook.config.meta = {} + default_package_hook.index = 0 + + # Default priority project hook (should be moved to end) + default_project_hook = mock.MagicMock() + default_project_hook.name = "default_project_hook" + default_project_hook.package_name = "test_project" + default_project_hook.config.meta = {} + default_project_hook.index = 1 + + # Create a list of hooks in a random order + hooks = [ + default_project_hook, + high_priority_hook, + default_package_hook, + medium_priority_hook, + ] + + # Sort hooks using _hook_keyfunc + sorted_hooks = sorted(hooks, key=task._hook_keyfunc) + + # Verify the sorting order + self.assertEqual(sorted_hooks[0].name, "high_priority_hook") # Priority 10 + self.assertEqual(sorted_hooks[1].name, "medium_priority_hook") # Priority 30 + self.assertEqual(sorted_hooks[2].name, "default_package_hook") # Priority DEFAULT (50) + self.assertEqual(sorted_hooks[3].name, "default_project_hook") # Priority PROJECT (100) + + def test_hook_keyfunc_same_priority(self): + """Test that hooks with same priority are sorted by package name and index""" + # Create a minimal RunTask instance + config = mock.MagicMock() + config.project_name = "test_project" + task = RunTask(args=mock.MagicMock(), config=config, manifest=mock.MagicMock()) + + # Create hook nodes with the same priority + # Same priority (20), package_b, index 0 + hook1 = mock.MagicMock() + hook1.name = "hook_b_0" + hook1.package_name = "package_b" + hook1.config.meta = {"hook_priority": 20} + hook1.index = 0 + + # Same priority (20), package_a, index 1 + hook2 = mock.MagicMock() + hook2.name = "hook_a_1" + hook2.package_name = "package_a" + hook2.config.meta = {"hook_priority": 20} + hook2.index = 1 + + # Same priority (20), package_a, index 0 + hook3 = mock.MagicMock() + hook3.name = "hook_a_0" + hook3.package_name = "package_a" + hook3.config.meta = {"hook_priority": 20} + hook3.index = 0 + + # Create a list of hooks in a random order + hooks = [hook1, hook2, hook3] + + # Sort hooks using _hook_keyfunc + sorted_hooks = sorted(hooks, key=task._hook_keyfunc) + + # Expected order based on package name (alphabetical) and then index: + # 1. hook_a_0 (package_a, index 0) + # 2. hook_a_1 (package_a, index 1) + # 3. hook_b_0 (package_b, index 0) + self.assertEqual(sorted_hooks[0].name, "hook_a_0") + self.assertEqual(sorted_hooks[1].name, "hook_a_1") + self.assertEqual(sorted_hooks[2].name, "hook_b_0") + + def test_project_hooks_default_priority(self): + """Test that project hooks with default priority are moved to the end""" + # Create a minimal RunTask instance + config = mock.MagicMock() + config.project_name = "test_project" + task = RunTask(args=mock.MagicMock(), config=config, manifest=mock.MagicMock()) + + # Create a package hook with default priority + package_hook = mock.MagicMock() + package_hook.name = "package_hook" + package_hook.package_name = "other_package" + package_hook.config.meta = {} # Default priority + package_hook.index = 0 + + # Create a project hook with default priority (should be moved to end) + project_hook = mock.MagicMock() + project_hook.name = "project_hook" + project_hook.package_name = "test_project" + project_hook.config.meta = {} # Default priority + project_hook.index = 0 + + # Sort hooks using _hook_keyfunc + sorted_hooks = sorted([project_hook, package_hook], key=task._hook_keyfunc) + + # Project hook should be after package hook + self.assertEqual(sorted_hooks[0].name, "package_hook") + self.assertEqual(sorted_hooks[1].name, "project_hook") + + # Verify the priority values from the key function + package_key = task._hook_keyfunc(package_hook) + project_key = task._hook_keyfunc(project_hook) + + # Package hook should have DEFAULT_HOOK_PRIORITY + self.assertEqual(package_key[0], DEFAULT_HOOK_PRIORITY) + + # Project hook should have PROJECT_HOOK_PRIORITY + self.assertEqual(project_key[0], PROJECT_HOOK_PRIORITY)