Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 0 additions & 39 deletions ai_edge_torch/_convert/test/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,45 +576,6 @@ def forward(self, x: torch.Tensor):
self.fail(f"Conversion failed with bloat16 inputs: {err}")
# pylint: enable=broad-except

def test_convert_model_with_torch_linspace_operation(self):
"""Test converting a simple model with torch.linspace operation."""

class SampleModel(nn.Module):

def forward(self, x: torch.Tensor):
return torch.linspace(0.5, 10.5, steps=x.shape[0], dtype=torch.float64)

model = SampleModel().eval()
args = (torch.randint(0, 100, (10, 10), dtype=torch.int64),)

try:
# Expect this to fix the error during conversion
ai_edge_torch.convert(model, args)
except Exception as err:
self.fail(f"Conversion failed with int64 inputs: {err}")
# pylint: enable=broad-except

def test_convert_model_with_torch_div_operation(self):
"""Test converting a simple model with torch.div operation."""

class SampleModel(nn.Module):

def forward(self, x: torch.Tensor, y: torch.Tensor):
return x / y

model = SampleModel().eval()
args = (
torch.randint(0, 100, (10, 10), dtype=torch.int64),
torch.randint(0, 100, (10, 10), dtype=torch.int64),
)

try:
# Expect this to fix the error during conversion
ai_edge_torch.convert(model, args)
except Exception as err:
self.fail(f"Conversion failed with int64 inputs: {err}")
# pylint: enable=broad-except

def test_compile_model(self):
"""Tests AOT compilation of a simple Add module."""

Expand Down
4 changes: 4 additions & 0 deletions ai_edge_torch/lowertools/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ def gather_state_dict(

for _, tensor, _ in _get_states(exported_programs, signatures):
unique_id = _tensor_unique_id(tensor)
if tensor.dtype == torch.int64:
tensor = tensor.to(torch.int32)
elif tensor.dtype == torch.float64:
tensor = tensor.to(torch.float32)
deduped_tensor_map[unique_id] = _torch_to_tf_variable(tensor)

state_dict = {}
Expand Down
21 changes: 18 additions & 3 deletions ai_edge_torch/lowertools/odml_torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,12 @@ class MergedBundle:

def torch_dtype_to_tf(dtype):
return {
torch.double: tf.float64,
# torch.double: tf.float64,
torch.double: tf.float32,
torch.float32: tf.float32,
torch.half: tf.float16,
torch.long: tf.int64,
# torch.long: tf.int64,
torch.long: tf.int32,
torch.int32: tf.int32,
torch.int16: tf.int16,
torch.bool: tf.bool,
Expand Down Expand Up @@ -158,10 +160,23 @@ def merged_bundle_to_tfl_model(
tf_module.f = []

for tf_sig, func in zip(tf_signatures, tf_functions):
processed_tf_sig = []
for spec in tf_sig:
if spec.dtype == tf.int64:
processed_tf_sig.append(
tf.TensorSpec(shape=spec.shape, dtype=tf.int32, name=spec.name)
)
elif spec.dtype == tf.float64:
processed_tf_sig.append(
tf.TensorSpec(shape=spec.shape, dtype=tf.float32, name=spec.name)
)
else:
processed_tf_sig.append(spec)

tf_module.f.append(
tf.function(
func,
input_signature=tf_sig,
input_signature=processed_tf_sig,
)
)

Expand Down
15 changes: 14 additions & 1 deletion ai_edge_torch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import numpy.typing as npt
import tensorflow as tf
import torch

from ai_edge_litert import interpreter as tfl_interpreter # pylint: disable=g-direct-tensorflow-import

Expand Down Expand Up @@ -132,7 +133,19 @@ def __call__(
)

# Gather the input dictionary based on the signature.
inputs = {f'args_{idx}': args[idx] for idx in range(len(args))}
inputs = {}
for idx in range(len(args)):
arg = args[idx]
arg_name = f'args_{idx}'
if hasattr(arg, 'dtype'):
print(f"Input '{arg_name}' has dtype: {arg.dtype}")
if arg.dtype == torch.int64:
arg = arg.to(torch.int32)
print(f"Converted '{arg_name}' to dtype: {arg.dtype}")
elif arg.dtype == torch.float64:
arg = arg.to(torch.float32)
print(f"Converted '{arg_name}' to dtype: {arg.dtype}")
inputs[arg_name] = arg
inputs = {**inputs, **kwargs}
outputs = runner(**inputs)

Expand Down
16 changes: 14 additions & 2 deletions ai_edge_torch/odml_torch/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ def _build_flat_inputs(exported_program: torch.export.ExportedProgram):
f"{type(arg)} (for {node.name}) does not have tensor meta"
)

if dataclasses.is_dataclass(tensor_meta):
if tensor_meta.dtype == torch.int64:
tensor_meta = dataclasses.replace(tensor_meta, dtype=torch.int32)
elif tensor_meta.dtype == torch.float64:
tensor_meta = dataclasses.replace(tensor_meta, dtype=torch.float32)

tensor_metas.append(tensor_meta)
# Assume all dynamic dimensions are unbounded.
# TODO: Add checks for ep.range_constraints in MLIR.
Expand Down Expand Up @@ -429,11 +435,17 @@ def exported_program_to_mlir(
# Assumption:
# All states comes first in the list of args, and user provided inputs
# comes later. Also there is no kwargs.
dtype = tensor_meta.dtype
if dtype == torch.int64:
dtype = torch.int32
elif dtype == torch.float64:
dtype = torch.float32

if input_spec.kind == torch.export.graph_signature.InputKind.USER_INPUT:
input_signature.append(
VariableSignature(
tensor_meta.shape,
tensor_meta.dtype,
dtype,
input_spec=InputSpec.user_input(user_inputs_cnt),
)
)
Expand All @@ -444,7 +456,7 @@ def exported_program_to_mlir(
input_signature.append(
VariableSignature(
tensor_meta.shape,
tensor_meta.dtype,
dtype,
input_spec=InputSpec.parameter(input_spec.target),
)
)
Expand Down
3 changes: 2 additions & 1 deletion ai_edge_torch/odml_torch/jax_bridge/_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
# i64/f64 tensors from Jax bridged lowerings. If not set properly, all the
# 64bit tensors would be truncated to 32bit dtype and potentially break the
# lowering.
jax.config.update("jax_enable_x64", True)
# jax.config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", False)


def _lower_to_ir_text(
Expand Down
6 changes: 4 additions & 2 deletions ai_edge_torch/odml_torch/jax_bridge/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ def t2j_dtype(dtype):
torch.half: jnp.float16,
torch.float32: jnp.float32,
torch.double: jnp.double,
torch.long: jnp.int64,
torch.int64: jnp.int64,
# torch.long: jnp.int64,
# torch.int64: jnp.int64,
torch.long: jnp.int32,
torch.int64: jnp.int32,
torch.int32: jnp.int32,
torch.int16: jnp.int16,
torch.int8: jnp.int8,
Expand Down
Loading
Loading