Skip to content

[FEA] Support recompute for HSTU #6

@shijieliu

Description

@shijieliu

Is your feature request related to a problem? Please describe.
The activations in HSTU can be a large limiter to scaling model. We need to investigate how to implement recomputation and expose to users.
Reference: Activation memory analysis doc

Describe the solution you'd like

PoC

  • Basic recompute w/o overlap. We will build intra-layer recompute on top of Enhance dynamicemb' example #31. The basic idea is to rewrite HUSTLayer that contains ln-linear_bias-silu-fused_attention-eltwise_mul-dropout-linear_add, find/configure and build the fusable pattern with the help of triton.
  • Intra layer overlap. While doing backward, there might be chances for overlapping 2 fused ops.
  • Inter layer overlap. While performing backward layer[i+1], we can prefetch the forward of layer[i].

Describe alternatives you've considered
N/A

Additional context
N/A

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions