Skip to content

how to use Apex DistributedDataParallel with Lightining? #10922

Discussion options

You must be logged in to vote

Here is a quick draft of what you could try:

from pytorch_lightning.plugins.training_type import DDPPlugin
from apex.parallel import DistributedDataParallel
class ApexDDPPlugin(DDPPlugin):

    def _setup_model(self, model: Module):
        return  DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs)

    @property
    def lightning_module(self):
        return self.module.module

I'm not sure if apex DistributedDataParallel supports device ids (it seems not??), you may need to remove it.

Use it in the trainer:

trainer = Trainer(gpus=2, strategy=ApexDDPPlugin(), precision=...)
trainer.fit(model)

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@mostafaelaraby
Comment options

Answer selected by rohitgr7
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants