@@ -2875,19 +2875,35 @@ def test_data_accessor(n_envs, batched, tol):
28752875 # * Call 'Get' -> Call 'Set' with 'Get' output -> Call 'Get'
28762876 # * Compare first 'Get' output with last 'Get' output
28772877 # * Compare last 'Get' output with corresponding slice of non-masking 'Get' output
2878- def get_all_supported_masks (i ):
2878+ def get_all_supported_masks (i , max_length ):
2879+ if max_length <= 0 or i > max_length - 1 :
2880+ return (None ,)
2881+ if i == max_length - 1 :
2882+ return (
2883+ i ,
2884+ [i ],
2885+ slice (i , i + 1 ),
2886+ range (i , i + 1 ),
2887+ np .array ([i ], dtype = np .int32 ),
2888+ torch .tensor ([i ], dtype = torch .int64 ),
2889+ torch .tensor ([i ], dtype = gs .tc_int , device = gs .device ),
2890+ )
28792891 return (
2880- i ,
2881- [i ],
2882- slice (i , i + 1 ),
2883- range (i , i + 1 ),
2884- np .array ([i ], dtype = np .int32 ),
2885- torch .tensor ([i ], dtype = torch .int64 ),
2886- torch .tensor ([i ], dtype = gs .tc_int , device = gs .device ),
2892+ [i , i + 1 ],
2893+ slice (i , i + 2 ),
2894+ range (i , i + 2 ),
2895+ np .array ([i , i + 1 ], dtype = np .int32 ),
2896+ torch .tensor ([i , i + 1 ], dtype = torch .int64 ),
2897+ torch .tensor ([i , i + 1 ], dtype = gs .tc_int , device = gs .device ),
28872898 )
28882899
2889- def must_cast (value ):
2890- return not (isinstance (value , torch .Tensor ) and value .dtype == gs .tc_int and value .device == gs .device )
2900+ def must_cast (value , dtype ):
2901+ return not (
2902+ isinstance (value , torch .Tensor )
2903+ and value .is_contiguous ()
2904+ and value .dtype == dtype
2905+ and value .device == gs .device
2906+ )
28912907
28922908 for arg1_max , arg2_max , getter_or_spec , setter , ti_data in (
28932909 # SOLVER
@@ -3010,66 +3026,73 @@ def must_cast(value):
30103026
30113027 # Check getter and setter for all possible combinations of row and column masking
30123028 for i in range (arg1_max ) if arg1_max > 0 else (None ,):
3013- for arg1 in get_all_supported_masks (i ) if arg1_max > 0 else (None ,):
3029+ if i is not None :
3030+ mask_i = [i , i + 1 ] if i < arg1_max - 1 else [i ]
3031+ for arg1 in get_all_supported_masks (i , arg1_max ):
30143032 for j in range (max (arg2_max , 1 )) if arg2_max >= 0 else (None ,):
3015- for arg2 in get_all_supported_masks (j ) if arg2_max > 0 else (None ,):
3033+ if j is not None :
3034+ mask_j = [j , j + 1 ] if j < arg2_max - 1 else [j ]
3035+ for arg2 in get_all_supported_masks (j , arg2_max ):
30163036 if arg1 is None and arg2 is not None :
3017- unsafe = not must_cast (arg2 )
3037+ unsafe = not must_cast (arg2 , gs . tc_int )
30183038 if getter is not None :
30193039 data = deepcopy (getter (arg2 , unsafe = unsafe ))
30203040 else :
30213041 if is_tuple :
3022- data = [torch .ones ((1 , * shape )) for shape in spec ]
3042+ data = [torch .ones ((len ( mask_j ) , * shape )) for shape in spec ]
30233043 else :
3024- data = torch .ones ((1 , * spec ))
3044+ data = torch .ones ((len ( mask_j ) , * spec ))
30253045 if setter is not None :
3046+ unsafe &= not must_cast (data , gs .tc_float )
30263047 setter (data , arg2 , unsafe = unsafe )
30273048 if n_envs :
30283049 if is_tuple :
3029- data_ = [val [[ j ] ] for val in datas ]
3050+ data_ = [val [mask_j ] for val in datas ]
30303051 else :
3031- data_ = datas [[ j ] ]
3052+ data_ = datas [mask_j ]
30323053 else :
30333054 data_ = datas
30343055 elif arg1 is not None and arg2 is None :
3035- unsafe = not must_cast (arg1 )
3056+ unsafe = not must_cast (arg1 , gs . tc_int )
30363057 if getter is not None :
30373058 data = deepcopy (getter (arg1 , unsafe = unsafe ))
30383059 else :
30393060 if is_tuple :
3040- data = [torch .ones ((1 , * shape )) for shape in spec ]
3061+ data = [torch .ones ((len ( mask_i ) , * shape )) for shape in spec ]
30413062 else :
3042- data = torch .ones ((1 , * spec ))
3063+ data = torch .ones ((len ( mask_i ) , * spec ))
30433064 if setter is not None :
3065+ unsafe &= not must_cast (data , gs .tc_float )
30443066 if is_tuple :
30453067 setter (* data , arg1 , unsafe = unsafe )
30463068 else :
30473069 setter (data , arg1 , unsafe = unsafe )
30483070 if is_tuple :
3049- data_ = [val [[ i ] ] for val in datas ]
3071+ data_ = [val [mask_i ] for val in datas ]
30503072 else :
3051- data_ = datas [[ i ] ]
3073+ data_ = datas [mask_i ]
30523074 else :
3053- unsafe = not any (map ( must_cast , (arg1 , arg2 ) ))
3075+ unsafe = not any (must_cast ( arg , gs . tc_int ) for arg in (arg1 , arg2 ))
30543076 if getter is not None :
30553077 data = deepcopy (getter (arg1 , arg2 , unsafe = unsafe ))
30563078 else :
30573079 if is_tuple :
3058- data = [torch .ones ((1 , 1 , * shape )) for shape in spec ]
3080+ data = [torch .ones ((len ( mask_j ), len ( mask_i ) , * shape )) for shape in spec ]
30593081 else :
3060- data = torch .ones ((1 , 1 , * spec ))
3082+ data = torch .ones ((len ( mask_j ), len ( mask_i ) , * spec ))
30613083 if setter is not None :
3084+ unsafe &= not must_cast (data , gs .tc_float )
30623085 setter (data , arg1 , arg2 , unsafe = unsafe )
30633086 if is_tuple :
3064- data_ = [val [[ j ] , :][:, [ i ] ] for val in datas ]
3087+ data_ = [val [mask_j , :][:, mask_i ] for val in datas ]
30653088 else :
3066- data_ = datas [[ j ] , :][:, [ i ] ]
3089+ data_ = datas [mask_j , :][:, mask_i ]
30673090 # FIXME: Not sure why tolerance must be increased for tests to pass
30683091 assert_allclose (data_ , data , tol = (5.0 * tol ))
30693092
3070- for dofs_idx in (* get_all_supported_masks (0 ), None ):
3071- for envs_idx in (* (get_all_supported_masks (0 ) if n_envs > 0 else ()), None ):
3072- unsafe = not any (map ( must_cast , (dofs_idx , envs_idx ) ))
3093+ for dofs_idx in (* get_all_supported_masks (0 , gs_s . n_dofs ), None ):
3094+ for envs_idx in (* (get_all_supported_masks (0 , gs_s . n_dofs ) if n_envs > 0 else ()), None ):
3095+ unsafe = not any (must_cast ( arg , gs . tc_int ) for arg in (dofs_idx , envs_idx ))
30733096 dofs_pos = gs_s .get_dofs_position (dofs_idx , envs_idx )
30743097 dofs_vel = gs_s .get_dofs_velocity (dofs_idx , envs_idx )
30753098 gs_s .control_dofs_position (dofs_pos , dofs_idx , envs_idx )
0 commit comments