Skip to content

Commit c5701d0

Browse files
justinchubypytorchmergebot
authored andcommitted
[ONNX] Create fake implementations for onnx ops; fix boolean mask in attention (pytorch#165780)
Previously we rely on the concreate implementation to generate fake implementation. This makes the fake implementation overly complicated and breaks in some cases when there are dynamic shapes. This PR updates onnx op registration to instead take a dedicated fake implementation. **Also fixed: When boolean mask is supplied to torch sdpa, it was previously taken the negation, which is incorrect.** Fix pytorch#164909 Also taken changes from pytorch#156635 Pull Request resolved: pytorch#165780 Approved by: https://github.com/titaiwangms
1 parent 23669d0 commit c5701d0

File tree

3 files changed

+192
-132
lines changed

3 files changed

+192
-132
lines changed

test/onnx/exporter/test_small_models_e2e.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,16 @@
55

66
import logging
77

8-
import onnxruntime
98
import pytest
109
import transformers
1110
from onnxscript import ir
12-
from packaging import version
1311

1412
import torch
1513
from torch.onnx._internal.exporter import _testing as onnx_testing
1614
from torch.testing._internal import common_utils
1715
from torch.utils import _pytree as torch_pytree
1816

1917

20-
def has_onnxruntime_opset_23() -> bool:
21-
return version.parse(onnxruntime.__version__) >= version.parse("1.23")
22-
23-
2418
class _WithExport:
2519
def export(self, model, args=(), kwargs=None, **options) -> torch.onnx.ONNXProgram:
2620
onnx_program = torch.onnx.export(
@@ -746,13 +740,7 @@ def forward(self, query, key, value):
746740
onnx_program = self.export(Model(), (query, key, value), opset_version=23)
747741
self.assertEqual(["Attention"], [n.op_type for n in onnx_program.model.graph])
748742

749-
if has_onnxruntime_opset_23():
750-
onnx_testing.assert_onnx_program(onnx_program, atol=1e-2, rtol=1)
751-
else:
752-
# Test with reference evaluator because ORT does not support the op as of version 1.22
753-
onnx_testing.assert_onnx_program(
754-
onnx_program, atol=1e-2, rtol=1, backend="reference"
755-
)
743+
onnx_testing.assert_onnx_program(onnx_program, atol=1e-2, rtol=1)
756744

757745
def test_rms_norm(self):
758746
"""Test RMS normalization with various configurations."""
@@ -789,8 +777,7 @@ def forward(self, x):
789777

790778
onnx_program = self.export(RMSNormWithWeight(), (x,), opset_version=23)
791779

792-
# Test with reference evaluator because ORT does not support the op as of version 1.22
793-
onnx_testing.assert_onnx_program(onnx_program, backend="reference")
780+
onnx_testing.assert_onnx_program(onnx_program)
794781

795782
def test_rms_norm_with_eps(self):
796783
"""Test RMS normalization with custom epsilon."""
@@ -803,7 +790,6 @@ def forward(self, x):
803790

804791
onnx_program = self.export(RMSNormWithEps(), (x,), opset_version=23)
805792

806-
# Test with reference evaluator because ORT does not support the op as of version 1.22
807793
onnx_testing.assert_onnx_program(onnx_program, backend="reference")
808794

809795
def test_enable_gqa_in_attention_23_with_dropout(self):

0 commit comments

Comments
 (0)