Skip to content

Commit e4e719b

Browse files
committed
fix: add _apply for pipeleines
1 parent e43dda4 commit e4e719b

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

src/pruna/algorithms/red_noe.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,9 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
111111
Any
112112
The model with the reduced number of experts per token.
113113
"""
114+
if is_transformers_pipeline_with_moe_lm(model):
115+
return self._apply_to_model_within_transformers_pipeline(model, smash_config)
116+
114117
device_map = get_device_map(model)
115118
# we need to save and reload with the new config, because immutable object.
116119
with tempfile.TemporaryDirectory() as temp_dir:
@@ -122,7 +125,12 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
122125
else:
123126
with config_path.open("r", encoding="utf-8") as f:
124127
config_json = json.load(f)
125-
config_json[smash_config["target_name"]["include"][0]] = smash_config["num_experts_per_token"]
128+
target_names = smash_config["target_name"]["include"]
129+
if not target_names:
130+
raise ValueError(
131+
"The 'include' list in 'target_name' is empty. Please provide at least one config parameter name to modify."
132+
)
133+
config_json[target_names[0]] = smash_config["num_experts_per_token"]
126134
with config_path.open("w", encoding="utf-8") as f:
127135
json.dump(config_json, f, indent=2)
128136
safe_memory_cleanup()

0 commit comments

Comments
 (0)