@@ -345,15 +345,15 @@ def _register_sgd_variants(registry: OptimizerRegistry) -> None:
345345 OptimInfo (
346346 name = 'sgd' ,
347347 opt_class = optim .SGD ,
348- description = 'Stochastic Gradient Descent with Nesterov momentum (default) ' ,
348+ description = 'torch.Optim Stochastic Gradient Descent (SGD) with Nesterov momentum' ,
349349 has_eps = False ,
350350 has_momentum = True ,
351351 defaults = {'nesterov' : True }
352352 ),
353353 OptimInfo (
354354 name = 'momentum' ,
355355 opt_class = optim .SGD ,
356- description = 'Stochastic Gradient Descent with classical momentum' ,
356+ description = 'torch.Optim Stochastic Gradient Descent (SGD) with classical momentum' ,
357357 has_eps = False ,
358358 has_momentum = True ,
359359 defaults = {'nesterov' : False }
@@ -798,7 +798,7 @@ def get_optimizer_info(name: str) -> OptimInfo:
798798
799799def get_optimizer_class (
800800 name : str ,
801- bind_defaults : bool = False ,
801+ bind_defaults : bool = True ,
802802) -> Union [Type [optim .Optimizer ], OptimizerCallable ]:
803803 """Get optimizer class by name with option to bind default arguments.
804804
@@ -821,17 +821,14 @@ def get_optimizer_class(
821821 ValueError: If optimizer name is not found in registry
822822
823823 Examples:
824- >>> # Get raw optimizer class
825- >>> Adam = get_optimizer_class('adam')
826- >>> opt = Adam(model.parameters(), lr=1e-3)
827-
828- >>> # Get optimizer with defaults bound
829- >>> AdamWithDefaults = get_optimizer_class('adam', bind_defaults=True)
830- >>> opt = AdamWithDefaults(model.parameters(), lr=1e-3)
831-
832824 >>> # Get SGD with nesterov momentum default
833- >>> SGD = get_optimizer_class('sgd', bind_defaults=True ) # nesterov=True bound
825+ >>> SGD = get_optimizer_class('sgd') # nesterov=True bound
834826 >>> opt = SGD(model.parameters(), lr=0.1, momentum=0.9)
827+
828+ >>> # Get raw optimizer class
829+ >>> SGD = get_optimizer_class('sgd')
830+ >>> opt = SGD(model.parameters(), lr=1e-3, momentum=0.9)
831+
835832 """
836833 return default_registry .get_optimizer_class (name , bind_defaults = bind_defaults )
837834
0 commit comments