diff --git a/nemo/collections/diffusion/models/dit/dit_layer_spec.py b/nemo/collections/diffusion/models/dit/dit_layer_spec.py index 9caa208205a2..38bfe7c9297e 100644 --- a/nemo/collections/diffusion/models/dit/dit_layer_spec.py +++ b/nemo/collections/diffusion/models/dit/dit_layer_spec.py @@ -122,27 +122,26 @@ def __init__( else: self.ln = norm(config.hidden_size, elementwise_affine=False, eps=self.config.layernorm_epsilon) self.n_adaln_chunks = n_adaln_chunks - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - ColumnParallelLinear( - config.hidden_size, - self.n_adaln_chunks * config.hidden_size, - config=config, - init_method=nn.init.normal_, - bias=modulation_bias, - gather_output=True, - ), + self.activation = nn.SiLU() + self.linear = ColumnParallelLinear( + config.hidden_size, + self.n_adaln_chunks * config.hidden_size, + config=config, + init_method=nn.init.normal_, + bias=modulation_bias, + gather_output=True, ) self.use_second_norm = use_second_norm if self.use_second_norm: self.ln2 = nn.LayerNorm(config.hidden_size, elementwise_affine=False, eps=1e-6) - nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.linear.weight, 0) - setattr(self.adaLN_modulation[-1].weight, "sequence_parallel", config.sequence_parallel) + setattr(self.linear.weight, "sequence_parallel", config.sequence_parallel) @jit_fuser def forward(self, timestep_emb): - output, bias = self.adaLN_modulation(timestep_emb) + timestep_emb = self.activation(timestep_emb) + output, bias = self.linear(timestep_emb) output = output + bias if bias else output return output.chunk(self.n_adaln_chunks, dim=-1) diff --git a/nemo/collections/diffusion/models/flux/model.py b/nemo/collections/diffusion/models/flux/model.py index d90c9e74ef4e..a7e617d458a8 100644 --- a/nemo/collections/diffusion/models/flux/model.py +++ b/nemo/collections/diffusion/models/flux/model.py @@ -816,10 +816,10 @@ def config(self) -> FluxConfig: def convert_state(self, source, target): # pylint: disable=C0301 mapping = { - 'transformer_blocks.*.norm1.linear.weight': 'double_blocks.*.adaln.adaLN_modulation.1.weight', - 'transformer_blocks.*.norm1.linear.bias': 'double_blocks.*.adaln.adaLN_modulation.1.bias', - 'transformer_blocks.*.norm1_context.linear.weight': 'double_blocks.*.adaln_context.adaLN_modulation.1.weight', - 'transformer_blocks.*.norm1_context.linear.bias': 'double_blocks.*.adaln_context.adaLN_modulation.1.bias', + 'transformer_blocks.*.norm1.linear.weight': 'double_blocks.*.adaln.linear.weight', + 'transformer_blocks.*.norm1.linear.bias': 'double_blocks.*.adaln.linear.bias', + 'transformer_blocks.*.norm1_context.linear.weight': 'double_blocks.*.adaln_context.linear.weight', + 'transformer_blocks.*.norm1_context.linear.bias': 'double_blocks.*.adaln_context.linear.bias', 'transformer_blocks.*.attn.norm_q.weight': 'double_blocks.*.self_attention.q_layernorm.weight', 'transformer_blocks.*.attn.norm_k.weight': 'double_blocks.*.self_attention.k_layernorm.weight', 'transformer_blocks.*.attn.norm_added_q.weight': 'double_blocks.*.self_attention.added_q_layernorm.weight', @@ -836,8 +836,8 @@ def convert_state(self, source, target): 'transformer_blocks.*.ff_context.net.0.proj.bias': 'double_blocks.*.context_mlp.linear_fc1.bias', 'transformer_blocks.*.ff_context.net.2.weight': 'double_blocks.*.context_mlp.linear_fc2.weight', 'transformer_blocks.*.ff_context.net.2.bias': 'double_blocks.*.context_mlp.linear_fc2.bias', - 'single_transformer_blocks.*.norm.linear.weight': 'single_blocks.*.adaln.adaLN_modulation.1.weight', - 'single_transformer_blocks.*.norm.linear.bias': 'single_blocks.*.adaln.adaLN_modulation.1.bias', + 'single_transformer_blocks.*.norm.linear.weight': 'single_blocks.*.adaln.linear.weight', + 'single_transformer_blocks.*.norm.linear.bias': 'single_blocks.*.adaln.linear.bias', 'single_transformer_blocks.*.proj_mlp.weight': 'single_blocks.*.mlp.linear_fc1.weight', 'single_transformer_blocks.*.proj_mlp.bias': 'single_blocks.*.mlp.linear_fc1.bias', 'single_transformer_blocks.*.attn.norm_q.weight': 'single_blocks.*.self_attention.q_layernorm.weight', diff --git a/nemo/collections/diffusion/models/flux/pipeline.py b/nemo/collections/diffusion/models/flux/pipeline.py index 31422fb70817..34303fdffb5b 100644 --- a/nemo/collections/diffusion/models/flux/pipeline.py +++ b/nemo/collections/diffusion/models/flux/pipeline.py @@ -163,7 +163,9 @@ def __init__( self.scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=scheduler_steps) self.params = params - def load_from_pretrained(self, ckpt_path, do_convert_from_hf=True, save_converted_model_to=None): + def load_from_pretrained( + self, ckpt_path, do_convert_from_hf=True, save_converted_model_to=None, load_dist_ckpt=False + ): """ Loads the model's weights from a checkpoint. If HF ckpt is provided, it will be converted to NeMo format and save it to local folder. @@ -175,11 +177,23 @@ def load_from_pretrained(self, ckpt_path, do_convert_from_hf=True, save_converte Whether to convert the checkpoint from Hugging Face format before loading. Default is True. save_converted_model_to (str, optional): Path to save the converted checkpoint if `do_convert_from_hf` is True. Default is None. + load_dist_ckpt (bool, optional): + Whether to load the checkpoint from dist.ckpt format (NeMo2 checkpoint). Default is False. Logs: The function logs information about missing or unexpected keys during checkpoint loading. """ - if do_convert_from_hf: + assert not (do_convert_from_hf and load_dist_ckpt), 'do_convert_from_hf and load_dist_ckpt cannot both be true' + + if load_dist_ckpt: + from megatron.core import dist_checkpointing + + sharded_state_dict = dict(state_dict=self.transformer.sharded_state_dict(prefix="module.")) + loaded_state_dict = dist_checkpointing.load( + sharded_state_dict=sharded_state_dict, checkpoint_dir=ckpt_path + ) + ckpt = {k.removeprefix("module."): v for k, v in loaded_state_dict["state_dict"].items()} + elif do_convert_from_hf: ckpt = flux_transformer_converter(ckpt_path, self.transformer.config) if save_converted_model_to is not None: save_path = os.path.join(save_converted_model_to, 'nemo_flux_transformer.safetensors') @@ -196,6 +210,11 @@ def load_from_pretrained(self, ckpt_path, do_convert_from_hf=True, save_converte f"please check the ckpt provided or the image quality may be compromised.\n {missing}" ) logging.info(f"Found unexepected keys: \n {unexpected}") + if len(unexpected) > 0: + logging.info( + f"The following keys are unexpected during checkpoint loading, " + f"please check the ckpt provided or the image quality may be compromised.\n {unexpected}" + ) def encoder_prompt( self, @@ -685,12 +704,26 @@ def __init__( self.flux_controlnet = FluxControlNet(contorlnet_config) if flux_controlnet is None else flux_controlnet def load_from_pretrained( - self, flux_ckpt_path, controlnet_ckpt_path, do_convert_from_hf=True, save_converted_model_to=None + self, + flux_ckpt_path, + controlnet_ckpt_path, + do_convert_from_hf=True, + save_converted_model_to=None, + load_dist_ckpt=False, ): ''' Converts both flux base model and flux controlnet ckpt into NeMo format. ''' - if do_convert_from_hf: + assert not (do_convert_from_hf and load_dist_ckpt), 'do_convert_from_hf and load_dist_ckpt cannot both be true' + if load_dist_ckpt: + from megatron.core import dist_checkpointing + + sharded_state_dict = dict(state_dict=self.transformer.sharded_state_dict(prefix="module.")) + loaded_state_dict = dist_checkpointing.load( + sharded_state_dict=sharded_state_dict, checkpoint_dir=flux_ckpt_path + ) + flux_ckpt = {k.removeprefix("module."): v for k, v in loaded_state_dict["state_dict"].items()} + elif do_convert_from_hf: flux_ckpt = flux_transformer_converter(flux_ckpt_path, self.transformer.config) flux_controlnet_ckpt = flux_transformer_converter(controlnet_ckpt_path, self.flux_controlnet.config) diff --git a/nemo/collections/diffusion/utils/flux_ckpt_converter.py b/nemo/collections/diffusion/utils/flux_ckpt_converter.py index c70fc58cc491..1aeae5824b58 100644 --- a/nemo/collections/diffusion/utils/flux_ckpt_converter.py +++ b/nemo/collections/diffusion/utils/flux_ckpt_converter.py @@ -80,10 +80,10 @@ def _import_qkv(transformer_config, q, k, v): flux_key_mapping = { 'double_blocks': { - 'norm1.linear.weight': 'adaln.adaLN_modulation.1.weight', - 'norm1.linear.bias': 'adaln.adaLN_modulation.1.bias', - 'norm1_context.linear.weight': 'adaln_context.adaLN_modulation.1.weight', - 'norm1_context.linear.bias': 'adaln_context.adaLN_modulation.1.bias', + 'norm1.linear.weight': 'adaln.linear.weight', + 'norm1.linear.bias': 'adaln.linear.bias', + 'norm1_context.linear.weight': 'adaln_context.linear.weight', + 'norm1_context.linear.bias': 'adaln_context.linear.bias', 'attn.norm_q.weight': 'self_attention.q_layernorm.weight', 'attn.norm_k.weight': 'self_attention.k_layernorm.weight', 'attn.norm_added_q.weight': 'self_attention.added_q_layernorm.weight', @@ -102,8 +102,8 @@ def _import_qkv(transformer_config, q, k, v): 'ff_context.net.2.bias': 'context_mlp.linear_fc2.bias', }, 'single_blocks': { - 'norm.linear.weight': 'adaln.adaLN_modulation.1.weight', - 'norm.linear.bias': 'adaln.adaLN_modulation.1.bias', + 'norm.linear.weight': 'adaln.linear.weight', + 'norm.linear.bias': 'adaln.linear.bias', 'proj_mlp.weight': 'mlp.linear_fc1.weight', 'proj_mlp.bias': 'mlp.linear_fc1.bias', # 'proj_out.weight': 'proj_out.weight', @@ -219,8 +219,5 @@ def flux_transformer_converter(ckpt_path=None, transformer_config=None): new_state_dict[f'single_blocks.{str(i)}.mlp.linear_fc2.bias'] = ( diffuser_state_dict[f'single_transformer_blocks.{str(i)}.proj_out.bias'].detach().clone() ) - new_state_dict[f'single_blocks.{str(i)}.self_attention.linear_proj.bias'] = ( - diffuser_state_dict[f'single_transformer_blocks.{str(i)}.proj_out.bias'].detach().clone() - ) return new_state_dict diff --git a/scripts/flux/flux_controlnet_infer.py b/scripts/flux/flux_controlnet_infer.py index d8325c62b78d..156a2f1895d7 100644 --- a/scripts/flux/flux_controlnet_infer.py +++ b/scripts/flux/flux_controlnet_infer.py @@ -64,6 +64,12 @@ def parse_args(): default=False, help="Must be true if provided checkpoint is not already converted to NeMo version", ) + parser.add_argument( + "--load_dist_ckpt", + action='store_true', + default=False, + help="Load distributed checkpoint for Flux", + ) parser.add_argument( "--save_converted_model_to", type=str, @@ -143,6 +149,7 @@ def parse_args(): args.controlnet_ckpt, do_convert_from_hf=args.do_convert_from_hf, save_converted_model_to=args.save_converted_model_to, + load_dist_ckpt=args.load_dist_ckpt, ) dtype = torch.float32 text = args.prompts.split(',') diff --git a/scripts/flux/flux_infer.py b/scripts/flux/flux_infer.py index 603c82836d7e..5517b464f51f 100644 --- a/scripts/flux/flux_infer.py +++ b/scripts/flux/flux_infer.py @@ -55,6 +55,12 @@ def parse_args(): default=False, help="Must be true if provided checkpoint is not already converted to NeMo version", ) + parser.add_argument( + "--load_dist_ckpt", + action='store_true', + default=False, + help="Load distributed checkpoint for Flux", + ) parser.add_argument( "--save_converted_model_to", type=str, @@ -118,6 +124,7 @@ def parse_args(): args.flux_ckpt, do_convert_from_hf=args.do_convert_from_hf, save_converted_model_to=args.save_converted_model_to, + load_dist_ckpt=args.load_dist_ckpt, ) dtype = torch.float32 text = args.prompts.split(',')