From a00af54371d254a0802b12061bcc9f47e67236ad Mon Sep 17 00:00:00 2001 From: Alan Kessler Date: Fri, 27 Mar 2026 15:03:58 -0700 Subject: [PATCH] Guard embedder inputs against maxPositionEmbeddings BERT models crash when input exceeds maxPositionEmbeddings because the position embedding table is fixed-size. Truncate with a warning rather than crashing. Expose maxPositionEmbeddings on the EmbeddingModel protocol (default nil, non-breaking) so callers can check the limit and pre-truncate or chunk as needed. Fixes #62. --- Libraries/MLXEmbedders/EmbeddingModel.swift | 10 ++++++++ Libraries/MLXEmbedders/Models/Bert.swift | 20 +++++++++++++++- Libraries/MLXEmbedders/Models/NomicBert.swift | 24 +++++++++++++++++-- 3 files changed, 51 insertions(+), 3 deletions(-) diff --git a/Libraries/MLXEmbedders/EmbeddingModel.swift b/Libraries/MLXEmbedders/EmbeddingModel.swift index d55b2c48..9062eaef 100644 --- a/Libraries/MLXEmbedders/EmbeddingModel.swift +++ b/Libraries/MLXEmbedders/EmbeddingModel.swift @@ -101,6 +101,14 @@ public struct EmbeddingModelOutput { public protocol EmbeddingModel: Module { var vocabularySize: Int { get } + /// The maximum number of position embeddings supported by this model, or `nil` + /// if the model uses a position encoding (e.g. RoPE) that handles arbitrary lengths. + /// + /// Inputs exceeding this length are truncated internally with a warning. + /// Callers may pre-truncate for efficiency or to implement custom strategies + /// (e.g. chunking with pooling). + var maxPositionEmbeddings: Int? { get } + func callAsFunction( _ inputs: MLXArray, positionIds: MLXArray?, @@ -113,6 +121,8 @@ public protocol EmbeddingModel: Module { } extension EmbeddingModel { + public var maxPositionEmbeddings: Int? { nil } + func callAsFunction( _ inputs: MLXArray, positionIds: MLXArray? = nil, diff --git a/Libraries/MLXEmbedders/Models/Bert.swift b/Libraries/MLXEmbedders/Models/Bert.swift index 2af0c758..ad3bae91 100644 --- a/Libraries/MLXEmbedders/Models/Bert.swift +++ b/Libraries/MLXEmbedders/Models/Bert.swift @@ -261,6 +261,11 @@ public class BertModel: Module, EmbeddingModel { /// The total count of tokens in the model's vocabulary. public var vocabularySize: Int + public var maxPositionEmbeddings: Int? { + _maxPositionEmbeddings > 0 ? _maxPositionEmbeddings : nil + } + private let _maxPositionEmbeddings: Int + /// Initializes a BERT model. /// - Parameters: /// - config: The architecture settings (layers, heads, dimensions). @@ -269,6 +274,7 @@ public class BertModel: Module, EmbeddingModel { public init(_ config: BertConfiguration, lmHead: Bool = false) { precondition(config.vocabularySize > 0) vocabularySize = config.vocabularySize + _maxPositionEmbeddings = config.maxPositionEmbeddings encoder = Encoder(config) _embedder.wrappedValue = BertEmbedding(config) @@ -300,8 +306,20 @@ public class BertModel: Module, EmbeddingModel { if inp.ndim == 1 { inp = inp.reshaped(1, -1) } - let embeddings = embedder(inp, positionIds: positionIds, tokenTypeIds: tokenTypeIds) var mask = attentionMask + var typeIds = tokenTypeIds + var posIds = positionIds + if _maxPositionEmbeddings > 0, inp.dim(1) > _maxPositionEmbeddings { + print( + "Warning: Input length \(inp.dim(1)) exceeds maxPositionEmbeddings" + + " (\(_maxPositionEmbeddings)), truncating." + ) + inp = inp[0..., ..<_maxPositionEmbeddings] + mask = mask?[0..., ..<_maxPositionEmbeddings] + typeIds = typeIds?[0..., ..<_maxPositionEmbeddings] + posIds = posIds?[0..., ..<_maxPositionEmbeddings] + } + let embeddings = embedder(inp, positionIds: posIds, tokenTypeIds: typeIds) if mask != nil { // Cast mask to the same dtype as the embeddings output so it is // compatible with scaled_dot_product_attention's type promotion diff --git a/Libraries/MLXEmbedders/Models/NomicBert.swift b/Libraries/MLXEmbedders/Models/NomicBert.swift index 417145b6..3a716790 100644 --- a/Libraries/MLXEmbedders/Models/NomicBert.swift +++ b/Libraries/MLXEmbedders/Models/NomicBert.swift @@ -679,6 +679,11 @@ public class NomicBertModel: Module, EmbeddingModel { /// The size of the vocabulary. public var vocabularySize: Int + public var maxPositionEmbeddings: Int? { + _maxPositionEmbeddings > 0 ? _maxPositionEmbeddings : nil + } + private let _maxPositionEmbeddings: Int + /// Initializes the Nomic BERT model. /// /// - Parameters: @@ -691,6 +696,7 @@ public class NomicBertModel: Module, EmbeddingModel { ) { precondition(config.vocabularySize > 0) vocabularySize = config.vocabularySize + _maxPositionEmbeddings = config.maxPositionEmbeddings encoder = Encoder(config) _embedder.wrappedValue = NomicEmbedding(config) @@ -734,14 +740,28 @@ public class NomicBertModel: Module, EmbeddingModel { inp = inp.reshaped(1, -1) } + // Truncate to max position embeddings when using absolute position embeddings + var mask = attentionMask + var typeIds = tokenTypeIds + var posIds = positionIds + if _maxPositionEmbeddings > 0, inp.dim(1) > _maxPositionEmbeddings { + print( + "Warning: Input length \(inp.dim(1)) exceeds maxPositionEmbeddings" + + " (\(_maxPositionEmbeddings)), truncating." + ) + inp = inp[0..., ..<_maxPositionEmbeddings] + mask = mask?[0..., ..<_maxPositionEmbeddings] + typeIds = typeIds?[0..., ..<_maxPositionEmbeddings] + posIds = posIds?[0..., ..<_maxPositionEmbeddings] + } + // 2. Process Attention Mask // Input: Binary mask (1 = valid, 0 = mask). // Operation: .log(). // log(1) = 0 (Add 0 to attention score -> No change) // log(0) = -inf (Add -inf to attention score -> Zero probability after Softmax) let embeddings = embedder( - inp, positionIds: positionIds, tokenTypeIds: tokenTypeIds) - var mask = attentionMask + inp, positionIds: posIds, tokenTypeIds: typeIds) if mask != nil { // Cast mask to the same dtype as the embeddings output so it is // compatible with scaled_dot_product_attention's type promotion