Skip to content

Commit 37da3e0

Browse files
authored
PIM writeback on HITL approval (#390)
* truth-hitl: support enhanced review payload shapes for UI * Add HITL approval-driven PIM writeback flow * Sort imports for truth HITL/export modules * Format truth HITL/export files for lint
1 parent 9bb33c0 commit 37da3e0

File tree

9 files changed

+441
-27
lines changed

9 files changed

+441
-27
lines changed

apps/truth-export/src/truth_export/event_handlers.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,25 @@ async def handle_export_job(partition_context, event) -> None: # noqa: ANN001
3939
return
4040

4141
if _is_hitl_writeback_event(payload, data, protocol):
42-
result = await engine.writeback_entity(
42+
approved_fields_raw = data.get("approved_fields")
43+
approved_fields = (
44+
[str(field) for field in approved_fields_raw if field]
45+
if isinstance(approved_fields_raw, list)
46+
else None
47+
)
48+
49+
result = await engine.writeback_to_pim(
4350
adapters.writeback_manager,
51+
adapters.truth_store,
4452
str(entity_id),
53+
approved_attributes=approved_fields,
4554
dry_run=bool(data.get("dry_run", False)),
4655
)
4756
await adapters.truth_store.save_export_result(result)
4857
audit = engine.build_writeback_audit_event(
4958
entity_id=str(entity_id),
5059
result=result,
51-
trigger="event:export-jobs",
60+
trigger="event:export-jobs:hitl-approved",
5261
)
5362
await adapters.truth_store.save_audit_event(audit.model_dump())
5463
logger.info(

apps/truth-export/src/truth_export/export_engine.py

Lines changed: 126 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from holiday_peak_lib.adapters.acp_mapper import AcpCatalogMapper
1010
from holiday_peak_lib.adapters.ucp_mapper import UcpProtocolMapper
11-
from holiday_peak_lib.integrations import PIMWritebackManager
11+
from holiday_peak_lib.integrations import PIMWritebackManager, WritebackResult, WritebackStatus
1212

1313
from .schemas_compat import AuditAction, AuditEvent, ExportResult
1414

@@ -118,6 +118,63 @@ async def writeback_entity(
118118
summary["pim_response_summary"] = self._build_pim_response_summary(summary)
119119
return summary
120120

121+
async def writeback_to_pim(
122+
self,
123+
manager: PIMWritebackManager,
124+
truth_store: Any,
125+
entity_id: str,
126+
*,
127+
approved_attributes: list[str] | None = None,
128+
dry_run: bool = False,
129+
) -> dict[str, Any]:
130+
"""Write back approved attributes only, preserving manager safety guards."""
131+
if not approved_attributes:
132+
return await self.writeback_entity(manager, entity_id, dry_run=dry_run)
133+
134+
approved_set = {field for field in approved_attributes if field}
135+
if dry_run:
136+
preview = await self.writeback_entity(manager, entity_id, dry_run=True)
137+
return self._filter_writeback_summary(preview, approved_set)
138+
139+
raw_attributes = await truth_store.get_attributes(entity_id)
140+
selected = [
141+
attr
142+
for attr in raw_attributes
143+
if attr.get("field") in approved_set and attr.get("writeback_eligible", False)
144+
]
145+
146+
if not selected:
147+
skipped_summary: dict[str, Any] = {
148+
"entity_id": entity_id,
149+
"total": 0,
150+
"succeeded": 0,
151+
"skipped": 0,
152+
"conflicts": 0,
153+
"errors": 0,
154+
"timestamp": datetime.now(timezone.utc).isoformat(),
155+
"results": [],
156+
"dry_run": False,
157+
"status": "skipped",
158+
}
159+
skipped_summary["pim_response_summary"] = self._build_pim_response_summary(
160+
skipped_summary
161+
)
162+
return skipped_summary
163+
164+
async def _write_one(attr: dict[str, Any]) -> WritebackResult:
165+
return await manager.writeback_attribute(
166+
entity_id,
167+
field=str(attr.get("field")),
168+
value=attr.get("value"),
169+
truth_version=attr.get("version"),
170+
)
171+
172+
results = list(await asyncio.gather(*[_write_one(attr) for attr in selected]))
173+
summary = self._summarize_field_results(entity_id, results, dry_run=False)
174+
summary["status"] = self._resolve_writeback_status(summary)
175+
summary["pim_response_summary"] = self._build_pim_response_summary(summary)
176+
return summary
177+
121178
async def writeback_batch(
122179
self,
123180
manager: PIMWritebackManager,
@@ -185,3 +242,71 @@ def _build_pim_response_summary(result: dict[str, Any]) -> dict[str, Any]:
185242
"errors": result.get("errors", 0),
186243
"messages": messages,
187244
}
245+
246+
def _summarize_field_results(
247+
self,
248+
entity_id: str,
249+
field_results: list[WritebackResult],
250+
*,
251+
dry_run: bool,
252+
) -> dict[str, Any]:
253+
succeeded = 0
254+
skipped = 0
255+
conflicts = 0
256+
errors = 0
257+
result_items: list[dict[str, Any]] = []
258+
259+
for item in field_results:
260+
status = item.status
261+
if status == WritebackStatus.SUCCESS:
262+
succeeded += 1
263+
elif status in (WritebackStatus.SKIPPED, WritebackStatus.DRY_RUN):
264+
skipped += 1
265+
elif status == WritebackStatus.CONFLICT:
266+
conflicts += 1
267+
else:
268+
errors += 1
269+
result_items.append(item.model_dump())
270+
271+
return {
272+
"entity_id": entity_id,
273+
"total": len(field_results),
274+
"succeeded": succeeded,
275+
"skipped": skipped,
276+
"conflicts": conflicts,
277+
"errors": errors,
278+
"timestamp": datetime.now(timezone.utc).isoformat(),
279+
"results": result_items,
280+
"dry_run": dry_run,
281+
}
282+
283+
def _filter_writeback_summary(
284+
self,
285+
summary: dict[str, Any],
286+
approved_set: set[str],
287+
) -> dict[str, Any]:
288+
filtered_results = [
289+
result for result in summary.get("results", []) if result.get("field") in approved_set
290+
]
291+
filtered = dict(summary)
292+
filtered["results"] = filtered_results
293+
filtered["total"] = len(filtered_results)
294+
filtered["succeeded"] = len(
295+
[r for r in filtered_results if r.get("status") == WritebackStatus.SUCCESS.value]
296+
)
297+
filtered["skipped"] = len(
298+
[
299+
r
300+
for r in filtered_results
301+
if r.get("status") in {WritebackStatus.SKIPPED.value, WritebackStatus.DRY_RUN.value}
302+
]
303+
)
304+
filtered["conflicts"] = len(
305+
[r for r in filtered_results if r.get("status") == WritebackStatus.CONFLICT.value]
306+
)
307+
filtered["errors"] = len(
308+
[r for r in filtered_results if r.get("status") == WritebackStatus.ERROR.value]
309+
)
310+
filtered["status"] = self._resolve_writeback_status(filtered)
311+
filtered["pim_response_summary"] = self._build_pim_response_summary(filtered)
312+
return filtered

apps/truth-export/src/truth_export/routes.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import asyncio
56
from typing import Any
67

78
from fastapi import APIRouter, Depends, HTTPException
@@ -39,12 +40,14 @@ class BulkExportRequest(BaseModel):
3940
class PIMWritebackRequest(BaseModel):
4041
dry_run: bool = False
4142
trigger: str = "api"
43+
approved_fields: list[str] | None = None
4244

4345

4446
class PIMBatchWritebackRequest(BaseModel):
4547
entity_ids: list[str] = Field(min_length=1, max_length=100)
4648
dry_run: bool = False
4749
max_concurrency: int = Field(default=5, ge=1, le=20)
50+
approved_fields: list[str] | None = None
4851

4952

5053
# ---------------------------------------------------------------------------
@@ -103,12 +106,29 @@ async def export_pim_batch(
103106
if len(request.entity_ids) > 100:
104107
raise HTTPException(status_code=400, detail="Batch size cannot exceed 100 entities")
105108

106-
results = await engine.writeback_batch(
107-
adapters.writeback_manager,
108-
request.entity_ids,
109-
dry_run=request.dry_run,
110-
max_concurrency=request.max_concurrency,
111-
)
109+
if request.approved_fields:
110+
semaphore = asyncio.Semaphore(max(1, request.max_concurrency))
111+
112+
async def _run_one(entity_id: str) -> dict[str, Any]:
113+
async with semaphore:
114+
return await engine.writeback_to_pim(
115+
adapters.writeback_manager,
116+
adapters.truth_store,
117+
entity_id,
118+
approved_attributes=request.approved_fields,
119+
dry_run=request.dry_run,
120+
)
121+
122+
results = list(
123+
await asyncio.gather(*[_run_one(entity_id) for entity_id in request.entity_ids])
124+
)
125+
else:
126+
results = await engine.writeback_batch(
127+
adapters.writeback_manager,
128+
request.entity_ids,
129+
dry_run=request.dry_run,
130+
max_concurrency=request.max_concurrency,
131+
)
112132

113133
for entity_result in results:
114134
await adapters.truth_store.save_export_result(entity_result)
@@ -140,6 +160,14 @@ async def export_pim_single(
140160
entity_id,
141161
dry_run=request.dry_run,
142162
)
163+
if request.approved_fields:
164+
result = await engine.writeback_to_pim(
165+
adapters.writeback_manager,
166+
adapters.truth_store,
167+
entity_id,
168+
approved_attributes=request.approved_fields,
169+
dry_run=request.dry_run,
170+
)
143171
await adapters.truth_store.save_export_result(result)
144172
audit_event = engine.build_writeback_audit_event(
145173
entity_id=entity_id,

apps/truth-export/tests/test_event_handlers.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,13 @@
1313
)
1414
from truth_export.adapters import build_truth_export_adapters
1515
from truth_export.event_handlers import build_event_handlers
16-
from truth_export.schemas_compat import ProductStyle
16+
from truth_export.schemas_compat import ProductStyle, TruthAttribute
1717

1818

1919
class _StubWritebackManager:
20+
def __init__(self) -> None:
21+
self.called_fields: list[str] = []
22+
2023
async def dry_run(self, entity_id: str) -> ProductWritebackResult:
2124
return await self.writeback_product(entity_id)
2225

@@ -35,6 +38,17 @@ async def writeback_product(self, entity_id: str) -> ProductWritebackResult:
3538
],
3639
)
3740

41+
async def writeback_attribute(self, entity_id: str, field: str, value, *, truth_version=None):
42+
_ = value
43+
_ = truth_version
44+
self.called_fields.append(field)
45+
return WritebackResult(
46+
entity_id=entity_id,
47+
field=field,
48+
status=WritebackStatus.SUCCESS,
49+
message="Writeback succeeded",
50+
)
51+
3852

3953
@pytest.mark.asyncio
4054
async def test_build_event_handlers_includes_export_jobs() -> None:
@@ -45,7 +59,27 @@ async def test_build_event_handlers_includes_export_jobs() -> None:
4559
@pytest.mark.asyncio
4660
async def test_export_jobs_hitl_approval_triggers_writeback_path() -> None:
4761
adapters = build_truth_export_adapters()
48-
adapters.writeback_manager = _StubWritebackManager()
62+
stub_manager = _StubWritebackManager()
63+
adapters.writeback_manager = stub_manager
64+
adapters.truth_store.seed_attributes(
65+
"STYLE-001",
66+
[
67+
TruthAttribute(
68+
entityType="style",
69+
entityId="STYLE-001",
70+
attributeKey="title",
71+
value="Approved title",
72+
source="SYSTEM",
73+
),
74+
TruthAttribute(
75+
entityType="style",
76+
entityId="STYLE-001",
77+
attributeKey="description",
78+
value="Should not be written",
79+
source="SYSTEM",
80+
),
81+
],
82+
)
4983

5084
handlers = build_event_handlers(adapters=adapters)
5185
handler = handlers["export-jobs"]
@@ -59,6 +93,7 @@ async def test_export_jobs_hitl_approval_triggers_writeback_path() -> None:
5993
"entity_id": "STYLE-001",
6094
"protocol": "pim",
6195
"status": "approved",
96+
"approved_fields": ["title"],
6297
},
6398
}
6499
)
@@ -71,6 +106,7 @@ async def test_export_jobs_hitl_approval_triggers_writeback_path() -> None:
71106
adapters.truth_store._audit_events[-1]["details"]["writeback_status"]
72107
== "completed" # pylint: disable=protected-access
73108
)
109+
assert stub_manager.called_fields == ["title"]
74110

75111

76112
@pytest.mark.asyncio

apps/truth-export/tests/test_export.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,3 +339,59 @@ async def get_attributes(self, _entity_id: str):
339339
assert result["status"] == "conflict"
340340
assert result["conflicts"] == 1
341341
pim.push_enrichment.assert_not_called()
342+
343+
344+
@pytest.mark.asyncio
345+
async def test_writeback_to_pim_uses_approved_fields_only(engine, adapters):
346+
adapters.truth_store.seed_attributes(
347+
"STYLE-001",
348+
[
349+
TruthAttribute(
350+
entityType="style",
351+
entityId="STYLE-001",
352+
attributeKey="title",
353+
value="Approved title",
354+
source="SYSTEM",
355+
),
356+
TruthAttribute(
357+
entityType="style",
358+
entityId="STYLE-001",
359+
attributeKey="description",
360+
value="Should not be written",
361+
source="SYSTEM",
362+
),
363+
],
364+
)
365+
366+
class _StubManager:
367+
def __init__(self):
368+
self.called_fields: list[str] = []
369+
370+
async def writeback_attribute(
371+
self, entity_id: str, field: str, value, *, truth_version=None
372+
):
373+
_ = entity_id
374+
_ = value
375+
_ = truth_version
376+
self.called_fields.append(field)
377+
378+
from holiday_peak_lib.integrations import WritebackResult, WritebackStatus
379+
380+
return WritebackResult(
381+
entity_id="STYLE-001",
382+
field=field,
383+
status=WritebackStatus.SUCCESS,
384+
message="Writeback succeeded",
385+
)
386+
387+
manager = _StubManager()
388+
result = await engine.writeback_to_pim(
389+
manager,
390+
adapters.truth_store,
391+
"STYLE-001",
392+
approved_attributes=["title"],
393+
)
394+
395+
assert manager.called_fields == ["title"]
396+
assert result["status"] == "completed"
397+
assert result["total"] == 1

0 commit comments

Comments
 (0)