File tree Expand file tree Collapse file tree 1 file changed +9
-2
lines changed Expand file tree Collapse file tree 1 file changed +9
-2
lines changed Original file line number Diff line number Diff line change @@ -38,6 +38,8 @@ def __init__(
3838 scale_grad_by_freq : bool = False ,
3939 sparse : bool = False ,
4040 _weight : Optional [Tensor ] = None ,
41+ device = None ,
42+ dtype = None ,
4143 ) -> None :
4244 super (StableEmbedding , self ).__init__ (
4345 num_embeddings ,
@@ -48,8 +50,10 @@ def __init__(
4850 scale_grad_by_freq ,
4951 sparse ,
5052 _weight ,
53+ device ,
54+ dtype ,
5155 )
52- self .norm = torch .nn .LayerNorm (embedding_dim )
56+ self .norm = torch .nn .LayerNorm (embedding_dim , device = device )
5357 GlobalOptimManager .get_instance ().register_module_override (
5458 self , "weight" , {"optim_bits" : 32 }
5559 )
@@ -81,7 +85,10 @@ def forward(self, input: Tensor) -> Tensor:
8185 self .sparse ,
8286 )
8387
84- return self .norm (emb )
88+ # always apply layer norm in full precision
89+ emb = emb .to (torch .get_default_dtype ())
90+
91+ return self .norm (emb ).to (self .weight .dtype )
8592
8693
8794class Embedding (torch .nn .Embedding ):
You can’t perform that action at this time.
0 commit comments