File tree Expand file tree Collapse file tree 2 files changed +12
-6
lines changed Expand file tree Collapse file tree 2 files changed +12
-6
lines changed Original file line number Diff line number Diff line change @@ -202,7 +202,7 @@ def module_bytecode_vhlo(self) -> bytes:
202202 target_version = stablehlo .get_minimum_version ()
203203 else :
204204 target_version = stablehlo .get_version_from_compatibility_requirement (
205- stablehlo .StablehloCompatibilityRequirement .WEEK_4
205+ stablehlo .StablehloCompatibilityRequirement .WEEK_12
206206 )
207207 module_bytecode = xla_extension .mlir .serialize_portable_artifact (
208208 self .module_bytecode , target_version
Original file line number Diff line number Diff line change @@ -104,20 +104,26 @@ def _extract_call_args(
104104def _wrap_as_tf_func (lowered , tf_state_dict ):
105105 """Build tf.function from lowered and tf_state_dict."""
106106
107- def inner (* args ):
107+ version = 6
108+ if hasattr (tfxla , "call_module_maximum_supported_version" ):
109+ version = tfxla .call_module_maximum_supported_version ()
110+
111+ def tf_func (* args ):
108112 t_outs = [torch_dtype_to_tf (sig .dtype ) for sig in lowered .output_signature ]
109113 s_outs = [_get_shape_with_dynamic (sig ) for sig in lowered .output_signature ]
110114 call_args = _extract_call_args (lowered , args , tf_state_dict )
111115 return tfxla .call_module (
112116 tuple (call_args ),
113- version = 5 ,
117+ version = version ,
114118 Tout = t_outs , # dtype information
115- Sout = s_outs , # Shape information
119+ Sout = s_outs , # shape information
116120 function_list = [],
117- module = lowered .module_bytecode ,
121+ module = lowered .module_bytecode_vhlo ,
122+ has_token_input_output = False ,
123+ platforms = ["CPU" ],
118124 )
119125
120- return inner
126+ return tf_func
121127
122128
123129def _make_input_signatures (
You can’t perform that action at this time.
0 commit comments