99 _get_decomp_for_cia ,
1010)
1111from torch ._ops import OpOverload
12-
1312from torch_tensorrt .dynamo ._defaults import default_device
1413from torch_tensorrt .dynamo .conversion .converter_utils import get_positive_dim
1514from torch_tensorrt .dynamo .utils import to_torch_device
@@ -423,8 +422,8 @@ def instance_norm_decomposition(
423422
424423@register_torch_trt_decomposition (
425424 torch .ops .aten .full_like , registry = TORCH_TRT_DECOMPOSITIONS
426- ) # type: ignore
427- def full_like_decomposition (* args , ** kwargs ) -> torch .Tensor :
425+ )
426+ def full_like_decomposition (* args : Any , ** kwargs : Any ) -> torch .Tensor :
428427 input = args [0 ]
429428 shape = args [0 ].shape
430429 fill_value = args [1 ]
@@ -454,11 +453,13 @@ def scaled_dot_product_attention_decomposition(
454453) -> torch .Tensor :
455454 L , S = query .size (- 2 ), key .size (- 2 )
456455 device = query .device
457- attn_bias = torch .zeros (L , S , dtype = query .dtype , device = device )
456+
457+ if is_causal or attn_mask is not None :
458+ attn_bias = torch .zeros ((L , S ), dtype = query .dtype , device = device )
458459
459460 if is_causal :
460461 assert attn_mask is None , "attn_mask must be None when is_causal=True"
461- temp_mask = torch .ones (L , S , dtype = torch .bool , device = device ).tril (diagonal = 0 )
462+ temp_mask = torch .ones (( L , S ) , dtype = torch .bool , device = device ).tril (diagonal = 0 )
462463 attn_bias = attn_bias .masked_fill (temp_mask .logical_not (), float ("-inf" ))
463464
464465 if attn_mask is not None :
@@ -471,17 +472,20 @@ def scaled_dot_product_attention_decomposition(
471472 key = key .repeat_interleave (query .size (- 3 ) // key .size (- 3 ), - 3 )
472473 value = value .repeat_interleave (query .size (- 3 ) // value .size (- 3 ), - 3 )
473474
474- attn_weight = query @ key .transpose (- 2 , - 1 )
475+ attn_weight = torch . matmul ( query , key .transpose (- 2 , - 1 ) )
475476
476477 if scale is None :
477478 scale = torch .sqrt (torch .scalar_tensor (query .size (- 1 ), dtype = torch .int ))
478479 attn_weight = attn_weight / scale
479480 else :
480481 attn_weight = attn_weight * scale
481482
482- attn_weight = attn_weight + attn_bias
483+ if is_causal or attn_mask is not None :
484+ # We only add attn_bias when we have to, otherwise this will have a negative impact on the performance even it's 0.
485+ attn_weight = attn_weight + attn_bias
486+
483487 attn_weight = torch .softmax (attn_weight , dim = - 1 )
484- return attn_weight @ value
488+ return torch . matmul ( attn_weight , value )
485489
486490
487491@register_torch_trt_decomposition (
0 commit comments