Skip to content

Commit 0a9381e

Browse files
JingyaHuangdacorvo
authored andcommitted
fix: tp size mismatched in the config
1 parent 891ac13 commit 0a9381e

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

optimum/neuron/modeling_diffusion.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -946,6 +946,7 @@ def _export(
946946
ip_adapter_args=ip_adapter_args,
947947
output_hidden_states=output_hidden_states,
948948
torch_dtype=torch_dtype,
949+
tensor_parallel_size=tensor_parallel_size,
949950
controlnet_ids=controlnet_ids,
950951
**input_shapes_copy,
951952
)
@@ -955,7 +956,7 @@ def _export(
955956
for name, (model, neuron_config) in models_and_neuron_configs.items():
956957
if "vae" in name: # vae configs are not cached.
957958
continue
958-
model_config = model.config
959+
model_config = getattr(model, "config", None) or neuron_config._config
959960
if isinstance(model_config, FrozenDict):
960961
model_config = OrderedDict(model_config)
961962
model_config = DiffusersPretrainedConfig.from_dict(model_config)
@@ -969,7 +970,7 @@ def _export(
969970
input_names=neuron_config.inputs,
970971
output_names=neuron_config.outputs,
971972
dynamic_batch_size=neuron_config.dynamic_batch_size,
972-
tensor_parallel_size=tensor_parallel_size,
973+
tensor_parallel_size=neuron_config.tensor_parallel_size,
973974
compiler_type=NEURON_COMPILER_TYPE,
974975
compiler_version=NEURON_COMPILER_VERSION,
975976
inline_weights_to_neff=inline_weights_to_neff,
@@ -991,6 +992,7 @@ def _export(
991992

992993
if cache_exist:
993994
# load cache
995+
logger.info(f"Neuron cache found at {model_cache_dir}. If you want to recompile the model, please set `disable_neuron_cache=True`.")
994996
neuron_model = cls.from_pretrained(model_cache_dir, data_parallel_mode=data_parallel_mode)
995997
# replace weights
996998
if not inline_weights_to_neff:

optimum/neuron/models/inference/flux/modeling_flux.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def __init__(
117117
joint_attention_dim: int = 4096,
118118
pooled_projection_dim: int = 768,
119119
guidance_embeds: bool = False,
120-
axes_dims_rope: tuple[int] = (16, 56, 56),
120+
axes_dims_rope: list[int] = [16, 56, 56],
121121
reduce_dtype: torch.dtype = torch.bfloat16,
122122
):
123123
super().__init__()

0 commit comments

Comments
 (0)