103
103
"""Utility to export a quantized torch model to quantized ONNX."""
104
104
105
105
import contextlib
106
+ from typing import TYPE_CHECKING
106
107
107
108
import onnx
108
109
import torch
109
110
from torch .onnx import symbolic_helper
110
111
from torch .onnx import symbolic_helper as sym_help
111
- from torch .onnx ._internal import jit_utils
112
- from torch .onnx .symbolic_opset14 import _attention_scale , _causal_attention_mask
112
+
113
+ if TYPE_CHECKING :
114
+ if hasattr (torch .onnx ._internal , "jit_utils" ):
115
+ from torch .onnx ._internal .jit_utils import GraphContext
116
+ else : # torch >= 2.9
117
+ from torch .onnx ._internal .torchscript_exporter .jit_utils import GraphContext
113
118
114
119
onnx_dtype_map = {
115
120
"BFloat16" : onnx .TensorProto .BFLOAT16 ,
125
130
126
131
127
132
def export_int8 (
128
- g : torch . onnx . _internal . jit_utils . GraphContext ,
133
+ g : " GraphContext" ,
129
134
inputs : torch .Value ,
130
135
amax : torch .Tensor ,
131
136
num_bits : int ,
@@ -184,7 +189,7 @@ def export_int8(
184
189
185
190
186
191
def export_int4 (
187
- g : torch . onnx . _internal . jit_utils . GraphContext ,
192
+ g : " GraphContext" ,
188
193
inputs : torch .Value ,
189
194
amax : torch .Tensor ,
190
195
num_bits : int ,
@@ -208,7 +213,7 @@ def export_int4(
208
213
209
214
210
215
def _fp8_quantize (
211
- g : torch . onnx . _internal . jit_utils . GraphContext ,
216
+ g : " GraphContext" ,
212
217
inputs : torch .Value ,
213
218
scale_inv : float ,
214
219
trt_high_precision_dtype : str ,
@@ -236,7 +241,7 @@ def _fp8_quantize(
236
241
237
242
238
243
def _fp8_dequantize (
239
- g : torch . onnx . _internal . jit_utils . GraphContext ,
244
+ g : " GraphContext" ,
240
245
inputs : torch .Value ,
241
246
scale_inv : float ,
242
247
trt_high_precision_dtype : str ,
@@ -263,7 +268,7 @@ def _fp8_dequantize(
263
268
264
269
265
270
def export_fp8 (
266
- g : torch . onnx . _internal . jit_utils . GraphContext ,
271
+ g : " GraphContext" ,
267
272
inputs : torch .Value ,
268
273
amax : float ,
269
274
trt_high_precision_dtype : str | None ,
@@ -279,21 +284,29 @@ def export_fp8(
279
284
280
285
281
286
def scaled_dot_product_attention (
282
- g : jit_utils . GraphContext ,
283
- query : torch ._C .Value ,
284
- key : torch ._C .Value ,
285
- value : torch ._C .Value ,
286
- attn_mask : torch ._C .Value | None = None ,
287
+ g : " GraphContext" ,
288
+ query : " torch._C.Value" ,
289
+ key : " torch._C.Value" ,
290
+ value : " torch._C.Value" ,
291
+ attn_mask : " torch._C.Value | None" = None ,
287
292
dropout_p : float = 0.0 ,
288
293
is_causal : bool = False ,
289
- scale : torch ._C .Value | None = None ,
294
+ scale : " torch._C.Value | None" = None ,
290
295
enable_gqa : bool = False ,
291
296
):
292
297
"""Perform scaled dot product attention."""
293
298
if hasattr (torch .onnx , "_type_utils" ):
294
- from torch .onnx import _type_utils
295
- else :
296
- from torch .onnx ._internal .torchscript_exporter import _type_utils
299
+ from torch .onnx ._type_utils import JitScalarType
300
+ else : # torch >= 2.9
301
+ from torch .onnx ._internal .torchscript_exporter import JitScalarType
302
+
303
+ if hasattr (torch .onnx , "symbolic_opset14" ):
304
+ from torch .onnx .symbolic_opset14 import _attention_scale , _causal_attention_mask
305
+ else : # torch >= 2.9
306
+ from torch .onnx ._internal .torchscript_exporter .symbolic_opset14 import (
307
+ _attention_scale ,
308
+ _causal_attention_mask ,
309
+ )
297
310
298
311
assert (not is_causal ) or (is_causal and symbolic_helper ._is_none (attn_mask )), (
299
312
"is_causal and attn_mask cannot be set at the same time"
@@ -327,22 +340,20 @@ def scaled_dot_product_attention(
327
340
328
341
if symbolic_helper ._is_none (attn_mask ):
329
342
mul_qk_add = mul_qk
330
- elif _type_utils . JitScalarType .from_value (attn_mask ) == _type_utils . JitScalarType .BOOL :
343
+ elif JitScalarType .from_value (attn_mask ) == JitScalarType .BOOL :
331
344
# Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf'))
332
345
const_zero = g .op ("Constant" , value_t = torch .tensor ([0.0 ]))
333
346
const_neg_inf = g .op ("Constant" , value_t = torch .tensor ([- float ("inf" )]))
334
347
attn_mask = g .op ("Where" , attn_mask , const_zero , const_neg_inf )
335
348
mul_qk_add = g .op ("Add" , mul_qk , attn_mask )
336
- elif _type_utils . JitScalarType .from_value (attn_mask ) in (
337
- _type_utils . JitScalarType .FLOAT ,
338
- _type_utils . JitScalarType .HALF ,
339
- _type_utils . JitScalarType .BFLOAT16 ,
349
+ elif JitScalarType .from_value (attn_mask ) in (
350
+ JitScalarType .FLOAT ,
351
+ JitScalarType .HALF ,
352
+ JitScalarType .BFLOAT16 ,
340
353
):
341
354
mul_qk_add = g .op ("Add" , mul_qk , attn_mask )
342
355
else :
343
- raise ValueError (
344
- f"Unsupported type for attn_mask: { _type_utils .JitScalarType .from_value (attn_mask )} "
345
- )
356
+ raise ValueError (f"Unsupported type for attn_mask: { JitScalarType .from_value (attn_mask )} " )
346
357
347
358
attn_weight = g .op ("Softmax" , mul_qk_add , axis_i = - 1 )
348
359
@@ -357,14 +368,14 @@ def scaled_dot_product_attention(
357
368
358
369
359
370
def export_fp8_mha (
360
- g : torch . onnx . _internal . jit_utils . GraphContext ,
361
- query : torch ._C .Value ,
362
- key : torch ._C .Value ,
363
- value : torch ._C .Value ,
364
- attn_mask : torch ._C .Value | None = None ,
371
+ g : " GraphContext" ,
372
+ query : " torch._C.Value" ,
373
+ key : " torch._C.Value" ,
374
+ value : " torch._C.Value" ,
375
+ attn_mask : " torch._C.Value | None" = None ,
365
376
dropout_p : float = 0.0 ,
366
377
is_causal : bool = False ,
367
- scale : torch ._C .Value | None = None ,
378
+ scale : " torch._C.Value | None" = None ,
368
379
q_quantized_scale : float = 1.0 ,
369
380
k_quantized_scale : float = 1.0 ,
370
381
v_quantized_scale : float = 1.0 ,
@@ -396,12 +407,18 @@ def export_fp8_mha(
396
407
|
397
408
Cast
398
409
"""
399
- from torch .onnx .symbolic_opset14 import _attention_scale , _causal_attention_mask
400
-
401
410
if hasattr (torch .onnx , "_type_utils" ):
402
- from torch .onnx import _type_utils
403
- else :
404
- from torch .onnx ._internal .torchscript_exporter import _type_utils
411
+ from torch .onnx ._type_utils import JitScalarType
412
+ else : # torch >= 2.9
413
+ from torch .onnx ._internal .torchscript_exporter import JitScalarType
414
+
415
+ if hasattr (torch .onnx , "symbolic_opset14" ):
416
+ from torch .onnx .symbolic_opset14 import _attention_scale , _causal_attention_mask
417
+ else : # torch >= 2.9
418
+ from torch .onnx ._internal .torchscript_exporter .symbolic_opset14 import (
419
+ _attention_scale ,
420
+ _causal_attention_mask ,
421
+ )
405
422
406
423
# Pass all arguments, including x, to the custom ONNX operator
407
424
assert (not is_causal ) or (is_causal and sym_help ._is_none (attn_mask )), (
@@ -452,22 +469,20 @@ def export_fp8_mha(
452
469
453
470
if sym_help ._is_none (attn_mask ):
454
471
mul_qk_add = mul_qk
455
- elif _type_utils . JitScalarType .from_value (attn_mask ) == _type_utils . JitScalarType .BOOL :
472
+ elif JitScalarType .from_value (attn_mask ) == JitScalarType .BOOL :
456
473
# Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf'))
457
474
const_zero = g .op ("Constant" , value_t = torch .tensor ([0.0 ]))
458
475
const_neg_inf = g .op ("Constant" , value_t = torch .tensor ([- float ("inf" )]))
459
476
attn_mask = g .op ("Where" , attn_mask , const_zero , const_neg_inf )
460
477
mul_qk_add = g .op ("Add" , mul_qk , attn_mask )
461
- elif _type_utils . JitScalarType .from_value (attn_mask ) in (
462
- _type_utils . JitScalarType .FLOAT ,
463
- _type_utils . JitScalarType .HALF ,
464
- _type_utils . JitScalarType .BFLOAT16 ,
478
+ elif JitScalarType .from_value (attn_mask ) in (
479
+ JitScalarType .FLOAT ,
480
+ JitScalarType .HALF ,
481
+ JitScalarType .BFLOAT16 ,
465
482
):
466
483
mul_qk_add = g .op ("Add" , mul_qk , attn_mask )
467
484
else :
468
- raise ValueError (
469
- f"Unsupported type for attn_mask: { _type_utils .JitScalarType .from_value (attn_mask )} "
470
- )
485
+ raise ValueError (f"Unsupported type for attn_mask: { JitScalarType .from_value (attn_mask )} " )
471
486
472
487
attn_weight = g .op ("Softmax" , mul_qk_add , axis_i = - 1 )
473
488
@@ -495,7 +510,7 @@ def export_fp8_mha(
495
510
496
511
497
512
def _fp4_dynamic_quantize (
498
- g : torch . onnx . _internal . jit_utils . GraphContext ,
513
+ g : " GraphContext" ,
499
514
inputs : torch .Value ,
500
515
scale : float ,
501
516
trt_high_precision_dtype : str | None ,
@@ -531,7 +546,7 @@ def _fp4_dynamic_quantize(
531
546
532
547
533
548
def _fp4_dequantize (
534
- g : torch . onnx . _internal . jit_utils . GraphContext ,
549
+ g : " GraphContext" ,
535
550
inputs : torch .Value ,
536
551
scale : float | torch .Value ,
537
552
trt_high_precision_dtype : str | None ,
@@ -546,7 +561,7 @@ def _fp4_dequantize(
546
561
547
562
548
563
def _fp4_dequantize_2 (
549
- g : torch . onnx . _internal . jit_utils . GraphContext ,
564
+ g : " GraphContext" ,
550
565
inputs : torch .Value ,
551
566
dyn_scale : torch .Value ,
552
567
block_size : int ,
@@ -557,7 +572,7 @@ def _fp4_dequantize_2(
557
572
558
573
559
574
def _mxfp8_dynamic_quantize (
560
- g : torch . onnx . _internal . jit_utils . GraphContext ,
575
+ g : " GraphContext" ,
561
576
inputs : torch .Value ,
562
577
block_size : int ,
563
578
axis : int = - 1 ,
@@ -575,7 +590,7 @@ def _mxfp8_dynamic_quantize(
575
590
576
591
577
592
def _mxfp8_dequantize (
578
- g : torch . onnx . _internal . jit_utils . GraphContext ,
593
+ g : " GraphContext" ,
579
594
inputs : torch .Value ,
580
595
scale : torch .Value ,
581
596
block_size : int ,
@@ -593,7 +608,7 @@ def _mxfp8_dequantize(
593
608
594
609
595
610
def export_mxfp8 (
596
- g : torch . onnx . _internal . jit_utils . GraphContext ,
611
+ g : " GraphContext" ,
597
612
inputs : torch .Tensor ,
598
613
onnx_quantizer_type : str ,
599
614
block_size : int ,
@@ -611,7 +626,7 @@ def export_mxfp8(
611
626
612
627
613
628
def export_fp4 (
614
- g : torch . onnx . _internal . jit_utils . GraphContext ,
629
+ g : " GraphContext" ,
615
630
inputs : torch .Value ,
616
631
block_size : int ,
617
632
amax : torch .Value ,
0 commit comments