Skip to content

Commit 0d38c46

Browse files
authored
Flux performance optimization (mindspore-lab#966)
* flux performance optimization * add comment * change default value * change the default value * fix ci * fix typo
1 parent 7416d36 commit 0d38c46

File tree

3 files changed

+34
-3
lines changed

3 files changed

+34
-3
lines changed

examples/diffusers/controlnet/train_controlnet_flux.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
set_seed,
5050
)
5151
from mindone.transformers import CLIPTextModel, T5EncoderModel
52+
from mindone.utils.config import str2bool
5253

5354
logger = logging.getLogger(__name__)
5455

@@ -419,6 +420,12 @@ def parse_args(input_args=None):
419420
" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
420421
),
421422
)
423+
parser.add_argument(
424+
"--dataset_iterator_no_copy",
425+
default=True,
426+
type=str2bool,
427+
help="dataset iterator optimization strategy. Whether dataset iterator creates a Tensor without copy.",
428+
)
422429
parser.add_argument(
423430
"--dataset_name",
424431
type=str,
@@ -1165,7 +1172,13 @@ def __len__(self):
11651172
# Only show the progress bar once on each machine.
11661173
disable=not is_master(args),
11671174
)
1168-
train_dataloader_iter = train_dataloader.create_tuple_iterator(num_epochs=args.num_train_epochs - first_epoch)
1175+
# do_copy=False enables the dataset iterator to not do copy when creating a tensor which takes less time.
1176+
# Currently the default value of do_copy is True,
1177+
# it is expected that the default value of do_copy will be changed to False in MindSpore 2.7.0.
1178+
train_dataloader_iter = train_dataloader.create_tuple_iterator(
1179+
num_epochs=args.num_train_epochs - first_epoch,
1180+
do_copy=not args.dataset_iterator_no_copy,
1181+
)
11691182

11701183
for epoch in range(first_epoch, args.num_train_epochs):
11711184
flux_controlnet.set_train(True)

examples/diffusers/dreambooth/train_dreambooth_lora_flux.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,20 @@ def parse_args(input_args=None):
201201
required=False,
202202
help="A folder containing the training data of class images.",
203203
)
204+
parser.add_argument(
205+
"--jit_level",
206+
type=str,
207+
default="O1",
208+
choices=["O0", "O1", "O2"],
209+
help=(
210+
"Used to control the compilation optimization level, supports [O0, O1, O2]. The framework automatically "
211+
"selects the execution method. O0: All optimizations except those necessary for functionality are "
212+
"disabled, using an operator-by-operator execution method. O1: Enables common optimizations and automatic "
213+
"operator fusion optimizations, using an operator-by-operator execution method. This is an experimental "
214+
"optimization level, which is continuously being improved. O2: Enables extreme performance optimization, "
215+
"using a sinking execution method."
216+
),
217+
)
204218
parser.add_argument(
205219
"--instance_prompt",
206220
type=str,
@@ -908,7 +922,11 @@ def encode_prompt(
908922

909923
def main(args):
910924
args = parse_args()
911-
ms.set_context(mode=ms.GRAPH_MODE, jit_syntax_level=ms.STRICT)
925+
ms.set_context(
926+
mode=ms.GRAPH_MODE,
927+
jit_syntax_level=ms.STRICT,
928+
jit_config={"jit_level": args.jit_level},
929+
)
912930
init_distributed_device(args)
913931

914932
logging_dir = Path(args.output_dir, args.logging_dir)

mindone/diffusers/models/embeddings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1237,7 +1237,7 @@ def construct(self, ids: ms.Tensor) -> ms.Tensor:
12371237
cos_out = []
12381238
sin_out = []
12391239
pos = ids.float()
1240-
freqs_dtype = ms.float64
1240+
freqs_dtype = ms.float32
12411241
for i in range(n_axes):
12421242
cos, sin = get_1d_rotary_pos_embed(
12431243
self.axes_dim[i],

0 commit comments

Comments
 (0)