Skip to content
Merged
13 changes: 10 additions & 3 deletions thunder/dynamo/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
checkpoint_converter,
_get_example_inputs_from_placeholder,
_ThunderSplitGraphModule,
translate_dtensor_ops,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -98,6 +99,7 @@ def forward(self, l_x_: "f32[2]", y: "f32[2]"):
split_reasons: list[SplitReason] = []

nodes_in_unsupported_ctx_regions = get_nodes_in_unsupported_ctx_regions(gm)
translate_dtensor_ops(gm)

def callback(node) -> int:
nonlocal prev_value, partition_cnt, split_reasons, supported_partitions
Expand All @@ -119,9 +121,14 @@ def callback(node) -> int:
)
split_reasons.append(split_reason)
else:
is_thunder_supported, split_reason = is_node_supported_by_thunder(node)
if split_reason is not None:
split_reasons.append(split_reason)
# To support dynamo generated prims for `parallelize_module`.
# `translate_dtensor_ops` will mark the target as thunder supported if it is a DTensor operation.
if hasattr(node.target, "thunder_supported") and node.target.thunder_supported:
is_thunder_supported, split_reason = True, None
else:
is_thunder_supported, split_reason = is_node_supported_by_thunder(node)
if split_reason is not None:
split_reasons.append(split_reason)

if prev_value == is_thunder_supported: # We are in the same region.
return partition_cnt
Expand Down
75 changes: 75 additions & 0 deletions thunder/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,3 +1050,78 @@ def get_compiled_fn_and_timing(report, compile_fn, timer_fn):
err_msg = ", ".join([f"{x.name} raised exception: {x.compiled_fn}" for x in sorted_compiled_gm_to_measurement])
raise RuntimeError(f"No compiler was able to compile the graph module, {err_msg}")
return sorted_compiled_gm_to_measurement[0].compiled_fn


def translate_dtensor_ops(gm: torch.fx.GraphModule):
# We need this function because:
#
# For a program like:
# ```
# model = nn.Linear(hidden_size, hidden_size, bias=False)
# parallel_model = parallelize_module(model, mesh, {"fc1": ColwiseParallel()})
# model.fc1.weight.requires_grad = False

# # parallelize_module will handle the conversion to DTensor
# i = torch.randn(hidden_size, hidden_size)
# ````
#
# Dynamo captures an FX-Graph like:
# ```
# def forward(self, L_x_: "f32[16, 16]", L_self_modules_fc1_parameters_weight_: "f32[16, 16]"):
# l_x_ = L_x_
# l_self_modules_fc1_parameters_weight_ = L_self_modules_fc1_parameters_weight_
#
# input_tensor: "f32[16, 16]" = torch__dynamo_variables_torch_prim_from_local(l_x_); l_x_ = None
#
# linear: "f32[16, 16]" = torch._C._nn.linear(input_tensor, l_self_modules_fc1_parameters_weight_, None); input_tensor = l_self_modules_fc1_parameters_weight_ = None
#
# outputs: "f32[16, 16]" = torch__dynamo_variables_tensor_prim_redistribute(linear); linear = None
#
# hook_result: "f32[16, 8]" = torch__dynamo_variables_tensor_prim_to_local(outputs); outputs = None
# return (hook_result,)
# ```
# where:
# 1. In the FX Graph, the Tensor Parallel computation is decomposed into primitive operations such as `torch__dynamo_variables_torch_prim_from_local`, `torch__dynamo_variables_tensor_prim_redistribute`, and others.
# 2. It is important to note that these decompositions actually capture (close over) values such as `placements` and other metadata.
# For example, to understand the placements to which the output will be redistributed using `torch__dynamo_variables_tensor_prim_redistribute`,
# we need to use `inspect.getclosurevars(node.target)` to examine the values (like placements) that are captured and used during execution.
# The reference for where this closure is created can be found at:
# https://github.com/pytorch/pytorch/blob/0ab075a69e4577a60c4dcbff7bcc2ecd0a15ce46/torch/_dynamo/variables/tensor.py#L1186-L1210

for node in gm.graph.nodes:
from thunder.torch.experimental.dtensor_torch_and_prims import (
dtensor_from_local_prim,
dtensor_redistribute_prim,
dtensor_to_local_prim,
)

try:
closure_vars = inspect.getclosurevars(node.target)

if "from_local" in node.target.__name__:
mesh = closure_vars.nonlocals["args_as_value"][0]
placements = closure_vars.nonlocals["args_as_value"][1]

def dtensor_from_local_prim_wrapper(x, mesh=mesh, placements=placements):
return dtensor_from_local_prim(x, mesh, placements)

dtensor_from_local_prim_wrapper.thunder_supported = True
node.target = dtensor_from_local_prim_wrapper
if "redistribute" in node.target.__name__:
kwargs = closure_vars.nonlocals["kwargs_as_value"]
placements = kwargs["placements"]

def dtensor_redistribute_prim_wrapper(x, placements=placements):
return dtensor_redistribute_prim(x, placements=placements)

dtensor_redistribute_prim_wrapper.thunder_supported = True
node.target = dtensor_redistribute_prim_wrapper
if "to_local" in node.target.__name__:

def dtensor_to_local_prim_wrapper(x):
return dtensor_to_local_prim(x)

dtensor_to_local_prim_wrapper.thunder_supported = True
node.target = dtensor_to_local_prim_wrapper
except Exception:
pass
39 changes: 39 additions & 0 deletions thunder/tests/distributed/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
from torch.distributed._tensor import DeviceMesh, distribute_tensor
from torch.distributed.tensor.placement_types import Shard, Replicate
from torch.testing._internal.distributed._tensor.common_dtensor import DTensorConverter
from torch.distributed.tensor.parallel import (
parallelize_module,
ColwiseParallel,
)

from torch.testing._internal import common_utils

Expand All @@ -23,6 +27,7 @@
from thunder.tests.utils import is_output_differentiable, filter_differentiable_outputs
import thunder.core.dtypes as dtypes
from thunder.core.pytree import tree_flatten
from thunder.dynamo import thunderfx


# NOTE: We run all these similar functions seperately
Expand Down Expand Up @@ -272,6 +277,40 @@ def fn(x):

torch.testing.assert_close(actual, expected)

@common_utils.parametrize("jit_fn", (thunder.jit, thunderfx), name_fn=lambda jit_fn: jit_fn.__name__)
def test_dtensor_columnwise_parallel(self, jit_fn):
if jit_fn == thunder.jit:
# File "/opt/pytorch/lightning-thunder/thunder/core/jit_ext.py", line 444, in _general_jit_getattr_lookaside
# obj.original_value.__dict__,
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^
# AttributeError: 'object' object has no attribute '__dict__'. Did you mean: '__dir__'?
raise unittest.SkipTest("thunder.jit fails with AttributeError")

num_devices = self.world_size
mesh = DeviceMesh("cuda", list(range(num_devices)))
dim_size = 16
in_dtensor = torch.randn(dim_size, dim_size, requires_grad=False)
m = torch.nn.Linear(dim_size, dim_size)
m.requires_grad_(False)

parallelized_model = parallelize_module(m, mesh, ColwiseParallel())

# `parallelize_module` sets `requires_grad` to True, set it to False again.
parallelized_model.requires_grad_(False)

actual = parallelized_model(in_dtensor)
expected = m(in_dtensor)
torch.testing.assert_close(actual, expected)

tmodel = jit_fn(parallelized_model, nv_enable_linear=True)
actual = tmodel(in_dtensor)
torch.testing.assert_close(actual, expected)

if jit_fn == thunderfx:
assert len(tmodel._backend.subgraph_infos) == 1
assert len(tmodel._backend.subgraph_infos[0].thunder_compiled_fns) == 1
assert len(tmodel._backend.subgraph_infos[0].split_reasons) == 0

@common_utils.parametrize("executor", tuple(executors_map.keys()))
@common_utils.parametrize(
"input_shardings",
Expand Down
35 changes: 35 additions & 0 deletions thunder/torch/experimental/dtensor_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,41 @@ def placements(self):
def device_mesh(self):
return self.spec._o.device_mesh

@staticmethod
def from_local(
x,
mesh,
placements,
*,
run_check: bool = False,
shape: torch.Size | None = None,
stride: tuple[int, ...] | None = None,
):
import thunder.torch.experimental.dtensor_torch_and_prims as dtensor_torch_and_prims

res = dtensor_torch_and_prims.dtensor_from_local_prim(
x, mesh, placements, run_check=run_check, shape=shape, stride=stride
)
return res

def redistribute(
self,
device_mesh: "Optional[DeviceMesh]" = None,
placements: "Optional[Sequence[Placement]]" = None,
*,
async_op: bool = False,
) -> "DTensor":
import thunder.torch.experimental.dtensor_torch_and_prims as dtensor_torch_and_prims

res = dtensor_torch_and_prims.dtensor_redistribute_prim(self, device_mesh, placements, async_op=async_op)
return res

def to_local(self, *, grad_placements: "Optional[Sequence[Placement]]" = None):
import thunder.torch.experimental.dtensor_torch_and_prims as dtensor_torch_and_prims

res = dtensor_torch_and_prims.dtensor_to_local_prim(self, grad_placements=grad_placements)
return res

def replace(self, **changes):
r"""Return a copy of the TensorProxy object with new values for the specified fields as given to the constructor as arguments.
Valid keyword arguments are ``name``, ``history``, ``shape``, ``dtype``, ``device``, ``requires_grad``, ``distparallel_type``, ``thunder_fsdp_padding_size``.
Expand Down
95 changes: 95 additions & 0 deletions thunder/torch/experimental/dtensor_torch_and_prims.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import partial
from collections.abc import Callable
from enum import auto, Enum
from collections.abc import Sequence
from looseversion import LooseVersion

from thunder.torch import torchsymbol, TensorLike, register_function
Expand Down Expand Up @@ -371,6 +372,100 @@ def dtensor_reciprocal(a: TensorLike) -> TensorLike:
)


if torch.distributed.is_available():
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.placement_types import Placement, DeviceMesh

def dtensor_from_local_meta(
x,
mesh,
placements,
*,
run_check: bool = False,
shape: torch.Size | None = None,
stride: tuple[int, ...] | None = None,
):
res = run_with_fake_tensor(
DTensor.from_local, x, mesh, placements, run_check=run_check, shape=shape, stride=stride
)
from thunder.torch.experimental.dtensor_proxy import proxify_dtensor

res = proxify_dtensor(res)
return res

dtensor_from_local_prim = make_prim("dtensor_from_local", "dtensor_from_local", meta=dtensor_from_local_meta)

dtensor_from_local_prim_impl = pytorchex.register_operator(
"dtensor_from_local", like=dtensor_from_local_prim, fn=DTensor.from_local
)

pytorchex.register_implementation(dtensor_from_local_prim, dtensor_from_local_prim_impl)

@dtensor_torchsymbol(DTensor.from_local, id="dtensor.torch.from_local")
def dtensor_from_local(
x,
mesh,
placements,
*,
run_check: bool = False,
shape: torch.Size | None = None,
stride: tuple[int, ...] | None = None,
) -> DTensorProxy | None:
return dtensor_from_local_prim(x, mesh, placements, run_check=run_check, shape=shape, stride=stride)

def dtensor_redistribute_meta(
dtensor,
device_mesh: DeviceMesh | None = None,
placements: Sequence[Placement] | None = None,
*,
async_op: bool = False,
) -> DTensorProxy | None:
res = run_with_fake_tensor(DTensor.redistribute, dtensor, device_mesh, placements, async_op=async_op)
from thunder.torch.experimental.dtensor_proxy import proxify_dtensor

res = proxify_dtensor(res)
return res

dtensor_redistribute_prim = make_prim(
"dtensor_redistribute", "dtensor_redistribute", meta=dtensor_redistribute_meta
)

dtensor_redistribute_prim_impl = pytorchex.register_operator(
"dtensor_redistribute", like=dtensor_redistribute_prim, fn=DTensor.redistribute
)

@dtensor_torchsymbol(DTensor.redistribute, id="dtensor.torch.redistribute")
def dtensor_redistribute(
dtensor,
device_mesh: DeviceMesh | None = None,
placements: Sequence[Placement] | None = None,
*,
async_op: bool = False,
) -> DTensorProxy | None:
return dtensor_redistribute_prim(dtensor, device_mesh, placements, async_op=async_op)

pytorchex.register_implementation(dtensor_redistribute_prim, dtensor_redistribute_prim_impl)

def dtensor_to_local_meta(dtensor, *, grad_placements: Sequence[Placement] | None = None):
res = run_with_fake_tensor(DTensor.to_local, dtensor, grad_placements=grad_placements)
from thunder.core.proxies import proxy

res = proxy(res)
return res

dtensor_to_local_prim = make_prim("dtensor_to_local", "dtensor_to_local", meta=dtensor_to_local_meta)

dtensor_to_local_prim_impl = pytorchex.register_operator(
"dtensor_to_local", like=dtensor_to_local_prim, fn=DTensor.to_local
)

pytorchex.register_implementation(dtensor_to_local_prim, dtensor_to_local_prim_impl)

@dtensor_torchsymbol(DTensor.to_local, id="dtensor.torch.to_local")
def dtensor_to_local(dtensor, *, grad_placements: Sequence[Placement] | None = None) -> DTensorProxy | None:
return dtensor_to_local_prim(dtensor, grad_placements=grad_placements)


expand = partial(expand_impl, broadcast_prim=dtensor_broadcast_in_dim_prim)
maybe_broadcast = partial(maybe_broadcast_impl, expand_fn=expand)

Expand Down
Loading