@@ -762,32 +762,31 @@ def convert_to_float_tensor(input_data: Any) -> torch.Tensor:
762762 This function handles the following types of input:
763763 - Scalar (int or float): Converts to a tensor with a single element.
764764 - Tensor: Converts to a float64 tensor on CPU.
765- - Sequence of Tensors: Stacks the tensors into a single float64 tensor on CPU.
766765 The resulting tensor is detached, moved to CPU, and cast to torch.float64.
767766 Parameters:
768- input_data (Any): The input data to be converted to a tensor. It can be a scalar,
769- a tensor, or a list of tensors .
767+ input_data (Any): The input data to be converted to a tensor. It can be a scalar
768+ or a tensor .
770769 Returns:
771770 torch.Tensor: A tensor on CPU with dtype torch.float64.
772771 Raises:
773772 ValueError: If the input_data cannot be converted to a tensor.
773+ AssertionError: If the input_data is a Sequence.
774774 """
775+ # Assert that the input is not a Sequence
776+ assert not isinstance (input_data , Sequence )
775777 try :
776- # Check if the input is a Sequence of tensors
777- if isinstance (input_data , Sequence ):
778- input_tensor = torch .stack ([convert_to_float_tensor (a ) for a in input_data ])
779778 # Try to convert the input to a tensor
780- else :
781- input_tensor = torch .as_tensor (input_data , dtype = torch .float64 )
779+ input_tensor = torch .as_tensor (input_data , dtype = torch .float64 )
782780 except Exception as e :
783781 raise ValueError (
784782 f"Cannot convert value of type { type (input_data )} to a tensor: { e } "
785783 )
786- input_tensor = input_tensor .detach ().cpu ().double ()
787784
785+ input_tensor = input_tensor .detach ().cpu ().double ()
788786 # Convert NaN to 0.0
789787 if torch .isnan (input_tensor ).any ():
790788 input_tensor = torch .nan_to_num (input_tensor )
789+
791790 return input_tensor
792791
793792
@@ -837,3 +836,33 @@ def find_op_names(
837836 result .append (op_name )
838837
839838 return result
839+
840+
841+ def compare_intermediate_outputs (a : Any , b : Any , comparator ) -> List [float ]:
842+ """
843+ Compare two outputs, handling both sequence and non-sequence cases,
844+ and return a list of comparison results.
845+ Parameters:
846+ a: The first intermediate output to compare.
847+ b: The second intermediate output to compare.
848+ comparator: A comparator object with a `compare` method.
849+ Returns:
850+ List[float]: A list of comparison results.
851+ Raises:
852+ ValueError: If one input is a sequence and the other is not, or if sequences have different lengths.
853+ """
854+ is_a_sequence = isinstance (a , Sequence )
855+ is_b_sequence = isinstance (b , Sequence )
856+ if is_a_sequence and is_b_sequence :
857+ # Ensure both sequences have the same length
858+ if len (a ) != len (b ):
859+ raise ValueError ("Sequences must have the same length for comparison." )
860+
861+ # Compare each element in the sequences and return the list of results
862+ return [comparator .compare (x , y ) for x , y in zip (a , b )]
863+ elif not is_a_sequence and not is_b_sequence :
864+ # Compare non-sequence items and return the result in a list
865+ return [comparator .compare (a , b )]
866+ else :
867+ # Raise an error if one is a sequence and the other is not
868+ raise ValueError ("Both inputs must be sequences or both must be non-sequences." )
0 commit comments