@@ -770,32 +770,29 @@ def convert_to_float_tensor(input_data: Any) -> torch.Tensor:
770
770
This function handles the following types of input:
771
771
- Scalar (int or float): Converts to a tensor with a single element.
772
772
- Tensor: Converts to a float64 tensor on CPU.
773
- - Sequence of Tensors: Stacks the tensors into a single float64 tensor on CPU.
774
773
The resulting tensor is detached, moved to CPU, and cast to torch.float64.
775
774
Parameters:
776
- input_data (Any): The input data to be converted to a tensor. It can be a scalar,
777
- a tensor, or a list of tensors .
775
+ input_data (Any): The input data to be converted to a tensor. It can be a scalar
776
+ or a tensor .
778
777
Returns:
779
778
torch.Tensor: A tensor on CPU with dtype torch.float64.
780
- Raises:
781
- ValueError: If the input_data cannot be converted to a tensor.
779
+ Raises error if the input is not a scalar or a tensor
782
780
"""
781
+ # Assert that the input is not a Sequence
782
+ assert not isinstance (input_data , Sequence )
783
783
try :
784
- # Check if the input is a Sequence of tensors
785
- if isinstance (input_data , Sequence ):
786
- input_tensor = torch .stack ([convert_to_float_tensor (a ) for a in input_data ])
787
784
# Try to convert the input to a tensor
788
- else :
789
- input_tensor = torch .as_tensor (input_data , dtype = torch .float64 )
785
+ input_tensor = torch .as_tensor (input_data , dtype = torch .float64 )
790
786
except Exception as e :
791
787
raise ValueError (
792
788
f"Cannot convert value of type { type (input_data )} to a tensor: { e } "
793
789
)
794
- input_tensor = input_tensor .detach ().cpu ().double ()
795
790
791
+ input_tensor = input_tensor .detach ().cpu ().double ()
796
792
# Convert NaN to 0.0
797
793
if torch .isnan (input_tensor ).any ():
798
794
input_tensor = torch .nan_to_num (input_tensor )
795
+
799
796
return input_tensor
800
797
801
798
@@ -845,3 +842,33 @@ def find_op_names(
845
842
result .append (op_name )
846
843
847
844
return result
845
+
846
+
847
+ def compare_intermediate_outputs (a : Any , b : Any , comparator ) -> List [float ]:
848
+ """
849
+ Compare two outputs, handling both sequence and non-sequence cases,
850
+ and return a list of comparison results.
851
+ Parameters:
852
+ a: The first intermediate output to compare.
853
+ b: The second intermediate output to compare.
854
+ comparator: A comparator object with a `compare` method.
855
+ Returns:
856
+ List[float]: A list of comparison results.
857
+ Raises:
858
+ ValueError: If one input is a sequence and the other is not, or if sequences have different lengths.
859
+ """
860
+ is_a_sequence = isinstance (a , Sequence )
861
+ is_b_sequence = isinstance (b , Sequence )
862
+ if is_a_sequence and is_b_sequence :
863
+ # Ensure both sequences have the same length
864
+ if len (a ) != len (b ):
865
+ raise ValueError ("Sequences must have the same length for comparison." )
866
+
867
+ # Compare each element in the sequences and return the list of results
868
+ return [comparator .compare (x , y ) for x , y in zip (a , b )]
869
+ elif not is_a_sequence and not is_b_sequence :
870
+ # Compare non-sequence items and return the result in a list
871
+ return [comparator .compare (a , b )]
872
+ else :
873
+ # Raise an error if one is a sequence and the other is not
874
+ raise ValueError ("Both inputs must be sequences or both must be non-sequences." )
0 commit comments