Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Libraries/MLXEmbedders/EmbeddingModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ public struct EmbeddingModelOutput {
public protocol EmbeddingModel: Module {
var vocabularySize: Int { get }

/// The maximum number of position embeddings supported by this model, or `nil`
/// if the model uses a position encoding (e.g. RoPE) that handles arbitrary lengths.
///
/// Inputs exceeding this length are truncated internally with a warning.
/// Callers may pre-truncate for efficiency or to implement custom strategies
/// (e.g. chunking with pooling).
var maxPositionEmbeddings: Int? { get }

func callAsFunction(
_ inputs: MLXArray,
positionIds: MLXArray?,
Expand All @@ -113,6 +121,8 @@ public protocol EmbeddingModel: Module {
}

extension EmbeddingModel {
public var maxPositionEmbeddings: Int? { nil }

func callAsFunction(
_ inputs: MLXArray,
positionIds: MLXArray? = nil,
Expand Down
20 changes: 19 additions & 1 deletion Libraries/MLXEmbedders/Models/Bert.swift
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,11 @@ public class BertModel: Module, EmbeddingModel {
/// The total count of tokens in the model's vocabulary.
public var vocabularySize: Int

public var maxPositionEmbeddings: Int? {
_maxPositionEmbeddings > 0 ? _maxPositionEmbeddings : nil
}
private let _maxPositionEmbeddings: Int

/// Initializes a BERT model.
/// - Parameters:
/// - config: The architecture settings (layers, heads, dimensions).
Expand All @@ -269,6 +274,7 @@ public class BertModel: Module, EmbeddingModel {
public init(_ config: BertConfiguration, lmHead: Bool = false) {
precondition(config.vocabularySize > 0)
vocabularySize = config.vocabularySize
_maxPositionEmbeddings = config.maxPositionEmbeddings
encoder = Encoder(config)
_embedder.wrappedValue = BertEmbedding(config)

Expand Down Expand Up @@ -300,8 +306,20 @@ public class BertModel: Module, EmbeddingModel {
if inp.ndim == 1 {
inp = inp.reshaped(1, -1)
}
let embeddings = embedder(inp, positionIds: positionIds, tokenTypeIds: tokenTypeIds)
var mask = attentionMask
var typeIds = tokenTypeIds
var posIds = positionIds
if _maxPositionEmbeddings > 0, inp.dim(1) > _maxPositionEmbeddings {
print(
"Warning: Input length \(inp.dim(1)) exceeds maxPositionEmbeddings"
+ " (\(_maxPositionEmbeddings)), truncating."
)
inp = inp[0..., ..<_maxPositionEmbeddings]
mask = mask?[0..., ..<_maxPositionEmbeddings]
typeIds = typeIds?[0..., ..<_maxPositionEmbeddings]
posIds = posIds?[0..., ..<_maxPositionEmbeddings]
}
let embeddings = embedder(inp, positionIds: posIds, tokenTypeIds: typeIds)
if mask != nil {
// Cast mask to the same dtype as the embeddings output so it is
// compatible with scaled_dot_product_attention's type promotion
Expand Down
24 changes: 22 additions & 2 deletions Libraries/MLXEmbedders/Models/NomicBert.swift
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,11 @@ public class NomicBertModel: Module, EmbeddingModel {
/// The size of the vocabulary.
public var vocabularySize: Int

public var maxPositionEmbeddings: Int? {
_maxPositionEmbeddings > 0 ? _maxPositionEmbeddings : nil
}
private let _maxPositionEmbeddings: Int

/// Initializes the Nomic BERT model.
///
/// - Parameters:
Expand All @@ -691,6 +696,7 @@ public class NomicBertModel: Module, EmbeddingModel {
) {
precondition(config.vocabularySize > 0)
vocabularySize = config.vocabularySize
_maxPositionEmbeddings = config.maxPositionEmbeddings
encoder = Encoder(config)
_embedder.wrappedValue = NomicEmbedding(config)

Expand Down Expand Up @@ -734,14 +740,28 @@ public class NomicBertModel: Module, EmbeddingModel {
inp = inp.reshaped(1, -1)
}

// Truncate to max position embeddings when using absolute position embeddings
var mask = attentionMask
var typeIds = tokenTypeIds
var posIds = positionIds
if _maxPositionEmbeddings > 0, inp.dim(1) > _maxPositionEmbeddings {
print(
"Warning: Input length \(inp.dim(1)) exceeds maxPositionEmbeddings"
+ " (\(_maxPositionEmbeddings)), truncating."
)
inp = inp[0..., ..<_maxPositionEmbeddings]
mask = mask?[0..., ..<_maxPositionEmbeddings]
typeIds = typeIds?[0..., ..<_maxPositionEmbeddings]
posIds = posIds?[0..., ..<_maxPositionEmbeddings]
}

// 2. Process Attention Mask
// Input: Binary mask (1 = valid, 0 = mask).
// Operation: .log().
// log(1) = 0 (Add 0 to attention score -> No change)
// log(0) = -inf (Add -inf to attention score -> Zero probability after Softmax)
let embeddings = embedder(
inp, positionIds: positionIds, tokenTypeIds: tokenTypeIds)
var mask = attentionMask
inp, positionIds: posIds, tokenTypeIds: typeIds)
if mask != nil {
// Cast mask to the same dtype as the embeddings output so it is
// compatible with scaled_dot_product_attention's type promotion
Expand Down