Skip to content

Commit e95f24f

Browse files
authored
Fix pytest collection failure for classes decorated with context managers (apache#55915)
Classes decorated with `@conf_vars` and other context managers were disappearing during pytest collection, causing tests to be silently skipped. This affected several test classes including `TestWorkerStart` in the Celery provider tests. Root cause: `ContextDecorator` transforms decorated classes into callable wrappers. Since pytest only collects actual type objects as test classes, these wrapped classes are ignored during collection. Simple reproduction (no Airflow needed): ```py import contextlib import inspect @contextlib.contextmanager def simple_cm(): yield @simple_cm() class TestExample: def test_method(self): pass print(f'Is class? {inspect.isclass(TestExample)}') # False - pytest won't collect ``` and then run ```shell pytest test_example.py --collect-only ``` Airflow reproduction: ```shell breeze run pytest providers/celery/tests/unit/celery/cli/test_celery_command.py --collect-only -v breeze run pytest providers/celery/tests/unit/celery/cli/test_celery_command.py --collect-only -v ``` Solution: 1. Fixed affected test files by replacing class-level `@conf_vars` decorators with pytest fixtures 2. Created pytest fixtures to apply configuration changes 3. Used `@pytest.mark.usefixtures` to apply configuration to test classes 4. Added custom linter to prevent future occurrences and integrated it into pre-commit hooks Files changed: - Fixed 3 test files with problematic class decorators - Added custom linter with pre-commit integration This ensures pytest properly collects all test classes and prevents similar issues in the future through automated detection.
1 parent e6ebf6d commit e95f24f

File tree

5 files changed

+191
-36
lines changed

5 files changed

+191
-36
lines changed

.pre-commit-config.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1649,3 +1649,9 @@ repos:
16491649
files: ^airflow-core/src/airflow/serialization/schema\.json$|^airflow-core/src/airflow/serialization/serialized_objects\.py$
16501650
pass_filenames: false
16511651
require_serial: true
1652+
- id: check-contextmanager-class-decorators
1653+
name: Check for problematic context manager class decorators
1654+
entry: ./scripts/ci/prek/check_contextmanager_class_decorators.py
1655+
language: python
1656+
files: .*test.*\.py$
1657+
pass_filenames: true

providers/apache/kafka/tests/integration/apache/kafka/operators/test_consume.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,9 @@
2525
from confluent_kafka import Producer
2626

2727
# Import Operator
28+
from airflow.models.connection import Connection
2829
from airflow.providers.apache.kafka.operators.consume import ConsumeFromTopicOperator
2930

30-
from tests_common.test_utils.config import conf_vars
31-
3231
log = logging.getLogger(__name__)
3332

3433

@@ -49,23 +48,29 @@ def _basic_message_tester(message, test=None) -> Any:
4948
assert message.value().decode(encoding="utf-8") == test
5049

5150

51+
@pytest.fixture(autouse=True)
52+
def kafka_consumer_connections(create_connection_without_db):
53+
"""Create Kafka consumer connections for testing purpose."""
54+
connections = [
55+
Connection(
56+
conn_id="operator.consumer.test.integration.test_1",
57+
uri="kafka://broker:29092?socket.timeout.ms=10&bootstrap.servers=broker:29092&group.id=operator.consumer.test.integration.test_1&enable.auto.commit=False&auto.offset.reset=beginning",
58+
),
59+
Connection(
60+
conn_id="operator.consumer.test.integration.test_2",
61+
uri="kafka://broker:29092?socket.timeout.ms=10&bootstrap.servers=broker:29092&group.id=operator.consumer.test.integration.test_2&enable.auto.commit=False&auto.offset.reset=beginning",
62+
),
63+
Connection(
64+
conn_id="operator.consumer.test.integration.test_3",
65+
uri="kafka://broker:29092?socket.timeout.ms=10&bootstrap.servers=broker:29092&group.id=operator.consumer.test.integration.test_3&enable.auto.commit=False&auto.offset.reset=beginning",
66+
),
67+
]
68+
69+
for conn in connections:
70+
create_connection_without_db(conn)
71+
72+
5273
@pytest.mark.integration("kafka")
53-
@conf_vars(
54-
{
55-
(
56-
"connections",
57-
"operator.consumer.test.integration.test_1",
58-
): "kafka://broker:29092?socket.timeout.ms=10&bootstrap.servers=broker:29092&group.id=operator.consumer.test.integration.test_1&enable.auto.commit=False&auto.offset.reset=beginning",
59-
(
60-
"connections",
61-
"operator.consumer.test.integration.test_2",
62-
): "kafka://broker:29092?socket.timeout.ms=10&bootstrap.servers=broker:29092&group.id=operator.consumer.test.integration.test_2&enable.auto.commit=False&auto.offset.reset=beginning",
63-
(
64-
"connections",
65-
"operator.consumer.test.integration.test_3",
66-
): "kafka://broker:29092?socket.timeout.ms=10&bootstrap.servers=broker:29092&group.id=operator.consumer.test.integration.test_3&enable.auto.commit=False&auto.offset.reset=beginning",
67-
}
68-
)
6974
class TestConsumeFromTopic:
7075
"""
7176
test ConsumeFromTopicOperator

providers/apache/kafka/tests/integration/apache/kafka/operators/test_produce.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,9 @@
2222
import pytest
2323
from confluent_kafka import Consumer
2424

25+
from airflow.models.connection import Connection
2526
from airflow.providers.apache.kafka.operators.produce import ProduceToTopicOperator
2627

27-
from tests_common.test_utils.config import conf_vars
28-
2928
log = logging.getLogger(__name__)
3029

3130

@@ -34,19 +33,25 @@ def _producer_function():
3433
yield (json.dumps(i), json.dumps(i + 1))
3534

3635

36+
@pytest.fixture(autouse=True)
37+
def kafka_connections(create_connection_without_db):
38+
"""Create Kafka producer connections for testing purpose."""
39+
connections = [
40+
Connection(
41+
conn_id="kafka_default_test_1",
42+
uri="kafka://broker:29092?socket.timeout.ms=10&message.timeout.ms=10&group.id=operator.producer.test.integration.test_1",
43+
),
44+
Connection(
45+
conn_id="kafka_default_test_2",
46+
uri="kafka://broker:29092?socket.timeout.ms=10&message.timeout.ms=10&group.id=operator.producer.test.integration.test_2",
47+
),
48+
]
49+
50+
for conn in connections:
51+
create_connection_without_db(conn)
52+
53+
3754
@pytest.mark.integration("kafka")
38-
@conf_vars(
39-
{
40-
(
41-
"connections",
42-
"kafka_default_test_1",
43-
): "kafka://broker:29092?socket.timeout.ms=10&message.timeout.ms=10&group.id=operator.producer.test.integration.test_1",
44-
(
45-
"connections",
46-
"kafka_default_test_2",
47-
): "kafka://broker:29092?socket.timeout.ms=10&message.timeout.ms=10&group.id=operator.producer.test.integration.test_2",
48-
}
49-
)
5055
class TestProduceToTopic:
5156
"""
5257
test ProduceToTopicOperator

providers/celery/tests/unit/celery/cli/test_celery_command.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,14 @@
3737
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
3838

3939

40+
@pytest.fixture(autouse=False)
41+
def conf_stale_bundle_cleanup_disabled():
42+
with conf_vars({("dag_processor", "stale_bundle_cleanup_interval"): "0"}):
43+
yield
44+
45+
4046
@pytest.mark.backend("mysql", "postgres")
41-
@conf_vars({("dag_processor", "stale_bundle_cleanup_interval"): 0})
47+
@pytest.mark.usefixtures("conf_stale_bundle_cleanup_disabled")
4248
class TestCeleryStopCommand:
4349
@classmethod
4450
def setup_class(cls):
@@ -120,7 +126,7 @@ def test_custom_pid_file_is_used_in_start_and_stop(
120126

121127

122128
@pytest.mark.backend("mysql", "postgres")
123-
@conf_vars({("dag_processor", "stale_bundle_cleanup_interval"): 0})
129+
@pytest.mark.usefixtures("conf_stale_bundle_cleanup_disabled")
124130
class TestWorkerStart:
125131
@classmethod
126132
def setup_class(cls):
@@ -181,7 +187,7 @@ def test_worker_started_with_required_arguments(self, mock_celery_app, mock_pope
181187

182188

183189
@pytest.mark.backend("mysql", "postgres")
184-
@conf_vars({("dag_processor", "stale_bundle_cleanup_interval"): 0})
190+
@pytest.mark.usefixtures("conf_stale_bundle_cleanup_disabled")
185191
class TestWorkerFailure:
186192
@classmethod
187193
def setup_class(cls):
@@ -201,7 +207,7 @@ def test_worker_failure_gracefull_shutdown(self, mock_celery_app, mock_popen):
201207

202208

203209
@pytest.mark.backend("mysql", "postgres")
204-
@conf_vars({("dag_processor", "stale_bundle_cleanup_interval"): 0})
210+
@pytest.mark.usefixtures("conf_stale_bundle_cleanup_disabled")
205211
class TestFlowerCommand:
206212
@classmethod
207213
def setup_class(cls):
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
#!/usr/bin/env python3
2+
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
20+
"""
21+
Check for problematic context manager decorators on test classes.
22+
23+
Context managers (ContextDecorator, @contextlib.contextmanager) when used as class decorators
24+
transform the class into a callable wrapper, which prevents pytest from collecting the class.
25+
"""
26+
27+
from __future__ import annotations
28+
29+
import ast
30+
import sys
31+
from pathlib import Path
32+
33+
34+
class ContextManagerClassDecoratorChecker(ast.NodeVisitor):
35+
"""AST visitor to check for context manager decorators on test classes."""
36+
37+
def __init__(self, filename: str):
38+
self.filename = filename
39+
self.errors: list[str] = []
40+
41+
def visit_ClassDef(self, node: ast.ClassDef) -> None:
42+
"""Check class definitions for problematic decorators."""
43+
if not node.name.startswith("Test"):
44+
self.generic_visit(node)
45+
return
46+
47+
for decorator in node.decorator_list:
48+
decorator_name = self._get_decorator_name(decorator)
49+
if self._is_problematic_decorator(decorator_name):
50+
self.errors.append(
51+
f"{self.filename}:{node.lineno}: Class '{node.name}' uses @{decorator_name} "
52+
f"decorator which prevents pytest collection. Use @pytest.mark.usefixtures instead."
53+
)
54+
55+
self.generic_visit(node)
56+
57+
def _get_decorator_name(self, decorator: ast.expr) -> str:
58+
"""Extract decorator name from AST node."""
59+
if isinstance(decorator, ast.Name):
60+
return decorator.id
61+
if isinstance(decorator, ast.Call):
62+
if isinstance(decorator.func, ast.Name):
63+
return decorator.func.id
64+
if isinstance(decorator.func, ast.Attribute):
65+
return f"{self._get_attr_chain(decorator.func)}"
66+
elif isinstance(decorator, ast.Attribute):
67+
return f"{self._get_attr_chain(decorator)}"
68+
return "unknown"
69+
70+
def _get_attr_chain(self, node: ast.Attribute) -> str:
71+
"""Get the full attribute chain (e.g., 'contextlib.contextmanager')."""
72+
if isinstance(node.value, ast.Name):
73+
return f"{node.value.id}.{node.attr}"
74+
if isinstance(node.value, ast.Attribute):
75+
return f"{self._get_attr_chain(node.value)}.{node.attr}"
76+
return node.attr
77+
78+
def _is_problematic_decorator(self, decorator_name: str) -> bool:
79+
"""Check if decorator is known to break pytest class collection."""
80+
problematic_decorators = {
81+
"conf_vars",
82+
"env_vars",
83+
"contextlib.contextmanager",
84+
"contextmanager",
85+
}
86+
return decorator_name in problematic_decorators
87+
88+
89+
def check_file(filepath: Path) -> list[str]:
90+
"""Check a single file for problematic decorators."""
91+
try:
92+
with open(filepath, encoding="utf-8") as f:
93+
content = f.read()
94+
95+
tree = ast.parse(content, filename=str(filepath))
96+
checker = ContextManagerClassDecoratorChecker(str(filepath))
97+
checker.visit(tree)
98+
return checker.errors
99+
except Exception as e:
100+
return [f"{filepath}: Error parsing file: {e}"]
101+
102+
103+
def main() -> int:
104+
"""Main entry point."""
105+
if len(sys.argv) < 2:
106+
print("Usage: check_contextmanager_class_decorators.py <file_or_directory>...")
107+
return 1
108+
109+
all_errors = []
110+
111+
for arg in sys.argv[1:]:
112+
path = Path(arg)
113+
if path.is_file() and path.suffix == ".py":
114+
if "test" in str(path): # Only check test files
115+
all_errors.extend(check_file(path))
116+
else:
117+
print(f"Skipping non-test file: {path}")
118+
elif path.is_dir():
119+
for py_file in path.rglob("*.py"):
120+
if "test" in str(py_file): # Only check test files
121+
all_errors.extend(check_file(py_file))
122+
123+
if all_errors:
124+
print("Found problematic context manager class decorators:")
125+
for error in all_errors:
126+
print(f" {error}")
127+
return 1
128+
129+
return 0
130+
131+
132+
if __name__ == "__main__":
133+
sys.exit(main())

0 commit comments

Comments
 (0)