diff --git a/ai_edge_torch/_convert/test/test_convert.py b/ai_edge_torch/_convert/test/test_convert.py index 382fc4d7..b613489d 100644 --- a/ai_edge_torch/_convert/test/test_convert.py +++ b/ai_edge_torch/_convert/test/test_convert.py @@ -576,45 +576,6 @@ def forward(self, x: torch.Tensor): self.fail(f"Conversion failed with bloat16 inputs: {err}") # pylint: enable=broad-except - def test_convert_model_with_torch_linspace_operation(self): - """Test converting a simple model with torch.linspace operation.""" - - class SampleModel(nn.Module): - - def forward(self, x: torch.Tensor): - return torch.linspace(0.5, 10.5, steps=x.shape[0], dtype=torch.float64) - - model = SampleModel().eval() - args = (torch.randint(0, 100, (10, 10), dtype=torch.int64),) - - try: - # Expect this to fix the error during conversion - ai_edge_torch.convert(model, args) - except Exception as err: - self.fail(f"Conversion failed with int64 inputs: {err}") - # pylint: enable=broad-except - - def test_convert_model_with_torch_div_operation(self): - """Test converting a simple model with torch.div operation.""" - - class SampleModel(nn.Module): - - def forward(self, x: torch.Tensor, y: torch.Tensor): - return x / y - - model = SampleModel().eval() - args = ( - torch.randint(0, 100, (10, 10), dtype=torch.int64), - torch.randint(0, 100, (10, 10), dtype=torch.int64), - ) - - try: - # Expect this to fix the error during conversion - ai_edge_torch.convert(model, args) - except Exception as err: - self.fail(f"Conversion failed with int64 inputs: {err}") - # pylint: enable=broad-except - def test_compile_model(self): """Tests AOT compilation of a simple Add module.""" diff --git a/ai_edge_torch/lowertools/common_utils.py b/ai_edge_torch/lowertools/common_utils.py index 5e0e29c9..21fdf5d3 100644 --- a/ai_edge_torch/lowertools/common_utils.py +++ b/ai_edge_torch/lowertools/common_utils.py @@ -128,6 +128,10 @@ def gather_state_dict( for _, tensor, _ in _get_states(exported_programs, signatures): unique_id = _tensor_unique_id(tensor) + if tensor.dtype == torch.int64: + tensor = tensor.to(torch.int32) + elif tensor.dtype == torch.float64: + tensor = tensor.to(torch.float32) deduped_tensor_map[unique_id] = _torch_to_tf_variable(tensor) state_dict = {} diff --git a/ai_edge_torch/lowertools/odml_torch_utils.py b/ai_edge_torch/lowertools/odml_torch_utils.py index d97ddece..387bdb24 100644 --- a/ai_edge_torch/lowertools/odml_torch_utils.py +++ b/ai_edge_torch/lowertools/odml_torch_utils.py @@ -45,10 +45,12 @@ class MergedBundle: def torch_dtype_to_tf(dtype): return { - torch.double: tf.float64, + # torch.double: tf.float64, + torch.double: tf.float32, torch.float32: tf.float32, torch.half: tf.float16, - torch.long: tf.int64, + # torch.long: tf.int64, + torch.long: tf.int32, torch.int32: tf.int32, torch.int16: tf.int16, torch.bool: tf.bool, @@ -158,10 +160,23 @@ def merged_bundle_to_tfl_model( tf_module.f = [] for tf_sig, func in zip(tf_signatures, tf_functions): + processed_tf_sig = [] + for spec in tf_sig: + if spec.dtype == tf.int64: + processed_tf_sig.append( + tf.TensorSpec(shape=spec.shape, dtype=tf.int32, name=spec.name) + ) + elif spec.dtype == tf.float64: + processed_tf_sig.append( + tf.TensorSpec(shape=spec.shape, dtype=tf.float32, name=spec.name) + ) + else: + processed_tf_sig.append(spec) + tf_module.f.append( tf.function( func, - input_signature=tf_sig, + input_signature=processed_tf_sig, ) ) diff --git a/ai_edge_torch/model.py b/ai_edge_torch/model.py index 88466256..30ab838b 100644 --- a/ai_edge_torch/model.py +++ b/ai_edge_torch/model.py @@ -27,6 +27,7 @@ import numpy.typing as npt import tensorflow as tf +import torch from ai_edge_litert import interpreter as tfl_interpreter # pylint: disable=g-direct-tensorflow-import @@ -132,7 +133,19 @@ def __call__( ) # Gather the input dictionary based on the signature. - inputs = {f'args_{idx}': args[idx] for idx in range(len(args))} + inputs = {} + for idx in range(len(args)): + arg = args[idx] + arg_name = f'args_{idx}' + if hasattr(arg, 'dtype'): + print(f"Input '{arg_name}' has dtype: {arg.dtype}") + if arg.dtype == torch.int64: + arg = arg.to(torch.int32) + print(f"Converted '{arg_name}' to dtype: {arg.dtype}") + elif arg.dtype == torch.float64: + arg = arg.to(torch.float32) + print(f"Converted '{arg_name}' to dtype: {arg.dtype}") + inputs[arg_name] = arg inputs = {**inputs, **kwargs} outputs = runner(**inputs) diff --git a/ai_edge_torch/odml_torch/export.py b/ai_edge_torch/odml_torch/export.py index 0a883721..b7b411a7 100644 --- a/ai_edge_torch/odml_torch/export.py +++ b/ai_edge_torch/odml_torch/export.py @@ -55,6 +55,12 @@ def _build_flat_inputs(exported_program: torch.export.ExportedProgram): f"{type(arg)} (for {node.name}) does not have tensor meta" ) + if dataclasses.is_dataclass(tensor_meta): + if tensor_meta.dtype == torch.int64: + tensor_meta = dataclasses.replace(tensor_meta, dtype=torch.int32) + elif tensor_meta.dtype == torch.float64: + tensor_meta = dataclasses.replace(tensor_meta, dtype=torch.float32) + tensor_metas.append(tensor_meta) # Assume all dynamic dimensions are unbounded. # TODO: Add checks for ep.range_constraints in MLIR. @@ -429,11 +435,17 @@ def exported_program_to_mlir( # Assumption: # All states comes first in the list of args, and user provided inputs # comes later. Also there is no kwargs. + dtype = tensor_meta.dtype + if dtype == torch.int64: + dtype = torch.int32 + elif dtype == torch.float64: + dtype = torch.float32 + if input_spec.kind == torch.export.graph_signature.InputKind.USER_INPUT: input_signature.append( VariableSignature( tensor_meta.shape, - tensor_meta.dtype, + dtype, input_spec=InputSpec.user_input(user_inputs_cnt), ) ) @@ -444,7 +456,7 @@ def exported_program_to_mlir( input_signature.append( VariableSignature( tensor_meta.shape, - tensor_meta.dtype, + dtype, input_spec=InputSpec.parameter(input_spec.target), ) ) diff --git a/ai_edge_torch/odml_torch/jax_bridge/_wrap.py b/ai_edge_torch/odml_torch/jax_bridge/_wrap.py index bdf5c63d..09ea37b5 100644 --- a/ai_edge_torch/odml_torch/jax_bridge/_wrap.py +++ b/ai_edge_torch/odml_torch/jax_bridge/_wrap.py @@ -31,7 +31,8 @@ # i64/f64 tensors from Jax bridged lowerings. If not set properly, all the # 64bit tensors would be truncated to 32bit dtype and potentially break the # lowering. -jax.config.update("jax_enable_x64", True) +# jax.config.update("jax_enable_x64", True) +jax.config.update("jax_enable_x64", False) def _lower_to_ir_text( diff --git a/ai_edge_torch/odml_torch/jax_bridge/utils.py b/ai_edge_torch/odml_torch/jax_bridge/utils.py index d9b0449c..ccb933ee 100644 --- a/ai_edge_torch/odml_torch/jax_bridge/utils.py +++ b/ai_edge_torch/odml_torch/jax_bridge/utils.py @@ -27,8 +27,10 @@ def t2j_dtype(dtype): torch.half: jnp.float16, torch.float32: jnp.float32, torch.double: jnp.double, - torch.long: jnp.int64, - torch.int64: jnp.int64, + # torch.long: jnp.int64, + # torch.int64: jnp.int64, + torch.long: jnp.int32, + torch.int64: jnp.int32, torch.int32: jnp.int32, torch.int16: jnp.int16, torch.int8: jnp.int8, diff --git a/ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py b/ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py index 0cf1982c..41cdf033 100644 --- a/ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +++ b/ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py @@ -78,6 +78,8 @@ def lower_by_torch_xla2(op): lower_by_torch_xla2(torch.ops.aten._unsafe_view) lower_by_torch_xla2(torch.ops.aten.acos) lower_by_torch_xla2(torch.ops.aten.acosh) +lower_by_torch_xla2(torch.ops.aten.add.Scalar) +lower_by_torch_xla2(torch.ops.aten.add.Tensor) lower_by_torch_xla2(torch.ops.aten.addbmm.default) lower_by_torch_xla2(torch.ops.aten.addmm) lower_by_torch_xla2(torch.ops.aten.addmv) @@ -116,6 +118,7 @@ def lower_by_torch_xla2(op): lower_by_torch_xla2(torch.ops.aten.cumsum) lower_by_torch_xla2(torch.ops.aten.detach) lower_by_torch_xla2(torch.ops.aten.diagonal) +lower_by_torch_xla2(torch.ops.aten.div) lower_by_torch_xla2(torch.ops.aten.dot) lower_by_torch_xla2(torch.ops.aten.embedding) lower_by_torch_xla2(torch.ops.aten.empty) @@ -156,6 +159,7 @@ def lower_by_torch_xla2(op): lower_by_torch_xla2(torch.ops.aten.logical_not) lower_by_torch_xla2(torch.ops.aten.logical_or) lower_by_torch_xla2(torch.ops.aten.logical_xor) +lower_by_torch_xla2(torch.ops.aten.lt) lower_by_torch_xla2(torch.ops.aten.max) lower_by_torch_xla2(torch.ops.aten.max_pool2d_with_indices) lower_by_torch_xla2(torch.ops.aten.max_pool2d_with_indices_backward) @@ -166,6 +170,8 @@ def lower_by_torch_xla2(op): lower_by_torch_xla2(torch.ops.aten.min) lower_by_torch_xla2(torch.ops.aten.minimum) lower_by_torch_xla2(torch.ops.aten.mm) +lower_by_torch_xla2(torch.ops.aten.mul.Scalar) +lower_by_torch_xla2(torch.ops.aten.mul.Tensor) lower_by_torch_xla2(torch.ops.aten.native_batch_norm) lower_by_torch_xla2(torch.ops.aten.native_layer_norm_backward) lower_by_torch_xla2(torch.ops.aten.ne) @@ -209,6 +215,8 @@ def lower_by_torch_xla2(op): lower_by_torch_xla2(torch.ops.aten.squeeze) lower_by_torch_xla2(torch.ops.aten.squeeze_copy) lower_by_torch_xla2(torch.ops.aten.stack) +lower_by_torch_xla2(torch.ops.aten.sub.Scalar) +lower_by_torch_xla2(torch.ops.aten.sub.Tensor) lower_by_torch_xla2(torch.ops.aten.sum) lower_by_torch_xla2(torch.ops.aten.t) lower_by_torch_xla2(torch.ops.aten.tan) @@ -236,6 +244,7 @@ def lower_by_torch_xla2(op): lower_by_torch_xla2(torch.ops.aten.view_copy) lower_by_torch_xla2(torch.ops.aten.where.ScalarOther) lower_by_torch_xla2(torch.ops.aten.where.ScalarSelf) +lower_by_torch_xla2(torch.ops.aten.where.self) lower_by_torch_xla2(torch.ops.prims.broadcast_in_dim) lower_by_torch_xla2(torch.ops.prims.var) @@ -250,243 +259,6 @@ def _aten_copy(self, src, **kwargs): return _TORCH_XLA2_IMPLS[torch.ops.aten.copy](self, src) -@registry.lower(torch.ops.aten.add.Scalar) -def _aten_add_scalar(lctx: LoweringContext, self, other): - _log_usage(torch.ops.aten.add.Scalar) - - @jax_bridge.wrap - def jax_lowering(self, other): - other_dtype = jnp.result_type(other) - promoted_type = jnp.promote_types(self.dtype, other_dtype) - if promoted_type == jnp.float64: - promoted_type = jnp.float32 - return jnp.add( - self.astype(promoted_type), jnp.array(other, dtype=promoted_type) - ) - - return jax_lowering(lctx, self, other) - - -@registry.lower(torch.ops.aten.add.Tensor) -def _aten_add_tensor(lctx: LoweringContext, self, other): - _log_usage(torch.ops.aten.add.Tensor) - - @jax_bridge.wrap - def jax_lowering(self, other): - promoted_type = jnp.promote_types(self.dtype, other.dtype) - if promoted_type == jnp.float64: - promoted_type = jnp.float32 - return jnp.add(self.astype(promoted_type), other.astype(promoted_type)) - - return jax_lowering(lctx, self, other) - - -@registry.lower(torch.ops.aten.sub.Scalar) -def _aten_sub_scalar(lctx: LoweringContext, self, other): - _log_usage(torch.ops.aten.sub.Scalar) - - @jax_bridge.wrap - def jax_lowering(self, other): - other_dtype = jnp.result_type(other) - promoted_type = jnp.promote_types(self.dtype, other_dtype) - if promoted_type == jnp.float64: - promoted_type = jnp.float32 - return jnp.subtract( - self.astype(promoted_type), jnp.array(other, dtype=promoted_type) - ) - - return jax_lowering(lctx, self, other) - - -@registry.lower(torch.ops.aten.sub.Tensor) -def _aten_sub_tensor(lctx: LoweringContext, self, other): - _log_usage(torch.ops.aten.sub.Tensor) - - @jax_bridge.wrap - def jax_lowering(self, other): - promoted_type = jnp.promote_types(self.dtype, other.dtype) - if promoted_type == jnp.float64: - promoted_type = jnp.float32 - return jnp.subtract(self.astype(promoted_type), other.astype(promoted_type)) - - return jax_lowering(lctx, self, other) - - -@registry.lower(torch.ops.aten.lt.Scalar) -def _aten_lt_scalar(lctx: LoweringContext, self, other): - _log_usage(torch.ops.aten.lt.Scalar) - - @jax_bridge.wrap - def jax_lowering(self, other): - other_dtype = jnp.result_type(other) - promoted_type = jnp.promote_types(self.dtype, other_dtype) - if promoted_type == jnp.float64: - promoted_type = jnp.float32 - return jnp.less( - self.astype(promoted_type), jnp.array(other, dtype=promoted_type) - ) - - return jax_lowering(lctx, self, other) - - -@registry.lower(torch.ops.aten.lt.Tensor) -def _aten_lt_tensor(lctx: LoweringContext, self, other): - _log_usage(torch.ops.aten.lt.Tensor) - - @jax_bridge.wrap - def jax_lowering(self, other): - promoted_type = jnp.promote_types(self.dtype, other.dtype) - return jnp.less(self.astype(promoted_type), other.astype(promoted_type)) - - return jax_lowering(lctx, self, other) - - -@registry.lower(torch.ops.aten.mul.Scalar) -def _aten_mul_scalar(lctx: LoweringContext, self, other): - _log_usage(torch.ops.aten.mul.Scalar) - - @jax_bridge.wrap - def jax_lowering(self, other): - other_dtype = jnp.result_type(other) - promoted_type = jnp.promote_types(self.dtype, other_dtype) - if promoted_type == jnp.float64: - promoted_type = jnp.float32 - elif promoted_type == jnp.int64: - promoted_type = jnp.int32 - return jnp.multiply( - self.astype(promoted_type), jnp.array(other, dtype=promoted_type) - ) - - return jax_lowering(lctx, self, other) - - -@registry.lower(torch.ops.aten.mul.Tensor) -def _aten_mul_tensor(lctx: LoweringContext, self, other): - _log_usage(torch.ops.aten.mul.Tensor) - - @jax_bridge.wrap - def jax_lowering(self, other): - other_dtype = jnp.result_type(other) - promoted_type = jnp.promote_types(self.dtype, other_dtype) - if promoted_type == jnp.float64: - promoted_type = jnp.float32 - elif promoted_type == jnp.int64: - promoted_type = jnp.int32 - return jnp.multiply( - self.astype(promoted_type), jnp.array(other, dtype=promoted_type) - ) - - return jax_lowering(lctx, self, other) - - -@registry.lower(torch.ops.aten.div.Scalar) -def _aten_div_scalar(lctx: LoweringContext, self, other): - _log_usage(torch.ops.aten.div.Scalar) - - @jax_bridge.wrap - def jax_lowering(self, other): - other_dtype = jnp.result_type(other) - promoted_type = jnp.promote_types(self.dtype, other_dtype) - if promoted_type == jnp.float64: - promoted_type = jnp.float32 - elif promoted_type == jnp.int64: - promoted_type = jnp.int32 - return jnp.divide( - self.astype(promoted_type), jnp.array(other, dtype=promoted_type) - ) - - return jax_lowering(lctx, self, other) - - -@registry.lower(torch.ops.aten.div.Scalar_mode) -def _aten_div_scalar_mode(lctx: LoweringContext, self, other, rounding_mode=""): - _log_usage(torch.ops.aten.div.Scalar_mode) - - @jax_bridge.wrap - def jax_lowering(self, other): - other_dtype = jnp.result_type(other) - promoted_type = jnp.promote_types(self.dtype, other_dtype) - if promoted_type == jnp.float64: - promoted_type = jnp.float32 - elif promoted_type == jnp.int64: - promoted_type = jnp.int32 - if rounding_mode == "floor": - return jnp.floor_divide( - self.astype(promoted_type), jnp.array(other, dtype=promoted_type) - ) - result = jnp.divide( - self.astype(promoted_type), jnp.array(other, dtype=promoted_type) - ) - if rounding_mode == "trunc": - result = jnp.trunc(result) - return result - - return jax_lowering(lctx, self, other) - - -@registry.lower(torch.ops.aten.div.Tensor) -def _aten_div_tensor(lctx: LoweringContext, self, other): - _log_usage(torch.ops.aten.div.Tensor) - - @jax_bridge.wrap - def jax_lowering(self, other): - other_dtype = jnp.result_type(other) - promoted_type = jnp.promote_types(self.dtype, other_dtype) - if promoted_type == jnp.float64: - promoted_type = jnp.float32 - elif promoted_type == jnp.int64: - promoted_type = jnp.int32 - return jnp.divide( - self.astype(promoted_type), jnp.array(other, dtype=promoted_type) - ) - - return jax_lowering(lctx, self, other) - - -@registry.lower(torch.ops.aten.div.Tensor_mode) -def _aten_div_tensor_mode(lctx: LoweringContext, self, other, rounding_mode=""): - _log_usage(torch.ops.aten.div.Tensor_mode) - - @jax_bridge.wrap - def jax_lowering(self, other): - other_dtype = jnp.result_type(other) - promoted_type = jnp.promote_types(self.dtype, other_dtype) - if promoted_type == jnp.float64: - promoted_type = jnp.float32 - elif promoted_type == jnp.int64: - promoted_type = jnp.int32 - if rounding_mode == "floor": - return jnp.floor_divide( - self.astype(promoted_type), jnp.array(other, dtype=promoted_type) - ) - result = jnp.divide( - self.astype(promoted_type), jnp.array(other, dtype=promoted_type) - ) - if rounding_mode == "trunc": - result = jnp.trunc(result) - return result - - return jax_lowering(lctx, self, other) - - -@registry.lower(torch.ops.aten.where.self) -def _aten_where_self(lctx: LoweringContext, condition, self, other): - _log_usage(torch.ops.aten.where.self) - - @jax_bridge.wrap - def jax_lowering(condition, self, other): - promoted_type = jnp.promote_types(self.dtype, other.dtype) - if promoted_type == jnp.float64: - promoted_type = jnp.float32 - return jnp.where( - condition, - self.astype(promoted_type), - other.astype(promoted_type), - ) - - return jax_lowering(lctx, condition, self, other) - - # Schema: # - aten::einsum(str equation, Tensor[] tensors, *, int[]? path=None) # -> Tensor diff --git a/ai_edge_torch/odml_torch/lowerings/utils.py b/ai_edge_torch/odml_torch/lowerings/utils.py index 02fa290c..0d2768d4 100644 --- a/ai_edge_torch/odml_torch/lowerings/utils.py +++ b/ai_edge_torch/odml_torch/lowerings/utils.py @@ -29,10 +29,12 @@ def torch_dtype_to_ir_element_type(dtype) -> ir.Type: """Builds ir.Type from torch dtype.""" ty_get = { - torch.double: ir.F64Type.get, + # torch.double: ir.F64Type.get, + torch.double: ir.F32Type.get, torch.float32: ir.F32Type.get, torch.half: ir.F16Type.get, - torch.long: functools.partial(ir.IntegerType.get_signless, 64), + # torch.long: functools.partial(ir.IntegerType.get_signless, 64), + torch.long: functools.partial(ir.IntegerType.get_signless, 32), torch.int32: functools.partial(ir.IntegerType.get_signless, 32), torch.int16: functools.partial(ir.IntegerType.get_signless, 16), torch.int8: functools.partial(ir.IntegerType.get_signless, 8), @@ -122,6 +124,7 @@ def get_broadcast_dimensions( for val in range(len(shape_to) - len(shape_from), len(shape_to)): ret.append(val) + # check? return ir.DenseI64ArrayAttr.get(np.asarray(ret, np.int64)) @@ -221,7 +224,8 @@ def convert_int_to_float(t: ir.Value) -> ir.Value: ) elif elty.width == 64: return stablehlo.convert( - ir.RankedTensorType.get(t.type.shape, ir.F64Type.get()), t + # ir.RankedTensorType.get(t.type.shape, ir.F64Type.get()), t + ir.RankedTensorType.get(t.type.shape, ir.F32Type.get()), t ) @@ -235,14 +239,17 @@ def convert_int_to_float(t: ir.Value) -> ir.Value: np.dtype(np.int8): functools.partial(ir.IntegerType.get_signless, 8), np.dtype(np.int16): functools.partial(ir.IntegerType.get_signless, 16), np.dtype(np.int32): functools.partial(ir.IntegerType.get_signless, 32), - np.dtype(np.int64): functools.partial(ir.IntegerType.get_signless, 64), + # np.dtype(np.int64): functools.partial(ir.IntegerType.get_signless, 64), + np.dtype(np.int64): functools.partial(ir.IntegerType.get_signless, 32), np.dtype(np.uint8): functools.partial(ir.IntegerType.get_unsigned, 8), np.dtype(np.uint16): functools.partial(ir.IntegerType.get_unsigned, 16), np.dtype(np.uint32): functools.partial(ir.IntegerType.get_unsigned, 32), - np.dtype(np.uint64): functools.partial(ir.IntegerType.get_unsigned, 64), + # np.dtype(np.uint64): functools.partial(ir.IntegerType.get_unsigned, 64), + np.dtype(np.uint64): functools.partial(ir.IntegerType.get_unsigned, 32), np.dtype(np.float16): ir.F16Type.get, np.dtype(np.float32): ir.F32Type.get, - np.dtype(np.float64): ir.F64Type.get, + # np.dtype(np.float64): ir.F64Type.get, + np.dtype(np.float64): ir.F32Type.get, np.dtype(np.complex64): lambda: ir.ComplexType.get(ir.F32Type.get()), np.dtype(np.complex128): lambda: ir.ComplexType.get(ir.F64Type.get()), }