Skip to content

Commit 2e361d7

Browse files
authored
Fix FSDP2 breakage in nightly (#2684)
1 parent be40518 commit 2e361d7

File tree

4 files changed

+15
-0
lines changed

4 files changed

+15
-0
lines changed

test/dtypes/test_affine_quantized_tensor_parallel.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
from torchao.quantization.quant_api import quantize_
2727
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6
2828

29+
if common_utils.SEED is None:
30+
common_utils.SEED = 1234
31+
2932
try:
3033
import gemlite # noqa: F401
3134

test/dtypes/test_nf4.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
apply_activation_checkpointing,
2121
)
2222
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
23+
from torch.testing._internal import common_utils
2324
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
2425
from torch.testing._internal.common_fsdp import FSDPTest
2526
from torch.testing._internal.common_utils import (
@@ -29,6 +30,9 @@
2930
run_tests,
3031
)
3132

33+
if common_utils.SEED is None:
34+
common_utils.SEED = 1234
35+
3236
import torchao
3337
from packaging import version
3438
from torchao.dtypes._nf4tensor_api import nf4_weight_only

test/prototype/test_quantized_training.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import torch
1616
import torch.distributed as dist
1717
import torch.nn.functional as F
18+
import torch.testing._internal.common_utils as common_utils
1819
from torch import nn
1920
from torch.distributed._composable.fsdp import MixedPrecisionPolicy, fully_shard
2021
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
@@ -40,6 +41,9 @@
4041
)
4142
from torchao.quantization.quant_api import quantize_
4243

44+
if common_utils.SEED is None:
45+
common_utils.SEED = 1234
46+
4347
_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
4448

4549

test/test_low_bit_optim.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
OffloadPolicy,
1717
fully_shard,
1818
)
19+
from torch.testing._internal import common_utils
1920
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
2021
from torch.testing._internal.common_fsdp import FSDPTest
2122
from torch.testing._internal.common_utils import (
@@ -25,6 +26,9 @@
2526
run_tests,
2627
)
2728

29+
if common_utils.SEED is None:
30+
common_utils.SEED = 1234
31+
2832
from packaging.version import Version
2933
from torchao import optim
3034
from torchao.optim.quant_utils import (

0 commit comments

Comments
 (0)