Skip to content

Commit 56f2cda

Browse files
Juntian Liufacebook-github-bot
authored andcommitted
Add function for input preprocessing in numerical comparator (#11739)
Summary: Pull Request resolved: #11739 This PR adds the process_input_to_tensor function to convert inputs to torch.Tensor on CPU with torch.float32 dtype for preprocessing inputs used later in the numerical comparator. Differential Revision: D76745314
1 parent 3a6c664 commit 56f2cda

File tree

2 files changed

+73
-0
lines changed

2 files changed

+73
-0
lines changed

devtools/inspector/_inspector_utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,3 +690,38 @@ def map_runtime_aot_intermediate_outputs(
690690
)
691691

692692
return aot_runtime_mapping
693+
694+
695+
def process_input_to_tensor(input_data: Any) -> torch.Tensor:
696+
"""
697+
Convert input_data into a torch.Tensor on CPU with dtype torch.float32.
698+
This function handles the following types of input:
699+
- Scalar (int or float): Converts to a 1D tensor with a single element.
700+
- Tensor: Converts to a float32 tensor on CPU.
701+
- List of Tensors: Stacks the tensors into a single float32 tensor on CPU.
702+
The resulting tensor is detached, moved to CPU, and cast to torch.float32.
703+
Parameters:
704+
input_data (Any): The input data to be converted to a tensor. It can be a scalar,
705+
a tensor, or a list of tensors.
706+
Returns:
707+
torch.Tensor: A tensor on CPU with dtype torch.float32.
708+
Raises:
709+
ValueError: If the input_data cannot be converted to a tensor.
710+
"""
711+
try:
712+
# Check if the input is a scalar
713+
if isinstance(input_data, (int, float)):
714+
input_tensor = torch.tensor([input_data], dtype=torch.float32)
715+
# Check if the input is a list of tensors
716+
elif isinstance(input_data, list) and all(
717+
isinstance(i, torch.Tensor) for i in input_data
718+
):
719+
input_tensor = torch.stack(input_data).to(torch.float32)
720+
else:
721+
input_tensor = torch.as_tensor(input_data, dtype=torch.float32)
722+
except Exception as e:
723+
raise ValueError(
724+
f"Cannot convert value of type {type(input_data)} to a tensor: {e}"
725+
)
726+
input_tensor = input_tensor.detach().cpu().float()
727+
return input_tensor

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
is_inference_output_equal,
3737
map_runtime_aot_intermediate_outputs,
3838
merge_overlapping_debug_handles,
39+
process_input_to_tensor,
3940
TimeScale,
4041
)
4142

@@ -317,6 +318,43 @@ def test_map_runtime_aot_intermediate_outputs_complex_chain(self):
317318
expected = {((1, 2, 3, 4, 5, 6), 300): ((2, 3, 4, 5, 6, 7), 350)}
318319
self.assertEqual(actual, expected)
319320

321+
def test_process_input_convertible_inputs(self):
322+
# Scalar -> tensor
323+
actual_output1 = process_input_to_tensor(5)
324+
self.assertIsInstance(actual_output1, torch.Tensor)
325+
self.assertEqual(actual_output1.dtype, torch.float32)
326+
self.assertEqual(tuple(actual_output1.shape), (1,))
327+
self.assertTrue(torch.allclose(actual_output1, torch.tensor([5.0])))
328+
self.assertEqual(actual_output1.device.type, "cpu")
329+
330+
# Tensor of ints -> float32 CPU
331+
t_int = torch.tensor([4, 5, 6], dtype=torch.int32)
332+
actual_output2 = process_input_to_tensor(t_int)
333+
self.assertIsInstance(actual_output2, torch.Tensor)
334+
self.assertEqual(actual_output2.dtype, torch.float32)
335+
self.assertTrue(torch.allclose(actual_output2, torch.tensor([4.0, 5.0, 6.0])))
336+
self.assertEqual(actual_output2.device.type, "cpu")
337+
338+
# List of tensors -> stacked tensor float32 CPU
339+
t_list = [torch.tensor([1]), torch.tensor([2]), torch.tensor([3])]
340+
actual_output3 = process_input_to_tensor(t_list)
341+
self.assertIsInstance(actual_output3, torch.Tensor)
342+
self.assertEqual(actual_output3.dtype, torch.float32)
343+
self.assertEqual(tuple(actual_output3.shape), (3, 1))
344+
self.assertTrue(
345+
torch.allclose(actual_output3, torch.tensor([[1.0], [2.0], [3.0]]))
346+
)
347+
self.assertEqual(actual_output3.device.type, "cpu")
348+
349+
def test_process_input_tensor_non_convertible_raises(self):
350+
class X:
351+
pass
352+
353+
with self.assertRaises(ValueError) as cm:
354+
process_input_to_tensor(X())
355+
msg = str(cm.exception)
356+
self.assertIn("Cannot convert value of type", msg)
357+
320358

321359
def gen_mock_operator_graph_with_expected_map() -> (
322360
Tuple[OperatorGraph, Dict[int, OperatorNode]]

0 commit comments

Comments
 (0)