Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions devtools/inspector/_inspector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,19 @@ def map_runtime_aot_intermediate_outputs(
# runtime follow the same format as aot, so it's safe to convert to tuple
if isinstance(runtime_intermediate_output, list):
runtime_intermediate_output = tuple(runtime_intermediate_output)

# Currently, runtime_intermediate_output logs all delegate call arguments.
# Process here to extract only the outputs.
if isinstance(aot_intermediate_output, tuple):
# If both are sequences, slice runtime_intermediate_output to match the length of aot_intermediate_output
if isinstance(runtime_intermediate_output, tuple):
runtime_intermediate_output = runtime_intermediate_output[
-len(aot_intermediate_output) :
]
# If aot_intermediate_output is not a sequence but runtime_intermediate_output is, get the last element
elif isinstance(runtime_intermediate_output, tuple):
runtime_intermediate_output = runtime_intermediate_output[-1]

# Create a mapping between runtime and aot
aot_runtime_mapping[
(aot_combined_debug_handle, aot_intermediate_output)
Expand Down
6 changes: 3 additions & 3 deletions devtools/inspector/tests/inspector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,15 +571,15 @@ def test_get_runtime_intermediate_outputs_and_op_names(self):
self.assertIn((4,), runtime_outputs)
self.assertIn((4,), op_names)
self.assertTrue(
torch.equal(runtime_outputs[(4,)][0], torch.tensor([4.0, 5.0, 6.0]))
torch.allclose(runtime_outputs[(4,)][0], torch.tensor([4.0, 5.0, 6.0]))
)
self.assertEqual(op_names[(4,)], "op_3")

# Check that keys (5,) to (8,) are in the dictionary and have values of the correct size
for key in range(5, 9):
self.assertIn((key,), runtime_outputs)
self.assertIn((key,), op_names)
self.assertEqual(len(runtime_outputs[(key,)]), RAW_DATA_SIZE)
self.assertEqual(runtime_outputs[(key,)][0].size(0), RAW_DATA_SIZE)
self.assertEqual(op_names[(key,)], f"op_{key-1}")

def test_calculate_numeric_gap(self):
Expand Down Expand Up @@ -659,7 +659,7 @@ def _gen_random_float_list(self) -> List[float]:
def _gen_random_runtime_output(
self,
) -> List[Union[None, List[torch.Tensor], bool, float, int, str, torch.Tensor]]:
return list(torch.randn(RAW_DATA_SIZE))
return [torch.randn(RAW_DATA_SIZE)]

def _gen_random_events(self) -> List[Event]:
events = []
Expand Down
54 changes: 54 additions & 0 deletions devtools/inspector/tests/inspector_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,60 @@ def test_map_runtime_aot_intermediate_outputs_complex_chain(self):
expected = {((1, 2, 3, 4, 5, 6), 300): ((2, 3, 4, 5, 6, 7), 350)}
self.assertEqual(actual, expected)

def test_map_runtime_aot_intermediate_outputs_delegated(self):
# Currently, runtime_intermediate_output logs all delegate call arguments
# Test that the map function correctly extracted out the delegated outputs
aot_intermediate_outputs = {
(1, 2): torch.tensor([4, 5]),
(3, 4): torch.tensor([10, 11, 12]),
(5, 6): torch.tensor([13, 14, 15, 16, 17]),
}
runtime_intermediate_outputs = {
(1, 2): [torch.tensor([1, 2, 3]), torch.tensor([4, 5])],
(3, 4): [
torch.tensor([6, 7, 8, 9]),
torch.tensor(1),
torch.tensor([10, 11, 12]),
],
(5, 6): [
torch.tensor([1]),
torch.tensor([2]),
torch.tensor([13, 14, 15, 16, 17]),
],
}
actual = map_runtime_aot_intermediate_outputs(
aot_intermediate_outputs, runtime_intermediate_outputs
)
expected = {
((1, 2), torch.tensor([4, 5])): ((1, 2), torch.tensor([4, 5])),
((3, 4), torch.tensor([10, 11, 12])): ((3, 4), torch.tensor([10, 11, 12])),
((5, 6), torch.tensor([13, 14, 15, 16, 17])): (
(5, 6),
torch.tensor([13, 14, 15, 16, 17]),
),
}
self.assertEqual(len(actual), len(expected))

for (exp_aot_key, exp_aot_value), (
exp_runtime_key,
exp_runtime_value,
) in expected.items():
found = False
for (act_aot_key, act_aot_value), (
act_runtime_key,
act_runtime_value,
) in actual.items():
if exp_aot_key == act_aot_key and torch.allclose(
exp_aot_value, act_aot_value
):
found = True
self.assertEqual(exp_runtime_key, act_runtime_key)
self.assertTrue(
torch.allclose(exp_runtime_value, act_runtime_value)
)
break
self.assertTrue(found)

def test_convert_input_to_tensor_convertible_inputs(self):
# Scalar -> tensor
actual_output1 = convert_to_float_tensor(5)
Expand Down
Loading