|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | import os |
15 | | -import sys |
16 | 15 | from re import escape |
17 | 16 | from unittest import mock |
18 | 17 | from unittest.mock import ANY, MagicMock, Mock, PropertyMock, call |
|
35 | 34 | ) |
36 | 35 | from lightning.fabric.strategies.strategy import _Sharded |
37 | 36 | from lightning.fabric.utilities.exceptions import MisconfigurationException |
38 | | -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1 |
39 | 37 | from lightning.fabric.utilities.seed import pl_worker_init_function, seed_everything |
40 | 38 | from lightning.fabric.utilities.warnings import PossibleUserWarning |
41 | 39 | from lightning.fabric.wrappers import _FabricDataLoader, _FabricModule, _FabricOptimizer |
@@ -1204,40 +1202,3 @@ def test_verify_launch_called(): |
1204 | 1202 | fabric.launch() |
1205 | 1203 | assert fabric._launched |
1206 | 1204 | fabric._validate_launched() |
1207 | | - |
1208 | | - |
1209 | | -@pytest.mark.skipif(sys.platform == "darwin" and not _TORCH_GREATER_EQUAL_2_1, reason="Fix for MacOS in PyTorch 2.1") |
1210 | | -@RunIf(dynamo=True) |
1211 | | -@pytest.mark.parametrize( |
1212 | | - "kwargs", |
1213 | | - [ |
1214 | | - {}, |
1215 | | - pytest.param({"precision": "16-true"}, marks=pytest.mark.xfail(raises=RuntimeError, match="Unsupported")), |
1216 | | - pytest.param({"precision": "64-true"}, marks=pytest.mark.xfail(raises=RuntimeError, match="Unsupported")), |
1217 | | - ], |
1218 | | -) |
1219 | | -def test_fabric_with_torchdynamo_fullgraph(kwargs): |
1220 | | - class MyModel(torch.nn.Module): |
1221 | | - def __init__(self): |
1222 | | - super().__init__() |
1223 | | - self.l = torch.nn.Linear(10, 10) |
1224 | | - |
1225 | | - def forward(self, x): |
1226 | | - # forward gets compiled |
1227 | | - assert torch._dynamo.is_compiling() |
1228 | | - return self.l(x) |
1229 | | - |
1230 | | - def fn(model, x): |
1231 | | - assert torch._dynamo.is_compiling() |
1232 | | - a = x * 10 |
1233 | | - return model(a) |
1234 | | - |
1235 | | - fabric = Fabric(devices=1, accelerator="cpu", **kwargs) |
1236 | | - model = MyModel() |
1237 | | - fmodel = fabric.setup(model) |
1238 | | - # we are compiling a function that calls model.forward() inside |
1239 | | - cfn = torch.compile(fn, fullgraph=True) |
1240 | | - x = torch.randn(10, 10, device=fabric.device) |
1241 | | - # pass the fabric wrapped model to the compiled function, so that it gets compiled too |
1242 | | - out = cfn(fmodel, x) |
1243 | | - assert isinstance(out, torch.Tensor) |
0 commit comments