Skip to content

fix: @ModuleInfo for pooler + attention mask dtype in Bert/NomicBert#153

Merged
davidkoski merged 2 commits intoml-explore:mainfrom
jowharshamshiri:main
Mar 19, 2026
Merged

fix: @ModuleInfo for pooler + attention mask dtype in Bert/NomicBert#153
davidkoski merged 2 commits intoml-explore:mainfrom
jowharshamshiri:main

Conversation

@jowharshamshiri
Copy link
Copy Markdown
Contributor

@jowharshamshiri jowharshamshiri commented Mar 18, 2026

Summary

Two fixes for BertModel and NomicBertModel in MLXEmbedders:

1. Add @ModuleInfo to pooler property

BertModel.pooler and NomicBertModel.pooler were declared as let without @ModuleInfo. During quantized model loading, quantize()model.update(modules:)try! on needModuleInfo error → SIGABRT:

MLXNN/Module.swift:570: Fatal error: 'try!' expression unexpectedly raised an error:
MLXNN.UpdateError.needModuleInfo("Unable to get @ModuleInfo for BertModel.pooler
-- must be wrapped to receive updates")

Fix: Changed let pooler: Linear?@ModuleInfo var pooler: Linear? and updated init to use _pooler.wrappedValue.

2. Cast attention mask to embeddings output dtype

The attention mask was cast to embedder.wordEmbeddings.weight.dtype, which remains float32 since Embedding layers are not quantized. When Linear layers (Q/K/V projections) are quantized to float16, MLXFast.scaledDotProductAttention requires the mask dtype to promote to the output type (float16). A float32 mask cannot promote down to float16:

[scaled_dot_product_attention] Mask type must promote to output type float16.

Fix: Compute embeddings first, then cast the mask to embeddings.dtype instead of the embedding weight's storage dtype.

Files changed

  • Libraries/MLXEmbedders/Models/Bert.swift
  • Libraries/MLXEmbedders/Models/NomicBert.swift

@davidkoski
Copy link
Copy Markdown
Collaborator

This may be overthinking the problem. The error is correct, but perhaps being read more generally than intended:

MLXNN.UpdateError.needModuleInfo("Unable to get @ModuleInfo for BertModel.pooler -- must be wrapped to receive updates")

if we look at BertModel:

public class BertModel: Module, EmbeddingModel {
    let pooler: Linear?

The problem is primarily that pooler is a let -- you can't update it once set. I suspect that if that were:

    var pooler: Linear?

it would have worked fine. @ModuleInfo requires var and just wraps it in a holder -- this would certainly be enough for it to work, but a plain var should be as well.

Changes like this one in Qwen3.swift:

private class TransformerBlock: Module {
    /// The feed-forward network (SwiGLU).
    let mlp: MLP // -> @ModuleInfo var mlp: MLP

are not required -- there is no QuantizedMLP that it is going to be replaced with and the MLP has var properties inside it.

Now these changes are not wrong but they are rather large and mostly unnecessary. I think all let x: Linear should be changed to either var x: Linear or @ModuleInfo var x: Linear. The other references to modules don't need it (typically only things that you allow to be replaced at runtime and those are the ones that support quantization).

@jowharshamshiri
Copy link
Copy Markdown
Contributor Author

Thanks for the detailed explanation. You're right that the core issue is the let preventing mutation, and @ModuleInfo is only strictly needed for properties that get replaced at runtime (like Linear → QuantizedLinear during quantization). The norm, mlp, etc. properties don't get replaced wholesale, so they don't need it.

I'll trim this PR down to just the Linear (and Linear?) properties that can be replaced by QuantizedLinear:

  • BertModel.pooler: Linear?
  • NomicBertModel.pooler: Linear?

Should I use plain var or @ModuleInfo var for these?
Happy to match whichever convention you prefer.

@davidkoski
Copy link
Copy Markdown
Collaborator

I think either is fine, but if you are going to touch the lines might as well add @ModuleInfo. Thanks!

These Linear? properties are replaced by QuantizedLinear during
quantization via Module.update(modules:). Without @ModuleInfo,
the non-throwing update(modules:) wrapper hits try! on the thrown
needModuleInfo error, causing SIGABRT when loading quantized
embedding models.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…Bert

The attention mask was cast to `embedder.wordEmbeddings.weight.dtype`,
which can be float32 since Embedding layers are not quantized. When
Linear layers (Q/K/V projections) are quantized to float16,
`MLXFast.scaledDotProductAttention` requires the mask to promote to the
output type (float16). A float32 mask cannot promote *down* to float16,
causing a fatal error:

  [scaled_dot_product_attention] Mask type must promote to output type
  float16.

Fix: compute embeddings first, then cast the mask to `embeddings.dtype`,
which reflects the actual compute dtype flowing into the encoder and is
consistent with the attention output type.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@jowharshamshiri jowharshamshiri changed the title fix: add missing @ModuleInfo to Module-typed properties fix: @ModuleInfo for pooler + attention mask dtype in Bert/NomicBert Mar 19, 2026
Copy link
Copy Markdown
Collaborator

@davidkoski davidkoski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thank you!

@davidkoski davidkoski merged commit 3a7503d into ml-explore:main Mar 19, 2026
2 checks passed
viktike pushed a commit to viktike/mlx-swift-lm that referenced this pull request Mar 23, 2026
…l-explore#153)

* fix: add @ModuleInfo to BertModel.pooler and NomicBertModel.pooler

These Linear? properties are replaced by QuantizedLinear during
quantization via Module.update(modules:). Without @ModuleInfo,
the non-throwing update(modules:) wrapper hits try! on the thrown
needModuleInfo error, causing SIGABRT when loading quantized
embedding models.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: cast attention mask to embeddings output dtype in Bert and NomicBert

The attention mask was cast to `embedder.wordEmbeddings.weight.dtype`,
which can be float32 since Embedding layers are not quantized. When
Linear layers (Q/K/V projections) are quantized to float16,
`MLXFast.scaledDotProductAttention` requires the mask to promote to the
output type (float16). A float32 mask cannot promote *down* to float16,
causing a fatal error:

  [scaled_dot_product_attention] Mask type must promote to output type
  float16.

Fix: compute embeddings first, then cast the mask to `embeddings.dtype`,
which reflects the actual compute dtype flowing into the encoder and is
consistent with the attention output type.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants