@@ -187,6 +187,27 @@ def test_optim_default_dtype_bf16(self, optim_name, device):
187
187
finally :
188
188
torch .set_default_dtype (old_dtype )
189
189
190
+ @parametrize ("optim_name" , ["Adam8bit" , "Adam4bit" , "AdamFp8" ])
191
+ @parametrize ("device" , _DEVICES )
192
+ def test_param_groups (self , optim_name , device ):
193
+ if optim_name .endswith ("Fp8" ) and device == "cuda" :
194
+ if torch .cuda .get_device_capability () < (8 , 9 ):
195
+ pytest .skip ("FP8 CUDA requires compute capability >= 8.9" )
196
+
197
+ model = nn .Sequential (nn .Linear (32 , 256 ), nn .ReLU (), nn .Linear (256 , 32 ))
198
+ model .to (device = device )
199
+ param_groups = [
200
+ dict (params = list (model [0 ].parameters ()), lr = 1e-4 ),
201
+ dict (params = list (model [2 ].parameters ()), lr = 1e-5 ),
202
+ ]
203
+ optimizer = getattr (optim , optim_name )(param_groups )
204
+
205
+ x = torch .randn (4 , 32 , device = device )
206
+ loss = model (x ).sum ()
207
+ loss .backward ()
208
+ optimizer .step ()
209
+ optimizer .zero_grad ()
210
+
190
211
# aten.slice is required for dcp.load() when world size changes i.e. re-sharding
191
212
# however, it's cumbersome to test it directly, since we would need to run distributed
192
213
# test 2 times with different world size, and persist checkpoint across the 2 runs.
0 commit comments