@@ -1169,48 +1169,9 @@ def fn():
11691169 yield model [2 ].bias .grad
11701170 model .zero_grad ()
11711171
1172+ # TODO(jansel): we should be able to get this count to 1
11721173 self .check_output_and_recompiles (fn )
11731174
1174- def test_dynamic_shapes_from_forward (self ):
1175- class ToyModel (nn .Module ):
1176- def __init__ (self , in_feat = 10 , hidden_feat = 50 , out_feat = 5 ):
1177- super ().__init__ ()
1178- self .linear1 = nn .Linear (in_feat , hidden_feat )
1179- self .linear2 = nn .Linear (hidden_feat , hidden_feat )
1180- self .linear3 = nn .Linear (hidden_feat , out_feat )
1181- self .mse_loss = torch .nn .MSELoss ()
1182-
1183- def forward (self , inputs , output ):
1184- out1 = self .linear1 (inputs )
1185- out2 = self .linear2 (out1 )
1186- out3 = self .linear3 (out2 )
1187- return self .mse_loss (out3 , output )
1188-
1189- m = ToyModel ()
1190- m = torch .compile (m )
1191-
1192- def run (i ):
1193- torch ._dynamo .utils .counters .clear ()
1194- inp = torch .randn (i , 10 )
1195- target = torch .randn (i , 5 )
1196- loss = m (inp , target )
1197- with compiled_autograd ._enable (make_compiler_fn (dynamic = None )):
1198- loss .backward ()
1199-
1200- counters = torch ._dynamo .utils .counters
1201- run (3 )
1202- self .assertEqual (counters ["compiled_autograd" ]["captures" ], 1 )
1203- self .assertEqual (counters ["compiled_autograd" ]["compiles" ], 1 )
1204- run (4 )
1205- self .assertEqual (counters ["compiled_autograd" ]["captures" ], 1 )
1206- self .assertEqual (counters ["compiled_autograd" ]["compiles" ], 1 )
1207- run (5 )
1208- self .assertEqual (counters ["compiled_autograd" ]["captures" ], 0 )
1209- self .assertEqual (counters ["compiled_autograd" ]["compiles" ], 1 ) # should be 0
1210- run (6 )
1211- self .assertEqual (counters ["compiled_autograd" ]["captures" ], 0 )
1212- self .assertEqual (counters ["compiled_autograd" ]["compiles" ], 0 )
1213-
12141175 def test_dynamic_shapes_eager_node (self ):
12151176 # Here, we have no way of marking the symbolic sizes using in SumBackward as dynamic
12161177 def fn ():
@@ -3395,21 +3356,13 @@ def forward(self, inputs, sizes, scalars, hooks, packed_data):
33953356 getitem_22 = sizes[17]
33963357 getitem_23 = sizes[18]
33973358 getitem_24 = sizes[19]; sizes = None
3398- unwrap_maybe_dynamic_int = torch__dynamo_external_utils_unwrap_maybe_dynamic_int(getitem_5); getitem_5 = None
3399- unwrap_maybe_dynamic_int_1 = torch__dynamo_external_utils_unwrap_maybe_dynamic_int(getitem_6); getitem_6 = None
3400- unwrap_maybe_dynamic_int_2 = torch__dynamo_external_utils_unwrap_maybe_dynamic_int(getitem_7); getitem_7 = None
3401- unwrap_maybe_dynamic_int_3 = torch__dynamo_external_utils_unwrap_maybe_dynamic_int(getitem_8); getitem_8 = None
3402- unwrap_maybe_dynamic_int_16 = torch__dynamo_external_utils_unwrap_maybe_dynamic_int(getitem_21); getitem_21 = None
3403- unwrap_maybe_dynamic_int_17 = torch__dynamo_external_utils_unwrap_maybe_dynamic_int(getitem_22); getitem_22 = None
3404- unwrap_maybe_dynamic_int_18 = torch__dynamo_external_utils_unwrap_maybe_dynamic_int(getitem_23); getitem_23 = None
3405- unwrap_maybe_dynamic_int_19 = torch__dynamo_external_utils_unwrap_maybe_dynamic_int(getitem_24); getitem_24 = None
34063359
34073360 validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], True)]); getitem = None
34083361 getitem_25 = validate_outputs[0]; validate_outputs = None
34093362
3410- sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_25], [True], [unwrap_maybe_dynamic_int, unwrap_maybe_dynamic_int_1 ]); getitem_25 = unwrap_maybe_dynamic_int = unwrap_maybe_dynamic_int_1 = None
3363+ sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_25], [True], [getitem_5, getitem_6 ]); getitem_25 = getitem_5 = getitem_6 = None
34113364 getitem_26 = sum_backward0[0]; sum_backward0 = None
3412- validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_26], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_2, unwrap_maybe_dynamic_int_3 ], True)]); getitem_26 = unwrap_maybe_dynamic_int_2 = unwrap_maybe_dynamic_int_3 = None
3365+ validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_26], [((None, None, device(type='cpu'), 6, 0, None), [getitem_7, getitem_8 ], True)]); getitem_26 = getitem_7 = getitem_8 = None
34133366 getitem_27 = validate_outputs_1[0]; validate_outputs_1 = None
34143367
34153368 getitem_28 = hooks[0]; getitem_28 = None
@@ -3431,7 +3384,7 @@ def forward(self, inputs, sizes, scalars, hooks, packed_data):
34313384 call_backward = torch__dynamo_external_utils_call_backward(getitem_33, (), make_subclass); getitem_33 = make_subclass = None
34323385 getitem_36 = call_backward[0]
34333386 getitem_37 = call_backward[1]; call_backward = None
3434- validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_36, getitem_37], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_16, unwrap_maybe_dynamic_int_17 ], False), ((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_18, unwrap_maybe_dynamic_int_19 ], False)]); getitem_36 = getitem_37 = unwrap_maybe_dynamic_int_16 = unwrap_maybe_dynamic_int_17 = unwrap_maybe_dynamic_int_18 = unwrap_maybe_dynamic_int_19 = None
3387+ validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_36, getitem_37], [((None, None, device(type='cpu'), 6, 0, None), [getitem_21, getitem_22 ], False), ((None, None, device(type='cpu'), 6, 0, None), [getitem_23, getitem_24 ], False)]); getitem_36 = getitem_37 = getitem_21 = getitem_22 = getitem_23 = getitem_24 = None
34353388 getitem_39 = validate_outputs_2[0]
34363389
34373390 accumulate_grad__default_1 = torch.ops.inductor.accumulate_grad_.default(getitem_4, getitem_39); getitem_4 = getitem_39 = accumulate_grad__default_1 = None
@@ -3648,25 +3601,13 @@ def forward(self, inputs, sizes, scalars, hooks, packed_data):
36483601 getitem_11 = sizes[9]
36493602 getitem_12 = sizes[10]
36503603 getitem_13 = sizes[11]; sizes = None
3651- unwrap_maybe_dynamic_int = torch__dynamo_external_utils_unwrap_maybe_dynamic_int(getitem_2); getitem_2 = None
3652- unwrap_maybe_dynamic_int_1 = torch__dynamo_external_utils_unwrap_maybe_dynamic_int(getitem_3); getitem_3 = None
3653- unwrap_maybe_dynamic_int_2 = torch__dynamo_external_utils_unwrap_maybe_dynamic_int(getitem_4); getitem_4 = None
3654- unwrap_maybe_dynamic_int_3 = torch__dynamo_external_utils_unwrap_maybe_dynamic_int(getitem_5); getitem_5 = None
3655- unwrap_maybe_dynamic_int_4 = torch__dynamo_external_utils_unwrap_maybe_dynamic_int(getitem_6); getitem_6 = None
3656- unwrap_maybe_dynamic_int_5 = torch__dynamo_external_utils_unwrap_maybe_dynamic_int(getitem_7); getitem_7 = None
3657- unwrap_maybe_dynamic_int_6 = torch__dynamo_external_utils_unwrap_maybe_dynamic_int(getitem_8); getitem_8 = None
3658- unwrap_maybe_dynamic_int_7 = torch__dynamo_external_utils_unwrap_maybe_dynamic_int(getitem_9); getitem_9 = None
3659- unwrap_maybe_dynamic_int_8 = torch__dynamo_external_utils_unwrap_maybe_dynamic_int(getitem_10); getitem_10 = None
3660- unwrap_maybe_dynamic_int_9 = torch__dynamo_external_utils_unwrap_maybe_dynamic_int(getitem_11); getitem_11 = None
3661- unwrap_maybe_dynamic_int_10 = torch__dynamo_external_utils_unwrap_maybe_dynamic_int(getitem_12); getitem_12 = None
3662- unwrap_maybe_dynamic_int_11 = torch__dynamo_external_utils_unwrap_maybe_dynamic_int(getitem_13); getitem_13 = None
36633604
36643605 validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], False)]); getitem = None
36653606 getitem_14 = validate_outputs[0]; validate_outputs = None
36663607
3667- sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_14], [True], [unwrap_maybe_dynamic_int, unwrap_maybe_dynamic_int_1 ]); getitem_14 = unwrap_maybe_dynamic_int = unwrap_maybe_dynamic_int_1 = None
3608+ sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_14], [True], [getitem_2, getitem_3 ]); getitem_14 = getitem_2 = getitem_3 = None
36683609 getitem_15 = sum_backward0[0]; sum_backward0 = None
3669- validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_15], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_2, unwrap_maybe_dynamic_int_3 ], False)]); getitem_15 = unwrap_maybe_dynamic_int_2 = unwrap_maybe_dynamic_int_3 = None
3610+ validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_15], [((None, None, device(type='cpu'), 6, 0, None), [getitem_4, getitem_5 ], False)]); getitem_15 = getitem_4 = getitem_5 = None
36703611 getitem_16 = validate_outputs_1[0]; validate_outputs_1 = None
36713612
36723613 getitem_17 = hooks[0]
@@ -3678,7 +3619,7 @@ def forward(self, inputs, sizes, scalars, hooks, packed_data):
36783619 mul_backward0 = torch__dynamo_compiled_autograd_ops_MulBackward0([getitem_16], [True, True], call_hook, 6, call_hook_1, 6); getitem_16 = call_hook = call_hook_1 = None
36793620 getitem_21 = mul_backward0[0]
36803621 getitem_22 = mul_backward0[1]; mul_backward0 = None
3681- validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_21, getitem_22], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_4, unwrap_maybe_dynamic_int_5 ], False), ((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_6, unwrap_maybe_dynamic_int_7 ], False)]); getitem_21 = getitem_22 = unwrap_maybe_dynamic_int_4 = unwrap_maybe_dynamic_int_5 = unwrap_maybe_dynamic_int_6 = unwrap_maybe_dynamic_int_7 = None
3622+ validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_21, getitem_22], [((None, None, device(type='cpu'), 6, 0, None), [getitem_6, getitem_7 ], False), ((None, None, device(type='cpu'), 6, 0, None), [getitem_8, getitem_9 ], False)]); getitem_21 = getitem_22 = getitem_6 = getitem_7 = getitem_8 = getitem_9 = None
36823623 getitem_23 = validate_outputs_2[0]
36833624 getitem_24 = validate_outputs_2[1]; validate_outputs_2 = None
36843625
@@ -3687,7 +3628,7 @@ def forward(self, inputs, sizes, scalars, hooks, packed_data):
36873628 call_hook_2 = torch__dynamo_external_utils_call_hook(getitem_25, getitem_26, hook_type = 'unpack_hook'); getitem_25 = getitem_26 = None
36883629 cos_backward0 = torch__dynamo_compiled_autograd_ops_CosBackward0([getitem_24], [True], call_hook_2); getitem_24 = call_hook_2 = None
36893630 getitem_27 = cos_backward0[0]; cos_backward0 = None
3690- validate_outputs_3 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_27], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_8, unwrap_maybe_dynamic_int_9 ], False)]); getitem_27 = unwrap_maybe_dynamic_int_8 = unwrap_maybe_dynamic_int_9 = None
3631+ validate_outputs_3 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_27], [((None, None, device(type='cpu'), 6, 0, None), [getitem_10, getitem_11 ], False)]); getitem_27 = getitem_10 = getitem_11 = None
36913632 getitem_28 = validate_outputs_3[0]; validate_outputs_3 = None
36923633 add = torch.add(getitem_23, getitem_28); getitem_23 = getitem_28 = None
36933634
@@ -3696,7 +3637,7 @@ def forward(self, inputs, sizes, scalars, hooks, packed_data):
36963637 call_hook_3 = torch__dynamo_external_utils_call_hook(getitem_29, getitem_30, hook_type = 'unpack_hook'); getitem_29 = getitem_30 = None
36973638 sin_backward0 = torch__dynamo_compiled_autograd_ops_SinBackward0([add], [True], call_hook_3); add = call_hook_3 = None
36983639 getitem_31 = sin_backward0[0]; sin_backward0 = None
3699- validate_outputs_4 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_31], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_10, unwrap_maybe_dynamic_int_11 ], False)]); getitem_31 = unwrap_maybe_dynamic_int_10 = unwrap_maybe_dynamic_int_11 = None
3640+ validate_outputs_4 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_31], [((None, None, device(type='cpu'), 6, 0, None), [getitem_12, getitem_13 ], False)]); getitem_31 = getitem_12 = getitem_13 = None
37003641 getitem_32 = validate_outputs_4[0]; validate_outputs_4 = None
37013642
37023643 accumulate_grad__default = torch.ops.inductor.accumulate_grad_.default(getitem_1, getitem_32); getitem_1 = getitem_32 = accumulate_grad__default = None
0 commit comments