1+ from concurrent .futures import ThreadPoolExecutor
12from functools import lru_cache
23from typing import Any
34
@@ -67,6 +68,10 @@ class InputSchema(BaseModel):
6768 default = False ,
6869 description = ("Whether to normalize the Jacobian by the number of elements" ),
6970 )
71+ precompute_jacobian : bool = Field (
72+ default = False ,
73+ description = ("Whether to precompute the Jacobian for faster VJP computation." ),
74+ )
7075 max_points : int = Field (
7176 default = 1000 ,
7277 description = ("Maximum number of points in the output mesh." ),
@@ -81,7 +86,7 @@ class TriangularMesh(BaseModel):
8186 """Triangular mesh representation with fixed-size arrays."""
8287
8388 points : Array [(None , 3 ), Float32 ] = Field (description = "Array of vertex positions." )
84- faces : Array [(None , 3 ), Float32 ] = Field (
89+ faces : Array [(None , 3 ), Int32 ] = Field (
8590 description = "Array of triangular faces defined by indices into the points array."
8691 )
8792 n_points : Int32 = Field (
@@ -237,6 +242,8 @@ def geometries_and_sdf(
237242#
238243# Tesseract endpoints
239244#
245+ jacobian_future = None
246+ executor = None
240247
241248
242249def apply (inputs : InputSchema ) -> OutputSchema :
@@ -264,13 +271,34 @@ def apply(inputs: InputSchema) -> OutputSchema:
264271 mesh = mesh [0 ]
265272
266273 points = np .zeros ((inputs .max_points , 3 ), dtype = np .float32 )
267- faces = np .zeros ((inputs .max_faces , 3 ), dtype = np .float32 )
274+ faces = np .zeros ((inputs .max_faces , 3 ), dtype = np .int32 )
268275
269276 points [: mesh .vertices .shape [0 ], :] = mesh .vertices .astype (np .float32 )
270277 faces [: mesh .faces .shape [0 ], :] = mesh .faces .astype (np .int32 )
271278
279+ # start a new thread to precompute the jacobian if requested
280+ if inputs .precompute_jacobian :
281+ print ("Starting Jacobian precomputation thread..." )
282+ executor = ThreadPoolExecutor (max_workers = 1 )
283+ future = executor .submit (
284+ jac_sdf_wrt_params ,
285+ target = inputs .mesh_tesseract ,
286+ differentiable_parameters = inputs .differentiable_parameters ,
287+ non_differentiable_parameters = inputs .non_differentiable_parameters ,
288+ static_parameters = inputs .static_parameters ,
289+ string_parameters = inputs .string_parameters ,
290+ scale_mesh = inputs .scale_mesh ,
291+ grid_size = inputs .grid_size ,
292+ grid_elements = inputs .grid_elements ,
293+ grid_center = inputs .grid_center ,
294+ epsilon = inputs .epsilon ,
295+ )
296+
297+ global jacobian_future
298+ jacobian_future = future
299+
272300 return OutputSchema (
273- sdf = sdf ,
301+ sdf = sdf . astype ( np . float32 ) ,
274302 mesh = TriangularMesh (
275303 points = points ,
276304 faces = faces ,
@@ -332,7 +360,7 @@ def jac_sdf_wrt_params(
332360 for i in range (n_params ):
333361 jac [i ] = (sdf_fields [i + 1 ] - sdf_fields [0 ]) / epsilon
334362
335- return jac
363+ return jac . astype ( np . float32 )
336364
337365
338366def vector_jacobian_product (
@@ -355,25 +383,36 @@ def vector_jacobian_product(
355383 assert vjp_inputs == {"differentiable_parameters" }
356384 assert vjp_outputs == {"sdf" }
357385
358- jac = jac_sdf_wrt_params (
359- target = inputs .mesh_tesseract ,
360- differentiable_parameters = inputs .differentiable_parameters ,
361- non_differentiable_parameters = inputs .non_differentiable_parameters ,
362- static_parameters = inputs .static_parameters ,
363- string_parameters = inputs .string_parameters ,
364- scale_mesh = inputs .scale_mesh ,
365- grid_size = inputs .grid_size ,
366- grid_elements = inputs .grid_elements ,
367- epsilon = inputs .epsilon ,
368- grid_center = inputs .grid_center ,
369- )
386+ # lets also check if the thread is still running
387+ if jacobian_future is not None :
388+ print ("Using precomputed Jacobian..." )
389+ if not jacobian_future .done ():
390+ print ("Waiting for Jacobian precomputation to finish..." )
391+ jac = jacobian_future .result ()
392+
393+ print ("Jacobian precomputation finished." )
394+ print (f"Jacobian shape: { jac .shape } and type: { jac .dtype } " )
395+ else :
396+ print ("Computing Jacobian..." )
397+ jac = jac_sdf_wrt_params (
398+ target = inputs .mesh_tesseract ,
399+ differentiable_parameters = inputs .differentiable_parameters ,
400+ non_differentiable_parameters = inputs .non_differentiable_parameters ,
401+ static_parameters = inputs .static_parameters ,
402+ string_parameters = inputs .string_parameters ,
403+ scale_mesh = inputs .scale_mesh ,
404+ grid_size = inputs .grid_size ,
405+ grid_elements = inputs .grid_elements ,
406+ epsilon = inputs .epsilon ,
407+ grid_center = inputs .grid_center ,
408+ )
370409 if inputs .normalize_jacobian :
371410 n_elements = (
372411 inputs .grid_elements [0 ] * inputs .grid_elements [1 ] * inputs .grid_elements [2 ]
373412 )
374413 jac = jac / n_elements
375414 # Reduce the cotangent vector to the shape of the Jacobian, to compute VJP by hand
376- vjp = np .einsum ("klmn,lmn->k" , jac , cotangent_vector ["sdf" ]). astype ( np . float32 )
415+ vjp = np .einsum ("klmn,lmn->k" , jac , cotangent_vector ["sdf" ])
377416
378417 return {"differentiable_parameters" : vjp }
379418
@@ -400,8 +439,12 @@ def abstract_eval(abstract_inputs: InputSchema) -> dict:
400439 "points" : ShapeDType (
401440 shape = (abstract_inputs .max_points , 3 ), dtype = "float32"
402441 ),
403- "faces" : ShapeDType (shape = (abstract_inputs .max_faces , 3 ), dtype = "float32 " ),
442+ "faces" : ShapeDType (shape = (abstract_inputs .max_faces , 3 ), dtype = "int32 " ),
404443 "n_points" : ShapeDType (shape = (), dtype = "int32" ),
405444 "n_faces" : ShapeDType (shape = (), dtype = "int32" ),
406445 },
407446 }
447+
448+
449+ if executor is not None :
450+ executor .shutdown ()
0 commit comments