Skip to content

Commit d884783

Browse files
committed
More fixes
1 parent ac7e116 commit d884783

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/python_tests/test_llm_pipeline.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -704,11 +704,11 @@ def run_perf_metrics_collection(model_id, generation_config_dict: dict, prompt:
704704
]
705705
@pytest.mark.parametrize("generation_config,prompt", test_cases)
706706
@pytest.mark.precommit
707-
def test_perf_metrics(generation_config, prompt):
707+
def test_perf_metrics(generation_config, prompt, model_downloader):
708708
import time
709709
start_time = time.perf_counter()
710710
model_id = 'katuni4ka/tiny-random-gemma2'
711-
perf_metrics = run_perf_metrics_collection(model_id, generation_config, prompt)
711+
perf_metrics = run_perf_metrics_collection(model_id, generation_config, prompt, model_downloader)
712712
total_time = (time.perf_counter() - start_time) * 1000
713713

714714
# Check that load time is adequate.
@@ -816,7 +816,7 @@ class Person(BaseModel):
816816
@pytest.mark.parametrize("pipeline_type", get_main_pipeline_types())
817817
@pytest.mark.parametrize("stop_str", {True, False})
818818
@pytest.mark.precommit
819-
def test_pipelines_generate_with_streaming(pipeline_type, stop_str):
819+
def test_pipelines_generate_with_streaming(pipeline_type, stop_str, model_downloader):
820820
# streamer
821821
it_cnt = 0
822822
def py_streamer(py_str: str):

0 commit comments

Comments
 (0)