Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions mergekit/architecture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,15 @@ def get_architecture_info(
raise RuntimeError(
"Must specify --allow-crimes to attempt to mix different architectures"
)
# Prefer using the base_model's architecture when mixing different families
# to ensure the output layout matches the base.
if config.base_model is not None:
try:
idx = models.index(config.base_model)
return model_arch_info[idx]
except ValueError:
# base_model not in referenced models; fall back to first
pass
return model_arch_info[0]

# try to infer from all models
Expand Down
1 change: 1 addition & 0 deletions mergekit/io/tensor_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
use_async: bool = False,
max_write_threads: int = 1,
) -> None:
out_path = os.path.abspath(os.path.expanduser(out_path))
os.makedirs(out_path, exist_ok=True)

self.out_path = out_path
Expand Down
12 changes: 9 additions & 3 deletions mergekit/merge_methods/slerp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
class SlerpTask(Task[torch.Tensor]):
gather_tensors: MergeTensorInput
base_model: ModelReference
t: float
t: Optional[float]
weight_info: WeightInfo

def uses_accelerator(self) -> bool:
Expand All @@ -38,6 +38,12 @@ def execute(self, tensors: Dict[ModelReference, torch.Tensor]) -> torch.Tensor:
elif self.base_model not in tensors:
raise RuntimeError("Base model not in input tensors")

# If no interpolation parameter was provided for this tensor, do not attempt to merge;
# simply return the base model's weight unchanged. This avoids shape/broadcast errors
# when the secondary model has incompatible tensor shapes.
if self.t is None:
return tensors[self.base_model]

[a, b] = list(tensors.items())
if a[0] != self.base_model:
[a, b] = [b, a]
Expand Down Expand Up @@ -72,7 +78,7 @@ def reference_url(self):
return "https://en.wikipedia.org/wiki/Slerp"

def parameters(self) -> List[ConfigParameterDef]:
return [ConfigParameterDef(name="t", required=True)]
return [ConfigParameterDef(name="t", required=False, default_value=None)]

def make_task(
self,
Expand All @@ -92,7 +98,7 @@ def make_task(


def lerp(
t: float, v0: Union[np.ndarray, torch.Tensor], v1: Union[np.ndarray, torch.Tensor]
t: Optional[float], v0: Union[np.ndarray, torch.Tensor], v1: Union[np.ndarray, torch.Tensor]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Lerp Breaks When t Is None Passes Optional Signature

The lerp function signature was changed to accept t: Optional[float], but the implementation return (1 - t) * v0 + t * v1 does not handle the case when t is None. This will cause a TypeError at runtime if None is passed, since Python cannot perform arithmetic operations with None. While the current code path in SlerpTask.execute() returns early when t is None (line 44-45), making the type signature accept Optional[float] without implementing proper handling creates a type-implementation mismatch and makes the function fragile to future refactoring or external usage.

Fix in Cursor Fix in Web

) -> Union[np.ndarray, torch.Tensor]:
return (1 - t) * v0 + t * v1

Expand Down
5 changes: 3 additions & 2 deletions mergekit/tokenizer/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def assign_embedding_sources(
continue

if num_present == 0:
token_configs[token] = TokenEmbeddingConfig(source=ZeroEmbedding())
token_configs[token] = TokenEmbeddingConfig(source=ZeroEmbedding(kind="zero"))
logging.warning(f"Token {repr(token)} not found in any model")
continue

Expand All @@ -152,7 +152,8 @@ def compute_default_embedding(
cfg: TokenEmbeddingConfig,
) -> torch.Tensor:
if isinstance(cfg.source, ZeroEmbedding):
pass
any_tensor = next(iter(tensors.values()))
embed = torch.zeros(any_tensor.shape[1], dtype=any_tensor.dtype, device=any_tensor.device)
elif isinstance(cfg.source, ModelTokenEmbedding):
model = cfg.source.model
assert (
Expand Down