Skip to content

Commit f5eb759

Browse files
committed
refactor(tests): improve command execution and output handling in DagsterTestContext
- Refactored `_run_command` method to utilize a context manager for safe directory changes. - Introduced `_create_subprocess` and `_stream_process_output` methods for better subprocess management and output streaming. - Enhanced `_start_output_threads` method to streamline output capture from subprocesses. - Updated test case to modify the SQL model file directly, improving clarity and functionality. - Removed commented-out code to clean up the test file.
1 parent 9e4d084 commit f5eb759

File tree

2 files changed

+193
-150
lines changed

2 files changed

+193
-150
lines changed

dagster_sqlmesh/conftest.py

Lines changed: 167 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import tempfile
1111
import threading
1212
import typing as t
13+
from contextlib import contextmanager
1314
from dataclasses import dataclass, field
1415

1516
import duckdb
@@ -544,125 +545,6 @@ class DagsterTestContext:
544545
dagster_project_path: str
545546
sqlmesh_project_path: str
546547

547-
def _stream_output(
548-
self,
549-
pipe: t.IO[str],
550-
output_queue: queue.Queue[tuple[str, str | None]],
551-
process_stdout: t.IO[str],
552-
) -> None:
553-
"""Stream output from a pipe to a queue.
554-
555-
Args:
556-
pipe: The pipe to read from (stdout or stderr)
557-
output_queue: Queue to write output to, as (stream_type, line) tuples
558-
process_stdout: The stdout pipe from the process, used to determine stream type
559-
"""
560-
# Use a StringIO buffer to accumulate characters into lines
561-
buffer = io.StringIO()
562-
stream_type = "stdout" if pipe is process_stdout else "stderr"
563-
564-
try:
565-
while True:
566-
char = pipe.read(1)
567-
if not char:
568-
# Flush any remaining content in buffer
569-
remaining = buffer.getvalue()
570-
if remaining:
571-
output_queue.put((stream_type, remaining))
572-
break
573-
574-
buffer.write(char)
575-
576-
# If we hit a newline, flush the buffer
577-
if char == "\n":
578-
output_queue.put((stream_type, buffer.getvalue()))
579-
buffer = io.StringIO()
580-
finally:
581-
buffer.close()
582-
output_queue.put((stream_type, None)) # Signal EOF
583-
584-
def _run_command(self, cmd: list[str], cwd: str | None = None) -> None:
585-
"""Execute a command and stream its output in real-time.
586-
587-
Args:
588-
cmd: List of command parts to execute
589-
cwd: Optional directory to change to before running the command.
590-
591-
Raises:
592-
subprocess.CalledProcessError: If the command returns non-zero exit code
593-
"""
594-
original_cwd = os.getcwd()
595-
596-
print(f"Running command: {' '.join(cmd)}")
597-
print(f"Original working directory: {original_cwd}")
598-
599-
process = None
600-
try:
601-
if cwd:
602-
print(f"Changing to directory: {cwd}")
603-
os.chdir(cwd)
604-
else:
605-
print(f"Running in current directory: {original_cwd}")
606-
607-
process = subprocess.Popen(
608-
cmd,
609-
stdout=subprocess.PIPE,
610-
stderr=subprocess.PIPE,
611-
text=True,
612-
universal_newlines=True,
613-
encoding="utf-8",
614-
errors="replace",
615-
)
616-
617-
if not process.stdout or not process.stderr:
618-
raise RuntimeError("Failed to open subprocess pipes")
619-
620-
# Create a single queue for all output
621-
output_queue: queue.Queue[tuple[str, str | None]] = queue.Queue()
622-
623-
# Start threads to read from pipes
624-
stdout_thread = threading.Thread(
625-
target=self._stream_output,
626-
args=(process.stdout, output_queue, process.stdout),
627-
)
628-
stderr_thread = threading.Thread(
629-
target=self._stream_output,
630-
args=(process.stderr, output_queue, process.stdout),
631-
)
632-
633-
stdout_thread.daemon = True
634-
stderr_thread.daemon = True
635-
stdout_thread.start()
636-
stderr_thread.start()
637-
638-
# Track which streams are still active
639-
active_streams = {"stdout", "stderr"}
640-
641-
# Read from queue and print output
642-
while active_streams:
643-
try:
644-
stream_type, content = output_queue.get(timeout=0.1)
645-
if content is None:
646-
active_streams.remove(stream_type)
647-
else:
648-
print(content, end="", flush=True)
649-
except queue.Empty:
650-
continue
651-
652-
stdout_thread.join()
653-
stderr_thread.join()
654-
process.wait()
655-
656-
if process.returncode != 0:
657-
raise subprocess.CalledProcessError(process.returncode, cmd)
658-
finally:
659-
# Ensure we change back to the original directory
660-
if os.getcwd() != original_cwd:
661-
print(f"Changing back to original directory: {original_cwd}")
662-
os.chdir(original_cwd)
663-
else:
664-
print(f"Remained in original directory: {original_cwd}")
665-
666548
def asset_materialisation(
667549
self,
668550
assets: list[str],
@@ -717,9 +599,174 @@ def asset_materialisation(
717599
config_json,
718600
]
719601

720-
# Change to the sqlmesh project directory before running the command (for some reason asset materialize needs to be run from the dirctory you want the db.db file to be in - feel free to investigate)
721602
self._run_command(cmd=cmd, cwd=self.sqlmesh_project_path)
722603

604+
def _run_command(self, cmd: list[str], cwd: str | None = None) -> None:
605+
"""Execute a command and stream its output in real-time.
606+
607+
Args:
608+
cmd: List of command parts to execute
609+
cwd: Optional directory to change to before running the command.
610+
611+
Raises:
612+
subprocess.CalledProcessError: If the command returns non-zero exit code
613+
RuntimeError: If subprocess pipes cannot be opened
614+
"""
615+
with self._manage_working_directory(cwd):
616+
process = self._create_subprocess(cmd)
617+
self._stream_process_output(process, cmd)
618+
619+
def _manage_working_directory(
620+
self, cwd: str | None = None
621+
) -> t.ContextManager[None]:
622+
"""Context manager to handle directory changes safely.
623+
624+
Args:
625+
cwd: Optional directory to change to before running the command.
626+
"""
627+
628+
@contextmanager
629+
def _directory_context():
630+
original_cwd = os.getcwd()
631+
try:
632+
if cwd:
633+
print(f"Changing to directory: {cwd}")
634+
os.chdir(cwd)
635+
else:
636+
print(f"Running in current directory: {original_cwd}")
637+
yield
638+
finally:
639+
if os.getcwd() != original_cwd:
640+
print(f"Changing back to original directory: {original_cwd}")
641+
os.chdir(original_cwd)
642+
643+
return _directory_context()
644+
645+
def _create_subprocess(self, cmd: list[str]) -> "subprocess.Popen[str]":
646+
"""Create and return a subprocess with proper pipe configuration.
647+
648+
Args:
649+
cmd: List of command parts to execute
650+
651+
Returns:
652+
subprocess.Popen: The created subprocess with stdout/stderr pipes
653+
654+
Raises:
655+
RuntimeError: If subprocess pipes cannot be opened
656+
"""
657+
print(f"Running command: {' '.join(cmd)}")
658+
process = subprocess.Popen(
659+
cmd,
660+
stdout=subprocess.PIPE,
661+
stderr=subprocess.PIPE,
662+
text=True,
663+
universal_newlines=True,
664+
encoding="utf-8",
665+
errors="replace",
666+
)
667+
if not process.stdout or not process.stderr:
668+
raise RuntimeError("Failed to open subprocess pipes")
669+
return process
670+
671+
def _stream_process_output(
672+
self, process: "subprocess.Popen[str]", cmd: list[str]
673+
) -> None:
674+
"""Handle the streaming of process output from both stdout and stderr.
675+
676+
Args:
677+
process: The subprocess whose output to stream
678+
cmd: The original command (for error reporting)
679+
680+
Raises:
681+
subprocess.CalledProcessError: If the process returns non-zero exit code
682+
"""
683+
output_queue: queue.Queue[tuple[str, str | None]] = queue.Queue()
684+
685+
# Start output capture threads
686+
threads = self._start_output_threads(process, output_queue)
687+
688+
# Process output until both streams are done
689+
active_streams = {"stdout", "stderr"}
690+
while active_streams:
691+
try:
692+
stream_type, content = output_queue.get(timeout=0.1)
693+
if content is None:
694+
active_streams.remove(stream_type)
695+
else:
696+
print(content, end="", flush=True)
697+
except queue.Empty:
698+
continue
699+
700+
# Wait for completion
701+
for thread in threads:
702+
thread.join()
703+
process.wait()
704+
705+
if process.returncode != 0:
706+
raise subprocess.CalledProcessError(process.returncode, cmd)
707+
708+
def _start_output_threads(
709+
self,
710+
process: "subprocess.Popen[str]",
711+
output_queue: queue.Queue[tuple[str, str | None]],
712+
) -> list[threading.Thread]:
713+
"""Start and return the stdout/stderr capture threads.
714+
715+
Args:
716+
process: The subprocess whose output to capture
717+
output_queue: Queue to write captured output to
718+
719+
Returns:
720+
list[threading.Thread]: List of started capture threads
721+
"""
722+
threads = []
723+
for pipe in [process.stdout, process.stderr]:
724+
thread = threading.Thread(
725+
target=self._stream_output,
726+
args=(pipe, output_queue, process.stdout),
727+
)
728+
thread.daemon = True
729+
thread.start()
730+
threads.append(thread)
731+
return threads
732+
733+
def _stream_output(
734+
self,
735+
pipe: t.IO[str],
736+
output_queue: queue.Queue[tuple[str, str | None]],
737+
process_stdout: t.IO[str],
738+
) -> None:
739+
"""Stream output from a pipe to a queue.
740+
741+
Args:
742+
pipe: The pipe to read from (stdout or stderr)
743+
output_queue: Queue to write output to, as (stream_type, line) tuples
744+
process_stdout: The stdout pipe from the process, used to determine stream type
745+
"""
746+
# Use a StringIO buffer to accumulate characters into lines
747+
buffer = io.StringIO()
748+
stream_type = "stdout" if pipe is process_stdout else "stderr"
749+
750+
try:
751+
while True:
752+
char = pipe.read(1)
753+
if not char:
754+
# Flush any remaining content in buffer
755+
remaining = buffer.getvalue()
756+
if remaining:
757+
output_queue.put((stream_type, remaining))
758+
break
759+
760+
buffer.write(char)
761+
762+
# If we hit a newline, flush the buffer
763+
if char == "\n":
764+
output_queue.put((stream_type, buffer.getvalue()))
765+
buffer = io.StringIO()
766+
finally:
767+
buffer.close()
768+
output_queue.put((stream_type, None)) # Signal EOF
769+
723770
def reset_assets(self) -> None:
724771
"""Resets the assets to the original state"""
725772
self.asset_materialisation(assets=["reset_asset"])

dagster_sqlmesh/controller/tests_plan_and_run/test_model_code_change.py

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import logging
2-
import time
32

43
import pytest
54

@@ -196,7 +195,6 @@ def test_given_model_chain_when_running_with_different_flags_then_behaves_as_exp
196195

197196
sample_dagster_test_context.asset_materialisation(
198197
assets=[
199-
"test_source",
200198
"seed_model_1",
201199
"seed_model_2",
202200
"staging_model_1",
@@ -210,35 +208,33 @@ def test_given_model_chain_when_running_with_different_flags_then_behaves_as_exp
210208
)
211209

212210
# # # Modify intermediate_model_1 sql to cause breaking change
213-
# sample_sqlmesh_test_context.modify_model_file(
214-
# "intermediate_model_1.sql",
215-
# """
216-
# MODEL (
217-
# name sqlmesh_example.intermediate_model_1,
218-
# kind INCREMENTAL_BY_TIME_RANGE (
219-
# time_column event_date
220-
# ),
221-
# start '2020-01-01',
222-
# cron '@daily',
223-
# grain (id, event_date)
224-
# );
225-
226-
# SELECT
227-
# main.id,
228-
# main.item_id,
229-
# main.event_date,
230-
# CONCAT('item - ', main.item_id) as item_name
231-
# FROM sqlmesh_example.staging_model_1 AS main
232-
# INNER JOIN sqlmesh_example.staging_model_2 as sub
233-
# ON main.id = sub.id
234-
# WHERE
235-
# event_date BETWEEN @start_date AND @end_date
236-
# """,
237-
# )
238-
239-
# sample_dagster_test_context.asset_materialisation(assets=["intermediate_model_1"], plan_options=PlanOptions(skip_backfill=True, enable_preview=True, skip_tests=True))
211+
sample_sqlmesh_test_context.modify_model_file(
212+
"intermediate_model_1.sql",
213+
"""
214+
MODEL (
215+
name sqlmesh_example.intermediate_model_1,
216+
kind INCREMENTAL_BY_TIME_RANGE (
217+
time_column event_date
218+
),
219+
start '2020-01-01',
220+
cron '@daily',
221+
grain (id, event_date)
222+
);
223+
224+
SELECT
225+
main.id,
226+
main.item_id,
227+
main.event_date,
228+
CONCAT('item - ', main.item_id) as item_name
229+
FROM sqlmesh_example.staging_model_1 AS main
230+
INNER JOIN sqlmesh_example.staging_model_2 as sub
231+
ON main.id = sub.id
232+
WHERE
233+
event_date BETWEEN @start_date AND @end_date
234+
""",
235+
)
240236

241-
time.sleep(5)
237+
sample_dagster_test_context.asset_materialisation(assets=["intermediate_model_1"], plan_options=PlanOptions(skip_backfill=True, enable_preview=True, skip_tests=True))
242238

243239
intermediate_model_1_df = (
244240
sample_sqlmesh_test_context.query(

0 commit comments

Comments
 (0)