Skip to content

Commit cc829d3

Browse files
Juntian Liufacebook-github-bot
authored andcommitted
Implemented Runtime Intermediate Output Extraction Based on Corresponding AOT Operators (#12212)
Summary: This PR adds functionality to extract the last n recorded values as runtime intermediate outputs, where n is the size of the intermediate output for the corresponding AOT operator during the final stage of mapping between AOT and runtime. By doing this, it adds support to delegate because currently, we log all the arguments of the delegate call. Differential Revision: D77712318
1 parent 756bf2f commit cc829d3

File tree

3 files changed

+70
-3
lines changed

3 files changed

+70
-3
lines changed

devtools/inspector/_inspector_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,19 @@ def map_runtime_aot_intermediate_outputs(
732732
# runtime follow the same format as aot, so it's safe to convert to tuple
733733
if isinstance(runtime_intermediate_output, list):
734734
runtime_intermediate_output = tuple(runtime_intermediate_output)
735+
736+
# Currently, runtime_intermediate_output logs all delegate call arguments.
737+
# Process here to extract only the outputs.
738+
if isinstance(aot_intermediate_output, tuple):
739+
# If both are sequences, slice runtime_intermediate_output to match the length of aot_intermediate_output
740+
if isinstance(runtime_intermediate_output, tuple):
741+
runtime_intermediate_output = runtime_intermediate_output[
742+
-len(aot_intermediate_output) :
743+
]
744+
# If aot_intermediate_output is not a sequence but runtime_intermediate_output is, get the last element
745+
elif isinstance(runtime_intermediate_output, tuple):
746+
runtime_intermediate_output = runtime_intermediate_output[-1]
747+
735748
# Create a mapping between runtime and aot
736749
aot_runtime_mapping[
737750
(aot_combined_debug_handle, aot_intermediate_output)

devtools/inspector/tests/inspector_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -571,15 +571,15 @@ def test_get_runtime_intermediate_outputs_and_op_names(self):
571571
self.assertIn((4,), runtime_outputs)
572572
self.assertIn((4,), op_names)
573573
self.assertTrue(
574-
torch.equal(runtime_outputs[(4,)][0], torch.tensor([4.0, 5.0, 6.0]))
574+
torch.allclose(runtime_outputs[(4,)][0], torch.tensor([4.0, 5.0, 6.0]))
575575
)
576576
self.assertEqual(op_names[(4,)], "op_3")
577577

578578
# Check that keys (5,) to (8,) are in the dictionary and have values of the correct size
579579
for key in range(5, 9):
580580
self.assertIn((key,), runtime_outputs)
581581
self.assertIn((key,), op_names)
582-
self.assertEqual(len(runtime_outputs[(key,)]), RAW_DATA_SIZE)
582+
self.assertEqual(runtime_outputs[(key,)][0].size(0), RAW_DATA_SIZE)
583583
self.assertEqual(op_names[(key,)], f"op_{key-1}")
584584

585585
def test_calculate_numeric_gap(self):
@@ -659,7 +659,7 @@ def _gen_random_float_list(self) -> List[float]:
659659
def _gen_random_runtime_output(
660660
self,
661661
) -> List[Union[None, List[torch.Tensor], bool, float, int, str, torch.Tensor]]:
662-
return list(torch.randn(RAW_DATA_SIZE))
662+
return [torch.randn(RAW_DATA_SIZE)]
663663

664664
def _gen_random_events(self) -> List[Event]:
665665
events = []

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,60 @@ def test_map_runtime_aot_intermediate_outputs_complex_chain(self):
343343
expected = {((1, 2, 3, 4, 5, 6), 300): ((2, 3, 4, 5, 6, 7), 350)}
344344
self.assertEqual(actual, expected)
345345

346+
def test_map_runtime_aot_intermediate_outputs_delegated(self):
347+
# Currently, runtime_intermediate_output logs all delegate call arguments
348+
# Test that the map function correctly extracted out the delegated outputs
349+
aot_intermediate_outputs = {
350+
(1, 2): torch.tensor([4, 5]),
351+
(3, 4): torch.tensor([10, 11, 12]),
352+
(5, 6): torch.tensor([13, 14, 15, 16, 17]),
353+
}
354+
runtime_intermediate_outputs = {
355+
(1, 2): [torch.tensor([1, 2, 3]), torch.tensor([4, 5])],
356+
(3, 4): [
357+
torch.tensor([6, 7, 8, 9]),
358+
torch.tensor(1),
359+
torch.tensor([10, 11, 12]),
360+
],
361+
(5, 6): [
362+
torch.tensor([1]),
363+
torch.tensor([2]),
364+
torch.tensor([13, 14, 15, 16, 17]),
365+
],
366+
}
367+
actual = map_runtime_aot_intermediate_outputs(
368+
aot_intermediate_outputs, runtime_intermediate_outputs
369+
)
370+
expected = {
371+
((1, 2), torch.tensor([4, 5])): ((1, 2), torch.tensor([4, 5])),
372+
((3, 4), torch.tensor([10, 11, 12])): ((3, 4), torch.tensor([10, 11, 12])),
373+
((5, 6), torch.tensor([13, 14, 15, 16, 17])): (
374+
(5, 6),
375+
torch.tensor([13, 14, 15, 16, 17]),
376+
),
377+
}
378+
self.assertEqual(len(actual), len(expected))
379+
380+
for (exp_aot_key, exp_aot_value), (
381+
exp_runtime_key,
382+
exp_runtime_value,
383+
) in expected.items():
384+
found = False
385+
for (act_aot_key, act_aot_value), (
386+
act_runtime_key,
387+
act_runtime_value,
388+
) in actual.items():
389+
if exp_aot_key == act_aot_key and torch.allclose(
390+
exp_aot_value, act_aot_value
391+
):
392+
found = True
393+
self.assertEqual(exp_runtime_key, act_runtime_key)
394+
self.assertTrue(
395+
torch.allclose(exp_runtime_value, act_runtime_value)
396+
)
397+
break
398+
self.assertTrue(found)
399+
346400
def test_convert_input_to_tensor_convertible_inputs(self):
347401
# Scalar -> tensor
348402
actual_output1 = convert_to_float_tensor(5)

0 commit comments

Comments
 (0)