File tree Expand file tree Collapse file tree 1 file changed +6
-3
lines changed Expand file tree Collapse file tree 1 file changed +6
-3
lines changed Original file line number Diff line number Diff line change @@ -83,11 +83,14 @@ def __init__(
8383 if isinstance (symmetry_cfg ["data_augmentation_func" ], str ):
8484 symmetry_cfg ["data_augmentation_func" ] = string_to_callable (symmetry_cfg ["data_augmentation_func" ])
8585 # Check valid configuration
86- if symmetry_cfg [ "use_data_augmentation" ] and not callable (symmetry_cfg ["data_augmentation_func" ]):
86+ if not callable (symmetry_cfg ["data_augmentation_func" ]):
8787 raise ValueError (
88- "Data augmentation enabled but the function is not callable:"
89- f" { symmetry_cfg ['data_augmentation_func' ]} "
88+ f"Symmetry configuration exists but the function is not callable: "
89+ f"{ symmetry_cfg ['data_augmentation_func' ]} "
9090 )
91+ # Check if the policy is compatible with symmetry
92+ if isinstance (policy , ActorCriticRecurrent ):
93+ raise ValueError ("Symmetry augmentation is not supported for recurrent policies." )
9194 # Store symmetry configuration
9295 self .symmetry = symmetry_cfg
9396 else :
You can’t perform that action at this time.
0 commit comments