Skip to content

Commit b50d46e

Browse files
committed
Fix format
1 parent 7183944 commit b50d46e

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

examples/fabric/tensor_parallel/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import lightning as L
22
import torch
33
import torch.nn.functional as F
4-
from data import RandomTokenDataset
54
from lightning.fabric.strategies import ModelParallelStrategy
65
from model import ModelArgs, Transformer
76
from parallelism import parallelize
87
from torch.distributed.tensor.parallel import loss_parallel
98
from torch.utils.data import DataLoader
109

10+
from data import RandomTokenDataset
11+
1112

1213
def train():
1314
strategy = ModelParallelStrategy(

examples/pytorch/tensor_parallel/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import lightning as L
22
import torch
33
import torch.nn.functional as F
4-
from data import RandomTokenDataset
54
from lightning.pytorch.strategies import ModelParallelStrategy
65
from model import ModelArgs, Transformer
76
from parallelism import parallelize
87
from torch.distributed.tensor.parallel import loss_parallel
98
from torch.utils.data import DataLoader
109

10+
from data import RandomTokenDataset
11+
1112

1213
class Llama3(L.LightningModule):
1314
def __init__(self):

0 commit comments

Comments
 (0)