Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion ivy/functional/backends/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,21 @@
import ivy
from ivy.func_wrapper import _dtype_from_version

all_dtypes = (
ivy.int8,
ivy.int16,
ivy.int32,
ivy.int64,
ivy.uint8,
ivy.bfloat16,
ivy.float16,
ivy.float32,
ivy.float64,
ivy.complex64,
ivy.complex128,
ivy.bool,
)

backend_version = {"version": torch.__version__.split("+")[0]}

# Registering ivy.Array as trackable submodule
Expand Down Expand Up @@ -70,7 +85,6 @@
)
}


valid_numeric_dtypes = {
"2.2 and below": (
ivy.int8,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,16 @@
),
},
"numpy": {},
"torch": {
"PreTrainedModel": (
"PreTrainedModel",
"from transformers.modeling_utils import PreTrainedModel",
),
"BaseModelOutput": (
"BaseModelOutput",
"from transformers.modeling_outputs import BaseModelOutput",
),
},
}

IVY_DEFAULT_DTYPE_MAPPING = {
Expand Down Expand Up @@ -231,6 +241,12 @@
},
"jax": {None: None},
"numpy": {None: None},
"torch": {
"torch.dtype": "torch.dtype",
"torch.Size": "torch.Size",
"torch.Tensor": "torch.Tensor",
"torch.device": "torch.device",
},
}

IVY_GLOBS = {
Expand Down
2 changes: 1 addition & 1 deletion ivy/transpiler/main_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@

# Frameworks whose code S2S supports currently
SUPPORTED_S2S_SOURCES = ["torch", "ivy"]
SUPPORTED_S2S_TARGETS = ["tensorflow", "jax", "numpy", "ivy"]
SUPPORTED_S2S_TARGETS = ["tensorflow", "jax", "numpy", "ivy", "torch"]
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,29 @@ def __init__(
"itemsize",
"ndim",
)
elif self.transformer.target == "torch":
# PyTorch tensor properties to ignore
self.properties_to_ignore = (
"data",
"shape",
"dtype",
"device",
"strides",
"size",
"ndim",
"T",
"real",
"imag",
)
else:
# Default properties to ignore for other targets
self.properties_to_ignore = (
"data",
"shape",
"dtype",
"device",
"strides",
)

def transform(self):
# no need to transform method calls for backend impl in ivy.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
# global
import ivy


# local
import gast
from ...configurations.base_transformer_config import (
BaseTransformerConfig,
)
from ...transformer import Transformer
from ....utils.ast_utils import (
ast_to_source_code,
)
from ....utils.api_utils import (
get_native_array_str_from_backend,
get_native_module_str_from_backend,
)
from .ivy_postprocessing_transformer import (
IvyCodePostProcessor,
)


class IvyToTorchCodePostProcessor(IvyCodePostProcessor):
"""
Perform post-processing for PyTorch backend.
"""

def __init__(
self,
root,
transformer: Transformer,
configuration: BaseTransformerConfig,
new_name="tensor",
) -> None:
super().__init__(root, transformer, configuration, new_name=new_name)
self.root = root
self.transformer = transformer
self.configuration = configuration

def _handle_ivy_array(self, node):
new_name = get_native_array_str_from_backend(ivy.backend)
return gast.parse(f"{ivy.backend}.{new_name}").body[0].value

def _handle_ivy_variable(self, node):
return gast.parse("torch.nn.Parameter").body[0].value

def _handle_ivy_module(self, node):
new_name = get_native_module_str_from_backend(
backend_str=ivy.backend,
is_root_obj=self.transformer.object_like.is_root_obj,
depth=self.transformer.object_like.depth,
)
new_name = new_name.replace(".", "_")
return gast.parse(f"{new_name}").body[0].value

def _handle_assign_transform(self, node):
return gast.Call(
func=gast.Attribute(
value=gast.Name(id="torch", ctx=gast.Load()),
attr="nn.Parameter",
ctx=gast.Load(),
),
args=node.value.args,
keywords=node.value.keywords,
)

def _transform_isinstance_check(self, node):
"""
if not isinstance(module, torch_nn_Module) --> if not isinstance(module, (torch_nn_Module, torch.nn.Module))
"""
new_args = [
node.args[0],
gast.Tuple(
elts=[
node.args[1],
gast.parse("torch.nn.Module").body[0].value,
],
ctx=gast.Load(),
),
]
node.args = new_args
return node

def _get_forward_name(self, node):
return "forward"

def _maybe_convert_device_attribute(self, node):
# For PyTorch, device is a property that can be accessed and modified
# No special handling needed for device in PyTorch
return node

def _maybe_replace_with_native_array_calls(self, node):
func_str = ast_to_source_code(node.func).strip()
if func_str in ("torch.Tensor", "Tensor", "ivy.Array"):
new_func = gast.Attribute(
value=gast.Name(
id="torch",
ctx=gast.Load(),
annotation=None,
type_comment=None,
),
attr="tensor",
ctx=gast.Load(),
)
node.func = gast.fix_missing_locations(new_func)
return node

def _replace_ivy_array_pattern(self, elts):
"""
Transform the type check argument of an isinstance call
to replace any occurrence of (ivy.Array, ivy.Array) with
(torch.Tensor, torch.nn.Parameter).
"""
# Pattern to look for: (ivy.Array, ivy.Array)
pattern = [
gast.Attribute(
value=gast.Name(id="ivy", ctx=gast.Load()),
attr="Array",
ctx=gast.Load(),
),
gast.Attribute(
value=gast.Name(id="ivy", ctx=gast.Load()),
attr="Array",
ctx=gast.Load(),
),
]

# Serialize the pattern into a string
pattern_dump = [gast.dump(node) for node in pattern]

# Traverse through the elements and replace any matching pattern
transformed_elts = []
i = 0
while i < len(elts):
# Serialize current slice of elements and compare with pattern_dump
elts_dump = [gast.dump(node) for node in elts[i : i + 2]]
if elts_dump == pattern_dump: # Check if we found the pattern
# Replace the pattern with (torch.Tensor, torch.nn.Parameter)
transformed_elts.extend(
[
gast.Attribute(
value=gast.Name(id="torch", ctx=gast.Load()),
attr="Tensor",
ctx=gast.Load(),
),
gast.Attribute(
value=gast.Attribute(
value=gast.Name(id="torch", ctx=gast.Load()),
attr="nn",
ctx=gast.Load(),
),
attr="Parameter",
ctx=gast.Load(),
),
]
)
i += 2 # Skip the matched elements
else:
transformed_elts.append(elts[i])
i += 1

return transformed_elts

def _maybe_modify_inplace_update_fn(self, node):
# Check if the function name contains "inplace_update"
if "inplace_update" in node.name:
# Step 1: Modify the default value of keep_input_dtype to True
self._modify_keep_input_dtype_kwarg(node)

# Step 2: Modify assignment nodes to use val_native on the right-hand side
self._modify_assignments_to_val_native(node)

# Step 3: Replace conditional blocks with direct assignment for PyTorch
self._replace_conditional_blocks_for_torch(node)

return node

def _modify_keep_input_dtype_kwarg(self, node):
"""Step 1: Modify keep_input_dtype kwarg default value to True in inplace update signature."""
for kwarg, default in zip(node.args.kwonlyargs, node.args.kw_defaults):
if ast_to_source_code(kwarg).strip() == "keep_input_dtype":
# Modify default value to True if it exists
if isinstance(default, gast.Constant):
default.value = True
break

def _modify_assignments_to_val_native(self, node):
"""Step 2: Modify assignment nodes to use val_native on the RHS in inplace_update body."""

class AssignVisitor(gast.NodeTransformer):
def visit_Assign(self, assign_node):
for target in assign_node.targets:
if ast_to_source_code(target).strip() == "x":
# Modify the right-hand side to use val_native (keep function calls)
val_native_node = gast.Name(id="val_native", ctx=gast.Load())
# If the right-hand side is a function call, replace its first argument with "val_native"
if isinstance(assign_node.value, gast.Call):
assign_node.value.args[0] = val_native_node
else:
# Otherwise, replace the entire right-hand side with "val_native"
assign_node.value = val_native_node
self.generic_visit(assign_node)
return assign_node

AssignVisitor().visit(node)

def _replace_conditional_blocks_for_torch(self, node):
"""Step 3: Replace conditional blocks with direct assignment for PyTorch."""

class IfVisitor(gast.NodeTransformer):
def visit_If(self, if_node):
# Check if this is the specific conditional we want to replace
condition_str = ast_to_source_code(if_node.test).strip()
if "torch_is_ivy_array_bknd" in condition_str or "is_ivy_array_bknd" in condition_str:
# Replace with a direct assignment: x = x_native
return gast.Assign(
targets=[gast.Name(id="x", ctx=gast.Store())],
value=gast.Name(id="x_native", ctx=gast.Load()),
type_comment=None,
)
self.generic_visit(if_node)
return if_node

IfVisitor().visit(node)
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@
from ...transformations.transformers.postprocessing_transformer.ivy_to_numpy_postprocessing_transformer import (
IvyToNumpyCodePostProcessor,
)
from ...transformations.transformers.postprocessing_transformer.ivy_to_torch_postprocessing_transformer import (
IvyToTorchCodePostProcessor,
)
from ...transformations.transformers.recursive_transformer.ivy_recursive_transformer import (
IvyRecurser,
)
Expand All @@ -75,6 +78,7 @@ class IvyToSourceTranslatorConfig(BaseTranslatorConfig):
IvyToTFCodePostProcessor: IvyCodePostProcessorConfig,
IvyToJAXCodePostProcessor: IvyCodePostProcessorConfig,
IvyToNumpyCodePostProcessor: IvyCodePostProcessorConfig,
IvyToTorchCodePostProcessor: IvyCodePostProcessorConfig,
}

def __init__(self, source="ivy", target="tensorflow", base_output_dir="") -> None:
Expand Down Expand Up @@ -122,6 +126,21 @@ def __init__(self, source="ivy", target="tensorflow", base_output_dir="") -> Non
PytorchToFlaxLayer,
HFPretrainedFlaxTransformer,
]
elif target == "torch":
self.transformers: List[BaseTransformer] = [
IvyNodeDeleter,
IvyDecoratorRemover,
# BaseTypeHintRemover,
BaseDocstringRemover,
# BaseTypeAnnotationRemover,
IvyMethodToFunctionConverter,
BaseDundersTransformer,
IvyCodePreProcessor,
BaseNameCanonicalizer,
BaseGlobalsTransformer,
IvyRecurser,
IvyToTorchCodePostProcessor,
]
elif target == "numpy":
self.transformers: List[BaseTransformer] = [
IvyNodeDeleter,
Expand Down
1 change: 1 addition & 0 deletions ivy/transpiler/utils/api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"tensorflow": {True: "tensorflow.keras.Model", False: "tensorflow.keras.Layer"},
"jax": {True: "flax.nnx.Module", False: "flax.nnx.Module"},
"numpy": {True: "type", False: "type"},
"torch": {True: "torch.nn.Module", False: "torch.nn.Module"},
}
TRANSLATED_OBJ_PREFIX = [
"Translated_",
Expand Down
9 changes: 9 additions & 0 deletions ivy/transpiler/utils/source_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,15 @@ def _maybe_create_stateful_layers_module(target: str, output_dir: str):

source = FlaxNativeLayers
module_name = "jax__stateful_layers"
elif target == "torch":
# For torch target, we don't have specific native layers yet
# Just create an empty module for now
source = "# PyTorch stateful layers module\n"
module_name = "torch__stateful_layers"
else:
# For any other targets, return without creating a module
return

stateful_file = os.path.join(output_dir, f"{module_name}.py")

f = open(stateful_file, "w", encoding="utf-8", newline="\n")
Expand Down