Skip to content

Commit 2fb6ee3

Browse files
committed
Simpler solution and fixes
Signed-off-by: Przemek Tredak <[email protected]>
1 parent 0fc2a62 commit 2fb6ee3

File tree

9 files changed

+124
-92
lines changed

9 files changed

+124
-92
lines changed

tests/pytorch/test_numerics.py

Lines changed: 70 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import math
66
import os
77
from typing import Dict, List, Tuple, Optional
8+
import warnings
89
import pytest
910
import random
1011

@@ -1296,14 +1297,15 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_
12961297
).eval()
12971298

12981299
# Share params
1299-
with torch.no_grad():
1300-
te_linear_ref.weight = Parameter(te_linear.weight.clone())
1301-
if bias:
1302-
te_linear_ref.bias = Parameter(te_linear.bias.clone())
1303-
if fuse_wgrad_accumulation:
1304-
weight = getattr(te_linear, f"weight")
1305-
weight.main_grad = torch.rand_like(weight, dtype=torch.float32)
1306-
te_linear_ref.weight.main_grad = weight.main_grad.clone()
1300+
with warnings.catch_warnings(action="ignore", category=RuntimeWarning):
1301+
with torch.no_grad():
1302+
te_linear_ref.weight = Parameter(te_linear.weight.clone())
1303+
if bias:
1304+
te_linear_ref.bias = Parameter(te_linear.bias.clone())
1305+
if fuse_wgrad_accumulation:
1306+
weight = getattr(te_linear, f"weight")
1307+
weight.main_grad = torch.rand_like(weight, dtype=torch.float32)
1308+
te_linear_ref.weight.main_grad = weight.main_grad.clone()
13071309

13081310
te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config, delay_wgrad_compute=True)
13091311
te_outputs_ref = _test_granular_accuracy(
@@ -1359,12 +1361,13 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
13591361
).eval()
13601362

13611363
# Share params
1362-
with torch.no_grad():
1363-
te_linear_ref.weight = Parameter(te_linear.weight.clone())
1364-
if fuse_wgrad_accumulation:
1365-
weight = getattr(te_linear, f"weight")
1366-
weight.main_grad = torch.rand_like(weight, dtype=torch.float32)
1367-
te_linear_ref.weight.main_grad = weight.main_grad.clone()
1364+
with warnings.catch_warnings(action="ignore", category=RuntimeWarning):
1365+
with torch.no_grad():
1366+
te_linear_ref.weight = Parameter(te_linear.weight.clone())
1367+
if fuse_wgrad_accumulation:
1368+
weight = getattr(te_linear, f"weight")
1369+
weight.main_grad = torch.rand_like(weight, dtype=torch.float32)
1370+
te_linear_ref.weight.main_grad = weight.main_grad.clone()
13681371

13691372
te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config, recipe=recipe)
13701373
te_outputs_ref = _test_granular_accuracy(te_linear_ref, bs, dtype, config, recipe=recipe)
@@ -1601,17 +1604,18 @@ def test_layernorm_linear_accuracy_delay_wgrad_compute(
16011604
).eval()
16021605

16031606
# Share params
1604-
with torch.no_grad():
1605-
ln_linear_ref.layer_norm_weight = Parameter(ln_linear.layer_norm_weight.clone())
1606-
if normalization != "RMSNorm":
1607-
ln_linear_ref.layer_norm_bias = Parameter(ln_linear.layer_norm_bias.clone())
1608-
ln_linear_ref.weight = Parameter(ln_linear.weight.clone())
1609-
if bias:
1610-
ln_linear_ref.bias = Parameter(ln_linear.bias.clone())
1611-
if fuse_wgrad_accumulation:
1612-
weight = getattr(ln_linear, f"weight")
1613-
weight.main_grad = torch.rand_like(weight, dtype=torch.float32)
1614-
ln_linear_ref.weight.main_grad = weight.main_grad.clone()
1607+
with warnings.catch_warnings(action="ignore", category=RuntimeWarning):
1608+
with torch.no_grad():
1609+
ln_linear_ref.layer_norm_weight = Parameter(ln_linear.layer_norm_weight.clone())
1610+
if normalization != "RMSNorm":
1611+
ln_linear_ref.layer_norm_bias = Parameter(ln_linear.layer_norm_bias.clone())
1612+
ln_linear_ref.weight = Parameter(ln_linear.weight.clone())
1613+
if bias:
1614+
ln_linear_ref.bias = Parameter(ln_linear.bias.clone())
1615+
if fuse_wgrad_accumulation:
1616+
weight = getattr(ln_linear, f"weight")
1617+
weight.main_grad = torch.rand_like(weight, dtype=torch.float32)
1618+
ln_linear_ref.weight.main_grad = weight.main_grad.clone()
16151619

16161620
te_outputs = _test_granular_accuracy(ln_linear, bs, dtype, config, delay_wgrad_compute=True)
16171621
te_outputs_ref = _test_granular_accuracy(
@@ -1739,19 +1743,24 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute(
17391743
).eval()
17401744

17411745
# Share params
1742-
with torch.no_grad():
1743-
ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone())
1744-
ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone())
1745-
ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone())
1746-
ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone())
1747-
if bias:
1748-
ln_mlp_ref.fc1_bias = Parameter(ln_mlp.fc1_bias.clone())
1749-
ln_mlp_ref.fc2_bias = Parameter(ln_mlp.fc2_bias.clone())
1750-
if fuse_wgrad_accumulation:
1751-
ln_mlp.fc1_weight.main_grad = torch.rand_like(ln_mlp.fc1_weight, dtype=torch.float32)
1752-
ln_mlp_ref.fc1_weight.main_grad = ln_mlp.fc1_weight.main_grad.clone()
1753-
ln_mlp.fc2_weight.main_grad = torch.rand_like(ln_mlp.fc2_weight, dtype=torch.float32)
1754-
ln_mlp_ref.fc2_weight.main_grad = ln_mlp.fc2_weight.main_grad.clone()
1746+
with warnings.catch_warnings(action="ignore", category=RuntimeWarning):
1747+
with torch.no_grad():
1748+
ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone())
1749+
ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone())
1750+
ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone())
1751+
ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone())
1752+
if bias:
1753+
ln_mlp_ref.fc1_bias = Parameter(ln_mlp.fc1_bias.clone())
1754+
ln_mlp_ref.fc2_bias = Parameter(ln_mlp.fc2_bias.clone())
1755+
if fuse_wgrad_accumulation:
1756+
ln_mlp.fc1_weight.main_grad = torch.rand_like(
1757+
ln_mlp.fc1_weight, dtype=torch.float32
1758+
)
1759+
ln_mlp_ref.fc1_weight.main_grad = ln_mlp.fc1_weight.main_grad.clone()
1760+
ln_mlp.fc2_weight.main_grad = torch.rand_like(
1761+
ln_mlp.fc2_weight, dtype=torch.float32
1762+
)
1763+
ln_mlp_ref.fc2_weight.main_grad = ln_mlp.fc2_weight.main_grad.clone()
17551764

17561765
te_outputs = _test_granular_accuracy(ln_mlp, bs, dtype, config, delay_wgrad_compute=True)
17571766
te_outputs_ref = _test_granular_accuracy(
@@ -1796,14 +1805,15 @@ def test_layernorm_mlp_accuracy_checkpoint(
17961805
).eval()
17971806

17981807
# Share params
1799-
with torch.no_grad():
1800-
ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone())
1801-
ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone())
1802-
ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone())
1803-
ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone())
1804-
if bias:
1805-
ln_mlp_ref.fc1_bias = Parameter(ln_mlp.fc1_bias.clone())
1806-
ln_mlp_ref.fc2_bias = Parameter(ln_mlp.fc2_bias.clone())
1808+
with warnings.catch_warnings(action="ignore", category=RuntimeWarning):
1809+
with torch.no_grad():
1810+
ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone())
1811+
ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone())
1812+
ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone())
1813+
ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone())
1814+
if bias:
1815+
ln_mlp_ref.fc1_bias = Parameter(ln_mlp.fc1_bias.clone())
1816+
ln_mlp_ref.fc2_bias = Parameter(ln_mlp.fc2_bias.clone())
18071817

18081818
te_outputs = _test_granular_accuracy(ln_mlp, bs, dtype, config, delay_wgrad_compute=False)
18091819
te_outputs_ref = _test_granular_accuracy(
@@ -1952,9 +1962,13 @@ def test_grouped_linear_accuracy(
19521962
# Share params
19531963
with torch.no_grad():
19541964
for i in range(num_gemms):
1955-
sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone())
1965+
sequential_linear[i].module_setattr(
1966+
"weight", Parameter(getattr(grouped_linear, f"weight{i}").clone())
1967+
)
19561968
if bias:
1957-
sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone())
1969+
sequential_linear[i].module_setattr(
1970+
"bias", Parameter(getattr(grouped_linear, f"bias{i}").clone())
1971+
)
19581972
if fuse_wgrad_accumulation:
19591973
weight_i = getattr(grouped_linear, f"weight{i}")
19601974
weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32)
@@ -2096,9 +2110,13 @@ def test_grouped_linear_accuracy_save_original_input(
20962110
# Share params
20972111
with torch.no_grad():
20982112
for i in range(num_gemms):
2099-
sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone())
2113+
sequential_linear[i].module_setattr(
2114+
"weight", Parameter(getattr(grouped_linear, f"weight{i}").clone())
2115+
)
21002116
if bias:
2101-
sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone())
2117+
sequential_linear[i].module_setattr(
2118+
"bias", Parameter(getattr(grouped_linear, f"bias{i}").clone())
2119+
)
21022120
if fuse_wgrad_accumulation:
21032121
weight_i = getattr(grouped_linear, f"weight{i}")
21042122
weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32)
@@ -2298,8 +2316,7 @@ def test_padding_grouped_linear_accuracy(
22982316
with torch.no_grad():
22992317
inner_grouped_linear = grouped_linear.linear_fn
23002318
for i in range(num_gemms):
2301-
setattr(
2302-
ref_grouped_linear,
2319+
ref_grouped_linear.module_setattr(
23032320
f"weight{i}",
23042321
Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()),
23052322
)
@@ -2375,8 +2392,7 @@ def test_padding_grouped_linear_accuracy_save_original_input(
23752392
with torch.no_grad():
23762393
inner_grouped_linear = grouped_linear.linear_fn
23772394
for i in range(num_gemms):
2378-
setattr(
2379-
ref_grouped_linear,
2395+
ref_grouped_linear.module_setattr(
23802396
f"weight{i}",
23812397
Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()),
23822398
)

transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,7 @@ def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unuse
482482

483483
self.register_load_state_dict_post_hook(remove_extra_states_check)
484484

485-
self._default_setattr = self._warning_setattr
485+
self._initialized = True
486486

487487
def _load_from_state_dict(
488488
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
@@ -678,9 +678,9 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None:
678678
# assume attention uses the same fp8_group as GEMMs
679679
fp8_group = FP8GlobalStateManager.get_fp8_group()
680680

681-
self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
682-
self.fp8 = FP8GlobalStateManager.is_fp8_enabled()
683-
self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
681+
self.fast_setattr("fp8_parameters", FP8GlobalStateManager.with_fp8_parameters())
682+
self.fast_setattr("fp8", FP8GlobalStateManager.is_fp8_enabled())
683+
self.fast_setattr("fp8_calibration", FP8GlobalStateManager.is_fp8_calibration())
684684
fp8_enabled = self.fp8 or self.fp8_calibration
685685
self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration
686686
if self.fp8_parameters or fp8_enabled:
@@ -705,7 +705,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None:
705705
)
706706
else:
707707
# If fp8 isn't enabled, turn off and return.
708-
self.fp8_initialized = False
708+
self.fast_setattr("fp8_initialized", False)
709709
return
710710

711711
if self.fp8_parameters and not self.fp8_initialized:
@@ -723,7 +723,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None:
723723

724724
# Allocate scales and amaxes
725725
self.init_fp8_meta_tensors(fp8_recipes)
726-
self.fp8_initialized = True
726+
self.fast_setattr("fp8_initialized", True)
727727

728728
self.fp8_meta["recipe"] = fp8_recipe_dpa
729729
if fp8_recipe != fp8_recipe_dpa:

transformer_engine/pytorch/distributed.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -729,8 +729,8 @@ def checkpoint(
729729
if isinstance(function, TransformerEngineBaseModule):
730730
# If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need
731731
# to scatter/gather activations that we will recompute anyway.
732-
setattr(function, "fsdp_wrapped", False)
733-
setattr(function, "fsdp_group", None)
732+
function.fast_setattr("fsdp_wrapped", False)
733+
function.fast_setattr("fsdp_group", None)
734734

735735
# Otherwise discard unused te.utils.checkpoint.checkpoint() arguments
736736
# and execute TE's own checkpointing
@@ -2046,7 +2046,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
20462046
)
20472047
root_state = _get_module_fsdp_state(fsdp_root)
20482048
assert root_state is not None, "Root module does not have a valid _FSDPState."
2049-
setattr(fsdp_root.module, "fsdp_group", root_state.process_group)
2049+
fsdp_root.module.fast_setattr("fsdp_group", root_state.process_group)
20502050

20512051
# Iterate through all FSDP-wrapped submodules and inject FSDP information into TE modules
20522052
fsdp_states, fsdp_modules = _get_fsdp_states_with_modules(fsdp_root)
@@ -2057,7 +2057,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
20572057
"TE modules with primary weights in FP8 cannot be FSDP-wrapped. "
20582058
"Please initialize your model without the te.quantized_model_init(...) context."
20592059
)
2060-
setattr(fsdp_module.module, "fsdp_group", state.process_group)
2060+
fsdp_module.module.fast_setattr("fsdp_group", state.process_group)
20612061

20622062

20632063
class FullyShardedDataParallel(FSDP):

transformer_engine/pytorch/graph.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -935,16 +935,18 @@ def new_fwd(*user_args, **user_kwargs):
935935

936936
forward = make_graphed_forward(func, func.training, graphed, func.forward, te_modules)
937937
if _order is None:
938-
func.forward = forward
938+
with warnings.catch_warnings(action="ignore", category=RuntimeWarning):
939+
func.forward = forward
939940
ret.append(func)
940941
else:
941942
ret.append(forward)
942943
else:
943944
ret.append(graphed)
944945

945946
backward_dw_func, reset_func = make_graphed_attribute_functions(i)
946-
setattr(ret[-1], "backward_dw", backward_dw_func)
947-
setattr(ret[-1], "reset", reset_func)
947+
with warnings.catch_warnings(action="ignore", category=RuntimeWarning):
948+
setattr(ret[-1], "backward_dw", backward_dw_func)
949+
setattr(ret[-1], "reset", reset_func)
948950

949951
if just_one_callable:
950952
return ret[0]

0 commit comments

Comments
 (0)