Skip to content

fp8 is slower than fp16 #445

@WangQiangItachi

Description

@WangQiangItachi

Hi experts, I tried to use Transformer Engine to do the inference test of llama2 on H800 and found that the speed of fp8 was much slower than fp16, bellow is a small reproduction that only contains Linear. The results show that the speed of fp8.Linear < te.linear < nn.linear, can anyone tell me if this test is reasonable? Thanks.

Docker image: nvcr.io/nvidia/pytorch:23.08-py3

Test code:

import os
import numpy as np
import torch
import transformer_engine.pytorch as te
from time import perf_counter
from torch import nn
from transformer_engine.common import recipe

os.environ['CUDA_VISIBLE_DEVICES'] = "1"

batch_sizes = [1024, 10240, 10240, 10240, 51200]
hidden_sizes = [1024, 1024, 5120, 10240, 10240]
output_sizes = [1024, 1024, 5120, 10240, 10240]

fp8_recipe = recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3)
fp8_recipe.reduce_amax = False
use_te_fp8s = [[False, False], [True, False], [True, True]]

class Model(torch.nn.Module):
    def __init__(self, hidden_size, output_size, use_te):
        super().__init__()
        if not use_te:
            self.linear = nn.Linear(hidden_size, output_size)
        else:
            self.linear = te.Linear(hidden_size, output_size)

    def forward(self, x):
        for i in range(10):
            x = self.linear(x)
        return x

for i in range(len(batch_sizes)):
    for use_te_fp8 in use_te_fp8s:
        use_te = use_te_fp8[0]
        use_fp8 = use_te_fp8[1]
        batch_size = batch_sizes[i]
        hidden_size = hidden_sizes[i]
        output_size = output_sizes[i]
        model = Model(hidden_size, output_size, use_te)
        model.to('cuda')
        input = torch.randn(batch_size, hidden_size).to('cuda')

        times = []
        for j in range(20):
            if j >= 5:
                start = perf_counter()
            with te.fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
                out = model(input)
            if j >= 5:
                end = perf_counter()
                times.append(end - start)
        print(f'use_te: {str(use_te):5}, use_fp8: {str(use_fp8):5}, bs: {batch_size:5}, hid_size: {hidden_size:5}, out_size: {output_size:5}, cost_time: {1000*np.mean(times):5f}ms')
    print('=============================================')

Test result:

use_te: False, use_fp8: False, bs:  1024, hid_size:  1024, out_size:  1024, cost_time: 0.231356ms
use_te: True , use_fp8: False, bs:  1024, hid_size:  1024, out_size:  1024, cost_time: 0.874557ms
use_te: True , use_fp8: True , bs:  1024, hid_size:  1024, out_size:  1024, cost_time: 3.224556ms
=============================================
use_te: False, use_fp8: False, bs: 10240, hid_size:  1024, out_size:  1024, cost_time: 0.238261ms
use_te: True , use_fp8: False, bs: 10240, hid_size:  1024, out_size:  1024, cost_time: 0.903459ms
use_te: True , use_fp8: True , bs: 10240, hid_size:  1024, out_size:  1024, cost_time: 3.308409ms
=============================================
use_te: False, use_fp8: False, bs: 10240, hid_size:  5120, out_size:  5120, cost_time: 0.248760ms
use_te: True , use_fp8: False, bs: 10240, hid_size:  5120, out_size:  5120, cost_time: 0.917531ms
use_te: True , use_fp8: True , bs: 10240, hid_size:  5120, out_size:  5120, cost_time: 3.316881ms
=============================================
use_te: False, use_fp8: False, bs: 10240, hid_size: 10240, out_size: 10240, cost_time: 0.256033ms
use_te: True , use_fp8: False, bs: 10240, hid_size: 10240, out_size: 10240, cost_time: 0.922597ms
use_te: True , use_fp8: True , bs: 10240, hid_size: 10240, out_size: 10240, cost_time: 11.672506ms
=============================================
use_te: False, use_fp8: False, bs: 51200, hid_size: 10240, out_size: 10240, cost_time: 0.238444ms
use_te: True , use_fp8: False, bs: 51200, hid_size: 10240, out_size: 10240, cost_time: 0.908310ms
use_te: True , use_fp8: True , bs: 51200, hid_size: 10240, out_size: 10240, cost_time: 52.545521ms

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions