@@ -942,9 +942,7 @@ def false_fn(x):
942942 b = torch .randn (4 , requires_grad = True )
943943 c = torch .randn (4 , requires_grad = True )
944944
945- for pred , fn in zip (
946- [torch .tensor (False ), torch .tensor (True )], [false_fn , true_fn ]
947- ):
945+ for pred in [torch .tensor (False ), torch .tensor (True )]:
948946 with self .assertRaisesRegex (
949947 torch ._dynamo .exc .UncapturedHigherOrderOpError ,
950948 "Cond doesn't work unless it is captured completely with torch.compile" ,
@@ -3066,13 +3064,9 @@ def run_test_and_get_grads_loss(model, initial_hs, inputs):
30663064 ).to (DEVICE )
30673065
30683066 # Test 3 models: RNNScanList, RNNScanTensor, RNNLoop
3069- models = [
3070- ("ScanList" , RNNScanList ),
3071- ("ScanTensor" , RNNScanTensor ),
3072- ("Loop" , RNNLoop ),
3073- ]
3067+ models = [RNNScanList , RNNScanTensor , RNNLoop ]
30743068
3075- for model_name , model_class in models :
3069+ for model_class in models :
30763070 # Create uncompiled model
30773071 model_uc = model_class ().to (DEVICE )
30783072 uncompiled_grads , uncompiled_loss = run_test_and_get_grads_loss (
@@ -7538,7 +7532,7 @@ def foo(x):
75387532
75397533 inps = (torch .ones (3 , 4 ), torch .ones (3 , 5 ), torch .ones (5 , 4 ), torch .ones (5 , 3 ))
75407534 for inp in inps :
7541- gm = make_fx (foo , tracing_mode = "symbolic" )(torch . ones ( 3 , 4 ) )
7535+ gm = make_fx (foo , tracing_mode = "symbolic" )(inp )
75427536 self .assertExpectedInline (
75437537 gm .code .strip (),
75447538 """\
0 commit comments