33from tinygrad .tensor import Tensor
44from tinygrad .dtype import dtypes
55from tinygrad .device import is_dtype_supported
6- from tinygrad .helpers import prod , make_tuple , flatten
6+ from tinygrad .helpers import prod , make_tuple , flatten , USE_ATOMICS
77from tinygrad .nn import optim , state , datasets # noqa: F401
88
99class BatchNorm :
@@ -304,6 +304,46 @@ def __call__(self, x:Tensor) -> Tensor:
304304 x = self ._norm (x .float ()).cast (x .dtype )
305305 return x if self .weight is None else x * self .weight
306306
307+ from tinygrad .uop .ops import UOp , KernelInfo , Ops
308+ def _embedding_bwd (grad_emb :UOp , call :UOp ) -> tuple :
309+ weight , idx = call .src [1 :]
310+ # for multi-device: unshard inputs to one device
311+ if isinstance (weight .device , tuple ):
312+ assert weight .axis is None , "sharded weights on Embedding not supported with USE_ATOMICS"
313+ grad_emb = grad_emb .copy_to_device (weight .device )
314+ idx = idx .copy_to_device (weight .device )
315+ # weight is replicated, grad_weight should match
316+ grad_weight_uop = Tensor .empty (weight .shape , dtype = weight .dtype , device = weight .device ).uop
317+
318+ # TODO: how do we remove this dumb kernel and use Tensor.zeros?
319+ def _zero_kernel (out :UOp ) -> UOp :
320+ i = UOp .range (out .size , 0 )
321+ return out .flatten ()[i ].store (0 ).end (i ).sink (arg = KernelInfo (name = "zero" ))
322+ grad_weight_uop = grad_weight_uop .custom_kernel (fxn = _zero_kernel )[0 ]
323+
324+ # TODO: do we have a universal helper for this?
325+ device = call .device .split (":" )[0 ] if not isinstance (call .device , tuple ) else call .device [0 ].split (":" )[0 ]
326+
327+ # this is the real atomic kernel
328+ def _embedding_bwd_kernel (grad_weight :UOp , grad_emb :UOp , idx :UOp ) -> UOp :
329+ idx_flat , grad_emb_flat = idx .flatten (), grad_emb .reshape ((idx .size , grad_weight .shape [- 1 ]))
330+ i = UOp .range (grad_emb_flat .shape [0 ], 0 ) # batch_size * sequence_length
331+ j = UOp .range (grad_emb_flat .shape [1 ], 1 ) # embed_size
332+ token_id = idx_flat [i ].clip (0 , grad_weight .shape [0 ]- 1 ).cast (dtypes .index )
333+ # atomic scatter-add: grad_weight[token_id, j] += grad_emb_flat[i, j]
334+ if device in ("CPU" , "NULL" ): atomic_arg = "__atomic_fetch_add({0}, {1}, __ATOMIC_RELAXED);"
335+ elif device == "AMD" : atomic_arg = "__hip_atomic_fetch_add({0}, {1}, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);"
336+ else : raise NotImplementedError (f"no atomics for device { device } " )
337+ atomic = UOp (Ops .CUSTOM , dtypes .void , (grad_weight .index (token_id , j , ptr = True ), grad_emb_flat [i , j ]), arg = atomic_arg )
338+ return atomic .end (i , j ).sink (arg = KernelInfo (name = "embedding_bwd" , opts_to_apply = ()))
339+ grad_weight_uop = grad_weight_uop .custom_kernel (grad_emb , idx , fxn = _embedding_bwd_kernel )[0 ]
340+
341+ return (grad_weight_uop , None )
342+
343+ def _embedding_fwd (weight :Tensor , idx :Tensor ) -> Tensor :
344+ arange = Tensor .arange (weight .shape [0 ], requires_grad = False , device = weight .device )
345+ return (arange == idx .unsqueeze (- 1 )).unsqueeze (- 1 ).where (weight , 0 ).sum (- 2 , dtype = weight .dtype )
346+
307347class Embedding :
308348 """
309349 A simple lookup table that stores embeddings of a fixed dictionary and size.
@@ -316,12 +356,12 @@ class Embedding:
316356 ```
317357 """
318358 def __init__ (self , vocab_size :int , embed_size :int ):
319- self .vocab_sz , self . embed_sz , self . weight = vocab_size , embed_size , Tensor .glorot_uniform (vocab_size , embed_size )
359+ self .weight = Tensor .glorot_uniform (vocab_size , embed_size )
320360
321361 def __call__ (self , idx :Tensor ) -> Tensor :
322362 if not dtypes .is_int (idx .dtype ): raise TypeError (f"Expected integer dtype for index in embedding, got { idx .dtype } " )
323- arange = Tensor .arange (self .weight . shape [ 0 ], requires_grad = False , device = self .weight .device )
324- return ( arange == idx . unsqueeze ( - 1 )). unsqueeze ( - 1 ). where ( self .weight , 0 ). sum ( - 2 , dtype = self . weight . dtype )
363+ if USE_ATOMICS : return Tensor .call (self .weight , idx , fxn = _embedding_fwd ( self .weight .as_param ( 0 ), idx . as_param ( 1 )), grad_fxn = _embedding_bwd )
364+ return _embedding_fwd ( self .weight , idx )
325365
326366class LSTMCell :
327367 """
0 commit comments