Skip to content

Commit cf2ea33

Browse files
committed
fix seeds..
1 parent a331838 commit cf2ea33

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/models/test_modeling_common.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1775,6 +1775,7 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
17751775
fine.
17761776
"""
17771777
# create 2 adapters with different ranks and alphas
1778+
torch.manual_seed(0)
17781779
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
17791780
model = self.model_class(**init_dict).to(torch_device)
17801781

@@ -1809,7 +1810,8 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
18091810
del model
18101811

18111812
# load the first adapter
1812-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1813+
torch.manual_seed(0)
1814+
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
18131815
model = self.model_class(**init_dict).to(torch_device)
18141816

18151817
if do_compile or (rank0 != rank1):
@@ -1824,7 +1826,6 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
18241826
model = torch.compile(model, mode="reduce-overhead")
18251827

18261828
with torch.inference_mode():
1827-
torch.manual_seed(0)
18281829
output0_after = model(**inputs_dict)["sample"]
18291830
assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)
18301831

@@ -1833,7 +1834,6 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
18331834

18341835
# we need to call forward to potentially trigger recompilation
18351836
with torch.inference_mode():
1836-
torch.manual_seed(0)
18371837
output1_after = model(**inputs_dict)["sample"]
18381838
assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol)
18391839

0 commit comments

Comments
 (0)