diff --git a/src/datachain/hash_utils.py b/src/datachain/hash_utils.py index aae94e1b5..537c14275 100644 --- a/src/datachain/hash_utils.py +++ b/src/datachain/hash_utils.py @@ -112,16 +112,27 @@ def hash_column_elements(columns: Sequence[str | Function | T]) -> str: return hashlib.sha256(json_str.encode("utf-8")).hexdigest() -def hash_callable(func): +def hash_callable(func, _visited=None): """ - Calculate a hash from a callable. + Calculate a hash from a callable, including its dependencies. Rules: - Named functions (def) → use source code for stable, cross-version hashing - Lambdas → use bytecode (deterministic in same Python runtime) + - Recursively hashes helper functions from the same module """ if not callable(func): raise TypeError("Expected a callable") + # Track visited functions to avoid infinite recursion + if _visited is None: + _visited = set() + + # Use id(func) to track which functions we've visited + func_id = id(func) + if func_id in _visited: + return hashlib.sha256(f"recursive:{func.__name__}".encode()).hexdigest() + _visited.add(func_id) + # Determine if it is a lambda is_lambda = func.__name__ == "" @@ -131,11 +142,13 @@ def hash_callable(func): lines, _ = inspect.getsourcelines(func) payload = textwrap.dedent("".join(lines)).strip() except (OSError, TypeError): - # Fallback: bytecode if source not available - payload = func.__code__.co_code + # Fallback: bytecode + constants if source not available + code = func.__code__ + payload = (code.co_code, code.co_consts, code.co_names, code.co_varnames) else: - # For lambdas, fall back directly to bytecode - payload = func.__code__.co_code + # For lambdas, use bytecode + constants + code = func.__code__ + payload = (code.co_code, code.co_consts, code.co_names, code.co_varnames) # Normalize annotations annotations = { @@ -149,8 +162,45 @@ def hash_callable(func): "annotations": annotations, } + # Find helper functions that this function depends on + dependencies = {} + if hasattr(func, "__code__") and hasattr(func, "__globals__"): + # Get all names referenced in the function's code + referenced_names = func.__code__.co_names + func_module = inspect.getmodule(func) + + for name in referenced_names: + # Look up the name in the function's global namespace + if name in func.__globals__: + obj = func.__globals__[name] + + # Only hash user-defined functions from the same module + # Skip built-ins, imported functions from other modules, and classes + if ( + callable(obj) + and hasattr(obj, "__module__") + and func_module is not None + and obj.__module__ == func_module.__name__ + and not inspect.isclass(obj) + and not inspect.isbuiltin(obj) + ): + # Recursively hash the dependency + try: + dependencies[name] = hash_callable(obj, _visited) + except (TypeError, OSError): + # If we can't hash it, skip it + pass + # Compute SHA256 h = hashlib.sha256() - h.update(str(payload).encode() if isinstance(payload, str) else payload) + if isinstance(payload, str): + h.update(payload.encode()) + else: + # payload is a tuple of (bytecode, consts, names, varnames) + h.update(str(payload).encode()) h.update(str(extras).encode()) + # Include dependency hashes in sorted order for determinism + if dependencies: + deps_str = json.dumps(dependencies, sort_keys=True) + h.update(deps_str.encode()) return h.hexdigest() diff --git a/tests/unit/test_hash_utils.py b/tests/unit/test_hash_utils.py index 2a5e442a4..bd90362d2 100644 --- a/tests/unit/test_hash_utils.py +++ b/tests/unit/test_hash_utils.py @@ -107,3 +107,73 @@ def test_lambda_different_hashes(): # Ensure hashes are all different assert len({h1, h2, h3}) == 3 + + +def test_hash_callable_with_dependencies(): + # Define helper and function that uses it + def helper(x): + return x + 1 + + def func_with_helper(x): + return helper(x) * 2 + + hash1 = hash_callable(func_with_helper) + assert hash1 == "5b2dbae7cca8695acd62ea2ee2226277962c1c59a098ab948ff1b2e73b3d822c" + + # Redefine helper with different implementation (same name, different code) + def helper(x): # noqa: F811 + return x + 10 + + def func_with_helper(x): + return helper(x) * 2 + + hash2 = hash_callable(func_with_helper) + assert hash2 == "099b86b464fb5a901393b28f073b7701f22a31775b5ce8402b4ea1116a50064e" + + # Hashes should be different because helper changed + assert hash1 != hash2 + + +def test_hash_callable_recursive(): + def factorial(n): + if n <= 1: + return 1 + return n * factorial(n - 1) + + assert hash_callable(factorial) is not None + + +def test_hash_callable_mutual_recursion(): + def func_a(n): + return func_b(n - 1) if n > 0 else 0 + + def func_b(n): + return func_a(n - 1) if n > 0 else 1 + + hash_a = hash_callable(func_a) + hash_b = hash_callable(func_b) + + assert hash_a is not None + assert hash_b is not None + # Hashes should be different since functions are different + assert hash_a != hash_b + + +def test_hash_callable_global_variable_limitation(): + # This test documents the current limitation - global variables don't affect hash + + global THRESHOLD # noqa: PLW0603 + THRESHOLD = 100 + + def filter_data(x): + return x > THRESHOLD + + hash1 = hash_callable(filter_data) + + # Change global variable + THRESHOLD = 200 + + hash2 = hash_callable(filter_data) + + # Hash is the same even though behavior changed (limitation) + assert hash1 == hash2