Skip to content

Commit 6e49e20

Browse files
authored
Disable fused linear CE if necessary
1 parent 8684660 commit 6e49e20

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

flame/train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,13 +198,13 @@ def main(job_config: JobConfig):
198198
)
199199
model_config.fuse_norm = False
200200
if parallel_dims.loss_parallel_enabled:
201-
if model_config.fuse_cross_entropy:
201+
if model_config.fuse_linear_cross_entropy:
202202
logger.warning(
203203
f"{color.red}"
204204
f"Loss parallel enabled. Disabling fused cross entropy for now."
205205
f"{color.reset}"
206206
)
207-
model_config.fuse_cross_entropy = False
207+
model_config.fuse_linear_cross_entropy = False
208208
model_config.vocab_size = max(tokenizer.vocab_size, model_config.vocab_size)
209209

210210
logger.info(
@@ -213,7 +213,7 @@ def main(job_config: JobConfig):
213213
with torch.device("meta"):
214214
model = AutoModelForCausalLM.from_config(model_config)
215215
if (
216-
getattr(model_config, "fuse_cross_entropy", False)
216+
getattr(model_config, "fuse_linear_cross_entropy", False)
217217
and FusedLinearCrossEntropyLoss is not None
218218
):
219219
model.criterion = FusedLinearCrossEntropyLoss(

0 commit comments

Comments
 (0)