-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_basic.py
More file actions
81 lines (68 loc) · 2.62 KB
/
test_basic.py
File metadata and controls
81 lines (68 loc) · 2.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from distconv import DCTensor, DistConvDDP, ParallelStrategy
from utils import cleanup_parallel_strategy, fp32_allclose
@pytest.fixture(scope="module")
def parallel_strategy(device: torch.device):
ps = ParallelStrategy(num_shards=4, device_type=device.type)
yield ps
cleanup_parallel_strategy(ps)
def generate_configs():
configs = []
for ndims in [1, 2, 3]:
for shard_dim in range(ndims):
for kernel_size in [1, 3, 5]:
configs.append((ndims, shard_dim, kernel_size))
return "ndims,shard_dim,kernel_size", configs
@pytest.mark.parametrize(*generate_configs())
def test_basic(
parallel_strategy: ParallelStrategy,
ndims: int,
shard_dim: int,
kernel_size: int,
device: torch.device,
):
"""
Test distributed convolution with different number of dimensions and shard dimensions.
Checks the output and gradients of the distributed convolution against the non-distributed
convolution.
Args:
parallel_strategy (ParallelStrategy): Parallel strategy for the distributed convolution.
ndims (int): Number of dimensions for the convolution (1, 2, or 3).
shard_dim (int): Dimension along which the tensor is sharded.
kernel_size (int): Size of the convolution kernel.
device (torch.device): Torch device to run test with.
"""
# Set the shard dimension for the parallel strategy
parallel_strategy.shard_dim = 2 + shard_dim
# Initialize the input tensor and convolution layer
shape = [1, 4] + [64] * ndims
x = torch.randn(*shape, device=device, requires_grad=True)
conv_class = getattr(nn, f"Conv{ndims}d")
conv = conv_class(4, 8, kernel_size=kernel_size, padding=kernel_size // 2).to(
device
)
# Perform forward and backward pass for reference (non-distributed) convolution
conv.zero_grad()
ref_y = conv(x)
ref_y.square().mean().backward()
ref_x_grad = x.grad
ref_conv_grad = conv.weight.grad
# Perform forward and backward pass for distributed convolution
conv.zero_grad()
ddp_conv = DistConvDDP(conv, parallel_strategy=parallel_strategy)
dcx = DCTensor.distribute(x, parallel_strategy)
dcy = ddp_conv(dcx)
ddpy = dcy.to_ddp()
ddpy.square().mean().backward()
x_grad = dcx.grad.to_ddp()
dc_conv_grad = conv.weight.grad
# Validate the results
if dist.get_rank() == 0:
assert fp32_allclose(ref_y, ddpy)
else:
assert ddpy.numel() == 0
assert fp32_allclose(ref_x_grad, x_grad)
assert fp32_allclose(ref_conv_grad, dc_conv_grad)