diff --git a/optimum/neuron/modeling_diffusion.py b/optimum/neuron/modeling_diffusion.py index 76c07ed02..10dfe93ea 100644 --- a/optimum/neuron/modeling_diffusion.py +++ b/optimum/neuron/modeling_diffusion.py @@ -945,6 +945,7 @@ def _export( ip_adapter_args=ip_adapter_args, output_hidden_states=output_hidden_states, torch_dtype=torch_dtype, + tensor_parallel_size=tensor_parallel_size, controlnet_ids=controlnet_ids, **input_shapes_copy, ) @@ -954,7 +955,7 @@ def _export( for name, (model, neuron_config) in models_and_neuron_configs.items(): if "vae" in name: # vae configs are not cached. continue - model_config = model.config + model_config = getattr(model, "config", None) or neuron_config._config if isinstance(model_config, FrozenDict): model_config = OrderedDict(model_config) model_config = DiffusersPretrainedConfig.from_dict(model_config) @@ -968,7 +969,7 @@ def _export( input_names=neuron_config.inputs, output_names=neuron_config.outputs, dynamic_batch_size=neuron_config.dynamic_batch_size, - tensor_parallel_size=tensor_parallel_size, + tensor_parallel_size=neuron_config.tensor_parallel_size, compiler_type=NEURON_COMPILER_TYPE, compiler_version=NEURON_COMPILER_VERSION, inline_weights_to_neff=inline_weights_to_neff, @@ -990,6 +991,9 @@ def _export( if cache_exist: # load cache + logger.info( + f"Neuron cache found at {model_cache_dir}. If you want to recompile the model, please set `disable_neuron_cache=True`." + ) neuron_model = cls.from_pretrained(model_cache_dir, data_parallel_mode=data_parallel_mode) # replace weights if not inline_weights_to_neff: diff --git a/optimum/neuron/models/inference/flux/modeling_flux.py b/optimum/neuron/models/inference/flux/modeling_flux.py index ee6026942..52c336666 100644 --- a/optimum/neuron/models/inference/flux/modeling_flux.py +++ b/optimum/neuron/models/inference/flux/modeling_flux.py @@ -117,7 +117,7 @@ def __init__( joint_attention_dim: int = 4096, pooled_projection_dim: int = 768, guidance_embeds: bool = False, - axes_dims_rope: tuple[int] = (16, 56, 56), + axes_dims_rope: list[int] = [16, 56, 56], reduce_dtype: torch.dtype = torch.bfloat16, ): super().__init__() diff --git a/pyproject.toml b/pyproject.toml index 0624bfb38..0c779c6d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,7 +75,7 @@ tests = [ "mediapipe", "timm >= 1.0.0", "hf_transfer", - "torchcodec", + "torchcodec < 0.6.0", ] quality = [ "ruff",