@@ -1338,6 +1338,36 @@ def test_variant_sharded_ckpt_right_format(self):
13381338 # Example: diffusion_pytorch_model.fp16-00001-of-00002.safetensors
13391339 assert all (f .split ("." )[1 ].split ("-" )[0 ] == variant for f in shard_files )
13401340
1341+ def test_layerwise_casting_training (self ):
1342+ def test_fn (storage_dtype , compute_dtype ):
1343+ if torch .device (torch_device ).type == "cpu" and compute_dtype == torch .bfloat16 :
1344+ return
1345+ init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1346+
1347+ model = self .model_class (** init_dict )
1348+ model = model .to (torch_device , dtype = compute_dtype )
1349+ model .enable_layerwise_casting (storage_dtype = storage_dtype , compute_dtype = compute_dtype )
1350+ model .train ()
1351+
1352+ inputs_dict = cast_maybe_tensor_dtype (inputs_dict , torch .float32 , compute_dtype )
1353+ with torch .amp .autocast (device_type = torch .device (torch_device ).type ):
1354+ output = model (** inputs_dict )
1355+
1356+ if isinstance (output , dict ):
1357+ output = output .to_tuple ()[0 ]
1358+
1359+ input_tensor = inputs_dict [self .main_input_name ]
1360+ noise = torch .randn ((input_tensor .shape [0 ],) + self .output_shape ).to (torch_device )
1361+ noise = cast_maybe_tensor_dtype (noise , torch .float32 , compute_dtype )
1362+ loss = torch .nn .functional .mse_loss (output , noise )
1363+
1364+ loss .backward ()
1365+
1366+ test_fn (torch .float16 , torch .float32 )
1367+ test_fn (torch .float8_e4m3fn , torch .float32 )
1368+ test_fn (torch .float8_e5m2 , torch .float32 )
1369+ test_fn (torch .float8_e4m3fn , torch .bfloat16 )
1370+
13411371 def test_layerwise_casting_inference (self ):
13421372 from diffusers .hooks .layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN , SUPPORTED_PYTORCH_LAYERS
13431373
0 commit comments