Skip to content

Commit 7dbc5ac

Browse files
committed
linting + adding unit tests for one file
1 parent b4fe85d commit 7dbc5ac

File tree

4 files changed

+69
-12
lines changed

4 files changed

+69
-12
lines changed

nodestream/cli/commands/print_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from ..operations import InitializeProject, PrintProjectSchema
44
from .nodestream_command import NodestreamCommand
5-
from .shared_options import PROJECT_FILE_OPTION, MANY_PIPELINES_ARGUMENT
5+
from .shared_options import MANY_PIPELINES_ARGUMENT, PROJECT_FILE_OPTION
66

77

88
class PrintSchema(NodestreamCommand):

nodestream/cli/operations/print_project_schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@ async def perform(self, command: NodestreamCommand):
2929
if self.pipeline_names:
3030
schema = self.project.get_pipelines_schema(
3131
pipeline_names=self.pipeline_names,
32-
type_overrides_file=type_overrides_file
32+
type_overrides_file=type_overrides_file,
3333
)
3434
else:
3535
schema = self.project.get_schema(type_overrides_file=type_overrides_file)
36-
36+
3737
# Import all schema printers so that they can register themselves
3838
SchemaPrinter.import_all()
3939
printer = SchemaPrinter.from_name(self.format_string)

nodestream/project/project.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,9 @@ def get_schema(self, type_overrides_file: Optional[Path] = None) -> Schema:
291291
schema.merge(overrides_schema)
292292
return schema
293293

294-
def get_pipelines_schema(self, pipeline_names: List[str], type_overrides_file: Optional[Path] = None) -> Schema:
294+
def get_pipelines_schema(
295+
self, pipeline_names: List[str], type_overrides_file: Optional[Path] = None
296+
) -> Schema:
295297
"""Returns a `GraphSchema` representing only the specified pipelines.
296298
297299
This method generates a schema from only the specified pipelines,
@@ -313,33 +315,36 @@ def get_pipelines_schema(self, pipeline_names: List[str], type_overrides_file: O
313315
targets_by_name=self.targets_by_name,
314316
storage_configuration=self.storage_configuration,
315317
)
316-
318+
317319
pipelines_found = []
318-
320+
319321
for scope in self.scopes_by_name.values():
320322
filtered_scope = PipelineScope(
321323
name=scope.name,
322324
config=scope.config,
323325
pipeline_configuration=scope.pipeline_configuration,
324326
)
325-
327+
326328
for pipeline_name in pipeline_names:
327329
if pipeline_name in scope.pipelines_by_name:
328330
filtered_scope.add_pipeline_definition(
329331
scope.pipelines_by_name[pipeline_name]
330332
)
331333
pipelines_found.append(pipeline_name)
332-
334+
333335
if filtered_scope.pipelines_by_name:
334336
filtered_project.add_scope(filtered_scope)
335-
337+
336338
if not pipelines_found:
337339
available_pipelines = [
338-
name for scope in self.scopes_by_name.values()
340+
name
341+
for scope in self.scopes_by_name.values()
339342
for name in scope.pipelines_by_name.keys()
340343
]
341-
raise ValueError(f"None of the specified pipelines {pipeline_names} were found. Available pipelines: {available_pipelines}")
342-
344+
raise ValueError(
345+
f"None of the specified pipelines {pipeline_names} were found. Available pipelines: {available_pipelines}"
346+
)
347+
343348
schema = filtered_project.make_schema()
344349
if type_overrides_file is not None:
345350
overrides_schema = Schema.read_from_file(type_overrides_file)

tests/unit/cli/operations/test_print_project_schema.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from pathlib import Path
2+
13
import pytest
24

35
from nodestream.cli.operations import PrintProjectSchema
@@ -42,3 +44,53 @@ async def test_print_project_schema_prints_schema_to_file(
4244
_, stdout, fileout = await ran_print_project_schema_operation("some/path")
4345
assert not stdout.called
4446
assert fileout.called
47+
48+
49+
@pytest.mark.asyncio
50+
async def test_print_project_schema_with_pipeline_filtering(
51+
mocker, project_with_default_scope
52+
):
53+
std_out = mocker.patch(
54+
"nodestream.schema.printers.SchemaPrinter.print_schema_to_stdout"
55+
)
56+
project_with_default_scope.get_schema = mocker.Mock(return_value="full_schema")
57+
project_with_default_scope.get_pipelines_schema = mocker.Mock(
58+
return_value="filtered_schema"
59+
)
60+
61+
operation = PrintProjectSchema(
62+
project=project_with_default_scope,
63+
format_string="plain",
64+
pipeline_names=["test_pipeline"],
65+
)
66+
await operation.perform(mocker.Mock())
67+
68+
# Should call get_pipelines_schema, not get_schema
69+
project_with_default_scope.get_pipelines_schema.assert_called_once_with(
70+
pipeline_names=["test_pipeline"], type_overrides_file=None
71+
)
72+
project_with_default_scope.get_schema.assert_not_called()
73+
std_out.assert_called_once()
74+
75+
76+
@pytest.mark.asyncio
77+
async def test_print_project_schema_with_pipeline_filtering_and_overrides(
78+
mocker, project_with_default_scope
79+
):
80+
project_with_default_scope.get_pipelines_schema = mocker.Mock(
81+
return_value="filtered_schema"
82+
)
83+
84+
operation = PrintProjectSchema(
85+
project=project_with_default_scope,
86+
format_string="plain",
87+
type_overrides_file="overrides.yaml",
88+
pipeline_names=["test_pipeline", "another_pipeline"],
89+
)
90+
await operation.perform(mocker.Mock())
91+
92+
# Should call get_pipelines_schema with correct arguments
93+
project_with_default_scope.get_pipelines_schema.assert_called_once_with(
94+
pipeline_names=["test_pipeline", "another_pipeline"],
95+
type_overrides_file=Path("overrides.yaml"),
96+
)

0 commit comments

Comments
 (0)