|
6 | 6 | import abc |
7 | 7 | import inspect |
8 | 8 | import sys |
| 9 | +from collections.abc import Callable |
9 | 10 | from pathlib import Path |
10 | | -from typing import Any, Literal |
| 11 | +from typing import Any, Literal, cast |
11 | 12 |
|
12 | 13 | import pytest |
13 | 14 | import yaml |
@@ -90,23 +91,34 @@ class TestMyConnector(ConnectorTestSuiteBase): |
90 | 91 | class ConnectorTestSuiteBase(abc.ABC): |
91 | 92 | """Base class for connector test suites.""" |
92 | 93 |
|
| 94 | + connector: type[IConnector] | Callable[[], IConnector] | None = None |
| 95 | + """The connector class or a factory function that returns an instance of IConnector.""" |
| 96 | + |
93 | 97 | @classmethod |
94 | 98 | def get_test_class_dir(cls) -> Path: |
95 | 99 | """Get the file path that contains the class.""" |
96 | 100 | module = sys.modules[cls.__module__] |
97 | 101 | # Get the directory containing the test file |
98 | 102 | return Path(inspect.getfile(module)).parent |
99 | 103 |
|
100 | | - connector: type[Connector] | Path | JavaClass | DockerImage | None = None |
101 | | - """The connector class or path to the connector to test.""" |
102 | | - |
103 | 104 | @classmethod |
104 | 105 | def create_connector( |
105 | 106 | cls, |
106 | 107 | scenario: ConnectorTestScenario, |
107 | 108 | ) -> IConnector: |
108 | 109 | """Instantiate the connector class.""" |
109 | | - raise NotImplementedError("Subclasses must implement this method.") |
| 110 | + connector = cls.connector # type: ignore |
| 111 | + if connector: |
| 112 | + if callable(connector) or isinstance(connector, type): |
| 113 | + # If the connector is a class or factory function, instantiate it: |
| 114 | + return cast(IConnector, connector()) |
| 115 | + |
| 116 | + # Otherwise, we can't instantiate the connector. Fail with a clear error message. |
| 117 | + raise NotImplementedError( |
| 118 | + "No connector class or connector factory function provided. " |
| 119 | + "Please provide a class or factory function in `cls.connector`, or " |
| 120 | + "override `cls.create_connector()` to define a custom initialization process." |
| 121 | + ) |
110 | 122 |
|
111 | 123 | def run_test_scenario( |
112 | 124 | self, |
|
0 commit comments