diff --git a/torchrec_dlrm/dlrm_main.py b/torchrec_dlrm/dlrm_main.py index 6d32d6cb..5922dafc 100644 --- a/torchrec_dlrm/dlrm_main.py +++ b/torchrec_dlrm/dlrm_main.py @@ -36,6 +36,11 @@ from torchrec.optim.optimizers import in_backward_optimizer_filter from tqdm import tqdm +try: + from distributed_shampoo import DistributedShampoo, SGDPreconditionerConfig +except ImportError: + pass + # OSS import try: # pyre-ignore[21] @@ -80,6 +85,12 @@ def parse_args(argv: list[str]) -> argparse.Namespace: default=1, help="number of epochs to train", ) + parser.add_argument( + "--precondition_frequency", + type=int, + default=100, + help="number of steps before running preconditioner", + ) parser.add_argument( "--batch_size", type=int, @@ -263,6 +274,16 @@ def parse_args(argv: list[str]) -> argparse.Namespace: action="store_true", help="Flag to determine if adagrad optimizer should be used.", ) + parser.add_argument( + "--shampoo_embedding", + action="store_true", + help="Use DistributedShampoo optimizer.", + ) + parser.add_argument( + "--shampoo_dense", + action="store_true", + help="Use DistributedShampoo optimizer.", + ) parser.add_argument( "--interaction_type", type=InteractionType, @@ -491,8 +512,8 @@ def train_val_test( args.limit_train_batches, args.limit_val_batches, ) - val_auroc = _evaluate(args.limit_val_batches, pipeline, val_dataloader, "val") - results.val_aurocs.append(val_auroc) + #val_auroc = _evaluate(args.limit_val_batches, pipeline, val_dataloader, "val") + results.val_aurocs.append(0.0) test_auroc = _evaluate(args.limit_test_batches, pipeline, test_dataloader, "test") results.test_auroc = test_auroc @@ -635,7 +656,7 @@ def main(argv: list[str]) -> None: ) train_model = DLRMTrain(dlrm_model) - embedding_optimizer = torch.optim.Adagrad if args.adagrad else torch.optim.SGD + # embedding_optimizer = torch.optim.Adagrad if args.adagrad else torch.optim.SGD # This will apply the Adagrad optimizer in the backward pass for the embeddings (sparse_arch). This means that # the optimizer update will be applied in the backward pass, in this case through a fused op. # TorchRec will use the FBGEMM implementation of EXACT_ADAGRAD. For GPU devices, a fused CUDA kernel is invoked. For CPU, FBGEMM_GPU invokes CPU kernels @@ -643,14 +664,15 @@ def main(argv: list[str]) -> None: # Note that lr_decay, weight_decay and initial_accumulator_value for Adagrad optimizer in FBGEMM v0.3.2 # cannot be specified below. This equivalently means that all these parameters are hardcoded to zero. - optimizer_kwargs = {"lr": args.learning_rate} - if args.adagrad: - optimizer_kwargs["eps"] = args.eps - apply_optimizer_in_backward( - embedding_optimizer, - train_model.model.sparse_arch.parameters(), - optimizer_kwargs, - ) + # optimizer_kwargs = {"lr": args.learning_rate} + # if args.adagrad: + # optimizer_kwargs["eps"] = args.eps + + # apply_optimizer_in_backward( + # embedding_optimizer, + # train_model.model.sparse_arch.parameters(), + # optimizer_kwargs, + # ) planner = EmbeddingShardingPlanner( topology=Topology( local_world_size=get_local_size(), @@ -660,7 +682,7 @@ def main(argv: list[str]) -> None: batch_size=args.batch_size, # If experience OOM, increase the percentage. see # https://pytorch.org/torchrec/torchrec.distributed.planner.html#torchrec.distributed.planner.storage_reservations.HeuristicalStorageReservation - storage_reservation=HeuristicalStorageReservation(percentage=0.05), + storage_reservation=HeuristicalStorageReservation(percentage=0.2), ) plan = planner.collective_plan( train_model, get_default_sharders(), dist.GroupMember.WORLD @@ -678,18 +700,55 @@ def main(argv: list[str]) -> None: print(table_name, "\n", plan, "\n") def optimizer_with_params(): - if args.adagrad: + if args.shampoo_dense: + return lambda params: DistributedShampoo( + params, + lr=0.001, + betas=(0., 0.999), + epsilon=1e-12, + momentum=0.9, + weight_decay=1e-05, + max_preconditioner_dim=8192, + precondition_frequency=args.precondition_frequency, + grafting_config=SGDPreconditionerConfig(), + ) + elif args.adagrad: + return lambda params: torch.optim.Adagrad( + params, lr=args.learning_rate, eps=args.eps + ) + else: + return lambda params: torch.optim.SGD(params, lr=args.learning_rate) + + def embedding_optimizer_with_params(): + if args.shampoo_embedding: + return lambda params: DistributedShampoo( + params, + lr=args.learning_rate, + betas=(0., 0.999), + epsilon=args.eps, + momentum=0.9, + weight_decay=1e-05, + max_preconditioner_dim=8192, + precondition_frequency=args.precondition_frequency, + grafting_config=SGDPreconditionerConfig(), + ) + elif args.adagrad: return lambda params: torch.optim.Adagrad( params, lr=args.learning_rate, eps=args.eps ) else: return lambda params: torch.optim.SGD(params, lr=args.learning_rate) + embedding_optimizer = KeyedOptimizerWrapper( + dict(in_backward_optimizer_filter(model.named_parameters(), include=True)), + embedding_optimizer_with_params(), + ) + dense_optimizer = KeyedOptimizerWrapper( dict(in_backward_optimizer_filter(model.named_parameters())), optimizer_with_params(), ) - optimizer = CombinedOptimizer([model.fused_optimizer, dense_optimizer]) + optimizer = CombinedOptimizer([embedding_optimizer, dense_optimizer]) lr_scheduler = LRPolicyScheduler( optimizer, args.lr_warmup_steps, args.lr_decay_start, args.lr_decay_steps )