1818from torchrec .distributed .embedding_types import EmbeddingComputeKernel
1919from torchrec .distributed .logger import _torchrec_method_logger
2020from torchrec .distributed .planner .constants import (
21+ A2A_INVERSE_BANDWITH_COEFFICIENT ,
2122 BATCHED_COPY_PERF_FACTOR ,
2223 BIGINT_DTYPE ,
2324 DP_ELEMENTWISE_KERNELS_PERF_FACTOR ,
@@ -461,6 +462,49 @@ def _get_expected_cache_prefetch_time(
461462 prefetch_bytes = expected_cache_fetches * emb_dim * table_data_type_size
462463 return prefetch_bytes / hbm_to_ddr_mem_bw
463464
465+ @classmethod
466+ def _input_dist_expected_latency (
467+ cls ,
468+ batch_sizes : List [int ],
469+ world_size : int ,
470+ local_world_size : int ,
471+ num_poolings : List [float ],
472+ input_lengths : List [float ],
473+ fwd_a2a_comm_data_type_size : float ,
474+ comms_bandwidths : GeneralizedCommsBandwidth ,
475+ ) -> float :
476+ """
477+ Calculates the expected latency for A2A input dist.
478+
479+ Args:
480+ batch_sizes (int): The batch size for each input feature.
481+ world_size (int): The total number of devices in the distributed setup.
482+ local_world_size (int): The number of devices on a single host.
483+ num_poolings (List[float]): Number of poolings per sample for each input feature.
484+ input_lengths (List[float]): Average number of lookups per input feature.
485+ fwd_a2a_comm_data_type_size (float): Data type size (in bytes) for forward all-to-all communication.
486+ comms_bandwidths (GeneralizedCommsBandwidth): Object to query communication bandwidths.
487+
488+ Returns:
489+ float: The expected latency (in seconds) for input distribution.
490+ """
491+ batch_inputs = sum (
492+ [x * y * z for x , y , z in zip (input_lengths , num_poolings , batch_sizes )]
493+ )
494+ input_read_size = math .ceil (
495+ batch_inputs * world_size * fwd_a2a_comm_data_type_size
496+ )
497+
498+ comms_bw = comms_bandwidths .get_bw (
499+ world_size = world_size ,
500+ local_world_size = local_world_size ,
501+ collective_type = CollectiveType .ALL_TO_ALL ,
502+ )
503+ message_bw = input_read_size / comms_bw
504+ input_dist_latency = message_bw * A2A_INVERSE_BANDWITH_COEFFICIENT
505+
506+ return input_dist_latency
507+
464508 @classmethod
465509 def _get_tw_sharding_perf (
466510 cls ,
@@ -551,6 +595,15 @@ def _get_tw_sharding_perf(
551595 hbm_to_ddr_mem_bw , expected_cache_fetches , emb_dim , table_data_type_size
552596 )
553597
598+ input_dist_comms = cls ._input_dist_expected_latency (
599+ batch_sizes = batch_sizes ,
600+ world_size = world_size ,
601+ local_world_size = local_world_size ,
602+ num_poolings = num_poolings ,
603+ input_lengths = input_lengths ,
604+ fwd_a2a_comm_data_type_size = input_data_type_size ,
605+ comms_bandwidths = comms_bandwidths ,
606+ )
554607 # in order of model parallel execution, starting with:
555608 # BWD DP -> BWD MP ... FWD MP -> FWD DP
556609 return Perf (
@@ -559,6 +612,7 @@ def _get_tw_sharding_perf(
559612 bwd_compute = bwd_compute + bwd_grad_indice_weights_kernel ,
560613 bwd_comms = bwd_comms ,
561614 prefetch_compute = prefetch_compute ,
615+ input_dist_comms = input_dist_comms ,
562616 )
563617
564618 @classmethod
@@ -658,13 +712,23 @@ def _get_rw_sharding_perf(
658712 emb_dim ,
659713 table_data_type_size ,
660714 )
715+ input_dist_comms = cls ._input_dist_expected_latency (
716+ batch_sizes = batch_sizes ,
717+ world_size = world_size ,
718+ local_world_size = local_world_size ,
719+ num_poolings = num_poolings ,
720+ input_lengths = input_lengths ,
721+ fwd_a2a_comm_data_type_size = input_data_type_size ,
722+ comms_bandwidths = comms_bandwidths ,
723+ )
661724
662725 return Perf (
663726 fwd_compute = fwd_compute ,
664727 fwd_comms = fwd_comms ,
665728 bwd_compute = bwd_compute + bwd_grad_indice_weights_kernel ,
666729 bwd_comms = bwd_comms + bwd_batched_copy ,
667730 prefetch_compute = prefetch_compute ,
731+ input_dist_comms = input_dist_comms ,
668732 )
669733
670734 @classmethod
@@ -790,13 +854,23 @@ def _get_twrw_sharding_perf(
790854 emb_dim ,
791855 table_data_type_size ,
792856 )
857+ input_dist_comms = cls ._input_dist_expected_latency (
858+ batch_sizes = batch_sizes ,
859+ world_size = world_size ,
860+ local_world_size = local_world_size ,
861+ num_poolings = num_poolings ,
862+ input_lengths = input_lengths ,
863+ fwd_a2a_comm_data_type_size = input_data_type_size ,
864+ comms_bandwidths = comms_bandwidths ,
865+ )
793866
794867 return Perf (
795868 fwd_compute = fwd_compute ,
796869 fwd_comms = fwd_comms ,
797870 bwd_compute = bwd_compute + bwd_grad_indice_weights_kernel ,
798871 bwd_comms = bwd_comms + bwd_batched_copy ,
799872 prefetch_compute = prefetch_compute ,
873+ input_dist_comms = input_dist_comms ,
800874 )
801875
802876 @classmethod
0 commit comments