| 
29 | 29 |     calculate_mse,  | 
30 | 30 |     calculate_snr,  | 
31 | 31 |     calculate_time_scale_factor,  | 
 | 32 | +    convert_to_float_tensor,  | 
32 | 33 |     create_debug_handle_to_op_node_mapping,  | 
33 | 34 |     EDGE_DIALECT_GRAPH_KEY,  | 
34 | 35 |     find_populated_event,  | 
@@ -317,6 +318,52 @@ def test_map_runtime_aot_intermediate_outputs_complex_chain(self):  | 
317 | 318 |         expected = {((1, 2, 3, 4, 5, 6), 300): ((2, 3, 4, 5, 6, 7), 350)}  | 
318 | 319 |         self.assertEqual(actual, expected)  | 
319 | 320 | 
 
  | 
 | 321 | +    def test_convert_input_to_tensor_convertible_inputs(self):  | 
 | 322 | +        # Scalar -> tensor  | 
 | 323 | +        actual_output1 = convert_to_float_tensor(5)  | 
 | 324 | +        self.assertIsInstance(actual_output1, torch.Tensor)  | 
 | 325 | +        self.assertEqual(actual_output1.dtype, torch.float64)  | 
 | 326 | +        self.assertEqual(tuple(actual_output1.shape), ())  | 
 | 327 | +        self.assertTrue(  | 
 | 328 | +            torch.allclose(actual_output1, torch.tensor([5.0], dtype=torch.float64))  | 
 | 329 | +        )  | 
 | 330 | +        self.assertEqual(actual_output1.device.type, "cpu")  | 
 | 331 | + | 
 | 332 | +        # Tensor of ints -> float32 CPU  | 
 | 333 | +        t_int = torch.tensor([4, 5, 6], dtype=torch.int32)  | 
 | 334 | +        actual_output2 = convert_to_float_tensor(t_int)  | 
 | 335 | +        self.assertIsInstance(actual_output2, torch.Tensor)  | 
 | 336 | +        self.assertEqual(actual_output2.dtype, torch.float64)  | 
 | 337 | +        self.assertTrue(  | 
 | 338 | +            torch.allclose(  | 
 | 339 | +                actual_output2, torch.tensor([4.0, 5.0, 6.0], dtype=torch.float64)  | 
 | 340 | +            )  | 
 | 341 | +        )  | 
 | 342 | +        self.assertEqual(actual_output2.device.type, "cpu")  | 
 | 343 | + | 
 | 344 | +        # List of tensors -> stacked tensor float32 CPU  | 
 | 345 | +        t_list = [torch.tensor([1, 2]), torch.tensor([2, 3]), torch.tensor([3, 4])]  | 
 | 346 | +        actual_output3 = convert_to_float_tensor(t_list)  | 
 | 347 | +        self.assertIsInstance(actual_output3, torch.Tensor)  | 
 | 348 | +        self.assertEqual(actual_output3.dtype, torch.float64)  | 
 | 349 | +        self.assertEqual(tuple(actual_output3.shape), (3, 2))  | 
 | 350 | +        self.assertTrue(  | 
 | 351 | +            torch.allclose(  | 
 | 352 | +                actual_output3,  | 
 | 353 | +                torch.tensor([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]], dtype=torch.float64),  | 
 | 354 | +            )  | 
 | 355 | +        )  | 
 | 356 | +        self.assertEqual(actual_output3.device.type, "cpu")  | 
 | 357 | + | 
 | 358 | +    def test_convert_input_to_tensor_non_convertible_raises(self):  | 
 | 359 | +        class X:  | 
 | 360 | +            pass  | 
 | 361 | + | 
 | 362 | +        with self.assertRaises(ValueError) as cm:  | 
 | 363 | +            convert_to_float_tensor(X())  | 
 | 364 | +        msg = str(cm.exception)  | 
 | 365 | +        self.assertIn("Cannot convert value of type", msg)  | 
 | 366 | + | 
320 | 367 | 
 
  | 
321 | 368 | def gen_mock_operator_graph_with_expected_map() -> (  | 
322 | 369 |     Tuple[OperatorGraph, Dict[int, OperatorNode]]  | 
 | 
0 commit comments