Skip to content

Commit fe9657c

Browse files
authored
NvTensorRtRtx EP option in GenAI - model builder (microsoft#1453)
1 parent e965694 commit fe9657c

File tree

2 files changed

+228
-3
lines changed

2 files changed

+228
-3
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ See documentation at https://onnxruntime.ai/docs/genai.
2424
|API| Python <br/>C# <br/>C/C++ <br/> Java ^ |Objective-C||
2525
|Platform| Linux <br/> Windows <br/>Mac ^ <br/>Android ^ ||iOS |||
2626
|Architecture|x86 <br/> x64 <br/> Arm64 ~ ||||
27-
|Hardware Acceleration|CUDA<br/>DirectML<br/>|QNN <br/> OpenVINO <br/> ROCm ||
27+
|Hardware Acceleration|CUDA<br/>DirectML<br/>|QNN <br/> OpenVINO <br/> ROCm | NvTensorRtRtx |
2828
|Features|MultiLoRA <br/> Continuous decoding (session continuation)^ | Constrained decoding | Speculative decoding |
2929

3030
\* The Llama model architecture supports similar model families such as CodeLlama, Vicuna, Yi, and more.

src/python/py/models/builder.py

Lines changed: 227 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
7575
},
7676
"dml": {},
7777
"webgpu": {},
78+
"NvTensorRtRtx": {},
7879
}
7980

8081
# Map input names to their types and shapes
@@ -343,6 +344,7 @@ def make_attention_init(self):
343344
("dml", TensorProto.FLOAT16),
344345
("webgpu", TensorProto.FLOAT16),
345346
("webgpu", TensorProto.FLOAT),
347+
("NvTensorRtRtx", TensorProto.FLOAT16),
346348
]
347349
if (self.ep, self.io_dtype) in valid_gqa_configurations:
348350
# Change model settings for GroupQueryAttention
@@ -757,6 +759,23 @@ def make_reduce_max(self, name, inputs, dtype, shape):
757759
self.make_node("ReduceMax", inputs=inputs, outputs=[output], name=name, keepdims=False)
758760
self.make_value_info(output, dtype, shape=shape)
759761

762+
def make_reduce_mean(self, name, inputs, dtype, shape, axes=[-1], keepdims=False):
763+
output = f"{name}/output_0"
764+
if self.quant_attrs["use_qdq"]:
765+
# Opset 18 uses axes as input[1]
766+
inputs.append(f"/model/constants/TensorProto.INT64/1D/{','.join(map(str, axes))}")
767+
self.make_node("ReduceMean", inputs=inputs, outputs=[output], name=name, keepdims=keepdims)
768+
self.make_value_info(output, dtype, shape=shape)
769+
else:
770+
# Opset 17 uses axes as attribute
771+
self.make_node("ReduceMean", inputs=inputs, outputs=[output], name=name, axes=axes, keepdims=keepdims)
772+
self.make_value_info(output, dtype, shape=shape)
773+
774+
def make_sqrt(self, name, inputs, dtype, shape):
775+
output = f"{name}/output_0"
776+
self.make_node("Sqrt", inputs=inputs, outputs=[output], name=name)
777+
self.make_value_info(output, dtype, shape=shape)
778+
760779
def make_cast(self, name, root_input, dtype, shape):
761780
output = f"{name}/output_0"
762781
self.make_node("Cast", inputs=[root_input], outputs=[output], name=name, to=dtype)
@@ -1059,6 +1078,13 @@ def make_embedding(self, embedding):
10591078
self.layernorm_attrs["skip_input"] = layernorm_attrs_value
10601079

10611080
def make_layernorm(self, layer_id, layernorm, skip, simple, location):
1081+
if self.ep == "NvTensorRtRtx" and (skip or simple):
1082+
# NvTensorRtRtx EP doesn't support Skip/SimplifiedLayerNormalization and SkipLayerNormalization, so we fallback to primitive ops
1083+
self._make_layernorm_op(layer_id, layernorm, skip, simple, location)
1084+
else:
1085+
self.make_layernorm_op(layer_id, layernorm, skip, simple, location)
1086+
1087+
def make_layernorm_op(self, layer_id, layernorm, skip, simple, location):
10621088
root_input = self.layernorm_attrs["root_input"]
10631089
skip_input = self.layernorm_attrs["skip_input"]
10641090

@@ -1112,6 +1138,68 @@ def make_layernorm(self, layer_id, layernorm, skip, simple, location):
11121138
# Assign output 3 of current SkipLayerNorm as root input to next SkipLayerNorm
11131139
self.layernorm_attrs["root_input"] = output_3
11141140

1141+
def _make_layernorm_op(self, layer_id, layernorm, skip, simple, location):
1142+
root_input = self.layernorm_attrs["root_input"]
1143+
skip_input = self.layernorm_attrs["skip_input"]
1144+
1145+
# Get precision types to use
1146+
old_torch_dtype = self.to_torch_dtype[self.io_dtype]
1147+
old_io_dtype = self.io_dtype
1148+
new_torch_dtype = torch.float32 if self.layernorm_attrs["cast"]["use_fp32"] else self.to_torch_dtype[self.io_dtype]
1149+
new_io_dtype = self.to_onnx_dtype[new_torch_dtype]
1150+
cast = old_torch_dtype != new_torch_dtype
1151+
1152+
# Create weight and bias tensors
1153+
weight = f"model.layers.{layer_id}.{location}_layernorm.weight"
1154+
self.make_external_tensor((layernorm.weight.detach().cpu().to(new_torch_dtype) + self.layernorm_attrs["add_offset"]).contiguous(), weight)
1155+
bias = f"model.layers.{layer_id}.{location}_layernorm.bias"
1156+
if not simple:
1157+
self.make_external_tensor(layernorm.bias.detach().cpu().to(new_torch_dtype).contiguous(), bias)
1158+
1159+
# Create input names for op
1160+
inputs = [root_input, skip_input, weight] if skip else [root_input, weight]
1161+
if not simple:
1162+
inputs.append(bias)
1163+
1164+
name = f"/model/layers.{layer_id}/{location}_layernorm/{'Skip' if skip else ''}LayerNorm"
1165+
op_type = f"{'Skip' if skip else ''}{'Simplified' if simple else ''}LayerNormalization"
1166+
kwargs = {"epsilon": self.layernorm_attrs["epsilon"]}
1167+
if not skip:
1168+
kwargs.update({"axis": -1, "stash_type": 1})
1169+
1170+
# Create output names for op
1171+
output_0 = f"/model/layers.{layer_id}/{location}_layernorm/output_0"
1172+
output_3 = f"/model/layers.{layer_id}/{location}_layernorm/output_3"
1173+
if self.layernorm_attrs["last_layernorm"] and (self.include_hidden_states or self.exclude_lm_head):
1174+
output_0 = "hidden_states"
1175+
outputs = [output_0, "", "", output_3] if skip and not self.layernorm_attrs["last_layernorm"] else [output_0]
1176+
1177+
# Create Cast nodes for inputs and outputs if old_dtype != new_dtype
1178+
if cast:
1179+
inputs, outputs = self.make_layernorm_casts(name, inputs, outputs, old_io_dtype, new_io_dtype)
1180+
root_input = inputs[0]
1181+
skip_input = inputs[1] if skip else None
1182+
1183+
if op_type == "SimplifiedLayerNormalization":
1184+
self._make_simplified_layer_norm(name, root_input, weight, outputs[0], new_io_dtype, shape=['batch_size', 'sequence_length', self.hidden_size])
1185+
elif op_type == "SkipSimplifiedLayerNormalization":
1186+
self._make_skip_simplified_layer_norm(name, root_input, skip_input, weight, outputs[0], output_3, new_io_dtype, shape=['batch_size', 'sequence_length', self.hidden_size])
1187+
elif op_type == "SkipLayerNormalization":
1188+
self._make_skip_layer_norm(name, root_input, skip_input, weight, bias, outputs[0], output_3, new_io_dtype, shape=['batch_size', 'sequence_length', self.hidden_size])
1189+
else:
1190+
raise ValueError(f"Invalid op_type: {op_type}")
1191+
1192+
if skip and not self.layernorm_attrs["last_layernorm"]:
1193+
self.make_value_info(outputs[3], new_io_dtype, shape=['batch_size', 'sequence_length', self.hidden_size])
1194+
1195+
# Update LayerNorm attributes
1196+
self.layernorm_attrs["output_0"] = output_0
1197+
if skip and not self.layernorm_attrs["last_layernorm"]:
1198+
self.layernorm_attrs["output_3"] = output_3
1199+
1200+
# Assign output 3 of current SkipLayerNorm as root input to next SkipLayerNorm
1201+
self.layernorm_attrs["root_input"] = output_3
1202+
11151203
def make_layernorm_casts(self, name, inputs, outputs, old_dtype, new_dtype):
11161204
# Name = name of original LayerNorm op as if the cast nodes did not exist
11171205
# Inputs = inputs into the original LayerNorm op as if the cast nodes did not exist
@@ -1354,6 +1442,110 @@ def make_rotary_embedding_multi_cache(self, **kwargs):
13541442
self.make_value_info(cos_cache_name, self.io_dtype, shape=["max_sequence_length", "head_dim / 2"])
13551443
self.make_value_info(sin_cache_name, self.io_dtype, shape=["max_sequence_length", "head_dim / 2"])
13561444

1445+
# This expansion of contrib-op can be updated / deprecated in future.
1446+
def _make_skip_simplified_layer_norm(self, basename, root_input, skip_input, weight_name, output_0, output_3, io_dtype, shape):
1447+
# root_input skip_input
1448+
# | |
1449+
# +------------------+
1450+
# |
1451+
# Add-------------> output (1)
1452+
# |
1453+
# SimplifiedLayerNorm----> output (0)
1454+
make_add_name = f"{basename}/Add"
1455+
output_3 = f"{basename}/Add/output_0" if output_3 is None else output_3
1456+
self.make_node("Add", inputs=[root_input, skip_input], outputs=[output_3], name=make_add_name)
1457+
self.make_value_info(output_3, io_dtype, shape=['batch_size', 'sequence_length', self.hidden_size])
1458+
1459+
make_simplified_layer_norm_name = f"{basename}/skip_simplified_layer_norm"
1460+
self._make_simplified_layer_norm(make_simplified_layer_norm_name, output_3, weight_name, output_0, io_dtype, shape=shape)
1461+
1462+
# This expansion contrib-op can be updated / depricated in future.
1463+
def _make_skip_layer_norm(self, basename, root_input, skip_input, weight_name, bias_name, output_0, output_3, io_dtype, shape):
1464+
# root_input skip_input
1465+
# | |
1466+
# +------------------+
1467+
# |
1468+
# Add-------------> output (1)
1469+
# |
1470+
# LayerNormalization-----> output (0)
1471+
output_3 = f"{basename}/Add/output_0" if output_3 is None else output_3
1472+
make_add_name = f"{basename}/Add"
1473+
self.make_node("Add", inputs=[root_input, skip_input], outputs=[output_3], name=make_add_name)
1474+
self.make_value_info(output_3, io_dtype, shape=['batch_size', 'sequence_length', self.hidden_size])
1475+
1476+
make_layer_norm_name = f"{basename}/LayerNormalization"
1477+
inputs = [output_3, weight_name, bias_name]
1478+
1479+
kwargs = {"epsilon": self.layernorm_attrs["epsilon"]}
1480+
kwargs.update({"axis": -1, "stash_type": 1})
1481+
1482+
self.make_node("LayerNormalization", inputs=inputs, outputs=[output_0], name=make_layer_norm_name, **kwargs)
1483+
self.make_value_info(output_0, io_dtype, shape=shape)
1484+
1485+
# This expansion contrib-op can be updated / depricated in future.
1486+
def _make_simplified_layer_norm(self, basename, root_input, weight_name, output_0, io_dtype, shape):
1487+
1488+
# Cast (float32) - most calc happens in higher precision
1489+
# |
1490+
# +-------+-------+
1491+
# | |
1492+
# Pow |
1493+
# | |
1494+
# ReduceMean |
1495+
# | |
1496+
# Add |
1497+
# | |
1498+
# Sqrt |
1499+
# | |
1500+
# Div |
1501+
# | |
1502+
# +-------+-------+
1503+
# |
1504+
# Mul
1505+
# |
1506+
# Cast_1 (io_dtype - float16)
1507+
# |
1508+
# Mul_1
1509+
1510+
make_cast_name = f"{basename}/Cast"
1511+
self.make_cast(make_cast_name, root_input, TensorProto.FLOAT, shape=shape)
1512+
1513+
make_pow_name = f"{basename}/Pow"
1514+
make_pow_inputs = [f"{make_cast_name}/output_0", f"/model/constants/TensorProto.FLOAT/0D/2"]
1515+
1516+
self.make_node("Pow", inputs=make_pow_inputs, outputs=[f"{make_pow_name}/output_0"], name=make_pow_name, domain="")
1517+
self.make_value_info(f"{make_pow_name}/output_0", TensorProto.FLOAT, shape=shape)
1518+
1519+
make_reducemean_name = f"{basename}/ReduceMean"
1520+
make_reducemean_inputs = [f"{make_pow_name}/output_0"]
1521+
self.make_reduce_mean(make_reducemean_name, make_reducemean_inputs, TensorProto.FLOAT, keepdims=True, axes=[-1], shape=shape)
1522+
1523+
make_add_name = f"{basename}/Add"
1524+
make_add_inputs = [f"{make_reducemean_name}/output_0", f"/model/constants/TensorProto.FLOAT/0D/{self.layernorm_attrs['epsilon']}"]
1525+
self.make_add(make_add_name, make_add_inputs, TensorProto.FLOAT, shape=shape)
1526+
1527+
make_sqrt_name = f"{basename}/Sqrt"
1528+
make_sqrt_inputs = [f"{make_add_name}/output_0"]
1529+
self.make_sqrt(make_sqrt_name, make_sqrt_inputs, TensorProto.FLOAT, shape=shape)
1530+
1531+
make_div_name = f"{basename}/Div"
1532+
make_div_inputs = [f"/model/constants/TensorProto.FLOAT/0D/1", f"{make_sqrt_name}/output_0"]
1533+
self.make_div(make_div_name, make_div_inputs, TensorProto.FLOAT, shape=shape)
1534+
1535+
make_mul_name = f"{basename}/Mul"
1536+
make_mul_inputs = [f"{make_div_name}/output_0", f"{make_cast_name}/output_0"]
1537+
self.make_mul(make_mul_name, make_mul_inputs, TensorProto.FLOAT, shape=shape)
1538+
1539+
make_cast_1_name = f"{basename}/Cast_1"
1540+
self.make_cast(make_cast_1_name, f"{make_mul_name}/output_0", dtype=io_dtype, shape=shape)
1541+
1542+
make_mul_1_name = f"{basename}/Mul_1"
1543+
make_mul_1_inputs = [f"{make_cast_1_name}/output_0", weight_name]
1544+
1545+
self.make_node("Mul", inputs=make_mul_1_inputs, outputs=[output_0], name=make_mul_1_name)
1546+
self.make_value_info(output_0, dtype=io_dtype, shape=shape)
1547+
1548+
13571549
def make_qk_norm(self, layer_id, attention):
13581550
# Make subgraph to compute SimplifiedLayerNorm after Q and K MatMuls in attention:
13591551
#
@@ -2190,17 +2382,47 @@ def make_activation_with_mul(self, layer_id, root_input, activation, domain):
21902382
return mul_act_name
21912383

21922384
def make_gelu(self, layer_id, root_input, activation):
2385+
# NvTensorRtRtx (Opset 21) uses standard "Gelu" replacing "Gelu" & "FastGelu" contrib ops, otherwise fallback to contrib ops
2386+
if self.ep == "NvTensorRtRtx" and activation in ["Gelu", "FastGelu"]:
2387+
return self._make_gelu_op(layer_id, root_input, activation)
2388+
else:
2389+
return self.make_gelu_op(layer_id, root_input, activation)
2390+
2391+
def make_gelu_op(self, layer_id, root_input, activation):
21932392
# Make nodes for this activation subgraph
21942393
#
21952394
# root_input (Add)
21962395
# |
21972396
# GeluAct
21982397
gelu_name = f"/model/layers.{layer_id}/mlp/act_fn/{activation}"
21992398
output = f"{gelu_name}/output_0"
2399+
22002400
self.make_node(activation, inputs=[root_input], outputs=[output], name=gelu_name, domain="com.microsoft")
22012401
self.make_value_info(output, self.io_dtype, shape=['batch_size', 'sequence_length', self.intermediate_size])
22022402

22032403
return gelu_name
2404+
2405+
# This expansion of contrib-op can be updated / deprecated in future.
2406+
def _make_gelu_op(self, layer_id, root_input, activation):
2407+
# Make nodes for this activation subgraph
2408+
#
2409+
# root_input (Add)
2410+
# |
2411+
# GeluAct
2412+
gelu_name = f"/model/layers.{layer_id}/mlp/act_fn/{activation}"
2413+
output = f"{gelu_name}/output_0"
2414+
2415+
# NvTensorRtRtx (Opset 21) uses standard "Gelu" replacing "Gelu" & "FastGelu" contrib ops, otherwise fallback to contrib ops
2416+
if activation == "Gelu":
2417+
self.make_node("Gelu", inputs=[root_input], outputs=[output], name=gelu_name, approximate="none")
2418+
elif activation == "FastGelu":
2419+
self.make_node("Gelu", inputs=[root_input], outputs=[output], name=gelu_name, approximate="tanh")
2420+
else:
2421+
raise NotImplementedError(f"The {activation} activation function is not currently supported.")
2422+
2423+
self.make_value_info(output, self.io_dtype, shape=['batch_size', 'sequence_length', self.intermediate_size])
2424+
2425+
return gelu_name
22042426

22052427
def make_relu(self, layer_id, root_input, activation):
22062428
relu_name = f"/model/layers.{layer_id}/mlp/act_fn/{activation}"
@@ -3447,6 +3669,9 @@ def check_extra_options(kv_pairs):
34473669
# 'include_hidden_states' is for when 'hidden_states' are outputted and 'logits' are outputted
34483670
raise ValueError(f"Both 'exclude_lm_head' and 'include_hidden_states' cannot be used together. Please use only one of them at once.")
34493671

3672+
# NvTensorRtRtx EP requires Opset 21, so force use_qdq which controls it.
3673+
if args.execution_provider == "NvTensorRtRtx":
3674+
kv_pairs["use_qdq"] = True
34503675

34513676
def parse_extra_options(kv_items):
34523677
"""
@@ -3640,7 +3865,7 @@ def get_args():
36403865
"-e",
36413866
"--execution_provider",
36423867
required=True,
3643-
choices=["cpu", "cuda", "rocm", "dml", "webgpu"],
3868+
choices=["cpu", "cuda", "rocm", "dml", "webgpu", "NvTensorRtRtx"],
36443869
help="Execution provider to target with precision of model (e.g. FP16 CUDA, INT4 CPU, INT4 WEBGPU)",
36453870
)
36463871

@@ -3714,7 +3939,7 @@ def get_args():
37143939
)
37153940

37163941
args = parser.parse_args()
3717-
print("Valid precision + execution provider combinations are: FP32 CPU, FP32 CUDA, FP16 CUDA, FP16 DML, BF16 CUDA, INT4 CPU, INT4 CUDA, INT4 DML, INT4 WEBGPU")
3942+
print("Valid precision + execution provider combinations are: FP32 CPU, FP32 CUDA, FP16 CUDA, FP16 DML, BF16 CUDA, FP16 NvTensorRtRtx, INT4 CPU, INT4 CUDA, INT4 DML, INT4 WEBGPU")
37183943
return args
37193944

37203945
if __name__ == '__main__':

0 commit comments

Comments
 (0)