Skip to content

Extend IdModel to map DIDs for certain patterns. #3987

@wujingyue

Description

@wujingyue

The following patterns came out from DID loop split (#2563).

Case 1: split reshape (before SdpwFwd)

in: logical=[h], loop=[d, h/d]  # d is an outer split of h
out: root=[h], logical=[a, h/a], loop=[d, a/d, h/a]

We may want to do inner split by d at some point. See case 3. I don't think whether inner or outer will make a whole lot of difference for IdModel.

Case 2: merge reshape (after SdpaFwd)

in: logical=[a, h/a], loop=[d, a/d, h/a]
out: root=[a, h/a], logical=[h], loop=[d, h/a]

Case 3: slice (used after the QKV linear in GPT)

There are several ways to represent that slice as mentioned in http://nv/ezS. One of the ways that I think is promising is:

in: logical=[b, s, hq+hk+hv], loop=[b, s, (hq+hk+hv)/d, d]  # d is an inner split of hq+hk+hv
Q: logical=[b, s, hq], loop=[b, s, hq/d, d]
K: logical=[b, s, hk], loop=[b, s, hq/d, d]
V: logical=[b, s, hv], loop=[b, s, hq/d, d]

Case 4: cat (backprop of the above slice)

dQ: logical=[b, s, hq], loop=[b, s, hq/d, d]
dK: logical=[b, s, hk], loop=[b, s, hq/d, d]
dV: logical=[b, s, hv], loop=[b, s, hq/d, d]
out: logical=[b, s, hq+hk+hv], loop=[b, s, (hq+hk+hv)/d, d]

In all above cases, d, a, h* are static.

cc @naoyam who requested me to write this down

Metadata

Metadata

Assignees

No one assigned

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions