Skip to content

Commit 3b2afed

Browse files
chunnienccopybara-github
authored andcommitted
update mlir/shlo serialization for per-op testing
PiperOrigin-RevId: 711813259
1 parent 562c93d commit 3b2afed

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

ai_edge_torch/odml_torch/export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def module_bytecode_vhlo(self) -> bytes:
202202
target_version = stablehlo.get_minimum_version()
203203
else:
204204
target_version = stablehlo.get_version_from_compatibility_requirement(
205-
stablehlo.StablehloCompatibilityRequirement.WEEK_4
205+
stablehlo.StablehloCompatibilityRequirement.WEEK_12
206206
)
207207
module_bytecode = xla_extension.mlir.serialize_portable_artifact(
208208
self.module_bytecode, target_version

ai_edge_torch/odml_torch/tf_integration.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,20 +104,26 @@ def _extract_call_args(
104104
def _wrap_as_tf_func(lowered, tf_state_dict):
105105
"""Build tf.function from lowered and tf_state_dict."""
106106

107-
def inner(*args):
107+
version = 6
108+
if hasattr(tfxla, "call_module_maximum_supported_version"):
109+
version = tfxla.call_module_maximum_supported_version()
110+
111+
def tf_func(*args):
108112
t_outs = [torch_dtype_to_tf(sig.dtype) for sig in lowered.output_signature]
109113
s_outs = [_get_shape_with_dynamic(sig) for sig in lowered.output_signature]
110114
call_args = _extract_call_args(lowered, args, tf_state_dict)
111115
return tfxla.call_module(
112116
tuple(call_args),
113-
version=5,
117+
version=version,
114118
Tout=t_outs, # dtype information
115-
Sout=s_outs, # Shape information
119+
Sout=s_outs, # shape information
116120
function_list=[],
117-
module=lowered.module_bytecode,
121+
module=lowered.module_bytecode_vhlo,
122+
has_token_input_output=False,
123+
platforms=["CPU"],
118124
)
119125

120-
return inner
126+
return tf_func
121127

122128

123129
def _make_input_signatures(

0 commit comments

Comments
 (0)