diff --git a/nodestream/pipeline/pipeline.py b/nodestream/pipeline/pipeline.py index 4b0ebb50..468853d2 100644 --- a/nodestream/pipeline/pipeline.py +++ b/nodestream/pipeline/pipeline.py @@ -1,131 +1,436 @@ +from abc import ABC, abstractmethod from asyncio import create_task, gather -from logging import getLogger -from typing import Iterable, List, Tuple - -from ..metrics import ( - RECORDS, - STEPS_RUNNING, - Metrics, -) +from dataclasses import dataclass, field +from enum import Enum, auto +from typing import Any, Iterable, List, Optional, Tuple, Type + +from ..metrics import RECORDS, STEPS_RUNNING, Metrics from ..schema import ExpandsSchema, ExpandsSchemaFromChildren from .channel import StepInput, StepOutput, channel from .object_storage import ObjectStore -from .progress_reporter import PipelineProgressReporter, no_op +from .progress_reporter import PipelineProgressReporter from .step import Step, StepContext -class StepExecutor: - """`StepExecutor` is a utility that is used to run a step in a pipeline. +async def no_op(_): + pass - The `StepExecutor` is responsible for starting, stopping, and running a - step in a pipeline. It is used to execute a step by passing records - between the input and output channels of the step. - """ - __slots__ = ("step", "input", "output", "context") +@dataclass(slots=True) +class RecordContext: + """A `Record` is a unit of data that is processed by a pipeline.""" + + record: Any + originating_step: Step + callback_token: Any + originated_from: Optional["RecordContext"] = field(default=None) + child_record_count: int = field(default=0) + + @staticmethod + def from_step_emission( + step: Step, + emission: Any, + originated_from: Optional["RecordContext"] = None, + ): + """Create a record from a step's emission of data. + + The `emission` can either be a single value or a tuple of two values. + If it is a single value, then it is assumed to be the data for the + record. If it is a tuple of two values, then the first value is + assumed to be the data for the record and the second value is assumed + to be the callback token for the record. If any other value is + provided, the data and callback token are both set to the value + provided. + + Args: + step (Step): The step that emitted the record. + emission (Any): The emission from the step. + originated_from (Optional[Record], optional): The record that + this record was emitted from. Defaults to None. + + Returns: + Record: The record created from the emission. + """ + record = callback_token = emission + if isinstance(emission, tuple): + record, callback_token = emission + + return RecordContext(record, step, callback_token, originated_from) + + async def child_dropped(self): + # If we have no children after this child has reported itself as + # having been dropped, then we can consider ourselves dropped as + # well since this must mean that we are not responsible for any more + # work and are not a resultant record in the pipeline and instead + # had to have been created as an intermediate step so our usefulness + # is simply if our children are useful. + self.child_record_count -= 1 + if self.child_record_count == 0: + await self.drop() + + async def drop(self): + # If we are being told to drop, then we need to run our callback so + # that the step that created us can clean up any resources it has + # allocated for this record if it opts into the feature + # (by implementing the `finalize_record` method). + await self.originating_step.finalize_record(self.callback_token) + + # If _we_ are being dropped, then there is a chance that our parent is + # done as well. So we can propagate the drop up the chain and ensure + # that all records are properly cleaned up. + if self.originated_from is not None: + await self.originated_from.child_dropped() + + +class ExecutionState(ABC): + @abstractmethod + async def execute_until_state_change(self) -> Optional["ExecutionState"]: + pass + + +class EmitResult(Enum): + EMITTED_RECORDS = auto() + CLOSED_DOWNSTREAM = auto() + NO_OP = auto() + + @property + def should_continue(self) -> bool: + return self != EmitResult.CLOSED_DOWNSTREAM + + @property + def did_emit_records(self) -> bool: + return self == EmitResult.EMITTED_RECORDS + + +class StepExecutionState(ExecutionState): + """State that a step is in when it is executing. + + This is the base class for all states that a step can be in. It provides + the basic functionality for executing a step and transitioning between + states. It also provides the basic functionality for emitting records + and handling errors. + """ def __init__( self, step: Step, + context: StepContext, input: StepInput, output: StepOutput, - context: StepContext, - ) -> None: + ): self.step = step + self.context = context self.input = input self.output = output - self.context = context - async def start_step(self): + def make_state(self, next_state: Type["StepExecutionState"]): + """Make the next state for the step. + + This method is used to create the next state for the step. It is + responsible for creating the next state and returning it. The next + state is created with the same step, context, input, and output as the + current state. + """ + return next_state(self.step, self.context, self.input, self.output) + + async def emit_record(self, record: RecordContext) -> EmitResult: + """Emit a record to the output channel. + + This method is used to emit a record to the output channel. It will + block until the record is put in the channel. If the channel is full, + it will block until there is space in the channel unless the channel is + closed on the other end. + + Returns: + EmitResult: The result of the emit operation. If the downstream is + not accepting more records, it will return + `EmitResult.CLOSED_DOWNSTREAM`. Otherwise, it will return + `EmitResult.EMITTED_RECORDS`. + """ + if not await self.output.put(record): + self.context.debug("Downstream is not accepting records. Stopping") + return EmitResult.CLOSED_DOWNSTREAM + + return EmitResult.EMITTED_RECORDS + + async def emit_from_generator( + self, + generator, + origin: Optional[RecordContext] = None, + ) -> EmitResult: + """Emit records from a generator. + + This method is used to emit records from a generator. It will block + until the record is put in the channel. If the channel is full, it will + block until there is space in the channel unless the channel is closed + on the other end. + + Returns: + EmitResult: The result of the emit operation. If the downstream + is not accepting more records, it will return + `EmitResult.CLOSED_DOWNSTREAM`. If no records were emitted, it + will return `EmitResult.NO_OP`. Otherwise, it will return + `EmitResult.EMITTED_RECORDS`. + """ + emitted = False + async for emission in generator: + # We can create a record for this data and attempt to submit it + # downstream. If we _failed_ to emit a record, then we need to + # stop processing any more records and we will report this to + # whatever state is calling us. Depending on that state, it can + # decide what to do. + record = RecordContext.from_step_emission(self.step, emission, origin) + result = await self.emit_record(record) + if result == EmitResult.CLOSED_DOWNSTREAM: + return result + + # If we got here, then we successfully emitted at least 1 record. + emitted = True + + # If we succesfully left the loop, then we either emitted something or + # or the generator was empty. Depending that, we can return the correct + # status to the caller. + return EmitResult.EMITTED_RECORDS if emitted else EmitResult.NO_OP + + +class StartStepState(StepExecutionState): + """State that a step is in when it is starting. + + This is the first state that a step is in when it is executed. The step + is in this state when it is first created and before it has started + processing records. Once the step has started, it will transition to the + `ProcessRecordsState`. If the step fails to start, it will transition to + the `StopStepExecution` state. + """ + + async def execute_until_state_change(self) -> Optional[StepExecutionState]: try: Metrics.get().increment(STEPS_RUNNING) await self.step.start(self.context) + return self.make_state(ProcessRecordsState) except Exception as e: - self.context.report_error("Error starting step", e) + self.context.report_error("Error starting step", e, fatal=True) + return None + + +class ProcessRecordsState(StepExecutionState): + """State that a step is in when it is processing records. + + This is the state that a step is in when it is actively processing records. + The step will remain in this state until it has processed all of its input + records and emitted all of its output records. Once the step has finished + processing records, it will transition to the + `EmitOutstandingRecordsState`. - async def stop_step(self): + If the step fails to process a record, it will transition to the + `StopStepExecution` state. If the step downstream is not accepting more + records, it will transition to the `StopStepExecution` state. + """ + + async def execute_until_state_change(self) -> Optional[StepExecutionState]: try: - Metrics.get().decrement(STEPS_RUNNING) - await self.step.finish(self.context) + while (next := await self.input.get()) is not None: + # Process the record and emit any resulting records downstream. + emissions = self.step.process_record(next.record, self.context) + result = await self.emit_from_generator(emissions, next) + + # If we didn't emit any records by processing this record, + # then we need to drop the record since it is not + # going to participate in the pipeline any further. + if not result.did_emit_records: + await next.drop() + + # If the downstream is not accepting more records, then we need + # to stop processing records by transitioning to a stop state. + if not result.should_continue: + return self.make_state(StopStepExecution) + + # If we get an exception, we need to stop processing records by + # transitioning to the stop state. Because this is part of the core + # execution of the step, we consider this a fatal error. except Exception as e: - self.context.report_error("Error stopping step", e) + self.context.report_error("Error processing record", e, fatal=True) + return self.make_state(StopStepExecution) - async def emit_record(self, record): - can_continue = await self.output.put(record) - if not can_continue: - self.context.debug( - "Downstream is not accepting more records. Gracefully stopping." - ) + # If we have gotten here, then we have processed all of our input + # records and we need to transition to the next state which is to emit + # any outstanding records. + return self.make_state(EmitOutstandingRecordsState) - return can_continue - async def drive_step(self): - try: - while (next_record := await self.input.get()) is not None: - results = self.step.process_record(next_record, self.context) - async for record in results: - if not await self.emit_record(record): - return +class EmitOutstandingRecordsState(StepExecutionState): + """State that a step is in when it is emitting outstanding records. - async for record in self.step.emit_outstanding_records(self.context): - if not await self.emit_record(record): - return + This is the state that a step is in when it is emitting any outstanding + records. This is done after all records have been processed. Regardless of + success or failure, the step will transition to the `StopStepExecution` + state. + """ - self.context.debug("Step finished emitting") + async def execute_until_state_change(self) -> Optional[StepExecutionState]: + try: + # Emit any outstanding records. If we get an error, we will still + # transition to the stop state. Unlike the processing state, this + # state does not really care about the result of the emit + # operation because we are transitioning to the stop state + # regardless of success or failure and there is no originating + # record to drop. + # + # NOTE: This is not quite true, as in theory records can be + # outstanding that were some how originated from a record. + # However, there is no nice way I've thought of to account for + # this without breaking the step interface to track the + # originating record for each outstanding record. For now, we + # will just leave it out of scope. + outstanding = self.step.emit_outstanding_records(self.context) + await self.emit_from_generator(outstanding) + + # If we get an exception, we will report it as fatal as steps + # processing outstanding records is part of the core execution + # of the step. except Exception as e: - self.context.report_error("Error running step", e, fatal=True) + self.context.report_error( + "Error emitting outstanding records", + e, + fatal=True, + ) - async def run(self): - self.context.debug("Starting step") - await self.start_step() - await self.drive_step() - await self.output.done() - self.input.done() - await self.stop_step() - self.context.debug("Finished step") + # We are doing to transition to the stop state regardless of what + # happened to get us here. + return self.make_state(StopStepExecution) -class PipelineOutput: - """`PipelineOutput` is an output channel for a pipeline. +class StopStepExecution(StepExecutionState): + """State that a step is in when it is stopping. - A `PipelineOutput` is used to consume records from the last step in a - pipeline and report the progress of the pipeline. + This is the state that a step is in when it is stopping. This is the final + state that a step will be in. Once it transitions to this state, it will + not transition to any other state and end execution. """ - __slots__ = ("input", "reporter", "observe_results") + async def execute_until_state_change(self) -> Optional[StepExecutionState]: + try: + # Closing the output channel will signal to any downstream steps + # that we are done processing records and that there nothing left + # to wait for. Similarly, we mark the input as done to signal + # to any upstream steps that we are done processing records and + # that producing more records is futile. + await self.output.done() + self.input.done() + + # Steps may need to do some finalization work when they are done. + await self.step.finish(self.context) + + # In the event of a failure closing out a step, we will report it as a + # non-fatal error because all core work has been accomplished. Resource + # cleanup, while messy, is not fatal to the pipeline as a whole. + except Exception as e: + self.context.report_error("Error stopping step", e) + + # We are done regarless of what happens. There is no next state. + Metrics.get().decrement(STEPS_RUNNING) + return None + - def __init__(self, input: StepInput, reporter: PipelineProgressReporter): +class PipelineOutputState(ExecutionState): + __slots__ = ("input", "reporter", "metrics") + + def __init__( + self, + input: StepInput, + reporter: PipelineProgressReporter, + metrics: Metrics, + ): self.input = input self.reporter = reporter - self.observe_results = reporter.observability_callback is not no_op + self.metrics = metrics + + def make_state(self, next_state: Type["ExecutionState"]): + return next_state(self.input, self.reporter, self.metrics) - def call_handling_errors(self, f, *args): + def call_ignoring_errors(self, f, *args): try: f(*args) except Exception: self.reporter.logger.exception(f"Error running {f.__name__}") - async def run(self): - """Run the pipeline output. - This method is used to run the pipeline output. It will consume records - from the last step in the pipeline and report the progress of the - pipeline using the `PipelineProgressReporter`. The pipeline output will - block until all records have been consumed from the last step in the - pipeline. - """ - metrics = Metrics.get() - self.call_handling_errors(self.reporter.on_start_callback) +class PipelineOutputStartState(PipelineOutputState): + """State that the pipeline output is in when it is starting. + + This is the first state that the pipeline output is in when it is + executed. The pipeline output is in this state when it is first created + and before it has started processing records. Once the pipeline output + has started, it will transition to the `PipelineOutputProcessRecordsState`. + """ + + async def execute_until_state_change(self) -> Optional[ExecutionState]: + self.call_ignoring_errors(self.reporter.on_start_callback) + return self.make_state(PipelineOutputProcessRecordsState) + +class PipelineOutputProcessRecordsState(PipelineOutputState): + """State that the pipeline output is in when it is processing records. + + This is the state that the pipeline output is in when it is processing + records. The pipeline output is in this state after it has started and + before it has finished processing all records. Once the pipeline output + has finished processing all records, it will transition to the + `PipelineOutputStopState`. + """ + + async def execute_until_state_change(self) -> Optional[ExecutionState]: index = 0 - while (record := await self.input.get()) is not None: - metrics.increment(RECORDS) - self.call_handling_errors(self.reporter.report, index, metrics) - if self.observe_results: - self.call_handling_errors(self.reporter.observe, record) + while (next := await self.input.get()) is not None: + self.metrics.increment(RECORDS) + self.call_ignoring_errors(self.reporter.report, index, self.metrics) + self.call_ignoring_errors(self.reporter.observe, next.record) + await next.drop() index += 1 + return self.make_state(PipelineOutputStopState) + + +class PipelineOutputStopState(PipelineOutputState): + """State that the pipeline output is in when it is stopping. + + This is the state that the pipeline output is in when it is stopping. This + is the final state that the pipeline output will be in. Once it transitions + to this state, it will not transition to any other state and end execution. + """ + + async def execute_until_state_change(self) -> Optional[ExecutionState]: + # For cases where the reporter _wants_ to have the exception thrown + # (e.g to have a status code in the CLI) we need to make sure we call + # on_finish_callback without swallowing exceptions (because thats the point). + self.reporter.on_finish_callback(self.metrics) + return None - self.call_handling_errors(self.reporter.on_finish_callback, metrics) + +class Executor: + __slots__ = ("state",) + + def __init__(self, state: ExecutionState) -> None: + self.state = state + + @classmethod + def for_step( + cls, + step: Step, + input: StepInput, + output: StepOutput, + context: StepContext, + ) -> "Executor": + return cls(StartStepState(step, context, input, output)) + + @classmethod + def pipeline_output( + cls, input: StepInput, reporter: PipelineProgressReporter + ) -> "Executor": + return cls(PipelineOutputStartState(input, reporter, Metrics.get())) + + async def run(self): + while self.state is not None: + self.state = await self.state.execute_until_state_change() class Pipeline(ExpandsSchemaFromChildren): @@ -137,7 +442,7 @@ class Pipeline(ExpandsSchemaFromChildren): and running the steps in the pipeline. """ - __slots__ = ("steps", "step_outbox_size", "logger", "object_store") + __slots__ = ("steps", "step_outbox_size", "object_store") def __init__( self, @@ -147,7 +452,6 @@ def __init__( ) -> None: self.steps = steps self.step_outbox_size = step_outbox_size - self.logger = getLogger(self.__class__.__name__) self.object_store = object_store def get_child_expanders(self) -> Iterable[ExpandsSchema]: @@ -177,14 +481,20 @@ async def run(self, reporter: PipelineProgressReporter): # step to the next step. The channels are used to pass records between # the steps in the pipeline. The channels have a fixed size to control # the flow of records between the steps. - executors: List[StepExecutor] = [] + executors: List[Executor] = [] current_input_name = None current_output_name = self.steps[-1].__class__.__name__ + f"_{len(self.steps)}" + # Here lies a footgun. DO NOT MOVE this executor from the first + # position in the list that makes its way to the gather call. It will + # break the pipeline because we need to ensure the on_start_callback + # is called before any operations occur in the actual steps in the + # pipeline. To do this, we need to make sure that the first coroutine + # scheduled is the one that calls the on_start_callback. current_input, current_output = channel( self.step_outbox_size, current_output_name, current_input_name ) - pipeline_output = PipelineOutput(current_input, reporter) + executors.append(Executor.pipeline_output(current_input, reporter)) # Create the executors for the steps in the pipeline. The executors # will be used to run the steps concurrently. The steps are created in @@ -204,7 +514,7 @@ async def run(self, reporter: PipelineProgressReporter): current_input, next_output = channel( self.step_outbox_size, current_output_name, current_input_name ) - exec = StepExecutor(step, current_input, current_output, context) + exec = Executor.for_step(step, current_input, current_output, context) current_output = next_output executors.append(exec) @@ -216,6 +526,5 @@ async def run(self, reporter: PipelineProgressReporter): # Run the pipeline by running all the steps and the pipeline output # concurrently. This will block until all steps are finished. - running_steps = (create_task(executor.run()) for executor in executors) - - await gather(*running_steps, create_task(pipeline_output.run())) + # Wait for all the executors to finish. + await gather(*(create_task(executor.run()) for executor in executors)) diff --git a/nodestream/pipeline/step.py b/nodestream/pipeline/step.py index b9937d97..df5627c3 100644 --- a/nodestream/pipeline/step.py +++ b/nodestream/pipeline/step.py @@ -1,10 +1,6 @@ from typing import AsyncGenerator, Optional -from ..metrics import ( - FATAL_ERRORS, - NON_FATAL_ERRORS, - Metrics, -) +from ..metrics import FATAL_ERRORS, NON_FATAL_ERRORS, Metrics from .object_storage import ObjectStore from .progress_reporter import PipelineProgressReporter @@ -149,6 +145,15 @@ async def finish(self, context: StepContext): """ pass + async def finalize_record(self, record_or_token: object): + """Finalize a record. + + This method is called when a record produced by this step has been + fully processed by all downstream steps. It is not called for records + that are not produced by this step. + """ + pass + class PassStep(Step): """A `PassStep` passes records through.""" diff --git a/tests/integration/test_pipeline_and_data_interpretation.py b/tests/integration/test_pipeline_and_data_interpretation.py index 73a2271f..10468857 100644 --- a/tests/integration/test_pipeline_and_data_interpretation.py +++ b/tests/integration/test_pipeline_and_data_interpretation.py @@ -63,8 +63,9 @@ async def test_pipeline_interpretation_snapshot( snapshot.snapshot_dir = "tests/integration/snapshots" pipeline_file = get_pipeline_fixture_file_by_name(pipeline_name) definition = PipelineDefinition.from_path(pipeline_file) + results = await drive_definition_to_completion(definition) results_as_json = json.dumps( - [asdict(r) for r in (await drive_definition_to_completion(definition))], + [asdict(r) for r in results], default=set_default, indent=4, sort_keys=True, diff --git a/tests/integration/test_pipeline_cleanup_flow.py b/tests/integration/test_pipeline_cleanup_flow.py new file mode 100644 index 00000000..32dfc71e --- /dev/null +++ b/tests/integration/test_pipeline_cleanup_flow.py @@ -0,0 +1,303 @@ +from unittest.mock import Mock + +import pytest + +from nodestream.pipeline.extractors import Extractor +from nodestream.pipeline.object_storage import ObjectStore +from nodestream.pipeline.pipeline import Pipeline +from nodestream.pipeline.progress_reporter import PipelineProgressReporter +from nodestream.pipeline.transformers import Transformer +from nodestream.pipeline.writers import Writer + + +class ResourceTrackingExtractor(Extractor): + """Extractor that tracks resource allocation and cleanup.""" + + def __init__(self, data_items): + self.data_items = data_items + self.allocated_resources = {} + self.finalized_tokens = [] + + async def extract_records(self): + for i, item in enumerate(self.data_items): + # Simulate resource allocation + token = f"extractor_resource_{i}" + self.allocated_resources[token] = f"resource_for_{item}" + yield (item, token) # Emit tuple with callback token + + async def finalize_record(self, record_token): + """Clean up resources allocated for this record.""" + if record_token in self.allocated_resources: + del self.allocated_resources[record_token] + self.finalized_tokens.append(record_token) + + +class ResourceTrackingTransformer(Transformer): + """Transformer that tracks resource allocation and cleanup.""" + + def __init__(self): + self.allocated_resources = {} + self.finalized_tokens = [] + self.processed_records = [] + + async def process_record(self, record, context): + # Track the record we processed + self.processed_records.append(record) + + # Simulate resource allocation for transformation + token = f"transformer_resource_{id(record)}" + self.allocated_resources[token] = f"transform_resource_for_{record}" + + # Transform the record + transformed = f"transformed_{record}" + yield (transformed, token) # Emit with callback token + + async def finalize_record(self, record_token): + """Clean up transformation resources.""" + if record_token in self.allocated_resources: + del self.allocated_resources[record_token] + self.finalized_tokens.append(record_token) + + +class ResourceTrackingWriter(Writer): + """Writer that tracks resource allocation and cleanup.""" + + def __init__(self): + self.allocated_resources = {} + self.finalized_tokens = [] + self.written_records = [] + + async def write_record(self, record): + # Track what we wrote + self.written_records.append(record) + + # Simulate resource allocation for writing + token = f"writer_resource_{id(record)}" + self.allocated_resources[token] = f"write_resource_for_{record}" + return token # Return token for cleanup + + async def process_record(self, record, context): + # Write the record and get cleanup token + token = await self.write_record(record) + yield (record, token) # Pass through with cleanup token + + async def finalize_record(self, record_token): + """Clean up writing resources.""" + if record_token in self.allocated_resources: + del self.allocated_resources[record_token] + self.finalized_tokens.append(record_token) + + +@pytest.mark.asyncio +async def test_end_to_end_cleanup_flow(): + """Test complete cleanup flow through extractor -> transformer -> writer.""" + # Create steps with resource tracking + extractor = ResourceTrackingExtractor(["item1", "item2", "item3"]) + transformer = ResourceTrackingTransformer() + writer = ResourceTrackingWriter() + + # Create pipeline + steps = (extractor, transformer, writer) + object_store = Mock(spec=ObjectStore) + object_store.namespaced = Mock(return_value=Mock()) + + pipeline = Pipeline(steps, step_outbox_size=10, object_store=object_store) + + # Create progress reporter + reporter = PipelineProgressReporter.for_testing([]) + + # Run pipeline + await pipeline.run(reporter) + + # Verify all records were processed + assert len(transformer.processed_records) == 3 + assert len(writer.written_records) == 3 + + # Verify finalize_record was called for writer (final step) + # Note: In multi-step pipelines, only the final step gets cleanup calls + # because intermediate records are transformed, not dropped + assert len(writer.finalized_tokens) == 3 + + # Writer resources should be cleaned up + assert len(writer.allocated_resources) == 0 + + +@pytest.mark.asyncio +async def test_cleanup_flow_with_filtering(): + """Test cleanup flow when some records are filtered out.""" + + class FilteringTransformer(Transformer): + def __init__(self): + self.allocated_resources = {} + self.finalized_tokens = [] + + async def process_record(self, record, context): + # Allocate resource for processing + token = f"filter_resource_{id(record)}" + self.allocated_resources[token] = f"resource_for_{record}" + + # Only pass through records containing "keep" + if "keep" in str(record): + yield (f"filtered_{record}", token) + # If we don't yield, the record will be dropped and finalized + + async def finalize_record(self, record_token): + if record_token in self.allocated_resources: + del self.allocated_resources[record_token] + self.finalized_tokens.append(record_token) + + # Create steps + extractor = ResourceTrackingExtractor(["keep1", "drop1", "keep2", "drop2"]) + filter_transformer = FilteringTransformer() + writer = ResourceTrackingWriter() + + steps = (extractor, filter_transformer, writer) + object_store = Mock(spec=ObjectStore) + object_store.namespaced = Mock(return_value=Mock()) + + pipeline = Pipeline(steps, step_outbox_size=10, object_store=object_store) + reporter = PipelineProgressReporter.for_testing([]) + + await pipeline.run(reporter) + + # Verify only "keep" records made it to writer + assert len(writer.written_records) == 2 + assert all("keep" in str(record) for record in writer.written_records) + + # Verify writer resources were cleaned up + assert len(writer.allocated_resources) == 0 + + # Verify finalize_record was called for writer (final step) + assert len(writer.finalized_tokens) == 2 # Only 2 kept records + + +@pytest.mark.asyncio +async def test_cleanup_flow_with_record_multiplication(): + """Test cleanup flow when one record generates multiple records.""" + + class MultiplyingTransformer(Transformer): + def __init__(self): + self.allocated_resources = {} + self.finalized_tokens = [] + + async def process_record(self, record, context): + # Allocate resource for processing + token = f"multiply_resource_{id(record)}" + self.allocated_resources[token] = f"resource_for_{record}" + + # Generate multiple records from one input + for i in range(3): + yield (f"{record}_copy_{i}", token) + + async def finalize_record(self, record_token): + if record_token in self.allocated_resources: + del self.allocated_resources[record_token] + self.finalized_tokens.append(record_token) + + # Create steps + extractor = ResourceTrackingExtractor(["item1", "item2"]) + multiplier = MultiplyingTransformer() + writer = ResourceTrackingWriter() + + steps = (extractor, multiplier, writer) + object_store = Mock(spec=ObjectStore) + object_store.namespaced = Mock(return_value=Mock()) + + pipeline = Pipeline(steps, step_outbox_size=10, object_store=object_store) + reporter = PipelineProgressReporter.for_testing([]) + + await pipeline.run(reporter) + + # Verify multiplication worked + assert len(writer.written_records) == 6 # 2 input * 3 copies each + + # Verify writer resources were cleaned up + assert len(writer.allocated_resources) == 0 + + # Verify finalize_record calls for writer (final step) + assert len(writer.finalized_tokens) == 6 # 6 output records + + +@pytest.mark.asyncio +async def test_cleanup_flow_with_exception(): + """Test cleanup flow when an exception occurs during processing.""" + + class FailingTransformer(Transformer): + def __init__(self): + self.allocated_resources = {} + self.finalized_tokens = [] + + async def process_record(self, record, context): + # Allocate resource + token = f"failing_resource_{id(record)}" + self.allocated_resources[token] = f"resource_for_{record}" + + if "fail" in str(record): + raise ValueError(f"Processing failed for {record}") + + yield (f"processed_{record}", token) + + async def finalize_record(self, record_token): + if record_token in self.allocated_resources: + del self.allocated_resources[record_token] + self.finalized_tokens.append(record_token) + + # Create steps + extractor = ResourceTrackingExtractor(["good1", "fail1", "good2"]) + failing_transformer = FailingTransformer() + writer = ResourceTrackingWriter() + + steps = (extractor, failing_transformer, writer) + object_store = Mock(spec=ObjectStore) + object_store.namespaced = Mock(return_value=Mock()) + + pipeline = Pipeline(steps, step_outbox_size=10, object_store=object_store) + + # Use a reporter that doesn't raise on fatal errors for this test + reporter = PipelineProgressReporter( + on_fatal_error_callback=lambda ex: None # Don't raise + ) + + await pipeline.run(reporter) + + # The pipeline should handle the exception and stop processing + # Writer should have processed at least one successful record before failure + assert len(writer.written_records) >= 1 # At least one successful record + + +@pytest.mark.asyncio +async def test_cleanup_flow_performance(): + """Test cleanup flow performance with many records.""" + # Create a large number of records to test performance + large_dataset = [f"item_{i}" for i in range(100)] + + extractor = ResourceTrackingExtractor(large_dataset) + transformer = ResourceTrackingTransformer() + writer = ResourceTrackingWriter() + + steps = (extractor, transformer, writer) + object_store = Mock(spec=ObjectStore) + object_store.namespaced = Mock(return_value=Mock()) + + pipeline = Pipeline(steps, step_outbox_size=10, object_store=object_store) + reporter = PipelineProgressReporter.for_testing([]) + + # Measure execution time + import time + + start_time = time.time() + + await pipeline.run(reporter) + + end_time = time.time() + execution_time = end_time - start_time + + # Verify all records were processed and cleaned up + assert len(writer.written_records) == 100 + assert len(writer.allocated_resources) == 0 + + # Verify cleanup calls were made for writer (final step) + assert len(writer.finalized_tokens) == 100 + + # Performance should be reasonable (adjust threshold as needed) + assert execution_time < 5.0 # Should complete within 5 seconds diff --git a/tests/unit/pipeline/test_pipeline.py b/tests/unit/pipeline/test_pipeline.py index 2dab0b17..e70a9b26 100644 --- a/tests/unit/pipeline/test_pipeline.py +++ b/tests/unit/pipeline/test_pipeline.py @@ -1,138 +1,833 @@ +from unittest.mock import AsyncMock, Mock + import pytest +from nodestream.pipeline.channel import channel from nodestream.pipeline.pipeline import ( - PipelineOutput, - PipelineProgressReporter, - Step, - StepContext, - StepExecutor, + EmitOutstandingRecordsState, + EmitResult, + Executor, + Pipeline, + PipelineOutputProcessRecordsState, + PipelineOutputStartState, + PipelineOutputStopState, + ProcessRecordsState, + RecordContext, + StartStepState, + StepExecutionState, StepInput, - StepOutput, + StopStepExecution, ) +from nodestream.pipeline.progress_reporter import PipelineProgressReporter +from nodestream.pipeline.step import Step, StepContext + + +@pytest.fixture +def mock_step(): + step = Mock(spec=Step) + step.start = AsyncMock() + step.finish = AsyncMock() + step.process_record = AsyncMock() + step.emit_outstanding_records = AsyncMock() + step.finalize_record = AsyncMock() + return step @pytest.fixture -def step_executor(mocker): - step = mocker.Mock(Step) - input = mocker.Mock(StepInput) - output = mocker.Mock(StepOutput) - context = mocker.Mock(StepContext) +def mock_context(): + context = Mock(spec=StepContext) + context.report_error = Mock() + context.debug = Mock() + return context + + +@pytest.fixture +def step_execution_state(mock_step, mock_context): + input_channel, output_channel = channel(10) + return StepExecutionState(mock_step, mock_context, input_channel, output_channel) + - return StepExecutor(step, input, output, context) +# Tests for Record class +@pytest.mark.asyncio +async def test_record_from_step_emission_simple(): + step = Mock(spec=Step) + data = {"test": "data"} + + record = RecordContext.from_step_emission(step, data) + + assert record.record == data + assert record.callback_token == data + assert record.originating_step == step + assert record.originated_from is None + assert record.child_record_count == 0 @pytest.mark.asyncio -async def test_start_step(step_executor): - await step_executor.start_step() - step_executor.step.start.assert_called_once_with(step_executor.context) +async def test_record_from_step_emission_with_tuple(): + step = Mock(spec=Step) + data = {"test": "data"} + token = "callback_token" + + record = RecordContext.from_step_emission(step, (data, token)) + + assert record.record == data + assert record.callback_token == token + assert record.originating_step == step @pytest.mark.asyncio -async def test_start_step_error(step_executor, mocker): - step_executor.step.start.side_effect = Exception("Boom") - await step_executor.start_step() - step_executor.context.report_error.assert_called_once_with( - "Error starting step", mocker.ANY +async def test_record_drop_calls_finalize(): + step = Mock(spec=Step) + step.finalize_record = AsyncMock() + token = "test_token" + + record = RecordContext(record="test", originating_step=step, callback_token=token) + await record.drop() + + step.finalize_record.assert_called_once_with(token) + + +@pytest.mark.asyncio +async def test_record_drop_propagates_to_parent(): + parent_step = Mock(spec=Step) + parent_step.finalize_record = AsyncMock() + child_step = Mock(spec=Step) + child_step.finalize_record = AsyncMock() + + parent_record = RecordContext( + record="parent", originating_step=parent_step, callback_token="parent_token" + ) + child_record = RecordContext( + record="child", + originating_step=child_step, + callback_token="child_token", + originated_from=parent_record, ) + await child_record.drop() + + child_step.finalize_record.assert_called_once_with("child_token") + # Parent should have child_dropped called + assert parent_record.child_record_count == -1 + @pytest.mark.asyncio -async def test_stop_step(step_executor): - await step_executor.stop_step() - step_executor.step.finish.assert_called_once_with(step_executor.context) +async def test_record_child_dropped(): + step = Mock(spec=Step) + step.finalize_record = AsyncMock() + + record = RecordContext( + record="test", + originating_step=step, + callback_token="token", + child_record_count=1, + ) + await record.child_dropped() + assert record.child_record_count == 0 + step.finalize_record.assert_called_once_with("token") + +# Tests for EmitResult enum +def test_emit_result_should_continue(): + assert EmitResult.EMITTED_RECORDS.should_continue is True + assert EmitResult.NO_OP.should_continue is True + assert EmitResult.CLOSED_DOWNSTREAM.should_continue is False + + +def test_emit_result_did_emit_records(): + assert EmitResult.EMITTED_RECORDS.did_emit_records is True + assert EmitResult.NO_OP.did_emit_records is False + assert EmitResult.CLOSED_DOWNSTREAM.did_emit_records is False + + +# Tests for StartStepState @pytest.mark.asyncio -async def test_stop_step_error(step_executor, mocker): - step_executor.step.finish.side_effect = Exception("Boom") - await step_executor.stop_step() - step_executor.context.report_error.assert_called_once_with( - "Error stopping step", mocker.ANY +async def test_start_step_state_success(mock_step, mock_context): + input_channel, output_channel = channel(10) + state = StartStepState(mock_step, mock_context, input_channel, output_channel) + + next_state = await state.execute_until_state_change() + + mock_step.start.assert_called_once_with(mock_context) + assert isinstance(next_state, ProcessRecordsState) + + +@pytest.mark.asyncio +async def test_start_step_state_error(mock_step, mock_context): + input_channel, output_channel = channel(10) + state = StartStepState(mock_step, mock_context, input_channel, output_channel) + mock_step.start.side_effect = Exception("Start failed") + + next_state = await state.execute_until_state_change() + + mock_context.report_error.assert_called_once() + assert next_state is None + + +# Tests for ProcessRecordsState +@pytest.mark.asyncio +async def test_process_records_state_transitions_to_emit_outstanding(): + """Test that ProcessRecordsState transitions to EmitOutstandingRecordsState.""" + mock_step = Mock(spec=Step) + mock_context = Mock(spec=StepContext) + + # Create mock channels + mock_input = Mock() + mock_input.get = AsyncMock(side_effect=[None]) # No records, just end signal + + mock_output = Mock() + + state = ProcessRecordsState(mock_step, mock_context, mock_input, mock_output) + + next_state = await state.execute_until_state_change() + + assert isinstance(next_state, EmitOutstandingRecordsState) + + +@pytest.mark.asyncio +async def test_process_records_state_drops_record_when_no_emission(): + mock_step = Mock(spec=Step) + mock_context = Mock(spec=StepContext) + input_channel, output_channel = channel(10) + + state = ProcessRecordsState(mock_step, mock_context, input_channel, output_channel) + + # Create a record with finalize_record mock + originating_step = Mock(spec=Step) + originating_step.finalize_record = AsyncMock() + test_record = RecordContext( + record="test_data", originating_step=originating_step, callback_token="token" ) + # Mock process_record to return empty generator + async def empty_generator(data, context): + return + yield # pragma: no cover + + mock_step.process_record.return_value = empty_generator("test_data", mock_context) + + # Put record and end signal + await input_channel.channel.put(test_record) + await input_channel.channel.put(None) + + await state.execute_until_state_change() + + # Verify record was dropped (finalize_record called) + originating_step.finalize_record.assert_called_once_with("token") + + +# Tests for StopStepExecution +@pytest.mark.asyncio +async def test_stop_step_execution(mock_step, mock_context): + input_channel, output_channel = channel(10) + state = StopStepExecution(mock_step, mock_context, input_channel, output_channel) + + next_state = await state.execute_until_state_change() + + mock_step.finish.assert_called_once_with(mock_context) + assert next_state is None + + +# Tests for Executor +@pytest.mark.asyncio +async def test_executor_for_step(): + mock_step = Mock(spec=Step) + mock_step.start = AsyncMock() + mock_step.finish = AsyncMock() + input_channel, output_channel = channel(10) + mock_context = Mock(spec=StepContext) + + executor = Executor.for_step(mock_step, input_channel, output_channel, mock_context) + + assert isinstance(executor.state, StartStepState) + + +@pytest.mark.asyncio +async def test_executor_pipeline_output(): + input_channel, _ = channel(10) + reporter = PipelineProgressReporter() + + executor = Executor.pipeline_output(input_channel, reporter) + + assert isinstance(executor.state, PipelineOutputStartState) + + +@pytest.mark.asyncio +async def test_pipeline_output_observes_data_not_record(): + """Test that observability callback receives data, not Record objects.""" + observed_items = [] + + def observe_callback(item): + observed_items.append(item) + + input_channel, _ = channel(10) + reporter = PipelineProgressReporter(observability_callback=observe_callback) + + # Create pipeline output state + from nodestream.metrics import Metrics + + metrics = Metrics() + state = PipelineOutputProcessRecordsState(input_channel, reporter, metrics) + + # Create test records with different data types + step = Mock(spec=Step) + test_data = ["simple_string", {"key": "value"}, ["list", "item"], 42] + + # Put records in input channel + for data in test_data: + record = RecordContext(record=data, originating_step=step, callback_token=data) + await input_channel.channel.put(record) + + # Signal end + await input_channel.channel.put(None) + + # Execute the state + await state.execute_until_state_change() + + # Verify that observed items are the actual data, not Record objects + assert len(observed_items) == len(test_data) + for i, observed in enumerate(observed_items): + assert observed == test_data[i] + assert not hasattr(observed, "data") # Should not be a Record + assert not hasattr(observed, "originating_step") # Should not be a Record + + +# Tests for StepExecutionState emit_record method +@pytest.mark.asyncio +async def test_step_execution_state_emit_record_success(mock_step, mock_context): + input_channel, output_channel = channel(10) + state = StartStepState(mock_step, mock_context, input_channel, output_channel) + + test_record = RecordContext( + record="test", originating_step=mock_step, callback_token="token" + ) + result = await state.emit_record(test_record) + + assert result == EmitResult.EMITTED_RECORDS + + +@pytest.mark.asyncio +async def test_step_execution_state_emit_record_closed_downstream( + mock_step, mock_context +): + input_channel, output_channel = channel(10) + state = StartStepState(mock_step, mock_context, input_channel, output_channel) + + # Set input_dropped to simulate downstream not accepting records + output_channel.channel.input_dropped = True + + test_record = RecordContext( + record="test", originating_step=mock_step, callback_token="token" + ) + result = await state.emit_record(test_record) + + assert result == EmitResult.CLOSED_DOWNSTREAM + mock_context.debug.assert_called_once_with( + "Downstream is not accepting records. Stopping" + ) + + +# Tests for StepExecutionState emit_from_generator method +@pytest.mark.asyncio +async def test_step_execution_state_emit_from_generator_with_records( + mock_step, mock_context +): + input_channel, output_channel = channel(10) + state = StartStepState(mock_step, mock_context, input_channel, output_channel) + + async def test_generator(): + yield "record1" + yield "record2" + + result = await state.emit_from_generator(test_generator()) + + assert result == EmitResult.EMITTED_RECORDS + + +@pytest.mark.asyncio +async def test_step_execution_state_emit_from_generator_closed_downstream( + mock_step, mock_context +): + input_channel, output_channel = channel(10) + state = StartStepState(mock_step, mock_context, input_channel, output_channel) + + # Set input_dropped to simulate downstream not accepting records + output_channel.channel.input_dropped = True + + async def test_generator(): + yield "record1" + yield "record2" + + result = await state.emit_from_generator(test_generator()) + + assert result == EmitResult.CLOSED_DOWNSTREAM + + +# Tests for ProcessRecordsState exception handling +@pytest.mark.asyncio +async def test_process_records_state_exception_handling(mock_step, mock_context): + input_channel, output_channel = channel(10) + state = ProcessRecordsState(mock_step, mock_context, input_channel, output_channel) + + # Create a record + originating_step = Mock(spec=Step) + originating_step.finalize_record = AsyncMock() + test_record = RecordContext( + record="test_data", originating_step=originating_step, callback_token="token" + ) + + # Mock process_record to raise an exception directly (not return a coroutine) + def failing_process_record(data, context): + raise Exception("Processing failed") + + mock_step.process_record = failing_process_record + + # Put record in input channel + await input_channel.channel.put(test_record) + await input_channel.channel.put(None) + + next_state = await state.execute_until_state_change() + + # Check that report_error was called with the right message and fatal=True + # The exact exception may vary, so we just check the call was made + mock_context.report_error.assert_called_once() + args, kwargs = mock_context.report_error.call_args + assert args[0] == "Error processing record" + assert kwargs.get("fatal") is True + assert isinstance(next_state, StopStepExecution) + @pytest.mark.asyncio -async def test_emit_record(step_executor, mocker): - record = mocker.Mock() - step_executor.output.put.return_value = True - await step_executor.emit_record(record) - step_executor.output.put.assert_called_once_with(record) +async def test_process_records_state_closed_downstream_transition( + mock_step, mock_context +): + input_channel, output_channel = channel(10) + state = ProcessRecordsState(mock_step, mock_context, input_channel, output_channel) + + # Set input_dropped to simulate downstream not accepting records + output_channel.channel.input_dropped = True + + # Create a record + originating_step = Mock(spec=Step) + originating_step.finalize_record = AsyncMock() + test_record = RecordContext( + record="test_data", originating_step=originating_step, callback_token="token" + ) + + # Mock process_record to return a generator with records + async def test_generator(data, context): + yield "output_record" + + mock_step.process_record.return_value = test_generator("test_data", mock_context) + + # Put record in input channel + await input_channel.channel.put(test_record) + await input_channel.channel.put(None) + + next_state = await state.execute_until_state_change() + assert isinstance(next_state, StopStepExecution) + +# Tests for EmitOutstandingRecordsState @pytest.mark.asyncio -async def test_emit_record_full(step_executor, mocker): - record = mocker.Mock() - step_executor.output.put.return_value = False - await step_executor.emit_record(record) - step_executor.output.put.assert_called_once_with(record) - step_executor.context.debug.assert_called_once_with( - "Downstream is not accepting more records. Gracefully stopping." +async def test_emit_outstanding_records_state_success(mock_step, mock_context): + input_channel, output_channel = channel(10) + state = EmitOutstandingRecordsState( + mock_step, mock_context, input_channel, output_channel + ) + + # Mock emit_outstanding_records to return a generator + async def outstanding_generator(context): + yield "outstanding_record1" + yield "outstanding_record2" + + mock_step.emit_outstanding_records.return_value = outstanding_generator( + mock_context ) + next_state = await state.execute_until_state_change() + + mock_step.emit_outstanding_records.assert_called_once_with(mock_context) + assert isinstance(next_state, StopStepExecution) + @pytest.mark.asyncio -async def test_drive_step(step_executor, mocker): - record = mocker.Mock() - step = Step() - step_executor.step = step - step_executor.input.get = mocker.AsyncMock(side_effect=[record, None]) - await step_executor.drive_step() - step_executor.output.put.assert_called_once_with(record) - step_executor.context.debug.assert_called_once_with("Step finished emitting") +async def test_emit_outstanding_records_state_exception_handling( + mock_step, mock_context +): + input_channel, output_channel = channel(10) + state = EmitOutstandingRecordsState( + mock_step, mock_context, input_channel, output_channel + ) + # Mock emit_outstanding_records to raise an exception directly + def failing_emit_outstanding(context): + raise Exception("Outstanding records failed") + mock_step.emit_outstanding_records = failing_emit_outstanding + + next_state = await state.execute_until_state_change() + + # Check that report_error was called with the right message and fatal=True + mock_context.report_error.assert_called_once() + args, kwargs = mock_context.report_error.call_args + assert args[0] == "Error emitting outstanding records" + assert kwargs.get("fatal") is True + assert isinstance(next_state, StopStepExecution) + + +# Tests for StopStepExecution error handling @pytest.mark.asyncio -async def test_drive_step_error(step_executor, mocker): - step_executor.input.get.side_effect = Exception("Boom") - await step_executor.drive_step() - step_executor.context.report_error.assert_called_once_with( - "Error running step", - mocker.ANY, - fatal=True, +async def test_stop_step_execution_finish_exception(mock_step, mock_context): + input_channel, output_channel = channel(10) + state = StopStepExecution(mock_step, mock_context, input_channel, output_channel) + + # Mock finish to raise an exception + mock_step.finish.side_effect = Exception("Finish failed") + + next_state = await state.execute_until_state_change() + + mock_context.report_error.assert_called_once_with( + "Error stopping step", mock_step.finish.side_effect ) + assert next_state is None +# Tests for PipelineOutputState call_ignoring_errors method @pytest.mark.asyncio -async def test_drive_step_cannot_continue(step_executor, mocker): - record = mocker.Mock() - step = Step() - step_executor.step = step - step_executor.input.get = mocker.AsyncMock(side_effect=[record, None]) - step_executor.output.put.return_value = False - await step_executor.drive_step() - step_executor.output.put.assert_called_once_with(record) +async def test_pipeline_output_state_call_ignoring_errors_success(): + input_channel, _ = channel(10) + reporter = PipelineProgressReporter() + from nodestream.metrics import Metrics + + metrics = Metrics() + + state = PipelineOutputProcessRecordsState(input_channel, reporter, metrics) + + # Test successful callback + callback_called = False + + def test_callback(): + nonlocal callback_called + callback_called = True + + state.call_ignoring_errors(test_callback) + assert callback_called @pytest.mark.asyncio -async def test_drive_step_cannot_continue_in_emit_outstanding_records( - step_executor, mocker +async def test_pipeline_output_state_call_ignoring_errors_exception(): + input_channel, _ = channel(10) + reporter = PipelineProgressReporter() + from nodestream.metrics import Metrics + + metrics = Metrics() + + state = PipelineOutputProcessRecordsState(input_channel, reporter, metrics) + + # Test callback that raises exception + def failing_callback(): + raise Exception("Callback failed") + + # Should not raise exception, just log it + state.call_ignoring_errors(failing_callback) + + +# Tests for PipelineOutputStartState +@pytest.mark.asyncio +async def test_pipeline_output_start_state(): + input_channel, _ = channel(10) + reporter = PipelineProgressReporter() + from nodestream.metrics import Metrics + + metrics = Metrics() + + state = PipelineOutputStartState(input_channel, reporter, metrics) + + next_state = await state.execute_until_state_change() + + assert isinstance(next_state, PipelineOutputProcessRecordsState) + + +# Tests for PipelineOutputStopState +@pytest.mark.asyncio +async def test_pipeline_output_stop_state(): + input_channel, _ = channel(10) + reporter = PipelineProgressReporter() + from nodestream.metrics import Metrics + + metrics = Metrics() + + state = PipelineOutputStopState(input_channel, reporter, metrics) + + next_state = await state.execute_until_state_change() + + assert next_state is None + + +# Tests for Executor run method +@pytest.mark.asyncio +async def test_executor_run(): + mock_step = Mock(spec=Step) + mock_step.start = AsyncMock() + mock_step.finish = AsyncMock() + mock_step.process_record = AsyncMock() + mock_step.emit_outstanding_records = AsyncMock() + mock_step.finalize_record = AsyncMock() + + input_channel, output_channel = channel(10) + mock_context = Mock(spec=StepContext) + mock_context.report_error = Mock() + mock_context.debug = Mock() + + executor = Executor.for_step(mock_step, input_channel, output_channel, mock_context) + + # Mock emit_outstanding_records to return empty generator + async def empty_generator(context): + return + yield # pragma: no cover + + mock_step.emit_outstanding_records.return_value = empty_generator(mock_context) + + # Close input to end processing + await input_channel.channel.put(None) + + await executor.run() + + mock_step.start.assert_called_once_with(mock_context) + mock_step.finish.assert_called_once_with(mock_context) + + +# Tests for Pipeline class +@pytest.mark.asyncio +async def test_pipeline_init(): + from nodestream.pipeline.object_storage import NullObjectStore + from nodestream.pipeline.step import PassStep + + steps = (PassStep(), PassStep()) + step_outbox_size = 100 + object_store = NullObjectStore() + + pipeline = Pipeline(steps, step_outbox_size, object_store) + + assert pipeline.steps == steps + assert pipeline.step_outbox_size == step_outbox_size + assert pipeline.object_store == object_store + + +@pytest.mark.asyncio +async def test_pipeline_get_child_expanders(): + from nodestream.pipeline.object_storage import NullObjectStore + from nodestream.pipeline.step import PassStep + from nodestream.schema import ExpandsSchema + + # Create a mock step that implements both Step and ExpandsSchema + class MockExpanderStep(Step, ExpandsSchema): + def process_record(self, record, context): + yield record + + def expand_schema(self, schema): + pass + + mock_expander_step = MockExpanderStep() + regular_step = PassStep() + + steps = (regular_step, mock_expander_step) + pipeline = Pipeline(steps, 100, NullObjectStore()) + + expanders = list(pipeline.get_child_expanders()) + + assert len(expanders) == 1 + assert expanders[0] == mock_expander_step + + +@pytest.mark.asyncio +async def test_pipeline_run(): + from nodestream.pipeline.object_storage import NullObjectStore + from nodestream.pipeline.step import PassStep + + # Create simple steps + step1 = PassStep() + step2 = PassStep() + steps = (step1, step2) + + pipeline = Pipeline(steps, 10, NullObjectStore()) + reporter = PipelineProgressReporter() + + # Run the pipeline + await pipeline.run(reporter) + + # If we get here without exception, the pipeline ran successfully + + +# Test to cover the remaining missing lines (150-151, 184, 244) +@pytest.mark.asyncio +async def test_emit_from_generator_early_return_on_closed_downstream( + mock_step, mock_context ): - record = mocker.Mock() + """Test that emit_from_generator returns early when downstream closes.""" + input_channel, output_channel = channel(10) + state = StartStepState(mock_step, mock_context, input_channel, output_channel) - async def emit_outstanding_records(): - yield record + # Create a custom generator that will set input_dropped after first yield + call_count = 0 + + async def test_generator(): + nonlocal call_count + call_count += 1 + yield "record1" + # After first record is yielded, simulate downstream closing + output_channel.channel.input_dropped = True + call_count += 1 + yield "record2" # This should not be processed due to early return + call_count += 1 + yield "record3" + + result = await state.emit_from_generator(test_generator()) - step_executor.input.get = mocker.AsyncMock(side_effect=[None]) - step_executor.step.emit_outstanding_records.return_value = ( - emit_outstanding_records() + # Should return CLOSED_DOWNSTREAM and only process first record + assert result == EmitResult.CLOSED_DOWNSTREAM + # Only first record should be processed, generator should stop after that + assert call_count == 2 # First yield + attempt at second yield + + +# Test to cover line 246 - the specific case where should_continue is False +@pytest.mark.asyncio +async def test_process_records_state_should_continue_false(mock_step, mock_context): + """Test ProcessRecordsState when emit result should_continue is False.""" + input_channel, output_channel = channel(10) + state = ProcessRecordsState(mock_step, mock_context, input_channel, output_channel) + + # Create a record + originating_step = Mock(spec=Step) + originating_step.finalize_record = AsyncMock() + test_record = RecordContext( + record="test_data", originating_step=originating_step, callback_token="token" ) - step_executor.output.put.return_value = False - await step_executor.drive_step() - step_executor.output.put.assert_called_once_with(record) + + # Mock process_record to return a generator that yields records + async def test_generator(data, context): + yield "output_record" + + mock_step.process_record.return_value = test_generator("test_data", mock_context) + + # Mock emit_from_generator to return CLOSED_DOWNSTREAM to trigger should_continue = False + from unittest.mock import patch + + with patch.object( + state, "emit_from_generator", return_value=EmitResult.CLOSED_DOWNSTREAM + ): + # Put record in input channel + await input_channel.channel.put(test_record) + # Put end signal + await input_channel.channel.put(None) + + next_state = await state.execute_until_state_change() + + # Should transition to StopStepExecution due to should_continue being False + assert isinstance(next_state, StopStepExecution) @pytest.mark.asyncio -async def test_pipeline_output_call_handling_errors(mocker): - def on_start_callback(): - raise Exception("Boom") +async def test_pipeline_output_on_finish_callback_exceptions_not_swallowed( + mocker, +): + """Test that on_finish_callback exceptions are not swallowed. + + This test covers the case described in the diff comment: + 'For cases where the reporter _wants_ to have the exception thrown + (e.g to have a status code in the CLI) we need to make sure we call + on_finish_callback without swallowing exceptions (because thats the + point).' + """ + + def on_finish_callback(metrics): + raise ValueError("CLI status code exception") + + input_mock = mocker.Mock(StepInput) + # No records to process + input_mock.get = mocker.AsyncMock(return_value=None) - output = PipelineOutput( - mocker.Mock(StepInput), + executor = Executor.pipeline_output( + input_mock, PipelineProgressReporter( - on_start_callback=on_start_callback, + on_finish_callback=on_finish_callback, logger=mocker.Mock(), ), ) - output.call_handling_errors(output.reporter.on_start_callback) - output.reporter.logger.exception.assert_called_once() + # The exception should propagate up and not be caught + with pytest.raises(ValueError, match="CLI status code exception"): + await executor.run() + + +@pytest.mark.asyncio +async def test_pipeline_on_start_callback_called_before_step_operations( + mocker, +): + """Test that on_start_callback is called before any step operations begin. + + This test covers the case described in the diff comment: + 'Here lies a footgun. DO NOT MOVE the `create_task` call after the + running_steps generator. It will break the pipeline because we need + to ensure the on_start_callback is called before any operations + occur in the actual steps in the pipeline.' + """ + from nodestream.pipeline.object_storage import ObjectStore + from nodestream.pipeline.pipeline import Pipeline + + # Track the order of operations + call_order = [] + + def on_start_callback(): + call_order.append("on_start_callback") + + def on_finish_callback(metrics): + call_order.append("on_finish_callback") + + # Create a mock step that records when it starts + mock_step = mocker.Mock(Step) + mock_step.__class__.__name__ = "MockStep" + + async def mock_start(context): + call_order.append("step_start") + + async def mock_finish(context): + call_order.append("step_finish") + + async def mock_process_record(record, context): + call_order.append("step_process") + yield record + + async def mock_emit_outstanding_records(context): + call_order.append("step_emit_outstanding") + for record in (): + yield record + + mock_step.start = mock_start + mock_step.finish = mock_finish + mock_step.process_record = mock_process_record + mock_step.emit_outstanding_records = mock_emit_outstanding_records + + # Create pipeline with the mock step + object_store = mocker.Mock(spec=ObjectStore) + object_store.namespaced.return_value = object_store + + pipeline = Pipeline((mock_step,), 10, object_store) + + reporter = PipelineProgressReporter( + on_start_callback=on_start_callback, + on_finish_callback=on_finish_callback, + logger=mocker.Mock(), + ) + + await pipeline.run(reporter) + + # Verify that on_start_callback was called before any step operations + assert call_order[0] == "on_start_callback" + assert "step_start" in call_order + start_idx = call_order.index("on_start_callback") + step_idx = call_order.index("step_start") + assert start_idx < step_idx diff --git a/tests/unit/pipeline/test_step_finalize_record.py b/tests/unit/pipeline/test_step_finalize_record.py new file mode 100644 index 00000000..5b3044a0 --- /dev/null +++ b/tests/unit/pipeline/test_step_finalize_record.py @@ -0,0 +1,227 @@ +import pytest + +from nodestream.pipeline.pipeline import RecordContext +from nodestream.pipeline.step import Step + + +class TestStep(Step): + """Test step that tracks finalize_record calls.""" + + def __init__(self): + self.finalize_calls = [] + + async def finalize_record(self, record_token): + self.finalize_calls.append(record_token) + + +class ResourceTrackingStep(Step): + """Step that tracks resources allocated per record.""" + + def __init__(self): + self.allocated_resources = {} + self.finalized_tokens = [] + + async def process_record(self, record, context): + # Simulate allocating resources for this record + token = f"resource_for_{id(record)}" + self.allocated_resources[token] = f"resource_data_{id(record)}" + yield (record, token) # Emit tuple with callback token + + async def finalize_record(self, record_token): + """Clean up resources allocated for this record.""" + if record_token in self.allocated_resources: + del self.allocated_resources[record_token] + self.finalized_tokens.append(record_token) + + +@pytest.mark.asyncio +async def test_step_finalize_record_custom_implementation(): + """Test that custom finalize_record implementation is called.""" + step = TestStep() + + tokens = ["token1", "token2", {"complex": "token"}] + + for token in tokens: + await step.finalize_record(token) + + assert step.finalize_calls == tokens + + +@pytest.mark.asyncio +async def test_finalize_record_called_on_record_drop(): + """Test that finalize_record is called when a record is dropped.""" + step = TestStep() + token = "test_token" + + record = RecordContext( + record="test_data", originating_step=step, callback_token=token + ) + + await record.drop() + + assert step.finalize_calls == [token] + + +@pytest.mark.asyncio +async def test_finalize_record_called_with_correct_token(): + """Test that finalize_record is called with the correct callback token.""" + step = TestStep() + + # Test with simple data (token == data) + record1 = RecordContext.from_step_emission(step, "simple_data") + await record1.drop() + + # Test with tuple (data, token) + record2 = RecordContext.from_step_emission(step, ("data", "custom_token")) + await record2.drop() + + assert step.finalize_calls == ["simple_data", "custom_token"] + + +@pytest.mark.asyncio +async def test_finalize_record_resource_cleanup(): + """Test finalize_record for resource cleanup scenario.""" + step = ResourceTrackingStep() + + # Simulate processing records + record1_data = "record1" + record2_data = "record2" + + # Process first record + async for emission in step.process_record(record1_data, None): + record1 = RecordContext.from_step_emission(step, emission) + break + + # Process second record + async for emission in step.process_record(record2_data, None): + record2 = RecordContext.from_step_emission(step, emission) + break + + # Verify resources were allocated + assert len(step.allocated_resources) == 2 + + # Drop first record + await record1.drop() + + # Verify first record's resources were cleaned up + assert len(step.allocated_resources) == 1 + assert len(step.finalized_tokens) == 1 + + # Drop second record + await record2.drop() + + # Verify all resources were cleaned up + assert len(step.allocated_resources) == 0 + assert len(step.finalized_tokens) == 2 + + +@pytest.mark.asyncio +async def test_finalize_record_exception_handling(): + """Test behavior when finalize_record raises an exception.""" + + class FailingStep(Step): + async def finalize_record(self, record_token): + raise ValueError(f"Failed to finalize {record_token}") + + step = FailingStep() + record = RecordContext( + record="test", originating_step=step, callback_token="failing_token" + ) + + # Exception should propagate + with pytest.raises(ValueError, match="Failed to finalize failing_token"): + await record.drop() + + +@pytest.mark.asyncio +async def test_finalize_record_with_child_records(): + """Test finalize_record behavior with parent-child record relationships.""" + parent_step = TestStep() + child_step = TestStep() + + parent_record = RecordContext( + record="parent", + originating_step=parent_step, + callback_token="parent_token", + child_record_count=2, + ) + + child1 = RecordContext( + record="child1", + originating_step=child_step, + callback_token="child1_token", + originated_from=parent_record, + ) + + child2 = RecordContext( + record="child2", + originating_step=child_step, + callback_token="child2_token", + originated_from=parent_record, + ) + + # Drop first child + await child1.drop() + + # Only child should be finalized, not parent yet + assert child_step.finalize_calls == ["child1_token"] + assert parent_step.finalize_calls == [] + + # Drop second child + await child2.drop() + + # Now both children and parent should be finalized + assert child_step.finalize_calls == ["child1_token", "child2_token"] + assert parent_step.finalize_calls == ["parent_token"] + + +@pytest.mark.asyncio +async def test_finalize_record_multiple_tokens_same_step(): + """Test finalize_record called multiple times on same step.""" + step = TestStep() + + records = [] + for i in range(5): + record = RecordContext( + record=f"data_{i}", originating_step=step, callback_token=f"token_{i}" + ) + records.append(record) + + # Drop all records + for record in records: + await record.drop() + + expected_tokens = [f"token_{i}" for i in range(5)] + assert step.finalize_calls == expected_tokens + + +@pytest.mark.asyncio +async def test_finalize_record_with_none_token(): + """Test finalize_record behavior with None token.""" + step = TestStep() + + record = RecordContext(record="test", originating_step=step, callback_token=None) + + await record.drop() + + assert step.finalize_calls == [None] + + +@pytest.mark.asyncio +async def test_finalize_record_with_complex_token(): + """Test finalize_record with complex token objects.""" + step = TestStep() + + complex_token = { + "id": "record_123", + "metadata": {"type": "test", "priority": 1}, + "resources": ["file1.txt", "file2.txt"], + } + + record = RecordContext( + record="test", originating_step=step, callback_token=complex_token + ) + + await record.drop() + + assert step.finalize_calls == [complex_token]