help to use partition.neighbor_list #8738
Answered
by
hawkinsp
mariogeiger
asked this question in
Q&A
-
First time I try to use import jax
import jax.numpy as jnp
from jax_md import partition, space
jax.config.update("jax_enable_x64", True)
(displacement_fn, shift_fn) = space.free()
box_size = 100.0
r_cutoff = 3.0
dr_threshold = 0.0
neighbor_list_fn = partition.neighbor_list(
displacement_fn,
box_size=box_size,
r_cutoff=r_cutoff,
dr_threshold=dr_threshold,
)
# all far from eachother
R = jnp.array(
[
[20.0, 20.0],
[30.0, 30.0],
[40.0, 40.0],
[50.0, 50.0],
]
)
neighbors = neighbor_list_fn.allocate(R)
print(neighbors.idx)
# two first point are close to eachother
R = jnp.array(
[
[20.0, 20.0],
[20.0, 20.0],
[40.0, 40.0],
[50.0, 50.0],
]
)
neighbors = neighbors.update(R)
print(neighbors.idx)
print(neighbors.did_buffer_overflow) Prints
I expected to see a |
Beta Was this translation helpful? Give feedback.
Answered by
hawkinsp
Nov 30, 2021
Replies: 1 comment 2 replies
-
I'd ask this in the JAX-MD repository! (https://github.com/google/jax-md) |
Beta Was this translation helpful? Give feedback.
2 replies
Answer selected by
mariogeiger
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I'd ask this in the JAX-MD repository! (https://github.com/google/jax-md)