Skip to content

Commit dfd0312

Browse files
author
蒋硕
committed
NPU implementation for FLUX
1 parent 53f3c64 commit dfd0312

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

examples/dreambooth/train_dreambooth_flux.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1041,7 +1041,8 @@ def main(args):
10411041
cur_class_images = len(list(class_images_dir.iterdir()))
10421042

10431043
if cur_class_images < args.num_class_images:
1044-
has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()
1044+
has_supported_fp16_accelerator = (torch.cuda.is_available() or torch.backends.mps.is_available()
1045+
or is_torch_npu_available())
10451046
torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32
10461047
if args.prior_generation_precision == "fp32":
10471048
torch_dtype = torch.float32

0 commit comments

Comments
 (0)