We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 53f3c64 commit dfd0312Copy full SHA for dfd0312
examples/dreambooth/train_dreambooth_flux.py
@@ -1041,7 +1041,8 @@ def main(args):
1041
cur_class_images = len(list(class_images_dir.iterdir()))
1042
1043
if cur_class_images < args.num_class_images:
1044
- has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()
+ has_supported_fp16_accelerator = (torch.cuda.is_available() or torch.backends.mps.is_available()
1045
+ or is_torch_npu_available())
1046
torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32
1047
if args.prior_generation_precision == "fp32":
1048
torch_dtype = torch.float32
0 commit comments