- 
                Notifications
    
You must be signed in to change notification settings  - Fork 6
 
fix nccl subcommunicator bug with asarray() #142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very good! I just left a comment for a bit of code that seems to be now repeated in two places, may be good to wrap it into a function... I leave it to you to choose if you'd like to do it here or later on. Just let me know if you want me to merge this PR 😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tharittk I left some comments, I am not so convinced about the code... feels like you are repeating some operations that could be avoided (at least in the previous code you did not do them 🤔
        
          
                pylops_mpi/DistributedArray.py
              
                Outdated
          
        
      | if nccl_comm == self.sub_comm: | ||
| all_tuples = self._allgather_subcomm(self.local_shape).get() | ||
| else: | ||
| assert (nccl_comm == self.base_comm_nccl) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not like asserts in code, they should only be used in tests... why not passing comm and masked and then have the if masked... as in the code you deleted, so we don't do twice the same checks?
I am also not so sure why before we had
else:
    local_shapes = self.local_shapes
but now also for the case without subcomm we repeat the creation of local_shapes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Earlier, I have this
else:
    local_shapes = self.local_shapes
so that I can have single return statement return local_shapes , but that may look unnecessary.
Now I change the code to look like this:
    def _nccl_local_shapes(self, masked: bool):
        if masked:
            all_tuples = self._allgather_subcomm(self.local_shape).get()
        else:
            all_tuples = self._allgather(self.local_shape).get()
        tuple_len = len(self.local_shape)
        local_shapes = [tuple(all_tuples[i : i + tuple_len]) for i in range(0, len(all_tuples), tuple_len)]
        return local_shapes
It takes only masked: bool. We can assume if masked=True then we will use subcomm and we don't have to pass it.
And then have local_shapes() call the _nccl_local_shapes
    @property
    def local_shapes(self):
        if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
            return self._nccl_local_shapes(False)
        else:
            return self._allgather(self.local_shape)
It is because I want to have this unpacking appear only in one place.
        tuple_len = len(self.local_shape)
        local_shapes = [tuple(all_tuples[i : i + tuple_len]) for i in range(0, len(all_tuples), tuple_len)]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tharittk perfect, this version looks great 😄
In the issue #140, the new test uncovers a bug from NCCL when used with parititioned communicators (sub-communicator).
The issue arises because the NCCL collective calls need the recv buffer and subsequent unrolling after
allgather(while MPI can send arbitary Python object and pack them in a list)Fix:
The recv buffer allocation was previously implemented as having size
MPI.COMM_WORLD.Get_size() * send_buf.sizein_nccl.pyThis will break in case of sub-communicators.Manual unrolling of
local_shapesin case of sub-communicators after_allgather_subcomm()