-
Notifications
You must be signed in to change notification settings - Fork 78
Open
Labels
Description
The following patterns came out from DID loop split (#2563).
Case 1: split reshape (before SdpwFwd)
in: logical=[h], loop=[d, h/d] # d is an outer split of h
out: root=[h], logical=[a, h/a], loop=[d, a/d, h/a]
We may want to do inner split by d at some point. See case 3. I don't think whether inner or outer will make a whole lot of difference for IdModel.
Case 2: merge reshape (after SdpaFwd)
in: logical=[a, h/a], loop=[d, a/d, h/a]
out: root=[a, h/a], logical=[h], loop=[d, h/a]
Case 3: slice (used after the QKV linear in GPT)
There are several ways to represent that slice as mentioned in http://nv/ezS. One of the ways that I think is promising is:
in: logical=[b, s, hq+hk+hv], loop=[b, s, (hq+hk+hv)/d, d] # d is an inner split of hq+hk+hv
Q: logical=[b, s, hq], loop=[b, s, hq/d, d]
K: logical=[b, s, hk], loop=[b, s, hq/d, d]
V: logical=[b, s, hv], loop=[b, s, hq/d, d]
Case 4: cat (backprop of the above slice)
dQ: logical=[b, s, hq], loop=[b, s, hq/d, d]
dK: logical=[b, s, hk], loop=[b, s, hq/d, d]
dV: logical=[b, s, hv], loop=[b, s, hq/d, d]
out: logical=[b, s, hq+hk+hv], loop=[b, s, (hq+hk+hv)/d, d]
In all above cases, d, a, h* are static.
cc @naoyam who requested me to write this down
Reactions are currently unavailable