Skip to content

Commit 556bdb2

Browse files
isururanawakameta-codesync[bot]
authored andcommitted
Objectives with input dist latencies
Summary: This diff introduce two objectives for LP Planner considering input_dist in Critical Path. - BALANCE_ACROSS_ALL_SYNC_POINTS_WITH_INPUT_DIST max(fwd compute) + max(bwd compute) + sum_{module, shardtype} max(fwd comms for module) + max(bwd comms for module, shardtype) + sum_{module, shardtype} max(bwd comms for module, shardtype} + sum_{module} max(input_dist_comms for module) - BALANCE_ACROSS_ALL_SYNC_POINTS_WITH_COMBINED_FWD_COMMS_INPUT_DIST max(fwd compute) + max(bwd compute) + sum_{module, shardtype} max(fwd comms + input_dist_comms for module) + max(bwd comms for module, shardtype) + sum_{module, shardtype} max(bwd comms for module, shardtype} Differential Revision: D87389540
1 parent 791373f commit 556bdb2

File tree

4 files changed

+135
-0
lines changed

4 files changed

+135
-0
lines changed

torchrec/distributed/planner/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@
4242
WEIGHTED_KERNEL_MULTIPLIER: float = 1.1 # empirical studies
4343
DP_ELEMENTWISE_KERNELS_PERF_FACTOR: float = 9.22 # empirical studies
4444

45+
# TODO: This can be hardware dependent, need more empirical results to verify
46+
A2A_INVERSE_BANDWITH_COEFFICIENT: float = 1 # empirical studies
47+
4548

4649
def kernel_bw_lookup(
4750
compute_device: str,

torchrec/distributed/planner/shard_estimators.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
1919
from torchrec.distributed.logger import _torchrec_method_logger
2020
from torchrec.distributed.planner.constants import (
21+
A2A_INVERSE_BANDWITH_COEFFICIENT,
2122
BATCHED_COPY_PERF_FACTOR,
2223
BIGINT_DTYPE,
2324
DP_ELEMENTWISE_KERNELS_PERF_FACTOR,
@@ -461,6 +462,49 @@ def _get_expected_cache_prefetch_time(
461462
prefetch_bytes = expected_cache_fetches * emb_dim * table_data_type_size
462463
return prefetch_bytes / hbm_to_ddr_mem_bw
463464

465+
@classmethod
466+
def _input_dist_expected_latency(
467+
cls,
468+
batch_sizes: List[int],
469+
world_size: int,
470+
local_world_size: int,
471+
num_poolings: List[float],
472+
input_lengths: List[float],
473+
fwd_a2a_comm_data_type_size: float,
474+
comms_bandwidths: GeneralizedCommsBandwidth,
475+
) -> float:
476+
"""
477+
Calculates the expected latency for A2A input dist.
478+
479+
Args:
480+
batch_sizes (int): The batch size for each input feature.
481+
world_size (int): The total number of devices in the distributed setup.
482+
local_world_size (int): The number of devices on a single host.
483+
num_poolings (List[float]): Number of poolings per sample for each input feature.
484+
input_lengths (List[float]): Average number of lookups per input feature.
485+
fwd_a2a_comm_data_type_size (float): Data type size (in bytes) for forward all-to-all communication.
486+
comms_bandwidths (GeneralizedCommsBandwidth): Object to query communication bandwidths.
487+
488+
Returns:
489+
float: The expected latency (in seconds) for input distribution.
490+
"""
491+
batch_inputs = sum(
492+
[x * y * z for x, y, z in zip(input_lengths, num_poolings, batch_sizes)]
493+
)
494+
input_read_size = math.ceil(
495+
batch_inputs * world_size * fwd_a2a_comm_data_type_size
496+
)
497+
498+
comms_bw = comms_bandwidths.get_bw(
499+
world_size=world_size,
500+
local_world_size=local_world_size,
501+
collective_type=CollectiveType.ALL_TO_ALL,
502+
)
503+
message_bw = input_read_size / comms_bw
504+
input_dist_latency = message_bw * A2A_INVERSE_BANDWITH_COEFFICIENT
505+
506+
return input_dist_latency
507+
464508
@classmethod
465509
def _get_tw_sharding_perf(
466510
cls,
@@ -551,6 +595,15 @@ def _get_tw_sharding_perf(
551595
hbm_to_ddr_mem_bw, expected_cache_fetches, emb_dim, table_data_type_size
552596
)
553597

598+
input_dist_comms = cls._input_dist_expected_latency(
599+
batch_sizes=batch_sizes,
600+
world_size=world_size,
601+
local_world_size=local_world_size,
602+
num_poolings=num_poolings,
603+
input_lengths=input_lengths,
604+
fwd_a2a_comm_data_type_size=input_data_type_size,
605+
comms_bandwidths=comms_bandwidths,
606+
)
554607
# in order of model parallel execution, starting with:
555608
# BWD DP -> BWD MP ... FWD MP -> FWD DP
556609
return Perf(
@@ -559,6 +612,7 @@ def _get_tw_sharding_perf(
559612
bwd_compute=bwd_compute + bwd_grad_indice_weights_kernel,
560613
bwd_comms=bwd_comms,
561614
prefetch_compute=prefetch_compute,
615+
input_dist_comms=input_dist_comms,
562616
)
563617

564618
@classmethod
@@ -658,13 +712,23 @@ def _get_rw_sharding_perf(
658712
emb_dim,
659713
table_data_type_size,
660714
)
715+
input_dist_comms = cls._input_dist_expected_latency(
716+
batch_sizes=batch_sizes,
717+
world_size=world_size,
718+
local_world_size=local_world_size,
719+
num_poolings=num_poolings,
720+
input_lengths=input_lengths,
721+
fwd_a2a_comm_data_type_size=input_data_type_size,
722+
comms_bandwidths=comms_bandwidths,
723+
)
661724

662725
return Perf(
663726
fwd_compute=fwd_compute,
664727
fwd_comms=fwd_comms,
665728
bwd_compute=bwd_compute + bwd_grad_indice_weights_kernel,
666729
bwd_comms=bwd_comms + bwd_batched_copy,
667730
prefetch_compute=prefetch_compute,
731+
input_dist_comms=input_dist_comms,
668732
)
669733

670734
@classmethod
@@ -790,13 +854,23 @@ def _get_twrw_sharding_perf(
790854
emb_dim,
791855
table_data_type_size,
792856
)
857+
input_dist_comms = cls._input_dist_expected_latency(
858+
batch_sizes=batch_sizes,
859+
world_size=world_size,
860+
local_world_size=local_world_size,
861+
num_poolings=num_poolings,
862+
input_lengths=input_lengths,
863+
fwd_a2a_comm_data_type_size=input_data_type_size,
864+
comms_bandwidths=comms_bandwidths,
865+
)
793866

794867
return Perf(
795868
fwd_compute=fwd_compute,
796869
fwd_comms=fwd_comms,
797870
bwd_compute=bwd_compute + bwd_grad_indice_weights_kernel,
798871
bwd_comms=bwd_comms + bwd_batched_copy,
799872
prefetch_compute=prefetch_compute,
873+
input_dist_comms=input_dist_comms,
800874
)
801875

802876
@classmethod

0 commit comments

Comments
 (0)