File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed
Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -198,13 +198,13 @@ def main(job_config: JobConfig):
198198 )
199199 model_config .fuse_norm = False
200200 if parallel_dims .loss_parallel_enabled :
201- if model_config .fuse_cross_entropy :
201+ if model_config .fuse_linear_cross_entropy :
202202 logger .warning (
203203 f"{ color .red } "
204204 f"Loss parallel enabled. Disabling fused cross entropy for now."
205205 f"{ color .reset } "
206206 )
207- model_config .fuse_cross_entropy = False
207+ model_config .fuse_linear_cross_entropy = False
208208 model_config .vocab_size = max (tokenizer .vocab_size , model_config .vocab_size )
209209
210210 logger .info (
@@ -213,7 +213,7 @@ def main(job_config: JobConfig):
213213 with torch .device ("meta" ):
214214 model = AutoModelForCausalLM .from_config (model_config )
215215 if (
216- getattr (model_config , "fuse_cross_entropy " , False )
216+ getattr (model_config , "fuse_linear_cross_entropy " , False )
217217 and FusedLinearCrossEntropyLoss is not None
218218 ):
219219 model .criterion = FusedLinearCrossEntropyLoss (
You can’t perform that action at this time.
0 commit comments