|
5 | 5 | import math |
6 | 6 | import os |
7 | 7 | from typing import Dict, List, Tuple, Optional |
| 8 | +import warnings |
8 | 9 | import pytest |
9 | 10 | import random |
10 | 11 |
|
@@ -1296,14 +1297,15 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_ |
1296 | 1297 | ).eval() |
1297 | 1298 |
|
1298 | 1299 | # Share params |
1299 | | - with torch.no_grad(): |
1300 | | - te_linear_ref.weight = Parameter(te_linear.weight.clone()) |
1301 | | - if bias: |
1302 | | - te_linear_ref.bias = Parameter(te_linear.bias.clone()) |
1303 | | - if fuse_wgrad_accumulation: |
1304 | | - weight = getattr(te_linear, f"weight") |
1305 | | - weight.main_grad = torch.rand_like(weight, dtype=torch.float32) |
1306 | | - te_linear_ref.weight.main_grad = weight.main_grad.clone() |
| 1300 | + with warnings.catch_warnings(action="ignore", category=RuntimeWarning): |
| 1301 | + with torch.no_grad(): |
| 1302 | + te_linear_ref.weight = Parameter(te_linear.weight.clone()) |
| 1303 | + if bias: |
| 1304 | + te_linear_ref.bias = Parameter(te_linear.bias.clone()) |
| 1305 | + if fuse_wgrad_accumulation: |
| 1306 | + weight = getattr(te_linear, f"weight") |
| 1307 | + weight.main_grad = torch.rand_like(weight, dtype=torch.float32) |
| 1308 | + te_linear_ref.weight.main_grad = weight.main_grad.clone() |
1307 | 1309 |
|
1308 | 1310 | te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config, delay_wgrad_compute=True) |
1309 | 1311 | te_outputs_ref = _test_granular_accuracy( |
@@ -1359,12 +1361,13 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe): |
1359 | 1361 | ).eval() |
1360 | 1362 |
|
1361 | 1363 | # Share params |
1362 | | - with torch.no_grad(): |
1363 | | - te_linear_ref.weight = Parameter(te_linear.weight.clone()) |
1364 | | - if fuse_wgrad_accumulation: |
1365 | | - weight = getattr(te_linear, f"weight") |
1366 | | - weight.main_grad = torch.rand_like(weight, dtype=torch.float32) |
1367 | | - te_linear_ref.weight.main_grad = weight.main_grad.clone() |
| 1364 | + with warnings.catch_warnings(action="ignore", category=RuntimeWarning): |
| 1365 | + with torch.no_grad(): |
| 1366 | + te_linear_ref.weight = Parameter(te_linear.weight.clone()) |
| 1367 | + if fuse_wgrad_accumulation: |
| 1368 | + weight = getattr(te_linear, f"weight") |
| 1369 | + weight.main_grad = torch.rand_like(weight, dtype=torch.float32) |
| 1370 | + te_linear_ref.weight.main_grad = weight.main_grad.clone() |
1368 | 1371 |
|
1369 | 1372 | te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config, recipe=recipe) |
1370 | 1373 | te_outputs_ref = _test_granular_accuracy(te_linear_ref, bs, dtype, config, recipe=recipe) |
@@ -1601,17 +1604,18 @@ def test_layernorm_linear_accuracy_delay_wgrad_compute( |
1601 | 1604 | ).eval() |
1602 | 1605 |
|
1603 | 1606 | # Share params |
1604 | | - with torch.no_grad(): |
1605 | | - ln_linear_ref.layer_norm_weight = Parameter(ln_linear.layer_norm_weight.clone()) |
1606 | | - if normalization != "RMSNorm": |
1607 | | - ln_linear_ref.layer_norm_bias = Parameter(ln_linear.layer_norm_bias.clone()) |
1608 | | - ln_linear_ref.weight = Parameter(ln_linear.weight.clone()) |
1609 | | - if bias: |
1610 | | - ln_linear_ref.bias = Parameter(ln_linear.bias.clone()) |
1611 | | - if fuse_wgrad_accumulation: |
1612 | | - weight = getattr(ln_linear, f"weight") |
1613 | | - weight.main_grad = torch.rand_like(weight, dtype=torch.float32) |
1614 | | - ln_linear_ref.weight.main_grad = weight.main_grad.clone() |
| 1607 | + with warnings.catch_warnings(action="ignore", category=RuntimeWarning): |
| 1608 | + with torch.no_grad(): |
| 1609 | + ln_linear_ref.layer_norm_weight = Parameter(ln_linear.layer_norm_weight.clone()) |
| 1610 | + if normalization != "RMSNorm": |
| 1611 | + ln_linear_ref.layer_norm_bias = Parameter(ln_linear.layer_norm_bias.clone()) |
| 1612 | + ln_linear_ref.weight = Parameter(ln_linear.weight.clone()) |
| 1613 | + if bias: |
| 1614 | + ln_linear_ref.bias = Parameter(ln_linear.bias.clone()) |
| 1615 | + if fuse_wgrad_accumulation: |
| 1616 | + weight = getattr(ln_linear, f"weight") |
| 1617 | + weight.main_grad = torch.rand_like(weight, dtype=torch.float32) |
| 1618 | + ln_linear_ref.weight.main_grad = weight.main_grad.clone() |
1615 | 1619 |
|
1616 | 1620 | te_outputs = _test_granular_accuracy(ln_linear, bs, dtype, config, delay_wgrad_compute=True) |
1617 | 1621 | te_outputs_ref = _test_granular_accuracy( |
@@ -1739,19 +1743,24 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute( |
1739 | 1743 | ).eval() |
1740 | 1744 |
|
1741 | 1745 | # Share params |
1742 | | - with torch.no_grad(): |
1743 | | - ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone()) |
1744 | | - ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone()) |
1745 | | - ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone()) |
1746 | | - ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone()) |
1747 | | - if bias: |
1748 | | - ln_mlp_ref.fc1_bias = Parameter(ln_mlp.fc1_bias.clone()) |
1749 | | - ln_mlp_ref.fc2_bias = Parameter(ln_mlp.fc2_bias.clone()) |
1750 | | - if fuse_wgrad_accumulation: |
1751 | | - ln_mlp.fc1_weight.main_grad = torch.rand_like(ln_mlp.fc1_weight, dtype=torch.float32) |
1752 | | - ln_mlp_ref.fc1_weight.main_grad = ln_mlp.fc1_weight.main_grad.clone() |
1753 | | - ln_mlp.fc2_weight.main_grad = torch.rand_like(ln_mlp.fc2_weight, dtype=torch.float32) |
1754 | | - ln_mlp_ref.fc2_weight.main_grad = ln_mlp.fc2_weight.main_grad.clone() |
| 1746 | + with warnings.catch_warnings(action="ignore", category=RuntimeWarning): |
| 1747 | + with torch.no_grad(): |
| 1748 | + ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone()) |
| 1749 | + ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone()) |
| 1750 | + ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone()) |
| 1751 | + ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone()) |
| 1752 | + if bias: |
| 1753 | + ln_mlp_ref.fc1_bias = Parameter(ln_mlp.fc1_bias.clone()) |
| 1754 | + ln_mlp_ref.fc2_bias = Parameter(ln_mlp.fc2_bias.clone()) |
| 1755 | + if fuse_wgrad_accumulation: |
| 1756 | + ln_mlp.fc1_weight.main_grad = torch.rand_like( |
| 1757 | + ln_mlp.fc1_weight, dtype=torch.float32 |
| 1758 | + ) |
| 1759 | + ln_mlp_ref.fc1_weight.main_grad = ln_mlp.fc1_weight.main_grad.clone() |
| 1760 | + ln_mlp.fc2_weight.main_grad = torch.rand_like( |
| 1761 | + ln_mlp.fc2_weight, dtype=torch.float32 |
| 1762 | + ) |
| 1763 | + ln_mlp_ref.fc2_weight.main_grad = ln_mlp.fc2_weight.main_grad.clone() |
1755 | 1764 |
|
1756 | 1765 | te_outputs = _test_granular_accuracy(ln_mlp, bs, dtype, config, delay_wgrad_compute=True) |
1757 | 1766 | te_outputs_ref = _test_granular_accuracy( |
@@ -1796,14 +1805,15 @@ def test_layernorm_mlp_accuracy_checkpoint( |
1796 | 1805 | ).eval() |
1797 | 1806 |
|
1798 | 1807 | # Share params |
1799 | | - with torch.no_grad(): |
1800 | | - ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone()) |
1801 | | - ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone()) |
1802 | | - ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone()) |
1803 | | - ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone()) |
1804 | | - if bias: |
1805 | | - ln_mlp_ref.fc1_bias = Parameter(ln_mlp.fc1_bias.clone()) |
1806 | | - ln_mlp_ref.fc2_bias = Parameter(ln_mlp.fc2_bias.clone()) |
| 1808 | + with warnings.catch_warnings(action="ignore", category=RuntimeWarning): |
| 1809 | + with torch.no_grad(): |
| 1810 | + ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone()) |
| 1811 | + ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone()) |
| 1812 | + ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone()) |
| 1813 | + ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone()) |
| 1814 | + if bias: |
| 1815 | + ln_mlp_ref.fc1_bias = Parameter(ln_mlp.fc1_bias.clone()) |
| 1816 | + ln_mlp_ref.fc2_bias = Parameter(ln_mlp.fc2_bias.clone()) |
1807 | 1817 |
|
1808 | 1818 | te_outputs = _test_granular_accuracy(ln_mlp, bs, dtype, config, delay_wgrad_compute=False) |
1809 | 1819 | te_outputs_ref = _test_granular_accuracy( |
@@ -1952,9 +1962,13 @@ def test_grouped_linear_accuracy( |
1952 | 1962 | # Share params |
1953 | 1963 | with torch.no_grad(): |
1954 | 1964 | for i in range(num_gemms): |
1955 | | - sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone()) |
| 1965 | + sequential_linear[i].module_setattr( |
| 1966 | + "weight", Parameter(getattr(grouped_linear, f"weight{i}").clone()) |
| 1967 | + ) |
1956 | 1968 | if bias: |
1957 | | - sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone()) |
| 1969 | + sequential_linear[i].module_setattr( |
| 1970 | + "bias", Parameter(getattr(grouped_linear, f"bias{i}").clone()) |
| 1971 | + ) |
1958 | 1972 | if fuse_wgrad_accumulation: |
1959 | 1973 | weight_i = getattr(grouped_linear, f"weight{i}") |
1960 | 1974 | weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32) |
@@ -2096,9 +2110,13 @@ def test_grouped_linear_accuracy_save_original_input( |
2096 | 2110 | # Share params |
2097 | 2111 | with torch.no_grad(): |
2098 | 2112 | for i in range(num_gemms): |
2099 | | - sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone()) |
| 2113 | + sequential_linear[i].module_setattr( |
| 2114 | + "weight", Parameter(getattr(grouped_linear, f"weight{i}").clone()) |
| 2115 | + ) |
2100 | 2116 | if bias: |
2101 | | - sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone()) |
| 2117 | + sequential_linear[i].module_setattr( |
| 2118 | + "bias", Parameter(getattr(grouped_linear, f"bias{i}").clone()) |
| 2119 | + ) |
2102 | 2120 | if fuse_wgrad_accumulation: |
2103 | 2121 | weight_i = getattr(grouped_linear, f"weight{i}") |
2104 | 2122 | weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32) |
@@ -2298,8 +2316,7 @@ def test_padding_grouped_linear_accuracy( |
2298 | 2316 | with torch.no_grad(): |
2299 | 2317 | inner_grouped_linear = grouped_linear.linear_fn |
2300 | 2318 | for i in range(num_gemms): |
2301 | | - setattr( |
2302 | | - ref_grouped_linear, |
| 2319 | + ref_grouped_linear.module_setattr( |
2303 | 2320 | f"weight{i}", |
2304 | 2321 | Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()), |
2305 | 2322 | ) |
@@ -2375,8 +2392,7 @@ def test_padding_grouped_linear_accuracy_save_original_input( |
2375 | 2392 | with torch.no_grad(): |
2376 | 2393 | inner_grouped_linear = grouped_linear.linear_fn |
2377 | 2394 | for i in range(num_gemms): |
2378 | | - setattr( |
2379 | | - ref_grouped_linear, |
| 2395 | + ref_grouped_linear.module_setattr( |
2380 | 2396 | f"weight{i}", |
2381 | 2397 | Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()), |
2382 | 2398 | ) |
|
0 commit comments