Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ def forward(self, L, *args, **kwargs): # noqa: N803
"""
return self.filter(L, *args, **kwargs)

@torch.compile(mode="max-autotune")
@torch.compile(mode="default")
def filter(self, L, *args, **kwargs): # noqa: N803
"""Compute the filter as a function of h and decay for the requested sequence length."""
h = self.h[:, :L]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,12 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
parser.add_argument(
"--grad-reduce-in-fp32", action="store_true", default=False, help="Gradient reduce in FP32."
) # DONE
parser.add_argument(
"--fsdp",
action="store_true",
default=False,
help="Enable FSDP training.",
)
parser.add_argument("--use-megatron-comm-overlap-llama3-8k", action="store_true", default=False) # DONE
parser.add_argument(
"--tp-comm-overlap-backend",
Expand Down Expand Up @@ -710,9 +716,9 @@ def train(args: argparse.Namespace) -> None:
recipe_kwargs["stride"] = args.stride
recipe_kwargs["window_min_length_threshold"] = args.window_min_length_threshold
recipe_kwargs["rc_aug"] = args.rc_aug
elif args.dataset_config_path:
elif args.dataset_config:
recipe_kwargs["dataset_dir"] = args.dataset_dir
recipe_kwargs["dataset_config_path"] = args.dataset_config_path
recipe_kwargs["dataset_config_path"] = args.dataset_config
Comment on lines -713 to +721
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to merge this block at a minimum


recipe_kwargs["pad_eod_loss_mask"] = args.eod_pad_in_loss_mask

Expand Down Expand Up @@ -747,7 +753,7 @@ def train(args: argparse.Namespace) -> None:
cfg: ConfigContainer = pretrain_config(**recipe_kwargs)

cfg.checkpoint.async_save = args.ckpt_async_save
cfg.checkpoint.ckpt_format = args.ckpt_format
cfg.checkpoint.ckpt_format = args.ckpt_format if not args.fsdp else "fsdp_dtensor"
cfg.checkpoint.save_interval = args.eval_interval
cfg.checkpoint.save_optim = True
cfg.checkpoint.save_rng = True
Expand Down Expand Up @@ -828,6 +834,10 @@ def train(args: argparse.Namespace) -> None:
cfg.ddp.overlap_grad_reduce = args.overlap_grad_reduce
cfg.ddp.grad_reduce_in_fp32 = args.grad_reduce_in_fp32
cfg.ddp.check_for_nan_in_grad = not args.no_check_for_nan_in_grad
if args.fsdp:
cfg.ddp.data_parallel_sharding_strategy = "optim_grads_params"
cfg.ddp.use_megatron_fsdp = True
cfg.checkpoint.ckpt_format = "fsdp_dtensor"
if args.use_megatron_comm_overlap_llama3_8k:
# Pick the floating point appropriate config.
fp8 = "fp8" in args.mixed_precision_recipe
Expand Down