-
Notifications
You must be signed in to change notification settings - Fork 482
Description
Summary
load_model() currently requires weights to exist on the filesystem as .safetensors files. Adding optional config and weights parameters would let callers who already have weights in memory skip the I/O layer while preserving all model construction, sanitization, and quantization logic.
This is a ~15-line backward-compatible change.
Motivation
Several use cases need to construct a model from weights that don't originate from local disk:
- Distributed inference — weights received over the network (see [Proposal] RDMA weight streaming — eliminate per-node model storage for distributed inference mlx#3208)
- In-memory safetensors — the safetensors library cannot round-trip bf16 through
numpy.load(bytes), butmx.load(BytesIO(data))works. Callers who parse safetensors in memory currently have no way to feed the resulting dict intoload_model()without writing to a tmpfile. (see [Feature Request] Cannot create tensor from raw bytes + dtypes mlx#1296) - Testing — constructing models from synthetic weights without touching disk
- Model merging/editing — composing weights in memory, then building a model
Today, anyone in these situations must reimplement the model construction pipeline (class resolution, sanitization, quantization dispatch, QQLinear upgrade) outside of mlx-lm, which drifts as new quantization paths are added upstream.
Proposed Change
Add two optional parameters to load_model():
def load_model(
model_path: Path,
lazy: bool = False,
strict: bool = True,
model_config: Optional[Dict[str, Any]] = None,
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
# New parameters:
config: Optional[dict] = None,
weights: Optional[Dict[str, mx.array]] = None,
) -> Tuple[nn.Module, dict]:When config is provided, skip load_config(model_path). When weights is provided, skip the glob + mx.load() loop. Everything downstream — sanitize(), _quantize(), quantize_activations, AWQ/GPTQ transforms, load_weights() — remains unchanged.
# Current lines 300-311:
config = load_config(model_path)
if model_config is not None:
config.update(model_config)
weight_files = glob.glob(str(model_path / "model*.safetensors"))
if not weight_files and strict:
raise FileNotFoundError(f"No safetensors found in {model_path}")
weights = {}
for wf in weight_files:
weights.update(mx.load(wf))
# Proposed:
if config is None:
config = load_config(model_path)
if model_config is not None:
config.update(model_config)
if weights is None:
weight_files = glob.glob(str(model_path / "model*.safetensors"))
if not weight_files and strict:
raise FileNotFoundError(f"No safetensors found in {model_path}")
weights = {}
for wf in weight_files:
weights.update(mx.load(wf))Notes
model_pathis still required even whenconfigandweightsare both provided, because custom model file resolution (config.get("model_file")) uses it. A follow-up could make itOptional[Path]with appropriate guards.sharded_load()currently callsload_model()twice — once to discover model structure, once to load weights. This change would let it be simplified in a follow-up.- Existing behavior is completely unchanged when the new parameters are not provided.
Contributing
I'm happy to submit a PR for this if the approach looks reasonable. The change is small enough that I can have it ready quickly.