-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Closed
Description
Description
I want to convert and build a llm model with tensorrt_llm. one of the layers is RmsNorm. RmsNorm uses this function:
def rms_norm(_class, input: Tensor,
normalized_shape: Union[int, Tuple[int]],
num_groups: int = 1,
weight: Optional[Tensor] = None,
eps: float = 1e-06) -> Tensor:
'''
Add a RMS norm operation on a tensor.
That operation applies the rms-normalization to its input tensor. In its
simplest form, for large language models, the 'normalized_shape' should be
set to the hidden dimension of the activation tensor. Otherwise, it is the
shape of the normalized fraction of the tensor (starting from the
right-most dimension).
The 'weight' tensor corresponds to 'gamma' in the rms-norm formula.
The 'eps' value is added to the variance before computing the squared-root.
Parameters:
input: Tensor
The tensor to normalize.
normalized_shape : Union[int, Tuple[int]]
The shape of the sub-tensor that is normalized. Use 'hidden_dim' to
normalize the inner-most dimension of an activation tensor in LLMs.
num_groups: int = 1
The group size.
weight : Optional[Tensor] = None
The 'gamma' term in layer-norm. Its shape must be
'normalized_shape'.
eps : float
The epsilon term to be added to the variance in the squared-root.weig
Returns:
The output tensor of that operation.
'''
normalized_shape = [normalized_shape] if isinstance(
normalized_shape, int) else normalized_shape
dim = tuple([-i - 1 for i in range(len(normalized_shape))])
if num_groups > 1:
assert len(normalized_shape) == 1
num_channels = input.size()[-1]
ndim = input.ndim()
old_shape = shape(input)
new_shape = concat([input.size(i) for i in range(ndim - 1)] +
[num_groups, num_channels // num_groups])
input = input.view(new_shape)
with precision("float32"):
input_dtype = input.dtype
fp32_input = cast(input, "float32")
_class.register_network_output('fp32_input', fp32_input)
varx = pow(fp32_input, 2.0)
_class.register_network_output('varx1', varx)
varx = varx.mean(dim=dim, keepdim=True)
_class.register_network_output('varx2', varx)
denom = varx + eps
denom = denom.sqrt()
fp32_y = fp32_input / denom
if num_groups > 1:
fp32_y = fp32_y.view(old_shape)
if weight is not None:
fp32_y = fp32_y * weight
y = cast(fp32_y, input_dtype)
return y
Pay attention to this line:
varx = pow(fp32_input, 2.0)
I compared the values of fp32_input and varx with the corresponding ones from PyTorch. The value of fp32_input matches the one in PyTorch, but varx is different.
Here, pow is defined as follows:
pow = partial(elementwise_binary, op=trt.ElementWiseOperation.POW)
PyTorch values:
x.to(dtype=torch.float64).sum().item()
# 143.05388989299536
x.pow(2).to(dtype=torch.float64).sum().item()
# 36695.1599856188
TensorRT values:
self.debug_buffer["transformer.layers.0.input_layernorm.fp32_input"].to(dtype=torch.float64).sum().item()
# 143.05388989299536
self.debug_buffer["transformer.layers.0.input_layernorm.varx1"].to(dtype=torch.float64).sum().item()
# 36694.23137744599
As can be seen, the outputs are not identical. This error gradually increases the divergence between PyTorch and TensorRT, resulting in different outcomes.
Environment
root@eaca564d5aff:/opt/tritonserver# pip list |grep tensorrt
tensorrt 10.10.0.31
tensorrt_cu12 10.10.0.31
tensorrt_cu12_bindings 10.10.0.31
tensorrt_cu12_libs 10.10.0.31
tensorrt_llm 0.20.0rc3
NVIDIA GPU: A100
NVIDIA Driver Version:570.124.06
CUDA Version:12.9
PyTorch Version (if applicable): 2.6.0+xpu
Metadata
Metadata
Assignees
Labels
No labels