Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions thunder/dynamo/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
checkpoint_converter,
_get_example_inputs_from_placeholder,
_ThunderSplitGraphModule,
translate_dtensor_ops,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -98,6 +99,7 @@ def forward(self, l_x_: "f32[2]", y: "f32[2]"):
split_reasons: list[SplitReason] = []

nodes_in_unsupported_ctx_regions = get_nodes_in_unsupported_ctx_regions(gm)
translate_dtensor_ops(gm)

def callback(node) -> int:
nonlocal prev_value, partition_cnt, split_reasons, supported_partitions
Expand All @@ -119,9 +121,14 @@ def callback(node) -> int:
)
split_reasons.append(split_reason)
else:
is_thunder_supported, split_reason = is_node_supported_by_thunder(node)
if split_reason is not None:
split_reasons.append(split_reason)
# To support dynamo generated prims for `parallelize_module`.
# `translate_dtensor_ops` will mark the target as thunder supported if it is a DTensor operation.
if hasattr(node.target, "thunder_supported") and node.target.thunder_supported:
is_thunder_supported, split_reason = True, None
else:
is_thunder_supported, split_reason = is_node_supported_by_thunder(node)
if split_reason is not None:
split_reasons.append(split_reason)

if prev_value == is_thunder_supported: # We are in the same region.
return partition_cnt
Expand Down
75 changes: 75 additions & 0 deletions thunder/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,3 +1050,78 @@ def get_compiled_fn_and_timing(report, compile_fn, timer_fn):
err_msg = ", ".join([f"{x.name} raised exception: {x.compiled_fn}" for x in sorted_compiled_gm_to_measurement])
raise RuntimeError(f"No compiler was able to compile the graph module, {err_msg}")
return sorted_compiled_gm_to_measurement[0].compiled_fn


def translate_dtensor_ops(gm: torch.fx.GraphModule):
# We need this function because:
#
# For a program like:
# ```
# model = nn.Linear(hidden_size, hidden_size, bias=False)
# parallel_model = parallelize_module(model, mesh, {"fc1": ColwiseParallel()})
# model.fc1.weight.requires_grad = False

# # parallelize_module will handle the conversion to DTensor
# i = torch.randn(hidden_size, hidden_size)
# ````
#
# Dynamo captures an FX-Graph like:
# ```
# def forward(self, L_x_: "f32[16, 16]", L_self_modules_fc1_parameters_weight_: "f32[16, 16]"):
# l_x_ = L_x_
# l_self_modules_fc1_parameters_weight_ = L_self_modules_fc1_parameters_weight_
#
# input_tensor: "f32[16, 16]" = torch__dynamo_variables_torch_prim_from_local(l_x_); l_x_ = None
#
# linear: "f32[16, 16]" = torch._C._nn.linear(input_tensor, l_self_modules_fc1_parameters_weight_, None); input_tensor = l_self_modules_fc1_parameters_weight_ = None
#
# outputs: "f32[16, 16]" = torch__dynamo_variables_tensor_prim_redistribute(linear); linear = None
#
# hook_result: "f32[16, 8]" = torch__dynamo_variables_tensor_prim_to_local(outputs); outputs = None
# return (hook_result,)
# ```
# where:
# 1. In the FX Graph, the Tensor Parallel computation is decomposed into primitive operations such as `torch__dynamo_variables_torch_prim_from_local`, `torch__dynamo_variables_tensor_prim_redistribute`, and others.
# 2. It is important to note that these decompositions actually capture (close over) values such as `placements` and other metadata.
# For example, to understand the placements to which the output will be redistributed using `torch__dynamo_variables_tensor_prim_redistribute`,
# we need to use `inspect.getclosurevars(node.target)` to examine the values (like placements) that are captured and used during execution.
# The reference for where this closure is created can be found at:
# https://github.com/pytorch/pytorch/blob/0ab075a69e4577a60c4dcbff7bcc2ecd0a15ce46/torch/_dynamo/variables/tensor.py#L1186-L1210

from thunder.torch.experimental.dtensor_torch_and_prims import (
dtensor_from_local_prim,
dtensor_redistribute_prim,
dtensor_to_local_prim,
)

for node in gm.graph.nodes:
try:
closure_vars = inspect.getclosurevars(node.target)

if "from_local" in node.target.__name__:
mesh = closure_vars.nonlocals["args_as_value"][0]
placements = closure_vars.nonlocals["args_as_value"][1]

def dtensor_from_local_prim_wrapper(x, mesh=mesh, placements=placements):
return dtensor_from_local_prim(x, mesh, placements)

dtensor_from_local_prim_wrapper.thunder_supported = True
node.target = dtensor_from_local_prim_wrapper
if "redistribute" in node.target.__name__:
kwargs = closure_vars.nonlocals["kwargs_as_value"]
placements = kwargs["placements"]

def dtensor_redistribute_prim_wrapper(x, placements=placements):
return dtensor_redistribute_prim(x, placements=placements)

dtensor_redistribute_prim_wrapper.thunder_supported = True
node.target = dtensor_redistribute_prim_wrapper
if "to_local" in node.target.__name__:

def dtensor_to_local_prim_wrapper(x):
return dtensor_to_local_prim(x)

dtensor_to_local_prim_wrapper.thunder_supported = True
node.target = dtensor_to_local_prim_wrapper
except Exception:
pass
42 changes: 42 additions & 0 deletions thunder/tests/distributed/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
from torch.distributed._tensor import DeviceMesh, distribute_tensor
from torch.distributed.tensor.placement_types import Shard, Replicate
from torch.testing._internal.distributed._tensor.common_dtensor import DTensorConverter
from torch.distributed.tensor.parallel import (
parallelize_module,
ColwiseParallel,
)

from torch.testing._internal import common_utils

Expand All @@ -23,6 +27,7 @@
from thunder.tests.utils import is_output_differentiable, filter_differentiable_outputs
import thunder.core.dtypes as dtypes
from thunder.core.pytree import tree_flatten
from thunder.dynamo import thunderfx


# NOTE: We run all these similar functions seperately
Expand Down Expand Up @@ -272,6 +277,43 @@ def fn(x):

torch.testing.assert_close(actual, expected)

@common_utils.parametrize("jit_fn", (thunder.jit, thunderfx), name_fn=lambda jit_fn: jit_fn.__name__)
def test_dtensor_columnwise_parallel(self, jit_fn):
num_devices = self.world_size
mesh = DeviceMesh("cuda", list(range(num_devices)))
dim_size = 16
in_dtensor = torch.randn(dim_size, dim_size, requires_grad=False)
m = torch.nn.Linear(dim_size, dim_size)
m.requires_grad_(False)

parallelized_model = parallelize_module(m, mesh, ColwiseParallel())

# `parallelize_module` sets `requires_grad` to True, set it to False again.
parallelized_model.requires_grad_(False)

actual = parallelized_model(in_dtensor)
expected = m(in_dtensor)
torch.testing.assert_close(actual, expected)

tmodel = jit_fn(parallelized_model, nv_enable_linear=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is nv_enable_linear True? I'm not objecting, I'm just curious

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mainly wanted to verify it for nvFuser multi-device.


if jit_fn == thunder.jit:
# Original error caught by the interpreter:
# File "/opt/pytorch/lightning-thunder/thunder/core/jit_ext.py", line 444, in _general_jit_getattr_lookaside
# obj.original_value.__dict__,
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^
# AttributeError: 'object' object has no attribute '__dict__'. Did you mean: '__dir__'?
with self.assertRaises(thunder.core.interpreter.InterpreterError):
actual = tmodel(in_dtensor)
else:
actual = tmodel(in_dtensor)
torch.testing.assert_close(actual, expected)

if jit_fn == thunderfx:
assert len(tmodel._backend.subgraph_infos) == 1
assert len(tmodel._backend.subgraph_infos[0].thunder_compiled_fns) == 1
assert len(tmodel._backend.subgraph_infos[0].split_reasons) == 0

@common_utils.parametrize("executor", tuple(executors_map.keys()))
@common_utils.parametrize(
"input_shardings",
Expand Down
35 changes: 35 additions & 0 deletions thunder/torch/experimental/dtensor_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,41 @@ def placements(self):
def device_mesh(self):
return self.spec._o.device_mesh

@staticmethod
def from_local(
x,
mesh,
placements,
*,
run_check: bool = False,
shape: torch.Size | None = None,
stride: tuple[int, ...] | None = None,
):
import thunder.torch.experimental.dtensor_torch_and_prims as dtensor_torch_and_prims

res = dtensor_torch_and_prims.dtensor_from_local_prim(
x, mesh, placements, run_check=run_check, shape=shape, stride=stride
)
return res

def redistribute(
self,
device_mesh: "Optional[DeviceMesh]" = None,
placements: "Optional[Sequence[Placement]]" = None,
*,
async_op: bool = False,
) -> "DTensor":
import thunder.torch.experimental.dtensor_torch_and_prims as dtensor_torch_and_prims

res = dtensor_torch_and_prims.dtensor_redistribute_prim(self, device_mesh, placements, async_op=async_op)
return res

def to_local(self, *, grad_placements: "Optional[Sequence[Placement]]" = None):
import thunder.torch.experimental.dtensor_torch_and_prims as dtensor_torch_and_prims

res = dtensor_torch_and_prims.dtensor_to_local_prim(self, grad_placements=grad_placements)
return res

def replace(self, **changes):
r"""Return a copy of the TensorProxy object with new values for the specified fields as given to the constructor as arguments.
Valid keyword arguments are ``name``, ``history``, ``shape``, ``dtype``, ``device``, ``requires_grad``, ``distparallel_type``, ``thunder_fsdp_padding_size``.
Expand Down
95 changes: 95 additions & 0 deletions thunder/torch/experimental/dtensor_torch_and_prims.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import partial
from collections.abc import Callable
from enum import auto, Enum
from collections.abc import Sequence
from looseversion import LooseVersion

from thunder.torch import torchsymbol, TensorLike, register_function
Expand Down Expand Up @@ -371,6 +372,100 @@ def dtensor_reciprocal(a: TensorLike) -> TensorLike:
)


if torch.distributed.is_available():
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.placement_types import Placement, DeviceMesh

def dtensor_from_local_meta(
x,
mesh,
placements,
*,
run_check: bool = False,
shape: torch.Size | None = None,
stride: tuple[int, ...] | None = None,
):
res = run_with_fake_tensor(
DTensor.from_local, x, mesh, placements, run_check=run_check, shape=shape, stride=stride
)
from thunder.torch.experimental.dtensor_proxy import proxify_dtensor

res = proxify_dtensor(res)
return res

dtensor_from_local_prim = make_prim("dtensor_from_local", "dtensor_from_local", meta=dtensor_from_local_meta)

dtensor_from_local_prim_impl = pytorchex.register_operator(
"dtensor_from_local", like=dtensor_from_local_prim, fn=DTensor.from_local
)

pytorchex.register_implementation(dtensor_from_local_prim, dtensor_from_local_prim_impl)

@dtensor_torchsymbol(DTensor.from_local, id="dtensor.torch.from_local")
def dtensor_from_local(
x,
mesh,
placements,
*,
run_check: bool = False,
shape: torch.Size | None = None,
stride: tuple[int, ...] | None = None,
) -> DTensorProxy | None:
return dtensor_from_local_prim(x, mesh, placements, run_check=run_check, shape=shape, stride=stride)

def dtensor_redistribute_meta(
dtensor,
device_mesh: DeviceMesh | None = None,
placements: Sequence[Placement] | None = None,
*,
async_op: bool = False,
) -> DTensorProxy | None:
res = run_with_fake_tensor(DTensor.redistribute, dtensor, device_mesh, placements, async_op=async_op)
from thunder.torch.experimental.dtensor_proxy import proxify_dtensor

res = proxify_dtensor(res)
return res

dtensor_redistribute_prim = make_prim(
"dtensor_redistribute", "dtensor_redistribute", meta=dtensor_redistribute_meta
)

dtensor_redistribute_prim_impl = pytorchex.register_operator(
"dtensor_redistribute", like=dtensor_redistribute_prim, fn=DTensor.redistribute
)

@dtensor_torchsymbol(DTensor.redistribute, id="dtensor.torch.redistribute")
def dtensor_redistribute(
dtensor,
device_mesh: DeviceMesh | None = None,
placements: Sequence[Placement] | None = None,
*,
async_op: bool = False,
) -> DTensorProxy | None:
return dtensor_redistribute_prim(dtensor, device_mesh, placements, async_op=async_op)

pytorchex.register_implementation(dtensor_redistribute_prim, dtensor_redistribute_prim_impl)

def dtensor_to_local_meta(dtensor, *, grad_placements: Sequence[Placement] | None = None):
res = run_with_fake_tensor(DTensor.to_local, dtensor, grad_placements=grad_placements)
from thunder.core.proxies import proxy

res = proxy(res)
return res

dtensor_to_local_prim = make_prim("dtensor_to_local", "dtensor_to_local", meta=dtensor_to_local_meta)

dtensor_to_local_prim_impl = pytorchex.register_operator(
"dtensor_to_local", like=dtensor_to_local_prim, fn=DTensor.to_local
)

pytorchex.register_implementation(dtensor_to_local_prim, dtensor_to_local_prim_impl)

@dtensor_torchsymbol(DTensor.to_local, id="dtensor.torch.to_local")
def dtensor_to_local(dtensor, *, grad_placements: Sequence[Placement] | None = None) -> DTensorProxy | None:
return dtensor_to_local_prim(dtensor, grad_placements=grad_placements)


expand = partial(expand_impl, broadcast_prim=dtensor_broadcast_in_dim_prim)
maybe_broadcast = partial(maybe_broadcast_impl, expand_fn=expand)

Expand Down
Loading