Skip to content

Commit 655d881

Browse files
dfalbelclaude
andauthored
Add Ministral/Mistral model implementation (#5)
Implements Ministral-style models with: - YaRN RoPE (Yet another RoPE extension) for extended context - GQA (Grouped Query Attention) with configurable num_key_value_heads - SwiGLU MLP with SiLU activation - RMSNorm Verified against HuggingFace transformers MistralForCausalLM with max diff ~6e-7 (floating point precision). Includes tests for: - Loading pretrained Mistral-7B and comparing logits - Creating models with custom config - Text generation with streaming output Co-authored-by: Claude <noreply@anthropic.com>
1 parent 741cee0 commit 655d881

File tree

2 files changed

+434
-0
lines changed

2 files changed

+434
-0
lines changed

R/ministral.R

Lines changed: 331 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,331 @@
1+
# References:
2+
# - https://huggingface.co/mistralai/Ministral-3-14B-Instruct-2512
3+
# - https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py
4+
5+
#' @noRd
6+
#' @importFrom zeallot %<-%
7+
#' @importFrom purrr map
8+
#' @import torch
9+
NULL
10+
11+
# YaRN RoPE helper functions
12+
yarn_find_correction_dim <- function(num_rotations, dim, base, max_position_embeddings) {
13+
(dim * log(max_position_embeddings / (num_rotations * 2 * pi))) / (2 * log(base))
14+
}
15+
16+
yarn_find_correction_range <- function(low_rot, high_rot, dim, base, max_position_embeddings) {
17+
low <- floor(yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
18+
high <- ceiling(yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
19+
c(max(low, 0), min(high, dim - 1))
20+
}
21+
22+
yarn_linear_ramp_mask <- function(min_val, max_val, dim, dtype = torch_float32()) {
23+
if (min_val == max_val) min_val <- min_val - 0.001
24+
linear_func <- (torch_arange(0, dim - 1, dtype = dtype) - min_val) / (max_val - min_val)
25+
torch_clamp(linear_func, 0, 1)
26+
}
27+
28+
ministral_rotate_half <- function(x) {
29+
c(x1, x2) %<-% torch_split(x, x$size(-1) / 2, -1)
30+
torch_cat(list(-x2, x1), dim = -1)
31+
}
32+
33+
repeat_kv <- function(hidden_states, n_rep) {
34+
if (n_rep == 1) return(hidden_states)
35+
c(batch, num_kv_heads, seq_len, head_dim) %<-% hidden_states$shape
36+
hidden_states$unsqueeze(3)$
37+
expand(c(batch, num_kv_heads, n_rep, seq_len, head_dim))$
38+
reshape(c(batch, num_kv_heads * n_rep, seq_len, head_dim))
39+
}
40+
41+
nn_ministral_rmsnorm <- nn_module(
42+
initialize = function(hidden_size, eps = 1e-5) {
43+
self$weight <- nn_parameter(torch_ones(hidden_size))
44+
self$eps <- eps
45+
},
46+
forward = function(x) {
47+
dtype <- x$dtype
48+
variance <- x$to(dtype = "float32")$pow(2)$mean(-1, keepdim = TRUE)
49+
x <- x * torch_rsqrt(variance + self$eps)
50+
(self$weight * x)$to(dtype = dtype)
51+
}
52+
)
53+
54+
nn_ministral_yarn_rotary_embedding <- nn_module(
55+
initialize = function(head_dim, max_pos, base, factor, beta_fast, beta_slow,
56+
original_max_pos, mscale, mscale_all_dim) {
57+
self$head_dim <- head_dim
58+
self$max_pos <- max_pos
59+
self$base <- base
60+
self$factor <- factor
61+
self$beta_fast <- beta_fast
62+
self$beta_slow <- beta_slow
63+
self$original_max_pos <- original_max_pos
64+
self$mscale <- mscale
65+
self$mscale_all_dim <- mscale_all_dim
66+
self$cached_embeddings()
67+
},
68+
.load_from_state_dict = function(...) {
69+
super$.load_from_state_dict(...)
70+
self$cached_embeddings(invalidate = TRUE)
71+
},
72+
get_mscale = function(scale, mscale) {
73+
if (mscale <= 0) return(1.0)
74+
0.1 * mscale * log(scale) + 1.0
75+
},
76+
cached_embeddings = function(t = 1, invalidate = FALSE) {
77+
invalidate <- invalidate || is.null(self$cos)
78+
if (invalidate) {
79+
dim <- self$head_dim
80+
pos_freqs <- self$base ^ (torch_arange(0, dim - 1, step = 2) / dim)
81+
inv_freq_extrapolation <- 1.0 / pos_freqs
82+
inv_freq_interpolation <- 1.0 / (self$factor * pos_freqs)
83+
84+
c(low, high) %<-% yarn_find_correction_range(
85+
self$beta_slow, self$beta_fast, dim, self$base, self$original_max_pos
86+
)
87+
inv_freq_extrapolation_factor <- yarn_linear_ramp_mask(low, high, dim / 2)
88+
89+
inv_freq <- inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) +
90+
inv_freq_extrapolation * inv_freq_extrapolation_factor
91+
self$inv_freq <- nn_buffer(inv_freq, persistent = FALSE)
92+
93+
self$attention_scale <- self$get_mscale(self$factor, self$mscale) /
94+
self$get_mscale(self$factor, self$mscale_all_dim)
95+
96+
freqs <- torch_arange(start = 0, end = self$max_pos - 1)$
97+
float()$outer(self$inv_freq)$view(c(1, 1, self$max_pos, dim / 2))
98+
emb <- torch_cat(list(freqs, freqs), dim = -1)
99+
self$cos <- nn_buffer(emb$cos(), persistent = FALSE)
100+
self$sin <- nn_buffer(emb$sin(), persistent = FALSE)
101+
}
102+
list(self$cos[,,1:t,], self$sin[,,1:t,], self$attention_scale)
103+
},
104+
forward = function(x) {
105+
c(b, nh, t, ed) %<-% x$shape
106+
c(cos, sin, attn_scale) %<-% self$cached_embeddings(t)
107+
(x * cos + ministral_rotate_half(x) * sin) * attn_scale
108+
}
109+
)
110+
111+
nn_ministral_attention <- nn_module(
112+
initialize = function(n_embd, n_head, n_kv_head, head_dim, max_pos,
113+
rope_base, rope_factor, rope_beta_fast, rope_beta_slow,
114+
rope_original_max_pos, rope_mscale, rope_mscale_all_dim) {
115+
self$n_head <- n_head
116+
self$n_kv_head <- n_kv_head
117+
self$head_dim <- head_dim
118+
self$n_kv_groups <- n_head %/% n_kv_head
119+
self$max_pos <- max_pos
120+
121+
self$rotary <- nn_ministral_yarn_rotary_embedding(
122+
head_dim, max_pos, rope_base, rope_factor, rope_beta_fast, rope_beta_slow,
123+
rope_original_max_pos, rope_mscale, rope_mscale_all_dim
124+
)
125+
126+
self$q_proj <- nn_linear(n_embd, n_head * head_dim, bias = FALSE)
127+
self$k_proj <- nn_linear(n_embd, n_kv_head * head_dim, bias = FALSE)
128+
self$v_proj <- nn_linear(n_embd, n_kv_head * head_dim, bias = FALSE)
129+
self$o_proj <- nn_linear(n_head * head_dim, n_embd, bias = FALSE)
130+
self$cached_bias()
131+
},
132+
forward = function(x) {
133+
c(b, t, h) %<-% x$shape
134+
135+
q <- self$q_proj(x)$view(c(b, t, self$n_head, self$head_dim))$transpose(2, 3)
136+
k <- self$k_proj(x)$view(c(b, t, self$n_kv_head, self$head_dim))$transpose(2, 3)
137+
v <- self$v_proj(x)$view(c(b, t, self$n_kv_head, self$head_dim))$transpose(2, 3)
138+
139+
q <- self$rotary(q)$to(dtype = "float")
140+
k <- self$rotary(k)$to(dtype = "float")
141+
142+
k <- repeat_kv(k, self$n_kv_groups)
143+
v <- repeat_kv(v, self$n_kv_groups)
144+
145+
att <- torch_matmul(q, k$transpose(-2, -1)) * (1 / sqrt(self$head_dim))
146+
att <- att$masked_fill(self$bias[,,1:t, 1:t] == 0, self$masked_bias)
147+
att <- nnf_softmax(att, dim = -1)$to(dtype = v$dtype)
148+
149+
y <- torch_matmul(att, v)$transpose(2, 3)$contiguous()$
150+
view(c(b, t, self$n_head * self$head_dim))
151+
self$o_proj(y)
152+
},
153+
.load_from_state_dict = function(...) {
154+
super$.load_from_state_dict(...)
155+
self$cached_bias()
156+
},
157+
cached_bias = function() {
158+
self$bias <- torch_ones(self$max_pos, self$max_pos)$bool()$tril()$
159+
view(c(1, 1, self$max_pos, self$max_pos)) |> nn_buffer(persistent = FALSE)
160+
self$masked_bias <- nn_buffer(torch_scalar_tensor(-Inf), persistent = FALSE)
161+
}
162+
)
163+
164+
nn_ministral_mlp <- nn_module(
165+
initialize = function(n_embd, n_inter) {
166+
self$gate_proj <- nn_linear(n_embd, n_inter, bias = FALSE)
167+
self$down_proj <- nn_linear(n_inter, n_embd, bias = FALSE)
168+
self$up_proj <- nn_linear(n_embd, n_inter, bias = FALSE)
169+
self$act <- nn_silu()
170+
},
171+
forward = function(x) {
172+
self$down_proj(self$act(self$gate_proj(x)) * self$up_proj(x))
173+
}
174+
)
175+
176+
nn_ministral_layer <- nn_module(
177+
initialize = function(n_embd, n_inter, n_head, n_kv_head, head_dim, max_pos,
178+
rmsnorm_eps, rope_base, rope_factor, rope_beta_fast,
179+
rope_beta_slow, rope_original_max_pos, rope_mscale,
180+
rope_mscale_all_dim) {
181+
self$ln_1 <- nn_ministral_rmsnorm(n_embd, rmsnorm_eps)
182+
self$ln_2 <- nn_ministral_rmsnorm(n_embd, rmsnorm_eps)
183+
self$attn <- nn_ministral_attention(
184+
n_embd, n_head, n_kv_head, head_dim, max_pos, rope_base, rope_factor,
185+
rope_beta_fast, rope_beta_slow, rope_original_max_pos, rope_mscale,
186+
rope_mscale_all_dim
187+
)
188+
self$mlp <- nn_ministral_mlp(n_embd, n_inter)
189+
},
190+
forward = function(x) {
191+
x <- x + self$attn(self$ln_1(x))
192+
x + self$mlp(self$ln_2(x))
193+
}
194+
)
195+
196+
nn_ministral_model <- nn_module(
197+
initialize = function(vocab_size, n_embd, n_inter, n_head, n_kv_head, head_dim,
198+
n_layer, max_pos, rmsnorm_eps, rope_base, rope_factor,
199+
rope_beta_fast, rope_beta_slow, rope_original_max_pos,
200+
rope_mscale, rope_mscale_all_dim) {
201+
self$transformer <- nn_module_dict(list(
202+
wte = nn_embedding(vocab_size, n_embd),
203+
h = nn_sequential(!!!map(
204+
1:n_layer,
205+
\(x) nn_ministral_layer(
206+
n_embd, n_inter, n_head, n_kv_head, head_dim, max_pos, rmsnorm_eps,
207+
rope_base, rope_factor, rope_beta_fast, rope_beta_slow,
208+
rope_original_max_pos, rope_mscale, rope_mscale_all_dim
209+
)
210+
)),
211+
ln_f = nn_ministral_rmsnorm(n_embd, rmsnorm_eps)
212+
))
213+
self$lm_head <- nn_linear(n_embd, vocab_size, bias = FALSE)
214+
},
215+
forward = function(idx) {
216+
x <- self$transformer$wte(idx)
217+
x <- self$transformer$h(x)
218+
x <- self$transformer$ln_f(x)
219+
self$lm_head(x)
220+
}
221+
)
222+
223+
#' ministral
224+
#'
225+
#' Initializes a Ministral-like model with YaRN RoPE and GQA
226+
#'
227+
#' @param vocab_size Vocabulary size.
228+
#' @param n_embd Embedding dimension.
229+
#' @param n_inter Intermediate size in MLP.
230+
#' @param n_head Number of attention heads.
231+
#' @param n_kv_head Number of key/value heads (for GQA).
232+
#' @param head_dim Dimension of each attention head.
233+
#' @param n_layer Number of transformer layers.
234+
#' @param max_pos Maximum position embeddings.
235+
#' @param rmsnorm_eps Epsilon for RMSNorm.
236+
#' @param rope_base Base for rotary embeddings.
237+
#' @param rope_factor YaRN scaling factor.
238+
#' @param rope_beta_fast YaRN beta_fast parameter.
239+
#' @param rope_beta_slow YaRN beta_slow parameter.
240+
#' @param rope_original_max_pos Original max position embeddings for YaRN.
241+
#' @param rope_mscale YaRN mscale parameter.
242+
#' @param rope_mscale_all_dim YaRN mscale_all_dim parameter.
243+
#' @param identifier HuggingFace model identifier.
244+
#' @param revision HuggingFace model revision.
245+
#' @returns An initialized [torch::nn_module()].
246+
#' @export
247+
ministral <- function(vocab_size = 131072, n_embd = 5120, n_inter = 16384,
248+
n_head = 32, n_kv_head = 8, head_dim = 128, n_layer = 40,
249+
max_pos = 262144, rmsnorm_eps = 1e-5, rope_base = 1e9,
250+
rope_factor = 16, rope_beta_fast = 32, rope_beta_slow = 1,
251+
rope_original_max_pos = 16384, rope_mscale = 1,
252+
rope_mscale_all_dim = 1) {
253+
nn_ministral_model(
254+
vocab_size, n_embd, n_inter, n_head, n_kv_head, head_dim, n_layer, max_pos,
255+
rmsnorm_eps, rope_base, rope_factor, rope_beta_fast, rope_beta_slow,
256+
rope_original_max_pos, rope_mscale, rope_mscale_all_dim
257+
)
258+
}
259+
260+
#' @describeIn ministral Initializes from HuggingFace config
261+
#' @export
262+
ministral_from_config <- function(identifier, revision = "main") {
263+
path <- hfhub::hub_download(identifier, "config.json", revision = revision)
264+
config <- jsonlite::fromJSON(path)
265+
266+
# Handle multimodal config (text_config nested)
267+
if (!is.null(config$text_config)) {
268+
rope_params <- config$text_config$rope_parameters
269+
config <- config$text_config
270+
} else {
271+
rope_params <- config$rope_parameters %||% config
272+
}
273+
274+
if (!config$model_type %in% c("ministral3", "mistral"))
275+
cli::cli_abort("Unsupported model_type: {.val {config$model_type}}")
276+
277+
if (config$hidden_act != "silu")
278+
cli::cli_abort("Unsupported hidden_act: {.val {config$hidden_act}}")
279+
280+
ministral(
281+
vocab_size = config$vocab_size,
282+
n_embd = config$hidden_size,
283+
n_inter = config$intermediate_size,
284+
n_head = config$num_attention_heads,
285+
n_kv_head = config$num_key_value_heads,
286+
head_dim = config$head_dim %||% (config$hidden_size %/% config$num_attention_heads),
287+
n_layer = config$num_hidden_layers,
288+
max_pos = config$max_position_embeddings,
289+
rmsnorm_eps = config$rms_norm_eps %||% 1e-5,
290+
rope_base = rope_params$rope_theta %||% 1e9,
291+
rope_factor = rope_params$factor %||% 16,
292+
rope_beta_fast = rope_params$beta_fast %||% 32,
293+
rope_beta_slow = rope_params$beta_slow %||% 1,
294+
rope_original_max_pos = rope_params$original_max_position_embeddings %||% 16384,
295+
rope_mscale = rope_params$mscale %||% 1,
296+
rope_mscale_all_dim = rope_params$mscale_all_dim %||% 1
297+
)
298+
}
299+
300+
#' @describeIn ministral Initializes and loads pretrained weights from HF Hub
301+
#' @export
302+
ministral_from_pretrained <- function(identifier, revision = "main") {
303+
with_device(device = "meta", {
304+
model <- ministral_from_config(identifier, revision)
305+
})
306+
state_dict <- hf_state_dict(identifier, revision)
307+
state_dict <- ministral_hf_weights_remap(state_dict)
308+
model$load_state_dict(state_dict, .refer_to_state_dict = TRUE)
309+
model
310+
}
311+
312+
ministral_hf_weights_remap <- function(state_dict) {
313+
nms <- names(state_dict)
314+
315+
# Handle multimodal models (language_model prefix)
316+
nms <- gsub("^language_model\\.", "", nms)
317+
318+
# Standard remapping
319+
nms <- gsub("model.embed_tokens.weight", "transformer.wte.weight", nms, fixed = TRUE)
320+
nms <- gsub("model.layers", "transformer.h", nms, fixed = TRUE)
321+
nms <- gsub("self_attn", "attn", nms, fixed = TRUE)
322+
nms <- gsub("input_layernorm", "ln_1", nms, fixed = TRUE)
323+
nms <- gsub("post_attention_layernorm", "ln_2", nms, fixed = TRUE)
324+
nms <- gsub("model.norm", "transformer.ln_f", nms, fixed = TRUE)
325+
326+
names(state_dict) <- nms
327+
328+
# Filter out non-language model weights
329+
keep <- !grepl("vision_tower|multi_modal_projector|scale", names(state_dict))
330+
state_dict[keep]
331+
}

0 commit comments

Comments
 (0)