Skip to content

Commit 9dccc6f

Browse files
committed
Updates for adx ingestion speed and on error rollback
1 parent e632e97 commit 9dccc6f

File tree

10 files changed

+108
-445
lines changed

10 files changed

+108
-445
lines changed

cosmotech/coal/azure/adx/ingestion.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,6 @@ def check_ingestion_status(
135135
client: QueuedIngestClient,
136136
source_ids: List[str],
137137
timeout: Optional[int] = None,
138-
logs: bool = False,
139138
) -> Iterator[Tuple[str, IngestionStatus]]:
140139
"""
141140
Check the status of ingestion operations.
@@ -144,7 +143,6 @@ def check_ingestion_status(
144143
client: The QueuedIngestClient to use
145144
source_ids: List of source IDs to check
146145
timeout: Timeout in seconds (default: 900)
147-
logs: Whether to log detailed information
148146
149147
Returns:
150148
Iterator of (source_id, status) tuples
@@ -185,7 +183,7 @@ def get_messages(queues):
185183

186184
LOGGER.debug(T("coal.logs.adx.status_messages").format(success=len(successes), failure=len(failures)))
187185

188-
non_sent_ids = remaining_ids[:]
186+
queued_ids = list(remaining_ids)
189187
# Process success and failure messages
190188
for messages, cast_func, status, log_function in [
191189
(successes, SuccessMessage, IngestionStatus.SUCCESS, LOGGER.debug),
@@ -207,11 +205,9 @@ def get_messages(queues):
207205
else:
208206
# The message did not correspond to a known ID
209207
continue
210-
break
211208
else:
212209
# No message was found on the current list of messages for the given IDs
213210
continue
214-
break
215211

216212
# Check for timeouts
217213
actual_timeout = timeout if timeout is not None else default_timeout
@@ -221,7 +217,7 @@ def get_messages(queues):
221217
LOGGER.warning(T("coal.logs.adx.ingestion_timeout").format(source_id=source_id))
222218

223219
# Yield results for remaining IDs
224-
for source_id in non_sent_ids:
220+
for source_id in queued_ids:
225221
yield source_id, _ingest_status[source_id]
226222

227223

cosmotech/coal/azure/adx/store.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def send_pyarrow_table_to_adx(
4040
data_format=DataFormat.CSV,
4141
drop_by_tags=drop_by_tags,
4242
report_level=ReportLevel.FailuresAndSuccesses,
43+
flush_immediately=True,
4344
)
4445

4546
file_name = f"adx_{database}_{table_name}_{int(time.time())}_{uuid.uuid4()}.csv"

cosmotech/coal/azure/adx/wrapper.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,20 +112,19 @@ def ingest_dataframe(self, table_name: str, dataframe: Any, drop_by_tag: str = N
112112
return ingest_dataframe(self.ingest_client, self.database, table_name, dataframe, drop_by_tag)
113113

114114
def check_ingestion_status(
115-
self, source_ids: List[str], timeout: int = None, logs: bool = False
115+
self, source_ids: List[str], timeout: int = None
116116
) -> Iterator[Tuple[str, IngestionStatus]]:
117117
"""
118118
Check the status of ingestion operations.
119119
120120
Args:
121121
source_ids: List of source IDs to check
122122
timeout: Timeout in seconds (default: self.timeout)
123-
logs: Whether to log detailed information
124123
125124
Returns:
126125
Iterator of (source_id, status) tuples
127126
"""
128-
return check_ingestion_status(self.ingest_client, source_ids, timeout or self.timeout, logs)
127+
return check_ingestion_status(self.ingest_client, source_ids, timeout or self.timeout)
129128

130129
def _clear_ingestion_status_queues(self, confirmation: bool = False):
131130
"""

cosmotech/coal/csm/engine/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,4 @@ def apply_simple_csv_parameter_to_simulator(
4444
raise ValueError(f"Parameter {parameter_name} does not exists.")
4545

4646

47-
__all__ = [apply_simple_csv_parameter_to_simulator]
47+
__all__ = ["apply_simple_csv_parameter_to_simulator"]

cosmotech/coal/postgresql/runner.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def send_runner_metadata_to_postgresql(
5454

5555
# Generate PostgreSQL URI
5656
postgresql_full_uri = generate_postgresql_full_uri(
57-
postgres_host, postgres_port, postgres_db, postgres_user, postgres_password
57+
postgres_host, str(postgres_port), postgres_db, postgres_user, postgres_password
5858
)
5959

6060
# Connect to PostgreSQL and update runner metadata
@@ -76,10 +76,10 @@ def send_runner_metadata_to_postgresql(
7676
DO
7777
UPDATE SET name = EXCLUDED.name, last_run_id = EXCLUDED.last_run_id;
7878
"""
79-
LOGGER.info(f"creating table {schema_table}")
79+
LOGGER.info(T("coal.logs.postgreql.runner.creating_table").format(schema_table=schema_table))
8080
curs.execute(sql_create_table)
8181
conn.commit()
82-
LOGGER.info(f"adding/updating runner metadata")
82+
LOGGER.info(T("coal.logs.postgreql.runner.metadata"))
8383
curs.execute(
8484
sql_upsert,
8585
(
@@ -90,4 +90,4 @@ def send_runner_metadata_to_postgresql(
9090
),
9191
)
9292
conn.commit()
93-
LOGGER.info("Runner metadata table has been updated")
93+
LOGGER.info(T("coal.logs.postgreql.runner.metadata_updated"))

cosmotech/csm_data/commands/adx_send_data.py

Lines changed: 94 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,19 @@
4545
show_default=True,
4646
help="Wait for ingestion to complete",
4747
)
48+
@click.option(
49+
"--tag",
50+
envvar="CSM_DATA_ADX_TAG",
51+
show_envvar=True,
52+
default=None,
53+
help="Optional tag to use for tracking and potential rollback of this ingestion operation",
54+
)
4855
def adx_send_data(
4956
adx_uri: str,
5057
adx_ingest_uri: str,
5158
database_name: str,
5259
wait: bool,
60+
tag: str = None,
5361
):
5462
# Import the function at the start of the command
5563
from cosmotech.coal.azure.adx.auth import create_ingest_client, create_kusto_client
@@ -61,8 +69,13 @@ def adx_send_data(
6169
from cosmotech.coal.azure.adx import type_mapping
6270

6371
import time
72+
import uuid
6473
from cosmotech.coal.azure.adx import IngestionStatus
6574

75+
# Generate operation tag if not provided
76+
operation_tag = tag or f"op-{str(uuid.uuid4())}"
77+
LOGGER.debug(f"Starting ingestion operation with tag: {operation_tag}")
78+
6679
LOGGER.debug("Initializing clients")
6780
kusto_client = create_kusto_client(adx_uri)
6881
ingest_client = create_ingest_client(adx_ingest_uri)
@@ -79,7 +92,7 @@ def adx_send_data(
7992
data = s.get_table(target_table_name)
8093

8194
if data.num_rows < 1:
82-
LOGGER.warn(f"Table {target_table_name} has no rows - skipping it")
95+
LOGGER.warning(f"Table {target_table_name} has no rows - skipping it")
8396
continue
8497

8598
LOGGER.debug(" - Checking if table exists")
@@ -99,38 +112,91 @@ def adx_send_data(
99112
create_table(kusto_client, database, target_table_name, mapping)
100113

101114
LOGGER.debug(f"Sending data to the table {target_table_name}")
102-
result = send_pyarrow_table_to_adx(ingest_client, database, target_table_name, data, None)
115+
# Use the operation_tag as the drop_by_tag parameter
116+
result = send_pyarrow_table_to_adx(ingest_client, database, target_table_name, data, operation_tag)
103117
source_ids.append(result.source_id)
104118
table_ingestion_id_mapping[result.source_id] = target_table_name
105119

120+
# Track if any failures occur
121+
has_failures = False
122+
106123
LOGGER.info("Store data was sent for ADX ingestion")
107-
if wait:
108-
LOGGER.info("Waiting for ingestion of data to finish")
109-
import tqdm
110-
111-
with tqdm.tqdm(desc="Ingestion status", total=len(source_ids)) as pbar:
112-
while any(
113-
map(
114-
lambda _status: _status[1] in (IngestionStatus.QUEUED, IngestionStatus.UNKNOWN),
115-
results := list(check_ingestion_status(ingest_client, source_ids)),
116-
)
117-
):
118-
cleared_ids = list(
119-
result for result in results if result[1] not in (IngestionStatus.QUEUED, IngestionStatus.UNKNOWN)
120-
)
121-
122-
for ingestion_id, ingestion_status in cleared_ids:
123-
pbar.update(1)
124-
source_ids.remove(ingestion_id)
125-
126-
if os.environ.get("CSM_USE_RICH", "False").lower() in ("true", "1", "yes", "t", "y"):
127-
for _ in range(10):
128-
time.sleep(1)
129-
pbar.update(0)
124+
try:
125+
if wait:
126+
LOGGER.info("Waiting for ingestion of data to finish")
127+
import tqdm
128+
129+
with tqdm.tqdm(desc="Ingestion status", total=len(source_ids)) as pbar:
130+
while any(
131+
list(
132+
map(
133+
lambda _status: _status[1] in (IngestionStatus.QUEUED, IngestionStatus.UNKNOWN),
134+
results := list(check_ingestion_status(ingest_client, source_ids)),
135+
)
136+
)
137+
):
138+
# Check for failures
139+
for ingestion_id, ingestion_status in results:
140+
if ingestion_status == IngestionStatus.FAILURE:
141+
LOGGER.error(
142+
f"Ingestion {ingestion_id} failed for table {table_ingestion_id_mapping.get(ingestion_id)}"
143+
)
144+
has_failures = True
145+
146+
cleared_ids = list(
147+
result
148+
for result in results
149+
if result[1] not in (IngestionStatus.QUEUED, IngestionStatus.UNKNOWN)
150+
)
151+
152+
for ingestion_id, ingestion_status in cleared_ids:
153+
pbar.update(1)
154+
source_ids.remove(ingestion_id)
155+
156+
time.sleep(1)
157+
if os.environ.get("CSM_USE_RICH", "False").lower() in ("true", "1", "yes", "t", "y"):
158+
pbar.refresh()
130159
else:
131-
time.sleep(10)
132-
pbar.update(len(source_ids))
133-
LOGGER.info("All data got ingested")
160+
for ingestion_id, ingestion_status in results:
161+
if ingestion_status == IngestionStatus.FAILURE:
162+
LOGGER.error(
163+
f"Ingestion {ingestion_id} failed for table {table_ingestion_id_mapping.get(ingestion_id)}"
164+
)
165+
has_failures = True
166+
pbar.update(len(source_ids))
167+
LOGGER.info("All data ingestion attempts completed")
168+
has_failures = True
169+
170+
# If any ingestion failed, perform rollback
171+
if has_failures:
172+
LOGGER.warning(f"Failures detected during ingestion - dropping data with tag: {operation_tag}")
173+
_drop_by_tag(kusto_client, database, operation_tag)
174+
175+
except Exception as e:
176+
LOGGER.exception("Error during ingestion process")
177+
# Perform rollback using the tag
178+
LOGGER.warning(f"Dropping data with tag: {operation_tag}")
179+
_drop_by_tag(kusto_client, database, operation_tag)
180+
raise e
181+
182+
if has_failures:
183+
click.Abort()
184+
185+
186+
def _drop_by_tag(kusto_client, database, tag):
187+
"""
188+
Drop all data with the specified tag
189+
"""
190+
LOGGER.info(f"Dropping data with tag: {tag}")
191+
192+
try:
193+
# Execute the drop by tag command
194+
drop_command = f'.drop extents <| .show database extents where tags has "drop-by:{tag}"'
195+
kusto_client.execute_mgmt(database, drop_command)
196+
LOGGER.info("Drop by tag operation completed")
197+
except Exception as e:
198+
LOGGER.error(f"Error during drop by tag operation: {str(e)}")
199+
LOGGER.exception("Drop by tag details")
134200

135201

136202
if __name__ == "__main__":

tests/unit/coal/test_azure/test_adx/test_adx_ingestion.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -357,31 +357,6 @@ def test_check_ingestion_status_with_timeout(
357357
assert len(result) == 1
358358
assert result[0] == (source_id, IngestionStatus.TIMEOUT)
359359

360-
@patch("cosmotech.coal.azure.adx.ingestion.KustoIngestStatusQueues")
361-
def test_check_ingestion_status_with_logs(self, mock_status_queues_class, mock_ingest_client, mock_status_queues):
362-
"""Test the check_ingestion_status function with logs enabled."""
363-
# Arrange
364-
source_id = "source-id-logs"
365-
_ingest_status[source_id] = IngestionStatus.QUEUED
366-
_ingest_times[source_id] = time.time()
367-
368-
# Set up mock status queues with empty queues
369-
mock_status_queues_class.return_value = mock_status_queues
370-
mock_success_queue = MagicMock()
371-
mock_success_queue.receive_messages.return_value = []
372-
mock_status_queues.success._get_queues.return_value = [mock_success_queue]
373-
mock_failure_queue = MagicMock()
374-
mock_failure_queue.receive_messages.return_value = []
375-
mock_status_queues.failure._get_queues.return_value = [mock_failure_queue]
376-
377-
# Act
378-
result = list(check_ingestion_status(mock_ingest_client, [source_id], logs=True))
379-
380-
# Assert
381-
assert len(result) == 1
382-
# The status should still be QUEUED since no messages were found and no timeout occurred
383-
assert result[0] == (source_id, IngestionStatus.QUEUED)
384-
385360
@patch("cosmotech.coal.azure.adx.ingestion.KustoIngestStatusQueues")
386361
def test_check_ingestion_status_unknown_id(self, mock_status_queues_class, mock_ingest_client, mock_status_queues):
387362
"""Test the check_ingestion_status function with an unknown source ID."""

0 commit comments

Comments
 (0)