Skip to content

Commit 6de403e

Browse files
TroyGardenmeta-codesync[bot]
authored andcommitted
add pipeline phases in train_pipeline.types (#3443)
Summary: Pull Request resolved: #3443 # context * add pipeline phases in types for better describing a pipeline phase * this can help parameterizing training pipeline options Reviewed By: spmex Differential Revision: D77062545 fbshipit-source-id: 92ad98235b59544d27a0bb277ee305888258377a
1 parent 283e2f8 commit 6de403e

File tree

1 file changed

+69
-0
lines changed
  • torchrec/distributed/train_pipeline

1 file changed

+69
-0
lines changed

torchrec/distributed/train_pipeline/types.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,72 @@ class PipelineState(Enum):
9696
IDLE = 0
9797
CALL_FWD = 1
9898
CALL_BWD = 2
99+
100+
def __str__(self) -> str:
101+
return self.name
102+
103+
104+
@unique
105+
class PipelinePhase(Enum):
106+
"""
107+
Pipeline phase for the train pipeline
108+
109+
please:
110+
1. order the phases in the order of execution of base pipeline.
111+
2. add notes to explain the phases if needed.
112+
113+
"""
114+
115+
def __str__(self) -> str:
116+
return self.value
117+
118+
def __eq__(self, obj: "PipelinePhase") -> bool:
119+
return self.value == obj.value
120+
121+
# placeholder for empty
122+
NULL = "null"
123+
124+
# usually the data is first available on CPU when loading from dataloader
125+
# need to move/copy the input batch to device if using GPU training
126+
COPY_BATCH_TO_DEVICE = "copy_batch_to_device"
127+
128+
# input post processing is needed for sparse data dist pipeline, where the sparse features
129+
# are traced (built) from the ModelInput via fx tracing
130+
INPUT_POST_PROC = "input_post_proc"
131+
132+
# the sparse features (AKA, KJTs) are in a jagged format so the data size are unknown to
133+
# other ranks. so a comms is needed to exchange the data size info, i.e., the splits
134+
INPUT_SPLITS_DIST = "input_splits_dist"
135+
136+
# once a rank knows the data size from other ranks (via splits dist), it can initialize
137+
# a all-to-all comms to exchange the actual data of the sparse features
138+
# NOTE: the splits have to be available on the host side
139+
INPUT_DATA_DIST = "input_data_dist"
140+
141+
# embedding lookup is done in FBGEMM.TBE on each rank
142+
EMBEDDING_LOOKUP = "embedding_lookup"
143+
144+
# the embedding lookup results (i.e., the embeddings) are needed in each rank, it's often done
145+
# with the output dist with an all_to_all comms
146+
EMBEDDING_OUTPUT_DIST = "embedding_output_dist"
147+
148+
# A typical DLRM model arch contains sparse arch and dense arch, here we treat the model excluding
149+
# "sparse modules" as dense part. It actually also includes the dense-sharded embedding tables.
150+
DENSE_FORWARD = "dense_forward"
151+
152+
# model's backward usually uses torch.autograd, the embedding modules' backward is handled by TBE
153+
DENSE_BACKWARD = "dense_backward"
154+
155+
# on each rank, after dense arch's backward, the gradients are available for the embedding tables
156+
# a backward of "embedding output dist" is needed to gather the embedding gradients from all ranks
157+
# to the rank where the embedding table is hosted.
158+
EMBEDDING_GRAD_DIST = "embedding_grad_dist"
159+
160+
# TBE backward usually update the embedding table weights inplace
161+
EMBEDDING_BACKWARD = "embedding_backward"
162+
163+
# we decouple the embedding update from backward just in case the change is not coupled
164+
EMBEDDING_UPDATE = "embedding_update"
165+
166+
# the optimizer step usually only includes the dense module weights
167+
DENSE_OPTIMIZER_STEP = "dense_optimizer_step"

0 commit comments

Comments
 (0)