Skip to content

Commit b4d25ad

Browse files
lanluo-nvidiaChen Fucehongwang
authored
broadcast_remove - cherry pick 3700 (#3757)
Co-authored-by: Chen Fu <[email protected]> Co-authored-by: cehongwang <[email protected]>
1 parent 0fa67cc commit b4d25ad

File tree

8 files changed

+67
-13
lines changed

8 files changed

+67
-13
lines changed

examples/apps/flux_demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def forward_loop(mod):
112112
"enabled_precisions": enabled_precisions,
113113
"truncate_double": True,
114114
"min_block_size": 1,
115-
"use_python_runtime": False,
115+
"use_python_runtime": True,
116116
"immutable_weights": False,
117117
"offload_module_to_cpu": args.low_vram_mode,
118118
"use_explicit_typing": use_explicit_typing,

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,7 @@ def aten_ops_gelu(
533533

534534

535535
@dynamo_tensorrt_converter(torch.ops.aten.matmul, supports_dynamic_shapes=True)
536+
@dynamo_tensorrt_converter(torch.ops.aten.matmul.default, supports_dynamic_shapes=True)
536537
@dynamo_tensorrt_converter(torch.ops.aten.dot.default, supports_dynamic_shapes=True)
537538
@dynamo_tensorrt_converter(torch.ops.aten.mm.default, supports_dynamic_shapes=True)
538539
@dynamo_tensorrt_converter(torch.ops.aten.mv.default, supports_dynamic_shapes=True)

py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1+
import logging
12
import operator
23
import warnings
34
from typing import Any, Callable, Optional, Union
45

5-
import numpy as np
66
import tensorrt as trt
77
import torch
88
from torch.fx.node import Target
@@ -20,6 +20,8 @@
2020
)
2121
from torch_tensorrt.dynamo.types import TRTElementWiseOp, TRTTensor
2222

23+
logger = logging.getLogger(__name__)
24+
2325

2426
def get_python_op_from_trt_elementwise_op(
2527
trt_op: TRTElementWiseOp,
@@ -148,7 +150,11 @@ def convert_binary_elementwise(
148150
ctx, rhs_val, trt_promoted_type, f"{name}_cast_rhs_val", target, source_ir
149151
)
150152

151-
if has_dynamic_shape(lhs_val.shape) or has_dynamic_shape(rhs_val.shape):
153+
if len(lhs_val.shape) == len(rhs_val.shape) and all(
154+
a == b or a == 1 or b == 1 for a, b in zip(lhs_val.shape, rhs_val.shape)
155+
):
156+
logger.info(f"skip broadcast for {name}")
157+
elif has_dynamic_shape(lhs_val.shape) or has_dynamic_shape(rhs_val.shape):
152158
lhs_val, rhs_val = broadcast(
153159
ctx, lhs_val, rhs_val, f"{name}_broadcast_lhs", f"{name}_broadcast_rhs"
154160
)

py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@
172172
aten.upsample_trilinear3d.vec,
173173
aten.upsample_bicubic2d.vec,
174174
aten.linear.default,
175+
aten.matmul.default,
175176
}
176177

177178

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
_get_decomp_for_cia,
1010
)
1111
from torch._ops import OpOverload
12-
1312
from torch_tensorrt.dynamo._defaults import default_device
1413
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
1514
from torch_tensorrt.dynamo.utils import to_torch_device
@@ -423,8 +422,8 @@ def instance_norm_decomposition(
423422

424423
@register_torch_trt_decomposition(
425424
torch.ops.aten.full_like, registry=TORCH_TRT_DECOMPOSITIONS
426-
) # type: ignore
427-
def full_like_decomposition(*args, **kwargs) -> torch.Tensor:
425+
)
426+
def full_like_decomposition(*args: Any, **kwargs: Any) -> torch.Tensor:
428427
input = args[0]
429428
shape = args[0].shape
430429
fill_value = args[1]
@@ -454,11 +453,13 @@ def scaled_dot_product_attention_decomposition(
454453
) -> torch.Tensor:
455454
L, S = query.size(-2), key.size(-2)
456455
device = query.device
457-
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=device)
456+
457+
if is_causal or attn_mask is not None:
458+
attn_bias = torch.zeros((L, S), dtype=query.dtype, device=device)
458459

459460
if is_causal:
460461
assert attn_mask is None, "attn_mask must be None when is_causal=True"
461-
temp_mask = torch.ones(L, S, dtype=torch.bool, device=device).tril(diagonal=0)
462+
temp_mask = torch.ones((L, S), dtype=torch.bool, device=device).tril(diagonal=0)
462463
attn_bias = attn_bias.masked_fill(temp_mask.logical_not(), float("-inf"))
463464

464465
if attn_mask is not None:
@@ -471,17 +472,20 @@ def scaled_dot_product_attention_decomposition(
471472
key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
472473
value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)
473474

474-
attn_weight = query @ key.transpose(-2, -1)
475+
attn_weight = torch.matmul(query, key.transpose(-2, -1))
475476

476477
if scale is None:
477478
scale = torch.sqrt(torch.scalar_tensor(query.size(-1), dtype=torch.int))
478479
attn_weight = attn_weight / scale
479480
else:
480481
attn_weight = attn_weight * scale
481482

482-
attn_weight = attn_weight + attn_bias
483+
if is_causal or attn_mask is not None:
484+
# We only add attn_bias when we have to, otherwise this will have a negative impact on the performance even it's 0.
485+
attn_weight = attn_weight + attn_bias
486+
483487
attn_weight = torch.softmax(attn_weight, dim=-1)
484-
return attn_weight @ value
488+
return torch.matmul(attn_weight, value)
485489

486490

487491
@register_torch_trt_decomposition(

py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,18 @@
1010

1111

1212
def split_addmm_nodes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
13+
"""
14+
Splits all `torch.ops.aten.addmm.default` nodes in the FX graph into separate
15+
`add` and `mm` nodes. This is useful for passes that want to insert additional
16+
logic (such as FP32 accumulation) specifically around the matrix multiplication
17+
operation, rather than the fused addmm.
18+
19+
Args:
20+
gm (torch.fx.GraphModule): The FX graph module to transform.
21+
22+
Returns:
23+
torch.fx.GraphModule: The modified FX graph module with addmm nodes split.
24+
"""
1325
target = torch.ops.aten.addmm.default
1426
addmm_nodes = [node for node in gm.graph.nodes if node.target == target]
1527
for addmm_node in addmm_nodes:
@@ -52,6 +64,7 @@ def accumulate_fp32_matmul(
5264
matmul_targets = [
5365
torch.ops.aten.mm.default,
5466
torch.ops.aten.bmm.default,
67+
torch.ops.aten.matmul.default,
5568
]
5669

5770
# Split torch.addmm nodes into add + mm and only add cast nodes around mm nodes

tools/perf/Flux/benchmark.sh

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,20 @@
11
#TODO: Enter the HF Token
22
huggingface-cli login --token HF_TOKEN
33

4+
nvidia-smi --query-gpu=index,utilization.gpu,utilization.memory,temperature.gpu,temperature.memory,power.draw,clocks.sm,clocks.mem,memory.total,memory.used --format=csv,nounits -lms 500 >> pytorch_fp16_gpu_utilization.txt &
5+
NVIDIA_SMI_PID=$!
6+
python flux_perf.py --pytorch --max_batch_size 3 > pytorch_fp16_benchmark.txt
7+
kill $NVIDIA_SMI_PID
8+
49
nvidia-smi --query-gpu=index,utilization.gpu,utilization.memory,temperature.gpu,temperature.memory,power.draw,clocks.sm,clocks.mem,memory.total,memory.used --format=csv,nounits -lms 500 >> fp8_gpu_utilization.txt &
510
NVIDIA_SMI_PID=$!
6-
python flux_perf.py --dtype fp8 --low_vram_mode> fp8_benchmark.txt
11+
python flux_perf.py --dtype fp8 --max_batch_size 3 > fp8_benchmark.txt
12+
kill $NVIDIA_SMI_PID
13+
14+
15+
nvidia-smi --query-gpu=index,utilization.gpu,utilization.memory,temperature.gpu,temperature.memory,power.draw,clocks.sm,clocks.mem,memory.total,memory.used --format=csv,nounits -lms 500 >> fp16_gpu_utilization.txt &
16+
NVIDIA_SMI_PID=$!
17+
python flux_perf.py --dtype fp16 --max_batch_size 3 > fp16_benchmark.txt
718
kill $NVIDIA_SMI_PID
819

920

tools/perf/Flux/flux_perf.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,22 @@ def benchmark(pipe, prompt, inference_step, batch_size=1, iterations=1):
4444
return
4545

4646

47+
from diffusers import FluxPipeline
48+
49+
4750
def main(args):
4851
print(f"Running flux_perfwith args: {args}")
49-
pipe, backbone, trt_gm = compile_model(args)
52+
if not args.pytorch:
53+
pipe, backbone, trt_gm = compile_model(args)
54+
else:
55+
pipe = (
56+
FluxPipeline.from_pretrained(
57+
"black-forest-labs/FLUX.1-dev",
58+
torch_dtype=torch.float16,
59+
)
60+
.to(torch.float16)
61+
.to("cuda:0")
62+
)
5063

5164
benchmark(pipe, ["Test"], 20, batch_size=args.max_batch_size, iterations=3)
5265

@@ -77,6 +90,11 @@ def main(args):
7790
action="store_true",
7891
help="Use dynamic shapes",
7992
)
93+
parser.add_argument(
94+
"--pytorch",
95+
action="store_true",
96+
help="Use pytorch runtime and no tensorrt",
97+
)
8098
parser.add_argument("--max_batch_size", type=int, default=1)
8199
args = parser.parse_args()
82100
main(args)

0 commit comments

Comments
 (0)