@@ -125,7 +125,7 @@ class TuningConfig:
125125 inputs_pre_hook : Callable = None
126126 use_cold_l2_cache : bool = False
127127 use_cuda_graph : bool = True
128- distributed_tuning_strategy : DistributedTuningStrategy = DistributedTuningStrategy .BROADCAST
128+ distributed_tuning_strategy : DistributedTuningStrategy = DistributedTuningStrategy .INDEPENDENT
129129
130130
131131@dataclass (unsafe_hash = True )
@@ -1524,16 +1524,9 @@ def cache_sync_pp_recv(self):
15241524 self .profiling_cache .merge_cache_data (profiling_cache )
15251525
15261526 def cache_sync_pp_send (self ):
1527- # Op with INDEPENDENT strategy shall not be send
1527+ # Send all cache contents to next pp rank
15281528 if self .mapping .has_pp () and not self .mapping .is_last_pp_rank :
1529- dependent_custom_ops = [
1530- op for op , strategy in
1531- self ._map_op_to_distributed_strategy .items ()
1532- if strategy != DistributedTuningStrategy .INDEPENDENT
1533- ]
1534- dependent_custom_ops_cache = dict ()
1535- for op in dependent_custom_ops :
1536- dependent_custom_ops_cache .update (
1537- self .profiling_cache .get_specific_custom_op (op ))
1538- self ._dist .send_object (dependent_custom_ops_cache ,
1539- self .mapping .next_pp_rank ())
1529+ self ._dist .send_object (
1530+ self .profiling_cache .cache ,
1531+ self .mapping .next_pp_rank (),
1532+ )
0 commit comments