@@ -762,32 +762,29 @@ def convert_to_float_tensor(input_data: Any) -> torch.Tensor:
762762 This function handles the following types of input:
763763 - Scalar (int or float): Converts to a tensor with a single element.
764764 - Tensor: Converts to a float64 tensor on CPU.
765- - Sequence of Tensors: Stacks the tensors into a single float64 tensor on CPU.
766765 The resulting tensor is detached, moved to CPU, and cast to torch.float64.
767766 Parameters:
768- input_data (Any): The input data to be converted to a tensor. It can be a scalar,
769- a tensor, or a list of tensors .
767+ input_data (Any): The input data to be converted to a tensor. It can be a scalar
768+ or a tensor .
770769 Returns:
771770 torch.Tensor: A tensor on CPU with dtype torch.float64.
772- Raises:
773- ValueError: If the input_data cannot be converted to a tensor.
771+ Raises error if the input is not a scalar or a tensor
774772 """
773+ # Assert that the input is not a Sequence
774+ assert not isinstance (input_data , Sequence )
775775 try :
776- # Check if the input is a Sequence of tensors
777- if isinstance (input_data , Sequence ):
778- input_tensor = torch .stack ([convert_to_float_tensor (a ) for a in input_data ])
779776 # Try to convert the input to a tensor
780- else :
781- input_tensor = torch .as_tensor (input_data , dtype = torch .float64 )
777+ input_tensor = torch .as_tensor (input_data , dtype = torch .float64 )
782778 except Exception as e :
783779 raise ValueError (
784780 f"Cannot convert value of type { type (input_data )} to a tensor: { e } "
785781 )
786- input_tensor = input_tensor .detach ().cpu ().double ()
787782
783+ input_tensor = input_tensor .detach ().cpu ().double ()
788784 # Convert NaN to 0.0
789785 if torch .isnan (input_tensor ).any ():
790786 input_tensor = torch .nan_to_num (input_tensor )
787+
791788 return input_tensor
792789
793790
@@ -837,3 +834,33 @@ def find_op_names(
837834 result .append (op_name )
838835
839836 return result
837+
838+
839+ def compare_intermediate_outputs (a : Any , b : Any , comparator ) -> List [float ]:
840+ """
841+ Compare two outputs, handling both sequence and non-sequence cases,
842+ and return a list of comparison results.
843+ Parameters:
844+ a: The first intermediate output to compare.
845+ b: The second intermediate output to compare.
846+ comparator: A comparator object with a `compare` method.
847+ Returns:
848+ List[float]: A list of comparison results.
849+ Raises:
850+ ValueError: If one input is a sequence and the other is not, or if sequences have different lengths.
851+ """
852+ is_a_sequence = isinstance (a , Sequence )
853+ is_b_sequence = isinstance (b , Sequence )
854+ if is_a_sequence and is_b_sequence :
855+ # Ensure both sequences have the same length
856+ if len (a ) != len (b ):
857+ raise ValueError ("Sequences must have the same length for comparison." )
858+
859+ # Compare each element in the sequences and return the list of results
860+ return [comparator .compare (x , y ) for x , y in zip (a , b )]
861+ elif not is_a_sequence and not is_b_sequence :
862+ # Compare non-sequence items and return the result in a list
863+ return [comparator .compare (a , b )]
864+ else :
865+ # Raise an error if one is a sequence and the other is not
866+ raise ValueError ("Both inputs must be sequences or both must be non-sequences." )
0 commit comments