@@ -431,6 +431,8 @@ class Primitive:
431431 call_primitive : bool = False
432432 # set for map primitives processed in final style.
433433 map_primitive : bool = False
434+ # set for ref primitives
435+ ref_primitive : bool = False
434436
435437 def __init__ (self , name : str ):
436438 self .name = name
@@ -1882,6 +1884,7 @@ def __repr__(self) -> str: return 'Mutable' + repr(self[...])
18821884def mutable_array (init_val ):
18831885 return mutable_array_p .bind (init_val )
18841886mutable_array_p = Primitive ('mutable_array' )
1887+ mutable_array_p .ref_primitive = True
18851888
18861889class InternalMutableArrayEffect (effects .Effect ):
18871890 pass
@@ -1899,6 +1902,18 @@ def _mutable_array_impl(init_val):
18991902 aval = get_aval (init_val )
19001903 return MutableArray (AbstractRef (aval ), init_val )
19011904
1905+ def freeze (ref ):
1906+ return freeze_p .bind (ref )
1907+ freeze_p = Primitive ('freeze' )
1908+ freeze_p .ref_primitive = True
1909+
1910+ @freeze_p .def_effectful_abstract_eval
1911+ def freeze_abstract_eval (ref_aval ):
1912+ return ref_aval .inner_aval , {internal_mutable_array_effect }
1913+
1914+ @freeze_p .def_impl
1915+ def _freeze_impl (ref ):
1916+ return ref [()]
19021917
19031918class AbstractToken (AbstractValue ):
19041919 def str_short (self , short_dtypes = False ): return 'Tok'
@@ -2516,10 +2531,11 @@ def write(v: Var, a: AbstractValue) -> None:
25162531
25172532 # Check the computed effect type matches the eqn's annotation, and is
25182533 # included in the jaxpr's annotation.
2519- if prim is mutable_array_p :
2520- outvar , = eqn .outvars
2521- in_idx [outvar ] = None # type: ignore
2522- mut_arrays .add (outvar )
2534+ if prim .ref_primitive :
2535+ if prim is mutable_array_p :
2536+ outvar , = eqn .outvars
2537+ in_idx [outvar ] = None # type: ignore
2538+ mut_arrays .add (outvar )
25232539 if eqn .effects != eqn_effects :
25242540 raise JaxprTypeError ("Inferred effects do not match equation effects. "
25252541 f"Equation effects: { eqn .effects } . "
0 commit comments