@@ -18,7 +18,7 @@ def calculate_rmse(original: np.ndarray, converted: np.ndarray) -> float:
1818 """Calculate Root Mean Square Error between original and converted tensors."""
1919 if original .shape != converted .shape :
2020 raise ValueError (f"Shape mismatch: { original .shape } vs { converted .shape } " )
21-
21+
2222 diff = original .astype (np .float64 ) - converted .astype (np .float64 )
2323 mse = np .mean (diff ** 2 )
2424 return np .sqrt (mse )
@@ -28,7 +28,7 @@ def calculate_max_error(original: np.ndarray, converted: np.ndarray) -> float:
2828 """Calculate maximum absolute error between original and converted tensors."""
2929 if original .shape != converted .shape :
3030 raise ValueError (f"Shape mismatch: { original .shape } vs { converted .shape } " )
31-
31+
3232 diff = np .abs (original .astype (np .float64 ) - converted .astype (np .float64 ))
3333 return np .max (diff )
3434
@@ -43,41 +43,41 @@ def validate_tensor_conversion(
4343) -> tuple [bool , dict [str , float ]]:
4444 """
4545 Validate accuracy of a single tensor conversion.
46-
46+
4747 Args:
4848 tensor_name: Name of the tensor being validated
4949 original_data: Original tensor data
5050 converted_data: Converted tensor data (after GGUF conversion)
5151 max_rmse_threshold: Maximum allowed RMSE
5252 max_error_threshold: Maximum allowed absolute error
5353 verbose: Whether to print detailed validation results
54-
54+
5555 Returns:
5656 Tuple of (passed: bool, metrics: dict)
5757 """
5858 try :
5959 rmse = calculate_rmse (original_data , converted_data )
6060 max_err = calculate_max_error (original_data , converted_data )
61-
61+
6262 passed = rmse <= max_rmse_threshold and max_err <= max_error_threshold
63-
63+
6464 metrics = {
6565 "rmse" : float (rmse ),
6666 "max_error" : float (max_err ),
6767 "rmse_threshold" : max_rmse_threshold ,
6868 "max_error_threshold" : max_error_threshold ,
6969 "passed" : passed
7070 }
71-
71+
7272 if verbose or not passed :
7373 status = "✓" if passed else "✗"
7474 logger .info (
7575 f"{ status } { tensor_name } : RMSE={ rmse :.6f} (threshold={ max_rmse_threshold } ), "
7676 f"MaxErr={ max_err :.6f} (threshold={ max_error_threshold } )"
7777 )
78-
78+
7979 return passed , metrics
80-
80+
8181 except Exception as e :
8282 logger .error (f"Error validating { tensor_name } : { e } " )
8383 return False , {"error" : str (e )}
@@ -91,35 +91,35 @@ def validate_model_conversion(
9191) -> dict [str , Any ]:
9292 """
9393 Validate accuracy of entire model conversion.
94-
94+
9595 Args:
9696 original_tensors: Dictionary of original tensor names to data
9797 converted_tensors: Dictionary of converted tensor names to data
9898 quantization_type: Type of quantization used (affects thresholds)
9999 verbose: Whether to print detailed validation results
100-
100+
101101 Returns:
102102 Dictionary with validation results and statistics
103103 """
104104 thresholds = get_quantization_thresholds (quantization_type )
105-
105+
106106 results = {
107107 "total_tensors" : 0 ,
108108 "passed_tensors" : 0 ,
109109 "failed_tensors" : [],
110110 "metrics" : {},
111111 "overall_passed" : True
112112 }
113-
113+
114114 common_tensors = set (original_tensors .keys ()) & set (converted_tensors .keys ())
115-
115+
116116 if not common_tensors :
117117 logger .warning ("No common tensors found between original and converted models" )
118118 results ["overall_passed" ] = False
119119 return results
120-
120+
121121 results ["total_tensors" ] = len (common_tensors )
122-
122+
123123 for tensor_name in sorted (common_tensors ):
124124 passed , metrics = validate_tensor_conversion (
125125 tensor_name ,
@@ -129,32 +129,32 @@ def validate_model_conversion(
129129 max_error_threshold = thresholds ["max_error" ],
130130 verbose = verbose
131131 )
132-
132+
133133 results ["metrics" ][tensor_name ] = metrics
134-
134+
135135 if passed :
136136 results ["passed_tensors" ] += 1
137137 else :
138138 results ["failed_tensors" ].append (tensor_name )
139139 results ["overall_passed" ] = False
140-
140+
141141 if verbose :
142142 logger .info (
143143 f"\n Validation Summary: { results ['passed_tensors' ]} /{ results ['total_tensors' ]} tensors passed"
144144 )
145145 if results ["failed_tensors" ]:
146146 logger .warning (f"Failed tensors: { ', ' .join (results ['failed_tensors' ])} " )
147-
147+
148148 return results
149149
150150
151151def get_quantization_thresholds (quantization_type : str ) -> dict [str , float ]:
152152 """
153153 Get appropriate error thresholds for different quantization types.
154-
154+
155155 Args:
156156 quantization_type: Type of quantization (f32, f16, q4_0, q8_0, etc.)
157-
157+
158158 Returns:
159159 Dictionary with "rmse" and "max_error" thresholds
160160 """
@@ -173,25 +173,25 @@ def get_quantization_thresholds(quantization_type: str) -> dict[str, float]:
173173 "q5_k" : {"rmse" : 8e-3 , "max_error" : 8e-2 },
174174 "q6_k" : {"rmse" : 5e-3 , "max_error" : 5e-2 },
175175 }
176-
176+
177177 default = {"rmse" : 1e-2 , "max_error" : 1e-1 }
178-
178+
179179 return thresholds_map .get (quantization_type .lower (), default )
180180
181181
182182def save_validation_report (results : dict [str , Any ], output_path : Path ) -> None :
183183 """
184184 Save validation results to a JSON file.
185-
185+
186186 Args:
187187 results: Validation results dictionary
188188 output_path: Path to save the report
189189 """
190190 import json
191-
191+
192192 with open (output_path , 'w' ) as f :
193193 json .dump (results , f , indent = 2 )
194-
194+
195195 logger .info (f"Validation report saved to { output_path } " )
196196
197197
0 commit comments