Skip to content

Commit 2cfe685

Browse files
authored
[exmaple] add vit missing functions (#2154)
1 parent a7d95b7 commit 2cfe685

File tree

2 files changed

+61
-4
lines changed

2 files changed

+61
-4
lines changed

examples/images/vit/test_vit.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,55 @@
1+
import os
2+
import random
13
from functools import partial
24

5+
import numpy as np
36
import pytest
47
import torch
58
import torch.multiprocessing as mp
69
from torch.nn.parallel import DistributedDataParallel as DDP
7-
from utils.util import set_seed, tensor_equal, tensor_shard_equal
810
from vit import get_training_components
911

1012
import colossalai
13+
from colossalai.context import ParallelMode
1114
from colossalai.context.parallel_mode import ParallelMode
1215
from colossalai.core import global_context as gpc
1316
from colossalai.nn.parallel.data_parallel import ColoDDP
14-
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec
17+
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec
1518
from colossalai.testing import rerun_if_address_is_in_use
1619
from colossalai.utils import free_port
1720
from colossalai.utils.cuda import get_current_device
1821
from colossalai.utils.model.colo_init_context import ColoInitContext
1922

2023

24+
def set_seed(seed):
25+
random.seed(seed)
26+
os.environ['PYTHONHASHSEED'] = str(seed)
27+
np.random.seed(seed)
28+
torch.manual_seed(seed)
29+
torch.cuda.manual_seed(seed)
30+
torch.backends.cudnn.deterministic = True
31+
32+
33+
def tensor_equal(A, B):
34+
return torch.allclose(A, B, rtol=1e-3, atol=1e-1)
35+
36+
37+
def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor):
38+
assert tensor.ndim == shard.ndim
39+
if tensor.shape == shard.shape:
40+
return tensor_equal(tensor, shard)
41+
else:
42+
dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape))
43+
if dims_not_eq.numel() == 1:
44+
# 1D shard
45+
dim = dims_not_eq.item()
46+
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
47+
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
48+
return tensor_equal(tensor.chunk(world_size, dim)[rank], shard)
49+
else:
50+
raise
51+
52+
2153
# Only for all Linear, it's 1d_row split because Linear will be transposed when calculating.
2254
# But for other layers, it's 1d_col split.
2355
# Layernorm is not supported for now.

examples/images/vit/vit.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,34 @@
1+
from abc import ABC, abstractmethod
2+
13
import torch
24
import torch.nn as nn
3-
from utils.dummy_data_generator import DummyDataGenerator
5+
from transformers import ViTConfig, ViTForImageClassification
46

57
from colossalai.utils.cuda import get_current_device
6-
from transformers import ViTConfig, ViTForImageClassification
8+
9+
10+
class DummyDataGenerator(ABC):
11+
12+
def __init__(self, length=10):
13+
self.length = length
14+
15+
@abstractmethod
16+
def generate(self):
17+
pass
18+
19+
def __iter__(self):
20+
self.step = 0
21+
return self
22+
23+
def __next__(self):
24+
if self.step < self.length:
25+
self.step += 1
26+
return self.generate()
27+
else:
28+
raise StopIteration
29+
30+
def __len__(self):
31+
return self.length
732

833

934
class DummyDataLoader(DummyDataGenerator):

0 commit comments

Comments
 (0)