Skip to content

Commit ef13d35

Browse files
committed
Fix version check
Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent 9280dd6 commit ef13d35

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

onnxscript/rewriter/rules/fusion/_gqa_test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,11 @@ def test_basic_gqa_fusion(self):
8181
self.assertGreater(count, 0, "GQA fusion should have occurred")
8282

8383
# We can't yet test numerical equivalence because of a bug in the op spec/implementation.
84-
if version.parse(onnx.__version__) >= version.parse("1.19.1"):
84+
onnx_ver = version.parse(onnx.__version__)
85+
if onnx_ver >= version.parse("1.19.1") and not (
86+
onnx_ver.is_prerelease or onnx_ver.is_devrelease
87+
):
88+
# Only official releases >= 1.19.1
8589
onnxscript.optimizer.remove_unused_nodes(model)
8690
rewritten_model_proto = ir.serde.serialize_model(model)
8791
onnxscript.rewriter.testing.assert_numerically_equal(

onnxscript/rewriter/testing.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,12 @@ def assert_numerically_equal(
102102
for i, (orig, rewritten) in enumerate(zip(original_outputs, the_rewritten_outputs)):
103103
print(f"==== Output {i} ====")
104104
diff = np.abs(orig - rewritten)
105-
print(diff)
105+
for h in range(diff.shape[1]):
106+
subarray = diff[:, h, :, :] # Select along H
107+
if np.allclose(subarray, 0):
108+
print(f"H={h}: all zeros")
109+
else:
110+
print(f"H={h}: not all zeros")
106111

107112
np.testing.assert_allclose(
108113
original_outputs, the_rewritten_outputs, rtol=rtol, atol=atol, equal_nan=True

0 commit comments

Comments
 (0)