|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 | import os
|
15 |
| -import sys |
16 | 15 | from re import escape
|
17 | 16 | from unittest import mock
|
18 | 17 | from unittest.mock import ANY, MagicMock, Mock, PropertyMock, call
|
|
35 | 34 | )
|
36 | 35 | from lightning.fabric.strategies.strategy import _Sharded
|
37 | 36 | from lightning.fabric.utilities.exceptions import MisconfigurationException
|
38 |
| -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1 |
39 | 37 | from lightning.fabric.utilities.seed import pl_worker_init_function, seed_everything
|
40 | 38 | from lightning.fabric.utilities.warnings import PossibleUserWarning
|
41 | 39 | from lightning.fabric.wrappers import _FabricDataLoader, _FabricModule, _FabricOptimizer
|
@@ -1204,40 +1202,3 @@ def test_verify_launch_called():
|
1204 | 1202 | fabric.launch()
|
1205 | 1203 | assert fabric._launched
|
1206 | 1204 | fabric._validate_launched()
|
1207 |
| - |
1208 |
| - |
1209 |
| -@pytest.mark.skipif(sys.platform == "darwin" and not _TORCH_GREATER_EQUAL_2_1, reason="Fix for MacOS in PyTorch 2.1") |
1210 |
| -@RunIf(dynamo=True) |
1211 |
| -@pytest.mark.parametrize( |
1212 |
| - "kwargs", |
1213 |
| - [ |
1214 |
| - {}, |
1215 |
| - pytest.param({"precision": "16-true"}, marks=pytest.mark.xfail(raises=RuntimeError, match="Unsupported")), |
1216 |
| - pytest.param({"precision": "64-true"}, marks=pytest.mark.xfail(raises=RuntimeError, match="Unsupported")), |
1217 |
| - ], |
1218 |
| -) |
1219 |
| -def test_fabric_with_torchdynamo_fullgraph(kwargs): |
1220 |
| - class MyModel(torch.nn.Module): |
1221 |
| - def __init__(self): |
1222 |
| - super().__init__() |
1223 |
| - self.l = torch.nn.Linear(10, 10) |
1224 |
| - |
1225 |
| - def forward(self, x): |
1226 |
| - # forward gets compiled |
1227 |
| - assert torch._dynamo.is_compiling() |
1228 |
| - return self.l(x) |
1229 |
| - |
1230 |
| - def fn(model, x): |
1231 |
| - assert torch._dynamo.is_compiling() |
1232 |
| - a = x * 10 |
1233 |
| - return model(a) |
1234 |
| - |
1235 |
| - fabric = Fabric(devices=1, accelerator="cpu", **kwargs) |
1236 |
| - model = MyModel() |
1237 |
| - fmodel = fabric.setup(model) |
1238 |
| - # we are compiling a function that calls model.forward() inside |
1239 |
| - cfn = torch.compile(fn, fullgraph=True) |
1240 |
| - x = torch.randn(10, 10, device=fabric.device) |
1241 |
| - # pass the fabric wrapped model to the compiled function, so that it gets compiled too |
1242 |
| - out = cfn(fmodel, x) |
1243 |
| - assert isinstance(out, torch.Tensor) |
0 commit comments