Skip to content

Commit 3e831f2

Browse files
Fix torch.onnx._internal imports for torch 2.9 (#356)
Signed-off-by: Keval Morabia <[email protected]>
1 parent e59ef52 commit 3e831f2

File tree

3 files changed

+81
-63
lines changed

3 files changed

+81
-63
lines changed

modelopt/torch/quantization/export_onnx.py

Lines changed: 65 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,18 @@
103103
"""Utility to export a quantized torch model to quantized ONNX."""
104104

105105
import contextlib
106+
from typing import TYPE_CHECKING
106107

107108
import onnx
108109
import torch
109110
from torch.onnx import symbolic_helper
110111
from torch.onnx import symbolic_helper as sym_help
111-
from torch.onnx._internal import jit_utils
112-
from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask
112+
113+
if TYPE_CHECKING:
114+
if hasattr(torch.onnx._internal, "jit_utils"):
115+
from torch.onnx._internal.jit_utils import GraphContext
116+
else: # torch >= 2.9
117+
from torch.onnx._internal.torchscript_exporter.jit_utils import GraphContext
113118

114119
onnx_dtype_map = {
115120
"BFloat16": onnx.TensorProto.BFLOAT16,
@@ -125,7 +130,7 @@
125130

126131

127132
def export_int8(
128-
g: torch.onnx._internal.jit_utils.GraphContext,
133+
g: "GraphContext",
129134
inputs: torch.Value,
130135
amax: torch.Tensor,
131136
num_bits: int,
@@ -184,7 +189,7 @@ def export_int8(
184189

185190

186191
def export_int4(
187-
g: torch.onnx._internal.jit_utils.GraphContext,
192+
g: "GraphContext",
188193
inputs: torch.Value,
189194
amax: torch.Tensor,
190195
num_bits: int,
@@ -208,7 +213,7 @@ def export_int4(
208213

209214

210215
def _fp8_quantize(
211-
g: torch.onnx._internal.jit_utils.GraphContext,
216+
g: "GraphContext",
212217
inputs: torch.Value,
213218
scale_inv: float,
214219
trt_high_precision_dtype: str,
@@ -236,7 +241,7 @@ def _fp8_quantize(
236241

237242

238243
def _fp8_dequantize(
239-
g: torch.onnx._internal.jit_utils.GraphContext,
244+
g: "GraphContext",
240245
inputs: torch.Value,
241246
scale_inv: float,
242247
trt_high_precision_dtype: str,
@@ -263,7 +268,7 @@ def _fp8_dequantize(
263268

264269

265270
def export_fp8(
266-
g: torch.onnx._internal.jit_utils.GraphContext,
271+
g: "GraphContext",
267272
inputs: torch.Value,
268273
amax: float,
269274
trt_high_precision_dtype: str | None,
@@ -279,21 +284,29 @@ def export_fp8(
279284

280285

281286
def scaled_dot_product_attention(
282-
g: jit_utils.GraphContext,
283-
query: torch._C.Value,
284-
key: torch._C.Value,
285-
value: torch._C.Value,
286-
attn_mask: torch._C.Value | None = None,
287+
g: "GraphContext",
288+
query: "torch._C.Value",
289+
key: "torch._C.Value",
290+
value: "torch._C.Value",
291+
attn_mask: "torch._C.Value | None" = None,
287292
dropout_p: float = 0.0,
288293
is_causal: bool = False,
289-
scale: torch._C.Value | None = None,
294+
scale: "torch._C.Value | None" = None,
290295
enable_gqa: bool = False,
291296
):
292297
"""Perform scaled dot product attention."""
293298
if hasattr(torch.onnx, "_type_utils"):
294-
from torch.onnx import _type_utils
295-
else:
296-
from torch.onnx._internal.torchscript_exporter import _type_utils
299+
from torch.onnx._type_utils import JitScalarType
300+
else: # torch >= 2.9
301+
from torch.onnx._internal.torchscript_exporter import JitScalarType
302+
303+
if hasattr(torch.onnx, "symbolic_opset14"):
304+
from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask
305+
else: # torch >= 2.9
306+
from torch.onnx._internal.torchscript_exporter.symbolic_opset14 import (
307+
_attention_scale,
308+
_causal_attention_mask,
309+
)
297310

298311
assert (not is_causal) or (is_causal and symbolic_helper._is_none(attn_mask)), (
299312
"is_causal and attn_mask cannot be set at the same time"
@@ -327,22 +340,20 @@ def scaled_dot_product_attention(
327340

328341
if symbolic_helper._is_none(attn_mask):
329342
mul_qk_add = mul_qk
330-
elif _type_utils.JitScalarType.from_value(attn_mask) == _type_utils.JitScalarType.BOOL:
343+
elif JitScalarType.from_value(attn_mask) == JitScalarType.BOOL:
331344
# Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf'))
332345
const_zero = g.op("Constant", value_t=torch.tensor([0.0]))
333346
const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")]))
334347
attn_mask = g.op("Where", attn_mask, const_zero, const_neg_inf)
335348
mul_qk_add = g.op("Add", mul_qk, attn_mask)
336-
elif _type_utils.JitScalarType.from_value(attn_mask) in (
337-
_type_utils.JitScalarType.FLOAT,
338-
_type_utils.JitScalarType.HALF,
339-
_type_utils.JitScalarType.BFLOAT16,
349+
elif JitScalarType.from_value(attn_mask) in (
350+
JitScalarType.FLOAT,
351+
JitScalarType.HALF,
352+
JitScalarType.BFLOAT16,
340353
):
341354
mul_qk_add = g.op("Add", mul_qk, attn_mask)
342355
else:
343-
raise ValueError(
344-
f"Unsupported type for attn_mask: {_type_utils.JitScalarType.from_value(attn_mask)}"
345-
)
356+
raise ValueError(f"Unsupported type for attn_mask: {JitScalarType.from_value(attn_mask)}")
346357

347358
attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1)
348359

@@ -357,14 +368,14 @@ def scaled_dot_product_attention(
357368

358369

359370
def export_fp8_mha(
360-
g: torch.onnx._internal.jit_utils.GraphContext,
361-
query: torch._C.Value,
362-
key: torch._C.Value,
363-
value: torch._C.Value,
364-
attn_mask: torch._C.Value | None = None,
371+
g: "GraphContext",
372+
query: "torch._C.Value",
373+
key: "torch._C.Value",
374+
value: "torch._C.Value",
375+
attn_mask: "torch._C.Value | None" = None,
365376
dropout_p: float = 0.0,
366377
is_causal: bool = False,
367-
scale: torch._C.Value | None = None,
378+
scale: "torch._C.Value | None" = None,
368379
q_quantized_scale: float = 1.0,
369380
k_quantized_scale: float = 1.0,
370381
v_quantized_scale: float = 1.0,
@@ -396,12 +407,18 @@ def export_fp8_mha(
396407
|
397408
Cast
398409
"""
399-
from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask
400-
401410
if hasattr(torch.onnx, "_type_utils"):
402-
from torch.onnx import _type_utils
403-
else:
404-
from torch.onnx._internal.torchscript_exporter import _type_utils
411+
from torch.onnx._type_utils import JitScalarType
412+
else: # torch >= 2.9
413+
from torch.onnx._internal.torchscript_exporter import JitScalarType
414+
415+
if hasattr(torch.onnx, "symbolic_opset14"):
416+
from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask
417+
else: # torch >= 2.9
418+
from torch.onnx._internal.torchscript_exporter.symbolic_opset14 import (
419+
_attention_scale,
420+
_causal_attention_mask,
421+
)
405422

406423
# Pass all arguments, including x, to the custom ONNX operator
407424
assert (not is_causal) or (is_causal and sym_help._is_none(attn_mask)), (
@@ -452,22 +469,20 @@ def export_fp8_mha(
452469

453470
if sym_help._is_none(attn_mask):
454471
mul_qk_add = mul_qk
455-
elif _type_utils.JitScalarType.from_value(attn_mask) == _type_utils.JitScalarType.BOOL:
472+
elif JitScalarType.from_value(attn_mask) == JitScalarType.BOOL:
456473
# Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf'))
457474
const_zero = g.op("Constant", value_t=torch.tensor([0.0]))
458475
const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")]))
459476
attn_mask = g.op("Where", attn_mask, const_zero, const_neg_inf)
460477
mul_qk_add = g.op("Add", mul_qk, attn_mask)
461-
elif _type_utils.JitScalarType.from_value(attn_mask) in (
462-
_type_utils.JitScalarType.FLOAT,
463-
_type_utils.JitScalarType.HALF,
464-
_type_utils.JitScalarType.BFLOAT16,
478+
elif JitScalarType.from_value(attn_mask) in (
479+
JitScalarType.FLOAT,
480+
JitScalarType.HALF,
481+
JitScalarType.BFLOAT16,
465482
):
466483
mul_qk_add = g.op("Add", mul_qk, attn_mask)
467484
else:
468-
raise ValueError(
469-
f"Unsupported type for attn_mask: {_type_utils.JitScalarType.from_value(attn_mask)}"
470-
)
485+
raise ValueError(f"Unsupported type for attn_mask: {JitScalarType.from_value(attn_mask)}")
471486

472487
attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1)
473488

@@ -495,7 +510,7 @@ def export_fp8_mha(
495510

496511

497512
def _fp4_dynamic_quantize(
498-
g: torch.onnx._internal.jit_utils.GraphContext,
513+
g: "GraphContext",
499514
inputs: torch.Value,
500515
scale: float,
501516
trt_high_precision_dtype: str | None,
@@ -531,7 +546,7 @@ def _fp4_dynamic_quantize(
531546

532547

533548
def _fp4_dequantize(
534-
g: torch.onnx._internal.jit_utils.GraphContext,
549+
g: "GraphContext",
535550
inputs: torch.Value,
536551
scale: float | torch.Value,
537552
trt_high_precision_dtype: str | None,
@@ -546,7 +561,7 @@ def _fp4_dequantize(
546561

547562

548563
def _fp4_dequantize_2(
549-
g: torch.onnx._internal.jit_utils.GraphContext,
564+
g: "GraphContext",
550565
inputs: torch.Value,
551566
dyn_scale: torch.Value,
552567
block_size: int,
@@ -557,7 +572,7 @@ def _fp4_dequantize_2(
557572

558573

559574
def _mxfp8_dynamic_quantize(
560-
g: torch.onnx._internal.jit_utils.GraphContext,
575+
g: "GraphContext",
561576
inputs: torch.Value,
562577
block_size: int,
563578
axis: int = -1,
@@ -575,7 +590,7 @@ def _mxfp8_dynamic_quantize(
575590

576591

577592
def _mxfp8_dequantize(
578-
g: torch.onnx._internal.jit_utils.GraphContext,
593+
g: "GraphContext",
579594
inputs: torch.Value,
580595
scale: torch.Value,
581596
block_size: int,
@@ -593,7 +608,7 @@ def _mxfp8_dequantize(
593608

594609

595610
def export_mxfp8(
596-
g: torch.onnx._internal.jit_utils.GraphContext,
611+
g: "GraphContext",
597612
inputs: torch.Tensor,
598613
onnx_quantizer_type: str,
599614
block_size: int,
@@ -611,7 +626,7 @@ def export_mxfp8(
611626

612627

613628
def export_fp4(
614-
g: torch.onnx._internal.jit_utils.GraphContext,
629+
g: "GraphContext",
615630
inputs: torch.Value,
616631
block_size: int,
617632
amax: torch.Value,

modelopt/torch/quantization/nn/modules/tensor_quantizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,7 @@ def _get_amax(self, inputs):
548548

549549
def _validate_amax(self, amax):
550550
# Dynamic control flow is not supported by torch dynamo
551-
if not is_torch_export_mode() and not torch._dynamo.is_compiling():
551+
if not is_torch_export_mode() and not torch.compiler.is_compiling():
552552
assert torch.all(amax >= 0) and not torch.any(torch.isinf(amax)), (
553553
f"Got invalid amax: {amax}"
554554
)
@@ -880,7 +880,7 @@ def forward(self, inputs):
880880
"""
881881
if hasattr(torch.onnx, "_globals"):
882882
from torch.onnx._globals import GLOBALS
883-
else:
883+
else: # torch >= 2.9
884884
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
885885

886886
if DTensor is not None and isinstance(inputs, DTensor):
@@ -914,7 +914,7 @@ def forward(self, inputs):
914914

915915
if (
916916
not is_torch_export_mode()
917-
and not torch._dynamo.is_compiling()
917+
and not torch.compiler.is_compiling()
918918
and GLOBALS.in_onnx_export
919919
):
920920
# GLOBALS could break TorchDynamo for some Pytorch versions (i.e., 2.3.0)

modelopt/torch/quantization/plugins/diffusers.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515

1616
"""Support quantization of diffusers layers."""
1717

18-
import functools
1918
from collections.abc import Callable, Iterator
2019
from functools import partial
2120
from types import ModuleType
21+
from typing import TYPE_CHECKING
2222

2323
import onnx
2424
import torch
@@ -27,7 +27,12 @@
2727
from torch.autograd import Function
2828
from torch.nn import functional as F
2929
from torch.onnx import symbolic_helper
30-
from torch.onnx._internal import jit_utils, registration
30+
31+
if TYPE_CHECKING:
32+
if hasattr(torch.onnx._internal, "jit_utils"):
33+
from torch.onnx._internal.jit_utils import GraphContext
34+
else: # torch >= 2.9
35+
from torch.onnx._internal.torchscript_exporter.jit_utils import GraphContext
3136

3237
from ..export_onnx import export_fp8_mha
3338
from ..nn import (
@@ -40,8 +45,6 @@
4045
)
4146
from .custom import _QuantFunctionalMixin
4247

43-
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=18)
44-
4548
onnx_dtype_map = {
4649
"BFloat16": onnx.TensorProto.BFLOAT16,
4750
"Float": onnx.TensorProto.FLOAT,
@@ -205,14 +208,14 @@ def forward(
205208
@staticmethod
206209
@symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "t", "t", "t", "s", "b")
207210
def symbolic(
208-
g: jit_utils.GraphContext,
209-
query: torch._C.Value,
210-
key: torch._C.Value,
211-
value: torch._C.Value,
212-
attn_mask: torch._C.Value | None = None,
211+
g: "GraphContext",
212+
query: "torch._C.Value",
213+
key: "torch._C.Value",
214+
value: "torch._C.Value",
215+
attn_mask: "torch._C.Value | None" = None,
213216
dropout_p: float = 0.0,
214217
is_causal: bool = False,
215-
scale: torch._C.Value | None = None,
218+
scale: "torch._C.Value | None" = None,
216219
q_quantized_scale: float = 1.0,
217220
k_quantized_scale: float = 1.0,
218221
v_quantized_scale: float = 1.0,

0 commit comments

Comments
 (0)