File tree Expand file tree Collapse file tree 1 file changed +6
-3
lines changed Expand file tree Collapse file tree 1 file changed +6
-3
lines changed Original file line number Diff line number Diff line change @@ -83,11 +83,14 @@ def __init__(
8383            if  isinstance (symmetry_cfg ["data_augmentation_func" ], str ):
8484                symmetry_cfg ["data_augmentation_func" ] =  string_to_callable (symmetry_cfg ["data_augmentation_func" ])
8585            # Check valid configuration 
86-             if  symmetry_cfg [ "use_data_augmentation" ]  and   not  callable (symmetry_cfg ["data_augmentation_func" ]):
86+             if  not  callable (symmetry_cfg ["data_augmentation_func" ]):
8787                raise  ValueError (
88-                     "Data augmentation enabled  but the function is not callable:"
89-                     f"  { symmetry_cfg ['data_augmentation_func' ]}  " 
88+                     f"Symmetry configuration exists  but the function is not callable:  "
89+                     f"{ symmetry_cfg ['data_augmentation_func' ]}  " 
9090                )
91+             # Check if the policy is compatible with symmetry 
92+             if  isinstance (policy , ActorCriticRecurrent ):
93+                 raise  ValueError ("Symmetry augmentation is not supported for recurrent policies." )
9194            # Store symmetry configuration 
9295            self .symmetry  =  symmetry_cfg 
9396        else :
    
 
   
 
     
   
   
          
     
  
    
     
 
    
      
     
 
     
    You can’t perform that action at this time.
  
 
    
  
     
    
      
        
     
 
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments