Skip to content

Commit 567dbf2

Browse files
committed
fixes #281
1 parent f4466f0 commit 567dbf2

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

python_coreml_stable_diffusion/torch2coreml.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -827,7 +827,7 @@ def convert_unet(pipe, args, model_name = None):
827827
logger.info("Done.")
828828

829829
if args.check_output_correctness:
830-
baseline_out = pipe.unet(**baseline_sample_unet_inputs,
830+
baseline_out = pipe.unet.to(torch.float32)(**baseline_sample_unet_inputs,
831831
return_dict=False)[0].numpy()
832832
reference_out = reference_unet(*sample_unet_inputs.values())[0].numpy()
833833
report_correctness(baseline_out, reference_out,

0 commit comments

Comments
 (0)