Skip to content

Commit b3e3bc5

Browse files
authored
Introducing prek hook to detect airflow imports in shared libraries (#61350)
1 parent 8a4f38f commit b3e3bc5

File tree

2 files changed

+112
-0
lines changed

2 files changed

+112
-0
lines changed

.pre-commit-config.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,18 @@ repos:
356356
language: python
357357
pass_filenames: false
358358
files: ^shared/.*$|^.*/pyproject.toml$|^.*/_shared/.*$
359+
- id: check-airflow-imports-in-shared
360+
name: Check for core/sdk imports in shared libraries
361+
entry: ./scripts/ci/prek/check_airflow_imports_in_shared.py
362+
language: python
363+
pass_filenames: true
364+
files: ^shared/.*/src/.*\.py$
365+
exclude: |
366+
(?x)
367+
^shared/listeners/src/airflow_shared/listeners/spec/taskinstance\.py$|
368+
^shared/logging/src/airflow_shared/logging/remote\.py$|
369+
^shared/observability/src/airflow_shared/observability/metrics/stats\.py$|
370+
^shared/secrets_backend/src/airflow_shared/secrets_backend/base\.py$
359371
- id: check-secrets-search-path-sync
360372
name: Check sync between sdk and core
361373
entry: ./scripts/ci/prek/check_secrets_search_path_sync.py
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
#!/usr/bin/env python
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
# /// script
20+
# requires-python = ">=3.10,<3.11"
21+
# dependencies = [
22+
# "rich>=13.6.0",
23+
# ]
24+
# ///
25+
from __future__ import annotations
26+
27+
import argparse
28+
import ast
29+
import sys
30+
from pathlib import Path
31+
32+
sys.path.insert(0, str(Path(__file__).parent.resolve()))
33+
from common_prek_utils import console
34+
35+
36+
def check_file_for_prohibited_imports(file_path: Path) -> list[tuple[int, str]]:
37+
try:
38+
source = file_path.read_text(encoding="utf-8")
39+
tree = ast.parse(source, filename=str(file_path))
40+
except (OSError, UnicodeDecodeError, SyntaxError):
41+
return []
42+
43+
violations = []
44+
45+
for node in ast.walk(tree):
46+
# Check `from airflow.x import y` statements
47+
if isinstance(node, ast.ImportFrom):
48+
if node.module and node.module.startswith("airflow."):
49+
# Allow airflow_shared imports (which show as airflow._shared at runtime)
50+
if not (
51+
node.module.startswith("airflow_shared") or node.module.startswith("airflow._shared")
52+
):
53+
import_names = ", ".join(alias.name for alias in node.names)
54+
statement = f"from {node.module} import {import_names}"
55+
violations.append((node.lineno, statement))
56+
57+
# Check `import airflow.x` statements
58+
elif isinstance(node, ast.Import):
59+
for alias in node.names:
60+
if alias.name.startswith("airflow."):
61+
# Allow airflow_shared imports (which show as airflow._shared at runtime)
62+
if not (
63+
alias.name.startswith("airflow_shared") or alias.name.startswith("airflow._shared")
64+
):
65+
statement = f"import {alias.name}"
66+
if alias.asname:
67+
statement += f" as {alias.asname}"
68+
violations.append((node.lineno, statement))
69+
70+
return violations
71+
72+
73+
def main():
74+
parser = argparse.ArgumentParser(description="Check for airflow imports in shared library files")
75+
parser.add_argument("files", nargs="*", help="Files to check")
76+
args = parser.parse_args()
77+
78+
if not args.files:
79+
return
80+
81+
total_violations = 0
82+
83+
for file_path in [Path(f) for f in args.files]:
84+
violations = check_file_for_prohibited_imports(file_path)
85+
if violations:
86+
console.print(f"[red]{file_path}[/red]:")
87+
for line_num, statement in violations:
88+
console.print(f" [yellow]Line {line_num}[/yellow]: {statement}")
89+
total_violations += len(violations)
90+
91+
if total_violations:
92+
console.print()
93+
console.print(f"[red]Found {total_violations} prohibited import(s) in shared library files[/red]")
94+
console.print("[yellow]Shared libraries must not import from airflow-core or task sdk[/yellow]")
95+
sys.exit(1)
96+
97+
98+
if __name__ == "__main__":
99+
main()
100+
sys.exit(0)

0 commit comments

Comments
 (0)