@@ -195,8 +195,8 @@ def tesseract_dispatch_lowering(
195195 * array_args : ArrayLike | ShapedArray | Any ,
196196 static_args : tuple [_Hashable , ...],
197197 input_pytreedef : PyTreeDef ,
198- output_pytreedef : PyTreeDef ,
199- output_avals : tuple [ShapeDtypeStruct , ...],
198+ output_pytreedef : PyTreeDef | None ,
199+ output_avals : tuple [ShapeDtypeStruct , ...] | None ,
200200 is_static_mask : tuple [bool , ...],
201201 client : Jaxeract ,
202202 eval_func : str ,
@@ -344,11 +344,11 @@ def apply_tesseract(
344344 f"Got { type (tesseract_client )} instead."
345345 )
346346
347- if "abstract_eval" not in tesseract_client .available_endpoints :
348- raise ValueError (
349- "Given Tesseract object does not support abstract_eval, "
350- "which is required for compatibility with JAX."
351- )
347+ # if "abstract_eval" not in tesseract_client.available_endpoints:
348+ # raise ValueError(
349+ # "Given Tesseract object does not support abstract_eval, "
350+ # "which is required for compatibility with JAX."
351+ # )
352352
353353 client = Jaxeract (tesseract_client )
354354
@@ -357,28 +357,30 @@ def apply_tesseract(
357357 array_args , static_args = split_args (flat_args , is_static_mask )
358358 static_args = tuple (_make_hashable (arg ) for arg in static_args )
359359
360+
360361 # Get abstract values for outputs, so we can unflatten them later
361- output_pytreedef , avals = None , None
362- avals = client .abstract_eval (
363- array_args ,
364- static_args ,
365- input_pytreedef ,
366- output_pytreedef ,
367- avals ,
368- is_static_mask ,
369- )
362+ output_pytreedef , avals , flat_avals = None , None , None
363+ if "abstract_eval" in tesseract_client .available_endpoints :
364+ avals = client .abstract_eval (
365+ array_args ,
366+ static_args ,
367+ input_pytreedef ,
368+ output_pytreedef ,
369+ avals ,
370+ is_static_mask ,
371+ )
370372
371- is_aval = lambda x : isinstance (x , dict ) and "dtype" in x and "shape" in x
372- flat_avals , output_pytreedef = jax .tree .flatten (avals , is_leaf = is_aval )
373- for aval in flat_avals :
374- if not is_aval (aval ):
375- continue
376- _check_dtype (aval ["dtype" ])
373+ is_aval = lambda x : isinstance (x , dict ) and "dtype" in x and "shape" in x
374+ flat_avals , output_pytreedef = jax .tree .flatten (avals , is_leaf = is_aval )
375+ for aval in flat_avals :
376+ if not is_aval (aval ):
377+ continue
378+ _check_dtype (aval ["dtype" ])
377379
378- flat_avals = tuple (
379- jax .ShapeDtypeStruct (shape = tuple (aval ["shape" ]), dtype = aval ["dtype" ])
380- for aval in flat_avals
381- )
380+ flat_avals = tuple (
381+ jax .ShapeDtypeStruct (shape = tuple (aval ["shape" ]), dtype = aval ["dtype" ])
382+ for aval in flat_avals
383+ )
382384
383385 # Apply the primitive
384386 out = tesseract_dispatch_p .bind (
@@ -392,5 +394,16 @@ def apply_tesseract(
392394 eval_func = "apply" ,
393395 )
394396
397+ # else:
398+ # # Apply the primitive
399+ # out = tesseract_dispatch_p.bind(
400+ # *array_args,
401+ # static_args=static_args,
402+ # input_pytreedef=input_pytreedef,
403+ # is_static_mask=is_static_mask,
404+ # client=client,
405+ # eval_func="apply",
406+ # )
407+
395408 # Unflatten the output
396409 return jax .tree .unflatten (output_pytreedef , out )
0 commit comments