Skip to content

Commit 61196d8

Browse files
Add option to inference the diffusion model in fp32 and fp64.
1 parent b4526d3 commit 61196d8

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

comfy/cli_args.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,10 @@ def __call__(self, parser, namespace, values, option_string=None):
6060
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
6161

6262
fpunet_group = parser.add_mutually_exclusive_group()
63-
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
64-
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.")
63+
fpunet_group.add_argument("--fp32-unet", action="store_true", help="Run the diffusion model in fp32.")
64+
fpunet_group.add_argument("--fp64-unet", action="store_true", help="Run the diffusion model in fp64.")
65+
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the diffusion model in bf16.")
66+
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16")
6567
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
6668
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
6769

comfy/model_management.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,10 @@ def maximum_vram_for_weights(device=None):
628628
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
629629
if model_params < 0:
630630
model_params = 1000000000000000000000
631+
if args.fp32_unet:
632+
return torch.float32
633+
if args.fp64_unet:
634+
return torch.float64
631635
if args.bf16_unet:
632636
return torch.bfloat16
633637
if args.fp16_unet:
@@ -674,7 +678,7 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
674678

675679
# None means no manual cast
676680
def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
677-
if weight_dtype == torch.float32:
681+
if weight_dtype == torch.float32 or weight_dtype == torch.float64:
678682
return None
679683

680684
fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)

0 commit comments

Comments
 (0)