Skip to content

load_model() should accept pre-loaded weights and config #969

@blightbow

Description

@blightbow

Summary

load_model() currently requires weights to exist on the filesystem as .safetensors files. Adding optional config and weights parameters would let callers who already have weights in memory skip the I/O layer while preserving all model construction, sanitization, and quantization logic.

This is a ~15-line backward-compatible change.

Motivation

Several use cases need to construct a model from weights that don't originate from local disk:

Today, anyone in these situations must reimplement the model construction pipeline (class resolution, sanitization, quantization dispatch, QQLinear upgrade) outside of mlx-lm, which drifts as new quantization paths are added upstream.

Proposed Change

Add two optional parameters to load_model():

def load_model(
    model_path: Path,
    lazy: bool = False,
    strict: bool = True,
    model_config: Optional[Dict[str, Any]] = None,
    get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
    # New parameters:
    config: Optional[dict] = None,
    weights: Optional[Dict[str, mx.array]] = None,
) -> Tuple[nn.Module, dict]:

When config is provided, skip load_config(model_path). When weights is provided, skip the glob + mx.load() loop. Everything downstream — sanitize(), _quantize(), quantize_activations, AWQ/GPTQ transforms, load_weights() — remains unchanged.

# Current lines 300-311:
config = load_config(model_path)
if model_config is not None:
    config.update(model_config)

weight_files = glob.glob(str(model_path / "model*.safetensors"))
if not weight_files and strict:
    raise FileNotFoundError(f"No safetensors found in {model_path}")

weights = {}
for wf in weight_files:
    weights.update(mx.load(wf))

# Proposed:
if config is None:
    config = load_config(model_path)
if model_config is not None:
    config.update(model_config)

if weights is None:
    weight_files = glob.glob(str(model_path / "model*.safetensors"))
    if not weight_files and strict:
        raise FileNotFoundError(f"No safetensors found in {model_path}")
    weights = {}
    for wf in weight_files:
        weights.update(mx.load(wf))

Notes

  • model_path is still required even when config and weights are both provided, because custom model file resolution (config.get("model_file")) uses it. A follow-up could make it Optional[Path] with appropriate guards.
  • sharded_load() currently calls load_model() twice — once to discover model structure, once to load weights. This change would let it be simplified in a follow-up.
  • Existing behavior is completely unchanged when the new parameters are not provided.

Contributing

I'm happy to submit a PR for this if the approach looks reasonable. The change is small enough that I can have it ready quickly.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions