@@ -1338,6 +1338,33 @@ def test_variant_sharded_ckpt_right_format(self):
13381338 # Example: diffusion_pytorch_model.fp16-00001-of-00002.safetensors
13391339 assert all (f .split ("." )[1 ].split ("-" )[0 ] == variant for f in shard_files )
13401340
1341+ def test_layerwise_casting_training (self ):
1342+ def test_fn (storage_dtype , compute_dtype ):
1343+ init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1344+
1345+ model = self .model_class (** init_dict )
1346+ model = model .to (torch_device , dtype = compute_dtype )
1347+ model .enable_layerwise_casting (storage_dtype = storage_dtype , compute_dtype = compute_dtype )
1348+ model .train ()
1349+
1350+ inputs_dict = cast_maybe_tensor_dtype (inputs_dict , torch .float32 , compute_dtype )
1351+ output = model (** inputs_dict )
1352+
1353+ if isinstance (output , dict ):
1354+ output = output .to_tuple ()[0 ]
1355+
1356+ input_tensor = inputs_dict [self .main_input_name ]
1357+ noise = torch .randn ((input_tensor .shape [0 ],) + self .output_shape ).to (torch_device )
1358+ loss = torch .nn .functional .mse_loss (output , noise )
1359+ loss .backward ()
1360+
1361+
1362+ test_fn (torch .float16 , torch .float32 )
1363+ test_fn (torch .float8_e4m3fn , torch .float32 )
1364+ test_fn (torch .float8_e5m2 , torch .float32 )
1365+ test_fn (torch .float8_e4m3fn , torch .bfloat16 )
1366+
1367+
13411368 def test_layerwise_casting_inference (self ):
13421369 from diffusers .hooks .layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN , SUPPORTED_PYTORCH_LAYERS
13431370
0 commit comments