Skip to content

Commit ba8f1df

Browse files
committed
test_add_noise_device
1 parent b73d57c commit ba8f1df

File tree

2 files changed

+31
-31
lines changed

2 files changed

+31
-31
lines changed

.github/workflows/pr_tests_mps.yml

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,30 +33,30 @@ jobs:
3333
fail-fast: false
3434
matrix:
3535
config:
36-
- name: Fast Pipelines MPS tests
37-
framework: pytorch_pipelines
38-
runner: macos-13-xlarge
39-
report: torch_mps_pipelines
40-
- name: Fast Models MPS tests
41-
framework: pytorch_models
42-
runner: macos-13-xlarge
43-
report: torch_mps_models
36+
# - name: Fast Pipelines MPS tests
37+
# framework: pytorch_pipelines
38+
# runner: macos-13-xlarge
39+
# report: torch_mps_pipelines
40+
# - name: Fast Models MPS tests
41+
# framework: pytorch_models
42+
# runner: macos-13-xlarge
43+
# report: torch_mps_models
4444
- name: Fast Schedulers MPS tests
4545
framework: pytorch_schedulers
4646
runner: macos-13-xlarge
4747
report: torch_mps_schedulers
48-
- name: Fast Others MPS tests
49-
framework: pytorch_others
50-
runner: macos-13-xlarge
51-
report: torch_mps_others
48+
# - name: Fast Others MPS tests
49+
# framework: pytorch_others
50+
# runner: macos-13-xlarge
51+
# report: torch_mps_others
5252
# - name: Fast Single File MPS tests
5353
# framework: pytorch_single_file
5454
# runner: macos-13-xlarge
5555
# report: torch_mps_single_file
56-
- name: Fast Lora MPS tests
57-
framework: pytorch_lora
58-
runner: macos-13-xlarge
59-
report: torch_mps_lora
56+
# - name: Fast Lora MPS tests
57+
# framework: pytorch_lora
58+
# runner: macos-13-xlarge
59+
# report: torch_mps_lora
6060
# - name: Fast Quantization MPS tests
6161
# framework: pytorch_quantization
6262
# runner: macos-13-xlarge

tests/schedulers/test_schedulers.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -711,21 +711,21 @@ def test_add_noise_device(self):
711711
scheduler = scheduler_class(**scheduler_config)
712712
scheduler.set_timesteps(self.default_num_inference_steps)
713713

714-
sample = self.dummy_sample.to(torch_device)
715-
if scheduler_class == CMStochasticIterativeScheduler:
716-
# Get valid timestep based on sigma_max, which should always be in timestep schedule.
717-
scaled_sigma_max = scheduler.sigma_to_t(scheduler.config.sigma_max)
718-
scaled_sample = scheduler.scale_model_input(sample, scaled_sigma_max)
719-
elif scheduler_class == EDMEulerScheduler:
720-
scaled_sample = scheduler.scale_model_input(sample, scheduler.timesteps[-1])
721-
else:
722-
scaled_sample = scheduler.scale_model_input(sample, 0.0)
723-
self.assertEqual(sample.shape, scaled_sample.shape)
724-
725-
noise = torch.randn_like(scaled_sample).to(torch_device)
726-
t = scheduler.timesteps[5][None]
727-
noised = scheduler.add_noise(scaled_sample, noise, t)
728-
self.assertEqual(noised.shape, scaled_sample.shape)
714+
# sample = self.dummy_sample.to(torch_device)
715+
# if scheduler_class == CMStochasticIterativeScheduler:
716+
# # Get valid timestep based on sigma_max, which should always be in timestep schedule.
717+
# scaled_sigma_max = scheduler.sigma_to_t(scheduler.config.sigma_max)
718+
# scaled_sample = scheduler.scale_model_input(sample, scaled_sigma_max)
719+
# elif scheduler_class == EDMEulerScheduler:
720+
# scaled_sample = scheduler.scale_model_input(sample, scheduler.timesteps[-1])
721+
# else:
722+
# scaled_sample = scheduler.scale_model_input(sample, 0.0)
723+
# self.assertEqual(sample.shape, scaled_sample.shape)
724+
725+
# noise = torch.randn_like(scaled_sample).to(torch_device)
726+
# t = scheduler.timesteps[5][None]
727+
# noised = scheduler.add_noise(scaled_sample, noise, t)
728+
# self.assertEqual(noised.shape, scaled_sample.shape)
729729

730730
def test_deprecated_kwargs(self):
731731
for scheduler_class in self.scheduler_classes:

0 commit comments

Comments
 (0)