Skip to content

Commit 62d39a2

Browse files
committed
add device and dtype parameters to StableEmbedding
1 parent 1efb87d commit 62d39a2

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

bitsandbytes/nn/modules.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ def __init__(
3838
scale_grad_by_freq: bool = False,
3939
sparse: bool = False,
4040
_weight: Optional[Tensor] = None,
41+
device=None,
42+
dtype=None,
4143
) -> None:
4244
super(StableEmbedding, self).__init__(
4345
num_embeddings,
@@ -48,8 +50,10 @@ def __init__(
4850
scale_grad_by_freq,
4951
sparse,
5052
_weight,
53+
device,
54+
dtype,
5155
)
52-
self.norm = torch.nn.LayerNorm(embedding_dim)
56+
self.norm = torch.nn.LayerNorm(embedding_dim, device=device)
5357
GlobalOptimManager.get_instance().register_module_override(
5458
self, "weight", {"optim_bits": 32}
5559
)
@@ -81,7 +85,10 @@ def forward(self, input: Tensor) -> Tensor:
8185
self.sparse,
8286
)
8387

84-
return self.norm(emb)
88+
# always apply layer norm in full precision
89+
emb = emb.to(torch.get_default_dtype())
90+
91+
return self.norm(emb).to(self.weight.dtype)
8592

8693

8794
class Embedding(torch.nn.Embedding):

0 commit comments

Comments
 (0)