fix: @ModuleInfo for pooler + attention mask dtype in Bert/NomicBert#153
fix: @ModuleInfo for pooler + attention mask dtype in Bert/NomicBert#153davidkoski merged 2 commits intoml-explore:mainfrom
Conversation
|
This may be overthinking the problem. The error is correct, but perhaps being read more generally than intended: if we look at BertModel: public class BertModel: Module, EmbeddingModel {
let pooler: Linear?
The problem is primarily that var pooler: Linear?it would have worked fine. Changes like this one in private class TransformerBlock: Module {
/// The feed-forward network (SwiGLU).
let mlp: MLP // -> @ModuleInfo var mlp: MLP
are not required -- there is no Now these changes are not wrong but they are rather large and mostly unnecessary. I think all |
|
Thanks for the detailed explanation. You're right that the core issue is the let preventing mutation, and @ModuleInfo is only strictly needed for properties that get replaced at runtime (like Linear → QuantizedLinear during quantization). The norm, mlp, etc. properties don't get replaced wholesale, so they don't need it. I'll trim this PR down to just the Linear (and Linear?) properties that can be replaced by QuantizedLinear:
Should I use plain var or @ModuleInfo var for these? |
|
I think either is fine, but if you are going to touch the lines might as well add |
These Linear? properties are replaced by QuantizedLinear during quantization via Module.update(modules:). Without @ModuleInfo, the non-throwing update(modules:) wrapper hits try! on the thrown needModuleInfo error, causing SIGABRT when loading quantized embedding models. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…Bert The attention mask was cast to `embedder.wordEmbeddings.weight.dtype`, which can be float32 since Embedding layers are not quantized. When Linear layers (Q/K/V projections) are quantized to float16, `MLXFast.scaledDotProductAttention` requires the mask to promote to the output type (float16). A float32 mask cannot promote *down* to float16, causing a fatal error: [scaled_dot_product_attention] Mask type must promote to output type float16. Fix: compute embeddings first, then cast the mask to `embeddings.dtype`, which reflects the actual compute dtype flowing into the encoder and is consistent with the attention output type. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
davidkoski
left a comment
There was a problem hiding this comment.
Looks good, thank you!
…l-explore#153) * fix: add @ModuleInfo to BertModel.pooler and NomicBertModel.pooler These Linear? properties are replaced by QuantizedLinear during quantization via Module.update(modules:). Without @ModuleInfo, the non-throwing update(modules:) wrapper hits try! on the thrown needModuleInfo error, causing SIGABRT when loading quantized embedding models. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: cast attention mask to embeddings output dtype in Bert and NomicBert The attention mask was cast to `embedder.wordEmbeddings.weight.dtype`, which can be float32 since Embedding layers are not quantized. When Linear layers (Q/K/V projections) are quantized to float16, `MLXFast.scaledDotProductAttention` requires the mask to promote to the output type (float16). A float32 mask cannot promote *down* to float16, causing a fatal error: [scaled_dot_product_attention] Mask type must promote to output type float16. Fix: compute embeddings first, then cast the mask to `embeddings.dtype`, which reflects the actual compute dtype flowing into the encoder and is consistent with the attention output type. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Summary
Two fixes for
BertModelandNomicBertModelin MLXEmbedders:1. Add
@ModuleInfotopoolerpropertyBertModel.poolerandNomicBertModel.poolerwere declared asletwithout@ModuleInfo. During quantized model loading,quantize()→model.update(modules:)→try!onneedModuleInfoerror → SIGABRT:Fix: Changed
let pooler: Linear?→@ModuleInfo var pooler: Linear?and updated init to use_pooler.wrappedValue.2. Cast attention mask to embeddings output dtype
The attention mask was cast to
embedder.wordEmbeddings.weight.dtype, which remainsfloat32sinceEmbeddinglayers are not quantized. WhenLinearlayers (Q/K/V projections) are quantized tofloat16,MLXFast.scaledDotProductAttentionrequires the mask dtype to promote to the output type (float16). Afloat32mask cannot promote down tofloat16:Fix: Compute embeddings first, then cast the mask to
embeddings.dtypeinstead of the embedding weight's storage dtype.Files changed
Libraries/MLXEmbedders/Models/Bert.swiftLibraries/MLXEmbedders/Models/NomicBert.swift