|
| 1 | +# References: |
| 2 | +# - https://huggingface.co/mistralai/Ministral-3-14B-Instruct-2512 |
| 3 | +# - https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py |
| 4 | + |
| 5 | +#' @noRd |
| 6 | +#' @importFrom zeallot %<-% |
| 7 | +#' @importFrom purrr map |
| 8 | +#' @import torch |
| 9 | +NULL |
| 10 | + |
| 11 | +# YaRN RoPE helper functions |
| 12 | +yarn_find_correction_dim <- function(num_rotations, dim, base, max_position_embeddings) { |
| 13 | + (dim * log(max_position_embeddings / (num_rotations * 2 * pi))) / (2 * log(base)) |
| 14 | +} |
| 15 | + |
| 16 | +yarn_find_correction_range <- function(low_rot, high_rot, dim, base, max_position_embeddings) { |
| 17 | + low <- floor(yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) |
| 18 | + high <- ceiling(yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) |
| 19 | + c(max(low, 0), min(high, dim - 1)) |
| 20 | +} |
| 21 | + |
| 22 | +yarn_linear_ramp_mask <- function(min_val, max_val, dim, dtype = torch_float32()) { |
| 23 | + if (min_val == max_val) min_val <- min_val - 0.001 |
| 24 | + linear_func <- (torch_arange(0, dim - 1, dtype = dtype) - min_val) / (max_val - min_val) |
| 25 | + torch_clamp(linear_func, 0, 1) |
| 26 | +} |
| 27 | + |
| 28 | +ministral_rotate_half <- function(x) { |
| 29 | + c(x1, x2) %<-% torch_split(x, x$size(-1) / 2, -1) |
| 30 | + torch_cat(list(-x2, x1), dim = -1) |
| 31 | +} |
| 32 | + |
| 33 | +repeat_kv <- function(hidden_states, n_rep) { |
| 34 | + if (n_rep == 1) return(hidden_states) |
| 35 | + c(batch, num_kv_heads, seq_len, head_dim) %<-% hidden_states$shape |
| 36 | + hidden_states$unsqueeze(3)$ |
| 37 | + expand(c(batch, num_kv_heads, n_rep, seq_len, head_dim))$ |
| 38 | + reshape(c(batch, num_kv_heads * n_rep, seq_len, head_dim)) |
| 39 | +} |
| 40 | + |
| 41 | +nn_ministral_rmsnorm <- nn_module( |
| 42 | + initialize = function(hidden_size, eps = 1e-5) { |
| 43 | + self$weight <- nn_parameter(torch_ones(hidden_size)) |
| 44 | + self$eps <- eps |
| 45 | + }, |
| 46 | + forward = function(x) { |
| 47 | + dtype <- x$dtype |
| 48 | + variance <- x$to(dtype = "float32")$pow(2)$mean(-1, keepdim = TRUE) |
| 49 | + x <- x * torch_rsqrt(variance + self$eps) |
| 50 | + (self$weight * x)$to(dtype = dtype) |
| 51 | + } |
| 52 | +) |
| 53 | + |
| 54 | +nn_ministral_yarn_rotary_embedding <- nn_module( |
| 55 | + initialize = function(head_dim, max_pos, base, factor, beta_fast, beta_slow, |
| 56 | + original_max_pos, mscale, mscale_all_dim) { |
| 57 | + self$head_dim <- head_dim |
| 58 | + self$max_pos <- max_pos |
| 59 | + self$base <- base |
| 60 | + self$factor <- factor |
| 61 | + self$beta_fast <- beta_fast |
| 62 | + self$beta_slow <- beta_slow |
| 63 | + self$original_max_pos <- original_max_pos |
| 64 | + self$mscale <- mscale |
| 65 | + self$mscale_all_dim <- mscale_all_dim |
| 66 | + self$cached_embeddings() |
| 67 | + }, |
| 68 | + .load_from_state_dict = function(...) { |
| 69 | + super$.load_from_state_dict(...) |
| 70 | + self$cached_embeddings(invalidate = TRUE) |
| 71 | + }, |
| 72 | + get_mscale = function(scale, mscale) { |
| 73 | + if (mscale <= 0) return(1.0) |
| 74 | + 0.1 * mscale * log(scale) + 1.0 |
| 75 | + }, |
| 76 | + cached_embeddings = function(t = 1, invalidate = FALSE) { |
| 77 | + invalidate <- invalidate || is.null(self$cos) |
| 78 | + if (invalidate) { |
| 79 | + dim <- self$head_dim |
| 80 | + pos_freqs <- self$base ^ (torch_arange(0, dim - 1, step = 2) / dim) |
| 81 | + inv_freq_extrapolation <- 1.0 / pos_freqs |
| 82 | + inv_freq_interpolation <- 1.0 / (self$factor * pos_freqs) |
| 83 | + |
| 84 | + c(low, high) %<-% yarn_find_correction_range( |
| 85 | + self$beta_slow, self$beta_fast, dim, self$base, self$original_max_pos |
| 86 | + ) |
| 87 | + inv_freq_extrapolation_factor <- yarn_linear_ramp_mask(low, high, dim / 2) |
| 88 | + |
| 89 | + inv_freq <- inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + |
| 90 | + inv_freq_extrapolation * inv_freq_extrapolation_factor |
| 91 | + self$inv_freq <- nn_buffer(inv_freq, persistent = FALSE) |
| 92 | + |
| 93 | + self$attention_scale <- self$get_mscale(self$factor, self$mscale) / |
| 94 | + self$get_mscale(self$factor, self$mscale_all_dim) |
| 95 | + |
| 96 | + freqs <- torch_arange(start = 0, end = self$max_pos - 1)$ |
| 97 | + float()$outer(self$inv_freq)$view(c(1, 1, self$max_pos, dim / 2)) |
| 98 | + emb <- torch_cat(list(freqs, freqs), dim = -1) |
| 99 | + self$cos <- nn_buffer(emb$cos(), persistent = FALSE) |
| 100 | + self$sin <- nn_buffer(emb$sin(), persistent = FALSE) |
| 101 | + } |
| 102 | + list(self$cos[,,1:t,], self$sin[,,1:t,], self$attention_scale) |
| 103 | + }, |
| 104 | + forward = function(x) { |
| 105 | + c(b, nh, t, ed) %<-% x$shape |
| 106 | + c(cos, sin, attn_scale) %<-% self$cached_embeddings(t) |
| 107 | + (x * cos + ministral_rotate_half(x) * sin) * attn_scale |
| 108 | + } |
| 109 | +) |
| 110 | + |
| 111 | +nn_ministral_attention <- nn_module( |
| 112 | + initialize = function(n_embd, n_head, n_kv_head, head_dim, max_pos, |
| 113 | + rope_base, rope_factor, rope_beta_fast, rope_beta_slow, |
| 114 | + rope_original_max_pos, rope_mscale, rope_mscale_all_dim) { |
| 115 | + self$n_head <- n_head |
| 116 | + self$n_kv_head <- n_kv_head |
| 117 | + self$head_dim <- head_dim |
| 118 | + self$n_kv_groups <- n_head %/% n_kv_head |
| 119 | + self$max_pos <- max_pos |
| 120 | + |
| 121 | + self$rotary <- nn_ministral_yarn_rotary_embedding( |
| 122 | + head_dim, max_pos, rope_base, rope_factor, rope_beta_fast, rope_beta_slow, |
| 123 | + rope_original_max_pos, rope_mscale, rope_mscale_all_dim |
| 124 | + ) |
| 125 | + |
| 126 | + self$q_proj <- nn_linear(n_embd, n_head * head_dim, bias = FALSE) |
| 127 | + self$k_proj <- nn_linear(n_embd, n_kv_head * head_dim, bias = FALSE) |
| 128 | + self$v_proj <- nn_linear(n_embd, n_kv_head * head_dim, bias = FALSE) |
| 129 | + self$o_proj <- nn_linear(n_head * head_dim, n_embd, bias = FALSE) |
| 130 | + self$cached_bias() |
| 131 | + }, |
| 132 | + forward = function(x) { |
| 133 | + c(b, t, h) %<-% x$shape |
| 134 | + |
| 135 | + q <- self$q_proj(x)$view(c(b, t, self$n_head, self$head_dim))$transpose(2, 3) |
| 136 | + k <- self$k_proj(x)$view(c(b, t, self$n_kv_head, self$head_dim))$transpose(2, 3) |
| 137 | + v <- self$v_proj(x)$view(c(b, t, self$n_kv_head, self$head_dim))$transpose(2, 3) |
| 138 | + |
| 139 | + q <- self$rotary(q)$to(dtype = "float") |
| 140 | + k <- self$rotary(k)$to(dtype = "float") |
| 141 | + |
| 142 | + k <- repeat_kv(k, self$n_kv_groups) |
| 143 | + v <- repeat_kv(v, self$n_kv_groups) |
| 144 | + |
| 145 | + att <- torch_matmul(q, k$transpose(-2, -1)) * (1 / sqrt(self$head_dim)) |
| 146 | + att <- att$masked_fill(self$bias[,,1:t, 1:t] == 0, self$masked_bias) |
| 147 | + att <- nnf_softmax(att, dim = -1)$to(dtype = v$dtype) |
| 148 | + |
| 149 | + y <- torch_matmul(att, v)$transpose(2, 3)$contiguous()$ |
| 150 | + view(c(b, t, self$n_head * self$head_dim)) |
| 151 | + self$o_proj(y) |
| 152 | + }, |
| 153 | + .load_from_state_dict = function(...) { |
| 154 | + super$.load_from_state_dict(...) |
| 155 | + self$cached_bias() |
| 156 | + }, |
| 157 | + cached_bias = function() { |
| 158 | + self$bias <- torch_ones(self$max_pos, self$max_pos)$bool()$tril()$ |
| 159 | + view(c(1, 1, self$max_pos, self$max_pos)) |> nn_buffer(persistent = FALSE) |
| 160 | + self$masked_bias <- nn_buffer(torch_scalar_tensor(-Inf), persistent = FALSE) |
| 161 | + } |
| 162 | +) |
| 163 | + |
| 164 | +nn_ministral_mlp <- nn_module( |
| 165 | + initialize = function(n_embd, n_inter) { |
| 166 | + self$gate_proj <- nn_linear(n_embd, n_inter, bias = FALSE) |
| 167 | + self$down_proj <- nn_linear(n_inter, n_embd, bias = FALSE) |
| 168 | + self$up_proj <- nn_linear(n_embd, n_inter, bias = FALSE) |
| 169 | + self$act <- nn_silu() |
| 170 | + }, |
| 171 | + forward = function(x) { |
| 172 | + self$down_proj(self$act(self$gate_proj(x)) * self$up_proj(x)) |
| 173 | + } |
| 174 | +) |
| 175 | + |
| 176 | +nn_ministral_layer <- nn_module( |
| 177 | + initialize = function(n_embd, n_inter, n_head, n_kv_head, head_dim, max_pos, |
| 178 | + rmsnorm_eps, rope_base, rope_factor, rope_beta_fast, |
| 179 | + rope_beta_slow, rope_original_max_pos, rope_mscale, |
| 180 | + rope_mscale_all_dim) { |
| 181 | + self$ln_1 <- nn_ministral_rmsnorm(n_embd, rmsnorm_eps) |
| 182 | + self$ln_2 <- nn_ministral_rmsnorm(n_embd, rmsnorm_eps) |
| 183 | + self$attn <- nn_ministral_attention( |
| 184 | + n_embd, n_head, n_kv_head, head_dim, max_pos, rope_base, rope_factor, |
| 185 | + rope_beta_fast, rope_beta_slow, rope_original_max_pos, rope_mscale, |
| 186 | + rope_mscale_all_dim |
| 187 | + ) |
| 188 | + self$mlp <- nn_ministral_mlp(n_embd, n_inter) |
| 189 | + }, |
| 190 | + forward = function(x) { |
| 191 | + x <- x + self$attn(self$ln_1(x)) |
| 192 | + x + self$mlp(self$ln_2(x)) |
| 193 | + } |
| 194 | +) |
| 195 | + |
| 196 | +nn_ministral_model <- nn_module( |
| 197 | + initialize = function(vocab_size, n_embd, n_inter, n_head, n_kv_head, head_dim, |
| 198 | + n_layer, max_pos, rmsnorm_eps, rope_base, rope_factor, |
| 199 | + rope_beta_fast, rope_beta_slow, rope_original_max_pos, |
| 200 | + rope_mscale, rope_mscale_all_dim) { |
| 201 | + self$transformer <- nn_module_dict(list( |
| 202 | + wte = nn_embedding(vocab_size, n_embd), |
| 203 | + h = nn_sequential(!!!map( |
| 204 | + 1:n_layer, |
| 205 | + \(x) nn_ministral_layer( |
| 206 | + n_embd, n_inter, n_head, n_kv_head, head_dim, max_pos, rmsnorm_eps, |
| 207 | + rope_base, rope_factor, rope_beta_fast, rope_beta_slow, |
| 208 | + rope_original_max_pos, rope_mscale, rope_mscale_all_dim |
| 209 | + ) |
| 210 | + )), |
| 211 | + ln_f = nn_ministral_rmsnorm(n_embd, rmsnorm_eps) |
| 212 | + )) |
| 213 | + self$lm_head <- nn_linear(n_embd, vocab_size, bias = FALSE) |
| 214 | + }, |
| 215 | + forward = function(idx) { |
| 216 | + x <- self$transformer$wte(idx) |
| 217 | + x <- self$transformer$h(x) |
| 218 | + x <- self$transformer$ln_f(x) |
| 219 | + self$lm_head(x) |
| 220 | + } |
| 221 | +) |
| 222 | + |
| 223 | +#' ministral |
| 224 | +#' |
| 225 | +#' Initializes a Ministral-like model with YaRN RoPE and GQA |
| 226 | +#' |
| 227 | +#' @param vocab_size Vocabulary size. |
| 228 | +#' @param n_embd Embedding dimension. |
| 229 | +#' @param n_inter Intermediate size in MLP. |
| 230 | +#' @param n_head Number of attention heads. |
| 231 | +#' @param n_kv_head Number of key/value heads (for GQA). |
| 232 | +#' @param head_dim Dimension of each attention head. |
| 233 | +#' @param n_layer Number of transformer layers. |
| 234 | +#' @param max_pos Maximum position embeddings. |
| 235 | +#' @param rmsnorm_eps Epsilon for RMSNorm. |
| 236 | +#' @param rope_base Base for rotary embeddings. |
| 237 | +#' @param rope_factor YaRN scaling factor. |
| 238 | +#' @param rope_beta_fast YaRN beta_fast parameter. |
| 239 | +#' @param rope_beta_slow YaRN beta_slow parameter. |
| 240 | +#' @param rope_original_max_pos Original max position embeddings for YaRN. |
| 241 | +#' @param rope_mscale YaRN mscale parameter. |
| 242 | +#' @param rope_mscale_all_dim YaRN mscale_all_dim parameter. |
| 243 | +#' @param identifier HuggingFace model identifier. |
| 244 | +#' @param revision HuggingFace model revision. |
| 245 | +#' @returns An initialized [torch::nn_module()]. |
| 246 | +#' @export |
| 247 | +ministral <- function(vocab_size = 131072, n_embd = 5120, n_inter = 16384, |
| 248 | + n_head = 32, n_kv_head = 8, head_dim = 128, n_layer = 40, |
| 249 | + max_pos = 262144, rmsnorm_eps = 1e-5, rope_base = 1e9, |
| 250 | + rope_factor = 16, rope_beta_fast = 32, rope_beta_slow = 1, |
| 251 | + rope_original_max_pos = 16384, rope_mscale = 1, |
| 252 | + rope_mscale_all_dim = 1) { |
| 253 | + nn_ministral_model( |
| 254 | + vocab_size, n_embd, n_inter, n_head, n_kv_head, head_dim, n_layer, max_pos, |
| 255 | + rmsnorm_eps, rope_base, rope_factor, rope_beta_fast, rope_beta_slow, |
| 256 | + rope_original_max_pos, rope_mscale, rope_mscale_all_dim |
| 257 | + ) |
| 258 | +} |
| 259 | + |
| 260 | +#' @describeIn ministral Initializes from HuggingFace config |
| 261 | +#' @export |
| 262 | +ministral_from_config <- function(identifier, revision = "main") { |
| 263 | + path <- hfhub::hub_download(identifier, "config.json", revision = revision) |
| 264 | + config <- jsonlite::fromJSON(path) |
| 265 | + |
| 266 | + # Handle multimodal config (text_config nested) |
| 267 | + if (!is.null(config$text_config)) { |
| 268 | + rope_params <- config$text_config$rope_parameters |
| 269 | + config <- config$text_config |
| 270 | + } else { |
| 271 | + rope_params <- config$rope_parameters %||% config |
| 272 | + } |
| 273 | + |
| 274 | + if (!config$model_type %in% c("ministral3", "mistral")) |
| 275 | + cli::cli_abort("Unsupported model_type: {.val {config$model_type}}") |
| 276 | + |
| 277 | + if (config$hidden_act != "silu") |
| 278 | + cli::cli_abort("Unsupported hidden_act: {.val {config$hidden_act}}") |
| 279 | + |
| 280 | + ministral( |
| 281 | + vocab_size = config$vocab_size, |
| 282 | + n_embd = config$hidden_size, |
| 283 | + n_inter = config$intermediate_size, |
| 284 | + n_head = config$num_attention_heads, |
| 285 | + n_kv_head = config$num_key_value_heads, |
| 286 | + head_dim = config$head_dim %||% (config$hidden_size %/% config$num_attention_heads), |
| 287 | + n_layer = config$num_hidden_layers, |
| 288 | + max_pos = config$max_position_embeddings, |
| 289 | + rmsnorm_eps = config$rms_norm_eps %||% 1e-5, |
| 290 | + rope_base = rope_params$rope_theta %||% 1e9, |
| 291 | + rope_factor = rope_params$factor %||% 16, |
| 292 | + rope_beta_fast = rope_params$beta_fast %||% 32, |
| 293 | + rope_beta_slow = rope_params$beta_slow %||% 1, |
| 294 | + rope_original_max_pos = rope_params$original_max_position_embeddings %||% 16384, |
| 295 | + rope_mscale = rope_params$mscale %||% 1, |
| 296 | + rope_mscale_all_dim = rope_params$mscale_all_dim %||% 1 |
| 297 | + ) |
| 298 | +} |
| 299 | + |
| 300 | +#' @describeIn ministral Initializes and loads pretrained weights from HF Hub |
| 301 | +#' @export |
| 302 | +ministral_from_pretrained <- function(identifier, revision = "main") { |
| 303 | + with_device(device = "meta", { |
| 304 | + model <- ministral_from_config(identifier, revision) |
| 305 | + }) |
| 306 | + state_dict <- hf_state_dict(identifier, revision) |
| 307 | + state_dict <- ministral_hf_weights_remap(state_dict) |
| 308 | + model$load_state_dict(state_dict, .refer_to_state_dict = TRUE) |
| 309 | + model |
| 310 | +} |
| 311 | + |
| 312 | +ministral_hf_weights_remap <- function(state_dict) { |
| 313 | + nms <- names(state_dict) |
| 314 | + |
| 315 | + # Handle multimodal models (language_model prefix) |
| 316 | + nms <- gsub("^language_model\\.", "", nms) |
| 317 | + |
| 318 | + # Standard remapping |
| 319 | + nms <- gsub("model.embed_tokens.weight", "transformer.wte.weight", nms, fixed = TRUE) |
| 320 | + nms <- gsub("model.layers", "transformer.h", nms, fixed = TRUE) |
| 321 | + nms <- gsub("self_attn", "attn", nms, fixed = TRUE) |
| 322 | + nms <- gsub("input_layernorm", "ln_1", nms, fixed = TRUE) |
| 323 | + nms <- gsub("post_attention_layernorm", "ln_2", nms, fixed = TRUE) |
| 324 | + nms <- gsub("model.norm", "transformer.ln_f", nms, fixed = TRUE) |
| 325 | + |
| 326 | + names(state_dict) <- nms |
| 327 | + |
| 328 | + # Filter out non-language model weights |
| 329 | + keep <- !grepl("vision_tower|multi_modal_projector|scale", names(state_dict)) |
| 330 | + state_dict[keep] |
| 331 | +} |
0 commit comments