Skip to content

Commit 902774a

Browse files
akihironittacarmocca
authored andcommitted
Specify Trainer(benchmark=False) in parity benchmarks (#13182)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent bd50b26 commit 902774a

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

tests/benchmarks/test_basic_parity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,6 @@ def vanilla_loop(cls_model, idx, device_type: str = "cuda", num_epochs=10):
149149

150150
def lightning_loop(cls_model, idx, device_type: str = "cuda", num_epochs=10):
151151
seed_everything(idx)
152-
torch.backends.cudnn.deterministic = True
153152

154153
model = cls_model()
155154
# init model parts
@@ -162,6 +161,7 @@ def lightning_loop(cls_model, idx, device_type: str = "cuda", num_epochs=10):
162161
gpus=1 if device_type == "cuda" else 0,
163162
logger=False,
164163
replace_sampler_ddp=False,
164+
benchmark=False,
165165
)
166166
trainer.fit(model)
167167

tests/benchmarks/test_sharded_parity.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,15 +137,31 @@ def plugin_parity_test(
137137
ddp_model = model_cls()
138138
use_cuda = gpus > 0
139139

140-
trainer = Trainer(fast_dev_run=True, max_epochs=1, gpus=gpus, precision=precision, strategy="ddp_spawn")
140+
trainer = Trainer(
141+
fast_dev_run=True,
142+
max_epochs=1,
143+
accelerator="gpu",
144+
devices=gpus,
145+
precision=precision,
146+
strategy="ddp_spawn",
147+
benchmark=False,
148+
)
141149

142150
max_memory_ddp, ddp_time = record_ddp_fit_model_stats(trainer=trainer, model=ddp_model, use_cuda=use_cuda)
143151

144152
# Reset and train Custom DDP
145153
seed_everything(seed)
146154
custom_plugin_model = model_cls()
147155

148-
trainer = Trainer(fast_dev_run=True, max_epochs=1, gpus=gpus, precision=precision, strategy="ddp_sharded_spawn")
156+
trainer = Trainer(
157+
fast_dev_run=True,
158+
max_epochs=1,
159+
accelerator="gpu",
160+
devices=gpus,
161+
precision=precision,
162+
strategy="ddp_sharded_spawn",
163+
benchmark=False,
164+
)
149165
assert isinstance(trainer.strategy, DDPSpawnShardedStrategy)
150166

151167
max_memory_custom, custom_model_time = record_ddp_fit_model_stats(

0 commit comments

Comments
 (0)