@@ -75,6 +75,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
75
75
},
76
76
"dml" : {},
77
77
"webgpu" : {},
78
+ "NvTensorRtRtx" : {},
78
79
}
79
80
80
81
# Map input names to their types and shapes
@@ -343,6 +344,7 @@ def make_attention_init(self):
343
344
("dml" , TensorProto .FLOAT16 ),
344
345
("webgpu" , TensorProto .FLOAT16 ),
345
346
("webgpu" , TensorProto .FLOAT ),
347
+ ("NvTensorRtRtx" , TensorProto .FLOAT16 ),
346
348
]
347
349
if (self .ep , self .io_dtype ) in valid_gqa_configurations :
348
350
# Change model settings for GroupQueryAttention
@@ -757,6 +759,23 @@ def make_reduce_max(self, name, inputs, dtype, shape):
757
759
self .make_node ("ReduceMax" , inputs = inputs , outputs = [output ], name = name , keepdims = False )
758
760
self .make_value_info (output , dtype , shape = shape )
759
761
762
+ def make_reduce_mean (self , name , inputs , dtype , shape , axes = [- 1 ], keepdims = False ):
763
+ output = f"{ name } /output_0"
764
+ if self .quant_attrs ["use_qdq" ]:
765
+ # Opset 18 uses axes as input[1]
766
+ inputs .append (f"/model/constants/TensorProto.INT64/1D/{ ',' .join (map (str , axes ))} " )
767
+ self .make_node ("ReduceMean" , inputs = inputs , outputs = [output ], name = name , keepdims = keepdims )
768
+ self .make_value_info (output , dtype , shape = shape )
769
+ else :
770
+ # Opset 17 uses axes as attribute
771
+ self .make_node ("ReduceMean" , inputs = inputs , outputs = [output ], name = name , axes = axes , keepdims = keepdims )
772
+ self .make_value_info (output , dtype , shape = shape )
773
+
774
+ def make_sqrt (self , name , inputs , dtype , shape ):
775
+ output = f"{ name } /output_0"
776
+ self .make_node ("Sqrt" , inputs = inputs , outputs = [output ], name = name )
777
+ self .make_value_info (output , dtype , shape = shape )
778
+
760
779
def make_cast (self , name , root_input , dtype , shape ):
761
780
output = f"{ name } /output_0"
762
781
self .make_node ("Cast" , inputs = [root_input ], outputs = [output ], name = name , to = dtype )
@@ -1059,6 +1078,13 @@ def make_embedding(self, embedding):
1059
1078
self .layernorm_attrs ["skip_input" ] = layernorm_attrs_value
1060
1079
1061
1080
def make_layernorm (self , layer_id , layernorm , skip , simple , location ):
1081
+ if self .ep == "NvTensorRtRtx" and (skip or simple ):
1082
+ # NvTensorRtRtx EP doesn't support Skip/SimplifiedLayerNormalization and SkipLayerNormalization, so we fallback to primitive ops
1083
+ self ._make_layernorm_op (layer_id , layernorm , skip , simple , location )
1084
+ else :
1085
+ self .make_layernorm_op (layer_id , layernorm , skip , simple , location )
1086
+
1087
+ def make_layernorm_op (self , layer_id , layernorm , skip , simple , location ):
1062
1088
root_input = self .layernorm_attrs ["root_input" ]
1063
1089
skip_input = self .layernorm_attrs ["skip_input" ]
1064
1090
@@ -1112,6 +1138,68 @@ def make_layernorm(self, layer_id, layernorm, skip, simple, location):
1112
1138
# Assign output 3 of current SkipLayerNorm as root input to next SkipLayerNorm
1113
1139
self .layernorm_attrs ["root_input" ] = output_3
1114
1140
1141
+ def _make_layernorm_op (self , layer_id , layernorm , skip , simple , location ):
1142
+ root_input = self .layernorm_attrs ["root_input" ]
1143
+ skip_input = self .layernorm_attrs ["skip_input" ]
1144
+
1145
+ # Get precision types to use
1146
+ old_torch_dtype = self .to_torch_dtype [self .io_dtype ]
1147
+ old_io_dtype = self .io_dtype
1148
+ new_torch_dtype = torch .float32 if self .layernorm_attrs ["cast" ]["use_fp32" ] else self .to_torch_dtype [self .io_dtype ]
1149
+ new_io_dtype = self .to_onnx_dtype [new_torch_dtype ]
1150
+ cast = old_torch_dtype != new_torch_dtype
1151
+
1152
+ # Create weight and bias tensors
1153
+ weight = f"model.layers.{ layer_id } .{ location } _layernorm.weight"
1154
+ self .make_external_tensor ((layernorm .weight .detach ().cpu ().to (new_torch_dtype ) + self .layernorm_attrs ["add_offset" ]).contiguous (), weight )
1155
+ bias = f"model.layers.{ layer_id } .{ location } _layernorm.bias"
1156
+ if not simple :
1157
+ self .make_external_tensor (layernorm .bias .detach ().cpu ().to (new_torch_dtype ).contiguous (), bias )
1158
+
1159
+ # Create input names for op
1160
+ inputs = [root_input , skip_input , weight ] if skip else [root_input , weight ]
1161
+ if not simple :
1162
+ inputs .append (bias )
1163
+
1164
+ name = f"/model/layers.{ layer_id } /{ location } _layernorm/{ 'Skip' if skip else '' } LayerNorm"
1165
+ op_type = f"{ 'Skip' if skip else '' } { 'Simplified' if simple else '' } LayerNormalization"
1166
+ kwargs = {"epsilon" : self .layernorm_attrs ["epsilon" ]}
1167
+ if not skip :
1168
+ kwargs .update ({"axis" : - 1 , "stash_type" : 1 })
1169
+
1170
+ # Create output names for op
1171
+ output_0 = f"/model/layers.{ layer_id } /{ location } _layernorm/output_0"
1172
+ output_3 = f"/model/layers.{ layer_id } /{ location } _layernorm/output_3"
1173
+ if self .layernorm_attrs ["last_layernorm" ] and (self .include_hidden_states or self .exclude_lm_head ):
1174
+ output_0 = "hidden_states"
1175
+ outputs = [output_0 , "" , "" , output_3 ] if skip and not self .layernorm_attrs ["last_layernorm" ] else [output_0 ]
1176
+
1177
+ # Create Cast nodes for inputs and outputs if old_dtype != new_dtype
1178
+ if cast :
1179
+ inputs , outputs = self .make_layernorm_casts (name , inputs , outputs , old_io_dtype , new_io_dtype )
1180
+ root_input = inputs [0 ]
1181
+ skip_input = inputs [1 ] if skip else None
1182
+
1183
+ if op_type == "SimplifiedLayerNormalization" :
1184
+ self ._make_simplified_layer_norm (name , root_input , weight , outputs [0 ], new_io_dtype , shape = ['batch_size' , 'sequence_length' , self .hidden_size ])
1185
+ elif op_type == "SkipSimplifiedLayerNormalization" :
1186
+ self ._make_skip_simplified_layer_norm (name , root_input , skip_input , weight , outputs [0 ], output_3 , new_io_dtype , shape = ['batch_size' , 'sequence_length' , self .hidden_size ])
1187
+ elif op_type == "SkipLayerNormalization" :
1188
+ self ._make_skip_layer_norm (name , root_input , skip_input , weight , bias , outputs [0 ], output_3 , new_io_dtype , shape = ['batch_size' , 'sequence_length' , self .hidden_size ])
1189
+ else :
1190
+ raise ValueError (f"Invalid op_type: { op_type } " )
1191
+
1192
+ if skip and not self .layernorm_attrs ["last_layernorm" ]:
1193
+ self .make_value_info (outputs [3 ], new_io_dtype , shape = ['batch_size' , 'sequence_length' , self .hidden_size ])
1194
+
1195
+ # Update LayerNorm attributes
1196
+ self .layernorm_attrs ["output_0" ] = output_0
1197
+ if skip and not self .layernorm_attrs ["last_layernorm" ]:
1198
+ self .layernorm_attrs ["output_3" ] = output_3
1199
+
1200
+ # Assign output 3 of current SkipLayerNorm as root input to next SkipLayerNorm
1201
+ self .layernorm_attrs ["root_input" ] = output_3
1202
+
1115
1203
def make_layernorm_casts (self , name , inputs , outputs , old_dtype , new_dtype ):
1116
1204
# Name = name of original LayerNorm op as if the cast nodes did not exist
1117
1205
# Inputs = inputs into the original LayerNorm op as if the cast nodes did not exist
@@ -1354,6 +1442,110 @@ def make_rotary_embedding_multi_cache(self, **kwargs):
1354
1442
self .make_value_info (cos_cache_name , self .io_dtype , shape = ["max_sequence_length" , "head_dim / 2" ])
1355
1443
self .make_value_info (sin_cache_name , self .io_dtype , shape = ["max_sequence_length" , "head_dim / 2" ])
1356
1444
1445
+ # This expansion of contrib-op can be updated / deprecated in future.
1446
+ def _make_skip_simplified_layer_norm (self , basename , root_input , skip_input , weight_name , output_0 , output_3 , io_dtype , shape ):
1447
+ # root_input skip_input
1448
+ # | |
1449
+ # +------------------+
1450
+ # |
1451
+ # Add-------------> output (1)
1452
+ # |
1453
+ # SimplifiedLayerNorm----> output (0)
1454
+ make_add_name = f"{ basename } /Add"
1455
+ output_3 = f"{ basename } /Add/output_0" if output_3 is None else output_3
1456
+ self .make_node ("Add" , inputs = [root_input , skip_input ], outputs = [output_3 ], name = make_add_name )
1457
+ self .make_value_info (output_3 , io_dtype , shape = ['batch_size' , 'sequence_length' , self .hidden_size ])
1458
+
1459
+ make_simplified_layer_norm_name = f"{ basename } /skip_simplified_layer_norm"
1460
+ self ._make_simplified_layer_norm (make_simplified_layer_norm_name , output_3 , weight_name , output_0 , io_dtype , shape = shape )
1461
+
1462
+ # This expansion contrib-op can be updated / depricated in future.
1463
+ def _make_skip_layer_norm (self , basename , root_input , skip_input , weight_name , bias_name , output_0 , output_3 , io_dtype , shape ):
1464
+ # root_input skip_input
1465
+ # | |
1466
+ # +------------------+
1467
+ # |
1468
+ # Add-------------> output (1)
1469
+ # |
1470
+ # LayerNormalization-----> output (0)
1471
+ output_3 = f"{ basename } /Add/output_0" if output_3 is None else output_3
1472
+ make_add_name = f"{ basename } /Add"
1473
+ self .make_node ("Add" , inputs = [root_input , skip_input ], outputs = [output_3 ], name = make_add_name )
1474
+ self .make_value_info (output_3 , io_dtype , shape = ['batch_size' , 'sequence_length' , self .hidden_size ])
1475
+
1476
+ make_layer_norm_name = f"{ basename } /LayerNormalization"
1477
+ inputs = [output_3 , weight_name , bias_name ]
1478
+
1479
+ kwargs = {"epsilon" : self .layernorm_attrs ["epsilon" ]}
1480
+ kwargs .update ({"axis" : - 1 , "stash_type" : 1 })
1481
+
1482
+ self .make_node ("LayerNormalization" , inputs = inputs , outputs = [output_0 ], name = make_layer_norm_name , ** kwargs )
1483
+ self .make_value_info (output_0 , io_dtype , shape = shape )
1484
+
1485
+ # This expansion contrib-op can be updated / depricated in future.
1486
+ def _make_simplified_layer_norm (self , basename , root_input , weight_name , output_0 , io_dtype , shape ):
1487
+
1488
+ # Cast (float32) - most calc happens in higher precision
1489
+ # |
1490
+ # +-------+-------+
1491
+ # | |
1492
+ # Pow |
1493
+ # | |
1494
+ # ReduceMean |
1495
+ # | |
1496
+ # Add |
1497
+ # | |
1498
+ # Sqrt |
1499
+ # | |
1500
+ # Div |
1501
+ # | |
1502
+ # +-------+-------+
1503
+ # |
1504
+ # Mul
1505
+ # |
1506
+ # Cast_1 (io_dtype - float16)
1507
+ # |
1508
+ # Mul_1
1509
+
1510
+ make_cast_name = f"{ basename } /Cast"
1511
+ self .make_cast (make_cast_name , root_input , TensorProto .FLOAT , shape = shape )
1512
+
1513
+ make_pow_name = f"{ basename } /Pow"
1514
+ make_pow_inputs = [f"{ make_cast_name } /output_0" , f"/model/constants/TensorProto.FLOAT/0D/2" ]
1515
+
1516
+ self .make_node ("Pow" , inputs = make_pow_inputs , outputs = [f"{ make_pow_name } /output_0" ], name = make_pow_name , domain = "" )
1517
+ self .make_value_info (f"{ make_pow_name } /output_0" , TensorProto .FLOAT , shape = shape )
1518
+
1519
+ make_reducemean_name = f"{ basename } /ReduceMean"
1520
+ make_reducemean_inputs = [f"{ make_pow_name } /output_0" ]
1521
+ self .make_reduce_mean (make_reducemean_name , make_reducemean_inputs , TensorProto .FLOAT , keepdims = True , axes = [- 1 ], shape = shape )
1522
+
1523
+ make_add_name = f"{ basename } /Add"
1524
+ make_add_inputs = [f"{ make_reducemean_name } /output_0" , f"/model/constants/TensorProto.FLOAT/0D/{ self .layernorm_attrs ['epsilon' ]} " ]
1525
+ self .make_add (make_add_name , make_add_inputs , TensorProto .FLOAT , shape = shape )
1526
+
1527
+ make_sqrt_name = f"{ basename } /Sqrt"
1528
+ make_sqrt_inputs = [f"{ make_add_name } /output_0" ]
1529
+ self .make_sqrt (make_sqrt_name , make_sqrt_inputs , TensorProto .FLOAT , shape = shape )
1530
+
1531
+ make_div_name = f"{ basename } /Div"
1532
+ make_div_inputs = [f"/model/constants/TensorProto.FLOAT/0D/1" , f"{ make_sqrt_name } /output_0" ]
1533
+ self .make_div (make_div_name , make_div_inputs , TensorProto .FLOAT , shape = shape )
1534
+
1535
+ make_mul_name = f"{ basename } /Mul"
1536
+ make_mul_inputs = [f"{ make_div_name } /output_0" , f"{ make_cast_name } /output_0" ]
1537
+ self .make_mul (make_mul_name , make_mul_inputs , TensorProto .FLOAT , shape = shape )
1538
+
1539
+ make_cast_1_name = f"{ basename } /Cast_1"
1540
+ self .make_cast (make_cast_1_name , f"{ make_mul_name } /output_0" , dtype = io_dtype , shape = shape )
1541
+
1542
+ make_mul_1_name = f"{ basename } /Mul_1"
1543
+ make_mul_1_inputs = [f"{ make_cast_1_name } /output_0" , weight_name ]
1544
+
1545
+ self .make_node ("Mul" , inputs = make_mul_1_inputs , outputs = [output_0 ], name = make_mul_1_name )
1546
+ self .make_value_info (output_0 , dtype = io_dtype , shape = shape )
1547
+
1548
+
1357
1549
def make_qk_norm (self , layer_id , attention ):
1358
1550
# Make subgraph to compute SimplifiedLayerNorm after Q and K MatMuls in attention:
1359
1551
#
@@ -2190,17 +2382,47 @@ def make_activation_with_mul(self, layer_id, root_input, activation, domain):
2190
2382
return mul_act_name
2191
2383
2192
2384
def make_gelu (self , layer_id , root_input , activation ):
2385
+ # NvTensorRtRtx (Opset 21) uses standard "Gelu" replacing "Gelu" & "FastGelu" contrib ops, otherwise fallback to contrib ops
2386
+ if self .ep == "NvTensorRtRtx" and activation in ["Gelu" , "FastGelu" ]:
2387
+ return self ._make_gelu_op (layer_id , root_input , activation )
2388
+ else :
2389
+ return self .make_gelu_op (layer_id , root_input , activation )
2390
+
2391
+ def make_gelu_op (self , layer_id , root_input , activation ):
2193
2392
# Make nodes for this activation subgraph
2194
2393
#
2195
2394
# root_input (Add)
2196
2395
# |
2197
2396
# GeluAct
2198
2397
gelu_name = f"/model/layers.{ layer_id } /mlp/act_fn/{ activation } "
2199
2398
output = f"{ gelu_name } /output_0"
2399
+
2200
2400
self .make_node (activation , inputs = [root_input ], outputs = [output ], name = gelu_name , domain = "com.microsoft" )
2201
2401
self .make_value_info (output , self .io_dtype , shape = ['batch_size' , 'sequence_length' , self .intermediate_size ])
2202
2402
2203
2403
return gelu_name
2404
+
2405
+ # This expansion of contrib-op can be updated / deprecated in future.
2406
+ def _make_gelu_op (self , layer_id , root_input , activation ):
2407
+ # Make nodes for this activation subgraph
2408
+ #
2409
+ # root_input (Add)
2410
+ # |
2411
+ # GeluAct
2412
+ gelu_name = f"/model/layers.{ layer_id } /mlp/act_fn/{ activation } "
2413
+ output = f"{ gelu_name } /output_0"
2414
+
2415
+ # NvTensorRtRtx (Opset 21) uses standard "Gelu" replacing "Gelu" & "FastGelu" contrib ops, otherwise fallback to contrib ops
2416
+ if activation == "Gelu" :
2417
+ self .make_node ("Gelu" , inputs = [root_input ], outputs = [output ], name = gelu_name , approximate = "none" )
2418
+ elif activation == "FastGelu" :
2419
+ self .make_node ("Gelu" , inputs = [root_input ], outputs = [output ], name = gelu_name , approximate = "tanh" )
2420
+ else :
2421
+ raise NotImplementedError (f"The { activation } activation function is not currently supported." )
2422
+
2423
+ self .make_value_info (output , self .io_dtype , shape = ['batch_size' , 'sequence_length' , self .intermediate_size ])
2424
+
2425
+ return gelu_name
2204
2426
2205
2427
def make_relu (self , layer_id , root_input , activation ):
2206
2428
relu_name = f"/model/layers.{ layer_id } /mlp/act_fn/{ activation } "
@@ -3447,6 +3669,9 @@ def check_extra_options(kv_pairs):
3447
3669
# 'include_hidden_states' is for when 'hidden_states' are outputted and 'logits' are outputted
3448
3670
raise ValueError (f"Both 'exclude_lm_head' and 'include_hidden_states' cannot be used together. Please use only one of them at once." )
3449
3671
3672
+ # NvTensorRtRtx EP requires Opset 21, so force use_qdq which controls it.
3673
+ if args .execution_provider == "NvTensorRtRtx" :
3674
+ kv_pairs ["use_qdq" ] = True
3450
3675
3451
3676
def parse_extra_options (kv_items ):
3452
3677
"""
@@ -3640,7 +3865,7 @@ def get_args():
3640
3865
"-e" ,
3641
3866
"--execution_provider" ,
3642
3867
required = True ,
3643
- choices = ["cpu" , "cuda" , "rocm" , "dml" , "webgpu" ],
3868
+ choices = ["cpu" , "cuda" , "rocm" , "dml" , "webgpu" , "NvTensorRtRtx" ],
3644
3869
help = "Execution provider to target with precision of model (e.g. FP16 CUDA, INT4 CPU, INT4 WEBGPU)" ,
3645
3870
)
3646
3871
@@ -3714,7 +3939,7 @@ def get_args():
3714
3939
)
3715
3940
3716
3941
args = parser .parse_args ()
3717
- print ("Valid precision + execution provider combinations are: FP32 CPU, FP32 CUDA, FP16 CUDA, FP16 DML, BF16 CUDA, INT4 CPU, INT4 CUDA, INT4 DML, INT4 WEBGPU" )
3942
+ print ("Valid precision + execution provider combinations are: FP32 CPU, FP32 CUDA, FP16 CUDA, FP16 DML, BF16 CUDA, FP16 NvTensorRtRtx, INT4 CPU, INT4 CUDA, INT4 DML, INT4 WEBGPU" )
3718
3943
return args
3719
3944
3720
3945
if __name__ == '__main__' :
0 commit comments