103103"""Utility to export a quantized torch model to quantized ONNX."""
104104
105105import contextlib
106+ from typing import TYPE_CHECKING
106107
107108import onnx
108109import torch
109110from torch .onnx import symbolic_helper
110111from torch .onnx import symbolic_helper as sym_help
111- from torch .onnx ._internal import jit_utils
112- from torch .onnx .symbolic_opset14 import _attention_scale , _causal_attention_mask
112+
113+ if TYPE_CHECKING :
114+ if hasattr (torch .onnx ._internal , "jit_utils" ):
115+ from torch .onnx ._internal .jit_utils import GraphContext
116+ else : # torch >= 2.9
117+ from torch .onnx ._internal .torchscript_exporter .jit_utils import GraphContext
113118
114119onnx_dtype_map = {
115120 "BFloat16" : onnx .TensorProto .BFLOAT16 ,
125130
126131
127132def export_int8 (
128- g : torch . onnx . _internal . jit_utils . GraphContext ,
133+ g : " GraphContext" ,
129134 inputs : torch .Value ,
130135 amax : torch .Tensor ,
131136 num_bits : int ,
@@ -184,7 +189,7 @@ def export_int8(
184189
185190
186191def export_int4 (
187- g : torch . onnx . _internal . jit_utils . GraphContext ,
192+ g : " GraphContext" ,
188193 inputs : torch .Value ,
189194 amax : torch .Tensor ,
190195 num_bits : int ,
@@ -208,7 +213,7 @@ def export_int4(
208213
209214
210215def _fp8_quantize (
211- g : torch . onnx . _internal . jit_utils . GraphContext ,
216+ g : " GraphContext" ,
212217 inputs : torch .Value ,
213218 scale_inv : float ,
214219 trt_high_precision_dtype : str ,
@@ -236,7 +241,7 @@ def _fp8_quantize(
236241
237242
238243def _fp8_dequantize (
239- g : torch . onnx . _internal . jit_utils . GraphContext ,
244+ g : " GraphContext" ,
240245 inputs : torch .Value ,
241246 scale_inv : float ,
242247 trt_high_precision_dtype : str ,
@@ -263,7 +268,7 @@ def _fp8_dequantize(
263268
264269
265270def export_fp8 (
266- g : torch . onnx . _internal . jit_utils . GraphContext ,
271+ g : " GraphContext" ,
267272 inputs : torch .Value ,
268273 amax : float ,
269274 trt_high_precision_dtype : str | None ,
@@ -279,21 +284,29 @@ def export_fp8(
279284
280285
281286def scaled_dot_product_attention (
282- g : jit_utils . GraphContext ,
283- query : torch ._C .Value ,
284- key : torch ._C .Value ,
285- value : torch ._C .Value ,
286- attn_mask : torch ._C .Value | None = None ,
287+ g : " GraphContext" ,
288+ query : " torch._C.Value" ,
289+ key : " torch._C.Value" ,
290+ value : " torch._C.Value" ,
291+ attn_mask : " torch._C.Value | None" = None ,
287292 dropout_p : float = 0.0 ,
288293 is_causal : bool = False ,
289- scale : torch ._C .Value | None = None ,
294+ scale : " torch._C.Value | None" = None ,
290295 enable_gqa : bool = False ,
291296):
292297 """Perform scaled dot product attention."""
293298 if hasattr (torch .onnx , "_type_utils" ):
294- from torch .onnx import _type_utils
295- else :
296- from torch .onnx ._internal .torchscript_exporter import _type_utils
299+ from torch .onnx ._type_utils import JitScalarType
300+ else : # torch >= 2.9
301+ from torch .onnx ._internal .torchscript_exporter import JitScalarType
302+
303+ if hasattr (torch .onnx , "symbolic_opset14" ):
304+ from torch .onnx .symbolic_opset14 import _attention_scale , _causal_attention_mask
305+ else : # torch >= 2.9
306+ from torch .onnx ._internal .torchscript_exporter .symbolic_opset14 import (
307+ _attention_scale ,
308+ _causal_attention_mask ,
309+ )
297310
298311 assert (not is_causal ) or (is_causal and symbolic_helper ._is_none (attn_mask )), (
299312 "is_causal and attn_mask cannot be set at the same time"
@@ -327,22 +340,20 @@ def scaled_dot_product_attention(
327340
328341 if symbolic_helper ._is_none (attn_mask ):
329342 mul_qk_add = mul_qk
330- elif _type_utils . JitScalarType .from_value (attn_mask ) == _type_utils . JitScalarType .BOOL :
343+ elif JitScalarType .from_value (attn_mask ) == JitScalarType .BOOL :
331344 # Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf'))
332345 const_zero = g .op ("Constant" , value_t = torch .tensor ([0.0 ]))
333346 const_neg_inf = g .op ("Constant" , value_t = torch .tensor ([- float ("inf" )]))
334347 attn_mask = g .op ("Where" , attn_mask , const_zero , const_neg_inf )
335348 mul_qk_add = g .op ("Add" , mul_qk , attn_mask )
336- elif _type_utils . JitScalarType .from_value (attn_mask ) in (
337- _type_utils . JitScalarType .FLOAT ,
338- _type_utils . JitScalarType .HALF ,
339- _type_utils . JitScalarType .BFLOAT16 ,
349+ elif JitScalarType .from_value (attn_mask ) in (
350+ JitScalarType .FLOAT ,
351+ JitScalarType .HALF ,
352+ JitScalarType .BFLOAT16 ,
340353 ):
341354 mul_qk_add = g .op ("Add" , mul_qk , attn_mask )
342355 else :
343- raise ValueError (
344- f"Unsupported type for attn_mask: { _type_utils .JitScalarType .from_value (attn_mask )} "
345- )
356+ raise ValueError (f"Unsupported type for attn_mask: { JitScalarType .from_value (attn_mask )} " )
346357
347358 attn_weight = g .op ("Softmax" , mul_qk_add , axis_i = - 1 )
348359
@@ -357,14 +368,14 @@ def scaled_dot_product_attention(
357368
358369
359370def export_fp8_mha (
360- g : torch . onnx . _internal . jit_utils . GraphContext ,
361- query : torch ._C .Value ,
362- key : torch ._C .Value ,
363- value : torch ._C .Value ,
364- attn_mask : torch ._C .Value | None = None ,
371+ g : " GraphContext" ,
372+ query : " torch._C.Value" ,
373+ key : " torch._C.Value" ,
374+ value : " torch._C.Value" ,
375+ attn_mask : " torch._C.Value | None" = None ,
365376 dropout_p : float = 0.0 ,
366377 is_causal : bool = False ,
367- scale : torch ._C .Value | None = None ,
378+ scale : " torch._C.Value | None" = None ,
368379 q_quantized_scale : float = 1.0 ,
369380 k_quantized_scale : float = 1.0 ,
370381 v_quantized_scale : float = 1.0 ,
@@ -396,12 +407,18 @@ def export_fp8_mha(
396407 |
397408 Cast
398409 """
399- from torch .onnx .symbolic_opset14 import _attention_scale , _causal_attention_mask
400-
401410 if hasattr (torch .onnx , "_type_utils" ):
402- from torch .onnx import _type_utils
403- else :
404- from torch .onnx ._internal .torchscript_exporter import _type_utils
411+ from torch .onnx ._type_utils import JitScalarType
412+ else : # torch >= 2.9
413+ from torch .onnx ._internal .torchscript_exporter import JitScalarType
414+
415+ if hasattr (torch .onnx , "symbolic_opset14" ):
416+ from torch .onnx .symbolic_opset14 import _attention_scale , _causal_attention_mask
417+ else : # torch >= 2.9
418+ from torch .onnx ._internal .torchscript_exporter .symbolic_opset14 import (
419+ _attention_scale ,
420+ _causal_attention_mask ,
421+ )
405422
406423 # Pass all arguments, including x, to the custom ONNX operator
407424 assert (not is_causal ) or (is_causal and sym_help ._is_none (attn_mask )), (
@@ -452,22 +469,20 @@ def export_fp8_mha(
452469
453470 if sym_help ._is_none (attn_mask ):
454471 mul_qk_add = mul_qk
455- elif _type_utils . JitScalarType .from_value (attn_mask ) == _type_utils . JitScalarType .BOOL :
472+ elif JitScalarType .from_value (attn_mask ) == JitScalarType .BOOL :
456473 # Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf'))
457474 const_zero = g .op ("Constant" , value_t = torch .tensor ([0.0 ]))
458475 const_neg_inf = g .op ("Constant" , value_t = torch .tensor ([- float ("inf" )]))
459476 attn_mask = g .op ("Where" , attn_mask , const_zero , const_neg_inf )
460477 mul_qk_add = g .op ("Add" , mul_qk , attn_mask )
461- elif _type_utils . JitScalarType .from_value (attn_mask ) in (
462- _type_utils . JitScalarType .FLOAT ,
463- _type_utils . JitScalarType .HALF ,
464- _type_utils . JitScalarType .BFLOAT16 ,
478+ elif JitScalarType .from_value (attn_mask ) in (
479+ JitScalarType .FLOAT ,
480+ JitScalarType .HALF ,
481+ JitScalarType .BFLOAT16 ,
465482 ):
466483 mul_qk_add = g .op ("Add" , mul_qk , attn_mask )
467484 else :
468- raise ValueError (
469- f"Unsupported type for attn_mask: { _type_utils .JitScalarType .from_value (attn_mask )} "
470- )
485+ raise ValueError (f"Unsupported type for attn_mask: { JitScalarType .from_value (attn_mask )} " )
471486
472487 attn_weight = g .op ("Softmax" , mul_qk_add , axis_i = - 1 )
473488
@@ -495,7 +510,7 @@ def export_fp8_mha(
495510
496511
497512def _fp4_dynamic_quantize (
498- g : torch . onnx . _internal . jit_utils . GraphContext ,
513+ g : " GraphContext" ,
499514 inputs : torch .Value ,
500515 scale : float ,
501516 trt_high_precision_dtype : str | None ,
@@ -531,7 +546,7 @@ def _fp4_dynamic_quantize(
531546
532547
533548def _fp4_dequantize (
534- g : torch . onnx . _internal . jit_utils . GraphContext ,
549+ g : " GraphContext" ,
535550 inputs : torch .Value ,
536551 scale : float | torch .Value ,
537552 trt_high_precision_dtype : str | None ,
@@ -546,7 +561,7 @@ def _fp4_dequantize(
546561
547562
548563def _fp4_dequantize_2 (
549- g : torch . onnx . _internal . jit_utils . GraphContext ,
564+ g : " GraphContext" ,
550565 inputs : torch .Value ,
551566 dyn_scale : torch .Value ,
552567 block_size : int ,
@@ -557,7 +572,7 @@ def _fp4_dequantize_2(
557572
558573
559574def _mxfp8_dynamic_quantize (
560- g : torch . onnx . _internal . jit_utils . GraphContext ,
575+ g : " GraphContext" ,
561576 inputs : torch .Value ,
562577 block_size : int ,
563578 axis : int = - 1 ,
@@ -575,7 +590,7 @@ def _mxfp8_dynamic_quantize(
575590
576591
577592def _mxfp8_dequantize (
578- g : torch . onnx . _internal . jit_utils . GraphContext ,
593+ g : " GraphContext" ,
579594 inputs : torch .Value ,
580595 scale : torch .Value ,
581596 block_size : int ,
@@ -593,7 +608,7 @@ def _mxfp8_dequantize(
593608
594609
595610def export_mxfp8 (
596- g : torch . onnx . _internal . jit_utils . GraphContext ,
611+ g : " GraphContext" ,
597612 inputs : torch .Tensor ,
598613 onnx_quantizer_type : str ,
599614 block_size : int ,
@@ -611,7 +626,7 @@ def export_mxfp8(
611626
612627
613628def export_fp4 (
614- g : torch . onnx . _internal . jit_utils . GraphContext ,
629+ g : " GraphContext" ,
615630 inputs : torch .Value ,
616631 block_size : int ,
617632 amax : torch .Value ,
0 commit comments