Skip to content

Commit d95e2fd

Browse files
committed
fixed unit tests
Signed-off-by: Suguna Velury <[email protected]>
1 parent f81c370 commit d95e2fd

File tree

2 files changed

+114
-113
lines changed

2 files changed

+114
-113
lines changed

modelopt/torch/export/unified_export_hf.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@
2626

2727
import torch
2828
import torch.nn as nn
29-
from accelerate import Accelerator
29+
30+
try:
31+
from accelerate import Accelerator
32+
except ImportError: # pragma: no cover
33+
Accelerator = None
3034
from safetensors.torch import save_file
3135

3236
from modelopt.torch.quantization import set_quantizer_by_cfg_context

tests/gpu/torch/export/test_fsdp2_export.py

Lines changed: 109 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -30,78 +30,73 @@
3030
from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, patch_fsdp_mp_dtypes
3131

3232

33-
@pytest.fixture(autouse=True)
34-
def patch_fsdp_dtypes():
35-
"""Automatically patch FSDP mixed precision dtypes for all tests in this module."""
36-
with patch_fsdp_mp_dtypes():
37-
yield
38-
39-
4033
def _update_weight_test(rank, size):
4134
"""Test fsdp2 weight update context for weight update -> only value changed"""
4235
from torch.distributed._composable.fsdp import fully_shard
4336

44-
# Define and shard model
45-
model = ToyModel(dims=[4, 4], bias=False).to("cuda")
37+
with patch_fsdp_mp_dtypes():
38+
# Define and shard model
39+
model = ToyModel(dims=[4, 4], bias=False).to("cuda")
4640

47-
assert not torch.equal(
48-
model.linears.weight.data,
49-
torch.zeros(4, 4).to(model.linears.weight.device).to(model.linears.weight.dtype),
50-
)
41+
assert not torch.equal(
42+
model.linears.weight.data,
43+
torch.zeros(4, 4).to(model.linears.weight.device).to(model.linears.weight.dtype),
44+
)
5145

52-
fully_shard(model.linears)
53-
fully_shard(model)
46+
fully_shard(model.linears)
47+
fully_shard(model)
5448

55-
torch.distributed.barrier()
49+
torch.distributed.barrier()
5650

57-
for name, module in model.named_modules():
58-
if "linears" in name:
59-
with fsdp2_aware_weight_update(model, module):
60-
module.weight.data = torch.zeros_like(module.weight.data)
51+
for name, module in model.named_modules():
52+
if "linears" in name:
53+
with fsdp2_aware_weight_update(model, module):
54+
module.weight.data = torch.zeros_like(module.weight.data)
6155

62-
torch.distributed.barrier()
63-
model.linears.unshard()
56+
torch.distributed.barrier()
57+
model.linears.unshard()
6458

65-
# Check if weights are as expected after unshard
66-
for param in model.parameters():
67-
assert torch.allclose(
68-
torch.zeros(4, 4).to(param.data.device).to(param.data.dtype), param.data
69-
)
59+
# Check if weights are as expected after unshard
60+
for param in model.parameters():
61+
assert torch.allclose(
62+
torch.zeros(4, 4).to(param.data.device).to(param.data.dtype), param.data
63+
)
7064

71-
# Check if forward pass is as expected
72-
model.linears.reshard()
73-
output = model(torch.randn(4, 4).to(model.linears.weight.device))
74-
assert torch.allclose(torch.zeros(4, 4).to(output.device).to(output.dtype), output)
65+
# Check if forward pass is as expected
66+
model.linears.reshard()
67+
output = model(torch.randn(4, 4).to(model.linears.weight.device))
68+
assert torch.allclose(torch.zeros(4, 4).to(output.device).to(output.dtype), output)
7569

7670

7771
def _compress_weight_test(rank, size):
7872
"""Test fsdp2 weight update context for weight compression -> only value,shape and dtype changed"""
7973
from torch.distributed._composable.fsdp import fully_shard
8074

81-
# Define and shard model
82-
model = ToyModel(dims=[6, 6], bias=False).to("cuda")
75+
with patch_fsdp_mp_dtypes():
76+
# Define and shard model
77+
model = ToyModel(dims=[6, 6], bias=False).to("cuda")
8378

84-
assert not torch.equal(
85-
model.linears.weight.data,
86-
torch.zeros(6, 6).to(model.linears.weight.device).to(model.linears.weight.dtype),
87-
)
79+
assert not torch.equal(
80+
model.linears.weight.data,
81+
torch.zeros(6, 6).to(model.linears.weight.device).to(model.linears.weight.dtype),
82+
)
8883

89-
fully_shard(model.linears)
90-
fully_shard(model)
91-
torch.distributed.barrier()
84+
fully_shard(model.linears)
85+
fully_shard(model)
86+
torch.distributed.barrier()
9287

93-
for name, module in model.named_modules():
94-
if "linears" in name:
95-
with fsdp2_aware_weight_update(model, module):
96-
module.weight.data = (
97-
torch.zeros(2, 2).to(torch.float8_e4m3fn).to(module.weight.data.device)
98-
)
88+
for name, module in model.named_modules():
89+
if "linears" in name:
90+
with fsdp2_aware_weight_update(model, module):
91+
module.weight.data = (
92+
torch.zeros(2, 2).to(torch.float8_e4m3fn).to(module.weight.data.device)
93+
)
9994

100-
torch.distributed.barrier()
101-
model.linears.unshard()
102-
# Check if weights are as expected after unshard
103-
for param in model.parameters():
104-
assert param.data.dtype == torch.float8_e4m3fn
95+
torch.distributed.barrier()
96+
model.linears.unshard()
97+
# Check if weights are as expected after unshard
98+
for param in model.parameters():
99+
assert param.data.dtype == torch.float8_e4m3fn
105100

106101

107102
def _compare_parameters_and_buffers(model1, model2):
@@ -126,97 +121,99 @@ def _fuse_layers(rank, size, quant_config):
126121

127122
from torch.distributed._composable.fsdp import fully_shard
128123

129-
# Initialize model
130-
model = SmallQKVModel(dim=32).to("cuda")
131-
non_fsdp_model = SmallQKVModel(dim=32).to("cuda")
132-
non_fsdp_model.load_state_dict(copy.deepcopy(model.state_dict()))
133-
model.eval()
134-
non_fsdp_model.eval()
124+
with patch_fsdp_mp_dtypes():
125+
# Initialize model
126+
model = SmallQKVModel(dim=32).to("cuda")
127+
non_fsdp_model = SmallQKVModel(dim=32).to("cuda")
128+
non_fsdp_model.load_state_dict(copy.deepcopy(model.state_dict()))
129+
model.eval()
130+
non_fsdp_model.eval()
135131

136-
_compare_parameters_and_buffers(model, non_fsdp_model)
132+
_compare_parameters_and_buffers(model, non_fsdp_model)
137133

138-
# Create calibration data ONCE
139-
calib_data = torch.randn(1, 32, device="cuda")
134+
# Create calibration data ONCE
135+
calib_data = torch.randn(1, 32, device="cuda")
140136

141-
def calib_fn(x):
142-
return x(calib_data)
137+
def calib_fn(x):
138+
return x(calib_data)
143139

144-
# Shard model
145-
fully_shard(model)
146-
torch.distributed.barrier()
140+
# Shard model
141+
fully_shard(model)
142+
torch.distributed.barrier()
147143

148-
# Quantize model
149-
mtq.quantize(model, quant_config, calib_fn)
150-
mtq.quantize(non_fsdp_model, quant_config, calib_fn)
144+
# Quantize model
145+
mtq.quantize(model, quant_config, calib_fn)
146+
mtq.quantize(non_fsdp_model, quant_config, calib_fn)
151147

152-
torch.distributed.barrier()
148+
torch.distributed.barrier()
153149

154-
model.apply_embed = True
155-
non_fsdp_model.apply_embed = True
150+
model.apply_embed = True
151+
non_fsdp_model.apply_embed = True
156152

157-
requantize_resmooth_fused_llm_layers(model)
158-
requantize_resmooth_fused_llm_layers(non_fsdp_model)
153+
requantize_resmooth_fused_llm_layers(model)
154+
requantize_resmooth_fused_llm_layers(non_fsdp_model)
159155

160-
torch.distributed.barrier()
156+
torch.distributed.barrier()
161157

162-
# Unshard model
163-
model.unshard()
158+
# Unshard model
159+
model.unshard()
164160

165-
_compare_parameters_and_buffers(model, non_fsdp_model)
161+
_compare_parameters_and_buffers(model, non_fsdp_model)
166162

167163

168164
def _export_quantized_weight_test(rank, size, quant_config):
169165
import copy
170166

171167
from torch.distributed._composable.fsdp import fully_shard
172168

173-
# Initialize model
174-
model = SmallQKVModel(dim=32).to("cuda")
175-
non_fsdp_model = SmallQKVModel(dim=32).to("cuda")
176-
non_fsdp_model.load_state_dict(copy.deepcopy(model.state_dict()))
177-
model.eval()
178-
non_fsdp_model.eval()
179-
_compare_parameters_and_buffers(model, non_fsdp_model)
169+
with patch_fsdp_mp_dtypes():
170+
# Initialize model
171+
model = SmallQKVModel(dim=32).to("cuda")
172+
non_fsdp_model = SmallQKVModel(dim=32).to("cuda")
173+
non_fsdp_model.load_state_dict(copy.deepcopy(model.state_dict()))
174+
model.eval()
175+
non_fsdp_model.eval()
176+
_compare_parameters_and_buffers(model, non_fsdp_model)
180177

181-
# Create calibration data ONCE
182-
calib_data = torch.randn(1, 32, device="cuda")
178+
# Create calibration data ONCE
179+
calib_data = torch.randn(1, 32, device="cuda")
183180

184-
def calib_fn(x):
185-
return x(calib_data)
181+
def calib_fn(x):
182+
return x(calib_data)
186183

187-
# Shard model
188-
fully_shard(model)
189-
torch.distributed.barrier()
184+
# Shard model
185+
fully_shard(model)
186+
torch.distributed.barrier()
190187

191-
# Quantize model
192-
mtq.quantize(model, quant_config, calib_fn)
193-
mtq.quantize(non_fsdp_model, quant_config, calib_fn)
188+
# Quantize model
189+
mtq.quantize(model, quant_config, calib_fn)
190+
mtq.quantize(non_fsdp_model, quant_config, calib_fn)
194191

195-
torch.distributed.barrier()
192+
torch.distributed.barrier()
196193

197-
model.apply_embed = True
198-
non_fsdp_model.apply_embed = True
194+
model.apply_embed = True
195+
non_fsdp_model.apply_embed = True
199196

200-
requantize_resmooth_fused_llm_layers(model)
201-
requantize_resmooth_fused_llm_layers(non_fsdp_model)
197+
requantize_resmooth_fused_llm_layers(model)
198+
requantize_resmooth_fused_llm_layers(non_fsdp_model)
202199

203-
torch.distributed.barrier()
200+
torch.distributed.barrier()
204201

205-
for name, sub_module in model.named_modules():
206-
if is_quantlinear(sub_module):
207-
with fsdp2_aware_weight_update(model, sub_module):
208-
_export_quantized_weight(sub_module, torch.float16)
202+
for name, sub_module in model.named_modules():
203+
if is_quantlinear(sub_module):
204+
with fsdp2_aware_weight_update(model, sub_module):
205+
_export_quantized_weight(sub_module, torch.float16)
209206

210-
for name, sub_module in non_fsdp_model.named_modules():
211-
if is_quantlinear(sub_module):
212-
with fsdp2_aware_weight_update(non_fsdp_model, sub_module):
213-
_export_quantized_weight(sub_module, torch.float16)
207+
for name, sub_module in non_fsdp_model.named_modules():
208+
if is_quantlinear(sub_module):
209+
with fsdp2_aware_weight_update(non_fsdp_model, sub_module):
210+
_export_quantized_weight(sub_module, torch.float16)
214211

215-
torch.distributed.barrier()
216-
# Unshard model
217-
model.unshard()
212+
torch.distributed.barrier()
213+
# Unshard model
214+
model.unshard()
218215

219-
_compare_parameters_and_buffers(model, non_fsdp_model)
216+
_compare_parameters_and_buffers(model, non_fsdp_model)
220217

221218

222219
@pytest.mark.parametrize("device_count", [2])

0 commit comments

Comments
 (0)