@@ -78,13 +78,13 @@ class FedOptRecipe(Recipe):
7878 server_expected_format (str): What format to exchange the parameters between server and client.
7979 source_model (str): ID of the source model component. Defaults to "model".
8080 optimizer_args (dict): Configuration for server-side optimizer with keys:
81- - path: Path to optimizer class (e.g., "torch.optim.SGD")
81+ - class_path: Fully qualified optimizer class (e.g., "torch.optim.SGD"). "path" is also accepted.
8282 - args: Dictionary of optimizer arguments (e.g., {"lr": 1.0, "momentum": 0.6})
83- - config_type: Type of configuration, typically "dict"
83+ - config_type: Optional; if omitted, set to "dict" so the config is not instantiated at load time.
8484 lr_scheduler_args (dict): Optional configuration for learning rate scheduler with keys:
85- - path: Path to scheduler class (e.g., "torch.optim.lr_scheduler.CosineAnnealingLR")
85+ - class_path: Fully qualified scheduler class (e.g., "torch.optim.lr_scheduler.CosineAnnealingLR"). "path" is also accepted.
8686 - args: Dictionary of scheduler arguments (e.g., {"T_max": 100, "eta_min": 0.9})
87- - config_type: Type of configuration, typically "dict"
87+ - config_type: Optional; if omitted, set to "dict" so the config is not instantiated at load time.
8888 device (str): Device to use for server-side optimization, e.g. "cpu" or "cuda:0".
8989 Defaults to None; will default to cuda if available and no device is specified.
9090 server_memory_gc_rounds: Run memory cleanup (gc.collect + malloc_trim) every N rounds on server.
@@ -102,12 +102,12 @@ class FedOptRecipe(Recipe):
102102 device="cpu",
103103 source_model="model",
104104 optimizer_args={
105- "path ": "torch.optim.SGD",
105+ "class_path ": "torch.optim.SGD",
106106 "args": {"lr": 1.0, "momentum": 0.6},
107107 "config_type": "dict"
108108 },
109109 lr_scheduler_args={
110- "path ": "torch.optim.lr_scheduler.CosineAnnealingLR",
110+ "class_path ": "torch.optim.lr_scheduler.CosineAnnealingLR",
111111 "args": {"T_max": "{num_rounds}", "eta_min": 0.9},
112112 "config_type": "dict"
113113 }
@@ -158,7 +158,7 @@ def __init__(
158158 self .initial_ckpt = v .initial_ckpt
159159
160160 # Validate inputs using shared utilities
161- from nvflare .recipe .utils import recipe_model_to_job_model , validate_ckpt
161+ from nvflare .recipe .utils import ensure_config_type_dict , recipe_model_to_job_model , validate_ckpt
162162
163163 validate_ckpt (self .initial_ckpt )
164164 if isinstance (self .model , dict ):
@@ -174,8 +174,10 @@ def __init__(
174174 self .server_expected_format : ExchangeFormat = v .server_expected_format
175175 self .device = device
176176 self .source_model = source_model
177- self .optimizer_args = optimizer_args
178- self .lr_scheduler_args = lr_scheduler_args
177+ # Ensure config_type "dict" so the component builder does not try to instantiate
178+ # optimizer/scheduler at config load time (params/optimizer are set at runtime).
179+ self .optimizer_args = ensure_config_type_dict (optimizer_args )
180+ self .lr_scheduler_args = ensure_config_type_dict (lr_scheduler_args )
179181 self .server_memory_gc_rounds = v .server_memory_gc_rounds
180182
181183 # Replace {num_rounds} placeholder if present in lr_scheduler_args
0 commit comments