-
Hi, I'm trying to write some traversal loop for a data structure and I encountered the problem of how to even do array accesses within a dr loop. So I have query_points and i want to parallelize the traversal over some access_array. The number of access and which elements are accessed depend on the query_point. I tried to make a dummy example to illustrate my problem with the access pattern: import drjit as dr
from drjit.auto import PCG32
from drjit.auto import Int, UInt, Float, Bool, TensorXf, Array3f, TensorXi
def euclidean_squared_distance(x: TensorXf, y: TensorXf):
return dr.sum(dr.power(x - y, 2))
def euclidean_distance(x: TensorXf, y: TensorXf):
return dr.sqrt(euclidean_squared_distance(x, y))
@dr.syntax
def dummy_func(access_array: TensorXf, query_point: TensorXf, result: TensorXi, N: Int):
counter = Int(0)
while flag:
dist = euclidean_distance(access_array[counter], query_point)
counter = dr.select(dist > Float(0.5), counter * Int(2), counter + Int(1))
result += counter
flag = Bool(counter < N)
return result
rand_array = dr.zeros(TensorXf, (100, 3))
rng = PCG32(size=100, initseq=100)
rand_array[:, 0] = rng.next_float32()
rand_array[:, 1] = rng.next_float32()
rand_array[:, 2] = rng.next_float32()
query_point = dr.zeros(TensorXf, (10000,3))
dummy_func(rand_array, query_point, dr.zeros(TensorXi, query_point.shape[0]), Int(50)) this raises the error
Is there a way around this issue? In the documentation there is a point about local memory and computed thread local indices, but is this needed if one only does read access to the access_array? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
Hello @gd193, If I understand your code example correctly, you should be able to |
Beta Was this translation helpful? Give feedback.
-
Hey @merlinND, thanks for your suggestion, full reproducer (need to pip install jaxkd https://github.com/dodgebc/jaxkd): import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5"
import drjit as dr
from drjit.auto import Int, UInt, Float, Bool, Array3f, PCG32
import jax
import jax.numpy as jnp
import jax.random as jr
import jaxkd as jk
import timeit
# Query and count neighbors on random 3d points
key = jr.key(83)
n_points = 10**6
n_query_points = 10**6
n_neighbors = 1
key, subkey1, subkey2 = jr.split(key, 3)
points = jr.uniform(subkey1, shape=(n_points, 3), minval=0.0, maxval=1.0)
query_points = jr.uniform(subkey2, shape=(n_query_points, 3), minval=0.0, maxval=1.0)
tree = jk.build_tree(points, optimize=False)
neighbors, distances = jk.query_neighbors(tree, query_points, k=n_neighbors)
#time jaxkd tree
f_jit = jax.jit(jk.query_neighbors, static_argnames=["k"])
jax.block_until_ready(f_jit(tree, query_points, k=n_neighbors))
n_runs = 100
timings_jax = jnp.zeros(n_runs)
for i in range(n_runs):
start = timeit.default_timer()
jax.block_until_ready(f_jit(tree, query_points, k=n_neighbors))
end = timeit.default_timer()
timings_jax = timings_jax.at[i].set(end - start)
print(f"jaxkd execution time for {n_query_points} queries: {timings_jax.mean() * 1000:.2f} +- {timings_jax.std() * 1000:.2f} ms.")
print(f"That is {n_query_points / timings_jax.mean() *1e-6:.2f} * 10^6 queries/s.")
tree_jax = tree.points[tree.indices, :] #reorder points such that kdtree is correctly stored
tree_dr = dr.auto.Array3f(tree_jax.T)
query_points_dr = dr.auto.Array3f(query_points.T)
### DrJit implementation
def euclidean_squared_distance(x, y):
return dr.sum(dr.power(x - y, 2))
def euclidean_distance(x, y):
return dr.sqrt(euclidean_squared_distance(x, y))
class StackEntry:
"""
A stack entry for the stack based tree traversal.
"""
DRJIT_STRUCT = {
'nodeID': UInt,
'sqrDist': Float,
}
def __init__(self, nodeID: UInt = UInt(0), sqrDist: Float = Float(0.0)):
self.nodeID = nodeID
self.sqrDist = sqrDist
def _get_coord3(point: Array3f, dim: UInt) -> Float:
"""
Get the coordinate of a point in a given dimension.
Wrapper bypass slicing limitations of Array3f with dr.switch
"""
coord = dr.switch(
index=dim,
targets=[
lambda point: point.x,
lambda point: point.y,
lambda point: point.z,
],
point=point,
)
return coord
@dr.syntax
def traverse_stack_based3(treeArray: Array3f,
numPoints: Int,
queryPoint: Array3f,
result_closestid: UInt,
result_closestdist: Float,
dist_function,
num_dims: Int = Int(3),
get_coord: callable = _get_coord3,
stack_depth: int = 30,
):
"""
traverse an implicitly stored KD-Tree in a stack-based manner. Tree has to be built beforehand.
e.g. using the `jaxkd.build_tree` function!
"""
#drjit port of the C++ code from https://github.com/ingowald/cudaKDTree/blob/master/cukd/traverse-default-stack-based.h
stackID = UInt(0)
stack = dr.alloc_local(StackEntry, stack_depth, value=dr.zeros(StackEntry)) ## zero-init local stack
currNodeID = dr.zeros(UInt, queryPoint.shape[1])
outerLoopFlag = Bool(True)
cullDist = result_closestdist
while outerLoopFlag:
while currNodeID < numPoints:
currLevel = UInt(dr.floor(dr.log2(currNodeID + Float(1)))) # level of the current node
currDim = currLevel % num_dims #cycle through dimensions in order
#currNode = treeArray[:, currNodeID[0]]
currNode = dr.gather(type(treeArray), treeArray, currNodeID)
currDist = dist_function(currNode, queryPoint)
if currDist < result_closestdist:
result_closestid = currNodeID
result_closestdist = currDist
cullDist = currDist
nodeCoord = get_coord(currNode, currDim)
queryCoord = get_coord(queryPoint, currDim)
lChild = UInt(2)*currNodeID + UInt(1)
rChild = lChild + UInt(1)
if queryCoord < nodeCoord:
closeChild = lChild
farChild = rChild
else:
closeChild = rChild
farChild = lChild
currNodeID = closeChild
sqrDistToPlane = dr.power(nodeCoord - queryCoord, 2)
if sqrDistToPlane < cullDist:
if farChild < numPoints:
stack[stackID] = StackEntry(nodeID=farChild, sqrDist=sqrDistToPlane)
stackID += UInt(1)
innerLoopFlag = Bool(True)
while innerLoopFlag:
if stackID == UInt(0):
innerLoopFlag = Bool(False)
outerLoopFlag = Bool(False)
else:
stackID -= UInt(1)
if stack[stackID].sqrDist >= cullDist:
pass
else:
currNodeID = stack[stackID].nodeID
innerLoopFlag = Bool(False)
return result_closestid, result_closestdist
##test DrJit code
result_closestid = dr.zeros(UInt, n_query_points)
result_closestdist = dr.zeros(Float, n_query_points)
result_closestdist += Float(1e10)
result = traverse_stack_based3(tree_dr, Int(n_points), query_points_dr, result_closestid, result_closestdist, euclidean_squared_distance, num_dims=Int(3))
dr_result_points_jax = dr.gather(Array3f, tree_dr, result[0]).jax()
print("Check if results are equal: ", jnp.all(tree.points[neighbors[:, 0], :] == dr_result_points_jax.T))
timings_dr = jnp.zeros(n_runs)
for i in range(n_runs):
result_closestid = dr.zeros(UInt, n_query_points)
result_closestdist = dr.zeros(Float, n_query_points)
result_closestdist += Float(1e10)
with dr.scoped_set_flag(dr.JitFlag.KernelHistory):
result = traverse_stack_based3(tree_dr, Int(n_points), query_points_dr, result_closestid, result_closestdist, euclidean_squared_distance, num_dims=Int(3))
dr.eval(result)
hist = dr.kernel_history()[0]
timings_dr = timings_dr.at[i].set(hist['execution_time']) #execution time in ms
print(f"DrJIT execution time for {n_query_points} queries: {timings_dr.mean():.2f} +- {timings_dr.std():.2f} ms.")
print(f"That is {n_query_points / timings_dr.mean() *1e-3:.2f} * 10^6 queries/s.") |
Beta Was this translation helpful? Give feedback.
Hello @gd193,
If I understand your code example correctly, you should be able to
dr.gather()
from the underlyingaccess_array.array
, computing the index to gather based oncounter
.I would recommend switching from tensors of size
(n, 3)
tomi.Point3f
of widthn
, which will allow you to write operations much more naturally.