Skip to content

Commit 494263f

Browse files
authored
fix: Update validation logic for Orbax DAGs (#1197)
This chnages update sthe logs query pattern to use `textPayload` instead of `jsonPayload`, resolving validation failures.
1 parent 366c45b commit 494263f

File tree

5 files changed

+16
-21
lines changed

5 files changed

+16
-21
lines changed

dags/orbax/maxtext_emc_restore_gcs.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,8 @@
150150
)
151151

152152
log_filters = [
153-
"jsonPayload.message:\"'event_type': 'emergency_restore'\"",
154-
"jsonPayload.message:\"'is_restoring_slice': True\"",
155-
"jsonPayload.message:\"'directory': 'gs://\"",
153+
"textPayload:\"'event_type': 'restore'\"",
154+
"textPayload:\"'directory': 'gs://\"",
156155
]
157156
validate_restored_source = validation_util.validate_log_exist.override(
158157
task_id="validate_emc_restored_from_gcs"

dags/orbax/maxtext_emc_restore_local.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,8 @@
146146
)
147147

148148
log_filters = [
149-
"jsonPayload.message:\"'event_type': 'emergency_restore'\"",
150-
"jsonPayload.message:\"'is_restoring_slice': True\"",
151-
"jsonPayload.message:\"'directory': '/local/\"",
149+
"textPayload:\"'event_type': 'restore'\"",
150+
"textPayload:\"'directory': '/local/\"",
152151
]
153152
validate_restored_source = validation_util.validate_log_exist.override(
154153
task_id="validate_emc_restored_from_local"

dags/orbax/maxtext_emc_resume_gcs.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,9 +213,8 @@
213213
)
214214

215215
log_filters = [
216-
"jsonPayload.message:\"'event_type': 'emergency_restore'\"",
217-
"jsonPayload.message:\"'is_restoring_slice': True\"",
218-
"jsonPayload.message:\"'directory': 'gs://\"",
216+
"textPayload:\"'event_type': 'restore'\"",
217+
"textPayload:\"'directory': 'gs://\"",
219218
]
220219
validate_restore_source = validation_util.validate_log_exist.override(
221220
task_id="validate_emc_restored_from_gcs"

dags/orbax/maxtext_mtc_restore_local.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,8 @@
149149
)
150150

151151
log_filters = [
152-
"jsonPayload.message:\"'event_type': 'restore'\"",
153-
"jsonPayload.message:\"'directory': '/local\"",
152+
"textPayload:\"'event_type': 'restore'\"",
153+
"textPayload:\"'directory': '/local\"",
154154
]
155155
validate_restored_source = validation_util.validate_log_exist.override(
156156
task_id="validate_restore_copy_from_peer"

dags/orbax/util/validation_util.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -68,21 +68,19 @@ def validate_checkpoint_at_steps_are_saved(
6868
location=location,
6969
cluster_name=cluster_name,
7070
pod_pattern=pod_pattern,
71-
text_filter=f'jsonPayload.message=~"{log_pattern}"',
71+
text_filter=f'textPayload=~"{log_pattern}"',
7272
start_time=start_time,
7373
end_time=end_time,
7474
)
7575

7676
steps_are_saved: set[int] = set() # Use a set for faster lookup.
7777
for entry in entries:
78-
if not isinstance(entry, logging_api.StructEntry):
78+
if not isinstance(entry, logging_api.TextEntry):
7979
raise AirflowFailException(
80-
"Log entry must be contain a jsonPayload attribute."
80+
"Log entry must be contain a textPayload attribute."
8181
)
82-
message = entry.payload.get("message")
83-
if not message:
84-
raise AirflowFailException(f"Failed to parse entry {entry}")
8582

83+
message = entry.payload
8684
m = complied_pattern.search(message)
8785
if m:
8886
steps_are_saved.add(int(m.group(1)))
@@ -468,7 +466,7 @@ def validate_restored_correct_checkpoint(
468466
cluster_name=cluster_name,
469467
namespace="default",
470468
pod_pattern=pod_pattern,
471-
text_filter="jsonPayload.message:\"'event_type'\"",
469+
text_filter="textPayload:\"'event_type'\"",
472470
start_time=start_time,
473471
end_time=end_time,
474472
)
@@ -478,12 +476,12 @@ def validate_restored_correct_checkpoint(
478476

479477
local_saved_steps_before_restore = []
480478
for entry in entries:
481-
if not isinstance(entry, logging_api.StructEntry):
479+
if not isinstance(entry, logging_api.TextEntry):
482480
raise AirflowFailException(
483-
"Log entry must be contain a jsonPayload attribute."
481+
"Log entry must be contain a textPayload attribute."
484482
)
485483

486-
message = entry.payload.get("message")
484+
message = entry.payload
487485

488486
if re.search(r"'event_type': 'save'", message):
489487
saved_step_match = re.search(r"'step': (\d+)", message)

0 commit comments

Comments
 (0)