Skip to content

MoE training optimization #2438

@ptrendx

Description

@ptrendx

This is a tracking issue for the multiple efforts needed for the performance optimization of MoE training, with the focus on D2H Sync-Free MoE. All the problem sizes should be supplied from device buffers.

TE/common:

TE/pyTorch:

  • Expose the grouped tensor type internally in PyTorch modules
  • Expose the grouped tensor type externally [@timmoon10 is doubtful of feasibility]
    • Expose the grouped tensor type as pyTorch tensor
    • Enable grouped tensor input to GroupedLinear
  • Enable single grouped tensor weight option in GroupedLinear
  • Utilize preswizzled inputs in the gemm
  • Changes to te.Sequential to enable grouped tensors
  • End to end MoE support in TransformerLayer

TE/JAX:

Sub-issues

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions