diff --git a/prody/atomic/functions.py b/prody/atomic/functions.py index 12964c96a..fb36c27dd 100644 --- a/prody/atomic/functions.py +++ b/prody/atomic/functions.py @@ -285,7 +285,7 @@ def sortAtoms(atoms, label, reverse=False): return AtomMap(ag, sort, acsi) -def sliceAtoms(atoms, select): +def sliceAtoms(atoms, select, allowSame=False): """Slice *atoms* using the selection defined by *select*. :arg atoms: atoms to be selected from @@ -297,6 +297,8 @@ def sliceAtoms(atoms, select): """ if atoms == select: + if allowSame: + return atoms._getSubset('all'), atoms.all raise ValueError('atoms and select arguments are the same') try: diff --git a/prody/ensemble/ensemble.py b/prody/ensemble/ensemble.py index 2e0c20361..9556ee1ba 100644 --- a/prody/ensemble/ensemble.py +++ b/prody/ensemble/ensemble.py @@ -217,7 +217,7 @@ def getAtoms(self, selected=True): return self._atoms return self._atoms[self._indices] - def setAtoms(self, atoms): + def setAtoms(self, atoms, allowSame=False): """Set *atoms* or specify a selection of atoms to be considered in calculations and coordinate requests. When a selection is set, corresponding subset of coordinates will be considered in, for @@ -238,7 +238,7 @@ def setAtoms(self, atoms): n_atoms = self._n_atoms if n_atoms: - if atoms.numAtoms() > n_atoms: + if atoms.numAtoms() > n_atoms and atoms.ca.numAtoms() > n_atoms: raise ValueError('atoms must be same size or smaller than ' 'the ensemble') @@ -261,11 +261,19 @@ def setAtoms(self, atoms): ag = atoms.getAtomGroup() except AttributeError: ag = atoms - if ag.numAtoms() != n_atoms: + try: + self_ag = self._atoms.getAtomGroup() + except AttributeError: + self_ag = self._atoms + try: + self_ag_n_atoms = self_ag.numAtoms() + except AttributeError: + self_ag_n_atoms = 0 + if ag.numAtoms() != n_atoms and ag.numAtoms() != self_ag_n_atoms and self_ag_n_atoms != 0: raise ValueError('size mismatch between this ensemble ({0} atoms) and atoms ({1} atoms)' .format(n_atoms, ag.numAtoms())) self._atoms = ag - self._indices, _ = sliceAtoms(self._atoms, atoms) + self._indices, _ = sliceAtoms(self._atoms, atoms, allowSame=allowSame) else: # if assigning atoms to a new ensemble self._n_atoms = atoms.numAtoms() @@ -294,12 +302,32 @@ def getCoords(self, selected=True): return None if self._indices is None or not selected: return self._coords.copy() - return self._coords[self._indices].copy() + + selids = self._indices + if self.hasSelectionIssue(): + selids = self.getIndices(calphas=True) + return self._coords[selids].copy() - def getIndices(self): + def getIndices(self, calphas=False): """Returns a copy of indices of selected columns""" + if calphas: + return array([list(self._atoms.ca.getIndices()).index(idx) + for idx in self._indices]) return copy(self._indices) + + def hasSelectionIssue(self): + if self._atoms is None or self._atoms.ca is None: + return False + + selids = self._indices + if selids is None: + return False + + if (selids.max() > self._coords.shape[0] + and set(selids).issubset(set(self._atoms.ca.getIndices()))): + return True + return False def setIndices(self, value): if not isListLike(value): @@ -359,10 +387,17 @@ def getWeights(self, selected=True): return None if self._indices is None or not selected: return self._weights.copy() + + + if self.hasSelectionIssue(): + selids = self.getIndices(calphas=True) + else: + selids = self._indices + if self._weights.ndim == 2: - return self._weights[self._indices].copy() + return self._weights[selids].copy() else: - return self._weights[:, self._indices].copy() + return self._weights[:, selids].copy() def _getWeights(self, selected=True): diff --git a/prody/ensemble/functions.py b/prody/ensemble/functions.py index cbe142310..f3f108568 100644 --- a/prody/ensemble/functions.py +++ b/prody/ensemble/functions.py @@ -215,7 +215,7 @@ def loadEnsemble(filename, **kwargs): data[key] = arr else: atoms = None - ensemble.setAtoms(atoms) + ensemble.setAtoms(atoms, allowSame=True) if '_indices' in attr_dict: indices = attr_dict['_indices'] diff --git a/prody/ensemble/pdbensemble.py b/prody/ensemble/pdbensemble.py index a141b77a8..4322de31b 100644 --- a/prody/ensemble/pdbensemble.py +++ b/prody/ensemble/pdbensemble.py @@ -459,6 +459,9 @@ def getMSA(self, indices=None, selected=True): atom_indices = self._indices if selected else slice(None, None, None) indices = indices if indices is not None else slice(None, None, None) + + if self.hasSelectionIssue(): + atom_indices = self.getIndices(calphas=True) return self._msa[indices, atom_indices] @@ -491,6 +494,8 @@ def getCoordsets(self, indices=None, selected=True): confs[i, which] = coords[which] else: selids = self._indices + if self.hasSelectionIssue(): + selids = self.getIndices(calphas=True) coords = coords[selids] confs = self._confs[indices, selids].copy() for i, w in enumerate(self._weights[indices]):