Skip to content

Commit 5f898a1

Browse files
committed
add some basic tests
1 parent 341fbfc commit 5f898a1

File tree

4 files changed

+84
-4
lines changed

4 files changed

+84
-4
lines changed

src/diffusers/models/unets/unet_2d.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,8 @@ def forward(
291291
# timesteps does not contain any weights and will always return f32 tensors
292292
# but time_embedding might actually be running in fp16. so we need to cast here.
293293
# there might be better ways to encapsulate this.
294-
t_emb = t_emb.to(dtype=self.dtype)
294+
# TODO(aryan): Need to have this reviewed
295+
t_emb = t_emb.to(dtype=sample.dtype)
295296
emb = self.time_embedding(t_emb)
296297

297298
if self.class_embedding is not None:
@@ -301,7 +302,7 @@ def forward(
301302
if self.config.class_embed_type == "timestep":
302303
class_labels = self.time_proj(class_labels)
303304

304-
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
305+
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
305306
emb = emb + class_emb
306307
elif self.class_embedding is None and class_labels is not None:
307308
raise ValueError("class_embedding needs to be initialized in order to use class conditioning")

src/diffusers/models/unets/unet_3d_condition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
9797
"""
9898

9999
_supports_gradient_checkpointing = False
100-
_always_upcast_modules = ["norm.*"]
100+
_always_upcast_modules = ["norm.*", "time_embedding"]
101101

102102
@register_to_config
103103
def __init__(

src/diffusers/models/unets/unet_motion_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2132,7 +2132,8 @@ def forward(
21322132
# timesteps does not contain any weights and will always return f32 tensors
21332133
# but time_embedding might actually be running in fp16. so we need to cast here.
21342134
# there might be better ways to encapsulate this.
2135-
t_emb = t_emb.to(dtype=self.dtype)
2135+
# TODO(aryan): Need to have this reviewed
2136+
t_emb = t_emb.to(dtype=sample.dtype)
21362137

21372138
emb = self.time_embedding(t_emb, timestep_cond)
21382139
aug_emb = None

tests/models/test_modeling_common.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
import copy
17+
import gc
1718
import inspect
1819
import json
1920
import os
@@ -56,6 +57,7 @@
5657
CaptureLogger,
5758
get_python_version,
5859
is_torch_compile,
60+
numpy_cosine_similarity_distance,
5961
require_torch_2,
6062
require_torch_accelerator_with_training,
6163
require_torch_gpu,
@@ -1331,6 +1333,82 @@ def test_variant_sharded_ckpt_right_format(self):
13311333
# Example: diffusion_pytorch_model.fp16-00001-of-00002.safetensors
13321334
assert all(f.split(".")[1].split("-")[0] == variant for f in shard_files)
13331335

1336+
def test_layerwise_upcasting_inference(self):
1337+
torch.manual_seed(0)
1338+
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1339+
model = self.model_class(**config).eval()
1340+
model = model.to(torch_device)
1341+
base_slice = model(**inputs_dict)[0].flatten().detach().cpu().numpy()
1342+
1343+
# fp16-fp32
1344+
torch.manual_seed(0)
1345+
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1346+
model = self.model_class(**config).eval()
1347+
model = model.to(torch_device)
1348+
model.enable_layerwise_upcasting(storage_dtype=torch.float16, compute_dtype=torch.float32)
1349+
layerwise_upcast_slice_fp16 = model(**inputs_dict)[0].flatten().detach().cpu().numpy()
1350+
1351+
# The precision test is not very important for fast tests. In most cases, the outputs will not be the same.
1352+
# We just want to make sure that the layerwise upcasting is working as expected.
1353+
self.assertTrue(numpy_cosine_similarity_distance(base_slice, layerwise_upcast_slice_fp16) < 1.0)
1354+
1355+
# fp8_e4m3-fp32
1356+
torch.manual_seed(0)
1357+
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1358+
model = self.model_class(**config).eval()
1359+
model = model.to(torch_device)
1360+
model.enable_layerwise_upcasting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32)
1361+
layerwise_upcast_slice_fp8_e4m3 = model(**inputs_dict)[0].flatten().detach().cpu().numpy()
1362+
1363+
self.assertTrue(numpy_cosine_similarity_distance(base_slice, layerwise_upcast_slice_fp8_e4m3) < 1.0)
1364+
1365+
# fp8_e5m2-fp32
1366+
torch.manual_seed(0)
1367+
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1368+
model = self.model_class(**config).eval()
1369+
model = model.to(torch_device)
1370+
model.enable_layerwise_upcasting(storage_dtype=torch.float8_e5m2, compute_dtype=torch.float32)
1371+
layerwise_upcast_slice_fp8_e5m2 = model(**inputs_dict)[0].flatten().detach().cpu().numpy()
1372+
1373+
self.assertTrue(numpy_cosine_similarity_distance(base_slice, layerwise_upcast_slice_fp8_e5m2) < 1.0)
1374+
1375+
@require_torch_gpu
1376+
def test_layerwise_upcasting_memory(self):
1377+
# fp32
1378+
gc.collect()
1379+
torch.cuda.empty_cache()
1380+
torch.cuda.reset_peak_memory_stats()
1381+
torch.cuda.synchronize()
1382+
1383+
torch.manual_seed(0)
1384+
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1385+
model = self.model_class(**config).eval()
1386+
model = model.to(torch_device)
1387+
model(**inputs_dict)
1388+
base_memory_footprint = model.get_memory_footprint()
1389+
base_max_memory = torch.cuda.max_memory_allocated()
1390+
1391+
model.to("cpu")
1392+
del model
1393+
1394+
# fp8_e4m3-fp32
1395+
gc.collect()
1396+
torch.cuda.empty_cache()
1397+
torch.cuda.reset_peak_memory_stats()
1398+
torch.cuda.synchronize()
1399+
1400+
torch.manual_seed(0)
1401+
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1402+
model = self.model_class(**config).eval()
1403+
model = model.to(torch_device)
1404+
model.enable_layerwise_upcasting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32)
1405+
model(**inputs_dict)
1406+
fp8_e4m3_memory_footprint = model.get_memory_footprint()
1407+
fp8_e4m3_max_memory = torch.cuda.max_memory_allocated()
1408+
1409+
self.assertTrue(fp8_e4m3_memory_footprint < base_memory_footprint)
1410+
self.assertTrue(fp8_e4m3_max_memory < base_max_memory)
1411+
13341412

13351413
@is_staging_test
13361414
class ModelPushToHubTester(unittest.TestCase):

0 commit comments

Comments
 (0)