@@ -299,14 +299,14 @@ def vector_jacobian_product(
299299 Returns:
300300 Dictionary containing VJP for the specified inputs.
301301 """
302- assert vjp_inputs == {"bar_params " }
302+ assert vjp_inputs == {"differentiable_parameters " }
303303 assert vjp_outputs == {"sdf" }
304304
305305 jac = jac_sdf_wrt_params (
306306 target = inputs .mesh_tesseract ,
307307 differentiable_parameters = inputs .differentiable_parameters ,
308308 non_differentiable_parameters = inputs .non_differentiable_parameters ,
309- geometry_ints = inputs .static_parameters ,
309+ static_parameters = inputs .static_parameters ,
310310 string_parameters = inputs .string_parameters ,
311311 grid_size = inputs .grid_size ,
312312 grid_elements = inputs .grid_elements ,
@@ -316,9 +316,10 @@ def vector_jacobian_product(
316316 n_elements = inputs .Nx * inputs .Ny * inputs .Nz
317317 jac = jac / n_elements
318318 # Reduce the cotangent vector to the shape of the Jacobian, to compute VJP by hand
319- vjp = np .einsum ("ijklmn ,lmn->ijk " , jac , cotangent_vector ["sdf" ]).astype (np .float32 )
319+ vjp = np .einsum ("klmn ,lmn->k " , jac , cotangent_vector ["sdf" ]).astype (np .float32 )
320320
321- return {"bar_params" : vjp }
321+ print (vjp .shape )
322+ return {"differentiable_parameters" : vjp }
322323
323324
324325def abstract_eval (abstract_inputs : InputSchema ) -> dict :
0 commit comments