Skip to content

Commit 908dc9c

Browse files
authored
Fix mixed fused layer norm to mimick nn.LayerNorm for torch>1.11 (#281)
* If pytorch>=1.11 available we can use nn.LayerNorm instead of MixedLayerNorm * Add MixedFusedLayerNorm fix * Turns out LayerNorm for bf16 is slower using torch==1.11 * Test for LayerNorm
1 parent c85b7c2 commit 908dc9c

File tree

3 files changed

+77
-38
lines changed

3 files changed

+77
-38
lines changed

megatron/fused_kernels/layer_norm_cuda_kernel.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ void cuApplyLayerNorm(
317317
if (gamma != NULL && beta != NULL) {
318318
for (int i = thrx; i < n2; i+=numx) {
319319
U curr = static_cast<U>(lvals[i]);
320-
ovals[i] = gamma[i] * static_cast<V>(c_invvar * (curr - mu)) + beta[i];
320+
ovals[i] = (curr - mu) * c_invvar * static_cast<U>(gamma[i]) + static_cast<U>(beta[i]);
321321
}
322322
} else {
323323
for (int i = thrx; i < n2; i+=numx) {

megatron/model/fused_layer_norm.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,17 @@
1818
with some changes. """
1919

2020
import numbers
21+
22+
from packaging import version
2123
import torch
24+
from torch import nn
2225
from torch.nn.parameter import Parameter
26+
import torch.nn.functional as F
2327
from torch.nn import init
2428
import importlib
2529

30+
from megatron import get_args
31+
2632
global fused_mix_prec_layer_norm_cuda
2733
fused_mix_prec_layer_norm_cuda = None
2834

@@ -62,19 +68,26 @@ def backward(ctx, grad_output):
6268
class MixedFusedLayerNorm(torch.nn.Module):
6369

6470
def __init__(self, normalized_shape, eps=1e-5):
65-
super(MixedFusedLayerNorm, self).__init__()
71+
super(MixedFusedLayerNorm, self).__init__()
72+
73+
global fused_mix_prec_layer_norm_cuda
74+
fused_mix_prec_layer_norm_cuda = importlib.import_module(
75+
"fused_mix_prec_layer_norm_cuda")
6676

67-
global fused_mix_prec_layer_norm_cuda
68-
fused_mix_prec_layer_norm_cuda = importlib.import_module(
69-
"fused_mix_prec_layer_norm_cuda")
77+
if isinstance(normalized_shape, numbers.Integral):
78+
normalized_shape = (normalized_shape,)
79+
self.normalized_shape = torch.Size(normalized_shape)
80+
self.eps = eps
81+
self.weight = Parameter(torch.Tensor(*normalized_shape))
82+
self.bias = Parameter(torch.Tensor(*normalized_shape))
83+
self.reset_parameters()
7084

71-
if isinstance(normalized_shape, numbers.Integral):
72-
normalized_shape = (normalized_shape,)
73-
self.normalized_shape = torch.Size(normalized_shape)
74-
self.eps = eps
75-
self.weight = Parameter(torch.Tensor(*normalized_shape))
76-
self.bias = Parameter(torch.Tensor(*normalized_shape))
77-
self.reset_parameters()
85+
args = get_args()
86+
87+
self.use_meg_ds_fused_layer_norm = (
88+
args.bf16 # Current Meg-DS cuda kernel has better throughput than torch.nn.LayerNorm
89+
or version.parse(torch.__version__) >= version.parse("1.11.0") # https://github.com/pytorch/pytorch/pull/66920
90+
)
7891

7992

8093
def reset_parameters(self):
@@ -84,7 +97,8 @@ def reset_parameters(self):
8497

8598

8699
def forward(self, input):
87-
88-
return FusedLayerNormAffineFunction.apply(
89-
input, self.weight, self.bias, self.normalized_shape,self.eps)
90-
100+
if self.use_meg_ds_fused_layer_norm:
101+
return FusedLayerNormAffineFunction.apply(
102+
input, self.weight, self.bias, self.normalized_shape, self.eps)
103+
else:
104+
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias)

tests/test_model.py

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,14 @@
44

55
import deepspeed
66
import torch
7+
from torch import nn
8+
import torch.nn.functional as F
9+
10+
from megatron.model.fused_layer_norm import MixedFusedLayerNorm
11+
from packaging import version
712

813
from megatron import initialize_megatron, get_args, get_tokenizer, global_vars
9-
from megatron.testing_utils import TestCasePlus, mockenv_context, flatten_arguments
14+
from megatron.testing_utils import TestCasePlus, mockenv_context, flatten_arguments, torch_assert_equal
1015
from megatron.training import setup_model_and_optimizer
1116
from pretrain_gpt import model_provider as gpt_model_provider, get_batch_pipe as get_gpt_batch_pipe
1217
from pretrain_prefix_lm import model_provider as prefix_lm_model_provider, get_batch_pipe as get_prefix_lm_batch_pipe
@@ -51,9 +56,6 @@ def get_default_args():
5156
}
5257

5358

54-
55-
56-
5759
def equal_vectors(tensor1, tensor2, dim=-1):
5860
"""View tensor1 and tensor2 as a list of vectors, and compute equality"""
5961
return torch.linalg.norm(tensor1 - tensor2, dim=dim) == 0
@@ -109,9 +111,7 @@ def test_gpt(self):
109111
output_changed = model(input_token_ids_changed, *input_batch[1:])
110112

111113
# All token in past should be unchanged
112-
self.assertTrue(
113-
torch.all(equal_vectors(output[:, :changed_index], output_changed[:, :changed_index]))
114-
)
114+
torch_assert_equal(output[:, :changed_index], output_changed[:, :changed_index])
115115
# All tokens in the future should have changed
116116
self.assertFalse(
117117
torch.any(equal_vectors(output[:, changed_index:], output_changed[:, changed_index:]))
@@ -173,23 +173,15 @@ def test_prefix_lm_reset_attention_mask(self):
173173
output_changed_target = model(token_ids_changed_target, *input_batch[1:])
174174

175175
# All token in past should be unchanged
176-
self.assertTrue(
177-
torch.all(
178-
equal_vectors(output[0, :changed_target_index], output_changed_target[0, :changed_target_index])
179-
)
180-
)
176+
torch_assert_equal(output[0, :changed_target_index], output_changed_target[0, :changed_target_index])
181177
# All tokens in the future should have changed
182178
self.assertFalse(
183179
torch.any(
184180
equal_vectors(output[0, changed_target_index:], output_changed_target[0, changed_target_index:])
185181
)
186182
)
187183
# Unchanged changed rows should not change either
188-
self.assertTrue(
189-
torch.all(
190-
equal_vectors(output[1, :], output_changed_target[1, :])
191-
)
192-
)
184+
torch_assert_equal(output[1, :], output_changed_target[1, :])
193185

194186
## --------------- CHANGE AN INPUT TOKEN ---------------------------
195187
# Let's change the the last prefix token and make sure that the first token changed
@@ -212,11 +204,7 @@ def test_prefix_lm_reset_attention_mask(self):
212204
)
213205
)
214206
# Unchanged changed rows should not change either
215-
self.assertTrue(
216-
torch.all(
217-
equal_vectors(output[1, :], output_changed_input[1, :])
218-
)
219-
)
207+
torch_assert_equal(output[1, :], output_changed_input[1, :])
220208

221209
def test_prefix_lm_wo_reset_attention_mask(self):
222210
"""
@@ -282,6 +270,43 @@ def test_gpt_rotary_embeddings(self):
282270

283271
#TODO: Check all invariants
284272

273+
def test_fused_layer_norm(self):
274+
command_args = get_default_args()
275+
276+
# Condition to use custom cuda kernel
277+
command_args["--bf16"] = ""
278+
del command_args["--fp16"]
279+
280+
with patch('sys.argv', flatten_arguments(command_args)):
281+
with mockenv_context(**self.dist_env_1_gpu):
282+
initialize_megatron()
283+
args = get_args()
284+
285+
dummy_input = torch.randn(args.micro_batch_size, args.seq_length, args.hidden_size, device="cuda", dtype=torch.bfloat16)
286+
287+
normalized_shape = (args.hidden_size,)
288+
epsilon = 1e-5
289+
mfln = MixedFusedLayerNorm(normalized_shape, eps=epsilon)
290+
291+
self.assertTrue(mfln.use_meg_ds_fused_layer_norm, "Expected model to use Megatron-DeepSpeed custom cuda kernel for LayerNorm.")
292+
self.assertTrue(args.bf16, "Test has to be done in half precision.")
293+
294+
# We set the weight manually so we simulate state that's not the initialisation
295+
weight = torch.randn(args.hidden_size, device="cuda", dtype=torch.bfloat16)
296+
bias = torch.randn(args.hidden_size, device="cuda", dtype=torch.bfloat16)
297+
mfln.weight = nn.Parameter(weight)
298+
mfln.bias = nn.Parameter(bias)
299+
300+
mfln_output = mfln(dummy_input)
301+
# We check that our layernorm matches pytorch 1.11 onwards
302+
if version.parse(torch.__version__) >= version.parse("1.11.0"):
303+
torch_layer_norm_output = F.layer_norm(dummy_input, normalized_shape, weight, bias, eps=epsilon)
304+
else:
305+
# In this case we use can check that basically it corresponds to the fp32 version
306+
torch_layer_norm_output = F.layer_norm(dummy_input.float(), normalized_shape, weight.float(), bias.float(), eps=epsilon).to(torch.bfloat16)
307+
308+
torch_assert_equal(mfln_output, torch_layer_norm_output)
309+
285310

286311
if __name__ == '__main__':
287312
unittest.main()

0 commit comments

Comments
 (0)