Skip to content

Commit 838cd07

Browse files
authored
use atomics for embedding backward (tinygrad#14400)
* embedding is slow * failing * float is fine * null * it fails * simplify embedding with broadcasting * ATOMIC_ADD incoming * min change * simpler test * better test * fix test * real test * simpler * cleanups * types and names * _zero_kernel * grad multi * hack * none * multi unshard * more for call * don't tag in call * good * call_multi * call_multi wow claude is useless * embedding backward mutli test * test passes * fix as_param * shape_to_shape_arg * add clip * before cast * fix spec=2, use atomics
1 parent 1998e0b commit 838cd07

File tree

10 files changed

+134
-23
lines changed

10 files changed

+134
-23
lines changed

examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_run.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ export REWRITE_STACK_LIMIT=5000000 HCQDEV_WAIT_TIMEOUT_MS=240000
99
export DEBUG=${DEBUG:-0}
1010
export FLASH_ATTENTION=${FLASH_ATTENTION:-1}
1111
export ALL2ALL=${ALL2ALL:-1}
12+
export USE_ATOMICS=${USE_ATOMICS:-1}
1213

1314
export DEFAULT_FLOAT="bfloat16" OPTIM_DTYPE="bfloat16"
1415
export DP=${DP:-8} BS=8 EVAL_BS=8 GRADIENT_ACC_STEPS=2

test/test_arange.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,5 +163,30 @@ def test_llama_embedding(self, noopt=1, op_limit=65536):
163163
# at least the arange is being fused
164164
def test_llama_embedding_opt(self): self.test_llama_embedding(0, 1_736_704_000)
165165

166+
# NOTE: call doesn't work with SPEC=2
167+
@unittest.skipIf(Device.DEFAULT not in ("CPU", "AMD"), "atomics only on AMD/CPU")
168+
@Context(USE_ATOMICS=1, SPEC=1)
169+
def test_llama_8b_embedding_backward(self):
170+
from tinygrad.renderer.cstyle import CStyleLanguage
171+
if Device.DEFAULT == "CPU" and not isinstance(Device["CPU"].renderer, CStyleLanguage): self.skipTest("CPU needs Clang renderer")
172+
vocab_size, embed_size = 1000, 128
173+
bs, seqlen = 4, 256
174+
idx = Tensor.randint(bs, seqlen, high=vocab_size)
175+
emb = nn.Embedding(vocab_size, embed_size)
176+
emb.weight = Tensor.ones(vocab_size, embed_size, requires_grad=True)
177+
gt = Tensor.zeros(bs, seqlen, embed_size)
178+
Tensor.realize(idx, emb.weight, gt)
179+
GlobalCounters.reset()
180+
loss = (emb(idx)-gt).square().sum()
181+
loss.backward()
182+
emb.weight.grad.realize()
183+
bwd_ops = GlobalCounters.global_ops
184+
print(f"embedding bwd: {GlobalCounters.kernel_count} kernels, {bwd_ops:,} ops")
185+
self.assertLess(bwd_ops, bs*seqlen*embed_size*20, f"backward ops {bwd_ops:,} should be less than 20 per with atomic scatter-add")
186+
# correctness check
187+
expected_grad = np.zeros((vocab_size, embed_size), dtype=np.float32)
188+
for i in idx.flatten().numpy(): expected_grad[i] += 2
189+
np.testing.assert_allclose(emb.weight.grad.numpy(), expected_grad, rtol=1e-5, atol=1e-5)
190+
166191
if __name__ == "__main__":
167192
unittest.main()

test/test_multitensor.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,28 @@ def test_embedding(self):
409409

410410
np.testing.assert_allclose(z.numpy(), z_shard.numpy(), atol=1e-6, rtol=1e-6)
411411

412+
def test_embedding_backward(self, shard_weight_axis=None):
413+
B, T, embed_size, vocab_size = 4, 10, 20, 28
414+
415+
layer = nn.Embedding(vocab_size, embed_size)
416+
layer.weight.requires_grad = True
417+
x = Tensor(np.random.randint(0, vocab_size, (B, T), dtype=np.int32))
418+
z = layer(x)
419+
z.sum().backward()
420+
grad = layer.weight.grad.numpy()
421+
422+
layer_sharded = nn.Embedding(vocab_size, embed_size)
423+
layer_sharded.weight.replace(layer.weight.shard(devices_2, axis=shard_weight_axis)).realize()
424+
layer_sharded.weight.requires_grad = True
425+
x_sharded = x.shard(devices_2, axis=None)
426+
z_shard = layer_sharded(x_sharded)
427+
z_shard.sum().backward()
428+
grad_shard = layer_sharded.weight.grad.numpy()
429+
430+
np.testing.assert_allclose(grad, grad_shard, atol=1e-6, rtol=1e-6)
431+
432+
def test_embedding_backward_shard_weight(self): self.test_embedding_backward(shard_weight_axis=1)
433+
412434
def test_rmsnorm(self):
413435
B, T, embed_size = 4, 10, 20
414436

test/unit/test_call.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,26 @@ def grad_fxn(grad:UOp, call:UOp): return (grad, grad)
3030

3131
# we define a plus function
3232
plus_fxn = UOp.param(0, dtypes.float, (10,10)) + UOp.param(1, dtypes.float, (10,10))
33-
c = Tensor.call(a, b, fxn=plus_fxn, arg=grad_fxn)
33+
c = Tensor.call(a, b, fxn=plus_fxn, grad_fxn=grad_fxn)
3434
c.mean().backward()
3535

3636
np.testing.assert_allclose(a.grad.numpy(), gt_a_grad, rtol=1e-5)
3737
np.testing.assert_allclose(b.grad.numpy(), gt_b_grad, rtol=1e-5)
3838

39-
@unittest.skip("needs GEMM on mixins")
4039
def test_call_gemm(self):
4140
M, K, N = 4, 8, 4
4241
a = Tensor.randn(M, K)
4342
b = Tensor.randn(K, N)
4443
Tensor.realize(a, b)
44+
c = Tensor.call(a, b, fxn=a.as_param(0) @ b.as_param(1))
45+
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5)
46+
47+
@unittest.skip("needs GEMM on mixins")
48+
def test_call_gemm_uop(self):
49+
M, K, N = 4, 8, 4
50+
a = Tensor.randn(M, K)
51+
b = Tensor.randn(K, N)
52+
Tensor.realize(a, b)
4553

4654
# we define a gemm function
4755
x = UOp.param(0, dtypes.float, shape=(M, K))

tinygrad/helpers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,8 @@ def tolist(self, obj=None):
204204
ALLOW_TF32 = ContextVar("ALLOW_TF32", 0)
205205
# set to 0 to disable the scheduler cache
206206
SCACHE = ContextVar("SCACHE", 1)
207+
# allow use of atomics for embedding backward
208+
USE_ATOMICS = ContextVar("USE_ATOMICS", 0)
207209

208210
@dataclass(frozen=True)
209211
class Metadata:

tinygrad/nn/__init__.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from tinygrad.tensor import Tensor
44
from tinygrad.dtype import dtypes
55
from tinygrad.device import is_dtype_supported
6-
from tinygrad.helpers import prod, make_tuple, flatten
6+
from tinygrad.helpers import prod, make_tuple, flatten, USE_ATOMICS
77
from tinygrad.nn import optim, state, datasets # noqa: F401
88

99
class BatchNorm:
@@ -304,6 +304,46 @@ def __call__(self, x:Tensor) -> Tensor:
304304
x = self._norm(x.float()).cast(x.dtype)
305305
return x if self.weight is None else x * self.weight
306306

307+
from tinygrad.uop.ops import UOp, KernelInfo, Ops
308+
def _embedding_bwd(grad_emb:UOp, call:UOp) -> tuple:
309+
weight, idx = call.src[1:]
310+
# for multi-device: unshard inputs to one device
311+
if isinstance(weight.device, tuple):
312+
assert weight.axis is None, "sharded weights on Embedding not supported with USE_ATOMICS"
313+
grad_emb = grad_emb.copy_to_device(weight.device)
314+
idx = idx.copy_to_device(weight.device)
315+
# weight is replicated, grad_weight should match
316+
grad_weight_uop = Tensor.empty(weight.shape, dtype=weight.dtype, device=weight.device).uop
317+
318+
# TODO: how do we remove this dumb kernel and use Tensor.zeros?
319+
def _zero_kernel(out:UOp) -> UOp:
320+
i = UOp.range(out.size, 0)
321+
return out.flatten()[i].store(0).end(i).sink(arg=KernelInfo(name="zero"))
322+
grad_weight_uop = grad_weight_uop.custom_kernel(fxn=_zero_kernel)[0]
323+
324+
# TODO: do we have a universal helper for this?
325+
device = call.device.split(":")[0] if not isinstance(call.device, tuple) else call.device[0].split(":")[0]
326+
327+
# this is the real atomic kernel
328+
def _embedding_bwd_kernel(grad_weight:UOp, grad_emb:UOp, idx:UOp) -> UOp:
329+
idx_flat, grad_emb_flat = idx.flatten(), grad_emb.reshape((idx.size, grad_weight.shape[-1]))
330+
i = UOp.range(grad_emb_flat.shape[0], 0) # batch_size * sequence_length
331+
j = UOp.range(grad_emb_flat.shape[1], 1) # embed_size
332+
token_id = idx_flat[i].clip(0, grad_weight.shape[0]-1).cast(dtypes.index)
333+
# atomic scatter-add: grad_weight[token_id, j] += grad_emb_flat[i, j]
334+
if device in ("CPU", "NULL"): atomic_arg = "__atomic_fetch_add({0}, {1}, __ATOMIC_RELAXED);"
335+
elif device == "AMD": atomic_arg = "__hip_atomic_fetch_add({0}, {1}, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);"
336+
else: raise NotImplementedError(f"no atomics for device {device}")
337+
atomic = UOp(Ops.CUSTOM, dtypes.void, (grad_weight.index(token_id, j, ptr=True), grad_emb_flat[i, j]), arg = atomic_arg)
338+
return atomic.end(i, j).sink(arg=KernelInfo(name="embedding_bwd", opts_to_apply=()))
339+
grad_weight_uop = grad_weight_uop.custom_kernel(grad_emb, idx, fxn=_embedding_bwd_kernel)[0]
340+
341+
return (grad_weight_uop, None)
342+
343+
def _embedding_fwd(weight:Tensor, idx:Tensor) -> Tensor:
344+
arange = Tensor.arange(weight.shape[0], requires_grad=False, device=weight.device)
345+
return (arange == idx.unsqueeze(-1)).unsqueeze(-1).where(weight, 0).sum(-2, dtype=weight.dtype)
346+
307347
class Embedding:
308348
"""
309349
A simple lookup table that stores embeddings of a fixed dictionary and size.
@@ -316,12 +356,12 @@ class Embedding:
316356
```
317357
"""
318358
def __init__(self, vocab_size:int, embed_size:int):
319-
self.vocab_sz, self.embed_sz, self.weight = vocab_size, embed_size, Tensor.glorot_uniform(vocab_size, embed_size)
359+
self.weight = Tensor.glorot_uniform(vocab_size, embed_size)
320360

321361
def __call__(self, idx:Tensor) -> Tensor:
322362
if not dtypes.is_int(idx.dtype): raise TypeError(f"Expected integer dtype for index in embedding, got {idx.dtype}")
323-
arange = Tensor.arange(self.weight.shape[0], requires_grad=False, device=self.weight.device)
324-
return (arange == idx.unsqueeze(-1)).unsqueeze(-1).where(self.weight, 0).sum(-2, dtype=self.weight.dtype)
363+
if USE_ATOMICS: return Tensor.call(self.weight, idx, fxn=_embedding_fwd(self.weight.as_param(0), idx.as_param(1)), grad_fxn=_embedding_bwd)
364+
return _embedding_fwd(self.weight, idx)
325365

326366
class LSTMCell:
327367
"""

tinygrad/schedule/multi.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def assign_multi(dest:UOp, src:UOp):
202202
return dest.src[0].assign(src.src[0]).multi(src.axis)
203203

204204
def passthrough_multi(root:UOp, multi:UOp):
205-
return UOp(root.op, root.dtype, (multi.src[0],), root.arg).multi(multi.axis)
205+
return UOp(root.op, root.dtype, (multi.src[0],)+tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src[1:]), root.arg).multi(multi.axis)
206206

207207
# NOTE: this is the same pattern as Ops.UNROLL
208208
multi_pm = PatternMatcher([
@@ -218,6 +218,7 @@ def passthrough_multi(root:UOp, multi:UOp):
218218
(UPat(Ops.COPY, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device"))), copy_multi),
219219
(UPat(Ops.ALLREDUCE, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device")), name="red"),
220220
lambda multi,device,red: multi.src[0].allreduce(red.arg, device).multi(axis=multi.axis)),
221+
(UPat(Ops.CALL, src=(UPat(Ops.MULTI, name="multi"), ), name="root", allow_any_len=True), passthrough_multi),
221222
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD),
222223
src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
223224
# multi supports custom kernels with CUSTOM_KERNEL + AFTER

tinygrad/schedule/rangeify.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,13 @@ def resolve_custom_kernel(ck:UOp) -> UOp:
7171
def resolve_call(c:UOp) -> UOp:
7272
params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
7373
args = c.src[1:]
74+
# TODO: this check belongs in spec, not here
7475
if [x.arg for x in params] != list(range(len(params))): raise RuntimeError(f"params not in order: {[x.arg for x in params]}")
7576
if len(params) != len(args): raise TypeError(f"expected {len(params)} args, got {len(args)}")
7677
for i, (p, a) in enumerate(zip(params, args)):
7778
if p.shape != a.shape: raise TypeError(f"arg {i} shape mismatch: expected {p.shape}, got {a.shape}")
7879
if p.dtype != a.dtype: raise TypeError(f"arg {i} dtype mismatch: expected {p.dtype}, got {a.dtype}")
79-
return c.src[0].substitute(dict(zip(params, args)))
80+
return c.src[0].substitute(dict(zip(params, args))).rtag(c.tag)
8081

8182
earliest_rewrites = mop_cleanup+PatternMatcher([
8283
# just removing it works...
@@ -533,11 +534,14 @@ def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
533534
(UPat((Ops.STORE, Ops.END), name="x"), split_store),
534535
])
535536

536-
def tag_uop(ctx:list[UOp], x:UOp):
537-
if x.tag is not None: return None
537+
def tag_uop(ctx:tuple[list[UOp], set[UOp]], x:UOp):
538+
if x.tag is not None or x in ctx[1]: return None
539+
if x.tag is None and x.op is Ops.CALL:
540+
# don't tag anything in a CALL
541+
for u in x.src[0].toposort(): ctx[1].add(u)
538542
if x.dtype.scalar() == dtypes.index: return None
539-
ctx.append(x)
540-
return x.replace(tag=(len(ctx)-1,))
543+
ctx[0].append(x)
544+
return x.replace(tag=(len(ctx[0])-1,))
541545
add_tags = PatternMatcher([
542546
# don't tag BUFFERs, they are global
543547
(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.LUNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.KERNEL, Ops.END,
@@ -563,7 +567,7 @@ def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
563567
def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
564568
if getenv("VIZ"): graph_rewrite(sink, PatternMatcher([]), name="View Input Graph")
565569
uop_list: list[UOp] = []
566-
tsink = graph_rewrite(sink, add_tags, ctx=uop_list, bottom_up=True, name="number the uops")
570+
tsink = graph_rewrite(sink, add_tags, ctx=(uop_list, set()), bottom_up=True, name="number the uops")
567571

568572
tsink = graph_rewrite(tsink, pm_mops+earliest_rewrites+replace_contiguous, ctx={}, name="earliest rewrites")
569573

tinygrad/tensor.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,15 @@ def dtype(self) -> DType: return self.uop.dtype
232232

233233
# ***** data handlers ****
234234

235-
def call(self, *lst:Tensor, fxn:UOp, arg:Any=None) -> Tensor:
236-
return Tensor(UOp.call(*[t.uop for t in (self,)+lst], fxn=fxn, arg=arg))
235+
def as_param(self, slot:int):
236+
if self.uop.axis is not None:
237+
multi_shape = tuple([s//len(self.device) if i==self.uop.axis else s for i,s in enumerate(self.shape)])
238+
param = UOp.param(slot, self.dtype, multi_shape, self.device).multi(self.uop.axis)
239+
else:
240+
param = UOp.param(slot, self.dtype, self.shape, self.device)
241+
return Tensor(param, device=self.device)
242+
def call(self, *lst:Tensor, fxn:Tensor|UOp, grad_fxn:Callable|None=None) -> Tensor:
243+
return Tensor(UOp.call(*[t.uop for t in (self,)+lst], fxn=fxn.uop if isinstance(fxn, Tensor) else fxn, arg=grad_fxn), device=self.device)
237244

238245
def custom_kernel(self, *lst:Tensor, fxn:Callable, grad_fxn:Callable|None=None) -> list[Tensor]:
239246
"""

tinygrad/uop/ops.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ def multirange_str(rngs:Iterable[UOp], color=False, pad=None) -> str:
5858
if pad is not None: ret += " " * (pad-ansilen(ret))
5959
return ret
6060

61+
def shape_to_shape_arg(arg:tuple[sint, ...]) -> UOp:
62+
if len(arg) == 0: return UOp(Ops.VECTORIZE, dtypes.index.vec(0))
63+
elif all(isinstance(x, int) for x in arg): return UOp.const(dtypes.index.vec(len(arg)), cast(tuple[int, ...], arg))
64+
else: return UOp(Ops.VECTORIZE, dtypes.index.vec(len(arg)), tuple(UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in arg))
65+
6166
def consumer_map_from_toposort(lst:Iterable[UOp]):
6267
ret: dict[UOp, dict[UOp, None]] = {}
6368
for u in lst:
@@ -222,7 +227,7 @@ def _shape(self) -> tuple[sint, ...]|None:
222227
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,)
223228
case Ops.PARAM:
224229
# NOTE: copied from marg
225-
if len(self.src) == 1: return tuple(self.src[0].sgep(i) for i in range(self.src[0].dtype.count))
230+
if len(self.src) >= 1: return tuple(self.src[0].sgep(i) for i in range(self.src[0].dtype.count))
226231
return None
227232

228233
# passthrough ops
@@ -558,11 +563,7 @@ def _mop(self, op:Ops, arg, same_shape_noop:bool=False) -> UOp:
558563
case Ops.PAD | Ops.SHRINK: src_args = list(zip(*arg))
559564
case Ops.PERMUTE | Ops.FLIP: src_args = []
560565
case _: raise RuntimeError(f"{op} is not a MovementOp")
561-
usrcs = []
562-
for arg in src_args:
563-
if len(arg) == 0: usrcs.append(UOp(Ops.VECTORIZE, dtypes.index.vec(0)))
564-
elif all(isinstance(x, int) for x in arg): usrcs.append(UOp.const(dtypes.index.vec(len(arg)), arg))
565-
else: usrcs.append(UOp(Ops.VECTORIZE, dtypes.index.vec(len(arg)), tuple(UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in arg)))
566+
usrcs = [shape_to_shape_arg(arg) for arg in src_args]
566567
if len(usrcs) == 0: ret = UOp(op, self.dtype, (self,), arg)
567568
else: ret = UOp(op, self.dtype, (self,)+UOp.sink(*usrcs).simplify().src)
568569
# for all movement ops, we check shape property to validity check the movement op
@@ -826,8 +827,8 @@ def set(self:UOp, val:UOp|ConstType, end:UOp|tuple[UOp, ...]|list[UOp]=()) -> UO
826827

827828
# TODO: this should replace placeholder
828829
@staticmethod
829-
def param(slot:int, dtype:DType, shape:tuple[int, ...]|None=None):
830-
src = () if shape is None else (UOp.const(dtypes.index.vec(len(shape)), shape),)
830+
def param(slot:int, dtype:DType, shape:tuple[sint, ...]|None=None, device=None):
831+
src = (UOp(Ops.NOOP) if shape is None else shape_to_shape_arg(shape),) + (() if device is None else (UOp(Ops.DEVICE, arg=device),))
831832
return UOp(Ops.PARAM, dtype, src, arg=slot)
832833

833834
def call(*srcs:UOp, fxn:UOp, arg:Any|None) -> UOp: return UOp(Ops.CALL, fxn.dtype, (fxn,)+srcs, arg)

0 commit comments

Comments
 (0)