@@ -96,3 +96,72 @@ class PipelineState(Enum):
96
96
IDLE = 0
97
97
CALL_FWD = 1
98
98
CALL_BWD = 2
99
+
100
+ def __str__ (self ) -> str :
101
+ return self .name
102
+
103
+
104
+ @unique
105
+ class PipelinePhase (Enum ):
106
+ """
107
+ Pipeline phase for the train pipeline
108
+
109
+ please:
110
+ 1. order the phases in the order of execution of base pipeline.
111
+ 2. add notes to explain the phases if needed.
112
+
113
+ """
114
+
115
+ def __str__ (self ) -> str :
116
+ return self .value
117
+
118
+ def __eq__ (self , obj : "PipelinePhase" ) -> bool :
119
+ return self .value == obj .value
120
+
121
+ # placeholder for empty
122
+ NULL = "null"
123
+
124
+ # usually the data is first available on CPU when loading from dataloader
125
+ # need to move/copy the input batch to device if using GPU training
126
+ COPY_BATCH_TO_DEVICE = "copy_batch_to_device"
127
+
128
+ # input post processing is needed for sparse data dist pipeline, where the sparse features
129
+ # are traced (built) from the ModelInput via fx tracing
130
+ INPUT_POST_PROC = "input_post_proc"
131
+
132
+ # the sparse features (AKA, KJTs) are in a jagged format so the data size are unknown to
133
+ # other ranks. so a comms is needed to exchange the data size info, i.e., the splits
134
+ INPUT_SPLITS_DIST = "input_splits_dist"
135
+
136
+ # once a rank knows the data size from other ranks (via splits dist), it can initialize
137
+ # a all-to-all comms to exchange the actual data of the sparse features
138
+ # NOTE: the splits have to be available on the host side
139
+ INPUT_DATA_DIST = "input_data_dist"
140
+
141
+ # embedding lookup is done in FBGEMM.TBE on each rank
142
+ EMBEDDING_LOOKUP = "embedding_lookup"
143
+
144
+ # the embedding lookup results (i.e., the embeddings) are needed in each rank, it's often done
145
+ # with the output dist with an all_to_all comms
146
+ EMBEDDING_OUTPUT_DIST = "embedding_output_dist"
147
+
148
+ # A typical DLRM model arch contains sparse arch and dense arch, here we treat the model excluding
149
+ # "sparse modules" as dense part. It actually also includes the dense-sharded embedding tables.
150
+ DENSE_FORWARD = "dense_forward"
151
+
152
+ # model's backward usually uses torch.autograd, the embedding modules' backward is handled by TBE
153
+ DENSE_BACKWARD = "dense_backward"
154
+
155
+ # on each rank, after dense arch's backward, the gradients are available for the embedding tables
156
+ # a backward of "embedding output dist" is needed to gather the embedding gradients from all ranks
157
+ # to the rank where the embedding table is hosted.
158
+ EMBEDDING_GRAD_DIST = "embedding_grad_dist"
159
+
160
+ # TBE backward usually update the embedding table weights inplace
161
+ EMBEDDING_BACKWARD = "embedding_backward"
162
+
163
+ # we decouple the embedding update from backward just in case the change is not coupled
164
+ EMBEDDING_UPDATE = "embedding_update"
165
+
166
+ # the optimizer step usually only includes the dense module weights
167
+ DENSE_OPTIMIZER_STEP = "dense_optimizer_step"
0 commit comments