Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,11 +499,19 @@ def create_model(
# Instantiate the model
logging.info(f"Instantiating model architecture: {model_class.__name__}")
model = model_class(**final_model_cfg, cast_dtype=cast_dtype)
_set_model_device_and_precision(model, device, precision, is_timm_model)

# The model could be in the meta device if inside a context manager,
# such as `accelerate.init_empty_weights`
# or inside a `transformers.PreTrainedModel.from_pretrained` call.
model_is_in_meta_device = next(model.parameters()).device.type == "meta"

if not model_is_in_meta_device:
_set_model_device_and_precision(model, device, precision, is_timm_model)
model_is_in_meta_device = device.type == 'meta'

# Load Full Pretrained CLIP Weights (if path exists)
pretrained_loaded = False
if checkpoint_path:
if checkpoint_path and not model_is_in_meta_device:
logging.info(f'Loading full pretrained weights from: {checkpoint_path}')
# Use the load_checkpoint helper which handles state dict loading, conversions, etc.
# Use strict=True by default for full model loading to catch mismatches.
Expand All @@ -518,7 +526,7 @@ def create_model(

# Load tower-specific weights (image and text), after the full CLIP checkpoint, potentially overwriting parts.
pretrained_image_loaded = False # Track if specific image weights loaded
if pretrained_image_path:
if pretrained_image_path and not model_is_in_meta_device:
if os.path.isfile(pretrained_image_path):
logging.info(f"Attempting to load image tower weights from: {pretrained_image_path}")
try:
Expand Down Expand Up @@ -547,7 +555,7 @@ def create_model(
logging.warning(f"Invalid file path specified for pretrained_image_path: {pretrained_image_path}")

pretrained_text_loaded = False # Track if specific text weights loaded
if pretrained_text_path:
if pretrained_text_path and not model_is_in_meta_device:
if os.path.isfile(pretrained_text_path):
logging.info(f"Attempting to load text tower weights from: {pretrained_text_path}")
try:
Expand Down Expand Up @@ -585,6 +593,8 @@ def create_model(
elif not pretrained_loaded and partially_loaded:
# Some tower weights loaded
logging.warning(f"Model {model_name} initialized partially.")
elif model_is_in_meta_device:
logging.info("The model is in the 'meta' device and thus it was not initialized.")
elif not pretrained_loaded and not partially_loaded:
# Absolutely no weights were loaded from any source
logging.warning(f"No pretrained weights loaded for model '{model_name}'. Model initialized randomly.")
Expand Down
Loading