@@ -1234,6 +1234,37 @@ def clone_contiguous(x):
1234
1234
with self .assertRaisesRegex (RuntimeError , msg ):
1235
1235
vmap (lambda x : x .clone (memory_format = torch .channels_last_3d ))(torch .randn (B0 ))
1236
1236
1237
+ @parametrize ("case" ,
1238
+ (
1239
+ (torch .clamp_min_ , TensorFactory .randn ),
1240
+ (torch .clamp_max_ , TensorFactory .randn ),
1241
+ ), name_fn = lambda x : x [0 ].__name__ )
1242
+ def test_clamp_inplace_variant (self , case ):
1243
+ test = self ._vmap_test
1244
+
1245
+ def get_number (getter ):
1246
+ return getter ([]).item ()
1247
+
1248
+ op , getter = case
1249
+ device = 'cpu'
1250
+ B0 , B1 = 7 , 11
1251
+
1252
+ # Single vmap: op(Tensor, Tensor)
1253
+ test (op , (getter ([B0 , 3 ], device ), getter ([B0 , 3 ], device )), check_propagates_grad = False )
1254
+ test (op , (getter ([B0 ], device ), getter ([B0 ], device )), check_propagates_grad = False )
1255
+ test (op , (getter ([2 , B0 , 3 ], device ), getter ([2 , B0 , 3 ], device )), in_dims = (1 , 1 ), check_propagates_grad = False )
1256
+ test (op , (getter ([B0 , 2 , 3 ], device ), getter ([2 , B0 , 3 ], device )),
1257
+ in_dims = (0 , 1 ), out_dims = 1 , check_propagates_grad = False )
1258
+ test (op , (getter ([B0 , 2 , 3 ], device ), getter ([1 , 1 ], device )), in_dims = (0 , None ), check_propagates_grad = False )
1259
+ test (op , (getter ([B0 , 3 ], device ), getter ([B0 , 3 ], device )), in_dims = (0 , 0 ), check_propagates_grad = False )
1260
+
1261
+ # Nested vmap: op(Tensor, Tensor)
1262
+ test (vmap (op ), (getter ([B0 , B1 , 2 , 3 ], device ), getter ([B0 , B1 , 1 , 3 ], device )), check_propagates_grad = False )
1263
+
1264
+ # Python number overload: op(Tensor, Number)
1265
+ number = get_number (getter )
1266
+ self ._test_unary (lambda t : op (t , number ), getter , device , check_propagates_grad = False )
1267
+
1237
1268
@parametrize ('case' , [
1238
1269
subtest (_make_case (torch .clamp_min ), name = 'clamp_min' ),
1239
1270
subtest (_make_case (torch .clamp_max ), name = 'clamp_max' ),
@@ -1255,7 +1286,7 @@ def get_number(getter):
1255
1286
test (op , (getter ([B0 ], device ), getter ([2 , B0 , 3 ], device )),
1256
1287
in_dims = (0 , 1 ), out_dims = 1 )
1257
1288
test (op , (getter ([B0 ], device ), getter ([2 , 3 ], device )), in_dims = (0 , None ))
1258
- test (op , (getter ([2 , 3 ], device ), getter ([B0 , 3 ], device )), in_dims = (0 , None ))
1289
+ test (op , (getter ([2 , 3 ], device ), getter ([B0 , 3 ], device )), in_dims = (None , 0 ))
1259
1290
1260
1291
# Nested vmap: op(Tensor, Tensor)
1261
1292
test (vmap (op ), (getter ([B0 , B1 , 2 , 3 ], device ), getter ([B0 , B1 , 3 ], device )))
@@ -3069,7 +3100,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
3069
3100
xfail ('hstack' ),
3070
3101
xfail ('linalg.multi_dot' ),
3071
3102
xfail ('nanmean' ),
3072
- xfail ('nn.functional.cosine_similarity' ),
3073
3103
xfail ('nn.functional.layer_norm' ),
3074
3104
xfail ('nn.functional.nll_loss' ),
3075
3105
xfail ('vstack' ),
0 commit comments