Skip to content

Commit 0998bd7

Browse files
committed
up
1 parent 5f560d0 commit 0998bd7

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

tests/pipelines/test_modular_pipelines_common.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,13 @@ def to_np(tensor):
3030
class ModularPipelineTesterMixin:
3131
"""
3232
This mixin is designed to be used with unittest.TestCase classes.
33-
It provides a set of common tests for each PyTorch pipeline, e.g. saving and loading the pipeline,
34-
equivalence of dict and tuple outputs, etc.
33+
It provides a set of common tests for each modular pipeline,
34+
including:
35+
- test_pipeline_call_signature: check if the pipeline's __call__ method has all required parameters
36+
- test_inference_batch_consistent: check if the pipeline's __call__ method can handle batch inputs
37+
- test_inference_batch_single_identical: check if the pipeline's __call__ method can handle single input
38+
- test_float16_inference: check if the pipeline's __call__ method can handle float16 inputs
39+
- test_to_device: check if the pipeline's __call__ method can handle different devices
3540
"""
3641

3742
# Canonical parameters that are passed to `__call__` regardless
@@ -45,7 +50,7 @@ class ModularPipelineTesterMixin:
4550
"output_type",
4651
]
4752
)
48-
# generator needs to be a intermediate input because it's mutable
53+
# this is modular specific: generator needs to be a intermediate input because it's mutable
4954
required_intermediate_params = frozenset(
5055
[
5156
"generator",

0 commit comments

Comments
 (0)