@@ -770,32 +770,29 @@ def convert_to_float_tensor(input_data: Any) -> torch.Tensor:
770770 This function handles the following types of input:
771771 - Scalar (int or float): Converts to a tensor with a single element.
772772 - Tensor: Converts to a float64 tensor on CPU.
773- - Sequence of Tensors: Stacks the tensors into a single float64 tensor on CPU.
774773 The resulting tensor is detached, moved to CPU, and cast to torch.float64.
775774 Parameters:
776- input_data (Any): The input data to be converted to a tensor. It can be a scalar,
777- a tensor, or a list of tensors .
775+ input_data (Any): The input data to be converted to a tensor. It can be a scalar
776+ or a tensor .
778777 Returns:
779778 torch.Tensor: A tensor on CPU with dtype torch.float64.
780- Raises:
781- ValueError: If the input_data cannot be converted to a tensor.
779+ Raises error if the input is not a scalar or a tensor
782780 """
781+ # Assert that the input is not a Sequence
782+ assert not isinstance (input_data , Sequence )
783783 try :
784- # Check if the input is a Sequence of tensors
785- if isinstance (input_data , Sequence ):
786- input_tensor = torch .stack ([convert_to_float_tensor (a ) for a in input_data ])
787784 # Try to convert the input to a tensor
788- else :
789- input_tensor = torch .as_tensor (input_data , dtype = torch .float64 )
785+ input_tensor = torch .as_tensor (input_data , dtype = torch .float64 )
790786 except Exception as e :
791787 raise ValueError (
792788 f"Cannot convert value of type { type (input_data )} to a tensor: { e } "
793789 )
794- input_tensor = input_tensor .detach ().cpu ().double ()
795790
791+ input_tensor = input_tensor .detach ().cpu ().double ()
796792 # Convert NaN to 0.0
797793 if torch .isnan (input_tensor ).any ():
798794 input_tensor = torch .nan_to_num (input_tensor )
795+
799796 return input_tensor
800797
801798
@@ -845,3 +842,33 @@ def find_op_names(
845842 result .append (op_name )
846843
847844 return result
845+
846+
847+ def compare_intermediate_outputs (a : Any , b : Any , comparator ) -> List [float ]:
848+ """
849+ Compare two outputs, handling both sequence and non-sequence cases,
850+ and return a list of comparison results.
851+ Parameters:
852+ a: The first intermediate output to compare.
853+ b: The second intermediate output to compare.
854+ comparator: A comparator object with a `compare` method.
855+ Returns:
856+ List[float]: A list of comparison results.
857+ Raises:
858+ ValueError: If one input is a sequence and the other is not, or if sequences have different lengths.
859+ """
860+ is_a_sequence = isinstance (a , Sequence )
861+ is_b_sequence = isinstance (b , Sequence )
862+ if is_a_sequence and is_b_sequence :
863+ # Ensure both sequences have the same length
864+ if len (a ) != len (b ):
865+ raise ValueError ("Sequences must have the same length for comparison." )
866+
867+ # Compare each element in the sequences and return the list of results
868+ return [comparator .compare (x , y ) for x , y in zip (a , b )]
869+ elif not is_a_sequence and not is_b_sequence :
870+ # Compare non-sequence items and return the result in a list
871+ return [comparator .compare (a , b )]
872+ else :
873+ # Raise an error if one is a sequence and the other is not
874+ raise ValueError ("Both inputs must be sequences or both must be non-sequences." )
0 commit comments