Skip to content

Commit 27c40a6

Browse files
Fix trailing whitespace in Python validation file and test-backend-ops.cpp
Co-Authored-By: Alex Peng <[email protected]>
1 parent b75b820 commit 27c40a6

File tree

2 files changed

+28
-28
lines changed

2 files changed

+28
-28
lines changed

gguf-py/gguf/conversion_validation.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def calculate_rmse(original: np.ndarray, converted: np.ndarray) -> float:
1818
"""Calculate Root Mean Square Error between original and converted tensors."""
1919
if original.shape != converted.shape:
2020
raise ValueError(f"Shape mismatch: {original.shape} vs {converted.shape}")
21-
21+
2222
diff = original.astype(np.float64) - converted.astype(np.float64)
2323
mse = np.mean(diff ** 2)
2424
return np.sqrt(mse)
@@ -28,7 +28,7 @@ def calculate_max_error(original: np.ndarray, converted: np.ndarray) -> float:
2828
"""Calculate maximum absolute error between original and converted tensors."""
2929
if original.shape != converted.shape:
3030
raise ValueError(f"Shape mismatch: {original.shape} vs {converted.shape}")
31-
31+
3232
diff = np.abs(original.astype(np.float64) - converted.astype(np.float64))
3333
return np.max(diff)
3434

@@ -43,41 +43,41 @@ def validate_tensor_conversion(
4343
) -> tuple[bool, dict[str, float]]:
4444
"""
4545
Validate accuracy of a single tensor conversion.
46-
46+
4747
Args:
4848
tensor_name: Name of the tensor being validated
4949
original_data: Original tensor data
5050
converted_data: Converted tensor data (after GGUF conversion)
5151
max_rmse_threshold: Maximum allowed RMSE
5252
max_error_threshold: Maximum allowed absolute error
5353
verbose: Whether to print detailed validation results
54-
54+
5555
Returns:
5656
Tuple of (passed: bool, metrics: dict)
5757
"""
5858
try:
5959
rmse = calculate_rmse(original_data, converted_data)
6060
max_err = calculate_max_error(original_data, converted_data)
61-
61+
6262
passed = rmse <= max_rmse_threshold and max_err <= max_error_threshold
63-
63+
6464
metrics = {
6565
"rmse": float(rmse),
6666
"max_error": float(max_err),
6767
"rmse_threshold": max_rmse_threshold,
6868
"max_error_threshold": max_error_threshold,
6969
"passed": passed
7070
}
71-
71+
7272
if verbose or not passed:
7373
status = "✓" if passed else "✗"
7474
logger.info(
7575
f"{status} {tensor_name}: RMSE={rmse:.6f} (threshold={max_rmse_threshold}), "
7676
f"MaxErr={max_err:.6f} (threshold={max_error_threshold})"
7777
)
78-
78+
7979
return passed, metrics
80-
80+
8181
except Exception as e:
8282
logger.error(f"Error validating {tensor_name}: {e}")
8383
return False, {"error": str(e)}
@@ -91,35 +91,35 @@ def validate_model_conversion(
9191
) -> dict[str, Any]:
9292
"""
9393
Validate accuracy of entire model conversion.
94-
94+
9595
Args:
9696
original_tensors: Dictionary of original tensor names to data
9797
converted_tensors: Dictionary of converted tensor names to data
9898
quantization_type: Type of quantization used (affects thresholds)
9999
verbose: Whether to print detailed validation results
100-
100+
101101
Returns:
102102
Dictionary with validation results and statistics
103103
"""
104104
thresholds = get_quantization_thresholds(quantization_type)
105-
105+
106106
results = {
107107
"total_tensors": 0,
108108
"passed_tensors": 0,
109109
"failed_tensors": [],
110110
"metrics": {},
111111
"overall_passed": True
112112
}
113-
113+
114114
common_tensors = set(original_tensors.keys()) & set(converted_tensors.keys())
115-
115+
116116
if not common_tensors:
117117
logger.warning("No common tensors found between original and converted models")
118118
results["overall_passed"] = False
119119
return results
120-
120+
121121
results["total_tensors"] = len(common_tensors)
122-
122+
123123
for tensor_name in sorted(common_tensors):
124124
passed, metrics = validate_tensor_conversion(
125125
tensor_name,
@@ -129,32 +129,32 @@ def validate_model_conversion(
129129
max_error_threshold=thresholds["max_error"],
130130
verbose=verbose
131131
)
132-
132+
133133
results["metrics"][tensor_name] = metrics
134-
134+
135135
if passed:
136136
results["passed_tensors"] += 1
137137
else:
138138
results["failed_tensors"].append(tensor_name)
139139
results["overall_passed"] = False
140-
140+
141141
if verbose:
142142
logger.info(
143143
f"\nValidation Summary: {results['passed_tensors']}/{results['total_tensors']} tensors passed"
144144
)
145145
if results["failed_tensors"]:
146146
logger.warning(f"Failed tensors: {', '.join(results['failed_tensors'])}")
147-
147+
148148
return results
149149

150150

151151
def get_quantization_thresholds(quantization_type: str) -> dict[str, float]:
152152
"""
153153
Get appropriate error thresholds for different quantization types.
154-
154+
155155
Args:
156156
quantization_type: Type of quantization (f32, f16, q4_0, q8_0, etc.)
157-
157+
158158
Returns:
159159
Dictionary with "rmse" and "max_error" thresholds
160160
"""
@@ -173,25 +173,25 @@ def get_quantization_thresholds(quantization_type: str) -> dict[str, float]:
173173
"q5_k": {"rmse": 8e-3, "max_error": 8e-2},
174174
"q6_k": {"rmse": 5e-3, "max_error": 5e-2},
175175
}
176-
176+
177177
default = {"rmse": 1e-2, "max_error": 1e-1}
178-
178+
179179
return thresholds_map.get(quantization_type.lower(), default)
180180

181181

182182
def save_validation_report(results: dict[str, Any], output_path: Path) -> None:
183183
"""
184184
Save validation results to a JSON file.
185-
185+
186186
Args:
187187
results: Validation results dictionary
188188
output_path: Path to save the report
189189
"""
190190
import json
191-
191+
192192
with open(output_path, 'w') as f:
193193
json.dump(results, f, indent=2)
194-
194+
195195
logger.info(f"Validation report saved to {output_path}")
196196

197197

tests/test-backend-ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5916,7 +5916,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
59165916
GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
59175917
GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, GGML_TYPE_Q6_K
59185918
};
5919-
5919+
59205920
for (ggml_type intermediate : quant_conversion_test_types) {
59215921
for (ggml_type dst : quant_conversion_test_types) {
59225922
if (intermediate != dst) {

0 commit comments

Comments
 (0)