Skip to content

Commit 7a9d3bd

Browse files
committed
update related unit test cases
Signed-off-by: Liu, Kaixuan <[email protected]>
1 parent abb8f73 commit 7a9d3bd

File tree

1 file changed

+54
-45
lines changed

1 file changed

+54
-45
lines changed

tests/pipelines/test_pipelines.py

Lines changed: 54 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,9 @@ def test_one_request_upon_cached(self):
167167
download_requests = [r.method for r in m.request_history]
168168
assert download_requests.count("HEAD") == 15, "15 calls to files"
169169
assert download_requests.count("GET") == 17, "15 calls to files + model_info + model_index.json"
170-
assert len(download_requests) == 32, (
171-
"2 calls per file (15 files) + send_telemetry, model_info and model_index.json"
172-
)
170+
assert (
171+
len(download_requests) == 32
172+
), "2 calls per file (15 files) + send_telemetry, model_info and model_index.json"
173173

174174
with requests_mock.mock(real_http=True) as m:
175175
DiffusionPipeline.download(
@@ -179,9 +179,9 @@ def test_one_request_upon_cached(self):
179179
cache_requests = [r.method for r in m.request_history]
180180
assert cache_requests.count("HEAD") == 1, "model_index.json is only HEAD"
181181
assert cache_requests.count("GET") == 1, "model info is only GET"
182-
assert len(cache_requests) == 2, (
183-
"We should call only `model_info` to check for _commit hash and `send_telemetry`"
184-
)
182+
assert (
183+
len(cache_requests) == 2
184+
), "We should call only `model_info` to check for _commit hash and `send_telemetry`"
185185

186186
def test_less_downloads_passed_object(self):
187187
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -217,9 +217,9 @@ def test_less_downloads_passed_object_calls(self):
217217
assert download_requests.count("HEAD") == 13, "13 calls to files"
218218
# 17 - 2 because no call to config or model file for `safety_checker`
219219
assert download_requests.count("GET") == 15, "13 calls to files + model_info + model_index.json"
220-
assert len(download_requests) == 28, (
221-
"2 calls per file (13 files) + send_telemetry, model_info and model_index.json"
222-
)
220+
assert (
221+
len(download_requests) == 28
222+
), "2 calls per file (13 files) + send_telemetry, model_info and model_index.json"
223223

224224
with requests_mock.mock(real_http=True) as m:
225225
DiffusionPipeline.download(
@@ -229,9 +229,9 @@ def test_less_downloads_passed_object_calls(self):
229229
cache_requests = [r.method for r in m.request_history]
230230
assert cache_requests.count("HEAD") == 1, "model_index.json is only HEAD"
231231
assert cache_requests.count("GET") == 1, "model info is only GET"
232-
assert len(cache_requests) == 2, (
233-
"We should call only `model_info` to check for _commit hash and `send_telemetry`"
234-
)
232+
assert (
233+
len(cache_requests) == 2
234+
), "We should call only `model_info` to check for _commit hash and `send_telemetry`"
235235

236236
def test_download_only_pytorch(self):
237237
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -538,26 +538,38 @@ def test_download_variant_partly(self):
538538
variant = "no_ema"
539539

540540
with tempfile.TemporaryDirectory() as tmpdirname:
541-
tmpdirname = StableDiffusionPipeline.download(
542-
"hf-internal-testing/stable-diffusion-all-variants",
543-
cache_dir=tmpdirname,
544-
variant=variant,
545-
use_safetensors=use_safetensors,
546-
)
547-
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
548-
files = [item for sublist in all_root_files for item in sublist]
549-
550-
unet_files = os.listdir(os.path.join(tmpdirname, "unet"))
541+
if use_safetensors:
542+
with self.assertRaises(OSError) as error_context:
543+
tmpdirname = StableDiffusionPipeline.download(
544+
"hf-internal-testing/stable-diffusion-all-variants",
545+
cache_dir=tmpdirname,
546+
variant=variant,
547+
use_safetensors=use_safetensors,
548+
)
549+
assert "Could not find the necessary `safetensors` weights" in str(error_context.exception)
550+
else:
551+
tmpdirname = StableDiffusionPipeline.download(
552+
"hf-internal-testing/stable-diffusion-all-variants",
553+
cache_dir=tmpdirname,
554+
variant=variant,
555+
use_safetensors=use_safetensors,
556+
)
557+
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
558+
files = [item for sublist in all_root_files for item in sublist]
551559

552-
# Some of the downloaded files should be a non-variant file, check:
553-
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
554-
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
555-
# only unet has "no_ema" variant
556-
assert f"diffusion_pytorch_model.{variant}{this_format}" in unet_files
557-
assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 1
558-
# vae, safety_checker and text_encoder should have no variant
559-
assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3
560-
assert not any(f.endswith(other_format) for f in files)
560+
unet_files = os.listdir(os.path.join(tmpdirname, "unet"))
561+
562+
# Some of the downloaded files should be a non-variant file, check:
563+
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
564+
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
565+
# only unet has "no_ema" variant
566+
assert f"diffusion_pytorch_model.{variant}{this_format}" in unet_files
567+
assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 1
568+
# vae, safety_checker and text_encoder should have no variant
569+
assert (
570+
sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3
571+
)
572+
assert not any(f.endswith(other_format) for f in files)
561573

562574
def test_download_variants_with_sharded_checkpoints(self):
563575
# Here we test for downloading of "variant" files belonging to the `unet` and
@@ -588,19 +600,16 @@ def test_download_legacy_variants_with_sharded_ckpts_raises_warning(self):
588600
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
589601
deprecated_warning_msg = "Warning: The repository contains sharded checkpoints for variant"
590602

591-
for is_local in [True, False]:
592-
with CaptureLogger(logger) as cap_logger:
593-
with tempfile.TemporaryDirectory() as tmpdirname:
594-
local_repo_id = repo_id
595-
if is_local:
596-
local_repo_id = snapshot_download(repo_id, cache_dir=tmpdirname)
603+
with CaptureLogger(logger) as cap_logger:
604+
with tempfile.TemporaryDirectory() as tmpdirname:
605+
local_repo_id = snapshot_download(repo_id, cache_dir=tmpdirname)
597606

598-
_ = DiffusionPipeline.from_pretrained(
599-
local_repo_id,
600-
safety_checker=None,
601-
variant="fp16",
602-
use_safetensors=True,
603-
)
607+
_ = DiffusionPipeline.from_pretrained(
608+
local_repo_id,
609+
safety_checker=None,
610+
variant="fp16",
611+
use_safetensors=True,
612+
)
604613
assert deprecated_warning_msg in str(cap_logger), "Deprecation warning not found in logs"
605614

606615
def test_download_safetensors_only_variant_exists_for_model(self):
@@ -616,7 +625,7 @@ def test_download_safetensors_only_variant_exists_for_model(self):
616625
variant=variant,
617626
use_safetensors=use_safetensors,
618627
)
619-
assert "Error no file name" in str(error_context.exception)
628+
assert "Could not find the necessary `safetensors` weights" in str(error_context.exception)
620629

621630
# text encoder has fp16 variants so we can load it
622631
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -675,7 +684,7 @@ def test_download_safetensors_variant_does_not_exist_for_model(self):
675684
use_safetensors=use_safetensors,
676685
)
677686

678-
assert "Error no file name" in str(error_context.exception)
687+
assert "Could not find the necessary `safetensors` weights" in str(error_context.exception)
679688

680689
def test_download_bin_variant_does_not_exist_for_model(self):
681690
variant = "no_ema"

0 commit comments

Comments
 (0)