how to use Apex DistributedDataParallel with Lightining? #10922
Answered
by
awaelchli
mostafaelaraby
asked this question in
DDP / multi-GPU / multi-node
-
I was wondering if there's a way to use apex.parallel.DistributedDataParallel instead of pytorch native DistributedDataParallel. (I am trying to reproduce a paper that used Apex DDP and apex mixed precision and i am getting lower results using pytorch native one) |
Beta Was this translation helpful? Give feedback.
Answered by
awaelchli
Dec 5, 2021
Replies: 1 comment 1 reply
-
Here is a quick draft of what you could try: from pytorch_lightning.plugins.training_type import DDPPlugin
from apex.parallel import DistributedDataParallel
class ApexDDPPlugin(DDPPlugin):
def _setup_model(self, model: Module):
return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs)
@property
def lightning_module(self):
return self.module.module I'm not sure if apex DistributedDataParallel supports device ids (it seems not??), you may need to remove it. Use it in the trainer: trainer = Trainer(gpus=2, strategy=ApexDDPPlugin(), precision=...)
trainer.fit(model) |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
rohitgr7
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Here is a quick draft of what you could try:
I'm not sure if apex DistributedDataParallel supports device ids (it seems not??), you may need to remove it.
Use it in the trainer: