Skip to content

Commit d7040e6

Browse files
Revert "[dynamo][guards] 1/N Guard selectively for DTensor (pytorch#165824)"
This reverts commit ee7434b. Reverted pytorch#165824 on behalf of https://github.com/anijain2305 due to internal job failed ([comment](pytorch#165824 (comment)))
1 parent 35f3572 commit d7040e6

File tree

4 files changed

+14
-93
lines changed

4 files changed

+14
-93
lines changed

test/distributed/tensor/test_dtensor_compile.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -464,25 +464,6 @@ def g(x):
464464
run(g, 64, 8)
465465
self.assertEqual(cnt.frame_count, 2)
466466

467-
def test_dtensor_requires_grad_recompile(self):
468-
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
469-
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
470-
471-
@torch.compile(backend=cnt, fullgraph=True)
472-
def f(x):
473-
y = x * x
474-
return y.to_local()
475-
476-
full_x = torch.randn(8, 8, requires_grad=False)
477-
x = distribute_tensor(full_x, mesh, [Shard(0)])
478-
f(x)
479-
480-
full_x = torch.randn(8, 8, requires_grad=True)
481-
x = distribute_tensor(full_x, mesh, [Shard(0)])
482-
f(x)
483-
484-
self.assertEqual(cnt.frame_count, 2)
485-
486467
def test_dtensor_attribute_access_on_intermediate(self):
487468
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
488469

torch/_dynamo/guards.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2150,19 +2150,6 @@ def metadata_checker(x: Any) -> bool:
21502150
metadata_checker, get_verbose_code_parts(global_name, guard)
21512151
)
21522152

2153-
def DTENSOR_SPEC_MATCH(self, guard: Guard) -> None:
2154-
# Copied from DTensor __metadata_guard__
2155-
# TODO - Consider moving this to C++ if stable
2156-
value = deepcopy(self.get(guard.name))
2157-
2158-
def guard_fn(x: Any) -> bool:
2159-
return x._check_equals(value, skip_shapes=True)
2160-
2161-
code = f"__dtensor_spec_{id(guard_fn)}"
2162-
self.get_guard_manager(guard).add_lambda_guard(
2163-
guard_fn, get_verbose_code_parts(code, guard)
2164-
)
2165-
21662153
def EQUALS_MATCH(self, guard: Guard, recompile_hint: Optional[str] = None) -> None:
21672154
ref = self.arg_ref(guard)
21682155
val = self.get(guard.name)

torch/_dynamo/variables/builder.py

Lines changed: 14 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -2229,70 +2229,25 @@ def wrap_tensor(self, value: torch.Tensor):
22292229
if isinstance(source, GradSource) and is_from_optimizer_source(source):
22302230
guard_type = GuardBuilder.NOT_NONE_MATCH
22312231

2232-
is_dtensor = torch.distributed.is_available() and isinstance(
2233-
value, torch.distributed.tensor.DTensor
2234-
)
2235-
if not is_dtensor:
2236-
# We guard on the _local_tensor and the _spec, and therefore we dont
2237-
# have to guard on the outer DTensor.
2238-
self.install_guards(
2239-
functools.partial(
2240-
guard_type,
2241-
value=(
2242-
value
2243-
if isinstance(source, NumpyTensorSource)
2244-
else TensorWeakRef(value)
2245-
),
2246-
)
2232+
self.install_guards(
2233+
functools.partial(
2234+
guard_type,
2235+
value=(
2236+
value
2237+
if isinstance(source, NumpyTensorSource)
2238+
else TensorWeakRef(value)
2239+
),
22472240
)
2241+
)
22482242

22492243
# We install TYPE_MATCH guards for traceable wrapper subclass object,
22502244
# and recursively install corresponding guard for each inner attribute.
22512245
if is_traceable_wrapper_subclass(value):
2252-
# Tensor subclass guards are very expensive because they are
2253-
# implemented in Python. Since DTensor is PyTorch-maintained class,
2254-
# we can skip a lot of these guards.
2255-
if is_dtensor:
2256-
self.install_guards(GuardBuilder.TYPE_MATCH)
2257-
2258-
# The inner tensor name is always _local_tensor. If its not, we
2259-
# raise assertion to update the check accordingly.
2260-
inner_tensor_name = value.__tensor_flatten__()[0][0]
2261-
if inner_tensor_name != "_local_tensor":
2262-
raise RuntimeError(
2263-
"Expecting Dtensor inner tensor name to be _local_tensor"
2264-
)
2265-
2266-
# Now selectively guard on the flattening context
2267-
flattening_ctx = value.__tensor_flatten__()[1]
2268-
# This is supposed to be (self._spec, self.requires_grad)
2269-
if not (
2270-
len(flattening_ctx) == 2
2271-
and flattening_ctx[0] == value._spec
2272-
and flattening_ctx[1] == value.requires_grad
2273-
):
2274-
# If not, raise an assertion to update to the new guards
2275-
raise RuntimeError(
2276-
"Expecting Dtensor flattening ctx to be _spec, requires_grad"
2277-
)
2278-
# Guard on the dtensor spec
2279-
install_guard(
2280-
AttrSource(self.source, "_spec").make_guard(
2281-
GuardBuilder.DTENSOR_SPEC_MATCH
2282-
)
2283-
)
2284-
# Move this to C++
2285-
install_guard(
2286-
AttrSource(self.source, "requires_grad").make_guard(
2287-
GuardBuilder.EQUALS_MATCH
2288-
)
2289-
)
2290-
else:
2291-
self.install_guards(GuardBuilder.TENSOR_SUBCLASS_METADATA_MATCH)
2292-
self.install_guards(GuardBuilder.TYPE_MATCH)
2293-
install_guard(
2294-
SubclassAttrListSource(source).make_guard(GuardBuilder.EQUALS_MATCH)
2295-
)
2246+
self.install_guards(GuardBuilder.TENSOR_SUBCLASS_METADATA_MATCH)
2247+
self.install_guards(GuardBuilder.TYPE_MATCH)
2248+
install_guard(
2249+
SubclassAttrListSource(source).make_guard(GuardBuilder.EQUALS_MATCH)
2250+
)
22962251

22972252
attrs, _ = value.__tensor_flatten__()
22982253
for attr in attrs:

torch/distributed/tensor/_api.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -671,8 +671,6 @@ def __get_tensor_shard__(self, index):
671671
def __metadata_guard__(
672672
cls, orig: tuple[DTensorSpec, bool], other: tuple[DTensorSpec, bool]
673673
) -> bool:
674-
# TODO - delete this - This is now unused after the PR -
675-
# https://github.com/pytorch/pytorch/pull/165824
676674
orig_spec, orig_requires_grad = orig
677675
other_spec, other_requires_grad = other
678676
return (

0 commit comments

Comments
 (0)