Skip to content

Commit b378ea8

Browse files
authored
fix model config (#50)
* fix model config * format
1 parent e1cf81f commit b378ea8

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

mlx_audio/tts/utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,14 @@ def load_model(model_path: Path, lazy: bool = False, **kwargs) -> nn.Module:
139139
weights.update(mx.load(wf))
140140

141141
model_class, model_type = get_model_and_args(model_type=model_type)
142-
model_config = model_class.ModelConfig.from_dict(config)
142+
143+
# Get model config from model class if it exists, otherwise use the config
144+
model_config = (
145+
model_class.ModelConfig.from_dict(config)
146+
if hasattr(model_class, "ModelConfig")
147+
else config
148+
)
149+
143150
model = model_class.Model(model_config)
144151
quantization = config.get("quantization", None)
145152
if quantization is None:

0 commit comments

Comments
 (0)