@@ -261,18 +261,32 @@ def test_hf_policy_init(policy_setup, num_gpus):
261261
262262
263263@pytest .fixture
264- def training_setup (tokenizer , num_gpus ):
265- """Setup and teardown specifically for training tests."""
264+ def training_setup (tokenizer , request , num_gpus ):
265+ """
266+ Setup and teardown specifically for training tests.
267+
268+ When used without parameterization, uses the default config.
269+ When parameterized, takes any config updates as a dictionary in request.param
270+ and applies them to the basic config.
271+ """
266272 policy = None
267273 cluster = None
268274 data = None
269275 loss_fn = None
270276
277+ # Get config updates from request.param if available
278+ config_updates = {}
279+ config_suffix = ""
280+ if hasattr (request , "param" ) and request .param is not None :
281+ config_updates = request .param
282+ config_suffix = "-" + "-" .join ([f"{ k } ={ v } " for k , v in config_updates .items ()])
283+
271284 try :
272285 # Create resources with unique name
273- cluster_name = f"test-train-{ num_gpus } gpu"
286+ cluster_name = f"test-train-{ num_gpus } gpu{ config_suffix } "
274287 print (
275- f"Creating training virtual cluster '{ cluster_name } ' for { num_gpus } GPUs..."
288+ f"Creating training virtual cluster '{ cluster_name } ' for { num_gpus } GPUs"
289+ f"{ ' with config updates: ' + str (config_updates ) if config_updates else '' } "
276290 )
277291
278292 cluster = RayVirtualCluster (
@@ -283,7 +297,10 @@ def training_setup(tokenizer, num_gpus):
283297 max_colocated_worker_groups = 1 ,
284298 )
285299
286- config = basic_llama_test_config
300+ # Create a config with optional modifications
301+ config = deepcopy (basic_llama_test_config )
302+ if config_updates :
303+ config .update (config_updates )
287304
288305 print ("Creating training HfPolicy..." )
289306 policy = HfPolicy (
@@ -341,8 +358,23 @@ def get_max_gpu_utilization(policy):
341358
342359
343360@pytest .mark .timeout (180 )
344- @pytest .mark .parametrize ("num_gpus" , [1 , 2 ], ids = ["1gpu" , "2gpu" ])
345- def test_hf_policy_training (training_setup , tracker , num_gpus ):
361+ @pytest .mark .parametrize (
362+ "num_gpus, training_setup, config_name" ,
363+ [
364+ (1 , None , "default" ),
365+ (2 , None , "default" ),
366+ (2 , {"fsdp_offload_enabled" : True }, "fsdp_offload" ),
367+ (2 , {"activation_checkpointing_enabled" : True }, "activation_checkpointing" ),
368+ ],
369+ indirect = ["training_setup" ],
370+ ids = [
371+ "1gpu_default" ,
372+ "2gpu_default" ,
373+ "2gpu_fsdp_offload" ,
374+ "2gpu_activation_checkpointing" ,
375+ ],
376+ )
377+ def test_hf_policy_training (training_setup , tracker , num_gpus , config_name ):
346378 def verify_loss_tensor (loss_tensor ):
347379 assert not torch .isnan (loss_tensor ).any (), "Loss should not be NaN"
348380 assert not torch .isinf (loss_tensor ).any (), "Loss should not be Inf"
@@ -357,7 +389,9 @@ def verify_loss_tensor(loss_tensor):
357389 assert loss_fn is not None , "Loss function was not created properly"
358390
359391 # Call prepare_for_training if available
360- print ("\n Preparing for training..." )
392+ print (
393+ f"\n Preparing for training with { num_gpus } GPU(s) and { config_name } config..."
394+ )
361395 policy .prepare_for_training ()
362396
363397 losses = []
@@ -370,7 +404,9 @@ def verify_loss_tensor(loss_tensor):
370404 verify_loss_tensor (loss_tensor )
371405 losses .append (loss_tensor [- 1 ].item ())
372406
373- print (f"Training loss: { results ['loss' ]} " )
407+ print (
408+ f"Training loss with { num_gpus } GPU(s) and { config_name } config: { results ['loss' ]} "
409+ )
374410
375411 policy .finish_training ()
376412 assert losses [0 ] > losses [- 1 ], "Loss should decrease over training iterations"
@@ -379,35 +415,46 @@ def verify_loss_tensor(loss_tensor):
379415 policy
380416 )
381417 print (
382- f"Max GPU Utilization after training: { after_training_mem_allocated :,.1f} MB allocated, "
418+ f"Max GPU Utilization after training with { num_gpus } GPU(s) and { config_name } config : { after_training_mem_allocated :,.1f} MB allocated, "
383419 f"{ after_training_mem_reserved :,.1f} MB reserved"
384420 )
385421 tracker .track (
386- f"after_training_mem_allocated_{ num_gpus } gpu" , after_training_mem_allocated
422+ f"{ num_gpus } gpu_{ config_name } _after_training_mem_allocated" ,
423+ after_training_mem_allocated ,
387424 )
388425 tracker .track (
389- f"after_training_mem_reserved_{ num_gpus } gpu" , after_training_mem_reserved
426+ f"{ num_gpus } gpu_{ config_name } _after_training_mem_reserved" ,
427+ after_training_mem_reserved ,
390428 )
391429
392430 policy .offload_after_refit ()
393431 after_offload_mem_allocated , after_offload_mem_reserved = get_max_gpu_utilization (
394432 policy
395433 )
396434 print (
397- f"Max GPU Utilization after offload: { after_offload_mem_allocated :,.1f} MB allocated, "
435+ f"Max GPU Utilization after offload with { num_gpus } GPU(s) and { config_name } config : { after_offload_mem_allocated :,.1f} MB allocated, "
398436 f"{ after_offload_mem_reserved :,.1f} MB reserved"
399437 )
400438 tracker .track (
401- f"after_offload_mem_allocated_{ num_gpus } gpu" , after_offload_mem_allocated
439+ f"{ num_gpus } gpu_{ config_name } _after_offload_mem_allocated" ,
440+ after_offload_mem_allocated ,
402441 )
403442 tracker .track (
404- f"after_offload_mem_reserved_{ num_gpus } gpu" , after_offload_mem_reserved
443+ f"{ num_gpus } gpu_{ config_name } _after_offload_mem_reserved" ,
444+ after_offload_mem_reserved ,
405445 )
406446
407447 # Compare memory after offload to memory after training
408- assert after_training_mem_allocated > 10_000 , (
409- "Memory after training should be more than 10GB"
410- )
448+ if config_name == "fsdp_offload" :
449+ # With FSDP offload, memory usage after training should already be low
450+ assert after_training_mem_allocated < 1_200 , (
451+ "FSDP offload after training should be less than 1.2GB)"
452+ )
453+ else :
454+ assert after_training_mem_allocated > 10_000 , (
455+ f"Memory after training with { config_name } config should be more than 10GB"
456+ )
457+
411458 assert after_offload_mem_allocated < 1_200 , (
412459 "Memory after offload should be less than 1.2GB"
413460 )
0 commit comments