@@ -946,6 +946,7 @@ def _export(
946946 ip_adapter_args = ip_adapter_args ,
947947 output_hidden_states = output_hidden_states ,
948948 torch_dtype = torch_dtype ,
949+ tensor_parallel_size = tensor_parallel_size ,
949950 controlnet_ids = controlnet_ids ,
950951 ** input_shapes_copy ,
951952 )
@@ -955,7 +956,7 @@ def _export(
955956 for name , (model , neuron_config ) in models_and_neuron_configs .items ():
956957 if "vae" in name : # vae configs are not cached.
957958 continue
958- model_config = model . config
959+ model_config = getattr ( model , " config" , None ) or neuron_config . _config
959960 if isinstance (model_config , FrozenDict ):
960961 model_config = OrderedDict (model_config )
961962 model_config = DiffusersPretrainedConfig .from_dict (model_config )
@@ -969,7 +970,7 @@ def _export(
969970 input_names = neuron_config .inputs ,
970971 output_names = neuron_config .outputs ,
971972 dynamic_batch_size = neuron_config .dynamic_batch_size ,
972- tensor_parallel_size = tensor_parallel_size ,
973+ tensor_parallel_size = neuron_config . tensor_parallel_size ,
973974 compiler_type = NEURON_COMPILER_TYPE ,
974975 compiler_version = NEURON_COMPILER_VERSION ,
975976 inline_weights_to_neff = inline_weights_to_neff ,
@@ -991,6 +992,7 @@ def _export(
991992
992993 if cache_exist :
993994 # load cache
995+ logger .info (f"Neuron cache found at { model_cache_dir } . If you want to recompile the model, please set `disable_neuron_cache=True`." )
994996 neuron_model = cls .from_pretrained (model_cache_dir , data_parallel_mode = data_parallel_mode )
995997 # replace weights
996998 if not inline_weights_to_neff :
0 commit comments