@@ -648,7 +648,10 @@ def vjp(data_fields_vjp: AutogradFieldMap) -> AutogradFieldMap:
648
648
"simulation's output. If this is unexpected, please review your "
649
649
"setup or contact customer support for assistance."
650
650
)
651
- return {k : 0 * v for k , v in sim_fields_original .items ()}
651
+ return {
652
+ k : (type (v )(0 * x for x in v ) if isinstance (v , (list , tuple )) else 0 * v )
653
+ for k , v in sim_fields_original .items ()
654
+ }
652
655
653
656
# Run adjoint simulations in batch
654
657
task_names_adj = [f"{ task_name } _adjoint_{ i } " for i in range (len (sims_adj ))]
@@ -670,17 +673,16 @@ def vjp(data_fields_vjp: AutogradFieldMap) -> AutogradFieldMap:
670
673
)
671
674
td .log .info ("Completed local batch adjoint simulations" )
672
675
673
- # sum partial derivatives from each adjoint simulation
676
+ # Process results from local gradient computation
677
+ vjp_fields_dict = {}
674
678
for task_name_adj , sim_data_adj in batch_data_adj .items ():
675
679
td .log .info (f"Processing VJP contribution from { task_name_adj } " )
676
- vjp_fields = postprocess_adj (
680
+ vjp_fields_dict [ task_name_adj ] = postprocess_adj (
677
681
sim_data_adj = sim_data_adj ,
678
682
sim_data_orig = sim_data_orig ,
679
683
sim_data_fwd = sim_data_fwd ,
680
684
sim_fields_keys = sim_fields_keys ,
681
685
)
682
- for k , v in vjp_fields .items ():
683
- vjp_traced_fields [k ] = vjp_traced_fields .get (k , 0 ) + v
684
686
else :
685
687
td .log .info ("Starting server-side batch of adjoint simulations ..." )
686
688
@@ -699,15 +701,24 @@ def vjp(data_fields_vjp: AutogradFieldMap) -> AutogradFieldMap:
699
701
tname_adj : sim .updated_copy (simulation_type = "autograd_bwd" , deep = False )
700
702
for tname_adj , sim in sims_adj_dict .items ()
701
703
}
702
- vjp_traced_fields_dict = _run_async_tidy3d_bwd (
704
+ vjp_fields_dict = _run_async_tidy3d_bwd (
703
705
simulations = sims_adj_dict ,
704
706
** run_kwargs ,
705
707
)
706
708
td .log .info ("Completed server-side batch of adjoint simulations." )
707
709
708
- for fields in vjp_traced_fields_dict .values ():
709
- for k , v in fields .items ():
710
- vjp_traced_fields [k ] = vjp_traced_fields .get (k , 0 ) + v
710
+ # Accumulate gradients from all adjoint simulations
711
+ for task_name_adj , vjp_fields in vjp_fields_dict .items ():
712
+ td .log .info (f"Processing VJP contribution from { task_name_adj } " )
713
+ for k , v in vjp_fields .items ():
714
+ if k in vjp_traced_fields :
715
+ val = vjp_traced_fields [k ]
716
+ if isinstance (val , (list , tuple )) and isinstance (v , (list , tuple )):
717
+ vjp_traced_fields [k ] = type (val )(x + y for x , y in zip (val , v ))
718
+ else :
719
+ vjp_traced_fields [k ] += v
720
+ else :
721
+ vjp_traced_fields [k ] = v
711
722
712
723
td .log .debug (f"Computed gradients for { len (vjp_traced_fields )} fields" )
713
724
return vjp_traced_fields
@@ -765,7 +776,8 @@ def vjp(data_fields_dict_vjp: dict[str, AutogradFieldMap]) -> dict[str, Autograd
765
776
if not sims_adj :
766
777
td .log .debug (f"Adjoint simulation for task '{ task_name } ' contains no sources." )
767
778
sim_fields_vjp_dict [task_name ] = {
768
- k : 0 * v for k , v in sim_fields_original_dict [task_name ].items ()
779
+ k : (type (v )(0 * x for x in v ) if isinstance (v , (list , tuple )) else 0 * v )
780
+ for k , v in sim_fields_original_dict [task_name ].items ()
769
781
}
770
782
continue
771
783
@@ -781,6 +793,9 @@ def vjp(data_fields_dict_vjp: dict[str, AutogradFieldMap]) -> dict[str, Autograd
781
793
)
782
794
return sim_fields_vjp_dict
783
795
796
+ # Dictionary to store VJP results from all adjoint simulations
797
+ vjp_results = {}
798
+
784
799
if local_gradient :
785
800
# Run all adjoint simulations in a single batch
786
801
path_dir = Path (run_async_kwargs .pop ("path_dir" ))
@@ -791,28 +806,20 @@ def vjp(data_fields_dict_vjp: dict[str, AutogradFieldMap]) -> dict[str, Autograd
791
806
all_sims_adj , path_dir = str (path_dir_adj ), ** run_async_kwargs
792
807
)
793
808
794
- # Process results for each original task
809
+ # Process results for each adjoint task
795
810
for adj_task_name , sim_data_adj in batch_data_adj .items ():
796
811
task_name = task_name_mapping [adj_task_name ]
797
812
sim_data_orig = sim_data_orig_dict [task_name ]
798
813
sim_data_fwd = sim_data_fwd_dict [task_name ]
799
814
sim_fields_keys = sim_fields_keys_dict [task_name ]
800
815
801
816
# Compute VJP contribution
802
- sim_fields_vjp = postprocess_adj (
817
+ vjp_results [ adj_task_name ] = postprocess_adj (
803
818
sim_data_adj = sim_data_adj ,
804
819
sim_data_orig = sim_data_orig ,
805
820
sim_data_fwd = sim_data_fwd ,
806
821
sim_fields_keys = sim_fields_keys ,
807
822
)
808
-
809
- # Sum contributions for each original task
810
- if task_name in sim_fields_vjp_dict :
811
- for k , v in sim_fields_vjp .items ():
812
- sim_fields_vjp_dict [task_name ][k ] += v
813
- else :
814
- sim_fields_vjp_dict [task_name ] = sim_fields_vjp
815
-
816
823
else :
817
824
# Set up parent tasks mapping for all adjoint simulations
818
825
parent_tasks = {}
@@ -830,19 +837,27 @@ def vjp(data_fields_dict_vjp: dict[str, AutogradFieldMap]) -> dict[str, Autograd
830
837
}
831
838
832
839
# Run all adjoint simulations in a single batch
833
- sim_fields_vjp_dict_adj = _run_async_tidy3d_bwd (
840
+ vjp_results = _run_async_tidy3d_bwd (
834
841
simulations = all_sims_adj ,
835
842
** run_async_kwargs ,
836
843
)
837
844
838
- # Combine results for each original task
839
- for adj_task_name , fields in sim_fields_vjp_dict_adj .items ():
840
- task_name = task_name_mapping [adj_task_name ]
841
- if task_name in sim_fields_vjp_dict :
842
- for k , v in fields .items ():
845
+ # Accumulate gradients from all adjoint simulations
846
+ for adj_task_name , vjp_fields in vjp_results .items ():
847
+ task_name = task_name_mapping [adj_task_name ]
848
+
849
+ if task_name not in sim_fields_vjp_dict :
850
+ sim_fields_vjp_dict [task_name ] = {}
851
+
852
+ for k , v in vjp_fields .items ():
853
+ if k in sim_fields_vjp_dict [task_name ]:
854
+ val = sim_fields_vjp_dict [task_name ][k ]
855
+ if isinstance (val , (list , tuple )) and isinstance (v , (list , tuple )):
856
+ sim_fields_vjp_dict [task_name ][k ] = type (val )(x + y for x , y in zip (val , v ))
857
+ else :
843
858
sim_fields_vjp_dict [task_name ][k ] += v
844
859
else :
845
- sim_fields_vjp_dict [task_name ] = fields
860
+ sim_fields_vjp_dict [task_name ][ k ] = v
846
861
847
862
return sim_fields_vjp_dict
848
863
0 commit comments