Skip to content

low precision of TensorRT 10.10.0.31 when running pow function on GPU A100 #4465

@Alireza3242

Description

@Alireza3242

Description

I want to convert and build a llm model with tensorrt_llm. one of the layers is RmsNorm. RmsNorm uses this function:

def rms_norm(_class, input: Tensor,
             normalized_shape: Union[int, Tuple[int]],
             num_groups: int = 1,
             weight: Optional[Tensor] = None,
             eps: float = 1e-06) -> Tensor:
    '''
    Add a RMS norm operation on a tensor.

    That operation applies the rms-normalization to its input tensor. In its
    simplest form, for large language models, the 'normalized_shape' should be
    set to the hidden dimension of the activation tensor. Otherwise, it is the
    shape of the normalized fraction of the tensor (starting from the
    right-most dimension).

    The 'weight' tensor corresponds to 'gamma' in the rms-norm formula.
    The 'eps' value is added to the variance before computing the squared-root.

    Parameters:
        input: Tensor
            The tensor to normalize.

        normalized_shape : Union[int, Tuple[int]]
            The shape of the sub-tensor that is normalized. Use 'hidden_dim' to
            normalize the inner-most dimension of an activation tensor in LLMs.

        num_groups: int = 1
            The group size.

        weight : Optional[Tensor] = None
            The 'gamma' term in layer-norm. Its shape must be
            'normalized_shape'.

        eps : float
            The epsilon term to be added to the variance in the squared-root.weig
    Returns:
        The output tensor of that operation.
    '''
    normalized_shape = [normalized_shape] if isinstance(
        normalized_shape, int) else normalized_shape

    dim = tuple([-i - 1 for i in range(len(normalized_shape))])

    if num_groups > 1:
        assert len(normalized_shape) == 1
        num_channels = input.size()[-1]
        ndim = input.ndim()
        old_shape = shape(input)
        new_shape = concat([input.size(i) for i in range(ndim - 1)] +
                           [num_groups, num_channels // num_groups])
        input = input.view(new_shape)

    with precision("float32"):
        input_dtype = input.dtype
        fp32_input = cast(input, "float32")
        _class.register_network_output('fp32_input', fp32_input)
        varx = pow(fp32_input, 2.0)
        _class.register_network_output('varx1', varx)
        varx = varx.mean(dim=dim, keepdim=True)
        _class.register_network_output('varx2', varx)
        denom = varx + eps
        denom = denom.sqrt()
        fp32_y = fp32_input / denom
        
        if num_groups > 1:
            fp32_y = fp32_y.view(old_shape)
            
        if weight is not None:
            fp32_y = fp32_y * weight
        y = cast(fp32_y, input_dtype)

    return y

Pay attention to this line:
varx = pow(fp32_input, 2.0)

I compared the values of fp32_input and varx with the corresponding ones from PyTorch. The value of fp32_input matches the one in PyTorch, but varx is different.
Here, pow is defined as follows:
pow = partial(elementwise_binary, op=trt.ElementWiseOperation.POW)

PyTorch values:

x.to(dtype=torch.float64).sum().item()
# 143.05388989299536

x.pow(2).to(dtype=torch.float64).sum().item()
# 36695.1599856188

TensorRT values:

self.debug_buffer["transformer.layers.0.input_layernorm.fp32_input"].to(dtype=torch.float64).sum().item()
# 143.05388989299536

self.debug_buffer["transformer.layers.0.input_layernorm.varx1"].to(dtype=torch.float64).sum().item()
# 36694.23137744599

As can be seen, the outputs are not identical. This error gradually increases the divergence between PyTorch and TensorRT, resulting in different outcomes.

Environment

root@eaca564d5aff:/opt/tritonserver# pip list |grep tensorrt
tensorrt                           10.10.0.31
tensorrt_cu12                      10.10.0.31
tensorrt_cu12_bindings             10.10.0.31
tensorrt_cu12_libs                 10.10.0.31
tensorrt_llm                       0.20.0rc3

NVIDIA GPU: A100

NVIDIA Driver Version:570.124.06

CUDA Version:12.9

PyTorch Version (if applicable): 2.6.0+xpu

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions