Skip to content

Commit b2bdd5d

Browse files
committed
fixup! fixup! fixup! fixup! tmp
1 parent 206a3cb commit b2bdd5d

File tree

2 files changed

+45
-12
lines changed

2 files changed

+45
-12
lines changed

airflow-core/src/airflow/assets/manager.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,13 @@ def _queue_partitioned_dags(
390390
timetable = serdag.dag.timetable
391391
if TYPE_CHECKING:
392392
assert isinstance(timetable, PartitionedAssetTimetable)
393-
target_key = timetable.get_partition_mapper(asset_id=asset_id).to_downstream(partition_key)
393+
394+
if asset_model := session.scalar(select(AssetModel).where(AssetModel.id == asset_id)) is None:
395+
raise ValueError()
396+
397+
target_key = timetable.get_partition_mapper(
398+
name=asset_model.name, uri=asset_model.uri
399+
).to_downstream(partition_key)
394400

395401
apdr = cls._get_or_create_apdr(
396402
target_key=target_key,

airflow-core/src/airflow/timetables/simple.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19+
from contextlib import suppress
1920
from typing import TYPE_CHECKING, Any, TypeAlias
2021

2122
from airflow._shared.timezones import timezone
2223
from airflow.serialization.definitions.assets import SerializedAsset, SerializedAssetAll, SerializedAssetBase
24+
from airflow.serialization.encoders import encode_asset_like, encode_partition_mapper
2325
from airflow.timetables.base import DagRunInfo, DataInterval, Timetable
2426

2527
try:
@@ -201,8 +203,6 @@ def summary(self) -> str:
201203
return "Asset"
202204

203205
def serialize(self) -> dict[str, Any]:
204-
from airflow.serialization.encoders import encode_asset_like
205-
206206
return {"asset_condition": encode_asset_like(self.asset_condition)}
207207

208208
def generate_run_id(
@@ -245,17 +245,37 @@ def summary(self) -> str:
245245
def __init__(self, assets: SerializedAssetBase, default_partition_mapper: PartitionMapper) -> None:
246246
super().__init__(assets=assets)
247247
self.default_partition_mapper = default_partition_mapper
248-
# TODO: (AIP-76) implement
249-
self._partition_mappers = None
248+
# TODO: (AIP-76) implement, serialized partition mapper?
249+
self._partition_mappers_by_name: dict = {}
250+
self._partition_mappers_by_uri: dict = {}
251+
self._build_partition_mappers_mapping()
252+
253+
def _build_partition_mappers_mapping(self) -> None:
254+
for _, ser_asset in self.asset_condition.iter_assets():
255+
if partition_mapper := ser_asset.partition_mapper:
256+
self._partition_mappers_by_name[ser_asset.name] = partition_mapper
257+
self._partition_mappers_by_uri[ser_asset.uri] = partition_mapper
258+
else:
259+
self._partition_mappers_by_name[ser_asset.name] = self.default_partition_mapper
260+
self._partition_mappers_by_uri[ser_asset.uri] = self.default_partition_mapper
261+
262+
# TODO: (AIP-76) handle asset alias, asset ref
250263

264+
# TODO: (AIP-76) how could we allow user to customize this timetable?
251265
def serialize(self) -> dict[str, Any]:
252266
from airflow.serialization.serialized_objects import encode_asset_like
253267

254268
return {
255269
"asset_condition": encode_asset_like(self.asset_condition),
256270
"partition_mapper": self.default_partition_mapper.serialize(),
257-
# TODO: (AIP-76) implement
258-
"_partition_mappers": None,
271+
"_partition_mappers_by_name": [
272+
(name, encode_partition_mapper(partition_mapper))
273+
for name, partition_mapper in self._partition_mappers_by_name.items()
274+
],
275+
"_partition_mappers_by_uri": [
276+
(uri, encode_partition_mapper(partition_mapper))
277+
for uri, partition_mapper in self._partition_mappers_by_ui.items()
278+
],
259279
}
260280

261281
@classmethod
@@ -267,10 +287,17 @@ def deserialize(cls, data: dict[str, Any]) -> Timetable:
267287
assets=decode_asset_like(data["asset_condition"]),
268288
default_partition_mapper=decode_partition_mapper(data["partition_mapper"]),
269289
)
270-
# TODO: (AIP-76) implement
271-
timetable._partition_mappers = None
290+
timetable._partition_mappers_by_name = {
291+
name: decode_partition_mapper(ser_partition_mapper)
292+
for name, ser_partition_mapper in data["_partition_mappers_by_name"].items()
293+
}
294+
timetable._partition_mappers_by_uri = {
295+
uri: decode_partition_mapper(ser_partition_mapper)
296+
for uri, ser_partition_mapper in data["_partition_mappers_by_uri"].items()
297+
}
272298
return timetable
273299

274-
# TODO: (AIP-76): get partition_mapper implement
275-
def get_partition_mapper(self, asset_id: id) -> PartitionMapper:
276-
pass
300+
def get_partition_mapper(self, *, name: str, uri: str) -> PartitionMapper:
301+
with suppress(KeyError):
302+
return self._partition_mappers_by_name[name]
303+
return self._partition_mappers_by_uri[uri]

0 commit comments

Comments
 (0)