Skip to content

Commit b0f8c3d

Browse files
hsheth2treff7es
andauthored
refactor(ingest): simplify stateful ingestion provider interface (#8104)
Co-authored-by: Tamas Nemeth <[email protected]>
1 parent afd65e1 commit b0f8c3d

27 files changed

+179
-323
lines changed

metadata-ingestion/src/datahub/ingestion/api/common.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from datahub.emitter.mce_builder import set_dataset_urn_to_lower
66
from datahub.ingestion.api.committable import Committable
7-
from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph
7+
from datahub.ingestion.graph.client import DataHubGraph
88

99
if TYPE_CHECKING:
1010
from datahub.ingestion.run.pipeline import PipelineConfig
@@ -43,22 +43,19 @@ class PipelineContext:
4343
def __init__(
4444
self,
4545
run_id: str,
46-
datahub_api: Optional["DatahubClientConfig"] = None,
46+
graph: Optional[DataHubGraph] = None,
4747
pipeline_name: Optional[str] = None,
4848
dry_run: bool = False,
4949
preview_mode: bool = False,
5050
pipeline_config: Optional["PipelineConfig"] = None,
5151
) -> None:
5252
self.pipeline_config = pipeline_config
53+
self.graph = graph
5354
self.run_id = run_id
5455
self.pipeline_name = pipeline_name
5556
self.dry_run_mode = dry_run
5657
self.preview_mode = preview_mode
5758
self.checkpointers: Dict[str, Committable] = {}
58-
try:
59-
self.graph = DataHubGraph(datahub_api) if datahub_api is not None else None
60-
except Exception as e:
61-
raise Exception(f"Failed to connect to DataHub: {e}") from e
6259

6360
self._set_dataset_urn_to_lower_if_needed()
6461

metadata-ingestion/src/datahub/ingestion/run/pipeline.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from datahub.ingestion.api.source import Extractor, Source
2828
from datahub.ingestion.api.transform import Transformer
2929
from datahub.ingestion.extractor.extractor_registry import extractor_registry
30+
from datahub.ingestion.graph.client import DataHubGraph
3031
from datahub.ingestion.reporting.reporting_provider_registry import (
3132
reporting_provider_registry,
3233
)
@@ -183,10 +184,15 @@ def __init__(
183184
self.last_time_printed = int(time.time())
184185
self.cli_report = CliReport()
185186

187+
self.graph = None
188+
with _add_init_error_context("connect to DataHub"):
189+
if self.config.datahub_api:
190+
self.graph = DataHubGraph(self.config.datahub_api)
191+
186192
with _add_init_error_context("set up framework context"):
187193
self.ctx = PipelineContext(
188194
run_id=self.config.run_id,
189-
datahub_api=self.config.datahub_api,
195+
graph=self.graph,
190196
pipeline_name=self.config.pipeline_name,
191197
dry_run=dry_run,
192198
preview_mode=preview_mode,

metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from datetime import timedelta
88
from enum import auto
99
from threading import BoundedSemaphore
10-
from typing import Union, cast
10+
from typing import Union
1111

1212
from datahub.cli.cli_utils import set_env_variables_override_config
1313
from datahub.configuration.common import (
@@ -126,8 +126,7 @@ def __post_init__(self) -> None:
126126

127127
def handle_work_unit_start(self, workunit: WorkUnit) -> None:
128128
if isinstance(workunit, MetadataWorkUnit):
129-
mwu: MetadataWorkUnit = cast(MetadataWorkUnit, workunit)
130-
self.treat_errors_as_warnings = mwu.treat_errors_as_warnings
129+
self.treat_errors_as_warnings = workunit.treat_errors_as_warnings
131130

132131
def handle_work_unit_end(self, workunit: WorkUnit) -> None:
133132
pass

metadata-ingestion/src/datahub/ingestion/source/looker/looker_source.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,6 +1331,3 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
13311331

13321332
def get_report(self) -> SourceReport:
13331333
return self.reporter
1334-
1335-
def close(self):
1336-
self.prepare_for_commit()

metadata-ingestion/src/datahub/ingestion/source/looker/lookml_source.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2163,6 +2163,3 @@ def get_internal_workunits(self) -> Iterable[MetadataWorkUnit]: # noqa: C901
21632163

21642164
def get_report(self):
21652165
return self.reporter
2166-
2167-
def close(self):
2168-
self.prepare_for_commit()

metadata-ingestion/src/datahub/ingestion/source/powerbi/powerbi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1228,7 +1228,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
12281228
# Because job_id is used as dictionary key, we have to set a new job_id
12291229
# Refer to https://github.com/datahub-project/datahub/blob/master/metadata-ingestion/src/datahub/ingestion/source/state/stateful_ingestion_base.py#L390
12301230
self.stale_entity_removal_handler.set_job_id(workspace.id)
1231-
self.register_stateful_ingestion_usecase_handler(
1231+
self.state_provider.register_stateful_ingestion_usecase_handler(
12321232
self.stale_entity_removal_handler
12331233
)
12341234

metadata-ingestion/src/datahub/ingestion/source/sql/vertica.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,6 +1283,3 @@ def _get_owner_information(self, table: str, label: str) -> Optional[str]:
12831283
return each["owner_name"]
12841284

12851285
return None
1286-
1287-
def close(self):
1288-
self.prepare_for_commit()

metadata-ingestion/src/datahub/ingestion/source/state/entity_removal_state.py

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
from typing import Any, Dict, Iterable, List, Type
1+
from typing import Any, Dict, Iterable, List, Tuple, Type
22

33
import pydantic
44

55
from datahub.emitter.mce_builder import make_assertion_urn, make_container_urn
6-
from datahub.ingestion.source.state.stale_entity_removal_handler import (
7-
StaleEntityCheckpointStateBase,
8-
)
6+
from datahub.ingestion.source.state.checkpoint import CheckpointStateBase
97
from datahub.utilities.checkpoint_state_util import CheckpointStateUtil
108
from datahub.utilities.dedup_list import deduplicate_list
119
from datahub.utilities.urns.urn import guess_entity_type
@@ -56,7 +54,7 @@ def _validate_field_rename(cls: Type, values: dict) -> dict:
5654
return pydantic.root_validator(pre=True, allow_reuse=True)(_validate_field_rename)
5755

5856

59-
class GenericCheckpointState(StaleEntityCheckpointStateBase["GenericCheckpointState"]):
57+
class GenericCheckpointState(CheckpointStateBase):
6058
urns: List[str] = pydantic.Field(default_factory=list)
6159

6260
# We store a bit of extra internal-only state so that we can keep the urns list deduplicated.
@@ -85,21 +83,34 @@ def __init__(self, **data: Any): # type: ignore
8583
self.urns = deduplicate_list(self.urns)
8684
self._urns_set = set(self.urns)
8785

88-
@classmethod
89-
def get_supported_types(cls) -> List[str]:
90-
return ["*"]
91-
9286
def add_checkpoint_urn(self, type: str, urn: str) -> None:
87+
"""
88+
Adds an urn into the list used for tracking the type.
89+
90+
:param type: Deprecated parameter, has no effect.
91+
:param urn: The urn string
92+
"""
93+
94+
# TODO: Deprecate the `type` parameter and remove it.
9395
if urn not in self._urns_set:
9496
self.urns.append(urn)
9597
self._urns_set.add(urn)
9698

9799
def get_urns_not_in(
98100
self, type: str, other_checkpoint_state: "GenericCheckpointState"
99101
) -> Iterable[str]:
102+
"""
103+
Gets the urns present in this checkpoint but not the other_checkpoint for the given type.
104+
105+
:param type: Deprecated. Set to "*".
106+
:param other_checkpoint_state: the checkpoint state to compute the urn set difference against.
107+
:return: an iterable to the set of urns present in this checkpoint state but not in the other_checkpoint.
108+
"""
109+
100110
diff = set(self.urns) - set(other_checkpoint_state.urns)
101111

102112
# To maintain backwards compatibility, we provide this filtering mechanism.
113+
# TODO: Deprecate the `type` parameter and remove it.
103114
if type == "*":
104115
yield from diff
105116
elif type == "topic":
@@ -110,6 +121,36 @@ def get_urns_not_in(
110121
def get_percent_entities_changed(
111122
self, old_checkpoint_state: "GenericCheckpointState"
112123
) -> float:
113-
return StaleEntityCheckpointStateBase.compute_percent_entities_changed(
124+
"""
125+
Returns the percentage of entities that have changed relative to `old_checkpoint_state`.
126+
127+
:param old_checkpoint_state: the old checkpoint state to compute the relative change percent against.
128+
:return: (1-|intersection(self, old_checkpoint_state)| / |old_checkpoint_state|) * 100.0
129+
"""
130+
return compute_percent_entities_changed(
114131
[(self.urns, old_checkpoint_state.urns)]
115132
)
133+
134+
135+
def compute_percent_entities_changed(
136+
new_old_entity_list: List[Tuple[List[str], List[str]]]
137+
) -> float:
138+
old_count_all = 0
139+
overlap_count_all = 0
140+
for new_entities, old_entities in new_old_entity_list:
141+
(overlap_count, old_count, _,) = get_entity_overlap_and_cardinalities(
142+
new_entities=new_entities, old_entities=old_entities
143+
)
144+
overlap_count_all += overlap_count
145+
old_count_all += old_count
146+
if old_count_all:
147+
return (1 - overlap_count_all / old_count_all) * 100.0
148+
return 0.0
149+
150+
151+
def get_entity_overlap_and_cardinalities(
152+
new_entities: List[str], old_entities: List[str]
153+
) -> Tuple[int, int, int]:
154+
new_set = set(new_entities)
155+
old_set = set(old_entities)
156+
return len(new_set.intersection(old_set)), len(old_set), len(new_set)

metadata-ingestion/src/datahub/ingestion/source/state/profiling_state_handler.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,17 @@ def __init__(
4040
pipeline_name: Optional[str],
4141
run_id: str,
4242
):
43-
self.source = source
43+
self.state_provider = source.state_provider
4444
self.stateful_ingestion_config: Optional[
4545
ProfilingStatefulIngestionConfig
4646
] = config.stateful_ingestion
4747
self.pipeline_name = pipeline_name
4848
self.run_id = run_id
49-
self.checkpointing_enabled: bool = source.is_stateful_ingestion_configured()
49+
self.checkpointing_enabled: bool = (
50+
self.state_provider.is_stateful_ingestion_configured()
51+
)
5052
self._job_id = self._init_job_id()
51-
self.source.register_stateful_ingestion_usecase_handler(self)
53+
self.state_provider.register_stateful_ingestion_usecase_handler(self)
5254

5355
def _ignore_old_state(self) -> bool:
5456
if (
@@ -91,7 +93,7 @@ def create_checkpoint(self) -> Optional[Checkpoint[ProfilingCheckpointState]]:
9193
def get_current_state(self) -> Optional[ProfilingCheckpointState]:
9294
if not self.is_checkpointing_enabled() or self._ignore_new_state():
9395
return None
94-
cur_checkpoint = self.source.get_current_checkpoint(self.job_id)
96+
cur_checkpoint = self.state_provider.get_current_checkpoint(self.job_id)
9597
assert cur_checkpoint is not None
9698
cur_state = cast(ProfilingCheckpointState, cur_checkpoint.state)
9799
return cur_state
@@ -108,7 +110,7 @@ def add_to_state(
108110
def get_last_state(self) -> Optional[ProfilingCheckpointState]:
109111
if not self.is_checkpointing_enabled() or self._ignore_old_state():
110112
return None
111-
last_checkpoint = self.source.get_last_checkpoint(
113+
last_checkpoint = self.state_provider.get_last_checkpoint(
112114
self.job_id, ProfilingCheckpointState
113115
)
114116
if last_checkpoint and last_checkpoint.state:

metadata-ingestion/src/datahub/ingestion/source/state/redundant_run_skip_handler.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,17 @@ def __init__(
4646
run_id: str,
4747
):
4848
self.source = source
49+
self.state_provider = source.state_provider
4950
self.stateful_ingestion_config: Optional[
5051
StatefulRedundantRunSkipConfig
5152
] = config.stateful_ingestion
5253
self.pipeline_name = pipeline_name
5354
self.run_id = run_id
54-
self.checkpointing_enabled: bool = source.is_stateful_ingestion_configured()
55+
self.checkpointing_enabled: bool = (
56+
self.state_provider.is_stateful_ingestion_configured()
57+
)
5558
self._job_id = self._init_job_id()
56-
self.source.register_stateful_ingestion_usecase_handler(self)
59+
self.state_provider.register_stateful_ingestion_usecase_handler(self)
5760

5861
def _ignore_old_state(self) -> bool:
5962
if (
@@ -114,7 +117,7 @@ def update_state(
114117
) -> None:
115118
if not self.is_checkpointing_enabled() or self._ignore_new_state():
116119
return
117-
cur_checkpoint = self.source.get_current_checkpoint(self.job_id)
120+
cur_checkpoint = self.state_provider.get_current_checkpoint(self.job_id)
118121
assert cur_checkpoint is not None
119122
cur_state = cast(BaseUsageCheckpointState, cur_checkpoint.state)
120123
cur_state.begin_timestamp_millis = start_time_millis
@@ -125,7 +128,7 @@ def should_skip_this_run(self, cur_start_time_millis: int) -> bool:
125128
return False
126129
# Determine from the last check point state
127130
last_successful_pipeline_run_end_time_millis: Optional[int] = None
128-
last_checkpoint = self.source.get_last_checkpoint(
131+
last_checkpoint = self.state_provider.get_last_checkpoint(
129132
self.job_id, BaseUsageCheckpointState
130133
)
131134
if last_checkpoint and last_checkpoint.state:

0 commit comments

Comments
 (0)