-
Notifications
You must be signed in to change notification settings - Fork 654
Expand file tree
/
Copy pathbase.py
More file actions
1652 lines (1472 loc) · 71.2 KB
/
base.py
File metadata and controls
1652 lines (1472 loc) · 71.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Base modules and utilities for TransformerEngine PyTorch API"""
import io
import math
import os
import pickle
import warnings
from enum import Enum
from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
from typing_extensions import Self
from contextlib import contextmanager
import logging
from types import MethodType
import torch
import torch.nn.functional as F
from torch.distributed.tensor import DTensor
import transformer_engine_torch as tex
from ._common import _ParameterInitMeta, noop_cat
from ..quantization import (
MXFP8BlockScalingRecipeState,
DelayedScalingRecipeState,
Float8CurrentScalingRecipeState,
Float8BlockScalingRecipeState,
NVFP4BlockScalingRecipeState,
FP8GlobalStateManager,
RecipeState,
)
from ..distributed import (
gather_along_first_dim,
is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase,
_fsdp_gather_tensors,
)
from ..constants import dist_group_type
from ..cpp_extensions.gemm import _NUM_MAX_UB_STREAMS
from ..quantized_tensor import QuantizedTensor, QuantizedTensorStorage, Quantizer
from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.storage.float8_tensor_storage import Float8TensorStorage
from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from ..tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage
from ..utils import (
is_non_tn_fp8_gemm_supported,
torch_get_autocast_gpu_dtype,
get_nvtx_range_context,
nvtx_range_push,
nvtx_range_pop,
)
from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ...common.recipe import DelayedScaling, Recipe
from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor
from ...debug.pytorch.utils import next_iter_when_debug_should_be_run, any_feature_enabled
__all__ = ["initialize_ub", "destroy_ub", "UserBufferQuantizationMode"]
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True
_dummy_wgrads = {}
_ub_communicators = None
_MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None
layers_atomic_ring_exchange = []
class UserBufferQuantizationMode(Enum):
"""
UserBufferQuantizationMode is an enum that represents the quantization mode of the UserBuffer.
"""
NONE = "none"
FP8 = "fp8"
def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor:
"""Returns a dummy tensor of given shape."""
assert len(shape) == 2
global _dummy_wgrads
if (shape[0], shape[1], dtype) not in _dummy_wgrads:
_dummy_wgrads[(shape[0], shape[1], dtype)] = torch.empty(
shape,
dtype=dtype,
device="cuda",
requires_grad=False,
)
if zero:
_dummy_wgrads[(shape[0], shape[1], dtype)].fill_(0)
return _dummy_wgrads[(shape[0], shape[1], dtype)].detach()
def initialize_ub(
shape: list,
tp_size: int,
use_fp8: bool = False,
quantization_modes: List[UserBufferQuantizationMode] = None,
dtype: torch.dtype = torch.bfloat16,
ub_cfgs: Optional[Union[dict, List[dict]]] = None,
bootstrap_backend: Union[str, torch.distributed.Backend] = None,
) -> None:
r"""
Initialize the Userbuffers communicator for overlapping tensor-parallel communications with
GEMM compute in ``te.Linear``, ``te.LayerNormLinear`` and ``te.LayerNormMLP`` modules.
Parameters
----------
shape : list
shape of the communication buffer, typically set to be the same as the global shape of
the input tensor to a ``te.TransformerLayer`` forward pass, with the sequence and batch
dimensions collapsed together -- i.e.: ``(sequence_length * batch_size, hidden_size)``
tp_size : int
number of GPUs in the tensor-parallel process group
use_fp8 : bool = False
allocate the communication buffer for FP8 GEMM inputs/outputs.
DEPRECATED: Please use ``quantization_modes`` instead.
quantization_modes : List[UserBufferQuantizationMode] = None
if a list of UserBufferQuantizationMode is provided, a UB communicator is created for each quantization setting in the list.
falls back to the legacy ``use_fp8`` parameter if ``None`` is provided.
dtype : torch.dtype = torch.bfloat16
non-FP8 data type of the communication buffer when ``use_fp8 = False``
ub_cfgs : dict = None
Configuration dictionary with the structure::
{
<gemm_name> : {
"method": <"ring_exchange" or "pipeline">,
"is_reduce_scatter": bool,
"num_sm": int,
"cga_size": int,
"set_sm_margin": bool,
"num_splits": int,
"aggregate": bool,
"atomic_gemm": bool,
"use_ce": bool,
"fp8_buf": bool,
}
}
for ``te.TransformerLayer`` GEMM layers in ``["qkv_fprop", "qkv_dgrad", "qkv_wgrad",
"proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad",
"fc2_fprop", "fc2_wgrad"]``.
a list may be provided to specify different overlap configurations for different the quantization settings in ``quantization_modes``
bootstrap_backend : str = None
``torch.distributed`` communication backend for the all-gather, broadcast and
barrier collectives during Userbuffers initialization. Not all backends are
valid for every cluster configuration and distributed launch method even if
they are available in PyTorch. When left unset, the initialization prefers
to use the MPI backend, falling back first on Gloo and then NCCL if MPI is
not available. Setting ``NVTE_UB_WITH_MPI=1`` when building TE overrides this
option and always initializes Userbuffers with direct MPI calls in C++,
which also requires ``MPI_HOME=/path/to/mpi/root`` to be set at compile time.
"""
if not tex.device_supports_multicast():
assert bool(int(os.getenv("UB_SKIPMC", "0"))), (
"CUDA device, driver and/or toolkit version does not support comm+GEMM overlap with "
+ "CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead."
)
if not quantization_modes:
warnings.warn(
"Initializing Userbuffers with use_fp8 is deprecated. Please use quantization_modes"
" instead.",
DeprecationWarning,
)
quantization_modes = [
UserBufferQuantizationMode.FP8 if use_fp8 else UserBufferQuantizationMode.NONE
]
else:
assert isinstance(quantization_modes, list), "quantization_modes must be a list"
assert all(
isinstance(mode, UserBufferQuantizationMode) for mode in quantization_modes
), "quantization_modes must be a list of UserBufferQuantizationMode"
if isinstance(ub_cfgs, dict) or ub_cfgs is None:
ub_cfgs = [ub_cfgs] * len(quantization_modes)
else:
assert len(ub_cfgs) == len(
quantization_modes
), "Number of ub_cfgs settings must match number of quantization configurations"
global _ub_communicators
assert _ub_communicators is None, "UB communicators are already initialized."
_ub_communicators = {}
if tex.ubuf_built_with_mpi():
# We're bootstrapping with direct calls to MPI in Userbuffers code so we need to force
# an MPI_Init() here by creating a new MPI process group...
assert torch.distributed.is_mpi_available()
_ = torch.distributed.new_group(backend="mpi")
helper = tex.CommOverlapHelper()
else:
# Bootstrapping with torch.distributed API, so check backend and construct
# intra/inter-node process groups...
assert (
torch.distributed.is_initialized()
), "torch.distributed must be initialized before Userbuffers"
if bootstrap_backend is None:
bootstrap_backend = "nccl"
if torch.distributed.is_mpi_available():
bootstrap_backend = "mpi"
elif torch.distributed.is_gloo_available():
bootstrap_backend = "gloo"
else:
assert bootstrap_backend in [
"gloo",
"mpi",
"nccl",
], "Invalid torch.distributed backend for bootstrapping Userbuffers!"
assert torch.distributed.is_backend_available(bootstrap_backend), (
f"PyTorch must be compiled with '{bootstrap_backend}' support in order to "
f"bootstrap Userbuffers with '{bootstrap_backend}' collectives."
)
world_group = torch.distributed.new_group(backend=bootstrap_backend)
world_rank = torch.distributed.get_rank(world_group)
world_size = torch.distributed.get_world_size(world_group)
num_domains = world_size // tp_size
mydomain_idx = world_rank // tp_size
if num_domains > 1:
ranks_per_domain_list = [
[i * tp_size + t for t in range(tp_size)] for i in range(num_domains)
]
tp_domain_group, _ = torch.distributed.new_subgroups_by_enumeration(
ranks_per_domain_list, backend=bootstrap_backend
)
local_rank = torch.distributed.get_rank(tp_domain_group)
tp_domain_ranks = torch.distributed.get_process_group_ranks(tp_domain_group)
helper = tex.CommOverlapHelper(world_group, tp_domain_group)
else:
# TP model on single NVLink domain, no replication, no data-parallelism
mydomain_idx = 0
local_rank = world_rank
tp_domain_ranks = list(range(world_size))
helper = tex.CommOverlapHelper(world_group)
if world_rank == 0:
print(f"!!! [UB] Number of TP domains: {num_domains}\n", end="", flush=True)
if local_rank == 0:
print(
f"!!! [UB] Global ranks on TP domain {mydomain_idx}: {tp_domain_ranks}\n",
end="",
flush=True,
)
# Default buffer precision: AllGather buffers use fp8 when using fp8 recipe
layers_all_gather_overlap = [
"qkv_fprop",
"qkv_dgrad",
"proj_dgrad",
"proj_wgrad",
"fc1_fprop",
"fc1_dgrad",
"fc2_dgrad",
"fc2_wgrad",
]
layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"]
dgrad_reduce_scatter_overlap = ["qkv_dgrad", "fc1_dgrad"]
# Default overlap methods for layers
methods = {
"ring_exchange": [
"qkv_fprop",
"fc1_fprop",
"proj_dgrad",
"fc2_dgrad",
],
"pipeline": ["proj_fprop", "fc2_fprop"],
"bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
"external": ["proj_wgrad", "fc2_wgrad"],
}
# AG-RS overlap pairs of layers forming a tensor-parallel block
ag_rs_pairs = {"qkv_fprop": "proj_fprop", "fc1_fprop": "fc2_fprop"}
rs_ag_pairs = {v: k for k, v in ag_rs_pairs.items()}
external_gemm_to_overlap = {"proj_wgrad": "proj_dgrad", "fc2_wgrad": "fc2_dgrad"}
global layers_atomic_ring_exchange
layers_atomic_ring_exchange = []
def get_method(name):
for method, names in methods.items():
if name in names:
return method
raise KeyError(f"Given layer name {name} does not exist.")
def get_default_config(name):
global _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY
method = get_method(name)
is_reduce_scatter = name in layers_reduce_scatter_overlap
if _MIN_STREAM_PRIORITY is None or _MAX_STREAM_PRIORITY is None:
_MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = tex.get_stream_priority_range()
default_cfg = {
"method": method,
"is_reduce_scatter": is_reduce_scatter,
"num_sm": 1 if method == "ring_exchange" else 16,
"cga_size": 1 if method == "ring_exchange" else 2,
"set_sm_margin": not method == "ring_exchange",
"num_splits": tp_size if method == "ring_exchange" else 4,
"aggregate": False,
"atomic_gemm": False,
"use_ce": True,
"fp8_buf": name in layers_all_gather_overlap,
"comm_priority": _MAX_STREAM_PRIORITY,
"gemm_priority": _MIN_STREAM_PRIORITY,
"pipeline_rs_overlap_first_gemm": False,
}
return default_cfg
def add_ub(
name: str,
quantization_mode: UserBufferQuantizationMode,
method: str,
is_reduce_scatter: bool,
num_sm: int = 16,
cga_size: int = 2,
set_sm_margin: bool = False,
num_splits: int = 0,
aggregate: bool = False,
atomic_gemm: bool = False,
use_ce: bool = True,
fp8_buf: bool = False,
comm_priority: int = 0,
gemm_priority: int = 0,
pipeline_rs_overlap_first_gemm: bool = False,
) -> None:
if atomic_gemm:
warnings.warn(
"Atomic GEMM uses a beta API from cublas and is not tested for all use cases."
)
assert (
quantization_mode == UserBufferQuantizationMode.FP8
), "Atomic GEMM overlap supported only for FP8 GEMM."
if method in ("bulk", "external"):
warnings.warn(
f"At {name}, atoimic GEMM not is supported for a bulk overlap."
"Defaulting to `atomic_gemm=False`."
)
atomic_gemm = 0
if not is_reduce_scatter and method == "pipeline":
raise ValueError(
f"At {name}, `pipeline` overlap method is not supported for AllGather."
)
# Check if both AG and RS overlaps use `atomic GEMM`` + `p2p ring-exchange`.
# Using atomic GEMM + p2p ring-exchange in only one of the pair breaks functionality.
global layers_atomic_ring_exchange
if atomic_gemm and method == "ring_exchange" and name in ag_rs_pairs:
layers_atomic_ring_exchange += [name, ag_rs_pairs[name]]
if name in rs_ag_pairs:
assert_message = (
f"At {name}, atomic AG-GEMM overlap with `ring_exchange` shuffles GEMM chunk "
"outputs, and RS-GEMM overlap un-suffle them. When one of the GEMM-AG and "
"GEMM-RS overlaps forming a TP block (e.g., qkv_fprop and proj_fprop) uses "
"`atomic gemm` and `ring_exhcnage`, its pair must use the same overlap config "
"for functionality."
)
if name in layers_atomic_ring_exchange:
assert atomic_gemm and method == "ring_exchange", assert_message
else:
if atomic_gemm and method == "ring_exchange":
assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message
if name in external_gemm_to_overlap:
assert method == "external", (
f"At {name}, `external` overlap method is specified, but the selected method is"
f" {method}"
)
assert external_gemm_to_overlap[name] in methods["ring_exchange"], (
f"At {name}, `external` overlap method is specified, but the external gemm"
f" {external_gemm_to_overlap[name]} is not using `ring_exchange` overlap method"
)
buffer_dtype = (
torch.uint8
if (quantization_mode == UserBufferQuantizationMode.FP8 and fp8_buf)
else dtype
)
if method == "ring_exchange":
ub_obj = tex.CommOverlapP2P(
shape, # Communication buffer shape
buffer_dtype, # Communication buffer data type
helper, # Helper for torch.distributed callbacks during bootstrapping
tp_size, # Tensor-parallel group size (may be different than local_size)
tex.CommOverlapType.RS if is_reduce_scatter else tex.CommOverlapType.AG,
num_max_streams=_NUM_MAX_UB_STREAMS,
comm_cga_size=cga_size,
num_comm_sm=num_sm,
set_sm_margin=set_sm_margin,
atomic_gemm=atomic_gemm,
use_ce=use_ce,
aggregate=aggregate,
gemm_priority=gemm_priority,
comm_priority=comm_priority,
)
else:
ub_obj = tex.CommOverlap(
shape, # Communication buffer shape
buffer_dtype, # Communication buffer data type
helper, # Helper for torch.distributed callbacks during bootstrapping
tp_size, # Tensor-parallel group size (may be different than local_size)
num_splits=num_splits,
num_max_streams=_NUM_MAX_UB_STREAMS,
comm_cga_size=cga_size,
num_comm_sm=num_sm,
set_sm_margin=set_sm_margin,
atomic_gemm=atomic_gemm,
gemm_priority=gemm_priority,
comm_priority=comm_priority,
rs_overlap_first_gemm=pipeline_rs_overlap_first_gemm,
)
_ub_communicators[(name, quantization_mode)] = ub_obj
for quantization_mode, user_ub_cfg in zip(quantization_modes, ub_cfgs):
if user_ub_cfg is not None:
for name in dgrad_reduce_scatter_overlap:
if (
name in user_ub_cfg
and "method" in user_ub_cfg[name]
and user_ub_cfg[name]["method"] != "bulk"
):
wgrad_name = name.replace("dgrad", "wgrad")
assert wgrad_name not in user_ub_cfg
layers_reduce_scatter_overlap.remove(wgrad_name)
layers_all_gather_overlap.remove(name)
layers_reduce_scatter_overlap.append(name)
methods["bulk"].remove(name)
new_method = user_ub_cfg[name]["method"]
methods[new_method].append(name)
for name in (
methods["ring_exchange"] + methods["pipeline"] + methods["bulk"] + methods["external"]
):
ub_cfg = get_default_config(name)
if user_ub_cfg is not None and name in user_ub_cfg:
fp8_buf = (name in layers_all_gather_overlap) or (
user_ub_cfg[name].get("fp8_buf", False) and name in methods["pipeline"]
)
ub_cfg.update(user_ub_cfg[name])
ub_cfg["fp8_buf"] = fp8_buf
add_ub(name, quantization_mode, **ub_cfg)
def get_ub(name: str, use_fp8: bool):
"""Get userbuffer communicator corresponding to give key."""
# For now use `use_fp8` boolean input as it matches the current design in the modules
# So favour simplicity until the correct design becomes clear.
# This is mainly an internal API so we don't need to worry about future changes
key = (name, UserBufferQuantizationMode.FP8 if use_fp8 else UserBufferQuantizationMode.NONE)
assert _ub_communicators is not None, "UB manager is not initialized."
assert key in _ub_communicators, f"UB for {name} with use_fp8={use_fp8} is not registered."
return _ub_communicators[key]
def destroy_ub():
"""Destroy all allocated userbuffer communicators."""
global _ub_communicators
_ub_communicators = None
global layers_atomic_ring_exchange
layers_atomic_ring_exchange = []
def fill_userbuffers_buffer_for_all_gather(
comm,
local_tensor: torch.Tensor,
quantizer: Optional[Quantizer],
process_group,
) -> tuple[torch.Tensor | QuantizedTensorStorage, torch.Tensor | QuantizedTensorStorage]:
"""Fill local shard of Userbuffers buffer with data for all-gather
Returns the full tensor and the local shard, both using the
Userbuffers buffer as their underlying data. These tensors should
be used carefully (e.g. only immediately before and after a
Userbuffers operation) since the underlying data may be
overwritten by other Userbuffers operations.
May perform blocking communication if needed for the gathered
tensor's metadata, e.g. scaling factors.
"""
# Tensor dimensions
local_shape = local_tensor.size()
if not local_shape:
raise ValueError(f"Invalid local tensor (shape={tuple(local_shape)})")
process_group_size = torch.distributed.get_world_size(process_group)
global_shape = list(local_shape)
global_shape[0] *= process_group_size
# Unquantized data
if quantizer is None:
if isinstance(local_tensor, QuantizedTensorStorage):
local_tensor = local_tensor.dequantize()
if comm.is_fp8_ubuf():
raise RuntimeError(
"Attempting to all-gather unquantized tensor, "
"but Userbuffers is initialized with FP8 buffers"
)
comm.copy_into_buffer(local_tensor, local_chunk=True)
global_tensor = comm.get_buffer(shape=global_shape)
return global_tensor, local_tensor
# FP8 data
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
if not isinstance(local_tensor, Float8TensorStorage):
if isinstance(local_tensor, QuantizedTensorStorage):
local_tensor.dequantize()
quantizer.set_usage(rowwise=True, columnwise=False)
local_tensor = quantizer(local_tensor)
if not comm.is_fp8_ubuf():
raise RuntimeError(
"Attempting to all-gather FP8 tensor, "
"but Userbuffers is not initialized with FP8 buffers"
)
comm.copy_into_buffer(local_tensor._data, local_chunk=True)
global_tensor_data = comm.get_buffer(shape=global_shape)
global_tensor = Float8TensorStorage(
data=global_tensor_data,
fp8_scale_inv=local_tensor._scale_inv,
fp8_dtype=local_tensor._fp8_dtype,
quantizer=quantizer,
)
return global_tensor, local_tensor
# MXFP8 data
if isinstance(quantizer, MXFP8Quantizer):
# Cast to MXFP8 if needed
if not isinstance(local_tensor, MXFP8TensorStorage):
if isinstance(local_tensor, QuantizedTensorStorage):
local_tensor.dequantize()
local_tensor = quantizer(local_tensor)
if not comm.is_fp8_ubuf():
raise RuntimeError(
"Attempting to all-gather MXFP8 tensor, "
"but Userbuffers is not initialized with FP8 buffers"
)
# Check which MXFP8 buffer to communicate
if quantizer.rowwise_usage == quantizer.columnwise_usage:
raise ValueError(
"Userbuffers can only communicate one MXFP8 buffer at a time, "
f"but quantizer has rowwise_usage={quantizer.rowwise_usage}, "
f"columnwise_usage={quantizer.columnwise_usage}"
)
with_rowwise_data = quantizer.rowwise_usage
# Copy MXFP8 data to local chunk of Userbuffers buffer
local_data = (
local_tensor._rowwise_data if with_rowwise_data else local_tensor._columnwise_data
)
comm.copy_into_buffer(local_data, local_chunk=True)
# Gather scaling-inverses
if math.prod(local_shape[:-1]) % 128 != 0:
raise ValueError(
"Userbuffers requires MXFP8 tensor dims that are divisible by 128, "
f"but got MXFP8 tensor with shape={tuple(local_shape)}"
)
local_scale_inv = (
local_tensor._rowwise_scale_inv
if with_rowwise_data
else local_tensor._columnwise_scale_inv
)
local_scale_inv_size = list(local_scale_inv.size())
global_scale_inv = torch.empty(
[process_group_size * local_scale_inv_size[0]] + local_scale_inv_size[1:],
dtype=local_scale_inv.dtype,
device=local_scale_inv.device,
)
torch.distributed.all_gather_into_tensor(
global_scale_inv,
local_scale_inv,
group=process_group,
)
# Construct MXFP8 tensor with Userbuffers buffer
rowwise_data, rowwise_scale_inv = None, None
columnwise_data, columnwise_scale_inv = None, None
global_data = comm.get_buffer(shape=global_shape)
if with_rowwise_data:
rowwise_data, rowwise_scale_inv = global_data, global_scale_inv
else:
columnwise_data, columnwise_scale_inv = global_data, global_scale_inv
global_tensor = MXFP8TensorStorage(
rowwise_data=rowwise_data,
rowwise_scale_inv=rowwise_scale_inv,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
fp8_dtype=local_tensor._fp8_dtype,
quantizer=quantizer,
)
return global_tensor, local_tensor
# Unsupported data format
raise ValueError(f"Unsupported quantizer for Userbuffers ({quantizer})")
class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Base TE module."""
def __init__(self) -> None:
super().__init__()
assert torch.cuda.is_available(), "TransformerEngine needs CUDA."
self.name = None
self.next_iter_when_debug_should_be_run = 0
self.fp8_initialized = False
self.fp8 = False
self.fp8_calibration = False
self.fp8_meta = {}
self.fp8_meta["fp8_checkpoint"] = False
self.fp8_meta["fp8_group"] = None
self.fp8_meta_tensors_initialized = False
self.quantizers = {"scaling_fwd": {}, "scaling_bwd": {}}
self.tp_group = None
self.tp_size = 1
self.sequence_parallel = False
self.param_init_meta = {}
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
self.preserve_high_precision_init_val = FP8GlobalStateManager.with_high_precision_init_val()
self.fsdp_wrapped = False
self.fsdp_group = None
self._fp8_workspaces: Dict[str, QuantizedTensor] = {}
self.activation_dtype: Optional[torch.dtype] = None
self.wgrad_accumulation_and_reduce_hooks = []
self.wgrad_store = None
if not TEDebugState.debug_enabled:
TEDebugState.initialize()
def fast_setattr(self, name: str, value: Any) -> None:
"""
Fast version of the Module's set attribute function.
Should be used for regular attributes, but not properties nor parameters/buffers.
"""
self.__dict__[name] = value
def module_setattr(self, name: str, value: Any) -> None:
"""
Regular version of the Module's set attribute function.
Should be used only when the fast version cannot be used - for the properties,
parameters and buffers.
"""
super().__setattr__(name, value)
def __setattr__(self, name: str, value: Any) -> None:
if "_initialized" in self.__dict__ and self._initialized:
warnings.warn(
"""The default implementation of torch.nn.Module introduces significant CPU overhead
when setting attributes and is therefore not recommended. Please use the explicit
calls (fast_setattr for setting regular values and module_setattr for setting
parameters, children modules and buffers).""",
RuntimeWarning,
)
super().__setattr__(name, value)
def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None:
"""
Delayed scaling only.
Increase or decrease size of amax history based on given `length`.
.. warning::
This changes the underlying amax memory location.
"""
if fwd is None:
fp8_meta_tensor_keys = ("scaling_fwd", "scaling_bwd")
else:
fp8_meta_tensor_keys = ("scaling_fwd" if fwd else "scaling_bwd",)
for meta_key in fp8_meta_tensor_keys:
if meta_key not in self.fp8_meta:
# Handles non-parameter FP8 modules, e.g. DPA.
continue
curr_len = self.fp8_meta[meta_key].amax_history.shape[0]
if length == curr_len:
continue
if length < curr_len:
self.fp8_meta[meta_key].amax_history = (
self.fp8_meta[meta_key].amax_history[:length].clone()
)
elif length > curr_len:
extra_rows = length - curr_len
self.fp8_meta[meta_key].amax_history = F.pad(
self.fp8_meta[meta_key].amax_history, pad=(0, 0, 0, extra_rows)
)
# Update quantizers with new amax pointers.
self.quantizers[meta_key] = self.fp8_meta[meta_key].make_quantizers()
# Make sure weight tensors has correct quantizers
self._update_weight_quantizers()
# Update the global buffers with new amax and history pointers.
if FP8GlobalStateManager.get_buffer_info() in self.fp8_meta:
fwd_pos, fwd_key, bwd_pos, bwd_key = self.fp8_meta[
FP8GlobalStateManager.get_buffer_info()
]
for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)):
if buffer_key in FP8GlobalStateManager.global_amax_buffer:
assert (
buffer_key in FP8GlobalStateManager.global_amax_history_buffer
), "TE internal error during amax history change."
FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = self.fp8_meta[
meta_key
].amax_history[0]
FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = (
self.fp8_meta[meta_key].amax_history
)
def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
"""Init scales and amaxes for fwd | bwd."""
fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd"
# Return early if recipe state matches recipe
if self.fp8_meta_tensors_initialized:
recipe_state = self.fp8_meta[fp8_meta_tensor_key]
if recipe.delayed() and isinstance(recipe_state, DelayedScalingRecipeState):
self.adjust_amax_history_length(recipe.amax_history_len, fwd=fwd)
return
if recipe.mxfp8() and isinstance(recipe_state, MXFP8BlockScalingRecipeState):
return
if recipe.float8_current_scaling() and isinstance(
recipe_state, Float8CurrentScalingRecipeState
):
return
if recipe.float8_block_scaling() and isinstance(
recipe_state, Float8BlockScalingRecipeState
):
return
if recipe.nvfp4() and isinstance(recipe_state, NVFP4BlockScalingRecipeState):
return
# Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
# 2 (grad_output and grad_input) for bwd
num_fp8_tensors = self.fp8_meta["num_gemms"] * 3 if fwd else self.fp8_meta["num_gemms"] * 2
# Initialize recipe state and quantizers
recipe_state = RecipeState.create(
recipe,
mode=("forward" if fwd else "backward"),
num_quantizers=num_fp8_tensors,
)
self.fp8_meta[fp8_meta_tensor_key] = recipe_state
self.quantizers[fp8_meta_tensor_key] = recipe_state.make_quantizers()
def _update_weight_quantizers(self) -> None:
"""Update the quantizers for the weight tensors."""
weight_tensors = self._get_weight_tensors()
weight_quantizers = self._get_weight_quantizers()
assert len(weight_tensors) == len(weight_quantizers), (
f"Number of weight tensors ({len(weight_tensors)}) and quantizers "
f"({len(weight_quantizers)}) must match"
)
for weight, quantizer in zip(weight_tensors, weight_quantizers):
if quantizer is not None and isinstance(weight, QuantizedTensorStorage):
weight.update_quantizer(quantizer)
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]:
"""Get the weight tensors of the module."""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement _get_weight_tensors function"
)
def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement _get_weight_quantizers function"
)
def init_fp8_meta_tensors(self, recipe: Recipe) -> None:
"""Init scales and amaxes."""
self.set_meta_tensor(True, recipe)
self.set_meta_tensor(False, recipe)
self.fast_setattr("fp8_meta_tensors_initialized", True)
def get_fp8_meta_tensors(self) -> None:
"""Get scales and amaxes."""
fwd_key, bwd_key = "scaling_fwd", "scaling_bwd"
if fwd_key not in self.fp8_meta or bwd_key not in self.fp8_meta:
return None
fp8_meta_tensors = {fwd_key: [], bwd_key: []}
with torch.no_grad():
for key in (fwd_key, bwd_key):
fp8_meta_tensors[key].append(self.fp8_meta[key].scale.clone())
fp8_meta_tensors[key].append(self.fp8_meta[key].amax_history.clone())
return fp8_meta_tensors
def reset_fp8_meta_tensors(self, fp8_meta_tensors=None) -> None:
"""Reset scales and amaxes."""
def reset(key):
if key in self.fp8_meta:
if fp8_meta_tensors is None:
self.fp8_meta[key].scale.copy_(torch.ones_like(self.fp8_meta[key].scale))
self.fp8_meta[key].amax_history.copy_(
torch.zeros_like(self.fp8_meta[key].amax_history)
)
else:
assert key in fp8_meta_tensors, "Cannot reset fp8 tensors."
self.fp8_meta[key].scale.copy_(fp8_meta_tensors[key][0])
self.fp8_meta[key].amax_history.copy_(fp8_meta_tensors[key][1])
with torch.no_grad():
reset("scaling_fwd")
reset("scaling_bwd")
def get_extra_state(self) -> torch.Tensor:
"""Save before checkpointing."""
# This implementation is working around a few issues:
#
# (1) PyTorch's "extra state" infrastructure might be able to
# support any picklable type, but they make no guarantees.
# We have experienced problems (e.g. in ONNX export) with
# non-tensor extra state.
# (2) PyTorch's checkpointing infrastructure does not remap
# devices for "extra state" like it does for "state dict".
# Thus, we want to avoid putting extra state on the GPU
# since it may be loaded on the wrong device.
# (3) The extra state consists of many small tensors. If we
# want to copy them all to CPU, then we need to avoid the
# overhead of many GPU-CPU memory transfers.
#
# See: https://github.com/NVIDIA/TransformerEngine/pull/351
# See: https://github.com/NVIDIA/TransformerEngine/pull/363
def to_cpu(src: torch.Tensor) -> torch.Tensor:
"""Helper function to make CPU copy of tensor
Memory transfer is asynchronous w.r.t. host, so GPU should
be synchronized before using result.
"""
dst = torch.empty_like(src, device="cpu")
dst.copy_(src, non_blocking=True)
return dst
# Store FP8 state if needed
state = None
fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration
if not fp8_checkpoint:
return torch.empty(0, dtype=torch.uint8)
# Copy tensors to CPU and store
state = {}
state["recipe"] = self.fp8_meta["recipe"]
if state["recipe"].delayed():
state["scale_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale)
state["amax_history_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].amax_history)
state["scale_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale)
state["amax_history_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].amax_history)
# Store other pickelable values
extra = {}
for k, v in self.fp8_meta.items():
if k != "buffer_index_and_autocast_key" and isinstance(
v, (bool, int, float, str, tuple, list)
):
extra[k] = v
state["extra_fp8_variables"] = extra
# Serialize state into byte tensor
torch.cuda.synchronize()
state_serialized = bytearray(pickle.dumps(state))
state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8)
return state_serialized
def set_extra_state(self, state: torch.Tensor) -> None:
"""Load previous state."""
# Maintain backwards compatibility with older checkpoints.
if state is None:
return
# Load state
if isinstance(state, torch.Tensor):
# No FP8 is indicated by an empty tensor we don't need to unpickle.
if state.numel() == 0:
return
# Default format: byte tensor with pickled data
state = pickle.loads(state.detach().cpu().numpy().tobytes())
elif isinstance(state, io.BytesIO):
# Deprecated format with io.BytesIO
state.seek(0)
state = torch.load(state, map_location="cuda")
else:
raise RuntimeError("Unsupported checkpoint format.")
if state is None:
return
# TE 1.x checkpoint compatibility: add DelayedScaling recipe if missing
if "recipe" not in state:
# TE 1.x only supported delayed scaling, which was the default recipe
state["recipe"] = DelayedScaling()
# TE 1.x also saved scale_inv, which is not needed with Recipe object
state.pop("scale_inv_fwd", None)
state.pop("scale_inv_bwd", None)
# Load extra items
self.fp8_meta.update(state["extra_fp8_variables"])
self.fp8_meta["recipe"] = state["recipe"]
if "global_fp8_buffer_pos_fwd_recompute" in self.fp8_meta:
del self.fp8_meta["global_fp8_buffer_pos_fwd_recompute"]
# Initialize before loading
self.init_fp8_meta_tensors(self.fp8_meta["recipe"])
def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None:
"""Helper function to copy tensor from CPU
Memory transfer is asynchronous w.r.t. host, so GPU should
be synchronized before using result.
"""
dst.copy_(src, non_blocking=True)
# Load tensors
if self.fp8_meta["recipe"].delayed():
copy_tensor(state["scale_fwd"], self.fp8_meta["scaling_fwd"].scale)
copy_tensor(state["amax_history_fwd"], self.fp8_meta["scaling_fwd"].amax_history)
copy_tensor(state["scale_bwd"], self.fp8_meta["scaling_bwd"].scale)
copy_tensor(state["amax_history_bwd"], self.fp8_meta["scaling_bwd"].amax_history)
torch.cuda.synchronize()
def set_activation_dtype(self, inp: torch.Tensor) -> None:
"""Get activation data type for AMP."""
# Native AMP (`torch.autocast`) gets highest priority
if torch.is_autocast_enabled():
self.fast_setattr("activation_dtype", torch_get_autocast_gpu_dtype())
return
# All checks after this have already been performed once, thus skip
if self.activation_dtype == inp.dtype:
return
dtype = inp.dtype
if not self.allow_different_data_and_param_types:
for name, param in self.named_parameters():
if param is not None:
assert dtype == param.dtype, (
"Data types for parameters must match when outside of autocasted region. "
f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
)
self.fast_setattr("activation_dtype", dtype)
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
"""
Set the tensor parallel group for the given
module before executing the forward pass.
Parameters
----------
tp_group : ProcessGroup, default = None
tensor parallel process group.
"""
self.tp_group = tp_group
self.tp_group_initialized = True
def _get_fp8_params(self) -> Union[List[torch.Tensor], None]:
"""returns the FP8 weights."""
fp8_params = []
for param in self.parameters(recurse=False):
if isinstance(param, QuantizedTensor) and param.requires_grad:
fp8_params.append(param)
if len(fp8_params) == 0:
return None
return fp8_params
# This routine is shared across FP8 and FP8_calibration paths so should not actually
# assume FP8 execution.
def init_fp8_metadata(self, num_gemms: int = 1) -> None:
"""Initialize fp8 related metadata and tensors during fprop."""
meta = self.fp8_meta
fp8 = FP8GlobalStateManager.is_fp8_enabled()
fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
self.fast_setattr("fp8_parameters", fp8_parameters)
self.fast_setattr("fp8", fp8)
self.fast_setattr("fp8_calibration", fp8_calibration)
fp8_enabled = fp8 or fp8_calibration
meta["fp8_checkpoint"] = fp8_enabled
_original_recipe = None
if fp8_parameters or fp8_enabled:
_original_recipe = meta.get("recipe", None)
if self.fp8_initialized and FP8GlobalStateManager.get_fp8_recipe() == _original_recipe:
# FP8 init has already been run and recipe is the same, don't do anything.
return