From 7d47d2c3b932befd4f633c3021d2a64d0cf95aa9 Mon Sep 17 00:00:00 2001 From: Armin Shayesteh Zadeh Date: Thu, 18 Jan 2024 17:03:29 -0600 Subject: [PATCH 1/5] initial commit of the nlist submodule. Contains initial implementations of cell list and verlet list methods for generating neighborlists. --- pysages/nlist/CellList.py | 185 ++++++++++++++++++++++++++++++++++ pysages/nlist/VerletList.py | 121 ++++++++++++++++++++++ pysages/nlist/__init__.py | 0 pysages/nlist/cellfuncs.py | 87 ++++++++++++++++ pysages/nlist/testbench.ipynb | 118 ++++++++++++++++++++++ 5 files changed, 511 insertions(+) create mode 100644 pysages/nlist/CellList.py create mode 100644 pysages/nlist/VerletList.py create mode 100644 pysages/nlist/__init__.py create mode 100644 pysages/nlist/cellfuncs.py create mode 100644 pysages/nlist/testbench.ipynb diff --git a/pysages/nlist/CellList.py b/pysages/nlist/CellList.py new file mode 100644 index 00000000..841a9272 --- /dev/null +++ b/pysages/nlist/CellList.py @@ -0,0 +1,185 @@ +from typing import Tuple + + +import jax +from jax import numpy as np +from jax import lax +from cellfuncs import _idx_to_tuple, _tuple_to_idx, _get_cell_ids, _get_neighbor + +class CellList: + """ + Cell list neighbor list algorithm for 3D systems implemented in JAX. + Loosely based on https://aiichironakano.github.io/cs596/01-1LinkedListCell.pdf + + Raises: + ValueError: If the cell list is not initialized before calling get_neighbor_ids() + + Returns: + CellList: CellList object containing the following attributes: + + box (Tuple): box size (3, ) + cutoff (float): cutoff distance for neighbor list (scalar) + cell_edge (jax.Array): number of cells in each dimension (3, ) + cell_cut (jax.Array): cell size in each dimension (3, ) + cell_num (int): total number of cells (scalar) + cell_idx (jax.Array): cell index for each particle (N, ) + buffer_size_cell (int): max number of neighbors per cell (scalar). If not set, it is set to 50% larger than the average number of particles per cell. + """ + box: Tuple + cutoff: float + cell_edge: jax.Array + cell_cut: jax.Array + cell_num: int + cell_idx: jax.Array + buffer_size_box: int + + def __init__(self, box: Tuple, cutoff: float, buffer_size_cell: int = None) -> None: + self.box = lax.stop_gradient(np.asarray(box)) # (3, ) + self.cutoff = lax.stop_gradient(cutoff) # scalar + self.cell_edge = np.floor(self.box/self.cutoff) # (3, ) + self.cell_cut = self.box/self.cell_edge # (3, ) + self.cell_num = np.prod(self.cell_edge) # scalar + self.cell_idx = None # (N, ) + self.buffer_size_cell = buffer_size_cell # scalar + + def set_buffer_size(self, buffer_size_cell: int) -> None: + """ + Setter for buffer_size_cell attribute. + + Args: + buffer_size_cell (int): max number of neighbors per cell + + Returns: + None + """ + self.buffer_size_cell = buffer_size_cell + return None + + def get_cell_ids(self, pos: jax.Array) -> jax.Array: + """ + Get cell ids for each particle in pos matrix. + + Args: + pos (jax.Array): Array of particle positions (N, 3) + + Returns: + jax.Array: Array of cell ids for each particle (N, ). Also sets the cell_idx attribute. + """ + self.cell_idx = np.zeros(pos.shape[0], dtype=np.int32) + self.cell_idx = _get_cell_ids(pos, self.cell_cut, self.cell_edge) + + if self.buffer_size_cell is None: # can be set manually if needed + self.buffer_size_cell = np.int32(np.ceil(pos.shape[0]//self.cell_num * 1.5)) # set the size of the nl list per cell to 50% larger than the average number of particles per cell + return self.cell_idx + + def _get_neighbor_ids(self, idx: int) -> jax.Array: + """ + Get neighbor ids for a single particle. + + Args: + idx (int): index of the particle in the pos matrix (scalar) + + Raises: + ValueError: If the neighbor list overflows + + Returns: + jax.Array: Array of neighbor ids for the particle (N, ) + """ + cell_id = self.cell_idx[idx] # index of the cell that the particle is in scalar + cell_id = np.expand_dims(cell_id, axis=0) # scalar to (1, ) + + cell_tuple = _idx_to_tuple(cell_id, self.cell_edge) # tuple of the cell that the particle is in (1, dim) + + neighbor_tuples = [] + for i in [-1, 0, 1]: # loop over cells behind and ahead of the current cell in each dimension + for j in [-1, 0, 1]: + for k in [-1, 0, 1]: + neighbor_tuples.append(np.asarray([cell_tuple[0]+i, cell_tuple[1]+j, cell_tuple[2]+k])) + + neighbor_tuples = np.asarray(neighbor_tuples) # list to jax.Array (27, dim) + neighbor_tuples_wrapped = jax.vmap(_get_neighbor, in_axes=(0, None), out_axes=0)(neighbor_tuples, self.cell_edge) # wrap the cell ids of the neighbors (27, dim) + + # get scalar ids for the neighboring cells + neighbor_cell_ids = jax.vmap(_tuple_to_idx, (0, None))(neighbor_tuples_wrapped, self.cell_edge) + + neighbor_ids = [] # get ids of the particles in the neighboring cells. -1 is used as a filler for empty cells. + for cidx in neighbor_cell_ids: + neighbor_ids.append(np.where(self.cell_idx == cidx, fill_value=-1, size=self.buffer_size_cell)[0]) + + + # concatenate the neighbor ids into a single array. + neighbor_ids = np.concatenate(neighbor_ids, axis=-1) + return neighbor_ids + + def get_neighbor_ids(self, idxs: jax.Array, mask_self: bool= False) -> jax.Array: + """ + Get neighbor ids for a list of particles. Uses vmap to vectorize the _get_neighbor_ids function. + + Args: + idxs (jax.Array): Array of particle indices (n, ) + mask_self (bool, optional): Whether to exclude self from neighbor list. Defaults to False. + + Raises: + ValueError: If the cell list is not initialized before calling get_neighbor_ids() + + Returns: + jax.Array: Array of neighbor ids for each particle (n, buffer_size_nl)) + """ + if self.cell_idx is None: + raise ValueError("Cell list is not initialized. Call get_cell_ids() first.") + + # convert idxs to jax.Array if it is not already + if not isinstance(idxs, np.ndarray): + idxs = np.asarray(idxs) + # expand dims if idxs is a single particle + if len(idxs.shape) == 0: + idxs = np.expand_dims(idxs, axis=-1) + + if idxs.shape[0] == 1: # single particle case, no vmap + # get neighbor ids for the particle + n_ids = self._get_neighbor_ids(idxs[0]) + # check for overflow + min_buffer = np.count_nonzero(n_ids == -1, axis=-1) + if min_buffer < 27: # if there are less than 27 -1s in a row of the neighbor list, there is an overflow from buffer_size_cell + raise ValueError("Neighbor list overflow. Increase buffer_size_cell.") + # remove self from neighbor list if mask_self is True + if mask_self: + n_ids = n_ids[n_ids != idxs[0]] + # sort + n_ids = np.sort(n_ids)[::-1] + # truncate. Remove the -1s from the end of the neighbor list(smallest possible neighbor list). + n_ids = n_ids[:-min_buffer] + return n_ids + else: + # get neighbor ids for the particles + n_ids = jax.vmap(self._get_neighbor_ids)(idxs) + # check for overflow + min_buffer = np.min(np.count_nonzero(n_ids == -1, axis=-1)) + if min_buffer < 27: # if there are less than 27 -1s in a row of the neighbor list, there is an overflow from buffer_size_cell + raise ValueError("Neighbor list overflow. Increase buffer_size_cell.") + # remove self from neighbor list if mask_self is True + if mask_self: + # set the self index to -1 + n_ids = n_ids.at[..., n_ids == idxs[:, None]].set(-1) + # add one to the minimum buffer size to account for the -1 just added + min_buffer += 1 + # sort + n_ids = np.sort(n_ids, axis=-1)[:, ::-1] + # truncate. Remove the -1s from the end of the neighbor list so that the row with the least -1s will have none (smallest possible neighbor list). + n_ids = n_ids[:, :-min_buffer] + + return n_ids + + def update(self, pos: jax.Array) -> None: + """ + Update the cell list. + + Args: + pos (jax.Array): Array of particle positions (N, 3) + + Returns: + None + """ + # update the cell ids by calling get_cell_ids + self.get_cell_ids(pos) + return None diff --git a/pysages/nlist/VerletList.py b/pysages/nlist/VerletList.py new file mode 100644 index 00000000..83baa4e9 --- /dev/null +++ b/pysages/nlist/VerletList.py @@ -0,0 +1,121 @@ +import jax +import jax.numpy as np +from jax import lax + +class VerletList: + """ + Verlet list neighbor list algorithm for 3D systems implemented in JAX. Originally implemented to work with the CellList class as a hybrid neighbor list algorithm. + + Returns: + VerletList: VerletList object containing the following attributes: + + cutoff (float): cutoff distance for neighbor list (scalar) + buffer_size (int): max number of neighbors per cell (scalar) + + """ + cutoff: float + buffer_size: int + + def __init__(self, cutoff: float) -> None: + self.cutoff = lax.stop_gradient(cutoff) # scalar + self.buffer_size = None # scalar + + def set_buffer_size(self, buffer_size: int) -> None: + """ + Setter for buffer_size attribute. + + Args: + buffer_size (int): max number of neighbors per cell + + Returns: + None + """ + self.buffer_size = buffer_size + return None + + def _get_dist(self, i: jax.Array, j: jax.Array) -> np.float32: + """ + Calculate the distance between two particles. (helper function for _pairwise_dist to use with vmap) + + Args: + i (jax.Array): Position of particle i (3, ) + j (jax.Array): Position of particle j (3, ) + + Returns: + np.float32: Distance between particles i and j (scalar) + """ + return np.linalg.norm(i-j) + + def _pairwise_dist(self, pos: jax.Array, ref: jax.Array) -> jax.Array: + """ + Calculate the pairwise distance between particles in pos and a single reference particle. (helper function for get_neighbor_ids to use with vmap) + + Args: + pos (jax.Array): position of particles (N, 3) + ref (jax.Array): position of reference particle (3, ) + + Returns: + jax.Array: array of distances between particles and reference particle (N, ) + """ + return jax.vmap(self._get_dist, (0, None))(pos, ref) + + def _is_neighbor(self, dist: jax.Array) -> jax.Array: + """ + Check if a particle is a neighbor of the reference particle. (helper function for get_neighbor_ids to use with vmap) + + Args: + dist (jax.Array): Array of distances between particles and reference particle (N, ) + + Returns: + jax.Array: Array of bools indicating whether a particle is a neighbor of the reference particle (N, ) + """ + return dist < self.cutoff + + def get_neighbor_ids(self, pos: jax.Array, sparse: bool = False, mask_self: bool = False) -> jax.Array: + """ + Get neighbor ids for each particle in pos matrix. + + Args: + pos (jax.Array): Array of particle positions (N, 3) + sparse (bool, optional): Whether to return the full (N, N) matrix of neighborhood or an Array. Defaults to False. + mask_self (bool, optional): Whether to exclude self from neighbor list. Defaults to False. + + Returns: + jax.Array: Array of neighbor ids for each particle (N, ) or (N, N) matrix of bools indicating whether a particle is a neighbor of another particle. + """ + # if buffer_size is not set, set it to the number of particles + if self.buffer_size is None: + self.buffer_size = pos.shape[0] + # calculate the pairwise distances between all particles + pair_dists = jax.vmap(self._pairwise_dist, (None, 0))(pos, pos) + # check if a particle is a neighbor of another particle based on the cutoff distance + is_neighbor = jax.vmap(self._is_neighbor)(pair_dists) + # remove self from neighbor list if mask_self is True + if mask_self: + i, j = np.diag_indices(is_neighbor.shape[0]) + is_neighbor = is_neighbor.at[..., i, j].set(False) + # return a list of arrays if sparse is True + if sparse: # return a list of arrays + neighbor_list = [] + for row in is_neighbor: + neighbor_list.append(np.where(row)[0]) + return neighbor_list + + return is_neighbor # return a NxN array of bools + + def get_neighborhood(self, pos: jax.Array, ref: jax.Array) -> jax.Array: + """ + Get the neighborhood of a specific particle. Implemented to work with the CellList class as a hybrid neighbor list algorithm. + + Args: + pos (jax.Array): Array of particle positions (N, 3) + ref (jax.Array): position of reference particle (3, ) + + Returns: + jax.Array: Array of bools indicating whether a particle is a neighbor of the reference particle (N, ) + """ + # calculate the pairwise distances between all particles and the reference particle + pair_dists = self._pairwise_dist(pos, ref) + # check if a particle is a neighbor of the reference particle based on the cutoff distance + is_neighbor = self._is_neighbor(pair_dists) + return is_neighbor \ No newline at end of file diff --git a/pysages/nlist/__init__.py b/pysages/nlist/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pysages/nlist/cellfuncs.py b/pysages/nlist/cellfuncs.py new file mode 100644 index 00000000..690d7956 --- /dev/null +++ b/pysages/nlist/cellfuncs.py @@ -0,0 +1,87 @@ +import jax +from jax import jit +import jax.numpy as np + +@jit +def _idx_to_tuple(idx: int, cell_edge: jax.Array) -> jax.Array: + """ + Convert cell index from scalar to tuple. + + Args: + idx (int): Scalar index of cell + cell_edge (jax.Array): Number of cells in each dimension + + Returns: + jax.Array: [index in x, index in y, index in z] (3,) or (N, 3) + """ + x: np.int32 = idx//(cell_edge[1]*cell_edge[2]) + y: np.int32 = (idx//cell_edge[2])%cell_edge[1] + z: np.int32 = idx%cell_edge[2] + return np.concatenate([x, y, z], axis=-1, dtype=np.int32) + +@jit +def _tuple_to_idx(tup: jax.Array, cell_edge: jax.Array) -> np.int32 | jax.Array: + """ + Covnert cell index from tuple to scalar. + + Args: + tup (jax.Array): [index in x, index in y, index in z] + cell_edge (jax.Array): Number of cells in each dimension + + Returns: + np.int32 or jax.Array: Scalar index of cell or (N, ) array of scalar indices + """ + return np.int32(tup[...,0]*cell_edge[1]*cell_edge[2] + tup[...,1]*cell_edge[2] + tup[...,2]) + +@jit +def _get_cell_ids(pos: jax.Array, cell_cut: jax.Array, cell_edge: jax.Array) -> jax.Array: + """ + Get scalar cell ids for each particle (row) in pos matrix. + + Args: + pos (jax.Array): matrix of particle positions (N, 3) + cell_cut (jax.Array): Cut off distance for each cell + cell_edge (jax.Array): Number of cells in each dimension + + Returns: + jax.Array: Array of cell ids for each particle (N, ) + """ + cell_tuples: np.int32 = pos//cell_cut + cell_ids = _tuple_to_idx(cell_tuples, cell_edge) + return cell_ids + +@jit +def _wrap_cell_ids(cell_ids: jax.Array, cell_edge: np.int32) -> jax.Array: + """ + Wraps the cell ids of particles in edge cells. (single dimension) + + Args: + cell_ids (jax.Array): Array of tuple cell ids in the current dimension for each particle (N, 1) + cell_edge (np.int32): Number of cells in current dimension + + Returns: + jax.Array: Wrapped cell ids (tuple) for each particle (N, 3) + """ + out_of_bound_low = (cell_ids == -1) # if cell id is -1 (out of bound from below) + out_of_bound_high = (cell_ids == cell_edge) # if cell id equal to the number of cells in that dimension (out of bound from above) + cell_ids = np.where(out_of_bound_low, cell_edge-1, cell_ids) # if out of bound, then wrap around from below + cell_ids = np.where(out_of_bound_high, 0, cell_ids) # if out of bound, then wrap around from above + return cell_ids + +@jit +def _get_neighbor(ids: jax.Array, cell_edge: jax.Array) -> jax.Array: + """ + Wrap the tuple cell ids of particles for each neighbor. (helper function for _get_neighbor_ids to use with vmap) + + Args: + ids (jax.Array): Array of tuple cell ids (N, 3) + cell_edge (jax.Array): Array of number of cells in each dimension (3, ) + + Returns: + jax.Array: Wrapped tuple cell ids (N, 3) + """ + i, j, k = ids + x = _wrap_cell_ids(i, cell_edge[0]) + y = _wrap_cell_ids(j, cell_edge[1]) + z = _wrap_cell_ids(k, cell_edge[2]) + return np.asarray([x, y, z]) diff --git a/pysages/nlist/testbench.ipynb b/pysages/nlist/testbench.ipynb new file mode 100644 index 00000000..eab4fb03 --- /dev/null +++ b/pysages/nlist/testbench.ipynb @@ -0,0 +1,118 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from CellList import CellList\n", + "from VerletList import VerletList\n", + "import jax" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "box_size = 8\n", + "positions = jax.random.uniform(jax.random.PRNGKey(0), (int(1e3), 3))*box_size" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "cell_list = CellList(box = (box_size, box_size, box_size), cutoff = 2.0)\n", + "cell_list.get_cell_ids(positions)\n", + "idxs = jax.numpy.asarray(list(range(positions.shape[0])))\n", + "all_ns = cell_list.get_neighbor_ids(idxs, mask_self=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "verlet_list = VerletList(cutoff = 2.0)\n", + "atom0_n_ids = all_ns[0][all_ns[0] != -1]\n", + "neighbors = verlet_list.get_neighborhood(positions[atom0_n_ids, :], positions[0, :])" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "atom0_real_neighnors = atom0_n_ids[neighbors]" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([993, 988, 968, 961, 959, 951, 947, 923, 891, 872, 871, 862, 828,\n", + " 808, 770, 736, 735, 708, 681, 631, 629, 621, 600, 594, 571, 555,\n", + " 541, 529, 509, 506, 466, 443, 437, 433, 400, 344, 328, 314, 309,\n", + " 283, 250, 242, 197, 145, 119, 98, 95, 47, 40, 22], dtype=int32)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "atom0_real_neighnors" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "jax", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From b27db90cb53938800b26311a27fea69de041425f Mon Sep 17 00:00:00 2001 From: Armin Shayesteh Zadeh Date: Thu, 18 Jan 2024 17:08:51 -0600 Subject: [PATCH 2/5] Reworked initialization and update of CellList object. Cleaner now. --- pysages/nlist/CellList.py | 38 +++++++++++++++++------------------ pysages/nlist/testbench.ipynb | 15 +++++++++++--- 2 files changed, 31 insertions(+), 22 deletions(-) diff --git a/pysages/nlist/CellList.py b/pysages/nlist/CellList.py index 841a9272..4192baf0 100644 --- a/pysages/nlist/CellList.py +++ b/pysages/nlist/CellList.py @@ -55,23 +55,6 @@ def set_buffer_size(self, buffer_size_cell: int) -> None: self.buffer_size_cell = buffer_size_cell return None - def get_cell_ids(self, pos: jax.Array) -> jax.Array: - """ - Get cell ids for each particle in pos matrix. - - Args: - pos (jax.Array): Array of particle positions (N, 3) - - Returns: - jax.Array: Array of cell ids for each particle (N, ). Also sets the cell_idx attribute. - """ - self.cell_idx = np.zeros(pos.shape[0], dtype=np.int32) - self.cell_idx = _get_cell_ids(pos, self.cell_cut, self.cell_edge) - - if self.buffer_size_cell is None: # can be set manually if needed - self.buffer_size_cell = np.int32(np.ceil(pos.shape[0]//self.cell_num * 1.5)) # set the size of the nl list per cell to 50% larger than the average number of particles per cell - return self.cell_idx - def _get_neighbor_ids(self, idx: int) -> jax.Array: """ Get neighbor ids for a single particle. @@ -170,6 +153,23 @@ def get_neighbor_ids(self, idxs: jax.Array, mask_self: bool= False) -> jax.Array return n_ids + def initiate(self, pos: jax.Array) -> None: + """ + Initialize the cell list. + + Args: + pos (jax.Array): Array of particle positions (N, 3) + + Returns: + None + """ + if self.buffer_size_cell is None: # can be set manually if needed + self.buffer_size_cell = np.int32(np.ceil(pos.shape[0]//self.cell_num * 1.5)) # set the size of the nl list per cell to 50% larger than the average number of particles per cell + # get the cell ids + self.cell_idx = np.zeros(pos.shape[0], dtype=np.int32) + self.cell_idx = _get_cell_ids(pos, self.cell_cut, self.cell_edge) + return None + def update(self, pos: jax.Array) -> None: """ Update the cell list. @@ -180,6 +180,6 @@ def update(self, pos: jax.Array) -> None: Returns: None """ - # update the cell ids by calling get_cell_ids - self.get_cell_ids(pos) + # update the cell ids by calling _get_cell_ids + self.cell_idx = _get_cell_ids(pos, self.cell_cut, self.cell_edge) return None diff --git a/pysages/nlist/testbench.ipynb b/pysages/nlist/testbench.ipynb index eab4fb03..2100da70 100644 --- a/pysages/nlist/testbench.ipynb +++ b/pysages/nlist/testbench.ipynb @@ -23,9 +23,18 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", + "I0000 00:00:1705619253.005336 1 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.\n" + ] + } + ], "source": [ "box_size = 8\n", "positions = jax.random.uniform(jax.random.PRNGKey(0), (int(1e3), 3))*box_size" @@ -38,7 +47,7 @@ "outputs": [], "source": [ "cell_list = CellList(box = (box_size, box_size, box_size), cutoff = 2.0)\n", - "cell_list.get_cell_ids(positions)\n", + "cell_list.initiate(positions)\n", "idxs = jax.numpy.asarray(list(range(positions.shape[0])))\n", "all_ns = cell_list.get_neighbor_ids(idxs, mask_self=True)" ] From d700a37c37f2803c81d3002c7e6210f09ce952fd Mon Sep 17 00:00:00 2001 From: Armin Shayesteh Zadeh Date: Thu, 8 Feb 2024 14:32:57 -0600 Subject: [PATCH 3/5] Rewrite of the nlist branch in functional form. --- pysages/nlist/CellList.py | 185 ------------------- pysages/nlist/VerletList.py | 121 ------------ pysages/nlist/cell_list.py | 167 +++++++++++++++++ pysages/nlist/cellfuncs.py | 87 --------- pysages/nlist/testbench.ipynb | 339 ++++++++++++++++++++++++++++++---- pysages/nlist/verlet_list.py | 93 ++++++++++ 6 files changed, 559 insertions(+), 433 deletions(-) delete mode 100644 pysages/nlist/CellList.py delete mode 100644 pysages/nlist/VerletList.py create mode 100644 pysages/nlist/cell_list.py delete mode 100644 pysages/nlist/cellfuncs.py create mode 100644 pysages/nlist/verlet_list.py diff --git a/pysages/nlist/CellList.py b/pysages/nlist/CellList.py deleted file mode 100644 index 4192baf0..00000000 --- a/pysages/nlist/CellList.py +++ /dev/null @@ -1,185 +0,0 @@ -from typing import Tuple - - -import jax -from jax import numpy as np -from jax import lax -from cellfuncs import _idx_to_tuple, _tuple_to_idx, _get_cell_ids, _get_neighbor - -class CellList: - """ - Cell list neighbor list algorithm for 3D systems implemented in JAX. - Loosely based on https://aiichironakano.github.io/cs596/01-1LinkedListCell.pdf - - Raises: - ValueError: If the cell list is not initialized before calling get_neighbor_ids() - - Returns: - CellList: CellList object containing the following attributes: - - box (Tuple): box size (3, ) - cutoff (float): cutoff distance for neighbor list (scalar) - cell_edge (jax.Array): number of cells in each dimension (3, ) - cell_cut (jax.Array): cell size in each dimension (3, ) - cell_num (int): total number of cells (scalar) - cell_idx (jax.Array): cell index for each particle (N, ) - buffer_size_cell (int): max number of neighbors per cell (scalar). If not set, it is set to 50% larger than the average number of particles per cell. - """ - box: Tuple - cutoff: float - cell_edge: jax.Array - cell_cut: jax.Array - cell_num: int - cell_idx: jax.Array - buffer_size_box: int - - def __init__(self, box: Tuple, cutoff: float, buffer_size_cell: int = None) -> None: - self.box = lax.stop_gradient(np.asarray(box)) # (3, ) - self.cutoff = lax.stop_gradient(cutoff) # scalar - self.cell_edge = np.floor(self.box/self.cutoff) # (3, ) - self.cell_cut = self.box/self.cell_edge # (3, ) - self.cell_num = np.prod(self.cell_edge) # scalar - self.cell_idx = None # (N, ) - self.buffer_size_cell = buffer_size_cell # scalar - - def set_buffer_size(self, buffer_size_cell: int) -> None: - """ - Setter for buffer_size_cell attribute. - - Args: - buffer_size_cell (int): max number of neighbors per cell - - Returns: - None - """ - self.buffer_size_cell = buffer_size_cell - return None - - def _get_neighbor_ids(self, idx: int) -> jax.Array: - """ - Get neighbor ids for a single particle. - - Args: - idx (int): index of the particle in the pos matrix (scalar) - - Raises: - ValueError: If the neighbor list overflows - - Returns: - jax.Array: Array of neighbor ids for the particle (N, ) - """ - cell_id = self.cell_idx[idx] # index of the cell that the particle is in scalar - cell_id = np.expand_dims(cell_id, axis=0) # scalar to (1, ) - - cell_tuple = _idx_to_tuple(cell_id, self.cell_edge) # tuple of the cell that the particle is in (1, dim) - - neighbor_tuples = [] - for i in [-1, 0, 1]: # loop over cells behind and ahead of the current cell in each dimension - for j in [-1, 0, 1]: - for k in [-1, 0, 1]: - neighbor_tuples.append(np.asarray([cell_tuple[0]+i, cell_tuple[1]+j, cell_tuple[2]+k])) - - neighbor_tuples = np.asarray(neighbor_tuples) # list to jax.Array (27, dim) - neighbor_tuples_wrapped = jax.vmap(_get_neighbor, in_axes=(0, None), out_axes=0)(neighbor_tuples, self.cell_edge) # wrap the cell ids of the neighbors (27, dim) - - # get scalar ids for the neighboring cells - neighbor_cell_ids = jax.vmap(_tuple_to_idx, (0, None))(neighbor_tuples_wrapped, self.cell_edge) - - neighbor_ids = [] # get ids of the particles in the neighboring cells. -1 is used as a filler for empty cells. - for cidx in neighbor_cell_ids: - neighbor_ids.append(np.where(self.cell_idx == cidx, fill_value=-1, size=self.buffer_size_cell)[0]) - - - # concatenate the neighbor ids into a single array. - neighbor_ids = np.concatenate(neighbor_ids, axis=-1) - return neighbor_ids - - def get_neighbor_ids(self, idxs: jax.Array, mask_self: bool= False) -> jax.Array: - """ - Get neighbor ids for a list of particles. Uses vmap to vectorize the _get_neighbor_ids function. - - Args: - idxs (jax.Array): Array of particle indices (n, ) - mask_self (bool, optional): Whether to exclude self from neighbor list. Defaults to False. - - Raises: - ValueError: If the cell list is not initialized before calling get_neighbor_ids() - - Returns: - jax.Array: Array of neighbor ids for each particle (n, buffer_size_nl)) - """ - if self.cell_idx is None: - raise ValueError("Cell list is not initialized. Call get_cell_ids() first.") - - # convert idxs to jax.Array if it is not already - if not isinstance(idxs, np.ndarray): - idxs = np.asarray(idxs) - # expand dims if idxs is a single particle - if len(idxs.shape) == 0: - idxs = np.expand_dims(idxs, axis=-1) - - if idxs.shape[0] == 1: # single particle case, no vmap - # get neighbor ids for the particle - n_ids = self._get_neighbor_ids(idxs[0]) - # check for overflow - min_buffer = np.count_nonzero(n_ids == -1, axis=-1) - if min_buffer < 27: # if there are less than 27 -1s in a row of the neighbor list, there is an overflow from buffer_size_cell - raise ValueError("Neighbor list overflow. Increase buffer_size_cell.") - # remove self from neighbor list if mask_self is True - if mask_self: - n_ids = n_ids[n_ids != idxs[0]] - # sort - n_ids = np.sort(n_ids)[::-1] - # truncate. Remove the -1s from the end of the neighbor list(smallest possible neighbor list). - n_ids = n_ids[:-min_buffer] - return n_ids - else: - # get neighbor ids for the particles - n_ids = jax.vmap(self._get_neighbor_ids)(idxs) - # check for overflow - min_buffer = np.min(np.count_nonzero(n_ids == -1, axis=-1)) - if min_buffer < 27: # if there are less than 27 -1s in a row of the neighbor list, there is an overflow from buffer_size_cell - raise ValueError("Neighbor list overflow. Increase buffer_size_cell.") - # remove self from neighbor list if mask_self is True - if mask_self: - # set the self index to -1 - n_ids = n_ids.at[..., n_ids == idxs[:, None]].set(-1) - # add one to the minimum buffer size to account for the -1 just added - min_buffer += 1 - # sort - n_ids = np.sort(n_ids, axis=-1)[:, ::-1] - # truncate. Remove the -1s from the end of the neighbor list so that the row with the least -1s will have none (smallest possible neighbor list). - n_ids = n_ids[:, :-min_buffer] - - return n_ids - - def initiate(self, pos: jax.Array) -> None: - """ - Initialize the cell list. - - Args: - pos (jax.Array): Array of particle positions (N, 3) - - Returns: - None - """ - if self.buffer_size_cell is None: # can be set manually if needed - self.buffer_size_cell = np.int32(np.ceil(pos.shape[0]//self.cell_num * 1.5)) # set the size of the nl list per cell to 50% larger than the average number of particles per cell - # get the cell ids - self.cell_idx = np.zeros(pos.shape[0], dtype=np.int32) - self.cell_idx = _get_cell_ids(pos, self.cell_cut, self.cell_edge) - return None - - def update(self, pos: jax.Array) -> None: - """ - Update the cell list. - - Args: - pos (jax.Array): Array of particle positions (N, 3) - - Returns: - None - """ - # update the cell ids by calling _get_cell_ids - self.cell_idx = _get_cell_ids(pos, self.cell_cut, self.cell_edge) - return None diff --git a/pysages/nlist/VerletList.py b/pysages/nlist/VerletList.py deleted file mode 100644 index 83baa4e9..00000000 --- a/pysages/nlist/VerletList.py +++ /dev/null @@ -1,121 +0,0 @@ -import jax -import jax.numpy as np -from jax import lax - -class VerletList: - """ - Verlet list neighbor list algorithm for 3D systems implemented in JAX. Originally implemented to work with the CellList class as a hybrid neighbor list algorithm. - - Returns: - VerletList: VerletList object containing the following attributes: - - cutoff (float): cutoff distance for neighbor list (scalar) - buffer_size (int): max number of neighbors per cell (scalar) - - """ - cutoff: float - buffer_size: int - - def __init__(self, cutoff: float) -> None: - self.cutoff = lax.stop_gradient(cutoff) # scalar - self.buffer_size = None # scalar - - def set_buffer_size(self, buffer_size: int) -> None: - """ - Setter for buffer_size attribute. - - Args: - buffer_size (int): max number of neighbors per cell - - Returns: - None - """ - self.buffer_size = buffer_size - return None - - def _get_dist(self, i: jax.Array, j: jax.Array) -> np.float32: - """ - Calculate the distance between two particles. (helper function for _pairwise_dist to use with vmap) - - Args: - i (jax.Array): Position of particle i (3, ) - j (jax.Array): Position of particle j (3, ) - - Returns: - np.float32: Distance between particles i and j (scalar) - """ - return np.linalg.norm(i-j) - - def _pairwise_dist(self, pos: jax.Array, ref: jax.Array) -> jax.Array: - """ - Calculate the pairwise distance between particles in pos and a single reference particle. (helper function for get_neighbor_ids to use with vmap) - - Args: - pos (jax.Array): position of particles (N, 3) - ref (jax.Array): position of reference particle (3, ) - - Returns: - jax.Array: array of distances between particles and reference particle (N, ) - """ - return jax.vmap(self._get_dist, (0, None))(pos, ref) - - def _is_neighbor(self, dist: jax.Array) -> jax.Array: - """ - Check if a particle is a neighbor of the reference particle. (helper function for get_neighbor_ids to use with vmap) - - Args: - dist (jax.Array): Array of distances between particles and reference particle (N, ) - - Returns: - jax.Array: Array of bools indicating whether a particle is a neighbor of the reference particle (N, ) - """ - return dist < self.cutoff - - def get_neighbor_ids(self, pos: jax.Array, sparse: bool = False, mask_self: bool = False) -> jax.Array: - """ - Get neighbor ids for each particle in pos matrix. - - Args: - pos (jax.Array): Array of particle positions (N, 3) - sparse (bool, optional): Whether to return the full (N, N) matrix of neighborhood or an Array. Defaults to False. - mask_self (bool, optional): Whether to exclude self from neighbor list. Defaults to False. - - Returns: - jax.Array: Array of neighbor ids for each particle (N, ) or (N, N) matrix of bools indicating whether a particle is a neighbor of another particle. - """ - # if buffer_size is not set, set it to the number of particles - if self.buffer_size is None: - self.buffer_size = pos.shape[0] - # calculate the pairwise distances between all particles - pair_dists = jax.vmap(self._pairwise_dist, (None, 0))(pos, pos) - # check if a particle is a neighbor of another particle based on the cutoff distance - is_neighbor = jax.vmap(self._is_neighbor)(pair_dists) - # remove self from neighbor list if mask_self is True - if mask_self: - i, j = np.diag_indices(is_neighbor.shape[0]) - is_neighbor = is_neighbor.at[..., i, j].set(False) - # return a list of arrays if sparse is True - if sparse: # return a list of arrays - neighbor_list = [] - for row in is_neighbor: - neighbor_list.append(np.where(row)[0]) - return neighbor_list - - return is_neighbor # return a NxN array of bools - - def get_neighborhood(self, pos: jax.Array, ref: jax.Array) -> jax.Array: - """ - Get the neighborhood of a specific particle. Implemented to work with the CellList class as a hybrid neighbor list algorithm. - - Args: - pos (jax.Array): Array of particle positions (N, 3) - ref (jax.Array): position of reference particle (3, ) - - Returns: - jax.Array: Array of bools indicating whether a particle is a neighbor of the reference particle (N, ) - """ - # calculate the pairwise distances between all particles and the reference particle - pair_dists = self._pairwise_dist(pos, ref) - # check if a particle is a neighbor of the reference particle based on the cutoff distance - is_neighbor = self._is_neighbor(pair_dists) - return is_neighbor \ No newline at end of file diff --git a/pysages/nlist/cell_list.py b/pysages/nlist/cell_list.py new file mode 100644 index 00000000..d411065a --- /dev/null +++ b/pysages/nlist/cell_list.py @@ -0,0 +1,167 @@ +from typing import Union + + +import jax +from jax import numpy as np + +def _tuple_to_idx(tup: jax.Array, cell_edge: jax.Array) -> Union[np.int32, jax.Array]: + """ + Covnert cell index from tuple to scalar. + + Args: + tup (jax.Array): [index in x, index in y, index in z] + cell_edge (jax.Array): Number of cells in each dimension + + Returns: + np.int32 or jax.Array: Scalar index of cell or (N, ) array of scalar indices + """ + return np.int32(tup[...,0]*cell_edge[1]*cell_edge[2] + tup[...,1]*cell_edge[2] + tup[...,2]) + +def _idx_to_tuple(idx: int, cell_edge: jax.Array) -> jax.Array: + """ + Convert cell index from scalar to tuple. + + Args: + idx (int): Scalar index of cell + cell_edge (jax.Array): Number of cells in each dimension + + Returns: + jax.Array: [index in x, index in y, index in z] (3,) or (N, 3) + """ + x: np.int32 = idx//(cell_edge[1]*cell_edge[2]) + y: np.int32 = (idx//cell_edge[2])%cell_edge[1] + z: np.int32 = idx%cell_edge[2] + return np.concatenate([x, y, z], axis=-1, dtype=np.int32) + +def get_cell_list(pos: jax.Array, box_size: jax.Array, cutoff: float) -> jax.Array: + """ + Initialize the cell list. + + Args: + pos (jax.Array): Array of particle positions (N, 3) + box_size (Tuple): box size (3, ) + cutoff (float): cutoff distance for neighbor list (scalar) + + Returns: + cell_idx (jax.Array): cell index for each particle (N, ) + """ + #setup the box parameters + cell_edge = np.floor(box_size/cutoff) # (3, ) + cell_cut = box_size/cell_edge # (3, ) + # get the cell ids + cell_tuples = pos//cell_cut + cell_idx = _tuple_to_idx(cell_tuples, cell_edge) + return cell_idx + +def _wrap_cell_ids(cell_ids: jax.Array, cell_edge: np.int32) -> jax.Array: + """ + Wraps the cell ids of particles in edge cells. (single dimension) + + Args: + cell_ids (jax.Array): Array of tuple cell ids in the current dimension for each particle (N, 1) + cell_edge (np.int32): Number of cells in current dimension + + Returns: + jax.Array: Wrapped cell ids (tuple) for each particle (N, 3) + """ + out_of_bound_low = (cell_ids == -1) # if cell id is -1 (out of bound from below) + out_of_bound_high = (cell_ids == cell_edge) # if cell id equal to the number of cells in that dimension (out of bound from above) + cell_ids = np.where(out_of_bound_low, cell_edge-1, cell_ids) # if out of bound, then wrap around from below + cell_ids = np.where(out_of_bound_high, 0, cell_ids) # if out of bound, then wrap around from above + return cell_ids + +def _get_neighbor_box(ids: jax.Array, cell_edge: jax.Array) -> jax.Array: + """ + Wrap the tuple cell ids of particles for each neighbor. (helper function for get_neighbor_ids to use with vmap) + + Args: + ids (jax.Array): Array of tuple cell ids (N, 3) + cell_edge (jax.Array): Array of number of cells in each dimension (3, ) + + Returns: + jax.Array: Wrapped tuple cell ids (N, 3) + """ + i, j, k = ids + x = _wrap_cell_ids(i, cell_edge[0]) + y = _wrap_cell_ids(j, cell_edge[1]) + z = _wrap_cell_ids(k, cell_edge[2]) + return np.asarray([x, y, z]) + +def get_neighbor_ids(box_size: jax.Array, cutoff: float, cell_idx: jax.Array, idx: int, buffer_size_cell: int) -> jax.Array: + """ + Get neighbor ids for a single particle. + + Args: + box_size (Tuple): box size (3, ) + cutoff (float): cutoff distance for neighbor list (scalar) + cell_idx (jax.Array): cell index for each particle (N, ) + idx (int): index of the particle in the pos matrix (scalar) + buffer_size_cell (int): buffer size for the cell list (scalar) + + Raises: + ValueError: If the neighbor list overflows + + Returns: + jax.Array: Array of neighbor ids for the particle (N, ) + """ + cell_edge = np.floor(box_size/cutoff) # (3, ) + cell_id = cell_idx[idx] # index of the cell that the particle is in scalar + cell_id = np.expand_dims(cell_id, axis=0) # scalar to (1, ) + + cell_tuple = _idx_to_tuple(cell_id, cell_edge) # tuple of the cell that the particle is in (1, dim) + + neighbor_tuples = [] + for i in [-1, 0, 1]: # loop over cells behind and ahead of the current cell in each dimension + for j in [-1, 0, 1]: + for k in [-1, 0, 1]: + neighbor_tuples.append(np.asarray([cell_tuple[0]+i, cell_tuple[1]+j, cell_tuple[2]+k])) + + neighbor_tuples = np.asarray(neighbor_tuples) # list to jax.Array (27, dim) + neighbor_tuples_wrapped = jax.vmap(_get_neighbor_box, in_axes=(0, None), out_axes=0)(neighbor_tuples, cell_edge) # wrap the cell ids of the neighbors (27, dim) + + # get scalar ids for the neighboring cells + neighbor_cell_ids = jax.vmap(_tuple_to_idx, (0, None))(neighbor_tuples_wrapped, cell_edge) + + neighbor_ids = [] # get ids of the particles in the neighboring cells. -1 is used as a filler for empty cells. + for cidx in neighbor_cell_ids: + neighbor_ids.append(np.where(cell_idx == cidx, fill_value=-1, size=buffer_size_cell)[0]) + + + # concatenate the neighbor ids into a single array. + neighbor_ids = np.concatenate(neighbor_ids, axis=-1) + return neighbor_ids + +def get_neighbors_list(box_size: jax.Array, cutoff: float, cell_idx: jax.Array, idxs: jax.Array, buffer_size_cell: int, mask_self: bool= False) -> jax.Array: + """ + Get neighbor ids for a list of particles. Uses vmap to vectorize on get_neighbor_ids function. + + Args: + cell_idx (jax.Array): cell index for each particle (N, ) + idxs (jax.Array): Array of particle indices (n, ) + mask_self (bool, optional): Whether to exclude self from neighbor list. Defaults to False. + + Raises: + ValueError: If the cell list is not initialized before calling get_neighbor_ids() + + Returns: + jax.Array: Array of neighbor ids for each particle (n, buffer_size_nl)) + """ + + # get neighbor ids for the particles + n_ids = jax.vmap(get_neighbor_ids, in_axes=(None, None, None, 0, None))(box_size, cutoff, cell_idx, idxs, buffer_size_cell) + # check for overflow + min_buffer = np.min(np.count_nonzero(n_ids == -1, axis=-1)) + if min_buffer < 27: # if there are less than 27 -1s in a row of the neighbor list, there is an overflow from buffer_size_cell + raise ValueError("Neighbor list overflow. Increase buffer_size_cell.") + # remove self from neighbor list if mask_self is True + if mask_self: + # set the self index to -1 + n_ids = n_ids.at[..., n_ids == idxs[:, None]].set(-1) + # add one to the minimum buffer size to account for the -1 just added + min_buffer += 1 + # sort + n_ids = np.sort(n_ids, axis=-1)[:, ::-1] + # truncate. Remove the -1s from the end of the neighbor list so that the row with the least -1s will have none (smallest possible neighbor list). + n_ids = n_ids[:, :-min_buffer] + + return n_ids diff --git a/pysages/nlist/cellfuncs.py b/pysages/nlist/cellfuncs.py deleted file mode 100644 index 690d7956..00000000 --- a/pysages/nlist/cellfuncs.py +++ /dev/null @@ -1,87 +0,0 @@ -import jax -from jax import jit -import jax.numpy as np - -@jit -def _idx_to_tuple(idx: int, cell_edge: jax.Array) -> jax.Array: - """ - Convert cell index from scalar to tuple. - - Args: - idx (int): Scalar index of cell - cell_edge (jax.Array): Number of cells in each dimension - - Returns: - jax.Array: [index in x, index in y, index in z] (3,) or (N, 3) - """ - x: np.int32 = idx//(cell_edge[1]*cell_edge[2]) - y: np.int32 = (idx//cell_edge[2])%cell_edge[1] - z: np.int32 = idx%cell_edge[2] - return np.concatenate([x, y, z], axis=-1, dtype=np.int32) - -@jit -def _tuple_to_idx(tup: jax.Array, cell_edge: jax.Array) -> np.int32 | jax.Array: - """ - Covnert cell index from tuple to scalar. - - Args: - tup (jax.Array): [index in x, index in y, index in z] - cell_edge (jax.Array): Number of cells in each dimension - - Returns: - np.int32 or jax.Array: Scalar index of cell or (N, ) array of scalar indices - """ - return np.int32(tup[...,0]*cell_edge[1]*cell_edge[2] + tup[...,1]*cell_edge[2] + tup[...,2]) - -@jit -def _get_cell_ids(pos: jax.Array, cell_cut: jax.Array, cell_edge: jax.Array) -> jax.Array: - """ - Get scalar cell ids for each particle (row) in pos matrix. - - Args: - pos (jax.Array): matrix of particle positions (N, 3) - cell_cut (jax.Array): Cut off distance for each cell - cell_edge (jax.Array): Number of cells in each dimension - - Returns: - jax.Array: Array of cell ids for each particle (N, ) - """ - cell_tuples: np.int32 = pos//cell_cut - cell_ids = _tuple_to_idx(cell_tuples, cell_edge) - return cell_ids - -@jit -def _wrap_cell_ids(cell_ids: jax.Array, cell_edge: np.int32) -> jax.Array: - """ - Wraps the cell ids of particles in edge cells. (single dimension) - - Args: - cell_ids (jax.Array): Array of tuple cell ids in the current dimension for each particle (N, 1) - cell_edge (np.int32): Number of cells in current dimension - - Returns: - jax.Array: Wrapped cell ids (tuple) for each particle (N, 3) - """ - out_of_bound_low = (cell_ids == -1) # if cell id is -1 (out of bound from below) - out_of_bound_high = (cell_ids == cell_edge) # if cell id equal to the number of cells in that dimension (out of bound from above) - cell_ids = np.where(out_of_bound_low, cell_edge-1, cell_ids) # if out of bound, then wrap around from below - cell_ids = np.where(out_of_bound_high, 0, cell_ids) # if out of bound, then wrap around from above - return cell_ids - -@jit -def _get_neighbor(ids: jax.Array, cell_edge: jax.Array) -> jax.Array: - """ - Wrap the tuple cell ids of particles for each neighbor. (helper function for _get_neighbor_ids to use with vmap) - - Args: - ids (jax.Array): Array of tuple cell ids (N, 3) - cell_edge (jax.Array): Array of number of cells in each dimension (3, ) - - Returns: - jax.Array: Wrapped tuple cell ids (N, 3) - """ - i, j, k = ids - x = _wrap_cell_ids(i, cell_edge[0]) - y = _wrap_cell_ids(j, cell_edge[1]) - z = _wrap_cell_ids(k, cell_edge[2]) - return np.asarray([x, y, z]) diff --git a/pysages/nlist/testbench.ipynb b/pysages/nlist/testbench.ipynb index 2100da70..23106370 100644 --- a/pysages/nlist/testbench.ipynb +++ b/pysages/nlist/testbench.ipynb @@ -7,6 +7,7 @@ "outputs": [], "source": [ "%load_ext autoreload\n", + "%load_ext line_profiler\n", "%autoreload 2" ] }, @@ -16,28 +17,69 @@ "metadata": {}, "outputs": [], "source": [ - "from CellList import CellList\n", - "from VerletList import VerletList\n", - "import jax" + "from cell_list import get_cell_list, get_neighbors_list, get_neighbor_ids\n", + "from verlet_list import get_neighbor_ids as get_verlet_neighbor_ids\n", + "from verlet_list import get_neighborhood, _pairwise_dist\n", + "from jax import random\n", + "import jax.numpy as jnp" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setting up a simple box\n", + "\n", + "This will create a $3 \\times 3 \\times 3$ centered at 0 and puts particles on the nodes of a grid with edge size of 1. First particle is placed at the center of the box $(0, 0, 0)$ and the rest of the particles (26 more) are placed at $+/- 1$ of the center." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", - "I0000 00:00:1705619253.005336 1 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.\n" - ] - } - ], + "outputs": [], "source": [ - "box_size = 8\n", - "positions = jax.random.uniform(jax.random.PRNGKey(0), (int(1e3), 3))*box_size" + "box_edge = 3.0\n", + "box_size = jnp.array([box_edge, box_edge, box_edge])\n", + "#positions = random.uniform(random.PRNGKey(0), (int(1e2), 3))*box_edge\n", + "positions = jnp.array([[ 0.0, 0.0, 0.0],\n", + " [ 0.0, 1.0, 0.0],\n", + " [ 0.0, -1.0, 0.0],\n", + " [ 1.0, 0.0, 0.0],\n", + " [ 1.0, 1.0, 0.0], \n", + " [ 1.0, -1.0, 0.0],\n", + " [-1.0, 0.0, 0.0],\n", + " [-1.0, 1.0, 0.0],\n", + " [-1.0, -1.0, 0.0],\n", + " [ 0.0, 0.0, 1.0],\n", + " [ 0.0, 1.0, 1.0],\n", + " [ 0.0, -1.0, 1.0],\n", + " [ 1.0, 0.0, 1.0],\n", + " [ 1.0, 1.0, 1.0], \n", + " [ 1.0, -1.0, 1.0],\n", + " [-1.0, 0.0, 1.0],\n", + " [-1.0, 1.0, 1.0],\n", + " [-1.0, -1.0, 1.0],\n", + " [ 0.0, 0.0, -1.0],\n", + " [ 0.0, 1.0, -1.0],\n", + " [ 0.0, -1.0, -1.0],\n", + " [ 1.0, 0.0, -1.0],\n", + " [ 1.0, 1.0, -1.0], \n", + " [ 1.0, -1.0, -1.0],\n", + " [-1.0, 0.0, -1.0],\n", + " [-1.0, 1.0, -1.0],\n", + " [-1.0, -1.0, -1.0]])\n", + " \n", + "cutoff_c = 1.0\n", + "cutoff_v = 1.0\n", + "buffer = 30" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Particles are shifted so that all $x, y, z$ values are positive ($(0, 0, 0)$ at the corner instead of the center)." ] }, { @@ -46,10 +88,15 @@ "metadata": {}, "outputs": [], "source": [ - "cell_list = CellList(box = (box_size, box_size, box_size), cutoff = 2.0)\n", - "cell_list.initiate(positions)\n", - "idxs = jax.numpy.asarray(list(range(positions.shape[0])))\n", - "all_ns = cell_list.get_neighbor_ids(idxs, mask_self=True)" + "positions += box_edge/2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Cell list method\n", + "This will create a cell list which breaks the box into 27 boxes of edge size 1 (`cutoff_c`). Because of how we placed the particles, each box will contain exactly one particle." ] }, { @@ -58,18 +105,61 @@ "metadata": {}, "outputs": [], "source": [ - "verlet_list = VerletList(cutoff = 2.0)\n", - "atom0_n_ids = all_ns[0][all_ns[0] != -1]\n", - "neighbors = verlet_list.get_neighborhood(positions[atom0_n_ids, :], positions[0, :])" + "cell_list = get_cell_list(positions, box_size, cutoff_c)\n", + "idxs = jnp.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Method 1\n", + "\n", + "Using `get_neighbors_list` we can get the neighbors for a list of particles. Under the hood, this functions calls `vmap` on another function `get_neighbor_ids` which is implemented for a single particle (+ some postprocessing)." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 9 8 7 6 5 4 3\n", + " 2 1 0]\n", + " [26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 9 8 7 6 5 4 3\n", + " 2 1 0]\n", + " [26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 9 8 7 6 5 4 3\n", + " 2 1 0]\n", + " [26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 9 8 7 6 5 4 3\n", + " 2 1 0]\n", + " [26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 9 8 7 6 5 4 3\n", + " 2 1 0]\n", + " [26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 9 8 7 6 5 4 3\n", + " 2 1 0]\n", + " [26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 9 8 7 6 5 4 3\n", + " 2 1 0]\n", + " [26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 9 8 7 6 5 4 3\n", + " 2 1 0]\n", + " [26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 9 8 7 6 5 4 3\n", + " 2 1 0]]\n" + ] + } + ], "source": [ - "atom0_real_neighnors = atom0_n_ids[neighbors]" + "nbors = get_neighbors_list(box_size=box_size, cutoff=cutoff_c, cell_idx=cell_list, idxs=idxs, buffer_size_cell=buffer, mask_self=False)\n", + "print(nbors)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Method 2\n", + "\n", + "Using `get_neighbor_ids` directly with an external for loop and the postprocessing outside of the neighbor list." ] }, { @@ -78,29 +168,198 @@ "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "Array([993, 988, 968, 961, 959, 951, 947, 923, 891, 872, 871, 862, 828,\n", - " 808, 770, 736, 735, 708, 681, 631, 629, 621, 600, 594, 571, 555,\n", - " 541, 529, 509, 506, 466, 443, 437, 433, 400, 344, 328, 314, 309,\n", - " 283, 250, 242, 197, 145, 119, 98, 95, 47, 40, 22], dtype=int32)" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "[Array([24, 6, 15, 25, 7, 16, 26, 8, 17, 18, 0, 9, 19, 10, 20, 2, 11,\n", + " 21, 3, 12, 22, 4, 13, 23, 5, 14], dtype=int32), Array([25, 7, 16, 26, 8, 17, 24, 6, 15, 19, 1, 10, 20, 11, 18, 0, 9,\n", + " 22, 4, 13, 23, 5, 14, 21, 3, 12], dtype=int32), Array([20, 2, 11, 18, 0, 9, 19, 1, 10, 23, 5, 14, 21, 12, 22, 4, 13,\n", + " 26, 8, 17, 24, 6, 15, 25, 7, 16], dtype=int32), Array([18, 0, 9, 19, 1, 10, 20, 2, 11, 21, 3, 12, 22, 13, 23, 5, 14,\n", + " 24, 6, 15, 25, 7, 16, 26, 8, 17], dtype=int32), Array([19, 1, 10, 20, 2, 11, 18, 0, 9, 22, 4, 13, 23, 14, 21, 3, 12,\n", + " 25, 7, 16, 26, 8, 17, 24, 6, 15], dtype=int32), Array([23, 5, 14, 21, 3, 12, 22, 4, 13, 26, 8, 17, 24, 15, 25, 7, 16,\n", + " 20, 2, 11, 18, 0, 9, 19, 1, 10], dtype=int32), Array([21, 3, 12, 22, 4, 13, 23, 5, 14, 24, 6, 15, 25, 16, 26, 8, 17,\n", + " 18, 0, 9, 19, 1, 10, 20, 2, 11], dtype=int32), Array([22, 4, 13, 23, 5, 14, 21, 3, 12, 25, 7, 16, 26, 17, 24, 6, 15,\n", + " 19, 1, 10, 20, 2, 11, 18, 0, 9], dtype=int32), Array([ 8, 17, 26, 6, 15, 24, 7, 16, 25, 2, 11, 20, 0, 18, 1, 10, 19,\n", + " 5, 14, 23, 3, 12, 21, 4, 13, 22], dtype=int32)]\n" + ] + } + ], + "source": [ + "nbors_2 = []\n", + "for i in idxs:\n", + " n_i = get_neighbor_ids(box_size=box_size, cutoff=cutoff_c, cell_idx=cell_list, idx=i, buffer_size_cell=buffer)\n", + " n_i = n_i[n_i != i]\n", + " n_i = n_i[n_i != -1]\n", + " nbors_2.append(n_i)\n", + "print(nbors_2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Both methods return all 27 particles as neighbors for each particle. This is the expected result because of the way the positions are set up." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Verlet list\n", + "\n", + "Verlet list can be used together with a cell list to get exact cutoffs. Cell list will capture all particles in the 27 neighboring cell, regardless of exact cutoff selected. This hybrid method is more efficient than calculating pairwise distance for all particles and doesn't require a skin radius." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Method 1\n", + "Here we can get the particle ids of the neighbors of particle $0$ and get its exact neighborhood (`cutoff_v`) with a verlet list using `get_neighborhood`. This functions returns a list of `bool`s that shows whether the particles in the cell list neighbor list are within the verlet cutoff as well." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[18 9 6 3 2 1 0]\n" + ] + } + ], + "source": [ + "atom0_n_ids = nbors[0][nbors[0] != -1]\n", + "neighbors = get_neighborhood(positions[atom0_n_ids, :], positions[0, :], cutoff_v, box_size)\n", + "print(atom0_n_ids[jnp.where(neighbors)])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Method 2\n", + "We can also get the neighborhood for all particles in a position matrix. The result is a $N \\times N$ matrix (or a list of size $N$ if `sparse = True`)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[ True True True True False False True False False True False False\n", + " False False False False False False True False False False False False\n", + " False False False]\n", + " [ True True False False True False False True False False True False\n", + " False False False False False False False True False False False False\n", + " False False False]\n", + " [ True True True False False True False False True False False True\n", + " False False False False False False False False True False False False\n", + " False False False]\n", + " [ True False False True True True False False False False False False\n", + " True False False False False False False False False True False False\n", + " False False False]\n", + " [False True False True True False False False False False False False\n", + " False True False False False False False False False False True False\n", + " False False False]\n", + " [False False True True True True False False False False False False\n", + " False False True False False False False False False False False True\n", + " False False False]\n", + " [ True False False True False False True True True False False False\n", + " False False False True False False False False False False False False\n", + " True False False]\n", + " [False True False False True False True True False False False False\n", + " False False False False True False False False False False False False\n", + " False True False]\n", + " [False False True False False True True True True False False False\n", + " False False False False False True False False False False False False\n", + " False False True]\n", + " [ True False False False False False False False False True True True\n", + " True False False True False False False False False False False False\n", + " False False False]\n", + " [False True False False False False False False False True True False\n", + " False True False False True False False False False False False False\n", + " False False False]\n", + " [False False True False False False False False False True True True\n", + " False False True False False True False False False False False False\n", + " False False False]\n", + " [False False False True False False False False False True False False\n", + " True True True False False False False False False False False False\n", + " False False False]\n", + " [False False False False True False False False False False True False\n", + " True True False False False False False False False False False False\n", + " False False False]\n", + " [False False False False False True False False False False False True\n", + " True True True False False False False False False False False False\n", + " False False False]\n", + " [False False False False False False True False False True False False\n", + " True False False True True True False False False False False False\n", + " False False False]\n", + " [False False False False False False False True False False True False\n", + " False True False True True False False False False False False False\n", + " False False False]\n", + " [False False False False False False False False True False False True\n", + " False False True True True True False False False False False False\n", + " False False False]\n", + " [ True False False False False False False False False True False False\n", + " False False False False False False True True True True False False\n", + " True False False]\n", + " [False True False False False False False False False False True False\n", + " False False False False False False True True False False True False\n", + " False True False]\n", + " [False False True False False False False False False False False True\n", + " False False False False False False True True True False False True\n", + " False False True]\n", + " [False False False True False False False False False False False False\n", + " True False False False False False True False False True True True\n", + " False False False]\n", + " [False False False False True False False False False False False False\n", + " False True False False False False False True False True True False\n", + " False False False]\n", + " [False False False False False True False False False False False False\n", + " False False True False False False False False True True True True\n", + " False False False]\n", + " [False False False False False False True False False False False False\n", + " False False False True False False True False False True False False\n", + " True True True]\n", + " [False False False False False False False True False False False False\n", + " False False False False True False False True False False True False\n", + " True True False]\n", + " [False False False False False False False False True False False False\n", + " False False False False False True False False True False False True\n", + " True True True]]\n" + ] } ], "source": [ - "atom0_real_neighnors" + "n_ids_verlet = get_verlet_neighbor_ids(positions, cutoff_v, box_size, sparse=False)\n", + "print(n_ids_verlet)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, - "outputs": [], - "source": [] + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Array([ 0, 1, 2, 3, 6, 9, 18], dtype=int32), Array([ 0, 1, 4, 7, 10, 19], dtype=int32), Array([ 0, 1, 2, 5, 8, 11, 20], dtype=int32), Array([ 0, 3, 4, 5, 12, 21], dtype=int32), Array([ 1, 3, 4, 13, 22], dtype=int32), Array([ 2, 3, 4, 5, 14, 23], dtype=int32), Array([ 0, 3, 6, 7, 8, 15, 24], dtype=int32), Array([ 1, 4, 6, 7, 16, 25], dtype=int32), Array([ 2, 5, 6, 7, 8, 17, 26], dtype=int32), Array([ 0, 9, 10, 11, 12, 15], dtype=int32), Array([ 1, 9, 10, 13, 16], dtype=int32), Array([ 2, 9, 10, 11, 14, 17], dtype=int32), Array([ 3, 9, 12, 13, 14], dtype=int32), Array([ 4, 10, 12, 13], dtype=int32), Array([ 5, 11, 12, 13, 14], dtype=int32), Array([ 6, 9, 12, 15, 16, 17], dtype=int32), Array([ 7, 10, 13, 15, 16], dtype=int32), Array([ 8, 11, 14, 15, 16, 17], dtype=int32), Array([ 0, 9, 18, 19, 20, 21, 24], dtype=int32), Array([ 1, 10, 18, 19, 22, 25], dtype=int32), Array([ 2, 11, 18, 19, 20, 23, 26], dtype=int32), Array([ 3, 12, 18, 21, 22, 23], dtype=int32), Array([ 4, 13, 19, 21, 22], dtype=int32), Array([ 5, 14, 20, 21, 22, 23], dtype=int32), Array([ 6, 15, 18, 21, 24, 25, 26], dtype=int32), Array([ 7, 16, 19, 22, 24, 25], dtype=int32), Array([ 8, 17, 20, 23, 24, 25, 26], dtype=int32)]\n" + ] + } + ], + "source": [ + "n_ids_verlet_s = get_verlet_neighbor_ids(positions, cutoff_v, box_size, sparse=True)\n", + "print(n_ids_verlet_s)" + ] } ], "metadata": { @@ -119,7 +378,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.9.18" } }, "nbformat": 4, diff --git a/pysages/nlist/verlet_list.py b/pysages/nlist/verlet_list.py new file mode 100644 index 00000000..2515c7f2 --- /dev/null +++ b/pysages/nlist/verlet_list.py @@ -0,0 +1,93 @@ +import jax +import jax.numpy as np + +def _get_dist(i: jax.Array, j: jax.Array, box_size: jax.Array) -> np.float32: + """ + Calculate the distance between two particles. (helper function for _pairwise_dist to use with vmap) + + Args: + i (jax.Array): Position of particle i (3, ) + j (jax.Array): Position of particle j (3, ) + + Returns: + np.float32: Distance between particles i and j (scalar) + """ + dx = i[0] - j[0] + dx = np.where(dx > box_size[0]/2, dx - box_size[0], dx) + dy = i[1] - j[1] + dy = np.where(dy > box_size[1]/2, dy - box_size[1], dy) + dz = i[2] - j[2] + dz = np.where(dz > box_size[2]/2, dz - box_size[2], dz) + return np.sqrt(dx**2 + dy**2 + dz**2) + +def _pairwise_dist(pos: jax.Array, ref: jax.Array, box_size: jax.Array) -> jax.Array: + """ + Calculate the pairwise distance between particles in pos and a single reference particle. (helper function for get_neighbor_ids to use with vmap) + + Args: + pos (jax.Array): position of particles (N, 3) + ref (jax.Array): position of reference particle (3, ) + + Returns: + jax.Array: array of distances between particles and reference particle (N, ) + """ + return jax.vmap(_get_dist, (0, None, None))(pos, ref, box_size) + +def _is_neighbor(dist: jax.Array, cutoff: float) -> jax.Array: + """ + Check if a particle is a neighbor of the reference particle. (helper function for get_neighbor_ids to use with vmap) + + Args: + dist (jax.Array): Array of distances between particles and reference particle (N, ) + cutoff (float): cutoff distance for neighbor list (scalar) + + Returns: + jax.Array: Array of bools indicating whether a particle is a neighbor of the reference particle (N, ) + """ + return dist <= cutoff + +def get_neighbor_ids(pos: jax.Array, cutoff: float, box_size: jax.Array, sparse: bool = False, mask_self: bool = False) -> jax.Array: + """ + Get neighbor ids for each particle in pos matrix. + + Args: + pos (jax.Array): Array of particle positions (N, 3) + sparse (bool, optional): Whether to return the full (N, N) matrix of neighborhood or an Array. Defaults to False. + mask_self (bool, optional): Whether to exclude self from neighbor list. Defaults to False. + + Returns: + jax.Array: Array of neighbor ids for each particle (N, ) or (N, N) matrix of bools indicating whether a particle is a neighbor of another particle. + """ + # calculate the pairwise distances between all particles + pair_dists = jax.vmap(_pairwise_dist, (None, 0, None))(pos, pos, box_size) + # check if a particle is a neighbor of another particle based on the cutoff distance + is_neighbor = jax.vmap(_is_neighbor, (0, None))(pair_dists, cutoff) + # remove self from neighbor list if mask_self is True + if mask_self: + i, j = np.diag_indices(is_neighbor.shape[0]) + is_neighbor = is_neighbor.at[..., i, j].set(False) + # return a list of arrays if sparse is True + if sparse: # return a list of arrays + neighbor_list = [] + for row in is_neighbor: + neighbor_list.append(np.where(row)[0]) + return neighbor_list + + return is_neighbor # return a NxN array of bools + +def get_neighborhood(pos: jax.Array, ref: jax.Array, cutoff: float, box_size: jax.Array) -> jax.Array: + """ + Get the neighborhood of a reference particle. + + Args: + pos (jax.Array): Array of particle positions (N, 3) + ref (jax.Array): position of reference particle (3, ) + + Returns: + jax.Array: Array of bools indicating whether a particle is a neighbor of the reference particle (N, ) + """ + # calculate the pairwise distances between all particles and the reference particle + pair_dists = _pairwise_dist(pos, ref, box_size) + # check if a particle is a neighbor of the reference particle based on the cutoff distance + is_neighbor = _is_neighbor(pair_dists, cutoff) + return is_neighbor \ No newline at end of file From 67b28bc3087dbe1c20529158fbac4ee6caecc393 Mon Sep 17 00:00:00 2001 From: Armin Shayesteh Zadeh Date: Thu, 15 Feb 2024 10:10:31 -0600 Subject: [PATCH 4/5] Fix to minimum image convention in verlet list method. --- pysages/nlist/verlet_list.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pysages/nlist/verlet_list.py b/pysages/nlist/verlet_list.py index 2515c7f2..7e8216a8 100644 --- a/pysages/nlist/verlet_list.py +++ b/pysages/nlist/verlet_list.py @@ -13,11 +13,14 @@ def _get_dist(i: jax.Array, j: jax.Array, box_size: jax.Array) -> np.float32: np.float32: Distance between particles i and j (scalar) """ dx = i[0] - j[0] - dx = np.where(dx > box_size[0]/2, dx - box_size[0], dx) + dx = np.where(dx > box_size[0]/2, dx - box_size[0], dx) + dx = np.where(dx < -box_size[0]/2, dx + box_size[0], dx) dy = i[1] - j[1] - dy = np.where(dy > box_size[1]/2, dy - box_size[1], dy) + dy = np.where(dy > box_size[1]/2, dy - box_size[1], dy) + dy = np.where(dy < -box_size[1]/2, dy + box_size[1], dy) dz = i[2] - j[2] - dz = np.where(dz > box_size[2]/2, dz - box_size[2], dz) + dz = np.where(dz > box_size[2]/2, dz - box_size[2], dz) + dz = np.where(dz < -box_size[2]/2, dz + box_size[2], dz) return np.sqrt(dx**2 + dy**2 + dz**2) def _pairwise_dist(pos: jax.Array, ref: jax.Array, box_size: jax.Array) -> jax.Array: From cb79d638400041770b8c86c46471d7a657e5ed5a Mon Sep 17 00:00:00 2001 From: Armin Shayesteh Zadeh Date: Fri, 16 Feb 2024 11:40:58 -0600 Subject: [PATCH 5/5] Added jax_md comparisons. Changed the _is_neighbor criterion to jax.Array: Returns: jax.Array: Array of bools indicating whether a particle is a neighbor of the reference particle (N, ) """ - return dist <= cutoff + return dist < cutoff def get_neighbor_ids(pos: jax.Array, cutoff: float, box_size: jax.Array, sparse: bool = False, mask_self: bool = False) -> jax.Array: """