Skip to content

Commit abdac0c

Browse files
authored
Support Mistral (#133)
1 parent 550644a commit abdac0c

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

tuned_lens/model_surgery.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ def get_final_norm(model: Model) -> Norm:
119119
final_layer_norm = base_model.ln_f
120120
elif isinstance(base_model, models.llama.modeling_llama.LlamaModel):
121121
final_layer_norm = base_model.norm
122+
elif isinstance(base_model, models.mistral.modeling_mistral.MistralModel):
123+
final_layer_norm = base_model.norm
122124
elif isinstance(base_model, models.gemma.modeling_gemma.GemmaModel):
123125
final_layer_norm = base_model.norm
124126
else:
@@ -166,6 +168,8 @@ def get_transformer_layers(model: Model) -> tuple[str, th.nn.ModuleList]:
166168
path_to_layers += ["h"]
167169
elif isinstance(base_model, models.llama.modeling_llama.LlamaModel):
168170
path_to_layers += ["layers"]
171+
elif isinstance(base_model, models.mistral.modeling_mistral.MistralModel):
172+
path_to_layers += ["layers"]
169173
elif isinstance(base_model, models.gemma.modeling_gemma.GemmaModel):
170174
path_to_layers += ["layers"]
171175
else:

0 commit comments

Comments
 (0)