@@ -382,3 +382,97 @@ def test_update_after_step():
382
382
break
383
383
384
384
assert weights_changed , "EMA weights did not change after update_after_step"
385
+
386
+
387
+ def test_same_output_as_reference_different_step ():
388
+ """Test that our implementation produces identical outputs to the reference when synthesizing at a different step."""
389
+ # Create a simple model
390
+ net = nn .Linear (512 , 512 )
391
+
392
+ # Initialize with same parameters
393
+ sigma_rels = (0.03 , 0.20 )
394
+ update_every = 10
395
+ checkpoint_every = 10
396
+
397
+ print ("\n Initializing with parameters:" )
398
+ print (f"sigma_rels: { sigma_rels } " )
399
+ print (f"update_every: { update_every } " )
400
+ print (f"checkpoint_every: { checkpoint_every } " )
401
+
402
+ # Create both implementations
403
+ ref_emas = ReferencePostHocEMA (
404
+ net ,
405
+ sigma_rels = sigma_rels ,
406
+ update_every = update_every ,
407
+ checkpoint_every_num_steps = checkpoint_every ,
408
+ checkpoint_folder = "./test-checkpoints-ref" ,
409
+ checkpoint_dtype = torch .float32 ,
410
+ )
411
+
412
+ our_emas = OurPostHocEMA .from_model (
413
+ model = net ,
414
+ checkpoint_dir = "./test-checkpoints-our" ,
415
+ update_every = update_every ,
416
+ checkpoint_every = checkpoint_every ,
417
+ sigma_rels = sigma_rels ,
418
+ checkpoint_dtype = torch .float32 ,
419
+ update_after_step = 0 , # Start immediately to match reference behavior
420
+ )
421
+
422
+ # Train both with identical updates
423
+ torch .manual_seed (42 ) # For reproducibility
424
+ net .train ()
425
+
426
+ print ("\n Training:" )
427
+ for step in range (100 ):
428
+ # Apply identical mutations to network
429
+ with torch .no_grad ():
430
+ net .weight .copy_ (torch .randn_like (net .weight ))
431
+ net .bias .copy_ (torch .randn_like (net .bias ))
432
+
433
+ # Update both EMA wrappers
434
+ ref_emas .update ()
435
+ our_emas .update_ (net )
436
+
437
+ if step % 10 == 0 :
438
+ print (f"Step { step } : Updated model and EMAs" )
439
+
440
+ # Synthesize EMA models with same parameters at step 50 (middle of training)
441
+ target_sigma = 0.15
442
+ target_step = 50
443
+ print (f"\n Synthesizing with target_sigma = { target_sigma } at step { target_step } " )
444
+
445
+ # Get reference checkpoints and weights
446
+ ref_checkpoints = sorted (Path ("./test-checkpoints-ref" ).glob ("*.pt" ))
447
+ print ("\n Reference checkpoints:" )
448
+ for cp in ref_checkpoints :
449
+ print (f" { cp .name } " )
450
+
451
+ # Get our checkpoints and weights
452
+ our_checkpoints = sorted (Path ("./test-checkpoints-our" ).glob ("*.pt" ))
453
+ print ("\n Our checkpoints:" )
454
+ for cp in our_checkpoints :
455
+ print (f" { cp .name } " )
456
+
457
+ ref_synth = ref_emas .synthesize_ema_model (sigma_rel = target_sigma , step = target_step )
458
+
459
+ with our_emas .model (net , target_sigma , step = target_step ) as our_synth :
460
+ # Test with same input
461
+ data = torch .randn (1 , 512 )
462
+ ref_output = ref_synth (data )
463
+ our_output = our_synth (data )
464
+
465
+ print ("\n Comparing outputs:" )
466
+ print (f"Reference output mean: { ref_output .mean ().item ():.4f} " )
467
+ print (f"Our output mean: { our_output .mean ().item ():.4f} " )
468
+ print (f"Max difference: { (ref_output - our_output ).abs ().max ().item ():.4f} " )
469
+
470
+ # Verify outputs match
471
+ assert torch .allclose (
472
+ ref_output , our_output , rtol = 1e-4 , atol = 1e-4
473
+ ), "Output from our implementation doesn't match reference"
474
+
475
+ # Clean up
476
+ for path in ["./test-checkpoints-ref" , "./test-checkpoints-our" ]:
477
+ if Path (path ).exists ():
478
+ shutil .rmtree (path )
0 commit comments