Skip to content

Commit c871703

Browse files
fix symmetry
1 parent 875b1da commit c871703

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

rsl_rl/algorithms/ppo.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,14 @@ def __init__(
8383
if isinstance(symmetry_cfg["data_augmentation_func"], str):
8484
symmetry_cfg["data_augmentation_func"] = string_to_callable(symmetry_cfg["data_augmentation_func"])
8585
# Check valid configuration
86-
if symmetry_cfg["use_data_augmentation"] and not callable(symmetry_cfg["data_augmentation_func"]):
86+
if not callable(symmetry_cfg["data_augmentation_func"]):
8787
raise ValueError(
88-
"Data augmentation enabled but the function is not callable:"
89-
f" {symmetry_cfg['data_augmentation_func']}"
88+
f"Symmetry configuration exists but the function is not callable: "
89+
f"{symmetry_cfg['data_augmentation_func']}"
9090
)
91+
# Check if the policy is compatible with symmetry
92+
if isinstance(policy, ActorCriticRecurrent):
93+
raise ValueError("Symmetry augmentation is not supported for recurrent policies.")
9194
# Store symmetry configuration
9295
self.symmetry = symmetry_cfg
9396
else:

0 commit comments

Comments
 (0)