diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 4a539f909..212f8831a 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -99,6 +99,12 @@ def from_dict(cls, d: Mapping): valid_args = {k: v for k, v in d.items() if k in all_fields} return cls(**valid_args) + def asdict(self): + # Use the full object instead of a Dict + ret = asdict(self) + ret["guided_decoding"] = self.guided_decoding + return ret + @dataclass class EngineConfig(EngineArgs): @@ -254,7 +260,7 @@ async def setup(self): # Setup sampling params self.sampling_params = get_default_sampling_params( - self.vllm_config, overrides=asdict(self.sampling_config) + self.vllm_config, overrides=self.sampling_config.asdict() ) # Setup processors