-
Notifications
You must be signed in to change notification settings - Fork 169
[5455919] Insert cast nodes for 'FP32 required' ops #363
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
da24fe6
fe837a8
1ec0af0
1a52354
3ea85b7
8a0b1d4
890b12e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -367,10 +367,12 @@ def interpret_trt_plugins_precision_flag( | |||||||||||||||||||||||||||||||
if trt_plugin_precision.count(":") == 1: | ||||||||||||||||||||||||||||||||
if precision not in supported_precisions: | ||||||||||||||||||||||||||||||||
logger.warning(f"Precision {precision} is not supported. Skipping.") | ||||||||||||||||||||||||||||||||
if precision == "fp16": | ||||||||||||||||||||||||||||||||
custom_ops_to_cast[op_type] = { | ||||||||||||||||||||||||||||||||
"inp": list(range(num_inps)), | ||||||||||||||||||||||||||||||||
"out": list(range(num_outs)), | ||||||||||||||||||||||||||||||||
if precision in ["fp16", "fp32"]: | ||||||||||||||||||||||||||||||||
custom_ops_to_cast[precision] = { | ||||||||||||||||||||||||||||||||
op_type: { | ||||||||||||||||||||||||||||||||
"inp": list(range(num_inps)), | ||||||||||||||||||||||||||||||||
"out": list(range(num_outs)), | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||
if precision in ["int8", "fp8"]: | ||||||||||||||||||||||||||||||||
if precision != quantize_mode: | ||||||||||||||||||||||||||||||||
|
@@ -408,10 +410,14 @@ def interpret_trt_plugins_precision_flag( | |||||||||||||||||||||||||||||||
f"Setting the custom op precision to be the same as quantize mode." | ||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
# Will cast the inputs to FP16 and the outputs back to FP32 | ||||||||||||||||||||||||||||||||
inp_precision_cast = [i for i, p in enumerate(inp_precision) if p == "fp16"] | ||||||||||||||||||||||||||||||||
out_precision_cast = [i for i, p in enumerate(out_precision) if p in ["fp16", "fp32"]] | ||||||||||||||||||||||||||||||||
custom_ops_to_cast[op_type] = {"inp": inp_precision_cast, "out": out_precision_cast} | ||||||||||||||||||||||||||||||||
# Will cast the inputs to FP16/FP32 and the outputs back to FP32 | ||||||||||||||||||||||||||||||||
for precision in ["fp16", "fp32"]: | ||||||||||||||||||||||||||||||||
inp_precision_cast = [i for i, p in enumerate(inp_precision) if p == precision] | ||||||||||||||||||||||||||||||||
out_precision_cast = [i for i, p in enumerate(out_precision) if p == precision] | ||||||||||||||||||||||||||||||||
if inp_precision_cast: | ||||||||||||||||||||||||||||||||
custom_ops_to_cast[precision] = { | ||||||||||||||||||||||||||||||||
op_type: {"inp": inp_precision_cast, "out": out_precision_cast} | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
Comment on lines
+413
to
421
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: Drops output-only casting and overwrites maps.
- # Will cast the inputs to FP16/FP32 and the outputs back to FP32
- for precision in ["fp16", "fp32"]:
- inp_precision_cast = [i for i, p in enumerate(inp_precision) if p == precision]
- out_precision_cast = [i for i, p in enumerate(out_precision) if p == precision]
- if inp_precision_cast:
- custom_ops_to_cast[precision] = {
- op_type: {"inp": inp_precision_cast, "out": out_precision_cast}
- }
+ # Will cast requested inputs to FP16/FP32 and outputs back to FP32
+ for precision in ["fp16", "fp32"]:
+ inp_precision_cast = [i for i, p in enumerate(inp_precision) if p == precision]
+ out_precision_cast = [i for i, p in enumerate(out_precision) if p == precision]
+ if inp_precision_cast or out_precision_cast:
+ ops_map = custom_ops_to_cast.setdefault(precision, {})
+ ops_map[op_type] = {"inp": inp_precision_cast, "out": out_precision_cast} 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||||||||||||||||
# Will add Q/DQ nodes in the requested I/O indices | ||||||||||||||||||||||||||||||||
inp_precision_quant = [i for i, p in enumerate(inp_precision) if p in ["int8", "fp8"]] | ||||||||||||||||||||||||||||||||
|
Uh oh!
There was an error while loading. Please reload this page.