-
Hi, I'd like to get information per ray per depth, i.e. per intersection. I have two needs.
For the first question, I expect to get gradients per intersection. def sample(...):
for i in range(self.max_depth):
dr.backward_from(Lo) # Lo is depend on param and ray, dr.shape(Lo) = [3, n_ray]
grad_L = dr.grad(param) # dr.shape(param) = [3, 1], want to get grad_L with shape [3, n_ray] For the second question, I want to calculate the second gradient with the same computation graph. def sample(...):
for i in range(self.max_depth):
dr.backward_from(bsdf_val, flags=dr.ADFlag.ClearVertices)
grad_bsdf = dr.grad(param)
dr.set_grad(opt[key], 0)
dr.backward_from(Lo)
grad_L = dr.grad(param) But after sample(), dr.grad(param) returns 0 despite the value of grad_bsdf and grad_L. Why? I have some question about how jit works and how multithreads sync and share the data.
My questions may be ambiguous. Let me explain more if you need. Thank you in advance! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 9 replies
-
Hi @lesphere Gradients are accumulated/summed. The gradient always has the same shape & type as the original primal value. If you want to know the gradient per thread/ray, you can introduce a dummy variable that is simply a repetition of your original parameter by the width of your kernel (number of threads/rays). Something like this: dummy_params = dr.repeat(my_param, dr.width(rays))
output = f(dummy_params) # some differentiable computation
dr.set_grad(output, 1)
dr.backward_to(dummy_params)
grad = dr.grad(dummy_params) For your second question, there is no implicit mechanism that will clear gradients. When are you checking Finally, regarding these questions:
If you haven't already, I'd recommend reading through this gentle introduction to Dr.Jit. ForYou'll learn how to get log messages on every kernel launch such that you now if your code is running multiple kernels or not. For even more details, you can have a look at the paper or video. |
Beta Was this translation helpful? Give feedback.
I don't think there is an easy workaround. The code is failing here because the plugin expects a 1-sized parameter, and not something wider. Maybe disabling symbolic vcalls (
dr.set_flag(dr.JitFlag.VCallRecord, False)
) will help, but even with that I would assume that some other parts of exisiting code would break because of this unexpected change in parameter width.I don't know what your final goal is, but I don't think there is a way in which you could make this per-ray gradient tracking work in a conventional Mitsuba setup (scene, plugins, etc..). You're better off writing whatever you need from scratch with Mitsuba "primitives". This might seem like a lot, but depending on what exactl…