Skip to content

Commit 5c36e99

Browse files
authored
Drop torch.compile fullgraph test (#19166)
1 parent 3b1643c commit 5c36e99

File tree

1 file changed

+0
-39
lines changed

1 file changed

+0
-39
lines changed

tests/tests_fabric/test_fabric.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15-
import sys
1615
from re import escape
1716
from unittest import mock
1817
from unittest.mock import ANY, MagicMock, Mock, PropertyMock, call
@@ -35,7 +34,6 @@
3534
)
3635
from lightning.fabric.strategies.strategy import _Sharded
3736
from lightning.fabric.utilities.exceptions import MisconfigurationException
38-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
3937
from lightning.fabric.utilities.seed import pl_worker_init_function, seed_everything
4038
from lightning.fabric.utilities.warnings import PossibleUserWarning
4139
from lightning.fabric.wrappers import _FabricDataLoader, _FabricModule, _FabricOptimizer
@@ -1204,40 +1202,3 @@ def test_verify_launch_called():
12041202
fabric.launch()
12051203
assert fabric._launched
12061204
fabric._validate_launched()
1207-
1208-
1209-
@pytest.mark.skipif(sys.platform == "darwin" and not _TORCH_GREATER_EQUAL_2_1, reason="Fix for MacOS in PyTorch 2.1")
1210-
@RunIf(dynamo=True)
1211-
@pytest.mark.parametrize(
1212-
"kwargs",
1213-
[
1214-
{},
1215-
pytest.param({"precision": "16-true"}, marks=pytest.mark.xfail(raises=RuntimeError, match="Unsupported")),
1216-
pytest.param({"precision": "64-true"}, marks=pytest.mark.xfail(raises=RuntimeError, match="Unsupported")),
1217-
],
1218-
)
1219-
def test_fabric_with_torchdynamo_fullgraph(kwargs):
1220-
class MyModel(torch.nn.Module):
1221-
def __init__(self):
1222-
super().__init__()
1223-
self.l = torch.nn.Linear(10, 10)
1224-
1225-
def forward(self, x):
1226-
# forward gets compiled
1227-
assert torch._dynamo.is_compiling()
1228-
return self.l(x)
1229-
1230-
def fn(model, x):
1231-
assert torch._dynamo.is_compiling()
1232-
a = x * 10
1233-
return model(a)
1234-
1235-
fabric = Fabric(devices=1, accelerator="cpu", **kwargs)
1236-
model = MyModel()
1237-
fmodel = fabric.setup(model)
1238-
# we are compiling a function that calls model.forward() inside
1239-
cfn = torch.compile(fn, fullgraph=True)
1240-
x = torch.randn(10, 10, device=fabric.device)
1241-
# pass the fabric wrapped model to the compiled function, so that it gets compiled too
1242-
out = cfn(fmodel, x)
1243-
assert isinstance(out, torch.Tensor)

0 commit comments

Comments
 (0)