1313# limitations under the License.
1414
1515import copy
16+ import enum
1617import hashlib
1718import io
1819import logging
@@ -62,11 +63,23 @@ def custom_hook() -> str:
6263 return ""
6364
6465
65- def get (module : ir .Module ,
66- devices : np .ndarray ,
67- compile_options : xla_client .CompileOptions ,
68- backend : xla_client .Client ,
69- compression_algorithm : str = "zstandard" ) -> str :
66+ class IgnoreCallbacks (enum .IntEnum ):
67+ # Do not remove any callback pointers from precompiled IR.
68+ NO = enum .auto ()
69+ # Remove all callback pointers from precompiled IR.
70+ ALL = enum .auto ()
71+ # Remove only custom_partitioning callback pointer from precompiled IR.
72+ CUSTOM_PARTITIONING = enum .auto ()
73+
74+
75+ def get (
76+ module : ir .Module ,
77+ devices : np .ndarray ,
78+ compile_options : xla_client .CompileOptions ,
79+ backend : xla_client .Client ,
80+ compression_algorithm : str = "zstandard" ,
81+ ignore_callbacks : IgnoreCallbacks = IgnoreCallbacks .NO ,
82+ ) -> str :
7083 """Creates a hashed string to use as a key to the compilation cache.
7184
7285 Creates a cache key that is a hex-encoded string of a unique hash based on
@@ -79,28 +92,47 @@ def get(module: ir.Module,
7992 backend: description of the platform (e.g., TPU version)
8093 compression_algorithm: a string representing the compression algorithm used
8194 for the executable before persisting in the cache
95+ ignore_callbacks: whether to remove the all callback pointer from the
96+ computation.
8297
8398 Typical return value example:
8499 'jit__psum-14ac577cdb2ef6d986078b4054cc9893a9a14a16dbb0d8f37b89167c1f1aacdf'
85100 """
86101 entries = [
87- ("computation" ,
88- lambda hash_obj : _hash_computation (hash_obj , module )),
89- ("jax_lib version" ,
90- lambda hash_obj : hash_obj .update (
91- bytes (jaxlib_version_str .encode ("utf-8" )))),
92- ("XLA flags" ,
93- lambda hash_obj : _hash_xla_flags (hash_obj , get_flag_prefixes ())),
94- ("compile_options" ,
95- lambda hash_obj : _hash_serialized_compile_options (
96- hash_obj , compile_options ,
97- # In case of GPU multi-process tasks we need to strip device
98- # assignment to use cache key as invariant between processes.
99- strip_device_assignment = (backend .platform == "gpu" ))),
100- ("accelerator_config" ,
101- lambda hash_obj : _hash_accelerator_config (hash_obj , devices , backend )),
102- ("compression" ,
103- lambda hash_obj : _hash_string (hash_obj , compression_algorithm )),
102+ (
103+ "computation" ,
104+ lambda hash_obj : _hash_computation (
105+ hash_obj , module , ignore_callbacks
106+ ),
107+ ),
108+ (
109+ "jax_lib version" ,
110+ lambda hash_obj : hash_obj .update (
111+ bytes (jaxlib_version_str .encode ("utf-8" ))
112+ ),
113+ ),
114+ (
115+ "XLA flags" ,
116+ lambda hash_obj : _hash_xla_flags (hash_obj , get_flag_prefixes ()),
117+ ),
118+ (
119+ "compile_options" ,
120+ lambda hash_obj : _hash_serialized_compile_options (
121+ hash_obj ,
122+ compile_options ,
123+ # In case of GPU multi-process tasks we need to strip device
124+ # assignment to use cache key as invariant between processes.
125+ strip_device_assignment = (backend .platform == "gpu" ),
126+ ),
127+ ),
128+ (
129+ "accelerator_config" ,
130+ lambda hash_obj : _hash_accelerator_config (hash_obj , devices , backend ),
131+ ),
132+ (
133+ "compression" ,
134+ lambda hash_obj : _hash_string (hash_obj , compression_algorithm ),
135+ ),
104136 ("custom_hook" , lambda hash_obj : _hash_string (hash_obj , custom_hook ())),
105137 ]
106138
@@ -131,45 +163,56 @@ def _log_cache_key_hash(hash_obj, last_serialized: str, hashfn):
131163 )
132164
133165
134- def _remove_custom_partitioning_ptr (m : ir .Module ):
135- """
136- Removes custom_partitioning callback pointer from precompiled IR.
166+ def _remove_callbacks (m : ir .Module , ignore_callbacks : IgnoreCallbacks ):
167+ """Removes callback pointers from precompiled IR.
168+
137169 Python function pointers are not deterministic across executions.
138170 """
139171 def _update_bc_attribute (op : ir .Operation ) -> ir .WalkResult :
140- if (op .name == "stablehlo.custom_call" and
141- op .attributes ["call_target_name" ].value == "CustomSPMDPartitioning" ):
172+ if op .name == "stablehlo.custom_call" and (
173+ (
174+ ignore_callbacks == IgnoreCallbacks .ALL
175+ and op .attributes ["call_target_name" ].value .endswith ("callback" )
176+ )
177+ or op .attributes ["call_target_name" ].value == "CustomSPMDPartitioning"
178+ ):
142179 op .attributes ["backend_config" ] = ir .StringAttr .get ("REMOVED" )
143180 return ir .WalkResult .ADVANCE
144181
182+ if ignore_callbacks == IgnoreCallbacks .NO :
183+ return m
184+
145185 m .operation .walk (_update_bc_attribute )
146186 return m
147187
148188
149- def _serialize_ir (m : ir .Module ) -> bytes :
189+ def _serialize_ir (m : ir .Module , ignore_callbacks : IgnoreCallbacks ) -> bytes :
150190 output = io .BytesIO ()
151- if config .remove_custom_partitioning_ptr_from_cache_key .value :
152- m = _remove_custom_partitioning_ptr (type_cast (ir .Module ,
153- m .operation .clone ()))
191+ if ignore_callbacks != IgnoreCallbacks .NO :
192+ m = _remove_callbacks (
193+ type_cast (ir .Module , m .operation .clone ()), ignore_callbacks
194+ )
154195 m .operation .write_bytecode (file = output )
155196 return output .getvalue ()
156197
157198
158- def _canonicalize_ir (m_original : ir .Module ) -> bytes :
199+ def _canonicalize_ir (
200+ m_original : ir .Module , ignore_callbacks : IgnoreCallbacks
201+ ) -> bytes :
159202 with m_original .context :
160203 m = type_cast (ir .Module , m_original .operation .clone ())
161204 passes = pm .PassManager .parse (
162205 "builtin.module(strip-debuginfo)"
163206 )
164207 passes .run (m .operation )
165- return _serialize_ir (m )
208+ return _serialize_ir (m , ignore_callbacks )
166209
167210
168- def _hash_computation (hash_obj , module ):
211+ def _hash_computation (hash_obj , module , ignore_callbacks : IgnoreCallbacks ):
169212 if config .compilation_cache_include_metadata_in_key .value :
170- canonical_ir = _serialize_ir (module )
213+ canonical_ir = _serialize_ir (module , ignore_callbacks )
171214 else :
172- canonical_ir = _canonicalize_ir (module )
215+ canonical_ir = _canonicalize_ir (module , ignore_callbacks )
173216 hash_obj .update (canonical_ir )
174217
175218
0 commit comments