In this example why isn't pjit giving a greater impact? #9297
Replies: 1 comment 5 replies
-
In your colab you are not actually splitting the input batch (PartitionSpec = None). The partitioning of the activations is generally more more important than the partitioning of the weights. Generally you would split between batch and model parallelism in some way. As a side note though: Micro Benchmarks like this are tricky. The trivial partitioning of replicating all weights and using only data parallelism will just be the fastest and simplest solution. Model parallelism is only necessary when you cannot fit the weights and/or the activations at batch_size=1 on a device. In this case you take the overhead of splitting weights and activations over multiple devices. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Pjit'ing a method ought to have a major impact on its runtime.
I created a collab implementing a very simple MLP that gets evaluated under different PartionSpecs.
However, the difference in runtime that I measure is marginal (tested on a Xeon with 8 titan-X). Could you help me understand if there is an error in the implementation?
Beta Was this translation helpful? Give feedback.
All reactions